from dataclasses import asdict
from types import SimpleNamespace
from unittest.mock import patch

from recoil.execution.step_runner import StepRunner
from recoil.execution.step_types import GateVerdict, ProjectPaths
from recoil.execution.types import GenerationResult


def test_execute_pass_populates_identity_score_from_video_identity_gate(tmp_path):
    video_dir = tmp_path / "renders"
    video_dir.mkdir()
    frame_path = tmp_path / "boundary.jpg"
    frame_path.write_bytes(b"jpg")
    paths = ProjectPaths(
        project="test",
        project_root=tmp_path,
        frames_dir=tmp_path / "frames",
        video_dir=video_dir,
        plans_dir=tmp_path / "plans",
        previs_dir=tmp_path / "previs",
    )
    runner = StepRunner(store=SimpleNamespace(), paths=paths)

    class FakeVideoModelClient:
        def __init__(self, *args, **kwargs):
            pass

        def submit(self, payload):
            return SimpleNamespace(result=None)

        def wait_for_job(self, job, timeout_s, on_status=None):
            return GenerationResult(
                success=True,
                video_data=b"mp4",
                model="seeddance-2.0",
                cost=0.1,
                metadata={},
            )

    def identity_gate(_frame_path, _shot_data):
        return GateVerdict(
            passed=True,
            gate_name="gate_2_video",
            reason="noticeable mismatch",
            details={"total_score": 3},
            cost=0.039,
            retriable=True,
        )

    with (
        patch(
            "recoil.execution.video_model_client.VideoModelClient",
            FakeVideoModelClient,
        ),
        patch.object(runner, "_write_sidecar", lambda *args, **kwargs: None),
        patch.object(runner, "_extract_boundary_frames", lambda *args, **kwargs: [frame_path]),
    ):
        result = runner.execute_pass(
            pass_id="EP001_PASS_TEST",
            prompt="test prompt",
            reference_image_paths=[],
            segment_shot_ids=["EP001_SH01"],
            expected_segment_timestamps=[(0.0, 3.0)],
            gates=[identity_gate],
            pass_counter=1,
            tag="TEST",
        )

    assert result.success is True
    assert result.segment_results[0].identity_score == 0.4


def test_execute_pass_fails_take_on_gate_failure(tmp_path):
    """REC-63: a FAILED identity gate on a segment must make execute_pass return
    success=False so the EpisodeRunner strategy branch fires a retry. Before the
    fix, execute_pass returned success=True regardless of gate outcome (the gate
    was toothless on the r2v_multi path)."""
    video_dir = tmp_path / "renders"
    video_dir.mkdir()
    frame_path = tmp_path / "boundary.jpg"
    frame_path.write_bytes(b"jpg")
    paths = ProjectPaths(
        project="test",
        project_root=tmp_path,
        frames_dir=tmp_path / "frames",
        video_dir=video_dir,
        plans_dir=tmp_path / "plans",
        previs_dir=tmp_path / "previs",
    )
    runner = StepRunner(store=SimpleNamespace(), paths=paths)

    class FakeVideoModelClient:
        def __init__(self, *args, **kwargs):
            pass

        def submit(self, payload):
            return SimpleNamespace(result=None)

        def wait_for_job(self, job, timeout_s, on_status=None):
            return GenerationResult(
                success=True,
                video_data=b"mp4",
                model="seeddance-2.0",
                cost=0.1,
                metadata={},
            )

    def failing_gate(_frame_path, _shot_data):
        return GateVerdict(
            passed=False,
            gate_name="gate_2_video",
            reason="critical identity mismatch",
            details={"total_score": 4},
            cost=0.039,
            retriable=True,
        )

    with (
        patch(
            "recoil.execution.video_model_client.VideoModelClient",
            FakeVideoModelClient,
        ),
        patch.object(runner, "_write_sidecar", lambda *args, **kwargs: None),
        patch.object(runner, "_extract_boundary_frames", lambda *args, **kwargs: [frame_path]),
    ):
        result = runner.execute_pass(
            pass_id="EP001_PASS_TEST",
            prompt="test prompt",
            reference_image_paths=[],
            segment_shot_ids=["EP001_SH01"],
            expected_segment_timestamps=[(0.0, 3.0)],
            gates=[failing_gate],
            pass_counter=1,
            tag="TEST",
        )

    assert result.success is False
    assert result.error == "critical identity mismatch"
    assert result.segment_results[0].usable is False


def test_execute_pass_records_gate_crash_on_segment(tmp_path):
    video_dir = tmp_path / "renders"
    video_dir.mkdir()
    frame_path = tmp_path / "boundary.jpg"
    frame_path.write_bytes(b"jpg")
    paths = ProjectPaths(
        project="test",
        project_root=tmp_path,
        frames_dir=tmp_path / "frames",
        video_dir=video_dir,
        plans_dir=tmp_path / "plans",
        previs_dir=tmp_path / "previs",
    )
    runner = StepRunner(store=SimpleNamespace(), paths=paths)

    class FakeVideoModelClient:
        def __init__(self, *args, **kwargs):
            pass

        def submit(self, payload):
            return SimpleNamespace(result=None)

        def wait_for_job(self, job, timeout_s, on_status=None):
            return GenerationResult(
                success=True,
                video_data=b"mp4",
                model="seeddance-2.0",
                cost=0.1,
                metadata={},
            )

    def crashing_gate(_frame_path, _shot_data):
        raise RuntimeError("identity service unavailable")

    with (
        patch(
            "recoil.execution.video_model_client.VideoModelClient",
            FakeVideoModelClient,
        ),
        patch.object(runner, "_write_sidecar", lambda *args, **kwargs: None),
        patch.object(runner, "_extract_boundary_frames", lambda *args, **kwargs: [frame_path]),
    ):
        result = runner.execute_pass(
            pass_id="EP001_PASS_TEST",
            prompt="test prompt",
            reference_image_paths=[],
            segment_shot_ids=["EP001_SH01"],
            expected_segment_timestamps=[(0.0, 3.0)],
            gates=[crashing_gate],
            pass_counter=1,
            tag="TEST",
        )

    segment = result.segment_results[0]
    assert segment.usable is False
    assert "identity service unavailable" in segment.gate_error
    assert (
        asdict(result)["segment_results"][0]["gate_error"]
        == segment.gate_error
    )
