import math
import matplotlib.pyplot as plt
import numpy as np
import random


#############################################################
# review central limit theorem
#############################################################


# Tomorrow's lecture will begin by picking up from Lecture 6, when we
# discussed the Central Limit Theorem. Thus, this pre-lecture code
# reviews CLT by considering the following scenario.

# Suppose you have n=20 coins. Normally, coins should be weighted
# evenly, so they have a 0.5 probability of landing heads or tails when
# flipped. Your n coins, however, were minted sloppily, and some may be
# weighted towards heads or tails more. You don't know what those biases
# are, but you believe for each coin j, the probability Pj it would land
# on heads is between 1/3 and 2/3. Thus, Pj is a random variable that
# follows a uniform distribution on the interval [1/3, 2/3], where j
# ranges from 1 through n.

# Now consider the probability P that if you flip these n coins, they
# would all land on heads. This is simply the product of all the Pj, so
# P is also a random variable. We'd like to estimate P, which comes down
# to determining its distribution and considering its mean and variance.

#   P = P1 * P2 * P3 ... * Pn
#   where Pj ~ Uniform(1/3, 2/3) for j = 1, 2, ..., n

# While all the Pj variables are identically distrbitued, P is a product
# of them, so unfortunately CLT doesn't apply. However, a useful trick
# we can do is to take the log of both sides.

#   log P = log P1 + log P2 + ... + log Pn

# If we define the random variables Qj = log Pj and Q = log P, then the
# following relationship holds:

#   Q = Q1 + Q2 + ... + Qn

# Thus, CLT says Q is approximately normally distributed, with its mean
# being n times the mean of Qj, and its variance being n times the mean
# of Qj.

#   mean(Q) = n * mean(Qj)
#   var(Q) = n * var(Qj)

# Unfortunately, the mean and variance of Qj are not simply the log of
# the mean and variance of Pj. To calculate them rigorously, we would
# need to go back to the mathematical definitions of mean and variance
# for continuous distributions, and wield integrals like swords and
# daggers, and at the end of the day there's a giant hole in the middle
# of our paper/blackboard/iPad.

# However, we can simulate our way forward! We know how to sample Pj's
# using random.uniform(), so sampling Qj is just taking the log of those
# samples. If we add up n samples of Qj, we will get a sample for Q. And
# if we get many samples of Q, we should see a normal distribution.


def sample_Pj():
    return random.uniform(1/3, 2/3)


def sample_Qj():
    return math.log(sample_Pj())


def sample_Q(n):
    samples = []
    for _ in range(n):
        samples.append(sample_Qj())
    return sum(samples)


def collect_Q_samples(n):
    samples = []
    for _ in range(10_000):
        samples.append(sample_Q(n))
    return samples


def plot_Q_distribution(n):
    samples = collect_Q_samples(n)
    plt.figure()
    plt.hist(samples, bins=200)
    plt.title("Sampled distribution of Q")
    plt.grid()

    print(np.mean(samples))
    print(np.var(samples))
    print(math.exp(np.mean(samples)))
    print(math.exp(np.var(samples)))


plot_Q_distribution(n=20)


# To get the distribution of P from Q, we simply invert the relationship
# and get Q = e^P. (Our logs above are implicitly base e, which is what
# math.log() uses as well.) Complete and run the functions below to see
# what P's distribution looks like. (You'll need to use math.exp() on
# each sample of Q.)


def collect_P_samples(n):
    ...  # TODO


def plot_P_distribution(n):
    ...  # TODO


plot_P_distribution(n=20)


# When you each run the code and compute a mean of P's samples, you
# should get similar results. (And if you all set the same random seed
# in the same way, you would get the exact same results.) In class
# tomorrow, we'll talk about how much variation is reasonable to expect
# between your results, and what sort of claims we can make about that.


############################################################
# show all plots
############################################################


plt.show()
