import math
import matplotlib.pyplot as plt
import numpy as np
import random


############################################################
# random seed
############################################################


def demo_random_seed():
    for seed in [1, 2, 3]:
        print()
        for trial in range(3):
            random.seed(seed)
            print()
            print(f"{seed = }, {trial = }")
            print(random.random())
            print(random.random())
            print(random.random())


# demo_random_seed()


# https://docs.python.org/3/library/random.html
# https://en.wikipedia.org/wiki/Mersenne_Twister


#############################################################
# confidence intervals
#############################################################



def throw_darts(num_darts):
    in_circle = 0
    for _ in range(num_darts):
        x = random.uniform(-1, 1)
        y = random.uniform(-1, 1)
        if math.sqrt(x**2 + y**2) <= 1:
            in_circle += 1
    return in_circle / num_darts


def estimate_pi(num_samples):
    return 4 * throw_darts(num_samples)


# print()
# print(f"estimate: {estimate_pi(1_000):.8f}")
# print(f"estimate: {estimate_pi(10_000):.8f}")
# print(f"estimate: {estimate_pi(100_000):.8f}")
# print(f"estimate: {estimate_pi(1_000_000):.8f}")


def plot_confidence_intervals(experiment):
    sample_sizes = [100, 1000, 10_000, 100_000, 1_000_000]
    estimates = []
    for num_samples in sample_sizes:
        estimates.append(experiment(num_samples))

    # each dart has a p = pi/4 probability of landing in circle
    # each dart's distribution is called Bernoulli
    #   https://en.wikipedia.org/wiki/Bernoulli_distribution
    # variance of Bernoulli distribution is p * (1 - p)
    # standard error is sqrt(variance / num darts)
    intervals = [[0] * len(sample_sizes)] * 2
    for i in range(len(sample_sizes)):
        p = estimates[i] / 4
        var = p * (1 - p) * 16
        sem = math.sqrt(var / sample_sizes[i])
        intervals[0][i] = -2 * sem
        intervals[1][i] = 2 * sem

    plt.figure()
    plt.axhline(y=math.pi, color="red")
    plt.errorbar(sample_sizes, estimates, yerr=intervals, linewidth=3)
    plt.title("95% confidence intervals on estimating $\\pi$")
    plt.xlabel("Sample size")
    plt.xscale("log")
    plt.grid()


# plot_confidence_intervals(estimate_pi)


#############################################################
# statistical significance
#############################################################


beer_group = [
    27, 20, 21, 26, 27, 31, 24, 21, 20, 19, 23, 24,
    28, 19, 24, 29, 18, 20, 17, 31, 20, 25, 28, 21, 27,
]

water_group = [
    21, 22, 15, 12, 21, 16, 19, 15, 22,
    24, 19, 23, 13, 22, 20, 24, 18, 20,
]


def study_results(beer, water, verbose=False):
    mean_beer = sum(beer) / len(beer)
    mean_water = sum(water) / len(water)
    diff = mean_beer - mean_water

    if verbose:
        print("Study results")
        print(f"  Number of beer participants:  {len(beer)}")
        print(f"  Number of water participants: {len(water)}")
        print(f"  Beer group mean:     {mean_beer:.4f}")
        print(f"  Water group mean:    {mean_water:.4f}")
        print(f"  Difference in means: {diff:.4f}")

    return diff


# print()
# study_results(beer_group, water_group, verbose=True)


def permutation_test(beer, water, num_permutations, seed=0):
    random.seed(seed)

    # combine the data
    combined_data = beer + water
    observed_diff = study_results(beer, water)

    # try out permutations
    count = 0
    for _ in range(num_permutations):
        random.shuffle(combined_data)
        perm_beer = combined_data[:len(beer)]
        perm_water = combined_data[len(beer):]
        perm_diff = sum(perm_beer) / len(beer) - sum(perm_water) / len(water)
        if perm_diff >= observed_diff:
            count += 1
    p_value = count / num_permutations

    print("Permutation test")
    print(f"  Observed difference: {observed_diff:.4f}")
    print(f"  Number of sampled permutations:       {num_permutations:,}")
    print(f"  Number of *more extreme* differences: {count:,}")
    print(f"  p-value: {p_value:.6f}")

    return p_value


def run_permutations(beer, water):
    # TODO: translate beer values down, so results are less significant

    for num_permutations in [1000, 10_000, 100_000, 500_000]:
        print()
        permutation_test(beer, water, num_permutations, seed=6100)
        # TODO: try different seeds


# run_permutations(beer_group, water_group)


########################################



def find_prob(num_regions, region, cases_per_sim, multiple, num_sims):
    num_times_over = 0
    threshold = cases_per_sim / num_regions * multiple

    for _ in range(num_sims):
        # build up a distribution of num cases in each region
        cases = [0] * num_regions
        for i in range(cases_per_sim):
            cases[random.randint(0, num_regions - 1)] += 1

        # check if the region exceeds original result
        if cases[region] > threshold:
            num_times_over += 1

    return 1 - num_times_over / num_sims


def simulate_cancer_clusters():
    num_cases_per_year = 36_000
    num_years = 3
    cases_per_sim = num_cases_per_year * num_years

    state_size = 10_000
    region_size = 10
    num_regions = state_size // region_size
    multiple = 1.3  # because 30% more cases than expected

    random.seed(0)
    num_trials = 1
    # num_trials = 10
    num_sims_per_trial = 20
    probs = []
    print("Cancer cluster simulation")
    for t in range(num_trials):
        print("  Starting trial", t)
        probs.append(find_prob(
            num_regions, 111, cases_per_sim,
            multiple, num_sims_per_trial,
        ))
        print(f"    Est. prob. of being a random event = {1 - probs[-1]:.4f}")
    if num_trials > 1:
        print(f"  Standard deviation of trials = {np.std(probs):.4f}")


# print()
# simulate_cancer_clusters()


############################################################
# show all plots
############################################################


plt.show()
