import matplotlib.pyplot as plt
import random
import time


############################################################
# designing __eq__() when attributes contain objects of the same type
############################################################


class Animal:

    def __init__(self, age, name=None):
        self.age = age
        self.name = name

    def __str__(self):
        return f"<Animal {self.name} {self.age} years>"

    def get_age_diff(self, other):
        """
        Return the magnitude of the age difference between two Animals.
        """
        assert isinstance(other, Animal)
        return abs(self.age - other.age)


class Rabbit(Animal):

    next_tag = 1

    def __init__(self, age, parent1=None, parent2=None):
        assert (
            parent1 is None and parent2 is None
            or parent1 is not None and parent2 is not None
        )
        super().__init__(age)
        self.parents = parent1, parent2
        # self.parents = {parent1, parent2}  # need __hash__() with __eq__()
        self.tag = Rabbit.next_tag
        Rabbit.next_tag += 1

    def __str__(self):
        return f"<Rabbit {self.tag:>03}>"

    __repr__ = __str__

    def __add__(self, other):
        """Make a new Rabbit offspring of self and other."""
        return Rabbit(0, self, other)

    def __eq__(self, other):
        """Return True if and only if self and other have the same parents."""
        raise NotImplementedError()

    def __eq__v1(self, other):
        # this ends up recursively calling __eq__()
        # on self.parents[0] and other.parents[0],
        # doesn"t work if one is a Rabbit and the other is None
        # return self.parents == other.parents
        return (
            self.parents == other.parents
            or self.parents == other.parents[::-1]
        )

    def __eq__v2(self, other):
        # this almost works, but not when self's or other's parents are None
        self_parent_tags = [p.tag for p in self.parents]
        other_parent_tags = [p.tag for p in other.parents]
        return (
            self_parent_tags == other_parent_tags
            or self_parent_tags == other_parent_tags[::-1]
        )

    __eq__ = __eq__v1
    # __eq__ = __eq__v2


def demo_rabbit_equality():
    print()
    r1 = Rabbit(age=3)
    r2 = Rabbit(age=4)
    r3 = Rabbit(age=5)
    print(f"{r1 = }")
    print(f"{r2 = }")
    print(f"{r3 = }")

    print()
    print(f"{r1.parents = }")
    r4 = r1 + r2
    print(f"{r4 = }")
    print(f"{r4.parents = }")

    print()
    r5 = r3 + r4
    print(f"{r5 = }")
    print(f"{r5.parents = }")

    print()
    r6 = r4 + r3
    print(f"{r6 = }")
    print(f"{r6.parents = }")

    print()
    print(f"{(r5 == r6) = }")  # should be True
    print(f"{(r4 == r6) = }")  # should be False
    print(f"{(r1 == r2) = }")  # should be True

    print()
    r8 = r1 + r2 + r3  # what are r8's parents?
    print(f"{r8 = }")
    print(f"{r8.parents = }")
    print(f"{r8.parents[0].parents = }")


# demo_rabbit_equality()


############################################################
# model random walks with classes and inheritance
############################################################


def mean(vals):
    return sum(vals) / len(vals)


def split_xy(points):
    return [x for x, _ in points], [y for _, y in points]


class Walker:

    def __init__(self, start):
        assert isinstance(start, Vector)
        self.location = start

    def __str__(self):
        return f"Walker @ {self.location}"

    def step(self):
        return random.choice([(0, 1), (0, -1), (1, 0), (-1, 0)])

    def move(self):
        self.location = self.location + self.step()


class Vector(list):

    def __init__(self, x, y):
        super().__init__((x, y))
        assert len(self) == 2

    def __str__(self):
        return f"<{self[0]}, {self[1]}>"

    def __add__(self, other):
        assert len(other) == 2
        return Vector(self[0] + other[0], self[1] + other[1])

    def distance(self, other):
        assert len(other) == 2
        dx = self[0] - other[0]
        dy = self[1] - other[1]
        return (dx**2 + dy**2) ** 0.5


def perform_walk(start, num_steps):
    person = Walker(start)
    history = [person.location]
    for _ in range(num_steps):
        person.move()
        history.append(person.location)
    return history


def plot_walk(num_steps):
    origin = Vector(0, 0)
    xpos, ypos = split_xy(perform_walk(origin, num_steps))
    plt.figure()
    plt.plot(xpos, ypos, color="orange")
    plt.plot(xpos[0], ypos[0], color="red", marker="o")
    plt.plot(
        xpos[-1], ypos[-1],
        color="xkcd:bright blue",
        marker="x", markersize=10, markeredgewidth=3,
    )
    plt.title(f"Random walk of {num_steps} uniform steps")
    plt.xlabel("X position")
    plt.ylabel("Y position")
    plt.xlim(-50, 50)
    plt.ylim(-50, 50)
    plt.grid()


# plot_walk(num_steps=2000)


########################################


class Walker:

    def __init__(self, start):
        assert isinstance(start, Vector)
        self.location = start

    def __str__(self):
        return f"Walker @ {self.location}"

    def step(self):
        raise NotImplementedError()

    def move(self):
        self.location += self.step()


class UniformWalker(Walker):

    def step(self):
        return random.choice([(1, 0), (-1, 0), (0, 1), (0, -1)])


class LeftBiasWalker(Walker):

    def step(self):
        return random.choice([(0.9, 0), (-1.1, 0), (0, 1), (0, -1)])


class RightBiasWalker(Walker):

    def step(self):
        return random.choice([(1.1, 0), (-0.9, 0), (0, 1), (0, -1)])


def collect_walk_ends(num_walks, num_steps, walker_type):
    end_locs = []
    for _ in range(num_walks):
        origin = Vector(0, 0)
        person = walker_type(start=origin)
        for _ in range(num_steps):
            person.move()
        end_locs.append(person.location)
    return end_locs


def plot_walk_ends(num_walks, num_steps, walker_type=UniformWalker):
    xpos, ypos = split_xy(collect_walk_ends(num_walks, num_steps, walker_type))
    plt.figure()
    plt.plot([0], [0], color="red", marker="o", markersize=15)  # plot origin
    plt.scatter(xpos, ypos, marker="x")
    plt.plot(mean(xpos), mean(ypos), color="cyan", marker="*", markersize=10)
    plt.title(f"Ending locations after {num_steps:,d} steps, {walker_type.__name__}")
    plt.xlabel("X position")
    plt.ylabel("Y position")
    plt.xlim(-100, 100)
    plt.ylim(-100, 100)
    plt.grid()


# plot_walk_ends(num_walks=100, num_steps=1000)
# plot_walk_ends(num_walks=100, num_steps=1000, walker_type=LeftBiasWalker)
# plot_walk_ends(num_walks=100, num_steps=1000, walker_type=RightBiasWalker)


def average_distance(num_walks, num_steps, walker_type=UniformWalker):
    return mean([
        loc.distance([0, 0])
        for loc in collect_walk_ends(num_walks, num_steps, walker_type)
    ])


# plot_walk_ends(num_walks=100, num_steps=100)
# plot_walk_ends(num_walks=100, num_steps=1000)
# plot_walk_ends(num_walks=100, num_steps=10000)

# print(average_distance(num_walks=100, num_steps=100))
# print(average_distance(num_walks=100, num_steps=1000))
# print(average_distance(num_walks=100, num_steps=10000))


def plot_walk_distances(max_steps, walker_type=UniformWalker, verbose=False):
    num_walks = 100
    steps_range = range(10, max_steps + 1, 10)

    distances = []
    for num_steps in steps_range:
        distances.append(average_distance(num_walks, num_steps, walker_type))
        if verbose and num_steps % 100 == 0:
            print(f"processed {num_steps = }")

    plt.figure()
    plt.plot(steps_range, distances)
    plt.title(f"Ending distances versus number of steps, {walker_type.__name__}")
    plt.xlabel("Number of steps")
    plt.ylabel("Average distance from origin")
    plt.grid()


# plot_walk_distances(max_steps=1000, verbose=True)
# plot_walk_distances(max_steps=1000, verbose=True, walker_type=LeftBiasWalker)
# plot_walk_distances(max_steps=1000, verbose=True, walker_type=RightBiasWalker)


############################################################
# model gas particles as random walks
############################################################


class Walker:

    def step(self):
        raise NotImplementedError()

    def reset(self):
        pass


class ContinuousWalker(Walker):

    def step(self):
        return random.uniform(-0.5, 0.5), random.uniform(-0.5, 0.5)


class InertialWalker(ContinuousWalker):

    def __init__(self):
        self.reset()

    def reset(self):
        self._last_step = None

    def step(self):
        if self._last_step is None:
            self._last_step = super().step()
        return self._last_step


class Field:

    def __init__(self, size):
        self._particle_locs = {}
        assert size > 0
        self._xlim = 0, size
        self._ylim = 0, size

    def add_particle(self, particle, loc):
        assert particle not in self._particle_locs, "duplicate particle"
        self._particle_locs[particle] = loc

    def get_loc(self, particle):
        assert particle in self._particle_locs, "particle not in field"
        return self._particle_locs[particle]

    def move_particle(self, particle):
        assert particle in self._particle_locs, "particle not in field"
        self._particle_locs[particle] += particle.step()


def simulate_gas(field_size, num_particles, num_steps, walker_type, plot=False):
    # make bounded field of gas particles
    field = Field(field_size)
    start_locs = {}
    for _ in range(num_particles):
        start = Vector(
            random.uniform(0, field_size),
            random.uniform(0, field_size),
        )
        p = walker_type()
        field.add_particle(p, start)
        start_locs[p] = list(start)

    # run simulation
    for _ in range(num_steps):
        for p in start_locs:
            field.move_particle(p)

    # plot final positions
    if plot:
        particles_x, particles_y = split_xy([field.get_loc(p) for p in start_locs])
        plt.figure()
        plt.scatter(particles_x, particles_y)
        plt.title(f"Particle end locations after {num_steps} steps, {walker_type.__name__}")
        plt.xlabel("X position")
        plt.ylabel("Y position")
        plt.xlim(0, field_size)
        plt.ylim(0, field_size)
        plt.grid()

    return field.wall_hits, field.collisions


# print(simulate_gas(50, 100, 100, ContinuousWalker, plot=True))
# print(simulate_gas(50, 100, 1000, ContinuousWalker, plot=True))
# print(simulate_gas(50, 100, 100, InertialWalker, plot=True))
# print(simulate_gas(50, 100, 1000, InertialWalker, plot=True))


############################################################
# check ideal gas law in two dimensions
############################################################


def vary_num_particles(num_trials):
    field_size = 50
    particle_counts = [2**k for k in range(1, 9)]
    num_steps = 100
    pressures = []
    for num_particles in particle_counts:
        trials = []
        for _ in range(num_trials):
            wall_hits, _ = simulate_gas(
                field_size, num_particles, num_steps, InertialWalker
            )
            trials.append(wall_hits)
        pressures.append(mean(trials) * num_steps / (4 * field_size))

    plt.figure()
    plt.loglog(particle_counts, pressures, marker="o")
    plt.title("Pressure vs Number of Particles")
    plt.xlabel("# particles")
    plt.ylabel("# wall hits $\\times$ walk length / perimeter")
    plt.grid()


# vary_num_particles(num_trials=30)


def vary_temperature(num_trials):
    field_size = 50
    num_particles = 100
    step_counts = [10 * 2**k for k in range(1, 9)]
    wall_hit_counts = []
    pressures = []
    for num_steps in step_counts:
        trials = []
        for _ in range(num_trials):
            wall_hits, _ = simulate_gas(
                field_size, num_particles, num_steps, InertialWalker
            )
            trials.append(wall_hits)
        pressures.append(mean(trials) * num_steps / (4 * field_size))
    temperatures = [num_steps**2 for num_steps in step_counts]

    plt.figure()
    plt.loglog(temperatures, pressures, marker="o")
    plt.title("Pressure vs Temperature")
    plt.xlabel("$(\\text{# steps})^2$")
    plt.ylabel("# wall hits $\\times$ walk length / perimeter")
    plt.grid()


# vary_temperature(num_trials=30)


def vary_area(num_trials):
    field_sizes = [5, 10, 20, 30, 40, 50]
    field_sizes = [10 * 2**k for k in range(1, 12)]
    num_particles = 100
    num_steps = 100
    wall_hit_counts = []
    pressures = []
    for field_size in field_sizes:
        trials = []
        for _ in range(num_trials):
            wall_hits, _ = simulate_gas(
                field_size, num_particles, num_steps, InertialWalker
            )
            trials.append(wall_hits)
        pressures.append(mean(trials) * num_steps / (4 * field_size))
    areas = [size**2 for size in field_sizes]

    plt.figure()
    plt.loglog(areas, pressures, marker="o")
    plt.title("Pressure vs Area")
    plt.xlabel("$(\\text{field size})^2$")
    plt.ylabel("# wall hits $\\times$ walk length / perimeter")
    plt.grid()

    plt.figure()
    plt.loglog(areas, [p * a for p, a in zip(pressures, areas)], marker="o")
    plt.title("Pressure $\\times$ Area")
    plt.xlabel("$(\\text{field size})^2$")
    plt.ylim(10, max(areas) * max(pressures))
    plt.grid()


# vary_area(num_trials=30)


############################################################
# show all plots
############################################################


plt.show()
