#!/usr/bin/env python3
"""
prompt_doctor.py — Pre-Generation Prompt Analysis

Analyzes storyboard prompts BEFORE generation to catch problems that would
waste regeneration cycles (and money). Uses Gemini to evaluate prompts for:

  - Anatomical risk: hands, faces, complex poses likely to deform
  - Ambiguity: vague descriptions that produce unpredictable results
  - Complexity overload: too many elements for a single generation
  - LoRA compatibility: prompt patterns known to conflict with LoRA identity
  - Prompt adjustments: specific rewrites to reduce failure probability

The idea: cheaper to fix a prompt ($0.001) than regenerate a keyframe ($0.035+).

Usage:
    # Analyze a single shot
    python3 prompt_doctor.py check \\
        --storyboard storyboard_ep_001.json \\
        --shot-id 3

    # Analyze all shots in an episode
    python3 prompt_doctor.py batch \\
        --project leviathan --episode 1

    # Generate a risk report (no API calls — heuristic analysis only)
    python3 prompt_doctor.py scan \\
        --project leviathan --episode 1

Exit codes: 0 = all clear, 1 = issues found, 2 = error

Requires:
    pip install google-generativeai
    export GOOGLE_API_KEY="your-key-here"
"""

import argparse
import json
import os
import re
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import List, Optional

from cost_tracker import CostTracker
from recoil.core.model_profiles import get_model


# ── Configuration ─────────────────────────────────────────────────────────

# Prompt doctor uses Flash — text-only, structured analysis, high volume.
# No vision needed, just reasoning about prompt text.
DEFAULT_MODEL = get_model("flash", "text")

# Risk patterns (heuristic scan, no API needed)
ANATOMICAL_RISK_PATTERNS = [
    (r"\bhand[s]?\b", "hands", "high",
     "Hands are the #1 failure mode in AI generation. Ensure hands are described "
     "with specific positions (e.g., 'left hand gripping pipe' not just 'hands visible')."),
    (r"\bfinger[s]?\b", "fingers", "high",
     "Explicit finger descriptions increase deformation risk. "
     "Prefer hand-level descriptions ('gripping', 'clenched') over finger-level."),
    (r"\bpointing\b", "pointing gesture", "medium",
     "Pointing gestures reliably produce extra/deformed fingers. "
     "Consider rephrasing to avoid explicit pointing."),
    (r"\bholding\b.*\bholding\b", "dual holding", "high",
     "Two 'holding' actions in one prompt = two hand-object interactions = high failure rate. "
     "Simplify to one held object or use a wider shot."),
    (r"\b(embrace|hugging|holding each other)\b", "physical contact", "high",
     "Character-to-character physical contact is extremely difficult for AI. "
     "Limbs merge, extra arms appear. Consider framing to minimize contact area."),
    (r"\b(running|jumping|leaping|fighting)\b", "dynamic action", "medium",
     "Dynamic full-body poses have higher deformation rates than static poses. "
     "Ensure the shot type (MS/LS) gives enough frame for the action."),
    (r"\b(from behind|over.shoulder)\b", "rear angle", "low",
     "Back/OTS angles generally reduce anatomical risk (face/hands less visible)."),
    (r"\b(close.up|ECU|extreme close)\b.*\bface\b", "face close-up", "medium",
     "Extreme facial close-ups amplify any facial asymmetry or uncanny valley effects. "
     "Ensure the LoRA training data includes this angle."),
    (r"\b(two|both|multiple)\b.*\b(character|person|people|figure)\b", "multi-character", "high",
     "Multi-character compositions are the second-hardest generation task. "
     "Identity blending, proportion mismatch, and limb confusion are common. "
     "Verify dual LoRA scales are set correctly."),
    (r"\b(reflection|mirror|glass)\b", "reflection", "high",
     "Reflective surfaces produce doubled/distorted character appearances. "
     "Consider removing reflection descriptions or compositing in post."),
    (r"\b(text|writing|sign|label|display|readout)\b", "text rendering", "medium",
     "AI models struggle with legible text. If the text content matters, "
     "plan for post-compositing. If decorative, add 'illegible' qualifier."),
]

# Prompt length thresholds
PROMPT_LENGTH_WARNING = 200  # words — longer prompts dilute attention
PROMPT_LENGTH_CRITICAL = 300  # words — model likely ignores late elements


# ── Heuristic Scanner (no API) ────────────────────────────────────────────

def scan_prompt(prompt_text: str, shot: dict) -> List[dict]:
    """Scan a prompt for risk patterns using regex heuristics.

    No API call — instant, free analysis. Returns list of findings.
    """
    findings = []
    text_lower = prompt_text.lower()

    # Pattern matching
    for pattern, name, risk, advice in ANATOMICAL_RISK_PATTERNS:
        if re.search(pattern, text_lower):
            findings.append({
                "type": "anatomical_risk",
                "pattern": name,
                "risk": risk,
                "advice": advice,
            })

    # Prompt length check
    word_count = len(prompt_text.split())
    if word_count > PROMPT_LENGTH_CRITICAL:
        findings.append({
            "type": "complexity",
            "pattern": "prompt_too_long",
            "risk": "high",
            "advice": f"Prompt is {word_count} words — model attention degrades "
                      f"significantly past ~200 words. Elements described late in "
                      f"the prompt may be ignored. Consider splitting into layers "
                      f"or reducing descriptive density.",
            "word_count": word_count,
        })
    elif word_count > PROMPT_LENGTH_WARNING:
        findings.append({
            "type": "complexity",
            "pattern": "prompt_long",
            "risk": "medium",
            "advice": f"Prompt is {word_count} words — approaching attention limit. "
                      f"Prioritize the most important visual elements early in the prompt.",
            "word_count": word_count,
        })

    # Shot type vs content mismatch
    shot_type = shot.get("shot_type", "")
    if shot_type in ("ECU", "CU") and re.search(r"\b(full.body|legs|feet|walking)\b", text_lower):
        findings.append({
            "type": "framing_mismatch",
            "pattern": "wide_elements_in_close_shot",
            "risk": "medium",
            "advice": f"Shot is {shot_type} but prompt describes full-body elements. "
                      f"These won't be visible in frame and may confuse the model. "
                      f"Remove references to body parts below the framing cutoff.",
        })

    if shot_type in ("LS", "WIDE") and re.search(r"\b(pore|eyelash|freckle|wrinkle)\b", text_lower):
        findings.append({
            "type": "framing_mismatch",
            "pattern": "micro_details_in_wide_shot",
            "risk": "low",
            "advice": f"Shot is {shot_type} but prompt describes micro-level facial "
                      f"details that won't be visible. Remove to reduce noise.",
        })

    # Generation approach considerations
    approach = shot.get("generation_approach", "")
    if approach == "triptych_split_flf":
        triptych = shot.get("triptych_prompt", "")
        if triptych and len(triptych.split()) > 250:
            findings.append({
                "type": "complexity",
                "pattern": "triptych_overloaded",
                "risk": "medium",
                "advice": "Triptych prompt is very dense. Each panel shares attention "
                          "with the others. Simplify per-panel descriptions for cleaner "
                          "panel separation.",
            })

    # Check for conflicting descriptors
    if re.search(r"\bbright\b", text_lower) and re.search(r"\bdark\b", text_lower):
        findings.append({
            "type": "contradiction",
            "pattern": "bright_and_dark",
            "risk": "low",
            "advice": "Prompt contains both 'bright' and 'dark' — ensure these "
                      "apply to different elements (e.g., bright accent in dark scene). "
                      "If contradictory, the model will average them.",
        })

    return findings


def scan_storyboard(storyboard: dict) -> dict:
    """Scan all shots in a storyboard for prompt risks.

    Returns structured report with per-shot findings and summary stats.
    """
    shots = storyboard.get("shots", [])
    all_findings = {}
    risk_counts = {"high": 0, "medium": 0, "low": 0}

    for shot in shots:
        shot_id = shot.get("id")
        shot_findings = []

        # Check first_frame
        first = shot.get("first_frame", "")
        if first:
            findings = scan_prompt(first, shot)
            for f in findings:
                f["source"] = "first_frame"
            shot_findings.extend(findings)

        # Check last_frame
        last = shot.get("last_frame", "")
        if last:
            findings = scan_prompt(last, shot)
            for f in findings:
                f["source"] = "last_frame"
            shot_findings.extend(findings)

        # Check triptych_prompt
        triptych = shot.get("triptych_prompt", "")
        if triptych:
            findings = scan_prompt(triptych, shot)
            for f in findings:
                f["source"] = "triptych_prompt"
            shot_findings.extend(findings)

        # Check hero_frame
        hero = shot.get("hero_frame", "")
        if hero:
            findings = scan_prompt(hero, shot)
            for f in findings:
                f["source"] = "hero_frame"
            shot_findings.extend(findings)

        if shot_findings:
            # Deduplicate by pattern name (same pattern from different sources = one finding)
            seen = set()
            deduped = []
            for f in shot_findings:
                key = f"{f['pattern']}_{f.get('source', '')}"
                if key not in seen:
                    seen.add(key)
                    deduped.append(f)

            all_findings[str(shot_id)] = {
                "shot_name": shot.get("name", ""),
                "shot_type": shot.get("shot_type", ""),
                "generation_approach": shot.get("generation_approach", ""),
                "findings": deduped,
            }

            for f in deduped:
                risk = f.get("risk", "low")
                if risk in risk_counts:
                    risk_counts[risk] += 1

    return {
        "total_shots": len(shots),
        "shots_with_findings": len(all_findings),
        "risk_counts": risk_counts,
        "shots": all_findings,
    }


# ── Gemini Deep Analysis ─────────────────────────────────────────────────

PROMPT_DOCTOR_PROMPT = """You are a prompt engineer specializing in AI image generation (Flux 2, Z-Image, SDXL family).

Analyze this generation prompt for a {shot_type} shot and identify issues that would cause generation failures.

SHOT CONTEXT:
- Shot type: {shot_type}
- Camera angle: {camera_angle}
- Generation approach: {generation_approach}
- Characters in shot: {characters}
- Emotion: {emotion}

FIRST FRAME PROMPT:
{first_frame}

LAST FRAME PROMPT:
{last_frame}

{triptych_section}

Evaluate for:
1. ANATOMICAL RISK: Hands, faces, body poses likely to deform. Rate 1-10 (10=safe, 1=will definitely fail).
2. COMPLEXITY: Too many elements competing for attention? Rate 1-10.
3. CLARITY: Is each element described specifically enough to produce consistent results? Rate 1-10.
4. NEGATIVE PROMPT GAPS: What should be in the negative prompt to prevent common failures?
5. PROMPT REWRITES: For any element scoring below 6, provide a specific rewrite that reduces risk.

Return ONLY valid JSON:
{{
  "anatomical_risk": <int 1-10>,
  "complexity": <int 1-10>,
  "clarity": <int 1-10>,
  "overall_risk": "low|medium|high|critical",
  "negative_prompt_additions": ["<suggested negative prompt terms>"],
  "rewrites": [
    {{
      "target": "first_frame|last_frame|triptych_prompt",
      "original_phrase": "<exact phrase to change>",
      "suggested_phrase": "<safer alternative>",
      "reason": "<why this reduces failure rate>"
    }}
  ],
  "warnings": ["<any other concerns>"],
  "notes": "<brief assessment>"
}}

Be practical. Don't flag things that work fine in practice. Focus on patterns you KNOW cause failures in diffusion models."""


def run_prompt_analysis(shot: dict, storyboard: dict,
                        tracker: Optional[CostTracker] = None,
                        episode: Optional[int] = None,
                        model_override: Optional[str] = None) -> dict:
    """Run Gemini deep analysis on a shot's prompts.

    Returns structured analysis with risk scores and suggested rewrites.
    """
    model_name = model_override or DEFAULT_MODEL

    try:
        import google.generativeai as genai
        api_key = os.environ.get("GOOGLE_API_KEY")
        if not api_key:
            print("ERROR: GOOGLE_API_KEY not set", file=sys.stderr)
            sys.exit(2)
        genai.configure(api_key=api_key)
        model = genai.GenerativeModel(model_name)
    except ImportError:
        print("ERROR: google-generativeai not installed", file=sys.stderr)
        sys.exit(2)

    triptych_section = ""
    if shot.get("triptych_prompt"):
        triptych_section = f"TRIPTYCH PROMPT:\n{shot['triptych_prompt']}"
    if shot.get("hero_frame"):
        triptych_section += f"\n\nHERO FRAME:\n{shot['hero_frame']}"

    chars_in_shot = shot.get("characters_in_shot", [])

    prompt = PROMPT_DOCTOR_PROMPT.format(
        shot_type=shot.get("shot_type", ""),
        camera_angle=shot.get("camera_angle", "eye"),
        generation_approach=shot.get("generation_approach", "standard_flf"),
        characters=", ".join(chars_in_shot) if chars_in_shot else "none",
        emotion=shot.get("emotion", ""),
        first_frame=shot.get("first_frame", "(empty)"),
        last_frame=shot.get("last_frame", "(empty)"),
        triptych_section=triptych_section or "(no triptych)",
    )

    shot_id = shot.get("id")
    t0 = time.time()

    try:
        response = model.generate_content(prompt)
        text = response.text.strip()

        # Strip code fences
        if text.startswith("```"):
            first_newline = text.index("\n") if "\n" in text else 3
            text = text[first_newline + 1:]
            if text.endswith("```"):
                text = text[:-3]
            text = text.strip()

        result = json.loads(text)

        # Extract token usage
        usage = {"tokens_in": 0, "tokens_out": 0}
        try:
            meta = response.usage_metadata
            if meta:
                usage["tokens_in"] = getattr(meta, "prompt_token_count", 0) or 0
                usage["tokens_out"] = getattr(meta, "candidates_token_count", 0) or 0
        except Exception:
            usage["tokens_in"] = 500
            usage["tokens_out"] = 300

    except Exception as e:
        elapsed_ms = int((time.time() - t0) * 1000)
        if tracker:
            tracker.log(
                category="qc", provider="gemini", model=model_name,
                tokens_in=0, tokens_out=0, duration_ms=elapsed_ms,
                episode=episode, shot_id=shot_id,
                detail=f"Prompt doctor — error: {str(e)[:100]}",
                success=False,
            )
        return {
            "shot_id": shot_id,
            "result": "error",
            "error": str(e),
        }

    elapsed_ms = int((time.time() - t0) * 1000)

    if tracker:
        tracker.log(
            category="qc", provider="gemini", model=model_name,
            tokens_in=usage["tokens_in"], tokens_out=usage["tokens_out"],
            duration_ms=elapsed_ms,
            episode=episode, shot_id=shot_id,
            detail=f"Prompt doctor — {result.get('overall_risk', '?')} risk",
            success=True,
        )

    return {
        "shot_id": shot_id,
        "shot_name": shot.get("name", ""),
        "shot_type": shot.get("shot_type", ""),
        "generation_approach": shot.get("generation_approach", ""),
        "analysis": result,
        "model": model_name,
    }


# ── Commands ──────────────────────────────────────────────────────────────

def find_project_root() -> Path:
    """Walk up from this file to find the Recoil project root."""
    candidate = Path(__file__).resolve().parent
    for _ in range(10):
        if (candidate / "tools").is_dir() and (candidate / "editors").is_dir():
            return candidate
        candidate = candidate.parent
    print("ERROR: Could not locate project root.", file=sys.stderr)
    sys.exit(2)


def load_storyboard(path: Path) -> dict:
    """Load storyboard JSON."""
    if not path.is_file():
        print(f"ERROR: Storyboard not found: {path}", file=sys.stderr)
        sys.exit(2)
    try:
        with open(path) as f:
            return json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON: {e}", file=sys.stderr)
        sys.exit(2)


def cmd_scan(args):
    """Heuristic prompt scan — no API calls, instant results."""
    project_root = find_project_root()
    project_dir = project_root / args.project
    ep_str = str(args.episode).zfill(3)

    sb_path = project_dir / "storyboards" / f"storyboard_ep_{ep_str}.json"
    storyboard = load_storyboard(sb_path)

    report = scan_storyboard(storyboard)
    report["episode"] = args.episode
    report["project"] = args.project
    report["mode"] = "heuristic_scan"
    report["timestamp"] = datetime.now(timezone.utc).isoformat()

    print(json.dumps(report, indent=2))

    # Summary
    rc = report["risk_counts"]
    print(f"\nEpisode {args.episode} Prompt Scan", file=sys.stderr)
    print(f"  {report['total_shots']} shots scanned", file=sys.stderr)
    print(f"  {report['shots_with_findings']} shots with findings", file=sys.stderr)
    print(f"  High risk: {rc['high']}  Medium: {rc['medium']}  Low: {rc['low']}",
          file=sys.stderr)

    if report["shots_with_findings"] > 0:
        print(f"\nShots with findings:", file=sys.stderr)
        for shot_id, data in report["shots"].items():
            high = sum(1 for f in data["findings"] if f["risk"] == "high")
            med = sum(1 for f in data["findings"] if f["risk"] == "medium")
            risk_label = "HIGH" if high else "MED" if med else "LOW"
            print(f"  Shot {shot_id} ({data['shot_name']}): "
                  f"{len(data['findings'])} findings [{risk_label}]",
                  file=sys.stderr)

    return 1 if rc["high"] > 0 else 0


def cmd_check(args):
    """Deep Gemini analysis of a single shot's prompts."""
    shot = {}
    storyboard = {}

    if args.storyboard:
        storyboard = load_storyboard(Path(args.storyboard))
        for s in storyboard.get("shots", []):
            if s.get("id") == args.shot_id:
                shot = s
                break

    if not shot:
        print(f"ERROR: Shot {args.shot_id} not found in storyboard", file=sys.stderr)
        return 2

    # Find tracker
    tracker = None
    sb_path = Path(args.storyboard).resolve()
    for parent in sb_path.parents:
        if (parent / "visual").is_dir() and (parent / "treatment.md").is_file():
            tracker = CostTracker(parent)
            break

    # Run heuristic scan first (free)
    first_frame = shot.get("first_frame", "")
    heuristic = scan_prompt(first_frame, shot) if first_frame else []

    # Run Gemini analysis
    print(f"Analyzing shot {args.shot_id} ({shot.get('name', '')})...",
          file=sys.stderr)

    analysis = run_prompt_analysis(
        shot, storyboard,
        tracker=tracker,
        episode=args.episode,
        model_override=args.model,
    )

    # Merge heuristic + Gemini findings
    output = {
        "shot_id": args.shot_id,
        "shot_name": shot.get("name", ""),
        "heuristic_findings": heuristic,
        "gemini_analysis": analysis.get("analysis", {}),
        "model": analysis.get("model", ""),
        "timestamp": datetime.now(timezone.utc).isoformat(),
    }

    print(json.dumps(output, indent=2))

    risk = analysis.get("analysis", {}).get("overall_risk", "low")
    print(f"\nRisk level: {risk.upper()}", file=sys.stderr)

    rewrites = analysis.get("analysis", {}).get("rewrites", [])
    if rewrites:
        print(f"  {len(rewrites)} prompt rewrite(s) suggested", file=sys.stderr)

    return 1 if risk in ("high", "critical") else 0


def cmd_batch(args):
    """Deep Gemini analysis of all shots in an episode."""
    project_root = find_project_root()
    project_dir = project_root / args.project
    ep_str = str(args.episode).zfill(3)

    sb_path = project_dir / "storyboards" / f"storyboard_ep_{ep_str}.json"
    storyboard = load_storyboard(sb_path)
    shots = storyboard.get("shots", [])

    if not shots:
        print("ERROR: No shots", file=sys.stderr)
        return 2

    tracker = CostTracker(project_dir)

    # Run heuristic scan first (free, identifies which shots need deep analysis)
    heuristic_report = scan_storyboard(storyboard)
    high_risk_shots = set()
    for shot_id, data in heuristic_report.get("shots", {}).items():
        if any(f["risk"] == "high" for f in data["findings"]):
            high_risk_shots.add(int(shot_id))

    # Deep analysis: all shots if --all, otherwise only high-risk shots
    analyze_all = args.all_shots
    results = {}
    risk_summary = {"low": 0, "medium": 0, "high": 0, "critical": 0}

    for i, shot in enumerate(shots):
        shot_id = shot.get("id", i + 1)

        if not analyze_all and shot_id not in high_risk_shots:
            continue

        print(f"  Shot {shot_id} ({shot.get('name', '')}): ",
              end="", file=sys.stderr, flush=True)

        analysis = run_prompt_analysis(
            shot, storyboard,
            tracker=tracker,
            episode=args.episode,
            model_override=args.model,
        )

        risk = analysis.get("analysis", {}).get("overall_risk", "low")
        rewrites = len(analysis.get("analysis", {}).get("rewrites", []))
        print(f"{risk.upper()} ({rewrites} rewrites)", file=sys.stderr)

        results[str(shot_id)] = analysis
        if risk in risk_summary:
            risk_summary[risk] += 1

        # Rate limit
        time.sleep(2)

    output = {
        "episode": args.episode,
        "project": args.project,
        "mode": "deep_analysis" if analyze_all else "targeted_analysis",
        "heuristic_summary": {
            "total_shots": heuristic_report["total_shots"],
            "shots_with_findings": heuristic_report["shots_with_findings"],
            "risk_counts": heuristic_report["risk_counts"],
        },
        "deep_analysis": {
            "shots_analyzed": len(results),
            "risk_summary": risk_summary,
            "shots": results,
        },
        "timestamp": datetime.now(timezone.utc).isoformat(),
    }

    # Save report
    reviews_dir = project_dir / "storyboards" / "reviews"
    reviews_dir.mkdir(parents=True, exist_ok=True)
    output_path = reviews_dir / f"prompt_doctor_ep_{ep_str}.json"
    with open(output_path, "w") as f:
        json.dump(output, f, indent=2)

    print(json.dumps(output, indent=2))

    analyzed = len(results)
    print(f"\nEpisode {args.episode} Prompt Doctor", file=sys.stderr)
    print(f"  Heuristic scan: {heuristic_report['shots_with_findings']}"
          f"/{heuristic_report['total_shots']} shots flagged", file=sys.stderr)
    print(f"  Deep analysis: {analyzed} shots", file=sys.stderr)
    print(f"  Risk: {risk_summary}", file=sys.stderr)
    print(f"Saved to: {output_path}", file=sys.stderr)

    return 1 if risk_summary.get("high", 0) + risk_summary.get("critical", 0) > 0 else 0


# ── CLI ───────────────────────────────────────────────────────────────────

def build_parser():
    parser = argparse.ArgumentParser(
        description="Prompt Doctor — pre-generation prompt analysis and risk assessment",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
    )

    sub = parser.add_subparsers(dest="command", help="Analysis mode")

    # scan — heuristic only, no API
    scn = sub.add_parser("scan", help="Heuristic prompt scan (free, instant)")
    scn.add_argument("--project", required=True, help="Project name")
    scn.add_argument("--episode", type=int, required=True, help="Episode number")

    # check — single shot deep analysis
    chk = sub.add_parser("check", help="Deep analysis of a single shot")
    chk.add_argument("--storyboard", required=True, help="Storyboard JSON path")
    chk.add_argument("--shot-id", type=int, required=True, help="Shot ID")
    chk.add_argument("--episode", type=int, help="Episode number")
    chk.add_argument("--model", help="Override Gemini model")

    # batch — episode-wide analysis
    btch = sub.add_parser("batch", help="Analyze all shots in an episode")
    btch.add_argument("--project", required=True, help="Project name")
    btch.add_argument("--episode", type=int, required=True, help="Episode number")
    btch.add_argument("--model", help="Override Gemini model")
    btch.add_argument("--all-shots", action="store_true",
                      help="Analyze all shots (default: only heuristic high-risk)")

    return parser


def main():
    parser = build_parser()
    args = parser.parse_args()

    if not args.command:
        parser.print_help()
        return 0

    commands = {
        "scan": cmd_scan,
        "check": cmd_check,
        "batch": cmd_batch,
    }

    handler = commands.get(args.command)
    if not handler:
        parser.print_help()
        return 2

    return handler(args)


if __name__ == "__main__":
    sys.exit(main())
