import matplotlib.pyplot as plt
import time


############################################################
# list vs dict operation runtimes
############################################################


def list_create(size):
    container = []
    for k in range(size):
        container.append(k)
    return container


def dict_create(size):
    container = {}
    for k in range(size):
        container[k] = k
    return container


def list_insert(container):
    size = len(container)
    for _ in range(100):
        # container.insert(0, "ITEM")
        container.insert(size // 2, "ITEM")
        # container.insert(-1, "ITEM")


def dict_insert(container):
    for k in range(100):
        container["ITEM" + str(k)] = "ITEM"


def list_delete(container):
    size = len(container)
    for _ in range(100):
        # container.pop(0)
        container.pop(size // 4)
        # container.pop(-1)


def dict_delete(container):
    size = len(container)
    for k in range(100):
        container.pop(size // 4 + k)


def sample_creation_times(obj_type, sizes, num_trials):
    if obj_type == list:
        operation = list_create
    elif obj_type == dict:
        operation = dict_create

    averages = []
    for size in sizes:
        times = []
        for _ in range(num_trials):
            start = time.time()
            operation(size)
            end = time.time()
            times.append(end - start)
        averages.append(sum(times) / len(times))
    return averages


def compare_list_dict_creation():
    sizes = [200, 400, 800, 1600, 3200, 6400, 12800, 25600, 51200, 102400]
    list_creation_times = sample_creation_times(list, sizes, num_trials=100)

    plt.figure()
    plt.plot(sizes, list_creation_times, marker="o", label="list")
    plt.title("Container Creation Times")
    plt.xlabel("Container size")
    plt.ylabel("Average runtime (sec)")
    plt.legend()
    plt.grid()
    plt.show()


# compare_list_dict_creation()


def sample_mutation_times(obj_type, op, sizes, num_trials):
    if obj_type == list:
        create = list_create
        if op == "insert":
            mutate = list_insert
        elif op == "delete":
            mutate = list_delete
    elif obj_type == dict:
        create = dict_create
        if op == "insert":
            mutate = dict_insert
        elif op == "delete":
            mutate = dict_delete

    averages = []
    for size in sizes:
        times = []
        for _ in range(num_trials):
            container = create(size)
            start = time.time()
            mutate(container)
            end = time.time()
            times.append(end - start)
        averages.append(sum(times) / len(times))
    return averages


def compare_list_dict_insertion():
    sizes = [200, 400, 800, 1600, 3200, 6400, 12800, 25600, 51200, 102400]
    list_insertion_times = sample_mutation_times(list, "insert", sizes, num_trials=100)
    dict_insertion_times = sample_mutation_times(dict, "insert", sizes, num_trials=100)

    plt.figure()
    plt.loglog(sizes, list_insertion_times, marker="o", label="list")
    plt.loglog(sizes, dict_insertion_times, marker="o", label="dict")
    plt.title("Container Insertion Times")
    plt.xlabel("Container size")
    plt.ylabel("Average runtime (sec)")
    plt.legend()
    plt.grid()
    plt.show()


def compare_list_dict_deletion():
    sizes = [200, 400, 800, 1600, 3200, 6400, 12800, 25600, 51200, 102400]
    list_deletion_times = sample_mutation_times(list, "delete", sizes, num_trials=100)
    dict_deletion_times = sample_mutation_times(dict, "delete", sizes, num_trials=100)

    plt.figure()
    plt.loglog(sizes, list_deletion_times, marker="o", label="list")
    plt.loglog(sizes, dict_deletion_times, marker="o", label="dict")
    plt.title("Container Deletion Times")
    plt.xlabel("Container size")
    plt.ylabel("Average runtime (sec)")
    plt.legend()
    plt.grid()
    plt.show()


# compare_list_dict_insertion()
# compare_list_dict_deletion()


############################################################
# sets: syntax and operations
############################################################


def demo_sets():
    fruits = {"apple", "watermelon", "tomato", "squash"}
    vegetables = {"carrot", "squash", "broccoli", "tomato"}
    print(fruits)
    print(vegetables)
    print(fruits | vegetables)
    print(fruits & vegetables)
    print(fruits - vegetables)
    print({"carrot", "squash"} < vegetables)


# print()
# demo_sets()


def get_primes():
    """Use the Sieve of Erastosthenes method."""
    candidates = set(range(1, 101))
    for n in [2, 3, 5, 7]:
        candidates -= set(range(n, 101, n))
    return candidates


# print()
# print(get_primes())


# EXERCISE: Generalize get_primes() to accept an upper bound, and find
# all primes up to that bound.


############################################################
# graph and display helpers
############################################################


def neighbors(graph, node):
    return graph.get(node, [])


def path_to_string(path):
    """Convert a list of nodes to a str representation."""
    if path is None:
        return "<empty path>"
    node_strs = []
    for node in path:
        node_strs.append(str(node))
    return " --> ".join(node_strs)


def pathlist_to_string(pathlist):
    """
    Convert a list of paths to a str representation. Used to inspect
    intermediate states of graph search.
    """
    path_strs = []
    for path in pathlist:
        path_strs.append(path_to_string(path))
    return "[" + ", ".join(path_strs) + "]"


############################################################
# finding paths in trees using breadth-first search
############################################################


def bfs_tree(graph, start, goal):
    if start == goal:
        return True

    current_frontier = [start]
    next_frontier = []

    # keep going until past last reachable frontier
    while len(current_frontier) > 0:
        print(f"Current frontier: {current_frontier}")

        # expand each node in current frontier to its neighbors
        for node in current_frontier:
            print(f"  Current BFS node: {node}")

            for next_node in neighbors(graph, node):
                print(f"    Considering next node: {next_node}")

                # stop if reached goal
                if next_node == goal:
                    return True

                # otherwise collect into next frontier
                next_frontier.append(next_node)

        current_frontier, next_frontier = next_frontier, []

    # all reachable frontiers exhausted without finding goal
    return False


def demo_bfs_tree():
    filesystem = {
        "home": ["school", "personal", "downloads"],
        "school": ["spring25", "fall25", "spring26"],
        "spring26": ["6.100", "8.02"],
        "personal": ["photos", "bills"],
        "downloads": ["ps1.zip", "ps1", "beach.jpg"],
        "ps1": ["pset.py", "test.py"],
    }

    print(bfs_tree(filesystem, "home", "pset.py"))
    print()
    print(bfs_tree(filesystem, "home", "beach.jpg"))
    print()

    # make "downloads" first child of "home"
    home_folders = filesystem["home"]
    home_folders.insert(0, home_folders.pop())

    # doesn't change number of frontiers explored, but explore less of
    # last frontier
    print(bfs_tree(filesystem, "home", "pset.py"))
    print()
    print(bfs_tree(filesystem, "home", "beach.jpg"))
    print()

    # QUESTION: How does bfs_tree() behave when it's called with a goal
    # that is not in the filesystem tree, e.g., goal="ps7"?


# demo_bfs_tree()


def bfs_tree_path(graph, start, goal):
    # return a PATH if reached goal
    if start == goal:
        return [start]

    # frontiers store PATHS now instead of just nodes
    current_frontier = [[start]]
    next_frontier = []

    while len(current_frontier) > 0:
        print(f"Current frontier: {pathlist_to_string(current_frontier)}")

        # expand each PATH in current frontier to its neighbors
        for path in current_frontier:
            print(f"  Current BFS path: {path_to_string(path)}")

            # extract the last node in the path, expand into a NEW PATH
            # for the next frontier
            node = path[-1]
            for next_node in neighbors(graph, node):
                print(f"    Considering next node: {next_node}")
                new_path = path + [next_node]

                # return a PATH if reached goal
                if next_node == goal:
                    return new_path

                # otherwise collect NEW PATH into next frontier
                next_frontier.append(new_path)

        current_frontier, next_frontier = next_frontier, []

    return None


# bfs_tree = bfs_tree_path
# demo_bfs_tree()


############################################################
# finding paths in general graphs using breadth-first search
############################################################


def bfs_graph(graph, start, goal):
    if start == goal:
        return [start]

    current_frontier = [[start]]
    next_frontier = []
    # store a VISITED SET of nodes already seen
    visited = {start}

    while len(current_frontier) > 0:
        print(f"Current frontier: {pathlist_to_string(current_frontier)}")

        for path in current_frontier:
            print(f"  Current BFS path: {path_to_string(path)}")

            node = path[-1]
            for next_node in neighbors(graph, node):
                print(f"    Considering next node: {next_node}")

                # avoid expanding to nodes ALREADY SEEN, otherwise
                # RECORD AS VISITED
                if next_node in visited:
                    print(f"      AVOID revisiting {next_node}")
                    continue
                visited.add(next_node)

                new_path = path + [next_node]
                if next_node == goal:
                    return new_path
                next_frontier.append(new_path)

        current_frontier, next_frontier = next_frontier, []

    return None


def demo_bfs_graph():
    flights = {
        "Boston": ["Providence", "New York"],
        "Providence": ["Boston", "New York"],
        "New York": ["Chicago"],
        "Chicago": ["Denver", "Phoenix"],
        "Denver": ["New York", "Phoenix"],
        "Los Angeles": ["Boston"],
    }

    print(bfs_graph(flights, "Boston", "Phoenix"))
    print()

    # change the ordering in which we explore Chicago's neighbors
    flights["Chicago"].reverse()

    # shortest path solution does not change
    print(bfs_graph(flights, "Boston", "Phoenix"))
    print()

    # add new flight from Providence to Phoenix
    flights["Providence"].append("Phoenix")

    # find a shorter path than before
    print(bfs_graph(flights, "Boston", "Phoenix"))
    print()


# demo_bfs_graph()
