"""S5-S8: Execution store integrity checks."""

import json
import os
import sqlite3

from recoil_checks import register_check


def _find_db(base):
    """Find the execution store database."""
    candidates = [
        os.path.join(base, "data", "execution_store.db"),
        os.path.join(base, "data", "execution_store.sqlite"),
    ]
    for c in candidates:
        if os.path.isfile(c):
            return c
    return None


def check_db_schema(base, discovered):
    """S5: Execution store DB schema matches expected tables/columns."""
    passes, fails, warns = [], [], []

    db_path = _find_db(base)
    if not db_path:
        warns.append("No execution store DB found (may not be initialized yet)")
        return {"pass": passes, "fail": fails, "warn": warns}

    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()

        # Check for expected tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = {row[0] for row in cursor.fetchall()}

        expected = {"shots", "takes", "episodes"}
        for t in expected:
            if t in tables:
                passes.append(f"Table '{t}' exists")
            else:
                warns.append(f"Table '{t}' not found (may use different schema)")

        conn.close()
    except sqlite3.Error as e:
        fails.append(f"DB error: {e}")

    return {"pass": passes, "fail": fails, "warn": warns}


def check_orphaned_generating(base, discovered):
    """S6: No shots stuck in 'generating' state."""
    passes, fails, warns = [], [], []

    db_path = _find_db(base)
    if not db_path:
        warns.append("No execution store DB found")
        return {"pass": passes, "fail": fails, "warn": warns}

    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()

        # Check for generating status
        cursor.execute(
            "SELECT COUNT(*) FROM shots WHERE status = 'generating'"
        )
        count = cursor.fetchone()
        if count and count[0] > 0:
            warns.append(f"{count[0]} shots stuck in 'generating' state")
        else:
            passes.append("No orphaned generating shots")

        conn.close()
    except sqlite3.Error:
        warns.append("Could not query shots table")

    return {"pass": passes, "fail": fails, "warn": warns}


def check_takes_json(base, discovered):
    """S7: Takes records have valid JSON metadata."""
    passes, fails, warns = [], [], []

    db_path = _find_db(base)
    if not db_path:
        warns.append("No execution store DB found")
        return {"pass": passes, "fail": fails, "warn": warns}

    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()

        cursor.execute("SELECT id, metadata FROM takes WHERE metadata IS NOT NULL LIMIT 100")
        rows = cursor.fetchall()

        if not rows:
            passes.append("No takes with metadata to check")
        else:
            bad = 0
            for row_id, meta in rows:
                try:
                    json.loads(meta)
                except (json.JSONDecodeError, TypeError):
                    bad += 1
            if bad:
                fails.append(f"{bad}/{len(rows)} takes have invalid JSON metadata")
            else:
                passes.append(f"All {len(rows)} takes have valid JSON metadata")

        conn.close()
    except sqlite3.Error:
        warns.append("Could not query takes table")

    return {"pass": passes, "fail": fails, "warn": warns}


def check_wal_mode(base, discovered):
    """S8: DB uses WAL mode for concurrent access safety."""
    passes, fails, warns = [], [], []

    db_path = _find_db(base)
    if not db_path:
        warns.append("No execution store DB found")
        return {"pass": passes, "fail": fails, "warn": warns}

    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        cursor.execute("PRAGMA journal_mode")
        mode = cursor.fetchone()[0]
        if mode.lower() == "wal":
            passes.append("DB uses WAL journal mode")
        else:
            warns.append(f"DB uses '{mode}' journal mode (WAL recommended)")
        conn.close()
    except sqlite3.Error as e:
        warns.append(f"Could not check journal mode: {e}")

    return {"pass": passes, "fail": fails, "warn": warns}


register_check("s5_db_schema", "Execution Store Schema", check_db_schema, section="store")
register_check("s6_orphaned_generating", "Orphaned Generating Shots", check_orphaned_generating, section="store")
register_check("s7_takes_json", "Takes JSON Integrity", check_takes_json, section="store")
register_check("s8_wal_mode", "DB WAL Mode", check_wal_mode, section="store")
