"""Tests for execution/providers/elevenlabs.py (CP-8 Phase 2).

All transport is mocked via the `transport=` kwarg. NO live API calls.
"""

from __future__ import annotations

import json
import urllib.error
from typing import Optional

import pytest

from recoil.execution.providers import elevenlabs as eleven
from recoil.execution.providers.elevenlabs import (
    synthesize_speech,
    SynthesisResult,
    AuthError,
    QuotaError,
    PayloadError,
    RateLimitError,
    ServerError,
    NetworkError,
)


# --------------------------------------------------------------------------
# Mock transport plumbing
# --------------------------------------------------------------------------

class _MockResponse:
    """Mimics the urlopen-returned context-managed response object."""

    def __init__(self, *, status: int = 200, body: bytes = b"",
                 headers: Optional[dict] = None):
        self.status = status
        self._body = body
        self.headers = headers or {}

    def read(self) -> bytes:
        return self._body

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc, tb):
        return False


class _MockTransport:
    """Records calls; returns queued responses or raises queued exceptions."""

    def __init__(self, *responses):
        # Each entry is either a _MockResponse or an Exception instance.
        self._responses = list(responses)
        self.calls: list[dict] = []
        self.call_count = 0

    def __call__(self, url, *, headers, body, timeout):
        self.call_count += 1
        self.calls.append({
            "url": url,
            "headers": dict(headers),
            "body": body,
            "timeout": timeout,
        })
        if not self._responses:
            raise AssertionError("MockTransport ran out of responses")
        nxt = self._responses.pop(0)
        if isinstance(nxt, BaseException):
            raise nxt
        return nxt


@pytest.fixture(autouse=True)
def _set_api_key(monkeypatch):
    monkeypatch.setenv("ELEVENLABS_API_KEY", "test-key-123")


@pytest.fixture(autouse=True)
def _no_sleep(monkeypatch):
    """Skip backoff sleeps so the suite runs fast."""
    monkeypatch.setattr(eleven.time, "sleep", lambda *_: None)


# --------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------

def test_happy_path_returns_synthesis_result_and_writes_file(tmp_path):
    audio_bytes = b"FAKE_MP3_BYTES"
    transport = _MockTransport(
        _MockResponse(
            status=200,
            body=audio_bytes,
            headers={"request-id": "req-abc-123", "content-length": str(len(audio_bytes))},
        )
    )
    result = synthesize_speech(
        text="hello world",
        voice_id="voice_x",
        output_dir=tmp_path,
        file_stem="utt_001",
        transport=transport,
    )
    assert isinstance(result, SynthesisResult)
    assert result.output_path == tmp_path / "utt_001.mp3"
    assert result.output_path.read_bytes() == audio_bytes
    assert result.voice_id == "voice_x"
    assert result.model == "eleven_multilingual_v2"
    assert result.request_id == "req-abc-123"
    assert transport.call_count == 1


def test_missing_api_key_raises_auth_error_before_http(tmp_path, monkeypatch):
    monkeypatch.delenv("ELEVENLABS_API_KEY", raising=False)
    transport = _MockTransport()  # no responses queued
    with pytest.raises(AuthError):
        synthesize_speech(
            text="hello",
            voice_id="voice_x",
            output_dir=tmp_path,
            file_stem="utt",
            transport=transport,
        )
    assert transport.call_count == 0


def test_empty_text_raises_payload_error_before_http(tmp_path):
    transport = _MockTransport()
    with pytest.raises(PayloadError):
        synthesize_speech(
            text="",
            voice_id="voice_x",
            output_dir=tmp_path,
            file_stem="utt",
            transport=transport,
        )
    assert transport.call_count == 0


def test_empty_voice_id_raises_payload_error_before_http(tmp_path):
    transport = _MockTransport()
    with pytest.raises(PayloadError):
        synthesize_speech(
            text="hi",
            voice_id="",
            output_dir=tmp_path,
            file_stem="utt",
            transport=transport,
        )
    assert transport.call_count == 0


def test_http_401_fails_fast_no_retry(tmp_path):
    err = urllib.error.HTTPError(
        url="https://api.elevenlabs.io/", code=401,
        msg="unauth", hdrs=None, fp=None,
    )
    transport = _MockTransport(err)
    with pytest.raises(AuthError):
        synthesize_speech(
            text="hi", voice_id="v", output_dir=tmp_path,
            file_stem="u", transport=transport,
        )
    assert transport.call_count == 1


def test_http_402_fails_fast_no_retry(tmp_path):
    err = urllib.error.HTTPError(
        url="https://api.elevenlabs.io/", code=402,
        msg="quota", hdrs=None, fp=None,
    )
    transport = _MockTransport(err)
    with pytest.raises(QuotaError):
        synthesize_speech(
            text="hi", voice_id="v", output_dir=tmp_path,
            file_stem="u", transport=transport,
        )
    assert transport.call_count == 1


def test_http_422_fails_fast_no_retry(tmp_path):
    err = urllib.error.HTTPError(
        url="https://api.elevenlabs.io/", code=422,
        msg="bad payload", hdrs=None, fp=None,
    )
    transport = _MockTransport(err)
    with pytest.raises(PayloadError):
        synthesize_speech(
            text="hi", voice_id="v", output_dir=tmp_path,
            file_stem="u", transport=transport,
        )
    assert transport.call_count == 1


def test_http_429_fails_fast_no_retry(tmp_path):
    err = urllib.error.HTTPError(
        url="https://api.elevenlabs.io/", code=429,
        msg="rate", hdrs=None, fp=None,
    )
    transport = _MockTransport(err)
    with pytest.raises(RateLimitError):
        synthesize_speech(
            text="hi", voice_id="v", output_dir=tmp_path,
            file_stem="u", transport=transport,
        )
    assert transport.call_count == 1


def test_http_500_retried_then_succeeds(tmp_path):
    err = lambda: urllib.error.HTTPError(  # noqa: E731
        url="x", code=500, msg="srv", hdrs=None, fp=None,
    )
    audio = b"OK"
    transport = _MockTransport(
        err(), err(), err(),
        _MockResponse(status=200, body=audio, headers={}),
    )
    result = synthesize_speech(
        text="hi", voice_id="v", output_dir=tmp_path,
        file_stem="u", transport=transport,
    )
    assert transport.call_count == 4
    assert result.output_path.read_bytes() == audio


def test_http_500_exhausted_raises_server_error(tmp_path):
    err = lambda: urllib.error.HTTPError(  # noqa: E731
        url="x", code=500, msg="srv", hdrs=None, fp=None,
    )
    transport = _MockTransport(err(), err(), err(), err())
    with pytest.raises(ServerError):
        synthesize_speech(
            text="hi", voice_id="v", output_dir=tmp_path,
            file_stem="u", transport=transport,
        )
    assert transport.call_count == 4


def test_network_blip_url_error_retried_then_succeeds(tmp_path):
    audio = b"DATA"
    transport = _MockTransport(
        urllib.error.URLError("conn reset"),
        _MockResponse(status=200, body=audio, headers={}),
    )
    result = synthesize_speech(
        text="hi", voice_id="v", output_dir=tmp_path,
        file_stem="u", transport=transport,
    )
    assert transport.call_count == 2
    assert result.output_path.read_bytes() == audio


def test_network_exhausted_raises_network_error(tmp_path):
    transport = _MockTransport(
        urllib.error.URLError("e1"),
        urllib.error.URLError("e2"),
        urllib.error.URLError("e3"),
        urllib.error.URLError("e4"),
    )
    with pytest.raises(NetworkError):
        synthesize_speech(
            text="hi", voice_id="v", output_dir=tmp_path,
            file_stem="u", transport=transport,
        )
    assert transport.call_count == 4


def test_voice_settings_default_merged(tmp_path):
    transport = _MockTransport(_MockResponse(status=200, body=b"x", headers={}))
    synthesize_speech(
        text="hi", voice_id="v", output_dir=tmp_path,
        file_stem="u", transport=transport,
    )
    sent_body = json.loads(transport.calls[0]["body"].decode("utf-8"))
    settings = sent_body["voice_settings"]
    assert settings["stability"] == 0.5
    assert settings["similarity_boost"] == 0.75
    assert settings["style"] == 0.0
    assert settings["use_speaker_boost"] is True


def test_voice_settings_override_applied(tmp_path):
    transport = _MockTransport(_MockResponse(status=200, body=b"x", headers={}))
    synthesize_speech(
        text="hi", voice_id="v", output_dir=tmp_path,
        file_stem="u",
        voice_settings={"stability": 0.9, "similarity_boost": 0.1},
        transport=transport,
    )
    sent_body = json.loads(transport.calls[0]["body"].decode("utf-8"))
    settings = sent_body["voice_settings"]
    assert settings["stability"] == 0.9
    assert settings["similarity_boost"] == 0.1
    # Untouched defaults still present
    assert settings["style"] == 0.0
    assert settings["use_speaker_boost"] is True


def test_output_format_mp3_yields_mp3_extension(tmp_path):
    transport = _MockTransport(_MockResponse(status=200, body=b"x", headers={}))
    result = synthesize_speech(
        text="hi", voice_id="v", output_dir=tmp_path,
        file_stem="u", output_format="mp3", transport=transport,
    )
    assert result.output_path.suffix == ".mp3"


def test_output_format_wav_yields_wav_extension(tmp_path):
    transport = _MockTransport(_MockResponse(status=200, body=b"x", headers={}))
    result = synthesize_speech(
        text="hi", voice_id="v", output_dir=tmp_path,
        file_stem="u", output_format="wav", transport=transport,
    )
    assert result.output_path.suffix == ".wav"


def test_request_id_parsed_from_request_id_header(tmp_path):
    transport = _MockTransport(
        _MockResponse(status=200, body=b"x", headers={"request-id": "rid-1"})
    )
    result = synthesize_speech(
        text="hi", voice_id="v", output_dir=tmp_path,
        file_stem="u", transport=transport,
    )
    assert result.request_id == "rid-1"


def test_request_id_parsed_from_x_request_id_header(tmp_path):
    transport = _MockTransport(
        _MockResponse(status=200, body=b"x", headers={"x-request-id": "xrid-2"})
    )
    result = synthesize_speech(
        text="hi", voice_id="v", output_dir=tmp_path,
        file_stem="u", transport=transport,
    )
    assert result.request_id == "xrid-2"


def test_cost_compute_uses_model_profile(tmp_path, monkeypatch):
    # Patch get_profile at import site. _compute_cost imports lazily.
    import sys
    import types
    fake_module = types.ModuleType("recoil.core.model_profiles")

    def fake_get_profile(model_id):
        assert model_id == "eleven_multilingual_v2"
        return {"cost_per_1k_chars": 0.30}

    fake_module.get_profile = fake_get_profile
    monkeypatch.setitem(sys.modules, "recoil.core.model_profiles", fake_module)

    transport = _MockTransport(_MockResponse(status=200, body=b"x", headers={}))
    text = "a" * 1000  # exactly 1k chars -> cost = 0.30
    result = synthesize_speech(
        text=text, voice_id="v", output_dir=tmp_path,
        file_stem="u", transport=transport,
    )
    assert result.cost_usd == pytest.approx(0.30)


def test_cost_compute_falls_back_to_zero_when_unregistered(tmp_path, monkeypatch):
    import sys
    import types
    fake_module = types.ModuleType("recoil.core.model_profiles")

    def fake_get_profile(model_id):
        raise KeyError(model_id)

    fake_module.get_profile = fake_get_profile
    monkeypatch.setitem(sys.modules, "recoil.core.model_profiles", fake_module)

    transport = _MockTransport(_MockResponse(status=200, body=b"x", headers={}))
    result = synthesize_speech(
        text="hello", voice_id="v", output_dir=tmp_path,
        file_stem="u", transport=transport,
    )
    assert result.cost_usd == 0.0


def test_output_dir_created_if_missing(tmp_path):
    nested = tmp_path / "deep" / "nested" / "out"
    assert not nested.exists()
    transport = _MockTransport(_MockResponse(status=200, body=b"x", headers={}))
    result = synthesize_speech(
        text="hi", voice_id="v", output_dir=nested,
        file_stem="u", transport=transport,
    )
    assert nested.is_dir()
    assert result.output_path.parent == nested


def test_request_body_json_shape(tmp_path):
    transport = _MockTransport(_MockResponse(status=200, body=b"x", headers={}))
    synthesize_speech(
        text="hello there", voice_id="vid_42",
        model_id="eleven_turbo_v2",
        output_dir=tmp_path, file_stem="u", transport=transport,
    )
    sent_body = json.loads(transport.calls[0]["body"].decode("utf-8"))
    assert sent_body["text"] == "hello there"
    assert sent_body["model_id"] == "eleven_turbo_v2"
    assert "voice_settings" in sent_body
    # url should bake in the voice_id
    assert "vid_42" in transport.calls[0]["url"]
