#!/usr/bin/env python3
"""
Pre-Production Gate Validator.

Confirms all upstream visual pipeline gates are green before committing
GPU resources to full production frame/video generation.

This is the visual pipeline equivalent of validate_pre_treatment.py —
it gates the expensive thing (generation) behind the cheap things
(grammar check, previz review, visual bible validation).

Usage:
  python3 validate_pre_production.py <project_path>
  python3 validate_pre_production.py <project_path> --json
  python3 validate_pre_production.py <project_path> --episode N
  python3 validate_pre_production.py <project_path> --prompt

Exit codes:
  0 = PASS — all gates green, safe to generate
  1 = FAIL — upstream gates not satisfied
  2 = file/parse error
"""

import argparse
import glob
import json
import os
import subprocess
import sys


def check_visual_bible(project_path):
    """Verify visual_bible.md exists and passes validation."""
    errors = []
    warnings = []

    vb_path = os.path.join(project_path, "visual_bible.md")
    if not os.path.exists(vb_path):
        errors.append("visual_bible.md not found — run /visual-design first")
        return errors, warnings

    # Run the visual bible validator
    validator = os.path.join(
        os.path.dirname(__file__), "validate_visual_bible.py"
    )
    if os.path.exists(validator):
        try:
            result = subprocess.run(
                [sys.executable, validator, vb_path, project_path, "--json"],
                capture_output=True,
                text=True,
                timeout=30,
            )
            if result.returncode == 1:
                errors.append(
                    "visual_bible.md has hard errors — run: "
                    f"python3 {validator} {vb_path} {project_path} --prompt"
                )
            elif result.returncode == 2:
                warnings.append("visual_bible.md has warnings (non-blocking)")
        except (subprocess.TimeoutExpired, FileNotFoundError):
            warnings.append("Could not run visual bible validator")
    else:
        warnings.append("validate_visual_bible.py not found — skipping visual bible check")

    return errors, warnings


def check_breakdown(project_path):
    """Verify breakdown.json exists."""
    errors = []
    warnings = []

    breakdown_path = os.path.join(project_path, "visual", "breakdown.json")
    if not os.path.exists(breakdown_path):
        errors.append("visual/breakdown.json not found — run /breakdown first")
        return errors, warnings

    try:
        with open(breakdown_path) as f:
            breakdown = json.load(f)
        if not breakdown.get("episodes"):
            warnings.append("breakdown.json has no episodes — may be empty")
    except (json.JSONDecodeError, IOError):
        errors.append("breakdown.json is not valid JSON")

    return errors, warnings


def check_storyboard(project_path, episode=None):
    """Verify storyboard(s) exist and pass grammar validation."""
    errors = []
    warnings = []

    sb_dir = os.path.join(project_path, "storyboards")
    if not os.path.isdir(sb_dir):
        errors.append("storyboards/ directory not found — run /storyboard first")
        return errors, warnings

    # Find storyboard files
    if episode:
        pattern = os.path.join(sb_dir, f"storyboard_ep_{episode:03d}.json")
        sb_files = glob.glob(pattern)
        if not sb_files:
            errors.append(f"Storyboard for episode {episode} not found: {pattern}")
            return errors, warnings
    else:
        sb_files = sorted(glob.glob(os.path.join(sb_dir, "storyboard_ep_*.json")))
        if not sb_files:
            errors.append("No storyboard files found in storyboards/")
            return errors, warnings

    # Validate each storyboard
    validator = os.path.join(os.path.dirname(__file__), "validate_storyboard.py")
    ep_dir = os.path.join(project_path, "episodes")

    for sb_file in sb_files:
        sb_name = os.path.basename(sb_file)

        # Find matching episode
        ep_match = None
        try:
            with open(sb_file) as f:
                sb_data = json.load(f)
            ep_num = sb_data.get("episode", 0)
            ep_path = os.path.join(ep_dir, f"ep_{ep_num:03d}.md")
            if os.path.exists(ep_path):
                ep_match = ep_path
        except (json.JSONDecodeError, IOError):
            errors.append(f"{sb_name}: not valid JSON")
            continue

        if not ep_match:
            warnings.append(f"{sb_name}: matching episode script not found — skipping validation")
            continue

        # Run storyboard validator (includes grammar checks)
        if os.path.exists(validator):
            try:
                result = subprocess.run(
                    [
                        sys.executable,
                        validator,
                        sb_file,
                        ep_match,
                        "--json",
                        "--project",
                        project_path,
                    ],
                    capture_output=True,
                    text=True,
                    timeout=30,
                )
                if result.returncode == 1:
                    try:
                        vdata = json.loads(result.stdout)
                        grammar_errors = [
                            e for e in vdata.get("errors", []) if "GRAMMAR:" in e
                        ]
                        other_errors = [
                            e for e in vdata.get("errors", []) if "GRAMMAR:" not in e
                        ]
                        if grammar_errors:
                            errors.append(
                                f"{sb_name}: {len(grammar_errors)} grammar error(s)"
                            )
                        if other_errors:
                            errors.append(
                                f"{sb_name}: {len(other_errors)} structural error(s)"
                            )
                    except json.JSONDecodeError:
                        errors.append(f"{sb_name}: validation failed (exit code 1)")
            except (subprocess.TimeoutExpired, FileNotFoundError):
                warnings.append(f"{sb_name}: could not run validator")

    return errors, warnings


def check_previz(project_path, episode=None):
    """Check if previz exists for the episode(s) being generated.

    Looks for previz_manifest.json in storyboards/assets/ep_NNN/previz/.
    This is a warning (not a hard gate) — previz is recommended before
    production but not required.
    """
    errors = []
    warnings = []

    assets_dir = os.path.join(project_path, "storyboards", "assets")
    if not os.path.isdir(assets_dir):
        warnings.append(
            "No storyboards/assets/ directory — "
            "consider running previz before full generation"
        )
        return errors, warnings

    if episode:
        manifest = os.path.join(
            assets_dir, f"ep_{episode:03d}", "previz", "previz_manifest.json"
        )
        if not os.path.isfile(manifest):
            warnings.append(
                f"No previz for episode {episode} — "
                "consider running: python3 generate_previz.py <project> --episode "
                f"{episode}"
            )
    else:
        # Check if any previz manifests exist
        manifests = glob.glob(
            os.path.join(assets_dir, "ep_*", "previz", "previz_manifest.json")
        )
        if not manifests:
            warnings.append(
                "No previz found — consider running previz before full generation"
            )

    return errors, warnings


def check_lora(project_path):
    """Check if LoRA registry exists and has trained models."""
    errors = []
    warnings = []

    registry_path = os.path.join(project_path, "visual", "lora_registry.json")
    if not os.path.exists(registry_path):
        warnings.append(
            "No LoRA registry — character consistency will depend on prompts only"
        )
        return errors, warnings

    try:
        with open(registry_path) as f:
            registry = json.load(f)
        trained = [
            k
            for k, v in registry.get("characters", {}).items()
            if v.get("t2i", {}).get("status") == "completed"
        ]
        pending = [
            k
            for k, v in registry.get("characters", {}).items()
            if v.get("t2i", {}).get("status") != "completed"
        ]
        if not trained:
            warnings.append("No trained LoRAs — character identity may drift")
        if pending:
            warnings.append(
                f"Pending LoRA training: {', '.join(pending)}"
            )
    except (json.JSONDecodeError, IOError):
        warnings.append("LoRA registry is not valid JSON")

    return errors, warnings


def validate(project_path, episode=None):
    """Run all pre-production checks.

    Returns (is_valid, errors, warnings, stats)
    """
    all_errors = []
    all_warnings = []

    checks = [
        ("Visual Bible", check_visual_bible),
        ("Breakdown", check_breakdown),
        ("Storyboard", lambda p: check_storyboard(p, episode)),
        ("Previz", lambda p: check_previz(p, episode)),
        ("LoRA", check_lora),
    ]

    gate_results = {}
    for name, check_fn in checks:
        errs, warns = check_fn(project_path)
        all_errors.extend(errs)
        all_warnings.extend(warns)
        gate_results[name] = "PASS" if not errs else "FAIL"

    stats = {
        "gates": gate_results,
        "total_errors": len(all_errors),
        "total_warnings": len(all_warnings),
        "episode": episode,
    }

    is_valid = len(all_errors) == 0
    return is_valid, all_errors, all_warnings, stats


def main():
    parser = argparse.ArgumentParser(
        description="Pre-production gate — validates all upstream visual pipeline gates"
    )
    parser.add_argument("project", help="Path to project folder")
    parser.add_argument("--json", action="store_true", help="Machine-readable output")
    parser.add_argument(
        "--episode",
        type=int,
        help="Check gates for a specific episode",
    )
    parser.add_argument(
        "--prompt",
        action="store_true",
        help="Output fix instructions",
    )

    args = parser.parse_args()

    if not os.path.isdir(args.project):
        print(f"ERROR: Project directory not found: {args.project}")
        sys.exit(2)

    is_valid, errors, warnings, stats = validate(args.project, episode=args.episode)

    if args.json:
        output = {
            "is_valid": is_valid,
            "errors": errors,
            "warnings": warnings,
            "stats": stats,
        }
        print(json.dumps(output, indent=2))

    elif args.prompt:
        if is_valid:
            print("All pre-production gates PASS. Safe to generate.")
        else:
            print("PRE-PRODUCTION GATE FAILURES:")
            print()
            for e in errors:
                print(f"  ✗ {e}")
            print()
            print("FIX THESE BEFORE GENERATING:")
            if any("visual_bible" in e for e in errors):
                print("  1. Run /visual-design [project] to create/fix visual_bible.md")
            if any("breakdown" in e for e in errors):
                print("  2. Run /breakdown [project] to create breakdown.json")
            if any("grammar" in e.lower() or "storyboard" in e.lower() for e in errors):
                print("  3. Fix storyboard grammar errors (run validate_storyboard.py --prompt)")
            if warnings:
                print()
                print("WARNINGS (non-blocking):")
                for w in warnings:
                    print(f"  ! {w}")

    else:
        print(f"=== Pre-Production Gate ===")
        print(f"Project: {os.path.basename(args.project)}")
        if args.episode:
            print(f"Episode: {args.episode}")
        print()

        for gate, status in stats["gates"].items():
            marker = "✓" if status == "PASS" else "✗"
            print(f"  {marker} {gate}: {status}")
        print()

        if errors:
            print(f"ERRORS ({len(errors)}):")
            for e in errors:
                print(f"  ✗ {e}")
            print()

        if warnings:
            print(f"WARNINGS ({len(warnings)}):")
            for w in warnings:
                print(f"  ! {w}")
            print()

        if is_valid:
            print("RESULT: PASS — safe to proceed with generation")
        else:
            print("RESULT: FAIL — fix errors before generating")

    sys.exit(0 if is_valid else 1)


if __name__ == "__main__":
    main()
