"""FeedbackAgent — Diagnoses gate failures and produces prompt/ref fixes.

Called from StepRunner's retry loop at a single insertion point.
StepRunner calls diagnose() with failure context.
FeedbackAgent returns FeedbackFix or None (unfixable, escalate to ICU).
"""
import json
import logging
import time
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Optional

logger = logging.getLogger(__name__)


class FeedbackStrategy(str, Enum):
    """Named strategies for feedback attempts."""
    ANATOMY_ANCHOR = "anatomy_anchor"
    FACE_ISOLATION = "face_isolation"
    REF_PRUNE_AND_ANCHOR = "ref_prune_and_anchor"
    WARDROBE_EXAGGERATION = "wardrobe_exaggeration"
    WARDROBE_CONTRASTIVE = "wardrobe_contrastive"
    LIGHTING_LOCK = "lighting_lock"
    ENVIRONMENT_REINFORCE = "environment_reinforce"
    STYLE_ANCHOR = "style_anchor"
    LVLM_REWRITE = "lvlm_rewrite"
    DIFFICULTY_DOWNGRADE = "difficulty_downgrade"
    VIDEO_SHORTEN = "video_shorten"
    VIDEO_MOTION_SIMPLIFY = "video_motion_simplify"
    CROP_TO_CLOSEUP = "crop_to_closeup"
    # Phase 3 additions (2026-04-10)
    SEED_REROLL = "seed_reroll"
    REDUCE_REFS = "reduce_refs"
    ASPECT_CROP = "aspect_crop"


@dataclass(frozen=True)
class RefChanges:
    """Describes how to modify the reference image set."""
    keep_only: list[str] | None = None
    remove: list[str] | None = None
    grayscale_expression_refs: bool = False
    face_crop_expression_refs: bool = False
    boost_scene_ref: bool = False
    crop_to_closeup: bool = False

    @property
    def has_removals(self) -> bool:
        return bool(self.remove) or bool(self.keep_only)

    @property
    def has_additions(self) -> bool:
        return self.boost_scene_ref


@dataclass(frozen=True)
class FeedbackFix:
    """A diagnosed fix to apply before retry."""
    strategy: FeedbackStrategy
    ref_changes: RefChanges | None
    negative_prompt_additions: list[str]
    confidence: float
    rationale: str
    diagnosis_cost: float

    def apply_negative_prompt(self, existing: str) -> str:
        """Merge negative prompt additions with existing."""
        additions = [a for a in self.negative_prompt_additions if a not in existing]
        if not additions:
            return existing
        return (existing + ", " + ", ".join(additions)).strip(", ")


@dataclass
class FeedbackAttempt:
    """Record of a single feedback attempt for the autopsy/log."""
    attempt_number: int
    strategy: FeedbackStrategy
    gate_failed: str
    failure_reason: str
    ref_changes_applied: str | None
    result: str  # "passed" | "failed_same" | "failed_different" | "failed_worse"
    output_path: str | None
    cost: float
    timestamp: float = 0.0

    def __post_init__(self):
        if self.timestamp == 0.0:
            self.timestamp = time.time()


class FeedbackAgent:
    """Standalone feedback module called from StepRunner's retry loop."""

    def __init__(self, project_id: str, engine_memory_path: Path | None = None):
        from .fix_registry import FIX_REGISTRY
        from .constants import NEVER_HEAL
        self.project_id = project_id
        self.fix_registry = FIX_REGISTRY
        self.never_heal = NEVER_HEAL
        self.attempt_history: list[FeedbackAttempt] = []
        self._memory_path = engine_memory_path

    def diagnose(
        self,
        verdict,
        prompt_sections: list[dict],
        current_refs: list[Path],
        ref_metadata: dict,
        failed_output_path: Path,
        attempt_number: int,
        modality: str = "keyframe",
        target_model: str = "",
        sibling_ref_path: Path | None = None,
    ) -> Optional[FeedbackFix]:
        """Analyze a gate failure and produce a fix, or None if unfixable."""
        from .fix_registry import match_fix
        from .constants import SEVERITY_CEILING

        # Never-feedback shortlist
        failure_cat = getattr(verdict, 'details', {}).get("failure_category", "")
        if failure_cat in self.never_heal:
            logger.info("Feedback skipped: never-feedback category '%s'", failure_cat)
            return None

        # Severity ceiling
        severity = getattr(verdict, 'details', {}).get("total_severity", 0)
        if severity > SEVERITY_CEILING:
            logger.info("Feedback skipped: severity %d exceeds ceiling %d", severity, SEVERITY_CEILING)
            return None

        # Video triage (Phase 3 — stub for now)
        if modality == "video":
            logger.info("Video feedback not yet implemented — escalating to ICU")
            return None

        # Load forbidden strategies from model profile
        forbidden_strategies: set[str] = set()
        if target_model:
            try:
                from recoil.core import model_profiles
                profile = model_profiles.get_profile(target_model)
                forbidden_strategies = set(profile.get("forbidden_reroll_strategies", []))
            except Exception:
                logger.warning("Could not load model profile for %s", target_model)

        # Deterministic fix from registry
        fix = match_fix(verdict, attempt_number)
        if fix:
            # Check if the matched fix uses a forbidden strategy
            if forbidden_strategies and fix.strategy.value in forbidden_strategies:
                logger.info(
                    "Feedback: %s FORBIDDEN for model %s — trying safe fallback",
                    fix.strategy.value, target_model,
                )
                fix = self._safe_fallback(forbidden_strategies, attempt_number, verdict)
                if fix:
                    logger.info("Feedback: safe fallback %s (confidence %.2f)", fix.strategy.value, fix.confidence)
                    if sibling_ref_path:
                        fix = self._inject_sibling_ref(fix, sibling_ref_path)
                    return fix
                return None
            logger.info("Feedback: %s (confidence %.2f)", fix.strategy.value, fix.confidence)
            if sibling_ref_path:
                fix = self._inject_sibling_ref(fix, sibling_ref_path)
            return fix

        # Safe fallback when no deterministic fix found but model has restrictions
        if forbidden_strategies:
            fix = self._safe_fallback(forbidden_strategies, attempt_number, verdict)
            if fix:
                logger.info("Feedback: safe fallback %s (confidence %.2f)", fix.strategy.value, fix.confidence)
                if sibling_ref_path:
                    fix = self._inject_sibling_ref(fix, sibling_ref_path)
                return fix

        # CROP_TO_CLOSEUP fallback — skip if forbidden
        if not forbidden_strategies or "crop_to_closeup" not in forbidden_strategies:
            fix = self._try_crop_to_closeup(verdict, attempt_number)
            if fix:
                logger.info("Feedback: CROP_TO_CLOSEUP (last resort, confidence %.2f)", fix.confidence)
                if sibling_ref_path:
                    fix = self._inject_sibling_ref(fix, sibling_ref_path)
                return fix

        # LVLM fallback (Phase 2 — stub for now)
        logger.info("No deterministic fix found. LVLM fallback not yet implemented.")
        return None

    def _inject_sibling_ref(
        self,
        fix: FeedbackFix,
        sibling_ref_path: Path,
    ) -> FeedbackFix:
        """Inject sibling ref as continuity_anchor into a FeedbackFix."""
        return FeedbackFix(
            strategy=fix.strategy,
            ref_changes=fix.ref_changes,
            negative_prompt_additions=fix.negative_prompt_additions,
            confidence=fix.confidence,
            rationale=f"{fix.rationale} [+sibling_ref: {sibling_ref_path.name}]",
            diagnosis_cost=fix.diagnosis_cost,
        )

    def _try_crop_to_closeup(
        self,
        verdict,
        attempt_number: int,
    ) -> Optional[FeedbackFix]:
        """Last-resort: crop a medium+ shot to close-up to salvage the face.

        Guards (ALL must pass):
        - Original framing is MS or wider (CU/MCU/ECU already tight)
        - Not an action or establishing shot (would break narrative)
        - All other strategies exhausted (attempt >= 2)
        - Max 2 consecutive CUs not exceeded (tracked per-session)
        - Shot verb_strength is LOW (vague action survives reframing)

        Synthesis: Track crop % per series — >5% means upstream ref/prompt problem.
        """
        from .constants import (
            CROP_CLOSEUP_MIN_FRAMING_ORDER,
            CROP_CLOSEUP_BLOCKED_SHOT_TYPES,
            CROP_CLOSEUP_MIN_ATTEMPT,
            CROP_CLOSEUP_MAX_CONSECUTIVE,
        )

        # Guard: must have failed enough times
        if attempt_number < CROP_CLOSEUP_MIN_ATTEMPT:
            return None

        # Guard: check framing from verdict details
        shot_type = getattr(verdict, 'details', {}).get("shot_type", "")
        if not shot_type:
            shot_type = getattr(verdict, 'details', {}).get("framing", "")
        shot_upper = shot_type.upper() if shot_type else ""

        # Framing order: EWS=0, WS=1, MWS=2, MS=3, MCU=4, CU=5, ECU=6
        FRAMING_ORDER = {"EWS": 0, "WS": 1, "MWS": 2, "MS": 3, "MCU": 4, "CU": 5, "ECU": 6, "OTS": 3}
        framing_rank = FRAMING_ORDER.get(shot_upper, -1)
        if framing_rank == -1 or framing_rank > CROP_CLOSEUP_MIN_FRAMING_ORDER:
            return None  # Already a close-up (MCU/CU/ECU) or unknown framing

        # Guard: not action or establishing
        if shot_upper in CROP_CLOSEUP_BLOCKED_SHOT_TYPES:
            return None

        # Guard: consecutive CU limit
        recent_crops = sum(
            1 for a in self.attempt_history[-CROP_CLOSEUP_MAX_CONSECUTIVE:]
            if a.strategy == FeedbackStrategy.CROP_TO_CLOSEUP
        )
        if recent_crops >= CROP_CLOSEUP_MAX_CONSECUTIVE:
            logger.warning("CROP_TO_CLOSEUP: max consecutive (%d) reached", CROP_CLOSEUP_MAX_CONSECUTIVE)
            return None

        # Guard: verb_strength should be LOW (from PlanPassCritic)
        verb_strength = getattr(verdict, 'details', {}).get("verb_strength", "")
        if verb_strength and verb_strength.upper() not in ("LOW", ""):
            return None

        return FeedbackFix(
            strategy=FeedbackStrategy.CROP_TO_CLOSEUP,
            ref_changes=RefChanges(
                keep_only=["hero"],
                face_crop_expression_refs=True,
                crop_to_closeup=True,
            ),
            negative_prompt_additions=["wide shot", "full body", "establishing shot"],
            confidence=0.30,
            rationale=f"Last-resort crop-to-closeup: {shot_upper} -> CU (attempt {attempt_number})",
            diagnosis_cost=0.00,
        )

    def _safe_fallback(
        self,
        forbidden_strategies: set[str],
        attempt_number: int,
        verdict,
    ) -> Optional[FeedbackFix]:
        """Fallback order when prompt mutation is forbidden.

        Order: (1) SEED_REROLL, (2) REDUCE_REFS, (3) ASPECT_CROP.
        Each is tried only if not in the forbidden set.
        """
        fallback_order = [
            (FeedbackStrategy.SEED_REROLL, None),
            (FeedbackStrategy.REDUCE_REFS, RefChanges(keep_only=["hero"])),
            (FeedbackStrategy.ASPECT_CROP, None),
        ]

        for strategy, ref_changes in fallback_order:
            if strategy.value in forbidden_strategies:
                continue
            return FeedbackFix(
                strategy=strategy,
                ref_changes=ref_changes,
                negative_prompt_additions=[],
                confidence=0.20,
                rationale=f"Safe fallback: {strategy.value} (prompt mutation forbidden)",
                diagnosis_cost=0.00,
            )

        # All safe fallbacks also forbidden — try crop_to_closeup
        return self._try_crop_to_closeup(verdict, attempt_number)

    def log_attempt(self, shot_id: str, attempt: FeedbackAttempt) -> None:
        """Append feedback attempt to history and JSONL log."""
        self.attempt_history.append(attempt)
        if self._memory_path:
            log_path = self._memory_path / "feedback" / "feedback_log.jsonl"
            log_path.parent.mkdir(parents=True, exist_ok=True)
            entry = {
                "timestamp": attempt.timestamp,
                "shot_id": shot_id,
                "project": self.project_id,
                "attempt": attempt.attempt_number,
                "strategy": attempt.strategy.value,
                "gate_failed": attempt.gate_failed,
                "failure_reason": attempt.failure_reason[:200],
                "result": attempt.result,
                "cost": attempt.cost,
            }
            with open(log_path, "a") as f:
                f.write(json.dumps(entry) + "\n")

    def log_success(self, shot_id: str, character_id: str,
                    strategy: FeedbackStrategy, model_version: str) -> None:
        """Record successful feedback. Update auto-inject state for Gate 1."""
        logger.info("Feedback SUCCESS: shot=%s strategy=%s", shot_id, strategy.value)
        # Auto-inject tracking implemented in Phase 2

    def generate_autopsy(self, shot_data: dict, final_verdict) -> dict:
        """Generate ICU escalation report from attempt_history."""
        from .autopsy import generate_autopsy_report
        return generate_autopsy_report(
            shot_id=shot_data.get("shot_id", "unknown"),
            project_id=self.project_id,
            attempt_history=self.attempt_history,
            final_verdict=final_verdict,
            shot_data=shot_data,
        )
