# api/routes/manual.py
"""Manual workbench endpoints — Phase 6.

Ported from review_server.py:
  - _api_manual_shots (line 8390)
  - _api_manual_escalate (line 8337)
  - _api_manual_export (line 8537)
  - _api_manual_reimport (line 8610)
  - _api_manual_resolve (line 8781)
  - _api_reveal_in_finder (line 5763)
"""

import json
import re
import shutil
import subprocess
import time
from pathlib import Path

from fastapi import APIRouter, Body, Depends, Query
from fastapi.responses import JSONResponse

from ..deps import get_project, get_paths, get_store, _paths_for_project
from ..state import submit_task

router = APIRouter(tags=["manual"])

# Legacy model ID map — accept old IDs from clients, resolve to canonical names
LEGACY_MODEL_MAP = {
    "kling-3.0": "kling-v3",
    "kling-3.0-fal": "kling-v3",
    "kling-v3-fal": "kling-v3",
    "kling-o3-fal": "kling-o3",
}


# ── Helper: resolve output-relative path to absolute ─────────────

def _resolve_output_path(rel_path: str, paths: dict) -> str:
    """Resolve an output/-relative path to an absolute path.
    Uses project output only -- no cross-project fallback."""
    if rel_path.startswith("output/"):
        stripped = rel_path.replace("output/", "", 1)
        return str(paths["output_dir"] / stripped)
    return str(Path(rel_path).resolve())


# ── Endpoints ─────────────────────────────────────────────────────

@router.get("/api/manual/shots/{ep_id}")
def manual_shots(
    ep_id: str,
    mode: str = Query("flagged"),
    project: str = Depends(get_project),
    paths: dict = Depends(get_paths),
    store=Depends(get_store),
):
    """Shots for manual intervention.

    ep_id can be an episode (EP001, ep_001) or "all" to get every shot.
    Query params:
      mode=flagged (default) -- only manual_escalated shots
      mode=all -- all shots with at least one take
    """
    if mode not in ("flagged", "all"):
        mode = "flagged"

    # Determine whether to load all or a single episode
    load_all = ep_id.lower() == "all"

    if load_all:
        all_shots = store.get_all_shots()
        episode_id = "ALL"
    else:
        ep_match = re.match(r"(?:EP|ep[_]?)(\d+)", ep_id)
        if not ep_match:
            return JSONResponse({"error": f"Invalid episode format: {ep_id}"}, status_code=400)
        ep_num = int(ep_match.group(1))
        episode_id = f"EP{ep_num:03d}"
        all_shots = store.get_shots_by_episode(episode_id)

    # Build a plan cache for shot metadata (keyed by ep_num)
    _plan_cache = {}

    def _get_plan_shots(shot_ep_num):
        if shot_ep_num not in _plan_cache:
            plan_path = paths["plans_dir"] / f"ep_{shot_ep_num:03d}_plan.json"
            plan_shots = {}
            if plan_path.exists():
                try:
                    plan = json.loads(plan_path.read_text(encoding="utf-8"))
                    for s in plan.get("shots", []):
                        plan_shots[s.get("shot_id", "")] = s
                except (json.JSONDecodeError, IOError):
                    pass
            _plan_cache[shot_ep_num] = plan_shots
        return _plan_cache[shot_ep_num]

    result_shots = []
    for shot in all_shots:
        gate = shot.get("gate_results", {})
        takes = shot.get("takes", [])

        # Apply mode-specific filter
        if mode == "flagged":
            if not gate.get("manual_escalated"):
                continue
        else:  # mode == "all"
            if len(takes) == 0:
                continue

        # Determine episode number from shot_id
        sid_match = re.match(r"EP(\d+)", shot.get("shot_id", ""))
        shot_ep_num = int(sid_match.group(1)) if sid_match else 0
        plan_data = _get_plan_shots(shot_ep_num).get(shot["shot_id"], {})

        # Find the latest output path for thumbnail
        latest_output = None
        if takes:
            latest_take = takes[-1]
            latest_output = latest_take.get("file_path") or latest_take.get("url")

        # Target frame: the "before" image for comparison.
        # Priority: hero/approved take -> second-to-last take -> output_path
        target_frame = None
        approved_take = next((t for t in takes if t.get("approved") or t.get("is_hero")), None)
        if approved_take:
            target_frame = approved_take.get("file_path") or approved_take.get("url")
        if not target_frame and len(takes) >= 2:
            # Use second-to-last take as the "before" (latest is the "after")
            prev_take = takes[-2]
            target_frame = prev_take.get("file_path") or prev_take.get("url")
        if not target_frame:
            target_frame = shot.get("output_path")

        shot_ep_id = shot.get("episode_id") or (f"EP{shot_ep_num:03d}" if shot_ep_num else "")

        shot_obj = {
            "shot_id": shot["shot_id"],
            "episode_id": shot_ep_id,
            "status": shot.get("status", ""),
            "pipeline": shot.get("pipeline", ""),
            "model": shot.get("model", ""),
            "latest_output": latest_output,
            "target_frame": target_frame,
            "hero_frame": gate.get("hero_frame"),
            "video_path": gate.get("video_path"),
            "prompt": plan_data.get("prompt_data", {}).get("prompt_skeleton", {}),
            "shot_type": plan_data.get("shot_type", ""),
            "camera": plan_data.get("camera", ""),
            "characters": [c.get("char_id", "") for c in plan_data.get("asset_data", {}).get("characters", [])],
            "action": plan_data.get("action_description", ""),
            "manual_escalated_at": gate.get("manual_escalated_at"),
            "manual_resolved": gate.get("manual_resolved", False),
            "manual_fixes": gate.get("manual_fixes", []),
            "failure_type": gate.get("failure_type"),
            "takes": takes,
        }

        # In all mode, expose manual_escalated flag for frontend badge
        if mode == "all":
            shot_obj["manual_escalated"] = gate.get("manual_escalated", False)

        result_shots.append(shot_obj)

    if mode == "flagged":
        # Sort by escalation time (most recent first)
        result_shots.sort(key=lambda s: s.get("manual_escalated_at", 0), reverse=True)

        return JSONResponse({
            "episode": episode_id,
            "shots": result_shots,
            "total": len(result_shots),
            "unresolved": len([s for s in result_shots if not s.get("manual_resolved")]),
        })
    else:  # mode == "all"
        # Sort by shot_id naturally (narrative order)
        result_shots.sort(key=lambda s: s.get("shot_id", ""))

        flagged_count = len([s for s in result_shots if s.get("manual_escalated")])

        return JSONResponse({
            "episode": episode_id,
            "shots": result_shots,
            "total": len(result_shots),
            "flagged_count": flagged_count,
            "unresolved": len([s for s in result_shots if not s.get("manual_resolved")]),
        })


@router.post("/api/manual/escalate")
def manual_escalate(
    body: dict = Body(default={}),
    project: str = Depends(get_project),
    store=Depends(get_store),
):
    """Flag a shot for manual intervention.

    Body: {"shot_id": "EP001_SH01", "failure_type": "artifacts"}
    Sets gate_results.manual_escalated = true without changing formal status.
    """
    shot_id = body.get("shot_id")
    if not shot_id:
        return JSONResponse({"error": "Missing shot_id"}, status_code=400)

    failure_type = body.get("failure_type")
    valid_failure_types = {"composition", "artifacts", "motion", "safety_filter", "character"}
    if failure_type and failure_type not in valid_failure_types:
        return JSONResponse({"error": f"Invalid failure_type '{failure_type}'"}, status_code=400)

    shot = store.get_shot(shot_id)
    if shot is None:
        return JSONResponse({"error": f"Shot not found: {shot_id}"}, status_code=404)

    gate = shot.get("gate_results", {})
    if gate.get("manual_escalated"):
        # If already escalated but now providing a failure_type, update it
        if failure_type and gate.get("failure_type") != failure_type:
            store.update_shot(shot_id, gate_results={
                "failure_type": failure_type,
            })
            return JSONResponse({"ok": True, "shot_id": shot_id, "already": True, "failure_type_updated": True})
        else:
            return JSONResponse({"ok": True, "shot_id": shot_id, "already": True})

    gate_update = {
        "manual_escalated": True,
        "manual_escalated_at": time.time(),
    }
    if failure_type:
        gate_update["failure_type"] = failure_type

    store.update_shot(shot_id, gate_results=gate_update)

    return JSONResponse({"ok": True, "shot_id": shot_id})


@router.post("/api/manual/export")
def manual_export(
    body: dict = Body(default={}),
    project: str = Depends(get_project),
):
    """Export bundle(s) for manual web UI work.

    Body: {"shot_ids": ["EP001_SH01"], "target_model": "kling-v3"}
    Wraps build_bundle() in a background task via submit_task.
    """
    shot_ids_raw = body.get("shot_ids", [])
    target_model = body.get("target_model", "kling-v3")
    target_model = LEGACY_MODEL_MAP.get(target_model, target_model)

    if not shot_ids_raw:
        return JSONResponse({"error": "Missing shot_ids"}, status_code=400)

    # Group shot IDs by episode for build_bundle calls
    episodes = {}
    for sid in shot_ids_raw:
        ep_match = re.match(r"EP(\d+)_SH(\d+)", sid)
        if ep_match:
            ep_num = int(ep_match.group(1))
            shot_num = int(ep_match.group(2))
            episodes.setdefault(ep_num, []).append(shot_num)

    if not episodes:
        return JSONResponse({"error": "No valid shot IDs provided"}, status_code=400)

    try:
        from tools.build_upload_bundle import build_bundle
    except ImportError as e:
        return JSONResponse({"error": f"build_upload_bundle not available: {e}"}, status_code=503)

    def _bg_export():
        results = []
        for ep_num, shot_nums in episodes.items():
            try:
                bundle_path = build_bundle(
                    episode=ep_num,
                    shot_ids=shot_nums,
                    model=target_model,
                    project=project,
                )
                if bundle_path:
                    results.append({
                        "episode": ep_num,
                        "shots": shot_nums,
                        "bundle_path": str(bundle_path),
                    })
            except Exception as e:
                print(f"  [ERR] Bundle export failed for EP{ep_num:03d}: {e}")
                results.append({
                    "episode": ep_num,
                    "shots": shot_nums,
                    "error": str(e),
                })
        return results

    task_id = submit_task(
        entity_id=",".join(shot_ids_raw),
        action="manual_export",
        fn=_bg_export,
    )

    return JSONResponse({
        "ok": True,
        "task_id": task_id,
        "status": "exporting",
        "episodes": list(episodes.keys()),
        "model": target_model,
    })


@router.post("/api/manual/reimport")
def manual_reimport(
    body: dict = Body(default={}),
    project: str = Depends(get_project),
    paths: dict = Depends(get_paths),
    store=Depends(get_store),
):
    """Re-import a manually fixed asset.

    Accepts EITHER:
      - file_data (base64) + file_name: from browser drag & drop
      - file_path: absolute disk path (CLI/scripting use)

    Body: {"shot_id": "EP001_SH01", "file_data": "base64...", "file_name": "fix.png",
           "failure_type": "artifacts", "fix_type": "manual_intervention",
           "notes": "optional notes"}
    """
    shot_id = body.get("shot_id")
    file_data = body.get("file_data")  # base64 from browser
    file_name = body.get("file_name")  # original filename from browser
    file_path = body.get("file_path")  # absolute path (CLI fallback)
    failure_type = body.get("failure_type")

    if not shot_id:
        return JSONResponse({"error": "Missing shot_id"}, status_code=400)
    if not file_data and not file_path:
        return JSONResponse({"error": "Missing file_data or file_path"}, status_code=400)
    if not failure_type:
        return JSONResponse(
            {"error": "Missing failure_type -- must be one of: composition, artifacts, motion, safety_filter, character"},
            status_code=400,
        )

    valid_failure_types = {"composition", "artifacts", "motion", "safety_filter", "character"}
    if failure_type not in valid_failure_types:
        return JSONResponse(
            {"error": f"Invalid failure_type '{failure_type}'. Must be one of: {', '.join(sorted(valid_failure_types))}"},
            status_code=400,
        )

    shot = store.get_shot(shot_id)
    if shot is None:
        return JSONResponse({"error": f"Shot not found: {shot_id}"}, status_code=404)

    # Determine episode from shot_id
    ep_match = re.match(r"EP(\d+)_SH(\d+)", shot_id)
    if not ep_match:
        return JSONResponse({"error": f"Invalid shot_id format: {shot_id}"}, status_code=400)
    ep_num = int(ep_match.group(1))

    # Type-routed destination
    VIDEO_EXTENSIONS = {".mp4", ".mov", ".webm", ".avi"}
    IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff"}

    # Determine suffix from original filename
    if file_data:
        suffix = Path(file_name).suffix if file_name else ".png"
    else:
        suffix = Path(file_path).suffix or ".png"

    suffix_lower = suffix.lower()
    if suffix_lower in VIDEO_EXTENSIONS:
        dest_dir = paths["video_dir"] / f"ep_{ep_num:03d}"
        asset_type = "video"
    elif suffix_lower in IMAGE_EXTENSIONS:
        dest_dir = paths["frames_dir"] / f"ep_{ep_num:03d}"
        asset_type = "frame"
    else:
        return JSONResponse({"error": f"Unsupported file type: {suffix}"}, status_code=400)

    dest_dir.mkdir(parents=True, exist_ok=True)

    # Dedup guard: reject rapid duplicate imports
    takes = shot.get("takes", [])
    if takes:
        last_take = takes[-1]
        if (last_take.get("layer") == "manual_fix" and
            time.time() - last_take.get("timestamp", 0) < 5):
            return JSONResponse({
                "error": "Duplicate import detected. Wait a moment before reimporting.",
            }, status_code=409)

    # Incremental filename: count existing manual_fix takes
    existing_fixes = [t for t in takes if t.get("layer") == "manual_fix"]
    fix_num = len(existing_fixes) + 1
    dest_name = f"{shot_id}_manual_fix_{fix_num:02d}{suffix}"

    # Filesystem collision guard
    dest = dest_dir / dest_name
    while dest.exists():
        fix_num += 1
        dest_name = f"{shot_id}_manual_fix_{fix_num:02d}{suffix}"
        dest = dest_dir / dest_name

    if file_data:
        # Browser upload path: decode base64 and write
        import base64 as _b64
        try:
            raw = _b64.b64decode(file_data)
            dest.write_bytes(raw)
        except Exception as exc:
            return JSONResponse({"error": f"Failed to decode file_data: {exc}"}, status_code=400)
    else:
        # Disk path fallback (CLI/scripting)
        src = Path(file_path)
        if not src.exists():
            return JSONResponse({"error": f"Source file not found: {file_path}"}, status_code=404)
        # Path traversal guard
        try:
            src.resolve().relative_to(Path.home())
        except ValueError:
            return JSONResponse({"error": "File path must be within home directory"}, status_code=400)
        shutil.copy2(str(src), str(dest))

    # Build relative output path for store
    rel_path = str(dest.relative_to(paths["project_dir"]))

    # Infer fix type
    fix_type = body.get("fix_type", "manual_intervention")

    # Build manual_fixes entry
    fix_entry = {
        "timestamp": time.time(),
        "failure_type": failure_type,
        "fix_type": fix_type,
        "model_used": shot.get("pipeline") or shot.get("model") or "",
        "enrichment_used": False,
        "notes": body.get("notes", ""),
        "reimported_path": rel_path,
    }

    # Append to gate_results.manual_fixes
    gate = shot.get("gate_results", {})
    manual_fixes = gate.get("manual_fixes", [])
    manual_fixes.append(fix_entry)

    # Build take entry with take_id and asset_type
    take_entry = {
        "file_path": rel_path,
        "layer": "manual_fix",
        "asset_type": asset_type,  # "video" or "frame"
        "take_id": f"{shot_id}_MF{int(time.time()) % 100000:05d}",
        "timestamp": time.time(),
    }

    # Single atomic update: promote hero + append take
    store.update_shot(shot_id,
        gate_results={
            "manual_fixes": manual_fixes,
            "hero_frame": rel_path,
        },
        output_path=rel_path,
        append_take=take_entry,
    )

    return JSONResponse({
        "ok": True,
        "shot_id": shot_id,
        "dest_path": rel_path,
        "fix_entry": fix_entry,
    })


@router.post("/api/manual/resolve")
def manual_resolve(
    body: dict = Body(default={}),
    project: str = Depends(get_project),
    store=Depends(get_store),
):
    """Mark a shot as resolved from manual intervention.

    Body: {"shot_id": "EP001_SH01", "failure_type": "artifacts",
           "fix_type": "manual_intervention", "notes": "optional",
           "action": "return_to_pipeline" | "export_video_bundle"}
    """
    shot_id = body.get("shot_id")
    failure_type = body.get("failure_type")
    action = body.get("action", "return_to_pipeline")

    if not shot_id:
        return JSONResponse({"error": "Missing shot_id"}, status_code=400)
    if not failure_type:
        return JSONResponse({"error": "Missing failure_type"}, status_code=400)

    valid_failure_types = {"composition", "artifacts", "motion", "safety_filter", "character"}
    if failure_type not in valid_failure_types:
        return JSONResponse(
            {"error": f"Invalid failure_type '{failure_type}'. Must be one of: {', '.join(sorted(valid_failure_types))}"},
            status_code=400,
        )

    shot = store.get_shot(shot_id)
    if shot is None:
        return JSONResponse({"error": f"Shot not found: {shot_id}"}, status_code=404)

    gate = shot.get("gate_results", {})
    if not gate.get("manual_escalated"):
        return JSONResponse({"error": f"Shot {shot_id} was not escalated for manual intervention"}, status_code=400)

    if gate.get("manual_resolved"):
        return JSONResponse({"ok": True, "shot_id": shot_id, "already_resolved": True})

    # Auto-infer fix type if not explicitly provided
    fix_type = body.get("fix_type")
    if not fix_type:
        if body.get("prompt_edited"):
            fix_type = "prompt_edit"
        elif body.get("model_changed"):
            fix_type = "model_switch"
        elif body.get("source") == "bundle":
            fix_type = "manual_intervention"
        else:
            fix_type = "unknown"

    # Collect overrides (shot_type, camera, action changes)
    overrides = body.get("overrides", {})

    fix_entry = {
        "timestamp": time.time(),
        "failure_type": failure_type,
        "fix_type": fix_type,
        "model_used": shot.get("pipeline") or shot.get("model") or "",
        "enrichment_used": body.get("enrichment_used", False),
        "notes": body.get("notes", ""),
        "auto_tagged": False,
    }
    if overrides:
        fix_entry["overrides"] = overrides
    if body.get("new_prompt"):
        fix_entry["new_prompt"] = body["new_prompt"]
    if body.get("new_model"):
        fix_entry["new_model"] = body["new_model"]

    manual_fixes = gate.get("manual_fixes", [])
    manual_fixes.append(fix_entry)

    # Determine hero asset: latest manual_fix or latest take, whichever is newer
    takes = shot.get("takes", [])
    last_manual_fix = manual_fixes[-1] if manual_fixes else None
    last_take = takes[-1] if takes else None

    hero_path = None
    hero_asset_type = "frame"  # default

    if last_manual_fix and last_take:
        mf_ts = last_manual_fix.get("timestamp", 0)
        tk_ts = last_take.get("timestamp", 0)
        if mf_ts > tk_ts:
            hero_path = last_manual_fix.get("reimported_path")
            hero_asset_type = last_manual_fix.get("asset_type", "frame")
        else:
            hero_path = last_take.get("file_path")
            hero_asset_type = last_take.get("asset_type", "frame")
    elif last_manual_fix:
        hero_path = last_manual_fix.get("reimported_path")
        hero_asset_type = last_manual_fix.get("asset_type", "frame")
    elif last_take:
        hero_path = last_take.get("file_path")
        hero_asset_type = last_take.get("asset_type", "frame")

    # Status transition matrix based on action + asset type
    if action == "return_to_pipeline":
        if hero_asset_type == "video":
            target_status = "video_complete"
        elif hero_path:
            target_status = "keyframe_approved"
        else:
            target_status = "keyframe_pending"  # prompt-only fix, re-generate
    elif action == "export_video_bundle":
        if hero_asset_type == "video":
            target_status = "video_complete"
        else:
            target_status = "video_pending"
    else:
        return JSONResponse({"error": f"Unknown action: {action}"}, status_code=400)

    # Atomic store update with status transition and hero promotion
    update_gate = {
        "manual_resolved": True,
        "manual_resolved_at": time.time(),
        "manual_fixes": manual_fixes,
    }
    if overrides:
        update_gate["manual_overrides"] = overrides
    if hero_path:
        update_gate["hero_frame"] = hero_path

    shot_updates = {}
    if hero_path:
        shot_updates["output_path"] = hero_path
    if overrides.get("shot_type"):
        shot_updates["shot_type_override"] = overrides["shot_type"]
    if body.get("new_model"):
        shot_updates["pipeline"] = body["new_model"]

    # Use force_reset_status for the status change (manual resolve transitions
    # from failed/escalated states which aren't in VALID_TRANSITIONS)
    store.force_reset_status(shot_id, target_status, reason=f"manual resolve: {action} ({failure_type})")
    store.update_shot(shot_id, gate_results=update_gate, **shot_updates)

    return JSONResponse({
        "ok": True,
        "shot_id": shot_id,
        "action": action,
        "status": target_status,
        "hero_path": hero_path,
        "fix_entry": fix_entry,
    })


@router.post("/api/reveal-in-finder")
def reveal_in_finder(
    body: dict = Body(default={}),
    project: str = Depends(get_project),
    paths: dict = Depends(get_paths),
):
    """Reveal a file in macOS Finder. Body: {"path": "output/frames/..."}"""
    rel_path = body.get("path", "")
    if not rel_path:
        return JSONResponse({"error": "Missing path"}, status_code=400)
    abs_path = Path(_resolve_output_path(rel_path, paths))
    if not abs_path.exists():
        return JSONResponse({"error": f"File not found: {rel_path}"}, status_code=404)
    subprocess.Popen(["open", "-R", str(abs_path)])
    return JSONResponse({"status": "revealed", "path": str(abs_path)})
