"""Local ComfyUI provider adapter — Krea 2 Turbo (bf16) image generation.

Unlike fal/flora (hosted queue APIs), this adapter talks to a LOCAL ComfyUI
server over HTTP (POST /prompt, poll /history/{id}, fetch /view). It exists so
the pipeline can route image keyframes to a self-hosted, uncensored open model
for graphic dramatic shots that hosted providers refuse.

Image generation is synchronous from the caller's view: ``execute_keyframe``
dispatches via :meth:`direct_submit_image` (mirroring fal/flora/google), which
builds the Krea 2 workflow, submits it, polls until the prompt executes, and
reads the rendered PNG bytes back. The seven async-video Protocol methods are
not used on the image path and raise NotImplementedError.

Krea 2 specifics (validated empirically 2026-06-24, see
consultations/recoil/krea2-oss-local-image-2026-06-24/SPIKE_AND_ADAPTER.md):
- bf16 weights ONLY on Apple Silicon (fp8 dies on MPS).
- Graph: UNETLoader(bf16) -> CLIPLoader(type "krea2") -> CLIPTextEncode ->
  ConditioningKrea2Rebalance(mult) -> KSampler(8, cfg 1, euler, simple) ->
  VAEDecode -> SaveImage. The rebalance node bypasses Krea 2's trained safety
  dilution; without it graphic content is softened. The official template's
  prompt_enhance LLM subgraph is intentionally NOT used.

Env:
- COMFYUI_URL — base URL of the ComfyUI server (default http://127.0.0.1:8188).
Cost: local inference is unmetered -> cost_usd = 0.0.
"""

from __future__ import annotations

import json
import os
import time
import urllib.error
import urllib.parse
import urllib.request
from typing import Optional

from recoil.execution.providers.base import (
    CAPABILITY_KEYS,
    PollRequest,
    PollResult,
    ProviderJob,
    SubmitRequest,
    UnifiedVideoPayload,
)
from recoil.execution.providers.payload_hints import coerce_to_dict

# Krea 2 Turbo ComfyUI asset filenames (see SPIKE doc §1). The server must have
# these in models/diffusion_models, models/text_encoders, models/vae.
_UNET = "krea2_turbo_bf16.safetensors"
_CLIP = "qwen3vl_4b_fp8_scaled.safetensors"
_VAE = "qwen_image_vae.safetensors"
_REBALANCE_DEFAULT_WEIGHTS = "1.0,1.0,1.0,1.0,1.0,1.0,1.0,2.5,5.0,1.1,4.0,1.0"

# aspect_ratio -> (width, height), ~1 MP, multiples of 16.
_AR_DIMS = {
    "1:1": (1024, 1024),
    "9:16": (768, 1344),
    "16:9": (1344, 768),
}

_POLL_TIMEOUT_S = 600
_POLL_INTERVAL_S = 2.0


class ComfyUIAdapter:
    provider_id = "comfyui"
    supported_models = ["krea-2-turbo"]
    auth_env_var = "COMFYUI_URL"  # not a secret; the endpoint URL
    base_url = os.environ.get("COMFYUI_URL", "http://127.0.0.1:8188")
    max_prompt_chars = None
    status = "primary"

    # Honest capability surface: Krea 2 Turbo runs at cfg 1.0 (8-step distilled),
    # so a negative prompt is inert — NOT advertised. Image dims derive from
    # aspect_ratio, not a resolution tier; execute_keyframe passes
    # resolution="720p", so only resolution_720p is claimed (the others would be
    # silently ignored — see base.py "never silently degrade").
    capabilities = {k: False for k in CAPABILITY_KEYS}
    capabilities.update(
        {
            "t2v": True,  # text-to-image (keyframe t2i routes through the "t2v" key)
            "resolution_720p": True,
        }
    )

    # ------------------------------------------------------------------
    # Workflow construction
    # ------------------------------------------------------------------
    def _build_workflow(
        self,
        prompt: str,
        *,
        seed: int,
        width: int,
        height: int,
        rebalance_multiplier: float,
        rebalance_weights: str,
        filename_prefix: str,
    ) -> dict:
        """ComfyUI API-format graph for Krea 2 Turbo with the rebalance node.

        cfg is fixed at 1.0 (Turbo is 8-step distilled), so the negative
        conditioning is inert — it is always ConditioningZeroOut. The adapter
        does not advertise the negative_prompt capability for this reason.
        """
        g: dict = {
            "10": {"class_type": "UNETLoader",
                   "inputs": {"unet_name": _UNET, "weight_dtype": "default"}},
            "11": {"class_type": "CLIPLoader",
                   "inputs": {"clip_name": _CLIP, "type": "krea2"}},
            "12": {"class_type": "VAELoader", "inputs": {"vae_name": _VAE}},
            "6": {"class_type": "CLIPTextEncode",
                  "inputs": {"text": prompt, "clip": ["11", 0]}},
            "5": {"class_type": "EmptyLatentImage",
                  "inputs": {"width": width, "height": height, "batch_size": 1}},
            "7": {"class_type": "ConditioningZeroOut",
                  "inputs": {"conditioning": ["6", 0]}},
        }
        # Rebalance: bypass Krea 2's trained safety dilution.
        positive = ["6", 0]
        if rebalance_multiplier and rebalance_multiplier != 1.0:
            g["13"] = {"class_type": "ConditioningKrea2Rebalance",
                       "inputs": {"conditioning": ["6", 0],
                                  "multiplier": float(rebalance_multiplier),
                                  "per_layer_weights": rebalance_weights}}
            positive = ["13", 0]
        g["3"] = {"class_type": "KSampler",
                  "inputs": {"seed": int(seed), "steps": 8, "cfg": 1.0,
                             "sampler_name": "euler", "scheduler": "simple",
                             "denoise": 1.0, "model": ["10", 0],
                             "positive": positive, "negative": ["7", 0],
                             "latent_image": ["5", 0]}}
        g["8"] = {"class_type": "VAEDecode",
                  "inputs": {"samples": ["3", 0], "vae": ["12", 0]}}
        g["9"] = {"class_type": "SaveImage",
                  "inputs": {"filename_prefix": filename_prefix, "images": ["8", 0]}}
        return g

    # ------------------------------------------------------------------
    # HTTP helpers (adapters own I/O)
    # ------------------------------------------------------------------
    def _post_json(self, path: str, body: dict) -> dict:
        req = urllib.request.Request(
            self.base_url + path,
            data=json.dumps(body).encode(),
            headers={"Content-Type": "application/json"},
            method="POST",
        )
        with urllib.request.urlopen(req, timeout=30) as resp:
            return json.loads(resp.read().decode())

    def _get_json(self, path: str) -> dict:
        with urllib.request.urlopen(self.base_url + path, timeout=30) as resp:
            return json.loads(resp.read().decode())

    def _get_bytes(self, path: str) -> bytes:
        with urllib.request.urlopen(self.base_url + path, timeout=120) as resp:
            return resp.read()

    # ------------------------------------------------------------------
    # Synchronous image dispatch — the live path execute_keyframe uses.
    # ------------------------------------------------------------------
    def direct_submit_image(self, payload: UnifiedVideoPayload) -> dict:
        """Build + submit the Krea 2 graph, poll, return rendered PNG bytes.

        Returns {"image_bytes", "cost_usd", "native_id"} for execute_keyframe
        (which reads image_bytes + cost_usd at step_runner.py:1952/1958).
        """
        hints = coerce_to_dict(payload.hints) or {}
        ar = payload.aspect_ratio or "9:16"
        width, height = _AR_DIMS.get(ar, _AR_DIMS["9:16"])
        if hints.get("width") and hints.get("height"):
            width, height = int(hints["width"]), int(hints["height"])
        seed = int(hints.get("seed") if hints.get("seed") is not None
                   else (int(time.time() * 1000) & 0x7FFFFFFF))
        multiplier = float(hints.get("rebalance_multiplier", 4.0))
        weights = hints.get("rebalance_weights") or _REBALANCE_DEFAULT_WEIGHTS
        prefix = (payload.shot_id or "krea2") + "_comfyui"

        workflow = self._build_workflow(
            payload.prompt, seed=seed, width=width, height=height,
            rebalance_multiplier=multiplier, rebalance_weights=weights,
            filename_prefix=prefix,
        )

        try:
            submit = self._post_json("/prompt", {"prompt": workflow})
        except urllib.error.HTTPError as e:
            raise RuntimeError(
                f"ComfyUI rejected the workflow ({e.code}): "
                f"{e.read().decode()[:800]}"
            ) from e
        except urllib.error.URLError as e:
            raise RuntimeError(
                f"ComfyUI unreachable at {self.base_url} ({e}). "
                f"Is the server running? Set COMFYUI_URL."
            ) from e

        prompt_id = submit.get("prompt_id")
        if not prompt_id:
            raise RuntimeError(f"ComfyUI /prompt returned no prompt_id: {submit}")

        deadline = time.monotonic() + _POLL_TIMEOUT_S
        while True:
            hist = self._get_json(f"/history/{prompt_id}")
            rec = hist.get(prompt_id)
            if rec is not None:
                status = (rec.get("status") or {}).get("status_str")
                if status == "error":
                    msg = self._extract_error(rec)
                    raise RuntimeError(f"ComfyUI execution failed: {msg}")
                img = self._first_image(rec)
                if img is not None:
                    qs = urllib.parse.urlencode(
                        {"filename": img["filename"],
                         "subfolder": img.get("subfolder", ""),
                         "type": img.get("type", "output")}
                    )
                    image_bytes = self._get_bytes("/view?" + qs)
                    if not image_bytes:
                        raise RuntimeError(
                            f"ComfyUI /view returned no bytes for {img}"
                        )
                    return {"image_bytes": image_bytes, "cost_usd": 0.0,
                            "native_id": prompt_id}
            if time.monotonic() >= deadline:
                raise TimeoutError(
                    f"ComfyUI prompt {prompt_id} did not finish within "
                    f"{_POLL_TIMEOUT_S}s"
                )
            time.sleep(_POLL_INTERVAL_S)

    @staticmethod
    def _first_image(rec: dict) -> Optional[dict]:
        for node_out in (rec.get("outputs") or {}).values():
            for im in node_out.get("images", []) or []:
                return im
        return None

    @staticmethod
    def _extract_error(rec: dict) -> str:
        for m in (rec.get("status") or {}).get("messages", []) or []:
            if m and m[0] == "execution_error":
                d = m[1]
                return f"{d.get('node_type')}: {d.get('exception_message')}"
        return "unknown error"

    # ------------------------------------------------------------------
    # compute_cost — local inference is free.
    # ------------------------------------------------------------------
    def compute_cost(self, duration_s: float, tier: str, profile: dict) -> float:
        return 0.0

    # ------------------------------------------------------------------
    # Async-video Protocol surface — unused on the image path.
    # ------------------------------------------------------------------
    def build_submit(self, payload: UnifiedVideoPayload, tier: str) -> SubmitRequest:
        raise NotImplementedError("ComfyUIAdapter is image-only (direct_submit_image)")

    def parse_submit(self, resp: dict, payload: UnifiedVideoPayload, tier: str) -> ProviderJob:
        raise NotImplementedError("ComfyUIAdapter is image-only (direct_submit_image)")

    def build_poll(self, job: ProviderJob) -> PollRequest:
        raise NotImplementedError("ComfyUIAdapter is image-only (direct_submit_image)")

    def parse_poll(self, resp: dict, job: ProviderJob) -> PollResult:
        raise NotImplementedError("ComfyUIAdapter is image-only (direct_submit_image)")

    def build_result_fetch(self, job: ProviderJob) -> Optional[PollRequest]:
        raise NotImplementedError("ComfyUIAdapter is image-only (direct_submit_image)")

    def parse_result(self, resp: dict, job: ProviderJob) -> PollResult:
        raise NotImplementedError("ComfyUIAdapter is image-only (direct_submit_image)")


ADAPTER = ComfyUIAdapter()
