import matplotlib.pyplot as plt
from pprint import pprint
import random


############################################################
# overview: traversing through a graph
############################################################


# We ended yesterday's lecture with a couple examples of graphs
# implemented as dictionaries. What we didn't do, however, was
# demonstrate what we actions we could perform with graphs and how we
# might implement that. The next few lectures will focus on strategies
# for traversing graphs and finding paths between nodes. As a warm-up,
# this pre-lecture code will demonstrate a very simple strategy for
# exploring graphs, using a random walk.


############################################################
# building a grid graph
############################################################


# To begin, we'll need to build a graph. The following function
# constructs a graph consisting of a grid of nodes with a specified
# number of rows and columns, similar to the following structure.

#   o--o--o--o
#   |  |  |  |
#   o--o--o--o
#   |  |  |  |
#   o--o--o--o
#   |  |  |  |
#   o--o--o--o

# However, we also connect all the nodes on the top edge with those on
# the bottom edge, and similarly for left and right edges. Thus, the
# graph is essentially a grid that "wraps around" on itself, and the
# connectivity is like in the game Pac-Man, where going off the right
# edge brings you back to the left edge.


def make_grid_graph(rows, cols):
    graph = {}
    for r in range(rows):
        for c in range(cols):
            node = r, c
            up = (r - 1) % rows
            down = (r + 1) % rows
            left = (c - 1) % cols
            right = (c + 1) % cols
            graph[node] = [(up, c), (down, c), (r, left), (r, right)]
    return graph


# Print the following graphs to make sure you understand their
# dictionary representation in Python. What are the nodes, and why are
# these valid keys for a Python dictionary? Also, do you notice anything
# curious about the first graph's representation?


# pprint(make_grid_graph(2, 2))
# pprint(make_grid_graph(4, 4))


############################################################
# exploring the graph through a random walk
############################################################


# Now we'll build a function to explore the graph using a random walk.
# Given a graph size specified as a (rows, cols) tuple, we'll make a
# grid graph, pick a random start node, and randomly traverse edges in
# sequences. The question we'll ask is after a certain number of steps,
# how much of the graph has been explored. We could answer this in two
# ways: how many of all the nodes we've seen, and how many of all the
# edges we've traversed. To keep things simpler, we'll count nodes.

# Hence, the function also sets up a dictionary `explored` that maps
# each node to a boolean indicating whether we've seen that node. To
# begin with, the initial start node is considered explored.

# Our goal is to plot the percentage of explored nodes over a specified
# number of steps. Thus, we need to determine the number of explored
# states at each time step. To do so, we're calling the `sum()` function
# on all the values of the `explored` dictionary. This might be a little
# unexpected, but let's break it down:
#   + `explored.values()` represents a collection of all the values in
#     the dictionary, which are either `True` or `False`.
#   + The `sum()` function can iterate over this collection.
#   + Python can interpret `True` as `1` and `False` as `0`.
# Thus, we're counting the number of `True` explored nodes at each time
# step.

# We've left the function below incomplete in terms of converting the
# number of nodes explored to percentages and plotting those results.
# Fill in those lines, and run the first set of calls to
# `random_explore()` below.


def random_explore(graph_size, num_steps):
    # make a grid graph
    rows, cols = graph_size
    graph = make_grid_graph(rows, cols)

    # set all nodes as being unexplored
    explored = {}
    for node in graph:
        explored[node] = False

    # choose a random start node and mark it as explored
    state = random.choice(list(graph))
    explored[state] = True
    num_states_explored = [1]

    # simulate random walk and track how many states are explored
    steps = range(num_steps)
    for _ in range(num_steps):
        state = random.choice(graph[state])
        explored[state] = True
        num_states_explored.append(sum(explored.values()))

    # convert number of states explored to percentages
    percentage_explored = []  # TODO: populate this list
    ...

    # plot
    plt.figure()
    plt.plot(...)  # TODO: complete this plotting call
    plt.title(f"Random walk on a {rows}x{cols} grid graph")
    plt.xlabel("Number of steps")
    plt.ylabel("Percentage of states explored")
    plt.ylim(0, 110)
    plt.grid()


# random_explore((4, 4), num_steps=100)
# random_explore((4, 4), num_steps=200)
# random_explore((4, 4), num_steps=400)
# plt.show()


# It seems for a 4x4 graph with 16 states, we are reliably covering the
# entire graph well within 100 steps, usually. This doesn't sound so
# bad, but we can note that 100 is about 8 times the number of states.
# What if we try larger graphs, and run random walks with a
# proportionally scaled number of steps as well?


# random_explore((5, 5), num_steps=200)
# random_explore((10, 10), num_steps=800)
# random_explore((20, 20), num_steps=3200)
# plt.show()


# You should see that it gets increasingly hard to cover the entire
# graph using a random walk method. This is related to the fact that the
# average total displacement at the end of a random walk scales slower
# than the number of steps. Hence, random walks are not a very efficient
# way to explore graphs. In class tomorrow and next week, we'll discuss
# more sophisticated and systematic methods and their associated
# properties.
