"""Tests for lib/bible_loader.py and bible-aware prompt_engine functions.

Covers:
  5a. Bible integrity (structure, confidence values, model completeness)
  5b. Accessor tests (get_model_rules, supports_negative_prompt, etc.)
  5c. Length enforcement (_enforce_prompt_length)
  5d. AR validation (validate_start_frame_ar)
  5e. Generic builder (build_prompt_from_bible)
"""

import logging

import pytest

from recoil.pipeline._lib.bible_loader import (
    get_ar_rules,
    get_frame_rules,
    get_gotchas,
    get_i2v_ar_behavior,
    get_model_rules,
    get_optimal_word_range,
    get_prompt_rules,
    get_ref_rules,
    load_bible,
    reload_bible,
    supports_negative_prompt,
)
from recoil.pipeline._lib.prompt_engine import (
    _enforce_prompt_length,
    build_prompt_from_bible,
    validate_start_frame_ar,
)


# ══════════════════════════════════════════════════════════════════════
# 5a. Bible Integrity Tests
# ══════════════════════════════════════════════════════════════════════

EXPECTED_MODELS = [
    "seedream-v4.5",
    "seedream-v5-lite",
    "gemini-3-pro-image-preview",
    "gemini-3.1-flash-image-preview",
    "seeddance-2.0",
    "kling-v3",
    "wan-2.7-i2v",
    "wan-2.7-r2v",
    "veo-3.1",
    "wan-2.2",
    "z-image-turbo",
    "flux-2",
]

REQUIRED_SECTIONS = ["meta", "prompt", "refs", "aspect_ratio"]

VALID_CONFIDENCE = {"CONFIRMED", "INFERRED", "UNTESTED"}

VALID_I2V_BEHAVIORS = {"respects_param", "matches_start_frame", None}


class TestBibleIntegrity:
    """5a — structural integrity of PROMPT_BIBLE.yaml."""

    def test_all_12_models_present(self):
        bible = load_bible()
        for model in EXPECTED_MODELS:
            assert model in bible, f"Missing model: {model}"

    def test_every_model_has_required_sections(self):
        bible = load_bible()
        for model_name, model_data in bible.items():
            if model_name not in EXPECTED_MODELS:
                continue
            for section in REQUIRED_SECTIONS:
                assert section in model_data, (
                    f"{model_name} missing required section: {section}"
                )

    def test_all_confidence_values_valid(self):
        bible = load_bible()
        for model_name, model_data in bible.items():
            if model_name not in EXPECTED_MODELS:
                continue
            confidence = model_data["meta"]["confidence"]
            assert confidence in VALID_CONFIDENCE, (
                f"{model_name} has invalid confidence: {confidence}"
            )

    def test_no_negative_prompt_without_param(self):
        """If negative_prompt is true, negative_param must not be null."""
        bible = load_bible()
        for model_name, model_data in bible.items():
            if model_name not in EXPECTED_MODELS:
                continue
            prompt = model_data["prompt"]
            if prompt.get("negative_prompt"):
                assert prompt.get("negative_param") is not None, (
                    f"{model_name} has negative_prompt=true but null negative_param"
                )

    def test_all_i2v_behavior_values_valid(self):
        bible = load_bible()
        for model_name, model_data in bible.items():
            if model_name not in EXPECTED_MODELS:
                continue
            behavior = model_data["aspect_ratio"].get("i2v_behavior")
            assert behavior in VALID_I2V_BEHAVIORS, (
                f"{model_name} has invalid i2v_behavior: {behavior}"
            )

    def test_kling_v3_mode_specific_optimal_words(self):
        """kling-v3 must have different ranges for i2v vs t2v."""
        bible = load_bible()
        kling = bible["kling-v3"]["prompt"]["optimal_words"]
        assert isinstance(kling, dict), "kling-v3 optimal_words should be a dict"
        assert "i2v" in kling, "kling-v3 missing i2v optimal_words"
        assert "t2v" in kling, "kling-v3 missing t2v optimal_words"
        # i2v range should be shorter than t2v
        assert kling["i2v"][1] < kling["t2v"][1], (
            "kling-v3 i2v max should be less than t2v max"
        )


# ══════════════════════════════════════════════════════════════════════
# 5b. Accessor Tests
# ══════════════════════════════════════════════════════════════════════


class TestGetModelRules:
    def test_known_model_returns_dict(self):
        result = get_model_rules("kling-v3")
        assert isinstance(result, dict)
        assert "meta" in result
        assert "prompt" in result

    def test_unknown_model_returns_none(self):
        result = get_model_rules("nonexistent-model-xyz")
        assert result is None


class TestSupportsNegativePrompt:
    def test_kling_supports_negative(self):
        assert supports_negative_prompt("kling-v3") is True

    def test_seedream_no_negative(self):
        assert supports_negative_prompt("seedream-v4.5") is False

    def test_seeddance_no_negative(self):
        assert supports_negative_prompt("seeddance-2.0") is False


class TestGetOptimalWordRange:
    def test_default_mode(self):
        result = get_optimal_word_range("seedream-v4.5")
        assert isinstance(result, tuple)
        assert len(result) == 2
        assert result[0] < result[1]

    def test_kling_i2v_mode(self):
        result = get_optimal_word_range("kling-v3", mode="i2v")
        assert result == (15, 40)

    def test_kling_t2v_mode(self):
        result = get_optimal_word_range("kling-v3", mode="t2v")
        assert result == (50, 100)

    def test_kling_default_mode(self):
        result = get_optimal_word_range("kling-v3", mode="default")
        assert result == (50, 100)

    def test_unknown_mode_falls_back(self):
        """Requesting a mode that doesn't exist falls back to default."""
        result = get_optimal_word_range("kling-v3", mode="nonexistent")
        assert isinstance(result, tuple)
        assert len(result) == 2

    def test_unknown_model_raises(self):
        with pytest.raises(KeyError, match="Unknown model"):
            get_optimal_word_range("nonexistent-model-xyz")


class TestGetI2vArBehavior:
    def test_seeddance_respects_param(self):
        assert get_i2v_ar_behavior("seeddance-2.0") == "respects_param"

    def test_kling_matches_start_frame(self):
        assert get_i2v_ar_behavior("kling-v3") == "matches_start_frame"

    def test_veo_respects_param(self):
        assert get_i2v_ar_behavior("veo-3.1") == "respects_param"

    def test_wan_i2v_matches_start_frame(self):
        assert get_i2v_ar_behavior("wan-2.7-i2v") == "matches_start_frame"

    def test_image_model_returns_none(self):
        assert get_i2v_ar_behavior("seedream-v4.5") is None


class TestGetGotchas:
    def test_seedream_has_gotchas(self):
        gotchas = get_gotchas("seedream-v4.5")
        assert isinstance(gotchas, list)
        assert len(gotchas) > 0

    def test_kling_has_gotchas(self):
        gotchas = get_gotchas("kling-v3")
        assert isinstance(gotchas, list)
        assert len(gotchas) > 0

    def test_unknown_model_raises(self):
        with pytest.raises(KeyError, match="Unknown model"):
            get_gotchas("nonexistent-model-xyz")


class TestOtherAccessors:
    def test_get_prompt_rules_returns_dict(self):
        rules = get_prompt_rules("kling-v3")
        assert isinstance(rules, dict)
        assert "optimal_words" in rules
        assert "style" in rules

    def test_get_prompt_rules_unknown_raises(self):
        with pytest.raises(KeyError, match="Unknown model"):
            get_prompt_rules("nonexistent-model-xyz")

    def test_get_ref_rules_returns_dict(self):
        rules = get_ref_rules("seedream-v4.5")
        assert isinstance(rules, dict)
        assert "max_count" in rules

    def test_get_ar_rules_returns_dict(self):
        rules = get_ar_rules("kling-v3")
        assert isinstance(rules, dict)
        assert "supported" in rules

    def test_get_frame_rules_video_model(self):
        rules = get_frame_rules("kling-v3")
        assert isinstance(rules, dict)
        assert "start_param" in rules

    def test_get_frame_rules_image_model_returns_none(self):
        rules = get_frame_rules("seedream-v4.5")
        assert rules is None


class TestReloadBible:
    def test_reload_clears_cache(self):
        """reload_bible() should clear cache and reload."""
        # Load once to populate cache
        bible1 = load_bible()
        # Reload
        reload_bible()
        bible2 = load_bible()
        # Both should be valid dicts with same content
        assert len(bible1) == len(bible2)
        assert set(bible1.keys()) == set(bible2.keys())


# ══════════════════════════════════════════════════════════════════════
# 5c. Length Enforcement Tests
# ══════════════════════════════════════════════════════════════════════


class TestEnforcePromptLength:
    def test_short_kling_i2v_passes(self):
        """A short I2V prompt within optimal range passes without truncation."""
        prompt = "Slow pan left revealing the corridor. Subject turns."
        result = _enforce_prompt_length(prompt, "kling-v3", mode="i2v")
        assert result == prompt

    def test_500_word_kling_i2v_logs_warning(self, caplog):
        """A 500-word prompt for Kling I2V logs a WARNING."""
        prompt = " ".join(["word"] * 500)
        with caplog.at_level(
            logging.WARNING, logger="recoil.pipeline._lib.prompt_engine"
        ):
            _enforce_prompt_length(prompt, "kling-v3", mode="i2v")
        assert any(
            "words" in record.message and "optimal" in record.message
            for record in caplog.records
        ), "Expected warning about word count exceeding optimal range"

    def test_exceeding_max_chars_truncates(self):
        """Prompt exceeding max_chars gets truncated at sentence boundary."""
        # kling-v3 has max_chars=2500
        # Build a prompt over 2500 chars with sentence boundaries
        sentence = "This is a test sentence with enough words. "
        prompt = sentence * 100  # well over 2500 chars
        assert len(prompt) > 2500
        result = _enforce_prompt_length(prompt, "kling-v3")
        assert len(result) <= 2500
        # Should end at a sentence boundary (period)
        assert result.endswith(".")

    def test_unknown_model_passes_through(self):
        """Unknown model returns prompt unchanged."""
        prompt = "Some prompt text."
        result = _enforce_prompt_length(prompt, "nonexistent-model-xyz")
        assert result == prompt


# ══════════════════════════════════════════════════════════════════════
# 5d. AR Validation Tests
# ══════════════════════════════════════════════════════════════════════


class TestValidateStartFrameAr:
    def test_none_path_returns_empty(self):
        result = validate_start_frame_ar(None, "kling-v3", "9:16")
        assert result == []

    def test_seeddance_returns_empty(self):
        """Seedance respects_param — no AR mismatch possible."""
        result = validate_start_frame_ar("/some/path.png", "seeddance-2.0", "9:16")
        assert result == []

    def test_kling_with_square_image_returns_warning(self, tmp_path):
        """A 1:1 image with Kling (matches_start_frame) and 9:16 target = warning."""
        try:
            from PIL import Image
        except ImportError:
            pytest.skip("Pillow not installed")

        # Create a 100x100 (1:1) test image
        img = Image.new("RGB", (100, 100), color="red")
        img_path = tmp_path / "square.png"
        img.save(img_path)

        result = validate_start_frame_ar(img_path, "kling-v3", "9:16")
        assert len(result) == 1
        assert "mismatch" in result[0].lower()

    def test_kling_with_matching_ar_no_warning(self, tmp_path):
        """A 9:16 image with Kling and 9:16 target = no warning."""
        try:
            from PIL import Image
        except ImportError:
            pytest.skip("Pillow not installed")

        # Create a 9:16 test image (90x160)
        img = Image.new("RGB", (90, 160), color="blue")
        img_path = tmp_path / "portrait.png"
        img.save(img_path)

        result = validate_start_frame_ar(img_path, "kling-v3", "9:16")
        assert result == []


# ══════════════════════════════════════════════════════════════════════
# 5e. Generic Builder Tests
# ══════════════════════════════════════════════════════════════════════


class TestBuildPromptFromBible:
    def test_seedream_returns_str_with_description(self):
        result = build_prompt_from_bible(
            model="seedream-v4.5",
            scene_description="A neon-lit alley at night",
        )
        assert isinstance(result, str)
        assert "neon-lit alley" in result

    def test_z_image_turbo_shorter_output(self):
        """Z-Image Turbo uses keywords style — should produce shorter output."""
        scene = "A neon-lit alley at night. Rain falling on concrete."
        turbo_result = build_prompt_from_bible(
            model="z-image-turbo",
            scene_description=scene,
        )
        build_prompt_from_bible(
            model="seedream-v4.5",
            scene_description=scene,
        )
        assert isinstance(turbo_result, str)
        assert len(turbo_result) > 0
        # Keywords style splits on periods/commas into tag-like fragments
        assert "," in turbo_result

    @pytest.mark.parametrize("model", EXPECTED_MODELS)
    def test_all_models_produce_nonempty_str(self, model):
        """Smoke test: every model produces a non-empty string."""
        result = build_prompt_from_bible(
            model=model,
            scene_description="A neon-lit scene at night",
        )
        assert isinstance(result, str)
        assert len(result) > 5, f"{model} produced output shorter than 5 chars"

    def test_unknown_model_raises(self):
        with pytest.raises(KeyError, match="Unknown model"):
            build_prompt_from_bible(
                model="nonexistent-model-xyz",
                scene_description="test",
            )

    def test_cot_prose_has_step_by_step_prefix(self):
        """cot_prose style should produce 'Let's think step by step' prefix."""
        result = build_prompt_from_bible(
            model="seedream-v5-lite",
            scene_description="A warrior in a dark forest",
        )
        assert result.startswith("Let's think step by step")

    def test_kwargs_appear_in_output(self):
        """Passing characters, location, camera kwargs should include them in output."""
        result = build_prompt_from_bible(
            model="seedream-v4.5",
            scene_description="A neon-lit alley",
            characters=["Marcus", "Elena"],
            location="downtown Tokyo",
            camera="wide shot",
        )
        assert "Marcus" in result
        assert "Elena" in result
        assert "downtown Tokyo" in result
        assert "wide shot" in result


# ══════════════════════════════════════════════════════════════════════
# 5f. Additional Edge Case Tests
# ══════════════════════════════════════════════════════════════════════


class TestEnforcePromptLengthEdgeCases:
    def test_truncation_no_period_uses_rstrip_fallback(self):
        """When the prompt has no periods, truncation falls back to rstrip."""
        # kling-v3 has max_chars=2500
        # Build a prompt with no sentence boundaries (no periods)
        prompt = "word " * 600  # ~3000 chars, no periods
        assert len(prompt) > 2500
        result = _enforce_prompt_length(prompt, "kling-v3")
        assert len(result) <= 2500
        # No period found, so it should fall back to rstrip
        assert not result.endswith(".")


class TestValidateStartFrameArEdgeCases:
    def test_malformed_target_ar_returns_warning(self):
        """A malformed target_ar string (e.g. 'invalid') should return a parse error, not crash."""
        try:
            from PIL import Image
        except ImportError:
            pytest.skip("Pillow not installed")

        import tempfile
        import os

        # Create a small test image
        img = Image.new("RGB", (100, 100), color="green")
        fd, path = tempfile.mkstemp(suffix=".png")
        os.close(fd)
        try:
            img.save(path)
            result = validate_start_frame_ar(path, "kling-v3", "invalid")
            assert len(result) == 1
            assert "could not parse" in result[0].lower()
        finally:
            os.unlink(path)
