import asyncio
import importlib
from unittest.mock import patch

import pytest

from recoil.execution.step_types import GateVerdict, StepResult
from recoil.pipeline.core.take import Beat, Scene
from recoil.pipeline.core.workflow import Workflow, WorkflowStep
from recoil.pipeline.orchestrator.episode_runner import EpisodeRunner
from recoil.pipeline.orchestrator.learning_engine import LearningEngine
from recoil.pipeline.orchestrator.production_types import RetryCostPolicy
from recoil.pipeline.orchestrator.strategy_registry import StrategyEngine


@pytest.fixture(autouse=True)
def _isolate(monkeypatch):
    from recoil.pipeline.core import registry

    dispatch = importlib.import_module("recoil.pipeline.core.dispatch")

    registry._reset_for_tests()
    dispatch._reset_bootstrap_for_tests()
    monkeypatch.setattr(
        "recoil.pipeline.orchestrator.episode_runner.init_scenes_from_plan",
        lambda *a, **kw: None,
    )
    monkeypatch.setattr(
        "recoil.pipeline.orchestrator.episode_runner.save_active_scene",
        lambda *a, **kw: None,
    )
    monkeypatch.setattr(
        "recoil.pipeline.orchestrator.episode_runner.ops_log.write",
        lambda *a, **kw: None,
    )
    monkeypatch.setattr(
        EpisodeRunner,
        "_estimate_take_cost",
        lambda self, beat: 0.0,
    )
    yield
    registry._reset_for_tests()
    dispatch._reset_bootstrap_for_tests()


class FakeGateAwareStepRunner:
    def __init__(self, tmp_path):
        self.tmp_path = tmp_path
        self.calls = []

    def execute_video(
        self,
        *,
        shot_id,
        prompt,
        model,
        gates=None,
        provider_hints=None,
        **kwargs,
    ):
        self.calls.append(
            {
                "shot_id": shot_id,
                "gates": gates,
                "provider_hints": dict(provider_hints or {}),
            }
        )
        output_path = self.tmp_path / f"take_{len(self.calls)}.mp4"
        output_path.write_bytes(b"mp4")

        if len(self.calls) == 1:
            verdict = gates[0](output_path, {"shot_id": shot_id})
            return StepResult(
                shot_id=shot_id,
                success=False,
                final_state="video_mechanical_failed",
                output_path=str(output_path),
                cost_usd=verdict.cost,
                error=verdict.reason,
                take_index=0,
                gate_verdict=verdict,
                model=model,
                pipeline="video",
            )

        return StepResult(
            shot_id=shot_id,
            success=True,
            final_state="video_complete",
            output_path=str(output_path),
            cost_usd=0.01,
            error=None,
            take_index=len(self.calls) - 1,
            gate_verdict=None,
            model=model,
            pipeline="video",
        )


def test_gate_failed_video_take_reaches_strategy_retry_without_dispatch_edit(tmp_path):
    gate_calls = []

    def failing_identity_gate(video_path, shot_data):
        gate_calls.append((video_path, shot_data))
        return GateVerdict(
            passed=False,
            gate_name="gate_2_video",
            reason="identity drift total_score=3",
            details={"total_score": 3},
            cost=0.039,
            retriable=True,
        )

    def _stub_wf(self, beat, take_index, **kw):
        return Workflow(
            workflow_id=f"{beat.beat_id}_take_{take_index}",
            steps=[
                WorkflowStep(
                    step_id="video",
                    modality="video_i2v",
                    payload={
                        "shot_id": beat.beat_id,
                        "prompt": "test prompt",
                        "model": "seeddance-2.0",
                        "duration": 3,
                        "aspect_ratio": "9:16",
                        "gates": [failing_identity_gate],
                        "provider_hints": {},
                    },
                )
            ],
            global_provenance={"shot_id": beat.beat_id, "episode": self.episode},
        )

    step_runner = FakeGateAwareStepRunner(tmp_path)
    learning = LearningEngine(project="test_proj", state_dir=tmp_path)
    engine = StrategyEngine(learning=learning, model="seeddance-2.0")
    runner = EpisodeRunner(
        project="test_proj",
        plan={"sequences": {"seq_01": {"shots": [{"shot_id": "EP001_SH01"}]}}},
        max_takes=2,
        budget_usd=50.0,
        step_runner=step_runner,
        strategy_engine=engine,
        retry_cost_policy=RetryCostPolicy(max_retry_spend_usd=6.0),
    )
    beat = Beat(
        beat_id="EP001_SH01",
        max_takes=2,
        beat_metadata={
            "scene_id": "seq_01",
            "shot": {
                "shot_id": "EP001_SH01",
                "duration_s": 3,
                "shot_type": "MS",
                "description": "Jade walks down the corridor.",
                "characters": [{"char_id": "JADE"}],
            },
            "modality": "video_i2v",
        },
    )
    scene = Scene(
        scene_id="seq_01",
        beats=[beat],
        scene_metadata={"episode": "ep_001", "project": "test_proj"},
    )

    with patch.object(EpisodeRunner, "_build_workflow_for_beat", _stub_wf):
        asyncio.run(
            runner._dispatch_one_beat(
                beat,
                scene,
                1,  # expected_version (REC-231 Phase 4)
                dry_run=False,
            )
        )

    assert len(gate_calls) == 1
    assert len(beat.takes) >= 2
    first_result = beat.takes[0].workflow.steps[0].receipt.run_result
    assert first_result.success is False
    assert first_result.metadata["final_state"] == "video_mechanical_failed"

    retry_payload = beat.takes[1].workflow.steps[0].payload
    assert (
        retry_payload.get("_retry_strategy")
        or retry_payload.get("provider_hints", {}).get("seed") is not None
    )
