"""Strategy-neutral grouping abstraction for episode generation."""

from __future__ import annotations

import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Literal

from recoil.core.model_profiles import get_profile
from recoil.pipeline._lib.derivation_sha import shotset_hash
from recoil.pipeline._lib.plan_loader import CanonicalPlan, CanonicalShot
from recoil.pipeline._lib.scene_clusterer import cluster_shots_into_batches

if TYPE_CHECKING:
    from recoil.pipeline.orchestrator.coverage_planner import CoveragePass

GroupingStrategyName = Literal["continuity", "coverage", "oner", "solo"]
GroupModality = Literal["r2v_multi", "video_i2v"]
ONER_DEFAULT_MODEL_ID = "seeddance-2.0"
ONER_DEFAULT_MAX_DURATION_S = 15.0
ONER_DEFAULT_MIN_DURATION_S = 4.0
ONER_DEFAULT_MAX_SHOTS = 6
ONER_PROMPT_DIRECTIVE = (
    "### Format C — Oner (Continuous Take)\n"
    "**What:** Generate one continuous 10-15 second shot. Extract edit points in post.\n"
    "**Like:** Shooting a long take on a Steadicam. Find the cuts later.\n"
    "**Use for:** Atmospheric moments, set pieces, climactic sequences where performance continuity matters.\n"
    "**Output:** One continuous clip. Editor finds the gold inside it."
)


@dataclass(frozen=True)
class GroupingIdentity:
    strategy: GroupingStrategyName
    ordinal: int
    shot_ids: list[str]
    source_pass_id: str | None = None

    def to_dict(self) -> dict:
        return {
            "strategy": self.strategy,
            "ordinal": self.ordinal,
            "shot_ids": list(self.shot_ids),
            "source_pass_id": self.source_pass_id,
            "shotset_hash": shotset_hash(self.shot_ids),
        }


@dataclass
class Group:
    identity: GroupingIdentity
    shots: list[CanonicalShot]
    scene_id: str
    modality: GroupModality
    coverage_pass: CoveragePass | None = None
    generation_config: dict = field(default_factory=dict)
    element_config: dict = field(default_factory=dict)
    prompt_directive: str | None = None


@dataclass
class GroupingContext:
    project: str
    episode: int
    canonical_plan: CanonicalPlan
    selected_coverage_passes: list[CoveragePass]
    tier_map: dict[str, int]
    wildcard_override: bool | None


AssembleFn = Callable[[list[CanonicalShot], GroupingContext], list[Group]]


@dataclass(frozen=True)
class GroupingStrategy:
    name: str
    assemble: AssembleFn


GROUPING_REGISTRY: dict[str, GroupingStrategy] = {}


def _register(name: str) -> Callable[[AssembleFn], AssembleFn]:
    """Decorator that registers a grouping assembler by public strategy name."""

    def decorator(fn: AssembleFn) -> AssembleFn:
        GROUPING_REGISTRY[name] = GroupingStrategy(name=name, assemble=fn)
        return fn

    return decorator


def get_grouping(name: str) -> GroupingStrategy:
    try:
        return GROUPING_REGISTRY[name]
    except KeyError as exc:
        available = ", ".join(sorted(GROUPING_REGISTRY))
        raise ValueError(
            f"Unknown grouping strategy {name!r}. Available: {available}"
        ) from exc


@_register("continuity")
def _assemble_continuity(
    shots: list[CanonicalShot],
    ctx: GroupingContext,
) -> list[Group]:
    del ctx
    groups: list[Group] = []
    for ordinal, batch in enumerate(cluster_shots_into_batches(shots), start=1):
        groups.append(
            Group(
                identity=GroupingIdentity(
                    strategy="continuity",
                    ordinal=ordinal,
                    shot_ids=[shot.shot_id for shot in batch.shots],
                ),
                shots=list(batch.shots),
                scene_id=batch.batch_id,
                modality="video_i2v" if batch.below_threshold else "r2v_multi",
            )
        )
    return groups


@_register("oner")
def _assemble_oner(
    shots: list[CanonicalShot],
    ctx: GroupingContext,
) -> list[Group]:
    del ctx
    groups: list[Group] = []
    ordinal = 1
    for scene_shots in _scene_runs(shots):
        for oner_shots in _split_oner_scene_for_caps(scene_shots):
            groups.append(
                Group(
                    identity=GroupingIdentity(
                        strategy="oner",
                        ordinal=ordinal,
                        shot_ids=[shot.shot_id for shot in oner_shots],
                    ),
                    shots=list(oner_shots),
                    scene_id=f"ONER_{ordinal:03d}",
                    modality="video_i2v" if len(oner_shots) == 1 else "r2v_multi",
                    prompt_directive=ONER_PROMPT_DIRECTIVE,
                )
            )
            ordinal += 1
    return groups


@_register("coverage")
def _assemble_coverage(
    shots: list[CanonicalShot],
    ctx: GroupingContext,
) -> list[Group]:
    passes = list(ctx.selected_coverage_passes)
    if not passes:
        passes = _build_coverage_passes(
            [_shot_to_coverage_dict(shot) for shot in shots],
            ctx.project,
            ctx.episode,
            ctx.tier_map,
            ctx.wildcard_override,
        )

    shot_by_id = {shot.shot_id: shot for shot in shots}
    groups: list[Group] = []
    for index, coverage_pass in enumerate(passes, start=1):
        shot_ids = _coverage_shot_ids(coverage_pass)
        pass_shots = [
            shot_by_id[shot_id] for shot_id in shot_ids if shot_id in shot_by_id
        ]
        ordinal = _coverage_ordinal(coverage_pass.pass_id, fallback=index)
        source_pass_id = (
            getattr(coverage_pass, "source_pass_id", None) or coverage_pass.pass_id
        )
        groups.append(
            Group(
                identity=GroupingIdentity(
                    strategy="coverage",
                    ordinal=ordinal,
                    shot_ids=shot_ids,
                    source_pass_id=source_pass_id,
                ),
                shots=pass_shots,
                scene_id=coverage_pass.pass_id,
                modality="r2v_multi",
                coverage_pass=coverage_pass,
                generation_config=dict(coverage_pass.generation_config or {}),
                element_config=dict(coverage_pass.element_config or {}),
            )
        )
    return groups


def _build_coverage_passes(
    shots_dicts: list[dict],
    project: str,
    episode: int,
    tier_map: dict[str, int],
    wildcard_override: bool | None,
):
    from recoil.pipeline.orchestrator import coverage_planner

    return coverage_planner.build_passes(
        shots_dicts,
        project,
        episode,
        tier_map,
        wildcard_override,
    )


@_register("solo")
def _assemble_solo(
    shots: list[CanonicalShot],
    ctx: GroupingContext,
) -> list[Group]:
    del ctx
    return [
        Group(
            identity=GroupingIdentity(
                strategy="solo",
                ordinal=0,
                shot_ids=[shot.shot_id],
            ),
            shots=[shot],
            scene_id=shot.shot_id,
            modality="video_i2v",
        )
        for shot in shots
    ]


def _coverage_ordinal(pass_id: str, *, fallback: int) -> int:
    match = re.search(r"(?:^|_)(?:PASS|COV)_(\d{1,3})(?:_|$)", pass_id)
    if match:
        return int(match.group(1))
    return fallback


def _scene_runs(shots: list[CanonicalShot]) -> list[list[CanonicalShot]]:
    runs: list[list[CanonicalShot]] = []
    current: list[CanonicalShot] = []
    current_scene: int | None = None
    for shot in shots:
        if current and shot.scene_index != current_scene:
            runs.append(current)
            current = []
        current.append(shot)
        current_scene = shot.scene_index
    if current:
        runs.append(current)
    return runs


def _split_oner_scene_for_caps(shots: list[CanonicalShot]) -> list[list[CanonicalShot]]:
    if not shots:
        return []
    min_duration_s, max_duration_s, max_shots = _oner_caps(shots)
    groups: list[list[CanonicalShot]] = []
    current: list[CanonicalShot] = []
    current_duration = 0.0

    def _flush() -> None:
        nonlocal current, current_duration
        if current:
            groups.append(current)
        current = []
        current_duration = 0.0

    for shot in shots:
        shot_duration = _oner_effective_duration(shot, min_duration_s)
        would_exceed_duration = (
            current
            and current_duration + shot_duration > max_duration_s
        )
        would_exceed_shots = current and len(current) >= max_shots
        if would_exceed_duration or would_exceed_shots:
            _flush()
        current.append(shot)
        current_duration += shot_duration

    _flush()
    return groups


def _oner_caps(shots: list[CanonicalShot]) -> tuple[float, float, int]:
    min_duration_s = ONER_DEFAULT_MIN_DURATION_S
    max_duration_s = ONER_DEFAULT_MAX_DURATION_S
    max_shots = ONER_DEFAULT_MAX_SHOTS
    for shot in shots:
        model_id = shot.video_model or ONER_DEFAULT_MODEL_ID
        try:
            profile = get_profile(model_id)
        except KeyError:
            continue
        profile_min = profile.get("min_duration_seconds")
        profile_max = profile.get("max_duration_seconds")
        if isinstance(profile_min, (int, float)):
            min_duration_s = max(min_duration_s, float(profile_min))
        if isinstance(profile_max, (int, float)):
            max_duration_s = min(max_duration_s, float(profile_max))
        shot_caps = [
            value
            for value in (
                profile.get("multi_prompt_max_shots"),
                profile.get("max_reference_images"),
            )
            if isinstance(value, int) and value > 0
        ]
        if shot_caps:
            max_shots = min(max_shots, min(shot_caps))
    return min_duration_s, max_duration_s, max(1, max_shots)


def _oner_effective_duration(shot: CanonicalShot, min_duration_s: float) -> float:
    if shot.duration_s is None:
        return min_duration_s
    return max(float(shot.duration_s), min_duration_s)


def _coverage_shot_ids(coverage_pass: CoveragePass) -> list[str]:
    shot_ids = [
        segment.source_shot_id
        for segment in getattr(coverage_pass, "segments", []) or []
        if segment.source_shot_id and segment.source_shot_id != "wildcard"
    ]
    if shot_ids:
        return shot_ids

    first, last = getattr(coverage_pass, "shot_range", ("", ""))
    return [shot_id for shot_id in (first, last) if shot_id and shot_id != "wildcard"]


def _shot_to_coverage_dict(shot: CanonicalShot) -> dict:
    data = dict(shot.raw or {})
    data["shot_id"] = shot.shot_id
    data.setdefault("scene_index", shot.scene_index)
    data.setdefault("pipeline", shot.pipeline)
    if shot.previs_model is not None:
        data.setdefault("model", shot.previs_model)
    if shot.video_model is not None:
        data.setdefault("video_model", shot.video_model)
    if shot.aspect_ratio is not None:
        data.setdefault("aspect_ratio", shot.aspect_ratio)
    if shot.quality is not None:
        data.setdefault("quality", shot.quality)
    if shot.cinematography is not None:
        data.setdefault("cinematography", shot.cinematography)

    routing = dict(data.get("routing_data") or {})
    if shot.duration_s is not None:
        routing.setdefault("target_editorial_duration_s", shot.duration_s)
    routing.setdefault("is_env_only", shot.is_env_only)
    routing.setdefault("has_dialogue", shot.has_dialogue)
    data["routing_data"] = routing

    asset = dict(data.get("asset_data") or {})
    asset.setdefault("location_id", shot.location_id)
    asset.setdefault(
        "characters",
        [_character_to_dict(character) for character in shot.characters],
    )
    data["asset_data"] = asset

    prompt = dict(data.get("prompt_data") or {})
    prompt.setdefault("shot_type", shot.shot_type)
    data["prompt_data"] = prompt

    return data


def _character_to_dict(character) -> dict:
    if isinstance(character, str):
        return {"char_id": character}
    if isinstance(character, dict):
        return dict(character)
    return {
        "char_id": character.char_id,
        "wardrobe_phase_id": character.wardrobe_phase_id,
        "emotion_keyword": character.emotion_keyword,
        "screen_position": character.screen_position,
        "visibility": character.visibility,
    }


__all__ = [
    "GROUPING_REGISTRY",
    "Group",
    "GroupingContext",
    "GroupingIdentity",
    "GroupingStrategy",
    "get_grouping",
]
