"""CP-8 Phase 4 — LipSyncPostProcessor unit tests.

Mocks the sync.so adapter via the `_transport=` payload key. Exercises the
5-step protocol (upload×2, submit, poll PROCESSING, poll COMPLETED, download)
plus all error classes and contract guarantees.
"""

import json
import sys
import pathlib
from unittest.mock import patch

sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent.parent.parent))
from recoil.core.paths import ensure_pipeline_importable, RECOIL_ROOT  # noqa: E402

ensure_pipeline_importable()

import pytest  # noqa: E402

from recoil.pipeline.core.registry import (  # noqa: E402
    MODALITY_LIPSYNC_POST,
    ModalityRunner,
    RunResult,
    _reset_for_tests,
)
from recoil.pipeline.core.runners.lipsync_post import LipSyncPostProcessor  # noqa: E402
from recoil.execution.providers import sync_so as _sync_so  # noqa: E402


# ── Fake response + 5-step protocol transport ──────────────────────────


class _FakeResponse:
    def __init__(self, body: bytes, status: int = 200):
        self._body = body
        self.status = status
        self.headers = {}

    def read(self) -> bytes:
        return self._body

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc, tb):
        return False


def _make_5step_transport(
    *,
    job_id: str = "job_abc_123",
    output_bytes: bytes = b"FAKE_MP4_BYTES",
    duration_s: float = 4.2,
    poll_processing_count: int = 1,
):
    """Build a transport that simulates: upload×2, submit, poll PROCESSING×N,
    poll COMPLETED, download."""
    state = {"poll_count": 0}

    def _transport(url, *, headers, body, method="GET", timeout=60.0):
        # Step 1+2: upload (POST /v2/upload) — return a presigned URL
        if "/v2/upload" in url:
            return _FakeResponse(json.dumps(
                {"url": "https://cdn.sync.so/file_xyz"}
            ).encode("utf-8"))
        # Step 3: submit (POST /v2/generate, body present)
        if url.endswith("/v2/generate") and method == "POST":
            return _FakeResponse(json.dumps({"id": job_id}).encode("utf-8"))
        # Step 4 + 5: poll (GET /v2/generate/{job_id})
        if "/v2/generate/" in url and method == "GET":
            state["poll_count"] += 1
            if state["poll_count"] <= poll_processing_count:
                return _FakeResponse(json.dumps(
                    {"status": "PROCESSING"}
                ).encode("utf-8"))
            return _FakeResponse(json.dumps({
                "status": "COMPLETED",
                "outputUrl": "https://cdn.sync.so/output_v.mp4",
                "duration_s": duration_s,
            }).encode("utf-8"))
        # Final step: download from the outputUrl
        if "output_v.mp4" in url:
            return _FakeResponse(output_bytes)
        raise AssertionError(f"unexpected URL in fake transport: {url}")

    return _transport


def _make_video_and_audio(tmp_path):
    v = tmp_path / "in.mp4"
    a = tmp_path / "in.mp3"
    v.write_bytes(b"FAKE VIDEO BYTES")
    a.write_bytes(b"FAKE AUDIO BYTES")
    return v, a


@pytest.fixture(autouse=True)
def reset_registry_and_env(monkeypatch):
    _reset_for_tests()
    monkeypatch.setenv("SYNC_SO_API_KEY", "test-syncso-key")
    yield
    _reset_for_tests()


# ── Modality / construction ────────────────────────────────────────────


def test_modality_constant():
    assert LipSyncPostProcessor.modality == MODALITY_LIPSYNC_POST == "lipsync_post"


def test_zero_arg_construction():
    r = LipSyncPostProcessor()
    assert isinstance(r, LipSyncPostProcessor)
    assert r.modality == "lipsync_post"


def test_satisfies_modality_runner_protocol():
    r = LipSyncPostProcessor()
    assert isinstance(r, ModalityRunner)
    assert hasattr(r, "run") and callable(r.run)
    assert hasattr(r, "modality")


# ── Happy path (5-step protocol) ───────────────────────────────────────


def test_happy_path_returns_success_runresult(tmp_path):
    v, a = _make_video_and_audio(tmp_path)
    r = LipSyncPostProcessor()
    out = r.run({
        "shot_id": "EP001_SH02",
        "video_path": str(v),
        "audio_path": str(a),
        "model": "lipsync-2.0",
        "output_dir": str(tmp_path / "out"),
        "_transport": _make_5step_transport(),
        "poll_interval_s": 0.0,
    })
    assert isinstance(out, RunResult)
    assert out.success is True
    assert out.error is None
    assert out.modality == "lipsync_post"
    assert out.id.startswith("EP001_SH02_lipsync_post_")
    assert out.output_path is not None
    assert pathlib.Path(out.output_path).is_file()
    assert pathlib.Path(out.output_path).read_bytes() == b"FAKE_MP4_BYTES"


def test_happy_path_metadata_keys(tmp_path):
    v, a = _make_video_and_audio(tmp_path)
    r = LipSyncPostProcessor()
    out = r.run({
        "shot_id": "S",
        "video_path": str(v),
        "audio_path": str(a),
        "model": "lipsync-2.0",
        "output_dir": str(tmp_path / "out"),
        "_transport": _make_5step_transport(job_id="J42", duration_s=3.7),
        "poll_interval_s": 0.0,
    })
    md = out.metadata
    assert md["final_state"] == "succeeded"
    assert md["model"] == "lipsync-2.0"
    assert md["job_id"] == "J42"
    assert md["duration_s"] == 3.7
    assert md["video_path"] == str(v)
    assert md["audio_path"] == str(a)
    assert md["sync_mode"] == "loop"
    assert md["fps"] is None
    assert "cost_usd" in md
    assert isinstance(md["cost_usd"], float)


# ── Missing required keys → failure-RunResult, NEVER raise ─────────────


def test_missing_shot_id_returns_failure(tmp_path):
    v, a = _make_video_and_audio(tmp_path)
    out = LipSyncPostProcessor().run({
        "video_path": str(v), "audio_path": str(a), "model": "lipsync-2.0",
    })
    assert out.success is False
    assert "shot_id" in out.error
    assert out.metadata["final_state"] == "failed"
    assert out.id.startswith("unknown_lipsync_post_")


def test_missing_video_path_returns_failure(tmp_path):
    _, a = _make_video_and_audio(tmp_path)
    out = LipSyncPostProcessor().run({
        "shot_id": "S", "audio_path": str(a), "model": "lipsync-2.0",
    })
    assert out.success is False
    assert "video_path" in out.error


def test_missing_audio_path_returns_failure(tmp_path):
    v, _ = _make_video_and_audio(tmp_path)
    out = LipSyncPostProcessor().run({
        "shot_id": "S", "video_path": str(v), "model": "lipsync-2.0",
    })
    assert out.success is False
    assert "audio_path" in out.error


def test_missing_model_returns_failure(tmp_path):
    v, a = _make_video_and_audio(tmp_path)
    out = LipSyncPostProcessor().run({
        "shot_id": "S", "video_path": str(v), "audio_path": str(a),
    })
    assert out.success is False
    assert "model" in out.error


# ── Disk-existence handled by adapter via PayloadError ─────────────────


def test_video_path_not_on_disk_returns_failure(tmp_path):
    _, a = _make_video_and_audio(tmp_path)
    out = LipSyncPostProcessor().run({
        "shot_id": "S",
        "video_path": str(tmp_path / "does_not_exist.mp4"),
        "audio_path": str(a),
        "model": "lipsync-2.0",
        "output_dir": str(tmp_path / "out"),
    })
    assert out.success is False
    assert out.metadata["error_class"] == "PayloadError"


def test_audio_path_not_on_disk_returns_failure(tmp_path):
    v, _ = _make_video_and_audio(tmp_path)
    out = LipSyncPostProcessor().run({
        "shot_id": "S",
        "video_path": str(v),
        "audio_path": str(tmp_path / "missing.mp3"),
        "model": "lipsync-2.0",
        "output_dir": str(tmp_path / "out"),
    })
    assert out.success is False
    assert out.metadata["error_class"] == "PayloadError"


# ── Adapter error class mapping ────────────────────────────────────────


def _run_with_adapter_error(exc, tmp_path):
    v, a = _make_video_and_audio(tmp_path)
    r = LipSyncPostProcessor()
    with patch.object(_sync_so, "lipsync_video", side_effect=exc):
        return r.run({
            "shot_id": "SX",
            "video_path": str(v),
            "audio_path": str(a),
            "model": "lipsync-2.0",
            "output_dir": str(tmp_path / "out"),
        })


def test_adapter_auth_error_no_retry(tmp_path):
    out = _run_with_adapter_error(_sync_so.AuthError("401"), tmp_path)
    assert out.success is False
    assert out.metadata["error_class"] == "AuthError"


def test_adapter_quota_error(tmp_path):
    out = _run_with_adapter_error(_sync_so.QuotaError("402"), tmp_path)
    assert out.success is False
    assert out.metadata["error_class"] == "QuotaError"


def test_adapter_job_failed_error(tmp_path):
    out = _run_with_adapter_error(
        _sync_so.JobFailedError("worker died"), tmp_path
    )
    assert out.success is False
    assert out.metadata["error_class"] == "JobFailedError"


def test_adapter_job_timeout_error(tmp_path):
    out = _run_with_adapter_error(
        _sync_so.JobTimeoutError("> 600s"), tmp_path
    )
    assert out.success is False
    assert out.metadata["error_class"] == "JobTimeoutError"


def test_adapter_payload_error(tmp_path):
    out = _run_with_adapter_error(_sync_so.PayloadError("422"), tmp_path)
    assert out.success is False
    assert out.metadata["error_class"] == "PayloadError"


def test_adapter_server_error(tmp_path):
    out = _run_with_adapter_error(_sync_so.ServerError("500"), tmp_path)
    assert out.success is False
    assert out.metadata["error_class"] == "ServerError"


def test_unexpected_exception_caught_not_reraised(tmp_path):
    out = _run_with_adapter_error(ValueError("boom"), tmp_path)
    assert out.success is False
    assert out.metadata["error_class"] == "ValueError"


# ── Output dir / file_stem ─────────────────────────────────────────────


def test_default_output_dir_is_recoil_root_lipsync_outputs(tmp_path):
    v, a = _make_video_and_audio(tmp_path)
    r = LipSyncPostProcessor()
    seen = {}

    def _capturing(**kwargs):
        seen.update(kwargs)
        from recoil.execution.providers.sync_so import LipSyncResult
        return LipSyncResult(
            output_path=pathlib.Path("/tmp/x.mp4"),
            duration_s=1.0, cost_usd=0.0,
            model=kwargs["model_id"], job_id="J",
            raw_metadata={},
        )

    with patch.object(_sync_so, "lipsync_video", side_effect=_capturing):
        out = r.run({
            "shot_id": "S",
            "video_path": str(v),
            "audio_path": str(a),
            "model": "lipsync-2.0",
        })
    assert out.success is True
    assert seen["output_dir"] == RECOIL_ROOT / "_lipsync_outputs"


def test_custom_output_dir_and_file_stem(tmp_path):
    v, a = _make_video_and_audio(tmp_path)
    r = LipSyncPostProcessor()
    seen = {}

    def _capturing(**kwargs):
        seen.update(kwargs)
        from recoil.execution.providers.sync_so import LipSyncResult
        return LipSyncResult(
            output_path=pathlib.Path("/tmp/x.mp4"),
            duration_s=1.0, cost_usd=0.0,
            model=kwargs["model_id"], job_id="J",
            raw_metadata={},
        )

    with patch.object(_sync_so, "lipsync_video", side_effect=_capturing):
        r.run({
            "shot_id": "S",
            "video_path": str(v),
            "audio_path": str(a),
            "model": "lipsync-2.0",
            "output_dir": str(tmp_path / "custom"),
            "file_stem": "my_stem",
        })
    assert seen["output_dir"] == pathlib.Path(str(tmp_path / "custom"))
    assert seen["file_stem"] == "my_stem"


# ── Threading: sync_mode, fps, timeout, poll, retries ──────────────────


def test_sync_mode_threaded_default_loop(tmp_path):
    v, a = _make_video_and_audio(tmp_path)
    r = LipSyncPostProcessor()
    seen = {}

    def _capturing(**kwargs):
        seen.update(kwargs)
        from recoil.execution.providers.sync_so import LipSyncResult
        return LipSyncResult(
            output_path=pathlib.Path("/tmp/x.mp4"),
            duration_s=1.0, cost_usd=0.0,
            model=kwargs["model_id"], job_id="J",
            raw_metadata={},
        )

    with patch.object(_sync_so, "lipsync_video", side_effect=_capturing):
        out = r.run({
            "shot_id": "S",
            "video_path": str(v),
            "audio_path": str(a),
            "model": "lipsync-2.0",
        })
    assert seen["sync_mode"] == "loop"
    assert out.metadata["sync_mode"] == "loop"


def test_sync_mode_threaded_custom(tmp_path):
    v, a = _make_video_and_audio(tmp_path)
    r = LipSyncPostProcessor()
    seen = {}

    def _capturing(**kwargs):
        seen.update(kwargs)
        from recoil.execution.providers.sync_so import LipSyncResult
        return LipSyncResult(
            output_path=pathlib.Path("/tmp/x.mp4"),
            duration_s=1.0, cost_usd=0.0,
            model=kwargs["model_id"], job_id="J",
            raw_metadata={},
        )

    with patch.object(_sync_so, "lipsync_video", side_effect=_capturing):
        out = r.run({
            "shot_id": "S",
            "video_path": str(v),
            "audio_path": str(a),
            "model": "lipsync-2.0",
            "sync_mode": "cut_off",
        })
    assert seen["sync_mode"] == "cut_off"
    assert out.metadata["sync_mode"] == "cut_off"


def test_fps_threaded(tmp_path):
    v, a = _make_video_and_audio(tmp_path)
    r = LipSyncPostProcessor()
    seen = {}

    def _capturing(**kwargs):
        seen.update(kwargs)
        from recoil.execution.providers.sync_so import LipSyncResult
        return LipSyncResult(
            output_path=pathlib.Path("/tmp/x.mp4"),
            duration_s=1.0, cost_usd=0.0,
            model=kwargs["model_id"], job_id="J",
            raw_metadata={},
        )

    with patch.object(_sync_so, "lipsync_video", side_effect=_capturing):
        out = r.run({
            "shot_id": "S",
            "video_path": str(v),
            "audio_path": str(a),
            "model": "lipsync-2.0",
            "fps": 24,
        })
    assert seen["fps"] == 24
    assert out.metadata["fps"] == 24


def test_timeout_poll_retries_threaded(tmp_path):
    v, a = _make_video_and_audio(tmp_path)
    r = LipSyncPostProcessor()
    seen = {}

    def _capturing(**kwargs):
        seen.update(kwargs)
        from recoil.execution.providers.sync_so import LipSyncResult
        return LipSyncResult(
            output_path=pathlib.Path("/tmp/x.mp4"),
            duration_s=1.0, cost_usd=0.0,
            model=kwargs["model_id"], job_id="J",
            raw_metadata={},
        )

    with patch.object(_sync_so, "lipsync_video", side_effect=_capturing):
        r.run({
            "shot_id": "S",
            "video_path": str(v),
            "audio_path": str(a),
            "model": "lipsync-2.0",
            "timeout_s": 999.0,
            "poll_interval_s": 1.5,
            "max_retries": 7,
        })
    assert seen["timeout_s"] == 999.0
    assert seen["poll_interval_s"] == 1.5
    assert seen["max_retries"] == 7


# ── RunResult.id format ────────────────────────────────────────────────


def test_runresult_id_format(tmp_path):
    v, a = _make_video_and_audio(tmp_path)
    out = LipSyncPostProcessor().run({
        "shot_id": "EP002_SH04",
        "video_path": str(v),
        "audio_path": str(a),
        "model": "lipsync-2.0",
        "output_dir": str(tmp_path / "out"),
        "_transport": _make_5step_transport(),
        "poll_interval_s": 0.0,
    })
    parts = out.id.rsplit("_", 1)
    assert parts[0] == "EP002_SH04_lipsync_post"
    assert parts[1].isdigit()
