#!/usr/bin/env python3
"""Build coverage passes from a shot plan (CLI wrapper).

Reads ep_NNN_plan.json, groups shots into scene-level coverage passes via
orchestrator.coverage_planner, and writes to ep_NNN_passes.json.

With --coverage flag, also runs dramatic intensity analysis and generates
coverage angles (reactions, cutaways, wide safeties) for peak shots.

Usage:
    python3 tools/build_coverage_passes.py --project tartarus --episode 1 --dry-run
    python3 tools/build_coverage_passes.py --project tartarus --episode 1 --lock
    python3 tools/build_coverage_passes.py --project tartarus --episode 1 --coverage --dry-run
    python3 tools/build_coverage_passes.py --project tartarus --episode 1 --coverage --lock
"""

import argparse
import json
import logging
import re
import sys
from pathlib import Path

# Planner-stable prefix matching the format produced by coverage_planner.build_passes:
#   EP{ep:03d}_PASS_{counter:03d}_{side}_{focus}{_format?}
# Used by reconciliation to tell "preserved replacement of an auto pass" from
# "manually-added pass that coexists with an auto pass on the same shots".
_PASS_SLOT_RE = re.compile(r"^(EP\d{3}_PASS_\d{3}_[AB]_[A-Z]+)")


def _pass_slot(pass_id: str) -> str:
    """Return the planner-stable slot prefix from a pass_id.

    Strips an optional trailing `_format` suffix (e.g. `_B`). For pass_ids
    that don't match the planner's pattern (e.g. `EP001_PASS_002_TEST_480P_FAL`),
    returns the full pass_id — guaranteeing manual entries get unique slots
    that never collide with auto-generated ones.
    """
    m = _PASS_SLOT_RE.match(pass_id or "")
    return m.group(1) if m else (pass_id or "")


def _should_preserve(pass_dict: dict) -> bool:
    """Return True if a pass should survive reconciliation on regen.

    Preserved when ANY of:
      1. status in ("edited", "rendered") — explicit Console-side edits.
      2. generation_config.prompt_override is non-empty — hand-written prompt.
      3. generation_config.ref_override is non-empty — hand-picked refs.

    Rules 2-3 catch passes hand-edited in JSON (where status would still be
    'draft' because only the Console writes 'edited'). Without them, custom
    overrides on a draft pass are silently overwritten on regen — that bug
    cost the PASS_017 winch-escape overrides on 2026-04-19.
    """
    if pass_dict.get("status") in ("edited", "rendered"):
        return True
    gc = pass_dict.get("generation_config", {}) or {}
    if gc.get("prompt_override"):
        return True
    if gc.get("ref_override"):
        return True
    return False


# ── Path setup ──
PIPELINE_ROOT = Path(__file__).resolve().parent.parent
_RECOIL_ROOT = PIPELINE_ROOT.parent  # recoil/pipeline/../ = recoil/
for _p in (str(PIPELINE_ROOT), str(_RECOIL_ROOT)):
    if _p not in sys.path:
        sys.path.insert(0, _p)

from recoil.execution.step_types import ProjectPaths  # noqa: E402
from orchestrator.coverage_planner import build_passes  # noqa: E402
from recoil.pipeline._lib import derivation_manifest  # noqa: E402
from recoil.pipeline._lib.derivation_sha import (  # noqa: E402
    content_sha,
    plan_structural_sha,
)

logger = logging.getLogger(__name__)


def load_plan(project: str, episode: int) -> list[dict]:
    plan_path = (
        ProjectPaths.for_episode(project, episode).plans_dir
        / f"ep_{episode:03d}_plan.json"
    )
    if not plan_path.exists():
        print(f"ERROR: Plan not found at {plan_path}")
        sys.exit(1)
    try:
        plan = json.loads(plan_path.read_text(encoding="utf-8"))
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON in {plan_path}: {e}")
        sys.exit(1)
    return plan.get("shots", [])


def load_overrides(project: str, episode: int) -> dict:
    ov_path = (
        ProjectPaths.for_episode(project, episode).plans_dir
        / f"ep_{episode:03d}_overrides.json"
    )
    if ov_path.exists():
        try:
            return json.loads(ov_path.read_text(encoding="utf-8"))
        except (json.JSONDecodeError, IOError):
            pass
    return {}


def merge_overrides(shot: dict, overrides: dict) -> dict:
    from recoil.pipeline._lib.previz_context import apply_overrides

    ov = overrides.get(shot.get("shot_id", ""), {})
    if not ov:
        return shot
    return apply_overrides(shot, ov)


def _print_summary(passes, shots):
    """Print a human-readable pass summary."""
    total_dur = sum(p.duration_s for p in passes)
    char_passes = [p for p in passes if p.pass_type == "character"]
    env_passes = [p for p in passes if p.pass_type == "env"]
    multi_seg = [p for p in passes if len(p.segments) > 1]

    print(f"\n{'=' * 60}")
    print(f"  {len(passes)} passes from {len(shots)} shots")
    print(f"  {len(char_passes)} character / {len(env_passes)} env")
    print(f"  {len(multi_seg)} multi-segment / {len(passes) - len(multi_seg)} single")
    print(f"  Total duration: {total_dur}s")
    print(f"{'=' * 60}\n")

    for p in passes:
        segs = ", ".join(f"{s.source_shot_id}({s.shot_type})" for s in p.segments)
        cfg = p.generation_config.get("cfg_scale", "?")
        sf = "yes" if p.generation_config.get("start_frame_path") else "no"
        chars = p.character_count
        print(f"  {p.pass_id}: {p.label}")
        print(
            f"    [{p.duration_s}s, {len(p.segments)} segs, {chars} chars, cfg={cfg}, sf={sf}]"
        )
        print(f"    {segs}")


def _print_coverage_report(
    tier_map, score_map, moments, shots, actual_angle_count=None
):
    """Print coverage density analysis report."""
    tier_names = {0: "Valley", 1: "Rising", 2: "Peak", 3: "Climax"}
    tier_counts = {0: 0, 1: 0, 2: 0, 3: 0}
    for t in tier_map.values():
        tier_counts[t] = tier_counts.get(t, 0) + 1

    total_angles = (
        actual_angle_count
        if actual_angle_count is not None
        else sum(len(m.coverage_types) for m in moments)
    )

    print(f"\n{'=' * 60}")
    print("  COVERAGE DENSITY ANALYSIS")
    print(f"{'=' * 60}")
    print(f"  {len(shots)} shots analyzed")
    for tier_num in sorted(tier_counts.keys()):
        name = tier_names.get(tier_num, f"Tier {tier_num}")
        count = tier_counts[tier_num]
        print(f"    Tier {tier_num} ({name}): {count} shots")
    print(f"\n  {len(moments)} coverage moments identified")
    print(f"  {total_angles} coverage angles to generate")
    print(f"{'=' * 60}\n")

    for m in moments:
        tier_label = tier_names.get(m.tier, f"Tier {m.tier}")
        anchor_skeleton = m.anchor_shot.get("prompt_data", {}).get(
            "prompt_skeleton", {}
        )
        emotion = anchor_skeleton.get("emotion_line", "")
        shot_range = (
            f"{m.shot_ids[0]}"
            if len(m.shot_ids) == 1
            else f"{m.shot_ids[0]}-{m.shot_ids[-1]}"
        )

        print(f"  {m.moment_id} ({shot_range}) [{tier_label}] {m.scene_label}")
        print(f"    Anchor: {m.anchor_shot_id}")
        if emotion:
            print(f'    Emotion: "{emotion}"')
        if m.coverage_types:
            print(f"    Coverage: {', '.join(m.coverage_types)}")
        else:
            print("    Coverage: none (hero only)")
        print()


def _generate_coverage_shots(moments, project, all_shots=None):
    """Generate synthetic coverage shot dicts from moments."""
    from recoil.pipeline._lib.prompt_engine import derive_coverage_shot

    # Build shot lookup if full list provided
    shot_lookup = {}
    if all_shots:
        shot_lookup = {s.get("shot_id", ""): s for s in all_shots}

    coverage_shots = []
    for moment in moments:
        if not moment.coverage_types:
            continue

        anchor = moment.anchor_shot
        # Tag the project for editorial priors lookup
        anchor["_project"] = project

        # Collect character IDs from ALL shots in the moment, not just anchor
        all_chars = set()
        for sid in moment.shot_ids:
            shot = shot_lookup.get(sid, anchor)
            for c in shot.get("asset_data", {}).get("characters") or []:
                cid = (c.get("char_id", "") if isinstance(c, dict) else str(c)).upper()
                if cid:
                    all_chars.add(cid)

        for cov_type in moment.coverage_types:
            derived = derive_coverage_shot(
                anchor,
                cov_type,
                all_characters=list(all_chars),
            )
            if derived:
                derived["_moment_id"] = moment.moment_id
                derived["_coverage_type"] = cov_type
                derived["_anchor_shot_id"] = moment.anchor_shot_id
                coverage_shots.append(derived)

    return coverage_shots


def main():
    parser = argparse.ArgumentParser(description="Build coverage passes from shot plan")
    parser.add_argument("--project", required=True, help="Project name (e.g. tartarus)")
    parser.add_argument("--episode", type=int, required=True, help="Episode number")
    parser.add_argument(
        "--dry-run", action="store_true", help="Preview passes without saving"
    )
    parser.add_argument("--lock", action="store_true", help="Save passes (Script Lock)")
    parser.add_argument(
        "--force",
        action="store_true",
        help="Regenerate all passes, overwriting edited ones",
    )
    parser.add_argument(
        "--coverage", action="store_true", help="Run coverage density analysis"
    )
    args = parser.parse_args()

    if not args.dry_run and not args.lock:
        print("ERROR: Must specify --dry-run or --lock")
        sys.exit(1)

    shots = load_plan(args.project, args.episode)
    overrides = load_overrides(args.project, args.episode)
    shots = [merge_overrides(s, overrides) for s in shots]

    # Standard pass building (always runs)
    passes = build_passes(shots, args.project, args.episode)

    # Coverage density analysis (optional)
    moments = []
    coverage_shots = []
    if args.coverage:
        from orchestrator.coverage_density import analyze_coverage_density

        # Load tier overrides if they exist
        tier_overrides_path = (
            ProjectPaths.for_episode(args.project, args.episode).plans_dir
            / f"ep_{args.episode:03d}_tier_overrides.json"
        )
        tier_overrides = {}
        if tier_overrides_path.exists():
            try:
                tier_overrides = json.loads(
                    tier_overrides_path.read_text(encoding="utf-8")
                )
            except (json.JSONDecodeError, IOError):
                pass

        tier_map, score_map, moments = analyze_coverage_density(shots, tier_overrides)

        if args.dry_run:
            # Generate coverage shots first to get accurate count
            coverage_shots = _generate_coverage_shots(
                moments, args.project, all_shots=shots
            )
            _print_coverage_report(
                tier_map,
                score_map,
                moments,
                shots,
                actual_angle_count=len(coverage_shots),
            )
            if coverage_shots:
                print(f"  COVERAGE SHOT PREVIEW ({len(coverage_shots)} angles):")
                print(f"  {'─' * 56}")
                for cs in coverage_shots:
                    skeleton = cs.get("prompt_data", {}).get("prompt_skeleton", {})
                    subj = skeleton.get("subject_line", "")[:50]
                    emotion = skeleton.get("emotion_line", "")[:40]
                    print(f"    {cs['shot_id']}: {cs['prompt_data']['shot_type']}")
                    print(f"      Subject: {subj}")
                    if emotion:
                        print(f"      Emotion: {emotion}")
                print()

        if args.lock:
            coverage_shots = _generate_coverage_shots(
                moments, args.project, all_shots=shots
            )

    if args.dry_run:
        print(json.dumps([p.to_dict() for p in passes], indent=2))
        _print_summary(passes, shots)
        if args.coverage:
            total_angles = len(coverage_shots)
            print(f"\n  Coverage: {total_angles} angles")
            print(f"  Hero shots: {len(shots)}")
            print(f"  TOTAL: {len(shots) + total_angles} shots")
        return

    if args.lock:
        # Validate before locking
        try:
            from orchestrator.coverage_validator import validate_all_passes, Severity

            results = validate_all_passes(passes)
            blockers = [r for r in results if r.severity == Severity.BLOCK]
            warnings = [r for r in results if r.severity == Severity.WARN]

            if blockers:
                print(f"\nBLOCKED — {len(blockers)} validation errors prevent lock:\n")
                for r in blockers:
                    print(f"  BLOCK: [{r.pass_id}] {r.message}")
                if warnings:
                    print(f"\n  ({len(warnings)} warnings also found)")
                sys.exit(1)

            if warnings:
                print(f"\n{len(warnings)} warnings (non-blocking):")
                for r in warnings:
                    print(f"  WARN: [{r.pass_id}] {r.message}")
                print()
        except ImportError:
            pass  # Validator not yet built — allow lock without validation

        out_dir = ProjectPaths.for_episode(
            args.project, args.episode
        ).coverage_passes_dir
        out_dir.mkdir(parents=True, exist_ok=True)
        out_path = out_dir / f"ep_{args.episode:03d}_passes.json"

        final_passes_data = [p.to_dict() for p in passes]

        # Preservation predicate is module-level (`_should_preserve`) — see
        # its docstring for the full rule.

        if not args.force and out_path.exists():
            try:
                existing = json.loads(out_path.read_text(encoding="utf-8"))
                # dict[frozenset, list[pass]] — multiple preserved passes can
                # share the same shot ids (e.g. A/B test fixtures). Collapsing
                # to a single value silently drops duplicates.
                edited_by_shots: dict[frozenset, list[dict]] = {}
                for p in existing:
                    if _should_preserve(p):
                        shot_ids = frozenset(
                            s.get("source_shot_id") for s in p.get("segments", [])
                        )
                        edited_by_shots.setdefault(shot_ids, []).append(p)

                if edited_by_shots:
                    preserved, regenerated, orphaned, coexisted = [], [], [], []
                    final_passes_data = []

                    for new_pass in [p.to_dict() for p in passes]:
                        new_shots = frozenset(
                            s["source_shot_id"] for s in new_pass.get("segments", [])
                        )
                        if new_shots in edited_by_shots:
                            new_slot = _pass_slot(new_pass.get("pass_id", ""))
                            candidates = edited_by_shots.pop(new_shots)
                            # Replacement: preserved pass with the SAME planner
                            # slot — it supersedes the auto-generated one.
                            # Coexist: preserved pass with a DIFFERENT slot
                            # (manual add: A/B fixtures, alt edits) — kept
                            # alongside the auto-generated pass.
                            replacements = [
                                c
                                for c in candidates
                                if _pass_slot(c.get("pass_id", "")) == new_slot
                            ]
                            coexists = [
                                c
                                for c in candidates
                                if _pass_slot(c.get("pass_id", "")) != new_slot
                            ]
                            if replacements:
                                for r in replacements:
                                    final_passes_data.append(r)
                                    preserved.append(r.get("pass_id"))
                            else:
                                final_passes_data.append(new_pass)
                                regenerated.append(new_pass.get("pass_id"))
                            for c in coexists:
                                final_passes_data.append(c)
                                coexisted.append(c.get("pass_id"))
                        else:
                            final_passes_data.append(new_pass)
                            regenerated.append(new_pass.get("pass_id"))

                    # Orphaned passes: preserved passes whose source shots no
                    # longer match any newly-generated pass.
                    for shot_ids, edited_list in edited_by_shots.items():
                        for edited in edited_list:
                            edited["_orphaned"] = True
                            edited.setdefault("warnings", []).append(
                                {
                                    "severity": "WARN",
                                    "check": "orphaned_pass",
                                    "message": "Source shots no longer match any generated pass",
                                }
                            )
                            final_passes_data.append(edited)
                            orphaned.append(edited.get("pass_id"))

                    print("\n  Reconciliation:")
                    print(f"    Preserved (slot replacement): {len(preserved)}")
                    print(f"    Coexisted (manual add, same shots): {len(coexisted)}")
                    print(f"    Regenerated (draft): {len(regenerated)}")
                    print(f"    Orphaned: {len(orphaned)}")
                    for pid in preserved:
                        print(f"      PRESERVED: {pid}")
                    for pid in coexisted:
                        print(f"      COEXISTED: {pid}")
                    for pid in orphaned:
                        print(f"      ORPHANED: {pid}")
            except (json.JSONDecodeError, IOError):
                pass

        # Dry-run preservation preview — same predicate as the lock path.
        if args.dry_run and out_path.exists():
            try:
                existing = json.loads(out_path.read_text(encoding="utf-8"))
                edited = [p for p in existing if _should_preserve(p)]
                if edited:
                    print(f"\n  {len(edited)} passes would be PRESERVED on --lock")
                    for p in edited:
                        gc = p.get("generation_config", {}) or {}
                        reason = (
                            "status=" + p.get("status", "?")
                            if p.get("status") in ("edited", "rendered")
                            else (
                                "prompt_override"
                                if gc.get("prompt_override")
                                else "ref_override"
                            )
                        )
                        print(
                            f"    {p.get('pass_id')}: {p.get('label', '')} [{reason}]"
                        )
            except (json.JSONDecodeError, IOError):
                pass

        if not args.dry_run:
            out_path.write_text(
                json.dumps(final_passes_data, indent=2), encoding="utf-8"
            )
            print(f"LOCKED: {len(final_passes_data)} passes written to {out_path}")

            # Stamp the coverage_passes manifest entry — the producer side of the
            # D3 staleness guard (C1, REC-164 Phase 3). The Phase-4 guard compares
            # the current plan's structural_sha against this recorded value, so we
            # hash the SAME full plan dict the guard reads. Best-effort-but-loud:
            # a manifest write failure FAILs the lock (MINOR-7).
            try:
                plan_path = (
                    ProjectPaths.for_episode(args.project, args.episode).plans_dir
                    / f"ep_{args.episode:03d}_plan.json"
                )
                plan_dict = json.loads(plan_path.read_text(encoding="utf-8"))
                psha = plan_structural_sha(plan_dict)
                covers_shots = sorted({
                    seg.get("source_shot_id")
                    for p in final_passes_data
                    for seg in p.get("segments", [])
                    if seg.get("source_shot_id")
                })
                derivation_manifest.stamp_stage(
                    args.project,
                    args.episode,
                    "coverage_passes",
                    kind="derived",
                    structural_sha=None,
                    content_sha=content_sha({"passes": final_passes_data}),
                    source={"plan_structural_sha": psha},
                    builder="build_coverage_passes --lock",
                    extra={"covers_shots": covers_shots},
                )
            except Exception:
                logger.error(
                    "Failed to stamp coverage_passes manifest for ep_%03d",
                    args.episode,
                    exc_info=True,
                )
                raise

        # Write coverage shots if --coverage
        if args.coverage and coverage_shots:
            cov_path = out_dir / f"ep_{args.episode:03d}_coverage.json"
            cov_data = []
            for cs in coverage_shots:
                # Strip internal metadata before saving
                entry = {k: v for k, v in cs.items() if not k.startswith("_")}
                entry["_moment_id"] = cs.get("_moment_id", "")
                entry["_coverage_type"] = cs.get("_coverage_type", "")
                entry["_anchor_shot_id"] = cs.get("_anchor_shot_id", "")
                cov_data.append(entry)
            cov_path.write_text(json.dumps(cov_data, indent=2), encoding="utf-8")
            print(
                f"LOCKED: {len(coverage_shots)} coverage angles written to {cov_path}"
            )

        _print_summary(passes, shots)


if __name__ == "__main__":
    main()
