"""Adapter registry + routing.

Discovers adapter modules in this package, loads them, and maps
(model_id, requirement) -> adapter instance per provider_strategy.json.

Routing precedence (highest wins):
  1. RECOIL_PROVIDER_OVERRIDE env var (forces provider_id for ALL models)
  2. capability_exceptions in provider_strategy.json (per-capability override)
  3. primary declared in provider_strategy.json
  4. (on failure) fallback declared in provider_strategy.json

Test mode: RECOIL_PROVIDER_MODE=test swaps real adapter modules for
`providers.testing.mock_<provider>` variants (same Protocol surface).
"""

from __future__ import annotations

import importlib
import logging
import os
from pathlib import Path
from typing import Optional

from recoil.execution.providers.base import (
    ProviderAdapter,
    ProviderCapabilityError,
    UnifiedVideoPayload,
)

logger = logging.getLogger(__name__)


# Auto-discovered adapter modules expose `ADAPTER = <instance>`.
_ADAPTER_MODULE_NAMES = ("fal", "atlas", "piapi", "google", "wan", "kling", "flora", "comfyui")
_TEST_MODULE_NAMES = (
    "mock_fal",
    "mock_atlas",
    "mock_piapi",
    "mock_google",
    "mock_wan",
    "mock_kling",
    "mock_flora",
    "mock_comfyui",
)


def _strategy_path() -> Path:
    return Path(__file__).resolve().parents[2] / "config" / "provider_strategy.json"


def _is_test_mode() -> bool:
    return (os.environ.get("RECOIL_PROVIDER_MODE") or "").lower() == "test"


# ── ProviderStrategyCache ─────────────────────────────────────────────
class ProviderStrategyCache:
    """Provider-strategy + adapter cache — instance-attached state.

    Phase D MF-11: was module-level dicts (``_STRATEGY_CACHE``,
    ``_ADAPTER_CACHE``) with global mutation; promoted to instance state
    to support hot-reload + test isolation. The module-level free
    functions (:func:`load_strategy`, :func:`_load_adapter`,
    :func:`reset_caches_for_tests`) are thin wrappers around
    :data:`_default_provider_cache`.

    Hot-reload tooling (Console v2 dev loop) can construct a fresh
    instance and swap.

    Thread-safety: NOT thread-safe. Callers serialize via process-level
    bootstrap.
    """

    def __init__(self) -> None:
        self._strategy_cache: Optional[dict] = None
        self._adapter_cache: dict[str, ProviderAdapter] = {}

    def load_strategy(self, *, force: bool = False) -> dict:
        """Load provider_strategy.json (cached). Body verbatim from
        the pre-MF-11 module-level :func:`load_strategy`."""
        if self._strategy_cache is not None and not force:
            return self._strategy_cache
        path = _strategy_path()
        if not path.is_file():
            logger.warning(
                "provider_strategy.json not found at %s — using empty strategy", path
            )
            self._strategy_cache = {}
            return self._strategy_cache
        from recoil.core.config_schema import validate_and_load
        self._strategy_cache = validate_and_load(path, "provider_strategy")
        return self._strategy_cache

    def load_adapter(self, provider_id: str) -> ProviderAdapter:
        """Import adapter module and return its ADAPTER instance. Body
        verbatim from the pre-MF-11 module-level :func:`_load_adapter`."""
        cache_key = f"{'test:' if _is_test_mode() else 'live:'}{provider_id}"
        if cache_key in self._adapter_cache:
            return self._adapter_cache[cache_key]

        if _is_test_mode():
            module_name = f"recoil.execution.providers.testing.mock_{provider_id}"
        else:
            module_name = f"recoil.execution.providers.{provider_id}"

        try:
            mod = importlib.import_module(module_name)
        except ModuleNotFoundError as e:
            raise ValueError(f"Unknown provider '{provider_id}': {e}") from e

        adapter = getattr(mod, "ADAPTER", None)
        if adapter is None:
            raise ValueError(f"Module {module_name} defines no ADAPTER instance")
        self._adapter_cache[cache_key] = adapter
        return adapter

    def reset(self) -> None:
        """Drop strategy + adapter caches. Tests call this after mutating env."""
        self._strategy_cache = None
        self._adapter_cache.clear()


# Process-singleton — the canonical cache instance for this process.
# Hot-reload tooling can construct a fresh ProviderStrategyCache() and reassign.
_default_provider_cache = ProviderStrategyCache()


# ── Module-level free functions — thin wrappers, preserved API ───────
def load_strategy(force: bool = False) -> dict:
    """Process-singleton convenience wrapper around
    :meth:`ProviderStrategyCache.load_strategy`. Tests/hot-reload should
    construct their own ProviderStrategyCache() and call .load_strategy()
    directly."""
    return _default_provider_cache.load_strategy(force=force)


def _load_adapter(provider_id: str) -> ProviderAdapter:
    """Process-singleton convenience wrapper around
    :meth:`ProviderStrategyCache.load_adapter`. Internal — kept underscore
    per Phase 6 (intra-package private API)."""
    return _default_provider_cache.load_adapter(provider_id)


def reset_caches_for_tests() -> None:
    """Drop strategy + adapter caches. Tests call this after mutating env."""
    _default_provider_cache.reset()


def list_adapters() -> list[ProviderAdapter]:
    """Return all currently-importable adapters (for introspection).

    Phase-21-22 Phase 5 Site 12: the catch-all `except Exception` was narrowed
    to the concrete failure modes a provider load can produce
    (ImportError / ModuleNotFoundError, AttributeError on missing ADAPTER,
    KeyError / ValueError / TypeError on malformed strategy / payload schema,
    pydantic.ValidationError on profile shape mismatch). Anything else
    (e.g. a runtime bug in a provider's module-level init) propagates so it
    surfaces in tests rather than being silently logged.
    """
    import pydantic  # local import — pydantic isn't required by this module otherwise
    out: list[ProviderAdapter] = []
    names = _TEST_MODULE_NAMES if _is_test_mode() else _ADAPTER_MODULE_NAMES
    for name in names:
        pid = name.replace("mock_", "")
        try:
            out.append(_load_adapter(pid))
        except (
            ImportError,
            ModuleNotFoundError,
            AttributeError,
            KeyError,
            ValueError,
            TypeError,
            pydantic.ValidationError,
        ) as exc:
            logger.warning("provider %s failed to load: %s", pid, exc)
    return out


def _required_capabilities(payload: UnifiedVideoPayload) -> list[str]:
    """Derive the capabilities a payload needs."""
    req: list[str] = []
    if payload.image is not None:
        req.append("i2v")
    elif payload.reference_images:
        req.append("r2v")
    else:
        req.append("t2v")
    if payload.image_tail is not None:
        req.append("end_frame")
    if payload.generate_audio:
        req.append("audio")
    if payload.negative_prompt:
        req.append("negative_prompt")
    req.append(f"resolution_{payload.resolution}")
    return req


def _providers_supporting(capability: str, model_id: str) -> list[str]:
    out: list[str] = []
    for adapter in list_adapters():
        if model_id not in adapter.supported_models:
            continue
        if adapter.capabilities.get(capability, False):
            out.append(adapter.provider_id)
    return out


def resolve_adapter(
    model_id: str,
    payload: UnifiedVideoPayload,
    tier: Optional[str] = None,
) -> tuple[ProviderAdapter, str]:
    """Pick the primary adapter+tier for (model_id, payload).

    Returns (adapter, tier). Raises ProviderCapabilityError if the
    selected primary lacks a required capability and no exception
    routing covers it.

    The optional ``tier`` kwarg is the caller-supplied tier (CP-2 spec-review
    edit #7 — preferred over the legacy ``RECOIL_PROVIDER_TIER_OVERRIDE``
    env var because ProductionLoop runs phases in a ThreadPoolExecutor where
    per-thread env mutation races. When ``tier`` is None, fall through to
    env override / strategy defaults as before. Capability enforcement is
    unchanged.
    """
    strategy = load_strategy()
    entry = strategy.get(model_id, {})

    required = _required_capabilities(payload)

    # Env override first
    override = os.environ.get("RECOIL_PROVIDER_OVERRIDE")
    if override:
        provider_id = override.lower()
        tier_override = os.environ.get("RECOIL_PROVIDER_TIER_OVERRIDE")
        resolved_tier = (
            tier
            or tier_override
            or entry.get(f"{provider_id}_tier")
            or entry.get("primary_tier")
            or "default"
        )
        adapter = _load_adapter(provider_id)
        # Capability enforcement applies even on env override
        for cap in required:
            if not adapter.capabilities.get(cap, False):
                raise ProviderCapabilityError(
                    model_id=model_id,
                    provider_id=provider_id,
                    capability=cap,
                    supported_providers=_providers_supporting(cap, model_id),
                )
        return adapter, resolved_tier
    capability_exceptions: dict = entry.get("capability_exceptions", {}) or {}

    # If any required capability has an explicit exception, honor it.
    for cap in required:
        if cap in capability_exceptions:
            provider_id = capability_exceptions[cap]
            resolved_tier = tier or entry.get(f"{provider_id}_tier") or "default"
            adapter = _load_adapter(provider_id)
            # Enforce ALL required capabilities on the exception target.
            for req_cap in required:
                if not adapter.capabilities.get(req_cap, False):
                    raise ProviderCapabilityError(
                        model_id=model_id,
                        provider_id=provider_id,
                        capability=req_cap,
                        supported_providers=_providers_supporting(req_cap, model_id),
                    )
            return adapter, resolved_tier

    primary_id = entry.get("primary")
    if not primary_id:
        raise ValueError(
            f"No primary provider declared for model {model_id!r} in provider_strategy.json"
        )
    primary = _load_adapter(primary_id)
    primary_tier = tier or entry.get("primary_tier") or "default"

    # Capability enforcement on the primary (refuse loud)
    for cap in required:
        if not primary.capabilities.get(cap, False):
            raise ProviderCapabilityError(
                model_id=model_id,
                provider_id=primary_id,
                capability=cap,
                supported_providers=_providers_supporting(cap, model_id),
            )

    return primary, primary_tier


def resolve_fallback(
    model_id: str,
    payload: UnifiedVideoPayload,
) -> Optional[tuple[ProviderAdapter, str]]:
    """Return (adapter, tier) for declared fallback, or None."""
    strategy = load_strategy()
    entry = strategy.get(model_id, {})
    fb_id = entry.get("fallback")
    if not fb_id:
        return None
    fb_tier = entry.get("fallback_tier") or "default"
    try:
        fb_adapter = _load_adapter(fb_id)
    except ValueError:
        return None
    # Capability enforcement on the fallback too.
    for cap in _required_capabilities(payload):
        if not fb_adapter.capabilities.get(cap, False):
            return None
    return fb_adapter, fb_tier


__all__ = [
    "ProviderStrategyCache",
    "load_strategy",
    "list_adapters",
    "resolve_adapter",
    "resolve_fallback",
    "reset_caches_for_tests",
]
