"""Shot intent primitives for prompt authoring.

`ShotPrimitive` is deliberately model-agnostic. It carries the normalized
editorial intent and resolved reference shape that later authoring phases use
to choose a strategy.
"""

from __future__ import annotations

import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Mapping


@dataclass
class ShotPrimitive:
    """Canonical shot intent used by prompt-authoring strategies."""

    shot_id: str
    scene_index: int
    shot_type: str
    target_editorial_duration_s: float
    intent: str
    camera_side: str | None = None
    screen_direction: str | None = None
    has_dialogue: bool | None = None
    char_ids: list[str] = field(default_factory=list)
    location_id: str | None = None
    timing_segments: list[Any] = field(default_factory=list)
    strategy: str | None = None
    refs: dict[str, Any] = field(default_factory=dict)


_WS_RE = re.compile(r"\s+")


def primitive_from_storyboard_shot(
    shot: Any, *, scene_defaults: Mapping[str, Any] | None
) -> ShotPrimitive:
    """Normalize one storyboard/plan shot without mutating its source schema."""

    raw = _raw_mapping(shot)
    defaults = scene_defaults or {}

    prompt_data = _mapping_value(raw, "prompt_data") or {}
    routing_data = _mapping_value(raw, "routing_data") or {}
    spatial_data = _mapping_value(raw, "spatial_data") or {}
    asset_data = _mapping_value(raw, "asset_data") or {}
    default_prompt_data = _mapping_value(defaults, "prompt_data") or {}
    default_routing_data = _mapping_value(defaults, "routing_data") or {}
    default_spatial_data = _mapping_value(defaults, "spatial_data") or {}
    default_asset_data = _mapping_value(defaults, "asset_data") or {}

    shot_id = _string(_first_present(raw, shot, "shot_id")) or "UNKNOWN_SHOT"
    scene_index = _int(
        _prefer(
            _first_present(raw, shot, "scene_index"),
            _mapping_value(defaults, "scene_index"),
            default=0,
        )
    )
    shot_type = _string(
        _prefer(
            _first_present(raw, shot, "shot_type"),
            _mapping_value(prompt_data, "shot_type"),
            _mapping_value(default_prompt_data, "shot_type"),
            _mapping_value(defaults, "shot_type"),
            default="",
        )
    )
    duration = _float(
        _prefer(
            _first_present(raw, shot, "target_editorial_duration_s"),
            _mapping_value(routing_data, "target_editorial_duration_s"),
            _first_present(raw, shot, "duration_s"),
            _mapping_value(default_routing_data, "target_editorial_duration_s"),
            _mapping_value(defaults, "target_editorial_duration_s"),
            _mapping_value(defaults, "duration_s"),
            default=0.0,
        )
    )

    refs = _refs_from_raw(raw)
    if _mapping_value(raw, "provenance") is not None:
        refs["provenance"] = _shallow_copy(_mapping_value(raw, "provenance"))

    timing_segments = _list_copy(
        _prefer(
            _mapping_value(raw, "timing_segments"),
            _mapping_value(defaults, "timing_segments"),
            default=[],
        )
    )

    return ShotPrimitive(
        shot_id=shot_id,
        scene_index=scene_index,
        shot_type=shot_type,
        target_editorial_duration_s=duration,
        intent=_derive_intent(raw, shot_id=shot_id, shot_type=shot_type),
        camera_side=_optional_string(
            _prefer(
                _mapping_value(raw, "camera_side"),
                _mapping_value(spatial_data, "camera_side"),
                _mapping_value(default_spatial_data, "camera_side"),
                _mapping_value(defaults, "camera_side"),
            )
        ),
        screen_direction=_optional_string(
            _prefer(
                _mapping_value(raw, "screen_direction"),
                _mapping_value(spatial_data, "screen_direction"),
                _mapping_value(default_spatial_data, "screen_direction"),
                _mapping_value(defaults, "screen_direction"),
            )
        ),
        has_dialogue=_optional_bool(
            _prefer(
                _mapping_value(raw, "has_dialogue"),
                _mapping_value(routing_data, "has_dialogue"),
                _mapping_value(default_routing_data, "has_dialogue"),
                _mapping_value(defaults, "has_dialogue"),
            )
        ),
        char_ids=_char_ids_from_sources(raw, shot, asset_data, defaults),
        location_id=_optional_string(
            _prefer(
                _mapping_value(raw, "location_id"),
                _mapping_value(asset_data, "location_id"),
                _first_present(raw, shot, "location_id"),
                _mapping_value(default_asset_data, "location_id"),
                _mapping_value(defaults, "location_id"),
            )
        ),
        timing_segments=timing_segments,
        strategy=_optional_string(_mapping_value(raw, "strategy")),
        refs=refs,
    )


def primitive_from_payload_context(
    ctx: Any,
    *,
    ref_manifest: Mapping[str, Any] | None,
    start_frame: Any = None,
    end_frame: Any = None,
    segment_timestamps: list[Any] | None = None,
) -> ShotPrimitive:
    """Build a primitive from the live dispatch PayloadContext shape."""

    shot = getattr(ctx, "shot", None)
    batch_shots = list(getattr(ctx, "batch_shots", None) or [])
    sources = batch_shots if batch_shots else ([shot] if shot is not None else [])
    if shot is None and sources:
        shot = sources[0]

    scene_defaults: dict[str, Any] = {}
    if getattr(ctx, "duration_s", None) is not None:
        scene_defaults["duration_s"] = getattr(ctx, "duration_s")

    primitive = primitive_from_storyboard_shot(shot or {}, scene_defaults=scene_defaults)

    ctx_shot_id = _optional_string(getattr(ctx, "shot_id", None))
    if ctx_shot_id:
        primitive.shot_id = ctx_shot_id
    if getattr(ctx, "duration_s", None) is not None:
        primitive.target_editorial_duration_s = _float(getattr(ctx, "duration_s"))

    refs: dict[str, Any] = dict(primitive.refs)
    if ref_manifest is not None:
        refs["manifest"] = dict(ref_manifest)
    if start_frame is not None:
        refs["start_frame"] = _pathish(start_frame)
    if end_frame is not None:
        refs["end_frame"] = _pathish(end_frame)
    primitive.refs = refs

    if sources:
        primitive.char_ids = _dedupe(
            cid
            for src in sources
            for cid in _char_ids_from_sources(
                _raw_mapping(src),
                src,
                _mapping_value(_raw_mapping(src), "asset_data") or {},
                {},
            )
        )
        if batch_shots or segment_timestamps:
            primitive.timing_segments = _timing_segments_from_batch(
                sources,
                segment_timestamps=segment_timestamps,
            )
        if batch_shots:
            primitive.target_editorial_duration_s = sum(
                _duration_for_source(src) for src in sources
            )
            primitive.intent = _condense(
                " | ".join(
                    _derive_intent(
                        _raw_mapping(src),
                        shot_id=_string(_first_present(_raw_mapping(src), src, "shot_id")),
                        shot_type=_string(
                            _prefer(
                                _first_present(_raw_mapping(src), src, "shot_type"),
                                _mapping_value(
                                    _mapping_value(_raw_mapping(src), "prompt_data") or {},
                                    "shot_type",
                                ),
                            )
                        ),
                    )
                    for src in sources
                )
            ) or primitive.intent

    return primitive


def _raw_mapping(value: Any) -> Mapping[str, Any]:
    if value is None:
        return {}
    if isinstance(value, Mapping):
        return value
    raw = getattr(value, "raw", None)
    if isinstance(raw, Mapping):
        return raw
    model_dump = getattr(value, "model_dump", None)
    if callable(model_dump):
        dumped = model_dump()
        if isinstance(dumped, Mapping):
            return dumped
    return {}


def _mapping_value(mapping: Any, key: str) -> Any:
    if isinstance(mapping, Mapping):
        return mapping.get(key)
    return None


def _first_present(raw: Mapping[str, Any], obj: Any, key: str) -> Any:
    if key in raw and raw.get(key) is not None:
        return raw.get(key)
    if obj is not None and hasattr(obj, key):
        return getattr(obj, key)
    return None


def _prefer(*values: Any, default: Any = None) -> Any:
    for value in values:
        if value is not None and value != "":
            return value
    return default


def _string(value: Any) -> str:
    if value is None:
        return ""
    if hasattr(value, "value"):
        value = value.value
    return str(value).strip()


def _optional_string(value: Any) -> str | None:
    text = _string(value)
    return text or None


def _int(value: Any) -> int:
    try:
        return int(value)
    except (TypeError, ValueError):
        return 0


def _float(value: Any) -> float:
    try:
        return float(value)
    except (TypeError, ValueError):
        return 0.0


def _optional_bool(value: Any) -> bool | None:
    if value is None or value == "":
        return None
    return bool(value)


def _condense(text: str, *, limit: int = 500) -> str:
    condensed = _WS_RE.sub(" ", text).strip(" |")
    if len(condensed) <= limit:
        return condensed
    return condensed[: limit - 3].rstrip() + "..."


def _derive_intent(raw: Mapping[str, Any], *, shot_id: str, shot_type: str) -> str:
    direct = _optional_string(_mapping_value(raw, "intent"))
    if direct:
        return _condense(direct)

    prompt_data = _mapping_value(raw, "prompt_data") or {}
    skeleton = _mapping_value(prompt_data, "prompt_skeleton") or {}
    parts = [
        _optional_string(_mapping_value(raw, "source_text")),
        _optional_string(_mapping_value(skeleton, "action_line")),
        _optional_string(_mapping_value(skeleton, "emotion_line")),
    ]
    intent = _condense(" | ".join(part for part in parts if part))
    if intent:
        return intent
    fallback = " ".join(part for part in (shot_id, shot_type, "shot intent") if part)
    return _condense(fallback) or "shot intent"


def _char_ids_from_sources(
    raw: Mapping[str, Any],
    obj: Any,
    asset_data: Mapping[str, Any],
    defaults: Mapping[str, Any],
) -> list[str]:
    candidates: list[Any] = []
    for value in (
        _mapping_value(raw, "char_ids"),
        _mapping_value(asset_data, "char_ids"),
        _mapping_value(_mapping_value(defaults, "asset_data") or {}, "char_ids"),
        _mapping_value(defaults, "char_ids"),
    ):
        if value:
            candidates.extend(_as_list(value))

    for value in (
        _mapping_value(raw, "characters"),
        _mapping_value(asset_data, "characters"),
        getattr(obj, "characters", None) if obj is not None else None,
        _mapping_value(_mapping_value(defaults, "asset_data") or {}, "characters"),
        _mapping_value(defaults, "characters"),
        _mapping_value(defaults, "characters_present"),
    ):
        if value:
            candidates.extend(_as_list(value))

    return _dedupe(_char_id(candidate) for candidate in candidates)


def _char_id(value: Any) -> str | None:
    if value is None:
        return None
    if isinstance(value, Mapping):
        return _optional_string(value.get("char_id") or value.get("id") or value.get("name"))
    for attr in ("char_id", "id", "name"):
        if hasattr(value, attr):
            return _optional_string(getattr(value, attr))
    return _optional_string(value)


def _as_list(value: Any) -> list[Any]:
    if value is None:
        return []
    if isinstance(value, list):
        return value
    if isinstance(value, tuple):
        return list(value)
    return [value]


def _dedupe(values: Any) -> list[str]:
    seen: set[str] = set()
    out: list[str] = []
    for value in values:
        text = _optional_string(value)
        if not text or text in seen:
            continue
        seen.add(text)
        out.append(text)
    return out


def _refs_from_raw(raw: Mapping[str, Any]) -> dict[str, Any]:
    refs: dict[str, Any] = {}
    raw_refs = _mapping_value(raw, "refs")
    if isinstance(raw_refs, Mapping):
        refs.update(dict(raw_refs))
    for key in ("start_frame", "end_frame"):
        value = _mapping_value(raw, key)
        if value is not None:
            refs[key] = _pathish(value)
    return refs


def _pathish(value: Any) -> Any:
    if isinstance(value, Path):
        return str(value)
    return value


def _shallow_copy(value: Any) -> Any:
    if isinstance(value, Mapping):
        return dict(value)
    if isinstance(value, list):
        return list(value)
    return value


def _list_copy(value: Any) -> list[Any]:
    if isinstance(value, list):
        return list(value)
    if isinstance(value, tuple):
        return list(value)
    return []


def _duration_for_source(source: Any) -> float:
    raw = _raw_mapping(source)
    routing_data = _mapping_value(raw, "routing_data") or {}
    return _float(
        _prefer(
            getattr(source, "duration_s", None),
            _mapping_value(raw, "duration_s"),
            _mapping_value(raw, "target_editorial_duration_s"),
            _mapping_value(routing_data, "target_editorial_duration_s"),
            default=0.0,
        )
    )


def _timing_segments_from_batch(
    sources: list[Any], *, segment_timestamps: list[Any] | None
) -> list[dict[str, Any]]:
    if not sources:
        return []

    durations = [_duration_for_source(src) for src in sources]
    starts: list[float] = []
    ends: list[float] = []

    if segment_timestamps and all(_is_pair(item) for item in segment_timestamps):
        for item in segment_timestamps:
            starts.append(_float(item[0]))
            ends.append(_float(item[1]))
    else:
        raw_starts = [_float(item) for item in (segment_timestamps or [])]
        if len(raw_starts) < len(sources):
            cursor = 0.0
            raw_starts = []
            for duration in durations:
                raw_starts.append(round(cursor, 2))
                cursor += duration
        starts = raw_starts[: len(sources)]
        for i, start in enumerate(starts):
            if i + 1 < len(starts):
                ends.append(starts[i + 1])
            else:
                ends.append(start + durations[i])

    segments: list[dict[str, Any]] = []
    for i, src in enumerate(sources):
        raw = _raw_mapping(src)
        shot_id = _string(_first_present(raw, src, "shot_id")) or f"segment_{i + 1}"
        duration = max(0.0, ends[i] - starts[i]) if i < len(starts) and i < len(ends) else durations[i]
        segment: dict[str, Any] = {
            "shot_id": shot_id,
            "start_s": starts[i] if i < len(starts) else 0.0,
            "end_s": ends[i] if i < len(ends) else durations[i],
            "duration_s": duration,
            "intent": _derive_intent(
                raw,
                shot_id=shot_id,
                shot_type=_string(
                    _prefer(
                        _first_present(raw, src, "shot_type"),
                        _mapping_value(_mapping_value(raw, "prompt_data") or {}, "shot_type"),
                    )
                ),
            ),
        }
        spatial_data = _mapping_value(raw, "spatial_data") or {}
        if isinstance(spatial_data, Mapping):
            for key in ("sublocation", "setting"):
                if key in spatial_data:
                    segment[key] = spatial_data[key]
        segments.append(segment)
    return segments


def _is_pair(value: Any) -> bool:
    return isinstance(value, (list, tuple)) and len(value) == 2
