"""
cost_tracker.py — API cost logging and budget tracking.

Tracks per-call costs, aggregates by episode/scene/model,
and writes cost logs to per-episode prep directories
(``prep/ep_NNN/`` under the v2 layout).
"""

import json
import time
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Optional

from recoil.core.paths import ProjectPaths


@dataclass
class GenerationRecord:
    """A single API call record."""

    timestamp: float
    episode: int
    shot_id: int
    shot_name: str
    model: str
    aspect_ratio: str
    image_size: str
    cost: float
    tier: str  # simple, standard, complex
    pass_type: str  # "grid", "pro", "expression", "video_gen", "prompt_enrichment"
    media_type: str = "image"  # "image" or "video"
    grid_type: Optional[str] = None
    success: bool = True
    error: Optional[str] = None
    output_path: Optional[str] = None


class CostTracker:
    """Track generation costs across an episode or session.

    Records every API call with its cost and metadata.
    Writes a log.json to the output directory.
    """

    def __init__(self, episode: int, project: str = None):
        self.episode = episode
        self.records: list[GenerationRecord] = []
        # v2 layout: prep/ep_NNN/ (was output/frames/ep_NNN/)
        self._output_dir = ProjectPaths.for_project(project).episode_prep_dir(
            episode
        )

    def record(
        self,
        shot_id: int,
        shot_name: str,
        model: str,
        aspect_ratio: str,
        image_size: str,
        cost: float,
        tier: str,
        pass_type: str,
        media_type: str = "image",
        grid_type: Optional[str] = None,
        success: bool = True,
        error: Optional[str] = None,
        output_path: Optional[str] = None,
    ) -> GenerationRecord:
        """Record a generation call."""
        rec = GenerationRecord(
            timestamp=time.time(),
            episode=self.episode,
            shot_id=shot_id,
            shot_name=shot_name,
            model=model,
            aspect_ratio=aspect_ratio,
            image_size=image_size,
            cost=cost,
            tier=tier,
            pass_type=pass_type,
            media_type=media_type,
            grid_type=grid_type,
            success=success,
            error=error,
            output_path=output_path,
        )
        self.records.append(rec)
        return rec

    @property
    def total_cost(self) -> float:
        """Total cost of all successful calls."""
        return sum(r.cost for r in self.records if r.success)

    @property
    def total_calls(self) -> int:
        """Total number of API calls."""
        return len(self.records)

    @property
    def failed_calls(self) -> int:
        """Number of failed calls."""
        return sum(1 for r in self.records if not r.success)

    def cost_by_model(self) -> dict[str, float]:
        """Aggregate cost by model."""
        costs: dict[str, float] = {}
        for r in self.records:
            if r.success:
                costs[r.model] = costs.get(r.model, 0.0) + r.cost
        return costs

    def cost_by_tier(self) -> dict[str, float]:
        """Aggregate cost by complexity tier."""
        costs: dict[str, float] = {}
        for r in self.records:
            if r.success:
                costs[r.tier] = costs.get(r.tier, 0.0) + r.cost
        return costs

    def cost_by_pass(self) -> dict[str, float]:
        """Aggregate cost by pass type (grid, pro, expression)."""
        costs: dict[str, float] = {}
        for r in self.records:
            if r.success:
                costs[r.pass_type] = costs.get(r.pass_type, 0.0) + r.cost
        return costs

    def cost_by_media_type(self) -> dict[str, float]:
        """Aggregate cost by media type (image, video)."""
        costs: dict[str, float] = {}
        for r in self.records:
            if r.success:
                costs[r.media_type] = costs.get(r.media_type, 0.0) + r.cost
        return costs

    def summary(self) -> str:
        """Human-readable cost summary."""
        lines = [
            f"=== Cost Summary — EP{self.episode:03d} ===",
            f"Total calls: {self.total_calls} ({self.failed_calls} failed)",
            f"Total cost:  ${self.total_cost:.3f}",
        ]

        by_model = self.cost_by_model()
        if by_model:
            lines.append("\nBy model:")
            for model, cost in sorted(by_model.items()):
                lines.append(f"  {model}: ${cost:.3f}")

        by_tier = self.cost_by_tier()
        if by_tier:
            lines.append("\nBy tier:")
            for tier, cost in sorted(by_tier.items()):
                lines.append(f"  {tier}: ${cost:.3f}")

        by_pass = self.cost_by_pass()
        if by_pass:
            lines.append("\nBy pass:")
            for pass_type, cost in sorted(by_pass.items()):
                lines.append(f"  {pass_type}: ${cost:.3f}")

        by_media = self.cost_by_media_type()
        if by_media and len(by_media) > 1:
            lines.append("\nBy media type:")
            for media_type, cost in sorted(by_media.items()):
                lines.append(f"  {media_type}: ${cost:.3f}")

        return "\n".join(lines)

    def save_log(self, path: Optional[Path] = None) -> Path:
        """Save cost log to JSON file.

        Args:
            path: Override output path. Default: prep/ep_{NNN}/log.json

        Returns:
            Path to the written log file.
        """
        if path is None:
            path = self._output_dir / "log.json"

        path.parent.mkdir(parents=True, exist_ok=True)

        log = {
            "episode": self.episode,
            "total_cost": round(self.total_cost, 4),
            "total_calls": self.total_calls,
            "failed_calls": self.failed_calls,
            "cost_by_model": {k: round(v, 4) for k, v in self.cost_by_model().items()},
            "cost_by_tier": {k: round(v, 4) for k, v in self.cost_by_tier().items()},
            "records": [asdict(r) for r in self.records],
        }

        path.write_text(json.dumps(log, indent=2), encoding="utf-8")
        return path

    def load_log(self, path: Optional[Path] = None) -> None:
        """Load existing log to resume tracking.

        Args:
            path: Override input path. Default: prep/ep_{NNN}/log.json
        """
        if path is None:
            path = self._output_dir / "log.json"

        if not path.exists():
            return

        data = json.loads(path.read_text(encoding="utf-8"))
        for rec_data in data.get("records", []):
            self.records.append(GenerationRecord(**rec_data))


if __name__ == "__main__":
    # Demo with simulated records
    tracker = CostTracker(episode=1)

    # Simulate some generation calls
    tracker.record(
        1,
        "corridor_dolly",
        "gemini-3-pro-image-preview",
        "9:16",
        "4K",
        0.134,
        "simple",
        "pro",
    )
    tracker.record(
        2,
        "jinx_wedges_hook",
        "gemini-3.1-flash-image-preview",
        "1:1",
        "4K",
        0.039,
        "standard",
        "grid",
        grid_type="directors_take",
    )
    tracker.record(
        2,
        "jinx_wedges_hook",
        "gemini-3-pro-image-preview",
        "9:16",
        "4K",
        0.134,
        "standard",
        "pro",
    )
    tracker.record(
        3,
        "rebreather_fog",
        "gemini-3.1-flash-image-preview",
        "1:1",
        "4K",
        0.039,
        "complex",
        "grid",
        grid_type="scene_coverage",
    )
    tracker.record(
        3,
        "rebreather_fog",
        "gemini-3-pro-image-preview",
        "9:16",
        "4K",
        0.134,
        "complex",
        "pro",
    )
    tracker.record(
        3,
        "rebreather_fog",
        "gemini-3-pro-image-preview",
        "9:16",
        "4K",
        0.134,
        "complex",
        "pro",
        success=False,
        error="Rate limited",
    )

    print(tracker.summary())
    print()

    # Save and verify
    log_path = tracker.save_log()
    print(f"Log saved: {log_path}")
