"""
execution_plan.py — Layer 2: Dynamic execution state for shot generation.

Separate from Layer 1 (static plan from ingest_pipeline.py).
Layer 1 = narrative intent (what to render). Layer 2 = execution state
(what has been rendered, what's in-flight, what failed).

Features:
  - Shot status state machine (previs → keyframe → video → complete)
  - Session-based crash recovery (UUID tracking, orphan detection)
  - SQLite + WAL backend via ExecutionStore (ADR-R03)
  - Scene-level checkpointing
  - Cost accumulation per shot

ADRs: H04 (Two-Layer Plan/Log), H05 (Session-Based Crash Recovery), R03 (SQLite Migration)
"""

import json
import logging
import os
import time
import uuid
from pathlib import Path
from typing import Optional

from recoil.execution.execution_store import ExecutionStore, SHOT_VALID_STATUSES, VALID_TRANSITIONS, InvalidTransitionError
from recoil.core.model_profiles import get_model

logger = logging.getLogger(__name__)


class ExecutionPlan:
    """Layer 2: Dynamic execution plan for an episode.

    Thin wrapper around ExecutionStore (SQLite + WAL).
    Manages shot-level state through the generation lifecycle.
    Supports crash recovery via session_id tracking and orphan detection.
    State machine validation lives here; storage is delegated to ExecutionStore.
    """

    def __init__(self, episode_id: str, manifest_path: Path | None = None,
                 store: ExecutionStore | None = None):
        self.episode_id = episode_id
        self.session_id = str(uuid.uuid4())
        self._manifest_path = manifest_path
        self._store = store or ExecutionStore()

    def generate_from_plan(self, plan: dict) -> dict:
        """Initialize execution plan from a plan.

        Creates one execution record per shot in the plan.
        Uses routing helpers for pipeline/model assignment.

        Args:
            plan: Parsed episode plan dict.

        Returns:
            Summary dict with episode_id, session_id, shot count.
        """
        shots = plan.get("shots", [])
        batch = []
        for shot in shots:
            shot_id = shot.get("shot_id", "")
            routing = shot.get("routing_data", {})

            pipeline = _determine_pipeline(routing)
            model = _determine_model(pipeline, routing)

            batch.append({
                "shot_id": shot_id,
                "episode_id": self.episode_id,
                "pipeline": pipeline,
                "model": model,
                "status": "previs_pending",
                "job_id": None,
                "session_id": self.session_id,
                "gate_results": {},
                "cost_incurred": 0.0,
                "output_path": None,
                "error_message": None,
                "attempts": 0,
                "max_attempts": 3,
            })

        self._store.insert_shots_batch(batch)

        logger.info(
            "Execution plan generated: %d shots for %s",
            len(batch), self.episode_id,
        )
        return {
            "episode_id": self.episode_id,
            "session_id": self.session_id,
            "total_shots": len(batch),
        }

    def load(self, path: Path | None = None) -> dict:
        """Load execution plan from SQLite.

        For backward compatibility, also supports loading from a JSON file
        and migrating the data into SQLite.
        """
        # Try loading from legacy JSON if explicit path provided
        if path and path.exists():
            return self._migrate_from_json(path)

        # Already in SQLite — build compat dict
        shots = self._store.get_shots_by_episode(self.episode_id)
        return {
            "episode_id": self.episode_id,
            "session_id": self.session_id,
            "shots": {s["shot_id"]: s for s in shots},
        }

    def _migrate_from_json(self, path: Path) -> dict:
        """One-time migration: load JSON execution plan into SQLite."""
        try:
            data = json.loads(path.read_text(encoding="utf-8"))
            logger.info("Migrating JSON execution plan to SQLite: %s", path)
            shots_dict = data.get("shots", {})
            batch = []
            for shot_id, shot_data in shots_dict.items():
                shot_data["shot_id"] = shot_id
                shot_data.setdefault("episode_id", self.episode_id)
                batch.append(shot_data)
            if batch:
                self._store.insert_shots_batch(batch)
            # Rename legacy file to indicate migration
            migrated = path.with_suffix(".json.migrated")
            path.rename(migrated)
            logger.info("Legacy JSON renamed to %s", migrated)
            return data
        except (json.JSONDecodeError, OSError) as e:
            logger.warning("Could not migrate JSON %s: %s", path, e)
            return {"episode_id": self.episode_id, "shots": {}}

    def save(self, path: Path | None = None) -> None:
        """No-op — SQLite auto-persists via WAL.

        Kept for backward compatibility with callers that expect save().
        """
        pass

    def checkpoint(self) -> None:
        """No-op — SQLite auto-persists. Kept for backward compat."""
        self._store.checkpoint()

    # ── Shot State Management ────────────────────────────────────

    def get_shot_status(self, shot_id: str) -> str:
        """Get the current status of a shot."""
        return self._store.get_shot_status(shot_id)

    def update_shot(self, shot_id: str, **fields) -> None:
        """Update shot fields. State machine validation is enforced by ExecutionStore.

        Common fields: status, job_id, output_path, error_message,
                       cost_incurred, gate_results

        Raises:
            InvalidTransitionError: If the status transition is not allowed.
        """
        # Always stamp session_id
        fields["session_id"] = self.session_id
        self._store.update_shot(shot_id, **fields)

    def get_shot(self, shot_id: str) -> dict | None:
        """Get full shot record."""
        return self._store.get_shot(shot_id)

    def get_shots_by_status(self, status: str) -> list[str]:
        """Get shot IDs with a specific status for this episode."""
        shots = self._store.get_shots_by_status(status)
        return [
            s["shot_id"] for s in shots
            if s.get("episode_id") == self.episode_id
        ]

    def total_cost(self) -> float:
        """Sum of costs for all shots in this episode."""
        return self._store.total_cost(self.episode_id)

    def summary(self) -> dict:
        """Status summary for the execution plan."""
        s = self._store.summary(self.episode_id)
        s["session_id"] = self.session_id
        return s

    # ── Crash Recovery ───────────────────────────────────────────

    def detect_orphans(self) -> list[dict]:
        """Detect orphaned shots — in-flight but from a different session."""
        return self._store.detect_orphans(self.session_id)

    def recover_orphan(self, shot_id: str) -> None:
        """Re-claim an orphaned shot with the current session_id."""
        self._store.recover_orphan(shot_id, self.session_id)
        logger.info("Recovered orphan %s for session %s", shot_id, self.session_id[:8])


# ── Pipeline Routing Helpers ─────────────────────────────────────

def _determine_pipeline(routing_data: dict) -> str:
    """Determine sub-pipeline from routing data."""
    has_dialogue = routing_data.get("has_dialogue", False)
    is_env = routing_data.get("is_env_only", False)
    num_chars = routing_data.get("num_characters", 0)
    camera = routing_data.get("camera_complexity", "static")
    needs_match_cut = routing_data.get("narrative_requires_match_cut", False)

    if is_env:
        return "still"
    if needs_match_cut:
        return "i2v"
    if has_dialogue and num_chars >= 2:
        return "t2v"
    if camera in ("tracking", "crane", "steadicam", "dolly"):
        return "t2v"
    return "still"


def _determine_model(pipeline: str, routing_data: dict) -> str:
    """Determine target model from pipeline type."""
    model_map = {
        "still": get_model("production", "image"),
        "i2v": get_model("i2v", "video"),
        "t2v": get_model("t2v_default", "video"),
        "multi_shot": get_model("multi_shot", "video"),
    }
    return model_map.get(pipeline, get_model("production", "image"))


# ── Shot Composer Utilities ──────────────────────────────────────


def compute_insert_id(episode_id: str, after_shot_id: str, before_shot_id: str | None) -> str:
    """Compute a new shot_id for insertion between two existing shots.

    Uses alphabetic suffix insertion. If after=SH03 and before=SH04, returns SH03A.
    If after=SH03A and before=SH03B, returns SH03AA.
    If after=SH03A and before=SH04, returns SH03B.

    Args:
        episode_id: Episode prefix, e.g. "EP001"
        after_shot_id: Full shot_id to insert after, e.g. "EP001_SH03"
        before_shot_id: Full shot_id of the next shot, or None if inserting at end

    Returns:
        New shot_id string, e.g. "EP001_SH03A"
    """
    import re as _re

    # Parse the after shot
    m = _re.match(r"EP\d{3}_SH(\d{2,3})([A-Z]*)", after_shot_id)
    if not m:
        raise ValueError(f"Cannot parse after_shot_id: {after_shot_id}")
    after_base = int(m.group(1))
    after_suffix = m.group(2)

    if before_shot_id is None:
        # Inserting at end — just append 'A' to after's suffix
        new_suffix = after_suffix + "A"
        return f"{episode_id}_SH{after_base:02d}{new_suffix}"

    # Parse the before shot
    m2 = _re.match(r"EP\d{3}_SH(\d{2,3})([A-Z]*)", before_shot_id)
    if not m2:
        raise ValueError(f"Cannot parse before_shot_id: {before_shot_id}")
    before_base = int(m2.group(1))
    before_suffix = m2.group(2)

    # Case 1: Different base numbers (SH03 → SH04) — room to insert after
    if after_base != before_base:
        if after_suffix == "":
            new_suffix = "A"
        elif after_suffix[-1] == "Z":
            new_suffix = after_suffix + "A"
        else:
            # Increment the last letter: A→B, B→C, etc.
            new_suffix = after_suffix[:-1] + chr(ord(after_suffix[-1]) + 1)
        return f"{episode_id}_SH{after_base:02d}{new_suffix}"

    # Case 2: Same base, different suffixes — find the next available letter
    # e.g. after=SH03A, before=SH03B → try SH03AA (go deeper)
    # e.g. after=SH03, before=SH03B → try SH03A
    if len(after_suffix) < len(before_suffix) or after_suffix < before_suffix:
        # Try incrementing after's last suffix char
        if after_suffix == "":
            if before_suffix == "A":
                raise ValueError(
                    f"No room between {after_shot_id} and {before_shot_id} — "
                    f"renumber the scene first"
                )
            new_suffix = "A"
        else:
            last_char = after_suffix[-1]
            if last_char == "Z":
                new_suffix = after_suffix + "A"
            else:
                next_char = chr(ord(last_char) + 1)
                candidate = after_suffix[:-1] + next_char
                if before_suffix and candidate >= before_suffix:
                    # Would collide — go deeper
                    new_suffix = after_suffix + "A"
                else:
                    new_suffix = candidate
    else:
        # Suffixes are adjacent or after >= before — go deeper
        new_suffix = after_suffix + "A"

    new_id = f"{episode_id}_SH{after_base:02d}{new_suffix}"

    # Final collision guard — if we generated the same ID as before_shot_id,
    # there's no room (e.g. inserting between SH01A and SH01AA)
    if before_shot_id and new_id == before_shot_id:
        raise ValueError(
            f"No room between {after_shot_id} and {before_shot_id} — "
            f"renumber the scene first"
        )

    return new_id


def insert_composed_shot(
    project: str,
    episode_id: str,
    after_shot_id: str,
    shot_record: dict,
    store: 'ExecutionStore | None' = None,
) -> str:
    """Insert a composed shot into both the plan file and execution store.

    Args:
        project: Project name (e.g. "starsend-test")
        episode_id: Episode ID (e.g. "EP001")
        after_shot_id: Shot ID to insert after
        shot_record: Full ShotRecord dict (must include shot_id)
        store: ExecutionStore instance (creates one if None)

    Returns:
        The shot_id of the inserted shot.
    """
    from recoil.core.paths import ProjectPaths
    from recoil.execution.execution_store import ExecutionStore as _ES

    if store is None:
        store = _ES(project=project)

    shot_id = shot_record["shot_id"]
    ep_num = int(episode_id.replace("EP", ""))

    # 1. Insert into execution store FIRST (easier to rollback)
    routing = shot_record.get("routing_data", {})
    pipeline = _determine_pipeline(routing)
    model = _determine_model(pipeline, routing)

    store.insert_shot({
        "shot_id": shot_id,
        "episode_id": episode_id,
        "pipeline": pipeline,
        "model": model,
        "status": "previs_pending",
    })

    # 2. Insert into plan file
    plans_dir = ProjectPaths.for_project(project).plans_dir
    plan_path = plans_dir / f"ep_{ep_num:03d}_plan.json"

    try:
        if plan_path.exists():
            plan = json.loads(plan_path.read_text(encoding="utf-8"))
            shots = plan.get("shots", [])

            # Find insertion index (after the specified shot)
            insert_idx = len(shots)  # Default: append at end
            for i, s in enumerate(shots):
                if s.get("shot_id") == after_shot_id:
                    insert_idx = i + 1
                    break

            shots.insert(insert_idx, shot_record)
            plan["shots"] = shots
            plan["total_shots"] = len(shots)
            plan_path.write_text(json.dumps(plan, indent=2, default=str), encoding="utf-8")
            logger.info("Inserted %s into plan at index %d", shot_id, insert_idx)
        else:
            logger.warning("Plan file %s does not exist — shot %s inserted to store only", plan_path, shot_id)
    except Exception:
        store.delete_shot(shot_id)  # Rollback
        raise

    logger.info("Inserted composed shot %s into execution store", shot_id)
    return shot_id


def renumber_scene(project: str, episode_id: str, scene_index: int) -> dict:
    """Renumber all shots in a scene with clean sequential IDs.

    Reassigns shot_ids preserving narrative order. Use after heavy insertion
    makes IDs ugly (e.g. SH03ABBA).

    Args:
        project: Project name
        episode_id: Episode ID (e.g. "EP001")
        scene_index: Scene index to renumber

    Returns:
        Dict mapping old_shot_id → new_shot_id
    """
    from recoil.core.paths import ProjectPaths
    from recoil.execution.execution_store import ExecutionStore as _ES
    from recoil.pipeline._lib.render_schema import get_sort_float

    store = _ES(project=project)
    ep_num = int(episode_id.replace("EP", ""))
    plans_dir = ProjectPaths.for_project(project).plans_dir
    plan_path = plans_dir / f"ep_{ep_num:03d}_plan.json"

    if not plan_path.exists():
        return {}

    plan = json.loads(plan_path.read_text(encoding="utf-8"))
    shots = plan.get("shots", [])

    # Filter to scene and sort by current order
    scene_shots = [s for s in shots if s.get("scene_index") == scene_index]
    scene_shots.sort(key=lambda s: get_sort_float(s["shot_id"]))

    if not scene_shots:
        return {}

    # Build renaming map
    rename_map = {}
    # Find the base number of the first shot in this scene
    import re as _re
    first_m = _re.match(r"EP\d{3}_SH(\d{2,3})", scene_shots[0]["shot_id"])
    if not first_m:
        return {}
    start_num = int(first_m.group(1))

    for i, shot in enumerate(scene_shots):
        old_id = shot["shot_id"]
        new_id = f"{episode_id}_SH{start_num + i:02d}"
        if old_id != new_id:
            rename_map[old_id] = new_id

    if not rename_map:
        return {}  # Already clean

    # Check for collisions with shots in other scenes
    all_ids = {s["shot_id"] for s in shots if s.get("scene_index") != scene_index}
    for old_id, new_id in rename_map.items():
        if new_id in all_ids:
            raise ValueError(f"Renumber collision: {new_id} already exists in another scene")

    # Apply renaming to plan
    for shot in shots:
        if shot["shot_id"] in rename_map:
            shot["shot_id"] = rename_map[shot["shot_id"]]

    plan["shots"] = shots
    plan["total_shots"] = len(shots)
    plan_path.write_text(json.dumps(plan, indent=2, default=str), encoding="utf-8")

    # Apply renaming to execution store (collect-all, delete-all, insert-all for swap safety)
    rename_data = {}
    for old_id, new_id in rename_map.items():
        old_data = store.get_shot(old_id)
        if old_data:
            old_data["shot_id"] = new_id
            rename_data[old_id] = (new_id, old_data)

    # Delete all old entries
    for old_id in rename_data:
        store.delete_shot(old_id)

    # Insert all new entries
    for old_id, (new_id, data) in rename_data.items():
        store.insert_shot(data)

    logger.info("Renumbered scene %d: %d shots renamed", scene_index, len(rename_map))
    return rename_map
