"""Tests for Validator gate methods in lib/validation.py.

Originally tested _run_previz_gate_1 / _run_previz_gate_2 in
tools/generate_previs.py. Those functions were extracted into the
Validator class (run_gate_1_text, run_gate_1_image). Tests updated
to use the new API.
"""

import json
import sys
import tempfile
from pathlib import Path
from types import ModuleType, SimpleNamespace
from unittest.mock import MagicMock


# Ensure project root is on path
_ROOT = Path(__file__).parent.parent
if str(_ROOT) not in sys.path:
    sys.path.insert(0, str(_ROOT))

# Stub out google.genai so we can mock it at module scope.
# The real SDK may not be installed in the test environment.
_genai_stub = ModuleType("google.genai")
_genai_types_stub = ModuleType("google.genai.types")
_genai_types_stub.GenerateContentConfig = MagicMock  # type: ignore[attr-defined]
_genai_types_stub.Part = MagicMock()  # type: ignore[attr-defined]
_genai_types_stub.ImageConfig = MagicMock  # type: ignore[attr-defined]
_genai_stub.types = _genai_types_stub  # type: ignore[attr-defined]
_genai_stub.Client = MagicMock  # type: ignore[attr-defined]

# Ensure google package exists
if "google" not in sys.modules:
    _google_stub = ModuleType("google")
    sys.modules["google"] = _google_stub
sys.modules["google.genai"] = _genai_stub
sys.modules["google.genai.types"] = _genai_types_stub

# Stub google.api_core.exceptions for tenacity retry logic
_api_core_stub = ModuleType("google.api_core")
_api_core_exc_stub = ModuleType("google.api_core.exceptions")
_api_core_exc_stub.ResourceExhausted = type("ResourceExhausted", (Exception,), {})  # type: ignore[attr-defined]
_api_core_exc_stub.InternalServerError = type("InternalServerError", (Exception,), {})  # type: ignore[attr-defined]
_api_core_exc_stub.ServiceUnavailable = type("ServiceUnavailable", (Exception,), {})  # type: ignore[attr-defined]
sys.modules["google.api_core"] = _api_core_stub
sys.modules["google.api_core.exceptions"] = _api_core_exc_stub

from recoil.pipeline._lib.validation import Validator, ValidationResult


def _mock_text_response(text: str):
    """Build a fake genai response with the given text."""
    part = MagicMock()
    part.text = text
    part.inline_data = None
    content = MagicMock()
    content.parts = [part]
    candidate = MagicMock()
    candidate.content = content
    response = MagicMock()
    response.candidates = [candidate]
    return response


def _mock_json_response(data: dict):
    """Build a fake genai response returning JSON text."""
    response = MagicMock()
    response.text = json.dumps(data)
    return response


SAMPLE_SHOT = {
    "shot_id": "EP001_SH03",
    "asset_data": {
        "characters": [
            {"char_id": "KANE"},
            {"char_id": "MIRA"},
        ],
        "location_id": "BRIDGE",
    },
    "action": "Kane and Mira argue over the navigation console.",
}


# ── Gate 1 Text (run_gate_1_text) ────────────────────────────────────


class TestGate1TextPassResponseParsing:
    """Gate 1 text correctly parses a PASS verdict."""

    def test_pass_verdict(self):
        response_text = (
            "SPEC_CHARACTERS: Kane, Mira\n"
            "PROMPT_CHARACTERS: Kane, Mira\n"
            "EXTRAS: none\n"
            "MISSING: none\n"
            "VERDICT: PASS"
        )

        mock_client = MagicMock()
        mock_client.models.generate_content.return_value = _mock_text_response(response_text)

        validator = Validator(api_key="test-key")
        validator._client = mock_client

        result = validator.run_gate_1_text(
            "Kane and Mira stand at the navigation console on the bridge.",
            SAMPLE_SHOT,
        )

        assert isinstance(result, ValidationResult)
        assert result.passed is True
        assert result.details["verdict"] == "PASS"
        assert result.details["extras"] == []
        assert result.details["missing"] == []
        assert result.cost == 0.00005


class TestGate1TextFailResponseParsing:
    """Gate 1 text correctly parses a FAIL verdict with extras and missing."""

    def test_fail_verdict(self):
        response_text = (
            "SPEC_CHARACTERS: Kane, Mira\n"
            "PROMPT_CHARACTERS: Kane, Jax\n"
            "EXTRAS: Jax\n"
            "MISSING: Mira\n"
            "VERDICT: FAIL"
        )

        mock_client = MagicMock()
        mock_client.models.generate_content.return_value = _mock_text_response(response_text)

        validator = Validator(api_key="test-key")
        validator._client = mock_client

        result = validator.run_gate_1_text(
            "Kane and Jax inspect the bridge console.",
            SAMPLE_SHOT,
        )

        assert result.passed is False
        assert result.details["verdict"] == "FAIL"
        assert "Jax" in result.details["extras"]
        assert "Mira" in result.details["missing"]
        assert result.cost == 0.00005


class TestGate1TextFailClosedOnException:
    """Gate 1 text returns passed=False when the API call raises an exception."""

    def test_exception_returns_fail(self):
        mock_client = MagicMock()
        mock_client.models.generate_content.side_effect = RuntimeError("API down")

        validator = Validator(api_key="test-key")
        validator._client = mock_client

        result = validator.run_gate_1_text(
            "Kane and Mira on the bridge.",
            SAMPLE_SHOT,
        )

        assert result.passed is False
        assert "error" in result.details
        assert result.cost == 0.0


def test_generate_previs_gate_1_text_validator_exception_fails_closed(tmp_path, monkeypatch):
    from recoil.core.paths import ProjectPaths
    from recoil.pipeline.tools import generate_previs as generate_previs_tool

    project_root = tmp_path / "project"
    paths = ProjectPaths.from_root(project_root)
    paths.plans_dir.mkdir(parents=True)
    (paths.plans_dir / "ep_001_plan.json").write_text(json.dumps({"shots": [SAMPLE_SHOT]}))

    monkeypatch.setattr(
        generate_previs_tool,
        "ProjectPaths",
        SimpleNamespace(for_project=lambda project=None: paths),
    )
    monkeypatch.setattr(
        generate_previs_tool,
        "build_previz_context",
        lambda **kwargs: [(None, "instruction", "generate a frame")],
    )
    monkeypatch.setattr(generate_previs_tool, "ExecutionStore", lambda *args, **kwargs: None)

    def fake_generate_flash_frame(*, context_parts):
        return {
            "success": True,
            "image_data": DUMMY_PNG,
            "authored_prompt": "Kane and Mira on the bridge.",
        }

    def raise_text_gate(self, authored_prompt, shot):
        raise RuntimeError("validator exploded")

    monkeypatch.setattr(generate_previs_tool, "_generate_flash_frame", fake_generate_flash_frame)
    monkeypatch.setattr(Validator, "run_gate_1_text", raise_text_gate)
    monkeypatch.setattr(
        Validator,
        "run_gate_1_image",
        lambda self, image_path: ValidationResult("gate_1_image", True, {}, cost=0.0),
    )

    result = generate_previs_tool.generate_previs(episode=1, project="project")

    gate_1_text = result[0]["gate_1_text"]
    assert gate_1_text["passed"] is False
    assert "validator exploded" in gate_1_text["details"]["error"]


class TestGate1TextExtractsCharactersFromShot:
    """Gate 1 prompt includes correct character list from the shot spec."""

    def test_prompt_includes_characters(self):
        mock_client = MagicMock()
        mock_client.models.generate_content.return_value = _mock_text_response(
            "SPEC_CHARACTERS: Kane, Mira\n"
            "PROMPT_CHARACTERS: Kane, Mira\n"
            "EXTRAS: none\n"
            "MISSING: none\n"
            "VERDICT: PASS"
        )

        validator = Validator(api_key="test-key")
        validator._client = mock_client

        validator.run_gate_1_text("test prompt", SAMPLE_SHOT)

        # Verify the prompt sent to the API contains the shot characters
        call_args = mock_client.models.generate_content.call_args
        # The gate prompt is passed as the 'contents' kwarg
        contents = call_args.kwargs.get("contents") or call_args[1].get("contents", "")
        assert "KANE" in contents
        assert "MIRA" in contents
        assert "BRIDGE" in contents
        assert "EP001_SH03" in contents


# ── Gate 1 Image (run_gate_1_image) ──────────────────────────────────


DUMMY_PNG = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100  # Fake PNG bytes


class TestGate1ImagePassAllChecks:
    """Gate 1 image passes when all mechanical checks are pass."""

    def test_pass_all(self):
        gate_response = {
            "black_frame": {"pass": True, "reason": "Normal image"},
            "watermark": {"pass": True, "reason": "No watermarks"},
            "anatomy": {"pass": True, "reason": "Normal anatomy"},
            "color": {"pass": True, "reason": "Normal colors"},
            "resolution": {"pass": True, "reason": "Good resolution"},
        }

        mock_client = MagicMock()
        mock_client.models.generate_content.return_value = _mock_json_response(gate_response)

        validator = Validator(api_key="test-key")
        validator._client = mock_client

        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
            f.write(DUMMY_PNG)
            tmp_path = Path(f.name)

        try:
            result = validator.run_gate_1_image(tmp_path)
        finally:
            tmp_path.unlink(missing_ok=True)

        assert result.passed is True
        assert result.gate == "gate_1_image"


class TestGate1ImageFailAnatomy:
    """Gate 1 image fails when anatomy check fails."""

    def test_fail_anatomy(self):
        gate_response = {
            "black_frame": {"pass": True, "reason": "Normal"},
            "watermark": {"pass": True, "reason": "Clean"},
            "anatomy": {"pass": False, "reason": "Extra arm visible"},
            "color": {"pass": True, "reason": "OK"},
            "resolution": {"pass": True, "reason": "OK"},
        }

        mock_client = MagicMock()
        mock_client.models.generate_content.return_value = _mock_json_response(gate_response)

        validator = Validator(api_key="test-key")
        validator._client = mock_client

        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
            f.write(DUMMY_PNG)
            tmp_path = Path(f.name)

        try:
            result = validator.run_gate_1_image(tmp_path)
        finally:
            tmp_path.unlink(missing_ok=True)

        assert result.passed is False
        assert result.details["anatomy"]["pass"] is False


class TestGate1ImageFailWatermark:
    """Gate 1 image fails when watermark is detected."""

    def test_fail_watermark(self):
        gate_response = {
            "black_frame": {"pass": True, "reason": "Normal"},
            "watermark": {"pass": False, "reason": "Visible watermark overlay"},
            "anatomy": {"pass": True, "reason": "OK"},
            "color": {"pass": True, "reason": "OK"},
            "resolution": {"pass": True, "reason": "OK"},
        }

        mock_client = MagicMock()
        mock_client.models.generate_content.return_value = _mock_json_response(gate_response)

        validator = Validator(api_key="test-key")
        validator._client = mock_client

        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
            f.write(DUMMY_PNG)
            tmp_path = Path(f.name)

        try:
            result = validator.run_gate_1_image(tmp_path)
        finally:
            tmp_path.unlink(missing_ok=True)

        assert result.passed is False
        assert result.details["watermark"]["pass"] is False


class TestGate1ImageFailBlackFrame:
    """Gate 1 image fails when image is a black frame."""

    def test_fail_black_frame(self):
        gate_response = {
            "black_frame": {"pass": False, "reason": "Entirely black image"},
            "watermark": {"pass": True, "reason": "N/A"},
            "anatomy": {"pass": True, "reason": "N/A"},
            "color": {"pass": True, "reason": "N/A"},
            "resolution": {"pass": True, "reason": "N/A"},
        }

        mock_client = MagicMock()
        mock_client.models.generate_content.return_value = _mock_json_response(gate_response)

        validator = Validator(api_key="test-key")
        validator._client = mock_client

        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
            f.write(DUMMY_PNG)
            tmp_path = Path(f.name)

        try:
            result = validator.run_gate_1_image(tmp_path)
        finally:
            tmp_path.unlink(missing_ok=True)

        assert result.passed is False
        assert result.details["black_frame"]["pass"] is False


class TestGate1ImageFailOnException:
    """Gate 1 image returns passed=False when API call raises (fail-closed)."""

    def test_exception_returns_fail(self):
        mock_client = MagicMock()
        mock_client.models.generate_content.side_effect = RuntimeError("API down")

        validator = Validator(api_key="test-key")
        validator._client = mock_client

        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
            f.write(DUMMY_PNG)
            tmp_path = Path(f.name)

        try:
            result = validator.run_gate_1_image(tmp_path)
        finally:
            tmp_path.unlink(missing_ok=True)

        # Gate 1 image is fail-closed (unlike gate 1 text which is fail-open)
        assert result.passed is False
        assert "error" in result.details


class TestGate1ImageMissingFile:
    """Gate 1 image returns passed=False when image file doesn't exist."""

    def test_missing_file(self):
        validator = Validator(api_key="test-key")
        result = validator.run_gate_1_image(Path("/nonexistent/image.png"))

        assert result.passed is False
        assert "error" in result.details
        assert result.cost == 0.0
