import random
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation


def run_multi_agent_episode(env, max_steps=100):
    """
    Runs a multi-agent episode in the given environment.
    """
    observation = env.reset()
    for step in range(max_steps):
        actions = []
        for blob in env.blobs:
            action = blob.get_action(observation)
            actions.append(action)
        observation = env.step(actions)


class Space:
    def sample(self):
        raise NotImplementedError()

    def contains(self, element):
        raise NotImplementedError()


class ContinuousGrid(Space):
    def __init__(self, width, height):
        self.width = width
        self.height = height

    def sample(self):
        return (random.uniform(0, self.width), random.uniform(0, self.height))

    def contains(self, element):
        x, y = element
        return 0 <= x <= self.width and 0 <= y <= self.height


class UnitCircle(Space):
    def sample(self):
        angle = random.uniform(0, 2 * math.pi)
        return (math.cos(angle), math.sin(angle))

    def contains(self, element):
        x, y = element
        eps = 1e-6
        return math.sqrt(x**2 + y**2) - 1 <= eps


class Blob:
    def __init__(self, position, radius):
        self.action_space = UnitCircle()
        self.position = position
        self.radius = radius

    def __eq__(self, other):
        return (
            isinstance(other, Blob)
            and self.position == other.position
            and self.radius == other.radius
        )

    def __gt__(self, other):
        return self.get_area() > other.get_area()

    def __lt__(self, other):
        return self.get_area() < other.get_area()

    def get_action(self, observation):
        # Default: move in a random direction
        return self.action_space.sample()

    def get_area(self):
        return math.pi * (self.radius**2)

    def _contains(self, other):
        x1, y1 = self.position
        r1 = self.radius
        x2, y2 = other.position
        r2 = other.radius
        distance = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
        return distance + r2 <= r1

    def attempt_eat(self, other):
        if self._contains(other) and self > other:
            self_area = self.get_area()
            other_area = other.get_area()
            new_area = self_area + other_area
            self.radius = math.sqrt(new_area / math.pi)
            return True
        return False

    def get_color(self):
        return "skyblue"


class AggressiveBlob(Blob):
    def get_action(self, observation):
        # Move toward the closest smaller blob
        closest_blob = None
        closest_distance = float("inf")
        for blob in observation:
            if blob == self:
                continue
            x1, y1 = self.position
            x2, y2 = blob.position
            if self > blob:
                distance = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
                if distance < closest_distance:
                    closest_distance = distance
                    closest_blob = blob

        if closest_blob is not None:
            x1, y1 = self.position
            x2, y2 = closest_blob.position
            direction = (x2 - x1, y2 - y1)
            norm = math.sqrt(direction[0] ** 2 + direction[1] ** 2)
            if norm > 0:
                return (direction[0] / norm, direction[1] / norm)
        # Otherwise move randomly
        return super().get_action(observation)

    def get_color(self):
        return "red"


class AgarioEnv:
    def __init__(self, width, height, blobs):
        self.space = ContinuousGrid(width, height)
        self.blobs = blobs
        self.history = []

    def reset(self):
        self.history = []
        return self.blobs

    def step(self, actions):
        state = [(b.position, b.radius, b.get_color()) for b in self.blobs]
        self.history.append(state)

        for i in range(len(self.blobs) - 1, -1, -1):
            blob, action = self.blobs[i], actions[i]
            x, y = blob.position
            dx, dy = action
            new_pos = (x + dx, y + dy)
            if self.space.contains(new_pos):
                blob.position = new_pos

            for j in range(len(self.blobs) - 1, -1, -1):
                if i != j and j < len(self.blobs):
                    other_blob = self.blobs[j]
                    if blob.attempt_eat(other_blob):
                        self.blobs.pop(j)
        return self.blobs

    def render_to_gif(self, filename="agario_sim.gif"):
        fig, ax = plt.subplots(figsize=(6, 6))
        ax.set_xlim(0, self.space.width)
        ax.set_ylim(0, self.space.height)
        ax.set_aspect("equal")

        def update(frame):
            ax.clear()
            ax.set_xlim(0, self.space.width)
            ax.set_ylim(0, self.space.height)
            ax.set_title(f"Step {frame + 1}")
            for pos, r, color in self.history[frame]:
                x, y = pos
                circle = plt.Circle((x, y), r, color=color, alpha=0.6)
                ax.add_patch(circle)
            return []

        ani = animation.FuncAnimation(fig, update, frames=len(self.history), blit=False)
        ani.save(filename, writer="pillow", fps=8)
        plt.close(fig)


# --- Example usage ---
if __name__ == "__main__":
    width, height = 100, 100
    blobs = [
        AggressiveBlob(
            position=(random.uniform(0, width), random.uniform(0, height)),
            radius=random.uniform(0.3, 1.5),
        )
        for _ in range(30)
    ] + [
        Blob(
            position=(random.uniform(0, width), random.uniform(0, height)),
            radius=random.uniform(0.3, 1.5),
        )
        for _ in range(30)
    ]

    env = AgarioEnv(width, height, blobs)

    run_multi_agent_episode(env, max_steps=200)
    env.render_to_gif("agario_sim.gif")
