import math
import matplotlib.pyplot as plt
import random


############################################################
# compare random walks with different step distributions
############################################################


def random_walk_gaussian(num_steps, step_stdev=1.0):
    x, y = [0.0, 0.0]
    path = [[x, y]]
    for _ in range(num_steps):
        dx = random.gauss(0, step_stdev)
        dy = random.gauss(0, step_stdev)
        x += dx
        y += dy
        path.append([x, y])
    return path


def random_walk_fixed_step_radius(num_steps):
    x, y = [0.0, 0.0]
    path = [[x, y]]
    for _ in range(num_steps):
        dir = random.uniform(0, 2 * math.pi)
        dx = math.cos(dir)
        dy = math.sin(dir)
        x += dx
        y += dy
        path.append([x, y])
    return path


########################################


num_steps = 2000
path_gauss = random_walk_gaussian(num_steps)
path_uniform = random_walk_fixed_step_radius(num_steps)

x_g, y_g = [[], []]
for p in path_gauss:
    x_g.append(p[0])
    y_g.append(p[1])

x_u, y_u = [[], []]
for p in path_uniform:
    x_u.append(p[0])
    y_u.append(p[1])

num_steps = [
    10, 100, 500,
    1000, 2000, 5000,
    10_000, 20_000, 50_000,
]
num_trials = 100
avg_dist_gauss = []
avg_dist_uniform = []
for steps in num_steps:
    total_g = 0.0
    total_u = 0.0
    for _ in range(num_trials):
        path_end_loc_g = random_walk_gaussian(steps)[-1]
        path_end_loc_u = random_walk_fixed_step_radius(steps)[-1]
        total_g += (path_end_loc_g[0] ** 2 + path_end_loc_g[1] ** 2) ** 0.5
        total_u += (path_end_loc_u[0] ** 2 + path_end_loc_u[1] ** 2) ** 0.5
    avg_dist_gauss.append(total_g / num_trials)
    avg_dist_uniform.append(total_u / num_trials)
    print(f"ran walks with {steps} steps")


########################################


def plot_walk_comparison_axes(num_steps,
                              x_g, y_g, x_u, y_u,
                              avg_dist_gauss, avg_dist_uniform):

    fig, axs = plt.subplots(1, 2, figsize=(14, 6))

    axs[0].plot(x_g, y_g, alpha=0.7, label="Gaussian Walk")
    axs[0].plot(x_u, y_u, alpha=0.7, label="Uniform Walk")
    axs[0].scatter(0, 0, color='black', s=80, zorder=3, label="Origin")
    axs[0].set_title("Gaussian vs. Uniform Random Walk Paths")
    axs[0].set_xlabel("x")
    axs[0].set_ylabel("y")
    axs[0].legend()

    axs[1].plot(num_steps, avg_dist_gauss, label="Gaussian Walk")
    axs[1].plot(num_steps, avg_dist_uniform, label="Uniform Walk")
    axs[1].set_title("Gaussian vs. Uniform Random Walk Distance Comparison")
    axs[1].set_xlabel("Number of Steps")
    axs[1].set_ylabel("Average Distance from Origin")
    axs[1].legend()

    plt.tight_layout()
    plt.show()


########################################


def plot_walk_comparison_subplot(num_steps,
                                 x_g, y_g, x_u, y_u,
                                 avg_dist_gauss, avg_dist_uniform):

    plt.figure(figsize=(14, 6))

    plt.subplot(1, 2, 1)
    plt.plot(x_g, y_g, alpha=0.7, label="Gaussian Walk")
    plt.plot(x_u, y_u, alpha=0.7, label="Uniform Walk")
    plt.scatter(0, 0, color='black', s=80, zorder=3, label="Origin")
    plt.title("Gaussian vs. Uniform Random Walk Paths")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(num_steps, avg_dist_gauss, label="Gaussian Walk")
    plt.plot(num_steps, avg_dist_uniform, label="Uniform Walk")
    plt.title("Gaussian vs. Uniform Random Walk Distance Comparison")
    plt.xlabel("Num Steps")
    plt.ylabel("Average End Distance")
    plt.legend()

    plt.tight_layout()
    plt.show()


# plot_walk_comparison_axes(num_steps, x_g, y_g, x_u, y_u, avg_dist_gauss, avg_dist_uniform)
# plot_walk_comparison_subplot(num_steps, x_g, y_g, x_u, y_u, avg_dist_gauss, avg_dist_uniform)
