"""
coverage_density.py — Dramatic intensity scoring and coverage tier assignment.

Analyzes shot plans to determine which shots need editorial coverage (reactions,
cutaways, wide safeties) based on emotional intensity, structural position, and
routing complexity. Coverage density is proportional to dramatic intensity.

Consultation: consultations/starsend/coverage-density-dramatic-peaks/SYNTHESIS.md
"""

from collections import defaultdict
from dataclasses import dataclass, field


# ── Emotion Keyword Lexicon ──
# Arousal scores 0.0-1.0 mapped to tier boundaries:
#   Tier 0 (Valley): < 0.30
#   Tier 1 (Rising): 0.30 - 0.64
#   Tier 2 (Peak):   0.65 - 0.89
#   Tier 3 (Climax): >= 0.90

EMOTION_LEXICON = {
    # Climax territory (0.90+)
    "terror": 0.95, "rage": 0.95, "horror": 0.93, "fury": 0.92,
    "anguish": 0.91, "ecstasy": 0.90, "lethal": 0.92,

    # Peak territory (0.65-0.89)
    "fear": 0.85, "panic": 0.88, "shock": 0.80, "betrayal": 0.80,
    "defiance": 0.78, "confrontation": 0.76, "desperation": 0.82,
    "grief": 0.78, "anger": 0.82, "intensity": 0.75,
    "determination": 0.72, "despair": 0.75, "threat": 0.74,
    "dread": 0.73, "desire": 0.70, "urgency": 0.72,
    "guilt": 0.68, "jealousy": 0.68, "suspicion": 0.67,
    "revelation": 0.70, "longing": 0.65, "passion": 0.72,
    "terrifying": 0.85, "violent": 0.80, "imposing": 0.75,
    "piercing": 0.72, "overwhelming": 0.78, "brutal": 0.85,
    "helplessness": 0.76, "activation": 0.78,

    # Rising territory (0.30-0.64)
    "tension": 0.60, "anxiety": 0.58, "concern": 0.50,
    "unease": 0.48, "curiosity": 0.45, "frustration": 0.55,
    "resolve": 0.52, "resignation": 0.42, "wary": 0.50,
    "skepticism": 0.40, "hope": 0.45, "relief": 0.40,
    "surprise": 0.55, "confusion": 0.45, "discomfort": 0.48,
    "reluctance": 0.38, "amusement": 0.35, "affection": 0.38,
    "warmth": 0.35, "pride": 0.45, "sorrow": 0.50,
    "nostalgia": 0.38, "bitterness": 0.55, "impatience": 0.52,
    "paranoia": 0.55, "oppressive": 0.50, "trapped": 0.52,
    "relentless": 0.50, "temptation": 0.58, "greedy": 0.55,
    "realization": 0.60, "discovery": 0.62, "eerie": 0.55,
    "stakes": 0.58, "focus": 0.42, "dormant": 0.38,

    # Valley territory (<0.30)
    "calm": 0.15, "neutral": 0.10, "contemplation": 0.22,
    "indifference": 0.12, "serenity": 0.10, "quiet": 0.12,
    "stillness": 0.08, "peace": 0.10, "boredom": 0.15,
    "fatigue": 0.18, "acceptance": 0.20, "melancholy": 0.28,
}

# Bigram compounds (checked before unigrams)
EMOTION_BIGRAMS = {
    "quiet determination": 0.55,
    "rising dread": 0.80,
    "cold fury": 0.93,
    "bitter resignation": 0.45,
    "high stakes": 0.72,
    "cold release": 0.55,
    "inhuman focus": 0.70,
}

DEFAULT_EMOTION_SCORE = 0.35


# ── Tier config ──

TIER_BOUNDARIES = (0.30, 0.65, 0.90)  # tier 0/1, 1/2, 2/3

MAX_CLIMAX_PER_SCENE = 2
MAX_CLUSTER_SIZE = 4

# Coverage angles per tier
COVERAGE_CONFIG = {
    0: {"multi_char": [], "solo": []},
    1: {"multi_char": [], "solo": []},
    2: {"multi_char": ["reaction"], "solo": ["cutaway"]},
    3: {"multi_char": ["reaction", "cutaway"], "solo": ["cutaway", "wide"]},
}


# ── Dataclass ──

@dataclass
class CoverageMoment:
    """A cluster of consecutive peak shots sharing one set of coverage angles."""
    moment_id: str
    shot_ids: list[str]
    anchor_shot_id: str
    anchor_shot: dict
    tier: int
    composite_scores: dict  # shot_id -> score
    coverage_types: list[str]  # e.g., ["reaction", "cutaway"]
    scene_label: str = ""


# ── Scoring functions ──

def score_emotion(emotion_text: str) -> float:
    """Return the highest arousal score from emotion keywords in the text."""
    if not emotion_text:
        return DEFAULT_EMOTION_SCORE
    text = emotion_text.lower().replace(",", " ").strip()

    # Check full phrase first
    if text in EMOTION_LEXICON:
        return EMOTION_LEXICON[text]

    # Check bigrams (consumed words skip unigram scoring)
    words = text.split()
    scores = []
    consumed = set()
    for i in range(len(words) - 1):
        bigram = f"{words[i]} {words[i + 1]}"
        if bigram in EMOTION_BIGRAMS:
            scores.append(EMOTION_BIGRAMS[bigram])
            consumed.add(i)
            consumed.add(i + 1)

    # Check unigrams (skip words already consumed by bigrams)
    for i, w in enumerate(words):
        if i not in consumed and w in EMOTION_LEXICON:
            scores.append(EMOTION_LEXICON[w])

    return max(scores) if scores else DEFAULT_EMOTION_SCORE


def score_structure(shot_index: int, scene_shot_count: int) -> float:
    """Score structural position in scene. 0.0-1.0, rises toward end."""
    if scene_shot_count <= 1:
        return 0.5
    return shot_index / (scene_shot_count - 1)


def score_routing(routing_data: dict) -> float:
    """Extract intensity signals from routing metadata."""
    score = 0.0
    if routing_data.get("num_characters", 0) >= 2:
        score += 0.20
    if routing_data.get("has_dialogue", False):
        score += 0.15
    complexity = routing_data.get("camera_complexity", "")
    if complexity in ("handheld", "push_in", "crane"):
        score += 0.15
    if routing_data.get("is_env_only", False):
        return 0.0  # ENV shots never get coverage
    return min(score, 0.50)


def compute_composite(e_score: float, s_score: float, r_score: float) -> float:
    """Composite scoring with max() envelope.

    Formula: max(e_score, e_score * (1 + s_score * 0.5) + r_score * 0.15)
    The max() ensures extreme emotion can never be dragged down by structure.
    """
    amplified = e_score * (1 + s_score * 0.5) + r_score * 0.15
    return max(e_score, amplified)


def composite_to_tier(composite: float) -> int:
    """Map composite score to tier 0-3."""
    if composite >= TIER_BOUNDARIES[2]:
        return 3
    elif composite >= TIER_BOUNDARIES[1]:
        return 2
    elif composite >= TIER_BOUNDARIES[0]:
        return 1
    return 0


# ── Scene boundary detection ──

def detect_scene_boundaries(shots: list[dict]) -> list[list[dict]]:
    """Group shots into scenes by location_id transitions.

    A scene boundary occurs when location_id changes. ENV-only shots
    at the boundary are grouped with the following scene (they serve
    as establishing shots for what comes next).
    """
    if not shots:
        return []

    scenes: list[list[dict]] = []
    current: list[dict] = [shots[0]]
    current_loc = shots[0].get("asset_data", {}).get("location_id", "")

    for shot in shots[1:]:
        loc = shot.get("asset_data", {}).get("location_id", "")
        if loc != current_loc and loc:
            scenes.append(current)
            current = [shot]
            current_loc = loc
        else:
            current.append(shot)

    if current:
        scenes.append(current)

    return scenes


# ── Climax cap ──

def apply_climax_cap(
    tier_map: dict[str, int],
    score_map: dict[str, float],
    scenes: list[list[dict]],
    max_per_scene: int = MAX_CLIMAX_PER_SCENE,
) -> dict[str, int]:
    """Demote excess Tier 3 shots to Tier 2 within each scene."""
    for scene_shots in scenes:
        climax_ids = [
            s["shot_id"] for s in scene_shots
            if tier_map.get(s["shot_id"], 0) == 3
        ]
        if len(climax_ids) > max_per_scene:
            ranked = sorted(climax_ids, key=lambda sid: score_map.get(sid, 0), reverse=True)
            for sid in ranked[max_per_scene:]:
                tier_map[sid] = 2
    return tier_map


# ── Moment clustering ──

def cluster_coverage_moments(
    shots: list[dict],
    tier_map: dict[str, int],
    score_map: dict[str, float],
    scenes: list[list[dict]],
) -> list[CoverageMoment]:
    """Collapse consecutive Tier 2+ shots into coverage moments.

    A new moment starts when:
    - Tier drops below 2
    - location_id changes
    - camera_side changes
    - Character set has no overlap with previous shot
    - Cluster exceeds MAX_CLUSTER_SIZE
    """
    moments: list[CoverageMoment] = []
    moment_counter = 0

    for scene_idx, scene_shots in enumerate(scenes):
        current_cluster: list[dict] = []

        for shot in scene_shots:
            sid = shot["shot_id"]
            tier = tier_map.get(sid, 0)

            if tier < 2:
                if current_cluster:
                    result = _finalize_cluster(
                        current_cluster, tier_map, score_map, moment_counter, scene_idx
                    )
                    moments.extend(result)
                    moment_counter += len(result)
                    current_cluster = []
                continue

            if not current_cluster:
                current_cluster = [shot]
                continue

            prev = current_cluster[-1]
            same_loc = (
                shot.get("asset_data", {}).get("location_id", "")
                == prev.get("asset_data", {}).get("location_id", "")
            )
            same_side = (
                shot.get("spatial_data", {}).get("camera_side", "")
                == prev.get("spatial_data", {}).get("camera_side", "")
            )
            prev_chars = {
                (c.get("char_id", "") if isinstance(c, dict) else str(c)).upper()
                for c in prev.get("asset_data", {}).get("characters", [])
            }
            shot_chars = {
                (c.get("char_id", "") if isinstance(c, dict) else str(c)).upper()
                for c in shot.get("asset_data", {}).get("characters", [])
            }
            chars_overlap = bool(prev_chars & shot_chars) if (prev_chars and shot_chars) else True

            if same_loc and same_side and chars_overlap and len(current_cluster) < MAX_CLUSTER_SIZE:
                current_cluster.append(shot)
            else:
                result = _finalize_cluster(
                    current_cluster, tier_map, score_map, moment_counter, scene_idx
                )
                moments.extend(result)
                moment_counter += len(result)
                current_cluster = [shot]

        if current_cluster:
            result = _finalize_cluster(
                current_cluster, tier_map, score_map, moment_counter, scene_idx
            )
            moments.extend(result)
            moment_counter += len(result)

    return moments


def _finalize_cluster(
    cluster: list[dict],
    tier_map: dict[str, int],
    score_map: dict[str, float],
    base_counter: int,
    scene_idx: int,
) -> list[CoverageMoment]:
    """Convert a cluster of shots into a CoverageMoment."""
    if not cluster:
        return []

    # Find anchor (highest composite score)
    anchor = max(cluster, key=lambda s: score_map.get(s["shot_id"], 0))
    anchor_id = anchor["shot_id"]
    max_tier = max(tier_map.get(s["shot_id"], 0) for s in cluster)

    # Determine character count for coverage type selection
    all_chars = set()
    for s in cluster:
        for c in s.get("asset_data", {}).get("characters", []):
            cid = (c.get("char_id", "") if isinstance(c, dict) else str(c)).upper()
            if cid:
                all_chars.add(cid)

    is_multi_char = len(all_chars) >= 2
    config_key = "multi_char" if is_multi_char else "solo"
    coverage_types = list(COVERAGE_CONFIG.get(max_tier, {}).get(config_key, []))

    return [CoverageMoment(
        moment_id=f"M{base_counter + 1}",
        shot_ids=[s["shot_id"] for s in cluster],
        anchor_shot_id=anchor_id,
        anchor_shot=anchor,
        tier=max_tier,
        composite_scores={s["shot_id"]: score_map.get(s["shot_id"], 0) for s in cluster},
        coverage_types=coverage_types,
        scene_label=f"Scene {scene_idx + 1}",
    )]


# ── Main entry point ──

def analyze_coverage_density(
    shots: list[dict],
    overrides: dict[str, int] | None = None,
) -> tuple[dict[str, int], dict[str, float], list[CoverageMoment]]:
    """Analyze shots and return tier assignments + coverage moments.

    Args:
        shots: List of shot dicts from ep_NNN_plan.json
        overrides: Optional dict of shot_id -> tier override (manual tagging)

    Returns:
        (tier_map, score_map, moments)
        - tier_map: {shot_id: tier_int}
        - score_map: {shot_id: composite_float}
        - moments: list of CoverageMoment with coverage types
    """
    if not shots:
        return {}, {}, []

    # Detect scene boundaries
    scenes = detect_scene_boundaries(shots)

    # Build per-scene index for structural scoring
    scene_index_map: dict[str, tuple[int, int]] = {}  # shot_id -> (index_in_scene, scene_size)
    for scene_shots in scenes:
        for i, shot in enumerate(scene_shots):
            scene_index_map[shot["shot_id"]] = (i, len(scene_shots))

    # Score each shot
    tier_map: dict[str, int] = {}
    score_map: dict[str, float] = {}

    for shot in shots:
        sid = shot["shot_id"]
        routing = shot.get("routing_data", {})

        # ENV-only shots are always Valley
        if routing.get("is_env_only", False):
            tier_map[sid] = 0
            score_map[sid] = 0.0
            continue

        emotion_line = shot.get("prompt_data", {}).get("prompt_skeleton", {}).get("emotion_line", "")
        idx_in_scene, scene_size = scene_index_map.get(sid, (0, 1))

        e = score_emotion(emotion_line)
        s = score_structure(idx_in_scene, scene_size)
        r = score_routing(routing)
        composite = compute_composite(e, s, r)

        score_map[sid] = round(composite, 3)
        tier_map[sid] = composite_to_tier(composite)

    # Apply manual overrides
    if overrides:
        for sid, tier in overrides.items():
            if sid in tier_map and 0 <= tier <= 3:
                tier_map[sid] = tier

    # Apply climax cap
    tier_map = apply_climax_cap(tier_map, score_map, scenes)

    # Cluster into coverage moments
    moments = cluster_coverage_moments(shots, tier_map, score_map, scenes)

    return tier_map, score_map, moments
