import random
import numpy as np
import matplotlib.pyplot as plt
import time
import sys


#set line width
plt.rcParams['lines.linewidth'] = 4
#set font size for titles
plt.rcParams['axes.titlesize'] = 16
#set font size for labels on axes
plt.rcParams['axes.labelsize'] = 16
#set size of numbers on x-axis
plt.rcParams['xtick.labelsize'] = 12
#set size of numbers on y-axis
plt.rcParams['ytick.labelsize'] = 12
#set size of ticks on x-axis
plt.rcParams['xtick.major.size'] = 7
#set size of ticks on y-axis
plt.rcParams['ytick.major.size'] = 7
#set size of markers, e.g., circles representing points
#set numpoints for legend
plt.rcParams['legend.numpoints'] = 1


############################################################
# inheritance: method reuse and overriding/shadowing
############################################################


class Animal:

    def __init__(self, age, name=None):
        self.age = age
        self.name = name

    def __str__(self):
        return f'<Animal "{self.name}" {self.age} years>'

    def get_age_diff(self, other):
        """
        Return the magnitude of the age difference between two Animals.
        """
        assert isinstance(other, Animal)
        return abs(self.age - other.age)

    def speak(self):
        """Print to the terminal some text that self would say."""
        raise NotImplementedError(
            f"Animal {self} does not have speak()ing ability yet"
        )


def demo_animal():
    creature = Animal(5, name="Morg")
    critter = Animal(2, name="Florg")
    print(creature)
    print(critter)

    # avoid creating new data attributes on-the-fly
    critter.size = "tiny"
    print(critter.size)
    print(creature.size)  # raises AttributeError

    # these are not identical method calls!
    # they just happen to give the same result
    print(creature.get_age_diff(critter))
    print(critter.get_age_diff(creature))

    creature.speak()  # raises NotImplementedERror


# demo_animal()


class Cat(Animal):
    """A Cat is an Animal that says "Meow" or something like that."""

    def __str__(self):
        return f"<Cat {self.name} {self.age} years>"

    def speak(self):
        """Print something a cat would say."""
        print("Meow")

    def confuse(self, other):
        """
        Given two Cats, make them both speak, swap their names, and
        return a string of the form:
            "<name> and <name> are <#> years apart."
        """
        self.speak()
        other.speak()
        self.name, other.name = other.name, self.name
        print(
            f"{self.name} and {other.name} are "
            f"{self.get_age_diff(other)} years apart."
        )


def demo_cats():
    lion = Cat(5, "Fluffy")
    tiger = Cat(8, "Furry")
    print(lion)
    print(tiger)

    lion.confuse(tiger)
    print(lion)
    print(tiger)


# demo_cats()


############################################################
# inheritance chains and super()
############################################################


class Person(Animal):
    """A Person is an Animal with other Person friends."""

    def __init__(self, name, age):  # note the order of the parameters!
        # use Animal to set common attributes
        Animal.__init__(self, age, name)
        # super().__init__(age, name)  # equivalent to line above

        # store Person-specific info
        self.friends = set()

        # super() reinterprets self as an object of type Animal
        # print(Animal.__init__)
        # print(self.__init__)
        # print(super().__init__)
        # print()

    def __str__(self):
        return f'<Person "{self.name}" {self.age} years old>'

    # enables __str__()-like printing of objects within collections
    # __repr__ = __str__

    def speak(self):
        print(f"Hello. My name is {self.name}.")

    def add_friend(self, other):
        """Record other as a friend of self and vice versa."""
        assert isinstance(other, Person)
        self.friends.add(other)
        other.friends.add(self)


def demo_person_friends():
    inigo = Person("Inigo Montoya", 35)
    fezzik = Person("Fezzik", 41)
    westley = Person("Dread Pirate Roberts", 25)
    max = Person("Miracle Max", 39)

    inigo.speak()
    inigo.add_friend(fezzik)
    inigo.add_friend(westley)
    print()
    print(inigo.friends)
    print(fezzik.friends)
    print()

    # use add_friend() to ensure mutual friendship
    # inigo.friends.add(max)
    # print(inigo.friends)
    # print(max.friends)
    # print()


# demo_person_friends()


class Student(Person):

    def __init__(self, name, age, major=None):
        super().__init__(name, age)
        self.major = major

    def __str__(self):
        name = f'"{self.name}"'
        age = f"{self.age} years old"
        major = f"Course {self.major}"
        return f"<Student {", ".join([name, age, major])}>"

    def speak(self):
        """Introduce yourself and then a friend a random."""
        super().speak()
        friend = random.choice(list(self.friends))
        if isinstance(friend, Student):
            print(f"My friend {friend.name} is Course {friend.major}.")
        else:
            print(f"My friend {friend.name} isn't a student.")
        print()


def demo_student_speak():
    alex = Student("Alex", 20, "6-4")
    kelly = Student("Kelly", 22, "CMS")
    sam = Student("Sam", 18, "2-A")
    max = Person("Miracle Max", 39)

    print(alex)
    alex.add_friend(kelly)
    alex.add_friend(sam)
    alex.add_friend(max)
    for _ in range(6):
        alex.speak()


# demo_student_speak()


############################################################
# designing __eq__() when attributes contain objects of the same type
############################################################


class Rabbit(Animal):

    next_tag = 1

    def __init__(self, age, parent1=None, parent2=None):
        super().__init__(age)
        self.parents = (parent1, parent2)
        self.id = Rabbit.next_tag
        Rabbit.next_tag += 1

    def __str__(self):
        return f"<Rabbit {self.id:>03}>"

    __repr__ = __str__

    def __add__(self, other):
        """Make a new Rabbit offspring of self and other."""
        return Rabbit(0, self, other)

    def __eq__(self, other):
        """Return True if and only if self and other have the same parents."""
        return self.__eq__v1(other)
        # return self.__eq__v2(other)
        # return self.__eq__v3(other)

    def __eq__v1(self, other):
        print(self, other)

        # this ends up recursively calling __eq__()
        # on self.parents[0] # and other.parents[0],
        # doesn't work if one is a Rabbit and the other is None
        return (
            self.parents == other.parents
            or self.parents == other.parents[::-1]
        )

    def __eq__v2(self, other):
        # this almost works, but not when self's or other's parents are None
        self_parent_ids = [p.id for p in self.parents]
        other_parent_ids = [p.id for p in other.parents]
        return (
            self_parent_ids == other_parent_ids
            or self_parent_ids == other_parent_ids[::-1]
        )

    def __eq__v3(self, other):
        get_parent_id = lambda x: x.id if isinstance(x, Rabbit) else None
        self_parent_ids = [get_parent_id(p) for p in self.parents]
        other_parent_ids = [get_parent_id(p) for p in other.parents]
        return (
            self_parent_ids == other_parent_ids
            or self_parent_ids == other_parent_ids[::-1]
        )


def demo_rabbit_equality():
    r1 = Rabbit(age=3)
    r2 = Rabbit(age=4)
    r3 = Rabbit(age=5)
    print(f"{r1 = }")
    print(f"{r2 = }")
    print(f"{r3 = }")
    print()

    print(f"{r1.parents = }")
    r4 = r1 + r2
    print(f"{r4 = }")
    print(f"{r4.parents = }")
    print()

    r5 = r3 + r4
    print(f"{r5 = }")
    print(f"{r5.parents = }")
    print()

    r6 = r4 + r3
    print(f"{r6 = }")
    print(f"{r6.parents = }")
    print()

    print(f"{(r5 == r6) = }")  # should be True
    print(f"{(r4 == r6) = }")  # should be False
    print(f"{(r1 == r2) = }")  # should be True
    print()

    r7 = r1 + r2 + r3  # what are r7's parents?
    print(f"{r7 = }")
    print(f"{r7.parents = }")
    print(f"{r7.parents[0].parents = }")
    print()


# demo_rabbit_equality()


############################################################
# model random walks with classes and inheritance
############################################################


class Location:

    def __init__(self, x, y):
        """x and y are numbers"""
        self.x = x
        self.y = y

    def move(self, delta_x, delta_y):
        """deltaX and deltaY are numbers"""
        return Location(self.x + delta_x, self.y + delta_y)

    def get_x(self):
        return self.x

    def get_y(self):
        return self.y

    def dist_from(self, other):
        x_dist = self.x - other.get_x()
        y_dist = self.y - other.get_y()
        return (x_dist**2 + y_dist**2)**0.5

    def __str__(self):
        return '<' + str(self.x) + ', ' + str(self.y) + '>'


class Field:

    def __init__(self):
        self.drunk_locs = {}

    def add_drunk(self, drunk, loc):
        if drunk in self.drunk_locs:
            raise ValueError('Duplicate drunk')
        else:
            self.drunk_locs[drunk] = loc

    def get_loc(self, drunk):
        if drunk not in self.drunk_locs:
            raise ValueError('Drunk not in field')
        return self.drunk_locs[drunk]

    def move_drunk(self, drunk):
        if drunk not in self.drunk_locs:
            raise ValueError('Drunk not in field')
        x_dist, y_dist = drunk.take_step()
        # use move() method of Location to set new location
        self.drunk_locs[drunk] = self.drunk_locs[drunk].move(x_dist, y_dist)

    def move_drunk_n_steps(self, drunk, steps):
        if drunk not in self.drunk_locs:
            raise ValueError('Drunk not in field')
        start = self.get_loc(drunk)
        for _ in range(steps):
            self.move_drunk(drunk)
        return start.dist_from(self.get_loc(drunk))



class Drunk:

    def __init__(self, name=None):
        """Assumes name is a str"""
        self.name = name

    def __str__(self):
        if self is not None:
            return self.name
        return 'Anonymous'


class UsualDrunk(Drunk):

    def take_step(self):
        step_choices = [(0, 1), (0, -1),
                        (1, 0), (-1, 0)]
        return random.choice(step_choices)

class MasochistDrunk(Drunk):

    def take_step(self):
        step_choices = [(0.0, 1.1), (0.0, -0.9),
                        (1.0, 0.0), (-1.0, 0.0)]
        return random.choice(step_choices)


class LiberalDrunk(Drunk):

    def take_step(self):
        step_choices = [(0.0, 1.0), (0.0, -1.0),
                        (0.9, 0.0), (-1.1, 0.0)]
        return random.choice(step_choices)


class ConservativeDrunk(Drunk):

    def take_step(self):
        step_choices = [(0.0, 1.0), (0.0, -1.0),
                        (1.1, 0.0), (-0.9, 0.0)]
        return random.choice(step_choices)


class LiberalMasochistDrunk(MasochistDrunk):

    def take_step(self):
        if random.choice([True, False]):
            step_choices = [(0.0, 1.0), (0.0, -1.0),
                            (0.9, 0.0), (-1.1, 0.0)]
            return random.choice(step_choices)
        else:
            return MasochistDrunk.take_step(self)


class CornerDrunk(Drunk):

    def take_step(self):
        step_choices = [(0.71, 0.71), (0.71, -0.71),
                        (-0.71, 0.71), (-0.71, -0.71)]
        return random.choice(step_choices)


############################################################
# simulate random walks
############################################################



def sim_walks(num_steps, num_trials, d_class):
    """Assumes num_steps an int >= 0, num_trials an int > 0,
         d_class a subclass of Drunk
       Simulates num_trials walks of num_steps steps each.
       Returns a list of the final distances for each trial"""
    Homer = d_class('Homer')
    origin = Location(0, 0)
    distances = []
    for _ in range(num_trials):
        f = Field()
        f.add_drunk(Homer, origin)
        distances.append(round(f.move_drunk_n_steps(Homer, num_steps), 1))
    return distances


def drunk_test(walk_lengths, num_trials, d_class):
    """Assumes walk_lengths a sequence of ints >= 0
         num_trials an int > 0, d_class a subclass of Drunk
       For each number of steps in walk_lengths, runs
       sim_walks with num_trials walks and prints results"""
    for num_steps in walk_lengths:
        distances = sim_walks(num_steps, num_trials, d_class)
        print(d_class.__name__, 'random walk of', num_steps, 'steps')
        print(' Mean =', round(sum(distances)/len(distances), 4))
        print(' Max =', max(distances), 'Min =', min(distances))


# random.seed(0)
# drunk_test((1, 2, 10, 100, 1000, 10000), 100, UsualDrunk)


def plot_drunk_test(walk_lengths, num_trials, d_class):
    """Assumes walk_lengths a sequence of ints >= 0
         num_trials an int > 0, d_class a subclass of Drunk
       Plots the average distance for each walk length and
         the sqrt of each walk length"""
    means = []
    for wl in walk_lengths:
        distances = (sim_walks(wl, num_trials, d_class))
        means.append(sum(distances) / len(distances))
    plt.plot(walk_lengths, means, label='Distance')
    roots = [wl**0.5 for wl in walk_lengths]
    plt.plot(walk_lengths, roots, '--', label='Sqrt of steps')
    plt.semilogy()
    plt.semilogx()
    plt.xlabel('Steps Taken')
    plt.ylabel('Distance from Origin')
    plt.title('Mean Distance from Origin\n100 Trials')
    plt.grid()
    plt.legend()
    plt.show()


# walk_lengths = []
# for i in range(1, 6):
#     walk_lengths.append(10**i)
# plot_drunk_test(walk_lengths, 100, UsualDrunk)


def sim_all(drunk_kinds, walk_lengths, num_trials):
    for d_class in drunk_kinds:
        random.seed(1)
        drunk_test(walk_lengths, num_trials, d_class)


# sim_all((UsualDrunk, MasochistDrunk), (1000, 10000), 100)


########################################
# visualize random walks
########################################


class RandomStyleGenerator:

    def __init__(self):
        self.colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
        self.lines = ['-', ':', '--', '-.']
        self.used_styles = set()

    def get_random_plot_style(self):
        """
        Returns a random plot style (color and marker combination).
        Avoids repeating the same style.
        """
        available_styles = [(color, line) for color in self.colors
                            for line in self.lines
                            if (color, line) not in self.used_styles]
        if not available_styles:
            # All styles have been used, reset the set
            self.used_styles.clear()
            available_styles = [(color, line)
                                for color in self.colors
                                for line in self.lines]
        random_color, random_line = random.choice(available_styles)
        self.used_styles.add((random_color, random_line))
        return random_color + random_line


def sim_drunk(num_trials, d_class, walk_lengths):
    mean_dists = []
    for num_steps in walk_lengths:
        print('Starting simulation of', num_steps, 'steps')
        trials = sim_walks(num_steps, num_trials, d_class)
        mean = sum(trials) / len(trials)
        mean_dists.append(mean)
    return mean_dists


def sim_all(drunk_kinds, walk_lengths, num_trials):
    style_choice = RandomStyleGenerator()
    for d_class in drunk_kinds:
        cur_style = style_choice.get_random_plot_style()
        print('Starting simulation of', d_class.__name__)
        means = sim_drunk(num_trials, d_class, walk_lengths)
        plt.plot(walk_lengths, means, cur_style, label=d_class.__name__)
    plt.title(f'Mean Distance from Origin ({num_trials} trials)')
    plt.xlabel('Number of Steps')
    plt.ylabel('Distance from Origin')
    plt.grid()
    plt.legend(loc='best')
    plt.show()


# random.seed(0)
# num_steps = (10, 100, 1000, 10000)
# sim_all(
#     (UsualDrunk, MasochistDrunk, LiberalDrunk, LiberalMasochistDrunk),
#     num_steps, 100
# )


########################################
# visualize final locations (not just final distance)
########################################


def sim_final_locs(num_steps, num_trials, d_class):
    locs = []
    d = d_class()
    for t in range(num_trials):
        f = Field()
        f.add_drunk(d, Location(0, 0))
        f.move_drunk_n_steps(d, num_steps)
        locs.append(f.get_loc(d))
    return locs


def plot_locs(drunk_kinds, num_steps, num_trials):
    style_choice = RandomStyleGenerator()
    for d_class in drunk_kinds:
        locs = sim_final_locs(num_steps, num_trials, d_class)
        x_vals, y_vals = [], []
        for loc in locs:
            x_vals.append(loc.get_x())
            y_vals.append(loc.get_y())
        x_vals = np.array(x_vals)
        y_vals = np.array(y_vals)
        mean_x = round(sum(x_vals)/len(x_vals))
        mean_y = round(sum(y_vals)/len(y_vals))
        abs_mean_x = round(sum(abs(x_vals))/len(x_vals))
        abs_mean_y = round(sum(abs(y_vals))/len(y_vals))
        cur_style = style_choice.get_random_plot_style()
        plt.scatter(
            x_vals, y_vals, color=cur_style[0],
            label=d_class.__name__ + ':'
            + ' mean abs dist = <' + str(abs_mean_x) + ', ' + str(abs_mean_y) + '>\n'
            + 'mean dist = <' + str(mean_x) + ', ' + str(mean_y) + '>'
        )
    plt.title('Location at End of Walks (' + str(num_steps) + ' steps)')
    plt.ylim(-1000, 1000)
    plt.xlim(-1000, 1000)
    plt.xlabel('Steps East/West of Origin')
    plt.ylabel('Steps North/South of Origin')
    plt.grid()
    plt.legend(loc='lower center', fontsize='large')
    plt.show()


# random.seed(1)
# plot_locs((UsualDrunk, MasochistDrunk, LiberalMasochistDrunk), 10000, 100)


########################################
# trace random walks with teleportation
########################################


class StyleIterator:

    def __init__(self, styles):
        self.index = 0
        self.styles = styles

    def next_style(self):
        result = self.styles[self.index]
        if self.index == len(self.styles) - 1:
            self.index = 0
        else:
            self.index += 1
        return result


class OddField(Field):

    def __init__(self, num_holes=1000, x_range=100, y_range=100):
        super().__init__()
        self.wormhole_locs = {}
        for w in range(num_holes):
            x = random.randint(-x_range, x_range)
            y = random.randint(-y_range, y_range)
            new_x = random.randint(-x_range, x_range)
            new_y = random.randint(-y_range, y_range)
            new_loc = Location(new_x, new_y)
            self.wormhole_locs[(x, y)] = new_loc

    def move_drunk(self, drunk):
        super().move_drunk(drunk)
        x = self.drunk_locs[drunk].get_x()
        y = self.drunk_locs[drunk].get_y()
        if (x, y) in self.wormhole_locs:
            self.drunk_locs[drunk] = self.wormhole_locs[(x, y)]


def trace_walk(field_kinds, num_steps):
    style_choice = StyleIterator(('b+', 'r^', 'ko'))
    for f_class in field_kinds:
        d = UsualDrunk()
        f = f_class()
        f.add_drunk(d, Location(0, 0))
        locs = []
        for s in range(num_steps):
            f.move_drunk(d)
            locs.append(f.get_loc(d))
        x_vals, y_vals = [], []
        for loc in locs:
            x_vals.append(loc.get_x())
            y_vals.append(loc.get_y())
        cur_style = style_choice.next_style()
        plt.plot(x_vals, y_vals, cur_style, label = f_class.__name__)
    plt.title('Spots Visited on Walk (' + str(num_steps) + ' steps)')
    plt.xlabel('Steps East/West of Origin')
    plt.ylabel('Steps North/South of Origin')
    plt.grid()
    plt.legend(loc='best')
    plt.show()


# random.seed(0)
# trace_walk((Field, OddField), 1000)


############################################################
# modeling gas particles as taking random walks
############################################################


class InertialDrunk(Drunk):

    def __init__(self, name):
       Drunk.__init__(self, name)
       self.last_step = None

    def reset_last_step(self):
        self.last_step = None

    def take_step(self, new_dir=False):
        if self.last_step is None:
            delta_func = lambda: (
                random.random() if random.random() < 0.5 else -random.random()
            )
            delta_x = delta_func()
            delta_y = delta_func()
            self.last_step = (delta_x, delta_y)
            return (delta_x, delta_y)
        else:
            return self.last_step


class ContinuousDrunk(InertialDrunk):

    def take_step(self):
        delta_func = lambda: (
            random.random() if random.random() < 0.5 else -random.random()
        )
        delta_x = delta_func()
        delta_y = delta_func()
        return (delta_x, delta_y)



########################################
# simulate with multiple drunks, don't allow collisions
# include boundaries on size of field
########################################


class FieldMulti:
    """Designed for large fields with small number of particles"""

    def __init__(self, x_lim, y_lim):
        self.drunk_locs = {}
        self.x_lim = x_lim
        self.y_lim = y_lim
        self.wall_hits, self.collisions = 0, 0

    def add_drunk(self, drunk, loc):
        if drunk in self.drunk_locs:
            raise ValueError('Duplicate drunk')
        else:
            self.drunk_locs[drunk] = loc

    def get_loc(self, drunk):
        if drunk not in self.drunk_locs:
            raise ValueError('Drunk not in field')
        return self.drunk_locs[drunk]

    def get_lims(self):
        return (self.x_lim, self.y_lim)

    def get_wall_hits(self):
        return self.wall_hits

    def get_collisions(self):
        return self.collisions

    def move_drunk(self, drunk):
        if drunk not in self.drunk_locs:
            raise ValueError('Drunk not in field')
        x_dist, y_dist = drunk.take_step()
        new_loc = self.drunk_locs[drunk].move(x_dist, y_dist)

        # respect edges of field
        too_close = False
        for d in self.drunk_locs:
            # check that step would not bring too close to other drunks
            if d != drunk and self.drunk_locs[d].dist_from(new_loc) < 1.0:
                too_close = True # So particle will not move
                self.collisions += 1
                d.reset_last_step() # So particle will bounce on next move
                break

        # check if would hit wall
        if too_close:
            pass
        elif new_loc.get_x() < -self.x_lim or new_loc.get_x() > self.x_lim:
            self.wall_hits += 1
            d.reset_last_step() # So particle will bounce on next move
        elif new_loc.get_y() < -self.y_lim or new_loc.get_y() > self.y_lim:
            self.wall_hits += 1
            d.reset_last_step() # So particle will bounce on next move
        else:
            # if any of above hold, don't move; otherwise move
            self.drunk_locs[drunk] = new_loc


class FieldMultiManyParticles(FieldMulti):
    """Better performance for a large number of particles"""

    def __init__(self, x_lim, y_lim):
        super().__init__(x_lim, y_lim)
        # Used so that collisions can be detected in constant time.
        # Initializes a 2D grid where each cell is ready to store a list of items
        self.drunks_by_loc = {
            (x, y): []
            for x in range(-(x_lim + 1), x_lim + 2)
            for y in range(-(y_lim + 1), y_lim + 2)
        }

    def add_drunk(self, drunk, loc):
        if drunk in self.drunk_locs:
            raise ValueError('Duplicate drunk')
        else:
            self.drunk_locs[drunk] = loc
            self.drunks_by_loc[(loc.get_x(), loc.get_y())].append(drunk)

    def move_drunk(self, drunk):
        if drunk not in self.drunk_locs:
            raise ValueError('Drunk not in field')
        old_x = int(self.drunk_locs[drunk].get_x())
        old_y = int(self.drunk_locs[drunk].get_y())
        x_dist, y_dist = drunk.take_step()
        new_loc = self.drunk_locs[drunk].move(x_dist, y_dist)
        x_val, y_val = new_loc.get_x(), new_loc.get_y()

        # Limit area to search for collisions
        # Find relevant columns
        d_col = int(x_val)
        relevant_cols = [d_col]
        if d_col < self.x_lim: # not at right edge
            relevant_cols.append(d_col + 1)
        if d_col > -self.x_lim: # not at left edge
            relevant_cols.append(d_col - 1)

        d_row = int(y_val)
        relevant_cells = [(x, d_row) for x in relevant_cols]
        if d_row < self.y_lim: # not at top edge
            for d_col in relevant_cols:
                relevant_cells.append((d_col, d_row + 1))
        if d_row > -self.y_lim: # not at bottom edge
            for d_col in relevant_cols:
                relevant_cells.append((d_col, d_row - 1))

        possible_neighbors = []
        for cell in relevant_cells:
            for d in self.drunks_by_loc[cell]:
                if d != drunk:
                    possible_neighbors.append(d)

        # To see neighborhood it is exploring for first move
        # Comment out when running a full simulation
        # print(f'relevant_cols = {relevant_cols}')
        # print(f'relevant_cells = {relevant_cells}')
        # neighbors = []
        # for d in possible_neighbors:
        #     neighbors.append(d.__str__())
        # print(f'Possible neighbors of {drunk.__str__()}: {neighbors}')
        # sys.exit(0)
        # End material to be commented out

        # check that not too close to other drunks
        too_close = False
        for d in possible_neighbors:
            if d != drunk and self.drunk_locs[d].dist_from(new_loc) < 1.0:
                too_close = True
                d.reset_last_step()
                self.collisions += 1
                break

        if too_close: # Check if would hit wall
            pass
        elif new_loc.get_x() < -self.x_lim or new_loc.get_x() > self.x_lim:
            self.wall_hits += 1
            d.reset_last_step() # So particle will bounce on next move
        elif new_loc.get_y() < -self.y_lim or new_loc.get_y() > self.y_lim:
            self.wall_hits += 1
            d.reset_last_step() # So particle will bounce on next move
        else:
            # if any of above hold, don't move; otherwise move
            self.drunk_locs[drunk] = new_loc
            self.drunks_by_loc[(old_x, old_y)].remove(drunk)
            self.drunks_by_loc[int((new_loc.get_x())), int(new_loc.get_y())].append(drunk)


def walk_multi(f, ds, num_steps, starts):
    """Assumes: f a Field, ds a dict of Drunks in f,
          num_steps an int >= 0,
          starts a dict of starting locations
       Moves each d in ds num_steps times, and returns the
       average distance between the
       final location and the location at the start of the walk."""
    for s in range(num_steps):
        for d in ds:
            f.move_drunk(d)
    dists = []
    for d in ds:
        start = starts[d]
        dist = start.dist_from(f.get_loc(d))
        dists.append(dist)
    return dists


def sim_walks_multi(num_steps, num_trials, d_class, num_drunks,
                    boundaries, opt_space=False, verbose=False):
    """Assumes num_steps an int >= 0, num_trials an int > 0,
         d_class a subclass of Drunk, num_drunks an int,
         boundaries an int
       Simulates num_trials walks of num_steps steps each.
       Returns a list of the final mean, max, and min distances
         for each trial"""
    starts = {}
    range_of_choices = [i for i in range(-boundaries, boundaries+1)]
    mean_dists, max_dists, min_dists, wall_hits, cols = [], [], [], [], []
    for t in range(num_trials):
        if opt_space:
            f = FieldMulti(boundaries, boundaries)
        else:
            f = FieldMultiManyParticles(boundaries, boundaries)
        for i in range(num_drunks):
            Homer = d_class('Homer' + str(i))
            start = Location(random.choice(range_of_choices),
                             random.choice(range_of_choices))
            f.add_drunk(Homer, start)
            starts[Homer] = start
        distances = walk_multi(f, f.drunk_locs, num_steps, starts)
        mean_dists.append(sum(distances) / len(distances))
        max_dists.append(max(distances))
        min_dists.append(min(distances))
        wall_hits.append(f.get_wall_hits())
        cols.append(f.get_collisions())
        if verbose:
            print(f'Mean distance for trial {t} = {mean_dists[-1]}')
    return mean_dists, max_dists, min_dists, wall_hits, cols


def drunk_test_multi(walk_lengths, num_trials, d_class, num_particles,
                     boundaries, opt_space=False, verbose=False):
    """Assumes walk_lengths a sequence of ints >= 0
         num_trials an int > 0, d_class a subclass of Drunk
       For each number of steps in walk_lengths, runs
       sim_walks with num_trials walks and prints results"""
    for l in walk_lengths:
        for num in num_particles:
            for d in boundaries:
                print(f'particles = {num}, size = {d:,}, steps = {l:,}')
                mean_dists, max_dists, min_dists, wall_hits, cols = \
                    sim_walks_multi(
                        l, num_trials, d_class, num, d,
                        verbose=verbose, opt_space=opt_space)
                max_d = round(max(max_dists))
                min_d = round(min(min_dists))
                mean_d = sum(mean_dists) / len(mean_dists)
                mean_wh = sum(wall_hits) / len(wall_hits)
                mean_col = sum(cols)/len(cols)
                print(f' Distance: Max = {max_d:,}, Min = {min_d:,},',
                      f'Mean = {round(mean_d):,}')
                print(f' Mean wall hits = {round(mean_wh):,}')
                if mean_col != 0:
                    print(f' Mean collisions = {round(mean_col):,}')
    return mean_d, mean_wh, mean_col


def sim_particles():
    Particle = InertialDrunk
    # Particle = ContinuousDrunk

    # Vary field size
    # random.seed(1)
    # num_particles = (1,)
    # sizes = (10, 20, 50, 100, 1000, 10000)
    # lengths = (500,)
    # num_trials = 50
    # drunk_test_multi(lengths, num_trials, Particle, num_particles, sizes, opt_space=True)


    # Vary walk length (can be thought of a velocity, which is related to temperature)
    # random.seed(1)
    # num_particles = (1,)
    # sizes = (50,)
    # lengths = (200, 400, 600, 800, 1000, 1200)
    # lengths = [2**p for p in range(5, 11)]
    # num_trials = 50
    # drunk_test_multi(lengths, num_trials, Particle, num_particles, sizes, opt_space=True)

    # # Vary number of particles per trial, corresponds to density
    # random.seed(1)
    # sizes = (50,)
    # lengths = (200,)
    # num_trials = 50
    # num_particles = [50, 100, 200, 300, 400]
    # wall_hits, collisions = [], []
    # for n in num_particles:
    #     mean_d, mean_wh, mean_col = drunk_test_multi(
    #         lengths, num_trials, Particle, (n,),
    #         sizes, opt_space=False, verbose=False)
    #     wall_hits.append(mean_wh)
    #     collisions.append(mean_col)


    # # Plot results from simulation varying number of particles
    # plt.figure()
    # plt.plot(num_particles, wall_hits,'o-', label='Data points')
    # plt.title('Pressure vs. Number of Particles')
    # plt.xlabel('Number of Particles')
    # plt.ylabel('Number of Wall Hits')

    # # Fit a model to predict number of wall hits
    # model = np.polyfit(num_particles, wall_hits, 1)
    # plt.plot(num_particles, np.polyval(model, num_particles), 'k',
    #          label='Linear model')

    # plt.grid()
    # plt.legend()

    # plt.figure()
    # plt.plot(num_particles, collisions,'o-')
    # plt.title('Particle Interactions vs. Number of Particles')
    # plt.xlabel('Number of Particles')
    # plt.ylabel('Number of Collisions')
    # plt.grid()
    # plt.legend()

    # plt.show()


# sim_particles()
