"""Shared fal.ai transport — composed by every fal-backed adapter.

Locked 2026-04-25 (CP-2 Phase 3, JT Q3 lock). Wan, Seedance, and Kling all
run on fal.ai infrastructure and share the same transport stack:

  - Queue auth via FAL_KEY -> Authorization: Key {key}
  - upload-to-fal-storage with per-instance caching of fal storage URLs
  - Queue submission to queue.fal.run/{model_path}
  - Status polling (GET status_url)
  - Result fetch (GET response_url)
  - Optional blocking-subscribe wrapper for fal_client.subscribe()

Three independent stacks would be three drift surfaces. Composition pattern,
not inheritance trees: adapters hold a FalTransport instance and call into
it for transport concerns; everything model-specific (endpoints, body shape,
result parsing) stays on the adapter.

This module performs I/O. Adapters never bypass it for fal.ai traffic.
"""

from __future__ import annotations

import json
import logging
import os
import tempfile
import urllib.error
import urllib.request
from typing import Optional

logger = logging.getLogger(__name__)

try:
    import fal_client as _fal_client
    _HAS_FAL = True
except ImportError:
    _fal_client = None
    _HAS_FAL = False


QUEUE_BASE = "https://queue.fal.run"
FAL_SUBSCRIBE_TIMEOUT_S = 900


# ----------------------------------------------------------------------
# Module-level helpers (stateless — used by FalAdapter directly today).
# ----------------------------------------------------------------------

def upload_bytes_to_fal(data: bytes, suffix: str = ".jpg") -> str:
    """Write bytes to a temp file, upload via fal_client, return URL.

    Stateless — no caching. Adapters that want session-level caching should
    use FalTransport.upload_bytes_to_fal which keys by content.
    """
    if not _HAS_FAL:
        raise RuntimeError("fal_client not installed — cannot upload frames")
    with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
        f.write(data)
        tmp_path = f.name
    try:
        return _fal_client.upload_file(tmp_path)
    finally:
        os.unlink(tmp_path)


def upload_path_to_fal(file_path: str) -> str:
    """Upload a file at `file_path` directly via fal_client. Stateless."""
    if not _HAS_FAL:
        raise RuntimeError("fal_client not installed — cannot upload files")
    return _fal_client.upload_file(file_path)


# ----------------------------------------------------------------------
# FalTransport — the per-adapter instance composed by wan / kling / fal.
# ----------------------------------------------------------------------

class FalTransport:
    """Per-adapter transport handle for fal.ai.

    Holds:
      - the FAL_KEY-driven auth header builder
      - a per-instance upload cache (path -> fal URL) so refs aren't
        re-uploaded across multiple submits in the same session.

    The cache lives on this instance; the adapter that owns this transport
    is itself cached by the registry, so the cache survives across submits.
    """

    QUEUE_BASE = QUEUE_BASE

    def __init__(self, auth_env_var: str = "FAL_KEY"):
        self._auth_env_var = auth_env_var
        self._upload_cache: dict[str, str] = {}  # absolute path -> fal URL

    # -- auth --

    def headers(self) -> dict:
        key = os.environ.get(self._auth_env_var)
        if not key:
            raise RuntimeError(
                f"{self._auth_env_var} environment variable not set."
            )
        return {
            "Authorization": f"Key {key}",
            "Content-Type": "application/json",
        }

    # -- upload --

    def upload_path(self, file_path: str) -> str:
        """Upload a local file to fal storage. Cached by absolute path."""
        abs_path = os.path.abspath(file_path)
        if abs_path in self._upload_cache:
            logger.debug("fal upload cache hit: %s", os.path.basename(abs_path))
            return self._upload_cache[abs_path]
        if not _HAS_FAL:
            raise RuntimeError("fal_client not installed — cannot upload files")
        url = _fal_client.upload_file(abs_path)
        self._upload_cache[abs_path] = url
        logger.info(
            "Uploaded to fal: %s -> %s",
            os.path.basename(abs_path), url[:80],
        )
        return url

    def upload_bytes(self, data: bytes, suffix: str = ".jpg") -> str:
        """Write bytes to temp file, upload, return URL. NOT cached
        (no stable key for raw bytes)."""
        return upload_bytes_to_fal(data, suffix=suffix)

    # -- HTTP queue (used by Seedance / Kling-style polling paths) --

    def _request(
        self, method: str, url: str, json_data: Optional[dict] = None
    ) -> dict:
        # CP-2 bugfix — explicit None check. Truthiness drops empty dict
        # bodies (e.g. cancel() passes {} intending an empty JSON body),
        # which can return 400 Bad Request from APIs that require an
        # application/json payload.
        body = json.dumps(json_data).encode() if json_data is not None else None
        req = urllib.request.Request(
            url, data=body, headers=self.headers(), method=method
        )
        try:
            with urllib.request.urlopen(req, timeout=30) as resp:
                return json.loads(resp.read().decode())
        except urllib.error.HTTPError as e:
            error_body = e.read().decode() if e.fp else ""
            logger.error(
                "fal.ai %s %s -> %d: %s", method, url, e.code, error_body
            )
            raise RuntimeError(
                f"fal.ai API error {e.code}: {error_body}"
            ) from e

    def submit_queue(self, model_path: str, body: dict) -> dict:
        """POST to queue.fal.run/{model_path}. Returns full response dict
        (caller extracts request_id / status_url / response_url)."""
        url = f"{self.QUEUE_BASE}/{model_path}"
        return self._request("POST", url, body)

    def poll_status(self, status_url: str) -> dict:
        """GET status URL. Returns the raw response dict."""
        return self._request("GET", status_url)

    def fetch_result(self, response_url: str) -> dict:
        """GET response URL on COMPLETED. Returns full result payload."""
        return self._request("GET", response_url)

    def cancel(self, model_path: str, request_id: str) -> None:
        """Best-effort cancel via PUT. Errors swallowed by caller."""
        url = f"{self.QUEUE_BASE}/{model_path}/requests/{request_id}/cancel"
        self._request("PUT", url, {})

    # -- blocking subscribe (used by Wan) --

    def subscribe_blocking(self, model_path: str, body: dict) -> dict:
        """Wrap fal_client.subscribe() — blocks until completion.

        Wan uses this instead of HTTP queue+poll because the SDK handles
        polling internally and Wan's reliability depends on it. Do NOT
        replace with the HTTP path.
        """
        if not _HAS_FAL:
            raise RuntimeError(
                "fal_client not installed — cannot submit blocking jobs"
            )
        return _fal_client.subscribe(
            model_path,
            arguments=body,
            client_timeout=FAL_SUBSCRIBE_TIMEOUT_S,
        )


__all__ = [
    "FalTransport",
    "FAL_SUBSCRIBE_TIMEOUT_S",
    "QUEUE_BASE",
    "upload_bytes_to_fal",
    "upload_path_to_fal",
]
