"""CP-6 Phase 3 — Workflow.run executor + hook tests.

These tests monkey-patch `_run_step_via_dispatch` so they exercise the
executor's control flow in isolation from dispatch(). Phase 4's tests
exercise the real dispatch integration.
"""

import sys
import pathlib
from types import SimpleNamespace

sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent.parent.parent))
from recoil.core.paths import ensure_pipeline_importable  # noqa: E402

ensure_pipeline_importable()

import pytest  # noqa: E402

from recoil.pipeline.core import workflow as wf_mod  # noqa: E402
from recoil.pipeline.core.workflow import (  # noqa: E402
    Workflow,
    WorkflowStep,
    WorkflowValidationError,
)


def _step(step_id="s1", modality="image_t2i", depends_on=None) -> WorkflowStep:
    return WorkflowStep(
        step_id=step_id,
        modality=modality,
        payload={"shot_id": "X", "prompt": "p", "model": "nbp"},
        depends_on=depends_on or [],
    )


def _fake_receipt(success=True, error=None, modality="image_t2i", shot_id="X"):
    """Minimal duck-typed receipt for executor tests. Has .run_result.success/error."""
    rr = SimpleNamespace(
        success=success,
        error=error,
        modality=modality,
        output_path="/tmp/x" if success else None,
    )
    return SimpleNamespace(
        run_result=rr,
        modality=modality,
        receipt_id=f"rcpt_fake_{shot_id}_{modality}",
        provenance={},
    )


@pytest.fixture
def patch_dispatch(monkeypatch):
    """Replace _run_step_via_dispatch with a callable test fixture can configure."""

    def _make_patch(fn):
        monkeypatch.setattr(wf_mod, "_run_step_via_dispatch", fn)

    return _make_patch


# ── DAG validation ─────────────────────────────────────────────────────


def test_run_rejects_empty_workflow(patch_dispatch):
    patch_dispatch(lambda step, wf, ctx: _fake_receipt())
    wf = Workflow(workflow_id="wf1", steps=[])
    with pytest.raises(WorkflowValidationError):
        wf.run(context=object())


def test_run_rejects_missing_dependency(patch_dispatch):
    patch_dispatch(lambda step, wf, ctx: _fake_receipt())
    wf = Workflow(
        workflow_id="wf1",
        steps=[
            _step(step_id="s1", depends_on=["nonexistent"]),
        ],
    )
    with pytest.raises(WorkflowValidationError):
        wf.run(context=object())


def test_run_rejects_self_dependency(patch_dispatch):
    patch_dispatch(lambda step, wf, ctx: _fake_receipt())
    wf = Workflow(
        workflow_id="wf1",
        steps=[
            _step(step_id="s1", depends_on=["s1"]),
        ],
    )
    with pytest.raises(WorkflowValidationError):
        wf.run(context=object())


def test_run_rejects_forward_reference(patch_dispatch):
    patch_dispatch(lambda step, wf, ctx: _fake_receipt())
    wf = Workflow(
        workflow_id="wf1",
        steps=[
            _step(step_id="s1", depends_on=["s2"]),
            _step(step_id="s2"),
        ],
    )
    with pytest.raises(WorkflowValidationError):
        wf.run(context=object())


def test_run_accepts_topologically_ordered_workflow(patch_dispatch):
    patch_dispatch(lambda step, wf, ctx: _fake_receipt(modality=step.modality))
    wf = Workflow(
        workflow_id="wf1",
        steps=[
            _step(step_id="s1"),
            _step(step_id="s2", depends_on=["s1"]),
            _step(step_id="s3", depends_on=["s2"]),
        ],
    )
    wf.run(context=object())
    assert all(s.status == "succeeded" for s in wf.steps)


# ── Status transitions ────────────────────────────────────────────────


def test_run_marks_succeeded_on_success(patch_dispatch):
    patch_dispatch(
        lambda step, wf, ctx: _fake_receipt(success=True, modality=step.modality)
    )
    wf = Workflow(workflow_id="wf1", steps=[_step()])
    wf.run(context=object())
    assert wf.steps[0].status == "succeeded"
    assert wf.steps[0].receipt is not None
    assert wf.steps[0].error is None


def test_run_marks_failed_on_runresult_failure(patch_dispatch):
    patch_dispatch(
        lambda step, wf, ctx: _fake_receipt(
            success=False,
            error="upstream API down",
            modality=step.modality,
        )
    )
    wf = Workflow(workflow_id="wf1", steps=[_step()])
    wf.run(context=object())
    assert wf.steps[0].status == "failed"
    assert wf.steps[0].receipt is not None
    assert "upstream API down" in (wf.steps[0].error or "")


def test_run_marks_failed_on_dispatch_exception(patch_dispatch):
    def _boom(step, wf, ctx):
        raise RuntimeError("dispatch crashed")

    patch_dispatch(_boom)
    wf = Workflow(workflow_id="wf1", steps=[_step()])
    wf.run(context=object())
    assert wf.steps[0].status == "failed"
    assert wf.steps[0].receipt is None  # exception path: no receipt
    assert "RuntimeError" in (wf.steps[0].error or "")
    assert "dispatch crashed" in (wf.steps[0].error or "")


def test_run_skips_dependent_when_upstream_fails(patch_dispatch):
    def _first_fails(step, wf, ctx):
        if step.step_id == "s1":
            return _fake_receipt(success=False, error="boom", modality=step.modality)
        return _fake_receipt(modality=step.modality)

    patch_dispatch(_first_fails)
    wf = Workflow(
        workflow_id="wf1",
        steps=[
            _step(step_id="s1"),
            _step(step_id="s2", depends_on=["s1"]),
            _step(step_id="s3", depends_on=["s2"]),
        ],
    )
    wf.run(context=object())
    assert wf.steps[0].status == "failed"
    assert wf.steps[1].status == "skipped"
    assert wf.steps[2].status == "skipped"
    assert "upstream dependency failed" in (wf.steps[1].error or "")
    assert "upstream dependency failed" in (wf.steps[2].error or "")


def test_run_continues_to_independent_branch_when_one_branch_fails(patch_dispatch):
    def _branch_a_fails(step, wf, ctx):
        if step.step_id == "a":
            return _fake_receipt(
                success=False, error="branch a boom", modality=step.modality
            )
        return _fake_receipt(modality=step.modality)

    patch_dispatch(_branch_a_fails)
    wf = Workflow(
        workflow_id="wf1",
        steps=[
            _step(step_id="a"),
            _step(step_id="b"),  # independent of a
            _step(step_id="b2", depends_on=["b"]),
            _step(step_id="a2", depends_on=["a"]),
        ],
    )
    wf.run(context=object())
    assert wf.steps[0].status == "failed"  # a
    assert wf.steps[1].status == "succeeded"  # b
    assert wf.steps[2].status == "succeeded"  # b2 (depends on b — fine)
    assert wf.steps[3].status == "skipped"  # a2 (depends on a — skipped)


def test_run_records_started_and_finished_us(patch_dispatch):
    patch_dispatch(lambda step, wf, ctx: _fake_receipt(modality=step.modality))
    wf = Workflow(workflow_id="wf1", steps=[_step()])
    wf.run(context=object())
    assert wf.steps[0].started_us is not None
    assert wf.steps[0].finished_us is not None
    assert wf.steps[0].finished_us >= wf.steps[0].started_us


# ── Hook contract ─────────────────────────────────────────────────────


def test_pre_and_post_step_hooks_fire_in_order(patch_dispatch):
    patch_dispatch(lambda step, wf, ctx: _fake_receipt(modality=step.modality))
    log = []
    wf = Workflow(workflow_id="wf1", steps=[_step(step_id="s1"), _step(step_id="s2")])
    wf.run(
        context=object(),
        pre_step=lambda s, w: log.append(f"pre:{s.step_id}"),
        post_step=lambda s, w: log.append(f"post:{s.step_id}"),
    )
    assert log == ["pre:s1", "post:s1", "pre:s2", "post:s2"]


def test_on_failure_hook_fires_only_on_failure(patch_dispatch):
    def _s1_fails(step, wf, ctx):
        if step.step_id == "s1":
            return _fake_receipt(success=False, error="nope", modality=step.modality)
        return _fake_receipt(modality=step.modality)

    patch_dispatch(_s1_fails)
    log = []
    wf = Workflow(workflow_id="wf1", steps=[_step(step_id="s1"), _step(step_id="s2")])
    wf.run(
        context=object(),
        on_failure=lambda s, w: log.append(f"fail:{s.step_id}"),
    )
    assert log == ["fail:s1"]
    assert wf.steps[1].status == "succeeded"  # s2 didn't depend on s1; ran fine


def test_skipped_steps_do_not_fire_hooks(patch_dispatch):
    def _s1_fails(step, wf, ctx):
        if step.step_id == "s1":
            return _fake_receipt(success=False, error="nope", modality=step.modality)
        return _fake_receipt(modality=step.modality)

    patch_dispatch(_s1_fails)
    pre_log, post_log, fail_log = [], [], []
    wf = Workflow(
        workflow_id="wf1",
        steps=[
            _step(step_id="s1"),
            _step(step_id="s2", depends_on=["s1"]),
        ],
    )
    wf.run(
        context=object(),
        pre_step=lambda s, w: pre_log.append(s.step_id),
        post_step=lambda s, w: post_log.append(s.step_id),
        on_failure=lambda s, w: fail_log.append(s.step_id),
    )
    # s1 ran (failed), s2 was skipped — hooks must NOT fire for s2.
    assert pre_log == ["s1"]
    assert post_log == ["s1"]
    assert fail_log == ["s1"]


def test_post_step_fires_on_dispatch_exception_too(patch_dispatch):
    def _boom(step, wf, ctx):
        raise RuntimeError("crashed")

    patch_dispatch(_boom)
    log = []
    wf = Workflow(workflow_id="wf1", steps=[_step()])
    wf.run(
        context=object(),
        post_step=lambda s, w: log.append(("post", s.step_id, s.status)),
        on_failure=lambda s, w: log.append(("fail", s.step_id, s.status)),
    )
    assert log == [("post", "s1", "failed"), ("fail", "s1", "failed")]


def test_hook_exception_propagates(patch_dispatch):
    patch_dispatch(lambda step, wf, ctx: _fake_receipt(modality=step.modality))

    def _bad_hook(s, w):
        raise ValueError("hook crashed")

    wf = Workflow(workflow_id="wf1", steps=[_step()])
    with pytest.raises(ValueError, match="hook crashed"):
        wf.run(context=object(), pre_step=_bad_hook)


def test_pre_step_exception_finalizes_step_to_failed(patch_dispatch):
    """Step must reach a terminal state (status=failed) before the
    pre_step exception propagates — no leaking 'running' status to
    downstream callers."""
    patch_dispatch(lambda step, wf, ctx: _fake_receipt(modality=step.modality))

    def _bad_hook(s, w):
        raise ValueError("hook crashed")

    wf = Workflow(workflow_id="wf1", steps=[_step()])
    with pytest.raises(ValueError, match="hook crashed"):
        wf.run(context=object(), pre_step=_bad_hook)
    # Step must be in a terminal state, not "running"
    assert wf.steps[0].status == "failed"
    assert "ValueError" in (wf.steps[0].error or "")
    assert "hook crashed" in (wf.steps[0].error or "")
    assert wf.steps[0].finished_us is not None
    assert wf.steps[0].finished_us >= wf.steps[0].started_us
    # post_step / on_failure NOT fired on pre_step exceptions
    assert wf.steps[0].receipt is None  # dispatch never ran


def test_run_returns_self(patch_dispatch):
    patch_dispatch(lambda step, wf, ctx: _fake_receipt(modality=step.modality))
    wf = Workflow(workflow_id="wf1", steps=[_step()])
    out = wf.run(context=object())
    assert out is wf


def test_rerun_clears_stale_execution_state(patch_dispatch):
    """Calling Workflow.run() a second time must reset error/receipt/timing
    so a previously-failed step that now succeeds doesn't carry stale error
    alongside status='succeeded'. Per CP-7 follow-up #9 documented semantics."""

    # First run: dispatch fails
    def _first_fails(step, wf, ctx):
        raise RuntimeError("first run boom")

    patch_dispatch(_first_fails)
    wf = Workflow(workflow_id="wf1", steps=[_step()])
    try:
        wf.run(context=object())
    except Exception:
        pass
    assert wf.steps[0].status == "failed"
    assert "first run boom" in (wf.steps[0].error or "")
    first_finished = wf.steps[0].finished_us
    assert first_finished is not None

    # Second run: dispatch succeeds — stale error/timing must be cleared
    patch_dispatch(lambda step, wf, ctx: _fake_receipt(modality=step.modality))
    wf.run(context=object())
    assert wf.steps[0].status == "succeeded"
    assert wf.steps[0].error is None  # cleared, not stale "first run boom"
    assert wf.steps[0].receipt is not None
    assert wf.steps[0].finished_us != first_finished  # fresh timestamp


def test_rerun_with_hook_exception_resets_unreached_steps(patch_dispatch):
    """Per CP-6 post-build review (Opus R1, §S2): calling Workflow.run() a
    second time where a hook raises mid-loop must NOT leave un-reached steps
    in stale "succeeded" state from the first run.

    Run 1: both steps succeed → A.status=succeeded, B.status=succeeded.
    Run 2: pre_step raises on A → loop aborts before B is reached.
            B must show "pending" with no receipt, NOT stale "succeeded"
            from run 1, otherwise CP-7 Take consumers see phantom-succeeded
            steps when iterating wf.steps after catching a hook exception.
    """
    patch_dispatch(lambda step, wf, ctx: _fake_receipt(modality=step.modality))
    wf = Workflow(
        workflow_id="wf_rerun_hook",
        steps=[_step(step_id="a"), _step(step_id="b", depends_on=["a"])],
    )

    # Run 1: clean success.
    wf.run(context=object())
    assert wf.steps[0].status == "succeeded"
    assert wf.steps[1].status == "succeeded"
    first_b_receipt = wf.steps[1].receipt
    first_b_finished = wf.steps[1].finished_us
    assert first_b_receipt is not None

    # Run 2: pre_step raises on step A → loop aborts before reaching B.
    def _boom(step, wf_):
        if step.step_id == "a":
            raise RuntimeError("hook boom")

    try:
        wf.run(context=object(), pre_step=_boom)
    except RuntimeError:
        pass

    # A is finalized to "failed" with the hook exception.
    assert wf.steps[0].status == "failed"
    assert "hook boom" in (wf.steps[0].error or "")
    # B is "pending" (reset upfront before the loop), NOT stale "succeeded".
    assert wf.steps[1].status == "pending", (
        f"Expected B reset to pending after hook-aborted re-run, got "
        f"{wf.steps[1].status!r} — stale state would mislead Take consumers."
    )
    assert wf.steps[1].receipt is None  # cleared from run 1
    assert wf.steps[1].started_us is None
    assert wf.steps[1].finished_us is None
    # Sanity: the prior receipt is gone, not just shadowed.
    assert wf.steps[1].receipt is not first_b_receipt
