"""
batch_manager.py — Queue episodes, track batch progress, termination criteria.

A batch represents "generate all shots for episode N" or "re-run these specific
takes." The BatchManager:
- Initializes batch state from ExecutionStore
- Tracks what succeeded, failed, is pending, is in-flight
- Monitors total cost vs budget
- Determines termination (all complete, budget exhausted, manual pause)
- Persists state to JSON for crash recovery
"""

import contextlib
import fcntl
import json
import logging
import os
import tempfile
import threading
import time
import uuid
from pathlib import Path
from typing import Optional

from orchestrator.production_types import BatchConfig, BatchState, BatchStatus

logger = logging.getLogger(__name__)


# ── Actionable statuses (shots the loop should process) ────────────
ACTIONABLE_STATUSES = frozenset({
    # Previs layer
    "previs_pending",
    "previs_failed",
    # Keyframe layer
    "keyframe_pending",
    "keyframe_mechanical_failed",
    # Video layer
    "video_pending",
    "video_failed",
    "video_mechanical_failed",
    # General
    "failed",
})

# Terminal statuses (no more work needed)
TERMINAL_STATUSES = frozenset({
    "approved",
    "rejected",
    "abandoned",
    "skipped",
    "video_complete",
})

# Statuses awaiting human review
REVIEW_STATUSES = frozenset({
    "previs_generated",
    "keyframe_generated",
    "keyframe_semantic_failed",
    "video_semantic_failed",
    "video_ready",
})


class BatchManager:
    """Tracks batch-level progress and termination.

    Cross-process safe: writes are guarded by an intra-process threading.Lock
    plus a cross-process fcntl.flock on `<state_dir>/.batch.lock`.
    """

    def __init__(
        self,
        config: BatchConfig,
        store,
        state_dir: Optional[Path] = None,
    ):
        """
        Args:
            config: Batch configuration.
            store: ExecutionStore instance.
            state_dir: Directory for batch state JSON. Defaults to
                projects/{project}/state/visual/batches/
        """
        self._config = config
        self._store = store

        if state_dir:
            self._state_dir = state_dir
        else:
            from recoil.core.paths import ProjectPaths
            self._state_dir = (
                ProjectPaths.for_project(config.project).visual_state_dir / "batches"
            )
        self._state_dir.mkdir(parents=True, exist_ok=True)

        self._lock = threading.Lock()
        self._flock_path = self._state_dir / ".batch.lock"

        # Initialize or recover state
        self._state = self._recover_or_create()

    @contextlib.contextmanager
    def _locked(self):
        """Acquire intra-process + cross-process locks for a critical section.

        Order matches the canonical workspace/state.py pattern: threading.Lock
        outer (cheap intra-process bounce), fcntl.flock inner. Releases both
        on exit, including exceptions.
        """
        with self._lock:
            self._state_dir.mkdir(parents=True, exist_ok=True)
            lock_fd = os.open(str(self._flock_path), os.O_CREAT | os.O_RDWR)
            try:
                fcntl.flock(lock_fd, fcntl.LOCK_EX)
                yield
            finally:
                fcntl.flock(lock_fd, fcntl.LOCK_UN)
                os.close(lock_fd)

    @property
    def state(self) -> BatchState:
        return self._state

    @property
    def config(self) -> BatchConfig:
        return self._config

    def _recover_or_create(self) -> BatchState:
        """Look for an existing active batch state file, or create new."""
        # Check for existing active batch
        for path in self._state_dir.glob("batch_*.json"):
            try:
                data = json.loads(path.read_text(encoding="utf-8"))
                if data.get("status") in (BatchStatus.RUNNING.value, BatchStatus.PAUSED.value):
                    if (data.get("project") == self._config.project
                            and data.get("episode_id") == self._config.episode_id):
                        logger.info("Recovering batch %s from %s", data["batch_id"], path)
                        state = BatchState(
                            batch_id=data["batch_id"],
                            config=self._config,
                            status=BatchStatus(data["status"]),
                            started_at=data.get("started_at", 0),
                            total_cost=data.get("total_cost", 0),
                            start_cost=data.get("start_cost", 0),
                            shots_completed=data.get("shots_completed", 0),
                            shots_failed=data.get("shots_failed", 0),
                            auto_approved=data.get("auto_approved", 0),
                        )
                        return state
            except (json.JSONDecodeError, KeyError, OSError) as e:
                logger.warning("Skipping corrupt batch state %s: %s", path, e)

        # Create new batch
        batch_id = f"batch_{self._config.episode_id}_{uuid.uuid4().hex[:8]}"
        return BatchState(
            batch_id=batch_id,
            config=self._config,
            status=BatchStatus.CREATED,
        )

    def start(self) -> None:
        """Mark batch as running.

        Also snapshots start_cost (the accumulated cost already on the books
        before this run) so delta_budget_usd can bound *additional* spend.
        Idempotent on re-start of the same batch: start_cost is only set
        the first time to preserve the original baseline across crashes.
        """
        self._state.status = BatchStatus.RUNNING
        self._state.started_at = time.time()
        # Refresh counts first so start_cost reflects the real accumulated total.
        self.refresh_counts()
        if self._state.start_cost == 0.0:
            self._state.start_cost = self._state.total_cost
        self._save_state()

    def refresh_counts(self) -> None:
        """Refresh shot counts from ExecutionStore."""
        shots = self._store.get_shots_by_episode(self._config.episode_id)

        pending = 0
        completed = 0
        failed = 0
        in_review = 0
        total_cost = 0.0

        for shot in shots:
            status = shot.get("status", "previs_pending")
            cost = shot.get("cost_incurred", 0) or 0
            total_cost += cost

            if status == "abandoned":
                failed += 1  # Count abandoned as failed, not completed
            elif status in ACTIONABLE_STATUSES:
                pending += 1
            elif status in TERMINAL_STATUSES:
                completed += 1
            elif status in REVIEW_STATUSES:
                in_review += 1
            elif status in ("previs_generating", "keyframe_generating",
                           "video_submitted", "video_processing", "video_downloading"):
                pending += 1  # In-flight counts as pending
            else:
                # Unknown or transition state — count as pending
                pending += 1

        self._state.shots_pending = pending
        self._state.shots_completed = completed
        self._state.shots_failed = failed
        self._state.shots_in_review = in_review
        self._state.total_cost = total_cost

    def get_actionable_shots(self, max_batch: int = 10) -> list[dict]:
        """Get shots ready for processing, ordered by priority.

        Priority: retries first, then pending previs, then keyframe, then video.
        """
        shots = self._store.get_shots_by_episode(self._config.episode_id)

        # Filter to actionable
        actionable = [s for s in shots if s.get("status") in ACTIONABLE_STATUSES]

        # Filter by max attempts
        actionable = [
            s for s in actionable
            if (s.get("attempts", 0) or 0) < self._config.max_attempts_per_shot
        ]

        # Sort by priority: failed retries first, then by pipeline phase
        def _priority(shot):
            status = shot.get("status", "")
            if "failed" in status:
                return (0, shot.get("shot_id", ""))
            if "previs" in status:
                return (1, shot.get("shot_id", ""))
            if "keyframe" in status:
                return (2, shot.get("shot_id", ""))
            if "video" in status:
                return (3, shot.get("shot_id", ""))
            return (4, shot.get("shot_id", ""))

        actionable.sort(key=_priority)
        return actionable[:max_batch]

    def is_budget_exhausted(self) -> bool:
        """Check if batch cost exceeds budget.

        If BatchConfig.delta_budget_usd is set, bounds *additional* spend
        (total_cost - start_cost) rather than absolute total_cost. Useful
        for one-off tests against episodes with accumulated prior cost.
        """
        delta_cap = self._config.delta_budget_usd
        if delta_cap is not None:
            spent_this_run = max(0.0, self._state.total_cost - self._state.start_cost)
            return spent_this_run >= delta_cap
        return self._state.total_cost >= self._config.budget_usd

    def is_complete(self) -> bool:
        """Check if all shots are in terminal or review states (no more actionable)."""
        self.refresh_counts()
        return self._state.shots_pending == 0

    def check_termination(self) -> Optional[str]:
        """Check all termination criteria. Returns reason string or None.

        Call this each loop iteration. Returns:
        - "budget_exhausted" if cost >= budget
        - "complete" if no actionable shots remain
        - "paused" if status was set to PAUSED externally
        - None if batch should continue
        """
        if self._state.status == BatchStatus.PAUSED:
            return "paused"

        self.refresh_counts()

        if self.is_budget_exhausted():
            self._state.status = BatchStatus.BUDGET_EXHAUSTED
            self._save_state()
            return "budget_exhausted"

        if self.is_complete():
            self._state.status = BatchStatus.COMPLETE
            self._state.completed_at = time.time()
            self._save_state()
            return "complete"

        return None

    def pause(self, reason: str = "") -> None:
        """Pause the batch. Production loop should stop processing."""
        self._state.status = BatchStatus.PAUSED
        self._state.error_message = reason
        self._save_state()
        logger.info("Batch %s paused: %s", self._state.batch_id, reason)

    def record_completion(self, shot_id: str, auto_approved: bool = False) -> None:
        """Record a shot completion.

        Note: shots_completed is recalculated by refresh_counts() from the store,
        so we do not manually increment it here to avoid double-accounting.
        """
        if auto_approved:
            self._state.auto_approved += 1
        self._save_state()

    def record_failure(self, shot_id: str) -> None:
        """Record a permanent shot failure.

        Note: shots_failed is recalculated by refresh_counts() from the store,
        so we do not manually increment it here to avoid double-accounting.
        """
        self._save_state()

    def summary(self) -> dict:
        """Current batch summary."""
        self.refresh_counts()
        elapsed = time.time() - self._state.started_at if self._state.started_at else 0
        delta_cap = self._config.delta_budget_usd
        if delta_cap is not None:
            spent_this_run = max(0.0, self._state.total_cost - self._state.start_cost)
            budget_remaining = round(delta_cap - spent_this_run, 4)
        else:
            budget_remaining = round(self._config.budget_usd - self._state.total_cost, 4)
        return {
            **self._state.to_dict(),
            "elapsed_seconds": round(elapsed, 1),
            "budget_remaining": budget_remaining,
        }

    def _save_state(self) -> None:
        """Persist batch state to JSON for crash recovery.

        Cross-process safe via fcntl.flock + tempfile + os.replace.
        """
        path = self._state_dir / f"{self._state.batch_id}.json"
        content = json.dumps(self._state.to_dict(), indent=2, default=str)

        with self._locked():
            fd, tmp = tempfile.mkstemp(
                dir=str(self._state_dir), prefix=".batch_", suffix=".tmp"
            )
            try:
                with os.fdopen(fd, "w", encoding="utf-8") as f:
                    f.write(content)
                os.replace(tmp, str(path))
            except Exception:
                try:
                    os.unlink(tmp)
                except OSError:
                    pass
                raise
