"""
spatial_compliance.py — Post-generation spatial compliance check.

Uses Gemini 2.0 Flash vision to extract character positions from generated frames,
then compares against the spatial instruction that was sent and against
the previous shot in the scene to detect 180-degree line crosses.

Cost: ~$0.01 per check (Flash vision).
Architecture: Gemini consultation 3 rounds, fully converged.
"""

import json
import logging
import os
import time

logger = logging.getLogger(__name__)

# ── Cached Gemini client (connection reuse across calls) ─────────────
_gemini_client = None

def _get_gemini_client():
    global _gemini_client
    if _gemini_client is None:
        api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY")
        if api_key:
            from google import genai
            _gemini_client = genai.Client(api_key=api_key)
    return _gemini_client

# ── Position ordinal for relative comparison ──────────────────────────
_POS_ORDINAL = {"LEFT": 0, "CENTER": 1, "RIGHT": 2}

# ── Extraction prompt template ────────────────────────────────────────

_EXTRACTION_PROMPT = """
Analyze this vertical 9:16 cinematic frame. Extract the spatial positioning of the characters present.
If there are no characters (e.g., environment or object insert), return an empty list.

Target Characters to identify:
{character_anchors}

Rules:
1. Horizontal Position: Where is the bulk of the character's body/face located on the screen?
2. Depth: Is the character clearly in the FOREGROUND (closer to camera), BACKGROUND (further away), or SAME_PLANE (roughly equal distance/solo shot)?
3. Facing: Where is the character looking? Use NOT_APPLICABLE for extreme close-ups or indistinguishable faces.
"""

# ── Response schema for structured JSON output ────────────────────────

_EXTRACTION_SCHEMA = {
    "type": "ARRAY",
    "items": {
        "type": "OBJECT",
        "properties": {
            "character_name": {
                "type": "STRING",
                "description": "Return ONLY the character name before the colon (e.g., 'Wren'). Do not include the visual description.",
            },
            "horizontal_position": {
                "type": "STRING",
                "enum": ["LEFT", "CENTER", "RIGHT"],
            },
            "depth": {
                "type": "STRING",
                "enum": ["FOREGROUND", "BACKGROUND", "SAME_PLANE"],
            },
            "facing": {
                "type": "STRING",
                "enum": ["SCREEN-LEFT", "SCREEN-RIGHT", "CAMERA", "AWAY", "NOT_APPLICABLE"],
            },
        },
        "required": ["character_name", "horizontal_position", "depth", "facing"],
    },
}


# ── Visual anchor builder ─────────────────────────────────────────────

def _build_character_anchor(char_entry: dict, bible: dict) -> str:
    """Build a 2-3 word visual anchor for character identification.

    Extracts from bible visual_description, first sentence, 60 char cap.
    """
    if not isinstance(char_entry, dict):
        return f"- {char_entry}"

    char_id = char_entry.get("char_id", "unknown")
    bible_char = bible.get("characters", {}).get(char_id, {})
    display_name = bible_char.get("display_name", char_id)
    visual_desc = bible_char.get("visual_description", "")

    if visual_desc:
        first_sentence = visual_desc.split(".")[0].strip()
        anchor = first_sentence[:60]
        if len(first_sentence) > 60:
            anchor += "..."
        return f"- {display_name}: {anchor}"
    return f"- {display_name}"


def _build_anchors_text(characters: list[dict], bible: dict) -> str:
    """Build the character anchors block for the extraction prompt."""
    if not characters:
        return "No characters expected in frame."
    return "\n".join(_build_character_anchor(c, bible) for c in characters)


# ── Flash extraction call ─────────────────────────────────────────────

def _extract_positions(image_data: bytes, characters: list[dict], bible: dict) -> list[dict]:
    """Send image to Flash 3.1 for spatial position extraction.

    Returns list of dicts: [{character_name, horizontal_position, depth, facing}, ...]
    """
    try:
        from google.genai import types
    except ImportError:
        logger.warning("google.genai not available for spatial compliance")
        return []

    anchors_text = _build_anchors_text(characters, bible)
    prompt = _EXTRACTION_PROMPT.format(character_anchors=anchors_text)

    client = _get_gemini_client()
    if not client:
        logger.warning("No Gemini API key for spatial compliance check")
        return []

    try:
        response = client.models.generate_content(
            model="gemini-2.5-flash",
            contents=[
                types.Part.from_bytes(data=image_data, mime_type="image/png"),
                prompt,
            ],
            config=types.GenerateContentConfig(
                response_mime_type="application/json",
                response_schema=_EXTRACTION_SCHEMA,
                temperature=0.1,
            ),
        )

        if response.text:
            raw_text = response.text.strip()
            # Strip markdown code fences if present
            if raw_text.startswith("```json"):
                raw_text = raw_text[7:]
                if raw_text.endswith("```"):
                    raw_text = raw_text[:-3]
                raw_text = raw_text.strip()
            elif raw_text.startswith("```"):
                raw_text = raw_text[3:]
                if raw_text.endswith("```"):
                    raw_text = raw_text[:-3]
                raw_text = raw_text.strip()
            result = json.loads(raw_text)
            if isinstance(result, list):
                return result
            return []
    except Exception as e:
        logger.error(f"Spatial compliance extraction failed: {e}")
        return []

    return []


# ── Severity calculation ──────────────────────────────────────────────

def _get_expected_positions(shot: dict, bible: dict) -> dict:
    """Extract expected character positions from the spatial instruction.

    Returns {char_name: {position, depth, format_num}} based on what was sent.
    """
    spatial_data = shot.get("spatial_data", {})
    asset_data = shot.get("asset_data", {})
    prompt_data = shot.get("prompt_data", {})
    routing_data = shot.get("routing_data", {})
    characters = asset_data.get("characters", [])

    from recoil.pipeline._lib.prompt_engine import _normalize_camera_side, _resolve_display_name

    camera_side = _normalize_camera_side(spatial_data.get("camera_side", "A"))
    shot_type = prompt_data.get("shot_type", "MS")
    is_env = routing_data.get("is_env_only", False)

    if is_env or len(characters) == 0:
        return {}  # ENV/INSERT — no character expectations

    expected = {}

    if len(characters) == 1:
        char_name = _resolve_display_name(characters[0], bible)
        pos = characters[0].get("screen_position", "center") if isinstance(characters[0], dict) else "center"
        # Determine format
        from recoil.pipeline._lib.prompt_engine import _is_moving, _is_punch_in
        screen_dir = spatial_data.get("screen_direction", "center")
        is_punch = _is_punch_in(shot, shot.get("_scene_shots"))

        if is_punch:
            format_num = 5
        elif _is_moving(screen_dir, prompt_data):
            format_num = 3
        else:
            format_num = 4

        expected[char_name] = {
            "position": pos.upper() if pos else "CENTER",
            "depth": "SAME_PLANE",
            "format_num": format_num,
        }

    elif len(characters) >= 2:
        # Determine if wide (Format 1) or OTS (Format 2)
        wide_types = {"WS", "EWS", "VWS"}
        is_wide = shot_type.upper() in wide_types

        if is_wide:
            from recoil.pipeline._lib.prompt_engine import _resolve_lr_assignment
            left_name, right_name = _resolve_lr_assignment(characters, spatial_data, camera_side, bible)
            expected[left_name] = {"position": "LEFT", "depth": "SAME_PLANE", "format_num": 1}
            expected[right_name] = {"position": "RIGHT", "depth": "SAME_PLANE", "format_num": 1}
        else:
            from recoil.pipeline._lib.prompt_engine import _resolve_ots_assignment
            fg_char, bg_char, fg_side = _resolve_ots_assignment(characters, spatial_data, camera_side, bible)
            bg_pos = "CENTER"  # In 9:16 OTS, BG character is always center
            expected[fg_char] = {"position": fg_side, "depth": "FOREGROUND", "format_num": 2}
            expected[bg_char] = {"position": bg_pos, "depth": "BACKGROUND", "format_num": 2}

    return expected


def _compare_to_instruction(extracted: list[dict], expected: dict) -> list[dict]:
    """Compare extracted positions to spatial instruction.

    Returns list of flag dicts: [{flag, severity, detail}, ...]
    """
    flags = []

    for char_name, spec in expected.items():
        # Find this character in extracted data (fuzzy: Flash may return
        # "Wren (visual description...)" instead of just "Wren")
        match = None
        char_upper = char_name.strip().upper()
        for ext in extracted:
            ext_name = ext.get("character_name", "").strip().upper()
            if ext_name == char_upper or ext_name.startswith(char_upper + " ") or ext_name.startswith(char_upper + "("):
                match = ext
                break

        if match is None:
            flags.append({
                "flag": "MISSING_CHARACTER",
                "severity": "VIOLATION",
                "detail": f"{char_name} not found in frame",
            })
            continue

        format_num = spec.get("format_num", 0)
        expected_pos = spec["position"]
        extracted_pos = match.get("horizontal_position", "CENTER")

        # Horizontal position check
        if extracted_pos != expected_pos:
            if expected_pos == "CENTER" and extracted_pos in ("LEFT", "RIGHT"):
                flags.append({
                    "flag": "POSITION_DRIFT",
                    "severity": "INFO",
                    "detail": f"{char_name}: expected CENTER, got {extracted_pos}",
                })
            else:
                flags.append({
                    "flag": "PROMPT_MISMATCH",
                    "severity": "VIOLATION",
                    "detail": f"{char_name}: expected {expected_pos}, got {extracted_pos}",
                })

        # Depth check (OTS only)
        if format_num == 2:
            expected_depth = spec["depth"]
            extracted_depth = match.get("depth", "SAME_PLANE")
            if (expected_depth == "FOREGROUND" and extracted_depth == "BACKGROUND") or \
               (expected_depth == "BACKGROUND" and extracted_depth == "FOREGROUND"):
                flags.append({
                    "flag": "DEPTH_INVERSION",
                    "severity": "VIOLATION",
                    "detail": f"{char_name}: expected {expected_depth}, got {extracted_depth}",
                })
            elif extracted_depth != expected_depth and extracted_depth == "SAME_PLANE":
                flags.append({
                    "flag": "DEPTH_FLATTENED",
                    "severity": "WARNING",
                    "detail": f"{char_name}: expected {expected_depth}, got SAME_PLANE",
                })

        # Lens contact check (all formats except Format 4 solo static)
        if match.get("facing") == "CAMERA" and format_num != 4:
            flags.append({
                "flag": "LENS_CONTACT",
                "severity": "WARNING",
                "detail": f"{char_name} is looking directly at camera",
            })

    return flags


def _cross_shot_check(
    extracted: list[dict],
    prev_compliance: dict | None,
    camera_side_changed: bool,
    is_multi_char: bool,
) -> list[dict]:
    """Check for 180-degree line crosses between adjacent shots.

    Compares RELATIVE positions (not absolute) to detect character swaps.
    """
    flags = []

    if not prev_compliance or camera_side_changed or not is_multi_char:
        if not prev_compliance and is_multi_char:
            flags.append({
                "flag": "CONTINUITY_UNCHECKED_MISSING_PREV",
                "severity": "INFO",
                "detail": "Previous shot not available for continuity check",
            })
        return flags

    prev_extracted = prev_compliance.get("extracted", [])
    if len(prev_extracted) < 2 or len(extracted) < 2:
        return flags

    # Build name→position maps (normalize to first word to handle
    # Flash returning "Wren (visual desc...)" vs just "Wren")
    def _normalize_name(n):
        """Extract character name, stripping parenthetical descriptions."""
        n = str(n or "").strip().upper()
        paren = n.find("(")
        if paren > 0:
            n = n[:paren].strip()
        return n

    prev_positions = {
        _normalize_name(e["character_name"]): _POS_ORDINAL.get(e["horizontal_position"], 1)
        for e in prev_extracted
    }
    curr_positions = {
        _normalize_name(e["character_name"]): _POS_ORDINAL.get(e["horizontal_position"], 1)
        for e in extracted
    }

    # Find characters present in both shots
    shared = set(prev_positions.keys()) & set(curr_positions.keys())
    if len(shared) < 2:
        return flags

    # Check relative order — did any pair swap?
    shared_list = sorted(shared)
    for i in range(len(shared_list)):
        for j in range(i + 1, len(shared_list)):
            a, b = shared_list[i], shared_list[j]
            prev_order = prev_positions[a] < prev_positions[b]
            curr_order = curr_positions[a] < curr_positions[b]

            if prev_order != curr_order and prev_positions[a] != prev_positions[b] and curr_positions[a] != curr_positions[b]:
                flags.append({
                    "flag": "LINE_CROSS",
                    "severity": "VIOLATION",
                    "detail": f"{a}/{b} swapped relative positions (180-degree line crossed)",
                })

    return flags


def _attribute_failure(flags: list[dict], authored_prompt: str, spatial_block: str) -> str:
    """Determine if failure is in prompting or generation.

    Checks if the authored prompt contains spatial instruction keywords.
    Returns 'GENERATION_FAILURE', 'AUTHORED_PROMPT_FAILURE', or 'NONE'.
    """
    has_violation = any(f["severity"] == "VIOLATION" for f in flags)
    if not has_violation:
        return "NONE"

    if not authored_prompt or not spatial_block:
        return "UNKNOWN"

    # Check if key spatial tokens appear in the authored prompt
    spatial_block_upper = spatial_block.upper()
    spatial_keywords = []
    for token in ["LEFT", "RIGHT", "CENTER", "FOREGROUND", "BACKGROUND", "SCREEN-LEFT", "SCREEN-RIGHT"]:
        if token in spatial_block_upper:
            spatial_keywords.append(token)

    if not spatial_keywords:
        return "UNKNOWN"

    authored_upper = authored_prompt.upper()
    matches = sum(1 for kw in spatial_keywords if kw in authored_upper)

    # If more than half the keywords appear in the authored prompt,
    # Flash understood the instruction but generated it wrong
    if matches >= len(spatial_keywords) / 2:
        return "GENERATION_FAILURE"
    return "AUTHORED_PROMPT_FAILURE"


def _max_severity(flags: list[dict]) -> str:
    """Return the highest severity from a list of flags."""
    severity_order = {"VIOLATION": 3, "WARNING": 2, "INFO": 1, "PASS": 0}
    if not flags:
        return "PASS"
    return max(flags, key=lambda f: severity_order.get(f["severity"], 0))["severity"]


# ── Main entry point ──────────────────────────────────────────────────

COMPLIANCE_CHECK_COST = 0.010  # ~$0.01 per Flash vision call


def run_spatial_compliance(
    image_data: bytes,
    shot: dict,
    bible: dict,
    prev_shot_compliance: dict | None = None,
    prev_camera_side: str | None = None,
    authored_prompt: str = "",
    spatial_block: str = "",
) -> dict:
    """Run the full spatial compliance check on a generated frame.

    Args:
        image_data: Raw PNG bytes of the generated frame
        shot: Full shot data dict (with spatial_data, asset_data, etc.)
        bible: Global bible dict
        prev_shot_compliance: The spatial_compliance dict from the previous shot's take_record
        prev_camera_side: The camera_side value from the previous shot
        authored_prompt: Flash's authored prompt text (for failure attribution)
        spatial_block: The spatial injection block that was sent (for failure attribution)

    Returns:
        dict with keys: extracted, severity, flags, failure_attribution, human_override, checked_at
    """
    spatial_data = shot.get("spatial_data", {})
    asset_data = shot.get("asset_data", {})
    characters = asset_data.get("characters", [])

    from recoil.pipeline._lib.prompt_engine import _normalize_camera_side
    camera_side = _normalize_camera_side(spatial_data.get("camera_side", "A"))
    prev_camera_side = _normalize_camera_side(prev_camera_side) if prev_camera_side else None

    # Skip check for shots without spatial data
    if not spatial_data.get("camera_side"):
        return {
            "extracted": [], "severity": "PASS", "flags": [],
            "failure_attribution": "NONE", "human_override": False,
            "checked_at": time.time(), "skipped": True, "skip_reason": "no_spatial_data", "cost": 0.0,
        }

    # Skip ENV shots and shots with no characters (save $0.01 + latency)
    is_env = shot.get("routing_data", {}).get("is_env_only", False)
    if is_env or not characters:
        return {
            "extracted": [], "severity": "PASS", "flags": [],
            "failure_attribution": "NONE", "human_override": False,
            "checked_at": time.time(), "skipped": True, "skip_reason": "env_or_no_chars", "cost": 0.0,
        }

    # Extract positions from the generated image
    extracted = _extract_positions(image_data, characters, bible)

    # Get expected positions from the spatial instruction
    expected = _get_expected_positions(shot, bible)

    # Compare to instruction
    flags = _compare_to_instruction(extracted, expected)

    # Cross-shot check
    camera_side_changed = (prev_camera_side is not None and prev_camera_side != camera_side)
    is_multi_char = len(characters) >= 2
    cross_flags = _cross_shot_check(extracted, prev_shot_compliance, camera_side_changed, is_multi_char)
    flags.extend(cross_flags)

    # Failure attribution
    attribution = _attribute_failure(flags, authored_prompt, spatial_block)

    return {
        "extracted": extracted,
        "severity": _max_severity(flags),
        "flags": [{"flag": f["flag"], "severity": f["severity"], "detail": f["detail"]} for f in flags],
        "failure_attribution": attribution,
        "human_override": False,
        "checked_at": time.time(),
        "skipped": False,
        "cost": COMPLIANCE_CHECK_COST,
    }
