"""REC-122 gate-2 — failed finalization must carry provider-reported cost."""
import logging
from types import SimpleNamespace

import pytest

from recoil.execution.providers import PollRequest, PollResult
from recoil.execution.types import GenerationResult
from recoil.execution import video_model_client as vmc
from recoil.execution.video_model_client import VideoModelClient


class _FakeClock:
    def __init__(self):
        self.now = 0.0
        self.sleeps = []

    def time(self):
        return self.now

    def sleep(self, seconds):
        self.sleeps.append(seconds)
        self.now += seconds


class _SequencedAdapter:
    def __init__(
        self,
        *,
        provider_id="flora",
        poll_results=None,
        poll_url="https://flora.test/runs/run_x",
        charged_cost=1.06,
        transport=None,
    ):
        self.provider_id = provider_id
        self.poll_results = list(poll_results or [])
        self.poll_url = poll_url
        self.build_poll_calls = 0
        self.transport = transport
        self._charged_cost = charged_cost

    def build_poll(self, pj):
        self.build_poll_calls += 1
        return PollRequest("GET", self.poll_url, {})

    def parse_poll(self, resp, pj):
        if self.poll_results:
            return self.poll_results.pop(0)
        return PollResult(
            status="IN_PROGRESS",
            observed_cost=self._charged_cost,
            raw={"status": "running"},
        )

    def build_result_fetch(self, pj):
        return None

    def parse_result(self, resp, pj):
        return self.parse_poll(resp, pj)

    def compute_cost(self, duration_s, tier, profile):
        return 1.25


def _job(adapter, *, native_id="run_x", poll_url="https://flora.test/runs/run_x"):
    pj = SimpleNamespace(
        native_id=native_id,
        tier="standard_720p",
        duration_s=5,
        resolution="720p",
        native_state={
            "poll_url": poll_url,
            "charged_cost": 0.75,
            "model_path": "flora/model",
        },
    )
    marked = {}
    job = SimpleNamespace(
        result=None,
        job_id=native_id,
        _provider_adapter=adapter,
        _provider_job=pj,
        _shot_id="shot-1",
        mark_failed=lambda err: marked.setdefault("err", err),
        mark_complete=lambda result: marked.setdefault("complete", result),
    )
    return job, pj, marked


@pytest.fixture
def video_client_fakes(monkeypatch):
    records = []
    clock = _FakeClock()
    monkeypatch.setattr(vmc.time, "time", clock.time)
    monkeypatch.setattr(vmc.time, "sleep", clock.sleep)
    monkeypatch.setattr(
        "recoil.execution.video_model_client.model_profiles.get_profile",
        lambda model_id: {"cost_per_second": 0.25},
    )
    monkeypatch.setattr(
        "recoil.execution.video_model_client._obs.record_call",
        lambda **kwargs: records.append(kwargs),
    )
    monkeypatch.setattr(
        "recoil.execution.video_model_client._lazy_imports",
        lambda: (GenerationResult, None, lambda url: b"video-bytes"),
    )
    monkeypatch.setattr(
        "recoil.execution.video_model_client._http", lambda *a, **k: {}
    )
    monkeypatch.setenv("RECOIL_ORPHAN_RECOVERY_S", "1")
    return records, clock


def test_finalize_failed_threads_observed_cost():
    """A provider-billed failure (observed_cost > 0) must surface in
    GenerationResult.cost — dropping it undercounts real spend now that
    failed results report only provider-confirmed cost (REC-122)."""
    fake_self = SimpleNamespace(_model_id="seeddance-2.0")
    adapter = SimpleNamespace(provider_id="flora")
    pj = SimpleNamespace(
        native_id="run_x", tier="standard_720p", duration_s=15,
        resolution="720p",
    )
    marked = {}
    job = SimpleNamespace(
        mark_failed=lambda err: marked.setdefault("err", err),
        result=None,
    )
    poll = SimpleNamespace(
        error="flora: BILLING_NOT_ENOUGH_CREDITS", observed_cost=1.23,
    )

    result = VideoModelClient._finalize_failed(
        fake_self, adapter, pj, job, poll, submit_start=0.0
    )

    assert result.success is False
    assert result.cost == 1.23
    assert "BILLING_NOT_ENOUGH_CREDITS" in (result.error or "")


def test_finalize_failed_zero_cost_for_unbilled_failure():
    fake_self = SimpleNamespace(_model_id="seeddance-2.0")
    adapter = SimpleNamespace(provider_id="flora")
    pj = SimpleNamespace(
        native_id="run_y", tier="standard_720p", duration_s=15,
        resolution="720p",
    )
    job = SimpleNamespace(mark_failed=lambda err: None, result=None)
    poll = SimpleNamespace(error=None, observed_cost=None)

    result = VideoModelClient._finalize_failed(
        fake_self, adapter, pj, job, poll, submit_start=0.0
    )
    assert result.success is False
    assert result.cost == 0.0


def test_wait_for_job_timeout_records_cost_and_attempts_cancel(monkeypatch):
    records = []
    cancel_calls = []

    monkeypatch.setattr(
        "recoil.execution.video_model_client.model_profiles.get_profile",
        lambda model_id: {"cost_per_second": 0.25},
    )
    monkeypatch.setattr(
        "recoil.execution.video_model_client._obs.record_call",
        lambda **kwargs: records.append(kwargs),
    )

    transport = SimpleNamespace(
        cancel=lambda model_path, native_id: cancel_calls.append(
            (model_path, native_id)
        )
    )
    adapter = SimpleNamespace(
        provider_id="fal",
        transport=transport,
        compute_cost=lambda duration_s, tier, profile: 1.25,
    )
    pj = SimpleNamespace(
        native_id="req-timeout",
        tier="standard_720p",
        duration_s=5,
        resolution="720p",
        native_state={
            "model_path": "bytedance/seedance-2.0/text-to-video",
            "response_url": "https://queue.fal.run/result/req-timeout",
            "charged_cost": 0.75,
        },
    )
    marked = {}
    job = SimpleNamespace(
        result=None,
        job_id="req-timeout",
        _provider_adapter=adapter,
        _provider_job=pj,
        _shot_id="shot-1",
        mark_failed=lambda err: marked.setdefault("err", err),
    )

    result = VideoModelClient("seeddance-2.0").wait_for_job(job, timeout_s=0)

    assert result.success is False
    assert result.cost == 0.75
    assert "native_id=req-timeout" in (result.error or "")
    assert "response_url=https://queue.fal.run/result/req-timeout" in (
        result.error or ""
    )
    assert cancel_calls == [
        ("bytedance/seedance-2.0/text-to-video", "req-timeout")
    ]
    assert records[-1]["status"] == "TIMEOUT"
    assert records[-1]["listed_cost"] == 1.25
    assert records[-1]["observed_cost"] == 0.75
    assert records[-1]["task_id"] == "req-timeout"
    assert marked["err"] == result.error


def test_wait_for_job_timeout_uses_listed_cost_without_submit_cost(monkeypatch):
    records = []
    monkeypatch.setattr(
        "recoil.execution.video_model_client.model_profiles.get_profile",
        lambda model_id: {"cost_per_second": 0.25},
    )
    monkeypatch.setattr(
        "recoil.execution.video_model_client._obs.record_call",
        lambda **kwargs: records.append(kwargs),
    )

    adapter = SimpleNamespace(
        provider_id="fal",
        transport=SimpleNamespace(cancel=lambda *_args: None),
        compute_cost=lambda duration_s, tier, profile: 1.25,
    )
    pj = SimpleNamespace(
        native_id="req-listed",
        tier="standard_720p",
        duration_s=5,
        resolution="720p",
        native_state={
            "model_path": "bytedance/seedance-2.0/text-to-video",
            "response_url": "https://queue.fal.run/result/req-listed",
        },
    )
    job = SimpleNamespace(
        result=None,
        job_id="req-listed",
        _provider_adapter=adapter,
        _provider_job=pj,
        mark_failed=lambda err: None,
    )

    result = VideoModelClient("seeddance-2.0").wait_for_job(job, timeout_s=0)

    assert result.success is False
    assert result.cost == 1.25
    assert records[-1]["listed_cost"] == 1.25
    assert records[-1]["observed_cost"] is None


def test_flora_poll_exceptions_recover_to_normal_success(
    monkeypatch, caplog, video_client_fakes
):
    records, _clock = video_client_fakes
    calls = {"n": 0}

    def flaky_http(*_args, **_kwargs):
        calls["n"] += 1
        if calls["n"] == 1:
            raise RuntimeError("temporary 500")
        return {}

    monkeypatch.setattr("recoil.execution.video_model_client._http", flaky_http)
    adapter = _SequencedAdapter(
        poll_results=[
            PollResult(
                status="COMPLETED",
                video_url="https://flora.test/out.mp4",
                observed_cost=0.75,
                raw={"status": "completed", "outputs": [{"url": "x"}]},
            )
        ]
    )
    job, _pj, marked = _job(adapter)

    caplog.set_level(logging.WARNING)
    result = VideoModelClient("seeddance-2.0").wait_for_job(job, timeout_s=1)

    assert result == marked["complete"]
    assert result.success is True
    assert result.video_data == b"video-bytes"
    assert result.video_url == "https://flora.test/out.mp4"
    assert result.metadata["provider"] == "flora"
    assert result.metadata["request_id"] == "run_x"
    assert records[-1]["status"] == "COMPLETED"
    assert "recovered orphaned run run_x after" in caplog.text

    direct_adapter = _SequencedAdapter(
        poll_results=[
            PollResult(
                status="COMPLETED",
                video_url="https://flora.test/out.mp4",
                observed_cost=0.75,
                raw={"status": "completed", "outputs": [{"url": "x"}]},
            )
        ]
    )
    direct_job, _direct_pj, _direct_marked = _job(direct_adapter)
    direct = VideoModelClient("seeddance-2.0").wait_for_job(
        direct_job, timeout_s=10
    )
    assert result == direct


def test_flora_recovery_provider_failed_returns_failure(video_client_fakes):
    records, _clock = video_client_fakes
    adapter = _SequencedAdapter(
        poll_results=[
            PollResult(
                status="IN_PROGRESS",
                observed_cost=0.75,
                raw={"status": "running"},
            ),
            PollResult(
                status="FAILED",
                observed_cost=0.75,
                error="flora: provider failed",
                raw={"status": "failed", "error_code": "FAILED"},
            ),
            PollResult(
                status="COMPLETED",
                video_url="https://flora.test/too-late.mp4",
                raw={"status": "completed"},
            ),
        ]
    )
    job, _pj, _marked = _job(adapter)

    result = VideoModelClient("seeddance-2.0").wait_for_job(job, timeout_s=1)

    assert result.success is False
    assert "provider failed" in (result.error or "")
    assert adapter.build_poll_calls == 2
    assert records[-1]["status"] == "FAILED"


def test_flora_recovery_deadline_orphans_running_with_submit_cost(
    video_client_fakes,
):
    records, _clock = video_client_fakes
    adapter = _SequencedAdapter(
        poll_results=[
            PollResult(
                status="IN_PROGRESS",
                observed_cost=0.75,
                raw={"status": "running"},
            ),
            PollResult(
                status="IN_PROGRESS",
                observed_cost=0.75,
                raw={"status": "running"},
            ),
        ]
    )
    job, _pj, marked = _job(adapter)

    result = VideoModelClient("seeddance-2.0").wait_for_job(job, timeout_s=1)

    assert result.success is False
    assert result.cost == 0.75
    assert "native_id=run_x" in (result.error or "")
    assert "poll_url=https://flora.test/runs/run_x" in (result.error or "")
    assert "recoverable: job may still complete server-side" in (result.error or "")
    assert records[-1]["status"] == "ORPHANED"
    assert records[-1]["observed_cost"] == 0.75
    assert marked["err"] == result.error


def test_flora_provider_failed_before_recovery_does_not_repoll(video_client_fakes):
    adapter = _SequencedAdapter(
        poll_results=[
            PollResult(
                status="FAILED",
                observed_cost=0.75,
                error=None,
                raw={"status": "failed"},
            ),
            PollResult(
                status="COMPLETED",
                video_url="https://flora.test/should-not-poll.mp4",
                raw={"status": "completed"},
            ),
        ]
    )
    job, _pj, _marked = _job(adapter)

    result = VideoModelClient("seeddance-2.0").wait_for_job(job, timeout_s=10)

    assert result.success is False
    assert adapter.build_poll_calls == 1


def test_orphan_recovery_zero_preserves_timeout_behavior(monkeypatch):
    records = []
    cancel_calls = []
    monkeypatch.setenv("RECOIL_ORPHAN_RECOVERY_S", "0")
    monkeypatch.setattr(
        "recoil.execution.video_model_client.model_profiles.get_profile",
        lambda model_id: {"cost_per_second": 0.25},
    )
    monkeypatch.setattr(
        "recoil.execution.video_model_client._obs.record_call",
        lambda **kwargs: records.append(kwargs),
    )
    monkeypatch.setattr(
        "recoil.execution.video_model_client._lazy_imports",
        lambda: (GenerationResult, None, lambda url: b"video-bytes"),
    )
    transport = SimpleNamespace(
        cancel=lambda model_path, native_id: cancel_calls.append(
            (model_path, native_id)
        )
    )
    adapter = _SequencedAdapter(transport=transport)
    job, pj, marked = _job(adapter, native_id="run-disabled")

    result = VideoModelClient("seeddance-2.0").wait_for_job(job, timeout_s=0)

    assert result.success is False
    assert result.cost == 0.75
    assert "native_id=run-disabled" in (result.error or "")
    assert "response_url=None" in (result.error or "")
    assert "recoverable:" not in (result.error or "")
    assert adapter.build_poll_calls == 0
    assert cancel_calls == [("flora/model", "run-disabled")]
    assert records[-1]["status"] == "TIMEOUT"
    assert marked["err"] == result.error
    assert pj.native_id == "run-disabled"


def test_orphan_recovery_junk_env_raises_value_error(monkeypatch, video_client_fakes):
    monkeypatch.setenv("RECOIL_ORPHAN_RECOVERY_S", "junk")
    adapter = _SequencedAdapter(
        poll_results=[
            PollResult(
                status="IN_PROGRESS",
                observed_cost=0.75,
                raw={"status": "running"},
            )
        ]
    )
    job, _pj, _marked = _job(adapter)

    with pytest.raises(ValueError, match="RECOIL_ORPHAN_RECOVERY_S"):
        VideoModelClient("seeddance-2.0").wait_for_job(job, timeout_s=1)


def test_non_flora_timeout_does_not_recover(video_client_fakes):
    records, _clock = video_client_fakes
    adapter = _SequencedAdapter(provider_id="fal")
    job, _pj, _marked = _job(adapter, native_id="req-fal")

    result = VideoModelClient("seeddance-2.0").wait_for_job(job, timeout_s=0)

    assert result.success is False
    assert adapter.build_poll_calls == 0
    assert "response_url=None" in (result.error or "")
    assert "recoverable:" not in (result.error or "")
    assert records[-1]["status"] == "TIMEOUT"


def test_flora_failed_with_provider_error_payload_is_terminal(video_client_fakes):
    adapter = _SequencedAdapter(
        poll_results=[
            PollResult(
                status="FAILED",
                observed_cost=0.75,
                error="flora: BILLING_NOT_ENOUGH_CREDITS",
                raw={"status": "failed", "error_code": "BILLING_NOT_ENOUGH_CREDITS"},
            ),
            PollResult(
                status="COMPLETED",
                video_url="https://flora.test/should-not-poll.mp4",
                raw={"status": "completed"},
            ),
        ]
    )
    job, _pj, _marked = _job(adapter)

    result = VideoModelClient("seeddance-2.0").wait_for_job(job, timeout_s=10)

    assert result.success is False
    assert "BILLING_NOT_ENOUGH_CREDITS" in (result.error or "")
    assert adapter.build_poll_calls == 1


def test_flora_completed_no_output_at_recovery_deadline_fails(video_client_fakes):
    adapter = _SequencedAdapter(
        poll_results=[
            PollResult(
                status="COMPLETED",
                observed_cost=0.75,
                raw={"status": "completed", "outputs": []},
            ),
            PollResult(
                status="COMPLETED",
                observed_cost=0.75,
                raw={"status": "completed", "outputs": []},
            ),
        ]
    )
    job, _pj, _marked = _job(adapter)

    result = VideoModelClient("seeddance-2.0").wait_for_job(job, timeout_s=10)

    assert result.success is False
    assert "completed_no_output" in (result.error or "")
    assert adapter.build_poll_calls == 2
