#!/usr/bin/env python3
"""
Production Cost Tracker.

Logs every API call across the pipeline — generation, training, QC, analysis,
reference images, upscaling — with cost estimates from pricing_rates.json.
Aggregates by episode, category, provider, and date for production budgeting.

Usage as module:
    from cost_tracker import CostTracker
    tracker = CostTracker("leviathan/")

    # Generation (convenience method)
    tracker.log_generation(episode=1, shot_id=3, stage="previz", ...)

    # Any API call (general method)
    tracker.log(category="training", provider="fal", model="training_z_image",
                cost_override=1.70, detail="Jinx Z-Image LoRA, 2000 steps")

    tracker.summary()

Usage as CLI:
    python3 cost_tracker.py <project_path> summary
    python3 cost_tracker.py <project_path> summary --episode 1
    python3 cost_tracker.py <project_path> summary --category training
    python3 cost_tracker.py <project_path> summary --since 2026-02-01
    python3 cost_tracker.py <project_path> failures
    python3 cost_tracker.py <project_path> rates
    python3 cost_tracker.py <project_path> reset [--episode N]
    python3 cost_tracker.py <project_path> series-estimate
    python3 cost_tracker.py <project_path> series-estimate --take-ratio 5
    python3 cost_tracker.py <project_path> series-estimate --t2i-model flux2_lora --json

Exit codes:
  0 = success
  2 = file/parse error
"""

# ╔════════════════════════════════════════════════════════════════════╗
# ║ DEPRECATED — Superseded by Starsend equivalents (Feb 2026).      ║
# ║ Kept alive for Recoil agent protocols + referencing scripts.     ║
# ║ Do NOT delete until agents/breakdown_agent.md, storyboard_agent, ║
# ║ engine_checks/structural.py, and batch_threepass.py are updated. ║
# ╚════════════════════════════════════════════════════════════════════╝

import argparse
import json
import sys
from datetime import datetime, timezone
from pathlib import Path


# ── Failure Taxonomy ──

FAILURE_TYPES = {
    "deformation": "Hands/face/body artifacts — adjust negative prompt, quality guard",
    "identity_drift": "Character doesn't match LoRA — check weight, trigger word",
    "composition_miss": "Framing doesn't match storyboard — revise prompt, try different seed",
    "lighting_miss": "Lighting doesn't match direction — add lighting layer override",
    "motion_failure": "WAN video produces bad motion — switch approach, use held_frame",
    "wardrobe_miss": "Wrong clothes for wardrobe phase — check breakdown phase resolution",
    "text_artifacts": "Unwanted text/watermarks in frame — adjust negative prompt",
    "color_drift": "Colors don't match palette — check HEX values in prompt",
    "other": "Uncategorized failure",
}

# ── Categories ──

CATEGORIES = {
    "generation": "T2I frame generation (previz, keyframe, production)",
    "video": "Video generation (WAN I2V)",
    "upscale": "Image upscaling (Gemini)",
    "training": "LoRA training (character, video)",
    "qc": "Visual QC and gate checks (Gemini/Claude vision)",
    "analysis": "Script analysis (script doctor, scene analysis)",
    "reference": "Reference image generation (Gemini)",
    "voice": "Voice generation (ElevenLabs TTS)",
    "nbp_derive": "NBP frame derivation from hero image (Gemini)",
}


# ── Pricing Loader ──

_ENGINE_DIR = Path(__file__).resolve().parent.parent
# pricing_rates.json lives in config/, not at the engine root.
# The previous path silently returned ({}, "unknown") for every call.
_RATES_PATH = _ENGINE_DIR / "config" / "pricing_rates.json"


def load_rates():
    """Load the most recent rate card from pricing_rates.json.

    Returns ({}, "unknown") on missing or schema-invalid data, preserving
    the legacy fail-soft contract this DEPRECATED Recoil-only script needs.
    """
    if not _RATES_PATH.exists():
        return {}, "unknown"

    try:
        # Local import — keep cost_tracker importable even when lib/ path
        # isn't on sys.path (this is a stand-alone CLI script).
        import sys as _sys
        if str(_ENGINE_DIR) not in _sys.path:
            _sys.path.insert(0, str(_ENGINE_DIR))
        from recoil.core.config_schema import validate_and_load
        data = validate_and_load(_RATES_PATH, "pricing_rates")
    except Exception:
        return {}, "unknown"

    cards = data.get("rate_cards", [])
    if not cards:
        return {}, "unknown"

    # Most recent card (last in list)
    card = cards[-1]
    return card, card.get("effective_date", "unknown")


def estimate_cost(provider, model, resolution=None, loras=0,
                  steps=None, tokens_in=0, tokens_out=0,
                  images_out=0, duration_sec=0):
    """Estimate cost from pricing_rates.json.

    Args:
        provider: "fal", "gemini", or "anthropic"
        model: Model key matching pricing_rates.json
        resolution: "WxH" string (for per-megapixel models)
        loras: Number of LoRAs (selects _lora variant if available)
        steps: Training steps (for per-1k-step models)
        tokens_in: Input tokens (for per-token models)
        tokens_out: Output tokens (for per-token models)
        images_out: Number of images generated (for per-image models)
        duration_sec: Output duration in seconds (for per-second models)

    Returns:
        (cost, rate_date) tuple
    """
    card, rate_date = load_rates()
    provider_rates = card.get(provider, {})

    # Try LoRA variant first
    key = model
    if loras > 0 and f"{model}_lora" in provider_rates:
        key = f"{model}_lora"

    rate_info = provider_rates.get(key, provider_rates.get(model))
    if not rate_info:
        return 0.0, rate_date

    unit = rate_info.get("unit", "")

    # Per-megapixel (fal.ai T2I)
    if unit == "megapixel":
        mp = 0.590  # default (576x1024)
        if resolution:
            try:
                w, h = resolution.split("x")
                mp = int(w) * int(h) / 1_000_000
            except ValueError:
                pass
        return round(rate_info["rate"] * mp, 6), rate_date

    # Per clip (flat rate)
    if unit == "clip":
        return round(rate_info["rate"], 4), rate_date

    # Per second (video generation)
    if unit == "second":
        return round(rate_info["rate"] * max(duration_sec, 1), 4), rate_date

    # Per 1K steps (training)
    if unit == "1k_steps":
        k_steps = (steps or 1000) / 1000
        return round(rate_info["rate"] * k_steps, 4), rate_date

    # Per 1M tokens (Gemini, Anthropic)
    if unit == "1M_tokens":
        input_cost = (tokens_in / 1_000_000) * rate_info.get("input", 0)
        output_cost = (tokens_out / 1_000_000) * rate_info.get("output", 0)
        image_cost = images_out * rate_info.get("image_output", 0)
        return round(input_cost + output_cost + image_cost, 6), rate_date

    return 0.0, rate_date


class CostTracker:
    """Track all API costs for a project."""

    def __init__(self, project_path):
        self.project_path = Path(project_path)
        self.log_path = self.project_path / "visual" / "cost_log.json"
        self._entries = self._load()

    def _load(self):
        if self.log_path.is_file():
            try:
                with open(self.log_path) as f:
                    data = json.load(f)
                return data.get("entries", [])
            except (json.JSONDecodeError, IOError):
                return []
        return []

    def _save(self):
        self.log_path.parent.mkdir(parents=True, exist_ok=True)
        data = {
            "project": self.project_path.name,
            "updated": datetime.now(timezone.utc).isoformat(),
            "total_entries": len(self._entries),
            "entries": self._entries,
        }
        with open(self.log_path, "w") as f:
            json.dump(data, f, indent=2)

    # ── General logging method ──

    def log(
        self,
        category,
        provider,
        model,
        success=True,
        cost_override=None,
        duration_ms=None,
        detail=None,
        episode=None,
        shot_id=None,
        resolution=None,
        loras=0,
        steps=None,
        tokens_in=0,
        tokens_out=0,
        images_out=0,
        duration_sec=0,
        metadata=None,
    ):
        """Log any API call.

        Args:
            category: One of CATEGORIES keys
            provider: "fal", "gemini", "anthropic"
            model: Model key matching pricing_rates.json
            success: Whether the call succeeded
            cost_override: Exact cost if known (skips estimation)
            duration_ms: Wall-clock time of the API call
            detail: Free-text description
            episode: Episode number (if applicable)
            shot_id: Shot ID (if applicable)
            resolution: "WxH" (for generation)
            loras: LoRA count (for generation)
            steps: Inference or training steps
            tokens_in: Input tokens (for LLM calls)
            tokens_out: Output tokens (for LLM calls)
            images_out: Images generated (for Gemini image gen)
            duration_sec: Output duration in seconds (for video)
            metadata: Additional key-value pairs to store
        """
        if cost_override is not None:
            estimated_cost = round(cost_override, 6)
            rate_date = "override"
        else:
            estimated_cost, rate_date = estimate_cost(
                provider, model, resolution=resolution, loras=loras,
                steps=steps, tokens_in=tokens_in, tokens_out=tokens_out,
                images_out=images_out, duration_sec=duration_sec,
            )

        entry = {
            "timestamp": datetime.now(timezone.utc).isoformat(),
            "category": category,
            "provider": provider,
            "model": model,
            "success": success,
            "estimated_cost": estimated_cost,
            "rate_date": rate_date,
        }

        if duration_ms is not None:
            entry["duration_ms"] = duration_ms
        if detail:
            entry["detail"] = detail
        if episode is not None:
            entry["episode"] = episode
        if shot_id is not None:
            entry["shot_id"] = shot_id
        if resolution:
            entry["resolution"] = resolution
        if loras:
            entry["loras"] = loras
        if steps:
            entry["steps"] = steps
        if tokens_in:
            entry["tokens_in"] = tokens_in
        if tokens_out:
            entry["tokens_out"] = tokens_out
        if images_out:
            entry["images_out"] = images_out
        if metadata:
            entry["metadata"] = metadata

        self._entries.append(entry)
        self._save()
        return estimated_cost

    # ── Convenience: generation logging (backward compatible) ──

    def log_generation(
        self,
        episode,
        shot_id,
        stage,
        model,
        success,
        attempt=1,
        steps=None,
        resolution=None,
        loras=0,
        generation_approach=None,
        motion_complexity=None,
        failure_type=None,
        failure_detail=None,
        duration_ms=None,
        seed=None,
    ):
        """Log a generation attempt (backward-compatible convenience method)."""
        meta = {}
        if attempt > 1:
            meta["attempt"] = attempt
        if generation_approach:
            meta["generation_approach"] = generation_approach
        if motion_complexity:
            meta["motion_complexity"] = motion_complexity
        if seed is not None:
            meta["seed"] = seed
        if not success and failure_type:
            if failure_type not in FAILURE_TYPES:
                failure_type = "other"
            meta["failure_type"] = failure_type
            if failure_detail:
                meta["failure_detail"] = failure_detail

        cat = "video" if stage == "video" else "upscale" if stage == "upscale" else "generation"

        return self.log(
            category=cat,
            provider="fal" if stage != "upscale" else "gemini",
            model=model,
            success=success,
            duration_ms=duration_ms,
            episode=episode,
            shot_id=shot_id,
            resolution=resolution,
            loras=loras,
            steps=steps,
            detail=f"stage={stage}",
            metadata=meta if meta else None,
        )

    # ── Summary ──

    def summary(self, episode=None, category=None, since=None):
        """Get cost summary with optional filters.

        Args:
            episode: Filter by episode number
            category: Filter by category (generation, training, qc, etc.)
            since: Filter by date string "YYYY-MM-DD"
        """
        entries = self._entries

        if episode is not None:
            entries = [e for e in entries if e.get("episode") == episode]
        if category:
            entries = [e for e in entries if e.get("category") == category]
        if since:
            entries = [e for e in entries if e.get("timestamp", "") >= since]

        if not entries:
            return {"total_cost": 0, "total_calls": 0, "message": "No data"}

        total_cost = sum(e.get("estimated_cost", 0) for e in entries)
        total_calls = len(entries)
        successes = sum(1 for e in entries if e.get("success"))
        success_rate = round(successes / total_calls * 100, 1) if total_calls else 0

        # By category
        by_category = {}
        for e in entries:
            cat = e.get("category", "?")
            if cat not in by_category:
                by_category[cat] = {"cost": 0, "count": 0}
            by_category[cat]["cost"] += e.get("estimated_cost", 0)
            by_category[cat]["count"] += 1

        # By provider
        by_provider = {}
        for e in entries:
            prov = e.get("provider", "?")
            if prov not in by_provider:
                by_provider[prov] = {"cost": 0, "count": 0}
            by_provider[prov]["cost"] += e.get("estimated_cost", 0)
            by_provider[prov]["count"] += 1

        # By episode (only entries that have episodes)
        by_episode = {}
        for e in entries:
            ep = e.get("episode")
            if ep is not None:
                if ep not in by_episode:
                    by_episode[ep] = {"cost": 0, "count": 0, "failures": 0}
                by_episode[ep]["cost"] += e.get("estimated_cost", 0)
                by_episode[ep]["count"] += 1
                if not e.get("success"):
                    by_episode[ep]["failures"] += 1

        # By date (YYYY-MM-DD)
        by_date = {}
        for e in entries:
            day = e.get("timestamp", "")[:10]
            if day:
                if day not in by_date:
                    by_date[day] = {"cost": 0, "count": 0}
                by_date[day]["cost"] += e.get("estimated_cost", 0)
                by_date[day]["count"] += 1

        # Failure taxonomy
        failure_counts = {}
        for e in entries:
            if not e.get("success"):
                meta = e.get("metadata", {}) or {}
                ft = meta.get("failure_type", e.get("detail", "unclassified"))
                failure_counts[ft] = failure_counts.get(ft, 0) + 1

        result = {
            "total_cost": round(total_cost, 2),
            "total_calls": total_calls,
            "success_rate": success_rate,
            "by_category": {
                k: {**v, "cost": round(v["cost"], 4)}
                for k, v in sorted(by_category.items())
            },
            "by_provider": {
                k: {**v, "cost": round(v["cost"], 4)}
                for k, v in sorted(by_provider.items())
            },
        }

        if by_episode:
            result["by_episode"] = {
                k: {**v, "cost": round(v["cost"], 4)}
                for k, v in sorted(by_episode.items())
            }
        if by_date:
            result["by_date"] = {
                k: {**v, "cost": round(v["cost"], 4)}
                for k, v in sorted(by_date.items())
            }
        if failure_counts:
            result["failure_taxonomy"] = dict(
                sorted(failure_counts.items(), key=lambda x: -x[1])
            )

        return result

    def get_failures(self, failure_type=None, episode=None):
        """Get failed entries, optionally filtered."""
        entries = [e for e in self._entries if not e.get("success")]
        if failure_type:
            entries = [
                e for e in entries
                if (e.get("metadata", {}) or {}).get("failure_type") == failure_type
            ]
        if episode is not None:
            entries = [e for e in entries if e.get("episode") == episode]
        return entries

    # ── Budget methods ──

    def total(self, since=None):
        """Return total estimated spend across all entries.

        Args:
            since: Optional date string "YYYY-MM-DD" to filter.
        """
        entries = self._entries
        if since:
            entries = [e for e in entries if e.get("timestamp", "") >= since]
        return round(sum(e.get("estimated_cost", 0) for e in entries), 4)

    def check_budget(self, cap_usd, margin_usd=1.0):
        """Check if spend is under budget.

        Args:
            cap_usd: Budget cap in USD.
            margin_usd: Safety margin — flag as over-budget this far before cap.

        Returns:
            (ok, remaining, total_spent) tuple.
            ok is False when total_spent + margin >= cap.
        """
        spent = self.total()
        remaining = round(cap_usd - spent, 4)
        ok = remaining > margin_usd
        return ok, remaining, spent

    def reset(self, episode=None, category=None):
        """Clear cost log, optionally filtered."""
        if episode is not None:
            self._entries = [e for e in self._entries if e.get("episode") != episode]
        elif category:
            self._entries = [e for e in self._entries if e.get("category") != category]
        else:
            self._entries = []
        self._save()


# ── Series Estimate ──

# Default assumptions for a 60-episode series (from Leviathan ep 1 storyboard).
SERIES_DEFAULTS = {
    "episodes": 60,
    "shots_per_episode": 31,
    "standard_shots": 25,
    "triptych_shots": 5,
    "held_shots": 1,
    "keyframes_per_standard": 2,
    "keyframes_per_triptych": 1,   # 1 strip = 3 panels
    "keyframes_per_held": 1,
    "clips_per_standard": 1,
    "clips_per_triptych": 2,
    "clips_per_held": 0,
    "take_ratio": 10,
    # Resolutions
    "standard_resolution": "768x1344",   # 1.032 MP
    "triptych_resolution": "2048x1216",  # 2.49 MP
    "previz_resolution": "512x896",      # 0.459 MP
    # Models (keys into pricing_rates.json)
    "t2i_model": "z_image_turbo_lora",
    "video_model": "wan_2.2_i2v",
    "upscale_model": "gemini-2.0-flash",
    # Characters
    "num_characters": 5,
    "lora_steps_z_image": 2000,
    "lora_steps_flux": 1000,
    "lora_steps_wan": 1000,
    "ref_candidates_per_char": 50,
    # Voice
    "words_per_episode": 200,
    "chars_per_word": 5,
    "voice_rate_per_1k_chars": 0.30,
    # Subscriptions (months of production)
    "production_months": 3,
}


def series_estimate(take_ratio=None, t2i_model=None, video_model=None,
                    episodes=None, as_json=False):
    """Calculate projected cost for a full series from pricing_rates.json.

    Returns dict with line items and totals.
    """
    card, rate_date = load_rates()
    if not card:
        return {"error": "No rate card found"}

    d = dict(SERIES_DEFAULTS)
    if take_ratio is not None:
        d["take_ratio"] = take_ratio
    if t2i_model is not None:
        d["t2i_model"] = t2i_model
    if video_model is not None:
        d["video_model"] = video_model
    if episodes is not None:
        d["episodes"] = episodes

    ep = d["episodes"]
    ratio = d["take_ratio"]

    # ── Per-episode counts ──
    kf_per_ep = (d["standard_shots"] * d["keyframes_per_standard"]
                 + d["triptych_shots"] * d["keyframes_per_triptych"]
                 + d["held_shots"] * d["keyframes_per_held"])
    clips_per_ep = (d["standard_shots"] * d["clips_per_standard"]
                    + d["triptych_shots"] * d["clips_per_triptych"]
                    + d["held_shots"] * d["clips_per_held"])

    total_kf_winners = kf_per_ep * ep
    total_clip_winners = clips_per_ep * ep
    total_kf_gen = total_kf_winners * ratio
    total_clips_gen = total_clip_winners * ratio

    # ── Resolve per-megapixel costs ──
    def mp(res):
        w, h = res.split("x")
        return int(w) * int(h) / 1_000_000

    std_mp = mp(d["standard_resolution"])
    tri_mp = mp(d["triptych_resolution"])
    pvz_mp = mp(d["previz_resolution"])

    fal_rates = card.get("fal", {})
    t2i_info = fal_rates.get(d["t2i_model"], {})
    t2i_rate = t2i_info.get("rate", 0)
    vid_info = fal_rates.get(d["video_model"], {})
    vid_rate = vid_info.get("rate", 0)
    vid_unit = vid_info.get("unit", "clip")

    # Cost per standard keyframe
    kf_std_cost = t2i_rate * std_mp if t2i_info.get("unit") == "megapixel" else t2i_rate
    kf_tri_cost = t2i_rate * tri_mp if t2i_info.get("unit") == "megapixel" else t2i_rate
    kf_pvz_cost = t2i_rate * pvz_mp if t2i_info.get("unit") == "megapixel" else t2i_rate

    # Video cost per clip
    clip_cost = vid_rate  # flat rate for "clip" unit

    # ── Line items ──
    items = {}

    # Pre-production
    preproduction = 0
    ref_cost = 50  # reference images (Gemini + fal mix)
    script_doctor = 20
    lora_candidates = d["num_characters"] * d["ref_candidates_per_char"] * (t2i_rate * std_mp)
    lora_z = d["num_characters"] * (d["lora_steps_z_image"] / 1000) * fal_rates.get("training_z_image", {}).get("rate", 0.85)
    lora_flux = d["num_characters"] * (d["lora_steps_flux"] / 1000) * fal_rates.get("training_flux", {}).get("rate", 3.00)
    lora_wan = d["num_characters"] * (d["lora_steps_wan"] / 1000) * fal_rates.get("training_wan", {}).get("rate", 2.00)
    preproduction = ref_cost + script_doctor + lora_candidates + lora_z + lora_flux + lora_wan
    items["pre_production"] = round(preproduction, 2)

    # Previz
    previz_frames = d["shots_per_episode"] * ep
    previz_cost = previz_frames * kf_pvz_cost
    previz_qc = previz_cost * 2  # ~2-3x regens
    items["previz"] = round(previz_cost + previz_qc, 2)

    # Keyframes
    std_kf_count = d["standard_shots"] * ep * d["keyframes_per_standard"] * ratio
    tri_kf_count = d["triptych_shots"] * ep * d["keyframes_per_triptych"] * ratio
    held_kf_count = d["held_shots"] * ep * d["keyframes_per_held"] * ratio
    keyframe_cost = (std_kf_count * kf_std_cost
                     + tri_kf_count * kf_tri_cost
                     + held_kf_count * kf_std_cost)
    items["keyframes"] = round(keyframe_cost, 2)
    items["keyframes_generated"] = std_kf_count + tri_kf_count + held_kf_count

    # Upscale (winners only — triptych strips expand to 3 panels each)
    gemini_rates = card.get("gemini", {})
    upscale_rate = gemini_rates.get(d["upscale_model"], {}).get("image_output", 0.039)
    upscale_frames = ((d["standard_shots"] * d["keyframes_per_standard"]
                       + d["triptych_shots"] * 3  # 3 panels per strip
                       + d["held_shots"] * d["keyframes_per_held"]) * ep)
    upscale_cost = upscale_frames * upscale_rate
    items["upscale"] = round(upscale_cost, 2)

    # Video
    video_cost = total_clips_gen * clip_cost
    items["video"] = round(video_cost, 2)
    items["video_clips_generated"] = total_clips_gen

    # Voice
    chars_per_ep = d["words_per_episode"] * d["chars_per_word"]
    total_chars = chars_per_ep * ep * ratio
    voice_cost = (total_chars / 1000) * d["voice_rate_per_1k_chars"]
    items["voice"] = round(voice_cost, 2)

    # Music (Suno Premier plan — $30/mo, 2000 songs)
    suno_rates = card.get("subscriptions", {}).get("suno_premier", {})
    music_cost = suno_rates.get("rate", 30) * d["production_months"]
    items["music"] = round(music_cost, 2)

    # Sound design (fixed estimate)
    items["sound_design"] = 60

    # Subscriptions
    subs = card.get("subscriptions", {})
    sub_cost = 0
    sub_cost += subs.get("claude_max_20x", {}).get("rate", 200) * d["production_months"]
    sub_cost += subs.get("gemini_ai_pro", {}).get("rate", 20) * d["production_months"]
    sub_cost += subs.get("midjourney_std", {}).get("rate", 30) * 2  # 2 months
    items["subscriptions"] = round(sub_cost, 2)

    total = sum(v for k, v in items.items() if isinstance(v, (int, float)) and not k.endswith("_generated"))
    items["total"] = round(total, 2)
    items["per_episode"] = round(total / d["episodes"], 2)

    # Metadata
    items["_config"] = {
        "episodes": d["episodes"],
        "take_ratio": d["take_ratio"],
        "t2i_model": d["t2i_model"],
        "video_model": d["video_model"],
        "rate_date": rate_date,
    }

    return items


# ── CLI ──

def main():
    parser = argparse.ArgumentParser(description="Production cost tracker")
    parser.add_argument("project", help="Path to project folder")

    subparsers = parser.add_subparsers(dest="command", required=True)

    # summary
    sum_p = subparsers.add_parser("summary", help="Show cost summary")
    sum_p.add_argument("--episode", type=int, help="Filter by episode")
    sum_p.add_argument("--category", help="Filter by category")
    sum_p.add_argument("--since", help="Filter since date (YYYY-MM-DD)")
    sum_p.add_argument("--json", action="store_true", help="JSON output")

    # failures
    fail_p = subparsers.add_parser("failures", help="Show failure log")
    fail_p.add_argument("--type", help="Filter by failure type")
    fail_p.add_argument("--episode", type=int, help="Filter by episode")

    # rates
    subparsers.add_parser("rates", help="Show current rate card")

    # reset
    reset_p = subparsers.add_parser("reset", help="Clear cost log")
    reset_p.add_argument("--episode", type=int, help="Clear only one episode")
    reset_p.add_argument("--category", help="Clear only one category")

    # series-estimate
    est_p = subparsers.add_parser("series-estimate", help="Project full series cost")
    est_p.add_argument("--episodes", type=int, default=60, help="Number of episodes (default 60)")
    est_p.add_argument("--take-ratio", type=int, default=10, help="Takes per usable take (default 10)")
    est_p.add_argument("--t2i-model", default=None, help="T2I model key (default z_image_turbo_lora)")
    est_p.add_argument("--video-model", default=None, help="Video model key (default wan_2.2_i2v)")
    est_p.add_argument("--json", action="store_true", help="JSON output")

    # taxonomy
    subparsers.add_parser("taxonomy", help="Show failure type definitions")

    # categories
    subparsers.add_parser("categories", help="Show cost categories")

    args = parser.parse_args()

    tracker = CostTracker(args.project)

    if args.command == "summary":
        s = tracker.summary(
            episode=args.episode, category=args.category, since=args.since
        )
        if args.json:
            print(json.dumps(s, indent=2))
        else:
            print(f"=== Cost Summary: {Path(args.project).resolve().name} ===")
            filters = []
            if args.episode:
                filters.append(f"Episode {args.episode}")
            if args.category:
                filters.append(f"Category: {args.category}")
            if args.since:
                filters.append(f"Since: {args.since}")
            if filters:
                print(f"  Filters: {', '.join(filters)}")
            print()
            print(f"  Total cost:  ${s['total_cost']:.2f}")
            print(f"  API calls:   {s['total_calls']}")
            print(f"  Success:     {s['success_rate']}%")

            if s.get("by_category"):
                print()
                print("  By Category:")
                for cat, data in s["by_category"].items():
                    print(f"    {cat}: ${data['cost']:.4f} ({data['count']} calls)")

            if s.get("by_provider"):
                print()
                print("  By Provider:")
                for prov, data in s["by_provider"].items():
                    print(f"    {prov}: ${data['cost']:.4f} ({data['count']} calls)")

            if s.get("by_episode"):
                print()
                print("  By Episode:")
                for ep, data in s["by_episode"].items():
                    fail_str = f" ({data['failures']} failures)" if data["failures"] else ""
                    print(f"    Ep {ep}: ${data['cost']:.4f} ({data['count']} calls){fail_str}")

            if s.get("by_date"):
                print()
                print("  By Date:")
                for day, data in s["by_date"].items():
                    print(f"    {day}: ${data['cost']:.4f} ({data['count']} calls)")

            if s.get("failure_taxonomy"):
                print()
                print("  Failure Types:")
                for ft, count in s["failure_taxonomy"].items():
                    print(f"    {ft}: {count}")

    elif args.command == "failures":
        failures = tracker.get_failures(
            failure_type=args.type, episode=args.episode
        )
        if not failures:
            print("No failures found.")
        else:
            print(f"=== Failures ({len(failures)}) ===")
            for f in failures:
                ts = f.get("timestamp", "?")[:19].replace("T", " ")
                cat = f.get("category", "?")
                model = f.get("model", "?")
                detail = f.get("detail", "")
                ep = f.get("episode")
                shot = f.get("shot_id")
                loc = ""
                if ep is not None:
                    loc += f"Ep{ep} "
                if shot is not None:
                    loc += f"Shot#{shot} "
                print(f"  {ts} [{cat}] {loc}{model}")
                if detail:
                    print(f"    {detail}")

    elif args.command == "series-estimate":
        est = series_estimate(
            take_ratio=args.take_ratio,
            t2i_model=args.t2i_model,
            video_model=args.video_model,
            episodes=args.episodes,
        )
        if est.get("error"):
            print(f"ERROR: {est['error']}")
            sys.exit(2)
        if args.json:
            print(json.dumps(est, indent=2))
        else:
            cfg = est.get("_config", {})
            print(f"=== Series Cost Estimate ===")
            print(f"  Episodes: {cfg.get('episodes', 60)}  |  Take ratio: {cfg.get('take_ratio', 10)}x")
            print(f"  T2I model: {cfg.get('t2i_model')}  |  Video model: {cfg.get('video_model')}")
            print(f"  Rate card: {cfg.get('rate_date')}")
            print()
            print(f"  Pre-production:    ${est.get('pre_production', 0):>10,.2f}")
            print(f"  Previz:            ${est.get('previz', 0):>10,.2f}")
            print(f"  Keyframes:         ${est.get('keyframes', 0):>10,.2f}  ({est.get('keyframes_generated', 0):,} images)")
            print(f"  Upscale:           ${est.get('upscale', 0):>10,.2f}")
            print(f"  Video:             ${est.get('video', 0):>10,.2f}  ({est.get('video_clips_generated', 0):,} clips)")
            print(f"  Voice:             ${est.get('voice', 0):>10,.2f}")
            print(f"  Music:             ${est.get('music', 0):>10,.2f}")
            print(f"  Sound design:      ${est.get('sound_design', 0):>10,.2f}")
            print(f"  Subscriptions:     ${est.get('subscriptions', 0):>10,.2f}")
            print(f"  {'─' * 40}")
            print(f"  TOTAL:             ${est.get('total', 0):>10,.2f}")
            print(f"  Per episode:       ${est.get('per_episode', 0):>10,.2f}")

    elif args.command == "rates":
        card, rate_date = load_rates()
        if not card:
            print("ERROR: No rate card found. Check pricing_rates.json")
            sys.exit(2)
        print(f"=== Rate Card (effective {rate_date}) ===")
        for provider in ("fal", "gemini", "anthropic", "elevenlabs", "suno", "kling", "runpod", "subscriptions"):
            rates = card.get(provider, {})
            if rates:
                print(f"\n  {provider}:")
                for model, info in rates.items():
                    unit = info.get("unit", "?")
                    if unit == "megapixel":
                        print(f"    {model}: ${info['rate']}/MP")
                    elif unit == "clip":
                        print(f"    {model}: ${info['rate']}/clip")
                    elif unit == "second":
                        print(f"    {model}: ${info['rate']}/sec")
                    elif unit == "1k_steps":
                        print(f"    {model}: ${info['rate']}/1K steps")
                    elif unit == "1M_tokens":
                        parts = [f"${info.get('input', 0)}/1M in", f"${info.get('output', 0)}/1M out"]
                        if info.get("image_output"):
                            parts.append(f"${info['image_output']}/image")
                        print(f"    {model}: {', '.join(parts)}")
                    elif unit == "1k_chars":
                        print(f"    {model}: ${info['rate']}/1K chars")
                    elif unit == "song":
                        print(f"    {model}: ${info['rate']}/song")
                    elif unit == "hour":
                        print(f"    {model}: ${info['rate']}/hr")
                    elif unit == "month":
                        extra = ""
                        if info.get("chars_included"):
                            extra = f" ({info['chars_included']:,} chars)"
                        elif info.get("songs_included"):
                            extra = f" ({info['songs_included']:,} songs)"
                        print(f"    {model}: ${info['rate']}/mo{extra}")
                    else:
                        print(f"    {model}: ${info.get('rate', '?')}/{unit}")

    elif args.command == "reset":
        tracker.reset(episode=args.episode, category=args.category)
        scope = []
        if args.episode:
            scope.append(f"episode {args.episode}")
        if args.category:
            scope.append(f"category {args.category}")
        print(f"Cost log cleared ({', '.join(scope) if scope else 'all'}).")

    elif args.command == "taxonomy":
        print("=== Failure Taxonomy ===")
        print()
        for ft, desc in FAILURE_TYPES.items():
            print(f"  {ft}")
            print(f"    {desc}")
        print()

    elif args.command == "categories":
        print("=== Cost Categories ===")
        print()
        for cat, desc in CATEGORIES.items():
            print(f"  {cat}")
            print(f"    {desc}")
        print()


if __name__ == "__main__":
    main()
