"""CP-8 Phase 4 — AudioRunner unit tests.

Mocks the ElevenLabs adapter via the `_transport=` payload key (the runner
threads it into the adapter). Exercises happy-path, missing-key validation,
each error class, output dir overrides, voice settings threading, and
ModalityRunner Protocol conformance.
"""

import sys
import pathlib
from unittest.mock import patch

sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent.parent.parent))
from recoil.core.paths import ensure_pipeline_importable, RECOIL_ROOT  # noqa: E402

ensure_pipeline_importable()

import pytest  # noqa: E402

from recoil.pipeline.core.registry import (  # noqa: E402
    MODALITY_AUDIO_T2A,
    ModalityRunner,
    RunResult,
    _reset_for_tests,
)
from recoil.pipeline.core.runners.audio_runner import AudioRunner  # noqa: E402
from recoil.execution.providers import elevenlabs as _eleven  # noqa: E402


# ── Fake transport response ────────────────────────────────────────────


class _FakeResponse:
    """Mimics urllib HTTPResponse just enough for the adapter."""

    def __init__(self, body: bytes, status: int = 200, headers=None):
        self._body = body
        self.status = status
        self.headers = headers or {"request-id": "req_test_123"}

    def read(self) -> bytes:
        return self._body

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc, tb):
        return False


def _make_happy_transport(audio_bytes: bytes = b"FAKE_MP3_BYTES"):
    """Returns a transport callable that always returns 200 OK with the bytes."""

    def _transport(url, *, headers, body, timeout):
        return _FakeResponse(audio_bytes)

    return _transport


def _make_raising_transport(exc):
    def _transport(url, *, headers, body, timeout):
        raise exc

    return _transport


@pytest.fixture(autouse=True)
def reset_registry_and_env(monkeypatch):
    _reset_for_tests()
    monkeypatch.setenv("ELEVENLABS_API_KEY", "test-key-abc")
    yield
    _reset_for_tests()


# ── Modality / construction contract ───────────────────────────────────


def test_modality_constant():
    assert AudioRunner.modality == MODALITY_AUDIO_T2A == "audio_t2a"


def test_zero_arg_construction():
    r = AudioRunner()
    assert isinstance(r, AudioRunner)
    assert r.modality == "audio_t2a"


def test_satisfies_modality_runner_protocol():
    r = AudioRunner()
    assert isinstance(r, ModalityRunner)
    assert hasattr(r, "run") and callable(r.run)
    assert hasattr(r, "modality")


# ── Happy path ─────────────────────────────────────────────────────────


def test_happy_path_returns_success_runresult(tmp_path):
    r = AudioRunner()
    out = r.run({
        "shot_id": "EP001_SH02",
        "text": "Hello world",
        "voice_id": "voice_xyz",
        "model": "eleven_multilingual_v2",
        "output_dir": str(tmp_path),
        "_transport": _make_happy_transport(),
    })
    assert isinstance(out, RunResult)
    assert out.success is True
    assert out.error is None
    assert out.modality == "audio_t2a"
    assert out.id.startswith("EP001_SH02_audio_t2a_")
    assert out.output_path is not None
    assert pathlib.Path(out.output_path).is_file()
    assert pathlib.Path(out.output_path).read_bytes() == b"FAKE_MP3_BYTES"


def test_happy_path_metadata_keys(tmp_path):
    r = AudioRunner()
    out = r.run({
        "shot_id": "EP001_SH02",
        "text": "Hello world",
        "voice_id": "voice_xyz",
        "model": "eleven_multilingual_v2",
        "output_dir": str(tmp_path),
        "_transport": _make_happy_transport(),
    })
    md = out.metadata
    assert md["final_state"] == "succeeded"
    assert md["model"] == "eleven_multilingual_v2"
    assert md["voice_id"] == "voice_xyz"
    assert md["request_id"] == "req_test_123"
    assert md["char_count"] == len("Hello world")
    assert md["output_format"] == "mp3"
    assert "duration_s" in md
    assert "cost_usd" in md
    assert isinstance(md["cost_usd"], float)


def test_happy_path_cost_usd_positive_when_profile_registered(tmp_path):
    """cost_usd > 0 when a model profile + cost_per_1k_chars is registered."""
    r = AudioRunner()
    # Patch cost compute to return non-zero unconditionally so this test
    # doesn't depend on the on-disk model_profiles.json content.
    with patch.object(_eleven, "_compute_cost", return_value=0.025):
        out = r.run({
            "shot_id": "EP001_SH02",
            "text": "Hello world",
            "voice_id": "voice_xyz",
            "model": "eleven_multilingual_v2",
            "output_dir": str(tmp_path),
            "_transport": _make_happy_transport(),
        })
    assert out.success is True
    assert out.metadata["cost_usd"] > 0


# ── Missing-key validation (failure-RunResult, NEVER raise) ────────────


def test_missing_shot_id_returns_failure():
    out = AudioRunner().run({
        "text": "x", "voice_id": "v", "model": "m",
        "_transport": _make_happy_transport(),
    })
    assert out.success is False
    assert "shot_id" in out.error
    assert out.metadata["final_state"] == "failed"
    assert out.id.startswith("unknown_audio_t2a_")


def test_missing_text_returns_failure():
    out = AudioRunner().run({
        "shot_id": "S1", "voice_id": "v", "model": "m",
        "_transport": _make_happy_transport(),
    })
    assert out.success is False
    assert "text" in out.error


def test_missing_voice_id_returns_failure():
    out = AudioRunner().run({
        "shot_id": "S1", "text": "x", "model": "m",
        "_transport": _make_happy_transport(),
    })
    assert out.success is False
    assert "voice_id" in out.error


def test_missing_model_returns_failure():
    out = AudioRunner().run({
        "shot_id": "S1", "text": "x", "voice_id": "v",
        "_transport": _make_happy_transport(),
    })
    assert out.success is False
    assert "model" in out.error


def test_empty_payload_mentions_all_missing_keys():
    out = AudioRunner().run({})
    assert out.success is False
    err = out.error
    for key in ("shot_id", "text", "voice_id", "model"):
        assert key in err
    assert out.metadata == {
        "final_state": "failed",
        "cost_usd": 0.0,
        "model": None,
        "voice_id": None,
        "request_id": None,
        "char_count": 0,
    }


# ── Adapter error mapping → failure-RunResult.error_class ──────────────


def _run_with_adapter_error(exc, tmp_path):
    """Patch synthesize_speech to raise the given error and run the runner."""
    r = AudioRunner()
    with patch.object(_eleven, "synthesize_speech", side_effect=exc):
        return r.run({
            "shot_id": "SX",
            "text": "hi",
            "voice_id": "v",
            "model": "m",
            "output_dir": str(tmp_path),
        })


def test_adapter_auth_error_maps_to_failure(tmp_path):
    out = _run_with_adapter_error(_eleven.AuthError("401"), tmp_path)
    assert out.success is False
    assert out.metadata["error_class"] == "AuthError"
    assert "AuthError" in out.error


def test_adapter_quota_error_maps_to_failure(tmp_path):
    out = _run_with_adapter_error(_eleven.QuotaError("402"), tmp_path)
    assert out.success is False
    assert out.metadata["error_class"] == "QuotaError"


def test_adapter_payload_error_maps_to_failure(tmp_path):
    out = _run_with_adapter_error(_eleven.PayloadError("422"), tmp_path)
    assert out.success is False
    assert out.metadata["error_class"] == "PayloadError"


def test_adapter_server_error_maps_to_failure(tmp_path):
    out = _run_with_adapter_error(_eleven.ServerError("500x3"), tmp_path)
    assert out.success is False
    assert out.metadata["error_class"] == "ServerError"


def test_adapter_network_error_maps_to_failure(tmp_path):
    out = _run_with_adapter_error(_eleven.NetworkError("timeout"), tmp_path)
    assert out.success is False
    assert out.metadata["error_class"] == "NetworkError"


def test_unexpected_exception_caught_not_reraised(tmp_path):
    """A bare ValueError from the adapter must NOT propagate out of run()."""
    out = _run_with_adapter_error(ValueError("boom"), tmp_path)
    assert out.success is False
    assert out.metadata["error_class"] == "ValueError"
    assert "ValueError" in out.error


# ── Output dir / file_stem ─────────────────────────────────────────────


def test_default_output_dir_is_recoil_root_audio_outputs():
    """When output_dir is not specified, the adapter is called with
    $RECOIL_ROOT/_audio_outputs."""
    r = AudioRunner()
    seen = {}

    def _capturing(**kwargs):
        seen.update(kwargs)
        # Bypass real synthesis — return a SynthesisResult-shaped value
        from recoil.execution.providers.elevenlabs import SynthesisResult
        return SynthesisResult(
            output_path=pathlib.Path("/tmp/x.mp3"),
            duration_s=None,
            cost_usd=0.0,
            model=kwargs["model_id"],
            voice_id=kwargs["voice_id"],
            request_id="rq",
            raw_metadata={"char_count": len(kwargs["text"])},
        )

    with patch.object(_eleven, "synthesize_speech", side_effect=_capturing):
        out = r.run({
            "shot_id": "SH",
            "text": "hello",
            "voice_id": "v",
            "model": "m",
        })
    assert out.success is True
    assert seen["output_dir"] == RECOIL_ROOT / "_audio_outputs"


def test_custom_output_dir_and_file_stem_honored(tmp_path):
    r = AudioRunner()
    seen = {}

    def _capturing(**kwargs):
        seen.update(kwargs)
        from recoil.execution.providers.elevenlabs import SynthesisResult
        return SynthesisResult(
            output_path=pathlib.Path("/tmp/y.mp3"),
            duration_s=None, cost_usd=0.0,
            model=kwargs["model_id"], voice_id=kwargs["voice_id"],
            request_id=None, raw_metadata={"char_count": len(kwargs["text"])},
        )

    with patch.object(_eleven, "synthesize_speech", side_effect=_capturing):
        r.run({
            "shot_id": "SH",
            "text": "hi",
            "voice_id": "v",
            "model": "m",
            "output_dir": str(tmp_path / "custom"),
            "file_stem": "my_stem",
        })
    assert seen["output_dir"] == pathlib.Path(str(tmp_path / "custom"))
    assert seen["file_stem"] == "my_stem"


# ── voice_settings + output_format threading ───────────────────────────


def test_voice_settings_threaded_to_adapter(tmp_path):
    r = AudioRunner()
    seen = {}

    def _capturing(**kwargs):
        seen.update(kwargs)
        from recoil.execution.providers.elevenlabs import SynthesisResult
        return SynthesisResult(
            output_path=pathlib.Path("/tmp/z.mp3"),
            duration_s=None, cost_usd=0.0,
            model=kwargs["model_id"], voice_id=kwargs["voice_id"],
            request_id=None, raw_metadata={"char_count": len(kwargs["text"])},
        )

    with patch.object(_eleven, "synthesize_speech", side_effect=_capturing):
        r.run({
            "shot_id": "SH",
            "text": "hi",
            "voice_id": "v",
            "model": "m",
            "output_dir": str(tmp_path),
            "stability": 0.7,
            "similarity_boost": 0.9,
            "style": 0.2,
            "use_speaker_boost": False,
        })
    vs = seen["voice_settings"]
    assert vs == {
        "stability": 0.7, "similarity_boost": 0.9,
        "style": 0.2, "use_speaker_boost": False,
    }


def test_output_format_wav_threaded(tmp_path):
    r = AudioRunner()
    seen = {}

    def _capturing(**kwargs):
        seen.update(kwargs)
        from recoil.execution.providers.elevenlabs import SynthesisResult
        return SynthesisResult(
            output_path=pathlib.Path("/tmp/z.wav"),
            duration_s=None, cost_usd=0.0,
            model=kwargs["model_id"], voice_id=kwargs["voice_id"],
            request_id=None, raw_metadata={"char_count": len(kwargs["text"])},
        )

    with patch.object(_eleven, "synthesize_speech", side_effect=_capturing):
        out = r.run({
            "shot_id": "SH",
            "text": "hi",
            "voice_id": "v",
            "model": "m",
            "output_dir": str(tmp_path),
            "output_format": "wav",
        })
    assert seen["output_format"] == "wav"
    assert out.metadata["output_format"] == "wav"


def test_output_format_wav_writes_wav_extension_via_adapter(tmp_path):
    """End-to-end through the real adapter: output_format=wav writes .wav file."""
    r = AudioRunner()
    out = r.run({
        "shot_id": "SH",
        "text": "hi",
        "voice_id": "v",
        "model": "eleven_multilingual_v2",
        "output_dir": str(tmp_path),
        "output_format": "wav",
        "_transport": _make_happy_transport(audio_bytes=b"RIFF...."),
    })
    assert out.success is True
    assert out.output_path.endswith(".wav")
    assert pathlib.Path(out.output_path).exists()


def test_max_retries_and_timeout_threaded(tmp_path):
    r = AudioRunner()
    seen = {}

    def _capturing(**kwargs):
        seen.update(kwargs)
        from recoil.execution.providers.elevenlabs import SynthesisResult
        return SynthesisResult(
            output_path=pathlib.Path("/tmp/x.mp3"),
            duration_s=None, cost_usd=0.0,
            model=kwargs["model_id"], voice_id=kwargs["voice_id"],
            request_id=None, raw_metadata={"char_count": len(kwargs["text"])},
        )

    with patch.object(_eleven, "synthesize_speech", side_effect=_capturing):
        r.run({
            "shot_id": "SH",
            "text": "hi",
            "voice_id": "v",
            "model": "m",
            "output_dir": str(tmp_path),
            "max_retries": 7,
            "timeout_s": 12.5,
        })
    assert seen["max_retries"] == 7
    assert seen["timeout_s"] == 12.5


# ── RunResult.id format ────────────────────────────────────────────────


def test_runresult_id_format(tmp_path):
    out = AudioRunner().run({
        "shot_id": "EP002_SH04",
        "text": "hi",
        "voice_id": "v",
        "model": "eleven_multilingual_v2",
        "output_dir": str(tmp_path),
        "_transport": _make_happy_transport(),
    })
    # Format: {shot_id}_audio_t2a_{timestamp}
    parts = out.id.rsplit("_", 1)
    assert parts[0] == "EP002_SH04_audio_t2a"
    assert parts[1].isdigit()
