#!/usr/bin/env python3
"""
visual_qc.py — Visual Quality Control Feedback Loop

Reviews AI-generated images against reference sheets, descriptions, and prompts
using Claude or Gemini vision APIs. Returns structured scoring with pass/fail
per image and specific fix notes.

Pipeline integration points:
  1. Reference Workshop — review character ref sheets for consistency
  2. Storyboard Editor — review generated panels against locked refs
  3. Video Generation — review sampled frames for identity drift

Usage:
    # Review a single character reference set (5 angles)
    python3 visual_qc.py ref-check \\
        --images front.png profile.png three_quarter.png full_body.png back.png \\
        --description "Calloused hands, amber debt counter on left wrist..." \\
        --prompt "Young woman salvager, patched cargo pants..." \\
        --character JINX --variant lower_deck_salvager

    # Review a storyboard panel against locked references
    python3 visual_qc.py panel-check \\
        --panel panel_shot_03.png \\
        --refs front.png profile.png three_quarter.png full_body.png back.png \\
        --shot-description "MCU of Jinx crouching in maintenance shaft" \\
        --character JINX

    # Review video frames for identity drift
    python3 visual_qc.py video-check \\
        --frames frame_000.png frame_030.png frame_060.png frame_090.png \\
        --refs front.png profile.png three_quarter.png full_body.png back.png \\
        --panel storyboard_panel.png \\
        --character JINX

    # Batch review from breakdown.json
    python3 visual_qc.py batch-check \\
        --breakdown /leviathan/visual/breakdown.json \\
        --refs-dir /leviathan/visual/refs/ \\
        --character JINX

    # Use Gemini instead of Claude (for large batches)
    python3 visual_qc.py ref-check ... --model gemini

Exit codes: 0 = all pass, 1 = some failures, 2 = error
"""

import argparse
import base64
import json
import os
import sys
import time
from pathlib import Path
from datetime import datetime, timezone
from typing import List

from cost_tracker import CostTracker


# ── API Configuration ─────────────────────────────────────────────────────

def get_anthropic_client():
    """Initialize Anthropic client. Requires ANTHROPIC_API_KEY env var."""
    try:
        import anthropic
        api_key = os.environ.get("ANTHROPIC_API_KEY")
        if not api_key:
            print("ERROR: ANTHROPIC_API_KEY not set", file=sys.stderr)
            sys.exit(2)
        return anthropic.Anthropic(api_key=api_key)
    except ImportError:
        print("ERROR: anthropic package not installed. Run: pip install anthropic", file=sys.stderr)
        sys.exit(2)


def get_gemini_client():
    """Initialize Gemini client. Requires GOOGLE_API_KEY env var."""
    try:
        import google.generativeai as genai
        api_key = os.environ.get("GOOGLE_API_KEY")
        if not api_key:
            print("ERROR: GOOGLE_API_KEY not set", file=sys.stderr)
            sys.exit(2)
        genai.configure(api_key=api_key)
        return genai.GenerativeModel("gemini-2.0-flash")
    except ImportError:
        print("ERROR: google-generativeai package not installed. Run: pip install google-generativeai", file=sys.stderr)
        sys.exit(2)


# ── Image Encoding ────────────────────────────────────────────────────────

def encode_image_base64(path: str) -> str:
    """Read image file and return base64 encoded string."""
    with open(path, "rb") as f:
        return base64.standard_b64encode(f.read()).decode("utf-8")


def get_media_type(path: str) -> str:
    """Determine media type from file extension."""
    ext = Path(path).suffix.lower()
    types = {
        ".png": "image/png",
        ".jpg": "image/jpeg",
        ".jpeg": "image/jpeg",
        ".webp": "image/webp",
        ".gif": "image/gif",
    }
    return types.get(ext, "image/png")


# ── Prompt Templates ──────────────────────────────────────────────────────

REF_CHECK_PROMPT = """You are a visual QC reviewer for an AI video production pipeline. Review these character reference images for quality and consistency.

The FRONT image is the ANCHOR — it is the locked visual ground truth. Every other angle must be identical to the front in all respects except camera position.

CHARACTER: {character}
WARDROBE VARIANT: {variant}
VISUAL DESCRIPTION: {description}
GENERATION PROMPT: {prompt}

The images are labeled: {angle_labels}

Evaluate each image against the FRONT anchor. Return a JSON object with this exact structure:

{{
  "overall_pass": true/false,
  "identity_consistency": {{
    "score": 1-10,
    "notes": "Are all images unmistakably the same person? Same face shape, skin tone, bone structure, age, build. Any drift = fail."
  }},
  "wardrobe_consistency": {{
    "score": 1-10,
    "notes": "Exact same clothing across all angles? Same garment type, cut, fit, damage/wear state, layering, fasteners, pockets. Any deviation from the front anchor = fail."
  }},
  "color_consistency": {{
    "score": 1-10,
    "notes": "Identical colors across all angles? Fabric color, material sheen, metal tones, leather/rubber hue. Compare each angle to the front anchor — any color shift = fail."
  }},
  "hair_makeup_consistency": {{
    "score": 1-10,
    "notes": "Exact same hair color, length, style, parting, texture? Same makeup, scars, tattoos, facial hair? Compare each angle to the front anchor — any change = fail."
  }},
  "props_accessories_consistency": {{
    "score": 1-10,
    "notes": "Same accessories, jewelry, belts, holsters, tools, weapons, implants, wrist devices? Same position on body? Any missing or added items vs front anchor = fail."
  }},
  "angle_accuracy": {{
    "front": {{ "correct": true/false, "note": "" }},
    "profile": {{ "correct": true/false, "note": "" }},
    "three_quarter": {{ "correct": true/false, "note": "" }},
    "full_body": {{ "correct": true/false, "note": "" }},
    "back": {{ "correct": true/false, "note": "" }}
  }},
  "artifacts": {{
    "extra_fingers": false,
    "face_asymmetry": false,
    "text_artifacts": false,
    "body_proportion_errors": false,
    "other": []
  }},
  "per_image": [
    {{
      "angle": "front",
      "pass": true/false,
      "deviation_from_anchor": "What differs from the front anchor (or 'N/A — this IS the anchor')",
      "issues": []
    }}
  ],
  "fix_suggestions": [
    "Specific actionable suggestions for re-generation if any images fail"
  ]
}}

SCORING RULES — be extremely strict:
- identity_consistency: 9+ to pass. All angles must be unmistakably the same person.
- wardrobe_consistency: 9+ to pass. Identical clothing in every angle. A different neckline, missing pocket, or changed sleeve length = fail.
- color_consistency: 9+ to pass. No color drift between angles. A warmer tone, shifted fabric hue, or different metal finish = fail.
- hair_makeup_consistency: 9+ to pass. Identical hair and makeup. A different part, changed length, or missing scar = fail.
- props_accessories_consistency: 9+ to pass. Same items in same positions. A missing belt, added earring, or shifted holster = fail.
- overall_pass is true ONLY if ALL five consistency scores are 9+ AND all per-image checks pass AND no major artifacts.

These reference sheets are the visual ground truth for every downstream frame and video. If the angles don't match perfectly, they're useless.

Return ONLY valid JSON, no markdown formatting."""


PANEL_CHECK_PROMPT = """You are a visual QC reviewer for an AI video production pipeline. Review this generated storyboard panel against the locked character references.

CHARACTER: {character}
SHOT DESCRIPTION: {shot_description}

The first {ref_count} images are the LOCKED REFERENCE IMAGES (identity truth).
The last image is the GENERATED PANEL to review.

Evaluate the generated panel against the references. Return a JSON object:

{{
  "pass": true/false,
  "identity_match": {{
    "score": 1-10,
    "notes": "Does the character in the panel match the reference identity?"
  }},
  "scene_accuracy": {{
    "score": 1-10,
    "notes": "Does the scene match the shot description?"
  }},
  "lighting_match": {{
    "score": 1-10,
    "notes": "Is the lighting consistent with the scene description?"
  }},
  "artifacts": {{
    "extra_fingers": false,
    "face_deformation": false,
    "body_proportion_errors": false,
    "text_artifacts": false,
    "other": []
  }},
  "fix_suggestions": [
    "Specific prompt adjustments to fix issues"
  ]
}}

pass is true only if identity_match >= 7 AND scene_accuracy >= 7 AND no major artifacts.

Return ONLY valid JSON, no markdown formatting."""


VIDEO_CHECK_PROMPT = """You are a visual QC reviewer for an AI video production pipeline. Review these sampled video frames for identity drift and quality.

CHARACTER: {character}
SHOT DESCRIPTION: {shot_description}

Images in order:
- First {ref_count} images: LOCKED REFERENCE IMAGES (identity truth)
- Next image: STORYBOARD PANEL (target composition)
- Remaining images: SAMPLED VIDEO FRAMES (to review) at {frame_labels}

Evaluate the video frames for temporal consistency and identity preservation. Return a JSON object:

{{
  "pass": true/false,
  "identity_drift": {{
    "score": 1-10,
    "worst_frame": "frame label with most drift",
    "notes": "Does identity hold across all frames?"
  }},
  "temporal_coherence": {{
    "score": 1-10,
    "notes": "Are there flickering, morphing, or discontinuity artifacts?"
  }},
  "composition_match": {{
    "score": 1-10,
    "notes": "Do frames match the storyboard panel composition?"
  }},
  "per_frame": [
    {{
      "frame": "frame label",
      "identity_score": 1-10,
      "quality_score": 1-10,
      "issues": []
    }}
  ],
  "fix_suggestions": [
    "Specific parameter adjustments for re-generation"
  ]
}}

pass is true only if identity_drift >= 7 AND temporal_coherence >= 7.

Return ONLY valid JSON, no markdown formatting."""


# ── API Calls ─────────────────────────────────────────────────────────────

def call_claude_vision(client, prompt: str, image_paths: List[str]) -> tuple:
    """Send images + prompt to Claude vision API and parse JSON response.

    Returns:
        (parsed_result, usage_info) where usage_info has provider, model,
        tokens_in, tokens_out, and duration_ms.
    """
    content = []
    for path in image_paths:
        b64 = encode_image_base64(path)
        media_type = get_media_type(path)
        content.append({
            "type": "image",
            "source": {
                "type": "base64",
                "media_type": media_type,
                "data": b64,
            }
        })
    content.append({"type": "text", "text": prompt})

    t0 = time.time()
    response = client.messages.create(
        model="claude-sonnet-4-20250514",
        max_tokens=4096,
        messages=[{"role": "user", "content": content}]
    )
    elapsed_ms = int((time.time() - t0) * 1000)

    # Extract token usage
    tokens_in = 0
    tokens_out = 0
    if hasattr(response, "usage") and response.usage:
        tokens_in = getattr(response.usage, "input_tokens", 0) or 0
        tokens_out = getattr(response.usage, "output_tokens", 0) or 0

    text = response.content[0].text.strip()
    # Strip markdown code fences if present
    if text.startswith("```"):
        text = text.split("\n", 1)[1]
        if text.endswith("```"):
            text = text[:-3]
        text = text.strip()

    usage_info = {
        "provider": "anthropic",
        "model": "claude-sonnet",
        "tokens_in": tokens_in,
        "tokens_out": tokens_out,
        "duration_ms": elapsed_ms,
    }

    return json.loads(text), usage_info


def call_gemini_vision(model, prompt: str, image_paths: List[str]) -> tuple:
    """Send images + prompt to Gemini vision API and parse JSON response.

    Returns:
        (parsed_result, usage_info) where usage_info has provider, model,
        tokens_in, tokens_out, and duration_ms.
    """
    import PIL.Image
    parts = []
    for path in image_paths:
        img = PIL.Image.open(path)
        parts.append(img)
    parts.append(prompt)

    t0 = time.time()
    response = model.generate_content(parts)
    elapsed_ms = int((time.time() - t0) * 1000)

    # Extract token usage from response metadata
    tokens_in = 0
    tokens_out = 0
    if hasattr(response, "usage_metadata") and response.usage_metadata:
        tokens_in = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
        tokens_out = getattr(response.usage_metadata, "candidates_token_count", 0) or 0

    text = response.text.strip()
    if text.startswith("```"):
        text = text.split("\n", 1)[1]
        if text.endswith("```"):
            text = text[:-3]
        text = text.strip()

    usage_info = {
        "provider": "gemini",
        "model": "gemini-2.0-flash",
        "tokens_in": tokens_in,
        "tokens_out": tokens_out,
        "duration_ms": elapsed_ms,
    }

    return json.loads(text), usage_info


def call_vision(prompt: str, image_paths: List[str], model_choice: str = "claude") -> tuple:
    """Route to appropriate vision API.

    Returns:
        (parsed_result, usage_info) tuple from the underlying API call.
    """
    if model_choice == "gemini":
        client = get_gemini_client()
        return call_gemini_vision(client, prompt, image_paths)
    else:
        client = get_anthropic_client()
        return call_claude_vision(client, prompt, image_paths)


# ── QC Commands ───────────────────────────────────────────────────────────

def cmd_ref_check(args, tracker=None):
    """Review a set of character reference images."""
    images = args.images
    if not images or len(images) == 0:
        print("ERROR: --images required (1-5 image paths)", file=sys.stderr)
        return 2

    for img in images:
        if not os.path.exists(img):
            print(f"ERROR: Image not found: {img}", file=sys.stderr)
            return 2

    angle_labels = ["front", "profile", "three_quarter", "full_body", "back"][:len(images)]
    labels_str = ", ".join(f"Image {i+1} = {a}" for i, a in enumerate(angle_labels))

    prompt = REF_CHECK_PROMPT.format(
        character=args.character or "Unknown",
        variant=args.variant or "default",
        description=args.description or "(no description provided)",
        prompt=args.prompt or "(no prompt provided)",
        angle_labels=labels_str,
    )

    print(f"Reviewing {len(images)} reference images for {args.character}...", file=sys.stderr)
    result, usage = call_vision(prompt, images, args.model)

    # Log cost
    passed = result.get("overall_pass", False)
    if tracker:
        tracker.log(
            category="qc",
            provider=usage["provider"],
            model=usage["model"],
            tokens_in=usage["tokens_in"],
            tokens_out=usage["tokens_out"],
            duration_ms=usage["duration_ms"],
            detail=f"Visual QC ref-check: {args.character}/{args.variant or 'default'} ({len(images)} images)",
            success=passed,
        )

    # Output
    print(json.dumps(result, indent=2))

    # Report
    id_score = result.get("identity_consistency", {}).get("score", 0)
    print(f"\n{'PASS' if passed else 'FAIL'} — Identity consistency: {id_score}/10", file=sys.stderr)

    missing = result.get("description_match", {}).get("missing", [])
    if missing:
        print(f"Missing elements: {', '.join(missing)}", file=sys.stderr)

    suggestions = result.get("fix_suggestions", [])
    if suggestions:
        print("Fix suggestions:", file=sys.stderr)
        for s in suggestions:
            print(f"  - {s}", file=sys.stderr)

    return 0 if passed else 1


def cmd_panel_check(args, tracker=None):
    """Review a generated storyboard panel against locked references."""
    if not args.panel or not os.path.exists(args.panel):
        print(f"ERROR: Panel image not found: {args.panel}", file=sys.stderr)
        return 2

    refs = args.refs or []
    for r in refs:
        if not os.path.exists(r):
            print(f"ERROR: Reference image not found: {r}", file=sys.stderr)
            return 2

    # Refs first, then panel
    all_images = refs + [args.panel]

    prompt = PANEL_CHECK_PROMPT.format(
        character=args.character or "Unknown",
        shot_description=args.shot_description or "(no description)",
        ref_count=len(refs),
    )

    print(f"Reviewing panel for {args.character}...", file=sys.stderr)
    result, usage = call_vision(prompt, all_images, args.model)

    # Log cost
    passed = result.get("pass", False)
    if tracker:
        tracker.log(
            category="qc",
            provider=usage["provider"],
            model=usage["model"],
            tokens_in=usage["tokens_in"],
            tokens_out=usage["tokens_out"],
            duration_ms=usage["duration_ms"],
            detail=f"Visual QC panel-check: {args.character}",
            success=passed,
        )

    print(json.dumps(result, indent=2))

    id_score = result.get("identity_match", {}).get("score", 0)
    scene_score = result.get("scene_accuracy", {}).get("score", 0)
    print(f"\n{'PASS' if passed else 'FAIL'} — Identity: {id_score}/10, Scene: {scene_score}/10", file=sys.stderr)

    return 0 if passed else 1


def cmd_video_check(args, tracker=None):
    """Review sampled video frames for identity drift."""
    refs = args.refs or []
    frames = args.frames or []
    panel = args.panel

    all_paths = refs[:]
    if panel:
        all_paths.append(panel)
    all_paths.extend(frames)

    for p in all_paths:
        if not os.path.exists(p):
            print(f"ERROR: Image not found: {p}", file=sys.stderr)
            return 2

    frame_labels = [f"frame_{i}" for i in range(len(frames))]

    prompt = VIDEO_CHECK_PROMPT.format(
        character=args.character or "Unknown",
        shot_description=args.shot_description or "(no description)",
        ref_count=len(refs),
        frame_labels=", ".join(frame_labels),
    )

    print(f"Reviewing {len(frames)} video frames for {args.character}...", file=sys.stderr)
    result, usage = call_vision(prompt, all_paths, args.model)

    # Log cost
    passed = result.get("pass", False)
    if tracker:
        tracker.log(
            category="qc",
            provider=usage["provider"],
            model=usage["model"],
            tokens_in=usage["tokens_in"],
            tokens_out=usage["tokens_out"],
            duration_ms=usage["duration_ms"],
            detail=f"Visual QC video-check: {args.character} ({len(frames)} frames)",
            success=passed,
        )

    print(json.dumps(result, indent=2))

    drift_score = result.get("identity_drift", {}).get("score", 0)
    coherence = result.get("temporal_coherence", {}).get("score", 0)
    print(f"\n{'PASS' if passed else 'FAIL'} — Identity drift: {drift_score}/10, Coherence: {coherence}/10", file=sys.stderr)

    return 0 if passed else 1


def cmd_batch_check(args, tracker=None):
    """Batch review all images for a character from breakdown.json."""
    if not args.breakdown or not os.path.exists(args.breakdown):
        print(f"ERROR: breakdown.json not found: {args.breakdown}", file=sys.stderr)
        return 2

    with open(args.breakdown, "r") as f:
        try:
            breakdown = json.load(f)
        except json.JSONDecodeError as e:
            print(f"ERROR: Corrupted breakdown.json: {e}", file=sys.stderr)
            return 2

    char_key = args.character
    char_data = (breakdown.get("characters") or {}).get(char_key)
    if not char_data:
        print(f"ERROR: Character {char_key} not found in breakdown.json", file=sys.stderr)
        return 2

    refs_dir = Path(args.refs_dir) if args.refs_dir else Path(args.breakdown).parent / "refs"
    description = char_data.get("visual_description", "")
    wardrobe = char_data.get("wardrobe", {})

    results = {}
    any_fail = False

    for variant_name, variant_data in wardrobe.items():
        variant_dir = refs_dir / "characters" / char_key / variant_name
        if not variant_dir.exists():
            print(f"SKIP: {variant_name} — no images at {variant_dir}", file=sys.stderr)
            continue

        images = []
        angles = ["front", "profile", "three_quarter", "full_body", "back"]
        found_angles = []
        for angle in angles:
            img_path = variant_dir / f"{angle}.png"
            if img_path.exists():
                images.append(str(img_path))
                found_angles.append(angle)

        if not images:
            print(f"SKIP: {variant_name} — no PNG images found", file=sys.stderr)
            continue

        labels_str = ", ".join(f"Image {i+1} = {a}" for i, a in enumerate(found_angles))

        # Build prompt from variant description
        variant_desc = variant_data.get("description", description)
        ref_prompt = (char_data.get("prompts", {}).get("reference") or "")

        prompt = REF_CHECK_PROMPT.format(
            character=char_key,
            variant=variant_name,
            description=variant_desc,
            prompt=f"(from breakdown.json prompts)",
            angle_labels=labels_str,
        )

        print(f"\nReviewing {char_key}/{variant_name} ({len(images)} images)...", file=sys.stderr)

        try:
            result, usage = call_vision(prompt, images, args.model)
            results[variant_name] = result
            passed = result.get("overall_pass", False)
            id_score = result.get("identity_consistency", {}).get("score", 0)
            print(f"  {'PASS' if passed else 'FAIL'} — Identity: {id_score}/10", file=sys.stderr)
            if not passed:
                any_fail = True

            # Log cost
            if tracker:
                tracker.log(
                    category="qc",
                    provider=usage["provider"],
                    model=usage["model"],
                    tokens_in=usage["tokens_in"],
                    tokens_out=usage["tokens_out"],
                    duration_ms=usage["duration_ms"],
                    detail=f"Visual QC batch-check: {char_key}/{variant_name} ({len(images)} images)",
                    success=passed,
                )
        except Exception as e:
            print(f"  ERROR: {e}", file=sys.stderr)
            results[variant_name] = {"error": str(e)}
            any_fail = True

    # Output all results
    output = {
        "character": char_key,
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "model": args.model,
        "variants": results,
    }
    print(json.dumps(output, indent=2))

    return 1 if any_fail else 0


# ── CLI ───────────────────────────────────────────────────────────────────

def build_parser():
    parser = argparse.ArgumentParser(
        description="Visual QC — Review AI-generated images against references",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__
    )
    parser.add_argument("--model", choices=["claude", "gemini"], default="claude",
                        help="Vision API to use (default: claude)")
    parser.add_argument("--project", default=None,
                        help="Project path for cost tracking (e.g., leviathan/)")

    sub = parser.add_subparsers(dest="command", help="QC mode")

    # ref-check
    ref = sub.add_parser("ref-check", help="Review character reference image set")
    ref.add_argument("--images", nargs="+", required=True, help="1-4 reference image paths")
    ref.add_argument("--character", required=True, help="Character key (e.g., JINX)")
    ref.add_argument("--variant", default="default", help="Wardrobe variant name")
    ref.add_argument("--description", default="", help="Visual description from breakdown.json")
    ref.add_argument("--prompt", default="", help="MJ prompt used to generate the images")

    # panel-check
    panel = sub.add_parser("panel-check", help="Review storyboard panel against refs")
    panel.add_argument("--panel", required=True, help="Generated panel image path")
    panel.add_argument("--refs", nargs="+", default=[], help="Locked reference image paths")
    panel.add_argument("--character", required=True, help="Character key")
    panel.add_argument("--shot-description", default="", help="Shot description from storyboard")

    # video-check
    video = sub.add_parser("video-check", help="Review video frames for identity drift")
    video.add_argument("--frames", nargs="+", required=True, help="Sampled frame paths")
    video.add_argument("--refs", nargs="+", default=[], help="Locked reference image paths")
    video.add_argument("--panel", default=None, help="Storyboard panel (target composition)")
    video.add_argument("--character", required=True, help="Character key")
    video.add_argument("--shot-description", default="", help="Shot description")

    # batch-check
    batch = sub.add_parser("batch-check", help="Batch review from breakdown.json")
    batch.add_argument("--breakdown", required=True, help="Path to breakdown.json")
    batch.add_argument("--refs-dir", default=None, help="Path to refs/ directory")
    batch.add_argument("--character", required=True, help="Character key to review")

    return parser


def main():
    parser = build_parser()
    args = parser.parse_args()

    if not args.command:
        parser.print_help()
        return 0

    # Initialize cost tracker if project path provided
    tracker = None
    if args.project:
        try:
            tracker = CostTracker(args.project)
        except Exception as e:
            print(f"WARNING: Could not initialize cost tracker: {e}", file=sys.stderr)

    commands = {
        "ref-check": cmd_ref_check,
        "panel-check": cmd_panel_check,
        "video-check": cmd_video_check,
        "batch-check": cmd_batch_check,
    }

    handler = commands.get(args.command)
    if not handler:
        parser.print_help()
        return 2

    return handler(args, tracker=tracker)


if __name__ == "__main__":
    sys.exit(main())
