#!/usr/bin/env python3

from collections import Counter
import random
from typing import Dict, List
import matplotlib.pyplot as plt
import numpy as np

############################################################
# Functions as parameters
############################################################

def filtered_list(orig_list, test):
    out_list = []
    for e in orig_list:
        if test(e):
            out_list.append(e)
    return out_list

def test_odd(x):
    return x%2 == 1
def test_even(x):
    return x%2 == 0

l = [1, 2, 3, 4, 5]

# print(filtered_list(l, test_odd))
# print(filtered_list(l, test_even))

# #the code below is equivalent but creates a function with lambda without naming it
# print(filtered_list(l, lambda x: x%2 == 1))


def add_all(n, fns):
    ans = 0
    for f in fns:
        ans += f(n)
    return ans

def f(x): return x**3
def g(x): return x*2
def h(x): return -2*x
list_of_functions = [f, g, h]

# print (add_all(2, list_of_functions))

############################################################
# Function as output
############################################################


def make_specialized_filtering_function(test):
    return lambda orig_list: list(filter(test, orig_list))

# example usage:
f = make_specialized_filtering_function(lambda x: x > 0)

# print(f([-2, 0, 3, 5, -1]))  # Output: [3, 5]

############################################################
# Comprehension
############################################################

l = [n for n in range(20) if n%2==1]
l2 = [n**2 for n in l]
l3 = [n**2 for n in range(20) if  n%2==1]

# print(l3)

########################################
# Warm-ups:
# 1. Write a list comprehension that pairs each name in a list with its
#    length as a tuple, producing, e.g.:
#       [("Jean-Luc", 8), ("William", 7), ...]
# 2. Combine first and last names into full-name strings.
#    a. Don't worry at first about whether to include a space between
#       first and last names.
#    b. Challenge: Adjust your code to handle "Worf" and "Data".
# 3. Print all full-name strings, each on its own lines, using a
#    one-line comprehension.


first_names = ["Jean-Luc", "William", "Geordi", "Worf", "Beverly", "Deanna", "Data"]
last_names = ["Picard", "Riker", "La Forge", ", son of Mogh", "Crusher", "Troi", ""]
assert len(first_names) == len(last_names)
num_names = len(first_names)


# Task 1
names_with_lengths = [(name, len(name)) for name in first_names]

names_with_lengths = []
for name in first_names:
    names_with_lengths.append((name, len(name)))

# print(names_with_lengths)


# Task 2
full_names = [
    first_names[ind] + " " + last_names[ind]
    for ind in range(num_names)
]
# print(full_names)


# Handle special last names
def combine(first, last):
    if not last or last[0] == ",":
        separator = ""
    else:
        separator = " "
    return first + separator + last

full_names = [
    combine(first_names[ind], last_names[ind])
    for ind in range(num_names)
]
# print(full_names)

# More compact form using conditional statement: A if condition else B
# Also, using zip() to avoid explicit indexing
full_names = [
    first + (" " if last and last[0] != "," else "") + last
    for first, last in zip(first_names, last_names)
]
# print(full_names)


# Task 3
# [print(name) for name in full_names]

############################################################
# RANDOM WALK SIMULATION CODE
############################################################

import random
from math import hypot
import matplotlib.pyplot as plt

# ---------- Locations (tuples) ----------
def make_location(x, y):
    return (x, y)

def move(loc, dx, dy):
    x, y = loc
    return (x + dx, y + dy)

def get_x(loc):
    return loc[0]

def get_y(loc):
    return loc[1]

def dist(loc_a, loc_b):
    return hypot(loc_a[0] - loc_b[0], loc_a[1] - loc_b[1])

def loc_str(loc):
    return '<' + str(get_x(loc)) + ', ' + str(get_y(loc)) + '>'



# ---------- Field (dict of {drunk_id: location}) ----------
def make_field():
    return {}

def add_drunk(field, drunk_id, loc):
    if drunk_id in field:
        raise ValueError("Duplicate drunk")
    field[drunk_id] = loc

def get_loc(field, drunk_id):
    if drunk_id not in field:
        raise ValueError("Drunk not in field")
    return field[drunk_id]

def move_drunk(field, drunk_id, step_fn):
    if drunk_id not in field:
        raise ValueError("Drunk not in field")
    dx, dy = step_fn()
    field[drunk_id] = move(field[drunk_id], dx, dy)

# ---------- “Drunk” step functions ----------
def usual_step():
    return random.choice([(0, 1), (0, -1), (1, 0), (-1, 0)])

def masochist_step():
    return random.choice([(0.0, 1.1), (0.0, -0.9), (1.0, 0.0), (-1.0, 0.0)])

def liberal_step():
    return random.choice([(0.0, 1.0), (0.0, -1.0), (0.9, 0.0), (-1.1, 0.0)])

def conservative_step():
    return random.choice([(0.0, 1.0), (0.0, -1.0), (1.1, 0.0), (-0.9, 0.0)])

def liberal_masochist_step():
    # Flip between liberal and masochist tendencies
    if random.choice([True, False]):
        return liberal_step()
    else:
        return masochist_step()

def corner_step():
    return random.choice([(0.71, 0.71), (0.71, -0.71), (-0.71, 0.71), (-0.71, -0.71)])

def continuous_step():
    return (random.uniform(-1,1), random.uniform(-1,1))

# ---------- Simulation helpers ----------
def walk(field, drunk_id, step_fn, num_steps):
    """Moves drunk_id num_steps times; returns distance from start to end."""
    start = get_loc(field, drunk_id)
    for _ in range(num_steps):
        move_drunk(field, drunk_id, step_fn)
    return dist(start, get_loc(field, drunk_id))

def walk_with_viz(field, drunk_id, step_fn, num_steps, pause=0.0001, axis_max=30, clr='black'):
    """Same as walk(), but draws the path live."""
    Lx, Ly = [], []
    start = get_loc(field, drunk_id)

    plt.xlim(-axis_max, axis_max)
    plt.ylim(-axis_max, axis_max)
    plt.plot(0, 0, 'o', color=clr, markersize=3)

    Lx.append(get_x(start))
    Ly.append(get_y(start))

    for _ in range(num_steps):
        move_drunk(field, drunk_id, step_fn)
        loc = get_loc(field, drunk_id)
        Lx.append(get_x(loc))
        Ly.append(get_y(loc))
        plt.plot(Lx, Ly, linewidth=1, color=clr)
        plt.draw()
        plt.show(block=False)
        plt.pause(pause)

    return dist(start, get_loc(field, drunk_id))

# random.seed(250)
# f = make_field()
# origin = make_location(0, 0)
# steps=350
# add_drunk(f, "Homer1", origin)
# print("Starting simulating Homer:")
# print("Distance:", walk_with_viz(f, "Homer1", usual_step, steps, clr='blue'))

# print("Starting simulating Bart:")
# add_drunk(f, "Bart1", origin)
# print("Distance:", walk_with_viz(f, "Bart1", liberal_step, steps, clr='green'))

# add_drunk(f, "Homer2", origin)
# print("Starting simulating Homer:")
# print("Distance:", walk_with_viz(f, "Homer2", usual_step, steps, clr='cornflowerblue'))

# print("Starting simulating Bart:")
# add_drunk(f, "Bart2", origin)
# print("Distance:", walk_with_viz(f, "Bart2", liberal_step, steps, clr='springgreen'))

# add_drunk(f, "Homer3", origin)
# print("Starting simulating Homer:")
# print("Distance:", walk_with_viz(f, "Homer3", usual_step, steps, clr='navy'))

# print("Starting simulating Bart:")
# add_drunk(f, "Bart3", origin)
# print("Distance:", walk_with_viz(f, "Bart3", liberal_step, steps, clr='darkgreen'))


def sim_walks(num_steps, num_trials, step_fn):
    """Runs num_trials walks of num_steps; returns list of final distances."""
    origin = make_location(0, 0)
    distances = []
    for _ in range(num_trials):
        f = make_field()
        add_drunk(f, "Homer", origin)
        distances.append(round(walk(f, "Homer", step_fn, num_steps), 1))
    return distances

def drunk_test(walk_lengths, num_trials, step_fn, step_name="step_fn"):
    """For each length, run sim_walks and print summary stats."""
    for num_steps in walk_lengths:
        distances = sim_walks(num_steps, num_trials, step_fn)
        print(step_name, "random walk of", num_steps, "steps")
        print(" Mean =", round(sum(distances) / len(distances), 4))
        print(" Max =", max(distances), "Min =", min(distances))


# random.seed(0)
# drunk_test((10, 100, 1000, 10000), 100, usual_step, step_name="Usual")


############################################################
# PageRank via Random Walk Simulation
############################################################

def teleport(nodes: List[str]) -> str:
    """Return a random node (teleportation step)."""
    return random.choice(nodes)


def step_from(current: str, graph: Dict[str, List[str]], nodes: List[str], damping: float) -> str:
    """Perform one step in the random walk (no nested functions)."""
    # Teleportation with probability 1 - damping
    if random.random() >= damping:
        return teleport(nodes)

    # Otherwise, follow a link or teleport if dangling
    outs = graph[current]
    if len(outs) == 0:
        return teleport(nodes)
    return random.choice(outs)


def simulate_pagerank_random_walk(
    graph: Dict[str, List[str]],
    steps: int = 200_000,
    damping: float = 0.85,
    seed: int = 42,
) -> Dict[str, float]:
    """Simulate PageRank using a Monte Carlo random walk."""
    random.seed(seed)

    # Build list of all nodes
    nodes = set(graph.keys())
    for outs in graph.values():
        nodes.update(outs)
    nodes = sorted(nodes)

    # Ensure every node has an adjacency list
    for u in nodes:
        if u not in graph:
            graph[u] = []

    # Start at random node
    current = random.choice(nodes)
    visits = Counter()
    counted_steps = 0

    for t in range(1, steps + 1):
        current = step_from(current, graph, nodes, damping)
        visits[current] += 1
        counted_steps += 1

    # Convert counts to probabilities
    if counted_steps == 0:
        return {u: 1.0 / len(nodes) for u in nodes}

    ranks = {u: visits[u] / counted_steps for u in nodes}
    total = sum(ranks.values())
    if total > 0:
        for u in ranks:
            ranks[u] /= total  # Normalize
    return ranks


def pretty_print_ranks(ranks: Dict[str, float]) -> None:
    """Print PageRank values sorted in descending order."""
    items = sorted(ranks.items(), key=lambda kv: kv[1], reverse=True)
    width = max(len(node) for node, _ in items)
    for node, rank in items:
        print(f"{node:<{width}}  {rank:.6f}")
    print(f"(Sum = {sum(ranks.values()):.6f})")


def example_graph() -> Dict[str, List[str]]:
    """Example directed graph."""
    return {
        "A": ["B", "C"],
        "B": ["C"],
        "C": ["A", "D", "E", "G"],
        "D": ["C", "G"],
        "E": ["D"],
        "F": ["C"],  # F points to only one node
        "G": [],     # G is a dangling node
    }


# graph = example_graph()
# ranks = simulate_pagerank_random_walk(
#     graph=graph,
#     steps=1_000_000,
#     damping=0.85,
#     seed=123
# )
# print("Estimated PageRank (Random Walk):")
# pretty_print_ranks(ranks)
