"""Tests for orchestrator/manifest.py — EpisodeLog state tracking."""

import json
import tempfile
from pathlib import Path

import pytest

from orchestrator.manifest import EpisodeLog


@pytest.fixture
def tmp_output_dir(tmp_path):
    """Temporary output directory for log tests."""
    return tmp_path / "ep_001"


@pytest.fixture
def log(tmp_output_dir):
    """Fresh EpisodeLog with temp directory."""
    return EpisodeLog(episode=1, output_dir=tmp_output_dir)


class TestEpisodeLogInit:
    def test_creates_with_episode(self, log):
        assert log.episode == 1

    def test_init_shots(self, log):
        log.init_shots([1, 2, 3])
        assert log.get_shot_status(1) == "pending"
        assert log.get_shot_status(2) == "pending"
        assert log.get_shot_status(3) == "pending"

    def test_unknown_shot_returns_pending(self, log):
        assert log.get_shot_status(999) == "pending"


class TestEpisodeLogUpdateShot:
    def test_update_status(self, log):
        log.init_shots([1])
        log.update_shot(1, "submitted")
        assert log.get_shot_status(1) == "submitted"

    def test_update_with_metadata(self, log):
        log.update_shot(1, "complete",
                             cost=0.134,
                             output_path="/output/frame.png",
                             tier="simple",
                             pipeline="still",
                             model="gemini-3-pro-image-preview")
        assert log.get_shot_status(1) == "complete"

    def test_invalid_status_raises(self, log):
        with pytest.raises(ValueError, match="Invalid status"):
            log.update_shot(1, "bogus_status")

    def test_update_creates_shot_if_missing(self, log):
        log.update_shot(42, "submitted")
        assert log.get_shot_status(42) == "submitted"


class TestEpisodeLogQueries:
    def test_get_pending_shots(self, log):
        log.init_shots([1, 2, 3])
        log.update_shot(1, "complete")
        log.update_shot(2, "failed")
        assert log.get_pending_shots() == [3]

    def test_get_failed_shots(self, log):
        log.init_shots([1, 2, 3])
        log.update_shot(2, "failed", error="API error")
        assert log.get_failed_shots() == [2]

    def test_get_complete_shots(self, log):
        log.init_shots([1, 2, 3])
        log.update_shot(1, "complete", cost=0.134)
        log.update_shot(3, "complete", cost=0.039)
        assert sorted(log.get_complete_shots()) == [1, 3]

    def test_get_in_progress_shots(self, log):
        log.init_shots([1, 2])
        log.update_shot(1, "submitted")
        log.update_shot(2, "processing")
        assert sorted(log.get_in_progress_shots()) == [1, 2]


class TestEpisodeLogCost:
    def test_total_cost(self, log):
        log.update_shot(1, "complete", cost=0.134)
        log.update_shot(2, "complete", cost=0.039)
        log.update_shot(3, "failed")
        assert log.total_cost() == pytest.approx(0.173)

    def test_total_cost_empty(self, log):
        assert log.total_cost() == 0.0


class TestEpisodeLogSummary:
    def test_summary(self, log):
        log.init_shots([1, 2, 3])
        log.update_shot(1, "complete", cost=0.134)
        log.update_shot(2, "failed")

        s = log.summary()
        assert s["episode"] == 1
        assert s["total_shots"] == 3
        assert s["by_status"]["complete"] == 1
        assert s["by_status"]["failed"] == 1
        assert s["by_status"]["pending"] == 1


class TestEpisodeLogPersistence:
    def test_save_and_load(self, log, tmp_output_dir):
        log.init_shots([1, 2, 3])
        log.update_shot(1, "complete", cost=0.134)
        log.update_shot(2, "submitted")
        log.save()

        # Reload from disk
        loaded = EpisodeLog(episode=1, output_dir=tmp_output_dir)
        assert loaded.get_shot_status(1) == "complete"
        assert loaded.get_shot_status(2) == "submitted"
        assert loaded.get_shot_status(3) == "pending"

    def test_save_creates_directory(self, tmp_path):
        deep_dir = tmp_path / "nested" / "ep_001"
        m = EpisodeLog(episode=1, output_dir=deep_dir)
        m.init_shots([1])
        m.save()
        assert (deep_dir / "log.json").exists()

    def test_corrupt_log_starts_fresh(self, tmp_output_dir):
        tmp_output_dir.mkdir(parents=True, exist_ok=True)
        (tmp_output_dir / "log.json").write_text("{{invalid json")

        m = EpisodeLog(episode=1, output_dir=tmp_output_dir)
        # Should not crash, should start fresh
        assert m.get_shot_status(1) == "pending"


class TestEpisodeLogKeyframeStatuses:
    """Test keyframe workflow statuses."""

    def test_keyframe_workflow(self, log):
        log.update_shot(1, "keyframe_pending")
        assert log.get_shot_status(1) == "keyframe_pending"

        log.update_shot(1, "keyframe_generated")
        assert log.get_shot_status(1) == "keyframe_generated"

        log.update_shot(1, "keyframe_approved")
        assert log.get_shot_status(1) == "keyframe_approved"

    def test_keyframe_rejected_loops_back(self, log):
        log.update_shot(1, "keyframe_generated")
        log.update_shot(1, "keyframe_rejected")
        assert log.get_shot_status(1) == "keyframe_rejected"

        # Can go back to pending
        log.update_shot(1, "keyframe_pending")
        assert log.get_shot_status(1) == "keyframe_pending"
