"""Filter-safety glossary loader and mode resolver."""

from __future__ import annotations

from collections import Counter
from dataclasses import dataclass
from datetime import datetime, timezone
from functools import lru_cache
from pathlib import Path
from types import MappingProxyType
from typing import Any, Mapping
import os
import re

import yaml

from recoil.core.paths import RECOIL_ROOT
from recoil.pipeline._lib.recoil_bridge import load_project_config


_GLOSSARY_PATH = RECOIL_ROOT / "config" / "filter_safety.yaml"
FILTER_SAFETY_LABELS_PATH = (
    RECOIL_ROOT / "_dispatch_logs" / "filter_safety_labels.jsonl"
)
_FILTER_SAFETY_ENV_VAR = "RECOIL_FILTER_SAFETY"
_FILTER_SAFETY_CONFIG_KEY = "filter_safety_mode"
MODES = ("off", "shadow")
_SECTION_HEADER_RE = re.compile(r"^[A-Z][A-Z &/-]*:?$")
_SENTENCE_SPLIT_RE = re.compile(r"[.!?\n]+")
_REQUIRED_KEYS = (
    "schema_version",
    "categories",
    "swaps",
    "safe_sections",
    "stack_threshold",
)


@dataclass(frozen=True)
class Glossary:
    categories: Mapping[str, tuple[str, ...]]
    swaps: Mapping[str, Mapping[str, str]]
    safe_sections: tuple[str, ...]
    stack_threshold: int


@dataclass(frozen=True)
class LintFinding:
    sentence: str
    categories: tuple[str, ...]
    terms: tuple[str, ...]
    level: str
    section: str | None
    suggested_swaps: dict[str, str]


def load_glossary(path: Path | None = None) -> Glossary:
    """Load and validate the filter-safety glossary."""
    if path is None:
        return _load_default_glossary()
    return _load_glossary_from_path(Path(path))


@lru_cache(maxsize=1)
def _load_default_glossary() -> Glossary:
    return _load_glossary_from_path(_GLOSSARY_PATH)


def _load_glossary_from_path(path: Path) -> Glossary:
    if not path.exists():
        raise ValueError(f"missing filter safety glossary file: {path}")

    try:
        raw = yaml.safe_load(path.read_text(encoding="utf-8"))
    except yaml.YAMLError as exc:
        raise ValueError(f"invalid filter safety YAML in {path}: {exc}") from exc
    except OSError as exc:
        raise ValueError(f"cannot read filter safety glossary {path}: {exc}") from exc

    if not isinstance(raw, dict):
        raise ValueError("filter safety glossary must be a mapping")

    for key in _REQUIRED_KEYS:
        if key not in raw:
            raise ValueError(f"filter safety glossary missing required key {key!r}")

    categories = _validate_categories(raw["categories"])
    swaps = _validate_swaps(raw["swaps"])
    safe_sections = _validate_string_list(raw["safe_sections"], "safe_sections")
    stack_threshold = raw["stack_threshold"]
    if not isinstance(stack_threshold, int):
        raise ValueError("stack_threshold must be an integer")

    return Glossary(
        categories=MappingProxyType(categories),
        swaps=MappingProxyType(swaps),
        safe_sections=tuple(safe_sections),
        stack_threshold=stack_threshold,
    )


def _validate_categories(value: object) -> dict[str, tuple[str, ...]]:
    if not isinstance(value, dict):
        raise ValueError("categories must be a mapping")
    categories: dict[str, tuple[str, ...]] = {}
    for category, terms in value.items():
        if not isinstance(category, str):
            raise ValueError("category names must be strings")
        if not isinstance(terms, list):
            raise ValueError(f"category {category!r} must be a list")
        for term in terms:
            if not isinstance(term, str):
                raise ValueError(f"category {category!r} terms must be strings")
        categories[category] = tuple(terms)
    return categories


def _validate_swaps(value: object) -> dict[str, Mapping[str, str]]:
    if not isinstance(value, dict):
        raise ValueError("swaps must be a mapping")
    swaps: dict[str, Mapping[str, str]] = {}
    for category, entries in value.items():
        if not isinstance(category, str):
            raise ValueError("swap category names must be strings")
        if not isinstance(entries, dict):
            raise ValueError(f"swaps category {category!r} must be a mapping")
        normalized: dict[str, str] = {}
        for avoid, replacement in entries.items():
            if not isinstance(avoid, str):
                raise ValueError(f"swap key in {category!r} must be a string")
            if not isinstance(replacement, str):
                raise ValueError(f"swap replacement for {avoid!r} must be a string")
            normalized[avoid.lower()] = replacement
        swaps[category] = MappingProxyType(normalized)
    return swaps


def _validate_string_list(value: object, name: str) -> tuple[str, ...]:
    if not isinstance(value, list):
        raise ValueError(f"{name} must be a list")
    for item in value:
        if not isinstance(item, str):
            raise ValueError(f"{name} entries must be strings")
    return tuple(value)


def filter_safety_mode(project: str) -> str:
    """Resolve filter-safety mode from env, project config, then default shadow."""
    mode = os.environ.get(_FILTER_SAFETY_ENV_VAR)
    if mode is None:
        cfg = load_project_config(project) or {}
        mode = cfg.get(_FILTER_SAFETY_CONFIG_KEY) or "shadow"
    if mode not in MODES:
        raise ValueError(
            f"invalid filter safety mode {mode!r}; expected one of {', '.join(MODES)}"
        )
    return mode


def lint_prompt(text: str, *, glossary: Glossary | None = None) -> list[LintFinding]:
    """Return stacked filter-safety findings for prompt text."""
    glossary = glossary or load_glossary()
    findings: list[LintFinding] = []

    for section, body in _split_sections(text):
        level = "INFO" if _is_safe_section(section, glossary.safe_sections) else "WARN"
        for sentence in _split_sentences(body):
            categories, terms = _match_sentence(sentence, glossary)
            if len(categories) <= glossary.stack_threshold:
                continue
            findings.append(
                LintFinding(
                    sentence=sentence,
                    categories=categories,
                    terms=terms,
                    level=level,
                    section=section,
                    suggested_swaps=_suggested_swaps(terms, glossary),
                )
            )

    return findings


def summarize_findings(findings) -> dict:
    """Build a compact JSON-serializable summary for provenance stamping."""
    finding_list = list(findings)
    term_counts = Counter(term for finding in finding_list for term in finding.terms)
    categories_hit = sorted(
        {category for finding in finding_list for category in finding.categories}
    )
    return {
        "warn": sum(1 for finding in finding_list if finding.level == "WARN"),
        "info": sum(1 for finding in finding_list if finding.level == "INFO"),
        "categories_hit": categories_hit,
        "top_terms": [
            term
            for term, _count in sorted(
                term_counts.items(), key=lambda item: (-item[1], item[0])
            )
        ],
    }


def build_flag_label(
    take: Any,
    failure: Any,
    *,
    project: str,
    episode: Any,
    shot_id: str,
    model: str | None = None,
    ts: str | None = None,
) -> dict:
    """Build a ground-truth label for a provider content-filter block."""
    provider_error = getattr(failure, "error", None) or ""
    prompt, prompt_error = _resolve_flag_prompt(take, provider_error)
    resolved_model = model or getattr(failure, "model", None) or _take_model(take)
    record = {
        "ts": ts or _utc_now_iso8601(),
        "project": project,
        "episode": episode,
        "shot_id": shot_id,
        "model": resolved_model,
        "provider_error": provider_error,
    }
    if prompt_error is not None:
        record["prompt_error"] = prompt_error
        return record

    record["prompt"] = prompt
    try:
        record["lint"] = summarize_findings(lint_prompt(prompt))
    except Exception as exc:  # noqa: BLE001
        record["lint_error"] = str(exc)
    return record


def _utc_now_iso8601() -> str:
    return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")


def _resolve_flag_prompt(
    take: Any, provider_error: str
) -> tuple[str | None, str | None]:
    workflow = getattr(take, "workflow", None)
    steps = getattr(workflow, "steps", None)
    if not steps:
        return None, "take.workflow.steps missing or empty"

    first_prompt: tuple[str | None, str | None] | None = None
    matching_prompt: tuple[str | None, str | None] | None = None
    for step in steps:
        payload = getattr(step, "payload", None)
        if not isinstance(payload, dict) or "prompt" not in payload:
            continue
        candidate = _prompt_value(payload.get("prompt"), step)
        if first_prompt is None:
            first_prompt = candidate
        if _step_failure_is_content_filter(step):
            return candidate
        if _step_failure_matches(step, provider_error):
            matching_prompt = candidate

    if matching_prompt is not None:
        return matching_prompt
    if first_prompt is not None:
        return first_prompt
    return None, "no workflow step payload contains prompt"


def _prompt_value(value: Any, step: Any) -> tuple[str | None, str | None]:
    step_id = getattr(step, "step_id", "<unknown>")
    if value is None:
        return None, f"prompt for step {step_id!r} is null"
    if not isinstance(value, str):
        return None, f"prompt for step {step_id!r} is not a string"
    return value, None


def _step_failure_matches(step: Any, provider_error: str) -> bool:
    error = _step_failure_error(step)
    if not error:
        return False
    return not provider_error or error in provider_error or provider_error in error


def _step_failure_is_content_filter(step: Any) -> bool:
    error = _step_failure_error(step)
    if not error:
        return False
    from recoil.pipeline.core.failure_mode import CONTENT_FILTER_PATTERNS

    error_lower = error.lower()
    return any(pattern in error_lower for pattern in CONTENT_FILTER_PATTERNS)


def _step_failure_error(step: Any) -> str | None:
    receipt = getattr(step, "receipt", None)
    run_result = getattr(receipt, "run_result", None)
    if getattr(run_result, "success", None) is not False:
        return None
    return getattr(run_result, "error", None)


def _take_model(take: Any) -> str:
    metadata = getattr(take, "take_metadata", None)
    if isinstance(metadata, dict) and metadata.get("model"):
        return str(metadata["model"])
    provenance = getattr(getattr(take, "workflow", None), "global_provenance", None)
    if isinstance(provenance, dict) and provenance.get("model"):
        return str(provenance["model"])
    return ""


def _split_sections(text: str) -> list[tuple[str | None, str]]:
    sections: list[tuple[str | None, list[str]]] = [(None, [])]
    current_section: str | None = None

    for line in text.splitlines():
        stripped = line.strip()
        if stripped and _SECTION_HEADER_RE.fullmatch(stripped):
            current_section = stripped[:-1] if stripped.endswith(":") else stripped
            sections.append((current_section, []))
            continue
        sections[-1][1].append(line)

    return [(section, "\n".join(lines)) for section, lines in sections]


def _split_sentences(text: str) -> list[str]:
    return [part.strip() for part in _SENTENCE_SPLIT_RE.split(text) if part.strip()]


def _match_sentence(
    sentence: str, glossary: Glossary
) -> tuple[tuple[str, ...], tuple[str, ...]]:
    categories: list[str] = []
    terms: list[str] = []
    seen_terms: set[str] = set()

    for category, category_terms in glossary.categories.items():
        category_hit = False
        for term in category_terms:
            normalized = term.lower()
            if normalized in seen_terms:
                continue
            if _phrase_re(term).search(sentence):
                seen_terms.add(normalized)
                terms.append(normalized)
                category_hit = True
        if category_hit:
            categories.append(category)

    return tuple(categories), tuple(terms)


@lru_cache(maxsize=512)
def _phrase_re(term: str) -> re.Pattern[str]:
    return re.compile(rf"(?<!\w){re.escape(term)}(?!\w)", re.IGNORECASE)


def _is_safe_section(section: str | None, safe_sections: tuple[str, ...]) -> bool:
    if section is None:
        return False
    section_upper = section.upper()
    return any(safe.upper() in section_upper for safe in safe_sections)


def _suggested_swaps(terms: tuple[str, ...], glossary: Glossary) -> dict[str, str]:
    suggestions: dict[str, str] = {}
    for term in terms:
        for swaps in glossary.swaps.values():
            replacement = swaps.get(term)
            if replacement is not None:
                suggestions[term] = replacement
                break
    return suggestions
