#!/usr/bin/env python3
"""
batch_threepass.py — Full three-pass LoRA training candidate generation.

Runs engine_shootout.py --threepass across all angles × expression tiers,
producing a complete set of LoRA training candidates from a single hero image.

Pipeline: Qwen MA → NBP (bg+expr) → SeedVR2
  - Expression angles (front, close-ups): full 5-expression range, 3-pass
  - Mild expression angles (3/4 views): 3 expressions, 3-pass
  - Body angles (low, high, full_body): neutral, 3-pass w/ proportionality prompting
  - Profile angles: neutral only, 2-pass (skip NBP)
  - Back angles: neutral only, 2-pass (skip NBP)
  - Environments rotate across angles to prevent LoRA overfitting

Usage:
    python3 batch_threepass.py leviathan/ --character JINX
    python3 batch_threepass.py leviathan/ --character JINX --dry-run
    python3 batch_threepass.py leviathan/ --character JINX --expressions neutral
    python3 batch_threepass.py leviathan/ --character JINX --angles front,profile_right
    python3 batch_threepass.py leviathan/ --character JINX --angles full_body,low_angle,high_angle
    python3 batch_threepass.py leviathan/ --character JINX --lighting "warm amber glow"
    python3 batch_threepass.py leviathan/ --character JINX --no-smart-back

Cost:  ~$0.036-0.17/job (2-pass for neutral-only, 3-pass for expression/body angles)
       Default 14 angles: ~30 jobs, ~$2.68, ~23 min
"""

import argparse
import json
import subprocess
import sys
import time
from datetime import datetime
from pathlib import Path

# Shared config + cost tracking
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "lib"))
from config_loader import load_project_config
from cost_tracker import CostTracker

# ── Expression Tiers ──────────────────────────────────────────────────
# Format: "emotion anchor — physical descriptor, descriptor, descriptor"
# Research finding: 1 emotion word + 2-3 physical descriptors is optimal
# for both Gemini (NBP) and diffusion models.

EXPRESSION_TIERS = {
    "neutral": [
        "neutral — relaxed features, steady gaze, lips closed naturally",
    ],
    "moderate": [
        "tired — slightly heavy eyelids, soft unfocused gaze, relaxed jaw",
        "focused — narrowed eyes, set jaw, intent forward stare",
        "wary — guarded gaze, slight tension around the mouth, watchful eyes",
    ],
    "intense": [
        "exhausted — heavy-lidded eyes, slight frown, drained hollow gaze",
        "furious — bared teeth, flared nostrils, intense glare with furrowed brow",
        "grief — glassy wet eyes, trembling lower lip, downturned mouth",
    ],
}

# For LoRA training: mostly neutral + moderate, a few intense for range
DEFAULT_EXPRESSION_SET = {
    "neutral": EXPRESSION_TIERS["neutral"],           # 1 expression
    "moderate": EXPRESSION_TIERS["moderate"],          # 3 expressions
    "intense": EXPRESSION_TIERS["intense"][:1],        # 1 expression (exhausted only)
}

# Full expression range for frontal/close-up angles (5 total)
FULL_EXPRESSIONS = (
    EXPRESSION_TIERS["neutral"]
    + EXPRESSION_TIERS["moderate"]
    + EXPRESSION_TIERS["intense"][:1]
)

# Mild expressions for 3/4 angles — breaks angle-expression coupling (3 total)
MILD_EXPRESSIONS = (
    EXPRESSION_TIERS["neutral"]
    + EXPRESSION_TIERS["moderate"][:2]  # tired + focused
)

# Neutral only for profile/low/high/back (1 total)
NEUTRAL_EXPRESSIONS = EXPRESSION_TIERS["neutral"]

# ── Angle Set ─────────────────────────────────────────────────────────
# Standard 14-angle Qwen Multi-Angle set

ALL_ANGLES = [
    "front",
    "three_quarter_right",
    "profile_right",
    "back_right",
    "back",
    "back_left",
    "profile_left",
    "three_quarter_left",
    "low_angle",
    "high_angle",
    "full_body",
    "full_body_three_quarter",
    "closeup_front",
    "closeup_three_quarter",
]

ANGLE_LABELS = {
    "front": "Front",
    "three_quarter_right": "3/4R",
    "profile_right": "ProfR",
    "back_right": "BackR",
    "back": "Back",
    "back_left": "BackL",
    "profile_left": "ProfL",
    "three_quarter_left": "3/4L",
    "low_angle": "Low",
    "high_angle": "High",
    "full_body": "FB",
    "full_body_three_quarter": "FB-3/4",
    "closeup_front": "CU-F",
    "closeup_three_quarter": "CU-3/4",
}

# Angle categories for smart expression distribution.
# Research: Frontal expressions are "significantly better recognized" than
# profile (Ibarretxe-Bilbao 2021). LoRA learns identity, not expressions.
EXPRESSION_ANGLES = {"front", "closeup_front", "closeup_three_quarter"}
MILD_EXPRESSION_ANGLES = {"three_quarter_right", "three_quarter_left"}
# Body angles get 3-pass with proportionality prompting (not 2-pass)
BODY_EXPRESSION_ANGLES = {"low_angle", "high_angle", "full_body", "full_body_three_quarter"}
NEUTRAL_ONLY_ANGLES = {"profile_right", "profile_left"}
BACK_ANGLES = {"back", "back_left", "back_right"}

# ── Environment Rotation ─────────────────────────────────────────────
# Environments cycle across angles to prevent LoRA overfitting to a
# single backdrop. Pulled from breakdown.json habitat_zones if available,
# falling back to generic defaults if no breakdown exists.

FALLBACK_ENVIRONMENT_POOL = [
    "dimly lit industrial corridor with exposed pipes and warm overhead sodium lamps",
    "rain-slicked city rooftop at dusk with neon reflections and cool blue ambient light",
    "sparse concrete bunker interior with a single harsh fluorescent tube overhead",
    "cluttered underground workshop lit by warm desk lamps and welding sparks",
    "abandoned subway platform with flickering emergency lights and wet tile walls",
    "open desert highway at golden hour with long shadows and warm orange backlight",
]


def _sanitize_visual_dna(visual_dna: str) -> list[str]:
    """Sanitize a visual_dna string for use as an image generation environment prompt.

    Handles two classes of problems:
    1. Dual-scene descriptions ("Shuttle bay: X. Planet surface: Y") — splits into
       separate entries so NBP doesn't render a literal split-screen composite.
    2. Non-visual/thematic language ("split personality", "clinical horror",
       "ceremonial precision") — strips phrases that aren't camera-visible.

    Returns a list of 1+ sanitized environment strings.
    """
    import re

    # ── Split dual-scene entries ──
    # Pattern: "Label: description. Label: description" — two named scenes
    # in one string. Split at ". [A-Z]" boundaries that look like scene labels.
    scenes = re.split(r'\.\s+(?=[A-Z][a-z]+(?:\s+[a-z]+)*\s*:)', visual_dna)
    if len(scenes) == 1:
        # Also try "X vs Y" / "X versus Y" splits
        vs_parts = re.split(r'\s+vs\.?\s+|\s+versus\s+', visual_dna, flags=re.IGNORECASE)
        if len(vs_parts) > 1:
            scenes = vs_parts

    # ── Strip non-visual language from each scene ──
    # These are abstract/thematic/narrative phrases that image models
    # interpret literally or nonsensically.
    NON_VISUAL_PHRASES = [
        r'[Ss]plit personality:?\s*',
        r'[Cc]linical horror',
        r'[Cc]eremonial precision',
        r'[Ff]orgotten territory(?:\s+with)?\b',
        r'FIRST\s+',  # ALL-CAPS narrative emphasis (e.g. "FIRST NATURAL LIGHT")
    ]

    sanitized = []
    for scene in scenes:
        scene = scene.strip().rstrip('.')
        if not scene:
            continue
        # Remove scene label prefix if present ("Shuttle bay: X" → "X"
        # only if what follows the colon is a full description)
        label_match = re.match(r'^([A-Z][a-z]+(?:\s+[a-z]+)*)\s*:\s*(.+)', scene)
        if label_match and len(label_match.group(2)) > 30:
            scene = label_match.group(2)
        for pattern in NON_VISUAL_PHRASES:
            scene = re.sub(pattern, '', scene)
        # Clean up residual artifacts
        scene = re.sub(r',\s*,', ',', scene)  # double commas
        scene = re.sub(r',?\s*\b(with|of|and|in|at)\s*$', '', scene)  # trailing prepositions
        scene = re.sub(r'\s{2,}', ' ', scene).strip().strip(',').strip()
        if scene:
            sanitized.append(scene)

    return sanitized if sanitized else [visual_dna]


def load_environment_pool(project_path):
    """Load environment descriptions from breakdown.json habitat_zones.

    Returns a list of sanitized visual_dna strings from the project's habitat
    zones. Each visual_dna is run through _sanitize_visual_dna() to strip
    non-visual language and split dual-scene descriptions into separate entries.
    Falls back to FALLBACK_ENVIRONMENT_POOL if breakdown.json doesn't exist
    or has no habitat_zones.
    """
    breakdown_path = Path(project_path) / "visual" / "breakdown.json"
    if not breakdown_path.is_file():
        print(f"  NOTE: No breakdown.json found — using fallback environments")
        return list(FALLBACK_ENVIRONMENT_POOL)

    try:
        with open(breakdown_path) as f:
            breakdown = json.load(f)
    except (json.JSONDecodeError, OSError) as e:
        print(f"  WARNING: Could not read breakdown.json: {e}")
        return list(FALLBACK_ENVIRONMENT_POOL)

    zones = breakdown.get("habitat_zones", {})
    if not zones:
        print(f"  NOTE: No habitat_zones in breakdown.json — using fallback environments")
        return list(FALLBACK_ENVIRONMENT_POOL)

    pool = []
    sanitized_count = 0
    for zone_key, zone_data in zones.items():
        visual_dna = zone_data.get("visual_dna", "")
        if visual_dna:
            entries = _sanitize_visual_dna(visual_dna)
            if len(entries) > 1 or entries[0] != visual_dna:
                sanitized_count += 1
            pool.extend(entries)

    if not pool:
        print(f"  NOTE: habitat_zones found but no visual_dna — using fallback environments")
        return list(FALLBACK_ENVIRONMENT_POOL)

    msg = f"  Environments: {len(pool)} from {len(zones)} habitat zones"
    if sanitized_count:
        msg += f" ({sanitized_count} sanitized)"
    print(msg)
    return pool


def load_character_traits(project_path, character):
    """Load character-specific visual traits from breakdown.json.

    Returns a string of visual traits to inject into the identity prompt,
    or None if no character data is found.
    """
    breakdown_path = Path(project_path) / "visual" / "breakdown.json"
    if not breakdown_path.is_file():
        return None

    try:
        with open(breakdown_path) as f:
            breakdown = json.load(f)
    except (json.JSONDecodeError, OSError):
        return None

    characters = breakdown.get("characters", {})
    char_data = characters.get(character.upper(), {})
    if not char_data:
        return None

    visual_desc = char_data.get("visual_description", "")
    if visual_desc:
        print(f"  Char traits: loaded from breakdown.json")
        return visual_desc
    return None


def main():
    parser = argparse.ArgumentParser(description="Batch three-pass LoRA training candidate generation")
    parser.add_argument("project", help="Project path (e.g., leviathan/)")
    parser.add_argument("--character", required=True, help="Character key (e.g., JINX)")
    parser.add_argument("--angles", default=None, help="Comma-separated angle keys (default: all 14)")
    parser.add_argument("--expressions", default=None,
                        help="Comma-separated tier names: neutral,moderate,intense (default: neutral+moderate+1 intense)")
    parser.add_argument("--lighting", default=None, help="Lighting description override")
    parser.add_argument("--dry-run", action="store_true", help="Show plan without generating")
    parser.add_argument("--hero", default=None, help="Explicit hero image path")
    parser.add_argument("--delay", type=int, default=5, help="Seconds between runs (default: 5)")
    parser.add_argument("--no-smart-back", action="store_true",
                        help="Disable smart back-angle handling (run full expressions on back angles)")
    parser.add_argument("--no-env-rotation", action="store_true",
                        help="Disable environment rotation (use breakdown.json default for all)")

    args = parser.parse_args()

    smart_back = not args.no_smart_back
    env_rotation = not args.no_env_rotation

    # Resolve project path (relative to recoil root or absolute)
    project_path = Path(args.project).resolve()
    if not project_path.is_dir():
        # Try relative to script's grandparent (recoil root)
        recoil_root = Path(__file__).resolve().parent.parent.parent
        project_path = recoil_root / args.project
    if not project_path.is_dir():
        print(f"ERROR: Project path not found: {args.project}", file=sys.stderr)
        sys.exit(1)

    # Cost tracking + budget cap
    project_config = load_project_config(project_path)
    tracker = CostTracker(project_path)
    budget_cap = project_config.get("budget_cap_usd")

    # Load environment pool from breakdown.json habitat zones
    ENVIRONMENT_POOL = load_environment_pool(project_path)

    # Load character-specific visual traits from breakdown.json
    character_traits = load_character_traits(project_path, args.character)

    # Resolve angles
    if args.angles:
        angles = [a.strip() for a in args.angles.split(",")]
        for a in angles:
            if a not in ALL_ANGLES:
                print(f"ERROR: Unknown angle '{a}'. Available: {', '.join(ALL_ANGLES)}", file=sys.stderr)
                sys.exit(1)
    else:
        angles = ALL_ANGLES

    # Resolve expressions
    if args.expressions:
        tier_names = [t.strip() for t in args.expressions.split(",")]
        expressions = []
        for t in tier_names:
            if t not in EXPRESSION_TIERS:
                print(f"ERROR: Unknown tier '{t}'. Available: neutral, moderate, intense", file=sys.stderr)
                sys.exit(1)
            expressions.extend(EXPRESSION_TIERS[t])
    else:
        expressions = []
        for tier_exprs in DEFAULT_EXPRESSION_SET.values():
            expressions.extend(tier_exprs)

    # Build job list with smart expression distribution
    # Resolve per-tier expression sets (may be overridden by --expressions)
    if args.expressions:
        # User override: use the same expressions for all angles
        full_expressions = expressions
        mild_expressions = expressions
    else:
        full_expressions = FULL_EXPRESSIONS
        mild_expressions = MILD_EXPRESSIONS

    jobs = []  # (angle, expression, skip_pass3, environment_override)
    env_index = 0

    for angle in angles:
        # Determine expression set and pipeline for this angle
        if angle in EXPRESSION_ANGLES:
            angle_exprs = full_expressions       # 5 expressions
            skip_p3 = False                      # Full 3-pass pipeline
        elif angle in MILD_EXPRESSION_ANGLES:
            angle_exprs = mild_expressions       # 3 expressions
            skip_p3 = False                      # Full 3-pass pipeline
        elif angle in BODY_EXPRESSION_ANGLES:
            angle_exprs = NEUTRAL_EXPRESSIONS    # neutral only (body shots)
            skip_p3 = False                      # 3-pass — NBP adds proportionality
        elif angle in NEUTRAL_ONLY_ANGLES:
            angle_exprs = NEUTRAL_EXPRESSIONS    # neutral only
            skip_p3 = True                       # 2-pass (no NBP)
        elif angle in BACK_ANGLES and smart_back:
            angle_exprs = NEUTRAL_EXPRESSIONS    # neutral only
            skip_p3 = True                       # 2-pass (no NBP)
        else:
            angle_exprs = full_expressions       # fallback (--no-smart-back)
            skip_p3 = False

        for expr in angle_exprs:
            # Environment rotation: cycle through pool
            if env_rotation:
                env_override = ENVIRONMENT_POOL[env_index % len(ENVIRONMENT_POOL)]
                env_index += 1
            else:
                env_override = None

            jobs.append((angle, expr, skip_p3, env_override))

    # Cost: 3-pass = Qwen MA ($0.035) + NBP ($0.065) + SeedVR2 ($0.001) = ~$0.10
    # Cost: 2-pass = Qwen MA ($0.035) + SeedVR2 ($0.001) = ~$0.036
    cost_3pass = sum(1 for _, _, skip, _ in jobs if not skip) * 0.10
    cost_2pass = sum(1 for _, _, skip, _ in jobs if skip) * 0.036
    total_cost = cost_3pass + cost_2pass
    total_time_est = sum(0.9 if skip else 1.5 for _, _, skip, _ in jobs)  # ~1.5 min 3-pass, ~0.9 min 2-pass

    # Resolve script path
    script_dir = Path(__file__).resolve().parent
    shootout_script = script_dir / "engine_shootout.py"
    if not shootout_script.is_file():
        print(f"ERROR: engine_shootout.py not found at {shootout_script}", file=sys.stderr)
        sys.exit(1)

    jobs_3pass = sum(1 for _, _, skip, _ in jobs if not skip)
    jobs_2pass = sum(1 for _, _, skip, _ in jobs if skip)

    print(f"\n{'='*60}")
    print(f"BATCH THREE-PASS — {args.character.upper()}")
    print(f"{'='*60}")
    print(f"  Pipeline:    Qwen MA → NBP (expression+body angles) | Qwen MA → SeedVR2 (profile/back)")
    print(f"  Project:     {args.project}")
    print(f"  Angles:      {len(angles)}")
    print(f"  Total jobs:  {len(jobs)} ({jobs_3pass} Qwen→NBP + {jobs_2pass} Qwen→SeedVR2)")
    print(f"  Est. cost:   ~${total_cost:.2f}")
    print(f"  Est. time:   ~{total_time_est:.0f} min")
    if args.lighting:
        print(f"  Lighting:    {args.lighting}")
    if env_rotation:
        print(f"  Env rotate:  ON ({len(ENVIRONMENT_POOL)} environments cycling)")
    print(f"{'='*60}")

    print(f"\n  Expression Distribution:")
    print(f"    Frontal+closeups:  {len(full_expressions)} expressions (full range)")
    print(f"    3/4 views:         {len(mild_expressions)} expressions (neutral + mild)")
    print(f"    Body/low/high:     {len(NEUTRAL_EXPRESSIONS)} expression (neutral, 3-pass w/ proportionality)")
    print(f"    Profile:           {len(NEUTRAL_EXPRESSIONS)} expression (neutral only)")
    print(f"    Back angles:       {len(NEUTRAL_EXPRESSIONS)} expression (neutral only)")

    print(f"\n  Angles:")
    for angle in angles:
        if angle in EXPRESSION_ANGLES:
            tag = f" [EXPR: {len(full_expressions)} expressions, Qwen→NBP]"
        elif angle in MILD_EXPRESSION_ANGLES:
            tag = f" [MILD: {len(mild_expressions)} expressions, Qwen→NBP]"
        elif angle in BODY_EXPRESSION_ANGLES:
            tag = " [BODY: neutral, 3-pass w/ proportionality]"
        elif angle in NEUTRAL_ONLY_ANGLES:
            tag = " [NEUTRAL: 1 expression, 2-pass]"
        elif angle in BACK_ANGLES and smart_back:
            tag = " [BACK: neutral only, 2-pass]"
        else:
            tag = f" [FULL: {len(full_expressions)} expressions, Qwen→NBP]"
        print(f"    - {angle} ({ANGLE_LABELS.get(angle, angle)}){tag}")

    if env_rotation:
        print(f"\n  Environment Pool ({len(ENVIRONMENT_POOL)}):")
        for i, env in enumerate(ENVIRONMENT_POOL):
            print(f"    [{i+1}] {env[:70]}...")

    if args.dry_run:
        print(f"\n  DRY RUN — {len(jobs)} jobs planned, nothing generated.")
        print(f"\n  Jobs:")
        for i, (angle, expr, skip_p3, env_ov) in enumerate(jobs):
            label = ANGLE_LABELS.get(angle, angle)
            expr_short = expr.split("—")[0].strip() if "—" in expr else expr[:20]
            mode_tag = "SVR2" if skip_p3 else "NBP"
            env_tag = f"env{(i % len(ENVIRONMENT_POOL)) + 1}" if env_ov else "default"
            print(f"    [{i+1:2d}/{len(jobs)}] {label:6s} | {expr_short:12s} | {mode_tag} | {env_tag}")
        sys.exit(0)

    print(f"\n  Starting in 3 seconds...\n")
    time.sleep(3)

    # Run jobs
    succeeded = 0
    failed = 0
    skipped = 0
    start_time = time.time()

    for i, (angle, expr, skip_p3, env_override) in enumerate(jobs):
        # Budget check before each job
        if budget_cap:
            # Reload tracker to pick up costs logged by subprocess
            tracker = CostTracker(project_path)
            ok, remaining, spent = tracker.check_budget(budget_cap)
            if not ok:
                print(f"\n  BUDGET CAP HIT: ${spent:.2f} spent of ${budget_cap:.2f} cap (${remaining:.2f} remaining)")
                print(f"  Stopping batch. {len(jobs) - i} jobs remaining.")
                print(f"  Increase budget_cap_usd in project_config.json to continue.")
                skipped = len(jobs) - i
                break

        label = ANGLE_LABELS.get(angle, angle)
        expr_short = expr.split("—")[0].strip() if "—" in expr else expr[:20]
        mode_tag = "Qwen→SeedVR2" if skip_p3 else "Qwen→NBP"
        print(f"\n  [{i+1}/{len(jobs)}] {label} / {expr_short} ({mode_tag})")
        print(f"  {'─'*50}")

        cmd = [
            sys.executable, str(shootout_script),
            args.project,
            "--character", args.character,
            "--angle", angle,
            "--expression", expr,
            "--threepass",
        ]
        if skip_p3:
            cmd.append("--skip-pass3")
        if args.lighting:
            cmd.extend(["--lighting", args.lighting])
        if args.hero:
            cmd.extend(["--hero", args.hero])
        if env_override:
            cmd.extend(["--environment", env_override])
        if character_traits:
            cmd.extend(["--character-traits", character_traits])

        try:
            result = subprocess.run(
                cmd,
                capture_output=False,
                text=True,
                timeout=600,  # 10 min max per run
            )
            if result.returncode == 0:
                succeeded += 1
            elif result.returncode == 2:
                # Exit code 2 = budget cap hit inside engine_shootout
                print(f"  ** BUDGET CAP — stopping batch **")
                skipped = len(jobs) - i - 1
                failed += 1
                break
            else:
                failed += 1
                print(f"  ** FAILED (exit code {result.returncode}) **")
        except subprocess.TimeoutExpired:
            failed += 1
            print(f"  ** TIMEOUT (10 min limit) **")
        except Exception as e:
            failed += 1
            print(f"  ** ERROR: {str(e)[:80]} **")

        # Delay between runs to avoid rate limiting
        if i < len(jobs) - 1:
            time.sleep(args.delay)

    elapsed = time.time() - start_time
    actual_cost = sum(0.036 if skip else 0.10 for _, _, skip, _ in jobs[:succeeded])

    # Report tracked cost from cost_log
    tracker = CostTracker(project_path)  # reload to get subprocess writes
    total_tracked = tracker.total()

    print(f"\n{'='*60}")
    print(f"BATCH COMPLETE — {args.character.upper()}")
    print(f"{'='*60}")
    print(f"  Succeeded:   {succeeded}/{len(jobs)}")
    print(f"  Failed:      {failed}/{len(jobs)}")
    if skipped:
        print(f"  Skipped:     {skipped}/{len(jobs)} (budget cap)")
    print(f"  Total time:  {elapsed/60:.1f} min")
    print(f"  Est. cost:   ~${actual_cost:.2f}")
    print(f"  Tracked:     ${total_tracked:.2f} total project spend")
    if budget_cap:
        print(f"  Budget:      ${total_tracked:.2f} / ${budget_cap:.2f}")
    print(f"  Reviewer:    http://127.0.0.1:8420/shootout_reviewer.html?project={args.project.strip('/')}&character={args.character.upper()}")
    print(f"{'='*60}\n")


if __name__ == "__main__":
    main()
