from __future__ import annotations

import json

import pytest

from recoil.core.paths import ProjectPaths
from recoil.pipeline._lib import plan_overrides
from recoil.pipeline._lib.plan_overrides import PlanOverridesError


@pytest.fixture
def project_paths(tmp_path, monkeypatch):
    paths = ProjectPaths.from_root(tmp_path / "tartarus")
    monkeypatch.setattr(
        plan_overrides.ProjectPaths,
        "for_project",
        classmethod(lambda cls, project: paths),
    )
    return paths


def _plan() -> dict:
    return {
        "episode_id": "EP001",
        "shots": [
            {
                "shot_id": "EP001_SH02",
                "source_text_hash": "h2",
                "source_text": "The test shot.",
                "prompt_data": {"focal_length": "85mm"},
            }
        ],
    }


def _override(
    *,
    shot_id: str = "EP001_SH02",
    target_span_hash: str = "h2",
    fields: dict | None = None,
) -> dict:
    return {
        "shot_id": shot_id,
        "target_span_hash": target_span_hash,
        "fields": fields if fields is not None else {"prompt_data": {"shot_type": "ECU"}},
        "authored_at": "2026-06-24T00:00:00Z",
    }


def _write_overrides(project_paths: ProjectPaths, payload) -> None:
    path = project_paths.plans_dir.parent / "plan_overrides" / "ep_001.json"
    path.parent.mkdir(parents=True, exist_ok=True)
    if isinstance(payload, bytes):
        path.write_bytes(payload)
    else:
        path.write_text(json.dumps(payload), encoding="utf-8")


def test_fresh_override_applies():
    plan = _plan()

    result, flags = plan_overrides.apply_overrides(
        plan,
        [_override()],
        {"EP001_SH02": "h2"},
    )

    assert result is plan
    assert plan["shots"][0]["prompt_data"]["shot_type"] == "ECU"
    assert flags == []
    assert plan["override_flags"] == []


def test_stale_override_flagged_not_applied():
    plan = _plan()

    _result, flags = plan_overrides.apply_overrides(
        plan,
        [_override()],
        {"EP001_SH02": "h2_CHANGED"},
    )

    expected = [
        {
            "shot_id": "EP001_SH02",
            "reason": "stale_span",
            "target_span_hash": "h2",
            "live_hash": "h2_CHANGED",
        }
    ]
    assert "shot_type" not in plan["shots"][0]["prompt_data"]
    assert flags == expected
    assert plan["override_flags"] == expected


def test_orphan_override_flagged_not_applied():
    plan = _plan()

    _result, flags = plan_overrides.apply_overrides(
        plan,
        [_override(shot_id="EP001_SH99")],
        {"EP001_SH02": "h2"},
    )

    assert "shot_type" not in plan["shots"][0]["prompt_data"]
    assert flags == [
        {
            "shot_id": "EP001_SH99",
            "reason": "orphan",
            "target_span_hash": "h2",
        }
    ]
    assert plan["override_flags"] == flags


def test_currency_unavailable_flagged_distinct_from_orphan():
    plan = _plan()

    _result, flags = plan_overrides.apply_overrides(
        plan,
        [_override()],
        {"EP001_SH02": None},
    )

    assert "shot_type" not in plan["shots"][0]["prompt_data"]
    assert flags == [
        {
            "shot_id": "EP001_SH02",
            "reason": "currency_unavailable",
            "target_span_hash": "h2",
        }
    ]
    assert flags[0]["reason"] not in {"orphan", "stale_span"}
    assert plan["override_flags"] == flags


def test_empty_overrides_noop_sets_empty_flags():
    plan = _plan()
    before = json.loads(json.dumps(plan))

    result, flags = plan_overrides.apply_overrides(
        plan,
        [],
        {"EP001_SH02": "h2"},
    )

    assert result is plan
    assert flags == []
    assert plan == {**before, "override_flags": []}


def test_deep_merge_preserves_existing_nested_fields():
    plan = _plan()

    _result, flags = plan_overrides.apply_overrides(
        plan,
        [_override(fields={"prompt_data": {"shot_type": "ECU"}})],
        {"EP001_SH02": "h2"},
    )

    assert flags == []
    assert plan["shots"][0]["prompt_data"] == {
        "focal_length": "85mm",
        "shot_type": "ECU",
    }


def test_load_overrides_absent_returns_empty(project_paths):
    assert plan_overrides.load_overrides("tartarus", 1) == []


def test_load_overrides_validates_and_returns_overrides(project_paths):
    override = _override()
    _write_overrides(
        project_paths,
        {
            "schema_version": 1,
            "episode_id": "EP001",
            "overrides": [override],
        },
    )

    assert plan_overrides.load_overrides("tartarus", 1) == [override]


@pytest.mark.parametrize(
    "payload",
    [
        b"{bad",
        [],
        {"schema_version": 2, "episode_id": "EP001", "overrides": []},
        {"schema_version": 1, "episode_id": "EP009", "overrides": []},
        {"schema_version": 1, "episode_id": "EP001", "overrides": "bad"},
        {
            "schema_version": 1,
            "episode_id": "EP001",
            "overrides": [{"target_span_hash": "h2", "fields": {}}],
        },
        {
            "schema_version": 1,
            "episode_id": "EP001",
            "overrides": [{"shot_id": "EP001_SH02", "target_span_hash": 123, "fields": {}}],
        },
        {
            "schema_version": 1,
            "episode_id": "EP001",
            "overrides": [
                {"shot_id": "EP001_SH02", "target_span_hash": "h2", "fields": "bad"}
            ],
        },
    ],
)
def test_load_overrides_invalid_schema_raises(project_paths, payload):
    _write_overrides(project_paths, payload)

    with pytest.raises(PlanOverridesError):
        plan_overrides.load_overrides("tartarus", 1)
