#!/usr/bin/env python3
"""Pipeline Integration Audit Harness.

Wraps every video generation test with correctness assertions at each
pipeline stage. Not just "did the video look good" — but "did the pipeline
do the right thing at every step."

Audit checkpoints per test:
  1. PROMPT:  prompt_engine compiled the correct model-specific prompt
  2. REFS:    ref_selector allocated the right number and type of refs
  3. PAYLOAD: Client built the correct payload type with expected fields
  4. REQUEST: API request body contains exactly the right keys (no silent drops)
  5. STORE:   ExecutionStore has the take record with correct metadata
  6. COST:    Cost was calculated correctly from model profile

Any failure = FAIL LOUDLY with full diagnostic output.
No fallbacks. No "generate anyway." If a check fails, we stop and report.

Usage:
    python3 recoil/tools/shootout/audit_harness.py --plan afterimage
    python3 recoil/tools/shootout/audit_harness.py --plan tartarus --shots EP001_SH03,EP001_SH05
    python3 recoil/tools/shootout/audit_harness.py --dry-run  # Audit prompts+refs only, no API calls
    python3 recoil/tools/shootout/audit_harness.py --audit-only  # Re-audit existing results
"""

import argparse
import json
import logging
import os
import sys
import time
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Optional

# Ensure recoil is importable
SCRIPT_DIR = Path(__file__).parent
RECOIL_ROOT = SCRIPT_DIR.parent.parent
sys.path.insert(0, str(RECOIL_ROOT))

from recoil.core.paths import ensure_pipeline_importable, ProjectPaths
ensure_pipeline_importable()

from recoil.core import model_profiles
from recoil.execution.assembler import allocate_references, WanI2VPayload, WanR2VPayload

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger("audit")

# Load expected parameter schemas
EXPECTED_PARAMS_PATH = SCRIPT_DIR / "expected_params.json"
EXPECTED_PARAMS = json.loads(EXPECTED_PARAMS_PATH.read_text()) if EXPECTED_PARAMS_PATH.exists() else {}


# ══════════════════════════════════════════════════════════════════════
# Audit Result Types
# ══════════════════════════════════════════════════════════════════════

@dataclass
class CheckResult:
    """Result of a single audit checkpoint."""
    checkpoint: str       # PROMPT, REFS, PAYLOAD, REQUEST, STORE, COST
    passed: bool
    details: str
    data: dict = field(default_factory=dict)  # Raw data for diagnostics

    def __str__(self):
        status = "PASS" if self.passed else "FAIL"
        return f"  [{status}] {self.checkpoint}: {self.details}"


@dataclass
class TestAudit:
    """Full audit record for one test generation."""
    test_id: str
    model: str
    mode: str             # i2v, between, r2v, r2v_scene, multishot, t2v, continuation
    content: str          # afterimage or tartarus
    checks: list[CheckResult] = field(default_factory=list)
    generation_result: Optional[dict] = None
    wall_time_s: float = 0.0
    cost_usd: float = 0.0

    @property
    def passed(self) -> bool:
        return all(c.passed for c in self.checks)

    @property
    def fail_count(self) -> int:
        return sum(1 for c in self.checks if not c.passed)

    def summary(self) -> str:
        status = "PASS" if self.passed else f"FAIL ({self.fail_count} failures)"
        lines = [f"\n{'='*60}", f"TEST: {self.test_id} | {self.model} | {self.mode} | {status}"]
        lines.append(f"{'='*60}")
        for c in self.checks:
            lines.append(str(c))
        if self.generation_result:
            lines.append(f"  [INFO] Cost: ${self.cost_usd:.3f} | Wall: {self.wall_time_s:.1f}s")
        return "\n".join(lines)


# ══════════════════════════════════════════════════════════════════════
# Checkpoint 1: PROMPT — Did prompt_engine compile correctly?
# ══════════════════════════════════════════════════════════════════════

def audit_prompt(shot: dict, model: str, mode: str, bible: dict, config: dict) -> CheckResult:
    """Verify prompt_engine compiled the right prompt variant for this model."""
    from recoil.pipeline._lib.prompt_engine import compile_all_prompts, build_wan_i2v_prompt

    # Determine which prompt key this model/mode should use
    key_map = {
        ("wan-2.7-i2v", "i2v"): "wan_i2v",
        ("wan-2.7-i2v", "between"): "wan_between",
        ("wan-2.7-i2v", "oner"): "wan_i2v",
        ("wan-2.7-r2v", "r2v"): "wan_r2v",
        ("wan-2.7-r2v", "r2v_scene"): "wan_r2v",
        ("kling-v3", "i2v"): "kling_i2v",
        ("kling-v3", "multishot"): "kling_i2v",
        ("kling-v3", "oner"): "kling_i2v",
        ("kling-v3", "t2v"): "kling_t2v",
        ("veo-3.1", "i2v"): "veo_t2v",
        ("veo-3.1", "oner"): "veo_t2v",
        ("veo-3.1", "t2v"): "veo_t2v",
    }

    expected_key = key_map.get((model, mode))
    if not expected_key:
        return CheckResult(
            checkpoint="PROMPT",
            passed=False,
            details=f"No prompt key mapping for model={model} mode={mode}",
        )

    try:
        compiled = compile_all_prompts(shot, bible, config)
    except Exception as e:
        return CheckResult(
            checkpoint="PROMPT",
            passed=False,
            details=f"compile_all_prompts() raised: {e}",
        )

    if expected_key not in compiled:
        return CheckResult(
            checkpoint="PROMPT",
            passed=False,
            details=f"Expected key '{expected_key}' not in compiled prompts. Keys: {sorted(compiled.keys())}",
        )

    prompt = compiled[expected_key]
    if not prompt or len(prompt.strip()) < 10:
        return CheckResult(
            checkpoint="PROMPT",
            passed=False,
            details=f"Prompt for '{expected_key}' is empty or too short ({len(prompt)} chars)",
            data={"prompt": prompt},
        )

    # Check word count is in expected range
    word_count = len(prompt.split())
    expected_ranges = {
        "kling_i2v": (5, 40),
        "kling_t2v": (30, 150),
        "veo_t2v": (30, 300),
        "wan_i2v": (30, 350),
        "wan_between": (30, 350),
        "wan_r2v": (50, 500),
    }
    lo, hi = expected_ranges.get(expected_key, (5, 500))
    if word_count < lo or word_count > hi:
        return CheckResult(
            checkpoint="PROMPT",
            passed=False,
            details=f"'{expected_key}' word count {word_count} outside expected range [{lo}-{hi}]",
            data={"prompt": prompt, "word_count": word_count},
        )

    return CheckResult(
        checkpoint="PROMPT",
        passed=True,
        details=f"'{expected_key}' compiled OK ({word_count} words)",
        data={"prompt_key": expected_key, "prompt": prompt, "word_count": word_count},
    )


# ══════════════════════════════════════════════════════════════════════
# Checkpoint 2: REFS — Did ref_selector allocate correctly?
# ══════════════════════════════════════════════════════════════════════

def audit_refs(model: str, mode: str, num_chars: int, has_props: bool, is_env: bool) -> CheckResult:
    """Verify allocate_references returns the right allocation for this model."""
    pipeline_map = {
        "i2v": "i2v",
        "between": "i2v",
        "oner": "i2v",
        "r2v": "r2v",
        "r2v_scene": "r2v",
        "multishot": "multi_shot",
        "t2v": "t2v",
        "continuation": "i2v",
    }
    pipeline = pipeline_map.get(mode, "i2v")

    try:
        alloc = allocate_references(pipeline, model, num_chars, has_props, is_env)
    except Exception as e:
        return CheckResult(
            checkpoint="REFS",
            passed=False,
            details=f"allocate_references() raised: {e}",
        )

    # Validate expectations per model
    errors = []

    if "wan" in model and pipeline == "i2v":
        # Wan I2V: should have keyframe=1, everything else 0
        if alloc.get("identity", 0) > 0:
            errors.append(f"Wan I2V should have identity=0, got {alloc['identity']}")
        if alloc.get("scene", 0) > 0:
            errors.append(f"Wan I2V should have scene=0, got {alloc['scene']}")

    if "wan" in model and pipeline == "r2v":
        # Wan R2V: should have identity refs, NO scene refs
        if alloc.get("identity", 0) == 0 and num_chars > 0:
            errors.append(f"Wan R2V with {num_chars} chars should have identity>0, got 0")
        if alloc.get("scene", 0) > 0:
            errors.append(f"Wan R2V should have scene=0 (prompt carries env), got {alloc['scene']}")

    if errors:
        return CheckResult(
            checkpoint="REFS",
            passed=False,
            details=f"Allocation errors: {'; '.join(errors)}",
            data={"allocation": alloc, "model": model, "pipeline": pipeline},
        )

    return CheckResult(
        checkpoint="REFS",
        passed=True,
        details=f"Allocation OK: {alloc}",
        data={"allocation": alloc},
    )


# ══════════════════════════════════════════════════════════════════════
# Checkpoint 3: PAYLOAD — Did the client build the right payload type?
# ══════════════════════════════════════════════════════════════════════

def audit_payload(model: str, mode: str, payload) -> CheckResult:
    """Verify the correct payload type was constructed."""
    expected_types = {
        ("wan-2.7-i2v", "i2v"): "WanI2VPayload",
        ("wan-2.7-i2v", "between"): "WanI2VPayload",
        ("wan-2.7-i2v", "oner"): "WanI2VPayload",
        ("wan-2.7-r2v", "r2v"): "WanR2VPayload",
        ("wan-2.7-r2v", "r2v_scene"): "WanR2VPayload",
        ("kling-v3", "i2v"): "dict",
        ("kling-v3", "multishot"): "MultiShotPayload",
        ("kling-v3", "oner"): "dict",
        ("kling-v3", "t2v"): "dict",
        ("veo-3.1", "i2v"): "dict",
        ("veo-3.1", "oner"): "dict",
        ("veo-3.1", "t2v"): "dict",
    }

    expected = expected_types.get((model, mode))
    actual = type(payload).__name__

    if expected and actual != expected:
        return CheckResult(
            checkpoint="PAYLOAD",
            passed=False,
            details=f"Expected {expected}, got {actual}",
            data={"expected": expected, "actual": actual},
        )

    # Validate payload fields
    errors = []
    if isinstance(payload, WanI2VPayload):
        if not payload.image_url:
            errors.append("WanI2VPayload.image_url is empty")
        if payload.enable_prompt_expansion:
            errors.append("WanI2VPayload.enable_prompt_expansion should be False")
        if mode == "between" and not payload.end_image_url:
            errors.append("In Between mode but end_image_url is None")
        v_errors = payload.validate()
        if v_errors:
            errors.extend(v_errors)

    elif isinstance(payload, WanR2VPayload):
        if not payload.reference_image_urls:
            errors.append("WanR2VPayload.reference_image_urls is empty")
        v_errors = payload.validate()
        if v_errors:
            errors.extend(v_errors)

    if errors:
        return CheckResult(
            checkpoint="PAYLOAD",
            passed=False,
            details=f"Payload validation: {'; '.join(errors)}",
            data={"errors": errors},
        )

    return CheckResult(
        checkpoint="PAYLOAD",
        passed=True,
        details=f"Payload type {actual} OK",
        data={"type": actual},
    )


# ══════════════════════════════════════════════════════════════════════
# Checkpoint 4: REQUEST — API request body has correct keys?
# ══════════════════════════════════════════════════════════════════════

def audit_request_body(endpoint: str, body: dict) -> CheckResult:
    """Validate the API request body against expected parameter schema.

    This is the check that would have caught tail_image_url.
    """
    schema = EXPECTED_PARAMS.get(endpoint)
    if not schema:
        return CheckResult(
            checkpoint="REQUEST",
            passed=True,
            details=f"No schema for endpoint {endpoint} — skipping validation",
            data={"endpoint": endpoint, "body_keys": sorted(body.keys())},
        )

    errors = []
    body_keys = set(body.keys())

    # Check required keys are present
    for key in schema.get("required", []):
        if key not in body_keys:
            errors.append(f"MISSING required key: '{key}'")

    # Check forbidden keys are NOT present
    for key in schema.get("forbidden", []):
        if key in body_keys:
            errors.append(f"FORBIDDEN key present: '{key}' — this will be silently dropped by fal.ai!")

    # Check for unknown keys (not in required or optional)
    known_keys = set(schema.get("required", []) + schema.get("optional", []))
    unknown = body_keys - known_keys
    if unknown:
        errors.append(f"UNKNOWN keys (may be silently dropped): {sorted(unknown)}")

    # Type checks
    for key, expected_type in schema.get("type_checks", {}).items():
        if key in body:
            actual_type = type(body[key]).__name__
            if expected_type == "str" and actual_type != "str":
                errors.append(f"Type mismatch: '{key}' should be {expected_type}, got {actual_type} ({body[key]})")
            elif expected_type == "bool" and actual_type != "bool":
                errors.append(f"Type mismatch: '{key}' should be {expected_type}, got {actual_type} ({body[key]})")

    if errors:
        return CheckResult(
            checkpoint="REQUEST",
            passed=False,
            details=f"Request body validation FAILED: {'; '.join(errors)}",
            data={"endpoint": endpoint, "body_keys": sorted(body.keys()), "errors": errors},
        )

    return CheckResult(
        checkpoint="REQUEST",
        passed=True,
        details=f"Request body OK. Keys: {sorted(body.keys())}",
        data={"endpoint": endpoint, "body_keys": sorted(body.keys())},
    )


# ══════════════════════════════════════════════════════════════════════
# Checkpoint 5: STORE — Did ExecutionStore record correctly?
# ══════════════════════════════════════════════════════════════════════

def audit_store(store, shot_id: str, model: str) -> CheckResult:
    """Verify the ExecutionStore has the take record with correct metadata."""
    try:
        shot_data = store.get_shot(shot_id)
    except Exception as e:
        return CheckResult(
            checkpoint="STORE",
            passed=False,
            details=f"ExecutionStore.get_shot({shot_id}) raised: {e}",
        )

    if not shot_data:
        return CheckResult(
            checkpoint="STORE",
            passed=False,
            details=f"No shot data found for {shot_id}",
        )

    takes = shot_data.get("takes", [])
    if not takes:
        return CheckResult(
            checkpoint="STORE",
            passed=False,
            details=f"Shot {shot_id} has no takes recorded",
            data={"shot_data_keys": sorted(shot_data.keys())},
        )

    latest_take = takes[-1]
    errors = []

    if not latest_take.get("file_path"):
        errors.append("Take has no file_path")
    if latest_take.get("model") != model:
        errors.append(f"Take model={latest_take.get('model')}, expected {model}")
    if latest_take.get("cost_usd", 0) <= 0:
        errors.append(f"Take cost is {latest_take.get('cost_usd', 0)} — should be >0")

    if errors:
        return CheckResult(
            checkpoint="STORE",
            passed=False,
            details=f"Take record issues: {'; '.join(errors)}",
            data={"take": latest_take},
        )

    return CheckResult(
        checkpoint="STORE",
        passed=True,
        details=f"Take #{latest_take.get('take_number', '?')} recorded, cost=${latest_take.get('cost_usd', 0):.3f}",
        data={"take": latest_take},
    )


# ══════════════════════════════════════════════════════════════════════
# Checkpoint 6: COST — Was cost calculated correctly?
# ══════════════════════════════════════════════════════════════════════

def audit_cost(model: str, duration: int, reported_cost: float) -> CheckResult:
    """Verify cost matches model_profiles rate * duration."""
    try:
        rate = model_profiles.get_cost(model)
    except Exception as e:
        return CheckResult(
            checkpoint="COST",
            passed=False,
            details=f"model_profiles.get_cost({model}) raised: {e}",
        )

    expected = rate * duration
    tolerance = 0.01  # $0.01 tolerance for rounding

    if abs(reported_cost - expected) > tolerance:
        return CheckResult(
            checkpoint="COST",
            passed=False,
            details=f"Cost mismatch: reported ${reported_cost:.3f}, expected ${expected:.3f} ({rate}/s * {duration}s)",
            data={"rate": rate, "duration": duration, "expected": expected, "reported": reported_cost},
        )

    return CheckResult(
        checkpoint="COST",
        passed=True,
        details=f"Cost OK: ${reported_cost:.3f} ({rate}/s * {duration}s)",
    )


# ══════════════════════════════════════════════════════════════════════
# Request Body Interceptor
# ══════════════════════════════════════════════════════════════════════

class RequestInterceptor:
    """Monkey-patches client submit methods to capture request bodies.

    Wraps _fal_request to log and store every outgoing body.
    Restores originals on __exit__.
    """

    def __init__(self):
        self.captured_requests: list[dict] = []
        self._originals: list[tuple] = []

    def __enter__(self):
        return self

    def __exit__(self, *args):
        # Restore all original methods
        for obj, attr, original in self._originals:
            setattr(obj, attr, original)

    def patch_client(self, client):
        """Patch a client's _fal_request to capture outgoing bodies."""
        if not hasattr(client, '_fal_request'):
            logger.warning("Client %s has no _fal_request — skipping interception", type(client).__name__)
            return

        original = client._fal_request
        self._originals.append((client, '_fal_request', original))
        captured = self.captured_requests

        def intercepted(method, url, json_data=None):
            if json_data is not None:
                # Capture a sanitized copy (truncate data URIs)
                sanitized = {}
                for k, v in json_data.items():
                    if isinstance(v, str) and v.startswith("data:"):
                        sanitized[k] = f"data:...({len(v)} chars)"
                    elif isinstance(v, list) and len(v) > 3:
                        sanitized[k] = f"[{len(v)} items]"
                    else:
                        sanitized[k] = v
                captured.append({
                    "method": method,
                    "url": url,
                    "body_keys": sorted(json_data.keys()),
                    "body_sanitized": sanitized,
                    "body_raw": json_data,  # Full body for schema validation
                    "timestamp": time.time(),
                })
                logger.info(
                    "INTERCEPTED %s %s — keys: %s",
                    method, url.split("/requests/")[0] if "/requests/" in url else url,
                    sorted(json_data.keys()),
                )
            return original(method, url, json_data)

        client._fal_request = intercepted

    def get_submit_request(self) -> Optional[dict]:
        """Return the POST request (the submit call), if captured."""
        for req in self.captured_requests:
            if req["method"] == "POST":
                return req
        return None


# ══════════════════════════════════════════════════════════════════════
# Audit Report
# ══════════════════════════════════════════════════════════════════════

def write_audit_report(audits: list[TestAudit], output_dir: Path):
    """Write structured audit report."""
    output_dir.mkdir(parents=True, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    report_path = output_dir / f"audit_report_{timestamp}.md"

    total = len(audits)
    passed = sum(1 for a in audits if a.passed)
    failed = total - passed
    total_cost = sum(a.cost_usd for a in audits)

    lines = [
        f"# Pipeline Audit Report — {timestamp}",
        f"",
        f"**Total tests:** {total} | **Passed:** {passed} | **Failed:** {failed} | **Cost:** ${total_cost:.2f}",
        f"",
    ]

    if failed > 0:
        lines.append("## FAILURES")
        lines.append("")
        for a in audits:
            if not a.passed:
                lines.append(f"### {a.test_id} ({a.model}, {a.mode})")
                for c in a.checks:
                    if not c.passed:
                        lines.append(f"- **{c.checkpoint}:** {c.details}")
                        if c.data:
                            for k, v in c.data.items():
                                if k == "prompt":
                                    lines.append(f"  - {k}: {str(v)[:200]}...")
                                elif k == "errors":
                                    for err in v:
                                        lines.append(f"  - {err}")
                                else:
                                    lines.append(f"  - {k}: {v}")
                lines.append("")

    lines.append("## All Results")
    lines.append("")
    lines.append("| Test | Model | Mode | PROMPT | REFS | PAYLOAD | REQUEST | STORE | COST | Result |")
    lines.append("|------|-------|------|--------|------|---------|---------|-------|------|--------|")

    for a in audits:
        check_map = {c.checkpoint: c for c in a.checks}
        cols = []
        for cp in ["PROMPT", "REFS", "PAYLOAD", "REQUEST", "STORE", "COST"]:
            c = check_map.get(cp)
            if c is None:
                cols.append("--")
            elif c.passed:
                cols.append("OK")
            else:
                cols.append("FAIL")
        result = "PASS" if a.passed else "FAIL"
        lines.append(f"| {a.test_id} | {a.model} | {a.mode} | {' | '.join(cols)} | {result} |")

    lines.append("")
    lines.append("## Detailed Output")
    lines.append("")
    for a in audits:
        lines.append(a.summary())
        lines.append("")

    report_path.write_text("\n".join(lines))
    logger.info("Audit report written to %s", report_path)

    # Also write machine-readable JSON
    json_path = output_dir / f"audit_results_{timestamp}.json"
    json_data = {
        "timestamp": timestamp,
        "total": total,
        "passed": passed,
        "failed": failed,
        "total_cost": total_cost,
        "tests": [
            {
                "test_id": a.test_id,
                "model": a.model,
                "mode": a.mode,
                "content": a.content,
                "passed": a.passed,
                "fail_count": a.fail_count,
                "cost_usd": a.cost_usd,
                "wall_time_s": a.wall_time_s,
                "checks": [
                    {"checkpoint": c.checkpoint, "passed": c.passed, "details": c.details}
                    for c in a.checks
                ],
            }
            for a in audits
        ],
    }
    json_path.write_text(json.dumps(json_data, indent=2))
    logger.info("Audit JSON written to %s", json_path)

    return report_path


# ══════════════════════════════════════════════════════════════════════
# Dry Run — Audit prompts + refs without API calls
# ══════════════════════════════════════════════════════════════════════

def run_dry_audit(plan_path: Path, models: list[str] = None):
    """Audit prompt compilation and ref allocation without making API calls.

    Validates that prompt_engine and ref_selector work correctly for
    each model/mode combination. No cost, no video generation.
    """
    plan = json.loads(plan_path.read_text())
    shots = plan.get("shots", [])
    bible = plan.get("bible_stub", {})
    config = {"film_stock": "Kodak Vision3 500T"}

    if not models:
        models = ["kling-v3", "veo-3.1", "wan-2.7-i2v", "wan-2.7-r2v"]

    modes_per_model = {
        "kling-v3": ["i2v", "oner", "multishot", "t2v"],
        "veo-3.1": ["i2v", "oner", "t2v"],
        "wan-2.7-i2v": ["i2v", "between", "oner"],
        "wan-2.7-r2v": ["r2v", "r2v_scene"],
    }

    audits = []

    for shot in shots:
        shot_id = shot["shot_id"]
        num_chars = shot.get("routing_data", {}).get("num_characters", 1)
        has_props = bool(shot.get("asset_data", {}).get("props"))
        is_env = shot.get("routing_data", {}).get("is_env_only", False)

        for model in models:
            for mode in modes_per_model.get(model, ["i2v"]):
                test_id = f"{shot_id}__{model.replace('.', '_')}__{mode}"
                audit = TestAudit(
                    test_id=test_id,
                    model=model,
                    mode=mode,
                    content=plan.get("project", "unknown"),
                )

                # Checkpoint 1: PROMPT
                audit.checks.append(audit_prompt(shot, model, mode, bible, config))

                # Checkpoint 2: REFS
                audit.checks.append(audit_refs(model, mode, num_chars, has_props, is_env))

                audits.append(audit)

    # Report
    output_dir = SCRIPT_DIR / "results"
    report_path = write_audit_report(audits, output_dir)

    # Summary
    total = len(audits)
    passed = sum(1 for a in audits if a.passed)
    failed = total - passed

    logger.info("")
    logger.info("=" * 60)
    logger.info("DRY AUDIT COMPLETE")
    logger.info("=" * 60)
    logger.info("Tests: %d | Passed: %d | Failed: %d", total, passed, failed)

    if failed > 0:
        logger.error("FAILURES:")
        for a in audits:
            if not a.passed:
                logger.error("  %s: %s", a.test_id, [c.details for c in a.checks if not c.passed])

    return audits


# ══════════════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════════════

def main():
    parser = argparse.ArgumentParser(description="Pipeline Integration Audit Harness")
    parser.add_argument("--plan", type=str, default="afterimage",
                        help="Plan source: 'afterimage' or 'tartarus' or path to plan JSON")
    parser.add_argument("--shots", type=str, default=None,
                        help="Comma-separated shot IDs to test (default: all)")
    parser.add_argument("--models", type=str, default=None,
                        help="Comma-separated model IDs (default: all)")
    parser.add_argument("--dry-run", action="store_true",
                        help="Audit prompts + refs only, no API calls")
    args = parser.parse_args()

    # Resolve plan path
    if args.plan == "afterimage":
        plan_path = ProjectPaths.for_project("afterimage").plans_dir / "test_plan.json"
    elif args.plan == "tartarus":
        plan_path = ProjectPaths.for_project("tartarus").plans_dir / "ep_001_plan.json"
    else:
        plan_path = Path(args.plan)

    if not plan_path.exists():
        logger.error("Plan not found: %s", plan_path)
        sys.exit(1)

    models = args.models.split(",") if args.models else None

    if args.dry_run:
        audits = run_dry_audit(plan_path, models)
        sys.exit(0 if all(a.passed for a in audits) else 1)

    # Full audit with API calls — TODO: wire up after build #2 completes
    logger.info("Full audit mode requires build #2 (prompt engine) to be complete.")
    logger.info("Use --dry-run to audit prompts and refs without API calls.")
    logger.info("Full audit will be enabled after prompt_engine integration lands.")


if __name__ == "__main__":
    main()
