from __future__ import annotations

import json
import sys
from pathlib import Path

from PIL import Image

from recoil.pipeline._lib import story_gate as sg
from recoil.pipeline.cli import generate


def _write_board(projects_root: Path, project: str, filename: str) -> Path:
    board = (
        projects_root
        / project
        / "prep"
        / "ep_001"
        / "storyboards"
        / filename
    )
    board.parent.mkdir(parents=True, exist_ok=True)
    Image.new("RGB", (200, 200), "white").save(board)
    sidecar = {
        "prompt": (
            "SCENE CONTEXT\nKnown cause.\n"
            "AUTHORING\n"
            "1. Jade looks at CRYO-07.\n"
            "2. Chassis catches cable.\n"
            "3. Jade reacts.\n"
            "4. Door opens."
        ),
        "panels": [{"index": i} for i in range(1, 5)],
        "source_sha256": "f" * 64,
    }
    Path(f"{board}.json").write_text(json.dumps(sidecar), encoding="utf-8")
    return board


def _labels(tmp_path: Path, project: str, *, boards: list[str], cases: list[dict]) -> Path:
    path = tmp_path / "labels.json"
    path.write_text(
        json.dumps(
            {
                "schema_version": 1,
                "project": project,
                "episode": 1,
                "cases": cases,
                "boards": boards,
            }
        ),
        encoding="utf-8",
    )
    return path


def _case_cont004(expected: str = "fail") -> dict:
    return {
        "artifact": "prep/ep_001/storyboards/EP001_CONT_004_v03.png",
        "locus": {"panel": 1},
        "expected": expected,
        "expected_route_any": ["board_problem"],
        "expected_check": "object_of_gaze_in_frame_and_front",
        "reason": "target behind actor",
    }


def _case_presumed() -> dict:
    return {
        "artifact": "prep/ep_001/storyboards/EP001_CONT_004_v03.png",
        "locus": {"panels": [2, 3, 4]},
        "expected": "presumed_pass",
        "reason": "weak negatives",
    }


def _case_cont007() -> dict:
    return {
        "artifact": "prep/ep_001/storyboards/EP001_CONT_007_v02.png",
        "locus": {"panel": 2},
        "expected": "fail",
        "expected_route_any": ["script_problem", "mixed"],
        "expected_check": "causal_setup_present",
        "reason": "missing motivation",
    }


def _verdict(*, board_id: str, route: str = "ok", fail_panel: int | None = None,
             fail_check: str = "object_of_gaze_in_frame_and_front") -> dict:
    panels = []
    for idx in range(1, 5):
        checks = {
            "depicts_beat": _check(True),
            "spatially_possible": _check(True),
            "eyeline_consistent": _check(True),
            "object_of_gaze_in_frame_and_front": _check(True),
            "causal_setup_present": _check(True),
        }
        if fail_panel == idx:
            checks[fail_check] = _check(False, reason=f"{fail_check} failed")
        panels.append(
            {
                "index": idx,
                "description": f"panel {idx}",
                "forced_checks": checks,
                "fix_hint": "",
                "fix_hint_injectable": False,
            }
        )
    return {
        "schema_version": sg.SCHEMA_VERSION,
        "judge_model": "mock",
        "prompt_version": sg.PROMPT_VERSION,
        "board_id": board_id,
        "source_sha256": "f" * 64,
        "text_stageability": None,
        "panels": panels,
        "transitions": [
            {
                "from": i,
                "to": i + 1,
                "forced_checks": {"causal_setup_present": _check(True)},
                "fix_hint": "",
                "fix_hint_injectable": False,
            }
            for i in (1, 2, 3)
        ],
        "routing": {"class": route, "confidence": 0.9, "evidence": route},
    }


def _check(passed: bool, *, reason: str = "ok") -> dict:
    return {
        "passed": passed,
        "severity": "SOFT" if passed else "HARD",
        "confidence": 0.9,
        "reason": reason,
    }


def _mock_text_stageability() -> dict:
    return {
        "stageable": False,
        "findings": [
            {
                "beat_index": 2,
                "check": "causal_setup_present",
                "passed": False,
                "severity": "HARD",
                "confidence": 0.91,
                "problem_kind": "absent_cause",
                "reason": "missing cause",
                "injectable": False,
                "suggested_script_question": "What caused it?",
            }
        ],
    }


def _board_id_from_prompt(prompt: str) -> str:
    marker = "BOARD ID\n"
    start = prompt.index(marker) + len(marker)
    return prompt[start:].splitlines()[0]


def test_calibration_named_case_vector_and_or_aggregate(tmp_path, monkeypatch):
    projects_root = tmp_path / "projects"
    project = "fixture"
    _write_board(projects_root, project, "EP001_CONT_004_v03.png")
    labels = _labels(
        tmp_path,
        project,
        boards=["EP001_CONT_004_v03.png"],
        cases=[_case_cont004(), _case_presumed()],
    )
    calls: dict[tuple[str, int], int] = {}

    def judge(prompt, images, *, model=None, max_attempts=3):
        if images is None:
            return _mock_text_stageability()
        key = (_board_id_from_prompt(prompt), len(images or []))
        calls[key] = calls.get(key, 0) + 1
        hit = calls[key] in (1, 3, 5)
        return _verdict(
            board_id=key[0],
            route="board_problem" if hit else "ok",
            fail_panel=1 if hit else None,
        )

    monkeypatch.setattr(sg, "_judge_call", judge)

    report = sg.run_calibration(
        labels,
        projects_root,
        samples=5,
        tiers=("mock-opus",),
        crops_modes=(True,),
    )

    named = [
        entry for entry in report["per_case"]
        if entry["expected"] == "fail" and entry["tier"] == "mock-opus"
    ][0]
    assert named["catch_vector"] == [1, 0, 1, 0, 1]
    assert named["caught_or"] is True
    assert report["aggregate"]["named_recall_or"] == 1.0


def test_calibration_never_fails_named_case_recall_below_one(tmp_path, monkeypatch):
    projects_root = tmp_path / "projects"
    project = "fixture"
    _write_board(projects_root, project, "EP001_CONT_004_v03.png")
    labels = _labels(
        tmp_path,
        project,
        boards=["EP001_CONT_004_v03.png"],
        cases=[_case_cont004()],
    )

    def judge(prompt, images, *, model=None, max_attempts=3):
        if images is None:
            return {"stageable": True, "findings": []}
        return _verdict(board_id=_board_id_from_prompt(prompt))

    monkeypatch.setattr(sg, "_judge_call", judge)

    report = sg.run_calibration(
        labels,
        projects_root,
        samples=2,
        tiers=("mock-opus",),
        crops_modes=(True,),
    )

    assert report["aggregate"]["named_recall_or"] < 1.0


def test_calibration_groups_results_per_tier(tmp_path, monkeypatch):
    projects_root = tmp_path / "projects"
    project = "fixture"
    _write_board(projects_root, project, "EP001_CONT_004_v03.png")
    labels = _labels(
        tmp_path,
        project,
        boards=["EP001_CONT_004_v03.png"],
        cases=[_case_cont004()],
    )

    def judge(prompt, images, *, model=None, max_attempts=3):
        if images is None:
            return {"stageable": True, "findings": []}
        return _verdict(
            board_id=_board_id_from_prompt(prompt),
            route="board_problem",
            fail_panel=1,
        )

    monkeypatch.setattr(sg, "_judge_call", judge)

    report = sg.run_calibration(
        labels,
        projects_root,
        samples=1,
        tiers=("mock-opus", "mock-sonnet"),
        crops_modes=(True,),
    )

    assert set(report["aggregate"]["by_tier"]) == {"mock-opus", "mock-sonnet"}
    assert {entry["tier"] for entry in report["per_case"]} == {
        "mock-opus",
        "mock-sonnet",
    }


def test_boards_basename_resolves_to_storyboards_dir(tmp_path, monkeypatch):
    projects_root = tmp_path / "projects"
    project = "fixture"
    board = _write_board(projects_root, project, "EP001_CONT_004_v03.png")
    labels = _labels(
        tmp_path,
        project,
        boards=[board.name],
        cases=[_case_cont004()],
    )

    def judge(prompt, images, *, model=None, max_attempts=3):
        if images is None:
            return {"stageable": True, "findings": []}
        return _verdict(
            board_id=_board_id_from_prompt(prompt),
            route="board_problem",
            fail_panel=1,
        )

    monkeypatch.setattr(sg, "_judge_call", judge)

    report = sg.run_calibration(
        labels,
        projects_root,
        samples=1,
        tiers=("mock-opus",),
        crops_modes=(True,),
    )

    assert report["aggregate"]["skipped"] == []
    assert report["per_case"][0]["catch_vector"] == [1]


def test_cont007_case_gets_text_stageability_vector(tmp_path, monkeypatch):
    projects_root = tmp_path / "projects"
    project = "fixture"
    _write_board(projects_root, project, "EP001_CONT_007_v02.png")
    labels = _labels(
        tmp_path,
        project,
        boards=["EP001_CONT_007_v02.png"],
        cases=[_case_cont007()],
    )

    def judge(prompt, images, *, model=None, max_attempts=3):
        if images is None:
            return _mock_text_stageability()
        return _verdict(
            board_id=_board_id_from_prompt(prompt),
            route="script_problem",
            fail_panel=2,
            fail_check="causal_setup_present",
        )

    monkeypatch.setattr(sg, "_judge_call", judge)

    report = sg.run_calibration(
        labels,
        projects_root,
        samples=3,
        tiers=("mock-opus",),
        crops_modes=(True,),
    )

    assert report["per_case"][0]["text_stageability"]["catch_vector"] == [1, 1, 1]


def test_presumed_pass_hard_flag_is_listed_not_counted_as_fp(tmp_path, monkeypatch):
    projects_root = tmp_path / "projects"
    project = "fixture"
    _write_board(projects_root, project, "EP001_CONT_004_v03.png")
    labels = _labels(
        tmp_path,
        project,
        boards=["EP001_CONT_004_v03.png"],
        cases=[_case_presumed()],
    )

    def judge(prompt, images, *, model=None, max_attempts=3):
        if images is None:
            return {"stageable": True, "findings": []}
        return _verdict(
            board_id=_board_id_from_prompt(prompt),
            route="board_problem",
            fail_panel=2,
        )

    monkeypatch.setattr(sg, "_judge_call", judge)

    report = sg.run_calibration(
        labels,
        projects_root,
        samples=1,
        tiers=("mock-opus",),
        crops_modes=(True,),
    )

    assert report["aggregate"]["named_recall_or"] == 1.0
    assert report["aggregate"]["hard_flags_on_presumed_pass"][0]["panel"] == 2


def test_missing_board_file_records_skip_without_crashing(tmp_path):
    projects_root = tmp_path / "projects"
    project = "fixture"
    labels = _labels(
        tmp_path,
        project,
        boards=["EP001_CONT_004_v03.png"],
        cases=[_case_cont004()],
    )

    report = sg.run_calibration(
        labels,
        projects_root,
        samples=1,
        tiers=("mock-opus",),
        crops_modes=(True,),
    )

    assert report["per_case"][0]["skipped"] is True
    assert "missing PNG" in report["aggregate"]["skipped"][0]["reason"]


def test_cli_story_gate_eval_exit_4_when_thresholds_fail(tmp_path, monkeypatch, capsys):
    projects_root = tmp_path / "projects"
    projects_root.mkdir()
    (projects_root / ".recoil-data-root").touch()
    labels = tmp_path / "labels.json"
    labels.write_text("{}", encoding="utf-8")
    monkeypatch.setenv("RECOIL_PROJECTS_ROOT", str(projects_root))

    def fake_run_calibration(labels_path, root, *, samples, tiers, crops_modes):
        assert labels_path == labels
        assert root == projects_root
        assert samples == 2
        assert tiers == ("mock-opus",)
        assert crops_modes == (True, False)
        return {
            "per_case": [
                {
                    "case_index": 0,
                    "tier": "mock-opus",
                    "crops_mode": True,
                    "catch_vector": [0, 0],
                    "caught_or": False,
                }
            ],
            "aggregate": {
                "named_recall_or": 0.0,
                "schema_valid_rate": 1.0,
                "judge_unavailable_count": 0,
                "hard_flags_on_presumed_pass": [],
            },
            "config": {"samples": 2, "tiers": ["mock-opus"], "crops_modes": [True, False]},
        }

    monkeypatch.setattr(generate, "run_calibration", fake_run_calibration)
    monkeypatch.setattr(
        sys,
        "argv",
        [
            "generate.py",
            "--project",
            "fixture",
            "--episode",
            "1",
            "--story-gate-eval",
            "--labels",
            str(labels),
            "--samples",
            "2",
            "--tier",
            "mock-opus",
        ],
    )

    code = generate.main()
    out = capsys.readouterr().out
    assert code == 4
    assert '"named_recall_or": 0.0' in out
    assert "StoryGate calibration" in out


def test_cli_no_crops_compare_runs_crops_true_only(tmp_path, monkeypatch, capsys):
    projects_root = tmp_path / "projects"
    projects_root.mkdir()
    (projects_root / ".recoil-data-root").touch()
    labels = tmp_path / "labels.json"
    labels.write_text("{}", encoding="utf-8")
    monkeypatch.setenv("RECOIL_PROJECTS_ROOT", str(projects_root))
    seen = {}

    def fake_run_calibration(labels_path, root, *, samples, tiers, crops_modes):
        seen["crops_modes"] = crops_modes
        return {
            "per_case": [],
            "aggregate": {
                "named_recall_or": 1.0,
                "schema_valid_rate": 1.0,
                "judge_unavailable_count": 0,
                "hard_flags_on_presumed_pass": [],
            },
            "config": {},
        }

    monkeypatch.setattr(generate, "run_calibration", fake_run_calibration)
    monkeypatch.setattr(
        sys,
        "argv",
        [
            "generate.py",
            "--project",
            "fixture",
            "--episode",
            "1",
            "--story-gate-eval",
            "--labels",
            str(labels),
            "--no-crops-compare",
        ],
    )

    assert generate.main() == 0
    capsys.readouterr()
    assert seen["crops_modes"] == (True,)


def test_text_catch_counts_toward_named_recall_when_image_judge_misses(tmp_path, monkeypatch):
    """Codex merge-gate finding (PR #76): for script-stageability cases the
    pre-gen text pass IS the capability under test — a text catch with an
    image-verdict miss must still count toward named_recall_or."""
    projects_root = tmp_path / "projects"
    project = "fixture"
    _write_board(projects_root, project, "EP001_CONT_007_v02.png")
    labels = _labels(
        tmp_path,
        project,
        boards=["EP001_CONT_007_v02.png"],
        cases=[_case_cont007()],
    )

    def judge(prompt, images, *, model=None, max_attempts=3):
        if images is None:
            # Text pass CATCHES the unmotivated action.
            return _mock_text_stageability()
        # Image judge MISSES: clean verdict, no failed checks.
        return _verdict(board_id=_board_id_from_prompt(prompt), route="ok")

    monkeypatch.setattr(sg, "_judge_call", judge)

    report = sg.run_calibration(
        labels,
        projects_root,
        samples=2,
        tiers=("mock-opus",),
        crops_modes=(True,),
    )

    entry = report["per_case"][0]
    assert entry["text_stageability"]["caught_or"] is True
    assert entry["catch_vector"] == [0, 0]
    assert report["aggregate"]["by_tier"]["mock-opus"]["named_recall_or"] == 1.0


def test_text_catch_on_wrong_beat_does_not_count(tmp_path, monkeypatch):
    """PR #76 gate r2: a HARD causal text finding on a DIFFERENT beat than
    the labeled locus is not a catch for the case."""
    projects_root = tmp_path / "projects"
    project = "fixture"
    _write_board(projects_root, project, "EP001_CONT_007_v02.png")
    labels = _labels(
        tmp_path,
        project,
        boards=["EP001_CONT_007_v02.png"],
        cases=[_case_cont007()],  # locus panel 2
    )

    def judge(prompt, images, *, model=None, max_attempts=3):
        if images is None:
            wrong = _mock_text_stageability()
            wrong["findings"][0]["beat_index"] = 1  # wrong locus
            return wrong
        return _verdict(board_id=_board_id_from_prompt(prompt), route="ok")

    monkeypatch.setattr(sg, "_judge_call", judge)

    report = sg.run_calibration(
        labels,
        projects_root,
        samples=1,
        tiers=("mock-opus",),
        crops_modes=(True,),
    )

    entry = report["per_case"][0]
    assert entry["text_stageability"]["caught_or"] is False
    assert report["aggregate"]["by_tier"]["mock-opus"]["named_recall_or"] == 0.0
