"""CP-8 Phase 6 — Take/Beat audio E2E tests.

Validates the Take/Beat layer over the audio_t2a + lipsync_post chain.
Exercises the retry-as-new-take pattern (audio fails → Beat.new_take with
fresh workflow → succeeds → primary picks the new take). Validates
take_metadata round-trip and Take.status compression on audio workflows.

ALL HTTP transport mocked via the `_transport=` payload key.
"""

import json
import sys
import pathlib
from types import SimpleNamespace

sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[3]))
from recoil.core.paths import ensure_pipeline_importable  # noqa: E402
ensure_pipeline_importable()


from recoil.pipeline.core.dispatch import _reset_bootstrap_for_tests  # noqa: E402
from recoil.pipeline.core.dispatch_context import DispatchContext  # noqa: E402
from recoil.pipeline.core.registry import _reset_for_tests  # noqa: E402
from recoil.pipeline.core.take import Beat, Take  # noqa: E402
from recoil.pipeline.core.workflow import Workflow, WorkflowStep  # noqa: E402


# ── Stub StepRunner (audio + lipsync runners do NOT consult StepRunner;
#    dispatch() requires it for bootstrap + _dispatch_path stamping) ─────


class _StubStepRunner:
    def __init__(self):
        self.calls: list[tuple[str, dict]] = []
        self._dispatch_path = "unknown"

    def execute_keyframe(self, **kw):  # pragma: no cover - not exercised
        self.calls.append(("keyframe", dict(kw)))
        return SimpleNamespace(
            shot_id=kw.get("shot_id", "X"), success=True,
            final_state="keyframe_generated", output_path="/tmp/x.png",
            cost_usd=0.04, error=None, take_index=0, gate_verdict=None,
            model="nbp", pipeline="still",
        )

    def execute_video(self, **kw):  # pragma: no cover - not exercised
        self.calls.append(("video", dict(kw)))
        return SimpleNamespace(
            shot_id=kw.get("shot_id", "X"), success=True,
            final_state="video_complete", output_path="/tmp/v.mp4",
            cost_usd=0.20, error=None, take_index=0, gate_verdict=None,
            model="seeddance-2.0", pipeline="i2v",
        )


# ── Fake HTTP transports (audio + lipsync) ──────────────────────────────


class _FakeResponse:
    def __init__(self, body: bytes, status: int = 200, headers=None):
        self._body = body
        self.status = status
        self.headers = headers or {"request-id": "req_test_123"}

    def read(self) -> bytes:
        return self._body

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc, tb):
        return False


def _make_audio_transport(audio_bytes: bytes = b"FAKE_MP3_BYTES"):
    def _transport(url, *, headers, body, timeout):
        return _FakeResponse(audio_bytes)

    return _transport


def _make_audio_failing_transport():
    """Audio transport that raises AuthError so audio dispatch fails."""
    from recoil.execution.providers import elevenlabs as _eleven

    def _transport(url, *, headers, body, timeout):
        raise _eleven.AuthError("401 invalid key")

    return _transport


def _make_lipsync_5step_transport(
    *,
    job_id: str = "job_xyz",
    output_bytes: bytes = b"FAKE_LIPSYNCED_MP4",
    duration_s: float = 4.2,
):
    state = {"poll_count": 0}

    def _transport(url, *, headers, body, method="GET", timeout=60.0):
        if "/v2/upload" in url:
            return _FakeResponse(
                json.dumps({"url": "https://cdn.sync.so/file_xyz"}).encode("utf-8")
            )
        if url.endswith("/v2/generate") and method == "POST":
            return _FakeResponse(json.dumps({"id": job_id}).encode("utf-8"))
        if "/v2/generate/" in url and method == "GET":
            state["poll_count"] += 1
            if state["poll_count"] <= 1:
                return _FakeResponse(
                    json.dumps({"status": "PROCESSING"}).encode("utf-8")
                )
            return _FakeResponse(
                json.dumps({
                    "status": "COMPLETED",
                    "outputUrl": "https://cdn.sync.so/output_v.mp4",
                    "duration_s": duration_s,
                }).encode("utf-8")
            )
        if "output_v.mp4" in url:
            return _FakeResponse(output_bytes)
        raise AssertionError(f"unexpected URL in fake transport: {url}")

    return _transport


def _make_lipsync_failing_transport():
    """Lipsync transport that raises on the upload step → adapter fails."""
    from recoil.execution.providers import sync_so as _sync_so

    def _transport(url, *, headers, body, method="GET", timeout=60.0):
        raise _sync_so.LipSyncError("transport refused")

    return _transport


# ── Fixtures + payload helpers ──────────────────────────────────────────


def _make_video_input(tmp_path, name="carrier.mp4"):
    v = tmp_path / name
    v.write_bytes(b"FAKE CARRIER VIDEO BYTES")
    return v


def _audio_payload(tmp_path, *, shot_id="EP001_SH02", **overrides):
    base = {
        "shot_id": shot_id,
        "text": "Hello world",
        "voice_id": "voice_xyz",
        "model": "eleven_multilingual_v2",
        "output_dir": str(tmp_path / f"audio_out_{shot_id}"),
        "_transport": _make_audio_transport(),
    }
    base.update(overrides)
    return base


def _lipsync_payload(tmp_path, *, shot_id="EP001_SH02", **overrides):
    """audio_path is supplied by default for take-level tests (no upstream
    workflow step). Tests that use the pre_step hook override this key."""
    v = _make_video_input(tmp_path, name=f"carrier_{shot_id}.mp4")
    a = tmp_path / f"existing_audio_{shot_id}.mp3"
    a.write_bytes(b"FAKE_EXISTING_AUDIO")
    base = {
        "shot_id": shot_id,
        "video_path": str(v),
        "audio_path": str(a),
        "model": "lipsync-2.0",
        "output_dir": str(tmp_path / f"lipsync_out_{shot_id}"),
        "_transport": _make_lipsync_5step_transport(),
        "poll_interval_s": 0.0,
    }
    base.update(overrides)
    return base


def _audio_to_lipsync_hook(step, workflow):
    """Resolves audio_path from the upstream `tts` step's receipt — copied from
    Phase 5 (test_audio_lipsync_workflow.py)."""
    if step.modality != "lipsync_post":
        return
    if "audio_path" in step.payload and step.payload["audio_path"]:
        return
    tts = workflow.get_step("tts")
    if tts is None or tts.receipt is None:
        return
    if not tts.receipt.run_result.success:
        return
    audio_path = tts.receipt.run_result.output_path
    if audio_path:
        step.payload["audio_path"] = audio_path


def _audio_only_workflow(workflow_id="wf_audio", shot_id="EP001_SH02",
                          tmp_path=None, **payload_overrides):
    return Workflow(
        workflow_id=workflow_id,
        steps=[
            WorkflowStep(
                step_id="tts", modality="audio_t2a",
                payload=_audio_payload(tmp_path, shot_id=shot_id,
                                       **payload_overrides),
            ),
        ],
    )


def _audio_then_lipsync_workflow(workflow_id="wf_combo", shot_id="EP001_SH02",
                                  tmp_path=None, *,
                                  audio_overrides=None,
                                  lipsync_overrides=None):
    audio_overrides = audio_overrides or {}
    lipsync_overrides = lipsync_overrides or {}
    # The lipsync payload does NOT have audio_path — the pre_step hook resolves
    # it from the upstream tts step's receipt.
    ls_payload = _lipsync_payload(tmp_path, shot_id=shot_id, **lipsync_overrides)
    ls_payload.pop("audio_path", None)
    return Workflow(
        workflow_id=workflow_id,
        steps=[
            WorkflowStep(
                step_id="tts", modality="audio_t2a",
                payload=_audio_payload(tmp_path, shot_id=shot_id,
                                       **audio_overrides),
            ),
            WorkflowStep(
                step_id="lipsync", modality="lipsync_post",
                payload=ls_payload,
                depends_on=["tts"],
            ),
        ],
    )


# ── Tests ───────────────────────────────────────────────────────────────


def test_audio_take_executes_via_take_execute(tmp_path):
    """Single-step audio_t2a workflow inside a Take inside a Beat.
    take.execute(context=ctx) → take.status='succeeded'."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_test", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    beat = Beat(beat_id="EP001_VO01")
    take = beat.new_take(
        workflow=_audio_only_workflow(workflow_id="wf_audio_solo",
                                      tmp_path=tmp_path),
    )
    take.execute(context=ctx)
    assert take.status == "succeeded"
    assert take.workflow.steps[0].status == "succeeded"
    assert take.workflow.steps[0].receipt.modality == "audio_t2a"
    assert take.workflow.steps[0].receipt.run_result.success is True


def test_combined_audio_lipsync_take_both_succeed(tmp_path):
    """audio_t2a + lipsync_post (depends_on=tts) both succeed → status='succeeded'."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_test", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    beat = Beat(beat_id="EP001_SH02")
    take = beat.new_take(
        workflow=_audio_then_lipsync_workflow(tmp_path=tmp_path),
    )
    take.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    assert take.status == "succeeded"
    assert take.workflow.steps[0].status == "succeeded"
    assert take.workflow.steps[1].status == "succeeded"
    assert take.workflow.steps[0].receipt.modality == "audio_t2a"
    assert take.workflow.steps[1].receipt.modality == "lipsync_post"


def test_combined_audio_succeeds_lipsync_fails_status_partial(tmp_path):
    """Audio succeeds but lipsync transport raises → take.status='partial'.
    One step succeeded + one step failed = the partial-status case."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_test", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    beat = Beat(beat_id="EP001_SH02")
    take = beat.new_take(
        workflow=_audio_then_lipsync_workflow(
            tmp_path=tmp_path,
            lipsync_overrides={"_transport": _make_lipsync_failing_transport()},
        ),
    )
    take.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    assert take.workflow.steps[0].status == "succeeded"
    assert take.workflow.steps[1].status == "failed"
    assert take.status == "partial"


def test_combined_audio_fails_take_status_failed(tmp_path):
    """Audio fails → lipsync skipped per CP-6 cascade → take.status='failed'.
    No step succeeded, so compression yields 'failed' (not 'partial')."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_test", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    beat = Beat(beat_id="EP001_SH02")
    take = beat.new_take(
        workflow=_audio_then_lipsync_workflow(
            tmp_path=tmp_path,
            audio_overrides={"_transport": _make_audio_failing_transport()},
        ),
    )
    take.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    assert take.workflow.steps[0].status == "failed"
    assert take.workflow.steps[1].status == "skipped"
    assert take.status == "failed"


def test_retry_as_new_take_pattern(tmp_path):
    """First take's audio fails → beat.new_take with a fresh workflow →
    second take succeeds → beat.select_primary() picks the second take.

    This is the canonical CP-8 retry idiom — re-attempts are NEW Takes,
    not re-execution of the failed Take.
    """
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_test", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    beat = Beat(beat_id="EP001_VO01")

    # First attempt: failing transport
    take_0 = beat.new_take(
        workflow=_audio_only_workflow(
            workflow_id="wf_attempt0", tmp_path=tmp_path,
            _transport=_make_audio_failing_transport(),
        ),
    )
    take_0.execute(context=ctx)
    assert take_0.status == "failed"

    # Reset registry/bootstrap so the next StepRunner gets re-wired (matches
    # CP-7 e2e idiom in test_take_e2e_scenarios.py).
    _reset_for_tests()
    _reset_bootstrap_for_tests()

    # Retry as NEW Take with fresh workflow — succeeds
    take_1 = beat.new_take(
        workflow=_audio_only_workflow(
            workflow_id="wf_attempt1", tmp_path=tmp_path,
            shot_id="EP001_VO01_retry",
        ),
    )
    take_1.execute(context=ctx)
    assert take_1.status == "succeeded"

    chosen = beat.select_primary()
    assert chosen == take_1.take_id
    assert beat.primary_take is take_1
    # Both takes preserved on beat.takes — first failure is NOT discarded.
    assert len(beat.takes) == 2
    assert beat.takes[0].status == "failed"
    assert beat.takes[1].status == "succeeded"


def test_take_metadata_round_trip_audio_stem_and_carrier_video(tmp_path):
    """Take.take_metadata with audio_stem + carrier_video keys round-trips
    via to_dict / from_dict."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_test", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    beat = Beat(beat_id="EP001_SH02")
    take = beat.new_take(
        workflow=_audio_then_lipsync_workflow(tmp_path=tmp_path),
        take_metadata={
            "audio_stem": "EP001_VO01_v3",
            "carrier_video": "/projects/tartarus/output/EP001_SH02_v0.mp4",
            "model": "eleven_multilingual_v2",
            "operator": "claude",
        },
    )
    take.execute(context=ctx, pre_step=_audio_to_lipsync_hook)

    serialized = take.to_dict()
    assert serialized["take_metadata"]["audio_stem"] == "EP001_VO01_v3"
    assert (
        serialized["take_metadata"]["carrier_video"]
        == "/projects/tartarus/output/EP001_SH02_v0.mp4"
    )

    revived = Take.from_dict(serialized)
    assert revived.take_metadata["audio_stem"] == "EP001_VO01_v3"
    assert (
        revived.take_metadata["carrier_video"]
        == "/projects/tartarus/output/EP001_SH02_v0.mp4"
    )
    assert revived.take_metadata["model"] == "eleven_multilingual_v2"
    assert revived.take_metadata["operator"] == "claude"
    # status round-trips on a fully-executed take
    assert revived.status == "succeeded"


def test_three_takes_first_two_fail_third_succeeds(tmp_path):
    """3 takes; first 2 fail (audio transport raises); third succeeds.
    Primary selection picks index 2."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_test", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    beat = Beat(beat_id="EP001_VO01")

    # Take 0: fail
    t0 = beat.new_take(
        workflow=_audio_only_workflow(
            workflow_id="wf0", tmp_path=tmp_path,
            _transport=_make_audio_failing_transport(),
        ),
    )
    t0.execute(context=ctx)
    assert t0.status == "failed"

    _reset_for_tests(); _reset_bootstrap_for_tests()

    # Take 1: fail
    t1 = beat.new_take(
        workflow=_audio_only_workflow(
            workflow_id="wf1", tmp_path=tmp_path,
            _transport=_make_audio_failing_transport(),
        ),
    )
    t1.execute(context=ctx)
    assert t1.status == "failed"

    _reset_for_tests(); _reset_bootstrap_for_tests()

    # Take 2: succeed
    t2 = beat.new_take(
        workflow=_audio_only_workflow(
            workflow_id="wf2", tmp_path=tmp_path,
            shot_id="EP001_VO01_t2",
        ),
    )
    t2.execute(context=ctx)
    assert t2.status == "succeeded"

    chosen = beat.select_primary()
    assert chosen == t2.take_id
    assert beat.primary_take.take_index == 2


def test_round_trip_preserves_receipts_on_succeeded_take(tmp_path):
    """Successful audio+lipsync take serializes + deserializes with both
    receipts intact."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_test", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    beat = Beat(beat_id="EP001_SH02")
    take = beat.new_take(
        workflow=_audio_then_lipsync_workflow(tmp_path=tmp_path),
    )
    take.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    assert take.status == "succeeded"

    revived = Take.from_dict(take.to_dict())
    assert revived.status == "succeeded"
    assert revived.workflow.steps[0].receipt is not None
    assert revived.workflow.steps[1].receipt is not None
    assert revived.workflow.steps[0].receipt.modality == "audio_t2a"
    assert revived.workflow.steps[1].receipt.modality == "lipsync_post"
    assert revived.workflow.steps[0].receipt.run_result.success is True
    assert revived.workflow.steps[1].receipt.run_result.success is True


def test_round_trip_preserves_receipts_on_partial_take(tmp_path):
    """Partial take (audio succeeded, lipsync failed) serializes with both
    receipts intact — including the failure receipt."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_test", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    beat = Beat(beat_id="EP001_SH02")
    take = beat.new_take(
        workflow=_audio_then_lipsync_workflow(
            tmp_path=tmp_path,
            lipsync_overrides={"_transport": _make_lipsync_failing_transport()},
        ),
    )
    take.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    assert take.status == "partial"

    revived = Take.from_dict(take.to_dict())
    assert revived.status == "partial"
    assert revived.workflow.steps[0].receipt is not None
    assert revived.workflow.steps[1].receipt is not None
    assert revived.workflow.steps[0].receipt.run_result.success is True
    assert revived.workflow.steps[1].receipt.run_result.success is False
    # Error message preserved on the failed step's receipt
    assert revived.workflow.steps[1].receipt.run_result.error is not None


def test_select_primary_first_success_skips_partial(tmp_path):
    """select_primary 'first_success' skips partial takes — if takes are
    [partial, succeeded], primary picks the succeeded one."""
    sr = _StubStepRunner()
    ctx = DispatchContext(
        caller_id="cp8_p6_test", step_runner=sr,
        receipts_log_path="DISABLED",
    )
    beat = Beat(beat_id="EP001_SH02")

    # Take 0: partial (audio ok, lipsync fails)
    t0 = beat.new_take(
        workflow=_audio_then_lipsync_workflow(
            workflow_id="wf_partial", tmp_path=tmp_path,
            lipsync_overrides={"_transport": _make_lipsync_failing_transport()},
        ),
    )
    t0.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    assert t0.status == "partial"

    _reset_for_tests(); _reset_bootstrap_for_tests()

    # Take 1: fully succeeded
    t1 = beat.new_take(
        workflow=_audio_then_lipsync_workflow(
            workflow_id="wf_ok", tmp_path=tmp_path, shot_id="EP001_SH02_t1",
        ),
    )
    t1.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    assert t1.status == "succeeded"

    chosen = beat.select_primary()
    assert chosen == t1.take_id  # partial is NOT picked
    assert beat.primary_take is t1


def test_jsonl_audit_log_emits_one_line_per_step(tmp_path):
    """audio + lipsync workflow under a Take = 2 dispatches = 2 JSONL lines."""
    sr = _StubStepRunner()
    log_path = tmp_path / "receipts.jsonl"
    ctx = DispatchContext(
        caller_id="cp8_p6_test", step_runner=sr,
        receipts_log_path=str(log_path),
    )
    beat = Beat(beat_id="EP001_SH02")
    take = beat.new_take(
        workflow=_audio_then_lipsync_workflow(tmp_path=tmp_path),
    )
    take.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    assert take.status == "succeeded"

    lines = [
        json.loads(line)
        for line in log_path.read_text().splitlines()
        if line.strip()
    ]
    assert len(lines) == 2
    modalities = [ln["modality"] for ln in lines]
    assert modalities == ["audio_t2a", "lipsync_post"]


def test_take_status_compression_two_step(tmp_path):
    """Take.status compression on 2-step audio+lipsync workflow:
      both succeeded → 'succeeded'
      one succeeded + one failed → 'partial'
      both failed (audio fails → lipsync skipped, no success) → 'failed'
    """
    sr = _StubStepRunner()

    # Case A: both succeeded
    _reset_for_tests(); _reset_bootstrap_for_tests()
    ctx = DispatchContext(caller_id="cp8_p6_test", step_runner=sr,
                          receipts_log_path="DISABLED")
    take_a = Take(
        take_id="t_a", take_index=0,
        workflow=_audio_then_lipsync_workflow(
            workflow_id="wf_A", tmp_path=tmp_path, shot_id="A",
        ),
    )
    take_a.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    assert take_a.status == "succeeded"

    # Case B: one succeeded + one failed → partial
    _reset_for_tests(); _reset_bootstrap_for_tests()
    take_b = Take(
        take_id="t_b", take_index=0,
        workflow=_audio_then_lipsync_workflow(
            workflow_id="wf_B", tmp_path=tmp_path, shot_id="B",
            lipsync_overrides={"_transport": _make_lipsync_failing_transport()},
        ),
    )
    take_b.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    assert take_b.workflow.steps[0].status == "succeeded"
    assert take_b.workflow.steps[1].status == "failed"
    assert take_b.status == "partial"

    # Case C: audio fails → lipsync skipped; no step succeeded → 'failed'
    _reset_for_tests(); _reset_bootstrap_for_tests()
    take_c = Take(
        take_id="t_c", take_index=0,
        workflow=_audio_then_lipsync_workflow(
            workflow_id="wf_C", tmp_path=tmp_path, shot_id="C",
            audio_overrides={"_transport": _make_audio_failing_transport()},
        ),
    )
    take_c.execute(context=ctx, pre_step=_audio_to_lipsync_hook)
    assert take_c.workflow.steps[0].status == "failed"
    assert take_c.workflow.steps[1].status == "skipped"
    assert take_c.status == "failed"
