"""Tests for orchestrator/scene_planner.py — shot classification and routing."""

import pytest

from orchestrator.scene_planner import (
    classify_shot_tier,
    is_env_shot,
    plan_scene,
    plan_episode,
    route_shot,
    is_batch_eligible,
    partition_long_scene,
)


class TestClassifyShotTier:
    """Parametrized tests covering all tier classification rules."""

    @pytest.mark.parametrize("shot,expected_tier", [
        # ENV shots → simple
        ({"characters_in_shot": [], "shot_type": "ENV"}, "simple"),
        ({"characters_in_shot": [], "shot_type": "ESTABLISHING"}, "simple"),

        # Wide/LS shots → simple
        ({"characters_in_shot": ["jinx"], "shot_type": "WIDE"}, "simple"),
        ({"characters_in_shot": ["jinx"], "shot_type": "LS"}, "simple"),
        ({"characters_in_shot": ["jinx"], "shot_type": "ELS"}, "simple"),
        ({"characters_in_shot": ["jinx"], "shot_type": "INSERT"}, "simple"),

        # Two+ characters → complex
        ({"characters_in_shot": ["jinx", "ava"], "shot_type": "MS"}, "complex"),
        ({"characters_in_shot": ["jinx", "ava", "kai"], "shot_type": "MS"}, "complex"),

        # ECU with emotion → complex
        ({"characters_in_shot": ["jinx"], "shot_type": "ECU", "emotion": "fear"}, "complex"),
        ({"characters_in_shot": ["jinx"], "shot_type": "BCU", "emotion": "rage"}, "complex"),

        # ECU without emotion → standard (not complex)
        ({"characters_in_shot": ["jinx"], "shot_type": "ECU", "emotion": ""}, "standard"),

        # Default → standard
        ({"characters_in_shot": ["jinx"], "shot_type": "MS"}, "standard"),
        ({"characters_in_shot": ["jinx"], "shot_type": "CU"}, "standard"),
    ])
    def test_tier_classification(self, shot, expected_tier):
        assert classify_shot_tier(shot) == expected_tier


class TestIsEnvShot:
    def test_env_shot(self, env_shot):
        assert is_env_shot(env_shot)

    def test_character_shot(self, sample_shot):
        assert not is_env_shot(sample_shot)


class TestPlanScene:
    def test_env_shots_first(self):
        shots = [
            {"id": 2, "characters_in_shot": ["jinx"], "shot_type": "MS"},
            {"id": 1, "characters_in_shot": [], "shot_type": "ENV"},
            {"id": 3, "characters_in_shot": ["ava"], "shot_type": "CU"},
        ]
        planned = plan_scene(shots, scene_index=0)

        # ENV shot should be first
        assert planned[0]["id"] == 1
        assert planned[0]["_provides_scene_ref"] is True
        assert planned[0]["_is_env"] is True

    def test_scene_index_tagged(self):
        shots = [{"id": 1, "characters_in_shot": [], "shot_type": "ENV"}]
        planned = plan_scene(shots, scene_index=3)
        assert planned[0]["_scene_index"] == 3

    def test_generation_order(self):
        shots = [
            {"id": 1, "characters_in_shot": [], "shot_type": "ENV"},
            {"id": 2, "characters_in_shot": ["jinx"], "shot_type": "MS"},
        ]
        planned = plan_scene(shots, scene_index=0)
        assert planned[0]["_generation_order"] == 0
        assert planned[1]["_generation_order"] == 1

    def test_empty_scene(self):
        assert plan_scene([], 0) == []

    def test_no_env_shot(self):
        shots = [
            {"id": 1, "characters_in_shot": ["jinx"], "shot_type": "MS"},
            {"id": 2, "characters_in_shot": ["ava"], "shot_type": "CU"},
        ]
        planned = plan_scene(shots, 0)
        assert not any(s["_provides_scene_ref"] for s in planned)


class TestRouteShot:
    """Test all 6 routing priority levels."""

    def test_priority_1_multi_shot(self, batch_eligible_scene):
        """Scene-eligible shots → multi_shot pipeline."""
        result = route_shot(batch_eligible_scene[0], batch_eligible_scene)
        assert result["pipeline"] == "multi_shot"
        assert result["model"] == "kling-v3"

    def test_priority_2_i2v(self, i2v_shot):
        """Start+end frame → i2v pipeline."""
        result = route_shot(i2v_shot)
        assert result["pipeline"] == "i2v"
        assert result["model"] == "kling-v3"

    def test_priority_3_still_env(self, env_shot):
        """ENV shot → still pipeline."""
        result = route_shot(env_shot)
        assert result["pipeline"] == "still"

    def test_priority_3_still_insert(self):
        """INSERT with no characters → still pipeline."""
        shot = {"id": 5, "characters_in_shot": [], "shot_type": "INSERT",
                "action": "Gauge close-up"}
        result = route_shot(shot)
        assert result["pipeline"] == "still"

    def test_priority_4_long_duration(self):
        """Long duration → Veo via t2v pipeline."""
        shot = {"id": 6, "characters_in_shot": ["jinx"], "shot_type": "MS",
                "duration_s": 20, "action": "Long tracking shot"}
        result = route_shot(shot)
        assert result["pipeline"] == "t2v"
        assert result["model"] == "veo-3.1"

    def test_priority_4_complex_camera(self):
        """Complex camera movement → Veo."""
        shot = {"id": 7, "characters_in_shot": ["jinx"], "shot_type": "MS",
                "camera_movement": "STEADICAM", "action": "Tracking"}
        result = route_shot(shot)
        assert result["pipeline"] == "t2v"
        assert result["model"] == "veo-3.1"

    def test_priority_5_dialogue(self, dialogue_shot):
        """Dialogue → I2V (keyframe preserves identity)."""
        result = route_shot(dialogue_shot)
        assert result["pipeline"] == "i2v"

    def test_priority_5_multi_char(self, complex_shot):
        """2+ characters → I2V (keyframe preserves identity)."""
        result = route_shot(complex_shot)
        assert result["pipeline"] == "i2v"

    def test_priority_6_default(self):
        """Default single-character → I2V (keyframe preserves identity)."""
        shot = {"id": 8, "characters_in_shot": ["jinx"], "shot_type": "MS",
                "action": "Jinx walks", "duration_s": 5}
        result = route_shot(shot)
        assert result["pipeline"] == "i2v"

    def test_scene_shots_context(self):
        """route_shot accepts optional scene_shots for batch eligibility."""
        shot = {"id": 8, "characters_in_shot": ["jinx"], "shot_type": "MS",
                "action": "Jinx walks", "duration_s": 5}
        result = route_shot(shot, scene_shots=None)
        assert result["pipeline"] in ("still", "i2v", "t2v", "multi_shot")


class TestBatchEligibility:
    def test_eligible(self, batch_eligible_scene):
        assert is_batch_eligible(batch_eligible_scene)

    def test_too_few_shots(self):
        shots = [{"id": 1, "characters_in_shot": ["jinx"], "location": "a"},
                 {"id": 2, "characters_in_shot": ["jinx"], "location": "a"}]
        assert not is_batch_eligible(shots)

    def test_too_many_shots(self):
        shots = [{"id": i, "characters_in_shot": ["jinx"], "location": "a"}
                 for i in range(10)]
        assert not is_batch_eligible(shots)

    def test_mixed_locations(self):
        shots = [
            {"id": 1, "characters_in_shot": ["jinx"], "location": "a"},
            {"id": 2, "characters_in_shot": ["jinx"], "location": "a"},
            {"id": 3, "characters_in_shot": ["jinx"], "location": "b"},
        ]
        assert not is_batch_eligible(shots)

    def test_i2v_shot_disqualifies(self):
        shots = [
            {"id": 1, "characters_in_shot": ["jinx"], "location": "a"},
            {"id": 2, "characters_in_shot": ["jinx"], "location": "a", "start_frame": "/x"},
            {"id": 3, "characters_in_shot": ["jinx"], "location": "a"},
        ]
        assert not is_batch_eligible(shots)

    def test_too_many_characters(self):
        shots = [
            {"id": 1, "characters_in_shot": ["a", "b"], "location": "x"},
            {"id": 2, "characters_in_shot": ["c", "d"], "location": "x"},
            {"id": 3, "characters_in_shot": ["e"], "location": "x"},
        ]
        assert not is_batch_eligible(shots)


class TestPartitionLongScene:
    def test_short_scene_no_partition(self):
        shots = [{"id": i} for i in range(5)]
        result = partition_long_scene(shots)
        assert len(result) == 1
        assert len(result[0]) == 5

    def test_long_scene_partitioned(self):
        shots = [{"id": i} for i in range(16)]
        result = partition_long_scene(shots)
        assert all(3 <= len(batch) <= 8 for batch in result)

    def test_remainder_merged(self):
        shots = [{"id": i} for i in range(10)]
        result = partition_long_scene(shots)
        # 10 shots: first batch 8, remainder 2 → merged into first batch
        assert len(result) == 1
        assert len(result[0]) == 10
