"""pass_store.py — Per-project JSON file-per-episode backend for coverage pass state.

Each project stores pass state as one JSON file per episode at:
    projects/{project}/state/visual/passes/ep_{NNN}_pass_state.json

Single-machine safe (fcntl + atomic replace). Cross-machine safety via the
project write lease (state_lease.py) — fcntl does NOT span machines.
- fcntl.flock prevents lost updates from concurrent read-modify-write across processes.
- os.replace prevents torn writes (atomic rename on POSIX).
- Both required; either alone is insufficient.
- Lock file: <passes_dir>/.passes.lock (auto-created).

Reference: workspace/state.py:91-113 (canonical pattern).

Parallel to ExecutionStore — does NOT modify or integrate with it.
"""

import contextlib
import fcntl
import json
import logging
import os
import re
import tempfile
import threading
import time
from pathlib import Path
from typing import Optional

from recoil.core.paths import ProjectPaths
from recoil.execution.state_lease import ensure_write_lease
from recoil.pipeline._lib.schema_versions import PASS_STORE_SCHEMA_VERSION
from recoil.pipeline.core.cost import read_cost_from_record_safe

logger = logging.getLogger(__name__)


def _project_passes_dir(project: str) -> Path:
    """Return the passes directory for a given project."""
    return ProjectPaths.for_project(project).passes_dir


def _episode_file(passes_dir: Path, episode_id: str) -> Path:
    """Return the JSON file path for an episode's pass state.

    Normalises episode_id: 'EP001' -> 'ep_001_pass_state.json'.
    """
    # Extract episode number from ID like 'EP001', 'ep_002', 'EP012', etc.
    m = re.match(r"[Ee][Pp](\d+)", episode_id)
    if m:
        ep_num = int(m.group(1))
        return passes_dir / f"ep_{ep_num:03d}_pass_state.json"
    # Fallback: use the raw episode_id
    safe_id = re.sub(r"[^a-zA-Z0-9_-]", "_", episode_id)
    return passes_dir / f"{safe_id}_pass_state.json"


def _episode_id_from_pass_id(pass_id: str) -> str:
    """Extract episode_id from a pass_id like 'EP001_PASS_003_L_SADIE_B'."""
    m = re.match(r"(EP\d+)", pass_id, re.IGNORECASE)
    if m:
        return m.group(1).upper()
    raise ValueError(
        f"Malformed pass_id '{pass_id}': expected format starting with 'EP' "
        f"followed by digits (e.g. 'EP001_PASS_003_L_SADIE_B')"
    )


class PassStore:
    """Per-project JSON file-per-episode store for coverage pass state.

    Each episode's passes are stored in a single JSON file. Writes are
    single-machine safe (fcntl + atomic replace). Cross-machine safety is via
    the project write lease (state_lease.py) — fcntl does NOT span machines.

    Args:
        project: Project name (e.g. "leviathan"). Determines storage location.
    """

    def __init__(self, project: str):
        self.project = project
        self._passes_dir = _project_passes_dir(project)
        self._lock = threading.Lock()
        self._flock_path = self._passes_dir / ".passes.lock"
        # Directory is created on first write, not on init

    @property
    def passes_dir(self) -> Path:
        return self._passes_dir

    # ── Atomic I/O ────────────────────────────────────────────────

    def _ensure_dir(self) -> None:
        """Create the passes directory if it doesn't exist."""
        self._passes_dir.mkdir(parents=True, exist_ok=True)

    def _read_episode_file(self, episode_id: str) -> dict:
        """Read an episode's pass state file. Returns empty dict if not found."""
        path = _episode_file(self._passes_dir, episode_id)
        if not path.is_file():
            return {}
        try:
            return json.loads(path.read_text(encoding="utf-8"))
        except (json.JSONDecodeError, IOError) as e:
            logger.warning("Failed to read pass state for %s: %s", episode_id, e)
            return {}

    def _write_episode_file_locked(self, episode_id: str, data: dict) -> None:
        """Atomic write via tempfile + os.replace(). Caller must hold the lock."""
        # setdefault preserves any prior on-disk version so a future constant
        # bump does not silently rewrite legacy records.
        data.setdefault("schema_version", PASS_STORE_SCHEMA_VERSION)
        path = _episode_file(self._passes_dir, episode_id)
        fd, tmp = tempfile.mkstemp(
            dir=str(self._passes_dir), prefix=f".tmp_{episode_id}_", suffix=".json"
        )
        try:
            with os.fdopen(fd, "w", encoding="utf-8") as f:
                json.dump(data, f, indent=2, default=str)
            os.replace(tmp, str(path))
        except Exception:
            try:
                os.unlink(tmp)
            except OSError:
                pass
            raise

    @contextlib.contextmanager
    def _locked(self):
        """Acquire intra-process + cross-process locks for a critical section.

        Order matches the canonical workspace/state.py pattern: threading.Lock
        outer (cheap intra-process bounce), fcntl.flock inner. Releases both
        on exit, including exceptions.
        """
        with self._lock:
            ensure_write_lease(self.project)
            self._ensure_dir()
            lock_fd = os.open(str(self._flock_path), os.O_CREAT | os.O_RDWR)
            try:
                fcntl.flock(lock_fd, fcntl.LOCK_EX)
                yield
            finally:
                fcntl.flock(lock_fd, fcntl.LOCK_UN)
                os.close(lock_fd)

    # ── Pass CRUD ─────────────────────────────────────────────────

    def create_pass(self, pass_id: str, segment_shot_ids: list[str]) -> None:
        """Create a new pass record.

        Args:
            pass_id: Unique pass identifier (e.g. 'EP001_PASS_003_L_SADIE_B').
            segment_shot_ids: Ordered list of shot IDs this pass covers.
        """
        episode_id = _episode_id_from_pass_id(pass_id)
        now = time.time()
        record = {
            "pass_id": pass_id,
            "status": "pending",
            "segment_shot_ids": list(segment_shot_ids),
            "video_path": None,
            "cost_usd": 0.0,
            "segment_timestamps": {},
            "expected_cuts": max(0, len(segment_shot_ids) - 1),
            "detected_cuts": 0,
            "confirmed_timestamps": {},
            "extraction_method": None,  # "auto_verified" | "human_confirmed" | None
            "scene_detect_threshold": None,
            "scene_detection_raw": [],
            "cuts_diverged": False,
            # AMEND_SPEC_01 Phase 1 — alignment + happy-accident tracking.
            # alignment_score:    aligned_count / expected_count on the over-
            #                     detection path; None when irrelevant.
            # model_added_count:  number of detected cuts preserved as alt
            #                     coverage ("model-added" segments).
            "alignment_score": None,
            "model_added_count": 0,
            "lineage_ref": None,  # reserved for Phase 8 reject-and-regenerate
            "takes": [],
            "retry_strategy": None,
            "created_at": now,
            "updated_at": now,
        }
        with self._locked():
            state = self._read_episode_file(episode_id)
            if "passes" not in state:
                state["passes"] = {}
            state["passes"][pass_id] = record
            self._write_episode_file_locked(episode_id, state)

    def update_pass(self, pass_id: str, **fields) -> None:
        """Merge fields into an existing pass record.

        Args:
            pass_id: Pass identifier.
            **fields: Fields to update (merged into existing record).
        """
        episode_id = _episode_id_from_pass_id(pass_id)
        with self._locked():
            state = self._read_episode_file(episode_id)
            passes = state.get("passes", {})
            if pass_id not in passes:
                logger.warning("Pass %s not found for update", pass_id)
                return
            record = passes[pass_id]
            for key, value in fields.items():
                if key == "cost_usd" and value is not None:
                    # Accumulate cost like ExecutionStore accumulates cost_incurred
                    record["cost_usd"] = read_cost_from_record_safe(record) + value
                elif key in (
                    "segment_timestamps",
                    "confirmed_timestamps",
                ) and isinstance(value, dict):
                    # Merge segment/confirmed timestamps (never clobber existing keys)
                    existing_ts = record.get(key, {}) or {}
                    if isinstance(existing_ts, dict):
                        existing_ts.update(value)
                        record[key] = existing_ts
                    else:
                        record[key] = value
                else:
                    record[key] = value
            record["updated_at"] = time.time()
            passes[pass_id] = record
            state["passes"] = passes
            self._write_episode_file_locked(episode_id, state)

    def get_pass(self, pass_id: str) -> Optional[dict]:
        """Read a single pass record.

        Args:
            pass_id: Pass identifier.

        Returns:
            Pass record dict, or None if not found.
        """
        episode_id = _episode_id_from_pass_id(pass_id)
        state = self._read_episode_file(episode_id)
        return state.get("passes", {}).get(pass_id)

    def append_pass_take(self, pass_id: str, take: dict) -> None:
        """Append a take record to a pass's takes list.

        Args:
            pass_id: Pass identifier.
            take: Take record dict.
        """
        episode_id = _episode_id_from_pass_id(pass_id)
        with self._locked():
            state = self._read_episode_file(episode_id)
            passes = state.get("passes", {})
            if pass_id not in passes:
                logger.warning("Pass %s not found for append_pass_take", pass_id)
                return
            record = passes[pass_id]
            if not isinstance(record.get("takes"), list):
                record["takes"] = []
            record["takes"].append(take)
            record["updated_at"] = time.time()
            passes[pass_id] = record
            state["passes"] = passes
            self._write_episode_file_locked(episode_id, state)

    def link_pass_to_shots(self, pass_id: str, shot_id_map: dict) -> None:
        """Update constituent shots with coverage_pass_id.

        Writes `coverage_pass_id` into each shot's record in the pass state file.
        This does NOT modify ExecutionStore — that's the caller's responsibility
        if cross-store linking is needed.

        Args:
            pass_id: Pass identifier.
            shot_id_map: Dict mapping shot_id -> any additional metadata to store.
        """
        episode_id = _episode_id_from_pass_id(pass_id)
        with self._locked():
            state = self._read_episode_file(episode_id)
            # Store the linkage in a top-level 'shot_links' dict
            if "shot_links" not in state:
                state["shot_links"] = {}
            for shot_id, metadata in shot_id_map.items():
                state["shot_links"][shot_id] = {
                    "coverage_pass_id": pass_id,
                    **(metadata if isinstance(metadata, dict) else {}),
                }
            self._write_episode_file_locked(episode_id, state)

    def list_passes(self, episode_id: str) -> list[dict]:
        """Return all passes for an episode.

        Args:
            episode_id: Episode identifier (e.g. 'EP001').

        Returns:
            List of pass record dicts, sorted by pass_id.
        """
        state = self._read_episode_file(episode_id)
        passes = state.get("passes", {})
        return sorted(passes.values(), key=lambda p: p.get("pass_id", ""))

    def next_pass_counter(self, episode_id: str) -> int:
        """Return the next available pass counter for an episode.

        Scans existing pass_ids of the form EP{NNN}_PASS_{CCC}_SH... and
        returns max(counter) + 1. Returns 1 if no passes exist.
        """
        passes = self.list_passes(episode_id)
        max_counter = 0
        # Match both old-format (pre-migration) and new-format pass_ids.
        # During the migration window, both forms may coexist in PassStore.
        pattern = re.compile(r"^EP\d+_PASS_(\d{3})")
        for p in passes:
            pid = p.get("pass_id", "")
            m = pattern.match(pid)
            if m:
                c = int(m.group(1))
                if c > max_counter:
                    max_counter = c
        return max_counter + 1

    def close(self) -> None:
        """No-op. JSON files don't need connection management."""
        pass


__all__ = [
    # Public symbols (Phase D — MF-3 + DEBT-9).
    "PassStore",
]
