"""
Big ideas:
- SD: how noisy individual data points are.
- SE: how noisy our estimate (like the sample mean) is across repeated samples.
- Confidence interval: a range around an estimate that tends to contain the true value
  some % of the time (e.g., 95%).
- Hypothesis test: “If the null were true, how surprising is my data?”
"""

# Presentation: https://docs.google.com/presentation/d/10j8hVtzN-l4JtiaV6tY60dXX3rRy5PwOQNK26h4w5yE/edit?usp=sharing
# Applets
# - Linear Regression: https://mathlets.org/mathlets/linear-regression/
# - Confidence Intervals: https://mathlets.org/mathlets/confidence-intervals/
# - Hypothesis Testing: https://www.stapplet.com/power.html

# Videos if you are still confused on previous topics:
# - Z Score, Confidence Interval, and Margin of Error: https://www.youtube.com/watch?v=DT-fPG0Hff8
# - Random walks: https://youtu.be/stgYW6M5o4k?si=WqGAkux5KWD1bzEx
# - OOP: https://youtu.be/JeznW_7DlB0?si=IpBOWv4lsZL1CUl7

import numpy as np
import matplotlib.pyplot as plt

"""
Guiding questions:
- What does it mean to "draw a sample" from a population?
- Why does the average of 30 people wiggle less than the value of 1 random person?
- Why is SE ≈ SD / sqrt(n)?
"""

rng = np.random.default_rng(0)

# "True" population: pretend this is everyone we care about
population = rng.normal(loc=50, scale=10, size=100_000)  # mean 50, SD 10

# Sample size and number of repeated samples
n = 30  # people per sample
num_samples = 2000  # how many samples we draw

# TODO 1: Change n to 5, then 200.
#   - How do the empirical SE and SD/sqrt(n) change?
#   - Does the histogram of sample means get wider or narrower?

sample_means = []
for _ in range(num_samples):
    # "Drawing a sample" = picking n individuals at random from the population.
    sample = rng.choice(population, size=n, replace=False)
    sample_means.append(np.mean(sample))

sample_means = np.array(sample_means)

# Empirical SE = SD of the sample means (how much means bounce around)
empirical_se = np.std(sample_means, ddof=1)

# Population SD: spread of individual people
pop_sd = np.std(population, ddof=1)

# Theoretical SE formula: SE(mean) ≈ SD / sqrt(n)
theoretical_se = pop_sd / np.sqrt(n)

print("=== Part 1: SD vs SE and Sampling Distribution ===")
print(f"Population mean ≈ {np.mean(population):.2f}")
print(f"Population SD   ≈ {pop_sd:.2f}")
print(f"Empirical SE of sample means ≈ {empirical_se:.2f}")
print(f"Theoretical SE (SD/sqrt(n)) ≈ {theoretical_se:.2f}")
print()

plt.hist(sample_means, bins=30, density=True)
plt.axvline(np.mean(population), linestyle="--", label="True population mean")
plt.xlabel("Sample mean (n=30)")
plt.ylabel("Density")
plt.title("Sampling distribution of the sample mean")
plt.legend()
plt.show()


"""
Goal:
- Connect SE to a confidence interval for the mean.
- See that “95% confidence” is about the long-run success of the method,
  not a magic guarantee for one interval.
"""

# New population for this part
population = rng.normal(loc=40, scale=15, size=100_000)
sample = rng.choice(population, size=50, replace=False)

sample_mean = np.mean(sample)
sample_sd = np.std(sample, ddof=1)
n = len(sample)
se = sample_sd / np.sqrt(n)

# 95% CI using z ≈ 1.96
z = 1.96
ci_lower = sample_mean - z * se
ci_upper = sample_mean + z * se

print("=== Part 2: Confidence Intervals ===")
print(f"Sample mean: {sample_mean:.2f}")
print(f"Sample SD:   {sample_sd:.2f}")
print(f"SE (mean):   {se:.2f}")
print(f"95% CI for mean: ({ci_lower:.2f}, {ci_upper:.2f})")
print()

# TODO 2: Change the sample size from 50 to 20 and then to 200.
#   - What happens to SE and to the CI width?
#   - How does this connect to SE = SD / sqrt(n)?


def compute_ci(sample, z=1.96):
    m = np.mean(sample)
    s = np.std(sample, ddof=1)
    n = len(sample)
    se = s / np.sqrt(n)
    return m - z * se, m + z * se


true_mean = np.mean(population)

num_intervals = 200
count_contains = 0

for _ in range(num_intervals):
    # Each loop = a new "study": new sample, new CI
    sample = rng.choice(population, size=50, replace=False)
    low, high = compute_ci(sample)
    if low <= true_mean <= high:
        count_contains += 1

coverage = 100 * count_contains / num_intervals
print(f"Over {num_intervals} simulated studies:")
print(f"{count_contains} intervals contained the true mean " f"(≈ {coverage:.1f}%)")
print("This is what “95% confidence” means: the *procedure* works ~95% of the time.")
print()


"""
Goal:
- See a p-value as: “If the null were true, how often would I see something
  this extreme or more?”

Setup:
- We observed 65 heads out of 100 flips.
- H0: coin is fair (p = 0.5).
- We simulate many fair-coin experiments and see how often we get
  results as extreme as ours.
"""
rng = np.random.default_rng(2)

n_flips = 100
observed_heads = 65
observed_prop = observed_heads / n_flips

print("=== Part 3: Hypothesis Testing (Coin Example) ===")
print(
    f"Observed: {observed_heads} heads out of {n_flips} flips (p̂ = {observed_prop:.2f})"
)

# Null hypothesis: p0 = 0.5 (fair coin)
p0 = 0.5
num_sims = 50_000

sim_props = []
for _ in range(num_sims):
    # Simulate one experiment under H0
    flips = rng.binomial(n=1, p=p0, size=n_flips)
    sim_props.append(np.mean(flips))

sim_props = np.array(sim_props)

# Two-sided p-value = chance of seeing a proportion this far or farther from 0.5
# So in words, those 3 lines do:
# 	1.	Measure how far your result is from the null.
# 	2.	Measure how far each simulated null-world result is from the null.
# 	3.	Count what fraction of null-world results are at least that far ->
#       that fraction is your estimated p-value.
dist_from_null_obs = abs(observed_prop - p0)
dist_from_null_sims = abs(sim_props - p0)
p_value = np.mean(dist_from_null_sims >= dist_from_null_obs)

print(f"Approximate two-sided p-value: {p_value:.4f}")

alpha = 0.05
if p_value < alpha:
    print(f"p < {alpha} → reject H0 (coin looks biased)")
else:
    print(f"p ≥ {alpha} → fail to reject H0 (no strong evidence of bias)")

print("\nImportant: 'Fail to reject H0' != 'we proved the coin is fair'.")
print("It just means the data are not surprising enough to claim bias at that α.")
print()

# TODO 3: Change observed_heads to 60, then 55, then 80 (out of 100).
#   - How does the p-value change?
#   - For which values would you reject H0 at α = 0.05?
