"""CP-9 Phase 4 — GeminiVisionEvalNode unit tests.

Verifies the EvalNode adapter that bridges :func:`gemini_vision.score_artifact`
to the :class:`EvalNode` Protocol. All tests mock at the
``score_artifact`` boundary so no live API calls fire.
"""

import sys
import pathlib
from unittest.mock import patch

sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent.parent.parent))
from recoil.core.paths import ensure_pipeline_importable  # noqa: E402

ensure_pipeline_importable()

import pytest  # noqa: E402

from recoil.execution.providers import gemini_vision as _gv  # noqa: E402
from recoil.pipeline.core.eval import EvalContext, EvalNode, EvalResult  # noqa: E402
from recoil.pipeline.core.runners._gemini_vision_eval_node import (  # noqa: E402
    GeminiVisionEvalNode,
)


def _provider_result(**overrides):
    base = dict(
        score=0.8,
        reasoning="Composed and well-lit.",
        cost_usd=0.0042,
        model_used="gemini-3.1-pro-preview",
        request_id="req_xyz",
        raw_metadata={
            "warnings": [],
            "finish_reason": "STOP",
            "usage": {"promptTokenCount": 1200, "candidatesTokenCount": 220},
        },
    )
    base.update(overrides)
    return _gv.EvalProviderResult(**base)


def _ctx(modality="image", **overrides):
    base = dict(
        target_artifact_path=pathlib.Path("/tmp/x.png"),
        target_take=None,
        prompt="generation prompt",
        rubric=f"Score this {modality} 0-1.",
        judge_id=f"eval_{modality}_v1",
        metadata={},
    )
    base.update(overrides)
    return EvalContext(**base)


# ── Construction ───────────────────────────────────────────────────────


def test_satisfies_eval_node_protocol():
    node = GeminiVisionEvalNode(artifact_modality="image")
    assert isinstance(node, EvalNode)
    assert hasattr(node, "judge_id")
    assert hasattr(node, "model_used")
    assert callable(node.evaluate)


def test_default_judge_id_per_modality():
    assert GeminiVisionEvalNode(artifact_modality="image").judge_id == "eval_image_v1"
    assert GeminiVisionEvalNode(artifact_modality="video").judge_id == "eval_video_v1"
    assert GeminiVisionEvalNode(artifact_modality="audio").judge_id == "eval_audio_v1"


def test_custom_judge_id_honored():
    node = GeminiVisionEvalNode(
        artifact_modality="image",
        judge_id="my_panel_judge_a",
    )
    assert node.judge_id == "my_panel_judge_a"


def test_default_model_id_matches_provider_default():
    node = GeminiVisionEvalNode(artifact_modality="image")
    assert node.model_used == _gv.DEFAULT_MODEL_ID == "gemini-3.1-pro-preview"


def test_custom_model_id_honored():
    node = GeminiVisionEvalNode(
        artifact_modality="image", model_id="gemini-future-x"
    )
    assert node.model_used == "gemini-future-x"


def test_unsupported_modality_rejected_at_construction():
    with pytest.raises(ValueError, match="unsupported artifact_modality"):
        GeminiVisionEvalNode(artifact_modality="multispectral")


# ── evaluate() — happy path ────────────────────────────────────────────


def test_evaluate_calls_score_artifact_with_correct_args():
    node = GeminiVisionEvalNode(artifact_modality="image")
    captured = {}

    def _fake_score_artifact(**kwargs):
        captured.update(kwargs)
        return _provider_result()

    with patch.object(_gv, "score_artifact", side_effect=_fake_score_artifact):
        result = node.evaluate(_ctx(modality="image"))

    assert isinstance(result, EvalResult)
    # Verify exact args threaded through.
    assert captured["artifact_path"] == pathlib.Path("/tmp/x.png")
    assert captured["artifact_modality"] == "image"
    assert captured["prompt"] == "Score this image 0-1."  # rubric → prompt
    assert captured["judge_id"] == "eval_image_v1"
    assert captured["model_id"] == "gemini-3.1-pro-preview"
    assert captured["api_key_env_var"] == _gv.DEFAULT_AUTH_ENV_VAR
    assert captured["timeout_s"] == _gv.DEFAULT_TIMEOUT_S
    assert captured["transport"] is None


def test_evaluate_translates_provider_result_to_eval_result():
    node = GeminiVisionEvalNode(artifact_modality="video")
    pr = _provider_result(
        score=0.55,
        reasoning="Mid-tier coherence.",
        cost_usd=0.0085,
        model_used="gemini-3.1-pro-preview",
        request_id="req_translation",
        raw_metadata={"warnings": ["score_clipped"], "finish_reason": "STOP"},
    )
    with patch.object(_gv, "score_artifact", return_value=pr):
        result = node.evaluate(_ctx(modality="video"))
    assert result.score == 0.55
    assert result.reasoning == "Mid-tier coherence."
    assert result.cost_usd == pytest.approx(0.0085)
    assert result.model_used == "gemini-3.1-pro-preview"
    assert result.judge_id == "eval_video_v1"
    # Adapter raw_metadata + request_id + artifact_modality echo all surfaced.
    assert result.metadata["warnings"] == ["score_clipped"]
    assert result.metadata["finish_reason"] == "STOP"
    assert result.metadata["request_id"] == "req_translation"
    assert result.metadata["artifact_modality"] == "video"


def test_evaluate_threads_transport_from_context_metadata():
    node = GeminiVisionEvalNode(artifact_modality="image")
    sentinel = object()
    captured = {}

    def _fake(**kwargs):
        captured.update(kwargs)
        return _provider_result()

    with patch.object(_gv, "score_artifact", side_effect=_fake):
        node.evaluate(_ctx(metadata={"_transport": sentinel}))
    assert captured["transport"] is sentinel


def test_evaluate_uses_modality_from_constructor_not_context():
    """The node passes its constructor modality to score_artifact, not any
    field from the EvalContext (EvalContext does not carry modality)."""
    node = GeminiVisionEvalNode(artifact_modality="audio")
    captured = {}

    def _fake(**kwargs):
        captured.update(kwargs)
        return _provider_result(model_used="gemini-3.1-pro-preview")

    with patch.object(_gv, "score_artifact", side_effect=_fake):
        node.evaluate(_ctx(modality="audio"))
    assert captured["artifact_modality"] == "audio"


# ── evaluate() — exception propagation ─────────────────────────────────


def test_evaluate_propagates_provider_error():
    """Adapter exceptions MUST propagate so the wrapping runner can map
    them into a failure-RunResult per audit § 12c R2."""
    node = GeminiVisionEvalNode(artifact_modality="image")
    with patch.object(_gv, "score_artifact",
                      side_effect=_gv.EvalServerError("Gemini 503")):
        with pytest.raises(_gv.EvalServerError):
            node.evaluate(_ctx())


def test_evaluate_propagates_auth_error():
    node = GeminiVisionEvalNode(artifact_modality="image")
    with patch.object(_gv, "score_artifact",
                      side_effect=_gv.EvalAuthError("missing key")):
        with pytest.raises(_gv.EvalAuthError):
            node.evaluate(_ctx())


def test_evaluate_propagates_unknown_exception():
    node = GeminiVisionEvalNode(artifact_modality="image")
    with patch.object(_gv, "score_artifact", side_effect=RuntimeError("boom")):
        with pytest.raises(RuntimeError):
            node.evaluate(_ctx())


# ── Custom timeout / api key threading ─────────────────────────────────


def test_custom_timeout_threaded():
    node = GeminiVisionEvalNode(artifact_modality="image", timeout_s=30.0)
    captured = {}

    def _fake(**kwargs):
        captured.update(kwargs)
        return _provider_result()

    with patch.object(_gv, "score_artifact", side_effect=_fake):
        node.evaluate(_ctx())
    assert captured["timeout_s"] == 30.0


def test_custom_api_key_env_var_threaded():
    node = GeminiVisionEvalNode(
        artifact_modality="image", api_key_env_var="MY_CUSTOM_KEY"
    )
    captured = {}

    def _fake(**kwargs):
        captured.update(kwargs)
        return _provider_result()

    with patch.object(_gv, "score_artifact", side_effect=_fake):
        node.evaluate(_ctx())
    assert captured["api_key_env_var"] == "MY_CUSTOM_KEY"
