"""
manifest.py — Layer 1: Episode-level generation log and shot state tracking.

Layer 1 is the STATIC generation log — narrative intent. It includes:
  - Shot records with 5 consumer groups (routing, prompt, spatial, asset, audio)
  - Human approvals (previs/keyframe editorial decisions)
  - Review status

The dynamic execution state (Layer 2) lives in execution_plan.py.

Shot statuses (Layer 1):
  pending → submitted → processing → complete
                                   → failed
  pending → keyframe_pending → keyframe_generated → keyframe_approved → submitted
                                                  → keyframe_rejected → keyframe_pending

Thread-safe writes via fcntl.flock (macOS/Linux).
Atomic writes via temp file + os.replace().
"""

import fcntl
import json
import os
import tempfile
import time
from pathlib import Path
from typing import Optional

from recoil.core.paths import ProjectPaths


# Valid shot statuses
_MANIFEST_ORCH_VALID_STATUSES = {
    "pending",
    "submitted",
    "processing",
    "complete",
    "failed",
    "keyframe_pending",
    "keyframe_generated",
    "keyframe_approved",
    "keyframe_rejected",
    "qc_mechanical_failed",
    "qc_semantic_failed",
}


class EpisodeLog:
    """Read-write log.json with shot-level state tracking.

    Persists to output/frames/ep_{NNN}/log.json with file locking
    to prevent corruption from concurrent pipeline processes.
    """

    def __init__(self, episode: int, output_dir: Optional[Path] = None):
        self.episode = episode
        # v2 layout: prep/ep_NNN/ (replaces output/frames/ep_NNN/)
        self._output_dir = output_dir or (
            ProjectPaths.for_project().episode_prep_dir(episode)
        )
        self._log_path = self._output_dir / "log.json"
        self._data: dict = {
            "episode": episode,
            "created_at": time.time(),
            "updated_at": time.time(),
            "shots": {},
            "human_approvals": {},
        }
        # Load existing log if present
        self.load()

    def get_shot_status(self, shot_id: int) -> str:
        """Get the current status of a shot. Returns 'pending' if not tracked."""
        shot_key = str(shot_id)
        shot_data = self._data.get("shots", {}).get(shot_key, {})
        return shot_data.get("status", "pending")

    def update_shot(
        self,
        shot_id: int,
        status: str,
        job_dict: Optional[dict] = None,
        output_path: Optional[str] = None,
        cost: Optional[float] = None,
        error: Optional[str] = None,
        tier: Optional[str] = None,
        pipeline: Optional[str] = None,
        model: Optional[str] = None,
    ) -> None:
        """Update a shot's status and metadata.

        Args:
            shot_id: Shot ID from storyboard.
            status: New status (must be in _MANIFEST_ORCH_VALID_STATUSES).
            job_dict: Serialized Job for crash recovery.
            output_path: Path to generated output file.
            cost: Cost of generation.
            error: Error message if failed.
            tier: Complexity tier (simple/standard/complex).
            pipeline: Sub-pipeline used (still/i2v/t2v/multi_shot).
            model: Model ID used.
        """
        if status not in _MANIFEST_ORCH_VALID_STATUSES:
            raise ValueError(
                f"Invalid status '{status}'. Valid: {sorted(_MANIFEST_ORCH_VALID_STATUSES)}"
            )

        shot_key = str(shot_id)
        if shot_key not in self._data["shots"]:
            self._data["shots"][shot_key] = {
                "shot_id": shot_id,
                "status": "pending",
                "created_at": time.time(),
            }

        shot = self._data["shots"][shot_key]
        shot["status"] = status
        shot["updated_at"] = time.time()

        if job_dict is not None:
            shot["job"] = job_dict
        if output_path is not None:
            shot["output_path"] = output_path
        if cost is not None:
            shot["cost"] = cost
        if error is not None:
            shot["error"] = error
        if tier is not None:
            shot["tier"] = tier
        if pipeline is not None:
            shot["pipeline"] = pipeline
        if model is not None:
            shot["model"] = model

        self._data["updated_at"] = time.time()

    def get_pending_shots(self) -> list[int]:
        """Get shot IDs that haven't been started yet."""
        return [
            int(sid)
            for sid, data in self._data.get("shots", {}).items()
            if data.get("status") == "pending"
        ]

    def get_failed_shots(self) -> list[int]:
        """Get shot IDs that failed and may need retry."""
        return [
            int(sid)
            for sid, data in self._data.get("shots", {}).items()
            if data.get("status") == "failed"
        ]

    def get_complete_shots(self) -> list[int]:
        """Get shot IDs that completed successfully."""
        return [
            int(sid)
            for sid, data in self._data.get("shots", {}).items()
            if data.get("status") == "complete"
        ]

    def get_in_progress_shots(self) -> list[int]:
        """Get shot IDs currently being processed."""
        return [
            int(sid)
            for sid, data in self._data.get("shots", {}).items()
            if data.get("status") in ("submitted", "processing")
        ]

    def total_cost(self) -> float:
        """Sum of costs for all completed shots."""
        return sum(
            data.get("cost", 0.0)
            for data in self._data.get("shots", {}).values()
            if data.get("status") == "complete"
        )

    def summary(self) -> dict:
        """Summary counts by status."""
        counts: dict[str, int] = {}
        for data in self._data.get("shots", {}).values():
            status = data.get("status", "pending")
            counts[status] = counts.get(status, 0) + 1
        return {
            "episode": self.episode,
            "total_shots": len(self._data.get("shots", {})),
            "total_cost": round(self.total_cost(), 4),
            "by_status": counts,
        }

    def load(self) -> None:
        """Read log from disk if it exists."""
        if self._log_path.exists():
            try:
                raw = self._log_path.read_text(encoding="utf-8")
                self._data = json.loads(raw)
                # Ensure shots key exists as a dict
                if "shots" not in self._data:
                    self._data["shots"] = {}
                # Migrate shots from list to dict if needed (plan logs
                # store shots as a list; EpisodeLog expects a dict keyed
                # by shot_id string).
                elif isinstance(self._data["shots"], list):
                    shots_dict = {}
                    for shot in self._data["shots"]:
                        key = str(shot.get("shot_id", shot.get("id", "")))
                        if key:
                            shots_dict[key] = shot
                    self._data["shots"] = shots_dict
                # Ensure human_approvals key exists
                if "human_approvals" not in self._data:
                    self._data["human_approvals"] = {}
            except (json.JSONDecodeError, OSError) as e:
                # Corrupt log — start fresh but log warning
                import logging

                logging.getLogger(__name__).warning(
                    "Could not load log %s: %s — starting fresh",
                    self._log_path,
                    e,
                )

    def save(self) -> None:
        """Atomic write: temp file + os.replace() with file locking.

        Prevents corruption if the process crashes mid-write.
        """
        self._output_dir.mkdir(parents=True, exist_ok=True)
        self._data["updated_at"] = time.time()
        content = json.dumps(self._data, indent=2, default=str)

        # Atomic: write to temp then replace
        fd, temp_path = tempfile.mkstemp(
            dir=self._output_dir,
            prefix=".log_",
            suffix=".tmp",
        )
        try:
            with os.fdopen(fd, "w", encoding="utf-8") as f:
                fcntl.flock(f.fileno(), fcntl.LOCK_EX)
                try:
                    f.write(content)
                finally:
                    fcntl.flock(f.fileno(), fcntl.LOCK_UN)
            os.replace(temp_path, self._log_path)
        except Exception:
            Path(temp_path).unlink(missing_ok=True)
            raise

    # ── Human Approvals (Layer 1) ────────────────────────────────

    def approve_previs(self, shot_id: int, approved_by: str = "director") -> None:
        """Record previs approval in Layer 1 log.

        When a director approves a previs frame, write the approval to
        log.json so re-generating the execution plan never loses
        editorial decisions.
        """
        approvals = self._data.setdefault("human_approvals", {})
        shot_key = str(shot_id)
        approvals.setdefault(shot_key, {})
        approvals[shot_key]["previs_approved"] = True
        approvals[shot_key]["previs_approved_by"] = approved_by
        approvals[shot_key]["previs_approved_at"] = time.time()
        self._data["updated_at"] = time.time()

    def approve_keyframe(self, shot_id: int, approved_by: str = "director") -> None:
        """Record keyframe approval in Layer 1 log."""
        approvals = self._data.setdefault("human_approvals", {})
        shot_key = str(shot_id)
        approvals.setdefault(shot_key, {})
        approvals[shot_key]["keyframe_approved"] = True
        approvals[shot_key]["keyframe_approved_by"] = approved_by
        approvals[shot_key]["keyframe_approved_at"] = time.time()
        self._data["updated_at"] = time.time()

    def get_approval(self, shot_id: int) -> dict:
        """Get approval status for a shot."""
        approvals = self._data.get("human_approvals", {})
        return approvals.get(str(shot_id), {})

    def get_approved_shots(self) -> list[int]:
        """Get shot IDs that have been approved (previs or keyframe)."""
        approvals = self._data.get("human_approvals", {})
        return [
            int(sid)
            for sid, data in approvals.items()
            if data.get("previs_approved") or data.get("keyframe_approved")
        ]

    def init_shots(self, shot_ids: list[int]) -> None:
        """Initialize tracking for a batch of shot IDs (sets all to pending)."""
        for shot_id in shot_ids:
            shot_key = str(shot_id)
            if shot_key not in self._data["shots"]:
                self._data["shots"][shot_key] = {
                    "shot_id": shot_id,
                    "status": "pending",
                    "created_at": time.time(),
                }
