"""Tests for Flash enrichment wrapper — ENV bypass, config bypass, validation."""

from unittest.mock import patch


def test_env_bypass():
    from recoil.pipeline._lib.prompt_engine import enrich_prompt

    result, version = enrich_prompt("empty corridor", "nbp", [], is_env=True)
    assert result == "empty corridor"
    assert version == "bypass_env"


def test_config_bypass():
    from recoil.pipeline._lib.prompt_engine import enrich_prompt

    result, version = enrich_prompt(
        "test", "nbp", [], project_config={"skip_flash_enrichment": True}
    )
    assert result == "test"
    assert version == "bypass_config"


def test_missing_prompt_file_bypass():
    from recoil.pipeline._lib.prompt_engine import enrich_prompt

    result, version = enrich_prompt("test prompt", "nonexistent_model", ["term1"])
    assert result == "test prompt"
    assert version == "bypass_no_prompt_file"


# ══════════════════════════════════════════════════════════════════════
# MOCK-BASED TESTS — exercise Flash API call + LOCKED_TERMS validation
# ══════════════════════════════════════════════════════════════════════


def _patch_enrichment(flash_return, prompt_files=None):
    """Helper: patch _call_flash_enrichment and load_prompt_file for mock tests.

    Args:
        flash_return: Value(s) for _call_flash_enrichment. Pass a list for
            sequential return values (first call, retry call, etc.), or a
            single string/side_effect for all calls.
        prompt_files: Optional dict mapping filename substrings to content.
            Defaults to returning "base instructions" / "model instructions"
            for any file.

    Returns:
        Tuple of (flash_patcher, prompt_patcher) context-manager patches.
    """
    if prompt_files is None:

        def _default_loader(filename):
            if "base" in filename:
                return "base instructions"
            return "model instructions"

        prompt_patch = patch(
            "recoil.pipeline._lib.prompt_engine.load_prompt_file",
            side_effect=_default_loader,
        )
    else:
        prompt_patch = patch(
            "recoil.pipeline._lib.prompt_engine.load_prompt_file",
            side_effect=lambda f: prompt_files.get(f, "instructions"),
        )

    if isinstance(flash_return, list):
        flash_patch = patch(
            "recoil.pipeline._lib.prompt_engine._call_flash_enrichment",
            side_effect=flash_return,
        )
    elif callable(flash_return) and not isinstance(flash_return, str):
        flash_patch = patch(
            "recoil.pipeline._lib.prompt_engine._call_flash_enrichment",
            side_effect=flash_return,
        )
    else:
        flash_patch = patch(
            "recoil.pipeline._lib.prompt_engine._call_flash_enrichment",
            return_value=flash_return,
        )

    return flash_patch, prompt_patch


class TestEnrichPromptMocked:
    """Mock-based tests that exercise the Flash API call path."""

    def test_happy_path(self):
        """Flash returns enriched prompt with all locked terms — success."""
        from recoil.pipeline._lib.prompt_engine import enrich_prompt

        enriched_text = "A cinematic wide shot, amber light, of the ruined corridor"
        locked = ["amber light", "ruined corridor"]

        flash_patch, prompt_patch = _patch_enrichment(enriched_text)
        with flash_patch as mock_flash, prompt_patch:
            result, version = enrich_prompt(
                "corridor scene",
                "nbp",
                locked,
            )

        assert result == enriched_text
        assert version == "v1.0"
        # Flash should have been called exactly once (no retry needed)
        assert mock_flash.call_count == 1

    def test_retry_on_locked_terms_violation(self):
        """First call drops a locked term, retry at temp=0 succeeds."""
        from recoil.pipeline._lib.prompt_engine import enrich_prompt

        locked = ["amber light", "ruined corridor"]
        # First call: missing "ruined corridor"
        first_response = "A cinematic wide shot, amber light, of the hallway"
        # Retry: includes both locked terms
        retry_response = "A cinematic wide shot, amber light, of the ruined corridor"

        flash_patch, prompt_patch = _patch_enrichment([first_response, retry_response])
        with flash_patch as mock_flash, prompt_patch:
            result, version = enrich_prompt(
                "corridor scene",
                "nbp",
                locked,
            )

        assert result == retry_response
        assert version == "v1.0"
        assert mock_flash.call_count == 2
        # Verify retry was called with temperature=0.0
        retry_call = mock_flash.call_args_list[1]
        assert retry_call.kwargs.get("temperature") == 0.0

    def test_timeout_fallback(self):
        """Flash call returns empty string (simulating timeout/error) — fallback to original."""
        from recoil.pipeline._lib.prompt_engine import enrich_prompt

        locked = ["amber light"]

        flash_patch, prompt_patch = _patch_enrichment("")
        with flash_patch as mock_flash, prompt_patch:
            result, version = enrich_prompt(
                "corridor scene",
                "nbp",
                locked,
            )

        assert result == "corridor scene"
        assert version == "fallback_error"
        assert mock_flash.call_count == 1

    def test_persistent_locked_terms_failure(self):
        """Both attempts drop a locked term — fallback to original prompt."""
        from recoil.pipeline._lib.prompt_engine import enrich_prompt

        locked = ["amber light", "ruined corridor"]
        # Both calls drop "ruined corridor"
        bad_response = "A cinematic wide shot, amber light, of the hallway"

        flash_patch, prompt_patch = _patch_enrichment([bad_response, bad_response])
        with flash_patch as mock_flash, prompt_patch:
            result, version = enrich_prompt(
                "corridor scene",
                "nbp",
                locked,
            )

        assert result == "corridor scene"
        assert version == "fallback_validation"
        assert mock_flash.call_count == 2

    def test_malformed_response_fallback(self):
        """Flash returns empty (simulating malformed/non-text response) on both attempts — fallback."""
        from recoil.pipeline._lib.prompt_engine import enrich_prompt

        locked = ["amber light"]
        # _call_flash_enrichment returns "" when response is malformed/empty
        # (the function catches exceptions and returns "")
        # First call returns empty => fallback_error (before locked terms check)

        flash_patch, prompt_patch = _patch_enrichment("")
        with flash_patch as mock_flash, prompt_patch:
            result, version = enrich_prompt(
                "corridor scene",
                "nbp",
                locked,
            )

        assert result == "corridor scene"
        assert version == "fallback_error"
        assert mock_flash.call_count == 1
