"""CP-9 Phase 5 — Take.aggregate_score field + Take.compute_aggregate_score()
helper + JSON round-trip preservation.

Per § 12d corrections from `eval-primitive-audit.md`:
  - Take dataclass appends `aggregate_score: Optional[float] = None` (7th field).
  - compute_aggregate_score aggregates `step.receipt.eval_scores` panel_scores
    across all workflow steps using mean. Returns None when no panels scored.
  - Take.to_dict / Take.from_dict round-trip aggregate_score (None default).

Tests build fake Workflow / WorkflowStep / GenerationReceipt fixtures so we
don't touch Phase 4 runners or live providers.
"""

import sys
import pathlib

sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent.parent.parent))
from recoil.core.paths import ensure_pipeline_importable  # noqa: E402
ensure_pipeline_importable()

import pytest  # noqa: E402

from recoil.pipeline.core.receipts import GenerationReceipt  # noqa: E402
from recoil.pipeline.core.registry import RunResult  # noqa: E402
from recoil.pipeline.core.take import Take  # noqa: E402
from recoil.pipeline.core.workflow import Workflow, WorkflowStep  # noqa: E402


# ─────────────────────────────────────────────────────────────────────────
# Fixture helpers — fabricate Workflow / Step / Receipt with optional
# eval_scores. No live API, no Phase 4 runner invocation.
# ─────────────────────────────────────────────────────────────────────────


def _make_run_result(
    *, success=True, modality="image_t2i", shot_id="X"
) -> RunResult:
    return RunResult(
        id=f"{shot_id}_{modality}_1700000000",
        modality=modality,
        output_path="/tmp/x.png" if success else None,
        output_url=None,
        metadata={"cost_usd": 0.04, "model": "nbp"},
        success=success,
        error=None if success else "boom",
    )


def _make_receipt(
    *,
    modality="image_t2i",
    shot_id="X",
    success=True,
    eval_scores=None,
) -> GenerationReceipt:
    return GenerationReceipt(
        receipt_id=f"rcpt_1_{shot_id}_{modality}",
        modality=modality,
        caller_id="test",
        project="tartarus",
        episode=1,
        shot_id=shot_id,
        timestamp_utc="2026-04-28T03:14:15Z",
        run_result=_make_run_result(success=success, modality=modality, shot_id=shot_id),
        provenance={"dispatch_path": "test"},
        eval_scores=dict(eval_scores or {}),
    )


def _make_step(
    *,
    step_id="kf",
    modality="image_t2i",
    receipt=None,
    status="pending",
) -> WorkflowStep:
    step = WorkflowStep(
        step_id=step_id,
        modality=modality,
        payload={"shot_id": "X", "prompt": "p", "model": "nbp"},
    )
    if receipt is not None:
        step.receipt = receipt
        step.status = status if status != "pending" else "succeeded"
    return step


def _make_take(
    *,
    take_id="t0",
    take_index=0,
    workflow_id="wf0",
    steps=None,
) -> Take:
    if steps is None:
        steps = [_make_step()]
    wf = Workflow(workflow_id=workflow_id, steps=steps)
    return Take(take_id=take_id, take_index=take_index, workflow=wf)


def _scorecard(panel_score):
    """Tiny PanelOfJudges scorecard shape — only `panel_score` is load-bearing
    for CP-9 aggregation."""
    return {"panel_score": panel_score, "judges": []}


# ─────────────────────────────────────────────────────────────────────────
# Field default + basic shape
# ─────────────────────────────────────────────────────────────────────────


def test_take_aggregate_score_default_None():
    """The new 7th field defaults to None when constructed without it."""
    t = _make_take()
    assert t.aggregate_score is None


def test_take_construct_with_aggregate_score():
    """Caller may pre-set aggregate_score (e.g. when reconstructing from disk
    or when an external scorer wrote the score)."""
    wf = Workflow(workflow_id="wf", steps=[_make_step()])
    t = Take(take_id="t", take_index=0, workflow=wf, aggregate_score=0.8)
    assert t.aggregate_score == 0.8


# ─────────────────────────────────────────────────────────────────────────
# compute_aggregate_score — happy paths
# ─────────────────────────────────────────────────────────────────────────


def test_compute_aggregate_score_no_eval_scores_returns_None():
    """Step has receipt but receipt.eval_scores is empty {} → None."""
    rcpt = _make_receipt(eval_scores={})
    t = _make_take(steps=[_make_step(receipt=rcpt)])
    assert t.compute_aggregate_score() is None
    assert t.aggregate_score is None


def test_compute_aggregate_score_single_panel_single_step():
    """One panel scorecard with panel_score=0.7 across one step → 0.7."""
    rcpt = _make_receipt(eval_scores={"panel_a": _scorecard(0.7)})
    t = _make_take(steps=[_make_step(receipt=rcpt)])
    score = t.compute_aggregate_score()
    assert score == pytest.approx(0.7)
    assert t.aggregate_score == pytest.approx(0.7)


def test_compute_aggregate_score_two_panels_one_step():
    """Average across two panels in a single step's eval_scores."""
    rcpt = _make_receipt(eval_scores={
        "panel_a": _scorecard(0.6),
        "panel_b": _scorecard(0.8),
    })
    t = _make_take(steps=[_make_step(receipt=rcpt)])
    score = t.compute_aggregate_score()
    assert score == pytest.approx(0.7)


def test_compute_aggregate_score_one_panel_two_steps():
    """One panel scoring on each of two steps → mean of the two panel_scores."""
    r1 = _make_receipt(modality="image_t2i", shot_id="A",
                       eval_scores={"panel_a": _scorecard(0.4)})
    r2 = _make_receipt(modality="video_i2v", shot_id="A",
                       eval_scores={"panel_a": _scorecard(0.6)})
    t = _make_take(steps=[
        _make_step(step_id="kf", modality="image_t2i", receipt=r1),
        _make_step(step_id="vid", modality="video_i2v", receipt=r2),
    ])
    score = t.compute_aggregate_score()
    assert score == pytest.approx(0.5)


def test_compute_aggregate_score_two_panels_two_steps():
    """Average across all 4 panel scores spanning two steps."""
    r1 = _make_receipt(modality="image_t2i", shot_id="A",
                       eval_scores={
                           "panel_a": _scorecard(0.2),
                           "panel_b": _scorecard(0.4),
                       })
    r2 = _make_receipt(modality="video_i2v", shot_id="A",
                       eval_scores={
                           "panel_a": _scorecard(0.6),
                           "panel_b": _scorecard(0.8),
                       })
    t = _make_take(steps=[
        _make_step(step_id="kf", modality="image_t2i", receipt=r1),
        _make_step(step_id="vid", modality="video_i2v", receipt=r2),
    ])
    score = t.compute_aggregate_score()
    assert score == pytest.approx(0.5)  # (0.2+0.4+0.6+0.8) / 4


# ─────────────────────────────────────────────────────────────────────────
# compute_aggregate_score — defensive paths (skip rules)
# ─────────────────────────────────────────────────────────────────────────


def test_compute_aggregate_score_skips_None_panel_scores():
    """A panel with panel_score=None is skipped; remaining ones average."""
    rcpt = _make_receipt(eval_scores={
        "panel_a": _scorecard(None),
        "panel_b": _scorecard(0.6),
    })
    t = _make_take(steps=[_make_step(receipt=rcpt)])
    assert t.compute_aggregate_score() == pytest.approx(0.6)


def test_compute_aggregate_score_skips_non_dict_scorecard_values():
    """Defensive: malformed eval_scores values (str / int / list) are skipped."""
    rcpt = _make_receipt(eval_scores={
        "panel_a": "garbage",
        "panel_b": 42,
        "panel_c": ["not", "a", "dict"],
        "panel_d": _scorecard(0.5),
    })
    t = _make_take(steps=[_make_step(receipt=rcpt)])
    assert t.compute_aggregate_score() == pytest.approx(0.5)


def test_compute_aggregate_score_skips_non_numeric_panel_scores():
    """panel_score that isn't coercible to float (e.g. dict, list) is skipped."""
    rcpt = _make_receipt(eval_scores={
        "panel_a": {"panel_score": "not_a_number"},
        "panel_b": {"panel_score": [0.5]},
        "panel_c": _scorecard(0.9),
    })
    t = _make_take(steps=[_make_step(receipt=rcpt)])
    # "not_a_number" → ValueError (swallowed), [0.5] → TypeError (swallowed),
    # 0.9 → kept.
    assert t.compute_aggregate_score() == pytest.approx(0.9)


def test_compute_aggregate_score_step_without_receipt_skipped():
    """A step that hasn't been executed yet (receipt=None) is silently skipped."""
    rcpt = _make_receipt(eval_scores={"panel_a": _scorecard(0.5)})
    t = _make_take(steps=[
        _make_step(step_id="kf", receipt=rcpt),
        _make_step(step_id="vid"),  # no receipt → skip
    ])
    assert t.compute_aggregate_score() == pytest.approx(0.5)


def test_compute_aggregate_score_all_steps_unscored_returns_None():
    """No step has any scoring panel → None, sets self.aggregate_score = None."""
    t = _make_take(steps=[_make_step(), _make_step(step_id="vid")])
    assert t.compute_aggregate_score() is None
    assert t.aggregate_score is None


def test_compute_aggregate_score_string_numeric_panel_score_coerced():
    """A panel_score stored as a numeric string ("0.42") coerces via float()."""
    rcpt = _make_receipt(eval_scores={"panel_a": _scorecard("0.42")})
    t = _make_take(steps=[_make_step(receipt=rcpt)])
    assert t.compute_aggregate_score() == pytest.approx(0.42)


# ─────────────────────────────────────────────────────────────────────────
# State / idempotency
# ─────────────────────────────────────────────────────────────────────────


def test_compute_aggregate_score_returns_and_stores():
    """Method returns the same value it stores on self.aggregate_score."""
    rcpt = _make_receipt(eval_scores={"panel_a": _scorecard(0.55)})
    t = _make_take(steps=[_make_step(receipt=rcpt)])
    returned = t.compute_aggregate_score()
    assert returned == t.aggregate_score == pytest.approx(0.55)


def test_compute_aggregate_score_idempotent_recomputable():
    """Calling compute twice yields the same answer (no state pollution).
    Calling after eval_scores mutate yields the new answer."""
    rcpt = _make_receipt(eval_scores={"panel_a": _scorecard(0.3)})
    t = _make_take(steps=[_make_step(receipt=rcpt)])
    assert t.compute_aggregate_score() == pytest.approx(0.3)
    assert t.compute_aggregate_score() == pytest.approx(0.3)
    # Mutate the underlying dict — receipt is frozen but eval_scores dict is not.
    t.workflow.steps[0].receipt.eval_scores["panel_b"] = _scorecard(0.9)
    assert t.compute_aggregate_score() == pytest.approx(0.6)


def test_compute_aggregate_score_resets_None_when_panels_disappear():
    """If a previously-scored take loses all panel_scores (defensive — wouldn't
    normally happen), recomputation resets aggregate_score to None."""
    rcpt = _make_receipt(eval_scores={"panel_a": _scorecard(0.5)})
    t = _make_take(steps=[_make_step(receipt=rcpt)])
    t.compute_aggregate_score()
    assert t.aggregate_score == pytest.approx(0.5)
    # Wipe scores
    t.workflow.steps[0].receipt.eval_scores.clear()
    assert t.compute_aggregate_score() is None
    assert t.aggregate_score is None


# ─────────────────────────────────────────────────────────────────────────
# JSON round-trip
# ─────────────────────────────────────────────────────────────────────────


def test_take_to_dict_round_trips_aggregate_score_None():
    """Default None survives to_dict/from_dict cycle."""
    t = _make_take()
    d = t.to_dict()
    assert "aggregate_score" in d
    assert d["aggregate_score"] is None
    t2 = Take.from_dict(d)
    assert t2.aggregate_score is None


def test_take_to_dict_round_trips_aggregate_score_set():
    """A set float aggregate_score round-trips losslessly."""
    wf = Workflow(workflow_id="wf", steps=[_make_step()])
    t = Take(take_id="t", take_index=0, workflow=wf, aggregate_score=0.73)
    d = t.to_dict()
    assert d["aggregate_score"] == pytest.approx(0.73)
    t2 = Take.from_dict(d)
    assert t2.aggregate_score == pytest.approx(0.73)


def test_take_from_dict_legacy_dict_without_aggregate_score():
    """A dict missing the aggregate_score key (legacy CP-7 sidecar) loads
    with aggregate_score=None — defaulted via d.get()."""
    legacy_dict = {
        "take_id": "t_legacy",
        "take_index": 0,
        "workflow": Workflow(workflow_id="wf", steps=[_make_step()]).to_dict(),
        "status": "succeeded",
        "created_at": "2026-04-28T03:14:15Z",
        "take_metadata": {},
        # no "aggregate_score" key at all
    }
    t = Take.from_dict(legacy_dict)
    assert t.aggregate_score is None
