#!/usr/bin/env python3
"""
generate_turnarounds.py — Character turnaround grid generation tool.

Workflow:
  1. Seedream edit: 4-panel grid (frontal, 3/4, profile, back) with identity ref
  2. Programmatic validation: aspect ratio, panel count, truncation
  3. Vision validation (optional): angle correctness, eye consistency, wardrobe, background
  4. Smart split: detect black divider columns, crop panels precisely
  5. SeedVR2 upscale: 2x super-resolution on each panel (no identity drift)
  6. Save to character ref directory

Usage:
    # Hero is auto-resolved via ProjectPaths.resolve_ref("char", subject, "identity", "hero")
    python3 tools/generate_turnarounds.py --character sadie \
        --project afterimage \
        --wardrobe "Oversized white t-shirt that rides up showing stomach, cotton underwear, barefoot" \
        --hair "Blonde hair in a loose messy bun with stray wisps"

    # Explicit --hero override (bypasses resolver)
    python3 tools/generate_turnarounds.py --character dusty \
        --hero /path/to/dusty_identity_hero.jpeg \
        --project afterimage \
        --wardrobe "White shirt sleeves rolled to elbows, dark trousers, dark shoes" \
        --hair "Messy black hair falling across forehead" \
        --validate flash --framing both
"""

import argparse
import logging
import os
import sys
import time
from pathlib import Path

import numpy as np
from PIL import Image
from io import BytesIO

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

logger = logging.getLogger(__name__)

MAX_GRID_RETRIES = 3
MAX_UPSCALE_RETRIES = 2


# ── Programmatic Validation ──────────────────────────────────────────

def validate_grid_programmatic(grid_img: Image.Image) -> tuple[bool, str]:
    """Check grid aspect ratio and minimum size. Zero cost.

    Returns (passed, reason).
    """
    ratio = grid_img.width / grid_img.height
    if ratio < 1.2:
        return False, f"Grid is not landscape (ratio {ratio:.2f}, need >1.2). Got {grid_img.width}x{grid_img.height}"
    if grid_img.width < 2000:
        return False, f"Grid too narrow ({grid_img.width}px, need >=2000)"
    return True, "OK"


def validate_panels_programmatic(panels: list[Image.Image]) -> tuple[bool, str]:
    """Check panel count and truncation. Zero cost.

    Returns (passed, reason).
    """
    if len(panels) != 4:
        return False, f"Expected 4 panels, got {len(panels)}"

    widths = [p.width for p in panels]
    min_w, max_w = min(widths), max(widths)
    if max_w > 0 and min_w / max_w < 0.70:
        return False, (
            f"Panel truncation detected: narrowest={min_w}px, widest={max_w}px "
            f"(ratio {min_w/max_w:.2f}, need >=0.70)"
        )

    too_small = [i for i, w in enumerate(widths) if w < 300]
    if too_small:
        return False, f"Panel(s) {too_small} too narrow (<300px): widths={widths}"

    return True, "OK"


def validate_upscale_programmatic(source: Image.Image, upscaled: Image.Image,
                                   factor: int) -> tuple[bool, str]:
    """Check upscale resolution and sanity. Zero cost."""
    expected_w = source.width * factor * 0.9
    expected_h = source.height * factor * 0.9
    if upscaled.width < expected_w or upscaled.height < expected_h:
        return False, (
            f"Upscale too small: {upscaled.width}x{upscaled.height}, "
            f"expected ~{source.width*factor}x{source.height*factor}"
        )

    from PIL import ImageStat
    stat = ImageStat.Stat(upscaled.convert("RGB"))
    # stddev per channel — if all near zero, image is blank
    if max(stat.stddev) < 10:
        return False, "Upscaled image appears blank (max channel stddev < 10)"

    return True, "OK"


# ── Interactive Validation ───────────────────────────────────────────

def validate_interactive(grid_path: Path, wardrobe: str, hair: str,
                          framing: str) -> bool:
    """Open image for human review with checklist. Returns True if approved."""
    os.system(f'open "{grid_path}"')
    body_check = "Full body head to toe" if framing == "full_body" else "Waist up, face clearly visible"
    print(f"\n{'='*50}")
    print("  INTERACTIVE VALIDATION")
    print(f"{'='*50}")
    print("  Check the opened image:")
    print(f"  [ ] 4 distinct panels, each a different angle")
    print(f"  [ ] Panel 1: frontal, 2: 3/4, 3: profile, 4: back")
    print(f"  [ ] {body_check} in all panels")
    print(f"  [ ] Hair: {hair}")
    print(f"  [ ] Wardrobe: {wardrobe}")
    print(f"  [ ] Eyes same color in all panels")
    print(f"  [ ] Clean gray background")
    response = input("\n  Pass? (y=accept, n=regenerate, q=quit): ").strip().lower()
    if response == "q":
        print("  Quitting.")
        sys.exit(0)
    return response == "y"


# ── Grid splitting ────────────────────────────────────────────────────

def smart_split_grid(img: Image.Image, expected_panels: int = 4,
                     threshold: int = 30) -> list[Image.Image]:
    """Split a multi-panel grid image by detecting black divider columns.

    Finds columns where mean brightness < threshold (divider lines),
    then crops between them. Returns list of panel Images.

    Never divides by equal width — Seedream panels are uneven.
    """
    arr = np.array(img)
    col_means = arr.mean(axis=(0, 2))  # mean brightness per column
    dark_cols = col_means < threshold

    # Find runs of dark columns
    divider_regions = []
    in_div = False
    div_start = 0
    for i, dark in enumerate(dark_cols):
        if dark and not in_div:
            div_start = i
            in_div = True
        elif not dark and in_div:
            divider_regions.append((div_start, i))
            in_div = False

    # Build split points: panel regions between dividers
    split_points = [0]
    for start, end in divider_regions:
        split_points.append(start)
        split_points.append(end)
    split_points.append(img.width)

    # Extract panels (skip divider regions)
    panels = []
    for i in range(0, len(split_points) - 1, 2):
        left, right = split_points[i], split_points[i + 1]
        if right - left > 100:  # skip tiny slivers
            panels.append(img.crop((left, 0, right, img.height)))

    if len(panels) != expected_panels:
        logger.warning(
            "Expected %d panels, got %d (dividers found: %d). "
            "Falling back to equal-width split.",
            expected_panels, len(panels), len(divider_regions),
        )
        panel_w = img.width // expected_panels
        panels = [
            img.crop((i * panel_w, 0, (i + 1) * panel_w, img.height))
            for i in range(expected_panels)
        ]

    return panels


# ── Grid prompt building ─────────────────────────────────────────────

ANGLE_NAMES = ["front", "profile", "back", "three_quarter"]

ANGLE_DESCRIPTIONS_FULL = {
    "front": "FRONTAL: Full body head to toe. Facing camera straight on, neutral expression, arms at sides. Both eyes fully visible.",
    "profile": "FULL SIDE PROFILE: Full body head to toe. Body facing directly to the right, perpendicular to camera. Complete 90-degree side view. Nose pointing to the right. Only ONE eye visible. Face clearly lit — NOT in shadow, NOT silhouetted.",
    "back": "BACK: Full body head to toe. Turned completely away from camera, back of head and body visible. No face visible.",
    "three_quarter": "THREE-QUARTER VIEW: Full body head to toe. Body turned about 30-45 degrees to the right. BOTH eyes still visible. Gaze directed away from camera, looking off to the right. One shoulder closer to camera. Face angled but NOT in profile — both cheeks visible.",
}

ANGLE_DESCRIPTIONS_MEDIUM = {
    "front": "FRONTAL: Facing camera, neutral expression, head and shoulders centered. Both eyes fully visible.",
    "profile": "FULL SIDE PROFILE: Turned 90 degrees right, perpendicular to camera. Complete side view. Nose pointing right. Only ONE eye visible. Face clearly lit — NOT in shadow, NOT silhouetted.",
    "back": "BACK: Turned away, back of head and upper back visible. No face visible.",
    "three_quarter": "THREE-QUARTER VIEW: Turned about 30-45 degrees right. BOTH eyes still visible, gaze directed away from camera, looking off to the right. Face angled but NOT in profile — we can still see both cheeks.",
}


def build_grid_prompt(character_desc: str, wardrobe: str, hair: str,
                       framing: str = "full_body",
                       style_suffix: str = "Shot on Kodak Portra 400. Photorealistic.") -> str:
    """Build the Seedream grid prompt for a given framing type."""
    if framing == "medium":
        angles = ANGLE_DESCRIPTIONS_MEDIUM
        body_instruction = "Medium shot, waist up, focus on face and upper body."
    else:
        angles = ANGLE_DESCRIPTIONS_FULL
        body_instruction = "Full body, head to toe."

    angle_lines = "\n".join(
        f"PANEL {i+1} — {desc}."
        for i, desc in enumerate(angles.values())
    )
    return (
        f"Figure 1 is a character reference — preserve this person's face and identity exactly.\n\n"
        f"A single WIDE LANDSCAPE image containing exactly FOUR panels arranged side by side. "
        f"Each panel shows the SAME person — same face, same body, same proportions, same outfit, same hair. "
        f"{body_instruction} Plain 18% gray background. Flat, neutral, even studio lighting from all sides — no dramatic shadows, no directional light, no moody lighting. Every panel lit identically.\n"
        f"{angle_lines}\n"
        f"{character_desc} {hair}. {wardrobe}. "
        f"Each panel is a DIFFERENT angle. Thin black dividers between panels. "
        f"{style_suffix}"
    )


# ── Generation functions ─────────────────────────────────────────────

def generate_grid(hero_url: str, prompt: str, model: str = "seedream-v4.5") -> tuple:
    """Generate a turnaround grid via Seedream edit endpoint.

    Args:
        hero_url: Pre-uploaded fal.ai URL for the hero ref (upload once, reuse across retries).

    Forces landscape_16_9 image_size to prevent 1:1 grids.
    Returns (PIL.Image, raw_bytes) or raises on failure.
    """
    import fal_client
    import requests
    logger.info("Generating turnaround grid via %s...", model)
    t0 = time.time()
    result = fal_client.subscribe(
        "fal-ai/bytedance/seedream/v4.5/edit",
        arguments={
            "prompt": prompt,
            "image_urls": [hero_url],
            "num_images": 1,
            "image_size": "landscape_16_9",
            "enable_safety_checker": False,
        },
    )
    elapsed = time.time() - t0

    images = result.get("images", [])
    if not images:
        raise RuntimeError("Seedream returned no images")

    img_url = images[0]["url"]
    resp = requests.get(img_url, timeout=120)
    resp.raise_for_status()

    img = Image.open(BytesIO(resp.content))
    logger.info("Grid generated: %dx%d in %.1fs", img.width, img.height, elapsed)
    return img, resp.content


def upscale_panel(panel_url: str, factor: int = 2) -> tuple:
    """Upscale a panel via SeedVR2.

    Args:
        panel_url: Pre-uploaded fal.ai URL (upload once, reuse across retries).

    Returns (PIL.Image, raw_bytes).
    """
    import requests
    import fal_client
    t0 = time.time()
    result = fal_client.subscribe(
        "fal-ai/seedvr/upscale/image",
        arguments={"image_url": panel_url, "upscale_factor": factor},
    )
    elapsed = time.time() - t0

    img_url = result["image"]["url"]
    resp = requests.get(img_url, timeout=120)
    resp.raise_for_status()

    img = Image.open(BytesIO(resp.content))
    logger.info("Upscaled: %dx%d in %.1fs", img.width, img.height, elapsed)
    return img, resp.content


# ── Main workflow with validation + retry ────────────────────────────

def run_turnaround(character: str, hero_path: Path, out_dir: Path,
                    prompt: str, framing: str, model: str,
                    validate_mode: str, upscale_factor: int,
                    skip_upscale: bool, wardrobe: str, hair: str,
                    tag: str = "", expected_panels: int = 4) -> dict:
    """Run a single turnaround set (one framing type).

    Returns dict with results summary.
    """
    framing_suffix = "_medium" if framing == "medium" else ""
    suffix = f"_{tag}{framing_suffix}"
    total_cost = 0.0

    # ── Step 1: Generate grid with retry loop ────────────────────────
    print(f"\n  [1/3] Generating {framing} grid (max {MAX_GRID_RETRIES} attempts)...")

    grid_img = None
    panels = None

    # Upload hero ref once (reuse URL across retry attempts)
    import fal_client as _fal
    logger.info("Uploading hero ref: %s", hero_path.name)
    hero_url = _fal.upload_file(str(hero_path))

    for attempt in range(1, MAX_GRID_RETRIES + 1):
        print(f"    Attempt {attempt}/{MAX_GRID_RETRIES}...", end=" ", flush=True)

        try:
            grid_img, grid_bytes = generate_grid(hero_url, prompt, model)
        except Exception as e:
            print(f"GENERATION FAILED: {e}")
            continue

        total_cost += 0.04

        # Gate 1A: Programmatic grid check
        passed, reason = validate_grid_programmatic(grid_img)
        if not passed:
            print(f"GRID CHECK FAIL: {reason}")
            continue

        # Gate 1B: Smart split + panel check
        panels = smart_split_grid(grid_img, expected_panels=expected_panels)
        passed, reason = validate_panels_programmatic(panels)
        if not passed:
            print(f"PANEL CHECK FAIL: {reason}")
            panels = None
            continue

        # Gate 1C: Vision validation (if requested)
        if validate_mode == "flash":
            try:
                # Add pipeline root to path for critic imports
                _pipeline_root = str(Path(__file__).resolve().parent.parent / "pipeline")
                if _pipeline_root not in sys.path:
                    sys.path.insert(0, _pipeline_root)
                from recoil.pipeline._lib.critics.turnaround_critic import validate_turnaround_grid
                passed, reason = validate_turnaround_grid(
                    grid_img, panels, wardrobe, hair, framing)
                if not passed:
                    print(f"VISION CHECK FAIL: {reason}")
                    panels = None
                    continue
            except ImportError:
                print("(flash validator not yet built, skipping)")

        elif validate_mode == "opus":
            # Save temp grid for interactive review
            temp_grid = out_dir / f"{character}_turnaround_grid{suffix}_temp.png"
            temp_grid.write_bytes(grid_bytes)
            if not validate_interactive(temp_grid, wardrobe, hair, framing):
                print("    Rejected by reviewer, regenerating...")
                panels = None
                temp_grid.unlink(missing_ok=True)
                continue
            temp_grid.unlink(missing_ok=True)

        # All checks passed
        print(f"OK ({grid_img.width}x{grid_img.height})")
        break
    else:
        print(f"\n  ERROR: All {MAX_GRID_RETRIES} grid attempts failed for {framing}.")
        return {"success": False, "framing": framing, "cost": total_cost}

    # Save grid
    grid_path = out_dir / f"{character}_turnaround_grid{suffix}.png"
    grid_path.write_bytes(grid_bytes)

    # ── Step 2: Save split panels ────────────────────────────────────
    print(f"\n  [2/3] Splitting grid into {len(panels)} panels...")
    panel_paths = []
    for panel, name in zip(panels, ANGLE_NAMES):
        p = out_dir / f"{character}_turn_{name}{suffix}.png"
        panel.save(p, "PNG")
        panel_paths.append(p)
        print(f"    {name}: {panel.width}x{panel.height}")

    # ── Step 3: SeedVR2 upscale with retry ───────────────────────────
    hq_paths = []
    if not skip_upscale:
        print(f"\n  [3/3] Upscaling panels ({upscale_factor}x SeedVR2)...")
        for panel_path, panel_img, name in zip(panel_paths, panels, ANGLE_NAMES):
            up_ok = False
            panel_url = _fal.upload_file(str(panel_path))
            for up_attempt in range(1, MAX_UPSCALE_RETRIES + 1):
                print(f"    [{name}] ", end="", flush=True)
                try:
                    up_img, up_bytes = upscale_panel(panel_url, upscale_factor)
                except Exception as e:
                    print(f"UPSCALE FAILED: {e}")
                    continue

                passed, reason = validate_upscale_programmatic(panel_img, up_img, upscale_factor)
                if not passed:
                    print(f"UPSCALE CHECK FAIL: {reason}")
                    continue

                up_path = out_dir / f"{character}_turn_{name}{suffix}_hq.png"
                up_path.write_bytes(up_bytes)
                total_cost += 0.001 * (up_img.width * up_img.height / 1_000_000)
                print(f"{up_img.width}x{up_img.height}")
                hq_paths.append(up_path)
                up_ok = True
                break

            if not up_ok:
                print(f"    WARNING: upscale failed for {name} after {MAX_UPSCALE_RETRIES} attempts")
    else:
        print(f"\n  [3/3] Upscale skipped")

    return {
        "success": True,
        "framing": framing,
        "grid": grid_path,
        "panels": panel_paths,
        "hq_panels": hq_paths,
        "cost": total_cost,
    }


def main():
    parser = argparse.ArgumentParser(
        description="Generate character turnaround refs (grid + split + upscale + validate)")
    parser.add_argument("--character", required=True, help="Character name (e.g. sadie, dusty)")
    parser.add_argument("--hero", default=None,
                        help="Explicit path to hero/identity reference image (overrides resolver)")
    parser.add_argument("--project", required=True, help="Project name (e.g. afterimage)")
    parser.add_argument("--wardrobe", required=True, help="Wardrobe description")
    parser.add_argument("--hair", required=True, help="Hair description")
    parser.add_argument("--desc", default="", help="Additional character description")
    parser.add_argument("--upscale-factor", type=int, default=2, help="SeedVR2 upscale factor")
    parser.add_argument("--model", default="seedream-v4.5", help="Seedream model ID")
    parser.add_argument("--framing", default="full_body",
                        choices=["full_body", "medium", "both"],
                        help="Turnaround framing type")
    parser.add_argument("--validate", default="none",
                        choices=["flash", "opus", "none"],
                        help="Validation mode: flash (Gemini), opus (interactive), none")
    parser.add_argument("--tag", default=None,
                        help="Output tag for filenames (e.g. 'leather_jacket'). "
                             "Auto-generates from timestamp if not provided. "
                             "Prevents overwriting previous runs.")
    parser.add_argument("--dry-run", action="store_true", help="Show prompt without generating")
    parser.add_argument("--skip-upscale", action="store_true", help="Skip SeedVR2 upscale")
    parser.add_argument("--style-suffix", default="Shot on Kodak Portra 400. Photorealistic.",
                        help="Replaces the photorealistic suffix (e.g. 'Warm 2D cartoon style.')")
    parser.add_argument("--custom-prompt", default=None,
                        help="Override build_grid_prompt entirely with a custom prompt string")
    parser.add_argument("--panels", type=int, default=4,
                        help="Expected panel count in generated grid (default 4)")
    args = parser.parse_args()

    # Generate tag if not provided — ensures no overwrites
    if args.tag is None:
        args.tag = time.strftime("%Y%m%d_%H%M%S")

    logging.basicConfig(level=logging.INFO, format="%(message)s")

    # Resolve hero via ProjectPaths resolver (or explicit --hero override)
    from recoil.core.paths import ProjectPaths, RefNotFoundError

    if args.hero:
        hero_path = Path(args.hero).resolve()
        if not hero_path.exists():
            print(f"ERROR: Hero image not found: {hero_path}")
            sys.exit(1)
    else:
        try:
            pp = ProjectPaths.for_project(args.project)
            resolved = pp.resolve_ref("char", args.character, "identity", "hero")
            hero_path = resolved.path
        except (RefNotFoundError, FileNotFoundError) as exc:
            print(f"ERROR: Could not resolve hero for {args.character!r} in project {args.project!r}: {exc}")
            print("  Use --hero /path/to/image to specify explicitly.")
            sys.exit(1)

    # Output directory
    projects_root = Path(__file__).resolve().parent.parent.parent / "projects"
    out_dir = projects_root / args.project / "output" / "refs" / "characters" / args.character
    out_dir.mkdir(parents=True, exist_ok=True)

    # Determine which framings to run
    framings = ["full_body", "medium"] if args.framing == "both" else [args.framing]

    if args.dry_run:
        for framing in framings:
            if args.custom_prompt:
                prompt = args.custom_prompt
            else:
                prompt = build_grid_prompt(args.desc, args.wardrobe, args.hair, framing,
                                           style_suffix=args.style_suffix)
            print(f"\n=== DRY RUN — {args.character} {framing} turnarounds ===")
            print(f"Hero: {hero_path}")
            print(f"Output: {out_dir}")
            print(f"Tag: {args.tag}")
            print(f"Grid file: {args.character}_turnaround_grid_{args.tag}.png")
            print(f"Model: {args.model}")
            print(f"Validate: {args.validate}")
            print(f"Panels: {args.panels}")
            print(f"Upscale: {'skip' if args.skip_upscale else f'{args.upscale_factor}x SeedVR2'}")
            print(f"\nPrompt ({len(prompt)} chars):\n{prompt}")
        return

    print(f"\n{'='*50}")
    print(f"  {args.character.upper()} TURNAROUNDS")
    print(f"{'='*50}")
    print(f"  Hero: {hero_path.name}")
    print(f"  Output: {out_dir}")
    print(f"  Tag: {args.tag}")
    print(f"  Model: {args.model}")
    print(f"  Framings: {', '.join(framings)}")
    print(f"  Validate: {args.validate}")

    all_results = []

    for framing in framings:
        print(f"\n{'─'*50}")
        print(f"  {framing.upper()} SET")
        print(f"{'─'*50}")

        if args.custom_prompt:
            prompt = args.custom_prompt
        else:
            prompt = build_grid_prompt(args.desc, args.wardrobe, args.hair, framing,
                                       style_suffix=args.style_suffix)
        result = run_turnaround(
            character=args.character,
            hero_path=hero_path,
            out_dir=out_dir,
            prompt=prompt,
            framing=framing,
            model=args.model,
            validate_mode=args.validate,
            upscale_factor=args.upscale_factor,
            skip_upscale=args.skip_upscale,
            wardrobe=args.wardrobe,
            hair=args.hair,
            tag=args.tag,
            expected_panels=args.panels,
        )
        all_results.append(result)

    # Summary
    total_cost = sum(r["cost"] for r in all_results)
    successes = [r for r in all_results if r["success"]]
    failures = [r for r in all_results if not r["success"]]

    print(f"\n{'='*50}")
    print(f"  {args.character.upper()} TURNAROUNDS COMPLETE")
    print(f"{'='*50}")
    for r in successes:
        print(f"  {r['framing']}: PASS — grid + {len(r['panels'])} panels + {len(r['hq_panels'])} HQ")
    for r in failures:
        print(f"  {r['framing']}: FAIL — all retry attempts exhausted")
    print(f"  Est. cost: ${total_cost:.3f}")
    print(f"{'='*50}")

    # Open grid results
    if sys.platform == "darwin" and successes:
        grid_files = [str(r["grid"]) for r in successes]
        os.system("open " + " ".join(f'"{f}"' for f in grid_files))


if __name__ == "__main__":
    main()
