"""
batcher.py — Group consecutive plan shots into multi-prompt batches.

Batching rules:
  - Max 6 shots per batch (Kling API limit)
  - Max 15s total duration per batch (Kling API limit)
  - Min 3s per shot (rounds up shorter shots)
  - Does NOT break on scene boundaries (multi-prompt handles transitions)
  - Does NOT break on character changes (pass character superset)
  - Stores _api_duration and _original_duration on each shot dict
"""

import logging

logger = logging.getLogger(__name__)


def batch_shots_for_multi_prompt(
    shots: list[dict],
    max_duration: int = 15,
    max_shots: int = 6,
    min_shot_duration: int = 3,
) -> list[list[dict]]:
    """Group consecutive plan shots into multi-prompt batches.

    Args:
        shots: List of plan shot dicts (must have routing_data.target_editorial_duration_s).
        max_duration: Maximum total duration per batch in seconds.
        max_shots: Maximum number of shots per batch.
        min_shot_duration: Minimum duration per shot (Kling API minimum).

    Returns:
        List of batches, where each batch is a list of shot dicts.
        Each shot dict gets two extra keys:
          _api_duration: the duration sent to the API (>= min_shot_duration)
          _original_duration: the original target duration from the plan
    """
    batches: list[list[dict]] = []
    current: list[dict] = []
    current_dur = 0

    for shot in shots:
        target = shot.get("routing_data", {}).get("target_editorial_duration_s", 5)
        api_dur = max(min_shot_duration, target)

        if current and (current_dur + api_dur > max_duration or len(current) >= max_shots):
            batches.append(current)
            current = []
            current_dur = 0

        shot["_api_duration"] = api_dur
        shot["_original_duration"] = target
        current.append(shot)
        current_dur += api_dur

    if current:
        batches.append(current)

    return batches


def collect_batch_characters(batch: list[dict]) -> list[str]:
    """Union of all character IDs across all shots in a batch.

    Returns sorted list for deterministic ordering.
    """
    char_ids: set[str] = set()
    for shot in batch:
        for char in shot.get("asset_data", {}).get("characters", []):
            cid = char.get("char_id")
            if cid:
                char_ids.add(cid)
    return sorted(char_ids)


def batch_total_duration(batch: list[dict]) -> int:
    """Sum of _api_duration for all shots in a batch."""
    return sum(s.get("_api_duration", 5) for s in batch)


def filter_single_shot_fallbacks(
    batches: list[list[dict]],
) -> tuple[list[list[dict]], list[dict]]:
    """Separate batches with only 1 shot (can't use multi-prompt).

    Multi-prompt requires minimum 2 shots. Single-shot batches are
    returned separately for the existing single-shot pipeline.

    Returns:
        (multi_prompt_batches, single_shot_fallbacks)
    """
    multi = []
    singles = []
    for batch in batches:
        if len(batch) >= 2:
            multi.append(batch)
        else:
            singles.extend(batch)
    return multi, singles
