"""PiAPI adapter for seeddance-2.0.

API shape (confirmed 2026-04-17 from piapi.ai docs):

  Submit:
    POST https://api.piapi.ai/api/v1/task
    Headers: X-API-Key: <PIAPI_API_KEY>, Content-Type: application/json
    Body: {
      "model": "seedance",
      "task_type": "seedance-2" | "seedance-2-fast" | ...,
      "input": {
        "prompt": <str>,
        "mode": "text_to_video" | "first_last_frames" | "omni_reference",
        "duration": <int 4..15>,
        "aspect_ratio": <str>,
        "image_urls": [<url>, ...]    # I2V: one; first_last_frames: two; omni: refs
        "video_urls": [<url>, ...],   # optional
        "negative_prompt": <str>       # if supported
      }
    }

  Poll:
    POST https://api.piapi.ai/api/v1/task/{task_id}
    (yes — PiAPI uses POST to fetch task status in v1; empty body)
    Result: {
      "data": {
        "status": "pending" | "processing" | "completed" | "failed",
        "output": {"video": {"url": <str>}, "audio": {"url": <str|null>}},
        "meta": {"usage": {"consume": <credits_float>}}   # observed cost signal
      }
    }

PiAPI observed cost signal (`meta.usage.consume`) is credits, not USD.
compute_cost converts using the tier's listed rate as the reference
point; the drift report compares billed-credits-ratio to listed-
credits-ratio to detect pricing changes.
"""

from __future__ import annotations

import json
import logging
import os
from typing import Optional

from recoil.execution.providers.base import (
    PollRequest,
    PollResult,
    ProviderJob,
    SubmitRequest,
    UnifiedVideoPayload,
)

logger = logging.getLogger(__name__)

_BASE = "https://api.piapi.ai/api/v1"


def _mode_for_payload(payload: UnifiedVideoPayload) -> str:
    if payload.image is not None and payload.image_tail is not None:
        return "first_last_frames"
    if payload.image is not None or payload.reference_images:
        # PiAPI treats both I2V and R2V as omni_reference.
        return "omni_reference"
    return "text_to_video"


def _collect_image_urls(payload: UnifiedVideoPayload) -> list[str]:
    out: list[str] = []
    if payload.image is not None:
        if isinstance(payload.image, str) and payload.image.startswith("http"):
            out.append(payload.image)
    if payload.image_tail is not None:
        if isinstance(payload.image_tail, str) and payload.image_tail.startswith("http"):
            out.append(payload.image_tail)
    for ref in payload.reference_images or []:
        if isinstance(ref, str) and ref.startswith("http"):
            out.append(ref)
    return out


class PiApiAdapter:
    provider_id = "piapi"
    supported_models = ["seeddance-2.0"]
    auth_env_var = "PIAPI_API_KEY"
    base_url = _BASE
    max_prompt_chars = None
    status = "testing"
    capabilities = {
        "t2v": True,
        "i2v": True,
        "r2v": True,
        # PiAPI does not expose end-frame as first-class; fal handles end_frame
        # via capability_exceptions in provider_strategy.json. Mark False so
        # the registry refuses to route end_frame to piapi without override.
        "end_frame": False,
        "audio": True,
        "negative_prompt": True,
        "resolution_480p": False,
        "resolution_720p": True,
        "resolution_1080p": True,
    }

    def _headers(self) -> dict:
        key = os.environ.get(self.auth_env_var)
        if not key:
            raise RuntimeError(f"{self.auth_env_var} environment variable not set.")
        return {
            "X-API-Key": key,
            "Content-Type": "application/json",
        }

    def build_submit(self, payload: UnifiedVideoPayload, tier: str) -> SubmitRequest:
        mode = _mode_for_payload(payload)
        image_urls = _collect_image_urls(payload)

        inp: dict = {
            "prompt": payload.prompt,
            "mode": mode,
            "duration": int(payload.duration_s),
        }
        if payload.aspect_ratio:
            inp["aspect_ratio"] = payload.aspect_ratio
        if image_urls:
            inp["image_urls"] = image_urls
        if payload.negative_prompt:
            inp["negative_prompt"] = payload.negative_prompt
        if payload.generate_audio:
            inp["generate_audio"] = True

        body = {
            "model": "seedance",
            "task_type": tier,       # e.g. "seedance-2" or "seedance-2-fast"
            "input": inp,
        }
        return SubmitRequest(
            method="POST",
            url=f"{_BASE}/task",
            headers=self._headers(),
            body=body,
        )

    def parse_submit(
        self, resp: dict, payload: UnifiedVideoPayload, tier: str
    ) -> ProviderJob:
        data = resp.get("data") or resp
        task_id = data.get("task_id") or data.get("id")
        if not task_id:
            raise RuntimeError(
                f"piapi adapter: no task_id in submit response: "
                f"{data.get('message') or resp}"
            )
        return ProviderJob(
            provider_id=self.provider_id,
            model_id=payload.model_id,
            native_id=task_id,
            tier=tier,
            duration_s=payload.duration_s,
            resolution=payload.resolution,
            native_state={},
        )

    def build_poll(self, job: ProviderJob) -> PollRequest:
        return PollRequest(
            method="POST",     # PiAPI poll is POST with empty body
            url=f"{_BASE}/task/{job.native_id}",
            headers=self._headers(),
        )

    def parse_poll(self, resp: dict, job: ProviderJob) -> PollResult:
        data = resp.get("data") or resp
        raw = (data.get("status") or "unknown").lower()
        if raw == "completed":
            return self.parse_result(resp, job)
        if raw == "failed":
            err = (data.get("error") or {}).get("message") \
                or data.get("message") \
                or "piapi generation failed"
            return PollResult(status="FAILED", error=err, raw=data)
        return PollResult(status="IN_PROGRESS", raw=data)

    def build_result_fetch(self, job: ProviderJob) -> Optional[PollRequest]:
        # PiAPI returns result in the status response; no separate fetch.
        return None

    def parse_result(self, resp: dict, job: ProviderJob) -> PollResult:
        data = resp.get("data") or resp
        output = data.get("output") or {}
        video_url = None
        if isinstance(output.get("video"), dict):
            video_url = output["video"].get("url")
        elif isinstance(output.get("video_url"), str):
            video_url = output["video_url"]
        audio_url = None
        if isinstance(output.get("audio"), dict):
            audio_url = output["audio"].get("url")

        meta = data.get("meta") or {}
        usage = meta.get("usage") or {}
        consume_credits = usage.get("consume")
        # Convert credits -> USD: PiAPI exposes credits but users purchase
        # them 1:1 with USD (1 credit ~= $0.01 in their published docs).
        # We record the raw credit-seconds in observed_cost; the drift report
        # does the ratio compare. When credits absent, leave None.
        observed_cost = None
        if isinstance(consume_credits, (int, float)):
            observed_cost = float(consume_credits) * 0.01

        if not video_url:
            return PollResult(
                status="FAILED",
                error=(
                    "provider returned COMPLETED with no video_url: "
                    + json.dumps(resp, default=str, sort_keys=True)[:500]
                ),
                raw=resp,
            )

        return PollResult(
            status="COMPLETED",
            video_url=video_url,
            audio_url=audio_url,
            observed_cost=observed_cost,
            raw=data,
        )

    def compute_cost(self, duration_s: float, tier: str, profile: dict) -> float:
        providers = (profile or {}).get("providers", {})
        piapi_block = providers.get("piapi") or {}
        tier_block = (piapi_block.get("tiers") or {}).get(tier) or {}
        rate = tier_block.get("cost_per_second")
        if rate is None:
            # Fallback to legacy flat fields in profile.
            if "fast" in tier:
                rate = profile.get("cost_per_second_piapi_fast", 0.10)
            else:
                rate = profile.get("cost_per_second_piapi_standard", 0.13)
        return float(rate) * float(duration_s)


ADAPTER = PiApiAdapter()

__all__ = ["PiApiAdapter", "ADAPTER"]
