import math
import random
import matplotlib.pyplot as plt


############################################################
# bisection search
############################################################


# exhaustive enumeration example from lec02

# total_tickets = 1000
# for alyssa in range(total_tickets + 1):
#     ben = alyssa - 20
#     cindy = alyssa * 2
#     if ben < 0:
#         continue
#     if alyssa + ben + cindy == total_tickets:
#         print("Alyssa:", alyssa, "tickets")
#         print("Ben:   ", ben, "tickets")
#         print("Cindy: ", cindy, "tickets")
#         break


########################################


# EXERCISE: Rewrite the above code as a bisection search function
# that takes in the total number of tickets.


# given helper function to compute ben and cindy's tickets
def compute_total_tickets(alyssa):
    ben = alyssa - 20
    cindy = alyssa * 2
    total = alyssa + ben + cindy
    return [alyssa, ben, cindy, total]


# Solution:


def find_ticket_distribution(total_tickets):
    pass


total_tickets = 1000
result = find_ticket_distribution(total_tickets)

if result:
    alyssa, ben, cindy = result
    print(f"Alyssa: {alyssa} tickets")
    print(f"Ben:    {ben} tickets")
    print(f"Cindy:  {cindy} tickets")
else:
    print("No solution found")


########################################


# EXERCISE: How would you alter the search code if compute_total_tickets()
# was as follows:

# def compute_total_tickets(alyssa):
# ben = 500 - alyssa
# cindy = 600 - 2 * alyssa
# total = alyssa + ben + cindy
# return [alyssa, ben, cindy, total]


############################################################
# random walk CLT
############################################################


def random_walk_gaussian_end_distance(num_steps, step_stdev=1.0):
    total_dist_squared = 0.0
    dx_total = 0
    dy_total = 0
    for _ in range(num_steps):
        dx_total += random.gauss(0, step_stdev)
        dy_total += random.gauss(0, step_stdev)
        dy_total += 0
    total_dist_squared = (dx_total**2 + dy_total**2) ** 0.5
    return total_dist_squared


def random_walk_gaussian_int_distance_squared(num_steps, step_stdev=1.0):
    total_dist_squared = 0.0
    for _ in range(num_steps):
        dx = random.gauss(0, step_stdev)
        dy = random.gauss(0, step_stdev)
        total_dist_squared += dx**2 + dy**2
    return total_dist_squared


def compute_squared_distributions(num_steps, num_trials, fn, step_stdev=1.0):
    all_distributions = []
    for steps in num_steps:
        all_dist_squared = []
        for _ in range(num_trials):
            dist_squared = fn(steps, step_stdev)
            all_dist_squared.append(dist_squared)
        all_distributions.append(all_dist_squared)
        print(f"Finished {steps} steps")
    return all_distributions


def plot_distributions(num_steps, all_distributions):
    plt.figure(figsize=(10, 6))

    for steps, data in zip(num_steps, all_distributions):
        plt.hist(data, bins=50, density=True, alpha=0.6, label=f"n = {steps}")

    plt.xlabel("Sum of squared step locations")
    plt.ylabel("Density")
    plt.legend()
    plt.show()


num_steps = [10, 50, 200, 500, 1000]
num_trials = 5000

# all_end_distributions = compute_squared_distributions(num_steps, num_trials, random_walk_gaussian_end_distance)
# plot_distributions(num_steps, all_end_distributions)

all_int_distributions = compute_squared_distributions(
    num_steps, num_trials, random_walk_gaussian_int_distance_squared
)
plot_distributions(num_steps, all_int_distributions)
