"""Wan 2.7 adapter — image-to-video + reference-to-video via fal.ai SDK.

Extracted from WanClient (recoil/execution/api_client.py) in CP-2 of the
June 2026 refactor. Preserves exact wire behavior:
  - i2v: fal-ai/wan/v2.7/image-to-video via fal_client.subscribe (blocking)
  - r2v: fal-ai/wan/v2.7/reference-to-video via fal_client.subscribe (blocking)

Auth: FAL_KEY environment variable (via FalTransport).

Implementation note: Wan does NOT use the HTTP queue pattern. fal_client
.subscribe() blocks until completion. Adapter contract: build_submit returns
a SubmitRequest with a sentinel URL "fal-sdk://wan-i2v" or "fal-sdk://wan-r2v";
the executor in Phase 5 calls direct_subscribe_i2v / direct_subscribe_r2v on
this adapter and bypasses the HTTP transport entirely.

Composes shared FalTransport (Q3 lock, 2026-04-25). Upload caching, auth,
and the fal_client.subscribe wrapper all live in fal_transport.py — Wan,
Seedance, and Kling pull from that single source so wire behavior cannot
drift across the three.
"""

from __future__ import annotations

import base64
import json
import logging
from typing import Any, Optional

from recoil.execution.providers.base import (
    PollRequest,
    PollResult,
    ProviderJob,
    SubmitRequest,
    UnifiedVideoPayload,
)
from recoil.execution.providers.fal_transport import FalTransport
from recoil.execution.providers.payload_hints import coerce_to_dict

logger = logging.getLogger(__name__)


# Wan endpoint paths on fal.ai. Identical to WanClient.ENDPOINTS.
_ENDPOINTS = {
    "i2v": "fal-ai/wan/v2.7/image-to-video",
    "r2v": "fal-ai/wan/v2.7/reference-to-video",
}

# HappyHorse-1.0 (Alibaba) endpoints on fal.ai — backup retry tier.
_HAPPY_HORSE_ENDPOINTS = {
    "i2v": "alibaba/happy-horse/image-to-video",
    "r2v": "alibaba/happy-horse/reference-to-video",
}

_WAN_MODELS = frozenset({"wan-2.7-i2v", "wan-2.7-r2v"})


def _endpoint_for(model_id: str, sub_path: str) -> str:
    """Return the fal endpoint for (model_id, sub_path).

    WanAdapter hosts both Wan 2.7 and HappyHorse-1.0 because they share the
    fal_client.subscribe wire pattern. HappyHorse-specific differences (token
    syntax, payload field names, no enable_* flags) are handled at the
    direct_subscribe_* level.
    """
    if model_id in _WAN_MODELS:
        return _ENDPOINTS[sub_path]
    if "happy-horse" in model_id:
        return _HAPPY_HORSE_ENDPOINTS[sub_path]
    return _ENDPOINTS[sub_path]

# Sentinels surfaced via SubmitRequest.url. The Phase 5 executor recognizes
# these and dispatches to direct_subscribe_* instead of the HTTP path.
_SENTINEL_I2V = "fal-sdk://wan-i2v"
_SENTINEL_R2V = "fal-sdk://wan-r2v"

DEFAULT_NEGATIVE_PROMPT = (
    "blurry, low quality, distorted, deformed, watermark, text, logo"
)


def _bytes_to_data_uri(image_bytes: bytes) -> str:
    """Convert raw bytes to a data URI for inline frame submission.

    Carried over verbatim from WanClient._bytes_to_data_uri. Detects MIME
    by magic bytes and falls back to PNG.
    """
    if image_bytes[:3] == b"\xff\xd8\xff":
        mime = "image/jpeg"
    elif image_bytes[:8] == b"\x89PNG\r\n\x1a\n":
        mime = "image/png"
    elif image_bytes[:4] == b"RIFF" and image_bytes[8:12] == b"WEBP":
        mime = "image/webp"
    else:
        mime = "image/png"
    b64 = base64.b64encode(image_bytes).decode()
    return f"data:{mime};base64,{b64}"


def _coerce_image_to_url(value: Any, transport: FalTransport) -> str:
    """Normalize a UnifiedVideoPayload.image / image_tail / reference_images
    entry to a fal-fetchable URL.

    Accepts:
      - str URL ("http://", "https://", "fal://", "data:")
      - str local path -> upload via cached transport
      - bytes / bytearray -> upload via temp file
      - str base64 -> data URI
    """
    if isinstance(value, (bytes, bytearray)):
        return transport.upload_bytes(bytes(value))
    if isinstance(value, str):
        if value.startswith(("http://", "https://", "fal://", "data:")):
            return value
        # Heuristic: if it looks like a path that exists on disk, upload it.
        # Otherwise treat as base64 image data.
        import os
        if os.path.isfile(value):
            return transport.upload_path(value)
        try:
            decoded = base64.b64decode(value, validate=True)
            return _bytes_to_data_uri(decoded)
        except Exception:
            # Last resort: pass through as-is (caller may have provided a
            # URL fragment we don't recognize).
            return value
    raise TypeError(
        f"wan adapter: cannot coerce {type(value).__name__} to image URL"
    )


# ----------------------------------------------------------------------
# The adapter
# ----------------------------------------------------------------------


class WanAdapter:
    provider_id = "wan"
    supported_models = ["wan-2.7-i2v", "wan-2.7-r2v", "happy-horse-i2v", "happy-horse-r2v"]
    auth_env_var = "FAL_KEY"
    base_url = "https://queue.fal.run"
    max_prompt_chars = None
    status = "primary"
    capabilities = {
        "t2v": False,
        "i2v": True,
        "r2v": True,
        "end_frame": True,         # i2v supports end_image_url
        "audio": False,
        "negative_prompt": True,
        "resolution_480p": False,
        "resolution_720p": True,
        "resolution_1080p": False,
    }

    def __init__(self):
        # Per-instance transport: holds the FAL_KEY auth + upload cache.
        # Registry caches the adapter, so this transport (and its cache)
        # survive across submits in one session.
        self._transport = FalTransport(auth_env_var=self.auth_env_var)

    # ---- submit ----

    def build_submit(
        self, payload: UnifiedVideoPayload, tier: str
    ) -> SubmitRequest:
        """Build a sentinel SubmitRequest. Wan bypasses HTTP — the Phase 5
        executor sees the sentinel URL and calls direct_subscribe_* instead.

        Sub-path is inferred from payload shape:
          payload.image present -> i2v
          payload.reference_images / reference_videos present -> r2v
        """
        if payload.image is not None:
            return SubmitRequest(
                method="POST",
                url=_SENTINEL_I2V,
                headers={},
                body={
                    "prompt": payload.prompt,
                    "model": payload.model_id,
                    "duration": payload.duration_s,
                    "resolution": payload.resolution,
                    "tier": tier,
                },
            )
        if payload.reference_images or payload.reference_videos:
            return SubmitRequest(
                method="POST",
                url=_SENTINEL_R2V,
                headers={},
                body={
                    "prompt": payload.prompt,
                    "model": payload.model_id,
                    "duration": payload.duration_s,
                    "resolution": payload.resolution,
                    "tier": tier,
                },
            )
        raise ValueError(
            "WanAdapter.build_submit: payload missing image (i2v) and "
            "reference_images/reference_videos (r2v) — Wan cannot do T2V."
        )

    def parse_submit(
        self, resp: dict, payload: UnifiedVideoPayload, tier: str
    ) -> ProviderJob:
        # The "resp" passed in is the result of direct_subscribe_* (see below).
        # The executor calls direct_subscribe and feeds its return value here.
        native_id = (
            resp.get("native_id")
            or resp.get("request_id")
            or resp.get("seed")
            or "wan-job"
        )
        sub_path = (
            "i2v" if payload.image is not None
            else "r2v"
        )
        return ProviderJob(
            provider_id=self.provider_id,
            model_id=payload.model_id,
            native_id=str(native_id),
            tier=tier,
            duration_s=payload.duration_s,
            resolution=payload.resolution,
            native_state={
                "sub_path": sub_path,
                "model_path": _endpoint_for(payload.model_id, sub_path),
                "video_url": resp.get("video_url")
                    or (resp.get("video") or {}).get("url"),
                "raw": resp,
            },
        )

    # ---- poll ----

    def build_poll(self, job: ProviderJob) -> PollRequest:
        # Blocking subscribe — already complete by the time we have a job.
        return PollRequest(method="GET", url="fal-sdk://wan-noop", headers={})

    def parse_poll(self, resp: dict, job: ProviderJob) -> PollResult:
        # Blocking subscribe completes synchronously.
        return PollResult(status="COMPLETED", raw=resp or {})

    # ---- result fetch ----

    def build_result_fetch(self, job: ProviderJob) -> Optional[PollRequest]:
        return None  # already in-hand from direct_subscribe_*

    def parse_result(self, resp: dict, job: ProviderJob) -> PollResult:
        video_url = (
            (resp or {}).get("video_url")
            or job.native_state.get("video_url")
        )
        if not video_url:
            return PollResult(
                status="FAILED",
                error=(
                    "provider returned COMPLETED with no video_url: "
                    + json.dumps(resp, default=str, sort_keys=True)[:500]
                ),
                raw=resp,
            )
        return PollResult(
            status="COMPLETED",
            video_url=video_url,
            audio_url=None,
            observed_cost=None,
            raw=resp or job.native_state.get("raw") or {},
        )

    # ---- cost ----

    def compute_cost(
        self, duration_s: float, tier: str, profile: dict
    ) -> float:
        rate = (profile or {}).get("cost_per_second")
        if rate is None:
            # Conservative default; real value comes from model_profiles.json.
            rate = 0.0
        return float(rate) * float(duration_s)

    # ---- direct-subscribe path (used by execute_video in Phase 5) ----

    def direct_subscribe_i2v(self, payload: UnifiedVideoPayload) -> dict:
        """Blocking i2v submit. Mirrors WanClient._submit_i2v body shape.

        Returns {video_url, raw, seed} — the executor wraps this into a
        PollResult via parse_submit/parse_result.
        """
        if payload.image is None:
            raise ValueError(
                "WanAdapter.direct_subscribe_i2v: payload.image required"
            )
        endpoint = _endpoint_for(payload.model_id, "i2v")
        is_wan = payload.model_id in _WAN_MODELS

        image_url = _coerce_image_to_url(payload.image, self._transport)

        body: dict = {
            "prompt": payload.prompt,
            "image_url": image_url,
            "duration": payload.duration_s,
            "resolution": payload.resolution,
        }
        if is_wan:
            # Wan-specific flags. HappyHorse rejects unknown params.
            body["enable_prompt_expansion"] = bool(
                coerce_to_dict(payload.hints).get("enable_prompt_expansion", False)
            )
            body["enable_safety_checker"] = bool(
                coerce_to_dict(payload.hints).get("enable_safety_checker", True)
            )

        if payload.image_tail is not None:
            body["end_image_url"] = _coerce_image_to_url(
                payload.image_tail, self._transport
            )

        if payload.negative_prompt:
            body["negative_prompt"] = payload.negative_prompt

        seed = coerce_to_dict(payload.hints).get("seed")
        if seed is not None:
            body["seed"] = seed

        audio_url = coerce_to_dict(payload.hints).get("audio_url")
        if audio_url:
            body["audio_url"] = audio_url

        video_url_in = coerce_to_dict(payload.hints).get("video_url")
        if video_url_in:
            body["video_url"] = video_url_in

        logger.info("Wan I2V request body keys: %s", sorted(body.keys()))
        logger.info(
            "Wan I2V endpoint: %s, duration=%s, end_frame=%s",
            endpoint, body["duration"], "end_image_url" in body,
        )

        result_data = self._transport.subscribe_blocking(endpoint, body)
        video_url = (result_data.get("video") or {}).get("url")
        return {
            "video_url": video_url,
            "raw": result_data,
            "seed": result_data.get("seed"),
            "native_id": result_data.get("request_id") or "wan-i2v",
        }

    def direct_subscribe_r2v(self, payload: UnifiedVideoPayload) -> dict:
        """Blocking r2v submit. Mirrors WanClient._submit_r2v body shape."""
        if not (payload.reference_images or payload.reference_videos):
            raise ValueError(
                "WanAdapter.direct_subscribe_r2v: payload.reference_images "
                "or reference_videos required"
            )
        endpoint = _endpoint_for(payload.model_id, "r2v")
        is_wan = payload.model_id in _WAN_MODELS

        body: dict = {
            "prompt": payload.prompt,
            "duration": payload.duration_s,
            "resolution": payload.resolution,
            "aspect_ratio": payload.aspect_ratio or "9:16",
        }
        if is_wan:
            body["enable_safety_checker"] = bool(
                coerce_to_dict(payload.hints).get("enable_safety_checker", True)
            )

        if payload.reference_images:
            ref_urls = [
                _coerce_image_to_url(ref, self._transport)
                for ref in payload.reference_images
            ]
            # Wan: reference_image_urls. HappyHorse: image_urls.
            ref_key = "reference_image_urls" if is_wan else "image_urls"
            body[ref_key] = ref_urls

        if payload.reference_videos:
            ref_video_urls: list[str] = []
            for v in payload.reference_videos:
                if isinstance(v, str) and (
                    v.startswith("http") or v.startswith("fal://")
                ):
                    ref_video_urls.append(v)
                elif isinstance(v, str):
                    # local path -> upload (cached)
                    ref_video_urls.append(self._transport.upload_path(v))
                else:
                    logger.warning(
                        "wan r2v: skipping invalid video ref %r", str(v)[:60]
                    )
            if ref_video_urls:
                body["reference_video_urls"] = ref_video_urls

        if coerce_to_dict(payload.hints).get("multi_shots"):
            body["multi_shots"] = True

        seed = coerce_to_dict(payload.hints).get("seed")
        if seed is not None:
            body["seed"] = seed

        if payload.negative_prompt:
            body["negative_prompt"] = payload.negative_prompt

        logger.info("Wan R2V request body keys: %s", sorted(body.keys()))
        logger.info(
            "Wan R2V endpoint: %s, refs=%d, multi_shots=%s",
            endpoint,
            len(body.get("reference_image_urls", [])),
            body.get("multi_shots", False),
        )

        result_data = self._transport.subscribe_blocking(endpoint, body)
        video_url = (result_data.get("video") or {}).get("url")
        return {
            "video_url": video_url,
            "raw": result_data,
            "seed": result_data.get("seed"),
            "native_id": result_data.get("request_id") or "wan-r2v",
        }


ADAPTER = WanAdapter()

__all__ = ["WanAdapter", "ADAPTER"]
