"""run_episode -- Episode-level meta-op for Phase 3 production cutover.

Orchestrates run_shot across all shots in an episode with:
- Style Anchor pre-pass (3 attempts, abort on failure unless --no-style-anchor)
- Coverage grouping (primary shots first, then siblings in parallel)
- Parallel execution via asyncio.Semaphore
- Budget enforcement (shared BudgetGuard across all shot threads)
- SIGTERM handling (loop.call_soon_threadsafe for thread safety)
- Abort/resume (state dump to runs/<run_id>.json)
- Morning triage summary

INVARIANT: Uses asyncio.run() as top-level entry point.
INVARIANT: run_shot() calls are dispatched via asyncio.to_thread().
"""

from __future__ import annotations

import asyncio
import json
import logging
import signal
import sys
import uuid
from collections import defaultdict
from pathlib import Path
from typing import Any

from recoil.pipeline._lib.coverage_context import (
    CoveragePassContext,
    EpisodeResult,
    OpResult,
    StopOnReview,
    SKIP_ON_RESUME,
    RETRY_ON_RESUME,
)
from recoil.pipeline._lib.budget_manager import BudgetGuard
from recoil.pipeline._lib.run_shot import run_shot
from recoil.pipeline._lib import ops_log
from recoil.core.paths import ProjectPaths as CoreProjectPaths

logger = logging.getLogger(__name__)

STYLE_ANCHOR_ATTEMPTS = 3


# ---------------------------------------------------------------------------
# ID generation
# ---------------------------------------------------------------------------

def _make_run_id() -> str:
    """Generate a run id: 'run_' + uuid7 hex prefix."""
    if sys.version_info < (3, 14):
        raise RuntimeError("uuid.uuid7() requires Python 3.14+. See spec D10.")
    return "run_" + uuid.uuid7().hex[:12]


# ---------------------------------------------------------------------------
# State persistence
# ---------------------------------------------------------------------------

def _state_dir(paths) -> Path:
    """Return the runs state directory."""
    d = CoreProjectPaths.from_root(paths.project_root).visual_state_dir / "runs"
    d.mkdir(parents=True, exist_ok=True)
    return d


def _save_run_state(
    paths,
    run_id: str,
    shot_results: list[OpResult],
    budget_guard: BudgetGuard,
    episode_id: str,
    aborted: bool = False,
    abort_reason: str | None = None,
    style_anchors: dict[str, Path] | None = None,
) -> Path:
    """Serialize run state to JSON for resume."""
    state = {
        "run_id": run_id,
        "episode_id": episode_id,
        "aborted": aborted,
        "abort_reason": abort_reason,
        "style_anchors": {k: str(v) for k, v in (style_anchors or {}).items()},
        "budget_spent": budget_guard.spent,
        "shots": [
            {
                "shot_id": r.shot_id,
                "status": r.status,
                "op_id": r.op_id,
                "output_path": r.output_path,
                "cost_usd": r.cost_usd,
                "attempts": r.attempts,
                "failure_mode": r.failure_mode,
                "validation_notes": r.validation_notes,
                "review_queue_id": r.review_queue_id,
            }
            for r in shot_results
        ],
    }
    state_path = _state_dir(paths) / f"{run_id}.json"
    state_path.write_text(json.dumps(state, indent=2))
    return state_path


def _load_run_state(paths, run_id: str) -> dict | None:
    """Load a previous run state for resume."""
    state_path = _state_dir(paths) / f"{run_id}.json"
    if not state_path.exists():
        return None
    return json.loads(state_path.read_text())


# ---------------------------------------------------------------------------
# Coverage grouping
# ---------------------------------------------------------------------------

def _group_by_coverage_pass(shots: list[dict]) -> list[dict]:
    """Group shots by coverage_pass_id. Auto-derive from scene_index + shot_type if absent.

    Returns a list of coverage groups, each containing:
      - coverage_pass_id: str
      - primary: dict (shot)
      - siblings: list[dict] (shots)
    """
    # Assign coverage_pass_id if missing
    for shot in shots:
        if "coverage_pass_id" not in shot:
            scene = shot.get("scene_index", shot.get("scene_id", "SC00"))
            shot_type = shot.get("shot_type", "primary")
            shot["coverage_pass_id"] = f"{scene}_COVERAGE_{shot_type}"

    # Group by coverage_pass_id
    groups: dict[str, list[dict]] = defaultdict(list)
    for shot in shots:
        groups[shot["coverage_pass_id"]].append(shot)

    result = []
    for pass_id, group_shots in groups.items():
        # First shot in each group is primary; rest are siblings
        primary = group_shots[0]
        siblings = group_shots[1:]
        result.append({
            "coverage_pass_id": pass_id,
            "primary": primary,
            "siblings": siblings,
        })

    return result


# ---------------------------------------------------------------------------
# Style Anchor
# ---------------------------------------------------------------------------

def _group_shots_by_scene(shots: list[dict]) -> dict[str, list[dict]]:
    """Group shots by scene key using fallback chain: scene_index -> location -> 'episode'."""
    groups: dict[str, list[dict]] = defaultdict(list)
    for shot in shots:
        scene_key = shot.get("scene_index") or shot.get("scene_id")
        if scene_key is None:
            scene_key = shot.get("location")
        if scene_key is None:
            scene_key = "episode"
        scene_key = str(scene_key)
        groups[scene_key].append(shot)
    return dict(groups)


def _generate_style_anchor(
    step_runner,
    store,
    paths,
    budget_guard: BudgetGuard,
    model: str,
    episode_id: str,
    shots: list[dict],
    run_id: str,
    scene_key: str = "episode",
) -> Path | None:
    """Generate a style anchor keyframe as pre-pass.

    Uses the first shot's identity refs and a generic location prompt.
    3 attempts max. Returns the saved path or None on failure.
    """
    anchor_path = CoreProjectPaths.from_root(paths.project_root).visual_state_dir / f"style_anchor_{episode_id}_{scene_key}.jpg"
    if anchor_path.exists():
        logger.info("Style anchor already exists: %s", anchor_path)
        return anchor_path

    # Build a style anchor shot dict from the first available shot
    first_shot = shots[0] if shots else {}
    anchor_shot = {
        "shot_id": f"{episode_id}_STYLE_ANCHOR",
        "prompt": first_shot.get("prompt", "Full body portrait, neutral pose, clean background"),
        "pipeline": "keyframe",
        "identity_refs": first_shot.get("identity_refs"),
        "scene_ref_path": first_shot.get("scene_ref_path"),
        "aspect_ratio": "9:16",
        "episode_id": episode_id,
    }

    for attempt in range(1, STYLE_ANCHOR_ATTEMPTS + 1):
        logger.info("Style anchor attempt %d/%d", attempt, STYLE_ANCHOR_ATTEMPTS)
        result = run_shot(
            shot=anchor_shot,
            store=store,
            paths=paths,
            budget_guard=budget_guard,
            model=model,
            step_runner=step_runner,
            run_id=run_id,
        )
        if result.status == "ok" and result.output_path:
            # Copy/move to canonical anchor location
            try:
                import shutil
                src = Path(result.output_path)
                if src.exists():
                    anchor_path.parent.mkdir(parents=True, exist_ok=True)
                    shutil.copy2(src, anchor_path)
                    return anchor_path
                else:
                    # output_path might be relative -- try under project root
                    abs_src = paths.project_root / src
                    if abs_src.exists():
                        anchor_path.parent.mkdir(parents=True, exist_ok=True)
                        shutil.copy2(abs_src, anchor_path)
                        return anchor_path
            except Exception as e:
                logger.warning("Could not copy style anchor: %s", e)
                return None

    logger.error("Style anchor failed after %d attempts", STYLE_ANCHOR_ATTEMPTS)
    return None


# ---------------------------------------------------------------------------
# run_episode (async internals)
# ---------------------------------------------------------------------------

async def _run_episode_async(
    project: str,
    episode_id: str,
    model: str,
    budget_usd: float,
    concurrency: int,
    stop_on_review: StopOnReview,
    resume_run_id: str | None,
    no_style_anchor: bool,
    step_runner,
    store,
    paths,
    shot_plan: list[dict],
) -> EpisodeResult:
    """Async internals for run_episode."""
    run_id = resume_run_id or _make_run_id()
    budget_guard = BudgetGuard(limit_usd=budget_usd, label=f"episode_{episode_id}")
    semaphore = asyncio.Semaphore(concurrency)
    abort_event = asyncio.Event()
    shot_results: list[OpResult] = []
    review_queue_count = 0

    # SIGTERM handler -- must use call_soon_threadsafe for asyncio.Event
    loop = asyncio.get_running_loop()
    original_sigterm = signal.getsignal(signal.SIGTERM)

    def _handle_sigterm(signum, frame):
        logger.warning("SIGTERM received -- aborting after in-flight shots complete")
        loop.call_soon_threadsafe(abort_event.set)

    signal.signal(signal.SIGTERM, _handle_sigterm)

    try:
        # ── Resume: load previous state ───────────────────────────────
        previous_results: dict[str, str] = {}  # {shot_id: status}
        if resume_run_id:
            prev_state = _load_run_state(paths, resume_run_id)
            if prev_state:
                for sr in prev_state.get("shots", []):
                    previous_results[sr["shot_id"]] = sr["status"]
                logger.info("Resuming run %s: %d previous results loaded",
                            run_id, len(previous_results))

        # ── Style Anchors (per-scene) ────────────────────────────────
        style_anchors: dict[str, Path] = {}
        if not no_style_anchor:
            scene_groups = _group_shots_by_scene(shot_plan)
            style_anchor_cost_ceiling = 0.50

            for scene_key, scene_shots in scene_groups.items():
                if budget_guard.spent >= style_anchor_cost_ceiling:
                    logger.warning(
                        "Style anchor cost ceiling ($%.2f) reached after %d scenes",
                        style_anchor_cost_ceiling, len(style_anchors),
                    )
                    break

                anchor_path = _generate_style_anchor(
                    step_runner=step_runner,
                    store=store,
                    paths=paths,
                    budget_guard=budget_guard,
                    model=model,
                    episode_id=episode_id,
                    shots=scene_shots,
                    run_id=run_id,
                    scene_key=scene_key,
                )
                if anchor_path is not None:
                    style_anchors[scene_key] = anchor_path
                else:
                    logger.warning("Style anchor failed for scene '%s'", scene_key)

            if not style_anchors and len(scene_groups) == 1:
                # Single-scene episode abort on anchor failure
                _save_run_state(
                    paths, run_id, shot_results, budget_guard, episode_id,
                    aborted=True, abort_reason="style_anchor_failed",
                    style_anchors=style_anchors,
                )
                return EpisodeResult(
                    run_id=run_id,
                    episode_id=episode_id,
                    total_shots=len(shot_plan),
                    completed=0,
                    by_status={},
                    total_cost_usd=budget_guard.spent,
                    budget_remaining_usd=budget_guard.remaining,
                    aborted=True,
                    abort_reason="style_anchor_failed",
                    style_anchors={},
                    review_queue_count=0,
                    shot_results=shot_results,
                )

        # ── Group shots by coverage pass ──────────────────────────────
        coverage_groups = _group_by_coverage_pass(shot_plan)

        # ── Run shots ─────────────────────────────────────────────────
        async def _run_one_shot(
            shot: dict,
            coverage_ctx: CoveragePassContext | None = None,
        ) -> OpResult:
            nonlocal review_queue_count

            if abort_event.is_set():
                return OpResult(
                    status="crashed", shot_id=shot["shot_id"],
                    op_id=ops_log.make_op_id(),
                    failure_mode=None,
                    validation_notes=["Aborted before start"],
                )

            shot_id = shot["shot_id"]

            # Resume: skip completed shots
            prev_status = previous_results.get(shot_id)
            if prev_status and prev_status in SKIP_ON_RESUME:
                logger.info("Skipping %s (previous status: %s)", shot_id, prev_status)
                return OpResult(
                    status=prev_status, shot_id=shot_id,
                    op_id=ops_log.make_op_id(),
                )

            async with semaphore:
                if abort_event.is_set():
                    return OpResult(
                        status="crashed", shot_id=shot_id,
                        op_id=ops_log.make_op_id(),
                        failure_mode=None,
                        validation_notes=["Aborted while waiting for slot"],
                    )

                # Look up per-scene style anchor
                shot_scene_key = str(
                    shot.get("scene_index")
                    or shot.get("scene_id")
                    or shot.get("location")
                    or "episode"
                )
                shot_style_anchor = style_anchors.get(shot_scene_key)
                if shot_style_anchor is None:
                    shot_style_anchor = style_anchors.get("episode")

                result = await asyncio.to_thread(
                    run_shot,
                    shot=shot,
                    store=store,
                    paths=paths,
                    budget_guard=budget_guard,
                    model=model,
                    step_runner=step_runner,
                    run_id=run_id,
                    style_anchor_path=shot_style_anchor,
                    coverage_context=coverage_ctx,
                    stop_on_review=stop_on_review,
                )

            # Track review queue entries
            if result.review_queue_id:
                review_queue_count += 1

            # Check StopOnReview policy
            if stop_on_review == StopOnReview.ON_ANY_REVIEW and result.review_queue_id:
                logger.warning("StopOnReview.ON_ANY_REVIEW: aborting after %s", shot_id)
                loop.call_soon_threadsafe(abort_event.set)
            elif stop_on_review == StopOnReview.ON_HARD_FAIL and result.status in (
                "icu_escalated", "crashed",
            ):
                logger.warning("StopOnReview.ON_HARD_FAIL: aborting after %s (%s)",
                               shot_id, result.status)
                loop.call_soon_threadsafe(abort_event.set)

            # Budget exhaustion check
            if budget_guard.remaining <= 0:
                logger.warning("Budget exhausted -- aborting after in-flight complete")
                loop.call_soon_threadsafe(abort_event.set)

            return result

        # Process coverage groups in order
        for group in coverage_groups:
            if abort_event.is_set():
                break

            pass_id = group["coverage_pass_id"]
            primary = group["primary"]
            siblings = group["siblings"]
            all_shot_ids = [primary["shot_id"]] + [s["shot_id"] for s in siblings]

            # Build coverage context for primary
            primary_ctx = CoveragePassContext(
                coverage_pass_id=pass_id,
                sibling_shot_ids=all_shot_ids,
                this_shot_role="primary",
                completed_siblings={},
                pass_min_success=max(0, len(all_shot_ids) - 1),
            )

            # Run primary shot first
            primary_result = await _run_one_shot(primary, primary_ctx)
            shot_results.append(primary_result)

            if abort_event.is_set():
                continue

            # Run siblings in parallel
            if siblings:
                completed_so_far = {primary["shot_id"]: primary_result.status}
                tasks = []
                for sib in siblings:
                    sib_ctx = CoveragePassContext(
                        coverage_pass_id=pass_id,
                        sibling_shot_ids=all_shot_ids,
                        this_shot_role=sib.get("shot_type", "coverage"),
                        completed_siblings=dict(completed_so_far),
                        pass_min_success=max(0, len(all_shot_ids) - 1),
                    )
                    tasks.append(_run_one_shot(sib, sib_ctx))

                sib_results = await asyncio.gather(*tasks)
                shot_results.extend(sib_results)

        # ── Build episode result ──────────────────────────────────────
        by_status: dict[str, int] = defaultdict(int)
        for r in shot_results:
            by_status[r.status] += 1

        aborted = abort_event.is_set()
        abort_reason: str | None = None
        if aborted:
            if budget_guard.remaining <= 0:
                abort_reason = "budget_exhausted"
            else:
                abort_reason = "sigterm_or_stop_on_review"

        result = EpisodeResult(
            run_id=run_id,
            episode_id=episode_id,
            total_shots=len(shot_plan),
            completed=len(shot_results),
            by_status=dict(by_status),
            total_cost_usd=budget_guard.spent,
            budget_remaining_usd=budget_guard.remaining,
            aborted=aborted,
            abort_reason=abort_reason,
            style_anchors=style_anchors,
            review_queue_count=review_queue_count,
            shot_results=shot_results,
        )

        # Save run state
        _save_run_state(
            paths, run_id, shot_results, budget_guard, episode_id,
            aborted=aborted, abort_reason=abort_reason,
            style_anchors=style_anchors,
        )

        return result

    finally:
        # Restore original signal handler
        signal.signal(signal.SIGTERM, original_sigterm)


# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------

def run_episode(
    project: str,
    episode_id: str,
    model: str,
    budget_usd: float,
    concurrency: int = 3,
    stop_on_review: StopOnReview = StopOnReview.NEVER,
    resume_run_id: str | None = None,
    no_style_anchor: bool = False,
    step_runner=None,
    store=None,
    paths=None,
    shot_plan: list[dict] | None = None,
) -> EpisodeResult:
    """Run all shots in an episode with full orchestration.

    This is the top-level entry point for overnight autonomous rendering.
    Uses asyncio.run() internally -- safe because this is the top-level call.

    Args:
        project: Project name (e.g. 'the-afterimage')
        episode_id: Episode ID (e.g. 'EP001')
        model: Model ID for generation
        budget_usd: Total budget for the episode
        concurrency: Max concurrent shots (default 3)
        stop_on_review: StopOnReview enum controlling abort behavior
        resume_run_id: If set, resume a previous run
        no_style_anchor: Skip style anchor generation
        step_runner: StepRunner instance (created if None)
        store: ExecutionStore instance (created if None)
        paths: ProjectPaths instance (created if None)
        shot_plan: List of shot dicts (loaded from plan file if None)

    Returns:
        EpisodeResult with morning triage summary.
    """
    # Lazy-create paths if not provided
    if paths is None:
        from recoil.execution.step_types import ProjectPaths
        ep_num = int(episode_id.replace("EP", "").replace("ep", ""))
        paths = ProjectPaths.for_episode(project, ep_num)

    # Load shot plan if not provided
    if shot_plan is None:
        plan_path = paths.plans_dir / f"ep_{episode_id.lower()}_plan.json"
        if plan_path.exists():
            shot_plan = json.loads(plan_path.read_text())
            if isinstance(shot_plan, dict):
                shot_plan = shot_plan.get("shots", [])
        else:
            shot_plan = []

    return asyncio.run(
        _run_episode_async(
            project=project,
            episode_id=episode_id,
            model=model,
            budget_usd=budget_usd,
            concurrency=concurrency,
            stop_on_review=stop_on_review,
            resume_run_id=resume_run_id,
            no_style_anchor=no_style_anchor,
            step_runner=step_runner,
            store=store,
            paths=paths,
            shot_plan=shot_plan,
        )
    )
