"""CP-9 Phase 6 integration test 5 — round-trip Beat with eval'd Takes.

Builds a real Beat with 2 Takes, each Take has 2 WorkflowSteps. Each step
gets a real GenerationReceipt with a PanelOfJudges scorecard written via
attach_eval_hooks (in-place mutation on receipt.eval_scores). After
Take.compute_aggregate_score, every Take has a non-None aggregate_score.

Then round-trip: Beat.to_dict() → Beat.from_dict(). Verify both:
  1. Every step receipt's eval_scores dicts are preserved byte-for-byte.
  2. Every Take's aggregate_score round-trips identically.

This is the cross-CP integration test — CP-5 receipts + CP-6 workflow +
CP-7 Take/Beat + CP-9 eval all serializing through one round trip.
"""

import sys
import pathlib

sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent.parent.parent))
from recoil.core.paths import ensure_pipeline_importable  # noqa: E402

ensure_pipeline_importable()

import pytest  # noqa: E402

from recoil.pipeline.core.dispatch import _reset_bootstrap_for_tests  # noqa: E402
from recoil.pipeline.core.dispatch_context import DispatchContext  # noqa: E402
from recoil.pipeline.core.eval import (  # noqa: E402
    PanelOfJudges,
    attach_eval_hooks,
    _reset_eval_registry_for_tests,
)
from recoil.pipeline.core.registry import _reset_for_tests  # noqa: E402
from recoil.pipeline.core.take import Beat  # noqa: E402
from recoil.pipeline.core.workflow import Workflow, WorkflowStep  # noqa: E402

from recoil.pipeline.core.tests._eval_test_helpers import (  # noqa: E402
    FakeEvalNode,
    StubStepRunner,
)


@pytest.fixture(autouse=True)
def reset_registries():
    _reset_for_tests()
    _reset_bootstrap_for_tests()
    _reset_eval_registry_for_tests()
    yield
    _reset_for_tests()
    _reset_bootstrap_for_tests()
    _reset_eval_registry_for_tests()


def _two_step_wf(workflow_id: str) -> Workflow:
    return Workflow(
        workflow_id=workflow_id,
        steps=[
            WorkflowStep(
                step_id="kf",
                modality="image_t2i",
                payload={"shot_id": "EP001_SH02", "prompt": "p", "model": "nbp", "aspect_ratio": "9_16"},
            ),
            WorkflowStep(
                step_id="vid",
                modality="video_i2v",
                payload={
                    "shot_id": "EP001_SH02",
                    "prompt": "p",
                    "model": "seeddance-2.0",
                    "aspect_ratio": "9_16",
                },
                depends_on=["kf"],
            ),
        ],
    )


def test_round_trip_beat_with_evald_takes_preserves_eval_scores_and_aggregate_score():
    """The big integration test — Beat with 2 Takes × 2 Steps × eval scores
    survives JSON round-trip with both eval_scores AND aggregate_score
    preserved on every layer."""
    sr = StubStepRunner()
    ctx = DispatchContext(
        caller_id="phase6_roundtrip",
        step_runner=sr,
        receipts_log_path="DISABLED",
    )

    # Two judges → meaningful aggregation. Different scores per take so
    # aggregate_score is distinguishable per take.
    judge_t0 = FakeEvalNode(judge_id="j_t0", score=0.80, cost_usd=0.002)
    judge_t1 = FakeEvalNode(judge_id="j_t1", score=0.50, cost_usd=0.002)

    panel_t0 = PanelOfJudges(panel_id="panel_take0", judges=[judge_t0])
    panel_t1 = PanelOfJudges(panel_id="panel_take1", judges=[judge_t1])

    beat = Beat(beat_id="EP001_SH02", beat_metadata={"scene_id": "ep001_sc02"})

    take_0 = beat.new_take(workflow=_two_step_wf("wf_take0"))
    _, post_t0, _ = attach_eval_hooks(take_0.workflow, panel_t0)
    take_0.execute(context=ctx, post_step=post_t0)

    take_1 = beat.new_take(workflow=_two_step_wf("wf_take1"))
    _, post_t1, _ = attach_eval_hooks(take_1.workflow, panel_t1)
    take_1.execute(context=ctx, post_step=post_t1)

    # Both takes succeeded; eval_scores populated on each step's receipt.
    assert take_0.status == "succeeded"
    assert take_1.status == "succeeded"
    for step in take_0.workflow.steps:
        assert "panel_take0" in step.receipt.eval_scores
        assert step.receipt.eval_scores["panel_take0"]["panel_score"] == pytest.approx(0.80)
    for step in take_1.workflow.steps:
        assert "panel_take1" in step.receipt.eval_scores
        assert step.receipt.eval_scores["panel_take1"]["panel_score"] == pytest.approx(0.50)

    # Compute aggregate scores on each take.
    take_0.compute_aggregate_score()
    take_1.compute_aggregate_score()
    # Each take has 2 steps, each with one panel scoring panel_score = 0.80 / 0.50
    # → mean across both steps = same value (single panel per receipt).
    assert take_0.aggregate_score == pytest.approx(0.80)
    assert take_1.aggregate_score == pytest.approx(0.50)

    # Round-trip the entire Beat.
    beat_dict = beat.to_dict()
    beat_rt = Beat.from_dict(beat_dict)

    # Beat-level identity preserved.
    assert beat_rt.beat_id == beat.beat_id
    assert beat_rt.beat_metadata == beat.beat_metadata
    assert len(beat_rt.takes) == 2

    for original_take, rt_take in zip(beat.takes, beat_rt.takes):
        # Take-level identity preserved.
        assert rt_take.take_id == original_take.take_id
        assert rt_take.take_index == original_take.take_index
        assert rt_take.status == original_take.status
        # Aggregate score round-trips on the Take.
        assert rt_take.aggregate_score == original_take.aggregate_score
        # Workflow + every step + every receipt + every eval_scores dict
        # round-trips intact.
        assert len(rt_take.workflow.steps) == len(original_take.workflow.steps)
        for original_step, rt_step in zip(
            original_take.workflow.steps, rt_take.workflow.steps
        ):
            assert rt_step.step_id == original_step.step_id
            assert rt_step.status == original_step.status
            assert rt_step.receipt is not None
            # eval_scores byte-equal across the round trip.
            assert rt_step.receipt.eval_scores == original_step.receipt.eval_scores
            # provenance also round-trips (eval_cost_usd lives there).
            assert (
                rt_step.receipt.provenance["eval_cost_usd"]
                == original_step.receipt.provenance["eval_cost_usd"]
            )


def test_round_trip_beat_select_primary_score_works_after_round_trip():
    """After round-trip, Beat.select_primary("score") on the rehydrated Beat
    produces the same primary_take_id as on the pre-round-trip Beat. CP-9
    aggregate_score persists across serialization → score-strategy is stable.
    """
    sr = StubStepRunner()
    ctx = DispatchContext(
        caller_id="phase6_roundtrip",
        step_runner=sr,
        receipts_log_path="DISABLED",
    )

    judge_winner = FakeEvalNode(judge_id="winner_j", score=0.92, cost_usd=0.001)
    judge_loser = FakeEvalNode(judge_id="loser_j", score=0.31, cost_usd=0.001)
    panel_winner = PanelOfJudges(panel_id="panel_winner", judges=[judge_winner])
    panel_loser = PanelOfJudges(panel_id="panel_loser", judges=[judge_loser])

    beat = Beat(beat_id="EP001_SH99")

    take_loser = beat.new_take(workflow=_two_step_wf("wf_loser"))
    _, post_loser, _ = attach_eval_hooks(take_loser.workflow, panel_loser)
    take_loser.execute(context=ctx, post_step=post_loser)

    take_winner = beat.new_take(workflow=_two_step_wf("wf_winner"))
    _, post_winner, _ = attach_eval_hooks(take_winner.workflow, panel_winner)
    take_winner.execute(context=ctx, post_step=post_winner)

    # Pre-round-trip select_primary on score.
    primary_id = beat.select_primary(strategy="score")
    assert primary_id == take_winner.take_id

    # Round-trip the Beat.
    beat_rt = Beat.from_dict(beat.to_dict())

    # Aggregate_score persists across round-trip — select_primary on the
    # rehydrated Beat produces the same winner without needing to recompute.
    primary_id_rt = beat_rt.select_primary(strategy="score")
    assert primary_id_rt == take_winner.take_id


def test_round_trip_beat_with_no_aggregate_score_remains_none():
    """If a take's aggregate_score was never computed (no eval ran), it
    survives round-trip as None — neither populated nor lost."""
    sr = StubStepRunner()
    ctx = DispatchContext(
        caller_id="phase6_roundtrip",
        step_runner=sr,
        receipts_log_path="DISABLED",
    )

    beat = Beat(beat_id="EP001_NOEVAL")
    take = beat.new_take(workflow=_two_step_wf("wf_noeval"))
    take.execute(context=ctx)  # NO hooks attached → no eval scores

    assert take.status == "succeeded"
    assert take.aggregate_score is None
    for step in take.workflow.steps:
        assert step.receipt.eval_scores == {}

    # Round-trip preserves None.
    beat_rt = Beat.from_dict(beat.to_dict())
    assert beat_rt.takes[0].aggregate_score is None
    for step in beat_rt.takes[0].workflow.steps:
        assert step.receipt.eval_scores == {}
