import math
import numpy as np
import matplotlib.pyplot as plt
import random


############################################################
# probability distributions
############################################################


def sample_and_plot_uniform(num_samples, discrete=False):
    low = 2
    high = 11
    outcomes = []
    for _ in range(num_samples):
        outcomes.append(random.uniform(low, high))

    plt.figure()
    plt.hist(outcomes, bins=50, density=True)

    x_vals = np.linspace(low - 2, low)
    y_vals = [0] * len(x_vals)
    plt.plot(x_vals, y_vals, color="red", linewidth=5)

    x_vals = np.linspace(low, high)
    y_vals = [1 / (high - low)] * len(x_vals)
    plt.plot(x_vals, y_vals, color="red", linewidth=5)

    x_vals = np.linspace(high, high + 2)
    y_vals = [0] * len(x_vals)
    plt.plot(x_vals, y_vals, color="red", linewidth=5)

    plt.title("Continuous uniform distribution")
    plt.xlabel("Outcome value")
    plt.ylabel("Probability density")
    plt.grid()


# sample_and_plot_uniform(num_samples=100)
# sample_and_plot_uniform(num_samples=1000)
# sample_and_plot_uniform(num_samples=10_000)


def sample_and_plot_uniform_discrete(num_samples):
    low = 2
    high = 11
    outcomes = []
    for _ in range(num_samples):
        outcomes.append(random.randint(low, high))

    plt.figure()
    plt.hist(outcomes, density=True)

    x_vals = range(low - 2, high + 3)
    y_vals = []
    for x in x_vals:
        if low <= x <= high:
            y_vals.append(1 / (high - low + 1))
        else:
            y_vals.append(0)
    plt.stem(x_vals, y_vals, linefmt="r-", )

    plt.title("Discrete uniform distribution")
    plt.xlabel("Outcome value")
    plt.ylabel("Probability")
    plt.grid()


# sample_and_plot_uniform_discrete(num_samples=100)
# sample_and_plot_uniform_discrete(num_samples=1000)
# sample_and_plot_uniform_discrete(num_samples=10_000)


def sample_and_plot_normal(num_samples):
    mean = 5
    stdev = 2
    outcomes = []
    for _ in range(num_samples):
        outcomes.append(random.gauss(mean, stdev))

    plt.figure()
    plt.hist(outcomes, bins=50, density=True)

    x_vals = np.linspace(mean - 4 * stdev, mean + 4 * stdev)
    y_vals = []
    for x in x_vals:
        y_vals.append(
            1 / (stdev * math.sqrt(2 * math.pi))
            * math.exp( -(x - mean) ** 2 / (2 * stdev**2) )
        )
    plt.plot(x_vals, y_vals, color="red", linewidth=5)

    plt.title("Normal distribution")
    plt.xlabel("Outcome value")
    plt.ylabel("Probability density")
    plt.grid()


# sample_and_plot_normal(num_samples=100)
# sample_and_plot_normal(num_samples=1000)
# sample_and_plot_normal(num_samples=10_000)


def sample_and_plot_exponential(num_samples):
    rate = 0.5
    outcomes = []
    for _ in range(num_samples):
        outcomes.append(random.expovariate(rate))

    plt.figure()
    plt.hist(outcomes, bins=50, density=True)

    x_vals = np.linspace(0, 10 / rate)
    y_vals = []
    for x in x_vals:
        y_vals.append(rate * math.exp(-rate * x))
    plt.plot(x_vals, y_vals, color="red", linewidth=5)

    plt.title("Exponential distribution")
    plt.xlabel("Outcome value")
    plt.ylabel("Probability density")
    plt.grid()


# sample_and_plot_exponential(num_samples=100)
# sample_and_plot_exponential(num_samples=1000)
# sample_and_plot_exponential(num_samples=10_000)


############################################################
# distributions of means
############################################################


def mean(values):
    return sum(values) / len(values)


def variance(values):
    center = mean(values)
    squared_differences = []
    for val in values:
        squared_differences.append((val - center) ** 2)
    return mean(squared_differences)


def stdev(values):
    return math.sqrt(variance(values))


def roll_die():
    return random.randint(1, 6)


def roll_dice(num_dice):
    rolls = []
    for _ in range(num_dice):
        rolls.append(roll_die())
    return rolls


def sample_dice_mean_distro(num_dice, num_trials):
    samples_of_mean = []
    for _ in range(num_trials):
        samples_of_mean.append(mean(roll_dice(num_dice)))

    plt.figure()
    plt.hist(
        samples_of_mean, density=True,
        bins=np.linspace(0.5, 6.5, num=min(100, 6 * num_dice + 1)),
    )

    plt.title(
        f"Sampled distribution of Mean\n"
        f"Sample size = {num_dice}, Num trials = {num_trials:,d}"
    )
    plt.xlabel("Outcome value")
    plt.ylabel("Probability density")
    plt.grid()


# sample_dice_mean_distro(num_dice=1, num_trials=10_000)
# sample_dice_mean_distro(num_dice=2, num_trials=10_000)
# sample_dice_mean_distro(num_dice=3, num_trials=10_000)
# sample_dice_mean_distro(num_dice=5, num_trials=10_000)
# sample_dice_mean_distro(num_dice=10, num_trials=10_000)
# sample_dice_mean_distro(num_dice=20, num_trials=10_000)
# sample_dice_mean_distro(num_dice=100, num_trials=10_000)


def sample_mean_of_distro(distribution, sample_size, num_trials):
    samples_of_mean = []
    for _ in range(num_trials):
        distro_sample = []
        for _ in range(sample_size):
            distro_sample.append(distribution())
        samples_of_mean.append(mean(distro_sample))

    plt.figure()
    plt.hist(samples_of_mean, density=True, bins=num_trials // 200)

    plt.title(
        f"Sampled distribution of Mean\n"
        f"Sample size = {sample_size}, Num trials = {num_trials:,d}"
    )
    plt.xlabel("Outcome value")
    plt.ylabel("Probability density")
    plt.grid()


def demo_central_limit_theorem(distribution):
    num_trials = 10_000
    for sample_size in [1, 2, 3, 5, 10, 20, 100]:
        sample_mean_of_distro(distribution, sample_size, num_trials)


def test_uniform():
    return random.uniform(1, 10)


def test_exponential():
    return random.expovariate(lambd=0.5)


def test_custom_discrete():
    return random.choice([1, 2, 2, 2, 5, 6, 8, 8, 8, 11])


# demo_central_limit_theorem(test_uniform)
# demo_central_limit_theorem(test_exponential)
# demo_central_limit_theorem(test_custom_discrete)


############################################################
# show all plots
############################################################


plt.show()
