import matplotlib.pyplot as plt
import numpy as np
import random


#############################################################
# evaluating goodness of fit
#############################################################


def read_data(filename):
    x_vals, y_vals = [[], []]
    with open(filename) as file:
        for line in file:
            x_str, y_str = line.split(",")
            x_vals.append(float(x_str))
            y_vals.append(float(y_str))
    return [x_vals, y_vals]


def fit_data(filename, deg=1, plot=True):
    x_vals, y_vals = read_data(filename)
    model = np.polyfit(x_vals, y_vals, deg)

    if plot:
        xlim = [min(0, min(x_vals)) * 1.1, max(0, max(x_vals)) * 1.1]
        ylim = [min(0, min(y_vals)) * 1.1, max(0, max(y_vals)) * 1.1]
        x_model = np.linspace(xlim[0], xlim[1])
        y_model = np.polyval(model, x_model)

        plt.figure()
        plt.scatter(x_vals, y_vals, color="r")
        plt.plot(x_model, y_model, marker="")
        plt.title(f"Data from {filename}, fit degree = {deg}")
        plt.xlabel("x")
        plt.ylabel("y")
        plt.xlim(xlim)
        plt.ylim(ylim)
        plt.grid()

    return [x_vals, y_vals, model]


# fit_data("mystery1.csv", deg=1)
# fit_data("mystery2.csv", deg=2)
# fit_data("mystery2.csv", deg=1)
# fit_data("mystery2-dense.csv", deg=2)
# fit_data("mystery2-dense.csv", deg=1)


def sse(y_vals, y_preds):
    squared_diffs = []
    for i in range(len(y_vals)):
        squared_diffs.append((y_vals[i] - y_preds[i]) ** 2)
    return sum(squared_diffs)


def mse(y_vals, y_preds):
    return sse(y_vals, y_preds) / len(y_vals)


def variance(y_vals):
    y_mean = sum(y_vals) / len(y_vals)
    return mse(y_vals, [y_mean] * len(y_vals))


def r_squared(y_vals, y_preds):
    return 1 - mse(y_vals, y_preds) / variance(y_vals)


def eval_fit(filename, deg=1, metric=sse, verbose=True):
    x_vals, y_vals, model = fit_data(filename, deg=deg, plot=False)
    y_preds = np.polyval(model, x_vals)
    goodness = metric(y_vals, y_preds)

    if verbose:
        if metric == sse:
            metric_name = "SSE"
        elif metric == mse:
            metric_name = "MSE"
        elif metric == r_squared:
            metric_name = "R^2"
        print(metric_name, goodness)

    return [x_vals, y_vals, model, goodness]


def eval_metrics():
    for metric in [sse]:
        print()
        # eval_fit("mystery1.csv", deg=0, metric=metric)
        # eval_fit("mystery1.csv", deg=1, metric=metric)
        # eval_fit("mystery2.csv", deg=2, metric=metric)
        # eval_fit("mystery2.csv", deg=1, metric=metric)
        # eval_fit("mystery2-dense.csv", deg=2, metric=metric)
        # eval_fit("mystery2-dense.csv", deg=1, metric=metric)
        # eval_fit("mystery2-double.csv", deg=2, metric=metric)
        # eval_fit("mystery2-double.csv", deg=1, metric=metric)


# eval_metrics()


#############################################################
# training and validation
#############################################################


def vary_degree(filename, degrees, plot=False):
    if plot:
        plt.figure()

    print()
    for deg in degrees:
        x_vals, y_vals, model, r2 = eval_fit(
            filename, deg=deg, metric=r_squared, verbose=False
        )
        print(f"{deg = }, {r2 = }")

        if plot:
            xlim = [min(0, min(x_vals)) * 1.1, max(0, max(x_vals)) * 1.1]
            ylim = [min(0, min(y_vals)) * 1.1, max(0, max(y_vals)) * 1.1]
            x_model = np.linspace(xlim[0], xlim[1])
            y_model = np.polyval(model, x_model)

            plt.scatter(x_vals, y_vals, color="r")
            plt.plot(x_model, y_model, marker="", label=f"{deg=}, {r2=:.5f}")
            plt.xlim(xlim)
            plt.ylim(ylim)

    if plot:
        plt.title(f"Fitting polynomials on data from {filename}")
        plt.xlabel("x")
        plt.ylabel("y")
        plt.grid()
        plt.legend()


# vary_degree("mystery2.csv", degrees=range(10))
# vary_degree("mystery2-dense.csv", degrees=range(10))

# vary_degree("mystery2.csv", degrees=range(5), plot=True)
# vary_degree("mystery2.csv", degrees=range(1, 11, 3), plot=True)


def split_data(x_vals, y_vals):
    ##########################
    # approach 1: split on every other data point
    ##########################

    training = [x_vals[::2], y_vals[::2]]
    validation= [x_vals[1::2], y_vals[1::2]]
    return [training, validation]

    ##########################
    # approach 2: randomly partition data
    ##########################

    # random.seed(1)

    # shuffle x and y values together
    xy_vals = []
    for i in range(len(x_vals)):
        xy_vals.append([x_vals[i], y_vals[i]])
    random.shuffle(xy_vals)

    # split the shuffled pairs
    mid = len(xy_vals) // 2
    xy_train = xy_vals[:mid]
    xy_validate = xy_vals[mid:]

    # regroup into x and y lists
    x_train, y_train = [[], []]
    x_validate, y_validate = [[], []]
    for i in range(len(xy_train)):
        x, y = xy_train[i]
        x_train.append(x)
        y_train.append(y)
    for i in range(len(xy_validate)):
        x, y = xy_validate[i]
        x_validate.append(x)
        y_validate.append(y)

    return [[x_train, y_train], [x_validate, y_validate]]


def train_validate(filename, degrees):
    x_vals, y_vals = read_data(filename)
    [x_train, y_train], [x_validate, y_validate] = split_data(x_vals, y_vals)

    plt.figure()
    xlim = [min(0, min(x_vals)) * 1.1, max(0, max(x_vals)) * 1.1]
    ylim = [min(0, min(y_vals)) * 1.1, max(0, max(y_vals)) * 1.1]

    for deg in degrees:
        # KEY DIFFERENCE: fit model only on training data set
        model = np.polyfit(x_train, y_train, deg)
        x_model = np.linspace(xlim[0], xlim[1])
        y_model = np.polyval(model, x_model)

        # KEY DIFFERENCE: evaluate model only on validation data set
        y_pred = np.polyval(model, x_validate)
        r2 = r_squared(y_validate, y_pred)

        # plt.scatter(x_vals, y_vals, color="r")
        plt.scatter(x_train, y_train, color="r")
        plt.scatter(x_validate, y_validate, color="r", marker="x")
        plt.plot(x_model, y_model, marker="", label=f"{deg=}, {r2=:.5f}")

    plt.title(f"Fitting polynomials on data from {filename}")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.grid()
    plt.legend()


# train_validate("mystery2.csv", [1, 2, 3, 4, 5])
# train_validate("mystery2-dense.csv", [1, 2, 3, 4, 5])
# train_validate("mystery2-dense.csv", range(1, 11, 3))


############################################################
# random seed
############################################################


def demo_random_seed():
    for seed in [1, 2, 3]:
        print()
        for trial in range(3):
            random.seed(seed)
            print()
            print(f"{seed = }, {trial = }")
            print(random.random())
            print(random.random())
            print(random.random())


# demo_random_seed()


# https://docs.python.org/3/library/random.html
# https://en.wikipedia.org/wiki/Mersenne_Twister


############################################################
# show all plots
############################################################


plt.show()
