"""Plan override SSOT helpers.

Override files live at ``plan_overrides/ep_NNN.json`` beside the canonical
plans directory and use schema version 1:

    {
      "schema_version": 1,
      "episode_id": "EP001",
      "overrides": [
        {
          "shot_id": "EP001_SH02",
          "target_span_hash": "sha256:...",
          "fields": {"prompt_data": {"shot_type": "ECU"}},
          "authored_at": "2026-06-24T00:00:00Z",
          "note": "optional operator note"
        }
      ]
    }

``fields`` is a shallow dict of plan-shot keys to override. Nested dict values
are deep-merged onto the matching plan shot so a prompt_data override can add
one key without replacing existing prompt_data fields.
"""

from __future__ import annotations

import copy
import json
from pathlib import Path
from typing import Any

from recoil.core.paths import ProjectPaths


class PlanOverridesError(ValueError):
    """Raised when a plan override file is corrupt or violates the schema."""


def overrides_path(project: str, episode: int) -> Path:
    return (
        ProjectPaths.for_project(project).plans_dir.parent
        / "plan_overrides"
        / f"ep_{episode:03d}.json"
    )


def load_overrides(project: str, episode: int) -> list[dict]:
    """Load and validate an episode override file.

    Missing files return an empty list. Present but malformed files raise
    PlanOverridesError so operator-authored corruption is never treated as
    "no overrides".
    """
    path = overrides_path(project, episode)
    if not path.exists():
        return []

    try:
        data = json.loads(path.read_text(encoding="utf-8"))
    except json.JSONDecodeError as exc:
        raise PlanOverridesError(f"Invalid JSON in {path}: {exc}") from exc
    except OSError as exc:
        raise PlanOverridesError(f"Could not read plan overrides at {path}: {exc}") from exc

    if not isinstance(data, dict):
        raise PlanOverridesError(f"Plan overrides at {path} must be a JSON object")

    if data.get("schema_version") != 1:
        raise PlanOverridesError(
            f"Unsupported plan override schema_version {data.get('schema_version')!r} at {path}"
        )

    expected_episode_id = _episode_id(episode)
    if _normalize_episode_id(data.get("episode_id")) != expected_episode_id:
        raise PlanOverridesError(
            f"Plan overrides at {path} target {data.get('episode_id')!r}, "
            f"expected {expected_episode_id}"
        )

    overrides = data.get("overrides")
    if not isinstance(overrides, list):
        raise PlanOverridesError(f"Plan overrides at {path} must contain an overrides list")

    for index, override in enumerate(overrides):
        _validate_override(override, path, index)

    return overrides


def apply_overrides(
    plan: dict,
    overrides: list[dict],
    live_spans: dict[str, str | None],
) -> tuple[dict, list[dict]]:
    """Apply fresh overrides to ``plan`` and flag stale or missing currency.

    ``live_spans`` is supplied by the caller from the just-derived in-memory
    plan. This function does not read the plan from disk, inspect manifests, or
    call any currency helpers.
    """
    flags: list[dict] = []
    shots_by_id = {
        shot.get("shot_id"): shot
        for shot in plan.get("shots", [])
        if isinstance(shot, dict) and isinstance(shot.get("shot_id"), str)
    }

    for override in overrides:
        shot_id = override["shot_id"]
        target_span_hash = override["target_span_hash"]

        if shot_id not in live_spans or shot_id not in shots_by_id:
            flags.append(
                {
                    "shot_id": shot_id,
                    "reason": "orphan",
                    "target_span_hash": target_span_hash,
                }
            )
            continue

        live_hash = live_spans[shot_id]
        if live_hash is None:
            flags.append(
                {
                    "shot_id": shot_id,
                    "reason": "currency_unavailable",
                    "target_span_hash": target_span_hash,
                }
            )
            continue

        if live_hash != target_span_hash:
            flags.append(
                {
                    "shot_id": shot_id,
                    "reason": "stale_span",
                    "target_span_hash": target_span_hash,
                    "live_hash": live_hash,
                }
            )
            continue

        _deep_merge(shots_by_id[shot_id], override["fields"])

    plan["override_flags"] = flags
    return plan, flags


def _episode_id(episode: int) -> str:
    return f"EP{episode:03d}"


def _normalize_episode_id(value: Any) -> str | None:
    if not isinstance(value, str):
        return None
    normalized = value.strip().upper()
    if normalized.startswith("EP_"):
        normalized = "EP" + normalized[3:]
    return normalized


def _validate_override(override: Any, path: Path, index: int) -> None:
    if not isinstance(override, dict):
        raise PlanOverridesError(f"Override {index} in {path} must be an object")

    shot_id = override.get("shot_id")
    if not isinstance(shot_id, str) or not shot_id:
        raise PlanOverridesError(f"Override {index} in {path} must include a non-empty shot_id")

    target_span_hash = override.get("target_span_hash")
    if not isinstance(target_span_hash, str):
        raise PlanOverridesError(
            f"Override {index} in {path} must include a string target_span_hash"
        )

    fields = override.get("fields")
    if not isinstance(fields, dict):
        raise PlanOverridesError(f"Override {index} in {path} must include fields object")


def _deep_merge(target: dict, fields: dict) -> None:
    for key, value in fields.items():
        existing = target.get(key)
        if isinstance(existing, dict) and isinstance(value, dict):
            _deep_merge(existing, value)
        else:
            target[key] = copy.deepcopy(value)
