import heapq
import matplotlib
import matplotlib.pyplot as plt
from math import radians, sin, cos, sqrt, atan2


############################################################
# Classes Example: Method Resolution Order
############################################################

class A:
    def __init__(self):
        self.value = 'A'
    def show(self):
        return self.value


class B(A):
    def __init__(self):
        super().__init__()
        self.value += 'B'

class C(A):
    def __init__(self):
        super().__init__()
        self.value += 'C'
    def show(self):
        return self.value + '!'

class D(B, C):
    def __init__(self):
        super().__init__()
        self.value += 'D'

obj = D()
# print(obj.show())

############################################################
# Classes Frame Example
############################################################

class A:
    def __init__(self):
        self.value = 'A'

    def method1(self, msg):
        print(msg)


class B(A):
    def __init__(self):
        super().__init__()
        self.value = 'B'

    def method1(self, msg=None):
        for i in range(3):
            super().method1("Hello from B")


class C(B):
    def __init__(self):
        super().__init__()
        self.value = 'C'
        self.cost = 100

    def method1(self, msg=None):
        super().method1("Hello from C")

# You can copy it into Pythontutor to see the frame behavior
# obj = C()
# obj.method1("Hello from main")



############################################################
# EXCEPTIONS
############################################################

class DataError(Exception):
    pass

def prepare_filename(filename):
    try:
        new_filename = filename.strip().lower()
        return new_filename
    except AttributeError:
        raise AttributeError("Invalid filename format")
    else:
        print("Filename prepared")

def read_and_parse(filename):
    file = None
    try:
        file = open(prepare_filename(filename), "r")
        text = file.read()
        if not text.strip():
            raise DataError("File is empty")
        number = int(text.strip())
        return number
    except ValueError:
        print("Inner except: Invalid number format.")
        raise DataError("Parsing failed")
    finally:
        print("Inner finally: Closing file.")
        if file:
            file.close()

def process_data():
    try:
        number = read_and_parse("missing.txt")
        print("Processed number:", number + 10)
        return number
    except DataError as e:
        print("Outer except (DataError):", e)
    except Exception as e:
        print("Some error happened")
    finally:
        print("Cleanup complete.")

def app():
    try:
        result = process_data()
        print("App result:", result)
    except Exception as e:
        print("App error")
    else:
        print("All went fine!")

# app()

########################################
# Recursion and object frames
########################################

class Node:
    def __init__(self, name, value, children=None):
        self.name = name
        self.value = value
        self.children = children if children else []

    def total_value(self, max_depth):
        total = self.value
        if max_depth > 0:
            for child in self.children:
                total += child.total_value(max_depth -1)
        else:
            print("Max depth reached - skipping anything below")
        return total

a1 = Node("A1", 2)
a2 = Node("A2", 3)
a = Node("A", 5, [a1, a2])
b = Node("B", 10)
root = Node("root", 1, [a, b])

# print(root.total_value(2))

########################################
# Asserts
########################################

def average_per_subject(records):
    subject_totals = {}
    subject_counts = {}
    for record in records:
        _, grades, subjects = record
        assert len(grades) == len(subjects), "They should be the same length"
        assert len(grades) > 0, "They shoudl be at least one subject with a grade"
        for subject, grade in zip(subjects, grades):
            subject_totals[subject] = subject_totals.get(subject, 0) + grade
            subject_counts[subject] = subject_counts.get(subject, 0) + 1

    return {
        subject: subject_totals[subject] / subject_counts[subject]
        for subject in subject_totals
    }


def average_performance(records):
    total_sum = 0
    total_count = 0
    for record in records:
        _, grades, subjects = record
        assert len(grades) == len(subjects), "They should be the same length"
        assert len(grades) > 0, "They shoudl be at least one subject with a grade"
        if len(subjects) >0:
            total_sum += sum(grades) / len(grades)
            total_count += 1

    return total_sum / total_count if total_count else 0.0


grades = [
    [['peter', 'parker'], [10.0, 5.0, 8.5], ['math', 'science', 'english']],
    [['bruce', 'wayne'], [10.0, 8.0, 7.4], ['math', 'science', 'english']],
    [['pan', 'peter'], [], ['math', 'science', 'english']]
]

# print("\nAverage per subject:")
# for subject, avg in average_per_subject(grades).items():
#     print(f"  {subject}: {avg:.2f}")

# print(f"\nOverall total average: {average_performance(grades):.2f}")



########################################
# Representing graphs
########################################

class Node:
    """Represents a generic graph node with only a name (string)."""

    def __init__(self, name):
        self.name = name

    def __str__(self):
        return self.name

    def __eq__(self, other):
        return self.name == str(other)

    def __lt__(self, other):
        return self.name < str(other)


    def __hash__(self):
        # Nodes are uniquely identified by their name
        return hash(self.name)


class MapNode(Node):
    """Extends Node with (x, y) coordinate information."""

    def __init__(self, name, coords):
        super().__init__(name)
        if not (isinstance(coords, tuple) and len(coords) == 2):
            raise ValueError("coords must be a tuple (x, y)")
        self.coords = coords  # e.g., (-71.0589, 42.3601)


class SimpleDigraph:
    """Represents a weighted directed graph using Nodes as keys."""

    def __init__(self, nodes=()):
        self._edges = {}  # dict: Node -> dict(Node -> weight)
        for node in nodes:
            self.add_node(node)

    def add_node(self, node):
        self._edges[node] = {}

    def get_node(self, id):
        if id in self._edges:
            return id

    def add_edge(self, src, dest, weight=1):
        """Add a directed edge between two Node objects (or names)."""
        src = self.get_node(src)
        dest = self.get_node(dest)
        self._edges[src][dest] = weight

    def get_all_nodes(self):
        return list(self._edges.keys())

    def outgoing_edges_of(self, node):
        return self._edges[node].copy()

    def children_of(self, node):
        return list(self._edges[node].keys())


class Digraph:
    """Represents a weighted directed graph using Nodes as keys."""

    def __init__(self, nodes=()):
        self._edges = {}  # dict: Node -> dict(Node -> weight)
        for node in nodes:
            self.add_node(node)

    def add_node(self, node):
        """Add a Node object to the graph."""
        if not isinstance(node, Node):
            raise TypeError("add_node expects a Node instance")
        if node in self._edges:
            raise ValueError(f"Duplicate node: {node.name}")
        self._edges[node] = {}

    def get_node(self, id):
        """Internal helper: resolve a node name or return node directly."""
        if isinstance(id, str) or isinstance(id, Node):
            for n in self._edges:
                if n.name == id:
                    return n
            raise ValueError(f"Unknown node name: '{id}'")
        raise TypeError(f"Expected Node or str, got {type(id).__name__}")

    def add_edge(self, src, dest, weight=1):
        """Add a directed edge between two Node objects (or names)."""
        src = self.get_node(src)
        dest = self.get_node(dest)
        self._edges[src][dest] = weight

    def get_all_nodes(self):
        return list(self._edges.keys())

    def outgoing_edges_of(self, node):
        return self._edges[node].copy()

    def children_of(self, node):
        return list(self._edges[node].keys())

    def __str__(self):
        vals = []
        for src in sorted(self._edges.keys(), key=lambda n: n.name):
            entry = f"{src}: "
            if self._edges[src]:
                entry += ", ".join(f"{dest.name}({w})" for dest, w in self._edges[src].items())
            vals.append(entry)
        return "\n".join(vals)


class Graph(Digraph):
    def add_edge(self, node1, node2, weight=1):
        super().add_edge(node1, node2, weight)
        super().add_edge(node2, node1, weight)


class MapGraph(Graph):
    """Graph that stores MapNode instances and offers geographic utilities."""

    EARTH_RADIUS_M = 6_371_000  # mean Earth radius in meters

    def __init__(self, nodes=(), max_speed_kmh=100):
        super().__init__(nodes)
        self.max_speed_kmh = max_speed_kmh

    def add_node(self, node):
        if not isinstance(node, MapNode):
            raise TypeError("MapGraph expects MapNode instances")
        super().add_node(node)

    def _resolve_mapnode(self, node):
        resolved = self.get_node(node)
        if not isinstance(resolved, MapNode):
            raise TypeError("MapGraph operations require MapNode instances but got: " + str(type(resolved)))
        return resolved

    def _coords(self, node):
        return self._resolve_mapnode(node).coords

    def haversine_distance(self, node1, node2):
        """Return the great-circle distance between two nodes in meters."""
        lon1, lat1 = self._coords(node1)
        lon2, lat2 = self._coords(node2)
        phi1, phi2 = radians(lat1), radians(lat2)
        dphi = radians(lat2 - lat1)
        dlambda = radians(lon2 - lon1)
        a = sin(dphi / 2)**2 + cos(phi1) * cos(phi2) * sin(dlambda / 2)**2
        c = 2 * atan2(sqrt(a), sqrt(1 - a))
        return MapGraph.EARTH_RADIUS_M * c

    def distance(self, node1, node2):
        """Return the Euclidean distance between two nodes in coordinate space."""
        x1, y1 = self._coords(node1)
        x2, y2 = self._coords(node2)
        return sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)

    def dist_great_circle(self, node1, node2):
        """Return the great-circle distance between nodes in kilometers."""
        return self.haversine_distance(node1, node2) / 1000

    def add_edge(self, node1, node2, weight=None):
        """
        Add an undirected edge whose weight equals the great-circle distance
        between the two endpoints (in kilometers).
        """
        node1_resolved = self._resolve_mapnode(node1)
        node2_resolved = self._resolve_mapnode(node2)
        actual_weight = self.distance(node1_resolved, node2_resolved)
        super().add_edge(node1_resolved, node2_resolved, actual_weight)


########################################
# Example graphs
########################################


def build_city_nodes_graph():
    city_coords = {
        "Boston": (-71.0589, 42.3601),
        "Providence": (-71.4128, 41.8240),
        "New York": (-74.0060, 40.7128),
        "Chicago": (-87.6298, 41.8781),
        "Denver": (-104.9903, 39.7392),
        "Pittsburgh": (-79.9959, 40.4406),
        "Salt Lake City": (-111.8910, 40.7608),
        "San Francisco": (-122.4194, 37.7749),
        "Houston": (-95.3698, 29.7604),
        "Bozeman": (-111.0429, 45.6770),
        "Seattle": (-122.3321, 47.6062),
        "Minneapolis": (-93.2650, 44.9778),
        "Los Angeles": (-118.2437, 34.0522),
        "Atlanta": (-84.3880, 33.7490),
        "Miami": (-80.1918, 25.7617),
        "Philadelphia": (-75.1652, 39.9526),
        "Phoenix": (-112.0740, 33.4484),
        "San Diego": (-117.1611, 32.7157),
        "Dallas": (-96.7970, 32.7767),
        "Portland": (-122.6587, 45.5122),
        "Las Vegas": (-115.1398, 36.1699),
        "Austin": (-97.7431, 30.2672),
        "Nashville": (-86.7816, 36.1627),
        "Indianapolis": (-86.1581, 39.7684),
        "Charlotte": (-80.8431, 35.2271),
        "Cleveland": (-81.6944, 41.4993)
    }

    # Create MapNode objects for each city
    nodes = {name: MapNode(name, coords) for name, coords in city_coords.items()}

    # Create the graph with these nodes (no edges yet)
    g = MapGraph(nodes.values())

    g.add_edge('Boston', 'Cleveland')
    g.add_edge('Boston', 'Providence')
    g.add_edge('Boston', 'New York')

    g.add_edge('Cleveland', 'Minneapolis')
    g.add_edge('Cleveland', 'Chicago')
    g.add_edge('Cleveland', 'Pittsburgh')

    g.add_edge('Providence', 'Boston')
    g.add_edge('Providence', 'New York')

    g.add_edge('New York', 'Cleveland')
    g.add_edge('New York', 'Pittsburgh')
    g.add_edge('New York', 'Philadelphia')

    g.add_edge('Philadelphia', 'Indianapolis')
    g.add_edge('Philadelphia', 'Charlotte')

    g.add_edge('Pittsburgh', 'Indianapolis')

    g.add_edge('Chicago', 'Minneapolis')
    g.add_edge('Chicago', 'Denver')
    g.add_edge('Chicago', 'Indianapolis')

    g.add_edge('Charlotte', 'Indianapolis')
    g.add_edge('Charlotte', 'Nashville')
    g.add_edge('Charlotte', 'Atlanta')

    g.add_edge('Indianapolis', 'Denver')
    g.add_edge('Indianapolis', 'Nashville')

    g.add_edge('Minneapolis', 'Bozeman')

    g.add_edge('Atlanta', 'Dallas')
    g.add_edge('Atlanta', 'Miami')
    g.add_edge('Atlanta', 'Houston')

    g.add_edge('Miami', 'Houston')

    g.add_edge('Dallas', 'Denver')
    g.add_edge('Dallas', 'Las Vegas')
    g.add_edge('Dallas', 'Phoenix')
    g.add_edge('Dallas', 'San Diego')
    g.add_edge('Dallas', 'Austin')
    g.add_edge('Dallas', 'Houston')

    g.add_edge('Houston', 'Austin')

    g.add_edge('Austin', 'San Diego')

    g.add_edge('Bozeman', 'Seattle')
    g.add_edge('Bozeman', 'Salt Lake City')

    g.add_edge('Nashville', 'Denver')
    g.add_edge('Nashville', 'Dallas')

    g.add_edge('Denver', 'Bozeman')
    g.add_edge('Denver', 'Salt Lake City')
    g.add_edge('Denver', 'Las Vegas')

    g.add_edge('Salt Lake City', 'Seattle')
    g.add_edge('Salt Lake City', 'Portland')
    g.add_edge('Salt Lake City', 'San Francisco')

    g.add_edge('Las Vegas', 'San Francisco')
    g.add_edge('Las Vegas', 'Los Angeles')
    g.add_edge('Las Vegas', 'Phoenix')

    g.add_edge('Phoenix', 'San Diego')

    g.add_edge('Seattle', 'Portland')

    g.add_edge('Portland', 'San Francisco')

    g.add_edge('San Francisco', 'Los Angeles')

    g.add_edge('Los Angeles', 'San Diego')

    return g



def visualize_graph(graph, display_names=False, color = "C0"):
    Lx, Ly, Lc = [], [], []
    def draw_line(n1, n2):
        Lx.append(n1.coords[0])
        Lx.append(n2.coords[0])
        Ly.append(n1.coords[1])
        Ly.append(n2.coords[1])
        Lx.append(None)
        Ly.append(None)

    for n1 in graph.get_all_nodes():
        if display_names:
            plt.text(n1.coords[0], n1.coords[1], n1.name)
        for n2 in graph.outgoing_edges_of(n1):
            w = g.outgoing_edges_of(n1)[n2]
            draw_line(n1, n2)

    return plt.plot(Lx, Ly, linewidth=1, color=color)

def visualize_path(path, line_width = 10, color = "C1", linestyle='-'):
    Lx, Ly, Lc = [], [], []
    #path is a list of nodes
    prev_node = None
    for node in path:
        if prev_node:
            Lx.append(prev_node.coords[0])
            Lx.append(node.coords[0])
            Ly.append(prev_node.coords[1])
            Ly.append(node.coords[1])
        prev_node = node
    plt.plot(Lx, Ly, linewidth=line_width,  color = color, linestyle=linestyle)

def dijkstra_heap(graph, start, goal, visualize = False, pause=0.5):
    start_node = graph.get_node(start)
    goal_node = graph.get_node(goal)

    # Initialize heap (priority queue) with (cost, path)
    queue = [(0, [start_node])]
    heapq.heapify(queue)  # ensures it's a valid heap
    visited = set()

    while queue:
        # Pop the smallest-cost path
        cost, path = heapq.heappop(queue)
        current_node = path[-1]

        # Skip if already processed
        if current_node in visited:
            print(f"Skipping {current_node!s} with cost {cost}. Finished set: {{"
              f"{', '.join(str(node) for node in visited)}}}")
            continue
        visited.add(current_node)
        print(f"✅ Finished {current_node!s} with cost {cost}. Finished set: {{"
              f"{', '.join(str(node) for node in visited)}}}")

        if visualize:
            plt.clf()
            visualize_graph(graph, True)
            visualize_path(path)
            plt.title("Dijkstra")
            plt.plot(start_node.coords[0], start_node.coords[1], 'o', color='red',markersize=10, )
            plt.plot(goal_node.coords[0], goal_node.coords[1], 'o', color='red',markersize=10, )
            plt.draw()
            plt.pause(pause)

        # Stop when we reach the goal
        if current_node == goal_node:
            return cost, path

        # Expand neighbors
        for neighbor, weight in graph.outgoing_edges_of(current_node).items():
            if neighbor in visited:
                continue
            new_cost = cost + weight
            new_path = path + [neighbor]
            heapq.heappush(queue, (new_cost, new_path))
            print(f"  ➕ Added path {current_node!s} → {neighbor!s} (total cost {new_cost})")

        # Optional: pretty-print queue contents for debugging
        pretty_queue = [(c, [str(n) for n in p]) for c, p in queue]
        print(f"Current queue: {pretty_queue}\n")

    # If no path found
    return None

# small_graph = Graph()
# small_graph.add_node(Node("A"))
# small_graph.add_node(Node("B"))
# small_graph.add_node(Node("C"))
# small_graph.add_node(Node("D"))
# small_graph.add_edge("A", "B", 2)
# small_graph.add_edge("A", "C", 4)
# small_graph.add_edge("B", "C", 1)
# small_graph.add_edge("B", "D", 7)
# small_graph.add_edge("C", "D", 3)

# result = dijkstra_heap(small_graph, "A", "D", visualize=False)

# if result is not None:
#     cost, path = result
#     print("Shortest Path Found:")
#     print(" → ".join(str(node) for node in path))
#     print(f"Total Cost: {cost:.2f}")
# else:
#     print("No path found between the specified nodes.")


continue_processing = False
def pause_until_key_pressed(actually_pause_now = True):
    def on_key(event):
        if event.key =='q':
            sys.exit()
        global continue_processing
        continue_processing = not continue_processing
        if not continue_processing:
            while not continue_processing:
                plt.pause(0.1)  # Yield control to the event loop

    global continue_processing
    continue_processing = not actually_pause_now
    fig = plt.gcf()
    fig.canvas.mpl_connect('key_press_event', on_key)
    while not continue_processing:
        plt.pause(0.1)  # Yield control to the event loop



def astar_heap(graph, start, goal, visualize=False, pause=0.5):
    start_node = graph.get_node(start)
    goal_node = graph.get_node(goal)

    # Initialize heap with (f_score, g_score, path)
    # f_score = g_score + h_score (where h_score is the heuristic)
    g_score = 0  # Cost from start to current node
    h_score = graph.distance(start_node, goal_node)  # Heuristic: Euclidean distance to goal
    f_score = g_score + h_score
    queue = [(f_score, g_score, [start_node])]
    heapq.heapify(queue)
    visited = set()

    while queue:
        # Pop the path with smallest f_score

        f_score, g_score, path = heapq.heappop(queue)
        current_node = path[-1]
        print(f"A*: Exploring {current_node!s} with f_score {f_score} + g_score {g_score} = {g_score + h_score}")

        # Skip if already processed
        if current_node in visited:
            continue
        visited.add(current_node)
        # print(f"✅ Finished {current_node!s} with f_score {f_score}. Finished set: {{"
        #       f"{', '.join(str(node) for node in visited)}}}")

        if visualize:
            plt.clf()
            visualize_graph(graph, True)
            visualize_path(path)
            plt.title("A*")
            if current_node == goal_node:
                pause_until_key_pressed()
            plt.plot(start_node.coords[0], start_node.coords[1], 'o', color='red',markersize=10)
            plt.plot(goal_node.coords[0], goal_node.coords[1], 'o', color='red',markersize=10)
            visualize_path([goal_node, current_node], 4, "red", linestyle=':')
            plt.draw()
            plt.pause(pause)

        # Stop when we reach the goal
        if current_node == goal_node:
            return g_score, path

        # Expand neighbors
        for neighbor, edge_weight in graph.outgoing_edges_of(current_node).items():
            if neighbor in visited:
                continue
            new_g_score = g_score + edge_weight
            new_h_score = graph.distance(neighbor, goal_node)  # Heuristic estimate
            new_f_score = new_g_score + new_h_score
            new_path = path + [neighbor]
            heapq.heappush(queue, (new_f_score, new_g_score, new_path))
            # print(f"  ➕ Added path {current_node!s} → {neighbor!s} (f_score={new_f_score:.2f}, g={new_g_score:.2f}, h={new_h_score:.2f})")

        # Optional: pretty-print queue contents for debugging
        pretty_queue = [(f, g, [str(n) for n in p]) for f, g, p in queue]
        # print(f"Current queue: {pretty_queue}\n")

    # If no path found
    return None

########################################
# Test Dijkstra and A* on city graph
########################################

def test_dijkstra():
    plt.ion()
    plt.figure(1)
    visualize_graph(g, True)
    plt.plot(start_node.coords[0], start_node.coords[1], 'o', color='red',markersize=10)
    plt.plot(end_node.coords[0], end_node.coords[1], 'o', color='red',markersize=10)
    plt.title("Dijkstra")
    plt.show(block=False)
    plt.draw()
    pause_until_key_pressed()
    result_dijkstra = dijkstra_heap(g, start, end, visualize=True)

def test_astar():
    plt.ion()
    plt.clf()
    visualize_graph(g, True)
    plt.plot(start_node.coords[0], start_node.coords[1], 'o', color='red',markersize=10)
    plt.plot(end_node.coords[0], end_node.coords[1], 'o', color='red',markersize=10)
    plt.title("A*")
    plt.show(block=False)
    plt.draw()
    pause_until_key_pressed()
    result_astar = astar_heap(g, start, end, visualize=True)

g = build_city_nodes_graph()
start = 'Boston'
end = 'Phoenix'
start_node = g.get_node(start)
end_node = g.get_node(end)

# test_dijkstra()
# test_astar()

############################################################
# Alternative Dijkstra's Algorithm Implementation and compute all paths
############################################################

def dijkstra_heap_all(graph, start):
    start_node = graph.get_node(start)

    # Distance and predecessor tables
    distances = {node: float('inf') for node in graph.get_all_nodes()}
    distances[start_node] = 0
    predecessors = {node: None for node in graph.get_all_nodes()}

    queue = [(0, start_node)]
    heapq.heapify(queue)
    visited = set()

    while queue:
        cost, current_node = heapq.heappop(queue)
        if current_node in visited:
            continue
        visited.add(current_node)

        for neighbor, weight in graph.outgoing_edges_of(current_node).items():
            new_cost = cost + weight
            if new_cost < distances[neighbor]:
                distances[neighbor] = new_cost
                predecessors[neighbor] = current_node
                heapq.heappush(queue, (new_cost, neighbor))

    return distances, predecessors

def reconstruct_path(predecessors, start_node, goal_node):
    """Rebuilds the shortest path from start_node to goal_node."""
    path = []
    current = goal_node
    while current is not None:
        path.append(current)
        current = predecessors[current]
    path.reverse()
    if path[0] != start_node:
        return None
    return path


def print_dijkstra_result(graph, start):
    """Nicely prints the results of dijkstra_heap_all."""
    distances, predecessors = dijkstra_heap_all(graph, start)
    start_node = graph.get_node(start)

    print(f"\n🔹 Shortest paths from node {start_node}:\n")
    print(f"{'Destination':<15}{'Distance':<10}{'Path'}")
    print("-" * 50)

    for node, dist in distances.items():
        if dist == float('inf'):
            print(f"{str(node):<15}{'∞':<10}No path")
        else:
            path = reconstruct_path(predecessors, start_node, node)
            path_str = " → ".join(str(n) for n in path)
            print(f"{str(node):<15}{dist:<10.2f}{path_str}")

def dijkstra_all_pairs(graph):
    """Compute shortest distances and paths between *every* pair of nodes."""
    all_pairs = {}
    nodes = graph.get_all_nodes()

    for start in nodes:
        distances, predecessors = dijkstra_heap_all(graph, start)
        print_dijkstra_result(graph, start.name)

# Example usage:
# print_dijkstra_result(g, start)
# dijkstra_all_pairs(g)

############################################################
# Prim's Algorithm
############################################################

def prim(graph, start):
    start_node = graph.get_node(start)

    predecessors = {node: None for node in graph.get_all_nodes()}
    total_cost = 0

    queue = [(0, None, start_node)]
    heapq.heapify(queue)
    visited = set()

    while queue:
        cost, from_node, current_node = heapq.heappop(queue)
        if current_node in visited:
            continue
        visited.add(current_node)
        total_cost += cost

        if from_node is not None:
            predecessors[current_node] = from_node

        for neighbor, weight in graph.outgoing_edges_of(current_node).items():
            heapq.heappush(queue, (weight, current_node, neighbor))

    return total_cost, predecessors

def print_prim_result(graph, start):
    """Nicely prints the results of Prim's MST algorithm."""
    total_cost, predecessors = prim(graph, start)
    start_node = graph.get_node(start)

    print(f"\n🌲 Minimum Spanning Tree starting from {start_node}:")
    print("-" * 50)

    for node, parent in predecessors.items():
        if parent is not None:
            weight = graph.outgoing_edges_of(parent)[node]
            print(f"{parent} — {node} (weight {weight})")

    print("-" * 50)
    print(f"Total MST cost: {total_cost:.2f}")


# Example call:
# print_prim_result(g, start)

############################################################
# Floyd Warshall Algorithm
############################################################

def floyd_warshall(graph):
    """
    Compute shortest paths between all pairs of nodes using the Floyd–Warshall algorithm.
    Returns:
        distances[(u, v)] = shortest distance from u to v
        next_node[(u, v)] = next node on the shortest path from u to v
    """
    nodes = list(graph.get_all_nodes())
    distances = {}
    next_node = {}

    # Initialize distances with edge weights and 0 for self-loops
    for u in nodes:
        for v in nodes:
            if u == v:
                distances[(u, v)] = 0
            else:
                distances[(u, v)] = float('inf')
            next_node[(u, v)] = None

    # Fill in direct edge weights
    for u in nodes:
        for v, weight in graph.outgoing_edges_of(u).items():
            distances[(u, v)] = weight
            next_node[(u, v)] = v

    # Main triple loop
    for k in nodes:
        for i in nodes:
            for j in nodes:
                if distances[(i, k)] + distances[(k, j)] < distances[(i, j)]:
                    distances[(i, j)] = distances[(i, k)] + distances[(k, j)]
                    next_node[(i, j)] = next_node[(i, k)]

    return distances, next_node


def reconstruct_fw_path(next_node, start, goal):
    """
    Reconstruct the path from start to goal using the Floyd–Warshall next_node table.
    """
    if next_node[(start, goal)] is None:
        return None

    path = [start]
    while start != goal:
        start = next_node[(start, goal)]
        if start is None:
            return None
        path.append(start)
    return path


def floyd_warshall_all_paths(graph):
    """
    Wrapper that returns both the shortest distance and full path between all pairs.
    """
    distances, next_node = floyd_warshall(graph)
    results = {}
    nodes = list(graph.get_all_nodes())

    for i in nodes:
        for j in nodes:
            path = reconstruct_fw_path(next_node, i, j)
            results[(i, j)] = {
                "distance": distances[(i, j)],
                "path": path
            }

    return results

# results = floyd_warshall_all_paths(g)

# for (u, v), info in results.items():
#     if info["path"] is None or info["distance"] == float("inf"):
#         print(f"{u} → {v}: no path")
#     else:
#         path_str = " → ".join(str(n) for n in info["path"])
#         print(f"{u} → {v}: {info['distance']:.2f} via {path_str}")
