"""CP-9 Phase 6 integration test 4 — PanelOfJudges cost-cap aborts mid-run.

Per Phase 1 audit § 4 Locked Decision #3 (and BUILD_SPEC § Phase 3): when
the cumulative cost across judges meets-or-exceeds `cost_cap_usd`, the panel
hard-aborts BEFORE invoking the next judge. The partial scorecard preserves
the judges that ran, and `panel_warnings` carries the locked warning string
``cost_cap_aborted_at_judge_N`` (where N is the index of the FIRST judge
NOT invoked).

This test exercises the cap inside the integration path: the panel runs
inside the post_step hook attached to a real Workflow, dispatched against
a real StubStepRunner. The scorecard lands on the real
GenerationReceipt.eval_scores via in-place mutation.

Mocking: only at EvalNode + StepRunner boundaries. Workflow / dispatch /
attach_eval_hooks / PanelOfJudges are real.
"""

import sys
import pathlib

sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent.parent.parent))
from recoil.core.paths import ensure_pipeline_importable  # noqa: E402

ensure_pipeline_importable()

import pytest  # noqa: E402

from recoil.pipeline.core.dispatch import _reset_bootstrap_for_tests  # noqa: E402
from recoil.pipeline.core.dispatch_context import DispatchContext  # noqa: E402
from recoil.pipeline.core.eval import (  # noqa: E402
    PanelOfJudges,
    attach_eval_hooks,
    _reset_eval_registry_for_tests,
)
from recoil.pipeline.core.registry import _reset_for_tests  # noqa: E402
from recoil.pipeline.core.workflow import Workflow, WorkflowStep  # noqa: E402

from recoil.pipeline.core.tests._eval_test_helpers import (  # noqa: E402
    FakeEvalNode,
    StubStepRunner,
)


@pytest.fixture(autouse=True)
def reset_registries():
    _reset_for_tests()
    _reset_bootstrap_for_tests()
    _reset_eval_registry_for_tests()
    yield
    _reset_for_tests()
    _reset_bootstrap_for_tests()
    _reset_eval_registry_for_tests()


def _single_step_wf() -> Workflow:
    return Workflow(
        workflow_id="wf_cost_cap",
        steps=[
            WorkflowStep(
                step_id="kf",
                modality="image_t2i",
                payload={"shot_id": "EP001_SH02", "prompt": "p", "model": "nbp", "aspect_ratio": "9_16"},
            ),
        ],
    )


def test_cost_cap_aborts_after_first_judge_partial_scorecard_persists():
    """cost_cap=$0.01, 3 judges costing $0.05 each.

    Probe sequence:
      - before judge 0: accumulated=0.00 < 0.01 → run judge 0; acc → 0.05
      - before judge 1: accumulated=0.05 >= 0.01 → ABORT, append
        ``cost_cap_aborted_at_judge_1``
    Result: 1 judge ran, 2 didn't. Scorecard carries that 1 judge's result
    AND the warning string with index = 1 (the FIRST judge not invoked).
    """
    sr = StubStepRunner()
    ctx = DispatchContext(
        caller_id="phase6_cost_cap",
        step_runner=sr,
        receipts_log_path="DISABLED",
    )

    judges = [
        FakeEvalNode(judge_id="cap_j0", score=0.5, cost_usd=0.05),
        FakeEvalNode(judge_id="cap_j1", score=0.6, cost_usd=0.05),
        FakeEvalNode(judge_id="cap_j2", score=0.7, cost_usd=0.05),
    ]
    panel = PanelOfJudges(
        panel_id="cap_panel",
        judges=judges,
        cost_cap_usd=0.01,
    )

    _, post_step, _ = attach_eval_hooks(_single_step_wf(), panel)

    wf = _single_step_wf()
    wf.run(context=ctx, post_step=post_step)

    step = wf.steps[0]
    assert step.status == "succeeded"
    assert step.receipt is not None

    # The scorecard is on the receipt via in-place mutation.
    scorecard = step.receipt.eval_scores["cap_panel"]

    # Only the first judge ran; 2 + 3 were never invoked.
    assert len(scorecard["judges"]) == 1
    assert scorecard["judges"][0]["judge_id"] == "cap_j0"
    assert judges[0].calls and len(judges[0].calls) == 1
    assert judges[1].calls == []
    assert judges[2].calls == []

    # Warning string format LOCKED: f"cost_cap_aborted_at_judge_{idx}"
    # where idx is the index of the FIRST judge NOT invoked. Here that's 1.
    abort_warnings = [
        w for w in scorecard["panel_warnings"] if w.startswith("cost_cap_aborted")
    ]
    assert len(abort_warnings) == 1
    assert abort_warnings[0] == "cost_cap_aborted_at_judge_1"

    # panel_score is the median of [0.5] = 0.5 (single-judge fallback)
    assert scorecard["panel_score"] == pytest.approx(0.5)
    # panel_cost_usd reflects only the cost of the judges that actually ran
    assert scorecard["panel_cost_usd"] == pytest.approx(0.05)


def test_cost_cap_zero_aborts_before_any_judge_warning_index_zero():
    """cost_cap=0.0 + first probe sees accumulated=0.00 >= 0.00 → abort
    immediately. Warning string has index 0 (the FIRST judge not invoked is
    judge 0). Scorecard carries panel_score=None (no judges ran)."""
    sr = StubStepRunner()
    ctx = DispatchContext(
        caller_id="phase6_cost_cap",
        step_runner=sr,
        receipts_log_path="DISABLED",
    )

    judge = FakeEvalNode(judge_id="zero_j", score=0.5, cost_usd=0.05)
    panel = PanelOfJudges(panel_id="zero_panel", judges=[judge], cost_cap_usd=0.0)

    _, post_step, _ = attach_eval_hooks(_single_step_wf(), panel)

    wf = _single_step_wf()
    wf.run(context=ctx, post_step=post_step)

    step = wf.steps[0]
    scorecard = step.receipt.eval_scores["zero_panel"]
    assert scorecard["panel_score"] is None  # no judges → None
    assert scorecard["judges"] == []
    assert judge.calls == []
    abort_warnings = [
        w for w in scorecard["panel_warnings"] if w.startswith("cost_cap_aborted")
    ]
    assert abort_warnings == ["cost_cap_aborted_at_judge_0"]


def test_cost_cap_high_enough_runs_all_no_warnings():
    """Negative control — when cap is well above projected cost, no abort
    fires and no cost-cap warning appears in panel_warnings."""
    sr = StubStepRunner()
    ctx = DispatchContext(
        caller_id="phase6_cost_cap",
        step_runner=sr,
        receipts_log_path="DISABLED",
    )

    judges = [
        FakeEvalNode(judge_id="ok_j0", score=0.5, cost_usd=0.001),
        FakeEvalNode(judge_id="ok_j1", score=0.5, cost_usd=0.001),
        FakeEvalNode(judge_id="ok_j2", score=0.5, cost_usd=0.001),
    ]
    panel = PanelOfJudges(
        panel_id="ok_panel",
        judges=judges,
        cost_cap_usd=10.0,
    )

    _, post_step, _ = attach_eval_hooks(_single_step_wf(), panel)

    wf = _single_step_wf()
    wf.run(context=ctx, post_step=post_step)

    step = wf.steps[0]
    scorecard = step.receipt.eval_scores["ok_panel"]
    assert len(scorecard["judges"]) == 3
    abort_warnings = [
        w for w in scorecard["panel_warnings"] if w.startswith("cost_cap_aborted")
    ]
    assert abort_warnings == []
    assert all(j.calls and len(j.calls) == 1 for j in judges)


def test_cost_cap_aborts_at_judge_2_when_first_two_fit_under_cap():
    """cap=$0.10, 3 judges costing $0.06 each.

    Probe sequence:
      - before judge 0: 0.00 < 0.10 → run; acc → 0.06
      - before judge 1: 0.06 < 0.10 → run; acc → 0.12
      - before judge 2: 0.12 >= 0.10 → ABORT at index 2
    Warning index is 2 (the FIRST judge NOT invoked)."""
    sr = StubStepRunner()
    ctx = DispatchContext(
        caller_id="phase6_cost_cap",
        step_runner=sr,
        receipts_log_path="DISABLED",
    )

    judges = [
        FakeEvalNode(judge_id="j0", score=0.4, cost_usd=0.06),
        FakeEvalNode(judge_id="j1", score=0.6, cost_usd=0.06),
        FakeEvalNode(judge_id="j2", score=0.8, cost_usd=0.06),
    ]
    panel = PanelOfJudges(
        panel_id="mid_panel",
        judges=judges,
        cost_cap_usd=0.10,
    )

    _, post_step, _ = attach_eval_hooks(_single_step_wf(), panel)

    wf = _single_step_wf()
    wf.run(context=ctx, post_step=post_step)

    scorecard = wf.steps[0].receipt.eval_scores["mid_panel"]
    assert len(scorecard["judges"]) == 2
    assert [j["judge_id"] for j in scorecard["judges"]] == ["j0", "j1"]
    assert judges[2].calls == []
    abort_warnings = [
        w for w in scorecard["panel_warnings"] if w.startswith("cost_cap_aborted")
    ]
    assert abort_warnings == ["cost_cap_aborted_at_judge_2"]
    # panel_score is median of [0.4, 0.6] = 0.5
    assert scorecard["panel_score"] == pytest.approx(0.5)
    assert scorecard["panel_cost_usd"] == pytest.approx(0.12)
