"""S21-S26: Pipeline integrity checks."""

import json
import os
import re

from recoil_checks import register_check


def check_sub_pipeline_routes(base, discovered):
    """S21: Sub-pipeline routes cover all shot types."""
    passes, fails, warns = [], [], []

    # Look for route_shot or routing logic
    for root, _, files in os.walk(os.path.join(base, "lib")):
        for f in files:
            if not f.endswith(".py"):
                continue
            path = os.path.join(root, f)
            with open(path, encoding="utf-8") as fh:
                content = fh.read()

            if "route_shot" in content or "SubPipeline" in content:
                passes.append(f"{f}: contains routing logic")

                # Check for still/i2v/t2v/multi_shot coverage
                pipelines = {"still", "i2v", "t2v", "multi_shot"}
                found = set()
                for p in pipelines:
                    if p in content.lower():
                        found.add(p)

                missing = pipelines - found
                if missing:
                    warns.append(f"{f}: missing pipeline references: {missing}")
                else:
                    passes.append(f"{f}: all 4 sub-pipelines referenced")

    if not passes:
        warns.append("No routing logic found in lib/")

    return {"pass": passes, "fail": fails, "warn": warns}


def check_cost_rates_match(base, discovered):
    """S22: Cost rates in code match model_profiles.json."""
    passes, fails, warns = [], [], []

    profiles_path = os.path.join(base, "config", "model_profiles.json")
    if not os.path.isfile(profiles_path):
        warns.append("model_profiles.json not found")
        return {"pass": passes, "fail": fails, "warn": warns}

    with open(profiles_path) as f:
        profiles = json.load(f)

    # Build cost map from profiles
    cost_map = {}
    for model_id, profile in profiles.items():
        for key, val in profile.items():
            if key.startswith("cost_per") and isinstance(val, (int, float)):
                cost_map[model_id] = val
                break

    # Check if any Python files hardcode different costs
    for root, _, files in os.walk(os.path.join(base, "lib")):
        for f in files:
            if not f.endswith(".py"):
                continue
            path = os.path.join(root, f)
            with open(path, encoding="utf-8") as fh:
                content = fh.read()

            # Look for hardcoded cost values
            for m in re.finditer(r'cost.*?(\d+\.\d+)', content):
                val = float(m.group(1))
                # Check if it matches any known cost
                if val in cost_map.values():
                    passes.append(f"{f}: cost {val} matches model_profiles")

    if not passes and not warns:
        passes.append("No hardcoded costs found to verify")

    return {"pass": passes, "fail": fails, "warn": warns}


def check_prompt_templates(base, discovered):
    """S23: Prompt template files are valid."""
    passes, fails, warns = [], [], []

    # Check for prompt templates in various locations
    template_dirs = [
        os.path.join(base, "data", "prompt_templates"),
        os.path.join(base, "config"),
    ]

    found_templates = False
    for template_dir in template_dirs:
        if not os.path.isdir(template_dir):
            continue
        for f in os.listdir(template_dir):
            if "prompt" in f.lower() or "template" in f.lower():
                found_templates = True
                path = os.path.join(template_dir, f)
                if f.endswith(".json"):
                    try:
                        with open(path) as fh:
                            json.load(fh)
                        passes.append(f"{f}: valid JSON")
                    except json.JSONDecodeError:
                        fails.append(f"{f}: invalid JSON")

    if not found_templates:
        passes.append("No prompt template files found (templates may be inline)")

    return {"pass": passes, "fail": fails, "warn": warns}


def check_render_schema(base, discovered):
    """S24: render_schema.py defines required data classes."""
    passes, fails, warns = [], [], []

    schema_path = os.path.join(base, "lib", "render_schema.py")
    if not os.path.isfile(schema_path):
        warns.append("render_schema.py not found")
        return {"pass": passes, "fail": fails, "warn": warns}

    with open(schema_path, encoding="utf-8") as f:
        content = f.read()

    expected_classes = ["ShotRecord", "PromptData", "RoutingData", "AssetData"]
    for cls in expected_classes:
        if f"class {cls}" in content:
            passes.append(f"Class {cls} defined")
        else:
            warns.append(f"Class {cls} not found in render_schema.py")

    return {"pass": passes, "fail": fails, "warn": warns}


def check_plan_structure(base, discovered):
    """S25: Plan files have required fields."""
    passes, fails, warns = [], [], []

    plans_dir = os.path.join(base, "data", "plans")
    if not os.path.isdir(plans_dir):
        warns.append("No plans/ directory")
        return {"pass": passes, "fail": fails, "warn": warns}

    for f in sorted(os.listdir(plans_dir))[:5]:
        if not f.endswith(".json"):
            continue
        path = os.path.join(plans_dir, f)
        try:
            with open(path) as fh:
                data = json.load(fh)

            required = {"episode", "shots"}
            missing = required - set(data.keys())
            if missing:
                fails.append(f"{f}: missing keys {missing}")
            else:
                passes.append(f"{f}: has required keys")
        except json.JSONDecodeError:
            fails.append(f"{f}: invalid JSON")

    if not passes and not fails:
        passes.append("No plans to check")

    return {"pass": passes, "fail": fails, "warn": warns}


def check_validation_module(base, discovered):
    """S26: validation.py module exists and has key validators."""
    passes, fails, warns = [], [], []

    val_path = os.path.join(base, "lib", "validation.py")
    if not os.path.isfile(val_path):
        warns.append("lib/validation.py not found")
        return {"pass": passes, "fail": fails, "warn": warns}

    with open(val_path, encoding="utf-8") as f:
        content = f.read()

    passes.append("validation.py exists")

    # Check for key validation functions
    expected_funcs = ["validate", "check", "verify"]
    found = sum(1 for func in expected_funcs if f"def {func}" in content or f"def {func}_" in content)
    if found:
        passes.append(f"{found} validation functions found")
    else:
        warns.append("No validation functions found")

    return {"pass": passes, "fail": fails, "warn": warns}


register_check("s21_sub_pipelines", "Sub-Pipeline Coverage", check_sub_pipeline_routes, section="pipeline")
register_check("s22_cost_rates", "Cost Rates Match Profiles", check_cost_rates_match, section="pipeline")
register_check("s23_prompt_templates", "Prompt Templates Valid", check_prompt_templates, section="pipeline")
register_check("s24_render_schema", "Render Schema Classes", check_render_schema, section="pipeline", quick=True)
register_check("s25_plan_structure", "Plan Structure", check_plan_structure, section="pipeline")
register_check("s26_validation_module", "Validation Module", check_validation_module, section="pipeline")
