"""Phase 3 data models for run.shot / run.episode orchestration.

CoveragePassContext: sibling-aware shot grouping for coverage passes.
StopOnReview: 3-level enum controlling abort behavior on review queue entries.
OpResult: terminal result from run.shot (NEVER raises -- D6).
EpisodeResult: aggregate result from run.episode.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any


class StopOnReview(str, Enum):
    """Controls how run.episode handles shots entering the review queue.

    NEVER:          Overnight default -- keep running, triage in the morning.
    ON_HARD_FAIL:   Abort on ICU escalation, crash, or style-anchor failure.
    ON_ANY_REVIEW:  Abort on any shot entering the review queue.
    """
    NEVER = "never"
    ON_HARD_FAIL = "on_hard_fail"
    ON_ANY_REVIEW = "on_any_review"


# Terminal statuses for run.shot (7 values, no partial_success)
TERMINAL_STATUSES = frozenset({
    "ok",
    "budget_exhausted_success",
    "needs_review",
    "budget_exhausted",
    "attempts_exhausted",
    "icu_escalated",
    "crashed",
})

# Statuses that are skipped on --resume (already finished or need manual action)
SKIP_ON_RESUME = frozenset({
    "ok",
    "budget_exhausted_success",
    "needs_review",
    "attempts_exhausted",
    "icu_escalated",
})

# Statuses that are retried on --resume
RETRY_ON_RESUME = frozenset({
    "budget_exhausted",
    "crashed",
})


@dataclass
class CoveragePassContext:
    """Sibling-aware context for coverage pass shots.

    A coverage pass groups related shots (primary + alternates at different
    sizes/angles). run.episode builds these from shot plan grouping.
    """
    coverage_pass_id: str                    # e.g. "SC01_COVERAGE_A"
    sibling_shot_ids: list[str]              # all shots in this pass
    this_shot_role: str                      # "primary" | "coverage_ws" | "coverage_cu" etc.
    completed_siblings: dict[str, str]       # {shot_id: terminal_status}
    pass_min_success: int = 0                # minimum clean shots for pass viability (default: N-1)

    def to_dict(self) -> dict:
        return {
            "coverage_pass_id": self.coverage_pass_id,
            "sibling_shot_ids": list(self.sibling_shot_ids),
            "this_shot_role": self.this_shot_role,
            "completed_siblings": dict(self.completed_siblings),
            "pass_min_success": self.pass_min_success,
        }


@dataclass
class OpResult:
    """Terminal result of a single run.shot invocation.

    INVARIANT: run.shot NEVER raises on expected failure (D6).
    All 7 terminal statuses are returned as OpResult data.
    """
    status: str                              # one of TERMINAL_STATUSES
    shot_id: str
    op_id: str                               # "op_<uuid7.hex[:12]>"
    output_path: str | None = None
    cost_usd: float = 0.0
    attempts: int = 0
    failure_mode: str | None = None          # FailureMode enum value
    validation_notes: list[str] = field(default_factory=list)
    review_queue_id: str | None = None
    coverage_context: dict | None = None

    def __post_init__(self):
        if self.status not in TERMINAL_STATUSES:
            raise ValueError(
                f"OpResult.status must be one of {TERMINAL_STATUSES}, "
                f"got '{self.status}'"
            )


@dataclass
class EpisodeResult:
    """Aggregate result from run.episode."""
    run_id: str
    episode_id: str
    total_shots: int = 0
    completed: int = 0
    by_status: dict[str, int] = field(default_factory=dict)
    total_cost_usd: float = 0.0
    budget_remaining_usd: float = 0.0
    aborted: bool = False
    abort_reason: str | None = None
    style_anchors: dict[str, Path] = field(default_factory=dict)
    review_queue_count: int = 0
    shot_results: list[OpResult] = field(default_factory=list)

    @property
    def style_anchor_path(self) -> Path | None:
        """Backward-compat: return the first style anchor, or None."""
        if not self.style_anchors:
            return None
        return next(iter(self.style_anchors.values()))

    def get_style_anchor_for_scene(self, scene_key: str) -> Path | None:
        """Look up style anchor for a scene, with fallback chain."""
        if scene_key in self.style_anchors:
            return self.style_anchors[scene_key]
        return self.style_anchors.get("episode")

    def morning_summary(self) -> str:
        """Human-readable morning triage summary."""
        lines = [
            f"Episode {self.episode_id} -- Run {self.run_id}",
            f"  Shots: {self.completed}/{self.total_shots}",
            f"  Cost: ${self.total_cost_usd:.2f} (${self.budget_remaining_usd:.2f} remaining)",
        ]
        for status, count in sorted(self.by_status.items()):
            lines.append(f"  {status}: {count}")
        if self.aborted:
            lines.append(f"  ABORTED: {self.abort_reason}")
        if self.review_queue_count:
            lines.append(f"  Review queue: {self.review_queue_count} entries")
        if self.style_anchors:
            lines.append(f"  Style anchors: {len(self.style_anchors)} ({', '.join(self.style_anchors.keys())})")

        # Validation notes rollup
        all_notes = []
        for r in self.shot_results:
            for note in r.validation_notes:
                all_notes.append(f"  [{r.shot_id}] {note}")
        if all_notes:
            lines.append("  Validation notes:")
            lines.extend(all_notes[:20])
            if len(all_notes) > 20:
                lines.append(f"  ... and {len(all_notes) - 20} more")

        return "\n".join(lines)
