"""
cost_tracker.py — API cost logging and budget tracking.

Tracks per-call costs, aggregates by episode/scene/model,
and writes cost logs to output directories.
"""

import json
import time
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Optional

from lib.constants import STARSEND_ROOT, project_output_dir


@dataclass
class GenerationRecord:
    """A single API call record."""
    timestamp: float
    episode: int
    shot_id: int
    shot_name: str
    model: str
    aspect_ratio: str
    image_size: str
    cost: float
    tier: str                   # simple, standard, complex
    pass_type: str              # "grid", "pro", "expression", "video_gen", "prompt_enrichment"
    media_type: str = "image"   # "image" or "video"
    grid_type: Optional[str] = None
    success: bool = True
    error: Optional[str] = None
    output_path: Optional[str] = None


class CostTracker:
    """Track generation costs across an episode or session.

    Records every API call with its cost and metadata.
    Writes a log.json to the output directory.
    """

    def __init__(self, episode: int, project: str = None):
        self.episode = episode
        self.records: list[GenerationRecord] = []
        self._output_dir = project_output_dir(project) / "frames" / f"ep_{episode:03d}"

    def record(self, shot_id: int, shot_name: str, model: str,
               aspect_ratio: str, image_size: str, cost: float,
               tier: str, pass_type: str, media_type: str = "image",
               grid_type: Optional[str] = None,
               success: bool = True, error: Optional[str] = None,
               output_path: Optional[str] = None) -> GenerationRecord:
        """Record a generation call."""
        rec = GenerationRecord(
            timestamp=time.time(),
            episode=self.episode,
            shot_id=shot_id,
            shot_name=shot_name,
            model=model,
            aspect_ratio=aspect_ratio,
            image_size=image_size,
            cost=cost,
            tier=tier,
            pass_type=pass_type,
            media_type=media_type,
            grid_type=grid_type,
            success=success,
            error=error,
            output_path=output_path,
        )
        self.records.append(rec)
        return rec

    @property
    def total_cost(self) -> float:
        """Total cost of all successful calls."""
        return sum(r.cost for r in self.records if r.success)

    @property
    def total_calls(self) -> int:
        """Total number of API calls."""
        return len(self.records)

    @property
    def failed_calls(self) -> int:
        """Number of failed calls."""
        return sum(1 for r in self.records if not r.success)

    def cost_by_model(self) -> dict[str, float]:
        """Aggregate cost by model."""
        costs: dict[str, float] = {}
        for r in self.records:
            if r.success:
                costs[r.model] = costs.get(r.model, 0.0) + r.cost
        return costs

    def cost_by_tier(self) -> dict[str, float]:
        """Aggregate cost by complexity tier."""
        costs: dict[str, float] = {}
        for r in self.records:
            if r.success:
                costs[r.tier] = costs.get(r.tier, 0.0) + r.cost
        return costs

    def cost_by_pass(self) -> dict[str, float]:
        """Aggregate cost by pass type (grid, pro, expression)."""
        costs: dict[str, float] = {}
        for r in self.records:
            if r.success:
                costs[r.pass_type] = costs.get(r.pass_type, 0.0) + r.cost
        return costs

    def cost_by_media_type(self) -> dict[str, float]:
        """Aggregate cost by media type (image, video)."""
        costs: dict[str, float] = {}
        for r in self.records:
            if r.success:
                costs[r.media_type] = costs.get(r.media_type, 0.0) + r.cost
        return costs

    def summary(self) -> str:
        """Human-readable cost summary."""
        lines = [
            f"=== Cost Summary — EP{self.episode:03d} ===",
            f"Total calls: {self.total_calls} ({self.failed_calls} failed)",
            f"Total cost:  ${self.total_cost:.3f}",
        ]

        by_model = self.cost_by_model()
        if by_model:
            lines.append("\nBy model:")
            for model, cost in sorted(by_model.items()):
                lines.append(f"  {model}: ${cost:.3f}")

        by_tier = self.cost_by_tier()
        if by_tier:
            lines.append("\nBy tier:")
            for tier, cost in sorted(by_tier.items()):
                lines.append(f"  {tier}: ${cost:.3f}")

        by_pass = self.cost_by_pass()
        if by_pass:
            lines.append("\nBy pass:")
            for pass_type, cost in sorted(by_pass.items()):
                lines.append(f"  {pass_type}: ${cost:.3f}")

        by_media = self.cost_by_media_type()
        if by_media and len(by_media) > 1:
            lines.append("\nBy media type:")
            for media_type, cost in sorted(by_media.items()):
                lines.append(f"  {media_type}: ${cost:.3f}")

        return "\n".join(lines)

    def save_log(self, path: Optional[Path] = None) -> Path:
        """Save cost log to JSON file.

        Args:
            path: Override output path. Default: output/frames/ep_{NNN}/log.json

        Returns:
            Path to the written log file.
        """
        if path is None:
            path = self._output_dir / "log.json"

        path.parent.mkdir(parents=True, exist_ok=True)

        log = {
            "episode": self.episode,
            "total_cost": round(self.total_cost, 4),
            "total_calls": self.total_calls,
            "failed_calls": self.failed_calls,
            "cost_by_model": {k: round(v, 4) for k, v in self.cost_by_model().items()},
            "cost_by_tier": {k: round(v, 4) for k, v in self.cost_by_tier().items()},
            "records": [asdict(r) for r in self.records],
        }

        path.write_text(json.dumps(log, indent=2), encoding="utf-8")
        return path

    def load_log(self, path: Optional[Path] = None) -> None:
        """Load existing log to resume tracking.

        Args:
            path: Override input path. Default: output/frames/ep_{NNN}/log.json
        """
        if path is None:
            path = self._output_dir / "log.json"

        if not path.exists():
            return

        data = json.loads(path.read_text(encoding="utf-8"))
        for rec_data in data.get("records", []):
            self.records.append(GenerationRecord(**rec_data))


if __name__ == "__main__":
    # Demo with simulated records
    tracker = CostTracker(episode=1)

    # Simulate some generation calls
    tracker.record(1, "corridor_dolly", "gemini-3-pro-image-preview",
                   "9:16", "4K", 0.134, "simple", "pro")
    tracker.record(2, "jinx_wedges_hook", "gemini-3.1-flash-image-preview",
                   "1:1", "4K", 0.039, "standard", "grid", grid_type="directors_take")
    tracker.record(2, "jinx_wedges_hook", "gemini-3-pro-image-preview",
                   "9:16", "4K", 0.134, "standard", "pro")
    tracker.record(3, "rebreather_fog", "gemini-3.1-flash-image-preview",
                   "1:1", "4K", 0.039, "complex", "grid", grid_type="scene_coverage")
    tracker.record(3, "rebreather_fog", "gemini-3-pro-image-preview",
                   "9:16", "4K", 0.134, "complex", "pro")
    tracker.record(3, "rebreather_fog", "gemini-3-pro-image-preview",
                   "9:16", "4K", 0.134, "complex", "pro",
                   success=False, error="Rate limited")

    print(tracker.summary())
    print()

    # Save and verify
    log_path = tracker.save_log()
    print(f"Log saved: {log_path}")
