# api/state.py
"""Module-level singletons: task registry, generation tracker, thread pool.

Moved verbatim from review_server.py.
Only addition: event_loop reference for SSE broadcast (Phase 9).
"""

import asyncio
import threading
import time
import uuid as _uuid
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

# ── Constants ──────────────────────────────────────────────────────
EDITORS_DIR = Path(__file__).resolve().parent.parent / "editors"
PROJECT_ROOT = Path(__file__).resolve().parent.parent

# ── Default project (set by lifespan in main.py) ──────────────────
default_project: str = None

# ── Event loop reference (set by lifespan, used by _broadcast_task_event for SSE) ──
event_loop: asyncio.AbstractEventLoop = None


# ── Generation Tracker ─────────────────────────────────────────────
class GenerationTracker:
    """Prevents duplicate generation submissions. Fast UX-layer rejection."""

    TIMEOUT_SECS = 300  # 5 minutes

    def __init__(self):
        self._active: dict[str, float] = {}  # shot_id -> start_time
        self._lock = threading.Lock()

    def try_start(self, shot_id: str) -> bool:
        """Atomically try to claim a shot for generation. Returns False if already active."""
        now = time.time()
        with self._lock:
            if shot_id in self._active:
                if now - self._active[shot_id] > self.TIMEOUT_SECS:
                    pass  # Expired — allow re-claim
                else:
                    return False
            self._active[shot_id] = now
            return True

    def finish(self, shot_id: str) -> None:
        """Release a shot after generation completes (success or failure)."""
        with self._lock:
            self._active.pop(shot_id, None)

    def is_active(self, shot_id: str) -> bool:
        """Check if a shot is currently being generated."""
        with self._lock:
            if shot_id not in self._active:
                return False
            if time.time() - self._active[shot_id] > self.TIMEOUT_SECS:
                self._active.pop(shot_id, None)
                return False
            return True


gen_tracker = GenerationTracker()
thread_pool = ThreadPoolExecutor(max_workers=4)

# ── Task Registry ──────────────────────────────────────────────────
task_registry: dict = {}
task_lock = threading.Lock()


def submit_task(entity_id, action, fn, *args, metadata=None, **kwargs):
    """Submit a callable to the thread pool with task tracking.

    Returns task_id immediately. Background thread updates registry on completion.
    """
    task_id = _uuid.uuid4().hex[:8]
    with task_lock:
        task_registry[task_id] = {
            "task_id": task_id,
            "entity_id": entity_id,
            "action": action,
            "status": "running",
            "started": time.time(),
            "result": None,
            "error": None,
            "metadata": metadata or {},
        }

    def _wrapper():
        try:
            result = fn(*args, **kwargs)
            with task_lock:
                task_registry[task_id]["status"] = "complete"
                task_registry[task_id]["result"] = result
            _broadcast_task_event("task:complete", task_id, entity_id, action)
        except Exception as e:
            import traceback
            traceback.print_exc()
            with task_lock:
                task_registry[task_id]["status"] = "failed"
                task_registry[task_id]["error"] = str(e)
            _broadcast_task_event("task:failed", task_id, entity_id, action, error=str(e))

    thread_pool.submit(_wrapper)
    return task_id


def prune_task_registry():
    """Remove tasks older than 10 minutes, keeping at most 50."""
    cutoff = time.time() - 600
    with task_lock:
        expired = [tid for tid, t in task_registry.items()
                   if t["status"] in ("complete", "failed") and t["started"] < cutoff]
        for tid in expired:
            del task_registry[tid]
        if len(task_registry) > 50:
            sorted_tasks = sorted(task_registry.items(), key=lambda x: x[1]["started"])
            for tid, _ in sorted_tasks[:len(task_registry) - 50]:
                if task_registry[tid]["status"] in ("complete", "failed"):
                    del task_registry[tid]


def _broadcast_task_event(event_type, task_id, entity_id, action, error=None, status=None):
    """Push SSE event if event loop is available. No-op before Phase 9."""
    if event_loop is None or event_loop.is_closed():
        return  # Loop gone (server reloaded) or not yet set — skip silently
    try:
        from .sse import broadcast
        data = {"task_id": task_id, "entity_id": entity_id, "action": action}
        if error:
            data["error"] = error
        if status:
            data["status"] = status
        asyncio.run_coroutine_threadsafe(broadcast(event_type, data), event_loop)
    except Exception:
        pass  # SSE is nice-to-have, never crash the task
