#!/usr/bin/env python3
"""
spatial_inference.py — Spatial Continuity Inference Engine

Reads a storyboard and infers/populates spatial continuity data:
- spatial.camera_side (A/B — 180° line)
- spatial.screen_direction (left-to-right, right-to-left, etc.)
- spatial.blocking (per-character position and facing)
- edge_continuity.spatial_note + inherit_layers
- scene_break_before (inferred from location changes)
- same_angle_from / continuity_from (inferred from consecutive shot patterns)
- Dialogue return detection (A→B→A patterns)

Usage:
    python3 spatial_inference.py PROJECT/ --episode N --dry-run
    python3 spatial_inference.py PROJECT/ --episode N --apply
    python3 spatial_inference.py PROJECT/ --episode N --interactive

Exit codes:
    0 = success
    1 = errors
    2 = file/parse error
"""

import argparse
import json
import re
import sys
from pathlib import Path
from typing import List, Optional, Tuple


# ── Direction Inference Patterns ──────────────────────────────────────

# Prose cues that indicate facing or perspective direction
_FACING_LEFT_CUES = re.compile(
    r"(?:faces?\s+left|facing\s+left|looks?\s+left|turns?\s+left|"
    r"screen[- ]left|camera[- ]left|toward\s+(?:the\s+)?left)",
    re.IGNORECASE,
)
_FACING_RIGHT_CUES = re.compile(
    r"(?:faces?\s+right|facing\s+right|looks?\s+right|turns?\s+right|"
    r"screen[- ]right|camera[- ]right|toward\s+(?:the\s+)?right)",
    re.IGNORECASE,
)
_FACING_AWAY_CUES = re.compile(
    r"(?:from\s+behind|rear\s+view|back\s+(?:of|to)|away\s+from\s+camera|"
    r"walks?\s+away|turns?\s+away|facing\s+away)",
    re.IGNORECASE,
)
_FACING_TOWARD_CUES = re.compile(
    r"(?:faces?\s+(?:the\s+)?camera|toward\s+(?:the\s+)?camera|"
    r"looks?\s+(?:directly\s+)?(?:at|into)\s+(?:the\s+)?(?:camera|lens)|"
    r"direct\s+gaze)",
    re.IGNORECASE,
)

# Camera side B indicators (reverse angle, OTS, etc.)
_SIDE_B_CUES = re.compile(
    r"(?:reverse\s+angle|over[- ]the[- ]shoulder|OTS|POV|"
    r"from\s+(?:behind|the\s+back)|counter[- ]shot|"
    r"opposing\s+angle|flip\s+side)",
    re.IGNORECASE,
)

# Motion/direction cues for screen_direction
_MOTION_LTR_CUES = re.compile(
    r"(?:runs?\s+(?:toward\s+)?right|moves?\s+right|chases?\s+right|"
    r"walks?\s+right|left\s+to\s+right|enters?\s+from\s+(?:the\s+)?left|"
    r"track(?:s|ing)?\s+right|dolly\s+right)",
    re.IGNORECASE,
)
_MOTION_RTL_CUES = re.compile(
    r"(?:runs?\s+(?:toward\s+)?left|moves?\s+left|chases?\s+left|"
    r"walks?\s+left|right\s+to\s+left|enters?\s+from\s+(?:the\s+)?right|"
    r"track(?:s|ing)?\s+left|dolly\s+left)",
    re.IGNORECASE,
)

# Scale order for same_angle_from / continuity_from detection
SCALE_ORDER = {"ECU": 0, "CU": 1, "MCU": 2, "MS": 3, "LS": 4, "WIDE": 5}


# ── Scene Detection ──────────────────────────────────────────────────

def detect_scenes(shots: List[dict]) -> List[List[dict]]:
    """Group shots into scenes based on scene_break_before or location changes.

    Also infers scene breaks from location changes in first_frame/action text.
    """
    if not shots:
        return []

    scenes = []
    current_scene = []

    for shot in shots:
        is_break = shot.get("scene_break_before", False)

        # Infer scene break from location change in first_frame text
        if not is_break and current_scene:
            prev = current_scene[-1]
            prev_text = (prev.get("first_frame", "") + " " + prev.get("action", "")).lower()
            curr_text = (shot.get("first_frame", "") + " " + shot.get("action", "")).lower()

            # Check for INT./EXT. style location markers in script_excerpt
            prev_excerpt = prev.get("script_excerpt", "")
            curr_excerpt = shot.get("script_excerpt", "")
            if (re.match(r"(?:INT\.|EXT\.|INT/EXT\.)", curr_excerpt, re.IGNORECASE)
                    and current_scene):
                is_break = True

        if is_break and current_scene:
            scenes.append(current_scene)
            current_scene = []

        current_scene.append(shot)

    if current_scene:
        scenes.append(current_scene)

    return scenes


# ── Inference Functions ──────────────────────────────────────────────

def infer_camera_side(shot: dict, scene_shots: List[dict]) -> str:
    """Infer camera side (A or B) from prose cues.

    Default: A. Mark as B for reverse angles, OTS, POV shots.
    """
    text = " ".join([
        shot.get("first_frame", ""),
        shot.get("last_frame", ""),
        shot.get("action", ""),
        shot.get("subject", ""),
        shot.get("description", ""),
    ])

    if _SIDE_B_CUES.search(text):
        return "B"

    # POV shot type is always side B
    if shot.get("shot_type") == "POV":
        return "B"

    return "A"


def infer_screen_direction(shot: dict) -> Optional[str]:
    """Infer screen direction from camera_movement and action prose."""
    text = " ".join([
        shot.get("first_frame", ""),
        shot.get("action", ""),
        shot.get("camera_movement", ""),
    ])

    if _MOTION_LTR_CUES.search(text):
        return "left-to-right"
    if _MOTION_RTL_CUES.search(text):
        return "right-to-left"

    # Camera movement hints
    movement = shot.get("camera_movement", "").lower()
    if movement in ("track", "dolly"):
        # Default tracking to left-to-right (reading direction)
        return "left-to-right"

    return None


def infer_facing(shot: dict) -> Optional[str]:
    """Infer character facing direction from prose cues."""
    text = " ".join([
        shot.get("first_frame", ""),
        shot.get("last_frame", ""),
        shot.get("action", ""),
        shot.get("description", ""),
    ])

    if _FACING_AWAY_CUES.search(text):
        return "away-from-camera"
    if _FACING_TOWARD_CUES.search(text):
        return "toward-camera"
    if _FACING_LEFT_CUES.search(text):
        return "left"
    if _FACING_RIGHT_CUES.search(text):
        return "right"

    return None


def infer_position(shot: dict, characters: List[str], char_name: str,
                   scene_context: dict) -> str:
    """Infer character screen position.

    For multi-character shots: assign positions based on character order
    and prose cues. For single-character shots: use shot type and scene
    context.
    """
    text = " ".join([
        shot.get("first_frame", ""),
        shot.get("action", ""),
        shot.get("description", ""),
    ]).lower()

    # Direct position mentions
    if f"{char_name}" in text:
        char_text = text[text.index(char_name):]
        if "screen-left" in char_text[:80] or "left side" in char_text[:80]:
            return "screen-left"
        if "screen-right" in char_text[:80] or "right side" in char_text[:80]:
            return "screen-right"
        if "foreground" in char_text[:80]:
            return "foreground"
        if "background" in char_text[:80]:
            return "background"

    # Multi-character: first character screen-left, second screen-right
    if len(characters) >= 2:
        idx = characters.index(char_name) if char_name in characters else 0
        if idx == 0:
            return "screen-left"
        elif idx == 1:
            return "screen-right"
        else:
            return "background"

    # Single character: tight shots → center, wider shots → based on scene context
    shot_type = shot.get("shot_type", "MS")
    if shot_type in ("ECU", "CU"):
        return "center"

    # Check scene context for established position
    established = scene_context.get(char_name)
    if established:
        return established

    return "center"


def infer_blocking(shot: dict, scene_context: dict) -> dict:
    """Infer blocking for all characters in a shot."""
    characters = shot.get("characters_in_shot", [])
    if not characters:
        return {}

    blocking = {}
    for char_name in characters:
        char_name_lower = char_name.lower()
        facing = infer_facing(shot)
        position = infer_position(shot, characters, char_name_lower, scene_context)

        # For multi-character dialogue: if no facing cue, infer from position
        if not facing and len(characters) >= 2:
            if position == "screen-left":
                facing = "right"  # Looking at the other character
            elif position == "screen-right":
                facing = "left"
            else:
                facing = "toward-camera"

        if not facing:
            facing = "toward-camera"

        blocking[char_name_lower] = {
            "position": position,
            "facing": facing,
        }

        # Update scene context with established position
        scene_context[char_name_lower] = position

    return blocking


def infer_spatial_note(shot: dict, prev_shot: dict) -> str:
    """Infer edge_continuity.spatial_note from blocking and angle data."""
    if not prev_shot:
        return ""

    curr_angle = shot.get("camera_angle", "eye")
    prev_angle = prev_shot.get("camera_angle", "eye")
    curr_type = shot.get("shot_type", "")
    prev_type = prev_shot.get("shot_type", "")
    curr_chars = set(shot.get("characters_in_shot", []))
    prev_chars = set(prev_shot.get("characters_in_shot", []))
    shared = curr_chars & prev_chars

    if not shared:
        return ""

    # Reverse angle detection
    curr_side = (shot.get("spatial") or {}).get("camera_side", "A")
    prev_side = (prev_shot.get("spatial") or {}).get("camera_side", "A")
    if curr_side != prev_side:
        return f"Reverse angle — camera crosses to side {curr_side}"

    # Scale change description
    curr_scale = SCALE_ORDER.get(curr_type, -1)
    prev_scale = SCALE_ORDER.get(prev_type, -1)
    if curr_scale >= 0 and prev_scale >= 0:
        if curr_scale < prev_scale:
            # Tighter shot
            chars_str = ", ".join(sorted(shared))
            return f"Punch in — {prev_type} to {curr_type} on {chars_str}, same axis"
        elif curr_scale > prev_scale:
            chars_str = ", ".join(sorted(shared))
            return f"Pull back — {prev_type} to {curr_type}, establishing context around {chars_str}"

    # Angle change
    if curr_angle != prev_angle:
        return f"Angle shift — {prev_angle} to {curr_angle}"

    return ""


def infer_same_angle_from(shot: dict, prev_shot: dict) -> Optional[dict]:
    """Detect same-angle continuation (img2img candidate)."""
    if not prev_shot:
        return None

    curr_chars = set(shot.get("characters_in_shot", []))
    prev_chars = set(prev_shot.get("characters_in_shot", []))
    if not (curr_chars & prev_chars):
        return None

    curr_angle = shot.get("camera_angle", "eye")
    prev_angle = prev_shot.get("camera_angle", "eye")
    curr_type = shot.get("shot_type", "")
    prev_type = prev_shot.get("shot_type", "")

    # Same character, same angle, same or adjacent scale
    if curr_angle == prev_angle and curr_type == prev_type:
        return {
            "shot_id": prev_shot["id"],
            "frame": "hero",
            "strength": 0.35,
        }

    return None


def infer_dialogue_return(shot: dict, scene_shots: List[dict],
                          current_index: int) -> Optional[dict]:
    """Detect A->B->A dialogue return patterns.

    Scan back through the scene for a matching setup shot:
    same character + same angle + same scale, not immediately adjacent.
    """
    curr_chars = set(shot.get("characters_in_shot", []))
    curr_angle = shot.get("camera_angle", "eye")
    curr_type = shot.get("shot_type", "")

    if not curr_chars:
        return None

    # Scan backward (skip immediate previous — that's handled by infer_same_angle_from)
    for j in range(current_index - 2, -1, -1):
        candidate = scene_shots[j]
        cand_chars = set(candidate.get("characters_in_shot", []))
        cand_angle = candidate.get("camera_angle", "eye")
        cand_type = candidate.get("shot_type", "")

        if (curr_chars == cand_chars
                and curr_angle == cand_angle
                and curr_type == cand_type):
            return {
                "shot_id": candidate["id"],
                "frame": "hero",
                "strength": 0.35,
            }
    return None


def infer_continuity_from(shot: dict, prev_shot: dict) -> Optional[dict]:
    """Detect punch-in detail shots (CU/ECU following wider shot of same character)."""
    if not prev_shot:
        return None

    curr_chars = set(shot.get("characters_in_shot", []))
    prev_chars = set(prev_shot.get("characters_in_shot", []))
    if not (curr_chars & prev_chars):
        return None

    curr_scale = SCALE_ORDER.get(shot.get("shot_type", ""), -1)
    prev_scale = SCALE_ORDER.get(prev_shot.get("shot_type", ""), -1)

    # CU/ECU following a wider shot
    if curr_scale >= 0 and prev_scale >= 0 and curr_scale < prev_scale:
        scale_diff = prev_scale - curr_scale
        if scale_diff >= 2:  # At least 2 steps tighter
            # Determine crop region
            curr_type = shot.get("shot_type", "")
            if curr_type == "ECU":
                region = "face"
            elif curr_type == "CU":
                region = "upper_third"
            else:
                region = "center"

            return {
                "shot_id": prev_shot["id"],
                "frame": "hero",
                "region": region,
                "method": "img2img_crop",
                "strength": 0.30,
            }

    return None


# ── Main Backfill Logic ──────────────────────────────────────────────

def backfill_storyboard(storyboard: dict, dry_run: bool = False) -> Tuple[dict, List[str]]:
    """Infer and populate spatial data for all shots in a storyboard.

    Returns:
        (modified_storyboard, list_of_changes)
    """
    shots = storyboard.get("shots", [])
    if not shots:
        return storyboard, ["No shots found"]

    scenes = detect_scenes(shots)
    changes = []

    for scene_shots in scenes:
        scene_context = {}  # Track established positions per character

        for i, shot in enumerate(scene_shots):
            shot_id = shot["id"]
            prev_shot = scene_shots[i - 1] if i > 0 else None
            characters = shot.get("characters_in_shot", [])

            # ── Infer scene_break_before ──
            if i == 0 and not shot.get("scene_break_before"):
                # First shot of a detected scene should be marked
                if scenes.index(scene_shots) > 0:
                    shot["scene_break_before"] = True
                    changes.append(f"S{shot_id:02d}: Set scene_break_before=true")

            # ── Infer spatial object ──
            existing_spatial = shot.get("spatial")
            if not existing_spatial and characters:
                spatial = {}

                # Camera side
                camera_side = infer_camera_side(shot, scene_shots)
                spatial["camera_side"] = camera_side

                # Screen direction
                screen_dir = infer_screen_direction(shot)
                if screen_dir:
                    spatial["screen_direction"] = screen_dir

                # Blocking
                blocking = infer_blocking(shot, scene_context)
                if blocking:
                    spatial["blocking"] = blocking

                if spatial:
                    shot["spatial"] = spatial
                    changes.append(
                        f"S{shot_id:02d}: Added spatial — "
                        f"side={camera_side}, "
                        f"dir={screen_dir or 'none'}, "
                        f"chars={list(blocking.keys()) if blocking else 'none'}"
                    )

            # ── Infer edge_continuity.spatial_note ──
            if prev_shot and not shot.get("scene_break_before", False):
                existing_edge = shot.get("edge_continuity")
                if not existing_edge or not existing_edge.get("spatial_note"):
                    note = infer_spatial_note(shot, prev_shot)
                    if note:
                        if not existing_edge:
                            shot["edge_continuity"] = {"spatial_note": note}
                        else:
                            shot["edge_continuity"]["spatial_note"] = note
                        changes.append(f"S{shot_id:02d}: Added edge_continuity.spatial_note: {note}")

            # ── Inherit scene layers (environment/lighting/color) ──
            if prev_shot and not shot.get("scene_break_before", False):
                edge = shot.get("edge_continuity") or {}
                if "inherit_layers" not in edge:
                    edge["inherit_layers"] = ["environment", "lighting", "color_objects"]
                    shot["edge_continuity"] = edge
                    changes.append(f"S{shot_id:02d}: Set inherit_layers for scene consistency")

            # ── Infer same_angle_from ──
            if prev_shot and not shot.get("same_angle_from") and not shot.get("scene_break_before", False):
                same_angle = infer_same_angle_from(shot, prev_shot)
                if same_angle:
                    shot["same_angle_from"] = same_angle
                    changes.append(f"S{shot_id:02d}: Set same_angle_from shot #{same_angle['shot_id']}")

            # ── Infer dialogue return (A→B→A pattern) ──
            if not shot.get("same_angle_from") and not shot.get("scene_break_before", False) and i >= 2:
                dialogue_return = infer_dialogue_return(shot, scene_shots, i)
                if dialogue_return:
                    shot["same_angle_from"] = dialogue_return
                    changes.append(f"S{shot_id:02d}: Set same_angle_from shot #{dialogue_return['shot_id']} (dialogue return)")

            # ── Infer continuity_from ──
            if prev_shot and not shot.get("continuity_from") and not shot.get("scene_break_before", False):
                continuity = infer_continuity_from(shot, prev_shot)
                if continuity:
                    shot["continuity_from"] = continuity
                    changes.append(f"S{shot_id:02d}: Set continuity_from shot #{continuity['shot_id']} ({continuity['region']})")

    return storyboard, changes


# ── Main ─────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description="Backfill spatial continuity data into existing storyboards"
    )
    parser.add_argument("project_dir", help="Project directory (e.g., leviathan/)")
    parser.add_argument("--episode", "-e", type=int, required=True, help="Episode number")
    parser.add_argument("--dry-run", action="store_true",
                        help="Print inferred spatial data without modifying storyboard")
    parser.add_argument("--apply", action="store_true",
                        help="Write spatial data into storyboard JSON")
    parser.add_argument("--interactive", action="store_true",
                        help="Print each shot's inferred data and ask for confirmation")
    args = parser.parse_args()

    if not args.dry_run and not args.apply and not args.interactive:
        print("ERROR: Specify --dry-run, --apply, or --interactive", file=sys.stderr)
        sys.exit(2)

    # Resolve project directory
    project_dir = Path(args.project_dir).resolve()
    if not project_dir.exists():
        # Try relative to engine parent
        engine_dir = Path(__file__).resolve().parent.parent.parent
        project_dir = engine_dir / args.project_dir
    if not project_dir.exists():
        print(f"ERROR: Project directory not found: {args.project_dir}", file=sys.stderr)
        sys.exit(1)

    ep_str = f"{args.episode:03d}"
    storyboard_path = project_dir / "storyboards" / f"storyboard_ep_{ep_str}.json"

    if not storyboard_path.exists():
        print(f"ERROR: Storyboard not found: {storyboard_path}", file=sys.stderr)
        sys.exit(1)

    try:
        with open(storyboard_path) as f:
            storyboard = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON: {e}", file=sys.stderr)
        sys.exit(2)

    shots = storyboard.get("shots", [])
    print(f"\n{'=' * 58}")
    print(f"  SPATIAL BACKFILL — Episode {args.episode}")
    print(f"{'=' * 58}")
    print(f"  Storyboard: {storyboard_path.name}")
    print(f"  Shots: {len(shots)}")

    # Count existing spatial data
    has_spatial = sum(1 for s in shots if s.get("spatial"))
    has_edge = sum(1 for s in shots if s.get("edge_continuity"))
    has_scene_break = sum(1 for s in shots if s.get("scene_break_before"))
    print(f"  Existing: spatial={has_spatial}, edge_continuity={has_edge}, scene_breaks={has_scene_break}")
    print(f"{'=' * 58}\n")

    # Run backfill
    modified, changes = backfill_storyboard(storyboard, dry_run=args.dry_run)

    if not changes:
        print("  No changes needed — all spatial data already populated.")
        sys.exit(0)

    # Report changes
    print(f"  Changes ({len(changes)}):\n")
    for change in changes:
        print(f"    {change}")
    print()

    if args.dry_run:
        print("  DRY RUN — no files modified.")
        print(f"  Run with --apply to write {len(changes)} changes to {storyboard_path.name}")
        sys.exit(0)

    if args.interactive:
        # In interactive mode, already printed changes above
        response = input("  Apply these changes? [y/N] ").strip().lower()
        if response != "y":
            print("  Cancelled.")
            sys.exit(0)

    # Write
    if args.apply or (args.interactive and response == "y"):
        with open(storyboard_path, "w") as f:
            json.dump(modified, f, indent=2)
        print(f"  Written to {storyboard_path}")

    # Summary
    new_spatial = sum(1 for s in modified.get("shots", []) if s.get("spatial"))
    new_edge = sum(1 for s in modified.get("shots", []) if s.get("edge_continuity"))
    new_breaks = sum(1 for s in modified.get("shots", []) if s.get("scene_break_before"))
    print(f"\n{'=' * 58}")
    print(f"  BACKFILL COMPLETE")
    print(f"{'=' * 58}")
    print(f"  spatial:          {has_spatial} → {new_spatial}")
    print(f"  edge_continuity:  {has_edge} → {new_edge}")
    print(f"  scene_breaks:     {has_scene_break} → {new_breaks}")
    print(f"{'=' * 58}\n")

    sys.exit(0)


if __name__ == "__main__":
    main()
