"""
take_provenance.py — Full reproduction recipe per take.

Provenance records are stored both:
1. In-line on the take record in ExecutionStore (as a 'provenance' key)
2. As a batch manifest JSON at projects/{project}/output/manifests/

The in-line record enables per-shot reproduction from the Console.
The manifest enables batch-level analysis and editorial handoff.
"""

import hashlib
import json
import logging
import os
import tempfile
import time
from pathlib import Path
from typing import Optional

from recoil.pipeline._lib.take_keys import TakeNumberMissingError, read_take_number
from orchestrator.production_types import ProvenanceRecord

logger = logging.getLogger(__name__)


def _file_hash(path: Path) -> str:
    """SHA-256 hash of a file for provenance tracking."""
    if not path.exists():
        return ""
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(8192), b""):
            h.update(chunk)
    return f"sha256:{h.hexdigest()[:16]}"


def _ref_entry(role: str, path: Path) -> dict:
    """Build a reference image provenance entry."""
    return {
        "role": role,
        "path": str(path),
        "hash": _file_hash(path),
    }


class ProvenanceWriter:
    """Writes take provenance records to ExecutionStore and manifest files.

    Thread-safe: all writes go through ExecutionStore's lock or atomic file ops.
    """

    def __init__(self, project: str, output_root: Path):
        """
        Args:
            project: Project name (e.g. "tartarus").
            output_root: Project output root (projects/{project}/output/).
        """
        self.project = project
        self._output_root = output_root
        self._manifests_dir = output_root / "manifests"
        self._manifests_dir.mkdir(parents=True, exist_ok=True)

    def build_record(
        self,
        shot_id: str,
        episode_id: str,
        take_number: int,
        step_result,
        prompt: str = "",
        negative_prompt: str = "",
        model: str = "",
        phase: str = "keyframe",
        refs_used: Optional[list[dict]] = None,
        params: Optional[dict] = None,
        gate_results: Optional[dict] = None,
        parent_take: Optional[str] = None,
        change_reason: Optional[str] = None,
        inputs_snapshot: Optional[dict] = None,
    ) -> ProvenanceRecord:
        """Build a ProvenanceRecord from a StepResult and generation context.

        Args:
            shot_id: Shot identifier.
            episode_id: Episode identifier.
            take_number: Take number for this shot.
            step_result: StepResult from StepRunner.
            prompt: Full generation prompt.
            negative_prompt: Negative prompt if used.
            model: Model ID used.
            phase: Generation phase (previs/keyframe/video).
            refs_used: List of ref dicts [{role, path, hash}].
            params: Model-specific generation params.
            gate_results: Gate verdict results dict.
            parent_take: Parent take ID for retries/re-runs.
            change_reason: Why this take was generated (retry reason).
            inputs_snapshot: Full inputs snapshot from StepRunner.
        """
        take_id = f"{shot_id}_T{take_number}"

        # Extract cost from step_result
        cost_dict = {"total": step_result.cost_usd}
        if phase == "video":
            cost_dict["video"] = step_result.cost_usd
        else:
            cost_dict[phase] = step_result.cost_usd
        if step_result.gate_verdict:
            cost_dict["gates"] = step_result.gate_verdict.cost

        # Build gate results from step_result
        gates = {}
        if gate_results:
            gates = gate_results
        elif step_result.gate_verdict:
            gv = step_result.gate_verdict
            gates[gv.gate_name] = {
                "passed": gv.passed,
                "reason": gv.reason,
                "cost": gv.cost,
                "deferred": gv.deferred,
            }

        record = ProvenanceRecord(
            take_id=take_id,
            shot_id=shot_id,
            episode_id=episode_id,
            project=self.project,
            attempt=take_number,
            phase=phase,
            model=model or step_result.model,
            prompt=prompt,
            negative_prompt=negative_prompt,
            params=params or {},
            refs_used=refs_used or [],
            gates=gates,
            cost=cost_dict,
            parent_take=parent_take,
            change_reason=change_reason,
        )

        # Enrich from inputs_snapshot if available
        if inputs_snapshot:
            if not record.refs_used and "refs_sent" in inputs_snapshot:
                record.refs_used = [
                    {"role": r.get("type", ""), "path": r.get("url", ""), "hash": ""}
                    for r in inputs_snapshot.get("refs_sent", [])
                    if r.get("sent_to_model")
                ]
            if not record.prompt and "prompt" in inputs_snapshot:
                record.prompt = inputs_snapshot["prompt"]

        return record

    def write_to_store(self, store, record: ProvenanceRecord) -> None:
        """Append provenance to the shot's take record in ExecutionStore.

        Finds the matching take by take_number and adds a 'provenance' key.
        """
        shot = store.get_shot(record.shot_id)
        if not shot:
            logger.warning("Cannot write provenance: shot %s not found", record.shot_id)
            return

        takes = shot.get("takes", [])
        take_num = record.attempt

        # Find the matching take and enrich it
        for take in takes:
            try:
                take_n = read_take_number(take)
            except TakeNumberMissingError:
                continue
            if take_n == take_num:
                take["provenance"] = record.to_dict()
                break
        else:
            logger.warning(
                "Take %d not found for shot %s — appending provenance as new take",
                take_num, record.shot_id,
            )
            takes.append({"take_number": take_num, "provenance": record.to_dict()})

        store.update_shot(record.shot_id, takes=takes)

    def write_manifest(self, episode_id: str, records: list[ProvenanceRecord]) -> Path:
        """Write a batch manifest JSON with all provenance records for an episode.

        Returns path to the written manifest file.
        """
        manifest_path = self._manifests_dir / f"{episode_id}_manifest.json"

        manifest = {
            "episode_id": episode_id,
            "project": self.project,
            "generated_at": time.time(),
            "total_takes": len(records),
            "total_cost": round(sum(r.cost.get("total", 0) for r in records), 4),
            "takes": [r.to_dict() for r in records],
        }

        # Atomic write
        fd, tmp = tempfile.mkstemp(
            dir=str(self._manifests_dir), prefix=".manifest_", suffix=".tmp"
        )
        try:
            with os.fdopen(fd, "w", encoding="utf-8") as f:
                json.dump(manifest, f, indent=2, default=str)
            os.replace(tmp, str(manifest_path))
        except Exception:
            try:
                os.unlink(tmp)
            except OSError:
                pass
            raise

        logger.info("Manifest written: %s (%d takes)", manifest_path, len(records))
        return manifest_path

    def create_selects_symlinks(self, episode_id: str, store) -> int:
        """Create symlinks in selects/ for approved takes.

        Returns number of symlinks created.
        """
        selects_dir = self._output_root / "selects" / episode_id
        selects_dir.mkdir(parents=True, exist_ok=True)

        count = 0
        shots = store.get_shots_by_episode(episode_id)
        for shot in shots:
            if shot.get("status") not in ("approved", "video_complete"):
                continue
            output_path = shot.get("output_path")
            if not output_path:
                continue

            # Resolve to absolute path
            from recoil.core.paths import projects_root
            abs_path = projects_root() / self.project / output_path
            if not abs_path.exists():
                continue

            link_path = selects_dir / abs_path.name
            if link_path.exists() or link_path.is_symlink():
                link_path.unlink()
            link_path.symlink_to(abs_path)
            count += 1

        return count


class ProvenanceReader:
    """Reads historical provenance from ExecutionStore to inform retry strategy selection."""

    @staticmethod
    def get_prior_strategies(shot_id: str, store) -> list[str]:
        """Return strategy names previously tried for this shot, ordered oldest-first.

        Checks two sources:
        1. shot["strategies_tried"] — written by production_loop retry iterations
        2. take["provenance"]["lineage"]["change_reason"] — from ProvenanceWriter records
        Deduplicates while preserving order.
        """
        shot = store.get_shot(shot_id) if store else None
        if not shot:
            return []

        seen: set[str] = set()
        strategies: list[str] = []

        # Source 1: direct strategies_tried list on shot record (authoritative)
        for s in shot.get("strategies_tried") or []:
            if s and s not in seen:
                seen.add(s)
                strategies.append(s)

        # Source 2: provenance lineage on individual takes
        for take in shot.get("takes", []):
            provenance = take.get("provenance") or {}
            lineage = provenance.get("lineage") or {}
            reason = lineage.get("change_reason") or provenance.get("change_reason")
            if reason and reason not in seen:
                seen.add(reason)
                strategies.append(reason)

        return strategies
