"""
production_types.py — Type definitions for the Production Orchestrator.

Frozen/immutable dataclasses for orchestration state. These are the contracts
between the production loop and its subsystems (batch manager, retry dispatcher,
autonomy controller, learning engine, provenance writer).
"""

import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional


# ── Enums ──────────────────────────────────────────────────────────

class BatchStatus(str, Enum):
    """Lifecycle of a production batch."""
    CREATED = "created"
    RUNNING = "running"
    PAUSED = "paused"
    BUDGET_EXHAUSTED = "budget_exhausted"
    COMPLETE = "complete"
    FAILED = "failed"


class FailureCategory(str, Enum):
    """Classification of generation failures for retry routing."""
    TRANSIENT = "transient"          # 429, 500, 503, timeout — auto-retry with backoff
    GATE_MECHANICAL = "gate_mechanical"  # Gate 1 fail — retry with different seed
    GATE_IDENTITY = "gate_identity"      # Gate 2A fail — retry with stronger refs
    GATE_WARDROBE = "gate_wardrobe"      # Gate 2A wardrobe fail — check phase, swap ref
    GATE_VIDEO_DRIFT = "gate_video_drift"  # Gate 3 — flag for review (not auto-reject)
    CONTENT_FILTER = "content_filter"    # Model refused — needs prompt rewrite
    PROMPT_DURATION_MISMATCH = "prompt_duration_mismatch"   # fal.ai schema validation — retry with corrected duration
    PERMANENT = "permanent"              # Exhausted retries or unfixable
    BUDGET = "budget"                    # Budget exceeded — pause batch


class ShotPhase(str, Enum):
    """Which generation phase a shot is in."""
    PREVIS = "previs"
    KEYFRAME = "keyframe"
    VIDEO = "video"


# ── Configuration ──────────────────────────────────────────────────

@dataclass(frozen=True)
class RetryPolicy:
    """Per-failure-category retry configuration."""
    max_retries: int = 3
    base_backoff_seconds: float = 2.0
    max_backoff_seconds: float = 120.0
    backoff_multiplier: float = 2.0


DEFAULT_RETRY_POLICIES: dict[FailureCategory, RetryPolicy] = {
    FailureCategory.TRANSIENT: RetryPolicy(max_retries=5, base_backoff_seconds=5.0, max_backoff_seconds=120.0),
    FailureCategory.GATE_MECHANICAL: RetryPolicy(max_retries=3, base_backoff_seconds=0.0),
    FailureCategory.GATE_IDENTITY: RetryPolicy(max_retries=2, base_backoff_seconds=0.0),
    FailureCategory.GATE_WARDROBE: RetryPolicy(max_retries=2, base_backoff_seconds=0.0),
    FailureCategory.GATE_VIDEO_DRIFT: RetryPolicy(max_retries=0),  # Don't retry — flag for review
    FailureCategory.CONTENT_FILTER: RetryPolicy(max_retries=0),    # Needs human intervention
    FailureCategory.PROMPT_DURATION_MISMATCH: RetryPolicy(max_retries=2, base_backoff_seconds=0.0),
    FailureCategory.PERMANENT: RetryPolicy(max_retries=0),
    FailureCategory.BUDGET: RetryPolicy(max_retries=0),
}


@dataclass(frozen=True)
class RetryCostPolicy:
    """Per-pass retry cost ceiling. Attached to BatchConfig.

    When the next strategy's estimated cost would push cumulative retry
    spend on a single pass above `max_retry_spend_usd`, the StrategyEngine
    skips it and escalates. Keeps stubborn passes from draining the
    nightly budget on a single shot.
    """
    max_retry_spend_usd: float = 6.00
    warn_threshold_usd: float = 4.00


@dataclass(frozen=True)
class AutonomyConfig:
    """Thresholds for auto-approve vs flag-for-human."""
    enabled: bool = False
    require_all_gates_pass: bool = True
    require_gate_3: bool = False
    exclude_shot_types: tuple[str, ...] = ("two_shot", "complex")
    max_auto_approve_per_batch: int = 0  # 0 = unlimited when enabled


@dataclass(frozen=True)
class BatchConfig:
    """Configuration for a production batch."""
    project: str
    episode_id: str
    budget_usd: float = 25.0
    max_concurrent: int = 1           # Phase 1: sequential
    poll_interval_seconds: float = 5.0
    autonomy: AutonomyConfig = field(default_factory=AutonomyConfig)
    retry_policies: dict[FailureCategory, RetryPolicy] = field(
        default_factory=lambda: dict(DEFAULT_RETRY_POLICIES)
    )
    pass_filter: Optional[list[str]] = None   # If set, only these pass IDs are executed
    max_attempts_per_shot: int = 5
    # If set, budget_usd is ignored and the batch is exhausted once
    # *additional* spend within this run (total_cost - start_cost) exceeds
    # this amount. Intended for one-off pass tests where the episode already
    # has accumulated cost on the books.
    delta_budget_usd: Optional[float] = None
    retry_cost_policy: RetryCostPolicy = field(default_factory=RetryCostPolicy)


# ── Runtime State ──────────────────────────────────────────────────

@dataclass
class BatchState:
    """Mutable runtime state for a production batch."""
    batch_id: str = ""
    config: BatchConfig = None
    status: BatchStatus = BatchStatus.CREATED
    started_at: float = 0.0
    completed_at: float = 0.0
    total_cost: float = 0.0
    # Snapshot of total_cost at the moment the batch started. Used with
    # BatchConfig.delta_budget_usd to bound *additional* spend within this run.
    start_cost: float = 0.0
    shots_completed: int = 0
    shots_failed: int = 0
    shots_pending: int = 0
    shots_in_review: int = 0
    auto_approved: int = 0
    error_message: Optional[str] = None

    def to_dict(self) -> dict:
        return {
            "batch_id": self.batch_id,
            "project": self.config.project if self.config else "",
            "episode_id": self.config.episode_id if self.config else "",
            "status": self.status.value,
            "started_at": self.started_at,
            "completed_at": self.completed_at,
            "total_cost": round(self.total_cost, 4),
            "start_cost": round(self.start_cost, 4),
            "shots_completed": self.shots_completed,
            "shots_failed": self.shots_failed,
            "shots_pending": self.shots_pending,
            "shots_in_review": self.shots_in_review,
            "auto_approved": self.auto_approved,
            "error_message": self.error_message,
        }


@dataclass
class RetryRequest:
    """A queued retry with backoff and fix instructions."""
    shot_id: str
    failure_category: FailureCategory
    attempt_number: int
    retry_at: float                    # timestamp — don't retry before this
    fix_suggestion: Optional[dict] = None  # From FeedbackAgent
    error_message: Optional[str] = None
    original_model: Optional[str] = None

    @property
    def ready(self) -> bool:
        return time.time() >= self.retry_at


@dataclass
class ProvenanceRecord:
    """Full reproduction recipe for a single take."""
    take_id: str           # e.g. "EP001_SH012_T3"
    shot_id: str           # e.g. "EP001_SH012"
    episode_id: str
    project: str
    attempt: int
    timestamp: float = field(default_factory=time.time)

    # Generation parameters
    phase: str = ""         # "previs", "keyframe", "video"
    model: str = ""
    endpoint: str = ""
    prompt: str = ""
    negative_prompt: str = ""
    seed: Optional[int] = None
    params: dict = field(default_factory=dict)
    refs_used: list[dict] = field(default_factory=list)

    # Video-specific
    video_model: Optional[str] = None
    video_prompt: Optional[str] = None
    video_duration: Optional[int] = None
    start_frame_path: Optional[str] = None

    # Gate results
    gates: dict[str, dict] = field(default_factory=dict)

    # Cost
    cost: dict[str, float] = field(default_factory=dict)

    # Human review
    human_review: Optional[dict] = None

    # Lineage
    parent_take: Optional[str] = None
    change_reason: Optional[str] = None
    prompt_diff: Optional[str] = None

    def to_dict(self) -> dict:
        d = {
            "take_id": self.take_id,
            "shot_id": self.shot_id,
            "episode_id": self.episode_id,
            "project": self.project,
            "attempt": self.attempt,
            "timestamp": self.timestamp,
            "generation": {
                "phase": self.phase,
                "model": self.model,
                "endpoint": self.endpoint,
                "prompt": self.prompt,
                "negative_prompt": self.negative_prompt,
                "seed": self.seed,
                "params": self.params,
                "refs_used": self.refs_used,
            },
            "gates": self.gates,
            "cost": self.cost,
        }
        if self.video_model:
            d["video"] = {
                "model": self.video_model,
                "prompt": self.video_prompt,
                "duration": self.video_duration,
                "start_frame": self.start_frame_path,
            }
        if self.human_review:
            d["human_review"] = self.human_review
        if self.parent_take:
            d["lineage"] = {
                "parent_take": self.parent_take,
                "change_reason": self.change_reason,
                "prompt_diff": self.prompt_diff,
            }
        return d
