"""Tests for take-rooted lineage filtering (`get_lineage(take_id=...)`).

The default beat-rooted behavior is covered by tests in test_engine_routes.py.
This file focuses on the new take_id chain-walk path added 2026-05-12.
"""
from __future__ import annotations

from typing import Optional
from unittest.mock import patch

import pytest

from recoil.api.adapters.lineage import _walk_parent_chain, get_lineage


# ── Unit tests for the chain walker ──────────────────────────────────────────


def _take(tid: str, parent: Optional[str] = None, **extra) -> dict:
    """Build a take dict with the minimum the walker cares about."""
    d = {"take_id": tid}
    if parent is not None:
        d["parent_take_id"] = parent
    d.update(extra)
    return d


def test_walk_parent_chain_single_take_no_parent():
    takes = [_take("T1")]
    chain = _walk_parent_chain(takes, "T1", "BEAT")
    assert [t["take_id"] for t in chain] == ["T1"]


def test_walk_parent_chain_linear_three_deep():
    """T3 → parent T2 → parent T1 → no parent (root). Returned root → target."""
    takes = [
        _take("T1"),
        _take("T2", parent="T1"),
        _take("T3", parent="T2"),
    ]
    chain = _walk_parent_chain(takes, "T3", "BEAT")
    assert [t["take_id"] for t in chain] == ["T1", "T2", "T3"]


def test_walk_parent_chain_siblings_excluded():
    """Two takes share a parent. Walking from one shouldn't include the other."""
    takes = [
        _take("T_KEY"),
        _take("T_V1", parent="T_KEY"),
        _take("T_V2", parent="T_KEY"),
    ]
    chain = _walk_parent_chain(takes, "T_V1", "BEAT")
    assert [t["take_id"] for t in chain] == ["T_KEY", "T_V1"]
    # T_V2 is NOT in the chain — confirming sibling exclusion.
    assert "T_V2" not in {t["take_id"] for t in chain}


def test_walk_parent_chain_missing_target_returns_empty():
    takes = [_take("T1"), _take("T2", parent="T1")]
    chain = _walk_parent_chain(takes, "T_DOES_NOT_EXIST", "BEAT")
    assert chain == []


def test_walk_parent_chain_cycle_safe():
    """Malformed data with a cycle shouldn't loop forever."""
    takes = [
        _take("A", parent="B"),
        _take("B", parent="A"),  # cycle!
    ]
    chain = _walk_parent_chain(takes, "A", "BEAT")
    # Both nodes visited once, no infinite loop. Order: root-of-visit → target.
    ids = [t["take_id"] for t in chain]
    assert set(ids) == {"A", "B"}
    assert len(ids) == 2


def test_walk_parent_chain_parent_pointing_at_missing_id():
    """Parent linkage to an id not in the list — chain stops there."""
    takes = [_take("T2", parent="T_MISSING")]
    chain = _walk_parent_chain(takes, "T2", "BEAT")
    assert [t["take_id"] for t in chain] == ["T2"]


def test_walk_parent_chain_synthesized_id_for_legacy_takes():
    """Takes without `take_id` get synthesized IDs from take_number/index."""
    takes = [
        {"take_number": 1},  # no take_id → resolves to "BEAT_T001"
        {"take_number": 2, "parent_take_id": "BEAT_T001"},
    ]
    chain = _walk_parent_chain(takes, "BEAT_T002", "BEAT")
    # Walker should find the synthesized "BEAT_T001" via the parent pointer
    # and emit both in root → target order.
    assert len(chain) == 2
    assert chain[0]["take_number"] == 1
    assert chain[1]["take_number"] == 2


# ── Integration tests for get_lineage(take_id=...) ───────────────────────────


@pytest.fixture
def _stub_shot():
    """Patch _resolve_shot to return a controlled shot dict without disk IO."""
    shot = {
        "shot_id": "BEAT",
        "model": "gemini-3.1-flash-image-preview",
        "takes": [
            {
                "take_id": "T_KEY",
                "take_number": 1,
                "file_path": "output/previs/keyframe.png",
                "compiled_prompt": "Establishing wide shot.",
                "cost_usd": 0.04,
                "model": "gemini-3.1-flash-image-preview",
            },
            {
                "take_id": "T_V1",
                "take_number": 2,
                "parent_take_id": "T_KEY",
                "file_path": "output/video/v1.mp4",
                "cost_usd": 0.5,
                "model": "kling-o3",
                "inputs_snapshot": {
                    "source": "snapshot",
                    "prompt_flat": "Establishing wide shot.",
                    "prompt_layers": [],
                    "refs_used": [],
                    "routing": {"pipeline": "i2v", "model": "kling-o3", "tier": "production", "reason": "test"},
                    "parent_take_id": "T_KEY",
                    "builder_name": "video_i2v",
                    "bible_version": "test123",
                    "config_hash": "",
                    "generation_params": None,
                    "bible_files": None,
                },
            },
            {
                "take_id": "T_V2",
                "take_number": 3,
                "parent_take_id": "T_KEY",
                "file_path": "output/video/v2.mp4",
                "cost_usd": 0.5,
                "model": "kling-o3",
            },
        ],
    }
    with patch(
        "recoil.api.adapters.lineage._resolve_shot",
        return_value=(shot, "tartarus"),
    ):
        yield shot


def test_get_lineage_take_rooted_returns_manifest(_stub_shot):
    L = get_lineage("BEAT", project_id="tartarus", take_id="T_V1")
    assert L is not None
    assert L.edges == []
    kinds = {n.kind for n in L.nodes}
    assert "parent_take" in kinds
    assert "prompt" in kinds
    assert "params" in kinds


def test_get_lineage_take_rooted_manifest_has_parent_take(_stub_shot):
    L = get_lineage("BEAT", project_id="tartarus", take_id="T_V1")
    parent = next(n for n in L.nodes if n.kind == "parent_take")
    assert parent.parent_take_id == "T_KEY"


def test_get_lineage_beat_rooted_unchanged_when_take_id_omitted(_stub_shot):
    """Back-compat: no take_id → full beat graph with sibling nodes."""
    L = get_lineage("BEAT", project_id="tartarus")
    assert L is not None
    step_kinds = [n.kind for n in (L.nodes or []) if "step" in (n.id or "")]
    # First take is "step", the other two are "sibling".
    assert step_kinds.count("step") == 1
    assert step_kinds.count("sibling") == 2


def test_get_lineage_unresolved_take_id_falls_back_to_beat_graph(_stub_shot):
    """If take_id doesn't match any take, return beat-rooted graph (not 404).

    A race between generation completion and frontend fetch could surface a
    take_id the shot file doesn't know about yet. Better to show the beat
    overview than an empty graph.
    """
    L = get_lineage("BEAT", project_id="tartarus", take_id="T_NOT_IN_SHOT")
    assert L is not None
    # Full beat graph: T_KEY + T_V1 + T_V2 all present.
    step_node_ids = {n.id for n in (L.nodes or []) if "step" in (n.id or "")}
    assert {"T_KEY_step", "T_V1_step", "T_V2_step"} <= step_node_ids
