import numpy as np
import matplotlib.pyplot as plt
############################################################
# TASK 1: UNDERSTANDING KEY CONCEPTS
############################################################

# Presentation: https://docs.google.com/presentation/d/10j8hVtzN-l4JtiaV6tY60dXX3rRy5PwOQNK26h4w5yE/edit?usp=sharing

# Applets
# - Linear Regression: https://mathlets.org/mathlets/linear-regression/
# - Confidence Intervals: https://mathlets.org/mathlets/confidence-intervals/
# - Hypothesis Testing: https://www.stapplet.com/power.html

# Videos if you are still confused on previous topics:
# - Z Score, Confidence Interval, and Margin of Error: https://www.youtube.com/watch?v=DT-fPG0Hff8
# - Random walks: https://youtu.be/stgYW6M5o4k?si=WqGAkux5KWD1bzEx
# - OOP: https://youtu.be/JeznW_7DlB0?si=IpBOWv4lsZL1CUl7

import numpy as np
import matplotlib.pyplot as plt

###########################################################################
# Train, validate, test 
###########################################################################

# Training / Validation / Test Split
#   - **Training data** is used to fit the model parameters.
#   - **Validation data** is used to tune hyperparameters and select the best 
#       model structure.
#   - **Test data** evaluates how well the final model performs on unseen data.

# Overfitting and Generalizability
#   - A model overfits when it memorizes the training data but fails to 
#       generalize to new data.
#   - A generalizable model performs well on both training and test sets.

### Curve Fitting and Loss Functions
# Objective / Loss Functions
#   - Measure how well a model fits the data.
#   - Example: Least Squares Error or Mean Squared Error for linear regression.
#       - MSE = LSE / n

### Why Square the Error?
#   - Squaring emphasizes larger errors and ensures all contributions are positive.
#   - Positive and negative deviations don't cancel each other out.
#   - Model more sensitive to outliers

### Linear Regression
#   - Fits a line (or polynomial) to data using a loss function like Least Squares.

#   - **`np.polyfit(x, y, n)`**
#       - Fits a degree `n` polynomial to the `x, y` data.
#   What is `np.polyfit()` often used in conjunction with?
#   - `np.linspace(a, b, n)`
#       - Creates array of n evenly spaced numbers from a to b (inclusive).
#   - `np.polyval(p, x)`
#       - Evaluates the polynomial `p` (represented by a 1-D array of its 
#           coefficients in the order of descending degree) at a value `x`.
#   - `np.mean(x)`, `np.std(x)`, `np.var(x)`
#       - Evaluate the mean, standard deviation, and variance of all of the 
#           elements of `x` respectively.

### Why Not Use Mean-Squared Error To Determine the Goodness of A Curve Fit?
#   - Mean squared error useful for comparing two different models for 
#       the same data
#   - Hard to know whether the number for Mean Squared Error is good or not – 
#       no bound on values; not scale independent

### Coefficient of Determination (R^2)
#   - R^2 = 1 - (sum (y_i-p_i)^2)/(sum (y_i - mu)^2)
#   - R² = 1: Model explains all variability in the data.
#   - R² = 0: Model is no better than predicting the mean.
#   - R² = 0.5: Model explains 50% of variability, but not accuracy.


# Now, let's try fitting a few models:

# degree 1 vs degree 2 vs...degree n
def curve_fitting_model(x=None, y=None, m=1, n=10):

    # Generate synthetic data
    np.random.seed(0)

    if x is None:
        x = np.linspace(0, 10, 20)
    if y is None:
        y = 3 * x + 7 + np.random.normal(0, 5, size=x.shape)

    # Sort points for smooth plotting
    sorted_indices = np.argsort(x)
    x_sorted = np.array(x)[sorted_indices]
    y_sorted = np.array(y)[sorted_indices]

   # TODO: Fit a curve of degree k, for each k between 1 - n for the sorted points.
    y_fits = []
    coefs_degs = []

    plt.scatter(x_sorted, y_sorted, label='Data')

    for i in range(m, n+1):
        plt.plot(x_sorted, y_fits[i-m], label=f'deg {i}', linestyle='-')
    plt.legend()
    plt.title(f"Curve Fitting: Degrees {m} to {n}")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.grid(True)
    plt.show()

    return coefs_degs

# curve_fitting_model(x=None, y=None, m=3, n=10)


def mse_and_r2s(x_sorted, y_sorted, coefs_degs, m=1, n=10):
    
    # Store MSEs
    mses = []
    r2s = []

    for i in range(m, n+1):

        # TODO: MSE calculation
        mse = None

        # TODO: R^2 calculation
        r2 = None

        print(f"deg {i} has MSE {mse} and R^2 value {r2}")
    
    return mses, r2s


def run_train_val_test(m=1, n=10, train_frac=0.6, val_frac=0.2):
    # ----- generate synthetic dataset -----
    np.random.seed(0)
    x = np.linspace(0, 10, 100)
    y = 3 * x + 7 + np.random.normal(0, 5, size=x.shape)

    # ----- shuffle + split -----
    # TODO: Shuffle and Split the data into three x, y lists
    x_lists = []
    y_lists = []

    # Sorting the data points for clean plotting
    sorted_x_lists = []
    sorted_y_lists = []

    for index in range(len(x_lists)):
        sorted_indices = np.argsort(x_lists[index])
        sorted_x_lists.append(np.array(x_lists[index])[sorted_indices])
        sorted_y_lists.append(np.array(y_lists[index])[sorted_indices])
        

    # ----- run original function on each split -----
    print("\n=== TRAINING SET ===")
    coefs_degs = curve_fitting_model(sorted_x_lists[0], sorted_y_lists[0], m, n)
    mse_and_r2s(sorted_x_lists[0], sorted_y_lists[0], coefs_degs, m, n)
    
    print("\n=== VALIDATION SET ===")
    mses, r2s = mse_and_r2s(sorted_x_lists[1], sorted_y_lists[1], coefs_degs, m, n)
    # TODO: find the best (not overfitting) model
    best_test_index = None
    print("this is best degree: ", best_test_index+m)   

    print("\n=== TEST SET ===")
    mse_and_r2s(sorted_x_lists[2], sorted_y_lists[2], coefs_degs, m, n)


run_train_val_test(m=1, n=10, train_frac=0.6, val_frac=0.2)


###########################################################################
# Population vs Sample, Central Limit Theorem 
###########################################################################

rng = np.random.default_rng(0) #rng = random number generator

def clt_visualization():
    
    # "True" population: pretend this is everyone
    population = rng.normal(loc=50, scale=10, size=100_000)  # mean 50, sd 10

    # we draw many samples and compute the sample mean each time
    n = 30 # sample size
    num_samples = 2000
    # TODO: collect the sample means here
    sample_means = []

    sample_means = np.array(sample_means)

    empirical_se = np.std(sample_means, ddof=1)
    
    pop_sd = np.std(sample_means, ddof=0)
    # TODO: theoretical SE ≈ population_sd / sqrt(n)
    theoretical_se = None

    print(f"Population mean ≈ {np.mean(population):.2f}")
    print(f"Mean of the sample means ≈ {np.mean(sample_means):.2f}")
    print(f"Population SD ≈ {pop_sd:.2f}")
    print(f"Empirical SE of sample means ≈ {empirical_se:.2f}")
    print(f"Theoretical SE (SD/sqrt(n)) ≈ {theoretical_se:.2f}")

    # code to plot the sample means
    plt.hist(sample_means, bins=30, density=True)
    plt.axvline(np.mean(population), linestyle="--")
    plt.xlabel("Sample mean")
    plt.ylabel("Density")
    plt.title(f"Sampling distribution of mean (n={n})")
    plt.show()


###########################################################################
# Confidence Interval
###########################################################################

def confidence_interval_demo():

    # population mean of test scores: mean = 8
    population = rng.normal(loc=40, scale=15, size=100_000)
    sample = rng.choice(population, size=50, replace=True)
    sample_mean = np.mean(sample)
    sample_sd = np.std(sample, ddof=1)
    n = len(sample)

    # TODO: standard error of the mean
    # pretending one sample represents population
    se = None

    # TODO: 95% confidence interval using z ≈ 1.96
    z = 1.96
    ci_lower = None
    ci_upper = None

    print("Confidence Interval ---------------")
    print(f"Sample mean: {sample_mean:.2f}")
    print(f"Sample SD: {sample_sd:.2f}")
    print(f"Standard error: {se:.2f}")
    print(f"95% CI for mean: ({ci_lower:.2f}, {ci_upper:.2f})")

    def compute_ci(sample, z=1.96):
        m = np.mean(sample)
        s = np.std(sample, ddof=1)
        n = len(sample)
        se = s / np.sqrt(n)
        return m - z*se, m + z*se

    true_mean = np.mean(population)

    num_intervals = 300
    count_contains = 0

    # TODO: determine whether or not the interval contains the true_mean
    for _ in range(num_intervals):
        pass

    print(f"{count_contains}/{num_intervals} intervals contained the true mean "
        f"≈ {100*count_contains/num_intervals:.1f}%")

###########################################################################
# Statistical significance 
###########################################################################

def coin_flip_demo():
    rng = np.random.default_rng(2)

    # given 65 heads flipped out of 100 what does this say about whether the coin is fair?
    # observed data
    n_flips = 100
    observed_heads = 65
    observed_prop = observed_heads / n_flips

    print(f"Observed: {observed_heads} heads out of {n_flips} flips (p̂ = {observed_prop:.2f})")

    # null hypothesis: fair coin, p = 0.5
    # alt hypothesis: biased towards heads
    p0 = 0.5
    num_sims = 50_000 # number of simulated experiments under H0

    sim_props = []
    for _ in range(num_sims):
        # how does the binomial function model H0?
        flips = rng.binomial(n=1, p=p0, size=n_flips)
        sim_props.append(np.mean(flips))

    sim_props = np.array(sim_props)

    # one-sided p-value: probability of seeing a proportion at least this far from 0.5
    dist_from_null_obs = observed_prop - p0
    dist_from_null_sims = sim_props - p0

    # TODO: what is p_value?
    p_value = None

    print(f"Approximate one-sided p-value: {p_value:.4f}")
    # what does failing to reject H0 mean? does that mean we are accepting H0?
    alpha = 0.05
    
    # TODO: what is the condition for rejecting / failing to reject H0?
    if None:
        print(f"p ? {alpha} -> reject H0 (coin looks biased)")
    else:
        print(f"p ? {alpha} -> fail to reject H0 (no strong evidence of bias)")