import math
import matplotlib.pyplot as plt
import random


############################################################
# wrap-up functions and mutation
############################################################


def demo_return_behavior(items):
    print(f"in function, {items = }")
    new_items = items + ["a", "b", "c"]
    items[1] = 200

    # return an expression
    return new_items
    return new_items + ["a", "b", "c"]

    # empty return --> returns None
    return

    # no return at all --> returns None
    print("last statement")

    # return a mutating expression --> returns None
    # return new_items.append("XYZ")


# print()
# num_list = [1, 2, 3]
# result = demo_return_behavior(num_list)
# print(f"in global, function call evaluates to: {result}", )
# print(f"in global: {num_list = }")


def make_sequence(length, initial=[1, 1]):
    for _ in range(length - 2):
        initial.append(initial[-2] + initial[-1])
    return initial


def make_sequence(length, initial=None):
    if initial is None:
        initial = [1, 1]
    for _ in range(length - 2):
        initial.append(initial[-2] + initial[-1])
    return initial


def make_sequence(length, initial=[1, 1]):
    sequence = initial.copy()
    for _ in range(length - 2):
        sequence.append(sequence[-2] + sequence[-1])
    return sequence


# print()
# print(make_sequence(length=5))
# print(make_sequence(length=5, initial=[2, 2]))
# print(make_sequence(length=5))


############################################################
# stochastic programs
############################################################


# https://docs.python.org/3/library/random.html
# docs.python.org > Library reference > Numerical and Mathemtical Modules > random


def demo_random_module():
    # random integers
    for _ in range(5):
        print(random.randint(3, 7))
    print()

    # random floats
    for _ in range(5):
        print(random.random())
    print()

    # random selection from a group
    for _ in range(5):
        print(random.choice(["raindrops", "whiskers", "kettles", "mittens"]))
    for _ in range(5):
        print(random.sample(["raindrops", "whiskers", "kettles", "mittens"], k=2))
    print()


# print()
# demo_random_module()


def sample_and_plot(distro, num_samples, show_plots=False):
    # get samples
    samples = []
    for _ in range(num_samples):
        if distro == "uniform":
            sample = random.uniform(3, 7)
        elif distro == "normal":
            sample = random.gauss(mu=5, sigma=2)
        samples.append(sample)

    # print the first few
    for val in sorted(samples[:10]):
        print(val)
    print()

    # plot a histogram if requested
    if show_plots:
        plt.figure()
        plt.hist(samples, bins=50)
        plt.title(f"Samples from a {distro.title()} Distribution")
        plt.xlabel("Value")
        plt.ylabel("Frequency")
        plt.grid()


def demo_sampling_from_distributions(show_plots=False):
    sample_and_plot("uniform", num_samples=100_000, show_plots=show_plots)
    sample_and_plot("normal", num_samples=100_000, show_plots=show_plots)
    if show_plots:
        plt.show()


# print()
# demo_sampling_from_distributions()


def roll_die():
    return random.randint(1, 6)


def roll_dice(num_dice):
    rolls = []
    for _ in range(num_dice):
        rolls.append(roll_die())
    return rolls


def simulate_dice(target, num_trials):
    print(f"simulate_dice() for {num_trials:,d} trials")
    hits = 0
    for trial in range(num_trials):
        if roll_dice(len(target)) == target:
            hits += 1
    estimated_prob = hits / num_trials
    actual_prob = 1 / 6 ** len(target)
    print(f"{estimated_prob = :.8f}")
    print(f"{actual_prob =    :.8f}")
    print()


# print()
# simulate_dice([1] * 2, 1000)
# simulate_dice([1] * 5, 1000)


############################################################
# estimating deterministic quantities
############################################################


def throw_darts(num_darts):
    """
    Sample points uniformly within the 2x2 square centered on the
    origin. Report the proportion (float) of points that within the unit
    circle inside the square.
    """
    in_circle = 0
    for _ in range(num_darts):
        x = random.uniform(-1, 1)
        y = random.uniform(-1, 1)
        if (x**2 + y**2) ** 0.5 <= 1:
            in_circle += 1
    return in_circle / num_darts


def estimate_pi(num_samples):
    return 4 * throw_darts(num_samples)


# print()
# print(f"estimate: {estimate_pi(1_000):.8f}")
# print(f"estimate: {estimate_pi(10_000):.8f}")
# print(f"estimate: {estimate_pi(100_000):.8f}")
# print(f"estimate: {estimate_pi(1_000_000):.8f}")
# print(f"estimate: {estimate_pi(10_000_000):.8f}")
# print(f"actual:   {math.pi:.8f}")


def estimate_integral(min_x, max_x, num_darts=100_000):
    # assume mathematical function is positive over the domain [min_x, max_x]
    min_y = 0
    max_y = 0

    # find max y value
    delta = 1e-3
    num_steps = round((max_x - min_x) / delta)
    for step in range(num_steps + 1):
        x = min_x + delta * step
        y = (x - 3) ** 2 + 1
        if y > max_y:
            max_y = y

    # throw darts within rectangle bounded by [min_x, max_x] and [min_y, max_y]
    under_curve = 0
    for _ in range(num_darts):
        x = random.uniform(min_x, max_x)
        y = random.uniform(min_y, max_y)
        if y <= (x - 3) ** 2 + 1:
            under_curve += 1

    total_area = (max_x - min_x) * (max_y - min_y)
    return under_curve / num_darts * total_area


# print()
# print(estimate_integral(1, 5))
#   y    = (x - 3) ** 2 + 1
#        = x**2     - 6*x    + 10
#   Y    = 1/3*x**3 - 3*x**2 + 10*x + c
#   Y(5) = 1/3*125  - 3*25   + 50   + c     = 16.667 + c
#   Y(1) = 1/3      - 3      + 10   + c     =  7.333 + c
#   Y(5) - Y(1)                             = 9.333


############################################################
# birthday problem
############################################################


def simulate_birthdays(num_people):
    dates = []
    for _ in range(num_people):
        dates.append(random.randint(1, 366))
    return dates


# EXERCISE: implement the has_overlaps() helper
def has_overlaps(dates):
    ...


def simulate_birthday_groups(num_people, num_trials=10_000):
    overlaps = 0
    for trial in range(num_trials):
        birthdays = simulate_birthdays(num_people)
        if has_overlaps(birthdays):
            overlaps += 1
    return overlaps / num_trials


def run_birthday_sims():
    for num_people in [2, 4, 8, 16, 32, 64, 128]:
        prob = simulate_birthday_groups(num_people)
        print(f"estimated prob of a shared birthday with {num_people:3d} people: {prob:.5f}")
    print()


# print()
# run_birthday_sims()
