#!/usr/bin/env python3
"""style_hold_eval.py — metrics for the krea2-flora style-hold experiment (v2 gate).

krea2-flora Phase 5 — the OBJECTIVE half of the pre-registered style-hold
experiment. The experiment decides whether reference-only conditioning holds a
Look across a style-stress corpus, or whether a style-LoRA is ever funded for
that one Look (see ``consultations/recoil/krea2-flora-image-pipeline-2026-06-02/
STYLE_HOLD_RESULT.md``). The verdict is BOTH-gated: it needs HUMAN on-look
scores AND these objective metrics. This module computes the objective half:

  1. PALETTE ΔE2000  — mean perceptual color distance (CIEDE2000) between each
     frame's dominant palette and the Look's authored ``palette.hex``. The
     pre-registered FAIL threshold is ``mean palette ΔE2000 > ~10``.

  2. STYLE-CENTROID DRIFT — CLIP (or Gram-matrix) style-embedding distance of
     each frame from the CENTROID of the Look's own ``style_refs``, compared to
     the INTERNAL SPREAD of those style_refs (their mean distance to their own
     centroid). The pre-registered FAIL condition is "drift exceeds the Look's
     own style_refs internal spread". Reporting drift relative to the look's own
     spread (not an absolute number) is what makes the threshold principled.

  3. FACE-EMBEDDING IDENTITY CHECK — rules OUT identity as the failure mode. If
     a frame reads "off-look", we must distinguish "the STYLE drifted" from "the
     CHARACTER's face drifted". This computes face-embedding distance so a human
     reading a FAIL knows whether identity (a separate, identity-LoRA problem)
     is the confound. It does NOT gate the verdict — it annotates it.

HEAVY DEPS ARE LAZY (by design)
-------------------------------
numpy, Pillow, scikit-image / colour-science (ΔE2000), torch + open-clip / CLIP
(style centroid), and a face-embedding lib (insightface / face_recognition) are
imported LAZILY *inside* the functions that use them — NEVER at module top.
Every lazy import is guarded with a clear ``StyleHoldDepError`` naming the
missing package + the ``pip install`` line. This is deliberate: the harness
validation only AST-parses + bare-imports this module, and the unattended build
environment has none of these heavy libs. The module MUST import cleanly on a
bare interpreter; the heavy machinery only materializes when JT runs the real
20-frame corpus in a supervised session.

USAGE (supervised, on the real corpus)
--------------------------------------
    PYTHONPATH=. python3 recoil/pipeline/tools/style_hold_eval.py \\
        --look noir_neon --frames /path/to/corpus_dir \\
        [--out metrics.json] [--identity-ref /path/to/face.png]

The CLI loads the Look from the Phase-1 registry (``load_registries``), globs the
frame corpus, computes the three metric families, and prints / writes a JSON
metrics report whose ``palette_de2000_mean`` + ``style_centroid_drift`` feed the
objective half of the STYLE_HOLD_RESULT.md verdict. Human on-look scores are
filled in separately (the verdict is BOTH-gated).
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Optional

# NOTE: NO heavy-dependency imports at module top. numpy / Pillow / skimage /
# colour / torch / CLIP / face-embedding libs are imported lazily inside the
# functions that use them, each guarded by _require(). See module docstring.

IMAGE_SUFFIXES = (".png", ".jpg", ".jpeg", ".webp")

# Pre-registered objective FAIL thresholds (see STYLE_HOLD_RESULT.md). These are
# RECORDED here, not enforced — the verdict is rendered by a human reading both
# these metrics AND the human on-look scores. Kept in code so the report can
# annotate each metric with the threshold it is measured against.
PALETTE_DE2000_FAIL_THRESHOLD = 10.0  # mean palette ΔE2000 > ~10 → objective-FAIL contributor
# style-centroid drift "FAILs" when it EXCEEDS the look's own style_refs internal
# spread — a relative threshold computed per-run, not a constant.


class StyleHoldDepError(ImportError):
    """A heavy optional dependency required for a metric is not installed.

    Raised (not silently swallowed) so a supervised run fails loud with the exact
    ``pip install`` line rather than producing a partial / misleading report.
    """


def _require(module: str, pip_name: Optional[str] = None):
    """Lazily import a heavy dep, raising StyleHoldDepError with install help."""
    pip_name = pip_name or module
    try:
        return __import__(module)
    except ImportError as exc:  # pragma: no cover — env-dependent
        raise StyleHoldDepError(
            f"style_hold_eval requires '{module}' for this metric but it is not "
            f"installed. Install it in the supervised run env:\n"
            f"    pip install {pip_name}"
        ) from exc


# --------------------------------------------------------------------------- #
# Frame corpus + Look loading
# --------------------------------------------------------------------------- #


def list_frames(frames_dir: Path) -> list[Path]:
    """Return sorted image files in a corpus directory (non-recursive + recursive)."""
    if not frames_dir.is_dir():
        raise FileNotFoundError(f"frames dir not found: {frames_dir}")
    frames = sorted(
        p for p in frames_dir.rglob("*")
        if p.is_file() and p.suffix.lower() in IMAGE_SUFFIXES
    )
    return frames


def load_look(look_id: str) -> dict:
    """Load a Look dict from the Phase-1 registry by id.

    Imports ``look_loader`` lazily (it pulls in yaml/registry validation) and
    raises a clear error if the look is absent — this is intentionally NOT a
    heavy-dep guard (yaml is light), but kept lazy so module import stays cheap.
    """
    from recoil.pipeline._lib.look_loader import load_registries

    looks, _ = load_registries()
    look = looks.get(look_id)
    if look is None:
        raise KeyError(
            f"look_id={look_id!r} not in the Look registry (known: {sorted(looks)})"
        )
    return look


def _resolve_style_ref_paths(look: dict) -> list[Path]:
    """Resolve a Look's style_refs[].path against the config ref-root (CONFIG_DIR)."""
    from recoil.pipeline._lib.look_loader import REF_ROOT

    paths: list[Path] = []
    for ref in look.get("style_refs") or []:
        rel = ref.get("path")
        if rel:
            paths.append((REF_ROOT / rel).resolve())
    return paths


# --------------------------------------------------------------------------- #
# Metric 1 — palette ΔE2000 vs Look.palette.hex
# --------------------------------------------------------------------------- #


def _hex_to_rgb01(hex_str: str) -> tuple[float, float, float]:
    """Convert '#rrggbb' (or 'rrggbb') to an (r, g, b) tuple in [0, 1]."""
    h = hex_str.lstrip("#")
    if len(h) != 6:
        raise ValueError(f"expected 6-digit hex color, got {hex_str!r}")
    return tuple(int(h[i:i + 2], 16) / 255.0 for i in (0, 2, 4))  # type: ignore[return-value]


def _dominant_colors(image_path: Path, k: int) -> "list":
    """Extract k dominant RGB[0,1] colors from an image via k-means (lazy deps)."""
    np = _require("numpy")
    PIL = _require("PIL", "Pillow")
    from PIL import Image  # noqa: F401 — ensures submodule loads after guard
    from sklearn.cluster import KMeans  # part of scikit-learn

    img = PIL.Image.open(image_path).convert("RGB")
    img.thumbnail((256, 256))  # downsample — palette is scale-invariant
    arr = np.asarray(img, dtype=np.float64).reshape(-1, 3) / 255.0
    n_clusters = min(k, max(1, len(np.unique(arr, axis=0))))
    km = KMeans(n_clusters=n_clusters, n_init=4, random_state=0).fit(arr)
    return [tuple(c) for c in km.cluster_centers_]


def _delta_e2000(rgb_a, rgb_b) -> float:
    """CIEDE2000 perceptual distance between two RGB[0,1] colors (lazy deps)."""
    np = _require("numpy")
    _require("skimage", "scikit-image")
    from skimage.color import rgb2lab, deltaE_ciede2000

    lab_a = rgb2lab(np.asarray(rgb_a, dtype=np.float64).reshape(1, 1, 3))
    lab_b = rgb2lab(np.asarray(rgb_b, dtype=np.float64).reshape(1, 1, 3))
    return float(deltaE_ciede2000(lab_a, lab_b)[0, 0])


def palette_delta_e(
    frames: list[Path], palette_hex: list[str], k: int = 5
) -> dict:
    """Mean per-frame palette ΔE2000 vs the Look's authored ``palette.hex``.

    For each frame: extract its ``k`` dominant colors, and for each authored
    Look palette color find the NEAREST frame color's ΔE2000. The frame's score
    is the mean over the authored palette (how far the frame is from carrying the
    Look's palette). The corpus score is the mean over frames.

    Returns a report dict including the pre-registered FAIL threshold (>~10) so a
    reader sees the metric next to the bar it is judged against.
    """
    look_rgb = [_hex_to_rgb01(h) for h in palette_hex]
    per_frame: list[dict] = []
    for fp in frames:
        frame_colors = _dominant_colors(fp, k)
        # For each authored palette color, nearest frame color's ΔE2000.
        nearest = []
        for lc in look_rgb:
            nearest.append(min(_delta_e2000(lc, fc) for fc in frame_colors))
        frame_mean = sum(nearest) / len(nearest) if nearest else 0.0
        per_frame.append({"frame": str(fp), "palette_de2000": frame_mean})

    corpus_mean = (
        sum(d["palette_de2000"] for d in per_frame) / len(per_frame)
        if per_frame else 0.0
    )
    return {
        "palette_de2000_mean": corpus_mean,
        "palette_de2000_fail_threshold": PALETTE_DE2000_FAIL_THRESHOLD,
        "palette_de2000_exceeds_threshold": corpus_mean > PALETTE_DE2000_FAIL_THRESHOLD,
        "per_frame": per_frame,
    }


# --------------------------------------------------------------------------- #
# Metric 2 — CLIP / Gram style-centroid drift vs Look.style_refs
# --------------------------------------------------------------------------- #


def _clip_style_embedding(image_path: Path):
    """CLIP image embedding (L2-normalized) for an image (lazy torch/CLIP).

    Uses ``open_clip`` if present (preferred), else the OpenAI ``clip`` package.
    The embedding is the style proxy; Gram-matrix is an alternative (see
    ``_gram_style_embedding``) when CLIP is unavailable.
    """
    torch = _require("torch")
    np = _require("numpy")
    PIL = _require("PIL", "Pillow")
    try:
        import open_clip
    except ImportError as exc:  # pragma: no cover — env-dependent
        raise StyleHoldDepError(
            "style centroid drift needs CLIP — install open_clip_torch:\n"
            "    pip install open_clip_torch"
        ) from exc

    model, _, preprocess = open_clip.create_model_and_transforms(
        "ViT-B-32", pretrained="laion2b_s34b_b79k"
    )
    model.eval()
    img = preprocess(PIL.Image.open(image_path).convert("RGB")).unsqueeze(0)
    with torch.no_grad():
        feat = model.encode_image(img)
    feat = feat / feat.norm(dim=-1, keepdim=True)
    return np.asarray(feat.squeeze(0).cpu().numpy(), dtype=np.float64)


def _gram_style_embedding(image_path: Path):
    """Flattened Gram-matrix style embedding (lazy torch/torchvision).

    Fallback style proxy when CLIP is not desired. Uses early VGG conv features'
    Gram matrix — the classic Gatys style signature. L2-normalized.
    """
    torch = _require("torch")
    np = _require("numpy")
    _require("torchvision")
    PIL = _require("PIL", "Pillow")
    from torchvision import models, transforms

    weights = models.VGG16_Weights.DEFAULT
    vgg = models.vgg16(weights=weights).features[:16].eval()
    tfm = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    img = tfm(PIL.Image.open(image_path).convert("RGB")).unsqueeze(0)
    with torch.no_grad():
        feats = vgg(img)
    b, c, h, w = feats.shape
    f = feats.reshape(c, h * w)
    gram = (f @ f.t()) / (c * h * w)
    vec = gram.reshape(-1).cpu().numpy()
    vec = np.asarray(vec, dtype=np.float64)
    norm = np.linalg.norm(vec)
    return vec / norm if norm > 0 else vec


def _cosine_distance(a, b) -> float:
    """1 - cosine similarity between two (already-normalized-ish) vectors."""
    np = _require("numpy")
    a = np.asarray(a, dtype=np.float64)
    b = np.asarray(b, dtype=np.float64)
    denom = (np.linalg.norm(a) * np.linalg.norm(b)) or 1.0
    return float(1.0 - (np.dot(a, b) / denom))


def style_centroid_drift(
    frames: list[Path], style_ref_paths: list[Path], method: str = "clip"
) -> dict:
    """Style drift of the corpus from the Look's style_refs centroid.

    Computes:
      * The CENTROID of the Look's style_refs embeddings.
      * The INTERNAL SPREAD = mean distance of each style_ref to that centroid
        (the Look's own variance — the principled threshold).
      * The MEAN FRAME DRIFT = mean distance of each corpus frame to the centroid.

    Pre-registered FAIL condition: ``mean frame drift EXCEEDS the style_refs
    internal spread``. The returned ``drift_exceeds_internal_spread`` flag encodes
    exactly that comparison.
    """
    np = _require("numpy")
    embed = _clip_style_embedding if method == "clip" else _gram_style_embedding

    if not style_ref_paths:
        raise ValueError("Look has no style_refs — cannot compute centroid drift.")

    ref_vecs = [embed(p) for p in style_ref_paths]
    centroid = np.mean(np.stack(ref_vecs), axis=0)

    internal = [_cosine_distance(v, centroid) for v in ref_vecs]
    internal_spread = float(np.mean(internal)) if internal else 0.0

    per_frame: list[dict] = []
    for fp in frames:
        d = _cosine_distance(embed(fp), centroid)
        per_frame.append({"frame": str(fp), "centroid_drift": d})

    mean_drift = (
        float(np.mean([d["centroid_drift"] for d in per_frame]))
        if per_frame else 0.0
    )
    return {
        "method": method,
        "style_refs_internal_spread": internal_spread,
        "mean_frame_drift": mean_drift,
        "drift_exceeds_internal_spread": mean_drift > internal_spread,
        "per_frame": per_frame,
    }


# --------------------------------------------------------------------------- #
# Metric 3 — face-embedding identity check (rules OUT identity confound)
# --------------------------------------------------------------------------- #


def _face_embedding(image_path: Path):
    """Face embedding for the largest detected face (lazy face-embedding lib).

    Prefers ``insightface`` (ArcFace) if present, else ``face_recognition``
    (dlib). Returns None when no face is detected (a frame may have 0 chars —
    that is expected and NOT an error). Raises StyleHoldDepError if neither lib
    is installed.
    """
    np = _require("numpy")
    try:
        import insightface  # noqa: F401
        from insightface.app import FaceAnalysis

        app = FaceAnalysis(name="buffalo_l")
        app.prepare(ctx_id=-1, det_size=(640, 640))
        PIL = _require("PIL", "Pillow")
        img = np.asarray(PIL.Image.open(image_path).convert("RGB"))[:, :, ::-1]
        faces = app.get(img)
        if not faces:
            return None
        faces.sort(key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1]))
        return np.asarray(faces[-1].normed_embedding, dtype=np.float64)
    except ImportError:
        pass

    try:
        import face_recognition
    except ImportError as exc:  # pragma: no cover — env-dependent
        raise StyleHoldDepError(
            "identity check needs a face-embedding lib — install one of:\n"
            "    pip install insightface onnxruntime    # preferred (ArcFace)\n"
            "    pip install face_recognition           # dlib fallback"
        ) from exc

    img = face_recognition.load_image_file(str(image_path))
    encs = face_recognition.face_encodings(img)
    if not encs:
        return None
    return np.asarray(encs[0], dtype=np.float64)


def face_identity_check(
    frames: list[Path], identity_ref: Optional[Path]
) -> dict:
    """Face-embedding distance per frame vs an optional identity reference.

    This RULES OUT identity as the failure mode: when a human flags a frame
    off-look, the report shows whether the FACE also drifted (an identity-LoRA
    problem, separate from the style-hold question). It does NOT gate the
    verdict. Frames with no detected face report ``face_detected: false``
    (expected for 0-character frames).
    """
    ref_vec = _face_embedding(identity_ref) if identity_ref else None
    per_frame: list[dict] = []
    for fp in frames:
        vec = _face_embedding(fp)
        if vec is None:
            per_frame.append({"frame": str(fp), "face_detected": False})
            continue
        entry: dict = {"frame": str(fp), "face_detected": True}
        if ref_vec is not None:
            entry["identity_distance"] = _cosine_distance(vec, ref_vec)
        per_frame.append(entry)

    distances = [d["identity_distance"] for d in per_frame if "identity_distance" in d]
    return {
        "identity_ref": str(identity_ref) if identity_ref else None,
        "frames_with_face": sum(1 for d in per_frame if d.get("face_detected")),
        "mean_identity_distance": (sum(distances) / len(distances)) if distances else None,
        "per_frame": per_frame,
        "note": (
            "Identity is an ANNOTATION, not a gate: it distinguishes 'style "
            "drifted' from 'face drifted'. The style-hold verdict is gated on "
            "palette ΔE2000 + style-centroid drift + human on-look scores."
        ),
    }


# --------------------------------------------------------------------------- #
# Top-level report assembly
# --------------------------------------------------------------------------- #


def evaluate(
    look_id: str,
    frames_dir: Path,
    *,
    style_method: str = "clip",
    identity_ref: Optional[Path] = None,
) -> dict:
    """Compute the full objective metrics report for the style-hold experiment.

    Loads the Look, globs the frame corpus, runs all three metric families, and
    returns a single JSON-serializable report. HUMAN on-look scores are NOT here
    — the STYLE_HOLD_RESULT.md verdict is BOTH-gated (these metrics AND the human
    scores). This function is the objective half only.
    """
    look = load_look(look_id)
    frames = list_frames(frames_dir)
    palette_hex = (look.get("palette") or {}).get("hex") or []
    style_ref_paths = _resolve_style_ref_paths(look)

    report: dict = {
        "look_id": look_id,
        "frames_dir": str(frames_dir),
        "frame_count": len(frames),
        "palette_hex": palette_hex,
        "style_ref_count": len(style_ref_paths),
    }

    # Each metric is computed independently so a missing heavy dep for ONE metric
    # surfaces as that metric's error block rather than aborting the whole run.
    try:
        report["palette"] = palette_delta_e(frames, palette_hex)
    except StyleHoldDepError as exc:
        report["palette"] = {"error": "missing_dependency", "message": str(exc)}

    try:
        report["style_centroid"] = style_centroid_drift(
            frames, style_ref_paths, method=style_method
        )
    except StyleHoldDepError as exc:
        report["style_centroid"] = {"error": "missing_dependency", "message": str(exc)}
    except ValueError as exc:
        report["style_centroid"] = {"error": "no_style_refs", "message": str(exc)}

    try:
        report["identity"] = face_identity_check(frames, identity_ref)
    except StyleHoldDepError as exc:
        report["identity"] = {"error": "missing_dependency", "message": str(exc)}

    return report


def main(argv: Optional[list[str]] = None) -> int:
    parser = argparse.ArgumentParser(
        prog="style_hold_eval.py",
        description=(
            "krea2-flora style-hold experiment metrics (objective half): palette "
            "ΔE2000 vs Look.palette.hex, CLIP/Gram style-centroid drift vs "
            "Look.style_refs, and face-embedding identity check (rules OUT "
            "identity as the failure mode)."
        ),
    )
    parser.add_argument(
        "--look", required=True, metavar="LOOK_ID",
        help="Look id to load from the Phase-1 registry (e.g. noir_neon)",
    )
    parser.add_argument(
        "--frames", required=True, metavar="DIR", type=Path,
        help="Directory of the frame corpus (globbed recursively for images)",
    )
    parser.add_argument(
        "--style-method", default="clip", choices=("clip", "gram"),
        help="Style embedding method for centroid drift (default: clip)",
    )
    parser.add_argument(
        "--identity-ref", default=None, metavar="FACE_IMAGE", type=Path,
        help="Optional identity reference image for the face-embedding check",
    )
    parser.add_argument(
        "--out", default=None, metavar="JSON_PATH", type=Path,
        help="Optional path to write the JSON metrics report (also prints stdout)",
    )
    args = parser.parse_args(argv)

    try:
        report = evaluate(
            args.look,
            args.frames,
            style_method=args.style_method,
            identity_ref=args.identity_ref,
        )
    except (FileNotFoundError, KeyError) as exc:
        print(json.dumps({"error": "input_error", "message": str(exc)}, indent=2))
        return 2

    payload = json.dumps(report, indent=2, default=str)
    print(payload)
    if args.out:
        args.out.parent.mkdir(parents=True, exist_ok=True)
        args.out.write_text(payload, encoding="utf-8")
    return 0


if __name__ == "__main__":
    sys.exit(main())
