#!/usr/bin/env python3
"""
Motion Complexity Classifier for Storyboard Shots.

Analyzes each shot in a storyboard and tags it with a motion complexity
rating (LOW / MEDIUM / HIGH). Informs generation routing and cost estimation.

Usage:
  python3 motion_complexity.py <storyboard.json>           # Report only
  python3 motion_complexity.py <storyboard.json> --enrich  # Write motion_complexity to JSON
  python3 motion_complexity.py <storyboard.json> --json    # Machine-readable output
  python3 motion_complexity.py <storyboard.json> --suggest  # Suggest generation_approach changes

Exit codes:
  0 = success
  2 = file/parse error
"""

import argparse
import json
import os
import re
import sys
from copy import deepcopy


# ── Complexity Keywords ──
# Each keyword adds to the complexity score. Thresholds determine final rating.

CAMERA_MOTION_KEYWORDS = {
    # High complexity camera moves
    "track": 2,
    "tracking": 2,
    "follows": 2,
    "orbits": 3,
    "orbiting": 3,
    "crane": 2,
    "crane up": 3,
    "crane down": 3,
    "steadicam": 2,
    # Medium complexity
    "dolly": 1,
    "dolly in": 1,
    "dolly out": 1,
    "push in": 1,
    "pull back": 1,
    "pan": 1,
    "handheld": 1,
    # Low / none
    "static": 0,
}

CHARACTER_ROTATION_KEYWORDS = {
    "turns": 2,
    "turning": 2,
    "spins": 3,
    "spinning": 3,
    "whips around": 3,
    "faces away": 2,
    "faces toward": 2,
    "rotates": 2,
    "pivots": 2,
    "looks back": 1,
    "looks over shoulder": 1,
}

PHYSICS_KEYWORDS = {
    "throws": 2,
    "catches": 2,
    "falls": 2,
    "falling": 2,
    "crashes": 3,
    "crashing": 3,
    "collapses": 2,
    "lunges": 2,
    "leaps": 2,
    "jumps": 2,
    "drops": 1,
    "slides": 1,
    "slams": 2,
    "smashes": 3,
    "kicks": 2,
    "punches": 2,
    "strikes": 2,
    "grabs": 1,
    "shoves": 2,
    "wrestles": 3,
    "lifts": 1,
    "carries": 1,
}

ENVIRONMENTAL_KEYWORDS = {
    "crowd": 2,
    "explosion": 3,
    "debris": 2,
    "particles cascading": 2,
    "smoke": 1,
    "fire": 2,
    "water": 1,
    "rain": 1,
    "sparks": 1,
    "dust": 1,
    "steam": 1,
    "fog": 0,
    "wind": 1,
    "waves": 2,
    "lightning": 2,
    "erupts": 2,
    "shatters": 2,
}

# Thresholds
LOW_MAX = 2      # Score 0-2 = LOW
MEDIUM_MAX = 5   # Score 3-5 = MEDIUM
                  # Score 6+  = HIGH


def _text_for_shot(shot):
    """Combine all text fields for keyword analysis."""
    fields = [
        shot.get("action", ""),
        shot.get("motion_prompt", ""),
        shot.get("first_frame", ""),
        shot.get("last_frame", ""),
        shot.get("hero_frame", "") or "",
        shot.get("triptych_prompt", "") or "",
        shot.get("hero_action", "") or "",
        shot.get("anticipation_action", "") or "",
        shot.get("aftermath_action", "") or "",
        shot.get("description", ""),
        shot.get("atmosphere", ""),
    ]
    return " ".join(f.lower() for f in fields if f)


def _score_keywords(text, keyword_dict):
    """Score a text against a keyword dictionary. Returns (total_score, matched_keywords)."""
    score = 0
    matched = []
    for keyword, weight in sorted(keyword_dict.items(), key=lambda x: -len(x[0])):
        # Use word boundary matching to avoid partial matches
        pattern = r"\b" + re.escape(keyword) + r"\b"
        if re.search(pattern, text):
            score += weight
            matched.append((keyword, weight))
    return score, matched


def classify_shot(shot):
    """Classify a single shot's motion complexity.

    Returns:
        dict with keys:
          - complexity: "LOW" | "MEDIUM" | "HIGH"
          - score: int (raw complexity score)
          - factors: list of (factor_name, detail, weight) tuples
          - suggested_approach: recommended generation_approach
    """
    text = _text_for_shot(shot)
    total_score = 0
    factors = []

    # Camera movement (from enum field + text analysis)
    camera_move = shot.get("camera_movement", "static")
    if camera_move in ("track", "crane"):
        total_score += 2
        factors.append(("camera_movement", camera_move, 2))
    elif camera_move in ("dolly", "pan", "handheld"):
        total_score += 1
        factors.append(("camera_movement", camera_move, 1))

    # Camera movement keywords in text
    cam_score, cam_matched = _score_keywords(text, CAMERA_MOTION_KEYWORDS)
    if cam_score > 0:
        total_score += min(cam_score, 3)  # Cap at 3 to avoid double-counting
        for kw, w in cam_matched[:3]:
            factors.append(("camera_keyword", kw, w))

    # Multi-character
    chars = shot.get("characters_in_shot", [])
    if len(chars) >= 3:
        total_score += 3
        factors.append(("multi_character", f"{len(chars)} characters", 3))
    elif len(chars) == 2:
        total_score += 2
        factors.append(("multi_character", f"2 characters", 2))

    # Character rotation
    rot_score, rot_matched = _score_keywords(text, CHARACTER_ROTATION_KEYWORDS)
    if rot_score > 0:
        total_score += min(rot_score, 3)
        for kw, w in rot_matched[:2]:
            factors.append(("character_rotation", kw, w))

    # Physics interactions
    phys_score, phys_matched = _score_keywords(text, PHYSICS_KEYWORDS)
    if phys_score > 0:
        total_score += min(phys_score, 4)
        for kw, w in phys_matched[:3]:
            factors.append(("physics", kw, w))

    # Environmental complexity
    env_score, env_matched = _score_keywords(text, ENVIRONMENTAL_KEYWORDS)
    if env_score > 0:
        total_score += min(env_score, 3)
        for kw, w in env_matched[:2]:
            factors.append(("environment", kw, w))

    # Classify
    if total_score <= LOW_MAX:
        complexity = "LOW"
    elif total_score <= MEDIUM_MAX:
        complexity = "MEDIUM"
    else:
        complexity = "HIGH"

    # Suggest generation approach
    current_approach = shot.get("generation_approach", "")
    suggested = _suggest_approach(complexity, shot, current_approach)

    return {
        "complexity": complexity,
        "score": total_score,
        "factors": factors,
        "suggested_approach": suggested,
    }


def _suggest_approach(complexity, shot, current_approach):
    """Suggest a generation_approach based on motion complexity.

    LOW  → held_frame_push or held_frame_static preferred
    MEDIUM → standard_flf or triptych_split_flf
    HIGH → triptych_split_flf (more keyframes = more control)
    """
    # Respect existing approach if it makes sense
    if complexity == "LOW":
        if current_approach in ("held_frame_push", "held_frame_static"):
            return current_approach
        # Has motion prompt? Push is better than static
        if shot.get("motion_prompt"):
            return "held_frame_push"
        return "held_frame_static"

    elif complexity == "MEDIUM":
        if current_approach in ("standard_flf", "triptych_split_flf"):
            return current_approach
        return "standard_flf"

    else:  # HIGH
        return "triptych_split_flf"


def classify_storyboard(storyboard):
    """Classify all shots in a storyboard.

    Returns:
        list of dicts (one per shot), plus summary stats
    """
    results = []
    for shot in storyboard.get("shots", []):
        classification = classify_shot(shot)
        results.append({
            "shot_id": shot.get("id"),
            "shot_name": shot.get("name", ""),
            **classification,
        })

    # Summary stats
    counts = {"LOW": 0, "MEDIUM": 0, "HIGH": 0}
    approach_changes = 0
    for i, r in enumerate(results):
        counts[r["complexity"]] += 1
        current = storyboard["shots"][i].get("generation_approach", "")
        if current and r["suggested_approach"] != current:
            approach_changes += 1

    summary = {
        "total_shots": len(results),
        "low": counts["LOW"],
        "medium": counts["MEDIUM"],
        "high": counts["HIGH"],
        "approach_changes_suggested": approach_changes,
        "avg_score": round(sum(r["score"] for r in results) / max(len(results), 1), 1),
    }

    return results, summary


def enrich_storyboard(storyboard, apply_suggestions=False):
    """Add motion_complexity field to each shot in the storyboard.

    Args:
        storyboard: parsed storyboard dict
        apply_suggestions: if True, also update generation_approach

    Returns:
        enriched storyboard (deep copy)
    """
    enriched = deepcopy(storyboard)
    for shot in enriched.get("shots", []):
        classification = classify_shot(shot)
        shot["motion_complexity"] = classification["complexity"]
        shot["motion_complexity_score"] = classification["score"]
        if apply_suggestions:
            shot["generation_approach"] = classification["suggested_approach"]
    return enriched


def main():
    parser = argparse.ArgumentParser(
        description="Classify motion complexity for storyboard shots"
    )
    parser.add_argument("storyboard", help="Path to storyboard JSON")
    parser.add_argument(
        "--enrich",
        action="store_true",
        help="Write motion_complexity fields to storyboard JSON",
    )
    parser.add_argument(
        "--apply-suggestions",
        action="store_true",
        help="Also update generation_approach based on complexity (use with --enrich)",
    )
    parser.add_argument("--json", action="store_true", help="Machine-readable output")
    parser.add_argument(
        "--suggest",
        action="store_true",
        help="Show generation_approach suggestions",
    )

    args = parser.parse_args()

    if not os.path.exists(args.storyboard):
        print(f"ERROR: Storyboard not found: {args.storyboard}")
        sys.exit(2)

    try:
        with open(args.storyboard) as f:
            storyboard = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON: {e}")
        sys.exit(2)

    results, summary = classify_storyboard(storyboard)

    if args.enrich:
        enriched = enrich_storyboard(storyboard, apply_suggestions=args.apply_suggestions)
        with open(args.storyboard, "w") as f:
            json.dump(enriched, f, indent=2)
        print(f"Enriched {len(results)} shots with motion_complexity.")
        if args.apply_suggestions:
            print(f"Updated {summary['approach_changes_suggested']} generation_approach values.")
        return

    if args.json:
        output = {"results": results, "summary": summary}
        print(json.dumps(output, indent=2))
        return

    # Human-readable report
    print(f"=== Motion Complexity Analysis ===")
    print(f"Storyboard: {os.path.basename(args.storyboard)}")
    print(f"Total shots: {summary['total_shots']}")
    print(f"Average score: {summary['avg_score']}")
    print()
    print(f"  LOW:    {summary['low']}  (static/minimal motion — held frame candidates)")
    print(f"  MEDIUM: {summary['medium']}  (moderate motion — standard FLF)")
    print(f"  HIGH:   {summary['high']}  (complex motion — triptych recommended)")
    print()

    # Show high-complexity shots in detail
    high_shots = [r for r in results if r["complexity"] == "HIGH"]
    if high_shots:
        print("HIGH COMPLEXITY SHOTS (plan for extra regen cycles):")
        for r in high_shots:
            factors_str = ", ".join(f"{f[0]}:{f[1]}" for f in r["factors"][:4])
            print(f"  #{r['shot_id']} '{r['shot_name']}' (score {r['score']}): {factors_str}")
        print()

    if args.suggest:
        print("GENERATION APPROACH SUGGESTIONS:")
        for i, r in enumerate(results):
            current = storyboard["shots"][i].get("generation_approach", "?")
            suggested = r["suggested_approach"]
            marker = " ← CHANGE" if current != suggested else ""
            print(
                f"  #{r['shot_id']} [{r['complexity']}] "
                f"current: {current} → suggested: {suggested}{marker}"
            )
        print()
        print(
            f"{summary['approach_changes_suggested']} approach changes suggested. "
            f"Use --enrich --apply-suggestions to apply."
        )


if __name__ == "__main__":
    main()
