"""Atlas Cloud adapter for seeddance-2.0 (DEPRECATED 2026-04-15).

Preserved for historical runs and emergency re-activation. Not routable
by default — provider_strategy.json does not list atlas as primary or
fallback. Explicit env override (RECOIL_PROVIDER_OVERRIDE=atlas) still
works for one-off diagnostics.

Extracted from SeedDanceClient in the Provider Adapter Refactor
(2026-04-17). Preserves exact wire behavior of the atlas branch.
"""

from __future__ import annotations

import json
import logging
import os
from typing import Optional

from recoil.execution.providers.base import (
    PollRequest,
    PollResult,
    ProviderJob,
    SubmitRequest,
    UnifiedVideoPayload,
)

logger = logging.getLogger(__name__)


_ATLAS_BASE = "https://api.atlascloud.ai/api/v1/model"

_ACTION_MAP = {
    "t2v": "text-to-video",
    "i2v": "image-to-video",
    "r2v": "reference-to-video",
}


def _atlas_model_path(tier: str, action: str) -> str:
    # Atlas uses -fast suffix on the model name, not a sub-path.
    if tier.startswith("fast"):
        return f"bytedance/seedance-2.0-fast/{action}"
    return f"bytedance/seedance-2.0/{action}"


def _atlas_resolution(tier: str, fallback: str) -> str:
    if tier.endswith("_480p"):
        return "480p"
    if tier.endswith("_720p"):
        return "720p"
    return fallback


def _infer_action(payload: UnifiedVideoPayload) -> str:
    if payload.image is not None:
        return "i2v"
    if payload.reference_images:
        return "r2v"
    return "t2v"


class AtlasAdapter:
    provider_id = "atlas"
    supported_models = ["seeddance-2.0"]
    auth_env_var = "ATLAS_CLOUD_API_KEY"
    base_url = _ATLAS_BASE
    max_prompt_chars = None
    status = "deprecated"
    capabilities = {
        "t2v": True,
        "i2v": True,
        "r2v": True,
        "end_frame": True,
        "audio": True,
        "negative_prompt": True,
        "resolution_480p": True,
        "resolution_720p": True,
        "resolution_1080p": False,
    }

    def _headers(self) -> dict:
        key = os.environ.get(self.auth_env_var)
        if not key:
            raise RuntimeError(f"{self.auth_env_var} environment variable not set.")
        return {
            "Authorization": f"Bearer {key}",
            "Content-Type": "application/json",
            "User-Agent": "RecoilPipeline/1.0",
        }

    def build_submit(self, payload: UnifiedVideoPayload, tier: str) -> SubmitRequest:
        action_key = _infer_action(payload)
        action = _ACTION_MAP[action_key]
        model_path = _atlas_model_path(tier, action)
        resolution = _atlas_resolution(tier, payload.resolution)

        body: dict = {
            "model": model_path,
            "prompt": payload.prompt,
            "duration": int(payload.duration_s),
            "resolution": resolution,
            "ratio": payload.aspect_ratio or "adaptive",
            "generate_audio": bool(payload.generate_audio),
            "watermark": False,
        }
        # Atlas expects pre-hosted URLs. Callers must pass URL strings (or a
        # pre-uploaded fal URL, which Atlas will fetch).
        if isinstance(payload.image, str) and payload.image.startswith("http"):
            body["image_url"] = payload.image
        if isinstance(payload.image_tail, str) and payload.image_tail.startswith("http"):
            body["end_image_url"] = payload.image_tail
        if payload.reference_images:
            urls = [
                r for r in payload.reference_images
                if isinstance(r, str) and r.startswith("http")
            ]
            if urls:
                body["image_urls"] = urls
        if payload.negative_prompt:
            body["negative_prompt"] = payload.negative_prompt

        url = f"{_ATLAS_BASE}/generateVideo"
        return SubmitRequest(method="POST", url=url, headers=self._headers(), body=body)

    def parse_submit(
        self, resp: dict, payload: UnifiedVideoPayload, tier: str
    ) -> ProviderJob:
        data = resp.get("data", resp)
        prediction_id = data.get("id") or resp.get("id")
        if not prediction_id:
            raise RuntimeError(
                f"atlas adapter: no prediction id in submit response: "
                f"{data.get('message', resp)}"
            )
        action_key = _infer_action(payload)
        action = _ACTION_MAP[action_key]
        return ProviderJob(
            provider_id=self.provider_id,
            model_id=payload.model_id,
            native_id=prediction_id,
            tier=tier,
            duration_s=payload.duration_s,
            resolution=_atlas_resolution(tier, payload.resolution),
            native_state={"action": action},
        )

    def build_poll(self, job: ProviderJob) -> PollRequest:
        url = f"{_ATLAS_BASE}/prediction/{job.native_id}"
        return PollRequest(method="GET", url=url, headers=self._headers())

    def parse_poll(self, resp: dict, job: ProviderJob) -> PollResult:
        data = resp.get("data", resp)
        raw = (data.get("status") or "unknown").lower()
        if raw in ("completed", "succeeded"):
            # Atlas returns outputs in the same status payload — stash for parse_result.
            return PollResult(status="COMPLETED", raw=data)
        if raw == "failed":
            return PollResult(
                status="FAILED",
                error=data.get("message", "Atlas Cloud generation failed"),
                raw=data,
            )
        return PollResult(status="IN_PROGRESS", raw=data)

    def build_result_fetch(self, job: ProviderJob) -> Optional[PollRequest]:
        # Atlas returns result in the same call as status; no separate fetch needed.
        return None

    def parse_result(self, resp: dict, job: ProviderJob) -> PollResult:
        # Called by registry/client when build_result_fetch returns None; `resp`
        # is then the poll response we already have.
        data = resp.get("data", resp) if "data" in (resp or {}) else (resp or {})
        outputs = data.get("outputs") or []
        video_url = outputs[0] if outputs else None
        if not video_url:
            return PollResult(
                status="FAILED",
                error=(
                    "provider returned COMPLETED with no video_url: "
                    + json.dumps(resp, default=str, sort_keys=True)[:500]
                ),
                raw=resp,
            )
        return PollResult(
            status="COMPLETED",
            video_url=video_url,
            audio_url=None,
            observed_cost=None,
            raw=data,
        )

    def compute_cost(self, duration_s: float, tier: str, profile: dict) -> float:
        providers = (profile or {}).get("providers", {})
        atlas_block = providers.get("atlas") or {}
        tier_block = (atlas_block.get("tiers") or {}).get(tier) or {}
        rate = tier_block.get("cost_per_second_observed") or tier_block.get("cost_per_second_listed")
        if rate is None:
            rate = 0.225  # documented observed floor
        return float(rate) * float(duration_s)


ADAPTER = AtlasAdapter()

__all__ = ["AtlasAdapter", "ADAPTER"]
