#!/usr/bin/env python3
"""
Apply script-doctor annotations to episode files automatically.

Handles REWRITE and DELETE actions. Skips FLAG annotations (logged for manual review).
Re-validates each modified episode with episode_metrics.py after changes.

Usage:
    python3 apply_annotations.py [project_path]
    python3 apply_annotations.py [project_path] --dry-run
    python3 apply_annotations.py [project_path] --p1-only
"""

import argparse
import json
import os
import re
import shutil
import subprocess
import sys
from datetime import datetime
from pathlib import Path


def extract_replacement(note: str) -> str | None:
    """Extract replacement text from annotation note field."""
    # Look for REPLACEMENT: "..." pattern
    match = re.search(r'REPLACEMENT:\s*"(.+)"', note, re.DOTALL)
    if match:
        return match.group(1)
    # Try single quotes
    match = re.search(r"REPLACEMENT:\s*'(.+)'", note, re.DOTALL)
    if match:
        return match.group(1)
    return None


def load_annotations(project_path: str) -> dict:
    """Load annotations from script_doctor_annotations.json."""
    ann_path = os.path.join(project_path, "state", "script_doctor_annotations.json")
    if not os.path.exists(ann_path):
        print(f"ERROR: Annotations file not found: {ann_path}")
        sys.exit(1)
    with open(ann_path) as f:
        return json.load(f)


def group_by_episode(annotations: list) -> dict:
    """Group annotations by episode number."""
    groups = {}
    for ann in annotations:
        ep = ann["episode"]
        if ep not in groups:
            groups[ep] = []
        groups[ep].append(ann)
    return dict(sorted(groups.items()))


def backup_episode(ep_path: str, backup_dir: str):
    """Create backup of episode before modification."""
    os.makedirs(backup_dir, exist_ok=True)
    fname = os.path.basename(ep_path)
    backup_path = os.path.join(backup_dir, fname)
    if not os.path.exists(backup_path):
        shutil.copy2(ep_path, backup_path)


def validate_episode(ep_path: str) -> dict:
    """Run episode_metrics.py and return results."""
    recoil_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    metrics_script = os.path.join(recoil_root, "tools", "episode_metrics.py")
    try:
        result = subprocess.run(
            [sys.executable, metrics_script, ep_path, "--json"],
            capture_output=True, text=True, timeout=30
        )
        return json.loads(result.stdout or result.stderr)
    except Exception as e:
        return {"is_valid": False, "error": str(e)}


def apply_annotations_to_episode(ep_path: str, annotations: list, dry_run: bool = False) -> dict:
    """Apply all annotations to a single episode file."""
    with open(ep_path) as f:
        content = f.read()

    original = content
    applied = []
    skipped = []
    failed = []

    for ann in annotations:
        action = ann["action"]
        selected = ann.get("selected_text", "")
        note = ann.get("note", "")
        finding_id = ann.get("finding_id", "?")
        severity = ann.get("severity", "?")

        if action == "FLAG":
            skipped.append({
                "finding_id": finding_id,
                "severity": severity,
                "reason": "FLAG — requires manual review",
                "note": note
            })
            continue

        if action == "REWRITE":
            replacement = extract_replacement(note)
            if not replacement:
                failed.append({
                    "finding_id": finding_id,
                    "reason": "Could not extract replacement text from note",
                    "note": note
                })
                continue

            if selected in content:
                content = content.replace(selected, replacement, 1)
                applied.append({
                    "finding_id": finding_id,
                    "action": "REWRITE",
                    "old": selected[:60] + "..." if len(selected) > 60 else selected,
                    "new": replacement[:60] + "..." if len(replacement) > 60 else replacement
                })
            else:
                failed.append({
                    "finding_id": finding_id,
                    "reason": "Selected text not found in episode",
                    "selected_text": selected[:80]
                })

        elif action == "DELETE":
            if selected in content:
                content = content.replace(selected, "", 1)
                # Clean up double newlines from deletion
                content = re.sub(r'\n{3,}', '\n\n', content)
                applied.append({
                    "finding_id": finding_id,
                    "action": "DELETE",
                    "deleted": selected[:60] + "..." if len(selected) > 60 else selected
                })
            else:
                failed.append({
                    "finding_id": finding_id,
                    "reason": "Selected text not found for deletion",
                    "selected_text": selected[:80]
                })

    changed = content != original

    if changed and not dry_run:
        with open(ep_path, 'w') as f:
            f.write(content)

    return {
        "changed": changed,
        "applied": applied,
        "skipped": skipped,
        "failed": failed
    }


def main():
    parser = argparse.ArgumentParser(description="Apply script-doctor annotations")
    parser.add_argument("project", help="Project path")
    parser.add_argument("--dry-run", action="store_true", help="Preview without writing")
    parser.add_argument("--p1-only", action="store_true", help="Apply P1 severity only")
    args = parser.parse_args()

    project_path = args.project
    data = load_annotations(project_path)
    all_annotations = data["annotations"]

    if args.p1_only:
        all_annotations = [a for a in all_annotations if a.get("severity") == "P1"]

    grouped = group_by_episode(all_annotations)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    backup_dir = os.path.join(project_path, "backups", f"pre_revision_{timestamp}")

    print("=" * 60)
    print(f"APPLYING ANNOTATIONS — {'DRY RUN' if args.dry_run else 'LIVE'}")
    print("=" * 60)
    print(f"  Total annotations: {len(all_annotations)}")
    print(f"  Episodes affected: {len(grouped)}")
    if args.p1_only:
        print(f"  Filter: P1 only")
    print(f"  Backup dir: {backup_dir}")
    print()

    total_applied = 0
    total_skipped = 0
    total_failed = 0
    validation_failures = []

    for ep_num, annotations in grouped.items():
        ep_path = os.path.join(project_path, "episodes", f"ep_{ep_num:03d}.md")
        if not os.path.exists(ep_path):
            print(f"  [SKIP] Episode {ep_num}: file not found")
            continue

        # Backup
        if not args.dry_run:
            backup_episode(ep_path, backup_dir)

        result = apply_annotations_to_episode(ep_path, annotations, args.dry_run)

        n_applied = len(result["applied"])
        n_skipped = len(result["skipped"])
        n_failed = len(result["failed"])
        total_applied += n_applied
        total_skipped += n_skipped
        total_failed += n_failed

        status = "MODIFIED" if result["changed"] else "unchanged"
        print(f"  Ep {ep_num:3d}: {n_applied} applied, {n_skipped} skipped, {n_failed} failed — {status}")

        for f in result["failed"]:
            print(f"         FAIL: {f['finding_id']} — {f['reason']}")

        # Validate if modified
        if result["changed"] and not args.dry_run:
            val = validate_episode(ep_path)
            if not val.get("is_valid", False):
                wc = val.get("word_count", "?")
                validation_failures.append((ep_num, wc))
                print(f"         VALIDATION FAIL: {wc} words")

    print()
    print("=" * 60)
    print("SUMMARY")
    print("=" * 60)
    print(f"  Applied:  {total_applied}")
    print(f"  Skipped:  {total_skipped} (FLAG — manual review)")
    print(f"  Failed:   {total_failed} (text not found)")
    if validation_failures:
        print(f"\n  VALIDATION FAILURES ({len(validation_failures)}):")
        for ep_num, wc in validation_failures:
            print(f"    Episode {ep_num}: {wc} words")
    else:
        print(f"\n  All modified episodes pass validation")
    print("=" * 60)

    # Save revision log
    log_path = os.path.join(project_path, "state", "revision_log.json")
    log = {
        "timestamp": datetime.now().isoformat(),
        "dry_run": args.dry_run,
        "total_applied": total_applied,
        "total_skipped": total_skipped,
        "total_failed": total_failed,
        "validation_failures": validation_failures
    }
    with open(log_path, 'w') as f:
        json.dump(log, f, indent=2)


if __name__ == "__main__":
    main()
