#!/usr/bin/python3
"""
Verify Pattern Variety

Checks hook/cliffhanger pattern compliance:
1. No 4+ consecutive same hook type
2. No 4+ consecutive same cliffhanger type (after pilot)
3. Overall distribution within 70-85% targets
4. Pattern violations flagged with specific episodes

Usage: python3 verify_pattern_variety.py <project_path>
Example: python3 verify_pattern_variety.py ./leviathan

Returns:
- Exit code 0: Patterns healthy
- Exit code 1: Violations found
"""

import json
import sys
from pathlib import Path

# Add engine tools to path
_SCRIPT_DIR = Path(__file__).parent.resolve()
sys.path.insert(0, str(_SCRIPT_DIR))
sys.path.insert(0, str(_SCRIPT_DIR.parent.parent))  # CLAUDE_PROJECTS, for recoil.*
from recoil.core.paths import ProjectPaths

# Import from single source of truth
from engine_constants import (
    MAX_CONSECUTIVE_SAME_TYPE as MAX_CONSECUTIVE,
    PILOT_EPISODE_COUNT as PILOT_EPISODES,
    HOOK_SILENT_VALIDATION_MIN,
    HOOK_SILENT_VALIDATION_MAX,
    CLIFFHANGER_MIDACTION_VALIDATION_MIN,
    CLIFFHANGER_MIDACTION_VALIDATION_MAX,
)

# Pattern distribution validation ranges (from engine_constants → CONSTANTS.md)
TARGET_SILENT_PERCENT = (HOOK_SILENT_VALIDATION_MIN, HOOK_SILENT_VALIDATION_MAX)
TARGET_MIDACTION_PERCENT = (CLIFFHANGER_MIDACTION_VALIDATION_MIN, CLIFFHANGER_MIDACTION_VALIDATION_MAX)


def load_orchestrator_state(project_path: Path) -> dict | None:
    """Load orchestrator state if exists."""
    state_file = ProjectPaths.from_root(project_path).state_dir / "orchestrator_state.json"
    if not state_file.exists():
        return None

    try:
        with open(state_file, 'r') as f:
            return json.load(f)
    except json.JSONDecodeError as e:
        print(f"Error: Invalid JSON in orchestrator_state.json: {e}")
        return None


def find_consecutive_runs(history: list, threshold: int = MAX_CONSECUTIVE) -> list:
    """
    Find runs of consecutive same values that exceed threshold.
    Returns list of (value, start_index, length).
    """
    if not history:
        return []

    runs = []
    current_value = history[0]
    start_index = 0
    length = 1

    for i in range(1, len(history)):
        if history[i] == current_value:
            length += 1
        else:
            if length > threshold:
                runs.append((current_value, start_index, length))
            current_value = history[i]
            start_index = i
            length = 1

    # Check final run
    if length > threshold:
        runs.append((current_value, start_index, length))

    return runs


def verify_patterns(state: dict) -> list:
    """
    Verify pattern variety.
    Returns list of issues found.
    """
    issues = []
    pattern_state = state.get("pattern_state", {})
    position = state.get("position", {})
    current_episode = position.get("last_completed_episode", 0)

    hooks = pattern_state.get("hooks", {})
    cliffs = pattern_state.get("cliffhangers", {})

    # Hook analysis
    hook_history = hooks.get("history", [])
    silent_count = hooks.get("silent_count", 0)
    dialogue_count = hooks.get("dialogue_count", 0)
    total_hooks = silent_count + dialogue_count

    if total_hooks > 0:
        silent_percent = round(silent_count / total_hooks * 100, 1)

        # Check 1: Hook distribution
        if silent_percent < TARGET_SILENT_PERCENT[0]:
            issues.append({
                "type": "distribution_low",
                "category": "hooks",
                "value": "silent",
                "percent": silent_percent,
                "target": f"{TARGET_SILENT_PERCENT[0]}-{TARGET_SILENT_PERCENT[1]}%",
                "severity": "warning",
                "recommendation": f"Silent hooks at {silent_percent}%, below target {TARGET_SILENT_PERCENT[0]}%. Use more silent hooks."
            })
        elif silent_percent > TARGET_SILENT_PERCENT[1]:
            issues.append({
                "type": "distribution_high",
                "category": "hooks",
                "value": "silent",
                "percent": silent_percent,
                "target": f"{TARGET_SILENT_PERCENT[0]}-{TARGET_SILENT_PERCENT[1]}%",
                "severity": "warning",
                "recommendation": f"Silent hooks at {silent_percent}%, above target {TARGET_SILENT_PERCENT[1]}%. Use more dialogue hooks for variety."
            })

    # Check 2: Consecutive hook violations
    hook_runs = find_consecutive_runs(hook_history, MAX_CONSECUTIVE)
    for pattern, start_idx, length in hook_runs:
        start_ep = start_idx + 1  # Episodes are 1-indexed
        end_ep = start_idx + length
        issues.append({
            "type": "consecutive_violation",
            "category": "hooks",
            "pattern": pattern,
            "consecutive": length,
            "episodes": f"{start_ep}-{end_ep}",
            "severity": "error",
            "recommendation": f"{length} consecutive {pattern} hooks (Episodes {start_ep}-{end_ep}). Max allowed is {MAX_CONSECUTIVE}."
        })

    # Check current consecutive count
    current_silent_consec = hooks.get("consecutive_silent", 0)
    current_dialogue_consec = hooks.get("consecutive_dialogue", 0)

    if current_silent_consec == MAX_CONSECUTIVE:
        issues.append({
            "type": "at_limit",
            "category": "hooks",
            "pattern": "silent",
            "consecutive": current_silent_consec,
            "severity": "warning",
            "recommendation": f"{MAX_CONSECUTIVE} consecutive silent hooks. Next hook MUST be dialogue to avoid violation."
        })
    if current_dialogue_consec == MAX_CONSECUTIVE:
        issues.append({
            "type": "at_limit",
            "category": "hooks",
            "pattern": "dialogue",
            "consecutive": current_dialogue_consec,
            "severity": "warning",
            "recommendation": f"{MAX_CONSECUTIVE} consecutive dialogue hooks. Next hook MUST be silent to avoid violation."
        })

    # Cliffhanger analysis (only after pilot)
    cliff_history = cliffs.get("history", [])
    midaction_count = cliffs.get("mid_action_count", 0)
    aftermath_count = cliffs.get("aftermath_count", 0)
    total_cliffs = midaction_count + aftermath_count

    if total_cliffs > 0:
        midaction_percent = round(midaction_count / total_cliffs * 100, 1)

        # Check 3: Cliffhanger distribution
        if midaction_percent < TARGET_MIDACTION_PERCENT[0]:
            issues.append({
                "type": "distribution_low",
                "category": "cliffhangers",
                "value": "mid-action",
                "percent": midaction_percent,
                "target": f"{TARGET_MIDACTION_PERCENT[0]}-{TARGET_MIDACTION_PERCENT[1]}%",
                "severity": "warning",
                "recommendation": f"Mid-action cliffhangers at {midaction_percent}%, below target {TARGET_MIDACTION_PERCENT[0]}%. Use more mid-action endings."
            })
        elif midaction_percent > TARGET_MIDACTION_PERCENT[1]:
            issues.append({
                "type": "distribution_high",
                "category": "cliffhangers",
                "value": "mid-action",
                "percent": midaction_percent,
                "target": f"{TARGET_MIDACTION_PERCENT[0]}-{TARGET_MIDACTION_PERCENT[1]}%",
                "severity": "warning",
                "recommendation": f"Mid-action cliffhangers at {midaction_percent}%, above target {TARGET_MIDACTION_PERCENT[1]}%. Use more aftermath endings for variety."
            })

    # Check 4: Consecutive cliffhanger violations (after pilot)
    # Only check episodes after pilot (index 10+)
    post_pilot_history = cliff_history[PILOT_EPISODES:] if len(cliff_history) > PILOT_EPISODES else []
    cliff_runs = find_consecutive_runs(post_pilot_history, MAX_CONSECUTIVE)

    for pattern, start_idx, length in cliff_runs:
        actual_start_ep = start_idx + PILOT_EPISODES + 1
        actual_end_ep = actual_start_ep + length - 1
        issues.append({
            "type": "consecutive_violation",
            "category": "cliffhangers",
            "pattern": pattern,
            "consecutive": length,
            "episodes": f"{actual_start_ep}-{actual_end_ep}",
            "severity": "error",
            "recommendation": f"{length} consecutive {pattern} cliffhangers (Episodes {actual_start_ep}-{actual_end_ep}). Max allowed is {MAX_CONSECUTIVE} after pilot."
        })

    # Check current consecutive count (only warn if past pilot)
    if current_episode > PILOT_EPISODES:
        current_midaction_consec = cliffs.get("consecutive_mid_action", 0)
        current_aftermath_consec = cliffs.get("consecutive_aftermath", 0)

        if current_midaction_consec == MAX_CONSECUTIVE:
            issues.append({
                "type": "at_limit",
                "category": "cliffhangers",
                "pattern": "mid-action",
                "consecutive": current_midaction_consec,
                "severity": "warning",
                "recommendation": f"{MAX_CONSECUTIVE} consecutive mid-action cliffhangers. Next MUST be aftermath to avoid violation."
            })
        if current_aftermath_consec == MAX_CONSECUTIVE:
            issues.append({
                "type": "at_limit",
                "category": "cliffhangers",
                "pattern": "aftermath",
                "consecutive": current_aftermath_consec,
                "severity": "warning",
                "recommendation": f"{MAX_CONSECUTIVE} consecutive aftermath cliffhangers. Next MUST be mid-action to avoid violation."
            })

    return issues


def main():
    if len(sys.argv) < 2:
        print("Usage: python3 verify_pattern_variety.py <project_path>")
        print("Example: python3 verify_pattern_variety.py ./leviathan")
        sys.exit(1)

    project_path = Path(sys.argv[1]).resolve()

    if not project_path.exists():
        print(f"Error: Project path does not exist: {project_path}")
        sys.exit(1)

    state = load_orchestrator_state(project_path)
    if not state:
        print(f"Error: No orchestrator_state.json found. Run init_orchestrator_state.py first.")
        sys.exit(1)

    position = state.get("position", {})
    current_ep = position.get("last_completed_episode", 0)
    pattern_state = state.get("pattern_state", {})
    hooks = pattern_state.get("hooks", {})
    cliffs = pattern_state.get("cliffhangers", {})

    print(f"\n{'='*60}")
    print(f"PATTERN VARIETY CHECK")
    print(f"{'='*60}")
    print(f"\nProject: {project_path.name}")
    print(f"Current position: Episode {current_ep}")

    # Current distribution
    print(f"\n{'─'*60}")
    print(f"DISTRIBUTION:")
    print(f"{'─'*60}")

    silent_count = hooks.get("silent_count", 0)
    dialogue_count = hooks.get("dialogue_count", 0)
    total_hooks = silent_count + dialogue_count
    if total_hooks > 0:
        silent_pct = round(silent_count / total_hooks * 100, 1)
        dialogue_pct = 100 - silent_pct
        print(f"  Hooks:        {silent_pct}% silent / {dialogue_pct}% dialogue")
        print(f"                ({silent_count} silent, {dialogue_count} dialogue)")
        print(f"                Target: {TARGET_SILENT_PERCENT[0]}-{TARGET_SILENT_PERCENT[1]}% silent")

    midaction_count = cliffs.get("mid_action_count", 0)
    aftermath_count = cliffs.get("aftermath_count", 0)
    total_cliffs = midaction_count + aftermath_count
    if total_cliffs > 0:
        midaction_pct = round(midaction_count / total_cliffs * 100, 1)
        aftermath_pct = 100 - midaction_pct
        print(f"\n  Cliffhangers: {midaction_pct}% mid-action / {aftermath_pct}% aftermath")
        print(f"                ({midaction_count} mid-action, {aftermath_count} aftermath)")
        print(f"                Target: {TARGET_MIDACTION_PERCENT[0]}-{TARGET_MIDACTION_PERCENT[1]}% mid-action")

    # Current consecutive
    print(f"\n{'─'*60}")
    print(f"CONSECUTIVE:")
    print(f"{'─'*60}")
    print(f"  Hooks:        {hooks.get('consecutive_silent', 0)} silent, {hooks.get('consecutive_dialogue', 0)} dialogue")
    print(f"  Cliffhangers: {cliffs.get('consecutive_mid_action', 0)} mid-action, {cliffs.get('consecutive_aftermath', 0)} aftermath")
    print(f"  (Max allowed: {MAX_CONSECUTIVE})")

    issues = verify_patterns(state)

    if issues:
        # Separate by severity
        errors = [i for i in issues if i["severity"] == "error"]
        warnings = [i for i in issues if i["severity"] == "warning"]

        print(f"\n{'─'*60}")
        print(f"ISSUES FOUND: {len(errors)} errors, {len(warnings)} warnings")
        print(f"{'─'*60}")

        for issue in errors + warnings:
            severity_icon = "✗" if issue["severity"] == "error" else "⚠"
            print(f"\n  {severity_icon} [{issue['type'].upper()}]")
            print(f"    {issue['recommendation']}")

        print(f"\n{'='*60}")

        # Return non-zero if any errors
        has_errors = len(errors) > 0
        sys.exit(1 if has_errors else 0)

    else:
        print(f"\n{'─'*60}")
        print(f"✓ All patterns healthy")
        print(f"{'='*60}\n")
        sys.exit(0)


if __name__ == "__main__":
    main()
