"""Tests for run_episode -- orchestration, Style Anchor, coverage, abort/resume.

All tests mock run_shot to avoid deep dependency chains.
"""

import json
from pathlib import Path
from unittest.mock import MagicMock, patch

from recoil.pipeline._lib.coverage_context import OpResult, EpisodeResult, StopOnReview


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _make_paths(tmp_path):
    """Create a minimal ProjectPaths-like object."""
    paths = MagicMock()
    paths.project = "test-proj"
    paths.project_root = tmp_path
    paths.plans_dir = tmp_path / "_pipeline" / "state" / "visual" / "plans"
    (tmp_path / "_pipeline" / "state" / "visual").mkdir(parents=True, exist_ok=True)
    (tmp_path / "_pipeline" / "state" / "visual" / "runs").mkdir(parents=True, exist_ok=True)
    return paths


def _make_ok_result(shot_id, cost=0.10):
    return OpResult(
        status="ok",
        shot_id=shot_id,
        op_id=f"op_{'a' * 12}",
        output_path=f"/tmp/{shot_id}.jpg",
        cost_usd=cost,
        attempts=1,
    )


def _make_failed_result(shot_id, status="attempts_exhausted", cost=0.30):
    return OpResult(
        status=status,
        shot_id=shot_id,
        op_id=f"op_{'b' * 12}",
        output_path=None,
        cost_usd=cost,
        attempts=4,
        failure_mode="anatomy_face_merge",
        review_queue_id=f"rq_{'c' * 32}",
    )


def _make_shot_plan(n=4):
    """Generate a simple shot plan with n shots."""
    return [
        {
            "shot_id": f"SH{i + 1:02d}",
            "prompt": f"Shot {i + 1} prompt",
            "pipeline": "keyframe",
            "scene_index": f"SC{(i // 2) + 1:02d}",
            "shot_type": "primary" if i % 2 == 0 else "coverage_ws",
        }
        for i in range(n)
    ]


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------


class TestRunEpisode:
    def test_3_passing_1_failing(self, tmp_path):
        """3 clean shots + 1 failing -> review queue gets exactly 1 entry."""
        from recoil.pipeline._lib.run_episode import run_episode

        paths = _make_paths(tmp_path)
        shot_plan = _make_shot_plan(4)

        call_count = {"n": 0}

        def mock_run_shot(
            shot,
            store,
            paths,
            budget_guard,
            model,
            step_runner=None,
            run_id="",
            style_anchor_path=None,
            coverage_context=None,
            stop_on_review=StopOnReview.NEVER,
        ):
            call_count["n"] += 1
            sid = shot["shot_id"]
            if sid == "SH03":
                return _make_failed_result(sid)
            return _make_ok_result(sid)

        with patch(
            "recoil.pipeline._lib.run_episode.run_shot", side_effect=mock_run_shot
        ):
            with patch(
                "recoil.pipeline._lib.run_episode._generate_style_anchor",
                return_value=Path("/tmp/anchor.jpg"),
            ):
                result = run_episode(
                    project="test-proj",
                    episode_id="EP001",
                    model="test-model",
                    budget_usd=50.0,
                    no_style_anchor=True,
                    step_runner=MagicMock(),
                    store=MagicMock(),
                    paths=paths,
                    shot_plan=shot_plan,
                )

        assert isinstance(result, EpisodeResult)
        assert result.completed == 4
        assert result.by_status.get("ok", 0) == 3
        assert result.by_status.get("attempts_exhausted", 0) == 1
        assert result.review_queue_count == 1

    def test_style_anchor_generation(self, tmp_path):
        """Style Anchor is generated and saved to correct path."""
        from recoil.pipeline._lib.run_episode import _generate_style_anchor

        paths = _make_paths(tmp_path)
        budget_guard = MagicMock()
        budget_guard.would_exceed.return_value = False
        budget_guard.would_exceed_per_shot.return_value = False

        # Mock run_shot to return ok with an output file
        output_file = tmp_path / "output.jpg"
        output_file.write_bytes(b"fake image data")

        mock_result = _make_ok_result("EP001_STYLE_ANCHOR")
        mock_result.output_path = str(output_file)

        with patch(
            "recoil.pipeline._lib.run_episode.run_shot", return_value=mock_result
        ):
            anchor_path = _generate_style_anchor(
                step_runner=MagicMock(),
                store=MagicMock(),
                paths=paths,
                budget_guard=budget_guard,
                model="test-model",
                episode_id="EP001",
                shots=_make_shot_plan(2),
                run_id="run_test123456",
            )

        assert anchor_path is not None

    def test_style_anchor_failure_aborts(self, tmp_path):
        """3 failed style anchors -> episode aborts."""
        from recoil.pipeline._lib.run_episode import run_episode

        paths = _make_paths(tmp_path)
        shot_plan = _make_shot_plan(2)

        # Style anchor always fails
        def mock_run_shot(**kwargs):
            return _make_failed_result("STYLE_ANCHOR", status="crashed")

        with patch(
            "recoil.pipeline._lib.run_episode.run_shot", side_effect=mock_run_shot
        ):
            result = run_episode(
                project="test-proj",
                episode_id="EP001",
                model="test-model",
                budget_usd=50.0,
                no_style_anchor=False,
                step_runner=MagicMock(),
                store=MagicMock(),
                paths=paths,
                shot_plan=shot_plan,
            )

        assert result.aborted is True
        assert result.abort_reason == "style_anchor_failed"

    def test_no_style_anchor_bypass(self, tmp_path):
        """no_style_anchor=True skips anchor generation."""
        from recoil.pipeline._lib.run_episode import run_episode

        paths = _make_paths(tmp_path)
        shot_plan = _make_shot_plan(2)

        def mock_run_shot(
            shot,
            store,
            paths,
            budget_guard,
            model,
            step_runner=None,
            run_id="",
            style_anchor_path=None,
            coverage_context=None,
            stop_on_review=StopOnReview.NEVER,
        ):
            return _make_ok_result(shot["shot_id"])

        with patch(
            "recoil.pipeline._lib.run_episode.run_shot", side_effect=mock_run_shot
        ):
            with patch(
                "recoil.pipeline._lib.run_episode._generate_style_anchor"
            ) as mock_anchor:
                result = run_episode(
                    project="test-proj",
                    episode_id="EP001",
                    model="test-model",
                    budget_usd=50.0,
                    no_style_anchor=True,
                    step_runner=MagicMock(),
                    store=MagicMock(),
                    paths=paths,
                    shot_plan=shot_plan,
                )

        mock_anchor.assert_not_called()
        assert result.completed == 2

    def test_resume_skips_completed(self, tmp_path):
        """Resume skips completed shots, retries crashed."""
        from recoil.pipeline._lib.run_episode import run_episode

        paths = _make_paths(tmp_path)
        shot_plan = _make_shot_plan(3)

        # Write a previous run state
        run_id = "run_prevrun12345"
        state = {
            "run_id": run_id,
            "episode_id": "EP001",
            "aborted": True,
            "abort_reason": "sigterm",
            "style_anchors": {},
            "budget_spent": 0.20,
            "shots": [
                {
                    "shot_id": "SH01",
                    "status": "ok",
                    "op_id": "op_aaa",
                    "output_path": "/tmp/SH01.jpg",
                    "cost_usd": 0.10,
                    "attempts": 1,
                    "failure_mode": None,
                    "validation_notes": [],
                    "review_queue_id": None,
                },
                {
                    "shot_id": "SH02",
                    "status": "crashed",
                    "op_id": "op_bbb",
                    "output_path": None,
                    "cost_usd": 0.10,
                    "attempts": 1,
                    "failure_mode": None,
                    "validation_notes": [],
                    "review_queue_id": None,
                },
            ],
        }
        runs_dir = tmp_path / "_pipeline" / "state" / "visual" / "runs"
        runs_dir.mkdir(parents=True, exist_ok=True)
        (runs_dir / f"{run_id}.json").write_text(json.dumps(state))

        called_shots = []

        def mock_run_shot(
            shot,
            store,
            paths,
            budget_guard,
            model,
            step_runner=None,
            run_id="",
            style_anchor_path=None,
            coverage_context=None,
            stop_on_review=StopOnReview.NEVER,
        ):
            called_shots.append(shot["shot_id"])
            return _make_ok_result(shot["shot_id"])

        with patch(
            "recoil.pipeline._lib.run_episode.run_shot", side_effect=mock_run_shot
        ):
            result = run_episode(
                project="test-proj",
                episode_id="EP001",
                model="test-model",
                budget_usd=50.0,
                no_style_anchor=True,
                resume_run_id=run_id,
                step_runner=MagicMock(),
                store=MagicMock(),
                paths=paths,
                shot_plan=shot_plan,
            )

        # SH01 should be skipped (status=ok -> SKIP_ON_RESUME)
        # SH02 should be retried (status=crashed -> RETRY_ON_RESUME)
        # SH03 was not in previous run -> should be run
        assert "SH01" not in called_shots
        assert "SH02" in called_shots
        assert "SH03" in called_shots

    def test_budget_abort(self, tmp_path):
        """100% budget hit -> no new shots start, in-flight complete."""
        from recoil.pipeline._lib.run_episode import run_episode

        paths = _make_paths(tmp_path)
        shot_plan = _make_shot_plan(5)

        call_count = {"n": 0}

        def mock_run_shot(
            shot,
            store,
            paths,
            budget_guard,
            model,
            step_runner=None,
            run_id="",
            style_anchor_path=None,
            coverage_context=None,
            stop_on_review=StopOnReview.NEVER,
        ):
            call_count["n"] += 1
            # Each shot costs $20; with $30 budget, only ~1-2 should succeed
            budget_guard.charge(20.0, reserved_amount=0.15)
            return _make_ok_result(shot["shot_id"], cost=20.0)

        with patch(
            "recoil.pipeline._lib.run_episode.run_shot", side_effect=mock_run_shot
        ):
            result = run_episode(
                project="test-proj",
                episode_id="EP001",
                model="test-model",
                budget_usd=30.0,
                concurrency=1,
                no_style_anchor=True,
                step_runner=MagicMock(),
                store=MagicMock(),
                paths=paths,
                shot_plan=shot_plan,
            )

        # Should have aborted due to budget exhaustion
        assert result.aborted is True or result.total_cost_usd >= 30.0

    def test_morning_summary(self, tmp_path):
        """Verify morning_summary format."""
        result = EpisodeResult(
            run_id="run_test123456",
            episode_id="EP001",
            total_shots=5,
            completed=5,
            by_status={"ok": 3, "attempts_exhausted": 1, "needs_review": 1},
            total_cost_usd=2.50,
            budget_remaining_usd=47.50,
            aborted=False,
            review_queue_count=2,
            shot_results=[
                OpResult(
                    status="ok",
                    shot_id="SH01",
                    op_id="op_a" * 3,
                    validation_notes=["identity_drift: accepted (soft finding)"],
                ),
            ],
        )

        summary = result.morning_summary()
        assert "EP001" in summary
        assert "run_test123456" in summary
        assert "5/5" in summary
        assert "$2.50" in summary

    def test_coverage_grouping(self, tmp_path):
        """Verify primary runs before siblings within a coverage pass."""
        from recoil.pipeline._lib.run_episode import _group_by_coverage_pass

        shots = [
            {
                "shot_id": "SH01",
                "scene_index": "SC01",
                "shot_type": "primary",
                "coverage_pass_id": "SC01_A",
            },
            {
                "shot_id": "SH02",
                "scene_index": "SC01",
                "shot_type": "coverage_ws",
                "coverage_pass_id": "SC01_A",
            },
            {
                "shot_id": "SH03",
                "scene_index": "SC02",
                "shot_type": "primary",
                "coverage_pass_id": "SC02_A",
            },
        ]

        groups = _group_by_coverage_pass(shots)
        assert len(groups) == 2
        # First group: SC01_A with SH01 as primary and SH02 as sibling
        assert groups[0]["primary"]["shot_id"] == "SH01"
        assert len(groups[0]["siblings"]) == 1
        assert groups[0]["siblings"][0]["shot_id"] == "SH02"

    def test_stop_on_any_review_aborts(self, tmp_path):
        """StopOnReview.ON_ANY_REVIEW aborts when any shot enters review."""
        from recoil.pipeline._lib.run_episode import run_episode

        paths = _make_paths(tmp_path)
        shot_plan = _make_shot_plan(4)

        call_count = {"n": 0}

        def mock_run_shot(
            shot,
            store,
            paths,
            budget_guard,
            model,
            step_runner=None,
            run_id="",
            style_anchor_path=None,
            coverage_context=None,
            stop_on_review=StopOnReview.NEVER,
        ):
            call_count["n"] += 1
            sid = shot["shot_id"]
            # First shot enters review queue
            if call_count["n"] == 1:
                return OpResult(
                    status="needs_review",
                    shot_id=sid,
                    op_id=f"op_{'d' * 12}",
                    review_queue_id=f"rq_{'e' * 32}",
                )
            return _make_ok_result(sid)

        with patch(
            "recoil.pipeline._lib.run_episode.run_shot", side_effect=mock_run_shot
        ):
            result = run_episode(
                project="test-proj",
                episode_id="EP001",
                model="test-model",
                budget_usd=50.0,
                concurrency=1,
                stop_on_review=StopOnReview.ON_ANY_REVIEW,
                no_style_anchor=True,
                step_runner=MagicMock(),
                store=MagicMock(),
                paths=paths,
                shot_plan=shot_plan,
            )

        assert result.aborted is True

    def test_stop_on_hard_fail_aborts(self, tmp_path):
        """StopOnReview.ON_HARD_FAIL aborts on ICU escalation."""
        from recoil.pipeline._lib.run_episode import run_episode

        paths = _make_paths(tmp_path)
        shot_plan = _make_shot_plan(4)

        call_count = {"n": 0}

        def mock_run_shot(
            shot,
            store,
            paths,
            budget_guard,
            model,
            step_runner=None,
            run_id="",
            style_anchor_path=None,
            coverage_context=None,
            stop_on_review=StopOnReview.NEVER,
        ):
            call_count["n"] += 1
            sid = shot["shot_id"]
            if call_count["n"] == 1:
                return OpResult(
                    status="icu_escalated",
                    shot_id=sid,
                    op_id=f"op_{'f' * 12}",
                    review_queue_id=f"rq_{'g' * 32}",
                )
            return _make_ok_result(sid)

        with patch(
            "recoil.pipeline._lib.run_episode.run_shot", side_effect=mock_run_shot
        ):
            result = run_episode(
                project="test-proj",
                episode_id="EP001",
                model="test-model",
                budget_usd=50.0,
                concurrency=1,
                stop_on_review=StopOnReview.ON_HARD_FAIL,
                no_style_anchor=True,
                step_runner=MagicMock(),
                store=MagicMock(),
                paths=paths,
                shot_plan=shot_plan,
            )

        assert result.aborted is True

    def test_state_dump_and_load(self, tmp_path):
        """Run state is saved and loadable for resume."""
        from recoil.pipeline._lib.run_episode import _save_run_state, _load_run_state
        from recoil.pipeline._lib.budget_manager import BudgetGuard

        paths = _make_paths(tmp_path)
        guard = BudgetGuard(limit_usd=50.0, label="test")
        guard.charge(5.0)

        results = [_make_ok_result("SH01"), _make_failed_result("SH02")]

        state_path = _save_run_state(
            paths,
            "run_test123456",
            results,
            guard,
            "EP001",
            aborted=False,
            style_anchors={"episode": Path("/tmp/anchor.jpg")},
        )

        assert state_path.exists()

        loaded = _load_run_state(paths, "run_test123456")
        assert loaded is not None
        assert loaded["run_id"] == "run_test123456"
        assert loaded["budget_spent"] == 5.0
        assert len(loaded["shots"]) == 2
        assert loaded["shots"][0]["status"] == "ok"
        assert loaded["shots"][1]["status"] == "attempts_exhausted"
