"""Tests for vision critics — RefImageCritic, StartFrameCritic, VideoFrameCritic.

All tests mock validate_image / validate_video_frames calls so no real
Gemini API calls are made.
"""

import pytest
from unittest.mock import patch

from recoil.core.critic import Severity
from recoil.pipeline._lib.critics.ref_image_critic import RefImageCritic
from recoil.pipeline._lib.critics.start_frame_critic import StartFrameCritic
from recoil.pipeline._lib.critics.video_frame_critic import VideoFrameCritic


# ── Helpers ──────────────────────────────────────────────────────────────


def _make_image_result(checks, overrides=None):
    """Build a validate_image return dict where all checks pass by default.

    Args:
        checks: list of check dicts (name, question, expected, severity).
        overrides: dict mapping check name -> answer string to override.
    """
    overrides = overrides or {}
    results = []
    for check in checks:
        name = check["name"]
        expected = check["expected"]
        answer = overrides.get(name, expected)  # pass by default
        passed = expected.lower() in answer.lower()
        results.append(
            {
                "name": name,
                "passed": passed,
                "answer": answer,
                "expected": expected,
                "severity": check["severity"],
            }
        )
    return {
        "passed": all(r["passed"] for r in results),
        "results": results,
        "error": None,
    }


def _make_video_result(checks, num_frames=5, failing_frames_by_check=None):
    """Build a validate_video_frames return dict.

    Args:
        checks: list of check dicts.
        num_frames: how many frames.
        failing_frames_by_check: dict mapping check name -> list of frame
            indices that should fail that check.
    """
    failing_frames_by_check = failing_frames_by_check or {}
    frame_results = []
    for i in range(num_frames):
        per_check_results = []
        for check in checks:
            name = check["name"]
            expected = check["expected"]
            failing_indices = failing_frames_by_check.get(name, [])
            if i in failing_indices:
                # Fail: answer is the opposite of expected
                answer = "yes" if expected.lower() == "no" else "no"
                passed = False
            else:
                answer = expected
                passed = True
            per_check_results.append(
                {
                    "name": name,
                    "passed": passed,
                    "answer": answer,
                    "expected": expected,
                    "severity": check["severity"],
                }
            )
        frame_results.append(
            {
                "passed": all(r["passed"] for r in per_check_results),
                "results": per_check_results,
                "error": None,
                "frame_index": i,
                "frame_path": f"/tmp/frame_{i:03d}.png",
            }
        )
    return {
        "passed": all(fr["passed"] for fr in frame_results),
        "frame_results": frame_results,
        "error": None,
    }


# ── TestRefImageCritic ───────────────────────────────────────────────────


class TestRefImageCritic:
    """Tests for RefImageCritic."""

    def test_init_human_defaults(self):
        """Human character type sets 2 legs and 4 total limbs."""
        critic = RefImageCritic(character_type="human")
        assert critic.character_type == "human"
        assert critic.legs_on_ground == 2
        assert critic.total_limbs == 4
        assert critic.max_attempts == 1
        assert critic.expected_props == []

    def test_init_quadruped_defaults(self):
        """Quadruped sets 4 legs and 4 total limbs."""
        critic = RefImageCritic(character_type="quadruped")
        assert critic.legs_on_ground == 4
        assert critic.total_limbs == 4

    def test_init_vehicle_defaults(self):
        """Vehicle sets 0 limbs — no anatomy checks."""
        critic = RefImageCritic(character_type="vehicle")
        assert critic.legs_on_ground == 0
        assert critic.total_limbs == 0

    def test_init_custom_limb_counts(self):
        """Custom legs_on_ground and total_limbs override defaults."""
        critic = RefImageCritic(
            character_type="human",
            legs_on_ground=1,
            total_limbs=3,
        )
        assert critic.legs_on_ground == 1
        assert critic.total_limbs == 3

    @patch("recoil.pipeline._lib.critics.ref_image_critic.validate_image")
    def test_passing_evaluation(self, mock_validate):
        """All checks pass — result should have all passed dimensions."""
        critic = RefImageCritic(character_type="human")
        checks = critic._build_checks()

        mock_validate.return_value = _make_image_result(checks)

        dims = critic.evaluate("/fake/image.png", {})

        assert all(d.passed for d in dims)
        assert len(dims) == len(checks)
        # Should have LIMB_COUNT, EXTRA_APPENDAGES, BACKGROUND_CLEAN
        dim_names = [d.name for d in dims]
        assert "LIMB_COUNT" in dim_names
        assert "EXTRA_APPENDAGES" in dim_names
        assert "BACKGROUND_CLEAN" in dim_names

    @patch("recoil.pipeline._lib.critics.ref_image_critic.validate_image")
    def test_extra_limbs_detection(self, mock_validate):
        """EXTRA_APPENDAGES fails when phantom limbs are detected."""
        critic = RefImageCritic(character_type="human")
        checks = critic._build_checks()

        mock_validate.return_value = _make_image_result(
            checks,
            overrides={"EXTRA_APPENDAGES": "yes, there are extra fingers"},
        )

        dims = critic.evaluate("/fake/image.png", {})

        extra_dim = [d for d in dims if d.name == "EXTRA_APPENDAGES"][0]
        assert not extra_dim.passed
        assert extra_dim.severity == Severity.HARD
        assert "yes, there are extra fingers" in extra_dim.message

    @patch("recoil.pipeline._lib.critics.ref_image_critic.validate_image")
    def test_prop_check_pass(self, mock_validate):
        """PROP_HELD passes when character is holding expected prop."""
        critic = RefImageCritic(
            character_type="human",
            expected_props=["sword"],
        )
        checks = critic._build_checks()

        mock_validate.return_value = _make_image_result(checks)

        dims = critic.evaluate("/fake/image.png", {})

        prop_dim = [d for d in dims if d.name == "PROP_HELD"][0]
        assert prop_dim.passed

    @patch("recoil.pipeline._lib.critics.ref_image_critic.validate_image")
    def test_prop_check_fail(self, mock_validate):
        """PROP_HELD fails when character is not holding expected prop."""
        critic = RefImageCritic(
            character_type="human",
            expected_props=["sword"],
        )
        checks = critic._build_checks()

        mock_validate.return_value = _make_image_result(
            checks,
            overrides={"PROP_HELD": "no, not holding anything"},
        )

        dims = critic.evaluate("/fake/image.png", {})

        prop_dim = [d for d in dims if d.name == "PROP_HELD"][0]
        assert not prop_dim.passed
        assert prop_dim.severity == Severity.HARD

    @patch("recoil.pipeline._lib.critics.ref_image_critic.validate_image")
    def test_vehicle_no_anatomy_checks(self, mock_validate):
        """Vehicle type should not produce LIMB_COUNT or EXTRA_APPENDAGES."""
        critic = RefImageCritic(character_type="vehicle")
        checks = critic._build_checks()

        mock_validate.return_value = _make_image_result(checks)

        dims = critic.evaluate("/fake/image.png", {})

        dim_names = [d.name for d in dims]
        assert "LIMB_COUNT" not in dim_names
        assert "EXTRA_APPENDAGES" not in dim_names
        # Should still have BACKGROUND_CLEAN
        assert "BACKGROUND_CLEAN" in dim_names

    @patch("recoil.pipeline._lib.critics.ref_image_critic.validate_image")
    def test_api_error_raises(self, mock_validate):
        """Phase 2.5 Task 4: vision API failure fail-closes (raises), not silent pass."""
        critic = RefImageCritic(character_type="human")

        mock_validate.return_value = {
            "passed": True,
            "results": [],
            "error": "Gemini API timeout",
        }

        with pytest.raises(RuntimeError, match="vision check failed"):
            critic.evaluate("/fake/image.png", {})


# ── TestStartFrameCritic ─────────────────────────────────────────────────


class TestStartFrameCritic:
    """Tests for StartFrameCritic."""

    def test_init_defaults(self):
        """Default init: scene background, no character descriptions."""
        critic = StartFrameCritic()
        assert critic.expected_background == "scene"
        assert critic.character_descriptions == []
        assert critic.expected_elements == []
        assert critic.max_attempts == 1

    def test_init_solid_color_background(self):
        """Can specify solid color expected background."""
        critic = StartFrameCritic(expected_background="solid color")
        assert critic.expected_background == "solid color"

    @patch("recoil.pipeline._lib.critics.start_frame_critic.validate_image")
    def test_white_background_failure(self, mock_validate):
        """BACKGROUND_VALID fails when scene expected but got white background."""
        critic = StartFrameCritic(expected_background="scene")
        checks = critic._build_checks()

        mock_validate.return_value = _make_image_result(
            checks,
            overrides={"BACKGROUND_VALID": "no, the background is solid white"},
        )

        dims = critic.evaluate("/fake/start_frame.png", {})

        bg_dim = [d for d in dims if d.name == "BACKGROUND_VALID"][0]
        assert not bg_dim.passed
        assert bg_dim.severity == Severity.HARD

    @patch("recoil.pipeline._lib.critics.start_frame_critic.validate_image")
    def test_background_valid_pass(self, mock_validate):
        """BACKGROUND_VALID passes when scene is present."""
        critic = StartFrameCritic(expected_background="scene")
        checks = critic._build_checks()

        mock_validate.return_value = _make_image_result(checks)

        dims = critic.evaluate("/fake/start_frame.png", {})

        bg_dim = [d for d in dims if d.name == "BACKGROUND_VALID"][0]
        assert bg_dim.passed

    @patch("recoil.pipeline._lib.critics.start_frame_critic.validate_image")
    def test_beard_detection_via_character_descriptions(self, mock_validate):
        """CHARACTER_IDENTITY_* fails when expected beard is missing."""
        critic = StartFrameCritic(
            character_descriptions=[
                {
                    "name": "Marcus",
                    "hair": "dark brown",
                    "facial_hair": "full beard",
                    "clothing": "leather jacket",
                }
            ],
        )
        checks = critic._build_checks()

        mock_validate.return_value = _make_image_result(
            checks,
            overrides={
                "CHARACTER_IDENTITY_MARCUS": "no, the character is clean-shaven"
            },
        )

        dims = critic.evaluate("/fake/start_frame.png", {})

        char_dims = [d for d in dims if d.name.startswith("CHARACTER_IDENTITY")]
        assert len(char_dims) >= 1
        char_dim = char_dims[0]
        assert not char_dim.passed
        assert char_dim.severity == Severity.HARD

    @patch("recoil.pipeline._lib.critics.start_frame_critic.validate_image")
    def test_character_identity_pass(self, mock_validate):
        """CHARACTER_IDENTITY_* passes when character matches description."""
        critic = StartFrameCritic(
            character_descriptions=[
                {
                    "name": "Marcus",
                    "hair": "dark brown",
                    "facial_hair": "full beard",
                    "clothing": "leather jacket",
                }
            ],
        )
        checks = critic._build_checks()

        mock_validate.return_value = _make_image_result(checks)

        dims = critic.evaluate("/fake/start_frame.png", {})

        char_dims = [d for d in dims if d.name.startswith("CHARACTER_IDENTITY")]
        assert len(char_dims) >= 1
        char_dim = char_dims[0]
        assert char_dim.passed

    @patch("recoil.pipeline._lib.critics.start_frame_critic.validate_image")
    def test_scene_elements_check(self, mock_validate):
        """SCENE_ELEMENTS_* is SOFT severity — missing element doesn't hard-fail."""
        critic = StartFrameCritic(expected_elements=["car", "power pole"])
        checks = critic._build_checks()

        mock_validate.return_value = _make_image_result(checks)

        dims = critic.evaluate("/fake/start_frame.png", {})

        scene_dims = [d for d in dims if d.name.startswith("SCENE_ELEMENTS")]
        assert len(scene_dims) == 2
        for d in scene_dims:
            assert d.severity == Severity.SOFT

    @patch("recoil.pipeline._lib.critics.start_frame_critic.validate_image")
    def test_composition_valid_check(self, mock_validate):
        """COMPOSITION_VALID is always present and HARD severity."""
        critic = StartFrameCritic()
        checks = critic._build_checks()

        mock_validate.return_value = _make_image_result(checks)

        dims = critic.evaluate("/fake/start_frame.png", {})

        comp_dim = [d for d in dims if d.name == "COMPOSITION_VALID"][0]
        assert comp_dim.passed
        assert comp_dim.severity == Severity.HARD

    @patch("recoil.pipeline._lib.critics.start_frame_critic.validate_image")
    def test_api_error_raises(self, mock_validate):
        """Phase 2.5 Task 4: vision API failure fail-closes (raises), not silent pass."""
        critic = StartFrameCritic(
            character_descriptions=[
                {
                    "name": "Marcus",
                    "hair": "dark brown",
                    "facial_hair": "full beard",
                    "clothing": "leather jacket",
                }
            ],
        )

        mock_validate.return_value = {
            "passed": True,
            "results": [],
            "error": "Connection reset",
        }

        with pytest.raises(RuntimeError, match="vision check failed"):
            critic.evaluate("/fake/start_frame.png", {})


# ── TestVideoFrameCritic ─────────────────────────────────────────────────


class TestVideoFrameCritic:
    """Tests for VideoFrameCritic."""

    def test_init_defaults(self):
        """Default init: human, 5 frames, no style or elements."""
        critic = VideoFrameCritic()
        assert critic.character_type == "human"
        assert critic.num_frames == 5
        assert critic.expected_style == ""
        assert critic.expected_elements == []
        assert critic.max_attempts == 1

    def test_init_custom_params(self):
        """Custom params are stored correctly."""
        critic = VideoFrameCritic(
            character_type="quadruped",
            expected_style="dark cinematic",
            expected_elements=["deer", "forest"],
            num_frames=8,
        )
        assert critic.character_type == "quadruped"
        assert critic.expected_style == "dark cinematic"
        assert critic.expected_elements == ["deer", "forest"]
        assert critic.num_frames == 8

    @patch("recoil.pipeline._lib.critics.video_frame_critic.validate_video_frames")
    def test_all_frames_pass(self, mock_validate):
        """All checks pass across all frames."""
        critic = VideoFrameCritic(
            character_type="human",
            expected_style="noir",
        )
        checks = critic._build_checks()

        mock_validate.return_value = _make_video_result(checks, num_frames=5)

        dims = critic.evaluate("/fake/video.mp4", {})

        assert all(d.passed for d in dims)
        dim_names = [d.name for d in dims]
        assert "EXTRA_LIMBS" in dim_names
        assert "STYLE_CONSISTENT" in dim_names

    @patch("recoil.pipeline._lib.critics.video_frame_critic.validate_video_frames")
    def test_extra_limbs_in_frames_with_timestamp_reporting(self, mock_validate):
        """EXTRA_LIMBS fails with specific frame indices reported."""
        critic = VideoFrameCritic(character_type="human")
        checks = critic._build_checks()

        # Frames 1 and 3 have extra limbs
        mock_validate.return_value = _make_video_result(
            checks,
            num_frames=5,
            failing_frames_by_check={"EXTRA_LIMBS": [1, 3]},
        )

        dims = critic.evaluate("/fake/video.mp4", {})

        limbs_dim = [d for d in dims if d.name == "EXTRA_LIMBS"][0]
        assert not limbs_dim.passed
        assert limbs_dim.severity == Severity.HARD
        assert "1" in limbs_dim.message
        assert "3" in limbs_dim.message

    @patch("recoil.pipeline._lib.critics.video_frame_critic.validate_video_frames")
    def test_element_persistence_check(self, mock_validate):
        """ELEMENT_PERSISTENCE is SOFT and checks per element (unique names)."""
        critic = VideoFrameCritic(
            character_type="human",
            expected_elements=["car"],
        )
        checks = critic._build_checks()

        mock_validate.return_value = _make_video_result(checks, num_frames=5)

        dims = critic.evaluate("/fake/video.mp4", {})

        elem_dims = [d for d in dims if d.name.startswith("ELEMENT_PERSISTENCE_")]
        assert len(elem_dims) >= 1
        for d in elem_dims:
            assert d.severity == Severity.SOFT

    @patch("recoil.pipeline._lib.critics.video_frame_critic.validate_video_frames")
    def test_no_style_check_when_empty(self, mock_validate):
        """STYLE_CONSISTENT is omitted when expected_style is empty."""
        critic = VideoFrameCritic(character_type="human", expected_style="")
        checks = critic._build_checks()

        mock_validate.return_value = _make_video_result(checks, num_frames=5)

        dims = critic.evaluate("/fake/video.mp4", {})

        style_dims = [d for d in dims if d.name == "STYLE_CONSISTENT"]
        assert len(style_dims) == 0

    @patch("recoil.pipeline._lib.critics.video_frame_critic.validate_video_frames")
    def test_api_error_raises(self, mock_validate):
        """Phase 2.5 Task 4: vision API failure fail-closes (raises), not silent pass."""
        critic = VideoFrameCritic(character_type="human")

        mock_validate.return_value = {
            "passed": True,
            "frame_results": [],
            "error": "ffmpeg not found",
        }

        with pytest.raises(RuntimeError, match="vision check failed"):
            critic.evaluate("/fake/video.mp4", {})

    @patch("recoil.pipeline._lib.critics.video_frame_critic.validate_video_frames")
    def test_no_frames_extracted(self, mock_validate):
        """Phase 2.5 Task 4: no frames extracted is a HARD FRAME_EXTRACTION failure, not silent pass."""
        critic = VideoFrameCritic(character_type="human")

        mock_validate.return_value = {
            "passed": True,
            "frame_results": [],
            "error": None,
        }

        dims = critic.evaluate("/fake/video.mp4", {})

        # Fail-closed: a HARD FRAME_EXTRACTION failure with a message about corruption
        assert any(
            d.name == "FRAME_EXTRACTION"
            and not d.passed
            and d.severity == Severity.HARD
            for d in dims
        )
