from __future__ import annotations

import io
import json
import urllib.error
import urllib.request
from types import SimpleNamespace

import pytest

from recoil.execution.providers import flora
from recoil.execution.providers.base import UnifiedVideoPayload
from recoil.execution.providers.payload_hints import StepRunnerHints, coerce_to_dict
from recoil.execution.types import GenerationResult
from recoil.pipeline.cli import generate


PROJECT = "fixture"


@pytest.fixture(autouse=True)
def _projects_root(tmp_path, monkeypatch):
    root = tmp_path / "projects"
    root.mkdir()
    (root / ".recoil-data-root").touch()
    project_root = root / PROJECT
    project_root.mkdir()
    monkeypatch.setenv("RECOIL_PROJECTS_ROOT", str(root))
    return project_root


class _Resp:
    def __init__(self, body: bytes = b"{}"):
        self._body = body

    def __enter__(self):
        return self

    def __exit__(self, *_exc):
        return False

    def read(self) -> bytes:
        return self._body


def _asset_from_url(url: str, suffix: str) -> str:
    return url.split("/assets/", 1)[1].split(suffix, 1)[0]


def _http_error(url: str, code: int) -> urllib.error.HTTPError:
    return urllib.error.HTTPError(
        url,
        code,
        "error",
        hdrs={},
        fp=io.BytesIO(f"err-{code}".encode()),
    )


def _install_flora_http(
    monkeypatch,
    *,
    attach_errors: dict[str, int] | None = None,
    start_reserve: int = 0,
):
    calls: list[urllib.request.Request] = []
    state = {"reserve": start_reserve}
    attach_errors = dict(attach_errors or {})

    def _urlopen(req, timeout=30):
        calls.append(req)
        url = req.full_url
        method = req.get_method()
        if url.endswith("/assets") and method == "POST":
            state["reserve"] += 1
            asset_id = f"asset_{state['reserve']}"
            body = {
                "asset_id": asset_id,
                "url": f"https://cdn.example/{asset_id}-reserve.png",
                "upload": {
                    "url": f"https://upload.example/{asset_id}",
                    "method": "POST",
                    "file_field": "file",
                    "form_fields": {"token": "t"},
                },
            }
            return _Resp(json.dumps(body).encode())
        if url.startswith("https://upload.example/"):
            return _Resp(b"")
        if "/complete" in url:
            asset_id = _asset_from_url(url, "/complete")
            return _Resp(
                json.dumps(
                    {"download_url": f"https://cdn.example/{asset_id}.png"}
                ).encode()
            )
        if "/attach" in url:
            asset_id = _asset_from_url(url, "/attach")
            if asset_id in attach_errors:
                raise _http_error(url, attach_errors.pop(asset_id))
            return _Resp(b"{}")
        raise AssertionError(f"unexpected request: {method} {url}")

    monkeypatch.setattr(urllib.request, "urlopen", _urlopen)
    return calls


def _write_ref(tmp_path, data: bytes = b"same-bytes") -> str:
    path = tmp_path / "ref.png"
    path.write_bytes(data)
    return str(path)


def _upload(ref: str, *, project_id: str = "prj_a", project: str | None = PROJECT):
    return flora._upload_local_refs(
        [ref],
        api_key="ak_test",
        workspace_id="ws_test",
        project_id=project_id,
        project=project,
    )


def test_second_upload_same_bytes_project_only_attaches_cached_asset(tmp_path, monkeypatch):
    ref = _write_ref(tmp_path)
    calls = _install_flora_http(monkeypatch)

    assert _upload(ref) == ["https://cdn.example/asset_1.png"]
    calls.clear()

    assert _upload(ref) == ["https://cdn.example/asset_1.png"]
    assert len(calls) == 1
    assert calls[0].full_url.endswith("/projects/prj_a/assets/asset_1/attach")


def test_same_bytes_different_flora_project_id_reuploads(tmp_path, monkeypatch):
    ref = _write_ref(tmp_path)
    calls = _install_flora_http(monkeypatch)

    _upload(ref, project_id="prj_a")
    calls.clear()

    assert _upload(ref, project_id="prj_b") == ["https://cdn.example/asset_2.png"]
    assert [c.get_method() for c in calls] == ["POST", "POST", "POST", "POST"]
    assert calls[-1].full_url.endswith("/projects/prj_b/assets/asset_2/attach")


def test_corrupt_cache_fails_open_to_reupload(tmp_path, monkeypatch):
    ref = _write_ref(tmp_path)
    cache_file = flora._cache_path(PROJECT)
    cache_file.parent.mkdir(parents=True, exist_ok=True)
    cache_file.write_text("{not-json", encoding="utf-8")
    calls = _install_flora_http(monkeypatch)

    assert _upload(ref) == ["https://cdn.example/asset_1.png"]
    assert len(calls) == 4


def test_cached_attach_404_drops_entry_and_reuploads_once(tmp_path, monkeypatch):
    ref = _write_ref(tmp_path)
    calls = _install_flora_http(monkeypatch)
    _upload(ref)

    calls = _install_flora_http(
        monkeypatch,
        attach_errors={"asset_1": 404},
        start_reserve=1,
    )
    assert _upload(ref) == ["https://cdn.example/asset_2.png"]
    assert len(calls) == 5

    entry = flora._cache_read(flora._cache_path(PROJECT))[flora._file_sha256(ref)]
    assert entry["asset_id"] == "asset_2"
    assert entry["hosted_url"] == "https://cdn.example/asset_2.png"


def test_non_404_cached_attach_failure_raises(tmp_path, monkeypatch):
    ref = _write_ref(tmp_path)
    _install_flora_http(monkeypatch)
    _upload(ref)

    calls = _install_flora_http(monkeypatch, attach_errors={"asset_1": 500})
    with pytest.raises(RuntimeError, match="HTTP 500"):
        _upload(ref)
    assert len(calls) == 1


def test_project_none_disables_cache_reads_and_writes(tmp_path, monkeypatch):
    ref = _write_ref(tmp_path)
    calls = _install_flora_http(monkeypatch)
    monkeypatch.setattr(flora, "_cache_read", lambda *_a, **_kw: pytest.fail("read"))
    monkeypatch.setattr(flora, "_cache_update", lambda *_a, **_kw: pytest.fail("write"))

    assert _upload(ref, project=None) == ["https://cdn.example/asset_1.png"]
    assert len(calls) == 4
    assert not flora._cache_path(PROJECT).exists()


def test_cache_update_read_modify_write_preserves_existing_entries(_projects_root):
    cache_file = flora._cache_path(PROJECT)

    flora._cache_update(cache_file, lambda data: {**data, "a": {"asset_id": "one"}})
    flora._cache_update(cache_file, lambda data: {**data, "b": {"asset_id": "two"}})

    assert flora._cache_read(cache_file) == {
        "a": {"asset_id": "one"},
        "b": {"asset_id": "two"},
    }


def test_build_submit_threads_project_hint_to_ref_upload(tmp_path, monkeypatch):
    ref = _write_ref(tmp_path)
    seen: dict[str, object] = {}
    monkeypatch.setenv("FLORA_API_KEY", "ak_test")
    monkeypatch.setenv("RECOIL_FLORA_WORKSPACE", "ws_test")
    monkeypatch.setenv("RECOIL_FLORA_PROJECT", "prj_a")

    def _fake_upload(local_paths, api_key, workspace_id, project_id, project=None):
        seen["project"] = project
        return ["https://cdn.example/ref.png"]

    monkeypatch.setattr(flora, "_upload_local_refs", _fake_upload)
    payload = UnifiedVideoPayload(
        prompt="prompt",
        reference_images=[ref],
        hints=StepRunnerHints(project=PROJECT, episode=7),
    )

    flora.ADAPTER.build_submit(payload, tier="standard_720p")

    assert seen["project"] == PROJECT


def test_generate_step_runner_stamps_real_episode_on_execute_pass(_projects_root, monkeypatch):
    captured: dict[str, object] = {}

    class _FakeVideoModelClient:
        def __init__(self, model_id, tier=None):
            self.model_id = model_id
            self.tier = tier

        def submit(self, payload):
            captured["hints"] = payload["hints"]
            return SimpleNamespace(
                result=GenerationResult(
                    success=False,
                    model=self.model_id,
                    error="stop",
                    cost=0.0,
                )
            )

    import recoil.execution.video_model_client as video_model_client

    monkeypatch.setattr(video_model_client, "VideoModelClient", _FakeVideoModelClient)
    runner = generate._build_step_runner_for_episode(PROJECT, 7)

    runner.execute_pass(
        pass_id="EP007_PASS",
        prompt="prompt",
        reference_image_paths=[],
        segment_shot_ids=["EP007_SH10"],
        expected_segment_timestamps=[(0.0, 3.0)],
        model="seeddance-2.0",
        # REC-143: grouping identity is now validated PRE-spend; supply the
        # same minimal identity the execute_pass contract tests use.
        pass_counter=1,
        tag="TEST",
    )

    hints = coerce_to_dict(captured["hints"])
    assert hints["project"] == PROJECT
    assert hints["episode"] == 7
