#!/usr/bin/env python3
"""
Storyboard Versioning with Undo.

Tracks storyboard revisions per shot, enabling:
  - Version history with timestamps and change summaries
  - Undo/revert to any previous version
  - Per-shot diff between versions
  - Changed-shot detection for selective previz regeneration

Versions are stored alongside the storyboard in a _versions/ subfolder.

Usage:
  python3 storyboard_version.py save <storyboard.json> [--message "reason"]
  python3 storyboard_version.py list <storyboard.json>
  python3 storyboard_version.py diff <storyboard.json> [v1] [v2]
  python3 storyboard_version.py revert <storyboard.json> <version>
  python3 storyboard_version.py changed <storyboard.json> [v1] [v2]

Exit codes:
  0 = success
  1 = operation failed
  2 = file/parse error
"""

import argparse
import hashlib
import json
import os
import shutil
import sys
from datetime import datetime, timezone


def _versions_dir(storyboard_path):
    """Get the _versions/ directory for a storyboard file."""
    sb_dir = os.path.dirname(os.path.abspath(storyboard_path))
    return os.path.join(sb_dir, "_versions")


def _manifest_path(storyboard_path):
    """Get the manifest.json path."""
    return os.path.join(_versions_dir(storyboard_path), "manifest.json")


def _load_manifest(storyboard_path):
    """Load or initialize the version manifest."""
    mpath = _manifest_path(storyboard_path)
    if os.path.exists(mpath):
        try:
            with open(mpath) as f:
                return json.load(f)
        except json.JSONDecodeError:
            print(f"WARNING: Corrupt manifest {mpath}, reinitializing")
    return {
        "storyboard": os.path.basename(storyboard_path),
        "versions": [],
        "current_version": 0,
    }


def _save_manifest(storyboard_path, manifest):
    """Save the version manifest."""
    vdir = _versions_dir(storyboard_path)
    os.makedirs(vdir, exist_ok=True)
    mpath = _manifest_path(storyboard_path)
    with open(mpath, "w") as f:
        json.dump(manifest, f, indent=2)


def _hash_shot(shot):
    """Create a content hash for a shot based on key fields.

    Hashes the fields that affect generation output — changes to these
    fields mean the shot needs re-generation (or at least re-previz).
    """
    key_fields = {
        "shot_type": shot.get("shot_type", ""),
        "camera_angle": shot.get("camera_angle", ""),
        "camera_movement": shot.get("camera_movement", ""),
        "subject": shot.get("subject", ""),
        "action": shot.get("action", ""),
        "first_frame": shot.get("first_frame", ""),
        "last_frame": shot.get("last_frame", ""),
        "hero_frame": shot.get("hero_frame") or "",
        "triptych_prompt": shot.get("triptych_prompt") or "",
        "motion_prompt": shot.get("motion_prompt", ""),
        "lighting": shot.get("lighting", ""),
        "generation_approach": shot.get("generation_approach", ""),
        "characters_in_shot": sorted(shot.get("characters_in_shot", [])),
        "width": shot.get("width", 0),
        "height": shot.get("height", 0),
    }
    blob = json.dumps(key_fields, sort_keys=True)
    return hashlib.sha256(blob.encode()).hexdigest()[:12]


def _hash_storyboard(storyboard):
    """Create a per-shot hash map for the whole storyboard."""
    hashes = {}
    for shot in storyboard.get("shots", []):
        sid = shot.get("id", 0)
        hashes[str(sid)] = _hash_shot(shot)
    return hashes


def save_version(storyboard_path, message=None):
    """Save the current storyboard as a new version.

    Returns the version number.
    """
    if not os.path.exists(storyboard_path):
        print(f"ERROR: Storyboard not found: {storyboard_path}")
        return None

    try:
        with open(storyboard_path) as f:
            storyboard = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON in {storyboard_path}: {e}")
        return None

    manifest = _load_manifest(storyboard_path)
    version = len(manifest["versions"]) + 1
    shot_hashes = _hash_storyboard(storyboard)

    # Compute change summary vs previous version
    changes = []
    if manifest["versions"]:
        prev = manifest["versions"][-1]
        prev_hashes = prev.get("shot_hashes", {})
        for sid, h in shot_hashes.items():
            if sid not in prev_hashes:
                changes.append(f"Shot #{sid}: ADDED")
            elif prev_hashes[sid] != h:
                changes.append(f"Shot #{sid}: MODIFIED")
        for sid in prev_hashes:
            if sid not in shot_hashes:
                changes.append(f"Shot #{sid}: REMOVED")
        if not changes:
            changes.append("No changes from previous version")

    # Save versioned copy
    vdir = _versions_dir(storyboard_path)
    os.makedirs(vdir, exist_ok=True)
    sb_name = os.path.splitext(os.path.basename(storyboard_path))[0]
    version_filename = f"{sb_name}_v{version:03d}.json"
    version_path = os.path.join(vdir, version_filename)
    shutil.copy2(storyboard_path, version_path)

    # Update manifest
    entry = {
        "version": version,
        "filename": version_filename,
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "message": message or "",
        "shot_count": len(storyboard.get("shots", [])),
        "shot_hashes": shot_hashes,
        "changes": changes,
    }
    manifest["versions"].append(entry)
    manifest["current_version"] = version
    _save_manifest(storyboard_path, manifest)

    return version


def list_versions(storyboard_path):
    """List all versions with timestamps and change summaries."""
    manifest = _load_manifest(storyboard_path)
    versions = manifest.get("versions", [])
    if not versions:
        print("No versions saved yet.")
        return

    current = manifest.get("current_version", 0)
    print(f"=== Version History: {manifest['storyboard']} ===")
    print()
    for v in versions:
        marker = " ← CURRENT" if v["version"] == current else ""
        ts = v.get("timestamp", "?")[:19].replace("T", " ")
        msg = v.get("message", "")
        msg_str = f' — "{msg}"' if msg else ""
        print(f"  v{v['version']:03d}  {ts}  ({v['shot_count']} shots){msg_str}{marker}")
        for change in v.get("changes", [])[:5]:
            print(f"        {change}")
        if len(v.get("changes", [])) > 5:
            print(f"        ...and {len(v['changes']) - 5} more")
    print()


def diff_versions(storyboard_path, v1=None, v2=None):
    """Show per-shot differences between two versions.

    If v1/v2 not specified, diffs the last two versions.
    """
    manifest = _load_manifest(storyboard_path)
    versions = manifest.get("versions", [])

    if len(versions) < 2 and (v1 is None or v2 is None):
        print("Need at least 2 versions to diff.")
        return []

    if v1 is None:
        v1 = versions[-2]["version"]
    if v2 is None:
        v2 = versions[-1]["version"]

    entry1 = next((v for v in versions if v["version"] == v1), None)
    entry2 = next((v for v in versions if v["version"] == v2), None)

    if not entry1 or not entry2:
        print(f"ERROR: Version(s) not found. Available: {[v['version'] for v in versions]}")
        return []

    hashes1 = entry1.get("shot_hashes", {})
    hashes2 = entry2.get("shot_hashes", {})

    changes = []

    # Modified
    for sid in sorted(hashes2.keys(), key=lambda x: int(x)):
        if sid in hashes1 and hashes1[sid] != hashes2[sid]:
            changes.append({"shot_id": int(sid), "type": "modified"})
        elif sid not in hashes1:
            changes.append({"shot_id": int(sid), "type": "added"})

    # Removed
    for sid in sorted(hashes1.keys(), key=lambda x: int(x)):
        if sid not in hashes2:
            changes.append({"shot_id": int(sid), "type": "removed"})

    print(f"=== Diff: v{v1:03d} → v{v2:03d} ===")
    print()
    if not changes:
        print("  No changes.")
    else:
        for c in changes:
            marker = {"modified": "~", "added": "+", "removed": "-"}[c["type"]]
            print(f"  {marker} Shot #{c['shot_id']}: {c['type'].upper()}")
    print()
    print(f"  {len(changes)} shot(s) changed.")

    return changes


def get_changed_shots(storyboard_path, v1=None, v2=None):
    """Get list of shot IDs that changed between versions.

    Used by previz pipeline to regenerate only changed shots.
    """
    manifest = _load_manifest(storyboard_path)
    versions = manifest.get("versions", [])

    if len(versions) < 2 and (v1 is None or v2 is None):
        return []

    if v1 is None:
        v1 = versions[-2]["version"]
    if v2 is None:
        v2 = versions[-1]["version"]

    entry1 = next((v for v in versions if v["version"] == v1), None)
    entry2 = next((v for v in versions if v["version"] == v2), None)

    if not entry1 or not entry2:
        return []

    hashes1 = entry1.get("shot_hashes", {})
    hashes2 = entry2.get("shot_hashes", {})

    changed = []
    for sid in hashes2:
        if sid not in hashes1 or hashes1[sid] != hashes2[sid]:
            changed.append(int(sid))

    return sorted(changed)


def revert(storyboard_path, version):
    """Revert the storyboard to a previous version.

    This saves the current state as a new version first (so you can undo the undo),
    then restores the target version as the active storyboard.
    """
    manifest = _load_manifest(storyboard_path)
    versions = manifest.get("versions", [])

    entry = next((v for v in versions if v["version"] == version), None)
    if not entry:
        print(f"ERROR: Version {version} not found. Available: {[v['version'] for v in versions]}")
        return False

    # Save current state first (undo safety net)
    current_version = save_version(
        storyboard_path,
        message=f"Auto-save before revert to v{version:03d}",
    )
    if current_version is None:
        return False

    # Restore target version
    vdir = _versions_dir(storyboard_path)
    version_file = os.path.join(vdir, entry["filename"])

    if not os.path.exists(version_file):
        print(f"ERROR: Version file missing: {version_file}")
        return False

    shutil.copy2(version_file, storyboard_path)

    # Update manifest current_version (but don't remove newer entries)
    manifest = _load_manifest(storyboard_path)  # Re-load after save_version
    manifest["current_version"] = version
    _save_manifest(storyboard_path, manifest)

    print(f"Reverted to v{version:03d}. Current state saved as v{current_version:03d}.")
    return True


def main():
    parser = argparse.ArgumentParser(
        description="Storyboard versioning with undo"
    )
    subparsers = parser.add_subparsers(dest="command", required=True)

    # save
    save_p = subparsers.add_parser("save", help="Save current storyboard as new version")
    save_p.add_argument("storyboard", help="Path to storyboard JSON")
    save_p.add_argument("--message", "-m", help="Version message")

    # list
    list_p = subparsers.add_parser("list", help="List version history")
    list_p.add_argument("storyboard", help="Path to storyboard JSON")

    # diff
    diff_p = subparsers.add_parser("diff", help="Diff between versions")
    diff_p.add_argument("storyboard", help="Path to storyboard JSON")
    diff_p.add_argument("v1", nargs="?", type=int, help="First version (default: previous)")
    diff_p.add_argument("v2", nargs="?", type=int, help="Second version (default: latest)")
    diff_p.add_argument("--json", action="store_true", help="Output as JSON")

    # revert
    revert_p = subparsers.add_parser("revert", help="Revert to a previous version")
    revert_p.add_argument("storyboard", help="Path to storyboard JSON")
    revert_p.add_argument("version", type=int, help="Version number to revert to")

    # changed
    changed_p = subparsers.add_parser(
        "changed", help="List shot IDs that changed between versions"
    )
    changed_p.add_argument("storyboard", help="Path to storyboard JSON")
    changed_p.add_argument("v1", nargs="?", type=int, help="First version")
    changed_p.add_argument("v2", nargs="?", type=int, help="Second version")

    args = parser.parse_args()

    if args.command == "save":
        v = save_version(args.storyboard, message=args.message)
        if v:
            print(f"Saved version {v}.")
        else:
            sys.exit(2)

    elif args.command == "list":
        list_versions(args.storyboard)

    elif args.command == "diff":
        changes = diff_versions(args.storyboard, args.v1, args.v2)
        if hasattr(args, "json") and args.json:
            print(json.dumps(changes, indent=2))

    elif args.command == "revert":
        if not revert(args.storyboard, args.version):
            sys.exit(1)

    elif args.command == "changed":
        changed = get_changed_shots(args.storyboard, args.v1, args.v2)
        if changed:
            print(f"Changed shots: {changed}")
        else:
            print("No changed shots.")


if __name__ == "__main__":
    main()
