"""Lineage adapter — derives a graph from a beat's takes + receipts.

Two modes:

1. **Beat-rooted** (no `take_id` provided): emit one node per take with the
   first take as `step` and the rest as `sibling`. Useful as a beat overview.

2. **Manifest mode** (`take_id` provided and resolves to a take with an
   `inputs_snapshot`): emit a flat list of input nodes describing what was
   injected into the take. `edges=[]` is the manifest signal. Used by the
   inspector's LineageStrip.

3. **Take-rooted** (legacy, `take_id` provided but resolved take has no
   `inputs_snapshot` and reconstruction yields nothing): walk the
   `parent_take_id` linkage backward from the target take and emit ONLY the
   chain of takes that produced it. Falls back to beat-rooted mode if the
   take_id isn't found in the beat's takes list.

## parent_take_id schema convention

A take dict may include `parent_take_id: Optional[str]` pointing at another
take in the same beat (typically the keyframe that fed an i2v take). When
absent or null, the take is treated as a chain root (a prompt-driven start).
The generation pipeline is responsible for setting it at take creation;
older takes without the field still render as single-step chains.

This adapter is READ-ONLY. It never writes to projects/.
"""

from __future__ import annotations

import logging
import re
from typing import Any, Optional

from recoil.api.adapters.beats import _load_shot, _shots_dir
from recoil.api.schemas.engine import (
    SCHEMA_VERSION,
    Lineage,
    LineageEdge,
    LineageNode,
)
from recoil.core.paths import projects_root
from recoil.pipeline._lib.take_inputs import reconstruct_inputs


def _coerce_eval_detail(gate: dict[str, Any]) -> Optional[dict[str, Any]]:
    """Normalize a gate_1/gate_2 dict into LineageNode.eval_detail shape."""
    if not isinstance(gate, dict):
        return None
    passed = gate.get("passed")
    details = gate.get("details") or {}
    judges = []
    if isinstance(details, dict):
        for name, sub in details.items():
            if not isinstance(sub, dict):
                continue
            judges.append(
                {
                    "name": str(name),
                    "verdict": "pass" if sub.get("pass") else "fail",
                    "score": float(sub.get("score") or 0.0),
                    "note": str(sub.get("reason") or ""),
                }
            )
    return {
        "verdict": "pass" if passed else "fail",
        "score": 1.0 if passed else 0.0,
        "judges": judges,
    }


def _resolve_take_id(raw: dict, beat_id: str, idx: int) -> str:
    """Stable id for a take dict — `take_id` field if present, else synthesized
    from take_number / take_num / list index. Matches the synthesis the
    rendering loop does so chain-walk + node-emit agree on identity.
    """
    if raw.get("take_id"):
        return str(raw["take_id"])
    n = raw.get("take_number") or raw.get("take_num") or idx
    return f"{beat_id}_T{int(n):03d}"


def _walk_parent_chain(takes: list, target_take_id: str, beat_id: str) -> list[dict]:
    """Return the ancestry chain from root → target_take_id.

    Walks `parent_take_id` backward. Empty list when target isn't in takes.
    Cycle-safe (stops on first revisited id). Takes without a parent_take_id
    field are treated as chain roots — so a take that never had a parent set
    still renders as a single-step chain.
    """
    by_id: dict[str, dict] = {}
    for idx, raw in enumerate(takes):
        if not isinstance(raw, dict):
            continue
        by_id[_resolve_take_id(raw, beat_id, idx)] = raw

    if target_take_id not in by_id:
        return []

    chain_ids: list[str] = []
    visited: set[str] = set()
    current: Optional[str] = target_take_id
    while current and current not in visited:
        if current not in by_id:
            # parent_take_id points at an id we don't have. Stop the walk
            # without adding the missing id — chain is whatever we've found
            # so far rooted at the nearest known ancestor.
            break
        visited.add(current)
        chain_ids.append(current)
        raw = by_id[current]
        parent = raw.get("parent_take_id")
        current = str(parent) if parent else None

    chain_ids.reverse()
    return [by_id[tid] for tid in chain_ids]


def _resolve_shot(
    beat_id: str, project_id: Optional[str]
) -> tuple[Optional[dict], Optional[str]]:
    """Locate the shot dict AND capture which project slug it lives under.

    Mirrors get_shot_dict() in beats.py but also returns the resolved slug
    so the lineage adapter can build `/api/media/{slug}/...` URLs on output
    nodes. When project_id is provided this is a single-file stat.
    """
    candidates: list[str]
    if project_id:
        candidates = [project_id]
    else:
        root = projects_root()
        candidates = (
            [p.name for p in root.iterdir() if p.is_dir()] if root.exists() else []
        )
    for slug in candidates:
        path = _shots_dir(slug) / f"{beat_id}.json"
        if path.exists():
            return _load_shot(path), slug
    return None, None


def _path_to_media(p: str) -> Optional[str]:
    pl = p.lower()
    if pl.endswith((".mp4", ".mov", ".webm")):
        return "video"
    if pl.endswith((".png", ".jpg", ".jpeg", ".webp")):
        return "image"
    if pl.endswith((".mp3", ".wav", ".m4a")):
        return "audio"
    return None


def _resolve_take_by_id(
    takes: list, beat_id: str, target: str
) -> tuple[Optional[dict], Optional[int]]:
    """3-step server-side cascade. Returns (take_dict, index) or (None, None).

    1. Exact match on take['take_id'].
    2. Prefix strip: target like '{beat_id}_T{N:03d}' → integer N → match
       on take_number.
    3. Numeric extraction: any 'T(\\d+)' suffix → integer N → match on
       take_number.

    When step 2 or 3 finds multiple takes with the same take_number
    (legacy-data hazard: regen takes carry a take_num that overlaps with
    a pre-existing fixture take), prefer the one with `file_path` set —
    that's the take whose artifact the inspector is actually displaying.
    """
    prefix = f"{beat_id}_T"
    target_n_prefix: Optional[int] = None
    if target.startswith(prefix):
        try:
            target_n_prefix = int(target[len(prefix) :])
        except ValueError:
            pass
    m = re.search(r"[Tt](\d+)", target)
    target_n_regex: Optional[int] = int(m.group(1)) if m else None

    def _better(
        current: tuple[Optional[dict], Optional[int]], candidate: tuple[dict, int]
    ) -> tuple[dict, int]:
        if current[0] is None:
            return candidate
        cur_has_file = bool(
            current[0].get("file_path") or current[0].get("output_path")
        )
        cand_has_file = bool(
            candidate[0].get("file_path") or candidate[0].get("output_path")
        )
        # Promote the candidate when it has a real artifact and the
        # incumbent doesn't (the inspector displays artifact-bearing takes).
        if cand_has_file and not cur_has_file:
            return candidate
        return current

    prefix_match: tuple[Optional[dict], Optional[int]] = (None, None)
    regex_match: tuple[Optional[dict], Optional[int]] = (None, None)

    for idx, raw in enumerate(takes):
        if not isinstance(raw, dict):
            continue
        if raw.get("take_id") == target:
            return raw, idx
        take_num = int(raw.get("take_number") or raw.get("take_num") or -1)
        if target_n_prefix is not None and take_num == target_n_prefix:
            prefix_match = _better(prefix_match, (raw, idx))
        if target_n_regex is not None and take_num == target_n_regex:
            regex_match = _better(regex_match, (raw, idx))

    return prefix_match if prefix_match[0] is not None else regex_match


# Pattern matching the label format emitted by previz_context._build_character_refs:
# "Reference: {CHAR_UPPER} -- {role-text}"
_REF_LABEL_RE = re.compile(r"^Reference:\s*([A-Za-z_][A-Za-z0-9_]*)\s*--\s*(.+?)\s*$")


def _resolve_ref_url_from_label(
    label: str, project_id: Optional[str]
) -> tuple[Optional[str], Optional[str]]:
    """Best-effort recovery for refs whose snapshot dropped the path.

    `build_previz_inputs_snapshot()` writes refs_used with `url=""` (the path
    is consumed when the bytes are sent to the model). The label survives in
    the canonical format, so we reverse it back to a canonical ref file.

    Returns (resolved_url, media_kind) or (None, None) when unresolvable.
    """
    if not label or not project_id:
        return None, None
    m = _REF_LABEL_RE.match(label)
    if not m:
        return None, None
    char_id = m.group(1).lower()
    role_text = m.group(2).lower()

    if "hero" in role_text:
        angle = "hero"
    elif "front" in role_text:
        angle = "front"
    elif "three quarter" in role_text or "three-quarter" in role_text:
        angle = "three_quarter"
    elif "profile" in role_text:
        angle = "profile"
    elif "back" in role_text:
        angle = "back"
    else:
        return None, None

    try:
        from recoil.core.ref_resolver import resolve_character_refs

        proj_root = projects_root() / project_id
        refs_root = proj_root / "output" / "refs"
        if not refs_root.exists():
            return None, None
        resolved = resolve_character_refs(refs_root, char_id)
        path = resolved.get(angle)
        if path and path.is_file():
            rel = path.relative_to(proj_root)
            url = f"/api/media/{project_id}/{rel.as_posix()}"
            ext = path.suffix.lower()
            media = "image" if ext in (".png", ".jpg", ".jpeg", ".webp") else None
            return url, media
    except Exception:
        return None, None
    return None, None


def _resolve_ref_url(ref_url: str, project_id: Optional[str]) -> Optional[str]:
    if not ref_url:
        return None
    if ref_url.startswith(("http://", "https://", "data:", "/api/", "/", "file://")):
        return ref_url
    return f"/api/media/{project_id}/{ref_url}" if project_id else ref_url


def _build_manifest_lineage(
    beat_id: str,
    take: dict,
    take_id: str,
    shot: dict,
    project_id: Optional[str],
) -> Lineage:
    """Build an input-manifest Lineage from a take's `inputs_snapshot`.

    Manifest mode signal: `edges=[]`. The strip iterates nodes in declaration
    order: parent_take → refs → prompt → bible → model_config (+ optional
    failure-note footer).

    For legacy takes without `inputs_snapshot`, calls `reconstruct_inputs()`
    for a best-effort manifest. If reconstruction yields no usable data,
    returns a single `kind="note"` placeholder node.
    """
    snap = take.get("inputs_snapshot")
    if not snap:
        # Legacy take — best-effort reconstruct.
        try:
            snap = reconstruct_inputs(
                take=take,
                shot=shot,
                bible=shot.get("bible") or {},
                project_config=shot.get("project_config") or {},
                prompt_sections=shot.get("prompt_sections"),
                reference_images=shot.get("reference_images"),
            )
        except (TypeError, AttributeError, KeyError, ValueError):
            logging.warning("reconstruct_inputs failed for take in beat %s", beat_id)
            snap = None

    nodes: list[LineageNode] = []

    if not snap or (
        not snap.get("prompt_flat")
        and not snap.get("refs_used")
        and not snap.get("routing", {}).get("model")
    ):
        # Reconstruction yielded nothing — explicit placeholder.
        nodes.append(
            LineageNode(
                schema_version=SCHEMA_VERSION,
                id=f"{take_id}_legacy_note",
                beat_id=beat_id,
                kind="note",
                label="Legacy take",
                sub="No input manifest recorded",
                col=0,
                row=0,
            )
        )
        return Lineage(
            schema_version=SCHEMA_VERSION,
            beat_id=beat_id,
            root_take=take_id,
            nodes=nodes,
            edges=[],
        )

    # 1. parent_take node (if a parent take_id is recorded)
    parent_id = snap.get("parent_take_id")
    if parent_id:
        nodes.append(
            LineageNode(
                schema_version=SCHEMA_VERSION,
                id=f"{take_id}_parent",
                beat_id=beat_id,
                kind="parent_take",
                label="Parent take",
                sub=str(parent_id),
                col=0,
                row=len(nodes),
                parent_take_id=str(parent_id),
            )
        )

    # 2. ref nodes — one per ref_used
    for i, ref in enumerate(snap.get("refs_used") or []):
        if not isinstance(ref, dict):
            continue
        ref_type = str(ref.get("type") or "")
        ref_url = str(ref.get("url") or "")
        ref_label = str(ref.get("label") or ref.get("id") or ref_type or "ref")
        media_kind = _path_to_media(ref_url) if ref_url else None
        resolved_url = _resolve_ref_url(ref_url, project_id)
        # build_previz_inputs_snapshot() writes refs_used with empty url.
        # Recover from the canonical label format so the strip can show
        # thumbnails.
        if not resolved_url:
            recovered_url, recovered_media = _resolve_ref_url_from_label(
                ref_label, project_id
            )
            if recovered_url:
                resolved_url = recovered_url
                media_kind = recovered_media
        nodes.append(
            LineageNode(
                schema_version=SCHEMA_VERSION,
                id=f"{take_id}_ref_{i}",
                beat_id=beat_id,
                kind="ref",
                label=ref_label,
                sub=ref_type or None,
                col=0,
                row=len(nodes),
                media_kind=media_kind,
                url=resolved_url,
                ref_role=ref_type or None,
                ref_hash=str(ref.get("id") or "") or None,
            )
        )

    # 3. prompt node — full body of the flat prompt
    prompt_flat = snap.get("prompt_flat") or ""
    builder = snap.get("builder_name")
    if prompt_flat:
        nodes.append(
            LineageNode(
                schema_version=SCHEMA_VERSION,
                id=f"{take_id}_prompt",
                beat_id=beat_id,
                kind="prompt",
                label="Prompt",
                sub=(str(builder) if builder else None),
                col=0,
                row=len(nodes),
                media_kind="text",
                prompt_body=prompt_flat,
            )
        )

    # 4. bible nodes
    bible_files = snap.get("bible_files")
    bible_version = snap.get("bible_version")
    if bible_files:
        for i, bf in enumerate(bible_files):
            if not isinstance(bf, dict):
                continue
            nodes.append(
                LineageNode(
                    schema_version=SCHEMA_VERSION,
                    id=f"{take_id}_bible_{i}",
                    beat_id=beat_id,
                    kind="bible",
                    label=str(bf.get("file") or "bible"),
                    sub=", ".join(bf.get("sections") or []) or None,
                    col=0,
                    row=len(nodes),
                    bible_file=str(bf.get("file") or ""),
                    bible_sections=list(bf.get("sections") or []),
                )
            )
    elif bible_version:
        nodes.append(
            LineageNode(
                schema_version=SCHEMA_VERSION,
                id=f"{take_id}_bible_version",
                beat_id=beat_id,
                kind="bible",
                label="Bible",
                sub=f"version {str(bible_version)[:8]}",
                col=0,
                row=len(nodes),
            )
        )

    # 5. params node — model config as key/value rows.
    # The take's own `model` is authoritative — `routing.model` from a
    # reconstructed or hardcoded snapshot can disagree (e.g. video take
    # inherited a previz still-keyframe snapshot whose routing is pinned
    # to gemini-3.1-flash-image-preview). Prefer the take.
    routing = snap.get("routing") or {}
    effective_model = str(take.get("model") or routing.get("model") or "")
    params_rows: list[list[str]] = []
    if routing.get("pipeline"):
        params_rows.append(["pipeline", str(routing["pipeline"])])
    if routing.get("tier"):
        params_rows.append(["tier", str(routing["tier"])])
    gen_params = snap.get("generation_params") or {}
    for k in ("duration", "aspect_ratio", "seed", "guidance_scale"):
        if gen_params.get(k) is not None:
            params_rows.append([k, str(gen_params[k])])
    if snap.get("config_hash"):
        params_rows.append(["config_hash", str(snap["config_hash"])[:8]])
    if effective_model or params_rows:
        nodes.append(
            LineageNode(
                schema_version=SCHEMA_VERSION,
                id=f"{take_id}_params",
                beat_id=beat_id,
                kind="params",
                label="Model config",
                sub=effective_model or None,
                col=0,
                row=len(nodes),
                media_kind="text",
                model=effective_model or None,
                params_body=params_rows,
            )
        )

    # 6. failure footer
    if take.get("disposition") == "rejected" or take.get("error"):
        gv = take.get("gate_verdict") or {}
        reason = (str(gv.get("reason") or take.get("error") or "Generation failed"))[
            :200
        ]
        nodes.append(
            LineageNode(
                schema_version=SCHEMA_VERSION,
                id=f"{take_id}_failure",
                beat_id=beat_id,
                kind="note",
                label="Generation failed",
                sub=reason,
                col=0,
                row=len(nodes),
                failed=True,
            )
        )

    return Lineage(
        schema_version=SCHEMA_VERSION,
        beat_id=beat_id,
        root_take=take_id,
        nodes=nodes,
        edges=[],  # MANIFEST SIGNAL — no graph, flat list.
    )


def _build_beat_lineage(
    beat_id: str,
    resolved_project_id: Optional[str],
    raw_takes: list,
    is_chain: bool,
    shot: dict,
) -> Lineage:
    """Build the beat-graph Lineage (take-graph with siblings + outputs + eval edges).

    Extracted from the original get_lineage body. `_resolve_shot` and null-checks
    are handled by the caller.
    """
    nodes: list[LineageNode] = []
    edges: list[LineageEdge] = []

    # Synthetic root: the prompt that drove the chain (or the beat).
    prompt_id = f"{beat_id}_prompt"
    primary_prompt = None
    for raw in raw_takes:
        if isinstance(raw, dict) and raw.get("compiled_prompt"):
            primary_prompt = raw.get("compiled_prompt")
            break
    if primary_prompt is None:
        for raw in raw_takes:
            if isinstance(raw, dict) and raw.get("prompt_used"):
                primary_prompt = raw.get("prompt_used")
                break
    nodes.append(
        LineageNode(
            schema_version=SCHEMA_VERSION,
            id=prompt_id,
            beat_id=beat_id,
            kind="prompt",
            label="prompt",
            sub=None,
            col=0,
            row=0,
            media_kind="text",
            prompt_body=str(primary_prompt) if primary_prompt else None,
        )
    )

    root_take_id: Optional[str] = None
    for idx, raw in enumerate(raw_takes):
        if not isinstance(raw, dict):
            continue
        tid = _resolve_take_id(raw, beat_id, idx)
        if root_take_id is None and (raw.get("file_path") or raw.get("output_path")):
            root_take_id = tid

        # step node for the take itself. In chain mode every node is a step
        # (linear ancestry); in beat mode the first is `step`, rest `sibling`.
        step_node_id = f"{tid}_step"
        step_kind = "step" if is_chain or idx == 0 else "sibling"
        step_col = idx + 1 if is_chain else (1 if idx == 0 else 2)
        step_row = idx if is_chain else idx  # tighter rows in chain mode
        # Label by the take's actual number, not the loop index — so a chain
        # showing take 8 reads as "take 8", not "take 0" (which is what the
        # loop-local idx would render).
        take_num = raw.get("take_number") or raw.get("take_num")
        step_label = (
            f"take {int(take_num)}"
            if take_num is not None
            else (tid.rsplit("_", 1)[-1] if "_" in tid else f"take {idx}")
        )
        nodes.append(
            LineageNode(
                schema_version=SCHEMA_VERSION,
                id=step_node_id,
                beat_id=beat_id,
                kind=step_kind,
                label=step_label,
                sub=str(raw.get("model") or shot.get("model") or "?"),
                col=step_col,
                row=step_row,
                media_kind=None,
                model=str(raw.get("model") or shot.get("model") or ""),
                cost=float(raw.get("cost") or raw.get("cost_usd") or 0.0),
                failed=bool(raw.get("rejected") or raw.get("error")),
            )
        )
        # Edge: prompt → first-step (chain root), then step[i-1] → step[i] for
        # subsequent chain links. Beat mode: every step from prompt.
        if is_chain and idx > 0:
            prev_tid = _resolve_take_id(raw_takes[idx - 1], beat_id, idx - 1)
            edges.append(
                LineageEdge(
                    **{"from": f"{prev_tid}_step", "to": step_node_id, "kind": "data"}
                )
            )
        else:
            edges.append(
                LineageEdge(**{"from": prompt_id, "to": step_node_id, "kind": "data"})
            )

        # output node
        out_path = raw.get("file_path") or raw.get("output_path")
        if out_path:
            out_id = f"{tid}_out"
            out_url = (
                f"/api/media/{resolved_project_id}/{out_path}"
                if resolved_project_id
                else None
            )
            out_col = step_col + 1
            nodes.append(
                LineageNode(
                    schema_version=SCHEMA_VERSION,
                    id=out_id,
                    beat_id=beat_id,
                    kind="output",
                    label=str(out_path).rsplit("/", 1)[-1],
                    col=out_col,
                    row=step_row,
                    media_kind=_path_to_media(str(out_path)),
                    url=out_url,
                )
            )
            edges.append(
                LineageEdge(**{"from": step_node_id, "to": out_id, "kind": "data"})
            )

        # eval node
        gate = raw.get("gate_1") or raw.get("gate_2")
        if isinstance(gate, dict):
            eval_id = f"{tid}_eval"
            eval_col = step_col + (2 if out_path else 1)
            nodes.append(
                LineageNode(
                    schema_version=SCHEMA_VERSION,
                    id=eval_id,
                    beat_id=beat_id,
                    kind="eval",
                    label="gate",
                    col=eval_col,
                    row=step_row,
                    media_kind="text",
                    eval_detail=_coerce_eval_detail(gate),
                    failed=not bool(gate.get("passed")),
                )
            )
            edges.append(
                LineageEdge(**{"from": step_node_id, "to": eval_id, "kind": "data"})
            )

    if root_take_id is None:
        # No completed takes — pick the synthetic id for the first dict take.
        for idx, raw in enumerate(raw_takes):
            if isinstance(raw, dict):
                root_take_id = _resolve_take_id(raw, beat_id, idx)
                break

    return Lineage(
        schema_version=SCHEMA_VERSION,
        beat_id=beat_id,
        root_take=root_take_id or f"{beat_id}_T000",
        nodes=nodes,
        edges=edges,
    )


def get_lineage(
    beat_id: str,
    project_id: str,
    take_id: Optional[str] = None,
) -> Optional[Lineage]:
    """Build a Lineage for a beat (or for one take).

    Manifest mode (take_id resolves to a take with `inputs_snapshot`):
    emit an input-manifest Lineage — flat list of nodes describing what
    was injected into the take, with `edges=[]`. Used by the inspector's
    LineageStrip.

    Beat-rooted mode (no take_id, take_id unresolved, OR resolved take has
    no `inputs_snapshot` and `reconstruct_inputs` finds nothing): emit the
    existing take-graph Lineage with siblings + outputs + eval edges. Used
    by the LineageExplorer DAG canvas and as the fallback when the manifest
    path can't recover usable data.
    """
    shot, resolved_project_id = _resolve_shot(beat_id, project_id)
    if shot is None:
        return None

    raw_takes = shot.get("takes") or []
    if not isinstance(raw_takes, list):
        raw_takes = []

    is_chain = False
    if take_id:
        chain = _walk_parent_chain(raw_takes, take_id, beat_id)
        if chain:
            raw_takes = chain
            is_chain = True

    if take_id:
        target_take, _idx = _resolve_take_by_id(raw_takes, beat_id, take_id)
        if target_take is not None:
            canonical_take_id = (
                str(target_take.get("take_id"))
                if target_take.get("take_id")
                else _resolve_take_id(target_take, beat_id, _idx or 0)
            )
            return _build_manifest_lineage(
                beat_id=beat_id,
                take=target_take,
                take_id=canonical_take_id,
                shot=shot,
                project_id=resolved_project_id,
            )
        # Unresolved take_id: fall through to beat-rooted mode

    return _build_beat_lineage(
        beat_id=beat_id,
        resolved_project_id=resolved_project_id,
        raw_takes=raw_takes,
        is_chain=is_chain,
        shot=shot,
    )


__all__ = ["get_lineage"]
