"""World-state `setting` pass (REC-111).

Derives one per-segment `setting` line — sublocation + who-is-where + prop
state — for an r2v_multi shot's timing segments via a single batched LLM call.
The OUTPUT is LLM-generated; the integration seam around it is deterministic.

Authored ONCE here, then consumed by the shot_spec author (Phase 1) and, later,
the board builder. The call goes through the same `_call_anthropic`-style seam
that `jit_prompt` uses, with the model resolved from
`get_model("prose_author", "text")`.

Fail-soft contract: transient LLM errors / timeouts / malformed output → the
input segments are returned UNCHANGED (pre-existing `setting` keys preserved,
never stripped, none added) and `world_state_pass_skipped reason=...` is logged
at WARNING. Three consecutive skips raise `WorldStatePassOutage` because that
is a systemic settings-derivation outage.

`apply_settings_to_skeleton` is intentionally NOT built — callers pass segments
in memory; there is no new persistence format in this build.
"""

from __future__ import annotations

import json
import logging
from typing import Any

from recoil.core.model_profiles import get_model

logger = logging.getLogger(__name__)

# Per-segment `setting` line hard cap. Lines over this are truncated at a word
# boundary (see _truncate_to_limit).
SETTING_CHAR_LIMIT = 140

# Process-local breaker state. Dispatch runs through StepRunner in one process
# and one thread, so no locking is needed here.
_consecutive_skips: int = 0


class WorldStatePassOutage(RuntimeError):
    """Raised when settings derivation is systemically failing."""


def _reset_breaker_state() -> None:
    """Reset process-local consecutive skip state for tests."""

    global _consecutive_skips
    _consecutive_skips = 0

WORLD_STATE_SYSTEM = (
    "You are the world-state continuity author for a multi-shot video prompt.\n"
    "You receive an ordered list of N segments for ONE continuous scene. For "
    "each segment, write ONE `setting` line stating the physical world state at "
    "that moment: the sublocation, who is where, and the state of any props the "
    "actions touch.\n"
    "\n"
    "Rules:\n"
    "- Output EXACTLY N lines, one per segment, in the SAME order as the input. "
    "No numbering, no blank lines, no preamble, no trailing commentary.\n"
    "- Derive who/where/prop-state from each segment's actions IN ORDER. "
    "Persistent state carries forward: a pod opened in segment 2 stays open in "
    "segment 3 unless a later action changes it.\n"
    "- When a segment names a `sublocation`, begin its line with that "
    "sublocation's name (e.g. `Pod platform: ...`).\n"
    "- Flat, declarative, present tense. No camera language, no metaphors, no "
    "atmosphere or mood essays, no interiority. Keep each line under 140 "
    "characters.\n"
    "- Example: `Pod platform: Jade beside the open cryo-pod; Wren seated "
    "inside; the drop below.`"
)


def derive_settings(
    segments: list[dict],
    *,
    location_id: str | None,
    char_ids: list[str],
    sublocations: dict | None = None,
    model: str | None = None,
) -> list[dict]:
    """Return a NEW segment list where each segment gains a `setting` line.

    One batched LLM call authors all N lines. The input list is never mutated.
    On transient failure (transport error, line-count mismatch) the input
    `segments` are returned unchanged — pre-existing `setting` keys preserved,
    none added. Three consecutive failures raise `WorldStatePassOutage`.
    """

    seg_list = list(segments or [])
    if not seg_list:
        return seg_list

    try:
        # Dedicated role (default haiku): setting lines are a <=140-char
        # mechanical formatting task, and the Max-plan OAuth auth has far
        # more rate-limit headroom on haiku than opus/sonnet (verified
        # 2026-06-11: opus/sonnet 429 under interactive-session load while
        # haiku passes). Revisit via model_roles.json text.world_state.
        model_id = model or get_model("world_state", "text")
        user_prompt = _build_user_prompt(
            seg_list,
            location_id=location_id,
            char_ids=char_ids,
            sublocations=sublocations,
        )
        raw = _call_world_state_model(model_id, WORLD_STATE_SYSTEM, user_prompt)
        lines = [line.strip() for line in (raw or "").splitlines() if line.strip()]
        if len(lines) != len(seg_list):
            skip_count = _record_skip()
            logger.warning(
                "world_state_pass_skipped reason=line_count_mismatch "
                "expected=%d got=%d skip_count=%d",
                len(seg_list),
                len(lines),
                skip_count,
            )
            if skip_count >= 3:
                raise WorldStatePassOutage(
                    "malformed output: settings derivation is systemically "
                    f"failing after {skip_count} consecutive skips"
                )
            return segments
        new_segments: list[dict] = []
        for seg, line in zip(seg_list, lines):
            new_seg = dict(seg) if isinstance(seg, dict) else {}
            subloc = new_seg.get("sublocation")
            text = line
            if subloc:
                # Anchoring contract: a sublocation-tagged segment setting
                # must lead with its sublocation. Repair the prefix rather
                # than trusting prompt-only enforcement of an LLM seam.
                label = str(subloc).replace("_", " ").strip()
                if not text.lower().lstrip().startswith(label.lower()):
                    text = f"{label.capitalize()}: {text.lstrip()}"
            new_seg["setting"] = _truncate_to_limit(text)
            new_segments.append(new_seg)
        _reset_breaker_state()
        return new_segments
    except WorldStatePassOutage:
        raise
    except Exception as exc:  # noqa: BLE001 — transient failures stay fail-soft
        reason = type(exc).__name__
        skip_count = _record_skip()
        logger.warning(
            "world_state_pass_skipped reason=%s skip_count=%d",
            reason,
            skip_count,
        )
        if skip_count >= 3:
            raise WorldStatePassOutage(
                f"{reason}: settings derivation is systemically failing after "
                f"{skip_count} consecutive skips"
            ) from exc
        return segments


def _record_skip() -> int:
    global _consecutive_skips
    _consecutive_skips += 1
    return _consecutive_skips


def _build_user_prompt(
    segments: list[dict],
    *,
    location_id: str | None,
    char_ids: list[str],
    sublocations: dict | None,
) -> str:
    """Serialize the ordered segments + scene context for the batched call."""

    allowed = sorted(sublocations.keys()) if isinstance(sublocations, dict) else []
    payload = {
        "location_id": location_id,
        "char_ids": list(char_ids or []),
        "allowed_sublocations": allowed,
        "segment_count": len(segments),
        "segments": [
            _segment_context(i, seg)
            for i, seg in enumerate(segments)
        ],
        "instructions": (
            f"Output EXACTLY {len(segments)} setting lines, one per segment, "
            "in order."
        ),
    }
    return json.dumps(payload, ensure_ascii=True, indent=2, default=str)


def _seg_get(seg: Any, key: str) -> Any:
    if isinstance(seg, dict):
        return seg.get(key)
    return getattr(seg, key, None)


def _segment_context(index: int, seg: Any) -> dict[str, Any]:
    """Serialize enough per-segment state for who/where/prop inference."""

    if isinstance(seg, dict):
        raw = dict(seg)
    else:
        raw = {
            key: getattr(seg, key)
            for key in dir(seg)
            if not key.startswith("_") and not callable(getattr(seg, key, None))
        }
    # Keep the compact, frequently-used keys first while still preserving any
    # action/dialogue/character fields a caller already attached to the segment.
    ordered: dict[str, Any] = {
        "index": index,
        "shot_id": _seg_get(seg, "shot_id"),
        "sublocation": _seg_get(seg, "sublocation"),
        "intent": _seg_get(seg, "intent"),
    }
    for key, value in raw.items():
        ordered.setdefault(key, value)
    return ordered


def _truncate_to_limit(text: str, limit: int = SETTING_CHAR_LIMIT) -> str:
    """Trim a setting line to <= limit chars, cutting at a word boundary."""

    text = (text or "").strip()
    if len(text) <= limit:
        return text
    window = text[:limit]
    boundary = window.rfind(" ")
    if boundary > 0:
        return window[:boundary].rstrip()
    return window.rstrip()


def _call_world_state_model(model: str, system_prompt: str, user_prompt: str) -> str:
    """LLM call seam — monkeypatched in tests.

    Mirrors `jit_prompt._call_anthropic`: a thin wrapper over the Anthropic
    client that returns the raw model text (expected: exactly N newline-
    separated setting lines).
    """

    from recoil.core.claude_cli import claude_cli_call, claude_transport

    if claude_transport() == "cli":
        return claude_cli_call(user_prompt, system_prompt=system_prompt, model=model)


    from recoil.core.anthropic_client import anthropic_client

    client = anthropic_client()
    response = client.messages.create(
        model=model,
        max_tokens=1024,
        system=system_prompt,
        messages=[{"role": "user", "content": user_prompt}],
    )
    return response.content[0].text


__all__ = [
    "derive_settings",
    "WORLD_STATE_SYSTEM",
    "SETTING_CHAR_LIMIT",
    "WorldStatePassOutage",
]
