"""Spatial data validation pass — runs after Stage 2 (Storyboard Pass).

Catches contradictions between spatial_data fields and shot content:
  - toward-camera on non-movement shots
  - excessive center/solo defaults
  - eyeline conflicts with action_line/subject_line

Can auto-correct obvious errors or flag for human review.
"""

import json
import logging
from pathlib import Path

from recoil.core.paths import ProjectPaths

logger = logging.getLogger(__name__)

# Heuristics for detecting incorrect toward-camera assignments
_OFF_SCREEN_GAZE_KEYWORDS = [
    "looking at", "staring at", "gazing at", "watching", "examining",
    "scanning", "focused on", "eyes on", "gaze locked on", "gaze fixed on",
    "looking down", "looking away", "looking off", "looking up",
    "turned toward", "facing away", "turned away",
    "off-screen", "off screen", "offscreen",
]

_TOWARD_CAMERA_VALID_KEYWORDS = [
    "walking toward camera", "moving toward camera", "approaching camera",
    "charging at camera", "lunging toward", "runs toward",
    "breaking the fourth wall", "addresses the audience",
]


def validate_spatial_data(plan: dict, auto_correct: bool = True) -> dict:
    """Validate and optionally correct spatial_data in a shot plan.

    Returns dict with:
      corrections: list of {shot_id, field, old, new, reason}
      warnings: list of {shot_id, message}
      stats: {total, corrected, warned}
    """
    shots = plan.get("shots", [])
    corrections = []
    warnings = []

    for shot in shots:
        shot_id = shot.get("shot_id", "?")
        spatial = shot.get("spatial_data", {})
        prompt_data = shot.get("prompt_data", {})
        skeleton = prompt_data.get("prompt_skeleton", {})
        shot_type = prompt_data.get("shot_type", "MS")

        screen_dir = spatial.get("screen_direction", "center")
        interaction = spatial.get("character_relationships", {}).get(
            "interaction_type", "solo"
        )

        # Combine text fields for keyword matching
        action = skeleton.get("action_line", "").lower()
        subject = skeleton.get("subject_line", "").lower()
        combined = f"{action} {subject}"

        # ── Check 1: toward-camera with off-screen gaze in text ──
        if screen_dir == "toward-camera":
            has_offscreen_gaze = any(kw in combined for kw in _OFF_SCREEN_GAZE_KEYWORDS)
            has_valid_toward = any(kw in combined for kw in _TOWARD_CAMERA_VALID_KEYWORDS)

            if has_offscreen_gaze and not has_valid_toward:
                reason = (
                    f"Action/subject mentions off-screen gaze "
                    f"but screen_direction is toward-camera"
                )
                if auto_correct:
                    spatial["screen_direction"] = "center"
                    corrections.append({
                        "shot_id": shot_id,
                        "field": "screen_direction",
                        "old": "toward-camera",
                        "new": "center",
                        "reason": reason,
                    })
                else:
                    warnings.append({"shot_id": shot_id, "message": reason})

            elif not has_valid_toward:
                # toward-camera without explicit movement — flag for review
                warnings.append({
                    "shot_id": shot_id,
                    "message": (
                        f"screen_direction=toward-camera but no explicit "
                        f"movement toward camera found in action text"
                    ),
                })

        # ── Check 2: CU/ECU with toward-camera (almost always wrong) ──
        if screen_dir == "toward-camera" and shot_type in ("CU", "ECU", "BCU", "XCU"):
            has_valid_toward = any(kw in combined for kw in _TOWARD_CAMERA_VALID_KEYWORDS)
            if not has_valid_toward:
                reason = (
                    f"Close-up ({shot_type}) with toward-camera — "
                    f"close-ups are almost always off-axis observations"
                )
                if auto_correct:
                    spatial["screen_direction"] = "center"
                    corrections.append({
                        "shot_id": shot_id,
                        "field": "screen_direction",
                        "old": "toward-camera",
                        "new": "center",
                        "reason": reason,
                    })
                else:
                    warnings.append({"shot_id": shot_id, "message": reason})

        # ── Check 3: Excessive center/solo pattern ──
        # (flagged as warning only — too many center shots in a row)

    # Check for runs of center
    center_run = 0
    for shot in shots:
        sd = shot.get("spatial_data", {}).get("screen_direction", "center")
        if sd == "center":
            center_run += 1
            if center_run >= 4:
                warnings.append({
                    "shot_id": shot.get("shot_id", "?"),
                    "message": f"{center_run} consecutive center shots — consider lateral variation",
                })
        else:
            center_run = 0

    return {
        "corrections": corrections,
        "warnings": warnings,
        "stats": {
            "total": len(shots),
            "corrected": len(corrections),
            "warned": len(warnings),
        },
    }


def validate_and_fix_plan(
    episode: int, project: str, auto_correct: bool = True
) -> dict:
    """Load a plan, run spatial validation, save if corrected.

    Returns validation result dict.
    """
    plan_path = (
        ProjectPaths.for_project(project).plans_dir
        / f"ep_{episode:03d}_plan.json"
    )
    if not plan_path.exists():
        return {"error": f"Plan not found: {plan_path}"}

    plan = json.loads(plan_path.read_text(encoding="utf-8"))
    result = validate_spatial_data(plan, auto_correct=auto_correct)

    if auto_correct and result["corrections"]:
        plan_path.write_text(
            json.dumps(plan, indent=2, ensure_ascii=False), encoding="utf-8"
        )
        logger.info(
            "Spatial validation: %d corrections applied to ep_%03d",
            len(result["corrections"]), episode,
        )
        for c in result["corrections"]:
            logger.info(
                "  %s: %s %s → %s (%s)",
                c["shot_id"], c["field"], c["old"], c["new"], c["reason"],
            )

    return result
