# ==============================================================================
# PORTED FROM STARSEND: lib/vision_check.py
# DATE: 2026-03-29
# NOTE: For historical git blame prior to this date, see the starsend repository.
# ==============================================================================
"""
vision_check.py — Shared Gemini Flash vision validation utility.

Sends images to Gemini 2.5 Flash with structured validation questions,
returns pass/fail per check. Also handles video frame extraction via
ffmpeg + per-frame validation.

Key contract: On any non-transient exception, return passed=False with
gate_closed_by="exception" — the quality gate fails closed. Transient API
errors (ServiceUnavailable, DeadlineExceeded, httpx.TimeoutException) get
one retry with 2s backoff before failing closed. Per Build A SYNTHESIS D1
(Opus retry-only): the sanctioned-fallback registry's quality-neutrality
prong cannot honestly include "I didn't run the check."

Cost: ~$0.01 per image check via Gemini 2.5 Flash.
"""

import json
import logging
import os
import subprocess
import tempfile
from pathlib import Path

logger = logging.getLogger(__name__)

# Gemini model for vision checks — Flash for speed + cost
_VISION_MODEL = "gemini-2.5-flash"

_cached_client = None

# Module-level constants for transient retry (Build A Phase 1, 2026-05-09).
# These are the exceptions vision_check retries once before failing closed.
# Anything outside this set fails immediately — quality gate stays honest.
_TRANSIENT_RETRY_BACKOFF_S = 2.0
_TRANSIENT_RETRY_ATTEMPTS = 2  # original + 1 retry

# Build transient-exception tuple once at import time.
_TRANSIENT_EXCEPTIONS: tuple = ()
try:
    from google.api_core import exceptions as _gapi_exc
    _TRANSIENT_EXCEPTIONS = (_gapi_exc.ServiceUnavailable, _gapi_exc.DeadlineExceeded)
except ImportError:
    pass
try:
    import httpx as _httpx
    _TRANSIENT_EXCEPTIONS = _TRANSIENT_EXCEPTIONS + (_httpx.TimeoutException,)
except ImportError:
    pass


def _retry_transient(fn, *, attempts: int = _TRANSIENT_RETRY_ATTEMPTS,
                     backoff_s: float = _TRANSIENT_RETRY_BACKOFF_S):
    """Run `fn`; retry once on transient API exceptions. Other exceptions propagate.

    Transient = service unavailable, deadline exceeded, httpx timeout. These are
    network noise that a quick retry usually clears. Other exceptions (parse
    errors, image decode, auth, malformed responses) are NOT transient and
    should fail closed immediately.

    Per SYNTHESIS D1: retry, not sanctioned-fallback. The fallback registry's
    quality-neutrality prong cannot honestly include "I didn't run the check."
    """
    import time

    last_exc: Exception | None = None
    for attempt in range(attempts):
        try:
            return fn()
        except Exception as e:  # noqa: BLE001
            if not isinstance(e, _TRANSIENT_EXCEPTIONS):
                raise  # non-transient: fail fast
            last_exc = e
            if attempt < attempts - 1:
                logger.warning(
                    "vision_check transient error (attempt %d/%d), retrying in %.1fs: %s",
                    attempt + 1, attempts, backoff_s, e,
                )
                time.sleep(backoff_s)
    raise last_exc  # type: ignore[misc]


def _get_client():
    """Lazy-initialize Gemini client (cached at module level)."""
    global _cached_client
    if _cached_client is not None:
        return _cached_client

    from google import genai

    api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        raise RuntimeError("GEMINI_API_KEY or GOOGLE_API_KEY not set")
    _cached_client = genai.Client(api_key=api_key)
    return _cached_client


def _detect_mime(raw: bytes, extension: str) -> str:
    """Detect actual MIME type from file magic bytes, ignoring extension."""
    if raw[:3] == b"\xff\xd8\xff":
        return "image/jpeg"
    if raw[:8] == b"\x89PNG\r\n\x1a\n":
        return "image/png"
    if raw[:4] == b"RIFF" and raw[8:12] == b"WEBP":
        return "image/webp"
    _mime_map = {"jpg": "image/jpeg", "jpeg": "image/jpeg", "png": "image/png", "webp": "image/webp"}
    return _mime_map.get(extension.lower(), "image/png")


def _build_prompt(checks: list[dict], context_description: str, intention_context: dict | None = None) -> str:
    """Build a structured prompt for Gemini from check definitions.

    Each check is {name, question, expected, severity}.
    """
    lines = [
        "You are a visual quality inspector for a film production pipeline.",
    ]
    if context_description:
        lines.append(f"Context: {context_description}")

    if intention_context:
        lines.append("")
        lines.append("## What this image SHOULD depict:")
        if intention_context.get("generation_prompt"):
            lines.append(f"Generation prompt: {intention_context['generation_prompt']}")
        if intention_context.get("character_anchor"):
            lines.append(f"Character description (MUST match exactly): {intention_context['character_anchor']}")
        if intention_context.get("scene_description"):
            lines.append(f"Scene: {intention_context['scene_description']}")
        if intention_context.get("narrative_context"):
            lines.append(f"Story context: {intention_context['narrative_context']}")
        lines.append("")
        lines.append("Compare the image against these intentions. Be strict about character appearance matching the anchor description.")

    lines.extend([
        "",
        "Answer each question about this image. For each question, respond with "
        "a short answer (a few words). Be precise and literal.",
        "",
    ])
    for i, check in enumerate(checks, 1):
        lines.append(f"Q{i} ({check['name']}): {check['question']}")
    lines.append("")
    lines.append(
        "Respond in JSON format: {\"answers\": [{\"name\": \"...\", \"answer\": \"...\"}]}"
    )
    return "\n".join(lines)


def _parse_response(raw_text: str, checks: list[dict]) -> list[dict]:
    """Parse Gemini response and match answers against expected values.

    Returns list of {name, passed, answer, expected, severity} dicts.
    """
    results = []

    # Try to parse JSON from response
    try:
        # Strip markdown code fences if present
        text = raw_text.strip()
        if text.startswith("```"):
            text = text.split("\n", 1)[1] if "\n" in text else text[3:]
            if text.endswith("```"):
                text = text[:-3]
            text = text.strip()

        data = json.loads(text)
        answers = data.get("answers", [])
    except (json.JSONDecodeError, AttributeError):
        # Build A Phase 1: parse failure → fail closed (was fail-open).
        # If we can't parse the response, we cannot assert the check passed.
        logger.warning("Could not parse Gemini vision response: %s", raw_text[:200])
        for check in checks:
            results.append({
                "name": check["name"],
                "passed": False,
                "answer": "",
                "expected": check["expected"],
                "severity": check["severity"],
                "error": "Could not parse response",
                "gate_closed_by": "parse_error",
            })
        return results

    # Build lookup from answers
    answer_map = {}
    for ans in answers:
        name = ans.get("name", "")
        answer_map[name] = ans.get("answer", "")

    # Match each check against its answer
    for check in checks:
        answer = answer_map.get(check["name"], "")
        expected = check["expected"]

        # Case-insensitive substring match
        passed = expected.lower() in answer.lower() if answer else False

        results.append({
            "name": check["name"],
            "passed": passed,
            "answer": answer,
            "expected": expected,
            "severity": check["severity"],
        })

    return results


def validate_image(
    image_path: str | Path,
    checks: list[dict],
    context_description: str = "",
    intention_context: dict | None = None,
) -> dict:
    """Send image + structured questions to Gemini 2.5 Flash.

    Args:
        image_path: Path to the image file.
        checks: List of {name, question, expected, severity} dicts.
        context_description: Optional context about what the image should depict.
        intention_context: Optional dict with generation_prompt, character_anchor,
            scene_description, narrative_context for intention-aware validation.

    Returns:
        {
            "passed": bool,       # True if all checks pass
            "results": [...],     # Per-check results
            "error": str | None,  # Error message if API failed
        }

    On non-transient exception, returns passed=False with gate_closed_by="exception".
    Transient API errors get one retry before failing closed.
    """
    try:
        from google.genai import types

        image_path = Path(image_path)
        if not image_path.exists():
            # Build A Phase 1: missing input → fail closed (was fail-open).
            # Cannot assert a check passed if the image to check doesn't exist.
            logger.error("vision_check.validate_image — gate CLOSED: image not found: %s", image_path)
            return {
                "passed": False,
                "results": [],
                "error": f"Image not found: {image_path}",
                "gate_closed_by": "missing_input",
            }

        client = _get_client()
        image_bytes = image_path.read_bytes()
        suffix = image_path.suffix.lower().lstrip(".")
        mime = _detect_mime(image_bytes, suffix)

        prompt = _build_prompt(checks, context_description, intention_context)

        # Build A Phase 1: retry transient API errors once before propagating.
        response = _retry_transient(lambda: client.models.generate_content(
            model=_VISION_MODEL,
            contents=[
                types.Part.from_bytes(data=image_bytes, mime_type=mime),
                types.Part(text=prompt),
            ],
        ))

        raw_text = response.text if hasattr(response, "text") and response.text else ""
        results = _parse_response(raw_text, checks)

        all_passed = all(r["passed"] for r in results)

        return {
            "passed": all_passed,
            "results": results,
            "error": None,
        }

    except Exception as e:
        # Build A Phase 1 (2026-05-09): flipped from passed:True to passed:False.
        # Per SYNTHESIS D1 (Opus): the quality gate fails closed on any exception.
        # _retry_transient (above) retries network noise once before propagating.
        logger.error("vision_check.validate_image — gate CLOSED: %s", e)
        return {
            "passed": False,
            "results": [],
            "error": str(e),
            "gate_closed_by": "exception",
        }


def _extract_frames(video_path: Path, num_frames: int, tmpdir: str) -> list[Path]:
    """Extract N evenly-spaced frames from video via ffmpeg.

    Returns list of extracted frame paths.
    """
    # Get video duration first
    probe_cmd = [
        "ffprobe", "-v", "quiet",
        "-show_entries", "format=duration",
        "-of", "default=noprint_wrappers=1:nokey=1",
        str(video_path),
    ]
    result = subprocess.run(probe_cmd, capture_output=True, text=True)
    duration = float(result.stdout.strip()) if result.stdout.strip() else 5.0

    frames = []
    for i in range(num_frames):
        # Evenly space frames across the video duration
        timestamp = (duration * i) / max(num_frames - 1, 1) if num_frames > 1 else 0
        frame_path = Path(tmpdir) / f"frame_{i:03d}.png"

        extract_cmd = [
            "ffmpeg", "-y", "-v", "quiet",
            "-ss", f"{timestamp:.3f}",
            "-i", str(video_path),
            "-vframes", "1",
            "-q:v", "2",
            str(frame_path),
        ]
        subprocess.run(extract_cmd, capture_output=True)

        if frame_path.exists():
            frames.append(frame_path)

    return frames


def validate_video_frames(
    video_path: str | Path,
    checks: list[dict],
    context_description: str = "",
    num_frames: int = 5,
    intention_context: dict | None = None,
) -> dict:
    """Extract N frames from video and validate each.

    Args:
        video_path: Path to the video file.
        checks: List of {name, question, expected, severity} dicts.
        context_description: Optional context description.
        num_frames: Number of frames to extract and check (default 5).
        intention_context: Optional dict with generation_prompt, character_anchor,
            scene_description, narrative_context for intention-aware validation.

    Returns:
        {
            "passed": bool,           # True if all checks pass across all frames
            "frame_results": [...],   # Per-frame results
            "error": str | None,
        }

    On non-transient exception, returns passed=False with gate_closed_by="exception".
    """
    try:
        video_path = Path(video_path)
        if not video_path.exists():
            # Build A Phase 1: missing input → fail closed (was fail-open).
            logger.error("vision_check.validate_video — gate CLOSED: video not found: %s", video_path)
            return {
                "passed": False,
                "frame_results": [],
                "error": f"Video not found: {video_path}",
                "gate_closed_by": "missing_input",
            }

        with tempfile.TemporaryDirectory(prefix="vision_check_") as tmpdir:
            frames = _extract_frames(video_path, num_frames, tmpdir)

            if not frames:
                # Build A Phase 1: no frames extractable → fail closed (was fail-open).
                logger.error("vision_check.validate_video — gate CLOSED: no frames extracted from %s", video_path)
                return {
                    "passed": False,
                    "frame_results": [],
                    "error": "No frames extracted from video",
                    "gate_closed_by": "no_frames",
                }

            frame_results = []
            for i, frame_path in enumerate(frames):
                result = validate_image(
                    frame_path,
                    checks,
                    context_description=f"Frame {i + 1}/{len(frames)}. {context_description}",
                    intention_context=intention_context,
                )
                result["frame_index"] = i
                result["frame_path"] = str(frame_path)
                frame_results.append(result)

            # Aggregate: passed only if ALL frames pass
            all_passed = all(fr["passed"] for fr in frame_results)

            return {
                "passed": all_passed,
                "frame_results": frame_results,
                "error": None,
            }

    except Exception as e:
        # Build A Phase 1 (2026-05-09): flipped from passed:True to passed:False.
        logger.error("vision_check.validate_video_frames — gate CLOSED: %s", e)
        return {
            "passed": False,
            "frame_results": [],
            "error": str(e),
            "gate_closed_by": "exception",
        }


__all__ = [
    # Public symbols (Phase D — MF-3 + DEBT-9).
    "validate_image",
    "validate_video_frames",
]
