"""cinema_loader.py — Load and resolve Cinema Modes (CINEMA_MODES.yaml).

Single source of truth for cinematography catalogs (camera bodies, lens
systems, filtration, film stocks, texture grain, color grades) and the
named modes that compose them by reference.

Provides:
  * load_cinema_modes() — parse + validate + cache the YAML.
  * reload_cinema_modes() — clear cache (for tests + dev reloads).
  * resolve_mode(mode_id, overrides=None) — return a flat resolved-tokens
    dict with the same six catalog-token keys + aperture + shutter.
  * render_cinema_tokens(mode_id, model_id, shot_overrides=None) — return
    the final per-model-shaped string ready to slot into a builder.

Validation policy:
  * Load-time: HARD crash on invalid catalog references in modes
    (Reference Integrity per Law 1 — modes promise to compose catalogs;
    a broken reference is a config bug, not a runtime fluke).
  * Runtime (per-shot): SOFT — unknown mode_id logs WARNING and returns
    empty string. Matches `_get_seeddance_film_stock()` behavior at
    prompt_engine.py:2426.

Cache: module-level + mtime-stamped. If the YAML file's mtime changed
since last load (dev workflow: edit YAML, re-run generation), re-read.
Production runs load once at startup — the mtime check is one stat() call.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Optional

import yaml

from recoil.core.paths import CONFIG_DIR
from recoil.pipeline._lib.bible_loader import load_bible, get_global_defaults

logger = logging.getLogger(__name__)

_CINEMA_PATH = CONFIG_DIR / "CINEMA_MODES.yaml"

_SHOT_NAMES = {
    "ECU": "extreme close-up", "CU": "close-up", "BCU": "big close-up",
    "MCU": "medium close-up", "MS": "medium shot", "MFS": "medium full shot",
    "MLS": "medium long shot", "MWS": "medium wide shot", "LS": "long shot",
    "FS": "full shot", "WS": "wide shot", "EWS": "extreme wide shot",
    "VLS": "very long shot", "WIDE": "wide shot", "INSERT": "insert shot",
}

_MOVEMENT_NAMES = {
    "pan": "panning", "tilt": "tilting", "push_in": "push-in",
    "pull_back": "pull-back", "tracking": "tracking", "crane": "crane",
    "handheld": "handheld", "steadicam": "Steadicam", "dolly": "dolly",
    "push": "push-in", "pull": "pull-back", "track": "tracking",
    "dolly_in": "dolly-in", "dolly_out": "dolly-out",
}

_cinema: Optional[dict] = None
_cinema_mtime: Optional[float] = None

# The six catalog-reference fields every mode must declare.
_MODE_CATALOG_FIELDS = {
    "body":       "camera_bodies",
    "lens":       "lens_systems",
    "filtration": "filtration",
    "stock":      "film_stocks",
    "grain":      "texture_grain",
    "grade":      "color_grades",
}


class CinemaConfigError(ValueError):
    """Raised when CINEMA_MODES.yaml has structural or reference errors."""


def _validate_cinema_config(d: dict) -> None:
    """Validate structure + reference integrity. Raises CinemaConfigError on failure."""
    if not isinstance(d, dict):
        raise CinemaConfigError("CINEMA_MODES.yaml top-level must be a mapping")

    if d.get("schema_version") != 1:
        raise CinemaConfigError(
            f"CINEMA_MODES.yaml schema_version must be 1, got {d.get('schema_version')!r}"
        )

    catalogs = d.get("catalogs") or {}
    modes = d.get("modes") or {}

    if not isinstance(catalogs, dict) or not catalogs:
        raise CinemaConfigError("CINEMA_MODES.yaml 'catalogs' missing or empty")
    if not isinstance(modes, dict) or not modes:
        raise CinemaConfigError("CINEMA_MODES.yaml 'modes' missing or empty")

    # Required catalogs present.
    required_catalogs = set(_MODE_CATALOG_FIELDS.values())
    missing_catalogs = required_catalogs - set(catalogs.keys())
    if missing_catalogs:
        raise CinemaConfigError(
            f"CINEMA_MODES.yaml missing required catalogs: {sorted(missing_catalogs)}"
        )

    # Every catalog entry has prompt_tokens.
    for cat_name, cat in catalogs.items():
        if not isinstance(cat, dict):
            raise CinemaConfigError(f"catalog '{cat_name}' must be a mapping")
        for entry_id, entry in cat.items():
            if not isinstance(entry, dict) or "prompt_tokens" not in entry:
                raise CinemaConfigError(
                    f"catalog entry '{cat_name}.{entry_id}' missing 'prompt_tokens'"
                )
            if not isinstance(entry["prompt_tokens"], str) or not entry["prompt_tokens"].strip():
                raise CinemaConfigError(
                    f"catalog entry '{cat_name}.{entry_id}' has empty prompt_tokens"
                )

    # Reference integrity: every mode's catalog reference must resolve.
    for mode_id, mode in modes.items():
        if not isinstance(mode, dict):
            raise CinemaConfigError(f"mode '{mode_id}' must be a mapping")
        for field, cat_name in _MODE_CATALOG_FIELDS.items():
            ref = mode.get(field)
            if ref is None:
                raise CinemaConfigError(
                    f"mode '{mode_id}' missing required field '{field}'"
                )
            if ref not in catalogs[cat_name]:
                raise CinemaConfigError(
                    f"mode '{mode_id}' field '{field}' references "
                    f"'{ref}' which is not a key in catalog '{cat_name}'. "
                    f"Available: {sorted(catalogs[cat_name].keys())}"
                )


def load_cinema_modes() -> dict:
    """Load + cache CINEMA_MODES.yaml. Re-reads if file mtime changed."""
    global _cinema, _cinema_mtime

    if not _CINEMA_PATH.exists():
        raise FileNotFoundError(f"CINEMA_MODES.yaml not found at {_CINEMA_PATH}")

    current_mtime = _CINEMA_PATH.stat().st_mtime
    if _cinema is not None and _cinema_mtime == current_mtime:
        return _cinema

    raw = yaml.safe_load(_CINEMA_PATH.read_text(encoding="utf-8"))
    _validate_cinema_config(raw)
    _cinema = raw
    _cinema_mtime = current_mtime
    return _cinema


def reload_cinema_modes() -> None:
    """Clear cache and reload from disk (used by tests + dev reloads)."""
    global _cinema, _cinema_mtime
    _cinema = None
    _cinema_mtime = None
    load_cinema_modes()


def resolve_mode(
    mode_id: str,
    shot_overrides: Optional[dict] = None,
) -> Optional[dict]:
    """Resolve a mode_id (+ optional per-shot overrides) into a flat dict.

    Returns a dict with keys: body_tokens, lens_tokens, filtration_tokens,
    stock_tokens, grain_tokens, grade_tokens, aperture, shutter — each
    a plain string. Returns None if mode_id is unknown (soft fail with
    WARNING log).

    `shot_overrides` (per-shot dict) may contain any subset of the
    catalog-reference keys (body, lens, filtration, stock, grain, grade)
    whose values are catalog IDs. Invalid catalog IDs in overrides log a
    WARNING and are dropped (the mode's default for that field stands).
    Override values that are inline strings (not catalog IDs) are also
    dropped with a WARNING — modes never contain raw prompt text per the
    No Inline Specs rule.
    """
    if not mode_id:
        return None

    cinema = load_cinema_modes()
    modes = cinema["modes"]
    catalogs = cinema["catalogs"]

    if mode_id not in modes:
        logger.warning(
            "cinema_mode %r not found in CINEMA_MODES.yaml — "
            "falling back to no-cinema-tokens. Available: %s",
            mode_id, sorted(modes.keys()),
        )
        return None

    mode = modes[mode_id]
    resolved: dict = {}
    for field, cat_name in _MODE_CATALOG_FIELDS.items():
        # Per-shot override applies if it names a valid catalog entry.
        override_ref = (shot_overrides or {}).get(field)
        if override_ref is not None:
            if not isinstance(override_ref, str) or override_ref not in catalogs[cat_name]:
                logger.warning(
                    "cinema_overrides %r=%r invalid (not a key in catalog %r) — "
                    "falling back to mode default %r",
                    field, override_ref, cat_name, mode[field],
                )
                ref = mode[field]
            else:
                ref = override_ref
        else:
            ref = mode[field]
        resolved[f"{field}_tokens"] = catalogs[cat_name][ref]["prompt_tokens"]

    resolved["aperture"] = mode.get("aperture", "")
    resolved["shutter"] = mode.get("shutter", "")
    return resolved


def _compress_tokens(prompt_tokens: str) -> str:
    """Compress a catalog prompt_tokens string to its head clause.

    The first comma-delimited segment is the head noun phrase
    (e.g. 'Cooke S4/i spherical prime lenses, classic Cooke Look, ...' →
    'Cooke S4/i spherical prime lenses'). Deterministic, no NLP.
    """
    head = prompt_tokens.split(",", 1)[0].strip()
    return head


# Ordered field render sequence — matches natural English cinematography
# describing convention: body → lens → aperture → shutter → filtration →
# stock → grain → grade. Used by render_cinema_tokens to assemble the
# resolved field strings into a single output clause.
_RENDER_ORDER = (
    "body", "lens", "aperture", "shutter",
    "filtration", "stock", "grain", "grade",
)


def render_cinema_tokens(
    mode_id: Optional[str],
    model_id: str,
    shot_overrides: Optional[dict] = None,
) -> str:
    """Resolve cinema mode → model-aware token string.

    Returns empty string if:
      * mode_id is None / empty (no cinema mode active for this shot)
      * mode_id is unknown (soft fail; resolve_mode already logged WARNING)
      * every field in the model's cinema_token_map is None (this model
        opts out of cinema tokens entirely)

    Otherwise returns a single natural-language clause (no terminal period)
    ready to slot into a builder where `f"Shot on {film_stock}"` lives.

    Args:
        mode_id: A key in CINEMA_MODES.yaml `modes` (e.g. 'narrative_cinematic').
        model_id: Model identifier as it appears in model_profiles.json
                  (e.g. 'seeddance-2.0', 'kling-v3', 'wan-2.7-r2v', 'veo-3.1').
        shot_overrides: Optional per-shot partial overrides (catalog ids).

    Returns:
        Composed token string, or empty string.
    """
    resolved = resolve_mode(mode_id, shot_overrides=shot_overrides)
    if resolved is None:
        return ""

    # Per-model token map. Default to all-"full" for unknown models.
    try:
        from recoil.core.model_profiles import get_profile
        profile = get_profile(model_id)
        token_map = profile.get("cinema_token_map") or {}
    except Exception:  # noqa: BLE001 — soft fail
        token_map = {}

    # If the caller did not wire cinema_token_map for this model, default
    # all fields to "full" — better than emitting nothing for an unknown
    # model. (Wan, SeedDance, Kling, Veo are wired in Phase 3; others are
    # accepted into BUILDERS later without code change.)
    default_strategy = "full"

    field_to_resolved_key = {
        "body":       "body_tokens",
        "lens":       "lens_tokens",
        "filtration": "filtration_tokens",
        "stock":      "stock_tokens",
        "grain":      "grain_tokens",
        "grade":      "grade_tokens",
        "aperture":   "aperture",
        "shutter":    "shutter",
    }

    parts: list[str] = []
    for field in _RENDER_ORDER:
        strategy = token_map.get(field, default_strategy)
        if strategy is None:
            continue  # model opts out of this field
        value = resolved.get(field_to_resolved_key[field], "")
        if not value:
            continue
        if strategy == "compressed":
            value = _compress_tokens(value)
        # else "full" — use as-is
        parts.append(value)

    if not parts:
        return ""

    # Natural-language join: comma-separated for compressed (Kling/Veo
    # patterns), period-joined for full (SeedDance/Wan natural prose).
    # Decision rule: if ANY field uses "compressed" strategy for this
    # model, use the comma join; otherwise use the period join.
    any_compressed = any(
        token_map.get(f) == "compressed" for f in _RENDER_ORDER
    )
    if any_compressed:
        return ", ".join(parts)
    return ". ".join(parts)


def render_camera_line(
    shot: "CanonicalShot",
    mode: dict | None,
    model_id: str,
    bible: dict | None = None,
) -> str:
    """Returns "Camera: <shot_size>, <lens_type>, <camera_move>." or "" if shot data missing.

    lens_type resolution order:
      1. shot.cinematography.lens_type_override  (per-shot director override)
      2. mode.lens_per_shot_size_override[shot.shot_size]  (per-mode)
      3. bible.lens_per_shot_size[shot.shot_size]  (project-default from PROMPT_BIBLE)
      4. "" (omitted entirely, soft fail)

    The shot_size is read from shot.raw["prompt_data"]["shot_type"] and normalized
    to uppercase for the lens_per_shot_size lookup. Camera move is read from
    shot.raw["prompt_data"]["camera_movement"].

    Args:
        shot: CanonicalShot instance (or any object with .raw and .cinematography attrs).
        mode: Resolved mode dict from CINEMA_MODES.yaml (or None if no mode active).
        model_id: Model identifier. Controls per-model formatting — Veo embeds
            the camera line as a fragment (no prefix), others use it standalone.
        bible: PROMPT_BIBLE dict (optional; loaded from cache if None).

    Returns:
        Camera-line string, or empty string if shot data is insufficient.
    """
    if isinstance(shot, dict):
        prompt_data = shot.get("prompt_data", {})
        cinematography = shot.get("cinematography") or {}
    else:
        prompt_data = shot.raw.get("prompt_data", {}) if hasattr(shot, "raw") else {}
        cinematography = getattr(shot, "cinematography", None) or {}
    shot_type = (prompt_data.get("shot_type") or "").upper()
    camera_movement = prompt_data.get("camera_movement") or ""

    if not shot_type:
        return ""

    # Resolve lens_type through the 4-level cascade.
    lens_type = ""

    # Level 1: per-shot director override
    if isinstance(cinematography, dict):
        lens_type = cinematography.get("lens_type_override", "") or ""

    # Level 2: per-mode override table
    if not lens_type and mode is not None:
        mode_overrides = mode.get("lens_per_shot_size_override") or {}
        lens_type = mode_overrides.get(shot_type, "") or ""

    # Level 3: project-default from PROMPT_BIBLE global_defaults
    if not lens_type:
        if bible is None:
            try:
                gd = get_global_defaults()
            except Exception:  # noqa: BLE001
                gd = {}
        else:
            gd = bible.get("global_defaults") or {}
            if not isinstance(gd, dict):
                gd = {}
        lens_map = gd.get("lens_per_shot_size") or {}
        lens_type = lens_map.get(shot_type, "") or ""

    # Level 4: omit entirely (soft fail)

    shot_name = _SHOT_NAMES.get(shot_type, shot_type.lower())

    parts = [shot_name]
    if lens_type:
        parts.append(lens_type)
    if camera_movement and camera_movement != "static":
        parts.append(_MOVEMENT_NAMES.get(camera_movement, camera_movement))

    prefix = "Camera: "
    if model_id and model_id.startswith("veo"):
        prefix = ""
    return f"{prefix}{', '.join(parts)}."


def render_constraint_block(
    constraints: list[str],
    model_id: str,
    bible: dict | None = None,
    profiles: dict | None = None,
) -> tuple[str, list[str]]:
    """Returns (positive_suffix, negative_prompt_list).

    positive_suffix: "" or "Constraints: no X, no Y, no Z." closing string.
    negative_prompt_list: [] or list of phrase strings for the API negative_prompt param.

    Emit policy: reads supports_negative_prompt from model_profiles.json.
      - False/missing -> bans go to positive_suffix only.
      - True -> bans go to negative_prompt_list; positive_suffix is empty.

    Unknown slug -> soft fail, WARNING log, slug skipped.

    Args:
        constraints: List of constraint slugs (e.g. ["no_modern_color_grade", ...]).
        model_id: Model identifier for emit-mode lookup.
        bible: PROMPT_BIBLE dict (optional; loaded from cache if None).
        profiles: Model profiles dict (optional; loaded from cache if None).

    Returns:
        Tuple of (positive_suffix_string, negative_prompt_phrase_list).
    """
    if not constraints:
        return ("", [])

    if bible is None:
        try:
            bible_data = load_bible()
        except Exception:  # noqa: BLE001
            bible_data = {}
    else:
        bible_data = bible
    constraint_dict = bible_data.get("constraint_dictionary") or {}

    phrases: list[str] = []
    for slug in constraints:
        phrase = constraint_dict.get(slug)
        if phrase is None:
            logger.warning(
                "render_constraint_block: unknown constraint slug %r — skipping. "
                "Available: %s",
                slug, sorted(constraint_dict.keys()),
            )
            continue
        phrases.append(phrase)

    if not phrases:
        return ("", [])

    use_negative = False
    if profiles is not None:
        model_profile = profiles.get(model_id) or {}
        use_negative = bool(model_profile.get("supports_negative_prompt", False))
    else:
        try:
            from recoil.core.model_profiles import get_profile
            model_profile = get_profile(model_id)
            use_negative = bool(model_profile.get("supports_negative_prompt", False))
        except Exception:  # noqa: BLE001
            use_negative = False

    if use_negative:
        return ("", phrases)

    ban_clauses = ["; ".join(f"no {frag.strip()}" for frag in p.split(",")) for p in phrases]
    suffix = "Constraints: " + "; ".join(ban_clauses) + "."
    return (suffix, [])


__all__ = [
    "CinemaConfigError",
    "load_cinema_modes",
    "reload_cinema_modes",
    "render_camera_line",
    "render_cinema_tokens",
    "render_constraint_block",
    "resolve_mode",
]
