import math
import matplotlib.pyplot as plt
import pprint
import random
import time


############################################################
# model gas particles as random walks
############################################################


def mean(lines):
    return sum(lines) / len(lines)


def split_xy(points):
    return [x for x, _ in points], [y for _, y in points]


class Vector(list):

    def __init__(self, x, y):
        super().__init__((x, y))
        assert len(self) == 2

    def __str__(self):
        return f"<{self[0]}, {self[1]}>"

    def __eq__(self, other):
        # print(f"checking {self} == {other}")
        return isinstance(other, Vector) and super().__eq__(other)

    def __ne__(self, other):
        return not self == other

    def __add__(self, other):
        assert len(other) == 2
        return Vector(self[0] + other[0], self[1] + other[1])

    def __iadd__(self, other):
        assert len(other) == 2
        self[0] += other[0]
        self[1] += other[1]
        return self

    def __floordiv__(self, k):
        assert isinstance(k, (int, float))
        return Vector(self[0] // k, self[1] // k)

    def distance(self, other):
        assert len(other) == 2
        dx = self[0] - other[0]
        dy = self[1] - other[1]
        return (dx**2 + dy**2) ** 0.5


# print(Vector(2, 3) + [4, 0])
# print(Vector(15, 27) // 5)
# print(Vector(0, 0) == Vector(1, 1) + Vector(-1, -1))
# print(Vector(0, 0) == [0, 0])
# print(Vector(0, 0) != [0, 0])


class Walker:

    def step(self):
        raise NotImplementedError()

    def reset(self):
        pass


class ContinuousWalker(Walker):

    def step(self):
        # return random.uniform(-0.5, 0.5), random.uniform(-0.5, 0.5)
        direction = random.uniform(0, 2 * math.pi)
        length = 0.5
        return length * math.cos(direction), length * math.sin(direction)


class InertialWalker(ContinuousWalker):

    def __init__(self):
        self.reset()

    def __str__(self):
        return f"<InertialWalker step={self._last_step}>"

    __repr__ = __str__

    def reset(self):
        self._last_step = None

    def step(self):
        if self._last_step is None:
            self._last_step = super().step()
        return self._last_step


class Field:

    PARTICLE_SIZE = 1

    def __init__(self, size):
        self._particle_locs = {}
        assert size > 0
        self._xlim = 0, size
        self._ylim = 0, size
        self.collisions = 0
        self.wall_hits = 0

    def add_particle(self, particle, loc):
        assert particle not in self._particle_locs, "duplicate particle"
        self._particle_locs[particle] = loc

    def get_loc(self, particle):
        assert particle in self._particle_locs, "particle not in field"
        return self._particle_locs[particle]

    def _check_collision(self, particle, new_loc):
        for p, loc in self._particle_locs.items():
            if p != particle and new_loc.distance(loc) < self.PARTICLE_SIZE:
                self.collisions += 1
                particle.reset()
                return True
        return False

    def move_particle(self, particle):
        assert particle in self._particle_locs, "particle not in field"
        new_loc = self._particle_locs[particle] + particle.step()

        # check for collision
        if self._check_collision(particle, new_loc):
            return

        # check for wall hit
        if not (
            self._xlim[0] <= new_loc[0] <= self._xlim[1]
            and self._ylim[0] <= new_loc[1] <= self._ylim[1]
        ):
            self.wall_hits += 1
            particle.reset()
            return

        self._particle_locs[particle] = new_loc


class EfficientField(Field):

    def __init__(self, size):
        super().__init__(size)
        self._cell_size = 2 * self.PARTICLE_SIZE
        self._cell_count = math.ceil(size // self._cell_size)
        self._grid = {
            (x, y): set()
            for x in range(self._cell_count)
            for y in range(self._cell_count)
        }

    def _grid_coord(self, location):
        return tuple(location // self._cell_size)

    def add_particle(self, particle, loc):
        super().add_particle(particle, loc)
        self._grid[self._grid_coord(loc)].add(particle)

    def _check_collision(self, particle, new_loc):
        cell = self._grid_coord(self.get_loc(particle))
        neighbor_cells = [
            Vector(cell[0], cell[1]) + dir
            for dir in [
                (-1, -1), (-1, 0), (-1, 1),
                (0, -1), (0, 0), (0, 1),
                (1, -1), (1, 0), (1, 1),
            ]
        ]
        relevant_cells = [
            tuple(cell) for cell in neighbor_cells
            if (
                0 <= cell[0] < self._cell_count
                and 0 <= cell[1] < self._cell_count
            )
        ]
        for cell_coord in relevant_cells:
            for p in self._grid[cell_coord]:
                if p == particle:
                    continue
                if new_loc.distance(self.get_loc(p)) < self.PARTICLE_SIZE:
                    self.collisions += 1
                    particle.reset()
                    return True
        return False

    def move_particle(self, particle):
        old_loc = self.get_loc(particle)
        super().move_particle(particle)
        new_loc = self.get_loc(particle)
        if new_loc != old_loc:
            self._grid[self._grid_coord(old_loc)].remove(particle)
            self._grid[self._grid_coord(new_loc)].add(particle)


def _inspect_simulation_state(
    field, particles, title, size, plot=False, grid=False
):
    if grid:
        print()
        pprint.pprint(field._grid)

    if plot:
        particles_x, particles_y = split_xy([field.get_loc(p) for p in particles])
        plt.figure()
        plt.scatter(particles_x, particles_y, s=1000 / field._xlim[1])
        plt.title(title)
        plt.xlabel("X position")
        plt.ylabel("Y position")
        plt.xlim(0, size)
        plt.ylim(0, size)
        plt.grid()


def simulate_gas(
    field_size, num_particles, num_steps, walker_type, field_type,
    plot=False, grid=False,
):
    # make bounded field of gas particles
    field = field_type(field_size)
    particles = []
    for _ in range(num_particles):
        start = Vector(
            random.uniform(0, field_size),
            random.uniform(0, field_size),
        )
        p = walker_type()
        field.add_particle(p, start)
        particles.append(p)

    # plot starting positions
    _inspect_simulation_state(
        field, particles,
        f"Particle start locations, {walker_type.__name__}",
        field_size, plot, grid,
    )

    # run simulation
    for s in range(num_steps):
        if s % 10 == 0:
            print(f"completed {s} steps")
        for p in particles:
            field.move_particle(p)

    # plot final positions
    _inspect_simulation_state(
        field, particles,
        f"Particle end locations after {num_steps} steps, {walker_type.__name__}",
        field_size, plot, grid
    )

    return field.wall_hits, field.collisions


# print(simulate_gas(10, 20, 100, InertialWalker, Field, plot=True))
# print(simulate_gas(10, 20, 100, InertialWalker, EfficientField, plot=True, grid=True))

# print(simulate_gas(50, 100, 100, InertialWalker, Field, plot=True))
# print(simulate_gas(50, 1000, 100, InertialWalker, Field, plot=True))
# print(simulate_gas(500, 10_000, 100, InertialWalker, Field, plot=True))

# print(simulate_gas(50, 100, 100, InertialWalker, EfficientField, plot=True))
# print(simulate_gas(50, 1000, 100, InertialWalker, EfficientField, plot=True))
# print(simulate_gas(500, 10_000, 100, InertialWalker, EfficientField, plot=True))

plt.show()


############################################################
# model graphs using classes graphs
############################################################


class Digraph:
    """
    A weighted directed graph, mapping each node to outgoing edges and
    their weights. Nodes are strs.
    """

    def __init__(self, nodes=tuple()):
        self._edges = {}
        for node in nodes:
            self.add_node(node)

    def __str__(self):
        lines = []
        for source in self._edges:
            outgoing_edges = [
                f"{target}({weight})"
                for target, weight in self._edges[source].items()
            ]
            lines.append(f"{source}: {", ".join(outgoing_edges)}")
        lines.sort()
        return "\n".join(lines)

    def add_node(self, node):
        if node in self._edges:
            raise ValueError(f"duplicate node {node}")
        self._edges[node] = {}

    def add_edge(self, src, dest, weight=1):
        if src not in self._edges:
            self.add_node(src)
        if dest not in self._edges:
            self.add_node(dest)
        self._edges[src][dest] = weight

    def has_node(self, node):
        return node in self._edges

    def all_nodes(self):
        return list(self._edges.keys())

    def outgoing_edges_of(self, node):
        return self._edges[node].copy()

    def children_of(self, node):
        return list(self.outgoing_edges_of(node).keys())


class Graph(Digraph):
    """An undirected graph implemented as pairs of directed edges."""

    def add_edge(self, node1, node2, weight=1):
        super().add_edge(node1, node2, weight)
        super().add_edge(node2, node1, weight)


def build_test_graph(graph_type):
    g = graph_type()
    g.add_edge("A", "B")
    g.add_edge("A", "C")
    g.add_edge("A", "D")
    g.add_edge("B", "E")
    g.add_edge("C", "E")
    g.add_edge("C", "F")
    g.add_edge("C", "G")
    g.add_edge("E", "H")
    g.add_edge("F", "H")
    g.add_edge("G", "H")
    return g


# print()
# print(build_test_graph(Digraph))
# print()
# print(build_test_graph(Graph))


############################################################
# adapt previous bfs implementation from lecture 12
############################################################


def bfs_graph(graph, start, goal):
    if start == goal:
        return [start]

    current_frontier = [[start]]
    next_frontier = []
    visited = {start}

    while len(current_frontier) > 0:
        for path in current_frontier:
            node = path[-1]
            for next_node in graph.children_of(node):
                if next_node in visited:
                    continue
                visited.add(next_node)

                new_path = path + [next_node]
                if next_node == goal:
                    return new_path
                next_frontier.append(new_path)

        current_frontier, next_frontier = next_frontier, []

    return None


def demo_bfs_graph():
    flights_dict = {
        "Boston": ["Providence", "New York"],
        "Providence": ["Boston", "New York"],
        "New York": ["Chicago"],
        "Chicago": ["Denver", "Phoenix"],
        "Denver": ["New York", "Phoenix"],
        "Los Angeles": ["Boston"],
    }
    flights = Digraph()
    for source, targets in flights_dict.items():
        for target in targets:
            flights.add_edge(source, target)

    print()
    print(bfs_graph(flights, "Boston", "Phoenix"))

    # add new flight from Providence to Phoenix
    # flights["Providence"].append("Phoenix")
    flights.add_edge("Providence", "Phoenix")

    # find a shorter path than before
    print()
    print(bfs_graph(flights, "Boston", "Phoenix"))


# demo_bfs_graph()


############################################################
# adapt previous dijkstra's implementation from lecture 14
############################################################


def remove_min(queue):
    best = min(queue)
    idx = queue.index(best)
    return queue.pop(idx)


def find_node(queue, node):
    for idx in range(len(queue)):
        cost, queue_node = queue[idx]
        if queue_node == node:
            return idx
    return None


def update_node(queue, node, cost):
    idx = find_node(queue, node)
    if idx is None:
        queue.append((cost, node))
        updated = True
    else:
        old_cost, old_path = queue[idx]
        if cost < old_cost:
            queue[idx] = cost, node
            updated = True
        else:
            updated = False
    return updated


def trace_predecessors(node, start, predecessors):
    current = node
    path = [current]
    while current != start:
        current = predecessors[current]
        path.append(current)
    path.reverse()
    return path


def dijkstra_predecessors(graph, start, goal):
    queue = [(0, start)]
    finished = set()
    predecessors = {start: None}
    while len(queue) > 0:
        cost, node = remove_min(queue)
        finished.add(node)
        if node == goal:
            return cost, trace_predecessors(node, start, predecessors)
        for next_node, weight in graph.outgoing_edges_of(node).items():
            if next_node not in finished:
                new_cost = cost + weight
                if update_node(queue, next_node, new_cost):
                    predecessors[next_node] = node
    return None


def demo_dijkstras():
    flights_dict = {
        "Boston": [("Providence", 1), ("New York", 2)],
        "Providence": [("Boston", 1), ("New York", 2)],
        "New York": [("Chicago", 3)],
        "Chicago": [("Phoenix", 5), ("Denver", 3)],
        "Denver": [("New York", 4), ("Phoenix", 1)],
    }
    flights = Digraph()
    for source, targets in flights_dict.items():
        for target, weight in targets:
            flights.add_edge(source, target, weight)

    print()
    print(dijkstra_predecessors(flights, "Boston", "Phoenix"))


# demo_dijkstras()


############################################################
# apply dijkstra's to open street map data
############################################################


# OpenStreetMap: open-source map data
# https://www.openstreetmap.org/
# https://www.openstreetmap.org/about

# OSMnx: a library for accessing OSM data via NetworkX interfaces
# https://osmnx.readthedocs.io/en/stable/getting-started.html
# https://osmnx.readthedocs.io/en/stable/user-reference.html

# NetworkX: a library for modeling graphs
# https://networkx.org/documentation/stable/index.html
# https://networkx.org/documentation/stable/reference/index.html


import networkx as nx
import osmnx as ox


def graph_from_osm(places, network_type="drive"):
    G = ox.graph_from_place(places, network_type=network_type)
    ox.plot.plot_graph(G)

    # map nodes to lat-lon coordinates
    node_coords = {}
    for node_id, attr in G.nodes.data():
        # print(f"Node {node_id}: {attr}")
        lat = attr["y"]
        lon = attr["x"]
        node_coords[str(node_id)] = lon, lat

    # convert nx graph into our graph format
    graph = Digraph()
    for u, v, attr in G.edges.data():
        # print(f"Edge from {u} to {v}: {attr}")
        dist = attr["length"]
        graph.add_edge(str(u), str(v), dist)

    return graph, node_coords, G

    # Each edge has attributes like:
    #   "length": in meters
    #   "highway": road type
    #   "oneway": True/False
    #   "geometry": shapely LineString (optional)
    # Note: If the graph is directed (network_type="drive"), edges
    # represent one-way streets and may have multiple edges between the
    # same pair of nodes.


def visualize_graph(graph, node_coords, color="C0"):
    Lx, Ly, Lc = [], [], []
    for n1 in graph.all_nodes():
        p1 = node_coords[n1]
        for n2 in graph.children_of(n1):
            p2 = node_coords[n2]
            Lx += [p1[0], p2[0], None]
            Ly += [p1[1], p2[1], None]
    plt.plot(Lx, Ly, linewidth=1, color=color)


def visualize_path(path, node_coords, color="C1", linestyle="-"):
    Lx, Ly, Lc = [], [], []
    for i in range(len(path) - 1):
        p1 = node_coords[path[i]]
        p2 = node_coords[path[i + 1]]
        Lx += [p1[0], p2[0]]
        Ly += [p1[1], p2[1]]
    plt.plot(Lx, Ly, linewidth=3, color=color, linestyle=linestyle)


# places = ["Cambridge, MA, USA", "Somerville, MA, USA"]
# graph, coords, G = graph_from_osm(places, network_type="drive")
# visualize_graph(graph, coords)


dijkstra = dijkstra_predecessors


def demo_osm_pathfinding():
    places = ["Cambridge, MA, USA", "Somerville, MA, USA"]
    graph, coords, G = graph_from_osm(places, network_type="drive")
    # graph, coords, G = graph_from_osm(places, network_type="walk")

    # address as a string
    # home = "330 Mt Auburn St, Cambridge, MA 02138"
    # home = "795 Massachusetts Ave, Cambridge, MA 02139"
    home = "93 Highland Ave, Somerville, MA 02143"
    work = "32 Vassar St, Cambridge, MA 02139"

    # get nodes from latitude and longitude
    loc_home = ox.geocoder.geocode(home)
    loc_work = ox.geocoder.geocode(work)
    node_home = str(ox.distance.nearest_nodes(G, X=loc_home[1], Y=loc_home[0]))
    node_work = str(ox.distance.nearest_nodes(G, X=loc_work[1], Y=loc_work[0]))

    # run pathfinding algorithm
    distance, path = dijkstra(graph, node_home, node_work)
    print(f"total path distance is {distance}")
    visualize_graph(graph, coords)
    visualize_path(path, coords)
    plt.show()


# demo_osm_pathfinding()
