"""CP-7 Phase 2 — Beat construction + new_take + add_take tests."""

import sys
import pathlib

sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent.parent.parent))
from recoil.core.paths import ensure_pipeline_importable  # noqa: E402
ensure_pipeline_importable()

import pytest  # noqa: E402

from recoil.pipeline.core.take import Beat, Take  # noqa: E402
from recoil.pipeline.core.workflow import Workflow, WorkflowStep  # noqa: E402


def _wf(workflow_id="wf1") -> Workflow:
    return Workflow(workflow_id=workflow_id, steps=[
        WorkflowStep(step_id="kf", modality="image_t2i",
                     payload={"shot_id": "X", "prompt": "p", "model": "nbp"}),
    ])


def test_beat_minimal():
    b = Beat(beat_id="EP001_SH02")
    assert b.beat_id == "EP001_SH02" and b.takes == [] and b.primary_take_id is None


def test_beat_rejects_invalid_construction():
    with pytest.raises(ValueError):
        Beat(beat_id="")
    with pytest.raises(TypeError):
        Beat(beat_id="b1", takes="not a list")  # type: ignore
    with pytest.raises(TypeError):
        Beat(beat_id="b1", takes=[{"take_id": "t0"}])  # type: ignore


def test_beat_rejects_duplicate_take_id():
    t1 = Take(take_id="dup", take_index=0, workflow=_wf("wf1"))
    t2 = Take(take_id="dup", take_index=1, workflow=_wf("wf2"))
    with pytest.raises(ValueError):
        Beat(beat_id="b1", takes=[t1, t2])


def test_beat_rejects_primary_take_id_unknown():
    t = Take(take_id="t0", take_index=0, workflow=_wf())
    with pytest.raises(ValueError):
        Beat(beat_id="b1", takes=[t], primary_take_id="not_a_real_take_id")


def test_beat_new_take_assigns_index_and_id():
    b = Beat(beat_id="EP001_SH02")
    t0 = b.new_take(workflow=_wf("wf1"))
    t1 = b.new_take(workflow=_wf("wf2"))
    t2 = b.new_take(workflow=_wf("wf3"))
    assert [t.take_index for t in b.takes] == [0, 1, 2]
    assert [t.take_id for t in b.takes] == [
        "EP001_SH02_take_0", "EP001_SH02_take_1", "EP001_SH02_take_2",
    ]


def test_beat_new_take_passes_metadata():
    b = Beat(beat_id="b1")
    t = b.new_take(workflow=_wf(), take_metadata={"model": "kling-o3"})
    assert t.take_metadata == {"model": "kling-o3"}


def test_beat_add_take_external_and_dup_check():
    b = Beat(beat_id="b1")
    t = Take(take_id="custom_id", take_index=0, workflow=_wf())
    b.add_take(t)
    assert b.takes == [t]
    with pytest.raises(ValueError):
        b.add_take(Take(take_id="custom_id", take_index=1, workflow=_wf("wf2")))
    with pytest.raises(TypeError):
        b.add_take({"take_id": "t0"})  # type: ignore


def test_beat_primary_take_resolution():
    b = Beat(beat_id="b1")
    t = b.new_take(workflow=_wf())
    assert b.primary_take is None  # unset
    b.primary_take_id = t.take_id
    assert b.primary_take is t  # resolved
    b.takes = []  # defensive — stale id
    assert b.primary_take is None
