#!/usr/bin/env python3
"""
reconcile_take_paths.py

Reconcile stale video take paths in shot state JSON files.

Shot state files live at:
  projects/{project}/_pipeline/state/visual/shots/*.json

Each shot file has a `takes` array where each take may have a `file_path`
field pointing to a video (.mp4). When files are moved (e.g. into _orphans/
or renamed to PASS-convention), those stored paths become stale.

This script finds stale paths, tries to resolve them, and optionally patches
the JSON files in place.

Matching strategy (in priority order):
  1. Exact basename match in a non-_orphans location
  2. Exact basename match in _orphans/
  3. PASS-convention match: EP{NNN}_PASS_*_SH{N}_*_take{N}.mp4 where shot
     number and take index correspond to the stale path's shot/take numbers

Also reconciles `gate_results.video_path` when present.

Usage:
  python3 tools/reconcile_take_paths.py --project tartarus [--dry-run | --apply]
  python3 tools/reconcile_take_paths.py --all [--dry-run | --apply]
"""

import argparse
import json
import os
import re
import sys
from pathlib import Path

_HERE = os.path.dirname(os.path.abspath(__file__))
_RECOIL_ROOT = os.path.dirname(_HERE)
if _RECOIL_ROOT not in sys.path:
    sys.path.insert(0, _RECOIL_ROOT)

from recoil.core.paths import projects_root, ProjectPaths  # noqa: E402

# ── Configuration ────────────────────────────────────────────────────────────

PROJECTS_ROOT = str(projects_root())



# ── File index ───────────────────────────────────────────────────────────────

def build_video_index(project_root):
    """
    Walk the video output tree and build a dict:
      basename -> list of relative-to-project_root paths, non-orphan first.
    """
    pp = ProjectPaths.from_root(Path(project_root))
    video_dir = str(pp.renders_dir)
    if not os.path.isdir(video_dir):
        return {}

    index = {}  # basename -> [(rel_path, is_orphan), ...]
    for dirpath, dirnames, filenames in os.walk(video_dir):
        # Skip .ruff_cache and similar
        dirnames[:] = [d for d in dirnames if not d.startswith(".")]
        for fn in filenames:
            if not fn.endswith(".mp4"):
                continue
            abs_path = os.path.join(dirpath, fn)
            rel_path = os.path.relpath(abs_path, project_root)
            is_orphan = "_orphans" in rel_path.replace(os.sep, "/")
            index.setdefault(fn, []).append((rel_path, is_orphan))

    # Sort so non-orphans come first within each basename group
    for key in index:
        index[key].sort(key=lambda x: (x[1], x[0]))  # (is_orphan, path)

    return index


# ── Shot number extraction ────────────────────────────────────────────────────

def parse_old_shot_num(file_path):
    """
    From a stale path like 'output/video/ep_001/shot_032_take9.mp4',
    extract shot number (32) and take number (9).
    Returns (shot_num: int, take_num: int) or (None, None).
    """
    basename = os.path.basename(file_path)
    # shot_032_take9.mp4 or shot_001_take2.mp4
    m = re.match(r"shot_0*(\d+)(?:[A-Za-z]*)_take(\d+)\.mp4$", basename)
    if m:
        return int(m.group(1)), int(m.group(2))
    return None, None


def parse_shot_id_num(shot_id):
    """
    From a shot_id like 'EP001_SH32', extract (episode: 'EP001', shot_num: 32, suffix: '').
    Returns (episode, shot_num, suffix) or (None, None, None).
    """
    m = re.match(r"(EP\d+)_SH(\d+)([A-Za-z]*)$", shot_id)
    if m:
        return m.group(1), int(m.group(2)), m.group(3)
    return None, None, None


# ── PASS-convention matching ──────────────────────────────────────────────────

def find_pass_match(shot_id, stale_path, video_index):
    """
    Try to match a PASS-convention filename for the given shot.
    PASS filenames look like: EP001_PASS_014_SH30_31_A_JADE_take1.mp4

    We look for any indexed mp4 whose name:
      - starts with the episode prefix (EP001_PASS_...)
      - contains SH{shot_num} somewhere in the name
      - ends with _take{N}.mp4 where N matches the take number from stale_path

    Returns the best relative path match or None.
    """
    ep, shot_num, suffix = parse_shot_id_num(shot_id)
    if ep is None:
        return None

    _, take_num = parse_old_shot_num(stale_path)
    if take_num is None:
        return None

    # Pattern: starts with EP prefix, contains SH{N}, ends with _take{take_num}.mp4
    sh_pattern = re.compile(
        rf"^{re.escape(ep)}_PASS_.*_SH{shot_num}(?:[^_]*)?_.*_take{take_num}\.mp4$",
        re.IGNORECASE,
    )

    candidates = []
    for basename, entries in video_index.items():
        if sh_pattern.match(basename):
            # entries already sorted non-orphan-first
            candidates.extend(entries)

    if not candidates:
        return None

    # Prefer non-orphan
    candidates.sort(key=lambda x: (x[1], x[0]))
    return candidates[0][0]  # rel_path


# ── Path resolution ───────────────────────────────────────────────────────────

def resolve_stale_path(shot_id, stale_path, project_root, video_index):
    """
    Try to find a live file that corresponds to the stale path.
    Returns (new_rel_path, method) or (None, None) if unresolvable.

    Priority:
      1. Exact basename in a non-orphan location
      2. Exact basename in _orphans/
      3. PASS-convention match
    """
    basename = os.path.basename(stale_path)
    entries = video_index.get(basename, [])

    if entries:
        # entries is already sorted: non-orphan first
        best = entries[0]
        method = "orphan_relocation" if best[1] else "exact_non_orphan"
        return best[0], method

    # Try PASS-convention match
    pass_match = find_pass_match(shot_id, stale_path, video_index)
    if pass_match:
        return pass_match, "pass_convention"

    return None, None


# ── Shot file processing ──────────────────────────────────────────────────────

def process_shot_file(shot_path, project_root, video_index, apply):
    """
    Process a single shot JSON file.
    Returns a list of result dicts, one per stale path found:
      { 'shot_file', 'field', 'old_path', 'new_path', 'method', 'resolved' }
    """
    with open(shot_path) as f:
        shot = json.load(f)

    shot_id = shot.get("shot_id", os.path.splitext(os.path.basename(shot_path))[0])
    results = []
    patched = False

    # --- takes array ---
    for take in shot.get("takes", []):
        fp = take.get("file_path", "")
        if not fp.endswith(".mp4"):
            continue

        abs_path = os.path.join(project_root, fp)
        if os.path.exists(abs_path):
            continue  # live, skip

        new_path, method = resolve_stale_path(shot_id, fp, project_root, video_index)
        result = {
            "shot_file": os.path.relpath(shot_path, project_root),
            "field": "takes[].file_path",
            "old_path": fp,
            "new_path": new_path,
            "method": method,
            "resolved": new_path is not None,
        }
        results.append(result)

        if apply and new_path:
            take["file_path"] = new_path
            patched = True

    # --- gate_results.video_path ---
    gate_results = shot.get("gate_results")
    if isinstance(gate_results, dict):
        gvp = gate_results.get("video_path", "")
        if gvp.endswith(".mp4"):
            abs_path = os.path.join(project_root, gvp)
            if not os.path.exists(abs_path):
                new_path, method = resolve_stale_path(shot_id, gvp, project_root, video_index)
                result = {
                    "shot_file": os.path.relpath(shot_path, project_root),
                    "field": "gate_results.video_path",
                    "old_path": gvp,
                    "new_path": new_path,
                    "method": method,
                    "resolved": new_path is not None,
                }
                results.append(result)

                if apply and new_path:
                    gate_results["video_path"] = new_path
                    patched = True

    # --- atomic write if patched ---
    if patched:
        tmp_path = shot_path + ".tmp"
        with open(tmp_path, "w") as f:
            json.dump(shot, f, indent=2)
            f.write("\n")
        os.replace(tmp_path, shot_path)

    return results


# ── Project-level scan ────────────────────────────────────────────────────────

def reconcile_project(project_name, apply, verbose=True):
    project_root = os.path.join(PROJECTS_ROOT, project_name)
    if not os.path.isdir(project_root):
        print(f"ERROR: project not found: {project_root}", file=sys.stderr)
        return False

    pp = ProjectPaths.from_root(Path(project_root))
    shots_dir = str(pp.shots_dir)
    if not os.path.isdir(shots_dir):
        if verbose:
            print(f"  [{project_name}] No shots state dir found, skipping.")
        return True

    video_index = build_video_index(project_root)
    if verbose and not video_index:
        print(f"  [{project_name}] Warning: no mp4 files found in video output tree.")

    shot_files = sorted(
        os.path.join(shots_dir, fn)
        for fn in os.listdir(shots_dir)
        if fn.endswith(".json")
    )

    all_results = []
    for shot_path in shot_files:
        results = process_shot_file(shot_path, project_root, video_index, apply)
        all_results.extend(results)

    # ── Report ──
    stale_total = len(all_results)
    matched = sum(1 for r in all_results if r["resolved"])
    unresolvable = stale_total - matched

    mode_label = "APPLY" if apply else "DRY-RUN"
    print(f"\n[{project_name}] {mode_label}")
    print(f"  Shot files scanned : {len(shot_files)}")
    print(f"  Stale paths found  : {stale_total}")
    print(f"  Matched            : {matched}")
    print(f"  Unresolvable       : {unresolvable}")

    if all_results:
        print()
        col_w = 52
        for r in all_results:
            status = "MATCH" if r["resolved"] else "MISS "
            old_short = r["old_path"]
            new_short = r["new_path"] if r["new_path"] else "(no match)"
            method = r["field"] + " | " + (r["method"] or "-")
            if apply and r["resolved"]:
                action = "UPDATED"
            elif r["resolved"]:
                action = "would update"
            else:
                action = "UNRESOLVABLE"
            print(f"  [{status}] {action:12s}  {old_short}")
            print(f"           -> {new_short}  ({method})")

    return True


# ── Entry point ───────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description="Reconcile stale video take paths in shot state JSON files."
    )
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--project", metavar="NAME", help="Project name under projects/")
    group.add_argument("--all", action="store_true", help="Process all projects")

    mode = parser.add_mutually_exclusive_group()
    mode.add_argument(
        "--dry-run",
        dest="apply",
        action="store_false",
        default=False,
        help="Print what would change (default)",
    )
    mode.add_argument(
        "--apply",
        dest="apply",
        action="store_true",
        help="Write changes to disk (atomic)",
    )
    parser.set_defaults(apply=False)

    args = parser.parse_args()

    if args.all:
        if not os.path.isdir(PROJECTS_ROOT):
            print(f"ERROR: PROJECTS_ROOT not found: {PROJECTS_ROOT}", file=sys.stderr)
            sys.exit(1)
        projects = sorted(
            name
            for name in os.listdir(PROJECTS_ROOT)
            if os.path.isdir(os.path.join(PROJECTS_ROOT, name))
            and not name.startswith("_")
            and not name.startswith(".")
        )
        ok = True
        for project in projects:
            ok = reconcile_project(project, args.apply) and ok
        sys.exit(0 if ok else 1)
    else:
        ok = reconcile_project(args.project, args.apply)
        sys.exit(0 if ok else 1)


if __name__ == "__main__":
    main()
