"""sync.so lipsync adapter (CP-8).

Two-step protocol: upload local files → submit /v2/generate → poll /v2/generate/{id}
→ download output_url. Synchronous wrapper over urlopen; retries 5xx + network
on each step. Fail-fast on 401/402/422/429.

Public surface:
    lipsync_video(...) -> LipSyncResult
    LipSyncResult dataclass
    LipSyncError + subclasses (AuthError, QuotaError, PayloadError,
        RateLimitError, ServerError, NetworkError, JobFailedError, TimeoutError)
"""

from __future__ import annotations

import json
import logging
import mimetypes
import os
import time
import urllib.error
import urllib.request
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Optional

# Phase C: canonical pattern set (re-aliased to preserve import surface).
from recoil.pipeline.core.failure_mode import TRANSIENT_HTTP_CODES as RETRYABLE_HTTP

logger = logging.getLogger(__name__)

UPLOAD_URL = "https://api.sync.so/v2/upload"
GENERATE_URL = "https://api.sync.so/v2/generate"
DEFAULT_AUTH_ENV_VAR = "SYNC_SO_API_KEY"
DEFAULT_GEN_TIMEOUT_S = 600.0
DEFAULT_POLL_INTERVAL_S = 5.0
DEFAULT_PER_CALL_TIMEOUT_S = 60.0


class LipSyncError(RuntimeError):
    pass


class AuthError(LipSyncError):
    pass


class QuotaError(LipSyncError):
    pass


class PayloadError(LipSyncError):
    pass


class RateLimitError(LipSyncError):
    pass


class ServerError(LipSyncError):
    pass


class NetworkError(LipSyncError):
    pass


class JobFailedError(LipSyncError):
    pass


class JobTimeoutError(LipSyncError):
    pass


@dataclass
class LipSyncResult:
    output_path: Path
    duration_s: Optional[float]
    cost_usd: float
    model: str
    job_id: Optional[str]
    raw_metadata: dict = field(default_factory=dict)


def _default_transport(url: str, *, headers: dict, body: Optional[bytes],
                       method: str = "GET", timeout: float = DEFAULT_PER_CALL_TIMEOUT_S):
    req = urllib.request.Request(url, data=body, headers=headers, method=method)
    return urllib.request.urlopen(req, timeout=timeout)


def _classify_http(code: int, body_bytes: bytes) -> LipSyncError:
    msg = body_bytes.decode("utf-8", errors="replace")[:512]
    if code == 401:
        return AuthError(f"sync.so 401: {msg}")
    if code == 402:
        return QuotaError(f"sync.so 402: {msg}")
    if code == 422:
        return PayloadError(f"sync.so 422: {msg}")
    if code == 429:
        return RateLimitError(f"sync.so 429: {msg}")
    if code in RETRYABLE_HTTP:
        return ServerError(f"sync.so {code}: {msg}")
    return LipSyncError(f"sync.so {code}: {msg}")


def _do_with_retries(fn, *, max_retries: int = 3) -> Any:
    last: Optional[Exception] = None
    for attempt in range(max_retries + 1):
        try:
            return fn()
        except (AuthError, QuotaError, PayloadError, RateLimitError):
            raise
        except (ServerError, NetworkError) as e:
            last = e
        except urllib.error.HTTPError as e:
            err = _classify_http(e.code, e.read() if hasattr(e, "read") else b"")
            if isinstance(err, (AuthError, QuotaError, PayloadError, RateLimitError)):
                raise err
            last = err
        except (urllib.error.URLError, TimeoutError, ConnectionError) as e:
            last = NetworkError(f"network: {e}")
        if attempt < max_retries:
            time.sleep(2 ** attempt)
        else:
            assert last is not None
            raise last


def _compute_cost(model_id: str, duration_s: Optional[float]) -> float:
    if duration_s is None:
        return 0.0
    try:
        from recoil.core.model_profiles import get_profile
        prof = get_profile(model_id)
        rate = prof.get("cost_per_second")
        if rate is not None:
            return float(rate) * float(duration_s)
    except Exception as e:
        logger.debug(f"sync.so cost compute fell back ({e})")
    return 0.0


def _upload_file(path: Path, *, api_key: str, transport: Callable,
                 timeout_s: float, max_retries: int) -> str:
    """Multipart-upload a local file. Returns the presigned URL string."""
    boundary = "----RecoilCP8Boundary7c2cf"
    mime, _ = mimetypes.guess_type(str(path))
    mime = mime or "application/octet-stream"
    file_bytes = path.read_bytes()
    body = (
        f"--{boundary}\r\n"
        f'Content-Disposition: form-data; name="file"; filename="{path.name}"\r\n'
        f"Content-Type: {mime}\r\n\r\n"
    ).encode("utf-8") + file_bytes + f"\r\n--{boundary}--\r\n".encode("utf-8")

    headers = {
        "x-api-key": api_key,
        "Content-Type": f"multipart/form-data; boundary={boundary}",
    }

    def _do():
        with transport(UPLOAD_URL, headers=headers, body=body,
                       method="POST", timeout=timeout_s) as resp:
            code = getattr(resp, "status", 200)
            if code != 200:
                raise _classify_http(code, resp.read() if hasattr(resp, "read") else b"")
            data = json.loads(resp.read().decode("utf-8"))
            url = data.get("url") or data.get("presigned_url")
            if not url:
                raise PayloadError(f"sync.so upload returned no URL: {data}")
            return url

    return _do_with_retries(_do, max_retries=max_retries)


def lipsync_video(
    *,
    video_path: Path,
    audio_path: Path,
    model_id: str = "lipsync-2.0",
    output_format: str = "mp4",
    output_dir: Path,
    file_stem: str,
    sync_mode: str = "loop",
    fps: Optional[int] = None,
    api_key_env_var: str = DEFAULT_AUTH_ENV_VAR,
    timeout_s: float = DEFAULT_GEN_TIMEOUT_S,
    poll_interval_s: float = DEFAULT_POLL_INTERVAL_S,
    per_call_timeout_s: float = DEFAULT_PER_CALL_TIMEOUT_S,
    max_retries: int = 3,
    transport: Optional[Callable] = None,
) -> LipSyncResult:
    """Run sync.so lipsync over (video_path, audio_path) and return local output.

    Steps:
      1. Upload video_path → presigned URL.
      2. Upload audio_path → presigned URL.
      3. POST /v2/generate {model, input, options} → job_id.
      4. Poll GET /v2/generate/{job_id} until status COMPLETED|FAILED or timeout.
      5. Download output_url to output_dir/file_stem.<ext>.

    Raises:
        Auth/Quota/Payload/RateLimit on fail-fast 4xx.
        ServerError/NetworkError when retries exhausted.
        JobFailedError when status=FAILED.
        JobTimeoutError when timeout_s elapses without terminal status.
    """
    video_path = Path(video_path)
    audio_path = Path(audio_path)
    if not video_path.is_file():
        raise PayloadError(f"video_path not found: {video_path}")
    if not audio_path.is_file():
        raise PayloadError(f"audio_path not found: {audio_path}")

    api_key = os.environ.get(api_key_env_var)
    if not api_key:
        raise AuthError(f"{api_key_env_var} not set in environment")

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    output_path = output_dir / f"{file_stem}.{output_format}"

    tport = transport or _default_transport

    # 1+2 — uploads
    video_url = _upload_file(video_path, api_key=api_key, transport=tport,
                             timeout_s=per_call_timeout_s, max_retries=max_retries)
    audio_url = _upload_file(audio_path, api_key=api_key, transport=tport,
                             timeout_s=per_call_timeout_s, max_retries=max_retries)

    # 3 — submit
    submit_body = json.dumps({
        "model": model_id,
        "input": [
            {"type": "video", "url": video_url},
            {"type": "audio", "url": audio_url},
        ],
        "options": {k: v for k, v in {"sync_mode": sync_mode, "fps": fps}.items()
                    if v is not None},
    }).encode("utf-8")
    submit_headers = {"x-api-key": api_key, "Content-Type": "application/json"}

    def _submit():
        with tport(GENERATE_URL, headers=submit_headers, body=submit_body,
                   method="POST", timeout=per_call_timeout_s) as resp:
            code = getattr(resp, "status", 200)
            if code not in (200, 201, 202):
                raise _classify_http(code, resp.read() if hasattr(resp, "read") else b"")
            data = json.loads(resp.read().decode("utf-8"))
            jid = data.get("id") or data.get("job_id")
            if not jid:
                raise PayloadError(f"sync.so submit returned no id: {data}")
            return jid

    job_id: str = _do_with_retries(_submit, max_retries=max_retries)

    # 4 — poll
    poll_headers = {"x-api-key": api_key}
    started = time.time()
    last_status: dict = {}
    while True:
        if time.time() - started > timeout_s:
            raise JobTimeoutError(
                f"sync.so job {job_id} did not terminate within {timeout_s}s; "
                f"last status: {last_status}"
            )

        def _poll():
            with tport(f"{GENERATE_URL}/{job_id}", headers=poll_headers,
                       body=None, method="GET", timeout=per_call_timeout_s) as resp:
                code = getattr(resp, "status", 200)
                if code != 200:
                    raise _classify_http(code, resp.read() if hasattr(resp, "read") else b"")
                return json.loads(resp.read().decode("utf-8"))

        last_status = _do_with_retries(_poll, max_retries=max_retries)
        status = (last_status.get("status") or "").upper()
        if status == "COMPLETED":
            break
        if status == "FAILED":
            raise JobFailedError(
                f"sync.so job {job_id} FAILED: {last_status.get('error', '<no error msg>')}"
            )
        time.sleep(poll_interval_s)

    out_url = last_status.get("outputUrl") or last_status.get("output_url")
    if not out_url:
        raise PayloadError(f"sync.so completed without output URL: {last_status}")

    # 5 — download
    def _download():
        with tport(out_url, headers={}, body=None, method="GET",
                   timeout=per_call_timeout_s * 4) as resp:
            return resp.read()

    output_path.write_bytes(_do_with_retries(_download, max_retries=max_retries))

    duration_s = last_status.get("duration_s") or last_status.get("duration")
    return LipSyncResult(
        output_path=output_path,
        duration_s=float(duration_s) if duration_s is not None else None,
        cost_usd=_compute_cost(model_id, duration_s),
        model=model_id,
        job_id=job_id,
        raw_metadata={
            "video_path": str(video_path),
            "audio_path": str(audio_path),
            "sync_mode": sync_mode,
            "fps": fps,
            "final_status_keys": sorted(last_status.keys()),
        },
    )


__all__ = [
    "lipsync_video",
    "LipSyncResult",
    "LipSyncError",
    "AuthError",
    "QuotaError",
    "PayloadError",
    "RateLimitError",
    "ServerError",
    "NetworkError",
    "JobFailedError",
    "JobTimeoutError",
]
