import math
import matplotlib.pyplot as plt
import random


############################################################
# random walks
############################################################


def apply_step(loc, step_fn):
    step = step_fn()
    return [loc[0] + step[0], loc[1] + step[1]]


def perform_walk(start, num_steps, step_fn):
    history = [start]
    loc = start
    for _ in range(num_steps):
        history.append(apply_step(history[-1], step_fn))
    return history


def uniform_step():
    return random.choice([[1, 0], [-1, 0], [0, 1], [0, -1]])


def left_bias_step():
    return random.choice([[0.9, 0], [-1.1, 0], [0, 1], [0, -1]])


def right_bias_step():
    return random.choice([[1.1, 0], [-0.9, 0], [0, 1], [0, -1]])


def corner_step():
    step = 1 / 2**0.5
    return random.choice([
        [step, step], [-step, step], [-step, -step], [step, -step]
    ])


def continuous_step():
    return [random.uniform(-1, 1), random.uniform(-1, 1)]


# print(perform_walk([0, 0], 10, uniform_step))
# print(perform_walk([0, 0], 10, left_bias_step))
# print(perform_walk([0, 0], 10, right_bias_step))
# print(perform_walk([0, 0], 10, corner_step))
# print(perform_walk([0, 0], 10, continuous_step))


def print_walks():
    for step_fn in [
        uniform_step, left_bias_step, right_bias_step,
        corner_step, continuous_step,
    ]:
        for loc in perform_walk([0, 0], 10, step_fn):
            print(f"{loc[0]:10.4f}  {loc[1]:10.4f}")
        print()


# print_walks()


############################################################
# plotting a random walk
############################################################


def split_xy(points):
    xpos = [None] * len(points)
    ypos = [None] * len(points)
    for i in range(len(points)):
        xpos[i] = points[i][0]
        ypos[i] = points[i][1]
    return [xpos, ypos]


def plot_walk(num_steps, step_fn=uniform_step):
    path = perform_walk([0, 0], num_steps, step_fn)
    data = split_xy(path)
    xpos = data[0]
    ypos = data[1]

    plt.figure()
    plt.plot(xpos, ypos, color="orange")
    plt.plot(xpos[0], ypos[0], color="red", marker="o")
    plt.plot(
        xpos[-1], ypos[-1],
        color="xkcd:bright blue",
        marker="x", markersize=10, markeredgewidth=3,
    )
    plt.title(f"Random walk of {num_steps} steps, {step_fn.__name__}")
    plt.xlabel("X position")
    plt.ylabel("Y position")
    plt.grid()
    plt.show()


# plot_walk(num_steps=100)
# plot_walk(num_steps=100, step_fn=left_bias_step)
# plot_walk(num_steps=100, step_fn=right_bias_step)
# plot_walk(num_steps=100, step_fn=corner_step)
# plot_walk(num_steps=100, step_fn=continuous_step)


############################################################
# plot end locations
############################################################


def collect_walk_ends(num_walks, num_steps, step_fn):
    end_locs = []
    for _ in range(num_walks):
        path = perform_walk([0, 0], num_steps, step_fn)
        end_locs.append(path[-1])
    return end_locs


def mean(vals):
    return sum(vals) / len(vals)


def plot_walk_ends(num_walks, num_steps, step_fn=uniform_step):
    end_locs = collect_walk_ends(num_walks, num_steps, step_fn)
    xpos, ypos = split_xy(end_locs)

    plt.figure()
    plt.plot([0], [0], color="red", marker="o", markersize=15)  # plot origin
    plt.scatter(xpos, ypos, marker="x")
    plt.plot(mean(xpos), mean(ypos), color="cyan", marker="*", markersize=10)
    plt.title(f"Ending locations after {num_steps} steps, {step_fn.__name__}")
    plt.xlabel("X position")
    plt.ylabel("Y position")
    plt.xlim(-100, 100)
    plt.ylim(-100, 100)
    plt.grid()


# plot_walk_ends(num_walks=100, num_steps=1000)
# plot_walk_ends(num_walks=100, num_steps=1000, step_fn=left_bias_step)
# plot_walk_ends(num_walks=100, num_steps=1000, step_fn=right_bias_step)


############################################################
# plot end distances
############################################################


def get_distance(loc1, loc2):
    x1, y1 = loc1
    x2, y2 = loc2
    return ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5
    # return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
    # return math.dist(loc1, loc2)


def average_distance(num_walks, num_steps, step_fn=uniform_step):
    origin = [0, 0]
    end_locs = collect_walk_ends(num_walks, num_steps, step_fn)
    distances = []
    for loc in end_locs:
        distances.append(get_distance(origin, loc))
    return mean(distances)


# plot_walk_ends(num_walks=100, num_steps=100)
# plot_walk_ends(num_walks=100, num_steps=1000)
# plot_walk_ends(num_walks=100, num_steps=10000)
# print(average_distance(num_walks=100, num_steps=100))
# print(average_distance(num_walks=100, num_steps=1000))
# print(average_distance(num_walks=100, num_steps=10000))


def plot_walk_distances(max_steps, step_fn=uniform_step, verbose=False):
    num_walks = 100
    steps_range = range(10, max_steps + 1, 10)

    distances = []
    for num_steps in steps_range:
        distances.append(average_distance(num_walks, num_steps, step_fn))
        if verbose and num_steps % 100 == 0:
            print(f"processed {num_steps = }")

    plt.figure()
    plt.plot(steps_range, distances)
    plt.title(f"Ending distances versus number of steps, {step_fn.__name__}")
    plt.xlabel("Number of steps")
    plt.ylabel("Average distance from origin")
    plt.grid()


# plot_walk_distances(max_steps=1000, verbose=True)
# plot_walk_distances(max_steps=1000, verbose=True, step_fn=left_bias_step)
# plot_walk_distances(max_steps=1000, verbose=True, step_fn=corner_step)


############################################################
# show all plots
############################################################


# plt.show()
