"""Phase C — cost canonical helpers tests."""

import pytest
import logging
from unittest.mock import patch
from recoil.pipeline.core.cost import (
    CostMissingError,
    read_cost_from_result,
    read_cost_from_result_safe,
    read_cost_from_record_safe,
    compute_cost,
    get_cost,
)


class FakeResult:
    def __init__(self, id="test", metadata=None, cost_usd_attr=None):
        self.id = id
        self.metadata = metadata
        if cost_usd_attr is not None:
            self.cost_usd = cost_usd_attr


class TestReadCostFromResult:
    def test_reads_from_metadata(self):
        r = FakeResult(metadata={"cost_usd": 0.42})
        assert read_cost_from_result(r) == 0.42

    def test_reads_from_attribute_when_no_metadata(self):
        class R:
            id = "x"
            cost_usd = 0.7
        assert read_cost_from_result(R()) == 0.7

    def test_raises_when_metadata_missing_cost(self):
        r = FakeResult(metadata={"other": 1})
        with pytest.raises(CostMissingError):
            read_cost_from_result(r)

    def test_raises_when_metadata_cost_is_none(self):
        r = FakeResult(metadata={"cost_usd": None})
        with pytest.raises(CostMissingError):
            read_cost_from_result(r)

    def test_raises_when_no_metadata_and_no_attr(self):
        class R:
            id = "x"
        with pytest.raises(CostMissingError):
            read_cost_from_result(R())

    def test_zero_is_legitimate(self):
        r = FakeResult(metadata={"cost_usd": 0.0})
        assert read_cost_from_result(r) == 0.0

    def test_error_carries_result_id(self):
        r = FakeResult(id="EP001_SH02", metadata={})
        try:
            read_cost_from_result(r)
        except CostMissingError as e:
            assert e.result_id == "EP001_SH02"


class TestReadCostFromResultSafe:
    def test_returns_zero_on_missing(self, caplog):
        r = FakeResult(metadata={})
        with caplog.at_level(logging.WARNING):
            v = read_cost_from_result_safe(r)
        assert v == 0.0
        assert any("FALLBACK_FIRED" in rec.message for rec in caplog.records)

    def test_returns_value_on_present(self):
        r = FakeResult(metadata={"cost_usd": 1.23})
        assert read_cost_from_result_safe(r) == 1.23

    def test_no_log_when_value_present(self, caplog):
        r = FakeResult(metadata={"cost_usd": 1.0})
        with caplog.at_level(logging.WARNING):
            read_cost_from_result_safe(r)
        assert not any("FALLBACK_FIRED" in rec.message for rec in caplog.records)


class TestReadCostFromRecordSafe:
    def test_reads_dict_record(self):
        assert read_cost_from_record_safe({"cost_usd": 0.5}) == 0.5

    def test_returns_zero_on_missing_key(self, caplog):
        with caplog.at_level(logging.WARNING):
            v = read_cost_from_record_safe({"other": 1})
        assert v == 0.0
        assert any("FALLBACK_FIRED" in r.message for r in caplog.records)

    def test_returns_zero_on_unparseable(self, caplog):
        with caplog.at_level(logging.WARNING):
            v = read_cost_from_record_safe({"cost_usd": "garbage"})
        assert v == 0.0

    def test_returns_zero_on_non_dict_input(self, caplog):
        with caplog.at_level(logging.WARNING):
            v = read_cost_from_record_safe("not-a-dict")
        assert v == 0.0
        assert any("FALLBACK_FIRED" in r.message for r in caplog.records)

    def test_zero_record_value_passes_through(self):
        assert read_cost_from_record_safe({"cost_usd": 0.0}) == 0.0


class TestComputeCost:
    def test_per_second(self):
        profile = {"billing_unit": "per_second", "cost_per_second": 0.1}
        assert compute_cost("m", duration_s=10, profile=profile) == 1.0

    def test_per_1k_chars(self):
        profile = {"billing_unit": "per_1k_chars", "cost_per_1k_chars": 0.30}
        assert compute_cost("m", char_count=2000, profile=profile) == 0.60

    def test_per_1k_tokens(self):
        profile = {
            "billing_unit": "per_1k_tokens",
            "cost_per_1k_tokens_input": 0.002,
            "cost_per_1k_tokens_output": 0.012,
        }
        assert compute_cost(
            "m", token_input_count=1000, token_output_count=500, profile=profile
        ) == pytest.approx(0.002 + 0.006, rel=1e-6)

    def test_flat_per_image(self):
        profile = {"billing_unit": "flat_per_image", "image_cost": 0.039}
        assert compute_cost("m", profile=profile) == 0.039

    def test_tier_keyed_rate(self):
        profile = {
            "billing_unit": "per_second",
            "cost_per_second": {"standard": 0.10, "professional": 0.20},
        }
        assert compute_cost("m", duration_s=5, tier="professional", profile=profile) == 1.0

    def test_missing_duration_raises(self):
        profile = {"billing_unit": "per_second", "cost_per_second": 0.1}
        with pytest.raises(ValueError):
            compute_cost("m", profile=profile)

    def test_unknown_billing_unit_raises(self):
        profile = {"billing_unit": "unknown_kind"}
        with pytest.raises(CostMissingError):
            compute_cost("m", profile=profile)

    def test_missing_profile_raises(self):
        with patch("recoil.core.model_profiles.get_profile", return_value=None):
            with pytest.raises(CostMissingError):
                compute_cost("nonexistent_model")


class TestGetCost:
    def test_get_cost_default_raises_on_missing(self):
        r = FakeResult(metadata={})
        with pytest.raises(CostMissingError):
            get_cost(r)

    def test_get_cost_allow_missing_returns_zero(self, caplog):
        r = FakeResult(metadata={})
        with caplog.at_level(logging.WARNING):
            v = get_cost(r, allow_missing=True)
        assert v == 0.0

    def test_get_cost_returns_value(self):
        r = FakeResult(metadata={"cost_usd": 2.5})
        assert get_cost(r) == 2.5
