文件预览

run_golden_eval.py

查看 YouOS 技能包中的文件内容。

文件内容

scripts/run_golden_eval.py

"""Run golden benchmark evaluation against curated test cases."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any

import yaml

from app.core.settings import get_var_dir

ROOT_DIR = Path(__file__).resolve().parents[1]
GOLDEN_PATH = ROOT_DIR / "configs" / "benchmarks" / "golden.yaml"
# Per-instance: results from each instance's nightly land in its own var/
# so multiple instances don't clobber each other's last-eval JSON.
RESULTS_PATH = get_var_dir() / "golden_results.json"


def load_golden_cases(path: Path | None = None) -> list[dict[str, Any]]:
    """Load golden benchmark cases from YAML."""
    p = path or GOLDEN_PATH
    data = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
    return data.get("cases", [])


def score_case(
    case: dict[str, Any],
    draft: str,
    detected_mode: str,
    detected_language: str | None = None,
) -> dict[str, Any]:
    """Score a single golden benchmark case."""
    draft_lower = draft.lower()
    words = draft.split()
    word_count = len(words)

    # Keyword hit rate
    expected_keywords = case.get("expected_keywords", [])
    if expected_keywords:
        hits = sum(1 for kw in expected_keywords if kw.lower() in draft_lower)
        keyword_hit_rate = hits / len(expected_keywords)
    else:
        keyword_hit_rate = 1.0

    # Mode match
    expected_mode = case.get("expected_mode", "work")
    mode_match = detected_mode == expected_mode

    # Brevity check
    max_words = case.get("max_words", 100)
    brevity_pass = word_count <= max_words

    # Language detection check
    expected_language = case.get("expected_language")
    language_match = True
    if expected_language and detected_language:
        language_match = detected_language == expected_language

    # Overall pass/warn/fail — max_words violation is now a fail condition
    if not brevity_pass:
        status = "fail"
    elif keyword_hit_rate >= 0.5 and mode_match and language_match:
        status = "pass"
    elif keyword_hit_rate >= 0.25 or mode_match:
        status = "warn"
    else:
        status = "fail"

    result = {
        "case_id": case["id"],
        "description": case.get("description", ""),
        "keyword_hit_rate": round(keyword_hit_rate, 2),
        "mode_match": mode_match,
        "detected_mode": detected_mode,
        "expected_mode": expected_mode,
        "word_count": word_count,
        "max_words": max_words,
        "brevity_pass": brevity_pass,
        "status": status,
    }
    if expected_language:
        result["expected_language"] = expected_language
        result["detected_language"] = detected_language
        result["language_match"] = language_match
    return result


def run_golden_eval(
    *,
    generate_fn=None,
    database_url: str | None = None,
    configs_dir: Path | None = None,
    golden_path: Path | None = None,
) -> dict[str, Any]:
    """Run the full golden evaluation suite.

    If generate_fn is None, returns empty results (for testing without model).
    """
    cases = load_golden_cases(golden_path)
    results: list[dict[str, Any]] = []

    for case in cases:
        if generate_fn is not None:
            output = generate_fn(
                case["inbound"],
                database_url=database_url,
                configs_dir=configs_dir,
            )
            draft = output.get("draft", "")
            detected_mode = output.get("detected_mode", "unknown")
            detected_language = output.get("detected_language")
        else:
            draft = ""
            detected_mode = "unknown"
            detected_language = None

        result = score_case(case, draft, detected_mode, detected_language)
        results.append(result)

    total = len(results)
    passed = sum(1 for r in results if r["status"] == "pass")
    warned = sum(1 for r in results if r["status"] == "warn")
    failed = sum(1 for r in results if r["status"] == "fail")

    summary = {
        "total": total,
        "passed": passed,
        "warned": warned,
        "failed": failed,
        "results": results,
    }

    return summary


def save_results(summary: dict[str, Any], path: Path | None = None) -> None:
    """Save golden results to JSON."""
    p = path or RESULTS_PATH
    p.parent.mkdir(parents=True, exist_ok=True)
    with open(p, "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=2)


def format_scorecard(summary: dict[str, Any]) -> str:
    """Format golden benchmark results as a scorecard."""
    lines: list[str] = []
    lines.append("Golden Benchmark Results")
    lines.append("=" * 60)

    for r in summary["results"]:
        icon = {"pass": "PASS", "warn": "WARN", "fail": "FAIL"}.get(r["status"], "?")
        kw_pct = int(r["keyword_hit_rate"] * 100)
        lines.append(f"  {r['case_id']:<30} {icon:<5} | kw={kw_pct}% mode={'Y' if r['mode_match'] else 'N'} words={r['word_count']}/{r['max_words']}")

    lines.append("=" * 60)
    lines.append(f"  Total: {summary['total']} | Pass: {summary['passed']} | Warn: {summary['warned']} | Fail: {summary['failed']}")
    return "\n".join(lines)


def main() -> None:
    from app.core.settings import get_settings
    from app.db.bootstrap import resolve_sqlite_path

    parser = argparse.ArgumentParser(description="Run golden benchmark evaluation")
    parser.add_argument("--golden", type=Path, default=GOLDEN_PATH, help="Path to golden.yaml")
    parser.add_argument("--summary-only", action="store_true", help="Print scorecard without saving")
    parser.add_argument("--db-path", type=Path, default=resolve_sqlite_path(get_settings().database_url))
    args = parser.parse_args()

    from app.generation.service import DraftRequest, generate_draft

    database_url = f"sqlite:///{args.db_path}"
    configs_dir = ROOT_DIR / "configs"

    def _generate(prompt_text, *, database_url, configs_dir):
        response = generate_draft(
            DraftRequest(inbound_message=prompt_text),
            database_url=database_url,
            configs_dir=configs_dir,
        )
        return {
            "draft": response.draft,
            "detected_mode": response.detected_mode,
            "confidence": response.confidence,
        }

    summary = run_golden_eval(
        generate_fn=_generate,
        database_url=database_url,
        configs_dir=configs_dir,
        golden_path=args.golden,
    )

    print(format_scorecard(summary))

    if not args.summary_only:
        save_results(summary)
        print(f"\nResults saved to {RESULTS_PATH}")


if __name__ == "__main__":
    main()