"""Derivation manifest — per-episode + project-level provenance/freshness records.

A small functional module that reads/writes the derivation manifest JSON and
computes staleness purely from RECORDED shas. It does NOT rebuild artifacts or
call any pipeline stage — producers (Phase 3) stamp it, consumers (Phase 4 +
downstream D3 rebuild) read it.

Source-key contract (every producer stamps these; `freshness`/`recompute_health`
key on exactly these — MAJOR-4):
    camera_tested.source   = {"script_sha": <md5 of ep script>}
    plan.source            = {"camera_tested_content_sha": <ep camera_tested.content_sha>,
                              "bible_content_sha": <project _bible.content_sha>}
    coverage_passes.source = {"plan_structural_sha": <ep plan.structural_sha>}
    scenes.source          = {"plan_structural_sha": <ep plan.structural_sha>}

The bible is series-shared (project-level), so it gets its OWN project-level
record (`_bible.json`), NOT a per-episode stage entry (MAJOR-3).
"""
from __future__ import annotations

import json
import os
import tempfile
from datetime import datetime, timezone
from pathlib import Path

from recoil.core.paths import ProjectPaths
from recoil.pipeline._lib import derivation_sha, episode_script, plan_loader


# ── Paths ──────────────────────────────────────────────────────────

def manifest_path(project: str, episode: int) -> Path:
    return ProjectPaths.for_project(project).derivation_dir / f"ep_{episode:03d}.json"


def bible_manifest_path(project: str) -> Path:
    return ProjectPaths.for_project(project).derivation_dir / "_bible.json"


# ── Atomic write (same pattern as Phase 1 _save_json) ──────────────

def _atomic_write_json(path: Path, data: dict) -> None:
    """ATOMIC JSON write — tmp + os.replace, mkdir parents first."""
    path.parent.mkdir(parents=True, exist_ok=True)
    content = json.dumps(data, indent=2, default=str)
    fd, tmp = tempfile.mkstemp(dir=str(path.parent), suffix=".json")
    try:
        with os.fdopen(fd, "w", encoding="utf-8") as f:
            f.write(content)
            f.flush()
            os.fsync(f.fileno())
        os.replace(tmp, str(path))
    except Exception:
        try:
            os.unlink(tmp)
        except OSError:
            pass
        raise


# ── Per-episode manifest ───────────────────────────────────────────

def _skeleton(episode: int) -> dict:
    return {
        "episode": f"ep_{episode:03d}",
        "schema_version": 1,
        "inputs": {},
        "stages": {},
        "execution": {"boards": {}, "locks": {}},
        "health": {},
    }


def load(project: str, episode: int) -> dict:
    """Read the per-episode manifest, or a fresh skeleton if absent."""
    path = manifest_path(project, episode)
    if not path.exists():
        return _skeleton(episode)
    return json.loads(path.read_text(encoding="utf-8"))


def save(project: str, episode: int, manifest: dict) -> None:
    """ATOMIC write of the per-episode manifest."""
    _atomic_write_json(manifest_path(project, episode), manifest)


def stamp_stage(
    project: str,
    episode: int,
    stage: str,
    *,
    structural_sha=None,
    content_sha,
    source: dict,
    kind: str,
    builder: str,
    model=None,
    via=None,
    built_at: str | None = None,
    extra: dict | None = None,
) -> None:
    """Record a stage entry under `stages[stage]`, then save.

    `built_at` may be passed in by the caller; if omitted it is computed as an
    ISO8601 UTC timestamp.
    """
    manifest = load(project, episode)
    manifest["stages"][stage] = {
        "kind": kind,
        "content_sha": content_sha,
        "structural_sha": structural_sha,
        "source": source,
        "built_at": built_at or datetime.now(timezone.utc).isoformat(),
        "builder": builder,
        "model": model,
        "via": via,
        **(extra or {}),
    }
    save(project, episode, manifest)


# ── Project-level bible record (series-shared, MAJOR-3) ─────────────

def load_bible(project: str) -> dict:
    """Read the project-level `_bible.json` record, or {} if absent."""
    path = bible_manifest_path(project)
    if not path.exists():
        return {}
    return json.loads(path.read_text(encoding="utf-8"))


def stamp_bible(project: str, *, content_sha, builder, built_at, model=None) -> None:
    """ATOMIC write of the project-level bible record."""
    _atomic_write_json(
        bible_manifest_path(project),
        {
            "content_sha": content_sha,
            "builder": builder,
            "built_at": built_at,
            "model": model,
        },
    )


# ── Health + freshness (manifest-internal sha comparisons) ─────────

_RECOMPUTE_OWNED_FLAG_KEYS = {
    "scenes.missing_shots", "scenes.overlaps",
    "camera_tested.stale_vs_script", "plan.stale_vs_camera_tested",
    "plan.stale_vs_bible", "coverage_passes.stale_vs_plan",
    "scenes.stale_vs_plan", "scenes.stale_vs_script_spans",
}


def _camera_tested_script_fresh(project: str, episode: int, stage: dict) -> bool:
    stamped = ((stage or {}).get("source") or {}).get("script_sha")
    if not stamped:
        return False
    try:
        return stamped == episode_script.episode_script_sha(project, episode)
    except FileNotFoundError:
        return False


def _load_live_plan(project: str, episode: int):
    path = ProjectPaths.for_project(project).plans_dir / f"ep_{episode:03d}_plan.json"
    return plan_loader.load_plan(path)


def _scenes_script_spans_fresh(project: str, episode: int, stage: dict) -> bool:
    stored = (stage or {}).get("shot_script_spans")
    if not isinstance(stored, dict) or not stored:
        return False
    try:
        plan = _load_live_plan(project, episode)
    except Exception:
        return False
    live_map = {
        shot.shot_id: (shot.raw or {}).get("source_text_hash")
        for shot in plan.shots
    }
    for _scene_id, shot_spans in stored.items():
        if not isinstance(shot_spans, dict):
            return False
        for shot_id, stored_hash in shot_spans.items():
            if (
                stored_hash is None
                or shot_id not in live_map
                or live_map[shot_id] is None
                or live_map[shot_id] != stored_hash
            ):
                return False
    return True


def _live_shot_script_spans(project: str, episode: int) -> dict[str, str | None]:
    plan = _load_live_plan(project, episode)
    return {
        shot.shot_id: (shot.raw or {}).get("source_text_hash")
        for shot in plan.shots
    }


def recompute_health(project: str, episode: int) -> dict:
    """Derive `health` from the recorded stage entries, persist, and return it."""
    manifest = load(project, episode)
    stages = manifest.get("stages", {})
    bible = load_bible(project)

    camera_tested = stages.get("camera_tested") or {}
    plan = stages.get("plan") or {}
    coverage_passes = stages.get("coverage_passes") or {}
    scenes = stages.get("scenes") or {}

    complete_chain = bool(camera_tested and bible and plan and coverage_passes)

    existing_flags = (manifest.get("health") or {}).get("flags") or {}
    flags: dict = {
        k: v
        for k, v in existing_flags.items()
        if k not in _RECOMPUTE_OWNED_FLAG_KEYS
    }

    # scenes: surface recorded gap diagnostics
    if scenes:
        missing_shots = scenes.get("missing_shots")
        if missing_shots:
            flags["scenes.missing_shots"] = missing_shots
        overlaps = scenes.get("overlaps")
        if overlaps:
            flags["scenes.overlaps"] = overlaps

    if camera_tested and not _camera_tested_script_fresh(project, episode, camera_tested):
        flags["camera_tested.stale_vs_script"] = True

    # plan staleness vs its upstreams (camera_tested content + bible content)
    if plan:
        plan_source = plan.get("source") or {}
        if plan_source.get("camera_tested_content_sha") != camera_tested.get("content_sha"):
            flags["plan.stale_vs_camera_tested"] = True
        if plan_source.get("bible_content_sha") != bible.get("content_sha"):
            flags["plan.stale_vs_bible"] = True

    # coverage_passes / scenes staleness vs the plan structural_sha
    plan_structural = plan.get("structural_sha")
    if coverage_passes and (coverage_passes.get("source") or {}).get("plan_structural_sha") != plan_structural:
        flags["coverage_passes.stale_vs_plan"] = True
    if scenes and (scenes.get("source") or {}).get("plan_structural_sha") != plan_structural:
        flags["scenes.stale_vs_plan"] = True
    if scenes and not _scenes_script_spans_fresh(project, episode, scenes):
        flags["scenes.stale_vs_script_spans"] = True

    health = {"complete_chain": complete_chain, "flags": flags}
    manifest["health"] = health
    save(project, episode, manifest)
    return health


def stamp_flag(project: str, episode: int, flag: str, value) -> None:
    """Record a single health flag WITHOUT recomputing the chain — persists an
    abort reason (e.g. 'location.unresolved') when a stage fails before its
    artifact/stage entry is written. Atomic via save()."""
    manifest = load(project, episode)
    health = manifest.setdefault("health", {})
    health.setdefault("flags", {})[flag] = value
    save(project, episode, manifest)


def stamp_board(
    project: str,
    episode: int,
    shotset_hash: str,
    record: dict,
    *,
    manifest: dict | None = None,
) -> None:
    """Write a board-approval record into the execution SSOT, keyed by
    shotset_hash (D2/L2). Atomic via save(). The SSOT for 'is this shot-set's
    board approved', addressed by shot-set identity (survives regroup/relabel)."""
    # Single-writer assumption: load->mutate->save is atomic at file-replace
    # only, NOT against a concurrent writer of the same manifest (last-writer
    # wins). Board approval is single-writer/per-episode-serial, so no lock.
    m = load(project, episode) if manifest is None else manifest
    execution = m.setdefault("execution", {})
    execution.setdefault("boards", {})[shotset_hash] = record
    save(project, episode, m)


def get_board(project: str, episode: int, shotset_hash: str) -> "dict | None":
    """Read a board record from the execution SSOT, or None if absent."""
    manifest = load(project, episode)
    return (manifest.get("execution") or {}).get("boards", {}).get(shotset_hash)


def board_freshness(project: str, episode: int) -> list[tuple[str, bool, str | None]]:
    """Freshness for manifest.execution.boards records.

    Boards are not derivation stages, so this walks the normal upstream scenes
    chain first, then compares each board record's covered shot span hashes
    against the live plan. Unprovable records fail closed as stale.
    """
    manifest = load(project, episode)
    boards = (manifest.get("execution") or {}).get("boards") or {}
    if not boards:
        return []

    upstream_fresh, upstream_broken = freshness(project, episode, "scenes")
    if not upstream_fresh:
        return [
            (shotset_hash, False, upstream_broken or "scenes")
            for shotset_hash in boards
        ]

    try:
        live_spans = _live_shot_script_spans(project, episode)
    except Exception:
        live_spans = None

    results: list[tuple[str, bool, str | None]] = []
    for shotset_hash, record in boards.items():
        stored_sha = (record or {}).get("content_freshness_sha")
        covered = (record or {}).get("covered_shot_ids")
        if not stored_sha or not covered:
            results.append((shotset_hash, False, "board"))
            continue
        if live_spans is None:
            results.append((shotset_hash, False, "board"))
            continue

        spans: dict[str, str | None] = {}
        unprovable = False
        for shot_id in covered:
            if shot_id not in live_spans or live_spans[shot_id] is None:
                unprovable = True
                spans[shot_id] = None
            else:
                spans[shot_id] = live_spans[shot_id]
        if unprovable:
            results.append((shotset_hash, False, "board"))
            continue

        current_sha = derivation_sha.board_content_freshness_sha(spans)
        if current_sha != stored_sha:
            results.append((shotset_hash, False, "board"))
        else:
            results.append((shotset_hash, True, None))
    return results


def clear_flag(project: str, episode: int, flag: str) -> None:
    """Remove a previously-stamped health flag (no-op if absent). Used by the
    rederive success path to clear a stale 'location.unresolved' once the plan
    re-derives cleanly. Atomic via save()."""
    manifest = load(project, episode)
    flags = (manifest.get("health") or {}).get("flags") or {}
    if flag in flags:
        del flags[flag]
        manifest.setdefault("health", {})["flags"] = flags
        save(project, episode, manifest)


def _chain_for(stage: str) -> list[str]:
    """The dependency-chain prefix to walk for a requested `stage`.

    coverage_passes and scenes are plan-siblings (both depend on plan); camera_tested
    is the chain root (its script_sha upstream is external, not recomputable here).
    """
    if stage in ("coverage_passes", "scenes"):
        return ["camera_tested", "plan", stage]
    if stage == "plan":
        return ["camera_tested", "plan"]
    return [stage]


def freshness(project: str, episode: int, stage: str) -> tuple[bool, str | None]:
    """Walk camera_tested→plan→coverage_passes/scenes up to `stage`, comparing each
    stage's recorded source sha against the current upstream value (read from the
    manifest; bible via `load_bible`). Return `(False, <first_broken_stage>)` at the
    first stale link, else `(True, None)`. A missing requested-stage entry = `(False, stage)`.
    """
    manifest = load(project, episode)
    stages = manifest.get("stages", {})
    camera_tested = stages.get("camera_tested") or {}
    plan = stages.get("plan") or {}

    def _link_fresh(s: str) -> bool:
        src = (stages.get(s) or {}).get("source") or {}
        if s == "camera_tested":
            return _camera_tested_script_fresh(project, episode, stages.get(s) or {})
        if s == "plan":
            return (
                src.get("camera_tested_content_sha") == camera_tested.get("content_sha")
                and src.get("bible_content_sha") == load_bible(project).get("content_sha")
            )
        if s == "coverage_passes":
            return src.get("plan_structural_sha") == plan.get("structural_sha")
        if s == "scenes":
            return (
                src.get("plan_structural_sha") == plan.get("structural_sha")
                and _scenes_script_spans_fresh(project, episode, stages.get(s) or {})
            )
        # Unknown stage: nothing to compare.
        return True

    for s in _chain_for(stage):
        if s == stage and s not in stages:
            return (False, stage)
        if not _link_fresh(s):
            return (False, s)
    return (True, None)


__all__ = [
    "manifest_path",
    "bible_manifest_path",
    "load",
    "save",
    "stamp_stage",
    "load_bible",
    "stamp_bible",
    "recompute_health",
    "stamp_flag",
    "stamp_board",
    "get_board",
    "board_freshness",
    "clear_flag",
    "freshness",
]
