"""CP-9 Phase 3 — PanelOfJudges.

Coverage:
  - Construction validation (panel_id, judges, aggregation, cost_cap_usd).
  - Median + mean aggregation paths.
  - Outlier flagging (high + low + tightly-clustered no-warning).
  - Judge-exception swallowing → warning + skip.
  - Judge non-EvalResult return → warning + skip.
  - Cost-cap probe: aborts before next judge when accumulated >= cap;
    cap=0 aborts immediately; cap=None unlimited.
  - Re-run idempotency (no internal state leak).
  - Scorecard shape (panel_id / panel_score / panel_warnings / judges /
    aggregation / panel_cost_usd).

All judges are local stubs — no live API calls.
"""

import sys
import pathlib
from pathlib import Path

import pytest

sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent.parent.parent))
from recoil.core.paths import ensure_pipeline_importable  # noqa: E402

ensure_pipeline_importable()

from recoil.pipeline.core.eval import (  # noqa: E402
    EvalContext,
    EvalResult,
    PanelOfJudges,
)


class _FakeJudge:
    """Stub judge — returns a fixed score + cost. Tracks invocations."""
    def __init__(
        self, judge_id: str, score: float = 0.5, cost: float = 0.0,
        model_used: str = "fake-model",
    ) -> None:
        self.judge_id = judge_id
        self.model_used = model_used
        self._score = score
        self._cost = cost
        self.calls = 0

    def evaluate(self, context: EvalContext) -> EvalResult:
        self.calls += 1
        return EvalResult(
            score=self._score, reasoning=f"{self.judge_id}-says",
            judge_id=self.judge_id, model_used=self.model_used,
            cost_usd=self._cost,
        )


class _RaisingJudge:
    judge_id = "raiser"
    model_used = "raiser-model"

    def evaluate(self, context: EvalContext) -> EvalResult:
        raise RuntimeError("simulated provider blowup")


class _NonEvalResultJudge:
    judge_id = "wrongtype"
    model_used = "wrongtype-model"

    def evaluate(self, context: EvalContext):  # type: ignore[no-untyped-def]
        return "not_an_eval_result"


def _ctx(tmp_path: Path) -> EvalContext:
    img = tmp_path / "x.png"
    img.write_bytes(b"\x89PNG\r\n")
    return EvalContext(
        target_artifact_path=img, target_take=None,
        prompt="p", rubric="r", judge_id="ctx_caller",
    )


# ── Construction validation ──────────────────────────────────────────

def test_panel_construction_happy_path_single_judge() -> None:
    p = PanelOfJudges(panel_id="p1", judges=[_FakeJudge("a")])
    assert p.panel_id == "p1"
    assert p.aggregation == "median"
    assert p.cost_cap_usd is None


def test_panel_construction_zero_judges_raises_ValueError() -> None:
    with pytest.raises(ValueError, match="non-empty list"):
        PanelOfJudges(panel_id="p", judges=[])


def test_panel_construction_empty_panel_id_raises() -> None:
    with pytest.raises(ValueError, match="non-empty"):
        PanelOfJudges(panel_id="", judges=[_FakeJudge("a")])


def test_panel_construction_unsupported_aggregation_raises() -> None:
    with pytest.raises(ValueError, match="aggregation"):
        PanelOfJudges(
            panel_id="p", judges=[_FakeJudge("a")],
            aggregation="trimmed_mean",  # not supported
        )


def test_panel_construction_negative_cost_cap_raises() -> None:
    with pytest.raises(ValueError, match=">= 0 or None"):
        PanelOfJudges(
            panel_id="p", judges=[_FakeJudge("a")], cost_cap_usd=-0.01,
        )


def test_panel_construction_cost_cap_zero_allowed() -> None:
    """cost_cap=0 is a valid value (means: abort immediately)."""
    p = PanelOfJudges(
        panel_id="p", judges=[_FakeJudge("a")], cost_cap_usd=0.0,
    )
    assert p.cost_cap_usd == 0.0


def test_panel_construction_non_evalnode_judge_raises_TypeError() -> None:
    class _NotAJudge:
        pass
    with pytest.raises(TypeError, match="EvalNode protocol"):
        PanelOfJudges(panel_id="p", judges=[_NotAJudge()])  # type: ignore[list-item]


def test_panel_construction_aggregation_mean_ok() -> None:
    p = PanelOfJudges(panel_id="p", judges=[_FakeJudge("a")], aggregation="mean")
    assert p.aggregation == "mean"


# ── Aggregation paths ────────────────────────────────────────────────

def test_panel_score_single_judge_median(tmp_path: Path) -> None:
    judge = _FakeJudge("a", score=0.7)
    p = PanelOfJudges(panel_id="p", judges=[judge])
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert sc["panel_score"] == 0.7
    assert sc["aggregation"] == "median"
    assert len(sc["judges"]) == 1
    assert judge.calls == 1


def test_panel_score_three_judges_median(tmp_path: Path) -> None:
    judges = [
        _FakeJudge("a", score=0.6),
        _FakeJudge("b", score=0.8),
        _FakeJudge("c", score=0.7),
    ]
    p = PanelOfJudges(panel_id="p", judges=judges)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert sc["panel_score"] == 0.7  # median of [0.6, 0.7, 0.8]


def test_panel_score_three_judges_mean(tmp_path: Path) -> None:
    judges = [
        _FakeJudge("a", score=0.6),
        _FakeJudge("b", score=0.8),
        _FakeJudge("c", score=0.7),
    ]
    p = PanelOfJudges(panel_id="p", judges=judges, aggregation="mean")
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert sc["panel_score"] == pytest.approx(0.7, abs=1e-9)


def test_panel_score_two_judges_median_uses_average(tmp_path: Path) -> None:
    """statistics.median([0.4, 0.6]) == 0.5 — confirms statistics.median."""
    judges = [_FakeJudge("a", score=0.4), _FakeJudge("b", score=0.6)]
    p = PanelOfJudges(panel_id="p", judges=judges)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert sc["panel_score"] == pytest.approx(0.5)


# ── Outlier flagging ─────────────────────────────────────────────────

def test_panel_score_outlier_flagged_high(tmp_path: Path) -> None:
    judges = [
        _FakeJudge("a", score=0.5),
        _FakeJudge("b", score=0.5),
        _FakeJudge("c", score=0.95),
    ]
    p = PanelOfJudges(panel_id="p", judges=judges)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    # median = 0.5; 0.95 - 0.5 = 0.45 >= 0.3 → flagged
    flagged = [w for w in sc["panel_warnings"] if w.startswith("outlier:")]
    assert len(flagged) == 1
    assert "c" in flagged[0]


def test_panel_score_outlier_flagged_low(tmp_path: Path) -> None:
    judges = [
        _FakeJudge("a", score=0.95),
        _FakeJudge("b", score=0.95),
        _FakeJudge("c", score=0.5),
    ]
    p = PanelOfJudges(panel_id="p", judges=judges)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    flagged = [w for w in sc["panel_warnings"] if w.startswith("outlier:")]
    assert len(flagged) == 1
    assert "c" in flagged[0]


def test_panel_score_no_outliers_no_warnings(tmp_path: Path) -> None:
    judges = [
        _FakeJudge("a", score=0.6),
        _FakeJudge("b", score=0.65),
        _FakeJudge("c", score=0.7),
    ]
    p = PanelOfJudges(panel_id="p", judges=judges)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    flagged = [w for w in sc["panel_warnings"] if w.startswith("outlier:")]
    assert flagged == []


# ── Judge failure modes ──────────────────────────────────────────────

def test_panel_score_judge_exception_skipped(tmp_path: Path) -> None:
    judges = [
        _FakeJudge("a", score=0.5),
        _RaisingJudge(),
        _FakeJudge("c", score=0.7),
    ]
    p = PanelOfJudges(panel_id="p", judges=judges)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert len(sc["judges"]) == 2
    assert sc["panel_score"] == pytest.approx(0.6)  # median of [0.5, 0.7]
    err_warnings = [w for w in sc["panel_warnings"] if "errored" in w]
    assert len(err_warnings) == 1
    assert "RuntimeError" in err_warnings[0]
    assert "raiser" in err_warnings[0]


def test_panel_score_judge_returns_non_evalresult_skipped(tmp_path: Path) -> None:
    judges = [
        _FakeJudge("a", score=0.5),
        _NonEvalResultJudge(),
    ]
    p = PanelOfJudges(panel_id="p", judges=judges)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert len(sc["judges"]) == 1
    assert sc["panel_score"] == 0.5
    bad_warnings = [w for w in sc["panel_warnings"] if "non_EvalResult" in w]
    assert len(bad_warnings) == 1


def test_panel_score_all_judges_errored_returns_panel_score_None(tmp_path: Path) -> None:
    judges = [_RaisingJudge(), _RaisingJudge()]
    p = PanelOfJudges(panel_id="p", judges=judges)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert sc["panel_score"] is None
    assert sc["judges"] == []
    err_warnings = [w for w in sc["panel_warnings"] if "errored" in w]
    assert len(err_warnings) == 2
    assert sc["panel_cost_usd"] == 0.0


# ── Cost cap ─────────────────────────────────────────────────────────

def test_panel_score_cost_cap_aborts_before_next_judge(tmp_path: Path) -> None:
    """cap=$0.05; judges cost $0.04 each. After judge 1: accumulated=$0.04
    (< cap, run judge 2). After judge 2: accumulated=$0.08 (>= cap, abort
    before judge 3)."""
    judges = [
        _FakeJudge("a", score=0.5, cost=0.04),
        _FakeJudge("b", score=0.6, cost=0.04),
        _FakeJudge("c", score=0.7, cost=0.04),
    ]
    p = PanelOfJudges(panel_id="p", judges=judges, cost_cap_usd=0.05)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert len(sc["judges"]) == 2
    assert judges[2].calls == 0  # third judge never invoked
    abort_warnings = [w for w in sc["panel_warnings"]
                      if w.startswith("cost_cap_aborted")]
    assert len(abort_warnings) == 1
    assert "judge_2" in abort_warnings[0]
    assert sc["panel_cost_usd"] == pytest.approx(0.08)


def test_panel_score_cost_cap_zero_aborts_immediately(tmp_path: Path) -> None:
    """cap=0: first probe sees accumulated_cost==0 >= cap → abort before
    invoking any judge."""
    judges = [_FakeJudge("a", score=0.5, cost=0.04)]
    p = PanelOfJudges(panel_id="p", judges=judges, cost_cap_usd=0.0)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert judges[0].calls == 0
    assert sc["judges"] == []
    assert sc["panel_score"] is None
    abort_warnings = [w for w in sc["panel_warnings"]
                      if w.startswith("cost_cap_aborted")]
    assert len(abort_warnings) == 1
    assert "judge_0" in abort_warnings[0]


def test_panel_score_cost_cap_unlimited_default_runs_all(tmp_path: Path) -> None:
    judges = [
        _FakeJudge("a", score=0.5, cost=10.0),
        _FakeJudge("b", score=0.6, cost=10.0),
        _FakeJudge("c", score=0.7, cost=10.0),
    ]
    p = PanelOfJudges(panel_id="p", judges=judges)  # cost_cap_usd=None
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert len(sc["judges"]) == 3
    assert all(j.calls == 1 for j in judges)
    assert sc["panel_cost_usd"] == pytest.approx(30.0)
    assert not any(w.startswith("cost_cap_aborted") for w in sc["panel_warnings"])


# ── Scorecard shape ──────────────────────────────────────────────────

def test_panel_score_panel_cost_usd_sum(tmp_path: Path) -> None:
    judges = [
        _FakeJudge("a", score=0.5, cost=0.01),
        _FakeJudge("b", score=0.6, cost=0.02),
    ]
    p = PanelOfJudges(panel_id="p", judges=judges)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert sc["panel_cost_usd"] == pytest.approx(0.03)


def test_panel_score_judges_list_in_result_full_dicts(tmp_path: Path) -> None:
    judges = [_FakeJudge("a", score=0.5, cost=0.01)]
    p = PanelOfJudges(panel_id="p", judges=judges)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    j_dict = sc["judges"][0]
    expected_keys = {
        "score", "reasoning", "judge_id", "model_used", "cost_usd", "metadata",
    }
    assert set(j_dict.keys()) == expected_keys


def test_panel_score_aggregation_in_result(tmp_path: Path) -> None:
    judges = [_FakeJudge("a", score=0.5)]
    p = PanelOfJudges(panel_id="p", judges=judges, aggregation="mean")
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert sc["aggregation"] == "mean"


def test_panel_score_panel_id_in_result(tmp_path: Path) -> None:
    judges = [_FakeJudge("a", score=0.5)]
    p = PanelOfJudges(panel_id="critic-panel-v1", judges=judges)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert sc["panel_id"] == "critic-panel-v1"


def test_panel_score_scorecard_keys_complete(tmp_path: Path) -> None:
    judges = [_FakeJudge("a", score=0.5, cost=0.01)]
    p = PanelOfJudges(panel_id="p", judges=judges)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    expected = {
        "panel_id", "panel_score", "panel_warnings", "judges",
        "aggregation", "panel_cost_usd",
    }
    assert set(sc.keys()) == expected


def test_panel_score_no_judges_run_carries_errors_in_warnings(
    tmp_path: Path,
) -> None:
    """When all judges fail, panel_warnings carries every error."""
    judges = [_RaisingJudge(), _RaisingJudge(), _RaisingJudge()]
    p = PanelOfJudges(panel_id="p", judges=judges)
    sc = p.score(receipt=None, context=_ctx(tmp_path))
    assert len(sc["panel_warnings"]) == 3
    assert all("RuntimeError" in w for w in sc["panel_warnings"])


# ── Re-run idempotency ──────────────────────────────────────────────

def test_panel_score_rerunnable_no_state_leak(tmp_path: Path) -> None:
    judges = [_FakeJudge("a", score=0.5, cost=0.02)]
    p = PanelOfJudges(panel_id="p", judges=judges, cost_cap_usd=0.10)
    sc1 = p.score(receipt=None, context=_ctx(tmp_path))
    sc2 = p.score(receipt=None, context=_ctx(tmp_path))
    # Each run is independent — costs do NOT accumulate across runs.
    assert sc1["panel_cost_usd"] == pytest.approx(0.02)
    assert sc2["panel_cost_usd"] == pytest.approx(0.02)
    assert sc1["panel_score"] == sc2["panel_score"]
    # Judge invoked twice across two runs.
    assert judges[0].calls == 2
