"""Tests for tools/screen_test_gen.py — Screen Test image generation."""

from pathlib import Path
from unittest.mock import MagicMock, patch, PropertyMock

import pytest

from tools.screen_test_gen import (
    build_phase_prompt,
    enrich_director_note,
    generate_phase_image,
)


# ── Sample data ──────────────────────────────────────────────────────


def _sample_char():
    return {
        "name": "JINX",
        "visual_description": (
            "Young woman, early 20s. Lean athletic build. "
            "Short asymmetric dark hair, undercut on the left. "
            "Sharp features, high cheekbones, olive skin. "
            "Small scar across the bridge of her nose."
        ),
    }


def _sample_phase():
    return {
        "phase_id": "jinx_salvager",
        "label": "Salvager",
        "wardrobe": "Worn gray tactical vest over black compression top, cargo pants, heavy boots.",
        "hair": "Short asymmetric dark hair, undercut on the left, slightly windswept.",
        "makeup": "Minimal. Smudged eyeliner. Sun-weathered skin.",
        "marks": "Small scar across the bridge of her nose. Grease stains on hands.",
    }


# ── build_phase_prompt ───────────────────────────────────────────────


class TestBuildPhasePrompt:
    def test_includes_visual_description(self):
        prompt = build_phase_prompt(_sample_char(), _sample_phase())
        assert "olive skin" in prompt
        assert "high cheekbones" in prompt

    def test_includes_wardrobe(self):
        prompt = build_phase_prompt(_sample_char(), _sample_phase())
        assert "tactical vest" in prompt
        assert "cargo pants" in prompt

    def test_includes_hair(self):
        prompt = build_phase_prompt(_sample_char(), _sample_phase())
        assert "undercut" in prompt

    def test_includes_makeup(self):
        prompt = build_phase_prompt(_sample_char(), _sample_phase())
        assert "Smudged eyeliner" in prompt

    def test_includes_marks(self):
        prompt = build_phase_prompt(_sample_char(), _sample_phase())
        assert "Grease stains" in prompt

    def test_includes_arri_alexa_anchor(self):
        prompt = build_phase_prompt(_sample_char(), _sample_phase())
        assert "Arri Alexa" in prompt

    def test_includes_portrait_aspect(self):
        prompt = build_phase_prompt(_sample_char(), _sample_phase())
        assert "9:16" in prompt

    def test_includes_gray_backdrop(self):
        prompt = build_phase_prompt(_sample_char(), _sample_phase())
        assert "gray" in prompt.lower()

    def test_includes_reference_match_instruction(self):
        prompt = build_phase_prompt(_sample_char(), _sample_phase())
        assert "reference" in prompt.lower()

    def test_without_director_note(self):
        prompt = build_phase_prompt(_sample_char(), _sample_phase())
        assert "DIRECTOR" not in prompt

    def test_with_director_note(self):
        note = "She just survived a fight. Breathing hard. Blood on her knuckles."
        prompt = build_phase_prompt(_sample_char(), _sample_phase(), enriched_note=note)
        assert "Blood on her knuckles" in prompt
        assert "DIRECTOR" in prompt

    def test_missing_optional_phase_fields(self):
        """Phase with only wardrobe should still build a valid prompt."""
        char = _sample_char()
        phase = {
            "phase_id": "jinx_minimal",
            "label": "Minimal",
            "wardrobe": "Plain white t-shirt, jeans.",
        }
        prompt = build_phase_prompt(char, phase)
        assert "white t-shirt" in prompt
        assert "Arri Alexa" in prompt


# ── enrich_director_note ─────────────────────────────────────────────


class TestEnrichDirectorNote:
    @patch("tools.screen_test_gen._get_text_client")
    def test_returns_enriched_text(self, mock_get_client):
        mock_client = MagicMock()
        mock_response = MagicMock()
        mock_response.text = "  Sweat-streaked face, clenched jaw, labored breathing, adrenaline flush  "
        mock_client.models.generate_content.return_value = mock_response
        mock_get_client.return_value = mock_client

        result = enrich_director_note(
            base_description="Young woman, early 20s, athletic build.",
            director_note="She just survived a fight.",
        )

        assert result == "Sweat-streaked face, clenched jaw, labored breathing, adrenaline flush"
        mock_client.models.generate_content.assert_called_once()

    @patch("tools.screen_test_gen._get_text_client")
    def test_passes_base_description_and_note(self, mock_get_client):
        mock_client = MagicMock()
        mock_response = MagicMock()
        mock_response.text = "enriched output"
        mock_client.models.generate_content.return_value = mock_response
        mock_get_client.return_value = mock_client

        enrich_director_note(
            base_description="Test character description.",
            director_note="Test director note.",
        )

        call_args = mock_client.models.generate_content.call_args
        prompt_content = call_args[1]["contents"] if "contents" in call_args[1] else call_args[0][1] if len(call_args[0]) > 1 else str(call_args)
        # The call should have been made with content that includes both the description and note
        mock_client.models.generate_content.assert_called_once()

    @patch("tools.screen_test_gen._get_text_client")
    def test_uses_flash_model(self, mock_get_client):
        mock_client = MagicMock()
        mock_response = MagicMock()
        mock_response.text = "enriched"
        mock_client.models.generate_content.return_value = mock_response
        mock_get_client.return_value = mock_client

        enrich_director_note("desc", "note")

        call_kwargs = mock_client.models.generate_content.call_args
        # Verify model parameter includes "flash"
        model_arg = call_kwargs[1].get("model", call_kwargs[0][0] if call_kwargs[0] else "")
        assert "flash" in model_arg.lower() or "flash" in str(call_kwargs).lower()


# ── generate_phase_image ─────────────────────────────────────────────


class TestGeneratePhaseImage:
    def _mock_response_with_image(self, image_bytes=b"fake_png_data"):
        """Build a mock genai response containing inline image data."""
        mock_part = MagicMock()
        mock_part.inline_data.data = image_bytes
        mock_candidate = MagicMock()
        mock_candidate.content.parts = [mock_part]
        mock_response = MagicMock()
        mock_response.candidates = [mock_candidate]
        return mock_response

    def _mock_response_empty(self):
        """Build a mock genai response with no image data."""
        mock_response = MagicMock()
        mock_response.candidates = []
        return mock_response

    @patch("tools.screen_test_gen._get_image_client")
    def test_success_with_hero_only(self, mock_get_client, tmp_path):
        hero = tmp_path / "hero.png"
        hero.write_bytes(b"hero_image_bytes")
        output = tmp_path / "output" / "phase.png"

        mock_api_client = MagicMock()
        mock_api_client.models.generate_content.return_value = self._mock_response_with_image()
        mock_get_client.return_value = mock_api_client

        result = generate_phase_image(
            hero_path=hero,
            three_quarter_path=None,
            prompt="test prompt",
            output_path=output,
        )

        assert result is True
        assert output.exists()
        assert output.read_bytes() == b"fake_png_data"

    @patch("tools.screen_test_gen._get_image_client")
    def test_success_with_hero_and_three_quarter(self, mock_get_client, tmp_path):
        hero = tmp_path / "hero.png"
        hero.write_bytes(b"hero_bytes")
        three_quarter = tmp_path / "three_quarter.png"
        three_quarter.write_bytes(b"tq_bytes")
        output = tmp_path / "output" / "phase.png"

        mock_api_client = MagicMock()
        mock_api_client.models.generate_content.return_value = self._mock_response_with_image()
        mock_get_client.return_value = mock_api_client

        result = generate_phase_image(
            hero_path=hero,
            three_quarter_path=three_quarter,
            prompt="test prompt",
            output_path=output,
        )

        assert result is True
        # Verify both images were passed as parts
        call_args = mock_api_client.models.generate_content.call_args
        contents = call_args[1].get("contents", call_args[0][1] if len(call_args[0]) > 1 else [])
        # Should have hero part + three_quarter part + prompt = at least 3 items
        assert len(contents) >= 3

    @patch("tools.screen_test_gen._get_image_client")
    def test_success_with_anchor(self, mock_get_client, tmp_path):
        hero = tmp_path / "hero.png"
        hero.write_bytes(b"hero_bytes")
        anchor = tmp_path / "anchor.png"
        anchor.write_bytes(b"anchor_bytes")
        output = tmp_path / "output" / "phase.png"

        mock_api_client = MagicMock()
        mock_api_client.models.generate_content.return_value = self._mock_response_with_image()
        mock_get_client.return_value = mock_api_client

        result = generate_phase_image(
            hero_path=hero,
            three_quarter_path=None,
            prompt="test prompt",
            output_path=output,
            anchor_path=anchor,
        )

        assert result is True
        # Should have hero part + anchor part + prompt = at least 3 items
        call_args = mock_api_client.models.generate_content.call_args
        contents = call_args[1].get("contents", call_args[0][1] if len(call_args[0]) > 1 else [])
        assert len(contents) >= 3

    @patch("tools.screen_test_gen._get_image_client")
    def test_creates_output_dirs(self, mock_get_client, tmp_path):
        hero = tmp_path / "hero.png"
        hero.write_bytes(b"hero_bytes")
        output = tmp_path / "deep" / "nested" / "dir" / "phase.png"

        mock_api_client = MagicMock()
        mock_api_client.models.generate_content.return_value = self._mock_response_with_image()
        mock_get_client.return_value = mock_api_client

        result = generate_phase_image(
            hero_path=hero,
            three_quarter_path=None,
            prompt="test prompt",
            output_path=output,
        )

        assert result is True
        assert output.parent.exists()

    @patch("tools.screen_test_gen._get_image_client")
    def test_returns_false_on_empty_response(self, mock_get_client, tmp_path):
        hero = tmp_path / "hero.png"
        hero.write_bytes(b"hero_bytes")
        output = tmp_path / "output" / "phase.png"

        mock_api_client = MagicMock()
        mock_api_client.models.generate_content.return_value = self._mock_response_empty()
        mock_get_client.return_value = mock_api_client

        result = generate_phase_image(
            hero_path=hero,
            three_quarter_path=None,
            prompt="test prompt",
            output_path=output,
        )

        assert result is False
        assert not output.exists()

    @patch("tools.screen_test_gen._get_image_client")
    def test_returns_false_on_api_exception(self, mock_get_client, tmp_path):
        hero = tmp_path / "hero.png"
        hero.write_bytes(b"hero_bytes")
        output = tmp_path / "output" / "phase.png"

        mock_api_client = MagicMock()
        mock_api_client.models.generate_content.side_effect = RuntimeError("API down")
        mock_get_client.return_value = mock_api_client

        result = generate_phase_image(
            hero_path=hero,
            three_quarter_path=None,
            prompt="test prompt",
            output_path=output,
        )

        assert result is False

    @patch("tools.screen_test_gen._get_image_client")
    def test_uses_correct_config(self, mock_get_client, tmp_path):
        """Verify temperature=0.6, 9:16 aspect, IMAGE+TEXT modalities."""
        hero = tmp_path / "hero.png"
        hero.write_bytes(b"hero_bytes")
        output = tmp_path / "output" / "phase.png"

        mock_api_client = MagicMock()
        mock_api_client.models.generate_content.return_value = self._mock_response_with_image()
        mock_get_client.return_value = mock_api_client

        with patch("tools.screen_test_gen.genai_types") as mock_types:
            mock_types.Part.from_bytes.return_value = MagicMock()
            mock_config = MagicMock()
            mock_types.GenerateContentConfig.return_value = mock_config
            mock_types.ImageConfig.return_value = MagicMock()

            generate_phase_image(
                hero_path=hero,
                three_quarter_path=None,
                prompt="test prompt",
                output_path=output,
            )

            # Check GenerateContentConfig was called with correct params
            config_call = mock_types.GenerateContentConfig.call_args
            assert config_call[1]["temperature"] == 0.6
            assert config_call[1]["response_modalities"] == ["IMAGE", "TEXT"]

            # Check ImageConfig was called with 9:16
            image_config_call = mock_types.ImageConfig.call_args
            assert image_config_call[1]["aspect_ratio"] == "9:16"
