from __future__ import annotations

from types import SimpleNamespace

from recoil.pipeline._lib import dispatch_payload as dp
from recoil.pipeline._lib.prose_validator import Severity, verify_authored_prose
from recoil.pipeline._lib.shot_primitive import ShotPrimitive


def _primitive(*, target_s: float, segment_durations: list[float]) -> ShotPrimitive:
    cursor = 0.0
    segments: list[dict] = []
    for i, duration_s in enumerate(segment_durations, start=1):
        segments.append(
            {
                "shot_id": f"EP002_SH{i:02d}",
                "start_s": cursor,
                "end_s": cursor + duration_s,
                "duration_s": duration_s,
                "intent": f"Beat {i}.",
            }
        )
        cursor += duration_s
    return ShotPrimitive(
        shot_id="EP002_PASS_009",
        scene_index=2,
        shot_type="OTS",
        target_editorial_duration_s=target_s,
        intent="Jade and Wren cross the med bay under pressure.",
        camera_side="B",
        screen_direction="left-to-right",
        timing_segments=segments,
    )


def _strategy() -> SimpleNamespace:
    return SimpleNamespace(name="directed_prose", required_inputs=[])


def _block_checks(results) -> set[str]:
    return {r.check for r in results if r.severity == Severity.BLOCK}


def test_normalized_timecodes_use_target_duration_when_segment_sum_differs() -> None:
    primitive = _primitive(target_s=15.0, segment_durations=[6.0, 6.0, 6.0])
    authored = (
        "[0:00-0:09] Jade grips the rail as the camera pushes with her breath.\n"
        "[0:09-0:18] Wren locks his jaw as the lens tracks his hand.\n"
        "[0:18-0:27] The pair brace as the frame pulls wider around them."
    )

    normalized = dp._normalize_authored_timecodes(authored, primitive)
    results = verify_authored_prose(normalized, primitive, _strategy())

    assert normalized.startswith("[0:00-0:05]")
    assert "[0:05-0:10]" in normalized
    assert "[0:10-0:15]" in normalized
    assert not _block_checks(results)


def test_normalized_timecodes_preserve_equal_duration_basis() -> None:
    primitive = _primitive(target_s=6.0, segment_durations=[3.0, 3.0])
    authored = (
        "[0:00-0:12] Jade breathes tight as the camera pushes with her.\n"
        "[0:12-0:24] Wren braces as the lens settles on his hand."
    )

    normalized = dp._normalize_authored_timecodes(authored, primitive)
    results = verify_authored_prose(normalized, primitive, _strategy())

    assert normalized.startswith("[0:00-0:03]")
    assert "[0:03-0:06]" in normalized
    assert not _block_checks(results)


def test_normalized_timecodes_leave_count_mismatch_for_verifier_block() -> None:
    primitive = _primitive(target_s=6.0, segment_durations=[3.0, 3.0])
    authored = "[0:00-0:06] Jade and Wren brace as the camera pushes in."

    normalized = dp._normalize_authored_timecodes(authored, primitive)
    results = verify_authored_prose(normalized, primitive, _strategy())

    assert normalized == authored
    assert "prose_verify_segment_count" in _block_checks(results)
