#!/usr/bin/env python3
"""
Validation Dispatcher

Thin router that detects the format from project config and dispatches
to the format-specific validator.

Usage:
    python3 tools/validate_episode.py <episode_path> [project_path] [--json]

If no project_path is provided, defaults to kill_box format.
Dynamically imports the validator from formats/{format_name}/validate.py.

Exit codes:
    0 = Valid
    1 = Invalid (errors found)
    2 = File/config error
"""

import sys
import json
import importlib.util
import argparse
from pathlib import Path


# =============================================================================
# FORMAT DETECTION
# =============================================================================

def detect_format(project_path: Path = None) -> str:
    """
    Detect the format from project_config.json.
    Returns format name string (e.g. 'kill_box', 'puzzle_box', 'kill_box_micro').
    Defaults to 'kill_box' if no config found.
    """
    if project_path is None:
        return 'kill_box'

    config_file = project_path / 'project_config.json'
    if not config_file.exists():
        return 'kill_box'

    try:
        config = json.loads(config_file.read_text(encoding='utf-8'))
        return config.get('format', 'kill_box')
    except (json.JSONDecodeError, OSError):
        return 'kill_box'


# =============================================================================
# DYNAMIC IMPORT
# =============================================================================

def load_validator(format_name: str):
    """
    Dynamically import the validator module from formats/{format_name}/validate.py.
    Returns the module, or raises ImportError if not found.
    """
    # Resolve formats directory relative to this script
    engine_root = Path(__file__).parent.parent
    validator_path = engine_root / 'formats' / format_name / 'validate.py'

    if not validator_path.exists():
        raise ImportError(
            f"No validator found for format '{format_name}' at {validator_path}"
        )

    spec = importlib.util.spec_from_file_location(
        f"formats.{format_name}.validate",
        validator_path
    )
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)

    # Verify the module exports the required functions
    if not hasattr(module, 'validate_episode'):
        raise ImportError(
            f"Validator for '{format_name}' missing validate_episode() function"
        )
    if not hasattr(module, 'validate_batch'):
        raise ImportError(
            f"Validator for '{format_name}' missing validate_batch() function"
        )

    return module


# =============================================================================
# DISPATCH
# =============================================================================

def dispatch_validate_episode(episode_path: str, project_path: str = None,
                              constants: dict = None) -> dict:
    """
    Detect format and dispatch to the correct validator's validate_episode.

    Args:
        episode_path: Path to the episode file.
        project_path: Optional path to the project directory.
        constants: Optional constants dict to pass through.

    Returns:
        {valid: bool, errors: [], warnings: [], metrics: {}, format: str}
    """
    proj = Path(project_path) if project_path else None
    format_name = detect_format(proj)

    try:
        validator = load_validator(format_name)
    except ImportError as e:
        return {
            'valid': False,
            'errors': [str(e)],
            'warnings': [],
            'metrics': {},
            'format': format_name,
        }

    result = validator.validate_episode(episode_path, constants=constants)
    result['format'] = format_name
    return result


def dispatch_validate_batch(episode_paths: list, project_path: str = None,
                            constants: dict = None) -> dict:
    """
    Detect format and dispatch to the correct validator's validate_batch.

    Args:
        episode_paths: List of episode file paths.
        project_path: Optional path to the project directory.
        constants: Optional constants dict to pass through.

    Returns:
        {valid: bool, episode_results: [...], format: str}
    """
    proj = Path(project_path) if project_path else None
    format_name = detect_format(proj)

    try:
        validator = load_validator(format_name)
    except ImportError as e:
        return {
            'valid': False,
            'episode_results': [],
            'errors': [str(e)],
            'format': format_name,
        }

    result = validator.validate_batch(episode_paths, constants=constants)
    result['format'] = format_name
    return result


# =============================================================================
# CLI
# =============================================================================

def main():
    parser = argparse.ArgumentParser(
        description='Validate an episode against its format-specific rules'
    )
    parser.add_argument('episode', type=Path, help='Path to episode file')
    parser.add_argument('project', type=Path, nargs='?', default=None,
                        help='Path to project directory (for format detection)')
    parser.add_argument('--json', action='store_true',
                        help='Output results as JSON')

    args = parser.parse_args()

    if not args.episode.exists():
        print(f"ERROR: File not found: {args.episode}", file=sys.stderr)
        sys.exit(2)

    project_path = str(args.project) if args.project else None
    result = dispatch_validate_episode(str(args.episode), project_path=project_path)

    if args.json:
        print(json.dumps(result, indent=2))
    else:
        status = "VALID" if result['valid'] else "INVALID"
        format_name = result.get('format', 'unknown')
        print(f"\n{'=' * 60}")
        print(f"Episode: {args.episode.name}")
        print(f"Format:  {format_name}")
        print(f"Status:  {status}")
        print(f"{'=' * 60}")

        if result.get('errors'):
            print("\nERRORS:")
            for err in result['errors']:
                print(f"  - {err}")

        if result.get('warnings'):
            print("\nWARNINGS:")
            for warn in result['warnings']:
                print(f"  - {warn}")

        if result.get('metrics'):
            print("\nMETRICS:")
            for key, val in result['metrics'].items():
                print(f"  {key}: {val}")

    sys.exit(0 if result['valid'] else 1)


if __name__ == '__main__':
    main()
