"""Tests for lib/execution_store.py — JSON file-per-shot execution state backend."""

import json
import logging
import threading
import time
from unittest.mock import patch

import pytest

from recoil.execution.execution_store import (
    ExecutionStore,
    InvalidTransitionError,
    ShotExistsError,
    VALID_TRANSITIONS,
)


def _make_store(db_path):
    """Create an ExecutionStore with SQLite migration disabled (for test isolation)."""
    with patch.object(ExecutionStore, "_try_migrate_from_sqlite"):
        return ExecutionStore(db_path=db_path)


def _make_shot(shot_id="ep01_s01_001", episode_id="ep01", **overrides):
    """Helper to build a shot dict with sensible defaults."""
    shot = {
        "shot_id": shot_id,
        "episode_id": episode_id,
        "pipeline": "still",
        "model": "gemini-3-pro-image-preview",
        "status": "previs_pending",
        "cost_incurred": 0.0,
        "attempts": 0,
        "max_attempts": 3,
    }
    shot.update(overrides)
    return shot


# ── Store Initialization ────────────────────────────────────────────


class TestStoreInit:
    def test_creates_shots_directory(self, tmp_path):
        shots_dir = tmp_path / "shots"
        store = _make_store(shots_dir)
        try:
            assert shots_dir.exists()
            assert shots_dir.is_dir()
        finally:
            store.close()

    def test_db_path_with_db_suffix_uses_sibling_shots_dir(self, tmp_path):
        db_path = tmp_path / "test.db"
        store = _make_store(db_path)
        try:
            expected = tmp_path / "shots"
            assert store._shots_dir == expected
            assert expected.is_dir()
        finally:
            store.close()

    def test_creates_parent_dirs(self, tmp_path):
        shots_dir = tmp_path / "nested" / "dirs" / "shots"
        store = _make_store(shots_dir)
        try:
            assert shots_dir.exists()
        finally:
            store.close()

    def test_shot_files_are_json(self, tmp_path):
        shots_dir = tmp_path / "shots"
        store = _make_store(shots_dir)
        try:
            store.insert_shot(_make_shot())
            shot_file = shots_dir / "ep01_s01_001.json"
            assert shot_file.exists()
            data = json.loads(shot_file.read_text())
            assert data["shot_id"] == "ep01_s01_001"
        finally:
            store.close()


# ── Shot CRUD ────────────────────────────────────────────────────────


class TestInsertAndGet:
    def test_insert_and_get_shot(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            shot = _make_shot()
            store.insert_shot(shot)
            result = store.get_shot("ep01_s01_001")

            assert result is not None
            assert result["shot_id"] == "ep01_s01_001"
            assert result["episode_id"] == "ep01"
            assert result["pipeline"] == "still"
            assert result["status"] == "previs_pending"
            assert result["cost_incurred"] == 0.0
        finally:
            store.close()

    def test_get_nonexistent_shot_returns_none(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            assert store.get_shot("does_not_exist") is None
        finally:
            store.close()

    def test_get_shot_status_default(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            # Non-existent shot returns default status
            assert store.get_shot_status("missing") == "previs_pending"
        finally:
            store.close()

    def test_get_shot_status_existing(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot(status="video_complete"))
            assert store.get_shot_status("ep01_s01_001") == "video_complete"
        finally:
            store.close()

    def test_insert_overwrites_on_duplicate(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot(status="previs_pending"))
            store.insert_shot(_make_shot(status="video_complete"))
            result = store.get_shot("ep01_s01_001")
            assert result["status"] == "video_complete"
        finally:
            store.close()

    def test_insert_shots_batch(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            shots = [_make_shot(shot_id=f"ep01_s01_{i:03d}") for i in range(5)]
            store.insert_shots_batch(shots)
            all_shots = store.get_all_shots()
            assert len(all_shots) == 5
        finally:
            store.close()


# ── Update Operations ────────────────────────────────────────────────


class TestUpdateShot:
    def test_update_status_valid_transition(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            # previs_pending -> previs_generating (valid transition)
            store.update_shot("ep01_s01_001", status="previs_generating")
            result = store.get_shot("ep01_s01_001")
            assert result["status"] == "previs_generating"
        finally:
            store.close()

    def test_update_status_invalid_transition_warns(self, tmp_path, caplog):
        """Phase 2.5 softened transitions: non-standard transitions warn, not raise."""
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            # previs_pending -> video_complete is NOT a standard transition
            # but is now allowed with a warning
            with caplog.at_level(logging.WARNING, logger="recoil.execution.execution_store"):
                store.update_shot("ep01_s01_001", status="video_complete")
            assert "non-standard transition" in caplog.text
            assert "previs_pending -> video_complete" in caplog.text
            # Transition was applied despite being non-standard
            result = store.get_shot("ep01_s01_001")
            assert result["status"] == "video_complete"
        finally:
            store.close()

    def test_update_status_same_state_noop(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            # Same state is allowed (no-op)
            store.update_shot("ep01_s01_001", status="previs_pending")
            result = store.get_shot("ep01_s01_001")
            assert result["status"] == "previs_pending"
        finally:
            store.close()

    def test_cost_accumulates(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot(cost_incurred=0.10))
            store.update_shot("ep01_s01_001", cost_incurred=0.05)
            result = store.get_shot("ep01_s01_001")
            assert abs(result["cost_incurred"] - 0.15) < 1e-9
        finally:
            store.close()

    def test_gate_results_merge(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            store.update_shot("ep01_s01_001", gate_results={"gate1": "pass"})
            store.update_shot("ep01_s01_001", gate_results={"gate2": "fail"})
            result = store.get_shot("ep01_s01_001")
            assert result["gate_results"] == {"gate1": "pass", "gate2": "fail"}
        finally:
            store.close()

    def test_takes_replaced_wholesale(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            store.update_shot("ep01_s01_001", takes=["take1.png", "take2.png"])
            store.update_shot("ep01_s01_001", takes=["take3.png"])
            result = store.get_shot("ep01_s01_001")
            assert result["takes"] == ["take3.png"]
        finally:
            store.close()

    def test_update_auto_creates_missing_shot(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.update_shot("new_shot", status="video_submitted")
            result = store.get_shot("new_shot")
            assert result is not None
            assert result["status"] == "video_submitted"
        finally:
            store.close()

    def test_append_take_via_update(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            store.update_shot("ep01_s01_001", append_take={"path": "take1.png"})
            store.update_shot("ep01_s01_001", append_take={"path": "take2.png"})
            result = store.get_shot("ep01_s01_001")
            assert len(result["takes"]) == 2
            assert result["takes"][0] == {"path": "take1.png"}
            assert result["takes"][1] == {"path": "take2.png"}
        finally:
            store.close()


# ── State Machine ────────────────────────────────────────────────────


class TestStateMachine:
    def test_full_previs_to_video_path(self, tmp_path):
        """Walk a shot through the full happy path."""
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            transitions = [
                "previs_generating",
                "previs_generated",
                "previs_approved",
                "keyframe_generating",
                "keyframe_generated",
                "keyframe_approved",
                "video_pending",
                "video_submitted",
                "video_processing",
                "video_ready",
                "video_complete",
                "approved",
            ]
            for status in transitions:
                store.update_shot("ep01_s01_001", status=status)
            result = store.get_shot("ep01_s01_001")
            assert result["status"] == "approved"
        finally:
            store.close()

    def test_invalid_transition_warns(self, tmp_path, caplog):
        """Phase 2.5 softened transitions: non-standard transitions warn, not raise."""
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            with caplog.at_level(logging.WARNING, logger="recoil.execution.execution_store"):
                store.update_shot("ep01_s01_001", status="approved")
            assert "non-standard transition" in caplog.text
            assert "previs_pending -> approved" in caplog.text
            # Transition was applied despite being non-standard
            result = store.get_shot("ep01_s01_001")
            assert result["status"] == "approved"
        finally:
            store.close()

    def test_failed_transition_always_allowed(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            store.update_shot("ep01_s01_001", status="failed")
            result = store.get_shot("ep01_s01_001")
            assert result["status"] == "failed"
        finally:
            store.close()

    def test_force_reset_bypasses_state_machine(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            store.force_reset_status("ep01_s01_001", "approved", "manual override")
            result = store.get_shot("ep01_s01_001")
            assert result["status"] == "approved"
        finally:
            store.close()

    def test_force_reset_invalid_target_raises(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            with pytest.raises(ValueError, match="Invalid target state"):
                store.force_reset_status("ep01_s01_001", "bogus_state", "test")
        finally:
            store.close()


# ── Create Shot (duplicate guard) ────────────────────────────────────


class TestCreateShot:
    def test_create_shot(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.create_shot("s1", {"episode_id": "ep01"})
            result = store.get_shot("s1")
            assert result is not None
            assert result["episode_id"] == "ep01"
            assert result["status"] == "previs_pending"
        finally:
            store.close()

    def test_create_duplicate_raises(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.create_shot("s1")
            with pytest.raises(ShotExistsError):
                store.create_shot("s1")
        finally:
            store.close()


# ── Query by Status ──────────────────────────────────────────────────


class TestQueryByStatus:
    def test_get_shots_by_single_status(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shots_batch([
                _make_shot(shot_id="s1", status="previs_pending"),
                _make_shot(shot_id="s2", status="previs_generated"),
                _make_shot(shot_id="s3", status="previs_pending"),
            ])
            results = store.get_shots_by_status("previs_pending")
            assert len(results) == 2
            assert all(r["status"] == "previs_pending" for r in results)
        finally:
            store.close()

    def test_get_shots_by_multiple_statuses(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shots_batch([
                _make_shot(shot_id="s1", status="previs_pending"),
                _make_shot(shot_id="s2", status="video_complete"),
                _make_shot(shot_id="s3", status="previs_generated"),
            ])
            results = store.get_shots_by_status("previs_pending", "previs_generated")
            assert len(results) == 2
        finally:
            store.close()

    def test_get_shots_by_episode(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shots_batch([
                _make_shot(shot_id="s1", episode_id="ep01"),
                _make_shot(shot_id="s2", episode_id="ep02"),
                _make_shot(shot_id="s3", episode_id="ep01"),
            ])
            results = store.get_shots_by_episode("ep01")
            assert len(results) == 2
            assert all(r["episode_id"] == "ep01" for r in results)
        finally:
            store.close()


# ── Orphan Detection ─────────────────────────────────────────────────


class TestOrphanDetection:
    def test_detect_orphans_from_different_session(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot(
                shot_id="s1",
                status="video_submitted",
                session_id="old_session",
                job_id="job_123",
            ))
            orphans = store.detect_orphans("new_session")
            assert len(orphans) == 1
            assert orphans[0]["shot_id"] == "s1"
            assert orphans[0]["old_session_id"] == "old_session"
        finally:
            store.close()

    def test_no_orphans_same_session(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot(
                shot_id="s1",
                status="video_submitted",
                session_id="current",
            ))
            orphans = store.detect_orphans("current")
            assert len(orphans) == 0
        finally:
            store.close()

    def test_completed_shots_not_orphans(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot(
                shot_id="s1",
                status="video_complete",
                session_id="old_session",
            ))
            orphans = store.detect_orphans("new_session")
            assert len(orphans) == 0
        finally:
            store.close()

    def test_recover_orphan(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot(
                shot_id="s1",
                status="video_submitted",
                session_id="old",
                attempts=1,
            ))
            store.recover_orphan("s1", "new_session")
            result = store.get_shot("s1")
            assert result["session_id"] == "new_session"
            assert result["attempts"] == 2  # incremented
        finally:
            store.close()


# ── Cost Aggregation ─────────────────────────────────────────────────


class TestCostAggregation:
    def test_total_cost_all(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shots_batch([
                _make_shot(shot_id="s1", cost_incurred=0.10),
                _make_shot(shot_id="s2", cost_incurred=0.20),
                _make_shot(shot_id="s3", cost_incurred=0.05),
            ])
            assert abs(store.total_cost() - 0.35) < 1e-9
        finally:
            store.close()

    def test_total_cost_by_episode(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shots_batch([
                _make_shot(shot_id="s1", episode_id="ep01", cost_incurred=0.10),
                _make_shot(shot_id="s2", episode_id="ep02", cost_incurred=0.20),
            ])
            assert abs(store.total_cost("ep01") - 0.10) < 1e-9
            assert abs(store.total_cost("ep02") - 0.20) < 1e-9
        finally:
            store.close()

    def test_total_cost_empty_store(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            assert store.total_cost() == 0.0
        finally:
            store.close()


# ── Summary Generation ───────────────────────────────────────────────


class TestSummary:
    def test_summary_all(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shots_batch([
                _make_shot(shot_id="s1", status="previs_pending", cost_incurred=0.10),
                _make_shot(shot_id="s2", status="previs_pending", cost_incurred=0.10),
                _make_shot(shot_id="s3", status="video_complete", cost_incurred=0.50),
            ])
            s = store.summary()
            assert s["episode_id"] == "all"
            assert s["total_shots"] == 3
            assert s["by_status"]["previs_pending"] == 2
            assert s["by_status"]["video_complete"] == 1
            assert abs(s["total_cost"] - 0.70) < 1e-4
        finally:
            store.close()

    def test_summary_by_episode(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shots_batch([
                _make_shot(shot_id="s1", episode_id="ep01"),
                _make_shot(shot_id="s2", episode_id="ep02"),
            ])
            s = store.summary(episode_id="ep01")
            assert s["episode_id"] == "ep01"
            assert s["total_shots"] == 1
        finally:
            store.close()

    def test_summary_empty_store(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            s = store.summary()
            assert s["total_shots"] == 0
            assert s["total_cost"] == 0.0
            assert s["by_status"] == {}
        finally:
            store.close()

    def test_budget_summary(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shots_batch([
                _make_shot(shot_id="s1", episode_id="ep01", status="previs_pending", cost_incurred=0.10),
                _make_shot(shot_id="s2", episode_id="ep01", status="video_complete", cost_incurred=0.50),
                _make_shot(shot_id="s3", episode_id="ep02", status="previs_pending", cost_incurred=0.05),
            ])
            bs = store.budget_summary()
            assert len(bs["episodes"]) == 2
            ep01 = bs["episodes"][0]
            assert ep01["episode_id"] == "ep01"
            assert ep01["total_shots"] == 2
            assert ep01["prep"] == 1
            assert ep01["in_the_can"] == 1
            assert abs(bs["season_total_cost"] - 0.65) < 1e-4
        finally:
            store.close()


# ── Concurrent Access ────────────────────────────────────────────────


class TestConcurrentAccess:
    def test_two_threads_writing(self, tmp_path):
        shots_dir = tmp_path / "shots"
        store = _make_store(shots_dir)

        barrier = threading.Barrier(2)
        errors = []

        def writer(thread_id, count):
            try:
                barrier.wait(timeout=5)
                for i in range(count):
                    store.insert_shot(_make_shot(
                        shot_id=f"t{thread_id}_s{i}",
                        episode_id=f"ep{thread_id}",
                    ))
            except Exception as e:
                errors.append(e)

        t1 = threading.Thread(target=writer, args=(1, 20))
        t2 = threading.Thread(target=writer, args=(2, 20))
        t1.start()
        t2.start()
        t1.join(timeout=15)
        t2.join(timeout=15)

        assert len(errors) == 0, f"Concurrent write errors: {errors}"

        all_shots = store.get_all_shots()
        assert len(all_shots) == 40
        store.close()


# ── Atomic Transition ────────────────────────────────────────────────


class TestAtomicTransition:
    def test_atomic_transition_success(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            result = store.atomic_transition(
                "ep01_s01_001",
                allowed_from={"previs_pending"},
                to_state="previs_generating",
            )
            assert result is True
            assert store.get_shot("ep01_s01_001")["status"] == "previs_generating"
        finally:
            store.close()

    def test_atomic_transition_wrong_current_state(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            result = store.atomic_transition(
                "ep01_s01_001",
                allowed_from={"keyframe_generating"},
                to_state="keyframe_generated",
            )
            assert result is False
            # Status unchanged
            assert store.get_shot("ep01_s01_001")["status"] == "previs_pending"
        finally:
            store.close()

    def test_atomic_transition_nonexistent_shot(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            result = store.atomic_transition(
                "missing",
                allowed_from={"previs_pending"},
                to_state="previs_generating",
            )
            assert result is False
        finally:
            store.close()


# ── Delete Shot ──────────────────────────────────────────────────────


class TestDeleteShot:
    def test_delete_existing(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            assert store.delete_shot("ep01_s01_001") is True
            assert store.get_shot("ep01_s01_001") is None
        finally:
            store.close()

    def test_delete_nonexistent(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            assert store.delete_shot("missing") is False
        finally:
            store.close()


# ── Edge Cases ───────────────────────────────────────────────────────


class TestEdgeCases:
    def test_json_fields_parsed(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot(
                gate_results={"gate1": "pass"},
                takes=["take1.png"],
            ))
            result = store.get_shot("ep01_s01_001")
            assert isinstance(result["gate_results"], dict)
            assert isinstance(result["takes"], list)
        finally:
            store.close()

    def test_empty_gate_results_default(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot())
            result = store.get_shot("ep01_s01_001")
            assert result["gate_results"] == {}
            assert result["takes"] == []
        finally:
            store.close()

    def test_close_idempotent(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        store.close()
        store.close()  # Should not raise

    def test_checkpoint_is_noop(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.checkpoint()  # Should not raise
        finally:
            store.close()

    def test_get_all_shots_empty(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            assert store.get_all_shots() == []
        finally:
            store.close()

    def test_shot_minimal_fields(self, tmp_path):
        """Insert a shot with only the required shot_id."""
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot({"shot_id": "minimal"})
            result = store.get_shot("minimal")
            assert result is not None
            assert result["episode_id"] == ""
            assert result["status"] == "previs_pending"
            assert result["cost_incurred"] == 0.0
        finally:
            store.close()

    def test_atomic_write_creates_json_file(self, tmp_path):
        """Verify shots are stored as individual JSON files."""
        shots_dir = tmp_path / "shots"
        store = _make_store(shots_dir)
        try:
            store.insert_shot(_make_shot(shot_id="test_shot_42"))
            shot_file = shots_dir / "test_shot_42.json"
            assert shot_file.exists()
            data = json.loads(shot_file.read_text())
            assert data["shot_id"] == "test_shot_42"
            assert "updated_at" in data
        finally:
            store.close()

    def test_retry_waste_cost_accumulates(self, tmp_path):
        store = _make_store(tmp_path / "shots")
        try:
            store.insert_shot(_make_shot(retry_waste_cost=0.05))
            store.update_shot("ep01_s01_001", retry_waste_cost=0.03)
            result = store.get_shot("ep01_s01_001")
            assert abs(result["retry_waste_cost"] - 0.08) < 1e-9
        finally:
            store.close()
