#!/usr/bin/env python3
"""One-shot migration: backfill `episode_id` into shot JSON files.

For every `projects/<project>/state/visual/shots/<SHOT>.json`, if the file
does not have `episode_id`, derive it from the filename prefix (substring
before the first underscore) and write the field back.

Idempotent: rerunning is a no-op for files that already have `episode_id`.

Usage:
    python recoil/pipeline/tools/backfill_episode_id.py [--dry-run] [--project <id>]
"""
from __future__ import annotations

import argparse
import json
import re
import sys
from pathlib import Path

from recoil.core.atomic_write import atomic_write_json
from recoil.core.paths import projects_root, ProjectPaths

SHOT_FILE_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_]*\.json$")


def _derive(stem: str) -> str:
    # Same convention as `_derive_episode_id` stage 4: prefix before first _
    return stem.split("_")[0]


def _shot_dirs(project_filter: str | None) -> list[Path]:
    proj_root = projects_root()
    if not proj_root.exists():
        return []
    out: list[Path] = []
    for proj_dir in sorted(proj_root.iterdir()):
        if not proj_dir.is_dir():
            continue
        if project_filter and proj_dir.name != project_filter:
            continue
        shots = ProjectPaths.from_root(proj_dir).shots_dir
        if shots.is_dir():
            out.append(shots)
    return out


def backfill(dry_run: bool, project_filter: str | None) -> tuple[int, int, int]:
    """Return (scanned, updated, skipped_already_set)."""
    scanned = 0
    updated = 0
    skipped = 0
    for shots_dir in _shot_dirs(project_filter):
        for path in sorted(shots_dir.iterdir()):
            if not SHOT_FILE_RE.match(path.name):
                continue
            scanned += 1
            # REF_* files are reference assets, not shots — beats.py explicitly
            # excludes them from episode derivation. Skip before reading so we
            # don't inject a phantom episode_id="REF" into the episode tree.
            if path.stem.startswith("REF_"):
                skipped += 1
                continue
            try:
                data = json.loads(path.read_text())
            except Exception as exc:
                print(f"  ! skip (read error) {path}: {exc}", file=sys.stderr)
                continue
            if not isinstance(data, dict):
                print(f"  ! skip (not a JSON object) {path}", file=sys.stderr)
                continue
            if data.get("episode_id"):
                skipped += 1
                continue
            stem = path.stem
            ep_id = _derive(stem)
            if not ep_id:
                print(f"  ! skip (no derivable id) {path}", file=sys.stderr)
                continue
            data["episode_id"] = ep_id
            if dry_run:
                print(f"  + would set episode_id={ep_id!r} in {path}")
            else:
                atomic_write_json(path, data)
                print(f"  + set episode_id={ep_id!r} in {path}")
            updated += 1
    return scanned, updated, skipped


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--dry-run", action="store_true")
    parser.add_argument("--project", default=None,
                        help="Only process this project (default: all)")
    args = parser.parse_args()
    scanned, updated, skipped = backfill(args.dry_run, args.project)
    verb = "would update" if args.dry_run else "updated"
    print(f"\nScanned {scanned} files; {verb} {updated}; skipped {skipped} "
          f"already-set.")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
