#!/usr/bin/env python3
"""Rewrite LoRA training captions to minimal format.

Z-Image Turbo trains best with minimal captions (1-10 words).
Heavy captions cause disfigured features.

Format: TRIGGER, class word, angle/framing
"""

import sys
from pathlib import Path

# Character definitions
CHARACTERS = {
    "JINX": {
        "trigger": "JNXCHAR",
        "class_word": "woman",
    },
    "KIAN": {
        "trigger": "KIANCHAR",
        "class_word": "armored android",
    },
    "VAREK": {
        "trigger": "VRKCHAR",
        "class_word": "man in uniform",
    },
}

# Angle/framing keywords to extract from filenames
ANGLE_KEYWORDS = [
    "profile_left", "profile_right", "profile",
    "three_quarter_left", "three_quarter_right", "three-quarter_left", "three-quarter_right",
    "three_quarter", "three-quarter",
    "over_shoulder", "over-shoulder", "ots",
    "high_angle", "low_angle",
    "front_face", "front_facing", "front",
    "back_view", "back",
    "upper_body",
    "full_body", "full body",
]

FRAMING_KEYWORDS = [
    "extreme_closeup", "extreme_close", "ecu",
    "closeup", "close_up", "close-up",
    "medium_closeup", "medium_close",
    "medium_shot", "medium",
    "wide_shot", "wide",
]


def extract_angle_from_filename(filename: str) -> str:
    """Extract angle and framing info from filename."""
    name = filename.lower().replace("-", "_")

    parts = []

    # Extract framing
    if any(k in name for k in ["extreme_closeup", "extreme_close", "ecu"]):
        parts.append("extreme close-up")
    elif any(k in name for k in ["closeup", "close_up", "close"]):
        parts.append("close-up")
    elif "medium" in name and "close" in name:
        parts.append("medium close-up")
    elif "medium" in name:
        parts.append("medium shot")
    elif any(k in name for k in ["full_body", "full body"]):
        parts.append("full body")
    elif "upper_body" in name:
        parts.append("upper body")
    elif "wide" in name:
        parts.append("wide shot")

    # Extract angle
    if "profile_left" in name or "profile left" in name:
        parts.append("left profile")
    elif "profile_right" in name or "profile right" in name:
        parts.append("right profile")
    elif "profile" in name:
        parts.append("profile")
    elif "three_quarter_right" in name or "three-quarter_right" in name:
        parts.append("three-quarter right")
    elif "three_quarter_left" in name or "three-quarter_left" in name:
        parts.append("three-quarter left")
    elif "three_quarter" in name or "three-quarter" in name:
        parts.append("three-quarter")
    elif "over_shoulder" in name or "over-shoulder" in name or "ots" in name:
        parts.append("over the shoulder")
    elif "back" in name and "back" not in [p for p in parts]:
        parts.append("from behind")
    elif "front" in name:
        parts.append("front view")

    if "high_angle" in name:
        parts.append("high angle")
    elif "low_angle" in name:
        parts.append("low angle")

    return ", ".join(parts) if parts else ""


def rewrite_captions(training_dir: Path, char_name: str, dry_run: bool = False):
    """Rewrite all caption .txt files for a character."""
    char_config = CHARACTERS.get(char_name.upper())
    if not char_config:
        print(f"ERROR: Unknown character {char_name}", file=sys.stderr)
        return

    trigger = char_config["trigger"]
    class_word = char_config["class_word"]
    char_dir = training_dir / char_name.upper()

    if not char_dir.exists():
        print(f"ERROR: Directory not found: {char_dir}", file=sys.stderr)
        return

    # Find all image files
    image_exts = {".jpeg", ".jpg", ".png", ".webp"}
    images = sorted([f for f in char_dir.iterdir() if f.suffix.lower() in image_exts])

    print(f"\n{'='*60}")
    print(f"  {char_name.upper()} — {len(images)} images")
    print(f"  Trigger: {trigger}, Class: {class_word}")
    print(f"{'='*60}")

    for img in images:
        txt_path = img.with_suffix(".txt")
        angle_info = extract_angle_from_filename(img.stem)

        # Build minimal caption
        if angle_info:
            caption = f"{trigger}, {class_word}, {angle_info}"
        else:
            caption = f"{trigger}, {class_word}"

        word_count = len(caption.split())

        # Read old caption for comparison
        old_caption = ""
        if txt_path.exists():
            old_caption = txt_path.read_text().strip()
        old_words = len(old_caption.split()) if old_caption else 0

        status = "NEW" if not txt_path.exists() else f"{old_words}w → {word_count}w"
        print(f"  [{status:>12}] {img.name}")
        print(f"               → {caption}")

        if not dry_run:
            txt_path.write_text(caption + "\n")

    print(f"\n  Done: {len(images)} captions written")


def main():
    import argparse
    parser = argparse.ArgumentParser(description="Rewrite LoRA captions to minimal format")
    parser.add_argument("project_dir", type=Path, help="Project directory")
    parser.add_argument("--character", "-c", help="Single character (default: all)")
    parser.add_argument("--dry-run", action="store_true", help="Preview without writing")
    args = parser.parse_args()

    training_dir = args.project_dir / "visual" / "lora_training"
    if not training_dir.exists():
        print(f"ERROR: Training dir not found: {training_dir}", file=sys.stderr)
        sys.exit(1)

    chars = [args.character.upper()] if args.character else list(CHARACTERS.keys())

    for char in chars:
        rewrite_captions(training_dir, char, dry_run=args.dry_run)

    total = sum(1 for c in chars
                for f in (training_dir / c).iterdir()
                if f.suffix.lower() in {".jpeg", ".jpg", ".png", ".webp"})

    mode = "DRY RUN" if args.dry_run else "COMPLETE"
    print(f"\n{'='*60}")
    print(f"  CAPTION REWRITE — {mode}")
    print(f"  Characters: {len(chars)}")
    print(f"  Total captions: {total}")
    print(f"{'='*60}\n")


if __name__ == "__main__":
    main()
