"""Element manager for Kling O3 Omni Elements via fal.ai.

Handles character and location element payloads for the multi-prompt I2V pipeline.
Elements are reference images (character identity, location context) that get
injected into the fal.ai request so the model maintains identity and environment
consistency across shots.

fal.ai Elements format:
  elements: [
    {
      "frontal_image_url": "data:image/png;base64,...",
      "reference_image_urls": ["data:image/png;base64,...", ...]
    }
  ]

Prompt injection: reference elements as @Element1, @Element2, @Element3, @Element4
(1-indexed, matches order in the elements array). Max 4 elements per request.
Elements can be characters, props, locations, or any object needing consistency.

Priority order when filling the 4-element cap:
  1. Characters (identity consistency is highest priority)
  2. Location (environment consistency)
  3. Props
"""

from __future__ import annotations

import base64
import logging
from pathlib import Path

from recoil.core.paths import ensure_pipeline_importable

ensure_pipeline_importable()

logger = logging.getLogger(__name__)

# Supported image MIME types
_MIME_MAP = {
    "png": "image/png",
    "jpg": "image/jpeg",
    "jpeg": "image/jpeg",
    "webp": "image/webp",
}


def _detect_mime(raw: bytes, extension: str) -> str:
    """Detect actual MIME type from file magic bytes, ignoring extension.

    File extensions lie (e.g. JPEG saved as .png). Magic bytes don't.
    Falls back to extension-based lookup if magic bytes don't match.
    """
    if raw[:3] == b"\xff\xd8\xff":
        return "image/jpeg"
    if raw[:8] == b"\x89PNG\r\n\x1a\n":
        return "image/png"
    if raw[:4] == b"RIFF" and raw[8:12] == b"WEBP":
        return "image/webp"
    # Fallback to extension
    return _MIME_MAP.get(extension, "image/png")


def _to_data_uri(path: Path) -> str | None:
    """Convert an image file to a data URI (data:mime;base64,...).

    Detects actual format from magic bytes — never trusts file extension
    alone, since files are often saved with wrong extensions.
    """
    try:
        raw = path.read_bytes()
    except OSError as e:
        logger.error("Failed to read ref %s: %s", path, e)
        return None

    suffix = path.suffix.lower().lstrip(".")
    mime_type = _detect_mime(raw, suffix)

    if mime_type != _MIME_MAP.get(suffix, "image/png"):
        logger.warning(
            "MIME mismatch: %s has .%s extension but contains %s data",
            path.name, suffix, mime_type,
        )

    encoded = base64.b64encode(raw).decode("ascii")
    return f"data:{mime_type};base64,{encoded}"


def extract_batch_location(plan_shots: list[dict]) -> str | None:
    """Extract the dominant location_id from a batch of plan shots.

    For multi-shot sequences, picks the most common location across the
    batch. If all shots share the same location, returns that. If mixed,
    returns the most frequent one (environment context for the majority).

    Args:
        plan_shots: List of plan shot dicts with asset_data.location_id.

    Returns:
        location_id string, or None if no location data found.
    """
    from collections import Counter

    locations = []
    for shot in plan_shots:
        loc = shot.get("asset_data", {}).get("location_id", "")
        if loc:
            locations.append(loc)

    if not locations:
        return None

    # Most common location wins
    counter = Counter(locations)
    return counter.most_common(1)[0][0]


class ElementManager:
    """Build and manage character + location element payloads for fal.ai."""

    MAX_ELEMENTS = 3  # With start/tail image, fal.ai caps at 3 (4 for pure T2V)

    def __init__(self, mode: str = "inline", validate_refs: bool = True):
        if mode not in ("inline", "url"):
            raise ValueError(f"Unsupported mode: {mode!r}. Use 'inline' or 'url'.")
        self.mode = mode
        self.validate_refs = validate_refs

    def _validate_ref(self, ref_path: Path, element_id: str) -> bool:
        """Validate a ref image before use as an element. Returns True if OK."""
        if not self.validate_refs:
            return True
        try:
            from recoil.pipeline._lib.critics.ref_image_critic import RefImageCritic
            critic = RefImageCritic(character_type="human")
            _, result = critic.run(str(ref_path))
            if not result.passed:
                failed = [f"{d.name}: {d.message}" for d in result.failed_dimensions]
                logger.warning(
                    "Ref validation FAILED for %s (%s): %s",
                    element_id, ref_path.name, "; ".join(failed),
                )
                return False
            logger.info("Ref validation passed for %s (%s)", element_id, ref_path.name)
            return True
        except Exception as e:
            logger.warning("Ref validation error (non-blocking): %s", e)
            return True  # Fail-open on validator errors

    def _resolve_refs(self, element_id: str, project: str) -> list[Path]:
        """Resolve reference images for a character or prop.

        Search order:
          1. recoil_bridge.get_character_refs() (characters)
          2. props/{id}/ directory with _front, _3q, _side naming convention
        """
        if not element_id:
            return []

        from recoil.pipeline._lib.recoil_bridge import get_character_refs
        from recoil.core.paths import projects_root

        # Try character refs first
        refs = get_character_refs(element_id.lower(), project=project)
        if refs:
            return refs

        # Try props directory (case-insensitive search)
        props_base = projects_root() / project / "output" / "refs" / "props"
        if not props_base.is_dir():
            return []

        # Find matching prop dir (try exact, lowercase, uppercase)
        prop_dir = None
        for candidate in (element_id, element_id.lower(), element_id.upper()):
            d = props_base / candidate
            if d.is_dir():
                prop_dir = d
                break

        if not prop_dir:
            return []

        _IMG_EXTS = (".png", ".jpg", ".jpeg", ".webp")

        # Frontal first, then other angles
        result = []
        for suffix in ("_front", "_default_front"):
            for f in prop_dir.iterdir():
                if f.is_file() and f.suffix.lower() in _IMG_EXTS and suffix in f.stem.lower():
                    result.append(f)
                    break
            if result:
                break

        # Additional angles
        for suffix in ("_3q", "_side", "_back"):
            for f in prop_dir.iterdir():
                if f.is_file() and f.suffix.lower() in _IMG_EXTS and suffix in f.stem.lower():
                    result.append(f)

        # If no structured refs found, just take the first image
        if not result:
            imgs = sorted(
                (f for f in prop_dir.iterdir() if f.is_file() and f.suffix.lower() in _IMG_EXTS),
                key=lambda p: p.name,
            )
            if imgs:
                result.append(imgs[0])

        return result

    def _resolve_location_refs(self, location_id: str, project: str) -> list[Path]:
        """Resolve reference images for a location via SSOT resolver."""
        from recoil.core.paths import ProjectPaths, RefNotFoundError

        paths = ProjectPaths.for_project(project)
        refs = []
        try:
            ref = paths.resolve_ref("loc", location_id, "identity", variant="hero")
            refs.append(ref.path)
        except (RefNotFoundError, FileNotFoundError):
            pass
        if not refs:
            from recoil.pipeline._lib.recoil_bridge import get_location_refs
            refs = get_location_refs(location_id, project=project)
        return refs

    def _build_element_entry(self, refs: list[Path], element_id: str = "") -> dict | None:
        """Build a single fal.ai element dict from a list of ref image paths.

        First ref = frontal_image_url (hero), rest = reference_image_urls
        (up to 3 additional). Returns None if frontal can't be encoded or
        fails validation.

        fal.ai requires reference_image_urls to be non-empty — if no
        additional angles exist, the frontal is duplicated as a reference.
        """
        if not refs:
            return None

        # Validate hero ref — if it fails, skip the whole element
        if not self._validate_ref(refs[0], element_id):
            return None

        frontal_uri = _to_data_uri(refs[0])
        if not frontal_uri:
            return None

        reference_uris = []
        for ref_path in refs[1:4]:  # Up to 3 additional angles
            # Validate each additional ref — skip failures but don't skip element
            if not self._validate_ref(ref_path, element_id):
                continue
            uri = _to_data_uri(ref_path)
            if uri:
                reference_uris.append(uri)

        # fal.ai rejects elements with empty reference_image_urls
        if not reference_uris:
            reference_uris = [frontal_uri]

        return {
            "frontal_image_url": frontal_uri,
            "reference_image_urls": reference_uris,
        }

    def build_fal_elements(
        self,
        element_ids: list[str],
        project: str,
        location_id: str | None = None,
    ) -> dict:
        """Build elements payload in fal.ai format.

        Resolves refs for characters, props, AND optionally a location.
        Each element gets a frontal image and up to 3 additional angle refs.

        Priority order (filling up to 4 element slots):
          1. Characters/props from element_ids (identity first)
          2. Location ref (environment consistency, single image)

        Args:
            element_ids: Character or prop IDs (e.g. ["KIT", "DEAD_COURIER"]).
            project: Project name for ref resolution.
            location_id: Optional location key (e.g. "INT. LEVIATHAN - LOWER DECK").
                If provided and slots remain after characters, adds location as
                an element with its best ref image as frontal_image_url.

        Returns:
            {"elements": [{frontal_image_url, reference_image_urls}, ...]}
            Empty dict if no refs found.
        """
        elements_list = []

        # Phase 1: Characters/props (highest priority — identity consistency)
        for element_id in element_ids[:self.MAX_ELEMENTS]:
            if len(elements_list) >= self.MAX_ELEMENTS:
                break
            refs = self._resolve_refs(element_id, project)
            if not refs:
                logger.warning("No refs found for %r in project %r", element_id, project)
                continue

            entry = self._build_element_entry(refs, element_id=element_id)
            if entry:
                elements_list.append(entry)

        # Phase 2: Location ref (if slots remain and location_id provided)
        if location_id and len(elements_list) < self.MAX_ELEMENTS:
            loc_refs = self._resolve_location_refs(location_id, project)
            if loc_refs:
                # Location gets a single frontal image (the best/first ref)
                entry = self._build_element_entry(loc_refs[:1], element_id=location_id)
                if entry:
                    elements_list.append(entry)
                    logger.info(
                        "Added location element for %r (%s) — slot %d/%d",
                        location_id, loc_refs[0].name,
                        len(elements_list), self.MAX_ELEMENTS,
                    )
            else:
                logger.debug("No location refs found for %r in project %r", location_id, project)

        if not elements_list:
            return {}

        return {"elements": elements_list}

    @staticmethod
    def build_elements_for_fal(
        char_ids: list[str],
        project: str,
        location_id: str | None = None,
    ) -> dict:
        """Convenience: build fal.ai elements payload with optional location.

        Args:
            char_ids: Character/prop IDs for identity elements.
            project: Project name.
            location_id: Optional location key for environment element.
        """
        mgr = ElementManager(mode="inline")
        return mgr.build_fal_elements(char_ids, project, location_id=location_id)

    @staticmethod
    def build_elements_with_info(
        char_ids: list[str],
        project: str,
        location_id: str | None = None,
    ) -> tuple[dict, bool, int]:
        """Build elements payload and return metadata for prompt injection.

        Returns:
            (elements_payload, has_location_element, total_elements)

            - elements_payload: The fal.ai payload dict (or empty dict).
            - has_location_element: True if a location element was included.
            - total_elements: Total number of elements in the payload.
        """
        mgr = ElementManager(mode="inline")

        # Build characters-only first to count them
        chars_only = mgr.build_fal_elements(char_ids, project, location_id=None)
        n_char_elements = len(chars_only.get("elements", []))

        # Build full payload with location
        full_payload = mgr.build_fal_elements(char_ids, project, location_id=location_id)
        total_elements = len(full_payload.get("elements", []))

        has_location = total_elements > n_char_elements

        return full_payload, has_location, total_elements

    @staticmethod
    def inject_element_refs(
        prompt: str,
        shot_char_ids: list[str],
        batch_char_ids: list[str],
        has_location_element: bool = False,
        total_elements: int = 0,
    ) -> str:
        """Inject @Element references into prompt text.

        Maps shot characters to their @Element index in the batch.
        Example: batch_char_ids=["KIT", "NAVI"], shot has ["KIT"] -> "@Element1"

        When has_location_element is True, the location element is the last
        element in the payload (after all characters). Its @Element reference
        is always appended so every shot gets environment context.

        Args:
            prompt: Original prompt text.
            shot_char_ids: Characters in this specific shot.
            batch_char_ids: All characters in the elements payload (order matters).
            has_location_element: If True, a location element follows the characters.
            total_elements: Total number of elements in the payload (chars + location).
                Used to determine the location's @Element index.

        Returns:
            Prompt with @Element references appended.
        """
        if not batch_char_ids and not has_location_element:
            return prompt

        refs = []

        # Character @Element refs
        if batch_char_ids and shot_char_ids:
            for char_id in shot_char_ids:
                try:
                    idx = batch_char_ids.index(char_id.upper())
                    refs.append(f"@Element{idx + 1}")
                except ValueError:
                    pass

        # Location @Element ref — always last, always included for every shot
        if has_location_element and total_elements > 0:
            refs.append(f"@Element{total_elements}")

        if refs:
            prompt = prompt.rstrip() + " " + " ".join(refs)

        return prompt

    def build_fal_elements_with_policy(
        self,
        element_ids: list[str],
        project: str,
        location_id: str | None = None,
        beat_policy=None,
    ) -> dict:
        """Build elements payload respecting beat composition policy.

        When beat_policy is provided:
        - Limits character elements to beat_policy.max_loras
        - ENVIRONMENT beats get 0 character elements (location only)
        - FACE_DOMINANT beats get max 1 character element

        Returns the same format as build_fal_elements().
        """
        if beat_policy is None:
            return self.build_fal_elements(element_ids, project, location_id)

        # Enforce max_loras
        limited_ids = element_ids[:beat_policy.max_loras]

        # For ENVIRONMENT style, skip all character elements
        if beat_policy.composition_style == "ENVIRONMENT":
            limited_ids = []

        return self.build_fal_elements(limited_ids, project, location_id)

    @staticmethod
    def split_characters_by_policy(
        characters: list[str],
        beat_policy=None,
    ) -> tuple[list[str], list[str]]:
        """Split characters into element refs and text-only refs.

        Returns (element_chars, text_only_chars) based on beat policy.
        Characters beyond max_loras become text-only in the prompt.
        """
        if beat_policy is None:
            return characters, []

        max_loras = beat_policy.max_loras
        return characters[:max_loras], characters[max_loras:]


def build_identity_anchor(character: dict) -> str:
    """Build a frozen identity anchor string for a character.

    This string is NEVER modified per-shot. It is copied verbatim into
    every prompt that features this character. The Genra pattern:
    same words every time = only seed variance, no wording variance.

    Format: [age + gender] + [hair] + [skin tone + features] +
            [top clothing] + [bottom clothing] + [signature detail]

    Args:
        character: Character dict from global_bible.json

    Returns:
        Frozen description string, or empty string if character data insufficient.
    """
    parts = []

    # Age + gender
    age = character.get("age", "")
    gender = character.get("gender", "")
    if age or gender:
        parts.append(f"{age} {gender}".strip())

    # Hair
    hair = character.get("hair", character.get("hair_description", ""))
    if hair:
        parts.append(f"hair: {hair}")

    # Skin tone + distinctive features
    skin = character.get("skin_tone", "")
    features = character.get("distinctive_features", character.get("facial_features", ""))
    if skin:
        parts.append(f"skin: {skin}")
    if features:
        parts.append(features if isinstance(features, str) else ", ".join(features))

    # Clothing
    clothing = character.get("clothing", character.get("wardrobe", ""))
    if clothing:
        if isinstance(clothing, dict):
            top = clothing.get("top", "")
            bottom = clothing.get("bottom", "")
            if top:
                parts.append(f"wearing {top}")
            if bottom:
                parts.append(bottom)
        elif isinstance(clothing, str):
            parts.append(f"wearing {clothing}")

    # Signature detail (cross-shot anchor)
    signature = character.get("signature_detail", character.get("accessories", ""))
    if signature:
        sig_str = signature if isinstance(signature, str) else ", ".join(signature)
        parts.append(f"signature: {sig_str}")

    return ", ".join(parts)


def get_identity_anchor(global_bible: dict, character_name: str) -> str:
    """Retrieve or build the identity anchor for a character from the global bible.

    If the bible already has an 'identity_anchor' field, return it verbatim.
    Otherwise, build one from the character data and log a warning that
    it should be frozen in the bible.

    Args:
        global_bible: The full global_bible.json dict
        character_name: Character name to look up

    Returns:
        Frozen identity anchor string
    """
    characters = global_bible.get("characters", {})

    # Try exact match first, then case-insensitive
    char_data = characters.get(character_name)
    if char_data is None:
        for name, data in characters.items():
            if name.lower() == character_name.lower():
                char_data = data
                break

    if char_data is None:
        logger.warning("Character '%s' not found in global bible", character_name)
        return ""

    # Use frozen anchor if already set
    anchor = char_data.get("identity_anchor", "")
    if anchor:
        return anchor

    # Build one dynamically and warn
    anchor = build_identity_anchor(char_data)
    if anchor:
        logger.warning(
            "Character '%s' has no frozen identity_anchor in bible. "
            "Built dynamically: '%s'. Freeze this in the bible for consistency.",
            character_name, anchor[:100]
        )
    return anchor
