"""Additional unit tests for Phase 3 -- edge cases not covered by Tasks 1-10.

Tests:
1. _extract_failure_mode() helper: missing details, unknown failure_category, None verdict
2. StopOnReview.ON_HARD_FAIL does NOT abort on needs_review (only icu_escalated/crashed)
3. StopOnReview.ON_ANY_REVIEW aborts on needs_review but also on attempts_exhausted
4. Review queue concurrent resolve safety
5. budget_exhausted_success: shot succeeds on last affordable attempt
6. _is_identity_drift_hard() edge cases
7. _resolve_action() for all action types
"""

import json
import threading
from unittest.mock import MagicMock, patch

from recoil.pipeline._lib.coverage_context import OpResult, StopOnReview


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _make_gate_verdict(passed=True, gate_name="gate_1", reason="", details=None):
    gv = MagicMock()
    gv.passed = passed
    gv.gate_name = gate_name
    gv.reason = reason
    gv.details = details or {}
    gv.cost = 0.01
    return gv


def _make_step_result(
    success=True,
    final_state="keyframe_generated",
    output_path="/tmp/output.jpg",
    cost_usd=0.15,
    error=None,
    gate_verdict=None,
):
    result = MagicMock()
    result.success = success
    result.final_state = final_state
    result.output_path = output_path
    result.cost_usd = cost_usd
    result.error = error
    result.gate_verdict = gate_verdict
    return result


def _make_paths(tmp_path):
    paths = MagicMock()
    paths.project = "test-proj"
    paths.project_root = tmp_path
    (tmp_path / "state" / "visual").mkdir(parents=True, exist_ok=True)
    return paths


def _make_budget_guard(limit=100.0, per_shot_cap=None):
    from recoil.pipeline._lib.budget_manager import BudgetGuard

    return BudgetGuard(limit_usd=limit, label="test", per_shot_cap_usd=per_shot_cap)


def _make_shot(shot_id="SH01", prompt="Test prompt", pipeline="keyframe", **extra):
    shot = {"shot_id": shot_id, "prompt": prompt, "pipeline": pipeline}
    shot.update(extra)
    return shot


# ---------------------------------------------------------------------------
# Tests: _extract_failure_mode() edge cases
# ---------------------------------------------------------------------------


class TestExtractFailureMode:
    """Independent tests for the _extract_failure_mode() helper."""

    def test_none_gate_verdict_returns_none_failure(self):
        """gate_verdict is None -> FailureMode.NONE."""
        from recoil.pipeline._lib.run_shot import _extract_failure_mode
        from recoil.pipeline._lib.critics import FailureMode

        result = _make_step_result(gate_verdict=None)
        assert _extract_failure_mode(result) == FailureMode.NONE

    def test_passed_verdict_returns_none_failure(self):
        """gate_verdict.passed=True -> FailureMode.NONE."""
        from recoil.pipeline._lib.run_shot import _extract_failure_mode
        from recoil.pipeline._lib.critics import FailureMode

        gv = _make_gate_verdict(passed=True)
        result = _make_step_result(gate_verdict=gv)
        assert _extract_failure_mode(result) == FailureMode.NONE

    def test_missing_details_dict_returns_unknown(self):
        """gate_verdict with no details dict -> FailureMode.UNKNOWN."""
        from recoil.pipeline._lib.run_shot import _extract_failure_mode
        from recoil.pipeline._lib.critics import FailureMode

        gv = _make_gate_verdict(passed=False, details=None)
        # Ensure details attribute returns None
        gv.details = None
        result = _make_step_result(gate_verdict=gv)
        fm = _extract_failure_mode(result)
        # Should fall through to gate_name lookup or UNKNOWN
        assert fm in (
            FailureMode.UNKNOWN,
            FailureMode.IDENTITY_DRIFT,
            FailureMode.ANATOMY_FACE_MERGE,
            FailureMode.SAFETY_SOFTENED,
        )

    def test_empty_failure_category_returns_unknown(self):
        """Empty failure_category string -> FailureMode.UNKNOWN."""
        from recoil.pipeline._lib.run_shot import _extract_failure_mode
        from recoil.pipeline._lib.critics import FailureMode

        gv = _make_gate_verdict(
            passed=False,
            gate_name="some_gate",
            details={"failure_category": ""},
        )
        result = _make_step_result(gate_verdict=gv)
        assert _extract_failure_mode(result) == FailureMode.UNKNOWN

    def test_unknown_failure_category_string(self):
        """Unrecognized failure_category -> FailureMode.UNKNOWN."""
        from recoil.pipeline._lib.run_shot import _extract_failure_mode
        from recoil.pipeline._lib.critics import FailureMode

        gv = _make_gate_verdict(
            passed=False,
            gate_name="gate_1",
            details={"failure_category": "completely_novel_failure_xyz"},
        )
        result = _make_step_result(gate_verdict=gv)
        assert _extract_failure_mode(result) == FailureMode.UNKNOWN

    def test_valid_failure_category_maps_correctly(self):
        """Known failure_category string -> correct FailureMode."""
        from recoil.pipeline._lib.run_shot import _extract_failure_mode
        from recoil.pipeline._lib.critics import FailureMode

        gv = _make_gate_verdict(
            passed=False,
            gate_name="gate_1",
            details={"failure_category": "anatomy_face_merge"},
        )
        result = _make_step_result(gate_verdict=gv)
        assert _extract_failure_mode(result) == FailureMode.ANATOMY_FACE_MERGE

    def test_identity_gate_name_fallback(self):
        """Gate name containing 'identity' maps to IDENTITY_DRIFT."""
        from recoil.pipeline._lib.run_shot import _extract_failure_mode
        from recoil.pipeline._lib.critics import FailureMode

        gv = _make_gate_verdict(
            passed=False,
            gate_name="identity_check",
            details={"failure_category": "some_unmapped_value"},
        )
        result = _make_step_result(gate_verdict=gv)
        assert _extract_failure_mode(result) == FailureMode.IDENTITY_DRIFT

    def test_safety_gate_name_fallback(self):
        """Gate name containing 'safety' maps to SAFETY_SOFTENED."""
        from recoil.pipeline._lib.run_shot import _extract_failure_mode
        from recoil.pipeline._lib.critics import FailureMode

        gv = _make_gate_verdict(
            passed=False,
            gate_name="safety_filter",
            details={"failure_category": "irrelevant_category"},
        )
        result = _make_step_result(gate_verdict=gv)
        assert _extract_failure_mode(result) == FailureMode.SAFETY_SOFTENED

    def test_content_filter_category_fallback(self):
        """failure_category containing 'content_filter' maps to SAFETY_SOFTENED."""
        from recoil.pipeline._lib.run_shot import _extract_failure_mode
        from recoil.pipeline._lib.critics import FailureMode

        gv = _make_gate_verdict(
            passed=False,
            gate_name="gate_1",
            details={"failure_category": "content_filter_violation"},
        )
        result = _make_step_result(gate_verdict=gv)
        assert _extract_failure_mode(result) == FailureMode.SAFETY_SOFTENED

    def test_anatomy_gate_name_fallback(self):
        """Gate name containing 'anatomy' maps to ANATOMY_FACE_MERGE."""
        from recoil.pipeline._lib.run_shot import _extract_failure_mode
        from recoil.pipeline._lib.critics import FailureMode

        gv = _make_gate_verdict(
            passed=False,
            gate_name="anatomy_critic",
            details={"failure_category": "unrecognized"},
        )
        result = _make_step_result(gate_verdict=gv)
        assert _extract_failure_mode(result) == FailureMode.ANATOMY_FACE_MERGE

    def test_face_merge_category_fallback(self):
        """failure_category containing 'face_merge' maps to ANATOMY_FACE_MERGE."""
        from recoil.pipeline._lib.run_shot import _extract_failure_mode
        from recoil.pipeline._lib.critics import FailureMode

        gv = _make_gate_verdict(
            passed=False,
            gate_name="gate_1",
            details={"failure_category": "face_merge_detected"},
        )
        result = _make_step_result(gate_verdict=gv)
        assert _extract_failure_mode(result) == FailureMode.ANATOMY_FACE_MERGE


# ---------------------------------------------------------------------------
# Tests: _is_identity_drift_hard() edge cases
# ---------------------------------------------------------------------------


class TestIsIdentityDriftHard:
    def test_none_verdict(self):
        from recoil.pipeline._lib.run_shot import _is_identity_drift_hard

        result = _make_step_result(gate_verdict=None)
        assert _is_identity_drift_hard(result) is False

    def test_empty_hard_failures(self):
        from recoil.pipeline._lib.run_shot import _is_identity_drift_hard

        gv = _make_gate_verdict(passed=False, details={"hard_failures": []})
        result = _make_step_result(gate_verdict=gv)
        assert _is_identity_drift_hard(result) is False

    def test_has_hard_failures(self):
        from recoil.pipeline._lib.run_shot import _is_identity_drift_hard

        gv = _make_gate_verdict(
            passed=False,
            details={"hard_failures": ["face_shape", "eye_color"]},
        )
        result = _make_step_result(gate_verdict=gv)
        assert _is_identity_drift_hard(result) is True

    def test_missing_details_attribute(self):
        from recoil.pipeline._lib.run_shot import _is_identity_drift_hard

        gv = MagicMock()
        gv.details = None
        result = _make_step_result(gate_verdict=gv)
        assert _is_identity_drift_hard(result) is False


# ---------------------------------------------------------------------------
# Tests: _resolve_action() for action types
# ---------------------------------------------------------------------------


class TestResolveAction:
    def test_accept_action(self):
        from recoil.pipeline._lib.run_shot import _resolve_action
        from recoil.pipeline._lib.critics import FailureMode

        result = _make_step_result()
        assert _resolve_action(FailureMode.BACKGROUND_CONTAMINATION, result) == "ACCEPT"

    def test_auto_reroll_action(self):
        from recoil.pipeline._lib.run_shot import _resolve_action
        from recoil.pipeline._lib.critics import FailureMode

        result = _make_step_result()
        assert _resolve_action(FailureMode.ANATOMY_FACE_MERGE, result) == "AUTO_REROLL"

    def test_soften_retry_action(self):
        from recoil.pipeline._lib.run_shot import _resolve_action
        from recoil.pipeline._lib.critics import FailureMode

        result = _make_step_result()
        assert _resolve_action(FailureMode.SAFETY_SOFTENED, result) == "SOFTEN_RETRY"

    def test_review_queue_action(self):
        from recoil.pipeline._lib.run_shot import _resolve_action
        from recoil.pipeline._lib.critics import FailureMode

        result = _make_step_result()
        assert _resolve_action(FailureMode.UNKNOWN, result) == "REVIEW_QUEUE"

    def test_identity_drift_soft_path(self):
        """Soft identity drift -> ACCEPT."""
        from recoil.pipeline._lib.run_shot import _resolve_action
        from recoil.pipeline._lib.critics import FailureMode

        gv = _make_gate_verdict(passed=False, details={"hard_failures": []})
        result = _make_step_result(gate_verdict=gv)
        assert _resolve_action(FailureMode.IDENTITY_DRIFT, result) == "ACCEPT"

    def test_identity_drift_hard_path(self):
        """Hard identity drift -> AUTO_REROLL."""
        from recoil.pipeline._lib.run_shot import _resolve_action
        from recoil.pipeline._lib.critics import FailureMode

        gv = _make_gate_verdict(
            passed=False,
            details={"hard_failures": ["face_shape"]},
        )
        result = _make_step_result(gate_verdict=gv)
        assert _resolve_action(FailureMode.IDENTITY_DRIFT, result) == "AUTO_REROLL"

    def test_unmapped_failure_mode_returns_review_queue(self):
        """A FailureMode not in ACTION_MAP -> REVIEW_QUEUE."""
        from recoil.pipeline._lib.run_shot import _resolve_action

        # Temporarily remove NONE from ACTION_MAP to test fallback
        result = _make_step_result()
        # Use a mock failure mode not in the map
        mock_fm = MagicMock()
        mock_fm.__hash__ = lambda self: hash("__test_unmapped__")
        assert _resolve_action(mock_fm, result) == "REVIEW_QUEUE"


# ---------------------------------------------------------------------------
# Tests: StopOnReview edge cases
# ---------------------------------------------------------------------------


class TestStopOnReviewEdgeCases:
    def test_on_hard_fail_does_not_abort_on_needs_review(self, tmp_path):
        """ON_HARD_FAIL should NOT abort when a shot just enters needs_review."""
        from recoil.pipeline._lib.run_episode import run_episode

        paths = MagicMock()
        paths.project = "test-proj"
        paths.project_root = tmp_path
        paths.plans_dir = tmp_path / "state" / "visual" / "plans"
        (tmp_path / "state" / "visual" / "runs").mkdir(parents=True, exist_ok=True)

        shot_plan = [
            {
                "shot_id": "SH01",
                "prompt": "p1",
                "pipeline": "keyframe",
                "scene_index": "SC01",
                "shot_type": "primary",
            },
            {
                "shot_id": "SH02",
                "prompt": "p2",
                "pipeline": "keyframe",
                "scene_index": "SC02",
                "shot_type": "primary",
            },
        ]

        call_count = {"n": 0}

        def mock_run_shot(
            shot,
            store,
            paths,
            budget_guard,
            model,
            step_runner=None,
            run_id="",
            style_anchor_path=None,
            coverage_context=None,
            stop_on_review=StopOnReview.NEVER,
        ):
            call_count["n"] += 1
            sid = shot["shot_id"]
            if sid == "SH01":
                # needs_review but NOT icu_escalated or crashed
                return OpResult(
                    status="needs_review",
                    shot_id=sid,
                    op_id=f"op_{'a' * 12}",
                    review_queue_id=f"rq_{'b' * 32}",
                    failure_mode="unknown",
                )
            return OpResult(
                status="ok",
                shot_id=sid,
                op_id=f"op_{'c' * 12}",
                output_path=f"/tmp/{sid}.jpg",
                cost_usd=0.10,
            )

        with patch(
            "recoil.pipeline._lib.run_episode.run_shot", side_effect=mock_run_shot
        ):
            result = run_episode(
                project="test-proj",
                episode_id="EP001",
                model="test-model",
                budget_usd=50.0,
                concurrency=1,
                stop_on_review=StopOnReview.ON_HARD_FAIL,
                no_style_anchor=True,
                step_runner=MagicMock(),
                store=MagicMock(),
                paths=paths,
                shot_plan=shot_plan,
            )

        # Both shots should run -- needs_review does NOT trigger ON_HARD_FAIL abort
        assert call_count["n"] == 2
        assert result.aborted is False

    def test_on_hard_fail_aborts_on_crashed(self, tmp_path):
        """ON_HARD_FAIL should abort when a shot crashes."""
        from recoil.pipeline._lib.run_episode import run_episode

        paths = MagicMock()
        paths.project = "test-proj"
        paths.project_root = tmp_path
        paths.plans_dir = tmp_path / "state" / "visual" / "plans"
        (tmp_path / "state" / "visual" / "runs").mkdir(parents=True, exist_ok=True)

        shot_plan = [
            {
                "shot_id": "SH01",
                "prompt": "p1",
                "pipeline": "keyframe",
                "scene_index": "SC01",
                "shot_type": "primary",
            },
            {
                "shot_id": "SH02",
                "prompt": "p2",
                "pipeline": "keyframe",
                "scene_index": "SC02",
                "shot_type": "primary",
            },
            {
                "shot_id": "SH03",
                "prompt": "p3",
                "pipeline": "keyframe",
                "scene_index": "SC03",
                "shot_type": "primary",
            },
        ]

        call_count = {"n": 0}

        def mock_run_shot(
            shot,
            store,
            paths,
            budget_guard,
            model,
            step_runner=None,
            run_id="",
            style_anchor_path=None,
            coverage_context=None,
            stop_on_review=StopOnReview.NEVER,
        ):
            call_count["n"] += 1
            sid = shot["shot_id"]
            if sid == "SH01":
                return OpResult(
                    status="crashed",
                    shot_id=sid,
                    op_id=f"op_{'d' * 12}",
                )
            return OpResult(
                status="ok",
                shot_id=sid,
                op_id=f"op_{'e' * 12}",
                output_path=f"/tmp/{sid}.jpg",
                cost_usd=0.10,
            )

        with patch(
            "recoil.pipeline._lib.run_episode.run_shot", side_effect=mock_run_shot
        ):
            result = run_episode(
                project="test-proj",
                episode_id="EP001",
                model="test-model",
                budget_usd=50.0,
                concurrency=1,
                stop_on_review=StopOnReview.ON_HARD_FAIL,
                no_style_anchor=True,
                step_runner=MagicMock(),
                store=MagicMock(),
                paths=paths,
                shot_plan=shot_plan,
            )

        assert result.aborted is True

    def test_on_any_review_aborts_on_attempts_exhausted(self, tmp_path):
        """ON_ANY_REVIEW aborts on any review queue entry (attempts_exhausted)."""
        from recoil.pipeline._lib.run_episode import run_episode

        paths = MagicMock()
        paths.project = "test-proj"
        paths.project_root = tmp_path
        paths.plans_dir = tmp_path / "state" / "visual" / "plans"
        (tmp_path / "state" / "visual" / "runs").mkdir(parents=True, exist_ok=True)

        shot_plan = [
            {
                "shot_id": "SH01",
                "prompt": "p1",
                "pipeline": "keyframe",
                "scene_index": "SC01",
                "shot_type": "primary",
            },
            {
                "shot_id": "SH02",
                "prompt": "p2",
                "pipeline": "keyframe",
                "scene_index": "SC02",
                "shot_type": "primary",
            },
        ]

        call_count = {"n": 0}

        def mock_run_shot(
            shot,
            store,
            paths,
            budget_guard,
            model,
            step_runner=None,
            run_id="",
            style_anchor_path=None,
            coverage_context=None,
            stop_on_review=StopOnReview.NEVER,
        ):
            call_count["n"] += 1
            sid = shot["shot_id"]
            if sid == "SH01":
                return OpResult(
                    status="attempts_exhausted",
                    shot_id=sid,
                    op_id=f"op_{'f' * 12}",
                    review_queue_id=f"rq_{'g' * 32}",
                    attempts=4,
                    failure_mode="anatomy_face_merge",
                )
            return OpResult(
                status="ok",
                shot_id=sid,
                op_id=f"op_{'h' * 12}",
                output_path=f"/tmp/{sid}.jpg",
                cost_usd=0.10,
            )

        with patch(
            "recoil.pipeline._lib.run_episode.run_shot", side_effect=mock_run_shot
        ):
            result = run_episode(
                project="test-proj",
                episode_id="EP001",
                model="test-model",
                budget_usd=50.0,
                concurrency=1,
                stop_on_review=StopOnReview.ON_ANY_REVIEW,
                no_style_anchor=True,
                step_runner=MagicMock(),
                store=MagicMock(),
                paths=paths,
                shot_plan=shot_plan,
            )

        assert result.aborted is True


# ---------------------------------------------------------------------------
# Tests: Review queue concurrent resolve safety
# ---------------------------------------------------------------------------


class TestReviewQueueConcurrentResolve:
    def test_concurrent_resolve_writes(self, tmp_path):
        """Multiple threads resolving simultaneously must not corrupt the file."""
        from recoil.pipeline._lib.review_queue import enqueue, resolve, list_pending

        queue_path = tmp_path / "review_queue.jsonl"

        # Enqueue 20 entries
        entries = []
        for i in range(20):
            e = enqueue(
                queue_path=queue_path,
                project="p",
                episode_id="E1",
                shot_id=f"SH{i:03d}",
                run_id="r1",
                reason="attempts_exhausted",
                failure_mode="unknown",
                attempts=[],
                total_cost_usd=0.01 * i,
            )
            entries.append(e)

        errors = []

        def worker(entry):
            try:
                resolve(
                    queue_path=queue_path,
                    rq_id=entry["rq_id"],
                    resolution="approved",
                    notes=f"resolved {entry['shot_id']}",
                )
            except Exception as e:
                errors.append(e)

        threads = [threading.Thread(target=worker, args=(e,)) for e in entries]
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        assert not errors
        # All entries should be resolved -- none pending
        pending = list_pending(queue_path=queue_path)
        assert len(pending) == 0
        # File should have 40 lines: 20 enqueues + 20 resolves
        lines = queue_path.read_text().strip().splitlines()
        assert len(lines) == 40
        for line in lines:
            json.loads(line)  # each line must be valid JSON


# ---------------------------------------------------------------------------
# Tests: budget_exhausted_success on last affordable attempt
# ---------------------------------------------------------------------------


class TestBudgetExhaustedSuccess:
    def test_shot_succeeds_on_last_affordable_attempt(self, tmp_path):
        """Shot uses all budget on first attempt, output is usable -> budget_exhausted_success."""
        from recoil.pipeline._lib.run_shot import run_shot

        paths = _make_paths(tmp_path)
        # Budget is exactly one attempt's cost
        guard = _make_budget_guard(limit=0.20)

        runner = MagicMock()
        # First attempt: fails with non-ACCEPT failure, produces output
        fail_verdict = _make_gate_verdict(
            passed=False,
            gate_name="gate_1",
            details={"failure_category": "anatomy_limb_miscount"},
        )
        runner.execute_keyframe.return_value = _make_step_result(
            success=False,
            final_state="keyframe_mechanical_failed",
            output_path="/tmp/last_attempt.jpg",
            cost_usd=0.15,
            gate_verdict=fail_verdict,
        )

        result = run_shot(
            shot=_make_shot(),
            store=MagicMock(),
            paths=paths,
            budget_guard=guard,
            model="test-model",
            step_runner=runner,
        )

        # First attempt runs (0.15 cost), then second attempt blocked by budget
        # Since we have output from attempt 1, status should be budget_exhausted_success
        assert result.status == "budget_exhausted_success"
        assert result.output_path == "/tmp/last_attempt.jpg"
        assert result.cost_usd > 0

    def test_budget_exhausted_no_output(self, tmp_path):
        """Budget exhausted with no usable output -> budget_exhausted (not success)."""
        from recoil.pipeline._lib.run_shot import run_shot

        paths = _make_paths(tmp_path)
        guard = _make_budget_guard(limit=0.20)

        runner = MagicMock()
        # First attempt: fails with no output
        fail_verdict = _make_gate_verdict(
            passed=False,
            gate_name="gate_1",
            details={"failure_category": "anatomy_limb_miscount"},
        )
        runner.execute_keyframe.return_value = _make_step_result(
            success=False,
            final_state="keyframe_mechanical_failed",
            output_path=None,
            cost_usd=0.15,
            gate_verdict=fail_verdict,
        )

        result = run_shot(
            shot=_make_shot(),
            store=MagicMock(),
            paths=paths,
            budget_guard=guard,
            model="test-model",
            step_runner=runner,
        )

        # No usable output -> plain budget_exhausted
        assert result.status == "budget_exhausted"
        assert result.output_path is None
