"""
Sampling from the Stars
=======================

Using real astronomical data from the HYG star catalog (~120,000 stars).

Goal: Estimate population parameters from a sample.

Topics:
- Sampling variability
- Central Limit Theorem  
- Confidence intervals
"""

import pandas as pd                  # for data handling
import numpy as np                   # for numerical operations
import matplotlib.pyplot as plt      # for plotting
from scipy import stats              # for statistical functions

# set random seed for reproducibility
np.random.seed(67)

# helper function for plotting
def clean_ax(ax):
    """Remove top/right spines."""
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# =============================================================================
# Load the Data
# =============================================================================

# Load the star database
stars = pd.read_csv('data/hyg_v42.csv')

# Convert to luminosity scale (higher = brighter)
brightness = 100 - stars['mag'].values * 10
true_mean, true_std = np.mean(brightness), np.std(brightness)

print(f"Population: {len(stars):,} stars | μ = {true_mean:.1f} | σ = {true_std:.1f}")
print("First 10 brightness values:", brightness[:10])

# =============================================================================
# Sampling Function
# =============================================================================

def take_sample(n=30):
    """Take a random sample and return mean, std, SE, and 95% CI."""
    sample = np.random.choice(brightness, size=n, replace=False)
    mean, std = np.mean(sample), np.std(sample)
    se = std / np.sqrt(n) # Standard Error
    ci = (mean - 2*se, mean + 2*se) # Approximate 95% CI using 2*SE
    return sample, mean, std, se, ci

sample, mean, std, se, ci = take_sample(30)
print(f"Sample: n={len(sample)} | x̄ = {mean:.1f} | s = {std:.1f} | SE={se:.1f} | 95% CI=({ci[0]:.1f}, {ci[1]:.1f})")

# =============================================================================
# The Night Sky
# =============================================================================
# Here's what we're working with - thousands of stars, each with a measurable brightness.

fig, ax = plt.subplots(figsize=(12, 6), facecolor='black')
ax.set_facecolor('black')

# Plot stars with realistic brightness scaling
# Magnitudes are logarithmic: each 1 mag = 2.512x brightness difference
# Brighter stars (lower mag) should be MUCH bigger
brightness_factor = 10 ** ((6 - stars['mag']) / 2.5)  # Relative to mag 6
sizes = np.clip(brightness_factor * 0.5, 0.1, 25)  # Scale and cap at smaller max
ax.scatter(stars['ra'] * 15, stars['dec'], s=sizes, c='white', alpha=0.8, edgecolors='none')

# Label a few famous bright stars
famous = {
    'Sirius': (101.3, -16.7, 8, -12),      # ra, dec, x-offset, y-offset
    'Vega': (279.2, 38.8, 8, 5),
    'Betelgeuse': (88.8, 7.4, 8, 5),
}
for name, (ra, dec, xoff, yoff) in famous.items():
    ax.scatter(ra, dec, s=80, c='gold', marker='*', edgecolors='white', linewidths=0.3, zorder=5)
    ax.annotate(name, (ra, dec), xytext=(xoff, yoff), textcoords='offset points', 
                color='gold', fontsize=9, fontweight='bold')

ax.set_xlim(0, 360)
ax.set_ylim(-90, 90)
ax.set_xlabel('Right Ascension (degrees)', color='white', fontsize=12)
ax.set_ylabel('Declination', color='white', fontsize=12)
ax.set_title(f'The Night Sky: {len(stars):,} Observable Stars', color='white', fontsize=14, fontweight='bold')
ax.tick_params(colors='white')
for spine in ax.spines.values():
    spine.set_color('white')
plt.tight_layout()
plt.show()

# =============================================================================
# The Question
# =============================================================================
# What is the average brightness of all these stars?
# We *could* measure every single star... but telescope time is expensive!
# What if we only have time to measure 30 stars?

# Population distribution
fig, ax = plt.subplots(figsize=(10, 4))

# plot histogram of brightness
ax.hist(brightness, bins=500, density=True, alpha=0.7, color='steelblue', edgecolor='white')

# Mark true mean and ±1 stddev
ax.axvline(true_mean, color='red', linewidth=2, linestyle='--', label=f'True Mean = {true_mean:.1f}')
ax.axvspan(true_mean - true_std, true_mean + true_std, alpha=0.15, color='orange', 
           label=f'±1 σ = {true_std:.1f}')
ax.set(xlabel='Luminosity (higher = brighter)', ylabel='Density', xlim=(0, 80),
       title=f'Population: All {len(stars):,} Observable Stars')

ax.legend(fontsize=10, loc='upper left')
clean_ax(ax)
plt.tight_layout()
plt.show()

print(f"(In real life, we wouldn't know μ = {true_mean:.1f} or σ = {true_std:.1f}!)")

# =============================================================================
# Taking a Sample
# =============================================================================
# Sample 30 random stars. Run multiple times to see variability.

sample_size = 30

your_sample, your_mean, your_std, your_se, your_ci = take_sample(sample_size)

fig, ax = plt.subplots(figsize=(10, 4))

# plot the population histogram
ax.hist(brightness, bins=500, density=True, alpha=0.3, color='gray', label='Population')

# plot the sample histogram
ax.hist(your_sample, bins=15, density=True, alpha=0.7, color='steelblue', edgecolor='white', label='Your Sample (n={sample_size})')

ax.axvline(true_mean, color='gray', linewidth=2, linestyle='--', label=f'True Mean = {true_mean:.1f}')
ax.axvline(your_mean, color='red', linewidth=2, label=f'Your Estimate = {your_mean:.1f}')
ax.set(xlabel='Luminosity', ylabel='Density', xlim=(0, 80), title='Your Sample vs Population')
ax.legend(fontsize=10, loc='upper left')
clean_ax(ax)
plt.tight_layout()
plt.show()

print(f"Sample: mean = {your_mean:.1f}, std = {your_std:.1f}, error = {your_mean - true_mean:+.1f}")

# =============================================================================
# Sampling Variability
# =============================================================================
# 40 astronomers each sample 30 stars. How much do their estimates vary?

sample_size = 30
n_trials = 40

sample_means = [take_sample(sample_size)[1] for _ in range(n_trials)]

fig, ax = plt.subplots(figsize=(10, 5))
ax.scatter(sample_means, range(n_trials), s=60, alpha=0.7, c='steelblue', edgecolors='white')
ax.axvline(true_mean, color='red', linewidth=2, linestyle='--', label=f'True Mean = {true_mean:.1f}')
ax.scatter([your_mean], [40], s=120, c='orange', marker='*', zorder=5, label=f'YOU = {your_mean:.1f}')
ax.set(xlabel='Estimated Mean', ylabel='Astronomer #', title='40 Astronomers, 40 Different Estimates')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3, axis='x')
clean_ax(ax)
plt.tight_layout()
plt.show()

print(f"Spread of estimates (std): {np.std(sample_means):.1f}  ← This is the Standard Error!")

# =============================================================================
# Central Limit Theorem
# =============================================================================
# As sample size increases, sample means cluster more tightly around μ.
# Standard Error: SE = σ / √n

sample_sizes = [5, 15, 30, 100]
n_trials = 500
fig, axes = plt.subplots(1, 4, figsize=(12, 3))

x = np.linspace(true_mean - 15, true_mean + 15, 100)

for ax, n in zip(axes, sample_sizes):

    means = [take_sample(n)[1] for _ in range(n_trials)]
    ax.hist(means, bins=25, density=True, alpha=0.7, color='steelblue', edgecolor='white')
    
    # Theoretical normal curve
    se = true_std / np.sqrt(n)
    ax.axvline(true_mean, color='red', linewidth=2, linestyle='--')
    ax.plot(x, stats.norm.pdf(x, true_mean, se), 'k-', linewidth=2)
    ax.set(title=f'n={n}, SE={se:.1f}', xlabel='Sample Mean', xlim=(true_mean-15, true_mean+15))

axes[0].set_ylabel('Density')
fig.suptitle('Central Limit Theorem: Larger Samples → Tighter Clustering', fontweight='bold')
plt.tight_layout()
plt.show()

print("Black curve = theoretical prediction (SE = σ/√n) | Blue = simulation")

# =============================================================================
# Confidence Intervals
# =============================================================================
# Since our estimate is uncertain, we report a **range** instead of a single number.
#
# The 95% CI formula:
#   CI = x̄ ± 2 × SE = x̄ ± 2 × (s / √n)
#
# A 95% confidence interval means: if we repeated this process many times,
# about 95% of our intervals would contain the true mean.
#
# (We use 2 instead of 1.96 for simplicity — it's close enough!)

# =============================================================================
# Where Does "±2" Come From?
# =============================================================================
# The normal distribution has a special property called the 68-95-99.7 rule:
# - 68% of values fall within ±1 standard deviation
# - 95% of values fall within ±2 standard deviations  
# - 99.7% of values fall within ±3 standard deviations
#
# Since we want 95% confidence, we use ±2 standard errors. (Technically 1.96, but 2 is close enough!)

# Visualize the 68-95-99.7 rule
fig, ax = plt.subplots(figsize=(10, 4))
x = np.linspace(-4, 4, 1000)
y = stats.norm.pdf(x)

ax.plot(x, y, 'k-', linewidth=2)
ax.fill_between(x, y, where=(x >= -3) & (x <= 3), alpha=0.15, color='orange', label='99.7% (±3σ)')
ax.fill_between(x, y, where=(x >= -2) & (x <= 2), alpha=0.2, color='green', label='95% (±2σ)')
ax.fill_between(x, y, where=(x >= -1) & (x <= 1), alpha=0.3, color='blue', label='68% (±1σ)')

# Mark the boundaries
for z, c in [(-3, 'orange'), (3, 'orange'), (-2, 'green'), (2, 'green'), (-1, 'blue'), (1, 'blue')]:
    ax.axvline(z, color=c, linestyle='--', linewidth=1, alpha=0.7)

ax.set(xlabel='Standard Deviations from Mean', ylabel='Density', 
       title='The 68-95-99.7 Rule', xlim=(-4, 4))
ax.legend(loc='upper right')
clean_ax(ax)
plt.tight_layout()
plt.show()

print("68% within ±1σ | 95% within ±2σ | 99.7% within ±3σ")

# =============================================================================
# Your 95% Confidence Interval
# =============================================================================

fig, ax = plt.subplots(figsize=(10, 2.5))
ax.barh(0, your_ci[1] - your_ci[0], left=your_ci[0], height=0.4, color='steelblue', alpha=0.6, label='95% CI')
ax.scatter([your_mean], [0], color='steelblue', s=120, zorder=5, label=f'Estimate: {your_mean:.1f}')
ax.axvline(true_mean, color='red', linewidth=2, linestyle='--', label=f'True Mean: {true_mean:.1f}')
ax.text(your_ci[0], -0.35, f'{your_ci[0]:.1f}', ha='center')
ax.text(your_ci[1], -0.35, f'{your_ci[1]:.1f}', ha='center')
ax.set(xlim=(true_mean-12, true_mean+12), ylim=(-0.6, 0.5), yticks=[], xlabel='Luminosity')
ax.legend(loc='upper right')
hit = your_ci[0] <= true_mean <= your_ci[1]
ax.set_title(f"{'✓ CI contains true mean!' if hit else '✗ CI missed!'}", 
             color='green' if hit else 'red', fontweight='bold')
clean_ax(ax)
plt.tight_layout()
plt.show()

# =============================================================================
# Coverage Check
# =============================================================================
# 100 CIs. Expect ~95 to contain the true mean.

fig, ax = plt.subplots(figsize=(10, 7))
n_trials = 100
sample_size = 30

hits = 0
for i in range(n_trials):
    _, mean, _, se, ci = take_sample(sample_size)
    contains = ci[0] <= true_mean <= ci[1]
    hits += contains
    color = 'steelblue' if contains else 'red'
    ax.plot(ci, [i, i], color=color, linewidth=1.5 if contains else 2.5, alpha=0.7)
    ax.scatter([mean], [i], color=color, s=15, zorder=5)

ax.axvline(true_mean, color='black', linewidth=2, linestyle='--', label=f'True Mean = {true_mean:.1f}')
ax.set(xlabel='Luminosity', ylabel='Sample #', 
       title=f'{n_trials} CIs: {hits} contain true mean ({n_trials-hits} missed)')
ax.legend(loc='upper right')
clean_ax(ax)
plt.tight_layout()
plt.show()

print(f"Expected ~95%, observed {hits / n_trials}%")

# =============================================================================
# Sample Size and CI Width
# =============================================================================
# Key insight: Larger samples give narrower (more precise) confidence intervals.
#
# The CI width is: 2 × 2 × SE = 4 × (s / √n)
#
# Notice the √n in the denominator — to cut your CI width in half, 
# you need 4× as many observations!

sample_sizes = np.array([10, 20, 40, 80, 160, 320, 640, 1280])
ci_widths = 4 * true_std / np.sqrt(sample_sizes)  # width = 2 * 2 * SE

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Left: CI width vs n
ax1.plot(sample_sizes, ci_widths, 'o-', color='steelblue', linewidth=2, markersize=8)
ax1.set(xlabel='Sample Size (n)', ylabel='95% CI Width', title='CI Width vs Sample Size')
clean_ax(ax1)

# Right: Visual comparison
for i, n in enumerate([10, 40, 160, 640]):
    width = 4 * true_std / np.sqrt(n)
    ax2.barh(i, width, left=true_mean - width/2, height=0.6, alpha=0.7, 
             color=['#ff6b6b', '#ffd93d', '#6bcb77', '#4d96ff'][i])
    ax2.text(true_mean + width/2 + 0.5, i, f'n={n}', va='center', fontsize=10)

ax2.axvline(true_mean, color='black', linestyle='--', linewidth=2)
ax2.set(xlabel='Luminosity', yticks=[], title='CI Width Comparison', xlim=(true_mean-15, true_mean+15))
clean_ax(ax2)
plt.tight_layout()
plt.show()

print("To halve CI width: need 4x the sample size")

# =============================================================================
# Summary
# =============================================================================
# 1. Samples vary — different samples give different estimates
# 2. Standard Error: SE = s / √n
# 3. Larger n → smaller SE → tighter estimates
# 4. 95% CI: x̄ ± 2 × SE
# 5. 95% coverage — expect ~5% of CIs to miss
