#!/usr/bin/env python3
"""
extract_palette.py — Extract Dominant HEX Color Palettes from Character Hero Images

Reads character hero reference images (PNG) and extracts 5-8 dominant colors
using manual K-means clustering. Excludes white/near-white studio backgrounds.
Outputs HEX codes with descriptive category names.

Usage:
    python3 extract_palette.py leviathan/
    python3 extract_palette.py leviathan/ --character JINX
    python3 extract_palette.py leviathan/ --dry-run
    python3 extract_palette.py leviathan/ --no-update
    python3 extract_palette.py leviathan/ --json
    python3 extract_palette.py leviathan/ --num-colors 6
    python3 extract_palette.py leviathan/ --white-threshold 220

Dependencies:
    pip install Pillow
"""

import argparse
import json
import math
import random
import sys
from pathlib import Path
from typing import Dict, List, Tuple


# ── Path Resolution ──────────────────────────────────────────────────────

def resolve_project_path(project_arg: str) -> Path:
    """Resolve project path from argument."""
    script_dir = Path(__file__).resolve().parent
    if script_dir.name == "tools" and script_dir.parent.name == "recoil":
        root = script_dir.parent.parent
    else:
        root = Path.cwd()

    project_name = project_arg.strip("/").strip("\\")
    candidate = root / project_name
    if candidate.is_dir():
        return candidate

    abs_path = Path(project_arg)
    if abs_path.is_dir():
        return abs_path

    cwd_path = Path.cwd() / project_name
    if cwd_path.is_dir():
        return cwd_path

    print(f"ERROR: Project directory not found: {project_arg}", file=sys.stderr)
    sys.exit(1)


# ── Color Utilities ──────────────────────────────────────────────────────

def rgb_to_hex(r: int, g: int, b: int) -> str:
    """Convert RGB tuple to HEX string."""
    return f"#{r:02X}{g:02X}{b:02X}"


def hex_to_rgb(hex_str: str) -> Tuple[int, int, int]:
    """Convert HEX string to RGB tuple."""
    hex_str = hex_str.lstrip("#")
    return (int(hex_str[0:2], 16), int(hex_str[2:4], 16), int(hex_str[4:6], 16))


def color_distance(c1: Tuple[int, int, int], c2: Tuple[int, int, int]) -> float:
    """Euclidean distance between two RGB colors."""
    return math.sqrt(sum((a - b) ** 2 for a, b in zip(c1, c2)))


def rgb_to_hsl(r: int, g: int, b: int) -> Tuple[float, float, float]:
    """Convert RGB (0-255) to HSL (h: 0-360, s: 0-1, l: 0-1)."""
    r_n, g_n, b_n = r / 255.0, g / 255.0, b / 255.0
    c_max = max(r_n, g_n, b_n)
    c_min = min(r_n, g_n, b_n)
    delta = c_max - c_min

    # Lightness
    l = (c_max + c_min) / 2.0

    if delta < 1e-6:
        return (0.0, 0.0, l)

    # Saturation
    if l < 0.5:
        s = delta / (c_max + c_min)
    else:
        s = delta / (2.0 - c_max - c_min)

    # Hue
    if c_max == r_n:
        h = 60.0 * (((g_n - b_n) / delta) % 6)
    elif c_max == g_n:
        h = 60.0 * (((b_n - r_n) / delta) + 2)
    else:
        h = 60.0 * (((r_n - g_n) / delta) + 4)

    if h < 0:
        h += 360.0

    return (h, s, l)


def is_white_or_near_white(r: int, g: int, b: int, threshold: int = 220) -> bool:
    """Check if a color is white or near-white (studio background)."""
    return r >= threshold and g >= threshold and b >= threshold


def is_near_black(r: int, g: int, b: int, threshold: int = 25) -> bool:
    """Check if a color is very dark / near-black."""
    return r <= threshold and g <= threshold and b <= threshold


def is_very_desaturated(r: int, g: int, b: int) -> bool:
    """Check if a color is extremely desaturated (grey)."""
    _, s, l = rgb_to_hsl(r, g, b)
    return s < 0.05 and 0.15 < l < 0.85


# ── K-Means Clustering (Manual Implementation) ──────────────────────────

def kmeans_cluster(
    pixels: List[Tuple[int, int, int]],
    k: int,
    max_iterations: int = 30,
    seed: int = 42,
) -> List[Tuple[Tuple[int, int, int], int]]:
    """
    Manual K-means clustering on RGB pixels.

    Returns list of (centroid_rgb, pixel_count) sorted by count descending.
    Uses K-means++ initialization for better convergence.
    """
    if not pixels:
        return []

    rng = random.Random(seed)
    n = len(pixels)

    # K-means++ initialization
    centroids = [pixels[rng.randint(0, n - 1)]]
    for _ in range(1, k):
        # Compute distances to nearest existing centroid
        distances = []
        for px in pixels:
            min_d = min(color_distance(px, c) for c in centroids)
            distances.append(min_d * min_d)  # squared distance for weighting

        total = sum(distances)
        if total < 1e-6:
            # All remaining pixels are identical to existing centroids
            centroids.append(pixels[rng.randint(0, n - 1)])
            continue

        # Weighted random selection
        target = rng.random() * total
        cumulative = 0.0
        chosen = pixels[0]
        for px, d in zip(pixels, distances):
            cumulative += d
            if cumulative >= target:
                chosen = px
                break
        centroids.append(chosen)

    # Iterate
    for _ in range(max_iterations):
        # Assign each pixel to nearest centroid
        assignments: Dict[int, List[Tuple[int, int, int]]] = {
            i: [] for i in range(k)
        }

        for px in pixels:
            best_idx = 0
            best_dist = float("inf")
            for ci, centroid in enumerate(centroids):
                d = color_distance(px, centroid)
                if d < best_dist:
                    best_dist = d
                    best_idx = ci
            assignments[best_idx].append(px)

        # Recompute centroids
        new_centroids = []
        converged = True
        for ci in range(k):
            cluster = assignments[ci]
            if not cluster:
                # Empty cluster: keep old centroid
                new_centroids.append(centroids[ci])
                continue

            avg_r = sum(px[0] for px in cluster) // len(cluster)
            avg_g = sum(px[1] for px in cluster) // len(cluster)
            avg_b = sum(px[2] for px in cluster) // len(cluster)
            new_c = (avg_r, avg_g, avg_b)

            if color_distance(new_c, centroids[ci]) > 1.0:
                converged = False

            new_centroids.append(new_c)

        centroids = new_centroids

        if converged:
            break

    # Count pixels per cluster (final assignment)
    counts = [0] * k
    for px in pixels:
        best_idx = 0
        best_dist = float("inf")
        for ci, centroid in enumerate(centroids):
            d = color_distance(px, centroid)
            if d < best_dist:
                best_dist = d
                best_idx = ci
        counts[best_idx] += 1

    # Pair centroids with counts
    result = list(zip(centroids, counts))
    result.sort(key=lambda x: x[1], reverse=True)

    return result


def merge_similar_colors(
    clusters: List[Tuple[Tuple[int, int, int], int]],
    merge_threshold: float = 35.0,
) -> List[Tuple[Tuple[int, int, int], int]]:
    """Merge clusters with centroids closer than merge_threshold.

    When two clusters are merged, the centroid is the weighted average
    of the two centroids (weighted by pixel count).
    """
    merged = list(clusters)
    changed = True

    while changed:
        changed = False
        for i in range(len(merged)):
            if merged[i] is None:
                continue
            for j in range(i + 1, len(merged)):
                if merged[j] is None:
                    continue

                c1, n1 = merged[i]
                c2, n2 = merged[j]
                if color_distance(c1, c2) < merge_threshold:
                    # Weighted average centroid
                    total = n1 + n2
                    new_c = (
                        (c1[0] * n1 + c2[0] * n2) // total,
                        (c1[1] * n1 + c2[1] * n2) // total,
                        (c1[2] * n1 + c2[2] * n2) // total,
                    )
                    merged[i] = (new_c, total)
                    merged[j] = None
                    changed = True

        merged = [x for x in merged if x is not None]

    merged.sort(key=lambda x: x[1], reverse=True)
    return merged


# ── Color Category Labeling ──────────────────────────────────────────────

# Skin tone ranges (approximate HSL ranges)
SKIN_HUE_RANGE = (5, 45)     # Warm orangish-brown hues
SKIN_SAT_MIN = 0.15
SKIN_LIGHT_RANGE = (0.25, 0.80)


def classify_color_category(
    r: int, g: int, b: int,
    rank: int,
    used_categories: set,
) -> str:
    """
    Assign a descriptive category name to a color based on its properties and rank.

    Categories: skin, skin_shadow, hair, wardrobe_primary, wardrobe_secondary,
    accent, highlight, shadow, detail.
    """
    h, s, l = rgb_to_hsl(r, g, b)

    # Skin detection: warm hue, moderate saturation, medium lightness
    if (SKIN_HUE_RANGE[0] <= h <= SKIN_HUE_RANGE[1]
            and s >= SKIN_SAT_MIN
            and SKIN_LIGHT_RANGE[0] <= l <= SKIN_LIGHT_RANGE[1]):
        if "skin" not in used_categories:
            return "skin"
        if "skin_shadow" not in used_categories:
            return "skin_shadow"

    # Very dark colors: hair or deep shadow
    if l < 0.15:
        if "hair" not in used_categories:
            return "hair"
        if "shadow" not in used_categories:
            return "shadow"

    # Dark-medium, low saturation, warm-neutral: likely hair
    if l < 0.35 and s < 0.4 and "hair" not in used_categories:
        return "hair"

    # Very light colors (l > 0.80): highlight or light detail
    if l > 0.80:
        if "highlight" not in used_categories:
            return "highlight"
        if "detail" not in used_categories:
            return "detail"

    # High-ranking colors that aren't skin/hair are wardrobe
    # Assign wardrobe before accent to avoid mis-labeling dominant clothing
    if rank < 4:
        if "wardrobe_primary" not in used_categories:
            return "wardrobe_primary"
        if "wardrobe_secondary" not in used_categories:
            return "wardrobe_secondary"

    # High saturation = accent (for less dominant vivid colors)
    if s > 0.5:
        if "accent" not in used_categories:
            return "accent"

    # Remaining wardrobe slots for lower-ranked colors
    if "wardrobe_primary" not in used_categories:
        return "wardrobe_primary"
    if "wardrobe_secondary" not in used_categories:
        return "wardrobe_secondary"

    # Remaining fallbacks
    for cat in ["accent", "detail", "highlight", "shadow"]:
        if cat not in used_categories:
            return cat

    return f"color_{rank}"


# ── Palette Extraction ───────────────────────────────────────────────────

def extract_palette_from_image(
    image_path: Path,
    num_colors: int = 8,
    white_threshold: int = 220,
    sample_step: int = 3,
) -> List[Dict]:
    """
    Extract dominant color palette from a single image.

    Args:
        image_path: Path to PNG image
        num_colors: Target number of colors (5-8)
        white_threshold: RGB values above this are excluded as background
        sample_step: Sample every Nth pixel for performance (3 = ~11% of pixels)

    Returns:
        List of dicts: [{"hex": "#AABBCC", "rgb": [170,187,204], "name": "wardrobe_primary", "percentage": 23.4}]
    """
    from PIL import Image

    img = Image.open(image_path).convert("RGB")
    width, height = img.size

    # Sample pixels (skip every sample_step for performance)
    pixels = []
    try:
        raw_data = list(img.get_flattened_data())
    except AttributeError:
        # Older Pillow versions don't have get_flattened_data
        raw_data = list(img.getdata())
    total_sampled = 0

    for i in range(0, len(raw_data), sample_step):
        r, g, b = raw_data[i]
        total_sampled += 1

        # Exclude white/near-white (studio background)
        if is_white_or_near_white(r, g, b, threshold=white_threshold):
            continue

        # Exclude near-black (pure black borders/artifacts)
        if is_near_black(r, g, b, threshold=10):
            continue

        # Exclude extremely desaturated greys (background shadow)
        if is_very_desaturated(r, g, b):
            continue

        pixels.append((r, g, b))

    if not pixels:
        print(f"  WARNING: No non-background pixels found in {image_path.name}", file=sys.stderr)
        return []

    bg_pct = (1.0 - len(pixels) / total_sampled) * 100
    print(f"  Sampled {total_sampled} pixels, {len(pixels)} after background removal ({bg_pct:.0f}% background)")

    # Over-cluster then merge (better results than clustering to exact target)
    initial_k = min(num_colors * 3, 20)
    clusters = kmeans_cluster(pixels, k=initial_k, max_iterations=40)

    # Merge similar colors
    clusters = merge_similar_colors(clusters, merge_threshold=35.0)

    # Take top N colors
    clusters = clusters[:num_colors]

    # Calculate total for percentages
    total_cluster_pixels = sum(count for _, count in clusters)

    # Build palette entries with category labels
    palette = []
    used_categories: set = set()

    for rank, (centroid, count) in enumerate(clusters):
        r, g, b = centroid
        pct = (count / total_cluster_pixels) * 100 if total_cluster_pixels > 0 else 0

        category = classify_color_category(r, g, b, rank, used_categories)
        used_categories.add(category)

        palette.append({
            "hex": rgb_to_hex(r, g, b),
            "rgb": [r, g, b],
            "name": category,
            "percentage": round(pct, 1),
        })

    return palette


# ── Console Output ───────────────────────────────────────────────────────

def print_palette(character: str, palette: List[Dict]):
    """Pretty-print a character palette to console."""
    print(f"\n  {character}")
    print(f"  {'─' * 50}")

    for entry in palette:
        hex_code = entry["hex"]
        name = entry["name"]
        pct = entry["percentage"]
        r, g, b = entry["rgb"]

        # ANSI color block (24-bit true color)
        color_block = f"\033[48;2;{r};{g};{b}m    \033[0m"
        print(f"  {color_block}  {hex_code}  {name:<22s} ({pct:5.1f}%)")

    print()


# ── Breakdown Integration ────────────────────────────────────────────────

def update_breakdown(
    breakdown: dict,
    palettes: Dict[str, List[Dict]],
) -> int:
    """
    Write color_palette entries to breakdown.json characters.

    Returns count of characters updated.
    """
    updated = 0
    characters = breakdown.get("characters", {})

    for char_key, palette in palettes.items():
        if char_key in characters:
            characters[char_key]["color_palette"] = palette
            updated += 1

    return updated


# ── Main ─────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description="Extract dominant HEX color palettes from character hero reference images"
    )
    parser.add_argument("project", help="Project path (e.g., 'leviathan/')")

    # Filters
    filter_group = parser.add_argument_group("filters")
    filter_group.add_argument("--character", type=str, metavar="KEY",
                              help="Extract for a specific character only (e.g., JINX)")

    # Extraction options
    extract_group = parser.add_argument_group("extraction options")
    extract_group.add_argument("--num-colors", type=int, default=8,
                               help="Number of dominant colors to extract (default: 8)")
    extract_group.add_argument("--white-threshold", type=int, default=220,
                               help="RGB threshold for white background exclusion (default: 220)")
    extract_group.add_argument("--sample-step", type=int, default=3,
                               help="Sample every Nth pixel for performance (default: 3)")
    extract_group.add_argument("--merge-threshold", type=float, default=35.0,
                               help="Euclidean distance below which colors are merged (default: 35.0)")

    # Output options
    output_group = parser.add_argument_group("output options")
    output_group.add_argument("--dry-run", action="store_true",
                              help="Preview extraction without writing to breakdown.json")
    output_group.add_argument("--no-update", action="store_true",
                              help="Don't update breakdown.json (console output only)")
    output_group.add_argument("--json", action="store_true",
                              help="Output raw JSON to stdout (for piping)")

    args = parser.parse_args()

    # Resolve project
    project_path = resolve_project_path(args.project)
    heroes_dir = project_path / "visual" / "refs" / "characters" / "heroes"
    breakdown_path = project_path / "visual" / "breakdown.json"

    if not heroes_dir.exists():
        print(f"ERROR: Heroes directory not found: {heroes_dir}", file=sys.stderr)
        sys.exit(1)

    # Find hero images
    hero_images = sorted(heroes_dir.glob("*.png"))
    if not hero_images:
        print(f"ERROR: No PNG images found in {heroes_dir}", file=sys.stderr)
        sys.exit(1)

    # Apply character filter
    if args.character:
        char_upper = args.character.upper()
        hero_images = [p for p in hero_images if p.stem.upper() == char_upper]
        if not hero_images:
            print(f"ERROR: No hero image found for character: {args.character}", file=sys.stderr)
            print(f"Available: {', '.join(p.stem for p in heroes_dir.glob('*.png'))}", file=sys.stderr)
            sys.exit(1)

    # Load breakdown if we might update it
    breakdown = None
    if not args.no_update and not args.dry_run and breakdown_path.exists():
        try:
            breakdown = json.loads(breakdown_path.read_text(encoding="utf-8"))
        except json.JSONDecodeError as e:
            print(f"WARNING: Could not parse breakdown.json: {e}", file=sys.stderr)
            print("Continuing with console output only.", file=sys.stderr)

    # Extract palettes
    print(f"\n{'=' * 60}")
    print(f"EXTRACTING COLOR PALETTES")
    print(f"Project: {project_path.name}")
    print(f"Heroes dir: {heroes_dir}")
    print(f"Images: {len(hero_images)}")
    print(f"Target colors: {args.num_colors}")
    print(f"White threshold: {args.white_threshold}")
    if args.dry_run:
        print(f"Mode: DRY RUN")
    print(f"{'=' * 60}")

    palettes: Dict[str, List[Dict]] = {}

    for image_path in hero_images:
        char_key = image_path.stem.upper()
        print(f"\n  Processing {char_key} ({image_path.name})...")

        if args.dry_run:
            print(f"  [DRY RUN] Would extract {args.num_colors} colors from {image_path}")
            continue

        palette = extract_palette_from_image(
            image_path,
            num_colors=args.num_colors,
            white_threshold=args.white_threshold,
            sample_step=args.sample_step,
        )

        if palette:
            palettes[char_key] = palette
            if not args.json:
                print_palette(char_key, palette)
        else:
            print(f"  WARNING: No palette extracted for {char_key}")

    if args.dry_run:
        print(f"\n{'=' * 60}")
        print(f"DRY RUN complete. No files modified.")
        print(f"{'=' * 60}\n")
        return

    # JSON output mode
    if args.json:
        json.dump(palettes, sys.stdout, indent=2)
        print()  # trailing newline
        return

    # Update breakdown.json
    if breakdown and not args.no_update and palettes:
        updated = update_breakdown(breakdown, palettes)
        if updated > 0:
            breakdown_path.write_text(
                json.dumps(breakdown, indent=2, ensure_ascii=False) + "\n",
                encoding="utf-8",
            )
            print(f"Updated breakdown.json: {updated} character palette(s) written.")
        else:
            print("No matching characters found in breakdown.json to update.")
    elif args.no_update:
        print("--no-update: breakdown.json not modified.")
    elif not breakdown_path.exists():
        print(f"breakdown.json not found at {breakdown_path} — skipping update.")

    # Summary
    print(f"\n{'=' * 60}")
    print(f"COMPLETE: {len(palettes)} palette(s) extracted")
    for char_key, palette in palettes.items():
        hex_list = " ".join(e["hex"] for e in palette)
        print(f"  {char_key}: {hex_list}")
    print(f"{'=' * 60}\n")


if __name__ == "__main__":
    main()
