# recoil/pipeline/_lib/scene_clusterer.py
"""Scene-clusterer for r2v_multi batching (Phase 4 — §26).

Groups a CanonicalShot list into multi-shot r2v batches respecting:
  - shared location (no batch crosses location boundaries)
  - shared scene_index family (no batch crosses scene boundaries)
  - max_batch_size cap (fal.ai r2v_multi prompt cap)
  - min_batch_size threshold (below this we fall back to per-shot i2v)
  - drastic angle / scale jumps that defeat continuity (heuristic)

The output Batch list is consumed by EpisodeRunner.run_episode_batches.
Each Batch.shots is preserved in declaration order — order matters for
the r2v_multi prompt's [Xs-Ys] timestamp annotations.

Consumes:  list[CanonicalShot] (Phase 0).
Produces:  list[Batch].

Data flow (also see BUILD_SPEC Phase 4 diagram):
    CanonicalPlan.shots
        → cluster_shots_into_batches(...)
        → list[Batch]
        → per-Batch Beat (or per-shot Beat if Batch is below threshold)
        → build_dispatch_payload(modality="r2v_multi" | "video_i2v")
        → dispatch
        → segment-frame extraction in step_runner.execute_pass
        → frames bound back to individual Beats inside the Batch.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import Iterable

from recoil.pipeline._lib.plan_loader import CanonicalShot

logger = logging.getLogger(__name__)

DEFAULT_MAX_BATCH_SIZE = 4  # narrative pacing cap (JT 2026-05-17 review).
# fal.ai r2v_multi hard cap is 6 — callers can
# raise via kwarg if needed.
DEFAULT_MIN_BATCH_SIZE = 3  # below this, fall back to per-shot i2v


@dataclass
class Batch:
    """A continuity-coherent group of CanonicalShots for one r2v_multi dispatch.

    Attributes:
        batch_id:               Stable id like "BATCH_001" (1-indexed
                                across the episode).
        shots:                  CanonicalShots in dispatch order.
        shared_location_id:     The location all shots share.
        shared_characters:      Union of characters across shots (used
                                for ref resolution).
        total_duration_s:       Sum of per-shot durations (for fal.ai
                                duration field).
        scene_index_range:      (min, max) of scene_index values.
        below_threshold:        True if len(shots) < min_batch_size and
                                this Batch should fall back to per-shot.
    """

    batch_id: str
    shots: list[CanonicalShot]
    shared_location_id: str
    shared_characters: list[str] = field(default_factory=list)
    total_duration_s: float = 0.0
    scene_index_range: tuple[int, int] = (0, 0)
    below_threshold: bool = False


def cluster_shots_into_batches(
    shots: Iterable[CanonicalShot],
    *,
    max_batch_size: int = DEFAULT_MAX_BATCH_SIZE,
    min_batch_size: int = DEFAULT_MIN_BATCH_SIZE,
) -> list[Batch]:
    """Cluster a CanonicalShot list into r2v_multi-compatible batches.

    Break conditions (any one closes the current batch):
      1. shot.location_id differs from the running batch.
      2. shot.scene_index differs from the running batch's last
         scene_index by >1 (allows adjacent scenes inside the same
         beat; >1 indicates a narrative jump like SH05→SH06 where the
         scene changes hard).
      3. len(current_batch) reaches max_batch_size.
      4. shot.shot_type implies a drastic angle jump
         (heuristic: WS following ECU, or vice versa, breaks continuity).

    Args:
        shots: Iterable of CanonicalShot (order preserved).
        max_batch_size: Cap per batch. Default 4 (JT narrative-pacing
                        preference). fal.ai r2v_multi accepts up to 6 —
                        pass max_batch_size=6 to use the full capacity.
        min_batch_size: Below this the batch is flagged
                        `below_threshold=True` so the orchestrator falls
                        back to per-shot video_i2v dispatch.

    Returns:
        list[Batch] in dispatch order.
    """
    shots_list = list(shots)
    if not shots_list:
        return []

    batches: list[Batch] = []
    current: list[CanonicalShot] = []
    current_loc: str | None = None
    current_scene_idx: int | None = None

    def _flush() -> None:
        nonlocal current, current_loc, current_scene_idx
        if not current:
            return
        b = _build_batch(current, idx=len(batches) + 1, min_batch_size=min_batch_size)
        batches.append(b)
        current = []
        current_loc = None
        current_scene_idx = None

    for shot in shots_list:
        loc = shot.location_id
        sidx = shot.scene_index
        stype = shot.shot_type

        if not current:
            current.append(shot)
            current_loc = loc
            current_scene_idx = sidx
            continue

        break_now = False
        if loc != current_loc:
            break_now = True
        elif (
            isinstance(sidx, int)
            and isinstance(current_scene_idx, int)
            and abs(sidx - current_scene_idx) > 1
        ):
            break_now = True
        elif _is_drastic_angle_jump(current[-1].shot_type, stype):
            break_now = True
        elif len(current) >= max_batch_size:
            break_now = True

        if break_now:
            _flush()
            current.append(shot)
            current_loc = loc
            current_scene_idx = sidx
        else:
            current.append(shot)
            current_scene_idx = sidx

    _flush()

    for b in batches:
        if b.below_threshold:
            logger.warning(
                "scene_clusterer: batch %s only has %d shot(s) "
                "(min=%d) — falling back to per-shot video_i2v.",
                b.batch_id,
                len(b.shots),
                min_batch_size,
            )
    return batches


def _build_batch(
    shots: list[CanonicalShot],
    *,
    idx: int,
    min_batch_size: int,
) -> Batch:
    head = shots[0]
    loc = head.location_id or ""
    chars: list[str] = []
    seen: set[str] = set()
    total_dur = 0.0
    sidx_values: list[int] = []
    for s in shots:
        for c in s.characters or []:
            # E2 fix (R4)—defensive coercion. CharacterEntry comes in as str
            # in canonical plans but the real ep_001_plan.json sometimes has
            # dict entries ({"char_id": "JADE", ...}) or None placeholders from
            # planner glitches. Coerce all shapes to upper-cased str char_id.
            if c is None:
                continue
            if isinstance(c, str):
                key = c.strip().upper()
            elif isinstance(c, dict):
                raw = c.get("char_id") or c.get("name") or ""
                key = str(raw).strip().upper()
            else:
                raw = getattr(c, "char_id", None) or getattr(c, "name", None) or ""
                key = str(raw).strip().upper()
            if key and key not in seen:
                seen.add(key)
                chars.append(key)
        if s.duration_s is not None:
            total_dur += float(s.duration_s)
        if isinstance(s.scene_index, int):
            sidx_values.append(s.scene_index)
    scene_range = (min(sidx_values), max(sidx_values)) if sidx_values else (0, 0)
    return Batch(
        batch_id=f"BATCH_{idx:03d}",
        shots=shots,
        shared_location_id=loc,
        shared_characters=chars,
        total_duration_s=total_dur,
        scene_index_range=scene_range,
        below_threshold=len(shots) < min_batch_size,
    )


def single_batch_from_shots(
    shots: Iterable[CanonicalShot],
    *,
    hard_cap: int = 6,
) -> list[Batch]:
    """Treat the given shots as ONE r2v_multi batch.

    The caller asserts these shots are a deliberate creative unit (e.g. the shots
    of a single coverage pass), so NO heuristic splitting applies — no location /
    scene_index / angle breaks, no min-size fallback to per-shot i2v. Flora bills
    per run, so a pass is one run whether it holds 1 shot or several rapid cuts.
    The only ceiling is fal.ai's r2v_multi per-call cap (default 6); above it we
    fall back to continuity clustering rather than silently truncating.
    """
    shots_list = list(shots)
    if not shots_list:
        return []
    if len(shots_list) > hard_cap:
        return cluster_shots_into_batches(shots_list, max_batch_size=hard_cap)
    return [_build_batch(shots_list, idx=1, min_batch_size=1)]


_DRASTIC_JUMP: dict[str, frozenset[str]] = {
    "WS": frozenset({"ECU"}),
    "LS": frozenset({"ECU"}),
    "ECU": frozenset({"WS", "LS"}),
}


def _is_drastic_angle_jump(opening: str | None, next_type: str | None) -> bool:
    if not opening or not next_type:
        return False
    return next_type in _DRASTIC_JUMP.get(opening, frozenset())


__all__ = [
    "Batch",
    "cluster_shots_into_batches",
    "single_batch_from_shots",
    "DEFAULT_MAX_BATCH_SIZE",
    "DEFAULT_MIN_BATCH_SIZE",
]
