#!/usr/bin/env python3
"""
validate_asset_refs.py — Validate v3 character asset reference images.

Renamed from validate_canonical_refs.py 2026-05-26. The `_canonical/` legacy
layout was retired; assets now live under `assets/char/{slug}/base/` (v3 layout).

Uses Gemini Flash vision to check character refs for common problems:
  - Multi-figure images (grids, 2-panel layouts, multiple people)
  - Jewelry/piercing identity bleed (earrings, studs, piercings)
  - Aspect ratio consistency across a character's ref set
  - Pose/facing mismatch vs filename (front.png shows back, etc.)

Usage:
    python3 -m tools.validate_asset_refs --project afterimage-anime
    python3 -m tools.validate_asset_refs --project tartarus --verbose
    python3 -m tools.validate_asset_refs --project afterimage-anime --character sadie

Exit codes:
    0 = all checks pass
    1 = one or more checks failed
    2 = script error (missing project, no images, API failure)
"""

import argparse
import json
import logging
import os
import sys
from pathlib import Path

# Project root setup — matches existing tools pattern
_PROJECT_ROOT = Path(__file__).parent.parent
if str(_PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(_PROJECT_ROOT))

from recoil.core.paths import ProjectPaths

logger = logging.getLogger("starsend.validate_asset_refs")

# ── Constants ────────────────────────────────────────────────────────

IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp")

# Angles we expect in canonical ref sets (hero is optional/separate)
EXPECTED_ANGLES = ("front", "profile", "three_quarter", "back")

# Map filename stems to expected facing direction
POSE_EXPECTATIONS = {
    "front": "facing the camera / viewer (front-on)",
    "back": "facing away from the camera (back/rear view)",
    "profile": "facing to the side (side profile, left or right)",
    "three_quarter": "facing at a three-quarter angle (between front and side)",
    "hero": "any pose (hero reference, no specific facing required)",
}

GEMINI_MODEL = "gemini-2.5-flash"

# ── Gemini client ────────────────────────────────────────────────────

_gemini_client = None


def _get_client():
    """Return a cached Gemini client."""
    global _gemini_client
    if _gemini_client is None:
        try:
            from google import genai
        except ImportError:
            logger.error("google-genai SDK not installed. pip install google-genai")
            sys.exit(2)

        api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY")
        if api_key:
            _gemini_client = genai.Client(api_key=api_key)
        else:
            # Let the SDK find credentials (ADC, etc.)
            _gemini_client = genai.Client()

    return _gemini_client


# ── MIME type helper ─────────────────────────────────────────────────

_MIME_MAP = {
    ".png": "image/png",
    ".jpg": "image/jpeg",
    ".jpeg": "image/jpeg",
    ".webp": "image/webp",
}


def _mime_for(path: Path) -> str:
    return _MIME_MAP.get(path.suffix.lower(), "image/png")


# ── Vision analysis ──────────────────────────────────────────────────

_ANALYSIS_PROMPT = """\
Analyze this character reference image and return a JSON object with these fields:

1. "figure_count": integer — How many distinct human figures (full or partial bodies) \
are visible? Count carefully. A single person shown once = 1. A grid/collage showing \
the same person multiple times = count each panel. Two different people = 2.

2. "is_grid": boolean — Is this a multi-panel grid, contact sheet, collage, or \
side-by-side layout? True if the image contains panel borders, dividing lines, \
or is clearly composed of multiple separate images tiled together.

3. "jewelry_details": string — Describe ANY earrings, studs, piercings, ear holes, \
hoops, or ear jewelry visible on the figure. If none are visible, return "none". \
Be very specific about what you see (e.g., "small silver stud in left ear", \
"gold hoop earrings in both ears", "visible piercing hole in right earlobe"). \
Look carefully at both ears.

4. "facing_direction": string — Which direction is the figure primarily facing? \
One of: "front" (facing camera), "back" (facing away), "profile_left" (side view \
facing left), "profile_right" (side view facing right), "three_quarter_left", \
"three_quarter_right". If multiple figures, describe the primary/largest one.

5. "clothing_anomalies": string — Note any obvious clothing problems: buttons on \
the back of a shirt (should be front), zippers in wrong places, impossible garment \
construction. Return "none" if clothing looks normal.

Return ONLY valid JSON, no markdown fences, no explanation.\
"""


def _analyze_image(image_path: Path) -> dict | None:
    """Send an image to Gemini Flash for structured analysis.

    Returns parsed dict or None on failure.
    """
    try:
        from google.genai import types
    except ImportError:
        logger.error("google.genai.types not available")
        return None

    client = _get_client()
    image_bytes = image_path.read_bytes()
    mime = _mime_for(image_path)

    try:
        response = client.models.generate_content(
            model=GEMINI_MODEL,
            contents=[
                types.Part.from_bytes(data=image_bytes, mime_type=mime),
                _ANALYSIS_PROMPT,
            ],
            config=types.GenerateContentConfig(
                response_mime_type="application/json",
                temperature=0.1,
            ),
        )

        if response.text:
            raw = response.text.strip()
            # Strip markdown code fences if present
            if raw.startswith("```json"):
                raw = raw[7:]
            elif raw.startswith("```"):
                raw = raw[3:]
            if raw.endswith("```"):
                raw = raw[:-3]
            raw = raw.strip()
            return json.loads(raw)

    except Exception as e:
        logger.error("Gemini analysis failed for %s: %s", image_path.name, e)

    return None


# ── Check functions ──────────────────────────────────────────────────

def _check_single_figure(analysis: dict, image_path: Path) -> tuple[str, str]:
    """Check that exactly one figure is present and it's not a grid.

    Returns (status, message) where status is PASS/FAIL.
    """
    count = analysis.get("figure_count", 0)
    is_grid = analysis.get("is_grid", False)

    if is_grid:
        return ("FAIL", f"Grid/multi-panel layout detected ({count} figure(s))")

    if count == 0:
        return ("FAIL", "No human figure detected")

    if count > 1:
        return ("FAIL", f"Multiple figures detected: {count} people in frame")

    return ("PASS", "Single figure confirmed")


def _check_jewelry(analysis: dict, image_path: Path) -> tuple[str, str]:
    """Check for earrings/piercings/studs — always report as WARN if found.

    Returns (status, message) where status is PASS/WARN.
    """
    details = analysis.get("jewelry_details", "none")

    if not details or details.lower().strip() in ("none", "n/a", ""):
        return ("PASS", "No ear jewelry or piercings detected")

    return ("WARN", f"Jewelry/piercings found: {details}")


def _check_pose_match(analysis: dict, image_path: Path) -> tuple[str, str]:
    """Check that the figure's facing direction matches the filename.

    Returns (status, message) where status is PASS/FAIL.
    """
    stem = image_path.stem.lower()
    if stem not in POSE_EXPECTATIONS or stem == "hero":
        return ("PASS", "No specific pose requirement for this filename")

    facing = analysis.get("facing_direction", "").lower()

    # Mapping: what facing values are acceptable for each filename
    acceptable = {
        "front": ["front"],
        "back": ["back"],
        "profile": ["profile_left", "profile_right"],
        "three_quarter": ["three_quarter_left", "three_quarter_right"],
    }

    ok_values = acceptable.get(stem, [])

    if any(v in facing for v in ok_values):
        return ("PASS", f"Pose matches filename ({facing})")

    # Also accept if the facing string contains the stem keyword
    if stem in facing:
        return ("PASS", f"Pose matches filename ({facing})")

    expected = POSE_EXPECTATIONS[stem]
    return ("FAIL", f"Pose mismatch: filename is '{stem}' but figure is facing '{facing}' (expected {expected})")


def _check_clothing(analysis: dict, image_path: Path) -> tuple[str, str]:
    """Check for clothing anomalies (back buttons, etc.).

    Returns (status, message) where status is PASS/WARN.
    """
    anomalies = analysis.get("clothing_anomalies", "none")

    if not anomalies or anomalies.lower().strip() in ("none", "n/a", ""):
        return ("PASS", "No clothing anomalies detected")

    return ("WARN", f"Clothing anomaly: {anomalies}")


# ── Aspect ratio check (local, no API) ──────────────────────────────

def _get_image_dimensions(path: Path) -> tuple[int, int] | None:
    """Get image width x height without heavy dependencies.

    Tries PIL first, falls back to reading PNG/JPEG headers directly.
    """
    try:
        from PIL import Image
        with Image.open(path) as img:
            return img.size  # (width, height)
    except ImportError:
        pass

    # Fallback: read PNG header (first 24 bytes contain dimensions)
    try:
        data = path.read_bytes()[:32]
        if data[:8] == b'\x89PNG\r\n\x1a\n':
            import struct
            w = struct.unpack('>I', data[16:20])[0]
            h = struct.unpack('>I', data[20:24])[0]
            return (w, h)
        # JPEG: search for SOF0 marker
        if data[:2] == b'\xff\xd8':
            import struct
            full = path.read_bytes()
            idx = 2
            while idx < len(full) - 9:
                if full[idx] != 0xFF:
                    break
                marker = full[idx + 1]
                if marker in (0xC0, 0xC1, 0xC2):
                    h = struct.unpack('>H', full[idx + 5:idx + 7])[0]
                    w = struct.unpack('>H', full[idx + 7:idx + 9])[0]
                    return (w, h)
                length = struct.unpack('>H', full[idx + 2:idx + 4])[0]
                idx += 2 + length
    except Exception:
        pass

    return None


def _check_aspect_ratios(char_images: list[Path]) -> list[tuple[str, str, str]]:
    """Check that all images for a character share the same aspect ratio.

    Returns list of (status, filename, message) tuples.
    """
    results = []
    ratios = {}

    for img_path in char_images:
        dims = _get_image_dimensions(img_path)
        if dims is None:
            results.append(("WARN", img_path.name, "Could not read image dimensions"))
            continue
        w, h = dims
        # Normalize to 2 decimal places
        ratio = round(w / h, 2) if h > 0 else 0
        ratios[img_path.name] = (ratio, w, h)

    # Group ratios with a tolerance of 0.03 (handles 2064x2048 vs 2048x2048)
    AR_TOLERANCE = 0.03

    def _ratio_group(r):
        """Bucket ratios within tolerance of each other."""
        return round(r / AR_TOLERANCE) * AR_TOLERANCE

    grouped = {name: _ratio_group(r) for name, (r, _, _) in ratios.items()}

    if len(set(grouped.values())) <= 1:
        # All within tolerance — PASS for all
        for name, (ratio, w, h) in ratios.items():
            results.append(("PASS", name, f"Aspect ratio {w}x{h} ({ratio})"))
    else:
        # Mixed ratios — find the most common group and flag outliers
        from collections import Counter
        group_counts = Counter(grouped.values())
        most_common_group = group_counts.most_common(1)[0][0]

        for name, (ratio, w, h) in ratios.items():
            if grouped[name] == most_common_group:
                results.append(("PASS", name, f"Aspect ratio {w}x{h} ({ratio})"))
            else:
                results.append((
                    "FAIL", name,
                    f"Aspect ratio mismatch: {w}x{h} ({ratio}) — "
                    f"expected ~{most_common_group:.2f}"
                ))

    return results


# ── Report formatting ────────────────────────────────────────────────

_STATUS_SYMBOLS = {
    "PASS": "\033[32mPASS\033[0m",
    "WARN": "\033[33mWARN\033[0m",
    "FAIL": "\033[31mFAIL\033[0m",
}


def _format_status(status: str) -> str:
    """Color-coded status string for terminal output."""
    return _STATUS_SYMBOLS.get(status, status)


# ── Main logic ───────────────────────────────────────────────────────

def validate_asset_refs(
    project: str,
    character_filter: str | None = None,
    verbose: bool = False,
) -> bool:
    """Validate all v2 character asset refs for a project.

    Returns True if all checks pass (no FAILs), False otherwise.
    """
    canonical_root = ProjectPaths.for_project(project).asset_class_dir("char")

    if not canonical_root.is_dir():
        print(f"ERROR: char assets directory not found: {canonical_root}")
        return False

    subject_dirs = sorted([
        d for d in canonical_root.iterdir()
        if d.is_dir() and not d.name.startswith((".", "_"))
    ])

    if character_filter:
        subject_dirs = [d for d in subject_dirs if d.name.lower() == character_filter.lower()]
        if not subject_dirs:
            print(f"ERROR: Character '{character_filter}' not found in {canonical_root}")
            return False

    if not subject_dirs:
        print(f"ERROR: No character directories found in {canonical_root}")
        return False

    print(f"\n{'='*60}")
    print(f"  Asset Ref Validation: {project}")
    print(f"  Characters: {', '.join(d.name for d in subject_dirs)}")
    print(f"  Model: {GEMINI_MODEL}")
    print(f"{'='*60}\n")

    all_passed = True
    total_checks = 0
    pass_count = 0
    warn_count = 0
    fail_count = 0

    for subject_path in subject_dirs:
        # Collect image files (skip _thumbs, .DS_Store, etc.)
        images = sorted([
            f for f in subject_path.iterdir()
            if f.is_file()
            and f.suffix.lower() in IMAGE_EXTS
            and not f.name.startswith(".")
        ])

        if not images:
            print(f"  [{_format_status('WARN')}] {subject_path.name}/  — No images found")
            warn_count += 1
            total_checks += 1
            continue

        print(f"  Character: {subject_path.name}")
        print(f"  {'─'*50}")

        # 1. Aspect ratio check (local, no API cost)
        ar_results = _check_aspect_ratios(images)
        for status, filename, message in ar_results:
            total_checks += 1
            if status == "PASS":
                pass_count += 1
                if verbose:
                    print(f"    [{_format_status(status)}] {filename} — Aspect: {message}")
            elif status == "WARN":
                warn_count += 1
                print(f"    [{_format_status(status)}] {filename} — Aspect: {message}")
            else:
                fail_count += 1
                all_passed = False
                print(f"    [{_format_status(status)}] {filename} — Aspect: {message}")

        # 2. Per-image vision checks via Gemini Flash
        for img_path in images:
            print(f"\n    Analyzing {img_path.name}...", end=" ", flush=True)
            analysis = _analyze_image(img_path)

            if analysis is None:
                print(f"\n    [{_format_status('WARN')}] {img_path.name} — API analysis failed (skipped)")
                warn_count += 1
                total_checks += 1
                continue

            if verbose:
                logger.debug("Raw analysis for %s: %s", img_path.name, json.dumps(analysis, indent=2))

            checks = [
                ("Figure", _check_single_figure(analysis, img_path)),
                ("Jewelry", _check_jewelry(analysis, img_path)),
                ("Pose", _check_pose_match(analysis, img_path)),
                ("Clothing", _check_clothing(analysis, img_path)),
            ]

            # Print results inline
            statuses = []
            for label, (status, message) in checks:
                total_checks += 1
                if status == "PASS":
                    pass_count += 1
                    statuses.append(f"\033[32m{label}\033[0m")
                elif status == "WARN":
                    warn_count += 1
                    statuses.append(f"\033[33m{label}\033[0m")
                else:
                    fail_count += 1
                    all_passed = False
                    statuses.append(f"\033[31m{label}\033[0m")

            print(" | ".join(statuses))

            # Print details for non-PASS checks
            for label, (status, message) in checks:
                if status != "PASS":
                    print(f"      [{_format_status(status)}] {label}: {message}")
                elif verbose:
                    print(f"      [{_format_status(status)}] {label}: {message}")

        print()

    # Summary
    print(f"{'='*60}")
    print(f"  Summary: {total_checks} checks — "
          f"\033[32m{pass_count} passed\033[0m, "
          f"\033[33m{warn_count} warnings\033[0m, "
          f"\033[31m{fail_count} failed\033[0m")

    if all_passed:
        print(f"  Result: \033[32mALL CHECKS PASSED\033[0m")
    else:
        print(f"  Result: \033[31mFAILURES DETECTED\033[0m")

    print(f"{'='*60}\n")

    return all_passed


# ── CLI ──────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description="Validate v2 character asset reference images.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    python3 -m tools.validate_asset_refs --project afterimage-anime
    python3 -m tools.validate_asset_refs --project tartarus --verbose
    python3 -m tools.validate_asset_refs --project afterimage-anime --character sadie
        """,
    )
    parser.add_argument(
        "--project", required=True,
        help="Project name (e.g., afterimage-anime, tartarus)",
    )
    parser.add_argument(
        "--character",
        help="Validate only this character (by folder name)",
    )
    parser.add_argument(
        "--verbose", "-v", action="store_true",
        help="Show all check details, including passes",
    )

    args = parser.parse_args()

    # Configure logging
    log_level = logging.DEBUG if args.verbose else logging.INFO
    logging.basicConfig(
        level=log_level,
        format="%(levelname)s: %(message)s",
    )

    passed = validate_asset_refs(
        project=args.project,
        character_filter=args.character,
        verbose=args.verbose,
    )

    sys.exit(0 if passed else 1)


if __name__ == "__main__":
    main()
