"""Element manager for Kling O3 Omni Elements via fal.ai.

Handles character and location element payloads for the multi-prompt I2V pipeline.
Elements are reference images (character identity, location context) that get
injected into the fal.ai request so the model maintains identity and environment
consistency across shots.

fal.ai Elements format:
  elements: [
    {
      "frontal_image_url": "data:image/png;base64,...",
      "reference_image_urls": ["data:image/png;base64,...", ...]
    }
  ]

Prompt injection: reference elements as @Element1, @Element2, @Element3, @Element4
(1-indexed, matches order in the elements array). Max 4 elements per request.
Elements can be characters, props, locations, or any object needing consistency.

Priority order when filling the 4-element cap:
  1. Characters (identity consistency is highest priority)
  2. Location (environment consistency)
  3. Props
"""

from __future__ import annotations

import base64
import logging
from pathlib import Path

logger = logging.getLogger(__name__)

# Supported image MIME types
_MIME_MAP = {
    "png": "image/png",
    "jpg": "image/jpeg",
    "jpeg": "image/jpeg",
    "webp": "image/webp",
}


def _detect_mime(raw: bytes, extension: str) -> str:
    """Detect actual MIME type from file magic bytes, ignoring extension.

    File extensions lie (e.g. JPEG saved as .png). Magic bytes don't.
    Falls back to extension-based lookup if magic bytes don't match.
    """
    if raw[:3] == b"\xff\xd8\xff":
        return "image/jpeg"
    if raw[:8] == b"\x89PNG\r\n\x1a\n":
        return "image/png"
    if raw[:4] == b"RIFF" and raw[8:12] == b"WEBP":
        return "image/webp"
    # Fallback to extension
    return _MIME_MAP.get(extension, "image/png")


def _to_data_uri(path: Path) -> str | None:
    """Convert an image file to a data URI (data:mime;base64,...).

    Detects actual format from magic bytes — never trusts file extension
    alone, since files are often saved with wrong extensions.
    """
    try:
        raw = path.read_bytes()
    except OSError as e:
        logger.error("Failed to read ref %s: %s", path, e)
        return None

    suffix = path.suffix.lower().lstrip(".")
    mime_type = _detect_mime(raw, suffix)

    if mime_type != _MIME_MAP.get(suffix, "image/png"):
        logger.warning(
            "MIME mismatch: %s has .%s extension but contains %s data",
            path.name, suffix, mime_type,
        )

    encoded = base64.b64encode(raw).decode("ascii")
    return f"data:{mime_type};base64,{encoded}"


def extract_batch_location(plan_shots: list[dict]) -> str | None:
    """Extract the dominant location_id from a batch of plan shots.

    For multi-shot sequences, picks the most common location across the
    batch. If all shots share the same location, returns that. If mixed,
    returns the most frequent one (environment context for the majority).

    Args:
        plan_shots: List of plan shot dicts with asset_data.location_id.

    Returns:
        location_id string, or None if no location data found.
    """
    from collections import Counter

    locations = []
    for shot in plan_shots:
        loc = shot.get("asset_data", {}).get("location_id", "")
        if loc:
            locations.append(loc)

    if not locations:
        return None

    # Most common location wins
    counter = Counter(locations)
    return counter.most_common(1)[0][0]


class ElementManager:
    """Build and manage character + location element payloads for fal.ai."""

    MAX_ELEMENTS = 3  # With start/tail image, fal.ai caps at 3 (4 for pure T2V)

    def __init__(self, mode: str = "inline"):
        if mode not in ("inline", "url"):
            raise ValueError(f"Unsupported mode: {mode!r}. Use 'inline' or 'url'.")
        self.mode = mode

    def _resolve_refs(
        self,
        element_id: str,
        project: str,
        element_type: str,
    ) -> list[Path]:
        """Resolve refs via the canonical entry point. element_type required.

        Returns list of Paths in priority order: hero first, then turnaround
        angles. Returns [] if MissingCanonicalRefsError raised (caller logs).
        """
        from recoil.core.ref_resolver import get_element_refs, MissingCanonicalRefsError
        try:
            refs_dict = get_element_refs(element_id, project, element_type)
        except MissingCanonicalRefsError:
            logger.warning("No canonical refs for %s/%s in %s", element_type, element_id, project)
            return []

        result: list[Path] = []
        if "hero" in refs_dict:
            result.append(refs_dict["hero"])
        if element_type == "characters":
            for angle in ("front", "three_quarter", "profile", "back"):
                p = refs_dict.get(angle)
                if p and p not in result:
                    result.append(p)
        elif element_type == "locations":
            for v in ("wide", "medium", "closeup"):
                p = refs_dict.get(v)
                if p and p not in result:
                    result.append(p)
        return result

    def _resolve_location_refs(self, location_id: str, project: str) -> list[Path]:
        return self._resolve_refs(location_id, project, "locations")

    def _build_element_entry(self, refs: list[Path]) -> dict | None:
        """Build a single fal.ai element dict from a list of ref image paths.

        First ref = frontal_image_url (hero), rest = reference_image_urls
        (up to 3 additional). Returns None if frontal can't be encoded.

        fal.ai requires reference_image_urls to be non-empty — if no
        additional angles exist, the frontal is duplicated as a reference.
        """
        if not refs:
            return None

        frontal_uri = _to_data_uri(refs[0])
        if not frontal_uri:
            return None

        reference_uris = []
        for ref_path in refs[1:4]:  # Up to 3 additional angles
            uri = _to_data_uri(ref_path)
            if uri:
                reference_uris.append(uri)

        # fal.ai rejects elements with empty reference_image_urls
        if not reference_uris:
            reference_uris = [frontal_uri]

        return {
            "frontal_image_url": frontal_uri,
            "reference_image_urls": reference_uris,
        }

    def build_fal_elements(
        self,
        element_ids: list[str],
        project: str,
        location_id: str | None = None,
        element_types: dict[str, str] | None = None,
    ) -> dict:
        """Build elements payload in fal.ai format.

        Resolves refs for characters, props, AND optionally a location.
        Each element gets a frontal image and up to 3 additional angle refs.

        Priority order (filling up to 4 element slots):
          1. Characters/props from element_ids (identity first)
          2. Location ref (environment consistency, single image)

        Args:
            element_ids: Character or prop IDs (e.g. ["KIT", "DEAD_COURIER"]).
            project: Project name for ref resolution.
            location_id: Optional location key (e.g. "INT. LEVIATHAN - LOWER DECK").
                If provided and slots remain after characters, adds location as
                an element with its best ref image as frontal_image_url.
            element_types: Optional mapping of element_id -> element_type
                ("characters" | "props" | "locations"). When None, every entry
                in element_ids is treated as "characters". The orchestrator
                supplies the real lookup at construction time.

        Returns:
            {"elements": [{frontal_image_url, reference_image_urls}, ...]}
            Empty dict if no refs found.
        """
        elements_list = []

        # Phase 1: Characters/props (highest priority — identity consistency)
        for element_id in element_ids[:self.MAX_ELEMENTS]:
            if len(elements_list) >= self.MAX_ELEMENTS:
                break
            element_type = (
                element_types.get(element_id, "characters")
                if element_types is not None
                else "characters"
            )
            refs = self._resolve_refs(element_id, project, element_type)
            if not refs:
                logger.warning("No refs found for %r in project %r", element_id, project)
                continue

            entry = self._build_element_entry(refs)
            if entry:
                elements_list.append(entry)

        # Phase 2: Location ref (if slots remain and location_id provided)
        if location_id and len(elements_list) < self.MAX_ELEMENTS:
            loc_refs = self._resolve_location_refs(location_id, project)
            if loc_refs:
                # Location gets a single frontal image (the best/first ref)
                entry = self._build_element_entry(loc_refs[:1])
                if entry:
                    elements_list.append(entry)
                    logger.info(
                        "Added location element for %r (%s) — slot %d/%d",
                        location_id, loc_refs[0].name,
                        len(elements_list), self.MAX_ELEMENTS,
                    )
            else:
                logger.debug("No location refs found for %r in project %r", location_id, project)

        if not elements_list:
            return {}

        return {"elements": elements_list}

    @staticmethod
    def build_elements_for_fal(
        char_ids: list[str],
        project: str,
        location_id: str | None = None,
    ) -> dict:
        """Convenience: build fal.ai elements payload with optional location.

        Args:
            char_ids: Character/prop IDs for identity elements.
            project: Project name.
            location_id: Optional location key for environment element.
        """
        mgr = ElementManager(mode="inline")
        return mgr.build_fal_elements(char_ids, project, location_id=location_id)

    @staticmethod
    def build_elements_with_info(
        char_ids: list[str],
        project: str,
        location_id: str | None = None,
    ) -> tuple[dict, bool, int]:
        """Build elements payload and return metadata for prompt injection.

        Returns:
            (elements_payload, has_location_element, total_elements)

            - elements_payload: The fal.ai payload dict (or empty dict).
            - has_location_element: True if a location element was included.
            - total_elements: Total number of elements in the payload.
        """
        mgr = ElementManager(mode="inline")

        # Build characters-only first to count them
        chars_only = mgr.build_fal_elements(char_ids, project, location_id=None)
        n_char_elements = len(chars_only.get("elements", []))

        # Build full payload with location
        full_payload = mgr.build_fal_elements(char_ids, project, location_id=location_id)
        total_elements = len(full_payload.get("elements", []))

        has_location = total_elements > n_char_elements

        return full_payload, has_location, total_elements

    @staticmethod
    def inject_element_refs(
        prompt: str,
        shot_char_ids: list[str],
        batch_char_ids: list[str],
        has_location_element: bool = False,
        total_elements: int = 0,
    ) -> str:
        """Inject @Element references into prompt text.

        Maps shot characters to their @Element index in the batch.
        Example: batch_char_ids=["KIT", "NAVI"], shot has ["KIT"] -> "@Element1"

        When has_location_element is True, the location element is the last
        element in the payload (after all characters). Its @Element reference
        is always appended so every shot gets environment context.

        Args:
            prompt: Original prompt text.
            shot_char_ids: Characters in this specific shot.
            batch_char_ids: All characters in the elements payload (order matters).
            has_location_element: If True, a location element follows the characters.
            total_elements: Total number of elements in the payload (chars + location).
                Used to determine the location's @Element index.

        Returns:
            Prompt with @Element references appended.
        """
        if not batch_char_ids and not has_location_element:
            return prompt

        refs = []

        # Character @Element refs
        if batch_char_ids and shot_char_ids:
            for char_id in shot_char_ids:
                try:
                    idx = batch_char_ids.index(char_id.upper())
                    refs.append(f"@Element{idx + 1}")
                except ValueError:
                    pass

        # Location @Element ref — always last, always included for every shot
        if has_location_element and total_elements > 0:
            refs.append(f"@Element{total_elements}")

        if refs:
            prompt = prompt.rstrip() + " " + " ".join(refs)

        return prompt
