"""Authoring strategy registry and resolution policy for prompt authoring."""

from __future__ import annotations

import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Mapping

from recoil.pipeline._lib.shot_primitive import ShotPrimitive

logger = logging.getLogger(__name__)

StrategyKey = tuple[str, str, str]
DefaultKey = tuple[str, str]

R2V_MULTI = "r2v_multi"
VIDEO_I2V = "video_i2v"
DETERMINISTIC_TEMPLATE = "deterministic_template"
DIRECTED_PROSE = "directed_prose"
SHOT_SPEC = "shot_spec"
START_END_FRAME = "start_end_frame"
_WILDCARD_MODEL = "*"
_SUPPORTED_MODALITIES = (R2V_MULTI, VIDEO_I2V)
_PROMPT_DIR = Path(__file__).with_name("prompts") / "strategies"


class StrategyResolutionError(ValueError):
    """Raised when strategy policy would imply an unsafe route."""


class AuthorInputError(ValueError):
    """Raised when a strategy's required primitive inputs are missing."""


@dataclass
class AuthorStrategy:
    """Prompt authoring strategy registered for a generation modality."""

    name: str
    modality: str
    system_prompt_path: Path
    required_inputs: list[str]
    is_default: bool
    applies: Callable[[ShotPrimitive], bool] = field(repr=False, compare=False)


def _truthy_path(primitive: ShotPrimitive, dotted_path: str) -> bool:
    value = get_primitive_value(primitive, dotted_path)
    if isinstance(value, (list, tuple, dict, set)):
        return bool(value)
    return value is not None and value != ""


def _positive_duration(primitive: ShotPrimitive) -> bool:
    try:
        return float(primitive.target_editorial_duration_s) > 0
    except (TypeError, ValueError):
        return False


def _start_end_applies(primitive: ShotPrimitive) -> bool:
    return (
        _truthy_path(primitive, "refs.start_frame")
        and _truthy_path(primitive, "refs.end_frame")
        and _truthy_path(primitive, "intent")
        and _positive_duration(primitive)
    )


def _directed_prose_applies(primitive: ShotPrimitive) -> bool:
    return (
        _truthy_path(primitive, "shot_type")
        and _truthy_path(primitive, "camera_side")
        and _truthy_path(primitive, "screen_direction")
        and _positive_duration(primitive)
        and _truthy_path(primitive, "char_ids")
        and _truthy_path(primitive, "timing_segments")
    )


def _deterministic_applies(_primitive: ShotPrimitive) -> bool:
    return True


def _strategy(
    name: str,
    modality: str,
    required_inputs: list[str],
    applies: Callable[[ShotPrimitive], bool],
    *,
    is_default: bool = False,
) -> AuthorStrategy:
    return AuthorStrategy(
        name=name,
        modality=modality,
        system_prompt_path=_PROMPT_DIR / f"{name}.md",
        required_inputs=list(required_inputs),
        is_default=is_default,
        applies=applies,
    )


DIRECTED_PROSE_REQUIRED_INPUTS = [
    "shot_type",
    "camera_side",
    "screen_direction",
    "target_editorial_duration_s",
    "char_ids",
    "timing_segments",
]
START_END_FRAME_REQUIRED_INPUTS = [
    "refs.start_frame",
    "refs.end_frame",
    "intent",
    "target_editorial_duration_s",
]

DEFAULT_AUTHOR_STRATEGY: dict[DefaultKey, str] = {
    ("seeddance-2.0", R2V_MULTI): DIRECTED_PROSE,
    ("kling-v3", VIDEO_I2V): START_END_FRAME,
}

AUTHOR_STRATEGIES: dict[StrategyKey, AuthorStrategy] = {
    (
        "seeddance-2.0",
        R2V_MULTI,
        DIRECTED_PROSE,
    ): _strategy(
        DIRECTED_PROSE,
        R2V_MULTI,
        DIRECTED_PROSE_REQUIRED_INPUTS,
        _directed_prose_applies,
        is_default=True,
    ),
    (
        "seeddance-2.0",
        R2V_MULTI,
        SHOT_SPEC,
    ): _strategy(
        SHOT_SPEC,
        R2V_MULTI,
        DIRECTED_PROSE_REQUIRED_INPUTS,
        _directed_prose_applies,
    ),
    (
        "kling-v3",
        VIDEO_I2V,
        START_END_FRAME,
    ): _strategy(
        START_END_FRAME,
        VIDEO_I2V,
        START_END_FRAME_REQUIRED_INPUTS,
        _start_end_applies,
        is_default=True,
    ),
    (
        _WILDCARD_MODEL,
        R2V_MULTI,
        DETERMINISTIC_TEMPLATE,
    ): _strategy(
        DETERMINISTIC_TEMPLATE,
        R2V_MULTI,
        [],
        _deterministic_applies,
        is_default=True,
    ),
    (
        _WILDCARD_MODEL,
        VIDEO_I2V,
        DETERMINISTIC_TEMPLATE,
    ): _strategy(
        DETERMINISTIC_TEMPLATE,
        VIDEO_I2V,
        [],
        _deterministic_applies,
        is_default=True,
    ),
}


def resolve_strategy(
    primitive: ShotPrimitive,
    *,
    explicit: str | None = None,
    model_id: str,
    requested_modality: str | None = None,
) -> tuple[str, AuthorStrategy]:
    """Resolve authoring modality and strategy using the Phase-2 precedence."""

    if not model_id:
        raise StrategyResolutionError("model_id is required to resolve author strategy")

    requested = _normalize_modality(requested_modality)
    shape = _refs_shape_resolution(primitive)

    if os.environ.get("PROSE_AUTHOR_FALLBACK") == "1":
        modality = requested or (shape[0] if shape else _default_modality(model_id))
        strategy = _lookup_strategy(model_id, modality, DETERMINISTIC_TEMPLATE)
        logger.warning(
            "prose_author_fallback reason=env_short_circuit strategy=%s "
            "primitive_id=%s video_model=%s modality=%s",
            strategy.name,
            primitive.shot_id,
            model_id,
            modality,
        )
        return modality, strategy

    named = explicit or os.environ.get("RECOIL_AUTHOR_STRATEGY") or primitive.strategy
    if named:
        modality = _modality_for_named_strategy(model_id, named, requested, shape)
        strategy = _lookup_strategy(model_id, modality, named)
        _block_incompatible(primitive, modality, strategy, requested)
        return modality, strategy

    if shape:
        modality, strategy_name = shape
        if requested and requested != modality:
            if strategy_name == START_END_FRAME:
                raise StrategyResolutionError(
                    f"requested modality {requested!r} is incompatible with "
                    f"primitive ref shape {modality!r}/{strategy_name!r}"
                )
        else:
            strategy = _lookup_strategy(model_id, modality, strategy_name)
            return modality, strategy

    modality = requested or _default_modality(model_id)
    strategy_name = DEFAULT_AUTHOR_STRATEGY.get(
        (model_id, modality),
        DETERMINISTIC_TEMPLATE,
    )
    strategy = _lookup_strategy(model_id, modality, strategy_name)
    _block_incompatible(primitive, modality, strategy, requested)
    return modality, strategy


def get_primitive_value(primitive: ShotPrimitive, dotted_path: str) -> Any:
    """Read a dotted primitive path from dataclass fields and nested dicts."""

    value: Any = primitive
    for part in dotted_path.split("."):
        if isinstance(value, Mapping):
            value = value.get(part)
        else:
            value = getattr(value, part, None)
        if value is None:
            return None
    return value


def missing_required_inputs(
    primitive: ShotPrimitive, strategy: AuthorStrategy
) -> list[str]:
    """Return strategy-required primitive paths that are absent or empty."""

    missing: list[str] = []
    for path in strategy.required_inputs:
        value = get_primitive_value(primitive, path)
        if isinstance(value, (list, tuple, dict, set)):
            absent = not value
        elif path == "target_editorial_duration_s":
            absent = not _positive_duration(primitive)
        else:
            absent = value is None or value == ""
        if absent:
            missing.append(path)
    return missing


def require_strategy_inputs(
    primitive: ShotPrimitive, strategy: AuthorStrategy
) -> None:
    """Raise if a strategy cannot be authored from this primitive."""

    missing = missing_required_inputs(primitive, strategy)
    if missing:
        raise AuthorInputError(
            f"{primitive.shot_id}: strategy {strategy.name!r} missing "
            f"required inputs: {', '.join(missing)}"
        )


def _normalize_modality(modality: str | None) -> str | None:
    if modality is None:
        return None
    if modality == "i2v":
        return VIDEO_I2V
    if modality not in _SUPPORTED_MODALITIES:
        raise StrategyResolutionError(f"unsupported author modality {modality!r}")
    return modality


def _default_modality(model_id: str) -> str:
    matches = [
        modality
        for (registered_model, modality), _name in DEFAULT_AUTHOR_STRATEGY.items()
        if registered_model == model_id
    ]
    if len(matches) == 1:
        return matches[0]
    return R2V_MULTI


def _lookup_strategy(
    model_id: str,
    modality: str,
    strategy_name: str,
) -> AuthorStrategy:
    strategy = AUTHOR_STRATEGIES.get((model_id, modality, strategy_name))
    if strategy is not None:
        return strategy
    strategy = AUTHOR_STRATEGIES.get((_WILDCARD_MODEL, modality, strategy_name))
    if strategy is not None:
        return strategy
    raise StrategyResolutionError(
        f"no author strategy registered for "
        f"({model_id!r}, {modality!r}, {strategy_name!r})"
    )


def _refs_shape_resolution(primitive: ShotPrimitive) -> tuple[str, str] | None:
    if _has_start_end_refs(primitive):
        return VIDEO_I2V, START_END_FRAME
    if _has_multi_ref_stack(primitive):
        return R2V_MULTI, DIRECTED_PROSE
    return None


def _has_start_end_refs(primitive: ShotPrimitive) -> bool:
    return bool(get_primitive_value(primitive, "refs.start_frame")) and bool(
        get_primitive_value(primitive, "refs.end_frame")
    )


def _has_multi_ref_stack(primitive: ShotPrimitive) -> bool:
    if len(primitive.char_ids or []) > 1:
        return True
    if len(primitive.timing_segments or []) > 1:
        return True
    manifest = primitive.refs.get("manifest") if isinstance(primitive.refs, dict) else None
    if isinstance(manifest, Mapping):
        identities = [key for key in manifest if str(key).startswith("identity_")]
        return len(identities) > 1
    return False


def _modality_for_named_strategy(
    model_id: str,
    strategy_name: str,
    requested: str | None,
    shape: tuple[str, str] | None,
) -> str:
    if requested:
        return requested
    exact = [
        modality
        for registered_model, modality, name in AUTHOR_STRATEGIES
        if registered_model == model_id and name == strategy_name
    ]
    if len(exact) == 1:
        return exact[0]
    wildcard = [
        modality
        for registered_model, modality, name in AUTHOR_STRATEGIES
        if registered_model == _WILDCARD_MODEL and name == strategy_name
    ]
    if len(wildcard) == 1:
        return wildcard[0]
    if shape and shape[1] == strategy_name:
        return shape[0]
    if strategy_name == START_END_FRAME:
        return VIDEO_I2V
    if strategy_name == DIRECTED_PROSE:
        return R2V_MULTI
    if strategy_name == DETERMINISTIC_TEMPLATE:
        return shape[0] if shape else _default_modality(model_id)
    raise StrategyResolutionError(f"unknown author strategy {strategy_name!r}")


def _block_incompatible(
    primitive: ShotPrimitive,
    modality: str,
    strategy: AuthorStrategy,
    requested: str | None,
) -> None:
    if requested and modality != requested:
        raise StrategyResolutionError(
            f"requested modality {requested!r} resolved to incompatible "
            f"modality {modality!r}"
        )
    if strategy.modality != modality:
        raise StrategyResolutionError(
            f"strategy {strategy.name!r} is registered for {strategy.modality!r}, "
            f"not {modality!r}"
        )
    if strategy.name == START_END_FRAME and modality != VIDEO_I2V:
        raise StrategyResolutionError("start_end_frame strategy requires video_i2v")
    if strategy.name == DIRECTED_PROSE and modality != R2V_MULTI:
        raise StrategyResolutionError("directed_prose strategy requires r2v_multi")
    if strategy.name == SHOT_SPEC and modality != R2V_MULTI:
        raise StrategyResolutionError("shot_spec strategy requires r2v_multi")
    if _has_start_end_refs(primitive) and modality != VIDEO_I2V:
        raise StrategyResolutionError(
            f"{primitive.shot_id}: start+end refs require video_i2v; "
            f"refusing to drop the end frame for {modality!r}"
        )
    if _has_start_end_refs(primitive) and strategy.name not in {
        START_END_FRAME,
        DETERMINISTIC_TEMPLATE,
    }:
        raise StrategyResolutionError(
            f"{primitive.shot_id}: start+end refs are incompatible with "
            f"strategy {strategy.name!r}"
        )


__all__ = [
    "AUTHOR_STRATEGIES",
    "DEFAULT_AUTHOR_STRATEGY",
    "AuthorInputError",
    "AuthorStrategy",
    "StrategyResolutionError",
    "missing_required_inputs",
    "require_strategy_inputs",
    "resolve_strategy",
]
