#!/usr/bin/env python3
"""
train_lora.py — Generalized LoRA Training Pipeline

Manages the full LoRA lifecycle: dataset preparation, training submission
(fal.ai), status polling, weight download, and registry auto-update.

Usage:
    # Prepare dataset (caption images, create ZIP from candidates selection)
    python3 train_lora.py leviathan/ prepare JINX [--dry-run] [--from-candidates]

    # Validate dataset against best practices (pre-flight check)
    python3 train_lora.py leviathan/ validate KIAN [--force]

    # Submit T2I training (non-blocking, saves request_id to registry)
    python3 train_lora.py leviathan/ submit JINX --type t2i [--steps 1000] [--lr 0.00005]

    # Submit WAN 2.2 video training (non-blocking)
    python3 train_lora.py leviathan/ submit JINX --type video [--steps 1000] [--lr 0.0007]

    # Check status — on completion: downloads weights + auto-updates registry
    python3 train_lora.py leviathan/ status [CHAR] [--wait]

    # Initialize registry from breakdown.json character list
    python3 train_lora.py leviathan/ init [--seed-existing]

    # Show current registry
    python3 train_lora.py leviathan/ show

Env vars:
    FAL_KEY — fal.ai API key (required for submit/status)

Dependencies:
    pip install fal-client
"""

import argparse
import json
import os
import shutil
import sys
import time
import zipfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Optional

from cost_tracker import CostTracker


# ── Training type → pricing_rates.json model key ─────────────────────────

TRAINING_MODEL_KEYS = {
    "t2i": "training_flux",
    "flux": "training_flux",
    "z_image": "training_z_image",
    "z_image_base": "training_z_image_base",
    "video": "training_wan",
    "wan": "training_wan",
}


# ── Model Configs ────────────────────────────────────────────────────────

MODEL_CONFIGS = {
    "t2i": {
        "endpoint": "fal-ai/flux-lora-fast-training",
        "default_steps": 1000,
        "default_lr": 0.00005,
        "result_keys": ["diffusers_lora_file"],
        "special_args": {},
        "registry_key": "t2i",
        "cost_per_1k_steps": 2.00,  # $2/1000 steps (fal-ai/flux-lora-fast-training — confirmed from fal.ai model page)
        # NOTE: fal-ai/flux-2-trainer is $8/1K steps — different, more expensive endpoint
    },
    "z_image": {
        "endpoint": "fal-ai/z-image-trainer",
        "default_steps": 2000,
        "default_lr": 0.0001,
        "result_keys": ["diffusers_lora_file"],
        "special_args": {
            "training_type": "content",
        },
        "registry_key": "z_image_t2i",
        "cost_per_1k_steps": 2.26,  # $2.26/1000 steps (Z-Image — confirmed from fal.ai billing)
    },
    "z_image_base": {
        "endpoint": "fal-ai/z-image-base-trainer",
        "default_steps": 2000,
        "default_lr": 0.0005,
        "result_keys": ["diffusers_lora_file"],
        "special_args": {},
        "registry_key": "z_image_base_t2i",
        "no_trigger_word": True,
        "cost_per_1k_steps": 0.85,  # $0.85/1000 steps (Z-Image Base — confirmed from fal.ai model page)
    },
    "video": {
        "endpoint": "fal-ai/wan/v2.2/image-to-video/lora/training",
        "default_steps": 1000,
        "default_lr": 0.0007,
        "result_keys": ["diffusers_lora_file", "high_noise_lora"],
        "special_args": {
            "use_face_detection": True,
            "use_masks": True,
            "is_style": False,
        },
        "registry_key": "video",
        "cost_per_1k_steps": 2.00,  # ~$2/1000 steps (WAN 2.2 video, estimate)
    },
}

REGISTRY_FILENAME = "lora_registry.json"
REGISTRY_VERSION = 1


# ── Path Resolution ──────────────────────────────────────────────────────

def resolve_project_path(project_arg: str) -> Path:
    """Resolve project path from argument."""
    script_dir = Path(__file__).resolve().parent
    if script_dir.name == "tools" and script_dir.parent.name == "recoil":
        root = script_dir.parent.parent
    else:
        root = Path.cwd()

    project_name = project_arg.strip("/").strip("\\")
    candidate = root / project_name
    if candidate.is_dir():
        return candidate

    abs_path = Path(project_arg)
    if abs_path.is_dir():
        return abs_path

    cwd_path = Path.cwd() / project_name
    if cwd_path.is_dir():
        return cwd_path

    print(f"ERROR: Project directory not found: {project_arg}", file=sys.stderr)
    sys.exit(1)


def registry_path(project_dir: Path) -> Path:
    """Return path to lora_registry.json."""
    return project_dir / "visual" / REGISTRY_FILENAME


# ── Registry I/O (importable by other tools) ─────────────────────────────

def load_registry(project_dir: Path) -> dict:
    """Load LoRA registry for a project.

    Returns:
        Dict keyed by lowercase character name with training data.
        Returns empty dict + warning if file missing.
    """
    reg_path = registry_path(project_dir)
    if not reg_path.exists():
        print(f"WARNING: LoRA registry not found at {reg_path}", file=sys.stderr)
        return {}

    try:
        with open(reg_path) as f:
            data = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON in {reg_path}: {e}", file=sys.stderr)
        return {}

    return data.get("characters", {})


def save_registry(project_dir: Path, registry: dict) -> None:
    """Save LoRA registry to disk."""
    reg_path = registry_path(project_dir)
    reg_path.parent.mkdir(parents=True, exist_ok=True)

    # Preserve version and any top-level metadata
    if reg_path.exists():
        try:
            with open(reg_path) as f:
                existing = json.load(f)
        except json.JSONDecodeError as e:
            print(f"WARNING: Corrupt registry at {reg_path}: {e}", file=sys.stderr)
            existing = {"version": REGISTRY_VERSION}
    else:
        existing = {"version": REGISTRY_VERSION}

    existing["characters"] = registry
    existing["updated_at"] = datetime.now(timezone.utc).isoformat()

    with open(reg_path, "w") as f:
        json.dump(existing, f, indent=2)
    print(f"  Registry saved → {reg_path}")


def get_inference_config(registry: dict, char_name: str) -> dict:
    """Get flat inference config for a character.

    Returns dict with keys matching what generate_storyboard_keyframes.py expects:
        trigger, t2i_path, z_image_t2i_path, video_lora_high, video_lora_low,
        scale_solo, scale_dual, flux2_scale_solo, flux2_scale_dual
    """
    char = registry.get(char_name.lower(), {})
    if not char:
        return {
            "trigger": None,
            "t2i_path": None,
            "z_image_t2i_path": None,
            "z_image_base_t2i_path": None,
            "video_lora_high": None,
            "video_lora_low": None,
            "scale_solo": 1.0,
            "scale_dual": 0.8,
        }

    t2i = char.get("t2i", {})
    z_image_t2i = char.get("z_image_t2i", {})
    z_image_base_t2i = char.get("z_image_base_t2i", {})
    video = char.get("video", {})
    inference = char.get("inference", {})

    config = {
        "trigger": char.get("trigger"),
        "t2i_path": t2i.get("path"),
        "z_image_t2i_path": z_image_t2i.get("path"),
        "z_image_base_t2i_path": z_image_base_t2i.get("path"),
        "video_lora_high": video.get("high_noise_path"),
        "video_lora_low": video.get("low_noise_path"),
        "scale_solo": inference.get("scale_solo", 1.0),
        "scale_dual": inference.get("scale_dual", 0.8),
    }
    # Pass through engine-specific scales when present
    if "flux2_scale_solo" in inference:
        config["flux2_scale_solo"] = inference["flux2_scale_solo"]
    if "flux2_scale_dual" in inference:
        config["flux2_scale_dual"] = inference["flux2_scale_dual"]
    # ECU/CU reduced scales
    if "scale_ecu" in inference:
        config["scale_ecu"] = inference["scale_ecu"]
    if "flux2_scale_ecu" in inference:
        config["flux2_scale_ecu"] = inference["flux2_scale_ecu"]
    return config


# ── Subcommand: init ─────────────────────────────────────────────────────

def cmd_init(project_dir: Path, args) -> int:
    """Initialize registry from breakdown.json character list."""
    breakdown_path = project_dir / "visual" / "breakdown.json"
    if not breakdown_path.exists():
        print(f"ERROR: breakdown.json not found at {breakdown_path}", file=sys.stderr)
        return 1

    try:
        with open(breakdown_path) as f:
            breakdown = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON in {breakdown_path}: {e}", file=sys.stderr)
        return 1

    characters = breakdown.get("characters", {})
    if not characters:
        print("ERROR: No characters in breakdown.json", file=sys.stderr)
        return 1

    registry = {}
    for char_name, char_data in characters.items():
        key = char_name.lower()
        registry[key] = {
            "trigger": char_data.get("trigger_word", f"{key.upper()[:4]}CHAR"),
            "t2i": {"path": None, "request_id": None},
            "video": {"high_noise_path": None, "low_noise_path": None, "request_id": None},
            "inference": {"scale_solo": 1.0, "scale_dual": 0.8},
            "dataset_dir": None,
        }

    # Seed existing known data if requested
    if args.seed_existing:
        _seed_existing_data(registry)

    save_registry(project_dir, registry)
    print(f"  Initialized {len(registry)} characters")
    return 0


def _seed_existing_data(registry: dict) -> None:
    """Seed known LoRA data from previous training runs."""
    # Jinx T2I LoRA
    if "jinx" in registry:
        registry["jinx"]["trigger"] = "JNXCHAR"
        registry["jinx"]["t2i"] = {
            "path": "https://v3b.fal.media/files/b/0a8d5e6b/CvgdMcO5AakNaF0EyouP7_pytorch_lora_weights_comfy_converted.safetensors",
            "trained_at": "2026-02-06T04:11:00Z",
            "steps": 1000,
            "dataset_images": 28,
            "request_id": None,
        }
        registry["jinx"]["video"] = {
            "high_noise_path": "https://v3b.fal.media/files/b/0a8d6e5a/aQbDPNKWsJXb3SVGoAaI3_adapter_model.safetensors",
            "low_noise_path": "https://v3b.fal.media/files/b/0a8d6e5b/kJsl9UPEpENap9TRgfXu4_adapter_model.safetensors",
            "trained_at": "2026-02-06T14:56:00Z",
            "steps": 1000,
            "request_id": None,
        }
        registry["jinx"]["dataset_dir"] = "~/Desktop/jinx_lora_training/images"

    # Kian — pending
    if "kian" in registry:
        registry["kian"]["trigger"] = "KIANCHAR"
        registry["kian"]["dataset_dir"] = "~/Desktop/kian_lora_training/images"


# ── Subcommand: show ──────────────────────────────────────────────────────

def cmd_show(project_dir: Path, args) -> int:
    """Pretty-print the current registry."""
    registry = load_registry(project_dir)
    if not registry:
        print("  No LoRA registry found (run `init` first)")
        return 1

    print(f"\n{'=' * 70}")
    print(f"  LoRA REGISTRY — {project_dir.name}")
    print(f"{'=' * 70}")

    for char_name, data in sorted(registry.items()):
        trigger = data.get("trigger", "?")
        t2i = data.get("t2i", {})
        z_image_t2i = data.get("z_image_t2i", {})
        video = data.get("video", {})
        inference = data.get("inference", {})

        z_image_base_t2i = data.get("z_image_base_t2i", {})

        t2i_status = _format_status(t2i)
        z_image_status = _format_status(z_image_t2i)
        z_image_base_status = _format_status(z_image_base_t2i)
        video_status = _format_video_status(video)

        print(f"\n  {char_name.upper()} (trigger: {trigger})")
        print(f"    T2I (Flux 2):  {t2i_status}")
        print(f"    T2I (Z-Turbo): {z_image_status}")
        print(f"    T2I (Z-Base):  {z_image_base_status}")
        print(f"    Video:         {video_status}")
        print(f"    Scale: solo={inference.get('scale_solo', 1.0)}, dual={inference.get('scale_dual', 0.8)}")
        if data.get("dataset_dir"):
            print(f"    Data:  {data['dataset_dir']}")

    print(f"\n{'=' * 70}")
    return 0


def _format_status(t2i: dict) -> str:
    """Format T2I training status."""
    if t2i.get("path"):
        trained = t2i.get("trained_at", "unknown date")
        steps = t2i.get("steps", "?")
        return f"READY ({steps} steps, trained {trained})"
    if t2i.get("request_id"):
        return f"TRAINING (request: {t2i['request_id'][:16]}...)"
    return "NOT TRAINED"


def _format_video_status(video: dict) -> str:
    """Format video LoRA training status."""
    if video.get("high_noise_path") and video.get("low_noise_path"):
        trained = video.get("trained_at", "unknown date")
        return f"READY (high+low noise, trained {trained})"
    if video.get("request_id"):
        return f"TRAINING (request: {video['request_id'][:16]}...)"
    return "NOT TRAINED"


# ── Subcommand: prepare ──────────────────────────────────────────────────

def cmd_prepare(project_dir: Path, args) -> int:
    """Prepare training dataset: caption images, create ZIP."""
    char_name = args.character.lower()
    registry = load_registry(project_dir)

    if not registry:
        print("ERROR: No registry found. Run `init` first.", file=sys.stderr)
        return 1

    char_data = registry.get(char_name)
    if not char_data:
        print(f"ERROR: Character '{char_name}' not in registry", file=sys.stderr)
        return 1

    trigger = char_data.get("trigger", char_name.upper()[:4] + "CHAR")

    # --from-candidates: build from curated lora_candidates selection
    target_model = getattr(args, 'target_model', 'z_image') or 'z_image'
    if getattr(args, 'from_candidates', False):
        return _prepare_from_candidates(project_dir, char_name, trigger, char_data, registry, args, target_model)

    # Find reference images
    refs_dir = project_dir / "visual" / "refs" / "characters" / char_name
    if not refs_dir.exists():
        # Try uppercase
        refs_dir = project_dir / "visual" / "refs" / "characters" / char_name.upper()
    if not refs_dir.exists():
        print(f"ERROR: No reference images at {refs_dir}", file=sys.stderr)
        return 1

    images = list(refs_dir.glob("*.png")) + list(refs_dir.glob("*.jpg")) + list(refs_dir.glob("*.jpeg"))
    if not images:
        print(f"ERROR: No images found in {refs_dir}", file=sys.stderr)
        return 1

    # Load visual description from breakdown.json
    breakdown_path = project_dir / "visual" / "breakdown.json"
    visual_desc = ""
    if breakdown_path.exists():
        try:
            with open(breakdown_path) as f:
                breakdown = json.load(f)
        except json.JSONDecodeError:
            breakdown = {}
        char_breakdown = breakdown.get("characters", {}).get(char_name.upper(), {})
        visual_desc = char_breakdown.get("visual_description", "")

    # Generate captions
    caption_text = f"{trigger}, {visual_desc}" if visual_desc else trigger
    print(f"  Character: {char_name.upper()}")
    print(f"  Trigger: {trigger}")
    print(f"  Images: {len(images)}")
    print(f"  Caption: {caption_text[:80]}...")

    if args.dry_run:
        print("\n  DRY RUN — would create:")
        for img in sorted(images):
            print(f"    {img.name} + {img.stem}.txt")
        return 0

    # Write caption files alongside images
    for img in images:
        txt_path = img.with_suffix(".txt")
        txt_path.write_text(caption_text)
    print(f"  Wrote {len(images)} caption files")

    # Create ZIP for upload
    zip_dir = project_dir / "visual" / "lora" / char_name.upper() / "training"
    zip_dir.mkdir(parents=True, exist_ok=True)
    zip_path = zip_dir / "dataset.zip"
    with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
        for img in sorted(images):
            zf.write(img, img.name)
            txt_path = img.with_suffix(".txt")
            if txt_path.exists():
                zf.write(txt_path, txt_path.name)

    print(f"  ZIP: {zip_path} ({zip_path.stat().st_size / 1024 / 1024:.1f} MB)")

    # Update registry with dataset info
    char_data["dataset_dir"] = str(refs_dir)
    save_registry(project_dir, registry)
    return 0


# ── Rich caption helpers ──────────────────────────────────────────────────

# Angle name -> natural language phrase
ANGLE_LABELS = {
    "front": "looking directly at camera",
    "profile_left": "in left profile",
    "profile_right": "in right profile",
    "three_quarter_left": "turned slightly left at a three-quarter angle",
    "three_quarter_right": "turned slightly right at a three-quarter angle",
    "over_shoulder": "seen from behind and to the right, looking past the near shoulder",
    "high_angle": "seen from above, camera looking down",
    "low_angle": "seen from below, camera looking up",
    "full_body": "standing, full body visible from head to feet",
    "upper_body": "framed from the waist up",
    "closeup_face": "in tight close-up, face filling the frame",
    "back": "facing away from camera, back of head visible",
}

# Lighting name -> natural language phrase
LIGHTING_LABELS = {
    "warm_ambient": "warm ambient light",
    "cool_overhead": "cool overhead light",
    "cool_clinical": "cool clinical blue-white light",
    "rim_backlit": "rim lighting from behind with glowing edges",
    "harsh_direct": "harsh direct light casting strong shadows",
    "harsh_topdown": "harsh top-down light with deep eye socket shadows",
    "soft_studio": "soft diffused studio light",
    "dramatic_split": "dramatic split lighting",
    "dramatic_side": "dramatic side lighting with strong shadows on one side",
    "emergency_red": "red emergency lighting",
    "amber_industrial": "warm amber industrial light",
}

# Model-specific caption length targets (word count)
# Z-Image Turbo: ~75 token attention window. 2-7 words optimal (trigger + angle).
# CivitAI "Minimalist/Clean Label Method": heavy captions = disfigured features.
# Flux: T5-based, handles longer prose. 30-80 words.
# Video: detailed scene description needed. 80-150 words.
MODEL_CAPTION_CONFIG = {
    "flux": {"min_words": 30, "max_words": 80, "style": "prose"},
    "t2i": {"min_words": 30, "max_words": 80, "style": "prose"},
    "z_image": {"min_words": 3, "max_words": 15, "style": "minimal"},
    "z_image_base": {"min_words": 3, "max_words": 15, "style": "minimal"},
    "video": {"min_words": 80, "max_words": 150, "style": "detailed"},
    "wan": {"min_words": 80, "max_words": 150, "style": "detailed"},
}

# Sentence templates for natural language variation (T5 models respond to prose)
_CAPTION_TEMPLATES_PROSE = [
    "{trigger} {expression}, {angle}. {environment}. {lighting}.",
    "{trigger} with {expression}, {angle} in {environment}. {lighting}.",
    "{trigger}, {angle}. {expression}. The scene is lit by {lighting} in {environment}.",
    "{trigger} {angle}, {expression}. {environment} with {lighting}.",
    "A portrait of {trigger}, {expression}. {angle}. {environment}, {lighting}.",
    "{trigger} {angle} with {expression}. {lighting} illuminates the scene in {environment}.",
]

_CAPTION_TEMPLATES_CONCISE = [
    "{trigger}, {expression}, {angle}. {environment}, {lighting}.",
    "{trigger} {angle}. {expression}. {environment}, {lighting}.",
    "{trigger}, {angle}, {expression}, {environment}, {lighting}.",
]

# Minimal captions for Z-Image Turbo (2-7 words: trigger + class word + angle)
# CivitAI Minimalist/Clean Label Method: heavy captions poison LoRA training
_CAPTION_TEMPLATES_MINIMAL = [
    "{trigger}, {angle}",
    "{trigger} {angle}",
    "{trigger}, {expression}, {angle}",
]

_CAPTION_TEMPLATES_DETAILED = [
    "{trigger} {angle}, {expression}. The setting is {environment}. {lighting} defines the mood. The subject holds still, the air thick around them.",
    "A detailed portrait of {trigger}, {angle}. {expression}. The environment is {environment}, rendered in {lighting}. Subtle movement in the air, the scene atmospheric and grounded.",
    "{trigger} is seen {angle}. {expression}. The environment around them is {environment}. {lighting} carves the contours of the scene, every surface catching light differently.",
]


def _build_rich_caption(trigger: str, manifest_entry: dict, target_model: str = "z_image") -> str:
    """Build a natural-language training caption from manifest metadata.

    Per best practices (see docs/lora_training_best_practices.md):
    - Trigger word first
    - Describe ONLY what varies: expression, angle, environment, lighting
    - Do NOT include identity features (face, body, scars) — LoRA learns these from images
    - Do NOT include wardrobe state descriptions — these are narrative, not visual
    - Do NOT include style tags — style is constant across all images, LoRA learns it from pixels
    - Use natural language (T5-based models trained on prose, not comma lists)
    - Caption length varies by target model
    """
    import hashlib

    # Extract fields
    expr = manifest_entry.get("expression", "")
    if expr:
        article = "an" if expr[0].lower() in "aeiou" else "a"
        expression_phrase = f"{article} {expr} expression"
    else:
        expression_phrase = "a neutral expression"

    angle = manifest_entry.get("angle", "")
    angle_phrase = ANGLE_LABELS.get(angle, angle.replace("_", " ") if angle else "looking at camera")

    location = manifest_entry.get("location", "")
    env_phrase = location.lower().strip() if location else "an industrial interior"

    lighting = manifest_entry.get("lighting", "")
    light_phrase = LIGHTING_LABELS.get(lighting, lighting.replace("_", " ") if lighting else "ambient light")

    # Select template style based on target model
    config = MODEL_CAPTION_CONFIG.get(target_model, MODEL_CAPTION_CONFIG["z_image"])

    if config["style"] == "detailed":
        templates = _CAPTION_TEMPLATES_DETAILED
    elif config["style"] == "minimal":
        templates = _CAPTION_TEMPLATES_MINIMAL
    elif config["style"] == "concise":
        templates = _CAPTION_TEMPLATES_CONCISE
    else:
        templates = _CAPTION_TEMPLATES_PROSE

    # Deterministic template selection based on image metadata (varied but reproducible)
    hash_input = f"{trigger}:{angle}:{expr}:{lighting}:{location}"
    idx = int(hashlib.md5(hash_input.encode()).hexdigest(), 16) % len(templates)
    template = templates[idx]

    caption = template.format(
        trigger=trigger,
        expression=expression_phrase,
        angle=angle_phrase,
        environment=env_phrase,
        lighting=light_phrase,
    )

    return caption


def _normalize_resolutions(
    images: List[Path],
    project_dir: Path,
    char_name: str,
    dry_run: bool,
    target_size: int = 1024,
) -> List[Path]:
    """Normalize all images to consistent resolution with face-aware cropping.

    1. Detect most common resolution
    2. If all match, return as-is
    3. For mismatched images: face-aware crop centered on face, resize to target
    4. Save normalized copies to lora_training/ — never overwrite originals
    5. Return updated image paths (normalized where needed, originals where not)
    """
    try:
        from PIL import Image as PILImage
    except ImportError:
        print("  WARNING: Pillow not installed — skipping resolution normalization")
        return images

    # Detect resolutions
    res_map = {}  # (w,h) -> [paths]
    for img_path in images:
        try:
            with PILImage.open(img_path) as im:
                res_map.setdefault(im.size, []).append(img_path)
        except Exception as e:
            print(f"  WARNING: Could not read {img_path.name}: {e}")

    if len(res_map) <= 1:
        # All same resolution, no normalization needed
        if res_map:
            w, h = next(iter(res_map))
            print(f"  Resolution: all {w}x{h} — no normalization needed")
        return images

    # Find most common resolution
    most_common_res = max(res_map, key=lambda k: len(res_map[k]))
    print(f"  Resolution: mixed ({len(res_map)} sizes). Most common: {most_common_res[0]}x{most_common_res[1]}")
    print(f"  Target: {target_size}x{target_size}")

    # Set up normalized output directory
    norm_dir = project_dir / "visual" / "lora" / char_name.upper() / "training" / "_normalized"
    if not dry_run:
        norm_dir.mkdir(parents=True, exist_ok=True)

    normalized_images = []
    for img_path in images:
        try:
            with PILImage.open(img_path) as im:
                if im.size == (target_size, target_size):
                    normalized_images.append(img_path)
                    continue

                if dry_run:
                    print(f"    Would normalize: {img_path.name} ({im.size[0]}x{im.size[1]} -> {target_size}x{target_size})")
                    normalized_images.append(img_path)
                    continue

                # Face-aware crop then resize
                cropped = _face_aware_crop(im, target_size)
                norm_path = norm_dir / img_path.name
                cropped.save(norm_path, quality=95)
                print(f"    Normalized: {img_path.name} ({im.size[0]}x{im.size[1]} -> {target_size}x{target_size})")
                normalized_images.append(norm_path)
        except Exception as e:
            print(f"  WARNING: Could not normalize {img_path.name}: {e}")
            normalized_images.append(img_path)

    return normalized_images


def _face_aware_crop(im, target_size: int):
    """Crop image centered on face (if detected), then resize to target_size x target_size.

    Uses a simple center-of-brightness heuristic for face detection without
    requiring opencv. Falls back to center crop if no face region found.
    """
    from PIL import Image as PILImage

    w, h = im.size
    # Determine crop box (square) that preserves the face region
    crop_side = min(w, h)

    # Try to find face region using a simple approach:
    # Convert to grayscale, find the brightest region in the upper half
    # (faces tend to be brighter than backgrounds, especially in portrait photos)
    try:
        gray = im.convert("L")
        # Focus on upper 60% of image where face typically is
        upper_region = gray.crop((0, 0, w, int(h * 0.6)))
        # Find center of mass of bright pixels
        import numpy as np
        arr = np.array(upper_region)
        # Threshold to find bright regions (face-like)
        threshold = arr.mean() + arr.std() * 0.5
        bright_mask = arr > threshold
        if bright_mask.any():
            ys, xs = np.where(bright_mask)
            cx = int(xs.mean())
            cy = int(ys.mean())
        else:
            cx, cy = w // 2, h // 3  # Default: upper-center
    except ImportError:
        # No numpy — fall back to center crop biased toward top
        cx, cy = w // 2, h // 3

    # Calculate crop box centered on detected point
    half = crop_side // 2
    left = max(0, min(cx - half, w - crop_side))
    top = max(0, min(cy - half, h - crop_side))
    right = left + crop_side
    bottom = top + crop_side

    cropped = im.crop((left, top, right, bottom))
    resized = cropped.resize((target_size, target_size), PILImage.LANCZOS)
    return resized


def _prepare_from_candidates(
    project_dir: Path,
    char_name: str,
    trigger: str,
    char_data: dict,
    registry: dict,
    args,
    target_model: str = "z_image",
) -> int:
    """Build training ZIP from curated lora_candidates selection."""
    char_upper = char_name.upper()
    cand_dir = project_dir / "visual" / "lora_candidates" / char_upper
    sel_path = cand_dir / "selection.json"

    if not cand_dir.is_dir():
        print(f"ERROR: No candidates directory at {cand_dir}", file=sys.stderr)
        print(f"  Run: python3 batch_generate_refs.py {project_dir.name}/ --character {char_upper} --lora-prep 50",
              file=sys.stderr)
        return 1

    if not sel_path.is_file():
        print(f"ERROR: No selection.json found at {sel_path}", file=sys.stderr)
        print(f"  Curate candidates first: http://127.0.0.1:8420/_standalone/lora_picker.html"
              f"?project={project_dir.name}&character={char_upper}", file=sys.stderr)
        return 1

    try:
        with open(sel_path) as f:
            selection = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid selection.json: {e}", file=sys.stderr)
        return 1

    selections = selection.get("selections", {})
    selected_files = [fn for fn, status in selections.items() if status == "selected"]

    if not selected_files:
        print("ERROR: No images selected. Open the picker and select 15-25 candidates.", file=sys.stderr)
        return 1

    # Verify all selected files exist
    images = []
    missing = []
    for fn in sorted(selected_files):
        img_path = cand_dir / fn
        if img_path.is_file():
            images.append(img_path)
        else:
            missing.append(fn)

    if missing:
        print(f"WARNING: {len(missing)} selected files not found:", file=sys.stderr)
        for fn in missing[:5]:
            print(f"  - {fn}", file=sys.stderr)

    if not images:
        print("ERROR: None of the selected files exist on disk.", file=sys.stderr)
        return 1

    # Load manifest for per-image metadata
    manifest_path = cand_dir / "manifest.json"
    manifest = {}
    manifest_lookup = {}  # filename -> metadata
    if manifest_path.is_file():
        try:
            with open(manifest_path) as f:
                manifest = json.load(f)
            for cand in manifest.get("candidates", []):
                manifest_lookup[cand["filename"]] = cand
        except (json.JSONDecodeError, KeyError):
            pass

    # ── A4: Load keystone metadata for manual images without manifest entries ──
    keystone_meta_path = cand_dir / "keystone_metadata.json"
    if keystone_meta_path.is_file():
        try:
            with open(keystone_meta_path) as f:
                keystone_meta = json.load(f)
            # Merge keystone metadata into manifest_lookup (only for images not already in manifest)
            for fn, meta in keystone_meta.items():
                if fn not in manifest_lookup:
                    manifest_lookup[fn] = meta
                    print(f"  Keystone metadata loaded for: {fn}")
        except (json.JSONDecodeError, KeyError) as e:
            print(f"  WARNING: Could not read keystone_metadata.json: {e}", file=sys.stderr)

    # ── A5: Harvest metadata from shootout results.json files ──
    # Each shootout dir contains a results.json with angle, expression, environment, lighting.
    # This is the primary metadata source for --from-candidates; manifest.json is optional.
    shootout_meta_count = 0
    for img in images:
        if img.name in manifest_lookup:
            continue  # Already have metadata from manifest or keystones
        # Walk up to find results.json in the same shootout dir
        results_json = img.parent / "results.json"
        if results_json.is_file():
            try:
                with open(results_json) as f:
                    rj = json.load(f)
                # Handle both dict (single result) and list (batch result) formats
                if isinstance(rj, list):
                    rj = rj[0] if rj else {}
                if isinstance(rj, dict) and rj:
                    expr_raw = rj.get("expression", "")
                    meta = {
                        "angle": rj.get("angle", ""),
                        "expression": expr_raw.split(" — ")[0] if " — " in expr_raw else expr_raw,
                        "location": rj.get("environment", "")[:120],
                        "lighting": rj.get("lighting", ""),
                    }
                    manifest_lookup[img.name] = meta
                    shootout_meta_count += 1
            except (json.JSONDecodeError, KeyError):
                pass
    if shootout_meta_count:
        print(f"  Shootout metadata: loaded {shootout_meta_count} entries from results.json files")

    # ── A3: Normalize resolutions (face-aware cropping) ──
    images = _normalize_resolutions(images, project_dir, char_name, args.dry_run)

    print(f"  Character: {char_upper}")
    print(f"  Trigger: {trigger}")
    print(f"  Selected: {len(images)} images (from {len(selections)} reviewed)")
    print(f"  Captions: natural language (trigger + expression + angle + env + lighting)")
    print(f"  Target model: {target_model}")

    # Run validation checks
    val_errors, val_warnings = _validate_training_dataset(project_dir, char_name, manifest_lookup, images)
    for check, msg in val_warnings:
        # Skip info-only entries in prepare output
        if check == "Coverage matrix":
            continue
        print(f"  WARNING: [{check}] {msg}")
    for check, msg in val_errors:
        print(f"  ERROR: [{check}] {msg}")
    if val_errors and not getattr(args, 'force', False):
        print(f"\n  {len(val_errors)} error(s) found. Fix before training or use --force to override.")
        return 1

    # ── D: Clean stale .txt files before writing new captions ──
    if not args.dry_run:
        stale_count = 0
        for img in images:
            stale_txt = img.with_suffix(".txt")
            if stale_txt.is_file():
                stale_txt.unlink()
                stale_count += 1
        if stale_count:
            print(f"  Cleaned {stale_count} stale .txt caption files")

    if args.dry_run:
        print(f"\n  DRY RUN — target model: {target_model}")
        print("  Would package:")
        for img in images:
            meta = manifest_lookup.get(img.name, {})
            if meta:
                caption = _build_rich_caption(trigger, meta, target_model)
            else:
                caption = trigger
            print(f"    {img.name} + {img.stem}.txt")
            print(f"      caption: {caption[:120]}...")
        return 0

    # ── Caption quality gate — refuse to train with bare trigger words ──
    bare_count = 0
    for img in images:
        meta = manifest_lookup.get(img.name, {})
        if not meta:
            bare_count += 1
    if bare_count > 0:
        pct = bare_count / len(images) * 100
        if pct > 20 and not getattr(args, 'force', False):
            print(f"  ERROR: [Caption quality] {bare_count}/{len(images)} images ({pct:.0f}%) have NO metadata — captions would be bare trigger words.")
            print(f"         This produces unusable LoRAs. Ensure shootout dirs have results.json files.")
            print(f"         Use --force to override (not recommended).")
            return 1
        elif bare_count > 0:
            print(f"  WARNING: [Caption quality] {bare_count}/{len(images)} images have no metadata — will use bare trigger word")

    # ── Caption length gate (Z-Image: 2-7 words optimal, max 15) ──
    # Z-Image Turbo's ~75 token attention window applies to training captions too.
    # Long captions (60+ words) poison training — noise swamps identity signal.
    # See visual_production_findings.md: "Caption length is #1 LoRA training variable"
    caption_words = []
    for img in images:
        meta = manifest_lookup.get(img.name, {})
        if meta:
            cap = _build_rich_caption(trigger, meta, target_model)
        else:
            cap = trigger
        caption_words.append(len(cap.split()))

    if caption_words:
        avg_words = sum(caption_words) / len(caption_words)
        max_words = max(caption_words)
        z_image_models = ("z_image", "z_image_base")

        if target_model in z_image_models:
            if max_words > 25:
                if not getattr(args, 'force', False):
                    print(f"  ERROR: [Caption length] Longest caption is {max_words} words (max 25 for Z-Image).")
                    print(f"         Z-Image Turbo LoRAs train best with 2-7 word captions.")
                    print(f"         Long captions poison training. Use --force to override.")
                    return 1
                else:
                    print(f"  WARNING: [Caption length] Longest caption is {max_words} words (recommended max 25 for Z-Image). --force used.")
            if avg_words > 15:
                print(f"  WARNING: [Caption length] Average caption is {avg_words:.0f} words (recommended ≤15 for Z-Image).")
                print(f"           Optimal: 2-7 words (trigger + angle + expression).")
            else:
                print(f"  Caption stats: avg {avg_words:.0f} words, max {max_words} words (Z-Image budget: ≤15 avg, ≤25 max)")
        else:
            print(f"  Caption stats: avg {avg_words:.0f} words, max {max_words} words (target: {target_model})")

    # Create ZIP with images + clean captions
    zip_dir = project_dir / "visual" / "lora" / char_upper / "training"
    zip_dir.mkdir(parents=True, exist_ok=True)
    zip_path = zip_dir / "dataset.zip"
    with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
        for img in images:
            zf.write(img, img.name)
            meta = manifest_lookup.get(img.name, {})
            if meta:
                caption = _build_rich_caption(trigger, meta, target_model)
            else:
                caption = trigger
            txt_name = img.stem + ".txt"
            zf.writestr(txt_name, caption)

    print(f"  ZIP: {zip_path} ({zip_path.stat().st_size / 1024 / 1024:.1f} MB)")
    print(f"  Contains: {len(images)} images + {len(images)} caption files")

    # Also write caption files alongside selected images for reference
    for img in images:
        meta = manifest_lookup.get(img.name, {})
        if meta:
            caption = _build_rich_caption(trigger, meta, target_model)
        else:
            caption = trigger
        txt_path = img.with_suffix(".txt")
        txt_path.write_text(caption)
    print(f"  Caption files: {len(images)} .txt files written alongside images")

    # Copy selected images to lora/CHARACTER/training dir for clean reference
    training_dir = project_dir / "visual" / "lora" / char_upper / "training"
    training_dir.mkdir(parents=True, exist_ok=True)
    for img in images:
        shutil.copy2(img, training_dir / img.name)
        txt_src = img.with_suffix(".txt")
        if txt_src.is_file():
            shutil.copy2(txt_src, training_dir / txt_src.name)
    char_data["dataset_dir"] = f"{project_dir.name}/visual/lora/{char_upper}/training"
    print(f"  Training set: {len(images)} images copied to {training_dir}")

    # Archive rejected images
    rejected = [fn for fn, status in selections.items() if status == "rejected"]
    if rejected:
        reject_dir = cand_dir / "_rejected"
        reject_dir.mkdir(exist_ok=True)
        moved = 0
        for fn in rejected:
            src = cand_dir / fn
            if src.is_file():
                target = reject_dir / fn
                target.parent.mkdir(parents=True, exist_ok=True)
                src.rename(target)
                moved += 1
        if moved:
            print(f"  Archived: {moved} rejected images -> {reject_dir.name}/")

    save_registry(project_dir, registry)
    return 0


# ── Subcommand: submit ───────────────────────────────────────────────────

def cmd_submit(project_dir: Path, args) -> int:
    """Submit training job to fal.ai (non-blocking)."""
    if not os.environ.get("FAL_KEY"):
        print("ERROR: FAL_KEY not set", file=sys.stderr)
        return 1

    import fal_client

    char_name = args.character.lower()
    train_type = args.type
    registry = load_registry(project_dir)

    if not registry:
        print("ERROR: No registry found. Run `init` first.", file=sys.stderr)
        return 1

    char_data = registry.get(char_name)
    if not char_data:
        print(f"ERROR: Character '{char_name}' not in registry", file=sys.stderr)
        return 1

    config = MODEL_CONFIGS[train_type]
    trigger = char_data.get("trigger", char_name.upper()[:4] + "CHAR")
    steps = args.steps or config["default_steps"]
    lr = args.lr or config["default_lr"]

    # Find and upload dataset ZIP (new path first, fallback to legacy)
    zip_path = project_dir / "visual" / "lora" / char_name.upper() / "training" / "dataset.zip"
    if not zip_path.exists():
        # Fallback: legacy flat path
        zip_path = project_dir / "visual" / f"{char_name}_training_dataset.zip"
    if not zip_path.exists():
        print(f"ERROR: Dataset ZIP not found.", file=sys.stderr)
        print(f"  Checked: {project_dir / 'visual' / 'lora' / char_name.upper() / 'training' / 'dataset.zip'}", file=sys.stderr)
        print(f"  Run: python3 train_lora.py {project_dir.name}/ prepare {char_name.upper()}", file=sys.stderr)
        return 1

    print(f"  Uploading dataset: {zip_path.name}...")
    data_url = fal_client.upload_file(str(zip_path))
    print(f"  Uploaded → {data_url[:60]}...")

    # Build training arguments
    train_args = {
        "images_data_url": data_url,
        "trigger_word": trigger,
        "steps": steps,
        "learning_rate": lr,
    }

    # Type-specific args
    if train_type == "t2i":
        train_args["create_masks"] = True
        train_args["is_style"] = False
    elif train_type == "z_image":
        train_args.update(config["special_args"])
        # Z-Image trainer uses image_data_url instead of images_data_url
        train_args["image_data_url"] = train_args.pop("images_data_url")
    elif train_type == "z_image_base":
        train_args.update(config["special_args"])
        # Base trainer uses image_data_url, no trigger_word param
        train_args["image_data_url"] = train_args.pop("images_data_url")
        train_args.pop("trigger_word", None)
        # Set default_caption as fallback (trigger is in caption files)
        train_args["default_caption"] = trigger
    elif train_type == "video":
        train_args.update(config["special_args"])

    print(f"\n  Submitting {train_type.upper()} training:")
    print(f"    Endpoint: {config['endpoint']}")
    print(f"    Trigger:  {trigger}")
    print(f"    Steps:    {steps}")
    print(f"    LR:       {lr}")
    cost_per_1k = config.get("cost_per_1k_steps", 1.0)
    cost_est = (steps / 1000) * cost_per_1k
    print(f"    Est cost: ~${cost_est:.2f} (${cost_per_1k}/1K steps)")

    # Non-blocking submit
    handler = fal_client.submit(config["endpoint"], arguments=train_args)
    request_id = handler.request_id
    print(f"\n  Submitted! Request ID: {request_id}")

    # Save request_id to registry using the config's registry key
    reg_key = config.get("registry_key", train_type)
    char_data.setdefault(reg_key, {})["request_id"] = request_id

    save_registry(project_dir, registry)
    print(f"  Request ID saved to registry")
    print(f"\n  Check status: python3 train_lora.py {project_dir.name}/ status {char_name.upper()}")
    return 0


# ── Subcommand: status ───────────────────────────────────────────────────

def cmd_status(project_dir: Path, args) -> int:
    """Check training status. On completion: download weights + update registry."""
    if not os.environ.get("FAL_KEY"):
        print("ERROR: FAL_KEY not set", file=sys.stderr)
        return 1

    import fal_client

    registry = load_registry(project_dir)
    if not registry:
        print("ERROR: No registry found.", file=sys.stderr)
        return 1

    # If character specified, check just that one
    chars_to_check = []
    if args.character:
        char_name = args.character.lower()
        if char_name not in registry:
            print(f"ERROR: Character '{char_name}' not in registry", file=sys.stderr)
            return 1
        chars_to_check = [char_name]
    else:
        # Check all characters with pending request_ids
        for name, data in registry.items():
            t2i_rid = data.get("t2i", {}).get("request_id")
            z_rid = data.get("z_image_t2i", {}).get("request_id")
            zb_rid = data.get("z_image_base_t2i", {}).get("request_id")
            vid_rid = data.get("video", {}).get("request_id")
            if t2i_rid or z_rid or zb_rid or vid_rid:
                chars_to_check.append(name)

    if not chars_to_check:
        print("  No pending training jobs found.")
        return 0

    tracker = CostTracker(project_dir)
    updated = False
    for char_name in chars_to_check:
        char_data = registry[char_name]

        # Check T2I (Flux 2)
        t2i = char_data.get("t2i", {})
        if t2i.get("request_id") and not t2i.get("path"):
            result = _check_training(
                fal_client, t2i["request_id"],
                MODEL_CONFIGS["t2i"]["endpoint"],
                char_name, "t2i", args.wait,
            )
            if result:
                _handle_t2i_completion(char_data, result)
                _download_all_weights_for_type(project_dir, char_name, "t2i", char_data)
                tracker.log(
                    category="training",
                    provider="fal",
                    model=TRAINING_MODEL_KEYS["t2i"],
                    steps=MODEL_CONFIGS["t2i"]["default_steps"],
                    detail=f"Character: {char_name}, type: t2i",
                    success=True,
                )
                updated = True

        # Check Z-Image T2I
        z_t2i = char_data.get("z_image_t2i", {})
        if z_t2i.get("request_id") and not z_t2i.get("path"):
            result = _check_training(
                fal_client, z_t2i["request_id"],
                MODEL_CONFIGS["z_image"]["endpoint"],
                char_name, "z_image", args.wait,
            )
            if result:
                _handle_z_image_completion(char_data, result)
                _download_all_weights_for_type(project_dir, char_name, "z_image", char_data)
                tracker.log(
                    category="training",
                    provider="fal",
                    model=TRAINING_MODEL_KEYS["z_image"],
                    steps=MODEL_CONFIGS["z_image"]["default_steps"],
                    detail=f"Character: {char_name}, type: z_image",
                    success=True,
                )
                updated = True

        # Check Z-Image Base T2I
        zb_t2i = char_data.get("z_image_base_t2i", {})
        if zb_t2i.get("request_id") and not zb_t2i.get("path"):
            result = _check_training(
                fal_client, zb_t2i["request_id"],
                MODEL_CONFIGS["z_image_base"]["endpoint"],
                char_name, "z_image_base", args.wait,
            )
            if result:
                _handle_z_image_base_completion(char_data, result)
                _download_all_weights_for_type(project_dir, char_name, "z_image_base", char_data)
                tracker.log(
                    category="training",
                    provider="fal",
                    model=TRAINING_MODEL_KEYS["z_image_base"],
                    steps=MODEL_CONFIGS["z_image_base"]["default_steps"],
                    detail=f"Character: {char_name}, type: z_image_base",
                    success=True,
                )
                updated = True

        # Check video
        video = char_data.get("video", {})
        if video.get("request_id") and not video.get("high_noise_path"):
            result = _check_training(
                fal_client, video["request_id"],
                MODEL_CONFIGS["video"]["endpoint"],
                char_name, "video", args.wait,
            )
            if result:
                _handle_video_completion(char_data, result)
                _download_all_weights_for_type(project_dir, char_name, "video", char_data)
                tracker.log(
                    category="training",
                    provider="fal",
                    model=TRAINING_MODEL_KEYS["video"],
                    steps=MODEL_CONFIGS["video"]["default_steps"],
                    detail=f"Character: {char_name}, type: video",
                    success=True,
                )
                updated = True

    if updated:
        save_registry(project_dir, registry)

    return 0


def _check_training(fal_client, request_id: str, endpoint: str,
                     char_name: str, train_type: str, wait: bool) -> Optional[dict]:
    """Check a single training job status."""
    print(f"\n  {char_name.upper()} ({train_type}): checking {request_id[:16]}...")

    while True:
        try:
            status = fal_client.status(endpoint, request_id, with_logs=True)
            status_str = str(type(status).__name__)
            print(f"    Status: {status_str}")

            if hasattr(status, "logs") and status.logs:
                for log in status.logs[-3:]:
                    print(f"    Log: {log.get('message', str(log))}")

            # Check if completed
            if "Completed" in status_str:
                result = fal_client.result(endpoint, request_id)
                print(f"    COMPLETE!")
                return result

            if not wait:
                return None

            print(f"    Waiting 60s...")
            time.sleep(60)

        except Exception as e:
            print(f"    Error: {e}")
            return None


def _handle_t2i_completion(char_data: dict, result: dict) -> None:
    """Process completed T2I training result."""
    t2i = char_data.setdefault("t2i", {})

    # Extract LoRA file URL
    lora_url = None
    if "diffusers_lora_file" in result:
        lora_file = result["diffusers_lora_file"]
        lora_url = lora_file.get("url") if isinstance(lora_file, dict) else lora_file

    if lora_url:
        t2i["path"] = lora_url
        t2i["trained_at"] = datetime.now(timezone.utc).isoformat()
        print(f"    T2I LoRA: {lora_url[:60]}...")
    else:
        print(f"    WARNING: No LoRA file URL in result")
        print(f"    Keys: {list(result.keys())}")


def _handle_z_image_completion(char_data: dict, result: dict) -> None:
    """Process completed Z-Image T2I training result."""
    z_t2i = char_data.setdefault("z_image_t2i", {})

    lora_url = None
    if "diffusers_lora_file" in result:
        lora_file = result["diffusers_lora_file"]
        lora_url = lora_file.get("url") if isinstance(lora_file, dict) else lora_file

    if lora_url:
        z_t2i["path"] = lora_url
        z_t2i["trained_at"] = datetime.now(timezone.utc).isoformat()
        print(f"    Z-Image LoRA: {lora_url[:60]}...")
    else:
        print(f"    WARNING: No LoRA file URL in result")
        print(f"    Keys: {list(result.keys())}")


def _handle_z_image_base_completion(char_data: dict, result: dict) -> None:
    """Process completed Z-Image Base T2I training result."""
    zb_t2i = char_data.setdefault("z_image_base_t2i", {})

    lora_url = None
    if "diffusers_lora_file" in result:
        lora_file = result["diffusers_lora_file"]
        lora_url = lora_file.get("url") if isinstance(lora_file, dict) else lora_file

    if lora_url:
        zb_t2i["path"] = lora_url
        zb_t2i["trained_at"] = datetime.now(timezone.utc).isoformat()
        print(f"    Z-Image Base LoRA: {lora_url[:60]}...")
    else:
        print(f"    WARNING: No LoRA file URL in result")
        print(f"    Keys: {list(result.keys())}")


def _handle_video_completion(char_data: dict, result: dict) -> None:
    """Process completed video training result."""
    video = char_data.setdefault("video", {})

    # WAN 2.2 produces high-noise and low-noise LoRAs
    high_url = None
    low_url = None

    if "diffusers_lora_file" in result:
        f = result["diffusers_lora_file"]
        low_url = f.get("url") if isinstance(f, dict) else f

    if "high_noise_lora" in result:
        f = result["high_noise_lora"]
        high_url = f.get("url") if isinstance(f, dict) else f

    if high_url:
        video["high_noise_path"] = high_url
        print(f"    High-noise LoRA: {high_url[:60]}...")
    if low_url:
        video["low_noise_path"] = low_url
        print(f"    Low-noise LoRA: {low_url[:60]}...")

    if high_url or low_url:
        video["trained_at"] = datetime.now(timezone.utc).isoformat()
    else:
        print(f"    WARNING: No LoRA URLs in result")
        print(f"    Keys: {list(result.keys())}")


# ── Auto-download weights ────────────────────────────────────────────────

def _download_weights(project_dir: Path, char_name: str, train_type: str,
                      url: str) -> Optional[Path]:
    """Download a safetensors file from fal.ai to the local weights directory.

    Saves to: [project]/visual/lora/[CHARACTER]/weights/[character]_[type].safetensors

    Args:
        project_dir: Project root path (e.g., leviathan/)
        char_name: Lowercase character name (e.g., 'jinx')
        train_type: Training type key (e.g., 't2i', 'z_image_t2i', 'video_high_noise')
        url: Remote URL of the safetensors file

    Returns:
        Path to downloaded file, or None if download failed.
    """
    if not url:
        return None

    char_upper = char_name.upper()
    weights_dir = project_dir / "visual" / "lora" / char_upper / "weights"
    weights_dir.mkdir(parents=True, exist_ok=True)

    # Build filename matching existing convention:
    #   jinx_flux2_t2i.safetensors, jinx_z_image_t2i.safetensors, etc.
    filename = f"{char_name}_{train_type}.safetensors"
    dest = weights_dir / filename

    print(f"    Downloading weights to {dest.name}...", end=" ", flush=True)
    try:
        import urllib.request
        urllib.request.urlretrieve(url, str(dest))
        size_mb = dest.stat().st_size / (1024 * 1024)
        print(f"OK ({size_mb:.1f} MB)")
        print(f"    Saved: {dest}")
        return dest
    except Exception as e:
        print(f"FAILED")
        print(f"    WARNING: Weight download failed: {e}", file=sys.stderr)
        print(f"    URL is saved in registry — download manually later.", file=sys.stderr)
        # Clean up partial download
        if dest.exists():
            dest.unlink()
        return None


def _download_all_weights_for_type(project_dir: Path, char_name: str,
                                    train_type: str, char_data: dict) -> None:
    """Download weights after a training completion, based on training type.

    Handles the mapping from training types to the URL keys in char_data.
    """
    if train_type == "t2i":
        url = char_data.get("t2i", {}).get("path")
        if url:
            _download_weights(project_dir, char_name, "flux2_t2i", url)
    elif train_type == "z_image":
        url = char_data.get("z_image_t2i", {}).get("path")
        if url:
            _download_weights(project_dir, char_name, "z_image_t2i", url)
    elif train_type == "z_image_base":
        url = char_data.get("z_image_base_t2i", {}).get("path")
        if url:
            _download_weights(project_dir, char_name, "z_image_base_t2i", url)
    elif train_type == "video":
        high_url = char_data.get("video", {}).get("high_noise_path")
        low_url = char_data.get("video", {}).get("low_noise_path")
        if high_url:
            _download_weights(project_dir, char_name, "video_high_noise", high_url)
        if low_url:
            _download_weights(project_dir, char_name, "video_low_noise", low_url)


# ── Subcommand: validate ─────────────────────────────────────────────────

def _validate_training_dataset(
    project_dir: Path,
    char_name: str,
    manifest_lookup: dict,
    images: List[Path],
) -> tuple:
    """Validate a training dataset against best practices.

    Returns (errors: list, warnings: list) where each entry is a (check, message) tuple.
    """
    errors = []
    warnings = []

    # ── Image count ──
    count = len(images)
    if count < 5:
        errors.append(("Image count", f"{count} images (minimum 5 required)"))
    elif count < 15:
        warnings.append(("Image count", f"{count} images (optimal: 15-25)"))
    elif count <= 25:
        pass  # Optimal
    elif count <= 30:
        warnings.append(("Image count", f"{count} images (optimal: 15-25, acceptable up to 30)"))
    else:
        warnings.append(("Image count", f"{count} images (optimal: 15-25, consider pruning)"))

    # ── Image resolution ──
    try:
        from PIL import Image as PILImage
        resolutions = set()
        min_dim = float('inf')
        for img_path in images:
            try:
                with PILImage.open(img_path) as im:
                    resolutions.add(im.size)
                    min_dim = min(min_dim, min(im.size))
            except Exception:
                warnings.append(("Resolution", f"Could not read {img_path.name}"))

        if len(resolutions) > 1:
            res_strs = [f"{w}x{h}" for w, h in sorted(resolutions)]
            errors.append(("Resolution", f"Mixed resolutions found: {', '.join(res_strs)}"))
        elif resolutions:
            w, h = next(iter(resolutions))
            if min_dim < 512:
                errors.append(("Resolution", f"{w}x{h} — minimum 512px on shortest side"))

        if resolutions and len(resolutions) == 1:
            w, h = next(iter(resolutions))
            # Pass — report for info
    except ImportError:
        warnings.append(("Resolution", "Pillow not installed — skipping resolution check (pip install Pillow)"))

    # ── Angle diversity ──
    angles = set()
    angle_counts = {}
    for img in images:
        meta = manifest_lookup.get(img.name, {})
        angle = meta.get("angle", "")
        if angle:
            base_angle = angle.split("_")[0] if angle.startswith("closeup") else angle
            angles.add(base_angle)
            angle_counts[base_angle] = angle_counts.get(base_angle, 0) + 1

    if len(angles) < 2:
        warnings.append(("Angle diversity", f"{len(angles)} angle(s) found (recommend 4+)"))
    elif len(angles) < 4:
        warnings.append(("Angle diversity", f"{len(angles)} angles found: {', '.join(sorted(angles))} (recommend 4+)"))

    # ── Expression diversity ──
    expressions = set()
    expression_counts = {}
    for img in images:
        meta = manifest_lookup.get(img.name, {})
        expr = meta.get("expression", "")
        if expr:
            expressions.add(expr)
            expression_counts[expr] = expression_counts.get(expr, 0) + 1

    if len(expressions) < 2:
        warnings.append(("Expression diversity", f"{len(expressions)} expression(s) found (recommend 3+)"))
    elif len(expressions) < 3:
        warnings.append(("Expression diversity", f"{len(expressions)} expressions found: {', '.join(sorted(expressions))} (recommend 3+)"))

    # ── Expression dominance check ──
    if expression_counts:
        max_expr_count = max(expression_counts.values())
        max_expr_name = [k for k, v in expression_counts.items() if v == max_expr_count][0]
        if max_expr_count > 6:
            warnings.append(("Expression balance", f"'{max_expr_name}' appears {max_expr_count} times (max recommended: 5-6)"))
        # Check neutral dominance (~30-40% ideal)
        neutral_count = expression_counts.get("neutral", 0)
        neutral_pct = (neutral_count / count * 100) if count else 0
        if neutral_count > 0 and neutral_pct > 50:
            warnings.append(("Expression balance", f"Neutral is {neutral_pct:.0f}% of dataset (target: 30-40%)"))

    # ── Coverage matrix (framing analysis) ──
    framing_counts = {"closeup": 0, "medium": 0, "full_body": 0, "back_misc": 0}
    CLOSEUP_ANGLES = {"closeup_face", "closeup"}
    MEDIUM_ANGLES = {"upper_body", "three_quarter_left", "three_quarter_right", "front",
                     "profile_left", "profile_right", "over_shoulder", "high_angle", "low_angle"}
    FULL_BODY_ANGLES = {"full_body"}
    BACK_MISC_ANGLES = {"back"}

    for img in images:
        meta = manifest_lookup.get(img.name, {})
        angle = meta.get("angle", "")
        if angle in CLOSEUP_ANGLES or angle.startswith("closeup"):
            framing_counts["closeup"] += 1
        elif angle in FULL_BODY_ANGLES:
            framing_counts["full_body"] += 1
        elif angle in BACK_MISC_ANGLES:
            framing_counts["back_misc"] += 1
        elif angle in MEDIUM_ANGLES:
            framing_counts["medium"] += 1
        elif angle:
            framing_counts["medium"] += 1  # default bucket

    # Target percentages for 25-30 images
    coverage_info = []
    if count > 0:
        cu_pct = framing_counts["closeup"] / count * 100
        med_pct = framing_counts["medium"] / count * 100
        fb_pct = framing_counts["full_body"] / count * 100
        bk_pct = framing_counts["back_misc"] / count * 100
        coverage_info = [
            ("Close-up/headshot", framing_counts["closeup"], cu_pct, "33-40%"),
            ("Medium/waist-up", framing_counts["medium"], med_pct, "25-30%"),
            ("Full body", framing_counts["full_body"], fb_pct, "20-25%"),
            ("Back/misc", framing_counts["back_misc"], bk_pct, "10%"),
        ]

        # Flag gaps
        if framing_counts["closeup"] == 0:
            warnings.append(("Coverage gap", "No close-up/headshot images (target: 33-40%)"))
        if framing_counts["full_body"] == 0:
            warnings.append(("Coverage gap", "No full body images (target: 20-25%)"))
        if framing_counts["medium"] == 0:
            warnings.append(("Coverage gap", "No medium/waist-up images (target: 25-30%)"))

    # Store coverage info for reporting
    warnings.append(("Coverage matrix", coverage_info))

    # ── Near-duplicate detection ──
    combo_counts = {}
    for img in images:
        meta = manifest_lookup.get(img.name, {})
        angle = meta.get("angle", "unknown")
        expr = meta.get("expression", "unknown")
        combo = f"{angle}+{expr}"
        combo_counts.setdefault(combo, []).append(img.name)

    duplicates = [(combo, files) for combo, files in combo_counts.items() if len(files) > 1]
    if duplicates:
        dup_details = [f"{combo} ({len(files)}x)" for combo, files in duplicates[:5]]
        warnings.append(("Near-duplicates", f"{len(duplicates)} pairs with same angle+expression: {', '.join(dup_details)}"))

    # ── File format ──
    formats = set()
    for img in images:
        formats.add(img.suffix.lower())

    non_image = formats - {".png", ".jpg", ".jpeg"}
    if non_image:
        warnings.append(("File format", f"Non-standard formats found: {', '.join(non_image)}"))

    return errors, warnings


def cmd_validate(project_dir: Path, args) -> int:
    """Validate training dataset without building ZIP."""
    char_name = args.character.lower()
    char_upper = char_name.upper()
    registry = load_registry(project_dir)

    if not registry:
        print("ERROR: No registry found. Run `init` first.", file=sys.stderr)
        return 1

    char_data = registry.get(char_name)
    if not char_data:
        print(f"ERROR: Character '{char_name}' not in registry", file=sys.stderr)
        return 1

    # Load candidates
    cand_dir = project_dir / "visual" / "lora_candidates" / char_upper
    sel_path = cand_dir / "selection.json"

    if not cand_dir.is_dir():
        print(f"ERROR: No candidates directory at {cand_dir}", file=sys.stderr)
        return 1

    if not sel_path.is_file():
        print(f"ERROR: No selection.json found at {sel_path}", file=sys.stderr)
        return 1

    try:
        with open(sel_path) as f:
            selection = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid selection.json: {e}", file=sys.stderr)
        return 1

    selections = selection.get("selections", {})
    selected_files = [fn for fn, status in selections.items() if status == "selected"]

    if not selected_files:
        print("ERROR: No images selected.", file=sys.stderr)
        return 1

    images = []
    for fn in sorted(selected_files):
        img_path = cand_dir / fn
        if img_path.is_file():
            images.append(img_path)

    if not images:
        print("ERROR: None of the selected files exist on disk.", file=sys.stderr)
        return 1

    # Load manifest
    manifest_path = cand_dir / "manifest.json"
    manifest_lookup = {}
    if manifest_path.is_file():
        try:
            with open(manifest_path) as f:
                manifest = json.load(f)
            for cand in manifest.get("candidates", []):
                manifest_lookup[cand["filename"]] = cand
        except (json.JSONDecodeError, KeyError):
            pass

    # Run validation
    errors, warnings = _validate_training_dataset(project_dir, char_name, manifest_lookup, images)

    # Report
    print(f"\nVALIDATION REPORT — {char_upper}")
    print("=" * 50)

    # Image count (always report)
    count = len(images)
    if any(c == "Image count" for c, _ in errors):
        for check, msg in errors:
            if check == "Image count":
                print(f"  [ERROR] {check}: {msg}")
    elif any(c == "Image count" for c, _ in warnings):
        for check, msg in warnings:
            if check == "Image count":
                print(f"  [WARN]  {check}: {msg}")
    else:
        print(f"  [PASS]  Image count: {count} (optimal: 15-25)")

    # Resolution
    res_reported = False
    for check, msg in errors:
        if check == "Resolution":
            print(f"  [ERROR] {check}: {msg}")
            res_reported = True
    for check, msg in warnings:
        if check == "Resolution":
            print(f"  [WARN]  {check}: {msg}")
            res_reported = True
    if not res_reported:
        try:
            from PIL import Image as PILImage
            with PILImage.open(images[0]) as im:
                print(f"  [PASS]  Resolution: all {im.size[0]}x{im.size[1]}")
        except Exception:
            print(f"  [PASS]  Resolution: consistent (could not read dimensions)")

    # Angle diversity
    angles_in_data = set()
    for img in images:
        meta = manifest_lookup.get(img.name, {})
        angle = meta.get("angle", "")
        if angle:
            base = angle.split("_")[0] if angle.startswith("closeup") else angle
            angles_in_data.add(base)
    angle_warned = any(c == "Angle diversity" for c, _ in warnings)
    if angle_warned:
        for check, msg in warnings:
            if check == "Angle diversity":
                print(f"  [WARN]  {check}: {msg}")
    else:
        print(f"  [PASS]  Angle diversity: {len(angles_in_data)} angles found")

    # Expression diversity
    exprs_in_data = set()
    for img in images:
        meta = manifest_lookup.get(img.name, {})
        expr = meta.get("expression", "")
        if expr:
            exprs_in_data.add(expr)
    expr_warned = any(c == "Expression diversity" for c, _ in warnings)
    if expr_warned:
        for check, msg in warnings:
            if check == "Expression diversity":
                print(f"  [WARN]  {check}: {msg}")
    else:
        print(f"  [PASS]  Expression diversity: {len(exprs_in_data)} expressions found")

    # Expression balance
    for check, msg in warnings:
        if check == "Expression balance":
            print(f"  [WARN]  {check}: {msg}")

    # Coverage matrix
    for check, msg in warnings:
        if check == "Coverage matrix" and isinstance(msg, list) and msg:
            print(f"\n  COVERAGE MATRIX (framing distribution)")
            print(f"  {'Framing':<22} {'Count':>5}  {'Actual':>7}  {'Target':>8}")
            print(f"  {'-'*22} {'-'*5}  {'-'*7}  {'-'*8}")
            for label, cnt, pct, target in msg:
                status = "OK" if cnt > 0 else "GAP"
                print(f"  {label:<22} {cnt:>5}  {pct:>6.0f}%  {target:>8}  {status}")
            print()

    # Coverage gaps
    for check, msg in warnings:
        if check == "Coverage gap":
            print(f"  [WARN]  {check}: {msg}")

    # Near-duplicates
    dup_warned = any(c == "Near-duplicates" for c, _ in warnings)
    if dup_warned:
        for check, msg in warnings:
            if check == "Near-duplicates":
                print(f"  [WARN]  {check}: {msg}")
    else:
        print(f"  [PASS]  Near-duplicates: none detected")

    # File format
    fmt_warned = any(c == "File format" for c, _ in warnings)
    if fmt_warned:
        for check, msg in warnings:
            if check == "File format":
                print(f"  [WARN]  {check}: {msg}")
    else:
        fmts = set(img.suffix.lower() for img in images)
        print(f"  [PASS]  File format: all {'/'.join(sorted(fmts)).upper()}")

    print("=" * 50)
    # Don't count Coverage matrix (info-only) in warning tally
    real_warnings = [(c, m) for c, m in warnings if c != "Coverage matrix"]
    n_err = len(errors)
    n_warn = len(real_warnings)
    if n_err > 0:
        print(f"  {n_err} ERROR, {n_warn} WARN — BLOCKED (fix errors before training)")
        if not getattr(args, 'force', False):
            return 1
        else:
            print("  --force: proceeding despite errors")
    elif n_warn > 0:
        print(f"  0 ERROR, {n_warn} WARN — OK to proceed")
    else:
        print(f"  0 ERROR, 0 WARN — PASS")

    return 0


# ── Main ──────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description="LoRA training pipeline: prepare, submit, poll, auto-register",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Subcommands:
  init      Initialize registry from breakdown.json
  prepare   Caption images + create ZIP for upload
  validate  Check dataset against best practices (pre-flight)
  submit    Submit training job to fal.ai (non-blocking)
  status    Check training status (auto-downloads on completion)
  show      Display current registry

Examples:
  python3 train_lora.py leviathan/ init --seed-existing
  python3 train_lora.py leviathan/ prepare JINX
  python3 train_lora.py leviathan/ validate KIAN
  python3 train_lora.py leviathan/ submit JINX --type t2i --steps 1000
  python3 train_lora.py leviathan/ status JINX --wait
  python3 train_lora.py leviathan/ show
        """,
    )
    parser.add_argument("project_dir", help="Project directory (e.g., leviathan/)")

    subparsers = parser.add_subparsers(dest="command", required=True)

    # init
    p_init = subparsers.add_parser("init", help="Initialize registry from breakdown.json")
    p_init.add_argument("--seed-existing", action="store_true",
                        help="Populate known Jinx/Kian LoRA data")

    # prepare
    p_prepare = subparsers.add_parser("prepare", help="Caption images + create ZIP")
    p_prepare.add_argument("character", help="Character name (e.g., JINX)")
    p_prepare.add_argument("--dry-run", action="store_true", help="Show plan without writing")
    p_prepare.add_argument("--from-candidates", action="store_true",
                           help="Build dataset from curated lora_candidates/ selection instead of refs/")
    p_prepare.add_argument("--force", action="store_true",
                           help="Proceed despite validation errors")
    p_prepare.add_argument("--target-model", dest="target_model",
                           choices=["t2i", "flux", "z_image", "z_image_base", "video", "wan"],
                           default="z_image",
                           help="Target model for caption style/length (default: z_image)")

    # submit
    p_submit = subparsers.add_parser("submit", help="Submit training to fal.ai")
    p_submit.add_argument("character", help="Character name (e.g., JINX)")
    p_submit.add_argument("--type", required=True,
                          choices=["t2i", "z_image", "z_image_base", "video"],
                          help="Training type: t2i (Flux 2), z_image (Z-Image Turbo), z_image_base (Z-Image Base), or video (WAN 2.2)")
    p_submit.add_argument("--steps", type=int, help="Training steps (default: from config)")
    p_submit.add_argument("--lr", type=float, help="Learning rate (default: from config)")

    # status
    p_status = subparsers.add_parser("status", help="Check training status")
    p_status.add_argument("character", nargs="?", help="Character name (omit to check all)")
    p_status.add_argument("--wait", action="store_true", help="Poll every 60s until done")

    # validate
    p_validate = subparsers.add_parser("validate", help="Check dataset against best practices")
    p_validate.add_argument("character", help="Character name (e.g., KIAN)")
    p_validate.add_argument("--force", action="store_true", help="Continue despite errors")

    # show
    subparsers.add_parser("show", help="Display current registry")

    args = parser.parse_args()

    # Resolve project
    project_dir = resolve_project_path(args.project_dir)

    # Dispatch
    if args.command == "init":
        return cmd_init(project_dir, args)
    elif args.command == "prepare":
        return cmd_prepare(project_dir, args)
    elif args.command == "validate":
        return cmd_validate(project_dir, args)
    elif args.command == "submit":
        return cmd_submit(project_dir, args)
    elif args.command == "status":
        return cmd_status(project_dir, args)
    elif args.command == "show":
        return cmd_show(project_dir, args)

    return 0


if __name__ == "__main__":
    sys.exit(main())
