"""
retry_dispatcher.py — Categorize failures and queue retries with backoff.

The dispatcher does NOT execute retries. It:
1. Classifies a StepResult failure into a FailureCategory
2. Checks the retry budget (per-shot attempt count vs policy max)
3. Computes the next retry time (with backoff)
4. Returns a RetryRequest or None (permanent failure)

The production loop calls dispatch() after a failed step, then processes
the retry queue on each loop iteration.
"""

import logging
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional

from orchestrator.production_types import (
    FailureCategory,
    RetryPolicy,
    RetryRequest,
    DEFAULT_RETRY_POLICIES,
)

# Phase C: canonical pattern lists live in pipeline.core.failure_mode.
# Re-imports preserve the public alias for callers that imported these
# directly (production_loop:818 imports _TRANSIENT_PATTERNS).
#
# Path-order bootstrap: importing `pipeline.core.failure_mode` triggers
# `pipeline/core/__init__.py`, which imports `pipeline.core.dispatch`,
# which does `from core.paths import RECOIL_ROOT`. If PIPELINE_ROOT
# precedes RECOIL_ROOT on sys.path, `core` resolves to the
# partially-initialized `pipeline.core` package and the import dies with
# a circular-import error. Mirror the trick used by
# `pipeline/tests/conftest.py` and `pipeline/core/tests/conftest.py`:
# ensure RECOIL_ROOT is ahead of PIPELINE_ROOT so `core` lands on
# `recoil/core/` (the genuine `core.paths` module).
_RECOIL_ROOT_STR = str(Path(__file__).resolve().parent.parent.parent)
if _RECOIL_ROOT_STR in sys.path:
    sys.path.remove(_RECOIL_ROOT_STR)
sys.path.insert(0, _RECOIL_ROOT_STR)

from recoil.pipeline.core.failure_mode import (  # noqa: E402  # sys.path bootstrap above
    TRANSIENT_PATTERN_STRINGS as _TRANSIENT_PATTERNS,
    IDENTITY_PATTERNS as _IDENTITY_PATTERNS,
    WARDROBE_PATTERNS as _WARDROBE_PATTERNS,
    SCHEMA_PATTERNS as _SCHEMA_PATTERNS,
    classify_failure as _canonical_classify,
    failure_category_for,
    UnknownFailureEscalation,
    FailureMode,
)

logger = logging.getLogger(__name__)


def classify_failure(
    step_result,
    shot_data: Optional[dict] = None,
) -> FailureCategory:
    """Classify a failed StepResult into a FailureCategory.

    Phase C: thin wrapper around pipeline.core.failure_mode.classify_failure
    + failure_category_for. Public signature preserved — callers continue
    to receive a FailureCategory. UnknownFailureEscalation is caught and
    coarsened to FailureCategory.PERMANENT (matches pre-Phase-C behavior
    of falling through to "permanent" on unmatched inputs); the escalation
    itself logs at WARNING per Tenet 6.

    Args:
        step_result: Failed StepResult from StepRunner.
        shot_data: Optional shot record from ExecutionStore for context.

    Returns:
        The appropriate FailureCategory.
    """
    error = (step_result.error or "").lower()
    final_state = step_result.final_state or ""
    gv = step_result.gate_verdict

    # Budget check stays inline — was a top-level branch pre-Phase-C and
    # FailureMode.COST_OVERRUN coarsens to BUDGET. Both paths produce
    # BUDGET; inline match preserves byte-identical behavior.
    if "budget" in error:
        return FailureCategory.BUDGET

    # Pre-Phase-C ordering preserved inline: TRANSIENT before
    # CONTENT_FILTER before SCHEMA. The canonical classifier runs
    # CONTENT_FILTER before SCHEMA before TRANSIENT, which over-classifies
    # mixed errors like "429 ... 422" (test_schema_error_does_not_shadow_transient
    # expects TRANSIENT) and "HTTP 422: ... rejected" (the canonical
    # CONTENT_FILTER_PATTERNS now contains "rejected", which would shadow
    # the SCHEMA pattern "422" — pre-Phase-C the narrow content-filter
    # set never had "rejected"). Pre-checks below preserve byte-identical
    # outputs for the existing retry_dispatcher test suite.
    for pattern in _TRANSIENT_PATTERNS:
        if pattern.lower() in error:
            return FailureCategory.TRANSIENT

    for pattern in _SCHEMA_PATTERNS:
        if pattern in error:
            return FailureCategory.PROMPT_DURATION_MISMATCH

    try:
        mode, _ = _canonical_classify(
            error_text=error,
            gate_verdict=gv,
            http_status=None,
            escalate_unknown=False,
            caller="retry_dispatcher.classify_failure",
        )
    except UnknownFailureEscalation:
        return FailureCategory.PERMANENT

    if mode is FailureMode.UNKNOWN:
        return FailureCategory.PERMANENT
    if mode is FailureMode.NONE:
        return FailureCategory.PERMANENT

    # Tier-2 gate-based check stays inline — coarsening rules differ
    # for "semantic" final_state (identity vs wardrobe disambiguation).
    if "semantic" in final_state:
        if gv and gv.details:
            mismatches = gv.details.get("mismatches", [])
            for m in mismatches:
                cat_str = m.get("category", "").lower()
                evidence = m.get("visual_evidence", "").lower()
                if any(p in cat_str or p in evidence for p in _WARDROBE_PATTERNS):
                    return FailureCategory.GATE_WARDROBE
                if any(p in cat_str or p in evidence for p in _IDENTITY_PATTERNS):
                    return FailureCategory.GATE_IDENTITY
        for pattern in _WARDROBE_PATTERNS:
            if pattern in error:
                return FailureCategory.GATE_WARDROBE
        for pattern in _IDENTITY_PATTERNS:
            if pattern in error:
                return FailureCategory.GATE_IDENTITY
        return FailureCategory.GATE_IDENTITY

    if "mechanical" in final_state:
        return FailureCategory.GATE_MECHANICAL

    if step_result.gate_verdict and step_result.gate_verdict.deferred:
        return FailureCategory.GATE_VIDEO_DRIFT

    try:
        return failure_category_for(mode)
    except (ValueError, UnknownFailureEscalation):
        return FailureCategory.PERMANENT


class RetryDispatcher:
    """Queues and manages retries with per-category backoff.

    Not thread-safe — designed for single-threaded production loop.
    """

    def __init__(
        self,
        policies: Optional[dict[FailureCategory, RetryPolicy]] = None,
    ):
        self._policies = policies or dict(DEFAULT_RETRY_POLICIES)
        self._queue: list[RetryRequest] = []
        self._attempt_counts: dict[str, dict[FailureCategory, int]] = defaultdict(
            lambda: defaultdict(int)
        )

    @property
    def queue_size(self) -> int:
        return len(self._queue)

    @property
    def queue(self) -> list[RetryRequest]:
        return list(self._queue)

    def dispatch(
        self,
        shot_id: str,
        step_result,
        shot_data: Optional[dict] = None,
        fix_suggestion: Optional[dict] = None,
        authoritative_attempts: Optional[int] = None,
    ) -> Optional[RetryRequest]:
        """Classify failure and queue a retry, or return None if permanent.

        Args:
            shot_id: Shot identifier.
            step_result: Failed StepResult.
            shot_data: Full shot record for context.
            fix_suggestion: Optional fix from FeedbackAgent.
            authoritative_attempts: Canonical attempt count from ExecutionStore.
                Used for RetryRequest.attempt_number instead of internal counter.

        Returns:
            RetryRequest if retry queued, None if failure is permanent.
        """
        category = classify_failure(step_result, shot_data)
        policy = self._policies.get(category, RetryPolicy(max_retries=0))

        # Check retry budget
        attempts = self._attempt_counts[shot_id][category]
        if attempts >= policy.max_retries:
            logger.info(
                "Shot %s: %s retry budget exhausted (%d/%d) → permanent",
                shot_id, category.value, attempts, policy.max_retries,
            )
            return None

        # Compute backoff
        backoff = min(
            policy.base_backoff_seconds * (policy.backoff_multiplier ** attempts),
            policy.max_backoff_seconds,
        )
        retry_at = time.time() + backoff

        # Track attempt
        self._attempt_counts[shot_id][category] += 1
        # Use authoritative count from store if provided, else internal count
        attempt_number = authoritative_attempts if authoritative_attempts is not None else sum(self._attempt_counts[shot_id].values())

        request = RetryRequest(
            shot_id=shot_id,
            failure_category=category,
            attempt_number=attempt_number,
            retry_at=retry_at,
            fix_suggestion=fix_suggestion,
            error_message=step_result.error,
            original_model=step_result.model,
        )
        self._queue.append(request)

        logger.info(
            "Shot %s: queued %s retry #%d (backoff %.1fs, total attempts %d)",
            shot_id, category.value, attempts + 1, backoff, attempt_number,
        )
        return request

    def get_ready(self) -> list[RetryRequest]:
        """Return all retry requests that are past their backoff time.

        Removes returned requests from the queue.
        """
        ready = [r for r in self._queue if r.ready]
        self._queue = [r for r in self._queue if not r.ready]
        return ready

    def cancel(self, shot_id: str) -> int:
        """Cancel all pending retries for a shot. Returns count removed."""
        before = len(self._queue)
        self._queue = [r for r in self._queue if r.shot_id != shot_id]
        removed = before - len(self._queue)
        if removed:
            logger.info("Cancelled %d retries for shot %s", removed, shot_id)
        return removed

    def total_attempts(self, shot_id: str) -> int:
        """Total retry attempts across all categories for a shot."""
        return sum(self._attempt_counts.get(shot_id, {}).values())

    def reset(self) -> None:
        """Clear all queued retries and attempt counts."""
        self._queue.clear()
        self._attempt_counts.clear()
