"""Pipeline Inspector API — data aggregation for the Inspector React app.

Aggregates plan data, execution state, prompt sections, and routing reasons
into a single JSON response consumed by the Pipeline Inspector frontend.

All functions are importable independently — no module-level imports from
lib/ or orchestrator/ to avoid circular import issues.
"""

import json
import time
import tempfile
from datetime import datetime, timezone
from pathlib import Path


def get_inspector_data(
    episode_num: int,
    project_paths: dict,
    bible: dict,
    project_config: dict,
) -> dict:
    """Aggregate all data needed for the Pipeline Inspector.

    Parameters:
        episode_num:    Episode number (e.g. 1 → ep_001)
        project_paths:  From _paths_for_project() in review_server.py
        bible:          Loaded global_bible.json dict
        project_config: Loaded pipeline_config.json dict

    Returns a dict with shots, scenes, cost_summary, and metadata.
    """
    plans_dir = Path(project_paths["plans_dir"])
    state_dir = Path(project_paths["state_dir"])
    project = project_paths["project"]

    # ── Load plan ────────────────────────────────────────────────────
    plan = _load_plan(plans_dir, episode_num)
    if plan is None:
        return {
            "error": f"Plan not found for episode {episode_num}",
            "episode": f"ep_{episode_num:03d}",
            "project": project,
            "loaded_at": _now_iso(),
            "shots": [],
            "scenes": [],
            "cost_summary": _empty_cost_summary(),
        }

    # ── Build shots ──────────────────────────────────────────────────
    shots_out = []
    cost_actual_total = 0.0
    cost_actual_count = 0
    cost_estimated_total = 0.0
    cost_estimated_count = 0

    for shot in plan.get("shots", []):
        shot_id = shot.get("shot_id", "")

        # Read execution state
        exec_state = _load_execution_state(state_dir, shot_id)

        # Compute prompt sections and refs FIRST (needed for reconstruction)
        prompt_sections = _compute_prompt_sections(shot, bible, project_config, episode_num)
        reference_images = _resolve_reference_images(shot, project_paths)

        # Compute routing reason (lazy import)
        routing_result = _compute_routing(shot)

        # Reconstruct inputs for legacy takes that lack inputs_snapshot
        for take in exec_state.get("takes", []):
            if isinstance(take, dict) and take.get("inputs_snapshot") is None:
                try:
                    from recoil.pipeline._lib.take_inputs import reconstruct_inputs
                    take["inputs_snapshot"] = reconstruct_inputs(
                        take, shot, bible, project_config,
                        prompt_sections=prompt_sections,
                        reference_images=reference_images,
                    )
                except Exception:
                    pass

        # Normalize takes
        takes = _normalize_takes(exec_state.get("takes", []))

        # Cost computation
        cost_actual = _extract_actual_cost(exec_state)
        cost_estimated = _estimate_cost(shot, project_config)

        if cost_actual > 0:
            cost_actual_total += cost_actual
            cost_actual_count += 1
        if cost_estimated > 0:
            cost_estimated_total += cost_estimated
            cost_estimated_count += 1

        # Extract sub-dicts from plan shot
        prompt_data = shot.get("prompt_data", {})
        routing_data = shot.get("routing_data", {})
        spatial_data = shot.get("spatial_data", {})
        asset_data = shot.get("asset_data", {})
        audio_data = shot.get("audio_data", {})

        # Determine shot label
        scene_idx = shot.get("scene_index", 0)
        shot_idx = shot.get("shot_index") or _infer_shot_index(shot_id)
        label = shot.get("label") or shot_id

        # Routing info merged
        pipeline = exec_state.get("pipeline") or routing_result.get("pipeline")
        model = exec_state.get("model") or routing_result.get("model")
        tier = routing_data.get("camera_complexity") or _infer_tier(shot, project_config)

        # Coverage fields
        is_coverage = shot.get("is_coverage", False)
        coverage_of = shot.get("coverage_of", None)
        has_manual_override = bool(exec_state.get("manual_prompt_override"))

        shots_out.append({
            "shot_id": shot_id,
            "scene_index": scene_idx,
            "shot_index": shot_idx,
            "label": label,
            "shot_type": prompt_data.get("shot_type", "MS"),
            "characters": asset_data.get("characters", []),
            "location_id": asset_data.get("location_id", ""),
            "routing": {
                "pipeline": pipeline,
                "model": model,
                "tier": tier,
                "reason": routing_result.get("reason", "Unknown"),
            },
            "status": exec_state.get("status", "pending"),
            "cost_actual": cost_actual,
            "cost_estimated": cost_estimated,
            "prompt_sections": prompt_sections,
            "prompt_data": prompt_data,
            "routing_data": routing_data,
            "spatial_data": spatial_data,
            "asset_data": asset_data,
            "audio_data": audio_data,
            "reference_images": reference_images,
            "takes": takes,
            "is_coverage": is_coverage,
            "coverage_of": coverage_of,
            "has_manual_override": has_manual_override,
        })

    # ── Build scenes ─────────────────────────────────────────────────
    scenes = _build_scenes(shots_out)

    # ── Cost summary ─────────────────────────────────────────────────
    cost_summary = {
        "actual_total": round(cost_actual_total, 4),
        "actual_shot_count": cost_actual_count,
        "estimated_total": round(cost_estimated_total, 4),
        "estimated_shot_count": cost_estimated_count,
    }

    return {
        "episode": f"ep_{episode_num:03d}",
        "project": project,
        "loaded_at": _now_iso(),
        "shots": shots_out,
        "scenes": scenes,
        "cost_summary": cost_summary,
    }


def _estimate_cost(shot: dict, project_config: dict) -> float:
    """Look up estimated cost from complexity_tiers in project_config.

    Falls back to standard tier if tier not found, or 0.0 if config missing.
    """
    if not project_config:
        return 0.0

    tiers = project_config.get("complexity_tiers", {})
    if not tiers:
        return 0.0

    # Determine tier from routing_data or infer
    routing_data = shot.get("routing_data", {})
    complexity = routing_data.get("camera_complexity", "standard")

    # Normalize tier name
    tier_key = str(complexity).lower().strip()
    tier_info = tiers.get(tier_key) or tiers.get("standard", {})

    return float(tier_info.get("estimated_cost", 0.0))


def get_inspector_notes(project_paths: dict) -> dict:
    """Read inspector_notes.json from state_dir.

    Returns default structure if file is missing or invalid.
    """
    state_dir = Path(project_paths["state_dir"])
    notes_path = state_dir / "inspector_notes.json"

    if notes_path.exists():
        try:
            data = json.loads(notes_path.read_text(encoding="utf-8"))
            if isinstance(data, dict):
                return data
        except (json.JSONDecodeError, OSError):
            pass

    return {
        "notes": {},
        "updated_at": None,
    }


def save_inspector_notes(project_paths: dict, notes_data: dict) -> None:
    """Write notes atomically to inspector_notes.json in state_dir."""
    state_dir = Path(project_paths["state_dir"])
    state_dir.mkdir(parents=True, exist_ok=True)
    notes_path = state_dir / "inspector_notes.json"

    notes_data["updated_at"] = _now_iso()

    # Atomic write via temp file + rename
    tmp_fd, tmp_path = tempfile.mkstemp(
        dir=str(state_dir), suffix=".tmp", prefix="inspector_notes_"
    )
    try:
        with open(tmp_fd, "w", encoding="utf-8") as f:
            json.dump(notes_data, f, indent=2)
        Path(tmp_path).replace(notes_path)
    except Exception:
        # Clean up temp file on failure
        try:
            Path(tmp_path).unlink(missing_ok=True)
        except OSError:
            pass
        raise


# ── Internal helpers ─────────────────────────────────────────────────


def _now_iso() -> str:
    """Return current UTC time as ISO 8601 string."""
    return datetime.now(timezone.utc).isoformat()


def _empty_cost_summary() -> dict:
    return {
        "actual_total": 0.0,
        "actual_shot_count": 0,
        "estimated_total": 0.0,
        "estimated_shot_count": 0,
    }


def _load_plan(plans_dir: Path, episode_num: int) -> dict | None:
    """Load plan file, trying zero-padded then unpadded names."""
    candidates = [
        plans_dir / f"ep_{episode_num:03d}_plan.json",
        plans_dir / f"ep_{episode_num}_plan.json",
    ]
    for path in candidates:
        if path.exists():
            try:
                return json.loads(path.read_text(encoding="utf-8"))
            except (json.JSONDecodeError, OSError):
                continue
    return None


def _load_execution_state(state_dir: Path, shot_id: str) -> dict:
    """Load execution state for a shot. Returns default dict if missing."""
    shots_dir = state_dir / "shots"
    state_path = shots_dir / f"{shot_id}.json"

    if state_path.exists():
        try:
            data = json.loads(state_path.read_text(encoding="utf-8"))
            if isinstance(data, dict):
                return data
        except (json.JSONDecodeError, OSError):
            pass

    return {
        "status": "pending",
        "pipeline": None,
        "model": None,
        "cost_incurred": 0,
        "takes": [],
    }


def _normalize_takes(takes: list) -> list:
    """Normalize take entries to a consistent structure.

    Handles variations: cost vs cost_usd vs cost_incurred, take_id vs take_number.
    """
    normalized = []
    for take in takes:
        if not isinstance(take, dict):
            continue

        # Normalize take_id
        take_id = (
            take.get("take_id")
            or take.get("take_number")
            or take.get("take_num")
        )
        if take_id is not None:
            take_id = str(take_id)

        # Normalize cost (use 'is not None' to preserve explicit zero values)
        cost = take.get("cost")
        if cost is None:
            cost = take.get("cost_usd")
        if cost is None:
            cost = take.get("cost_incurred")
        if cost is None:
            cost = 0.0

        file_path = take.get("file_path") or take.get("output_path")
        created_at = take.get("created_at") or take.get("timestamp") or take.get("started_at")

        # Infer pipeline from file_path when not explicitly set
        pipeline = take.get("pipeline")
        if not pipeline and file_path:
            fp_lower = str(file_path).lower()
            if "/previs/" in fp_lower or "/frames/" in fp_lower or any(
                fp_lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".webp")
            ):
                pipeline = "previz"
            elif "/video/" in fp_lower or any(
                fp_lower.endswith(ext) for ext in (".mp4", ".mov", ".webm")
            ):
                pipeline = "video"

        # Synthesize status from boolean flags if not set
        raw_status = take.get("status") or take.get("disposition")
        if not raw_status:
            if take.get("approved"):
                raw_status = "approved"
            elif take.get("rejected"):
                raw_status = "rejected"

        # Build thumbnail URL from file_path if it's an image, or video_url if video
        thumbnail = None
        video_url = None
        if file_path:
            fp_str = str(file_path)
            fp_lower = fp_str.lower()
            # Strip leading 'output/' — the server route adds /output/ prefix
            rel_path = fp_str[len("output/"):] if fp_str.startswith("output/") else fp_str
            url = f"/output/{rel_path}" if not fp_str.startswith("/") else fp_str
            if any(fp_lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".webp")):
                thumbnail = url
            elif any(fp_lower.endswith(ext) for ext in (".mp4", ".mov", ".webm")):
                video_url = url

        normalized.append({
            "take_id": take_id,
            "status": raw_status,
            "pipeline": pipeline,
            "model": take.get("model") or "unknown",
            "cost": float(cost),
            "created_at": created_at,
            "output_path": file_path,
            "thumbnail": thumbnail,
            "video_url": video_url,
            "prompt_used": take.get("prompt_used") or take.get("prompt") or take.get("authored_prompt"),
            "gate_results": take.get("gate_1") or take.get("gate_results"),
            "inputs_snapshot": take.get("inputs_snapshot"),
        })
    return normalized


def _compute_prompt_sections(
    shot: dict, bible: dict, project_config: dict, episode_num: int
) -> list:
    """Build prompt sections via prompt_engine. Returns [] on failure."""
    try:
        from recoil.pipeline._lib.prompt_engine import build_prompt_sections_from_plan
        return build_prompt_sections_from_plan(
            shot, bible, project_config, episode=episode_num
        )
    except Exception:
        return []


def _compute_routing(shot: dict) -> dict:
    """Compute routing via scene_planner.route_shot(). Returns fallback on failure."""
    try:
        from orchestrator.scene_planner import route_shot
        result = route_shot(shot)
        if isinstance(result, dict):
            return result
    except Exception:
        pass

    return {
        "pipeline": None,
        "model": None,
        "reason": "Routing unavailable",
    }


def _resolve_reference_images(shot: dict, project_paths: dict) -> list:
    """Check if hero refs exist on disk for characters and locations.

    Returns list of dicts matching frontend ReferenceImage: type, id, url, label.
    """
    refs = []
    asset_data = shot.get("asset_data", {})
    char_refs_dir = Path(project_paths.get("character_refs_dir", ""))
    loc_refs_dir = Path(project_paths.get("location_refs_dir", ""))

    def _find_hero(directory):
        """Find hero image file, trying exact names then glob fallback."""
        if not directory.is_dir():
            return None
        for name in ("hero.png", "hero.jpeg", "hero.jpg", "hero.webp"):
            candidate = directory / name
            if candidate.exists():
                return candidate
        # Glob fallback: look for *hero* or *Hero* files
        for pattern in ("*hero*.*", "*Hero*.*"):
            matches = sorted(directory.glob(pattern))
            img_matches = [m for m in matches if m.suffix.lower() in (".png", ".jpeg", ".jpg", ".webp")]
            if img_matches:
                return img_matches[0]
        return None

    # Character refs — IDs are uppercase in data, lowercase on disk
    for char in asset_data.get("characters", []):
        if isinstance(char, str):
            char_id = char
        else:
            char_id = char.get("char_id") or char.get("id") or ""
        if not char_id:
            continue
        char_lower = char_id.lower()
        hero = _find_hero(char_refs_dir / char_lower)
        if hero:
            refs.append({
                "type": "character",
                "id": char_lower,
                "label": char_id.replace("_", " ").title(),
                "url": f"/assets/char/{char_lower}/base/{hero.name}",
            })

    # Location refs — also lowercase on disk
    location_id = asset_data.get("location_id") or shot.get("location_id", "")
    if location_id:
        loc_lower = location_id.lower()
        hero = _find_hero(loc_refs_dir / loc_lower)
        if hero:
            refs.append({
                "type": "location",
                "id": loc_lower,
                "label": location_id.replace("_", " ").title(),
                "url": f"/assets/loc/{loc_lower}/base/{hero.name}",
            })

    return refs


def _extract_actual_cost(exec_state: dict) -> float:
    """Extract actual cost from execution state, checking multiple fields."""
    cost = exec_state.get("cost_incurred")
    if cost is None:
        cost = exec_state.get("cost")
    if cost is None:
        cost = exec_state.get("cost_usd")
    if cost is None:
        cost = 0.0
    return float(cost)


def _infer_shot_index(shot_id: str) -> int | None:
    """Infer shot index from shot_id like EP001_SH05 → 5."""
    try:
        parts = shot_id.split("_")
        for part in parts:
            if part.startswith("SH"):
                return int(part[2:])
    except (ValueError, IndexError):
        pass
    return None


def _infer_tier(shot: dict, project_config: dict) -> str | None:
    """Infer complexity tier from routing_data or return None."""
    routing_data = shot.get("routing_data", {})
    return routing_data.get("camera_complexity") or None


def _build_scenes(shots: list) -> list:
    """Group shots by scene_index and build scene summaries."""
    scene_map: dict[int, dict] = {}
    for shot in shots:
        si = shot.get("scene_index", 0)
        if si not in scene_map:
            scene_map[si] = {
                "scene_index": si,
                "location_id": shot.get("location_id", ""),
                "shot_count": 0,
            }
        scene_map[si]["shot_count"] += 1
        # Use the first shot's location_id for the scene if not already set
        if not scene_map[si]["location_id"] and shot.get("location_id"):
            scene_map[si]["location_id"] = shot["location_id"]

    return sorted(scene_map.values(), key=lambda s: s["scene_index"])
