#!/usr/bin/env python3
"""Honor-rate probe — does gpt-image-2 honor the derived screen-direction on boards?

Measures whether the board model places two-subject shots per the REC-180-derived
screen_direction. The Ep1 two-subject shots span a useful natural spread of expected placements
(foreground side + camera_side vary across SH31–39 — see --manifest), so the
production boards give a built-in compliance signal without a synthetic arm. NOTE:
these are OTS shots, so the discriminant is foreground-side + depth, NOT a clean
binary L/R split — it's a real check, not a perfectly balanced 50/50 control. A
separately generated flipped-instruction arm (deferred) would add discriminating
power if the natural spread proves weak.

This tool owns the VERIFIABLE, FREE half:
  --manifest   resolve the two-subject shots + their derived directions + the exact
               EXPECTED positions production scores against (via the live
               spatial_compliance helpers) -> honor_rate_probe_manifest.json + table.
  --score      score one already-generated board panel image against a shot's
               instruction via the live run_spatial_compliance(); honored = no
               position/depth VIOLATION flag (a solo panel always carries an INFO
               continuity flag, so overall severity != PASS even when honored).

The PAID half (generating the boards) goes through the normal pipeline on Studio:
  python3 recoil/pipeline/cli/generate.py --project tartarus --episode 1 --storyboard <batch>
(the --manifest output prints the batch selectors + the per-shot score commands).
Nothing here spends money; --score needs GEMINI_API_KEY (~$0.01/panel) and runs
on Studio where the boards + key live.

Thresholds (per PROBE_SPEC): >=75% honored -> text-prompt enforcement is enough;
40–74% -> reference-conditioning; <40% -> placement-conditioning.
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

_REPO_ROOT = (
    Path(__file__).resolve().parents[3]
)  # .../CLAUDE_PROJECTS — bootstrap `recoil.*`
if str(_REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(_REPO_ROOT))

from recoil.core import paths as core_paths  # noqa: E402
from recoil.pipeline._lib.spatial_compliance import (  # noqa: E402
    _build_character_anchor,
    _get_expected_positions,
    run_spatial_compliance,
)

_OUT = Path(__file__).resolve().parents[2] / "docs" / "honor_rate_probe_manifest.json"


def _project_root(project: str) -> Path:
    return Path(core_paths.projects_root()).expanduser() / project


def _load_plan(project: str, episode: int) -> dict:
    p = (
        _project_root(project)
        / "_pipeline"
        / "state"
        / "visual"
        / "plans"
        / f"ep_{episode:03d}_plan.json"
    )
    if not p.is_file():
        raise FileNotFoundError(f"plan not found: {p}")
    return json.loads(p.read_text())


def _load_bible(project: str) -> dict:
    p = _project_root(project) / "_pipeline" / "state" / "visual" / "global_bible.json"
    if not p.is_file():
        raise FileNotFoundError(f"global bible not found: {p}")
    return json.loads(p.read_text())


def _shots(plan: dict) -> list[dict]:
    return plan.get("shots") or plan.get("shot_records") or []


def _char_ids(shot: dict) -> list[str]:
    chars = (shot.get("asset_data") or {}).get("characters") or []
    return sorted(
        {c.get("char_id") for c in chars if isinstance(c, dict) and c.get("char_id")}
    )


def _two_subject_shots(plan: dict) -> list[dict]:
    return [s for s in _shots(plan) if len(_char_ids(s)) >= 2]


def build_manifest(project: str, episode: int) -> dict:
    plan, bible = _load_plan(project, episode), _load_bible(project)
    entries = []
    for s in _two_subject_shots(plan):
        sd = s.get("spatial_data") or {}
        chars = (s.get("asset_data") or {}).get("characters") or []
        entries.append(
            {
                "shot_id": s.get("shot_id"),
                "scene_index": s.get("scene_index"),
                "shot_type": (s.get("prompt_data") or {}).get("shot_type"),
                "screen_direction": sd.get("screen_direction"),
                "camera_side": sd.get("camera_side"),
                "axis_segment_id": sd.get("axis_segment_id"),
                "characters": _char_ids(s),
                "character_anchors": {
                    (c.get("char_id") or f"c{i}"): _build_character_anchor(c, bible)
                    for i, c in enumerate(chars)
                    if isinstance(c, dict)
                },
                "expected_positions": _get_expected_positions(s, bible),
            }
        )
    # group the natural L/R split (the built-in control)
    by_dir: dict[str, list[str]] = {}
    for e in entries:
        by_dir.setdefault(e["screen_direction"] or "?", []).append(e["shot_id"])
    return {
        "project": project,
        "episode": episode,
        "two_subject_shot_count": len(entries),
        "direction_split": by_dir,
        "thresholds": {
            ">=0.75": "text-prompt enforcement sufficient",
            "0.40-0.74": "reference-conditioning",
            "<0.40": "placement-conditioning",
        },
        "board_model": "gpt-image-2 (via image.storyboard role) @ quality:high size:half",
        "shots": entries,
    }


def cmd_manifest(args) -> int:
    m = build_manifest(args.project, args.episode)
    _OUT.write_text(json.dumps(m, indent=2))
    print(f"=== Honor-rate probe manifest — {args.project} ep{args.episode} ===")
    print(
        f"two-subject shots: {m['two_subject_shot_count']}  |  model: {m['board_model']}"
    )
    print(
        f"expected-placement spread (natural check, not a clean 50/50 control): {m['direction_split']}"
    )
    print(f"{'shot':<14}{'dir':<16}{'side':<6}{'type':<6}expected positions")
    for e in m["shots"]:
        exp = ", ".join(
            f"{k}={v.get('position')}"
            for k, v in (e["expected_positions"] or {}).items()
        )
        print(
            f"{e['shot_id']:<14}{str(e['screen_direction']):<16}{str(e['camera_side']):<6}"
            f"{str(e['shot_type']):<6}{exp}"
        )
    print(f"\nmanifest -> {_OUT}")
    print(
        "\nNEXT (paid, on Studio): generate the boards for the batch(es) covering these shots:"
    )
    print(
        "  python3 recoil/pipeline/cli/generate.py --project {p} --episode {e} --storyboard <batch>".format(
            p=args.project, e=args.episode
        )
    )
    print("then score each two-subject panel:")
    print(
        "  python3 recoil/pipeline/tools/honor_rate_probe.py --score <panel.png> --shot <SHOT_ID> "
        "--project {p} --episode {e}".format(p=args.project, e=args.episode)
    )
    return 0


def cmd_score(args) -> int:
    plan, bible = _load_plan(args.project, args.episode), _load_bible(args.project)
    shot = next((s for s in _shots(plan) if s.get("shot_id") == args.shot), None)
    if shot is None:
        print(f"shot {args.shot} not in plan", file=sys.stderr)
        return 1
    image_data = Path(args.score).read_bytes()
    result = run_spatial_compliance(image_data, shot, bible)
    flags = result.get("flags") or []
    # Honor = the board placed characters per the instruction. A single-panel score
    # has no previous shot, so run_spatial_compliance adds an INFO continuity flag
    # (and a CENTER->side miss is only POSITION_DRIFT/INFO). Only VIOLATION flags are
    # true honor failures: PROMPT_MISMATCH (wrong horizontal side), DEPTH_INVERSION
    # (OTS FG/BG swapped), MISSING_CHARACTER. So key off VIOLATION, not overall severity.
    violations = [f for f in flags if f.get("severity") == "VIOLATION"]
    honored = not violations
    print(
        json.dumps(
            {
                "shot": args.shot,
                "screen_direction": (shot.get("spatial_data") or {}).get(
                    "screen_direction"
                ),
                "expected_positions": _get_expected_positions(shot, bible),
                "honored": honored,
                "violations": violations,
                "all_flags": flags,
                "extracted": result.get("extracted"),
            },
            indent=2,
        )
    )
    return 0 if honored else 2


# --------------------------------------------------------------------------
# A-2: single-shot generation arm (instruction + flipped-instruction mirror)
# --------------------------------------------------------------------------
# Slices the live multi-shot batch that owns each two-subject shot down to a
# single-shot r2v_multi beat, writes it to a throwaway BATCH_9NN scene slot, and
# dispatches it through the normal board path with slots_override=1 (one 9:16
# panel). The mirror arm flips screen_direction so we can tell whether the model
# reads the text channel or just defaults. dry_run (default) builds the real
# prompt with ZERO image spend; --live dispatches on Studio.

_DIR_FLIP = {"left-to-right": "right-to-left", "right-to-left": "left-to-right"}
_PROBE_BATCH_BASE = 900  # temp scenes land in BATCH_901+ (real batches are 001-011)


def _ep_token(episode: int) -> str:
    return f"ep_{episode:03d}"


def _find_source_batch_path(project: str, episode: int, shot_id: str) -> Path:
    """Path of the live batch scene whose batch_shots contains shot_id."""
    from recoil.pipeline.core.persistence import load_scene

    scenes_dir = core_paths.ProjectPaths.for_project(project).orchestration_scenes_dir
    ep = _ep_token(episode)
    for p in sorted(scenes_dir.glob(f"{ep}_BATCH_[0-9][0-9][0-9].json")):
        try:
            scene = load_scene(p)
        except Exception:
            continue
        for beat in scene.beats:
            md = beat.beat_metadata or {}
            if any(s.get("shot_id") == shot_id for s in md.get("batch_shots", [])):
                return p
    raise SystemExit(f"no live batch contains {shot_id}")


def _slice_probe_scene(src_path: Path, shot_id: str, new_seq_id: str, *, mirror: bool) -> dict:
    """Raw-JSON slice of src batch down to a single-shot r2v_multi beat."""
    d = json.loads(src_path.read_text())
    d["scene_id"] = new_seq_id
    beat = next(
        b for b in d["beats"] if (b.get("beat_metadata") or {}).get("modality") == "r2v_multi"
    )
    md = beat["beat_metadata"]
    target = next(s for s in md["batch_shots"] if s.get("shot_id") == shot_id)
    if mirror:
        sd = target.setdefault("raw", {}).setdefault("spatial_data", {})
        # Flip the axis that actually drives placement. For OTS two-subject
        # shots that is camera_side (A<->B → FG side LEFT<->RIGHT in
        # _resolve_ots_assignment); screen_direction drives the wide-shot path.
        # NORMALIZE first — _normalize_camera_side maps missing/"center"/"left"
        # → "A", "right" → "B"; flipping the literal value would no-op on the
        # many shots stored as camera_side="center" (codex finding, REC-180).
        from recoil.pipeline._lib.prompt_engine import _normalize_camera_side
        norm = _normalize_camera_side(sd.get("camera_side", "A"))
        sd["camera_side"] = "B" if norm == "A" else "A"
        cur = sd.get("screen_direction")
        if cur in _DIR_FLIP:
            sd["screen_direction"] = _DIR_FLIP[cur]
    md["batch_shots"] = [target]
    md["shot"] = target
    d["beats"] = [beat]
    return d


def _shot_direction(shot_dict: dict) -> str:
    return (
        (shot_dict.get("raw", {}).get("spatial_data", {}) or {}).get("screen_direction") or "?"
    )


def cmd_generate(args) -> int:
    from recoil.pipeline._lib.board_builder import build_and_dispatch_board
    from recoil.pipeline.core.persistence import scene_path

    project, episode = args.project, args.episode
    ep = _ep_token(episode)
    dry = not args.live
    step_runner = None
    if not dry:
        from recoil.pipeline.cli.generate import _build_step_runner_for_episode

        step_runner = _build_step_runner_for_episode(project, episode)

    shots = [s.get("shot_id") for s in _two_subject_shots(_load_plan(project, episode))]
    arms = ["instruction"] if args.no_mirror else ["instruction", "mirror"]
    n = _PROBE_BATCH_BASE
    results, written, failures = [], [], 0
    print(
        f"=== honor-rate probe generate ({'DRY-RUN, free' if dry else 'LIVE, paid on Studio'}) ===\n"
        f"two-subject shots: {len(shots)} x {len(arms)} arm(s) = {len(shots) * len(arms)} boards"
    )
    try:
        for shot_id in shots:
            src = _find_source_batch_path(project, episode, shot_id)
            for arm in arms:
                n += 1
                seq_id = f"BATCH_{n:03d}"
                selector = f"EP{episode:03d}_CONT_{n:03d}"
                sliced = _slice_probe_scene(src, shot_id, seq_id, mirror=(arm == "mirror"))
                direction = _shot_direction(sliced["beats"][0]["beat_metadata"]["shot"])
                p = scene_path(project, ep, seq_id)
                p.write_text(json.dumps(sliced))
                written.append(p)
                res = build_and_dispatch_board(
                    project, episode, selector,
                    step_runner=step_runner, dry_run=dry, slots_override=1,
                )
                board = res.get("artifact") or res.get("board_path") if not dry else None
                # A live dispatch that didn't succeed (or produced no artifact)
                # is a FAILED paid arm — don't silently record it as processed.
                ok = dry or (res.get("success") and board)
                if not ok and not dry:
                    failures += 1
                rec = {
                    "shot": shot_id, "arm": arm, "selector": selector,
                    "instructed_direction": direction,
                    # The MUTATED shot the board was actually generated from
                    # (mirror flips camera_side/screen_direction). Scoring MUST
                    # compute expected positions from THIS, not the original
                    # plan shot — else the mirror arm scores against the
                    # un-flipped expectation (codex finding, REC-180).
                    "scored_shot": sliced["beats"][0]["beat_metadata"]["shot"].get("raw") or {},
                    "prompt_excerpt": (res.get("prompt") or "")[:200] if dry else None,
                    "board": board,
                    "error": None if ok else (res.get("error") or "no artifact"),
                }
                results.append(rec)
                status = "prompt built" if dry else (board if ok else f"FAILED: {rec['error']}")
                print(f"  {shot_id:14s} {arm:11s} dir={direction:13s} {status}")
    finally:
        for p in written:
            try:
                p.unlink()
            except OSError:
                pass

    out = _OUT.parent / "honor_rate_probe_generate.json"
    out.write_text(json.dumps({"dry_run": dry, "results": results}, indent=2))
    print(f"\nresults -> {out}")
    if dry:
        print("DRY-RUN: verify instruction vs mirror prompts carry opposite directions, "
              "then re-run with --live on Studio (paid).")
    if failures:
        print(f"\nERROR: {failures} live arm(s) failed to dispatch — see 'error' fields above.")
        return 1
    return 0


def cmd_score_all(args) -> int:
    """Score every board in the generate manifest against ITS arm's mutated
    shot (so the mirror arm is judged vs the flipped staging it was generated
    from, not the original plan shot), and report instruction/mirror honor +
    the flip control. Needs GEMINI_API_KEY."""
    project = args.project
    root = core_paths.ProjectPaths.for_project(project).project_root
    bible = _load_bible(project)
    gen = _OUT.parent / "honor_rate_probe_generate.json"
    data = json.loads(gen.read_text())
    if data.get("dry_run"):
        print("ERROR: manifest is a DRY-RUN (no boards generated) — nothing to score. "
              "Run --generate --live first.")
        return 1
    by_shot: dict[str, dict] = {}
    missing = 0
    for r in data.get("results", []):
        board = r.get("board")
        if not board:
            print(f"  MISSING {r['shot']}/{r['arm']}: no board (dispatch failed?)")
            missing += 1
            continue
        p = root / board
        if not p.is_file():
            print(f"  MISSING {r['shot']}/{r['arm']}: {p}")
            missing += 1
            continue
        scored_shot = r.get("scored_shot") or {}
        res = run_spatial_compliance(p.read_bytes(), scored_shot, bible)
        # run_spatial_compliance early-returns skipped=True (PASS, no flags) for
        # ANY shot it can't actually judge — missing spatial_data.camera_side OR
        # no/insufficient characters (env/solo). Mapping "no VIOLATION" to OK for
        # those would mint a FALSE honor, so key off the scorer's own skipped flag
        # and mark the arm UNSCORABLE (None) instead of honored.
        if res.get("skipped"):
            reason = res.get("skip_reason") or "skipped"
            by_shot.setdefault(r["shot"], {})[r["arm"]] = (None, reason)
            continue
        viols = [f for f in (res.get("flags") or []) if f.get("severity") == "VIOLATION"]
        vt = ",".join(sorted({(f.get("flag") or f.get("type", "?")) for f in viols})) or "none"
        by_shot.setdefault(r["shot"], {})[r["arm"]] = (not viols, vt)
    ih = itotal = mh = mtotal = flip = 0
    unscorable = 0
    has_mirror = any("mirror" in a for a in by_shot.values())
    print("=== honor-rate probe — scored against per-arm mutated staging ===")
    for shot, arms in sorted(by_shot.items()):
        i = arms.get("instruction")
        if not i:
            continue
        if i[0] is None:  # instruction arm unscorable (scorer skipped it)
            print(f"  {shot:11s} instr:SKIP ({i[1]})")
            unscorable += 1
            continue
        itotal += 1
        ih += i[0]
        line = f"  {shot:11s} instr:{'OK ' if i[0] else 'X  '}({i[1]})"
        m = arms.get("mirror")
        if m is not None and m[0] is not None:
            mtotal += 1
            mh += m[0]
            flipped = i[0] and m[0]
            flip += flipped
            line += f" mirror:{'OK ' if m[0] else 'X  '}({m[1]}) {'<- text-responsive' if flipped else ''}"
        elif m is not None:  # mirror arm present but unscorable
            unscorable += 1
            line += f" mirror:SKIP ({m[1]})"
        print(line)
    if itotal == 0:
        print("\nERROR: 0 instruction arms scored — the probe produced no usable signal "
              f"({missing} missing, {unscorable} unscorable). Not reporting a rate. FAIL.")
        return 1
    pct = (100 * ih // itotal)
    print(f"\ninstruction honor: {ih}/{itotal} = {pct}%")
    if has_mirror and mtotal:
        print(f"mirror honor:      {mh}/{mtotal} = {100 * mh // mtotal}%")
        print(f"text-responsive (output flipped WITH the instruction): {flip}/{mtotal} "
              f"— high = model reads the text channel; low = needs reference-conditioning")
    elif has_mirror:
        print("mirror honor:      no mirror arms were scorable (all skipped).")
    if missing or unscorable:
        print(f"\nWARNING: {missing} missing board(s) + {unscorable} unscorable arm(s) "
              "were excluded — the rate is over the scored subset only, not the full probe. "
              "FAIL the gate (incomplete signal).")
        return 1
    return 0


def main() -> int:
    ap = argparse.ArgumentParser(
        description="Honor-rate probe (REC-180 board screen-direction)."
    )
    ap.add_argument("--project", default="tartarus")
    ap.add_argument("--episode", type=int, default=1)
    ap.add_argument(
        "--manifest", action="store_true", help="emit the probe manifest + table (free)"
    )
    ap.add_argument(
        "--generate", action="store_true",
        help="generate single-shot boards (instruction + mirror); dry-run unless --live",
    )
    ap.add_argument(
        "--live", action="store_true",
        help="with --generate: actually dispatch the paid boards (Studio). Default is dry-run.",
    )
    ap.add_argument(
        "--no-mirror", action="store_true", help="with --generate: instruction arm only",
    )
    ap.add_argument(
        "--score",
        metavar="PANEL_PNG",
        help="score one generated board panel (needs GEMINI_API_KEY)",
    )
    ap.add_argument("--shot", help="shot id to score (with --score)")
    ap.add_argument(
        "--score-all", action="store_true",
        help="score every board in the generate manifest against its arm's mutated "
             "staging; report instruction/mirror honor + flip control (needs GEMINI_API_KEY)",
    )
    args = ap.parse_args()
    if args.score_all:
        return cmd_score_all(args)
    if args.score:
        if not args.shot:
            ap.error("--score requires --shot")
        return cmd_score(args)
    if args.generate:
        return cmd_generate(args)
    return cmd_manifest(args)


if __name__ == "__main__":
    raise SystemExit(main())
