"""Phase C — failure_mode canonical taxonomy tests."""

import pytest
from recoil.pipeline.core.failure_mode import (
    FailureMode,
    FailureCategory,
    classify_failure,
    failure_category_for,
    UnknownFailureEscalation,
    TRANSIENT_PATTERN_STRINGS,
    TRANSIENT_HTTP_CODES,
)


class TestClassifyFailure:
    def test_429_string_match_returns_transient(self):
        mode, conf = classify_failure(error_text="429 rate limit exceeded")
        assert mode is FailureMode.TRANSIENT
        assert 0.85 <= conf <= 1.0

    def test_503_string_match(self):
        mode, _ = classify_failure(error_text="503 service unavailable")
        assert mode is FailureMode.TRANSIENT

    def test_504_string_match_was_strategy_registry_only(self):
        mode, _ = classify_failure(error_text="504 gateway timeout")
        assert mode is FailureMode.TRANSIENT

    def test_econnreset_was_retry_dispatcher_only(self):
        mode, _ = classify_failure(error_text="ECONNRESET on socket")
        assert mode is FailureMode.TRANSIENT

    def test_500_was_retry_dispatcher_only(self):
        mode, _ = classify_failure(error_text="500 internal server error")
        assert mode is FailureMode.TRANSIENT

    def test_content_filter_classification(self):
        mode, _ = classify_failure(error_text="content policy violation")
        assert mode is FailureMode.CONTENT_FILTER_HARD_BLOCK

    def test_budget_classification(self):
        mode, _ = classify_failure(error_text="402 insufficient balance")
        assert mode is FailureMode.COST_OVERRUN

    def test_schema_classification(self):
        mode, _ = classify_failure(error_text="422 input should be valid")
        assert mode is FailureMode.PROMPT_DURATION_MISMATCH

    def test_http_status_500_returns_transient(self):
        mode, _ = classify_failure(http_status=500)
        assert mode is FailureMode.TRANSIENT

    def test_http_status_429_returns_transient(self):
        mode, _ = classify_failure(http_status=429)
        assert mode is FailureMode.TRANSIENT

    def test_http_status_422_returns_schema(self):
        mode, _ = classify_failure(http_status=422)
        assert mode is FailureMode.PROMPT_DURATION_MISMATCH

    def test_unknown_input_escalates_by_default(self):
        with pytest.raises(UnknownFailureEscalation):
            classify_failure(error_text="garbage random string nobody knows")

    def test_unknown_input_returns_unknown_when_escalate_false(self):
        mode, conf = classify_failure(
            error_text="garbage random string",
            escalate_unknown=False,
        )
        assert mode is FailureMode.UNKNOWN
        assert conf == 0.0

    def test_empty_input_escalates(self):
        with pytest.raises(UnknownFailureEscalation):
            classify_failure()

    def test_escalation_carries_caller_context(self):
        try:
            classify_failure(error_text="garbage", caller="test_caller")
        except UnknownFailureEscalation as e:
            assert e.caller == "test_caller"
            assert e.error_text == "garbage"
        else:
            pytest.fail("expected UnknownFailureEscalation")


class TestFailureCategoryFor:
    def test_transient_maps_to_transient(self):
        assert failure_category_for(FailureMode.TRANSIENT) is FailureCategory.TRANSIENT

    def test_content_filter_maps_to_content_filter(self):
        assert (
            failure_category_for(FailureMode.CONTENT_FILTER_HARD_BLOCK)
            is FailureCategory.CONTENT_FILTER
        )

    def test_identity_drift_maps_to_gate_identity(self):
        assert failure_category_for(FailureMode.IDENTITY_DRIFT) is FailureCategory.GATE_IDENTITY

    def test_none_raises(self):
        with pytest.raises(ValueError):
            failure_category_for(FailureMode.NONE)

    def test_unknown_escalates(self):
        with pytest.raises(UnknownFailureEscalation):
            failure_category_for(FailureMode.UNKNOWN)

    def test_mapping_is_total_over_failure_mode(self):
        for mode in FailureMode:
            if mode in (FailureMode.NONE, FailureMode.UNKNOWN):
                continue
            cat = failure_category_for(mode)
            assert isinstance(cat, FailureCategory)


class TestPatternUnion:
    def test_429_present(self):
        assert any("429" in p for p in TRANSIENT_PATTERN_STRINGS)

    def test_econnreset_present(self):
        assert any("ECONNRESET" in p for p in TRANSIENT_PATTERN_STRINGS)

    def test_504_present(self):
        assert any("504" in p for p in TRANSIENT_PATTERN_STRINGS)

    def test_500_present(self):
        assert any("500" in p for p in TRANSIENT_PATTERN_STRINGS)

    def test_http_codes_cover_5xx(self):
        for code in (500, 501, 502, 503, 504):
            assert code in TRANSIENT_HTTP_CODES

    def test_429_not_in_http_codes(self):
        assert 429 not in TRANSIENT_HTTP_CODES
