"""
coverage_validator.py — BLOCK/WARN/INFO validation for coverage passes.

BLOCK errors prevent --lock. WARN flags for human review. INFO is logged only.
"""

from types import SimpleNamespace

from recoil.core.model_profiles import get_segment_duration_bounds
from recoil.pipeline._lib.prose_validator import (
    Severity,
    ValidationResult,
    verify_authored_prose,
)
from orchestrator.coverage_planner import (
    CoveragePass,
    MAX_DURATION_S,
    MAX_SEGMENTS,
    MAX_CHARS_I2V,
    MAX_CHARS_T2V,
    SHOT_TYPE_ORDER,
)


def prose_verify(p: CoveragePass) -> list[ValidationResult]:
    """REC-72 D0b: deterministic gate that diffs LLM-authored prose against the
    pass's prompt skeleton.

    Inputs live in ``p.generation_config["prose_verify"]``, a dict:
        {"authored_prose": <str>, "prompt_skeleton": <dict>}
    where ``prompt_skeleton`` carries:
        {"char_ids": [<str>, ...],
         "segments": [{"timecode": "[m:ss-m:ss]"}, ...]  # or any list of segments
         "duration_s": <int>}             # target_editorial_duration_s

    STRICT NO-OP when the ``prose_verify`` key is ABSENT — which is the case for
    ALL of Phase 0 and every deterministic-builder pass until the D2 author stage
    populates it. Returns an empty list in that case, so existing passes are
    unaffected.

    When present, emits Severity.BLOCK on:
      - coverage:  any skeleton char_id missing from the authored prose;
                   the author's duration not aligned to the skeleton duration.
      - structure: segment count in prose (timecodes found) != skeleton count;
                   any skeleton segment without a [m:ss-m:ss] timecode in prose.
      - safety:    any @ImageN literal in the AUTHORED prose (must be pre-bind).
    """
    spec = p.generation_config.get("prose_verify")
    if spec is None:
        return []  # strict no-op — no authored prose yet (Phase 0 + all builder passes)

    prose = spec.get("authored_prose") or ""
    skeleton = spec.get("prompt_skeleton") or {}
    primitive = SimpleNamespace(
        shot_id=p.pass_id,
        target_editorial_duration_s=skeleton.get("duration_s") or p.duration_s,
        char_ids=skeleton.get("char_ids") or [],
        timing_segments=skeleton.get("segments") or [],
        refs={},
    )
    strategy = SimpleNamespace(
        name=spec.get("strategy") or "coverage_pass",
        required_inputs=[],
    )
    results = verify_authored_prose(prose, primitive, strategy)

    # Preserve the legacy CoveragePass comparison in addition to the canonical
    # authored-timecode duration sum.
    skel_duration = skeleton.get("duration_s")
    if skel_duration is not None and skel_duration != p.duration_s:
        results.append(ValidationResult(
            Severity.BLOCK, p.pass_id, "prose_verify_duration",
            f"Skeleton target duration {skel_duration}s does not match the pass "
            f"duration {p.duration_s}s",
        ))

    return results


def validate_pass(p: CoveragePass) -> list[ValidationResult]:
    """Run all checks against a single pass."""
    results: list[ValidationResult] = []

    mode = p.generation_config.get("mode", "i2v")
    char_count = p.character_count

    # ── BLOCK checks ──

    # Element budget exceeded (start frame takes 1 slot for I2V)
    if mode == "i2v" and char_count > MAX_CHARS_I2V:
        results.append(ValidationResult(
            Severity.BLOCK, p.pass_id, "element_budget",
            f"I2V element budget exceeded: {char_count} chars + start frame > {MAX_CHARS_I2V + 1} total slots",
        ))
    elif mode == "t2v" and char_count > MAX_CHARS_T2V:
        results.append(ValidationResult(
            Severity.BLOCK, p.pass_id, "element_budget",
            f"T2V element budget exceeded: {char_count} chars > {MAX_CHARS_T2V} slots",
        ))

    # Duration exceeded
    if p.duration_s > MAX_DURATION_S:
        results.append(ValidationResult(
            Severity.BLOCK, p.pass_id, "duration",
            f"Duration {p.duration_s}s exceeds {MAX_DURATION_S}s limit",
        ))

    # Segment duration vs model bounds (BLOCK)
    model = p.generation_config.get("model")
    if model:
        try:
            min_d, max_d = get_segment_duration_bounds(model)
        except KeyError:
            results.append(ValidationResult(
                severity=Severity.BLOCK,
                pass_id=p.pass_id,
                check="unknown_model",
                message=f"Model {model!r} not found in model_profiles.json",
            ))
        else:
            for i, seg in enumerate(p.segments):
                if seg.duration_s < min_d or seg.duration_s > max_d:
                    results.append(ValidationResult(
                        severity=Severity.BLOCK,
                        pass_id=p.pass_id,
                        check="segment_duration_model_bounds",
                        message=f"Segment {i} duration {seg.duration_s}s outside model bounds [{min_d}, {max_d}]",
                    ))

    # Segment count exceeded
    if len(p.segments) > MAX_SEGMENTS:
        results.append(ValidationResult(
            Severity.BLOCK, p.pass_id, "segment_count",
            f"Segment count {len(p.segments)} exceeds {MAX_SEGMENTS} limit",
        ))

    # Missing start frame for I2V
    if mode == "i2v" and not p.generation_config.get("start_frame_path"):
        results.append(ValidationResult(
            Severity.BLOCK, p.pass_id, "missing_start_frame",
            "I2V mode requires a start frame but none was resolved",
        ))

    # Prose-verify (REC-72 D0b): strict NO-OP unless generation_config carries a
    # "prose_verify" key (populated by the D2 author stage). When present, diffs
    # the authored prose against the skeleton and emits Severity.BLOCK on failure.
    results.extend(prose_verify(p))

    # ── WARN checks ──

    # cfg_scale mismatch
    cfg = p.generation_config.get("cfg_scale", 0.55)
    if p.pass_type == "character" and cfg < 0.50:
        results.append(ValidationResult(
            Severity.WARN, p.pass_id, "cfg_mismatch",
            f"Character pass with low cfg_scale {cfg} (expected >= 0.55)",
        ))
    if p.pass_type == "env" and cfg > 0.55:
        results.append(ValidationResult(
            Severity.WARN, p.pass_id, "cfg_mismatch",
            f"ENV pass with high cfg_scale {cfg} (expected <= 0.50)",
        ))

    # Duration outlier within pass
    if len(p.segments) >= 2:
        durs = [s.duration_s for s in p.segments]
        if max(durs) > 0 and min(durs) > 0 and max(durs) / min(durs) > 2.0:
            results.append(ValidationResult(
                Severity.WARN, p.pass_id, "duration_outlier",
                f"Duration spread {min(durs)}s-{max(durs)}s (ratio > 2x)",
            ))

    # Shot type inconsistency (WS → ECU → WS pattern)
    if len(p.segments) >= 3:
        ranks = [SHOT_TYPE_ORDER.get(s.shot_type, 5) for s in p.segments]
        for i in range(1, len(ranks) - 1):
            jump_in = abs(ranks[i] - ranks[i-1])
            jump_out = abs(ranks[i+1] - ranks[i])
            if jump_in >= 4 and jump_out >= 4:
                results.append(ValidationResult(
                    Severity.WARN, p.pass_id, "shot_type_jump",
                    f"Shot type whiplash at segment {i}: "
                    f"{p.segments[i-1].shot_type} → {p.segments[i].shot_type} → {p.segments[i+1].shot_type}",
                ))
                break

    # Prompt bleed risk: segment references non-focus character loaded as element
    if p.character_count > 1:
        for seg in p.segments:
            prompt_lower = seg.prompt.lower()
            for elem in p.element_config.get("character_elements", []):
                char_id = elem.get("char_id", "")
                if (char_id and char_id.lower() != p.focus_character.lower()
                        and char_id.lower() in prompt_lower):
                    results.append(ValidationResult(
                        Severity.WARN, p.pass_id, "prompt_bleed",
                        f"Segment {seg.segment_index} prompt references non-focus "
                        f"character '{char_id}' (loaded but not featured)",
                    ))
                    break  # one warning per pass is enough

    return results


def validate_all_passes(passes: list[CoveragePass]) -> list[ValidationResult]:
    """Validate all passes and add cross-pass INFO checks."""
    results: list[ValidationResult] = []

    for p in passes:
        results.extend(validate_pass(p))

    # ── INFO checks (cross-pass) ──

    if passes:
        single_seg = sum(1 for p in passes if len(p.segments) == 1)
        ratio = single_seg / len(passes)
        if ratio > 0.50:
            results.append(ValidationResult(
                Severity.INFO, "*", "single_segment_ratio",
                f"{single_seg}/{len(passes)} passes ({ratio:.0%}) are single-segment",
            ))

    return results
