"""
scene_planner.py — Scene-aware generation ordering and shot classification.

Plan-first: reads routing_data and prompt_data from plan shot records
when available. Falls back to parsing legacy Recoil storyboard dicts.

Key behaviors:
- Groups shots by scene (scene_break_before field or location changes)
- Orders ENV shots FIRST within each scene (provides_scene_ref)
- Classifies shots into complexity tiers (simple/standard/complex)
- Tags shots with scene ref dependencies
- Routes shots to sub-pipelines (still/i2v/t2v/multi_shot)
"""

from __future__ import annotations

from typing import Optional

from recoil.core.model_profiles import get_model, get_modality, supports_start_end_frame, supports_multi_shot


# Shot types that indicate ENV-only (no characters)
_ENV_INDICATORS = {"ENV", "ESTABLISHING", "INSERT"}

# Shot types that suggest simple tier
_SIMPLE_SHOT_TYPES = {"INSERT", "WIDE", "LS", "ELS", "EWS", "VLS", "WS", "FS"}

# Shot types that suggest complex tier
_COMPLEX_INDICATORS = {"ECU", "BCU"}  # Extreme close-ups need expression refs

# ── Coverage cycling map ──────────────────────────────────────────────
_COVERAGE_OPTIONS: dict[str, list[str]] = {
    "EWS": ["MS", "MCU", "CU"],
    "WS": ["MS", "MCU", "CU"],
    "LS": ["MS", "MCU"],
    "VLS": ["MS", "MCU"],
    "FS": ["MCU", "CU", "MS"],
    "MS": ["CU", "MCU", "WS"],
    "MLS": ["CU", "MS", "WS"],
    "MWS": ["CU", "MS"],
    "MFS": ["CU", "MS"],
    "MCU": ["WS", "MS", "CU"],
    "CU": ["WS", "MS", "MCU"],
    "BCU": ["CU", "MS"],
    "ECU": ["CU", "MS"],
    "OTS": ["OTS_REVERSE", "CU", "MS"],
}

_COVERAGE_EXCLUDE = {"INSERT", "DETAIL_INSERT", "TITLE", "TRANSITION", "ENV", "ESTABLISHING"}


def get_coverage_options(shot: dict) -> list[str]:
    """Return available coverage shot types for a given shot."""
    shot_type = shot.get("prompt_data", {}).get("shot_type", "").upper()
    if shot_type in _COVERAGE_EXCLUDE:
        return []
    if shot.get("coverage") is False:
        return []
    routing_data = shot.get("routing_data", {})
    num_chars = routing_data.get("num_characters", 0)
    if num_chars == 0:
        asset_chars = shot.get("asset_data", {}).get("characters", [])
        if not asset_chars and shot_type != "OTS":
            return []
    return _COVERAGE_OPTIONS.get(shot_type, [])


def generate_single_coverage(
    shot: dict,
    coverage_type: str,
    coverage_num: int = 1,
    episode_id: str | None = None,
) -> dict:
    """Generate a single coverage shot record for a specific angle."""
    import copy
    import re

    shot_id = shot.get("shot_id", "")
    original_type = shot.get("prompt_data", {}).get("shot_type", "").upper()

    cov_shot = copy.deepcopy(shot)
    cov_shot["shot_id"] = f"{shot_id}_COV_{coverage_num:02d}"
    cov_shot["is_coverage"] = True
    cov_shot["coverage_of"] = shot_id

    cov_prompt_data = cov_shot.get("prompt_data", {})
    cov_prompt_data["original_shot_type"] = original_type
    cov_prompt_data["shot_type"] = coverage_type if coverage_type != "OTS_REVERSE" else "OTS"
    cov_shot["prompt_data"] = cov_prompt_data

    if coverage_type == "OTS_REVERSE":
        asset_data = cov_shot.get("asset_data", {})
        chars = asset_data.get("characters", [])
        if len(chars) >= 2:
            chars[0], chars[1] = chars[1], chars[0]
            asset_data["characters"] = chars
        cov_prompt_data["reverse_ots"] = True

    if episode_id:
        cov_shot["episode_id"] = episode_id
    elif not cov_shot.get("episode_id"):
        ep_match = re.match(r"(EP\d+)", shot_id)
        if ep_match:
            cov_shot["episode_id"] = ep_match.group(1)

    return cov_shot


def generate_coverage_plan(
    shots: list[dict],
    episode_id: str | None = None,
    depth: int = 1,
) -> list[dict]:
    """Generate coverage shots for an entire episode (batch mode)."""
    coverage_shots = []
    for shot in shots:
        if shot.get("is_coverage"):
            continue
        options = get_coverage_options(shot)
        if not options:
            continue
        if shot.get("coverage") is True and not options:
            options = ["MS"]
        take_count = len(options) if depth == 0 else min(depth, len(options))
        for i in range(take_count):
            cov_shot = generate_single_coverage(
                shot, coverage_type=options[i], coverage_num=i + 1, episode_id=episode_id,
            )
            coverage_shots.append(cov_shot)
    return coverage_shots


def _is_plan_shot(shot: dict) -> bool:
    """Check if a shot dict is from a Starsend plan."""
    return "routing_data" in shot and "prompt_data" in shot


def classify_shot_tier(shot: dict) -> str:
    """Classify a shot into a complexity tier.

    Returns one of: "simple", "standard", "complex"

    Plan-first: reads routing_data.num_characters and prompt_data.shot_type.
    Falls back to legacy field parsing.

    Rules:
    - ENV shots (no characters) → simple (direct to Pro)
    - Wide/LS shots → simple (no facial detail needed)
    - Two+ character shots → complex (spatial syntax, 7-ref cap)
    - ECU/BCU with emotion → complex (needs expression ref)
    - Everything else → standard (full 3-pass)
    """
    if _is_plan_shot(shot):
        routing = shot["routing_data"]
        prompt = shot["prompt_data"]
        num_chars = routing.get("num_characters", 0)
        is_env = routing.get("is_env_only", False)
        shot_type = prompt.get("shot_type", "MS").upper()

        if is_env or num_chars == 0:
            return "simple"
        if num_chars >= 2:
            return "complex"
        if shot_type in _SIMPLE_SHOT_TYPES:
            return "simple"
        if shot_type in _COMPLEX_INDICATORS:
            # Check for emotion in asset_data characters
            chars = shot.get("asset_data", {}).get("characters", [])
            has_emotion = any(
                c.get("emotion_keyword", "neutral") != "neutral"
                for c in chars
            )
            if has_emotion:
                return "complex"
        return "standard"

    # Legacy path
    chars = shot.get("characters_in_shot", [])
    shot_type = shot.get("shot_type", "MS").upper()
    emotion = shot.get("emotion", "")

    if not chars:
        return "simple"
    if len(chars) >= 2:
        return "complex"
    if shot_type in _SIMPLE_SHOT_TYPES:
        return "simple"
    if shot_type in _COMPLEX_INDICATORS and emotion:
        return "complex"
    return "standard"


def is_env_shot(shot: dict) -> bool:
    """Determine if a shot is ENV-only (no characters)."""
    if _is_plan_shot(shot):
        routing = shot["routing_data"]
        return routing.get("is_env_only", False) or routing.get("num_characters", 0) == 0
    chars = shot.get("characters_in_shot", [])
    return len(chars) == 0


def plan_scene(scene_shots: list[dict], scene_index: int) -> list[dict]:
    """Plan generation order for a scene's shots.

    Returns shots in generation order with metadata added:
    - _tier: complexity tier ("simple", "standard", "complex")
    - _is_env: whether this is an ENV-only shot
    - _provides_scene_ref: True for the first ENV shot in the scene
    - _needs_scene_ref: True for all non-first shots in the scene
    - _scene_index: which scene this belongs to
    - _generation_order: position in the generation sequence

    ENV shots are moved to the front so they generate first and
    provide scene references for subsequent character shots.
    """
    if not scene_shots:
        return []

    # Separate ENV and character shots
    env_shots = []
    char_shots = []
    for shot in scene_shots:
        if is_env_shot(shot):
            env_shots.append(shot)
        else:
            char_shots.append(shot)

    # Ordered: ENV first, then character shots in original order
    ordered = env_shots + char_shots

    # Tag each shot with planning metadata
    result = []
    first_env_seen = False

    for i, shot in enumerate(ordered):
        # Make a shallow copy to avoid mutating the storyboard
        planned = dict(shot)
        planned["_tier"] = classify_shot_tier(shot)
        planned["_is_env"] = is_env_shot(shot)
        planned["_scene_index"] = scene_index
        planned["_generation_order"] = i

        if planned["_is_env"] and not first_env_seen:
            planned["_provides_scene_ref"] = True
            planned["_needs_scene_ref"] = False
            first_env_seen = True
        else:
            planned["_provides_scene_ref"] = False
            planned["_needs_scene_ref"] = i > 0  # All after first need scene ref

        result.append(planned)

    return result


def plan_episode(scenes: list[list[dict]]) -> list[list[dict]]:
    """Plan generation order for an entire episode.

    Args:
        scenes: List of scene groups (from recoil_bridge.get_all_scenes).

    Returns:
        List of planned scene groups with generation metadata.
    """
    return [plan_scene(scene, i) for i, scene in enumerate(scenes)]


def get_scene_summary(planned_scene: list[dict]) -> dict:
    """Get a summary of a planned scene."""
    tiers = {"simple": 0, "standard": 0, "complex": 0}
    env_count = 0
    char_count = 0

    for shot in planned_scene:
        tier = shot.get("_tier", "standard")
        tiers[tier] = tiers.get(tier, 0) + 1
        if shot.get("_is_env"):
            env_count += 1
        else:
            char_count += 1

    return {
        "total_shots": len(planned_scene),
        "env_shots": env_count,
        "char_shots": char_count,
        "tiers": tiers,
        "has_scene_ref_provider": any(s.get("_provides_scene_ref") for s in planned_scene),
    }


# ── Shot types that suggest static/still rendering ───────────────────
_STATIC_SHOT_TYPES = {"INSERT", "ECU"}

# ── Shot types with complex camera movement ──────────────────────────
_COMPLEX_CAMERA_TYPES = {"CRANE", "STEADICAM", "DOLLY", "TRACKING"}


def route_shot(shot: dict, scene_shots: list[dict] | None = None) -> dict:
    """Determine which sub-pipeline handles this shot.

    Plan-first: reads routing_data fields when available.
    Quality over cost — no budget-aware downgrade.

    Returns:
        {
            "pipeline": "still" | "i2v" | "t2v" | "multi_shot",
            "model": model ID string,
            "reason": str explaining the routing decision
        }

    Priority (from Gemini consultation Feb 27, 2026):
    1. Scene eligible for batching → multi_shot (SeedDance)
    2. Requires match-cut / start+end frame → i2v (Kling)
    3. ENV / static shots → still (NBP)
    4. Complex camera movement or >15s → t2v (Veo)
    5. Dialogue or 2+ characters → t2v (SeedDance)
    6. Default → t2v (Kling)
    """
    from recoil.core.paths import get_config
    routing_models = get_config().get("routing", {}).get("models", {})

    model_still = routing_models.get("still", get_model("production", "image"))
    model_i2v = routing_models.get("i2v", get_model("i2v", "video"))
    model_t2v_default = routing_models.get("t2v_default", get_model("t2v_default", "video"))
    model_t2v_dialogue = routing_models.get("t2v_dialogue", get_model("t2v_dialogue", "video"))
    model_t2v_long = routing_models.get("t2v_long", get_model("t2v_long", "video"))
    model_multi_shot = routing_models.get("multi_shot", get_model("multi_shot", "video"))

    # Extract fields — plan-first with legacy fallback
    if _is_plan_shot(shot):
        routing = shot["routing_data"]
        prompt = shot["prompt_data"]
        shot_type = prompt.get("shot_type", "MS").upper()
        num_chars = routing.get("num_characters", 0)
        is_env = routing.get("is_env_only", False) or num_chars == 0
        camera_movement = prompt.get("camera_movement", "static").upper()
        duration = routing.get("target_editorial_duration_s", 5)
        has_dialogue = routing.get("has_dialogue", False)
        needs_match_cut = routing.get("narrative_requires_match_cut", False)
    else:
        shot_type = shot.get("shot_type", "MS").upper()
        chars = shot.get("characters_in_shot", [])
        num_chars = len(chars)
        is_env = num_chars == 0
        camera_movement = (shot.get("camera_movement") or "").upper()
        duration = shot.get("duration_s", 5)
        has_dialogue = bool(shot.get("dialogue"))
        needs_match_cut = bool(shot.get("start_frame") or shot.get("end_frame"))

    # 1. Multi-shot batching (scene level)
    if scene_shots and is_batch_eligible(scene_shots):
        return {
            "pipeline": "multi_shot",
            "model": model_multi_shot,
            "reason": f"Scene has {len(scene_shots)} batch-eligible shots",
        }

    # 2. Match-cut / start+end frame precision → I2V
    if needs_match_cut:
        return {
            "pipeline": "i2v",
            "model": model_i2v,
            "reason": "Shot requires match-cut / start+end frame I2V control",
        }

    # 2.5 Prop manipulation → force still pipeline (ADR-P06)
    if _is_plan_shot(shot):
        prop_interaction = shot.get("asset_data", {}).get("prop_interaction", "none")
        if prop_interaction == "manipulated":
            return {
                "pipeline": "still",
                "model": model_still,
                "reason": "Manipulated prop interaction → still (Action Burst)",
            }

    # 3. ENV / static shots → still pipeline (NBP)
    if is_env:
        return {
            "pipeline": "still",
            "model": model_still,
            "reason": "ENV shot → still",
        }

    # 4. Complex camera movement or long duration → Veo (ADR-M03: ENV or solo B-roll only)
    _veo_eligible = is_env or (num_chars <= 1 and not has_dialogue)
    if duration > 15 and _veo_eligible:
        return {
            "pipeline": "t2v",
            "model": model_t2v_long,
            "reason": f"Long duration ({duration}s > 15s), solo/ENV → Veo",
        }
    if any(cm in camera_movement for cm in _COMPLEX_CAMERA_TYPES) and _veo_eligible:
        return {
            "pipeline": "t2v",
            "model": model_t2v_long,
            "reason": f"Complex camera ({camera_movement}), solo/ENV → Veo",
        }
    # Long/complex camera but NOT Veo-eligible → Kling 3.0 (SeedDance stub)
    if duration > 15 or any(cm in camera_movement for cm in _COMPLEX_CAMERA_TYPES):
        return {
            "pipeline": "t2v",
            "model": model_t2v_default,
            "reason": f"Long/complex camera but multi-char or dialogue → Kling 3.0 (SeedDance not yet available)",
        }

    # 5. Dialogue or multi-character → I2V (keyframe first for identity consistency)
    if has_dialogue:
        return {
            "pipeline": "i2v",
            "model": model_i2v,
            "reason": "Dialogue shot → I2V (keyframe preserves identity)",
        }
    if num_chars >= 2:
        return {
            "pipeline": "i2v",
            "model": model_i2v,
            "reason": f"Multi-character ({num_chars} chars) → I2V (keyframe preserves identity)",
        }

    # 6. Default → I2V (keyframe → video preserves identity across shots)
    return {
        "pipeline": "i2v",
        "model": model_i2v,
        "reason": "Default single-character shot → I2V (keyframe preserves identity)",
    }


def is_batch_eligible(scene_shots: list[dict]) -> bool:
    """Check if a scene qualifies for multi-shot batching.

    Handles both plan and legacy shot formats.

    Criteria (from Gemini consultation Round 3):
    1. 3-8 shots in the scene
    2. All shots share the same location
    3. No shots require I2V precision (match-cut/start+end frame)
    4. Max 2 unique characters (cross-attention bleed causes wardrobe/face
       swapping with 3+; 9 ref slots = 2 chars × 3 refs + 1 scene + 2 pose)
    """
    if len(scene_shots) < 3 or len(scene_shots) > 8:
        return False

    # Extract locations and characters from both formats
    locations = set()
    all_chars = set()
    has_match_cut = False

    for s in scene_shots:
        if _is_plan_shot(s):
            loc = s.get("asset_data", {}).get("location_id", "")
            locations.add(loc)
            for c in s.get("asset_data", {}).get("characters", []):
                all_chars.add(c.get("char_id", "").lower())
            if s.get("routing_data", {}).get("narrative_requires_match_cut", False):
                has_match_cut = True
        else:
            locations.add(s.get("location", ""))
            for c in s.get("characters_in_shot", []):
                all_chars.add(c.lower())
            if s.get("start_frame") or s.get("end_frame"):
                has_match_cut = True

    if len(locations) > 1:
        return False
    if has_match_cut:
        return False
    if len(all_chars) > 2:
        return False

    return True


def partition_long_scene(scene_shots: list[dict], max_batch: int = 8) -> list[list[dict]]:
    """Split scenes >8 shots into valid sub-batches.

    Preserves shot order. Each sub-batch is 3-8 shots.
    Remainder shots (<3) get appended to the last batch.
    """
    if len(scene_shots) <= max_batch:
        return [scene_shots]

    batches = []
    for i in range(0, len(scene_shots), max_batch):
        batch = scene_shots[i:i + max_batch]
        batches.append(batch)

    # If last batch is too small (<3), merge with previous
    if len(batches) > 1 and len(batches[-1]) < 3:
        last = batches.pop()
        batches[-1].extend(last)

    return batches


if __name__ == "__main__":
    import sys
    sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent.parent))

    from recoil.pipeline._lib.recoil_bridge import load_storyboard, get_all_scenes

    sb = load_storyboard(1)
    scenes = get_all_scenes(sb)

    print(f"=== EP001 Scene Plan ({len(scenes)} scenes) ===\n")

    planned = plan_episode(scenes)
    total_cost = 0.0

    for i, scene in enumerate(planned):
        summary = get_scene_summary(scene)
        print(f"Scene {i}: {summary['total_shots']} shots "
              f"(ENV: {summary['env_shots']}, CHAR: {summary['char_shots']})")
        print(f"  Tiers: {summary['tiers']}")
        print(f"  Has ENV ref provider: {summary['has_scene_ref_provider']}")

        # Estimate cost
        from recoil.core.paths import get_config
        tiers = get_config().get("complexity_tiers", {})
        cost_map = {k: v.get("estimated_cost", 0.30) for k, v in tiers.items()}
        scene_cost = sum(cost_map.get(s.get("_tier", "standard"), 0.30) for s in scene)
        total_cost += scene_cost
        print(f"  Estimated cost: ${scene_cost:.2f}")

        # Show generation order
        for shot in scene[:5]:  # First 5 shots
            marker = "ENV→" if shot.get("_provides_scene_ref") else "    "
            shot_label = shot.get("shot_id", shot.get("id", "?"))
            shot_name = shot.get("name", shot.get("source_text", "")[:40])
            print(f"  {marker} [{shot['_generation_order']}] {shot_label}: "
                  f"{shot_name} ({shot['_tier']})")
        if len(scene) > 5:
            print(f"  ... and {len(scene) - 5} more")
        print()

    print(f"Total estimated cost: ${total_cost:.2f}")
