"""Unit tests for look_loader — registries, FK validation, ref-budget caps.

Fast + offline. The committed-registry test exercises the real
recoil/config/{looks,identities} + CINEMA_MODES.yaml. The negative/cap tests
build self-contained tmp_path fixtures so they do not depend on (or mutate)
the committed example files.
"""

from __future__ import annotations

import struct
import zlib

import pytest

from recoil.pipeline._lib.look_loader import (
    LookRegistryError,
    build_look_bundle,
    load_registries,
    reload_registries,
)


# --------------------------------------------------------------------------- #
# Fixtures helpers
# --------------------------------------------------------------------------- #


def _write_png(path):
    """Write a minimal valid 1x1 RGBA PNG (stdlib only)."""
    path.parent.mkdir(parents=True, exist_ok=True)
    raw = b"\x00" + b"\x00\x00\x00\xff"

    def chunk(typ, data):
        c = typ + data
        return struct.pack(">I", len(data)) + c + struct.pack(">I", zlib.crc32(c) & 0xFFFFFFFF)

    sig = b"\x89PNG\r\n\x1a\n"
    ihdr = struct.pack(">IIBBBBB", 1, 1, 8, 6, 0, 0, 0)
    idat = zlib.compress(raw)
    path.write_bytes(sig + chunk(b"IHDR", ihdr) + chunk(b"IDAT", idat) + chunk(b"IEND", b""))


def _write_yaml(path, text):
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(text)


def _make_cinema(tmp_path, modes=("noir_tension",)):
    cinema = tmp_path / "CINEMA_MODES.yaml"
    body = "schema_version: 1\nmodes:\n"
    for m in modes:
        body += f"  {m}:\n    body: arri_alexa_35\n"
    cinema.write_text(body)
    return cinema


def _make_model_profiles(tmp_path):
    """Minimal model_profiles.json with the two models the cap tests need."""
    import json

    path = tmp_path / "model_profiles.json"
    path.write_text(json.dumps({
        "schema_version": 1,
        "seedream-v4.5": {
            "modality": "image",
            "max_reference_images": 10,
            "max_character_refs": 4,
            "supports_lora": False,
        },
        "nano-banana": {
            "modality": "image",
            "max_reference_images": 6,
            "max_character_refs": 3,
            "supports_lora": False,
        },
        # Real-world common case: a model that leaves max_character_refs null
        # (no separate identity sub-cap). Mirrors gemini-3.1-flash-image-preview,
        # gpt-image-2, seedream-v5-lite, and most video models in the committed
        # profiles. null != 0 — identities must use the full ref budget here.
        "flash-null-cap": {
            "modality": "image",
            "max_reference_images": 11,
            "max_character_refs": None,
            "supports_lora": False,
        },
    }))
    return path


# --------------------------------------------------------------------------- #
# 1. Committed registries parse
# --------------------------------------------------------------------------- #


def test_committed_registries_parse():
    reload_registries()
    looks, identities = load_registries()
    assert "noir_neon" in looks
    assert looks["noir_neon"]["extends_cinema_mode"] in {
        m for m in _committed_cinema_modes()
    }
    # at least the example identity is present
    assert "kara_voss" in identities


def _committed_cinema_modes():
    import yaml
    from recoil.core.paths import CONFIG_DIR
    data = yaml.safe_load((CONFIG_DIR / "CINEMA_MODES.yaml").read_text())
    return set((data.get("modes") or {}).keys())


# --------------------------------------------------------------------------- #
# 2. Bad FK raises
# --------------------------------------------------------------------------- #


def test_bad_cinema_mode_fk_raises(tmp_path):
    cinema = _make_cinema(tmp_path, modes=("noir_tension",))
    looks_dir = tmp_path / "looks"
    identities_dir = tmp_path / "identities"

    # ref file must exist or we'd fail for a different reason
    _write_png(tmp_path / "looks" / "badfk" / "ref_01.png")
    _write_yaml(looks_dir / "badfk.yaml", (
        "schema_version: 1\n"
        "look_id: badfk\n"
        "extends_cinema_mode: does_not_exist\n"
        "style_refs:\n"
        "  - { path: \"looks/badfk/ref_01.png\", priority: 1 }\n"
    ))

    with pytest.raises(LookRegistryError, match="extends_cinema_mode"):
        load_registries(
            looks_dir=looks_dir,
            identities_dir=identities_dir,
            cinema_path=cinema,
            ref_root=tmp_path,
            use_cache=False,
        )


def test_missing_ref_file_raises(tmp_path):
    cinema = _make_cinema(tmp_path)
    looks_dir = tmp_path / "looks"
    # No PNG written → ref missing.
    _write_yaml(looks_dir / "noref.yaml", (
        "schema_version: 1\n"
        "look_id: noref\n"
        "extends_cinema_mode: noir_tension\n"
        "style_refs:\n"
        "  - { path: \"looks/noref/missing.png\", priority: 1 }\n"
    ))
    with pytest.raises(LookRegistryError, match="missing"):
        load_registries(
            looks_dir=looks_dir,
            identities_dir=tmp_path / "identities",
            cinema_path=cinema,
            ref_root=tmp_path,
            use_cache=False,
        )


def test_missing_schema_version_raises(tmp_path):
    cinema = _make_cinema(tmp_path)
    looks_dir = tmp_path / "looks"
    _write_png(tmp_path / "looks" / "nover" / "ref_01.png")
    _write_yaml(looks_dir / "nover.yaml", (
        "look_id: nover\n"
        "extends_cinema_mode: noir_tension\n"
        "style_refs:\n"
        "  - { path: \"looks/nover/ref_01.png\", priority: 1 }\n"
    ))
    with pytest.raises(LookRegistryError, match="schema_version"):
        load_registries(
            looks_dir=looks_dir,
            identities_dir=tmp_path / "identities",
            cinema_path=cinema,
            ref_root=tmp_path,
            use_cache=False,
        )


# --------------------------------------------------------------------------- #
# 3. Ref-budget caps + truncation
# --------------------------------------------------------------------------- #


def _look_with_style_refs(n):
    return {
        "look_id": "capL",
        "extends_cinema_mode": "noir_tension",
        "aspect_default": "9:16",
        "creativity": "medium",
        "palette": {"hex": ["#000"], "lut": None},
        "look_pack": {"positive": ["x"], "avoid": ["y"]},
        "style_refs": [
            {"path": f"looks/capL/s{i}.png", "priority": i} for i in range(1, n + 1)
        ],
    }


def _identity_with_refs(n, ident_id="capI"):
    return {
        "identity_id": ident_id,
        "trigger": ident_id.upper(),
        "ref_set": [
            {"path": f"refs/{ident_id}/r{i}.png", "role": "identity", "priority": i}
            for i in range(1, n + 1)
        ],
    }


def test_seedream_respects_max_reference_images(tmp_path):
    profiles = _make_model_profiles(tmp_path)
    # 4 identity refs (cap 4) + 12 style refs → total wants 16; max_refs=10.
    look = _look_with_style_refs(12)
    identity = _identity_with_refs(4)

    bundle = build_look_bundle(
        look, [identity], "seedream-v4.5", profiles_path=profiles,
    )

    # identity capped at max_character_refs (4), all fit under max_refs(10)
    assert len(bundle.identity_refs) == 4
    # remaining budget = 10 - 4 = 6 style refs
    assert len(bundle.style_refs) == 6
    assert bundle.ref_budget["max_refs"] == 10
    assert bundle.ref_budget["used_identity"] == 4
    assert bundle.ref_budget["used_style"] == 6
    # 12 - 6 = 6 style refs truncated by max_reference_images
    trunc_reasons = [t["reason"] for t in bundle.ref_budget["truncated"]]
    assert trunc_reasons.count("max_reference_images") == 6
    # kept style refs are the lowest-priority numbers (1..6)
    kept = sorted(int(r.priority) for r in bundle.style_refs)
    assert kept == [1, 2, 3, 4, 5, 6]
    assert bundle.backing == "references"
    assert bundle.loras == []


def test_nano_banana_respects_max_character_refs(tmp_path):
    profiles = _make_model_profiles(tmp_path)
    # 5 identity refs but nano-banana max_character_refs=3 → 2 dropped.
    look = _look_with_style_refs(2)
    identity = _identity_with_refs(5)

    bundle = build_look_bundle(
        look, [identity], "nano-banana", profiles_path=profiles,
    )

    assert bundle.ref_budget["max_character_refs"] == 3
    assert len(bundle.identity_refs) == 3
    # identity refs kept are lowest priority (1,2,3)
    assert sorted(int(r.priority) for r in bundle.identity_refs) == [1, 2, 3]
    # 2 identity refs truncated due to max_character_refs
    char_trunc = [
        t for t in bundle.ref_budget["truncated"]
        if t["reason"] == "max_character_refs"
    ]
    assert len(char_trunc) == 2
    # remaining ref budget = 6 - 3 = 3, only 2 style refs → both fit
    assert len(bundle.style_refs) == 2
    # trigger surfaced
    assert "CAPI" in bundle.triggers


def test_null_char_cap_uses_full_ref_budget(tmp_path):
    """null max_character_refs ≠ 0: identities use the full ref budget."""
    profiles = _make_model_profiles(tmp_path)
    # max_refs=11, max_character_refs=null. 3 identity refs + 4 style refs.
    look = _look_with_style_refs(4)
    identity = _identity_with_refs(3)
    bundle = build_look_bundle(
        look, [identity], "flash-null-cap", profiles_path=profiles,
    )
    # All 3 identity refs kept (null cap = no identity sub-cap), NOT dropped.
    assert len(bundle.identity_refs) == 3
    # No identity ref truncated for max_character_refs.
    assert not [
        t for t in bundle.ref_budget["truncated"]
        if t["reason"] == "max_character_refs"
    ]
    # remaining = 11 - 3 = 8 ≥ 4 style refs → all style refs fit too.
    assert len(bundle.style_refs) == 4
    assert bundle.ref_budget["used_identity"] == 3
    assert bundle.ref_budget["used_style"] == 4


def test_explicit_zero_char_cap_drops_all_identity_refs(tmp_path):
    """Explicit max_character_refs=0 ≠ null: every identity ref is dropped."""
    import json
    profiles = tmp_path / "model_profiles.json"
    profiles.write_text(json.dumps({
        "schema_version": 1,
        "t2i-only": {
            "modality": "image",
            "max_reference_images": 0,
            "max_character_refs": 0,
            "supports_lora": False,
        },
    }))
    look = _look_with_style_refs(2)
    identity = _identity_with_refs(2)
    bundle = build_look_bundle(look, [identity], "t2i-only", profiles_path=profiles)
    assert len(bundle.identity_refs) == 0
    assert len(bundle.style_refs) == 0
    char_trunc = [
        t for t in bundle.ref_budget["truncated"]
        if t["reason"] == "max_character_refs"
    ]
    assert len(char_trunc) == 2


def test_unknown_model_uses_conservative_default_budget(tmp_path):
    """A model absent from model_profiles.json (e.g. the 'nbp'/'flash' aliases)
    must NOT drop the whole Look — it degrades to a conservative ref budget."""
    profiles = _make_model_profiles(tmp_path)  # has no 'nbp' key
    look = _look_with_style_refs(2)
    identity = _identity_with_refs(2)
    # Should not raise; should return a usable bundle, not None / empty.
    bundle = build_look_bundle(look, [identity], "nbp", profiles_path=profiles)
    assert bundle is not None
    # Conservative default budget (max_refs=4, no identity sub-cap) → all
    # 2 identity + 2 style refs ride along (4 ≤ 4), Look is NOT silently lost.
    assert len(bundle.identity_refs) == 2
    assert len(bundle.style_refs) == 2
    assert bundle.ref_budget["max_refs"] == 4
    # null sub-cap → no identity ref dropped for max_character_refs.
    assert not [
        t for t in bundle.ref_budget["truncated"]
        if t["reason"] == "max_character_refs"
    ]


def test_identity_refs_placed_first_and_bounded_by_total_budget(tmp_path):
    """nano-banana: identities consume the whole ref budget, style truncated."""
    profiles = _make_model_profiles(tmp_path)
    # max_refs=6, max_character_refs=3. 3 identity refs + 5 style refs.
    look = _look_with_style_refs(5)
    identity = _identity_with_refs(3)
    bundle = build_look_bundle(look, [identity], "nano-banana", profiles_path=profiles)
    # identity uses 3, remaining = 6-3 = 3 style slots
    assert len(bundle.identity_refs) == 3
    assert len(bundle.style_refs) == 3
    style_trunc = [
        t for t in bundle.ref_budget["truncated"]
        if t["reason"] == "max_reference_images"
    ]
    assert len(style_trunc) == 2
