############################################################
# graph and display helpers
############################################################


def neighbors(graph, node):
    return graph.get(node, [])


def path_to_string(path):
    """path is a list of nodes"""
    if path is None:
        return "None"
    node_strs = []
    for node in path:
        node_strs.append(str(node))
    return "-->".join(node_strs)


def pathlist_to_string(queue):
    """path is a list of nodes"""
    path_strs = []
    for path in queue:
        path_strs.append(path_to_string(path))
    return "[" + ", ".join(path_strs) + "]"


############################################################
# Dijkstra's on graphs with zero- and negative-weight edges
############################################################


# This is the same code as in Lecture 14. The purpose is to show it
# works equally well on graphs with zero-weight edges, and also in the
# situation negative-weight edges are limited to outgoing edges from the
# start.


def remove_min(queue):
    best = min(queue)
    idx = queue.index(best)
    return queue.pop(idx)


def find_node(queue, node):
    for idx in range(len(queue)):
        cost, queue_node = queue[idx]
        if queue_node == node:
            return idx
    return None


def update_node(queue, node, cost):
    idx = find_node(queue, node)

    if idx is None:
        print(f"    Adding {node!r} to queue")
        queue.append((cost, node))
        updated = True

    else:
        old_cost, old_path = queue[idx]
        if cost < old_cost:
            print(f"    Updating {node!r} on queue")
            queue[idx] = cost, node
            updated = True
        else:
            print(f"    No change to queue, new cost ({cost}) >= old cost ({old_cost})")
            updated = False

    print(f"    Current queue: {queue}")
    return updated


def trace_predecessors(node, start, predecessors):
    current = node
    path = [current]
    while current != start:
        current = predecessors[current]
        path.append(current)
    path.reverse()
    return path


def dijkstra_predecessors(graph, start, goal):
    queue = [(0, start)]
    finished = set()
    predecessors = {start: None}

    while len(queue) > 0:
        print(f"Current queue: {queue}")

        cost, node = remove_min(queue)
        finished.add(node)
        print(f"  Finished {node!r} with cost {cost}.")
        print(f"  Finished set: {finished}")

        if node == goal:
            return cost, trace_predecessors(node, start, predecessors)

        print(f"Expanding to neighbors from {node}")
        for next_node, weight in neighbors(graph, node):
            if next_node not in finished:
                print(f"  Processing {node!r}-->{next_node!r} with weight {weight}")
                new_cost = cost + weight
                if update_node(queue, next_node, new_cost):
                    predecessors[next_node] = node

        print()

    return None


test_graph = {
    "A": [("B", 2), ("C", 2)],
    "B": [("D", 3)],
    "C": [("E", 3)],
    "D": [("F", 2), ("G", 4), ("H", 2)],
    "E": [("H", 0)],  # THIS LINE IS DIFFERENT FROM LECTURE
    "H": [("G", 1)],
}


# the zero-weight edge E-->H puts H on the same frontier as E
# print(dijkstra_predecessors(test_graph, "A", "G"))

# negative edges going out of the start/source node don't affect
# correctness of Dijkstra's
# test_graph["A"] = [("B", -3), ("C", -2)]
# print(dijkstra_predecessors(test_graph, "A", "G"))

# adding a zero-weight edge B-->C results in a shorter path to C, and
# hence a shorter path to G
# test_graph["B"] += [("C", 0)]
# print(dijkstra_predecessors(test_graph, "A", "G"))


############################################################
# recursive BFS
############################################################


def bfs_helper(graph, start, goal, depth):
    """
    Compute the BFS frontier at a certain depth if the goal hasn't been
    found, or return the frontier where a goal was previously found.

    Return a tuple consisting of:
        found (bool): Whether the goal has been found.
        frontier (list): The current frontier if not `found`, or the
            frontier from the depth when `found` first became True.
        visited (set): The set of nodes discovered so far.
    """
    print(f"calling bfs_helper() at {depth = }")

    # base case
    if depth == 0:
        print(f"{depth = }: reached base case")
        return start == goal, [[start]], {start}

    # recursive case
    found, prev_frontier, visited = bfs_helper(graph, start, goal, depth - 1)

    # pass up solution if already reached goal in earlier frontier
    if found:
        print(f"{depth = }: returning solution from lower depth")
        return found, prev_frontier, visited

    # build up current frontier
    frontier = []
    for path in prev_frontier:
        node = path[-1]
        for next_node in neighbors(graph, node):
            if next_node not in visited:
                visited.add(next_node)
                frontier.append(path + [next_node])
                if next_node == goal:
                    print(f"{depth = }: found goal")
                    return True, frontier, visited

    # did not find goal yet in curret frontier
    print(f"{depth = }: goal not yet found")
    return False, frontier, visited


def count_nodes(graph):
    # collect nodes from dict keys
    nodes = set(graph.keys())

    # then collect end nodes from each specified edge
    for adjacencies in graph.values():
        nodes |= set(adjacencies)

    return len(nodes)


def bfs(graph, start, goal):
    max_depth = count_nodes(graph) - 1
    found, frontier, visited = bfs_helper(graph, start, goal, max_depth)
    if found:
        for path in frontier:
            if path[-1] == goal:
                return path
    return None


filesystem = {
    "home": ["school", "personal", "downloads"],
    "school": ["spring25", "fall25", "spring26"],
    "spring26": ["6.100", "8.02"],
    "personal": ["photos", "bills"],
    "downloads": ["ps1.zip", "ps1", "beach.jpg"],
    "ps1": ["pset.py", "test.py"],
}

flights = {
    "Boston": ["Providence", "New York"],
    "Providence": ["Boston"],
    "New York": ["Chicago", "Phoenix"],
    "Chicago": ["Providence", "Denver"],
    "Phoenix": ["Chicago"],
    "Denver": ["New York", "Phoenix"],
}


# print()
# print(bfs(filesystem, "home", "pset.py"))
# print()
# print(bfs(flights, "Boston", "Phoenix"))
# print()
# print(bfs(flights, "Boston", "Los Angeles"))


############################################################
# iterative deepening
############################################################


def dfs_helper(graph, start, goal, depth, visited):
    print(f"  calling dfs_helper() on node {start}")
    visited.add(start)

    # original base case
    if start == goal:
        return [start]

    # another base case: terminate if reach our allowed depth
    if depth == 0:
        return None

    # recursive case: run dfs on each neighbor with decremented depth
    for next_node in neighbors(graph, start):
        if next_node in visited:
            continue
        path_to_goal = dfs_helper(graph, next_node, goal, depth - 1, visited)
        if path_to_goal is not None:
            return [start] + path_to_goal

    return None


def dfs(graph, start, goal, depth=float("inf")):
    visited = set()
    path = dfs_helper(graph, start, goal, depth, visited)
    return path, visited


def iterative_deepening(graph, start, goal):
    depth = 0
    prev_visited = set()
    while True:
        print()
        print(f"running dfs() with {depth = }")
        path, visited = dfs(graph, start, goal, depth)

        # terminate if reachable set no longer grows with increasing depth
        if visited == prev_visited:
            return None
        prev_visited = visited

        if path is not None:
            return path
        depth += 1


# print()
# print(iterative_deepening(filesystem, "home", "pset.py"))
# print()
# print(iterative_deepening(flights, "Boston", "Phoenix"))
# print()
# print(iterative_deepening(flights, "Boston", "Los Angeles"))
