"""CP-8 Phase 6 — Scene/Beat/Take audio E2E scenarios.

Validates the full Scene → Beat → Take → Workflow → audio_t2a + lipsync_post
chain in realistic compositions:
    - Scenes with mixed visual + audio beats
    - Failed-take/sibling-take isolation
    - global_provenance reaching receipts
    - Cost compute via model_profiles for both modalities
    - Isolated dispatches across two episodes
    - file_stem keying so concurrent takes don't clobber outputs

ALL HTTP transport mocked via the `_transport=` payload key.
"""

import json
import sys
import pathlib
from types import SimpleNamespace

sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[3]))
from recoil.core.paths import ensure_pipeline_importable  # noqa: E402
ensure_pipeline_importable()


from recoil.pipeline.core.dispatch import _reset_bootstrap_for_tests  # noqa: E402
from recoil.pipeline.core.dispatch_context import DispatchContext  # noqa: E402
from recoil.pipeline.core.registry import _reset_for_tests  # noqa: E402
from recoil.pipeline.core.take import Beat, Scene, Take  # noqa: E402
from recoil.pipeline.core.workflow import Workflow, WorkflowStep  # noqa: E402


# ── Stub StepRunner for image_t2i ───────────────────────────────────────


class _StubStepRunner:
    def __init__(self):
        self.calls: list[tuple[str, dict]] = []
        self._dispatch_path = "unknown"

    def execute_keyframe(self, **kw):
        self.calls.append(("keyframe", dict(kw)))
        return SimpleNamespace(
            shot_id=kw.get("shot_id", "X"), success=True,
            final_state="keyframe_generated", output_path="/tmp/x.png",
            cost_usd=0.04, error=None, take_index=0, gate_verdict=None,
            model="nbp", pipeline="still",
        )

    def execute_video(self, **kw):  # pragma: no cover - not used in this file
        self.calls.append(("video", dict(kw)))
        return SimpleNamespace(
            shot_id=kw.get("shot_id", "X"), success=True,
            final_state="video_complete", output_path="/tmp/v.mp4",
            cost_usd=0.20, error=None, take_index=0, gate_verdict=None,
            model="seeddance-2.0", pipeline="i2v",
        )


# ── Fake HTTP transports ────────────────────────────────────────────────


class _FakeResponse:
    def __init__(self, body: bytes, status: int = 200, headers=None):
        self._body = body
        self.status = status
        self.headers = headers or {"request-id": "req_test_123"}

    def read(self) -> bytes:
        return self._body

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc, tb):
        return False


def _make_audio_transport(audio_bytes: bytes = b"FAKE_MP3_BYTES"):
    def _transport(url, *, headers, body, timeout):
        return _FakeResponse(audio_bytes)

    return _transport


def _make_audio_failing_transport():
    from recoil.execution.providers import elevenlabs as _eleven

    def _transport(url, *, headers, body, timeout):
        raise _eleven.AuthError("401 invalid key")

    return _transport


def _make_lipsync_5step_transport(
    *, job_id: str = "job_xyz",
    output_bytes: bytes = b"FAKE_LIPSYNCED_MP4",
    duration_s: float = 4.2,
):
    state = {"poll_count": 0}

    def _transport(url, *, headers, body, method="GET", timeout=60.0):
        if "/v2/upload" in url:
            return _FakeResponse(
                json.dumps({"url": "https://cdn.sync.so/file_xyz"}).encode("utf-8")
            )
        if url.endswith("/v2/generate") and method == "POST":
            return _FakeResponse(json.dumps({"id": job_id}).encode("utf-8"))
        if "/v2/generate/" in url and method == "GET":
            state["poll_count"] += 1
            if state["poll_count"] <= 1:
                return _FakeResponse(
                    json.dumps({"status": "PROCESSING"}).encode("utf-8")
                )
            return _FakeResponse(
                json.dumps({
                    "status": "COMPLETED",
                    "outputUrl": "https://cdn.sync.so/output_v.mp4",
                    "duration_s": duration_s,
                }).encode("utf-8")
            )
        if "output_v.mp4" in url:
            return _FakeResponse(output_bytes)
        raise AssertionError(f"unexpected URL in fake transport: {url}")

    return _transport


# ── Fixtures + payload helpers ──────────────────────────────────────────


def _make_video_input(tmp_path, name="carrier.mp4"):
    v = tmp_path / name
    v.write_bytes(b"FAKE CARRIER VIDEO BYTES")
    return v


def _audio_payload(tmp_path, *, shot_id="EP001_SH02", **overrides):
    base = {
        "shot_id": shot_id,
        "text": "Hello world from Phase 6 e2e test",
        "voice_id": "voice_xyz",
        "model": "eleven_multilingual_v2",
        "output_dir": str(tmp_path / f"audio_out_{shot_id}"),
        "_transport": _make_audio_transport(),
    }
    base.update(overrides)
    return base


def _lipsync_payload(tmp_path, *, shot_id="EP001_SH02", with_audio=True,
                     **overrides):
    v = _make_video_input(tmp_path, name=f"carrier_{shot_id}.mp4")
    base = {
        "shot_id": shot_id,
        "video_path": str(v),
        "model": "lipsync-2.0",
        "output_dir": str(tmp_path / f"lipsync_out_{shot_id}"),
        "_transport": _make_lipsync_5step_transport(),
        "poll_interval_s": 0.0,
    }
    if with_audio:
        a = tmp_path / f"existing_audio_{shot_id}.mp3"
        a.write_bytes(b"FAKE_EXISTING_AUDIO")
        base["audio_path"] = str(a)
    base.update(overrides)
    return base


def _image_payload(*, shot_id="EP001_SH02"):
    # Build A Phase 2 (2026-05-09): aspect_ratio is now required; no silent default.
    return {"shot_id": shot_id, "prompt": "p", "model": "nbp", "aspect_ratio": "9_16"}


def _audio_to_lipsync_hook(step, workflow):
    """Phase-5 canonical pre_step hook — re-defined here per spec direction."""
    if step.modality != "lipsync_post":
        return
    if "audio_path" in step.payload and step.payload["audio_path"]:
        return
    tts = workflow.get_step("tts")
    if tts is None or tts.receipt is None:
        return
    if not tts.receipt.run_result.success:
        return
    audio_path = tts.receipt.run_result.output_path
    if audio_path:
        step.payload["audio_path"] = audio_path


# ── Tests ───────────────────────────────────────────────────────────────


def test_scene_with_mixed_visual_and_audio_beats(tmp_path):
    """Scene.beats = [image_t2i Take, audio_t2a Take, lipsync_post Take];
    iterate scene.beats, execute each take; all succeed."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_e2e", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    scene = Scene(scene_id="ep001_sc02")

    # Beat 1: visual (image_t2i)
    visual_beat = Beat(beat_id="EP001_SH01")
    visual_take = visual_beat.new_take(
        workflow=Workflow(workflow_id="wf_visual", steps=[
            WorkflowStep(step_id="kf", modality="image_t2i",
                         payload=_image_payload(shot_id="EP001_SH01")),
        ]),
    )
    scene.add_beat(visual_beat)

    # Beat 2: audio (audio_t2a)
    audio_beat = Beat(beat_id="EP001_VO01")
    audio_take = audio_beat.new_take(
        workflow=Workflow(workflow_id="wf_audio", steps=[
            WorkflowStep(step_id="tts", modality="audio_t2a",
                         payload=_audio_payload(tmp_path, shot_id="EP001_VO01")),
        ]),
    )
    scene.add_beat(audio_beat)

    # Beat 3: lipsync (lipsync_post)
    lipsync_beat = Beat(beat_id="EP001_LS01")
    lipsync_take = lipsync_beat.new_take(
        workflow=Workflow(workflow_id="wf_lipsync", steps=[
            WorkflowStep(step_id="lipsync", modality="lipsync_post",
                         payload=_lipsync_payload(tmp_path, shot_id="EP001_LS01")),
        ]),
    )
    scene.add_beat(lipsync_beat)

    # Iterate scene.beats and execute each take
    for beat in scene.beats:
        for take in beat.takes:
            take.execute(context=ctx)

    assert visual_take.status == "succeeded"
    assert audio_take.status == "succeeded"
    assert lipsync_take.status == "succeeded"


def test_full_episode_smoke_three_modalities(tmp_path):
    """Full-episode smoke: 3 Beats — visual + audio + lipsync — each as own
    Take inside a Scene. All succeed."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_e2e", step_runner=sr,
        project="tartarus", episode=1,
        receipts_log_path="DISABLED",
    )
    scene = Scene(scene_id="ep001_sc02",
                  scene_metadata={"location": "control_room"})

    bv = Beat(beat_id="EP001_SH01")
    bv.new_take(workflow=Workflow(workflow_id="wf_v", steps=[
        WorkflowStep(step_id="kf", modality="image_t2i",
                     payload=_image_payload(shot_id="EP001_SH01")),
    ]))
    scene.add_beat(bv)

    ba = Beat(beat_id="EP001_VO01")
    ba.new_take(workflow=Workflow(workflow_id="wf_a", steps=[
        WorkflowStep(step_id="tts", modality="audio_t2a",
                     payload=_audio_payload(tmp_path, shot_id="EP001_VO01")),
    ]))
    scene.add_beat(ba)

    bl = Beat(beat_id="EP001_LS01")
    bl.new_take(workflow=Workflow(workflow_id="wf_l", steps=[
        WorkflowStep(step_id="lipsync", modality="lipsync_post",
                     payload=_lipsync_payload(tmp_path, shot_id="EP001_LS01")),
    ]))
    scene.add_beat(bl)

    statuses = []
    for beat in scene.beats:
        for take in beat.takes:
            take.execute(context=ctx)
            statuses.append(take.status)

    assert statuses == ["succeeded", "succeeded", "succeeded"]
    # Sanity: receipts on every step have correct modality
    modalities = [
        b.takes[0].workflow.steps[0].receipt.modality for b in scene.beats
    ]
    assert modalities == ["image_t2i", "audio_t2a", "lipsync_post"]


def test_scene_serialization_round_trips_audio_takes(tmp_path):
    """scene.to_dict() then Scene.from_dict() preserves all takes including
    audio receipts."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_e2e", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    scene = Scene(scene_id="ep001_sc02")
    beat = Beat(beat_id="EP001_VO01")
    take = beat.new_take(workflow=Workflow(
        workflow_id="wf_audio_rt", steps=[
            WorkflowStep(step_id="tts", modality="audio_t2a",
                         payload=_audio_payload(tmp_path)),
            WorkflowStep(step_id="lipsync", modality="lipsync_post",
                         payload=_lipsync_payload(tmp_path, with_audio=False),
                         depends_on=["tts"]),
        ],
    ))
    take.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    scene.add_beat(beat)

    revived = Scene.from_dict(scene.to_dict())
    assert revived.scene_id == "ep001_sc02"
    assert len(revived.beats) == 1
    revived_take = revived.beats[0].takes[0]
    assert revived_take.status == "succeeded"
    assert revived_take.workflow.steps[0].receipt is not None
    assert revived_take.workflow.steps[1].receipt is not None
    assert revived_take.workflow.steps[0].receipt.modality == "audio_t2a"
    assert revived_take.workflow.steps[1].receipt.modality == "lipsync_post"
    # Step output_path string round-trips
    assert (
        revived_take.workflow.steps[0].receipt.run_result.output_path
        == take.workflow.steps[0].receipt.run_result.output_path
    )


def test_failed_take_does_not_corrupt_sibling_takes(tmp_path):
    """Beat with 2 takes — first fails, second succeeds. After both execute,
    beat.takes contains both with intact independent state."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_e2e", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    beat = Beat(beat_id="EP001_VO01")

    # First take fails (audio transport raises)
    t0 = beat.new_take(workflow=Workflow(workflow_id="wf_fail", steps=[
        WorkflowStep(step_id="tts", modality="audio_t2a",
                     payload=_audio_payload(
                         tmp_path,
                         _transport=_make_audio_failing_transport(),
                     )),
    ]))
    t0.execute(context=ctx)
    assert t0.status == "failed"

    _reset_for_tests(); _reset_bootstrap_for_tests()

    # Second take succeeds
    t1 = beat.new_take(workflow=Workflow(workflow_id="wf_ok", steps=[
        WorkflowStep(step_id="tts", modality="audio_t2a",
                     payload=_audio_payload(
                         tmp_path, shot_id="EP001_VO01_t1",
                     )),
    ]))
    t1.execute(context=ctx)
    assert t1.status == "succeeded"

    # Sibling takes have INDEPENDENT state
    assert len(beat.takes) == 2
    assert beat.takes[0] is t0
    assert beat.takes[1] is t1
    assert beat.takes[0].status == "failed"
    assert beat.takes[1].status == "succeeded"
    # First take's receipt still reflects the failure
    assert beat.takes[0].workflow.steps[0].receipt is not None
    assert beat.takes[0].workflow.steps[0].receipt.run_result.success is False
    # Second take's receipt still reflects the success
    assert beat.takes[1].workflow.steps[0].receipt.run_result.success is True


def test_global_provenance_reaches_each_step_receipt(tmp_path):
    """Workflow(global_provenance={'shot_id': 'EP001_VO01'}) — after run, each
    step's receipt.provenance has shot_id stamped."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_e2e", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    wf = Workflow(
        workflow_id="wf_gp", steps=[
            WorkflowStep(step_id="tts", modality="audio_t2a",
                         payload=_audio_payload(tmp_path)),
            WorkflowStep(step_id="lipsync", modality="lipsync_post",
                         payload=_lipsync_payload(tmp_path, with_audio=False),
                         depends_on=["tts"]),
        ],
        global_provenance={
            "shot_id": "EP001_VO01",
            "scene_id": "ep001_sc02",
            "scene_intent": "VO over a closed door",
        },
    )
    take = Take(take_id="t0", take_index=0, workflow=wf)
    take.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    assert take.status == "succeeded"

    for step in take.workflow.steps:
        prov = step.receipt.provenance
        assert prov["shot_id"] == "EP001_VO01"
        assert prov["scene_id"] == "ep001_sc02"
        assert prov["scene_intent"] == "VO over a closed door"
        # workflow_id / workflow_step_id still win merge order
        assert prov["workflow_id"] == "wf_gp"


def test_pre_step_hook_pattern_documented(tmp_path):
    """The Phase-5 _audio_to_lipsync_hook is the canonical resolver pattern.
    Re-importing / re-defining it works inside a Take.execute call:
    when lipsync step starts with no audio_path key, the hook resolves it
    from the upstream tts step's receipt."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_e2e", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    lipsync_payload = _lipsync_payload(tmp_path, with_audio=False)
    assert "audio_path" not in lipsync_payload  # pre-condition

    take = Take(take_id="t0", take_index=0, workflow=Workflow(
        workflow_id="wf_hook_take", steps=[
            WorkflowStep(step_id="tts", modality="audio_t2a",
                         payload=_audio_payload(tmp_path)),
            WorkflowStep(step_id="lipsync", modality="lipsync_post",
                         payload=lipsync_payload, depends_on=["tts"]),
        ],
    ))
    take.execute(context=ctx, pre_step=_audio_to_lipsync_hook)

    tts_step = take.workflow.get_step("tts")
    lipsync_step = take.workflow.get_step("lipsync")
    assert lipsync_step.payload["audio_path"] == tts_step.receipt.run_result.output_path
    assert pathlib.Path(lipsync_step.payload["audio_path"]).is_file()
    assert take.status == "succeeded"


def test_cost_compute_via_model_profiles_for_both_modalities(tmp_path):
    """receipt.run_result.metadata['cost_usd'] > 0 for both audio_t2a and
    lipsync_post (model_profiles supplies cost_per_1k_chars and cost_per_second)."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_e2e", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    take = Take(take_id="t0", take_index=0, workflow=Workflow(
        workflow_id="wf_cost", steps=[
            WorkflowStep(step_id="tts", modality="audio_t2a",
                         payload=_audio_payload(tmp_path)),
            WorkflowStep(step_id="lipsync", modality="lipsync_post",
                         payload=_lipsync_payload(tmp_path, with_audio=False),
                         depends_on=["tts"]),
        ],
    ))
    take.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    assert take.status == "succeeded"

    audio_cost = float(
        take.workflow.steps[0].receipt.run_result.metadata.get("cost_usd") or 0.0
    )
    lipsync_cost = float(
        take.workflow.steps[1].receipt.run_result.metadata.get("cost_usd") or 0.0
    )
    assert audio_cost > 0.0, (
        "audio_t2a cost_usd must be > 0 — model_profiles eleven_multilingual_v2 "
        f"has cost_per_1k_chars=0.30. Got {audio_cost!r}"
    )
    assert lipsync_cost > 0.0, (
        "lipsync_post cost_usd must be > 0 — model_profiles lipsync-2.0 has "
        f"cost_per_second supplied. Got {lipsync_cost!r}"
    )


def test_isolated_dispatches_across_two_episodes(tmp_path):
    """Two DispatchContexts (project='A' vs 'B'); both dispatch audio_t2a;
    receipts have correct provenance.project for each."""
    sr_a = _StubStepRunner()
    ctx_a = DispatchContext(
        caller_id="cp8_p6_e2e", step_runner=sr_a,
        project="A", episode=1, receipts_log_path="DISABLED",
    )
    take_a = Take(take_id="t_a", take_index=0, workflow=Workflow(
        workflow_id="wf_A", steps=[
            WorkflowStep(step_id="tts", modality="audio_t2a",
                         payload=_audio_payload(tmp_path, shot_id="A_SH01")),
        ],
    ))
    take_a.execute(context=ctx_a)
    assert take_a.status == "succeeded"

    _reset_for_tests(); _reset_bootstrap_for_tests()

    sr_b = _StubStepRunner()
    ctx_b = DispatchContext(
        caller_id="cp8_p6_e2e", step_runner=sr_b,
        project="B", episode=1, receipts_log_path="DISABLED",
    )
    take_b = Take(take_id="t_b", take_index=0, workflow=Workflow(
        workflow_id="wf_B", steps=[
            WorkflowStep(step_id="tts", modality="audio_t2a",
                         payload=_audio_payload(tmp_path, shot_id="B_SH01")),
        ],
    ))
    take_b.execute(context=ctx_b)
    assert take_b.status == "succeeded"

    rcpt_a = take_a.workflow.steps[0].receipt
    rcpt_b = take_b.workflow.steps[0].receipt
    assert rcpt_a.project == "A"
    assert rcpt_b.project == "B"
    # Receipts should not cross-contaminate
    assert rcpt_a.shot_id == "A_SH01"
    assert rcpt_b.shot_id == "B_SH01"


def test_file_stem_includes_shot_id_no_clobber(tmp_path):
    """Two takes with shot_id='A' and shot_id='B' write to different output
    files. Default file_stem = f'{shot_id}_audio', so distinct shot_ids ⇒
    distinct output paths even when output_dir overlaps."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_e2e", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    shared_dir = str(tmp_path / "shared_audio_out")

    take_a = Take(take_id="t_A", take_index=0, workflow=Workflow(
        workflow_id="wf_A", steps=[
            WorkflowStep(step_id="tts", modality="audio_t2a",
                         payload=_audio_payload(
                             tmp_path, shot_id="A", output_dir=shared_dir,
                         )),
        ],
    ))
    take_a.execute(context=ctx)
    assert take_a.status == "succeeded"

    take_b = Take(take_id="t_B", take_index=0, workflow=Workflow(
        workflow_id="wf_B", steps=[
            WorkflowStep(step_id="tts", modality="audio_t2a",
                         payload=_audio_payload(
                             tmp_path, shot_id="B", output_dir=shared_dir,
                         )),
        ],
    ))
    take_b.execute(context=ctx)
    assert take_b.status == "succeeded"

    out_a = take_a.workflow.steps[0].receipt.run_result.output_path
    out_b = take_b.workflow.steps[0].receipt.run_result.output_path
    assert out_a != out_b, (
        f"shot_id-based file_stem must distinguish concurrent takes — "
        f"got the same output path for shot_ids 'A' and 'B': {out_a!r}"
    )
    # Both files exist on disk
    assert pathlib.Path(out_a).is_file()
    assert pathlib.Path(out_b).is_file()
    # Both contain "A" / "B" in the stem (default f"{shot_id}_audio")
    assert "A" in pathlib.Path(out_a).stem
    assert "B" in pathlib.Path(out_b).stem


def test_empty_beat_select_primary_returns_none(tmp_path):
    """Empty Beat: beat.takes == []. CP-7 contract (per
    test_beat_primary.test_first_success_empty_beat_returns_none) — returns
    None, does NOT raise. Documenting the actual contract here."""
    beat = Beat(beat_id="EP001_VO99")
    assert beat.takes == []
    result = beat.select_primary()
    assert result is None
    assert beat.primary_take_id is None
    assert beat.primary_take is None
