"""Backfill parent_take_id on video takes in existing shot JSONs.

For each shot JSON, find video takes that are missing parent_take_id and link
them to the most recent image take that appears before them in the takes list.

Algorithm
---------
1. Classify each take as image or video:
   - image_takes: pipeline in {"keyframe", "image", "image_t2i", "still", "previz"}
                  OR no "pipeline" key AND file_path doesn't end in ".mp4"
   - video_takes: pipeline in {"video", "i2v", "t2v", "video_i2v"}
                  OR file_path ends with ".mp4"

2. For each video take missing parent_take_id:
   - Find the most recent image take with a lower list index.
   - Resolve its ID using _resolve_take_id (inline from lineage.py).
   - Set parent_take_id.  If no image take precedes it, leave parent_take_id null.

3. Write atomically: temp file → os.replace().

Usage
-----
    python3 backfill_parent_take_id.py [--dry-run] [--project <slug>] [--verbose]
"""
from __future__ import annotations

import argparse
import json
import os
import sys
import tempfile
from pathlib import Path

# Allow import of recoil.core.paths without a full package install.
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent.parent))


# ---------------------------------------------------------------------------
# _resolve_take_id — inlined from recoil/api/adapters/lineage.py
# Do NOT import from lineage; keep this self-contained.
# ---------------------------------------------------------------------------

def _resolve_take_id(raw: dict, beat_id: str, idx: int) -> str:
    """Stable id for a take dict — take_id field if present, else synthesized
    from take_number / take_num / list index."""
    if raw.get("take_id"):
        return str(raw["take_id"])
    n = raw.get("take_number") or raw.get("take_num") or idx
    return f"{beat_id}_T{int(n):03d}"


# ---------------------------------------------------------------------------
# Classification helpers
# ---------------------------------------------------------------------------

_IMAGE_PIPELINES = {"keyframe", "image", "image_t2i", "still", "previz"}
_VIDEO_PIPELINES = {"video", "i2v", "t2v", "video_i2v", "multi_shot", "sequence_sequential"}


def _is_video_take(take: dict) -> bool:
    pipeline = take.get("pipeline")
    file_path = str(take.get("file_path") or take.get("output_path") or "")
    if pipeline in _VIDEO_PIPELINES:
        return True
    if file_path.lower().endswith(".mp4"):
        return True
    return False


def _is_image_take(take: dict) -> bool:
    pipeline = take.get("pipeline")
    file_path = str(take.get("file_path") or take.get("output_path") or "")
    if pipeline in _IMAGE_PIPELINES:
        return True
    if pipeline is None and not file_path.lower().endswith(".mp4"):
        # No pipeline key — classify by extension.
        return bool(file_path)
    return False


# ---------------------------------------------------------------------------
# Core backfill logic for a single shot dict
# ---------------------------------------------------------------------------

def backfill_shot(shot: dict, verbose: bool = False) -> tuple[dict, int]:
    """Return (updated_shot, changes_count).

    Does not mutate `shot` in place — returns a shallow copy of the top-level
    dict with a new 'takes' list where individual take dicts may be new dicts
    (when parent_take_id was set) or the originals (when unchanged).
    """
    shot_id = shot.get("shot_id", "UNKNOWN")
    raw_takes: list = shot.get("takes") or []
    if not isinstance(raw_takes, list):
        return shot, 0

    new_takes: list[dict] = []
    changes = 0

    # Track the index of the last image take seen so far.
    last_image_idx: int = -1  # -1 = none seen yet

    for idx, take in enumerate(raw_takes):
        if not isinstance(take, dict):
            new_takes.append(take)
            continue

        if _is_image_take(take) and not _is_video_take(take):
            last_image_idx = idx
            new_takes.append(take)

        elif _is_video_take(take):
            if "parent_take_id" in take:
                # Already has the key (even if null) — skip to preserve existing linkage.
                new_takes.append(take)
            elif last_image_idx == -1:
                # No image take precedes this video take — leave untouched.
                new_takes.append(take)
            else:
                # Link to the most recent preceding image take.
                image_take = raw_takes[last_image_idx]
                parent_id = _resolve_take_id(image_take, shot_id, last_image_idx)
                updated = dict(take)
                updated["parent_take_id"] = parent_id
                new_takes.append(updated)
                changes += 1
                if verbose:
                    vid_id = _resolve_take_id(take, shot_id, idx)
                    print(
                        f"  {shot_id}: video take {vid_id} → parent_take_id={parent_id}"
                    )
        else:
            # Unclassified (e.g. audio, unknown) — leave untouched.
            new_takes.append(take)

    if changes == 0:
        return shot, 0

    updated_shot = dict(shot)
    updated_shot["takes"] = new_takes
    return updated_shot, changes


# ---------------------------------------------------------------------------
# File I/O helpers
# ---------------------------------------------------------------------------

def _shots_dir(projects_root: Path, project_slug: str) -> Path:
    from recoil.core.paths import ProjectPaths
    return ProjectPaths.for_project(project_slug).shots_dir


def _write_atomic(path: Path, data: dict) -> None:
    """Write JSON to a temp file in the same directory, then os.replace()."""
    dir_ = path.parent
    fd, tmp_path = tempfile.mkstemp(dir=dir_, suffix=".tmp")
    try:
        with os.fdopen(fd, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=2, ensure_ascii=False)
            f.write("\n")
        os.replace(tmp_path, path)
    except Exception:
        # Best-effort cleanup of temp file on failure.
        try:
            os.unlink(tmp_path)
        except OSError:
            pass
        raise


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> None:
    parser = argparse.ArgumentParser(
        description="Backfill parent_take_id on video takes in shot JSONs."
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Print what would change without writing any files.",
    )
    parser.add_argument(
        "--project",
        metavar="SLUG",
        default=None,
        help="Limit to one project slug.",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Print each take being updated.",
    )
    args = parser.parse_args()

    from recoil.core.paths import projects_root as _projects_root  # noqa: PLC0415
    try:
        projects_root = _projects_root()
    except Exception as exc:
        print(f"ERROR: cannot resolve projects root: {exc}", file=sys.stderr)
        sys.exit(1)

    # Determine which project slugs to process.
    if args.project:
        slugs = [args.project]
    else:
        slugs = sorted(
            p.name for p in projects_root.iterdir()
            if p.is_dir() and not p.name.startswith(("_", "."))
        )

    total_files = 0
    total_changes = 0
    total_shots_updated = 0

    for slug in slugs:
        shots_dir = _shots_dir(projects_root, slug)
        if not shots_dir.exists():
            if args.verbose:
                print(f"[{slug}] no shots dir, skipping")
            continue

        shot_files = sorted(shots_dir.glob("*.json"))
        if not shot_files:
            continue

        for shot_path in shot_files:
            total_files += 1
            try:
                with open(shot_path, encoding="utf-8") as f:
                    shot = json.load(f)
            except (json.JSONDecodeError, OSError) as e:
                print(f"WARN: could not read {shot_path}: {e}", file=sys.stderr)
                continue

            updated_shot, changes = backfill_shot(shot, verbose=args.verbose)

            if changes > 0:
                total_changes += changes
                total_shots_updated += 1
                if args.dry_run:
                    print(
                        f"[DRY RUN] {slug}/{shot_path.name}: "
                        f"{changes} video take(s) would be updated"
                    )
                else:
                    _write_atomic(shot_path, updated_shot)
                    if args.verbose:
                        print(
                            f"[WROTE] {slug}/{shot_path.name}: "
                            f"{changes} video take(s) updated"
                        )

    suffix = " (dry run)" if args.dry_run else ""
    print(
        f"\nDone{suffix}. "
        f"Scanned {total_files} shot file(s) across {len(slugs)} project(s). "
        f"Updated {total_shots_updated} shot(s), "
        f"{total_changes} video take(s) linked."
    )


if __name__ == "__main__":
    main()
