"""Workflow object model — typed DAG of dispatch() calls.

CP-6 deliverable. Wraps the CP-5 dispatch surface in a declarative data
model. Today's executor is linear; tomorrow's Console v2 visual node-graph
adds branching execution semantics with zero data migration (per JT
locked Open Q #2 in loraverse SYNTHESIS).

Schema is locked at CP-6. CP-7's Take wraps a step's GenerationReceipt;
CP-8 adds audio steps; CP-9's PanelOfJudges hooks into pre/post-step.
Adding fields post-CP-6 is safe; renaming/removing is a contract break.

Public surface (frozen at CP-6):
    WorkflowStep   — one node in a workflow DAG
    Workflow       — a DAG of WorkflowSteps
    StepStatus     — Literal["pending","running","succeeded","failed","skipped"]
    Workflow.run   — linear executor (Phase 3)
    Workflow.to_dict / from_dict — JSON round-trip

JSON round-trip: Workflow.from_dict(w.to_dict()) == w (when no execution
state attached) and step-for-step equal (when executed — receipts round-trip
via GenerationReceipt.to_dict / from_dict).
"""

from __future__ import annotations

import time
from dataclasses import dataclass, field, replace
from typing import Any, Callable, Literal, Optional

from recoil.pipeline.core.receipts import GenerationReceipt, utc_now_iso8601


StepStatus = Literal["pending", "running", "succeeded", "failed", "skipped"]


@dataclass
class WorkflowStep:
    """One node in a Workflow DAG. Wraps one dispatch() call.

    Fields:
        step_id: Stable id, unique within the parent Workflow. Free-form
            string. Recommended format: `{role}_{shot_id}` (e.g.
            "keyframe_EP001_SH02", "video_EP001_SH02").
        modality: Canonical modality string passed to dispatch(). Must
            match one of the registered modalities at execution time
            (image_t2i / video_i2v live; audio_t2a / lipsync_post stub
            in CP-6 — they raise via dispatch and surface as failed step).
        payload: Per-modality payload dict passed to dispatch(). Schema
            documented on each runner class. Extra orchestration keys
            (client_id, etc.) pass through harmlessly.
        depends_on: List of step_ids that must complete (status="succeeded")
            before this step runs. Default: empty (root step).
        status: Execution state. CP-6 transitions:
            pending → running → succeeded | failed
            pending → skipped (when an upstream dependency failed)
        receipt: GenerationReceipt produced by this step's dispatch call.
            None until the executor runs the step. Set even on failure
            (the receipt's run_result.success carries the failure flag).
        error: String error message captured outside the receipt (e.g.
            executor-level errors like missing dependency). None when
            the receipt itself carries the error.
        started_us / finished_us: Microsecond epoch timestamps captured
            by the executor. None until the step runs.
    """

    step_id: str
    modality: str
    payload: dict[str, Any]
    depends_on: list[str] = field(default_factory=list)
    status: StepStatus = "pending"
    receipt: Optional[GenerationReceipt] = None
    error: Optional[str] = None
    started_us: Optional[int] = None
    finished_us: Optional[int] = None

    def __post_init__(self) -> None:
        if not isinstance(self.step_id, str) or not self.step_id:
            raise ValueError(
                f"WorkflowStep.step_id must be a non-empty string, got {self.step_id!r}"
            )
        if not isinstance(self.modality, str) or not self.modality:
            raise ValueError(
                f"WorkflowStep.modality must be a non-empty string, got {self.modality!r}"
            )
        if not isinstance(self.payload, dict):
            raise TypeError(
                f"WorkflowStep.payload must be a dict, got {type(self.payload).__name__}"
            )
        if not isinstance(self.depends_on, list):
            raise TypeError(
                f"WorkflowStep.depends_on must be a list, got {type(self.depends_on).__name__}"
            )

    def to_dict(self) -> dict[str, Any]:
        """Serialize for JSON. receipt round-trips via GenerationReceipt.to_dict."""
        return {
            "step_id": self.step_id,
            "modality": self.modality,
            "payload": dict(self.payload),
            "depends_on": list(self.depends_on),
            "status": self.status,
            "receipt": self.receipt.to_dict() if self.receipt is not None else None,
            "error": self.error,
            "started_us": self.started_us,
            "finished_us": self.finished_us,
        }

    @classmethod
    def from_dict(cls, d: dict[str, Any]) -> "WorkflowStep":
        receipt_dict = d.get("receipt")
        receipt = (
            GenerationReceipt.from_dict(receipt_dict)
            if receipt_dict is not None
            else None
        )
        return cls(
            step_id=d["step_id"],
            modality=d["modality"],
            payload=dict(d.get("payload") or {}),
            depends_on=list(d.get("depends_on") or []),
            status=d.get("status", "pending"),
            receipt=receipt,
            error=d.get("error"),
            started_us=d.get("started_us"),
            finished_us=d.get("finished_us"),
        )


@dataclass
class Workflow:
    """A DAG of WorkflowSteps. CP-6 ships linear execution semantics.

    Fields:
        workflow_id: Stable id for this workflow. Free-form string.
            Recommended format: `{project}_ep{NNN}_{role}` (e.g.
            "tartarus_ep001_full_shot_pipeline").
        steps: Ordered list of WorkflowStep. Order is the editor's intent;
            execution order is determined by `depends_on` topology. CP-6's
            linear executor walks the steps in declared order, ASSUMING
            the declared order is a valid topological sort. Phase 3
            validates this and rejects cyclic/forward-referencing graphs.
        global_provenance: Dict merged into every step's
            `dispatch_context.provenance_overrides` at execution time.
            Carries cross-step context (project, episode, scene_id,
            user-supplied tags, etc.).
        created_at: ISO 8601 UTC timestamp. Default: now.
    """

    workflow_id: str
    steps: list[WorkflowStep]
    global_provenance: dict[str, Any] = field(default_factory=dict)
    created_at: str = field(default_factory=utc_now_iso8601)

    def __post_init__(self) -> None:
        if not isinstance(self.workflow_id, str) or not self.workflow_id:
            raise ValueError(
                f"Workflow.workflow_id must be a non-empty string, got {self.workflow_id!r}"
            )
        if not isinstance(self.steps, list):
            raise TypeError(
                f"Workflow.steps must be a list, got {type(self.steps).__name__}"
            )
        # Allow empty steps list for serialization round-trip + edge case tests;
        # the executor (Phase 3) validates non-empty before run().
        for i, s in enumerate(self.steps):
            if not isinstance(s, WorkflowStep):
                raise TypeError(
                    f"Workflow.steps[{i}] must be a WorkflowStep, "
                    f"got {type(s).__name__}"
                )
        # Duplicate step_id detection happens here so it's caught at
        # construction time, not at run() time.
        seen = set()
        for s in self.steps:
            if s.step_id in seen:
                raise ValueError(
                    f"Duplicate step_id {s.step_id!r} in workflow {self.workflow_id!r}"
                )
            seen.add(s.step_id)

    def get_step(self, step_id: str) -> WorkflowStep:
        """Return the WorkflowStep with the given id. Raises KeyError if missing."""
        for s in self.steps:
            if s.step_id == step_id:
                return s
        raise KeyError(
            f"No step {step_id!r} in workflow {self.workflow_id!r}. "
            f"Steps: {[s.step_id for s in self.steps]}"
        )

    def to_dict(self) -> dict[str, Any]:
        return {
            "workflow_id": self.workflow_id,
            "steps": [s.to_dict() for s in self.steps],
            "global_provenance": dict(self.global_provenance),
            "created_at": self.created_at,
        }

    @classmethod
    def from_dict(cls, d: dict[str, Any]) -> "Workflow":
        return cls(
            workflow_id=d["workflow_id"],
            steps=[WorkflowStep.from_dict(s) for s in (d.get("steps") or [])],
            global_provenance=dict(d.get("global_provenance") or {}),
            created_at=d.get("created_at") or utc_now_iso8601(),
        )


# ──────────────────────────────────────────────────────────────────────
# Phase 3: Linear executor + hooks
# Phase 4 wires the placeholder _run_step_via_dispatch into real dispatch().
# ──────────────────────────────────────────────────────────────────────


class WorkflowValidationError(ValueError):
    """Raised when Workflow.run detects a structural problem before execution."""


# Hook signatures:
#   pre_step(step, workflow)   — called before every executed step (NOT skipped)
#   post_step(step, workflow)  — called after every executed step regardless of outcome
#   on_failure(step, workflow) — called when a step fails (in ADDITION to post_step)
HookFn = Callable[["WorkflowStep", "Workflow"], None]


def _validate_dag(workflow: "Workflow") -> None:
    """Raise WorkflowValidationError if the workflow's DAG is malformed.

    Checks:
      1. Workflow has at least one step.
      2. Every depends_on reference resolves to an existing step_id.
      3. No cycles (topological sort succeeds).
      4. Declared step order is a valid topological sort — i.e. each step's
         dependencies appear earlier in the list. CP-6's linear executor
         walks in declared order; declared-order topological-validity is
         what makes that walk correct.
    """
    if not workflow.steps:
        raise WorkflowValidationError(f"Workflow {workflow.workflow_id!r} has no steps")

    step_ids = {s.step_id for s in workflow.steps}

    # Missing-dependency check.
    for s in workflow.steps:
        for dep in s.depends_on:
            if dep not in step_ids:
                raise WorkflowValidationError(
                    f"Step {s.step_id!r} depends on missing step {dep!r}"
                )
            if dep == s.step_id:
                raise WorkflowValidationError(
                    f"Step {s.step_id!r} cannot depend on itself"
                )

    # Declared-order topological-sort check.
    # This also catches cycles indirectly — a cycle would force at least one
    # step to depend on a later step in declared order.
    seen: set[str] = set()
    for s in workflow.steps:
        for dep in s.depends_on:
            if dep not in seen:
                raise WorkflowValidationError(
                    f"Step {s.step_id!r} depends on {dep!r} which is "
                    f"declared later in the workflow (or part of a cycle). "
                    f"CP-6's linear executor requires steps to appear in "
                    f"topological order."
                )
        seen.add(s.step_id)


def _run_step_via_dispatch(
    step: "WorkflowStep",
    workflow: "Workflow",
    context,  # DispatchContext
) -> "GenerationReceipt":
    """Phase 4 implementation — calls pipeline.core.dispatch.dispatch.

    Builds a per-step DispatchContext that merges:
      - the caller's DispatchContext (caller_id, step_runner, project,
        episode, receipts_log_path, provenance_overrides)
      - workflow.global_provenance (cross-step context)
      - {"workflow_id": ..., "workflow_step_id": ...} (this-step identity)

    The merge order is workflow.global_provenance < caller-supplied
    provenance_overrides < {workflow_id, workflow_step_id}. Workflow-step
    identity wins last so callers can't accidentally clobber it.

    DispatchContext is frozen — we construct a NEW DispatchContext for the
    per-step call rather than mutating the caller's. The new instance reuses
    every other field (caller_id, step_runner, project, episode,
    receipts_log_path) verbatim.
    """
    # Inline imports avoid circular dependencies at workflow.py module load.
    from recoil.pipeline.core.dispatch import dispatch as _dispatch
    from recoil.pipeline.core.dispatch_context import DispatchContext as _DispatchContext

    if not isinstance(context, _DispatchContext):
        raise TypeError(
            f"Workflow.run requires a DispatchContext, got {type(context).__name__}"
        )

    merged_overrides: dict[str, Any] = {}
    merged_overrides.update(workflow.global_provenance)
    merged_overrides.update(context.provenance_overrides)
    merged_overrides["workflow_id"] = workflow.workflow_id
    merged_overrides["workflow_step_id"] = step.step_id

    per_step_ctx = replace(context, provenance_overrides=merged_overrides)

    return _dispatch(step.modality, step.payload, context=per_step_ctx)


def _workflow_run(
    workflow: "Workflow",
    *,
    context,
    pre_step: Optional[HookFn] = None,
    post_step: Optional[HookFn] = None,
    on_failure: Optional[HookFn] = None,
) -> "Workflow":
    """Linear DAG executor. Mutates workflow.steps in place; returns workflow.

    Algorithm:
      1. Validate DAG.
      2. Walk steps in declared order.
      3. For each step:
         a. If any dependency's status != "succeeded", mark step "skipped";
            don't invoke pre_step / post_step / on_failure for skipped steps.
         b. Otherwise: status="running", started_us=now, fire pre_step,
            call _run_step_via_dispatch, attach receipt, set finished_us.
         c. status="succeeded" iff receipt.run_result.success is True
            AND no exception escaped, else "failed".
         d. Fire post_step regardless of outcome (when not skipped).
         e. On "failed", fire on_failure as well.
      4. Hook exceptions: a hook raising propagates with its original
         exception type (no wrapping). Before re-raising a pre_step
         exception, the executor finalizes the current step:
         status="failed", error captured, finished_us set — so the step
         lands in a terminal state instead of leaking "running" to
         downstream callers. post_step / on_failure are NOT fired for
         a pre_step-raised exception (the spec's "remaining steps are
         not touched" applies to per-step hook firing too).
         Rationale: hooks are part of the contract — silent suppression
         would hide CP-9 eval failures in CP-9+. CP-9 may revisit if
         best-effort hooks become preferable.
      5. Return the (now-mutated) workflow object.

    Error handling:
      - dispatch-level exceptions (e.g. unknown modality, payload TypeError)
        DO escape dispatch() — these become step.status="failed", step.error
        captures the exception, step.receipt remains None.
      - runner-level failures (RunResult.success=False) come back as a
        receipt with success=False; step.status="failed", step.receipt
        carries the error.
    """
    _validate_dag(workflow)

    # Re-run idempotency: clear stale execution state from any prior run() —
    # for ALL steps, BEFORE the execution loop. Doing it per-step inside the
    # loop is incomplete: if a hook raises mid-loop, un-reached steps retain
    # stale status="succeeded" + receipts from the prior run, which a caller
    # catching the exception sees as "phantom succeeded steps." Per CP-6
    # post-build review (Opus R1, §S2): reset upfront so an aborted re-run
    # leaves un-reached steps in clean "pending" state.
    for step in workflow.steps:
        step.status = "pending"
        step.error = None
        step.receipt = None
        step.started_us = None
        step.finished_us = None

    succeeded_ids: set[str] = set()

    for step in workflow.steps:
        # Dependency check.
        if any(dep not in succeeded_ids for dep in step.depends_on):
            step.status = "skipped"
            step.error = (
                f"upstream dependency failed: "
                f"{[d for d in step.depends_on if d not in succeeded_ids]}"
            )
            continue

        step.status = "running"
        step.started_us = int(time.time() * 1_000_000)

        if pre_step is not None:
            try:
                pre_step(step, workflow)
            except Exception as hook_exc:  # noqa: BLE001
                step.status = "failed"
                step.error = f"{type(hook_exc).__name__}: {hook_exc}"
                step.finished_us = int(time.time() * 1_000_000)
                raise

        try:
            receipt = _run_step_via_dispatch(step, workflow, context)
        except Exception as e:  # noqa: BLE001 — capture all dispatch-side failures
            step.status = "failed"
            step.error = f"{type(e).__name__}: {e}"
            step.finished_us = int(time.time() * 1_000_000)
            if post_step is not None:
                post_step(step, workflow)
            if on_failure is not None:
                on_failure(step, workflow)
            continue

        step.receipt = receipt
        step.finished_us = int(time.time() * 1_000_000)

        if receipt.run_result.success:
            step.status = "succeeded"
            succeeded_ids.add(step.step_id)
        else:
            step.status = "failed"
            step.error = receipt.run_result.error

        if post_step is not None:
            post_step(step, workflow)
        if step.status == "failed" and on_failure is not None:
            on_failure(step, workflow)

    return workflow


# Bind the executor as a method on Workflow.
Workflow.run = _workflow_run  # type: ignore[assignment]
