"""
batch_critic.py — IP2: Batch Boundary Critic.

Diagnostic-only critic that runs at checkpoint batches (3, 6, 9, 12).
Uses Gemini Flash to evaluate 6 narrative quality dimensions across
the completed batch of episodes.

All dimensions are SOFT (advisory). Failures inject course-correction
instructions into the NEXT batch — they never trigger regeneration.

Invocation:
    python3 tools/batch_critic.py ./[project] --batch N

Output:
    projects/{project}/state/batch_critic_batch_NN.json
"""

import argparse
import json
import logging
import os
import sys
import time
from datetime import datetime, timezone
from pathlib import Path

# Cross-engine import: CriticLoop lives in recoil/pipeline/lib/
_SCRIPT_DIR = Path(__file__).resolve().parent
_RECOIL_ROOT = _SCRIPT_DIR.parent
_PIPELINE_ROOT = _RECOIL_ROOT / "pipeline"

if str(_PIPELINE_ROOT) not in sys.path:
    sys.path.insert(0, str(_PIPELINE_ROOT))

from recoil.core.critic import CriticLoop, CriticResult, Dimension, Severity
from recoil.core.paths import ProjectPaths

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)

# Checkpoint batches where this critic runs
CHECKPOINT_BATCHES = {3, 6, 9, 12}

# The 6 narrative dimensions
DIMENSIONS = [
    "VOICE",           # Character voice consistency
    "PATTERN_FATIGUE", # Repetitive structural patterns
    "ARC_EARNING",     # Emotional beats feel earned, not rushed
    "CONTINUITY",      # Cross-episode factual consistency
    "TEXTURE_TONE",    # Prose texture and tonal variety
    "EXPOSITION",      # Show-don't-tell ratio
]

_SYSTEM_PROMPT = """You are a senior showrunner reviewing a batch of 5 episodes from a vertical microdrama series.

Evaluate the batch against these 6 dimensions. For each, return:
- "pass" or "fail"
- A 1-2 sentence explanation

DIMENSIONS:
1. VOICE — Does each character sound distinct and consistent? Are idioms used at appropriate frequency?
2. PATTERN_FATIGUE — Are episode structures varied? (Hook types, cliffhanger types, scene patterns.) Flag if 3+ consecutive episodes follow the same structure.
3. ARC_EARNING — Do emotional beats feel earned by prior setup? Or do characters arrive at emotions without build-up?
4. CONTINUITY — Are locations, character states, props, and timeline consistent across episodes? Flag any contradictions.
5. TEXTURE_TONE — Does the prose have variety in rhythm, sentence length, and diction? Or does it flatten into a single register?
6. EXPOSITION — Is backstory woven into action, or dumped in dialogue/narration? Flag any "As you know, Bob" moments.

OUTPUT FORMAT (strict JSON):
{
  "dimensions": {
    "VOICE": {"passed": true, "message": "..."},
    "PATTERN_FATIGUE": {"passed": false, "message": "Episodes 11-13 all use silent hook + mid-action cliffhanger."},
    "ARC_EARNING": {"passed": true, "message": "..."},
    "CONTINUITY": {"passed": false, "message": "Marcus has a bandage in EP12 but it's never mentioned in EP13."},
    "TEXTURE_TONE": {"passed": true, "message": "..."},
    "EXPOSITION": {"passed": true, "message": "..."}
  },
  "course_corrections": [
    "Next batch: vary hook type — use dialogue hook for at least 1 of 5 episodes.",
    "Next batch: reference Marcus's injury in continuity notes."
  ]
}"""


class BatchBoundaryCritic(CriticLoop):
    """IP2: Batch-level narrative quality critic via Gemini Flash.

    All dimensions are SOFT. This critic diagnoses but never blocks.
    """

    def __init__(
        self,
        episodes_text: str,
        treatment_text: str = "",
        batch_num: int = 0,
        experience_pool_dir: Path | None = None,
    ):
        super().__init__(
            name="batch_boundary",
            max_attempts=1,  # Diagnostic only
            experience_pool_dir=experience_pool_dir,
            shot_id=f"batch_{batch_num:02d}",
        )
        self.episodes_text = episodes_text
        self.treatment_text = treatment_text
        self.batch_num = batch_num

    def evaluate(self, artifact: str, context: dict) -> list[Dimension]:
        """Call Gemini Flash to evaluate batch quality.

        Args:
            artifact: Concatenated episode text (the batch).
            context: Unused.

        Returns:
            List of 6 Dimension results.
        """
        # Build the user prompt
        user_prompt = f"# BATCH {self.batch_num} EPISODES\n\n{artifact}"
        if self.treatment_text:
            user_prompt = f"# SERIES TREATMENT (for reference)\n\n{self.treatment_text[:3000]}\n\n{user_prompt}"

        # Call Gemini Flash
        try:
            flash_result = self._call_flash(user_prompt)
        except Exception as e:
            logger.error("Flash call failed: %s", e)
            # Return all-pass on API failure (graceful degradation)
            return [
                Dimension(name=d, severity=Severity.SOFT, passed=True,
                          message=f"Critic unavailable: {e}")
                for d in DIMENSIONS
            ]

        if not flash_result.get("success"):
            return [
                Dimension(name=d, severity=Severity.SOFT, passed=True,
                          message=f"Flash call failed: {flash_result.get('error', 'unknown')}")
                for d in DIMENSIONS
            ]

        # Parse Flash response
        return self._parse_response(flash_result["text"])

    def auto_fix(self, artifact, failed_dims, context):
        """No auto-fix for batch critic — diagnostic only."""
        return artifact

    def _call_flash(self, user_prompt: str) -> dict:
        """Call Gemini Flash for batch evaluation."""
        try:
            from google import genai
            from google.genai import types as genai_types
        except ImportError:
            return {"success": False, "error": "google-genai not installed"}

        api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
        if not api_key:
            return {"success": False, "error": "GEMINI_API_KEY not set"}

        client = genai.Client(api_key=api_key)
        config = genai_types.GenerateContentConfig(
            temperature=0.3,
            systemInstruction=_SYSTEM_PROMPT,
            responseModalities=["TEXT"],
            responseMimeType="application/json",
        )

        # Use Flash model
        try:
            from recoil.core.model_profiles import get_model
            flash_model = get_model("flash", "text")
        except Exception:
            flash_model = "gemini-2.5-flash"

        response = client.models.generate_content(
            model=flash_model,
            contents=user_prompt,
            config=config,
        )

        text = ""
        if response and response.candidates:
            for candidate in response.candidates:
                if candidate.content and candidate.content.parts:
                    for part in candidate.content.parts:
                        if hasattr(part, "text") and part.text:
                            text += part.text

        if text:
            return {"success": True, "text": text.strip(), "cost": 0.001}
        return {"success": False, "error": "No text in Flash response"}

    def _parse_response(self, text: str) -> list[Dimension]:
        """Parse Flash JSON response into Dimension objects."""
        dims = []
        try:
            data = json.loads(text)
            dim_data = data.get("dimensions", {})
            for dim_name in DIMENSIONS:
                entry = dim_data.get(dim_name, {})
                dims.append(Dimension(
                    name=dim_name,
                    severity=Severity.SOFT,
                    passed=entry.get("passed", True),
                    message=entry.get("message", ""),
                ))

            # Store course corrections for orchestrator injection
            self._course_corrections = data.get("course_corrections", [])

        except (json.JSONDecodeError, KeyError) as e:
            logger.warning("Could not parse Flash response: %s", e)
            # Return all-pass on parse failure
            for dim_name in DIMENSIONS:
                dims.append(Dimension(
                    name=dim_name,
                    severity=Severity.SOFT,
                    passed=True,
                    message=f"Parse error: {e}",
                ))
            self._course_corrections = []

        return dims

    @property
    def course_corrections(self) -> list[str]:
        """Course correction strings for the next batch's orchestrator instructions."""
        return getattr(self, "_course_corrections", [])


def run_batch_critic(project_dir: Path, batch_num: int) -> dict:
    """Run the batch boundary critic for a completed batch.

    Args:
        project_dir: Path to project directory (e.g. projects/leviathan).
        batch_num: Batch number (1-12).

    Returns:
        Dict with critic_result, course_corrections, and output_path.
    """
    if batch_num not in CHECKPOINT_BATCHES:
        return {
            "skipped": True,
            "reason": f"Batch {batch_num} is not a checkpoint batch ({CHECKPOINT_BATCHES})",
        }

    # Load episodes for this batch
    episodes_dir = project_dir / "episodes"
    start_ep = (batch_num - 1) * 5 + 1
    end_ep = batch_num * 5

    episode_texts = []
    for ep_num in range(start_ep, end_ep + 1):
        ep_path = episodes_dir / f"ep_{ep_num:03d}.md"
        if ep_path.exists():
            episode_texts.append(ep_path.read_text(encoding="utf-8"))
        else:
            logger.warning("Episode %d not found: %s", ep_num, ep_path)

    if not episode_texts:
        return {"error": f"No episodes found for batch {batch_num}"}

    batch_text = "\n\n---\n\n".join(episode_texts)

    # Load treatment for context
    treatment_text = ""
    treatment_path = project_dir / "treatment.md"
    if treatment_path.exists():
        treatment_text = treatment_path.read_text(encoding="utf-8")

    # Experience pool dir
    pool_dir = ProjectPaths.from_root(project_dir).state_dir

    # Run critic
    critic = BatchBoundaryCritic(
        episodes_text=batch_text,
        treatment_text=treatment_text,
        batch_num=batch_num,
        experience_pool_dir=pool_dir,
    )
    _, result = critic.run(batch_text)

    # Save output
    output = {
        "batch": batch_num,
        "episodes": f"{start_ep}-{end_ep}",
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "critic_result": result.to_dict(),
        "course_corrections": critic.course_corrections,
    }

    output_path = ProjectPaths.from_root(project_dir).state_dir / f"batch_critic_batch_{batch_num:02d}.json"
    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_text(json.dumps(output, indent=2), encoding="utf-8")
    logger.info("Batch critic output: %s", output_path)

    return {
        "critic_result": result.to_dict(),
        "course_corrections": critic.course_corrections,
        "output_path": str(output_path),
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="IP2: Batch Boundary Critic")
    parser.add_argument("project", help="Project directory path")
    parser.add_argument("--batch", type=int, required=True, help="Batch number (1-12)")
    parser.add_argument("--json", action="store_true", help="Output as JSON (default)")
    args = parser.parse_args()

    project_path = Path(args.project).resolve()
    if not project_path.is_dir():
        print(f"ERROR: Project directory not found: {project_path}")
        sys.exit(1)

    result = run_batch_critic(project_path, args.batch)
    print(json.dumps(result, indent=2))
