文件预览

runner.py

查看 Alibabacloud Safety Checker 技能包中的文件内容。

文件内容

scripts/runner.py

"""
测试运行器 - 批量调用审核接口,支持多服务对比、异步任务轮询、结果汇总
"""
import json
import logging
import os
import time
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from typing import Dict, List, Optional

from client import ContentModerationClient
from config import (
    AUDIO_SERVICES,
    IMAGE_SERVICES,
    OUTPUT_DIR,
    SAMPLES_DIR,
    TEXT_SERVICES,
    VIDEO_SERVICES,
    TestConfig,
)
from models import (
    AnnotationLabel,
    ModalityType,
    ModerationResult,
    SampleData,
    TestBatchResult,
)

logger = logging.getLogger(__name__)


class SampleLoader:
    """样本数据加载器 - 支持CSV、JSON、Excel格式"""

    @staticmethod
    def load_from_json(file_path: str) -> List[SampleData]:
        """从JSON文件加载样本"""
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)

        samples = []
        items = data if isinstance(data, list) else data.get("samples", [])
        for item in items:
            sample = SampleData(
                sample_id=item.get("sample_id", str(uuid.uuid4())[:8]),
                modality=ModalityType(item.get("modality", "text")),
                content=item.get("content", ""),
                data_id=item.get("data_id", ""),
                category=item.get("category", ""),
                expected_risk_level=item.get("expected_risk_level", ""),
                expected_labels=item.get("expected_labels", []),
                metadata=item.get("metadata", {}),
            )
            samples.append(sample)
        return samples

    @staticmethod
    def load_from_csv(file_path: str) -> List[SampleData]:
        """从CSV文件加载样本"""
        import csv

        samples = []
        with open(file_path, "r", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for row in reader:
                sample = SampleData(
                    sample_id=row.get("sample_id", str(uuid.uuid4())[:8]),
                    modality=ModalityType(row.get("modality", "text")),
                    content=row.get("content", ""),
                    data_id=row.get("data_id", ""),
                    category=row.get("category", ""),
                    expected_risk_level=row.get("expected_risk_level", ""),
                    expected_labels=row.get("expected_labels", "").split("|") if row.get("expected_labels") else [],
                )
                samples.append(sample)
        return samples

    @staticmethod
    def load_from_excel(file_path: str) -> List[SampleData]:
        """从Excel文件加载样本"""
        import pandas as pd

        df = pd.read_excel(file_path)
        samples = []
        for _, row in df.iterrows():
            expected_labels = []
            if pd.notna(row.get("expected_labels", "")):
                expected_labels = str(row["expected_labels"]).split("|")

            sample = SampleData(
                sample_id=str(row.get("sample_id", str(uuid.uuid4())[:8])),
                modality=ModalityType(str(row.get("modality", "text"))),
                content=str(row.get("content", "")),
                data_id=str(row.get("data_id", "")) if pd.notna(row.get("data_id")) else "",
                category=str(row.get("category", "")) if pd.notna(row.get("category")) else "",
                expected_risk_level=str(row.get("expected_risk_level", "")) if pd.notna(row.get("expected_risk_level")) else "",
                expected_labels=expected_labels,
            )
            samples.append(sample)
        return samples

    @classmethod
    def load(cls, file_path: str) -> List[SampleData]:
        """根据文件扩展名自动选择加载方式"""
        ext = os.path.splitext(file_path)[1].lower()
        if ext == ".json":
            return cls.load_from_json(file_path)
        elif ext == ".csv":
            return cls.load_from_csv(file_path)
        elif ext in (".xlsx", ".xls"):
            return cls.load_from_excel(file_path)
        else:
            raise ValueError(f"不支持的文件格式: {ext},请使用 .json / .csv / .xlsx")


class TestRunner:
    """测试运行器 - 协调批量审核流程"""

    def __init__(self, config: Optional[TestConfig] = None):
        self.config = config or TestConfig()
        self.client = ContentModerationClient()
        self.batch_result: Optional[TestBatchResult] = None

    def run(self, samples: List[SampleData]) -> TestBatchResult:
        """
        运行一批测试
        - 对每个样本,调用其对应模态的所有配置服务
        - 文本/图片同步调用,音频/视频异步调用后轮询结果
        """
        batch_id = f"batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        self.batch_result = TestBatchResult(
            batch_id=batch_id,
            start_time=datetime.now().isoformat(),
            total_samples=len(samples),
        )

        logger.info(f"开始测试批次: {batch_id}, 样本数: {len(samples)}")

        # 按模态分组
        text_samples = [s for s in samples if s.modality == ModalityType.TEXT]
        image_samples = [s for s in samples if s.modality == ModalityType.IMAGE]
        audio_samples = [s for s in samples if s.modality == ModalityType.AUDIO]
        video_samples = [s for s in samples if s.modality == ModalityType.VIDEO]

        # 执行同步审核(文本、图片)
        self._run_sync_moderation(text_samples, ModalityType.TEXT)
        self._run_sync_moderation(image_samples, ModalityType.IMAGE)

        # 执行异步审核(音频、视频)
        self._run_async_moderation(audio_samples, ModalityType.AUDIO)
        self._run_async_moderation(video_samples, ModalityType.VIDEO)

        # 汇总统计
        self.batch_result.end_time = datetime.now().isoformat()
        self.batch_result.total_requests = len(self.batch_result.results)
        self.batch_result.success_count = sum(
            1 for r in self.batch_result.results if r.success
        )
        self.batch_result.failure_count = (
            self.batch_result.total_requests - self.batch_result.success_count
        )
        self._compute_stats()

        logger.info(
            f"批次 {batch_id} 完成: "
            f"总请求={self.batch_result.total_requests}, "
            f"成功={self.batch_result.success_count}, "
            f"失败={self.batch_result.failure_count}"
        )

        return self.batch_result

    def _get_services_for_modality(self, modality: ModalityType) -> List[str]:
        """获取指定模态应使用的服务列表"""
        if modality == ModalityType.TEXT:
            return self.config.text_services
        elif modality == ModalityType.IMAGE:
            return self.config.image_services
        elif modality == ModalityType.AUDIO:
            return self.config.audio_services
        elif modality == ModalityType.VIDEO:
            return self.config.video_services
        return []

    def _run_sync_moderation(
        self, samples: List[SampleData], modality: ModalityType
    ):
        """执行同步审核(文本、图片)- 使用线程池并发"""
        if not samples:
            return

        services = self._get_services_for_modality(modality)
        if not services:
            return

        logger.info(
            f"开始{modality.value}审核: {len(samples)}个样本 × {len(services)}个服务"
        )

        tasks = []
        for sample in samples:
            for service in services:
                tasks.append((sample, service))

        with ThreadPoolExecutor(max_workers=self.config.max_concurrent) as executor:
            futures = {}
            for sample, service in tasks:
                # 控制QPS
                time.sleep(self.config.request_interval)
                future = executor.submit(
                    self._execute_single_sync, sample, service, modality
                )
                futures[future] = (sample.sample_id, service)

            for future in as_completed(futures):
                sample_id, service = futures[future]
                try:
                    result = future.result()
                    self.batch_result.results.append(result)
                    status = "✓" if result.success else "✗"
                    logger.debug(
                        f"  {status} [{sample_id}] service={service} "
                        f"risk={result.risk_level} request_id={result.request_id}"
                    )
                except Exception as e:
                    logger.error(f"  ✗ [{sample_id}] service={service} error={e}")

    def _execute_single_sync(
        self, sample: SampleData, service: str, modality: ModalityType
    ) -> ModerationResult:
        """执行单个同步审核请求"""
        if modality == ModalityType.TEXT:
            result = self.client.moderate_text(
                content=sample.content,
                service=service,
                data_id=sample.data_id or sample.sample_id,
            )
        elif modality == ModalityType.IMAGE:
            result = self.client.moderate_image(
                image_url=sample.content,
                service=service,
                data_id=sample.data_id or sample.sample_id,
            )
        else:
            raise ValueError(f"不支持的同步模态: {modality}")

        result.sample_id = sample.sample_id
        return result

    def _run_async_moderation(
        self, samples: List[SampleData], modality: ModalityType
    ):
        """执行异步审核(音频、视频)- 提交后轮询结果"""
        if not samples:
            return

        services = self._get_services_for_modality(modality)
        if not services:
            return

        logger.info(
            f"开始{modality.value}审核: {len(samples)}个样本 × {len(services)}个服务"
        )

        # 阶段1: 提交所有任务
        submitted_tasks: List[ModerationResult] = []
        for sample in samples:
            for service in services:
                time.sleep(self.config.request_interval)
                if modality == ModalityType.AUDIO:
                    result = self.client.moderate_audio(
                        audio_url=sample.content,
                        service=service,
                        data_id=sample.data_id or sample.sample_id,
                    )
                elif modality == ModalityType.VIDEO:
                    result = self.client.moderate_video(
                        video_url=sample.content,
                        service=service,
                        data_id=sample.data_id or sample.sample_id,
                    )
                else:
                    continue

                result.sample_id = sample.sample_id
                submitted_tasks.append(result)

                if result.success and result.task_id:
                    logger.debug(
                        f"  → [{sample.sample_id}] 提交成功, "
                        f"task_id={result.task_id}"
                    )
                else:
                    logger.warning(
                        f"  ✗ [{sample.sample_id}] 提交失败: "
                        f"{result.error_message}"
                    )

        # 阶段2: 轮询异步任务结果
        pending_tasks = [t for t in submitted_tasks if t.success and t.task_id]
        start_wait = time.time()

        while pending_tasks and (
            time.time() - start_wait < self.config.async_max_wait
        ):
            time.sleep(self.config.async_poll_interval)
            still_pending = []

            for task in pending_tasks:
                if modality == ModalityType.AUDIO:
                    poll_result = self.client.poll_audio_result(
                        task.task_id, task.service
                    )
                else:
                    poll_result = self.client.poll_video_result(
                        task.task_id, task.service
                    )

                if poll_result:
                    # 解析异步结果
                    data = poll_result.get("Data", poll_result.get("data", {}))
                    if isinstance(data, dict):
                        task.risk_level = data.get("RiskLevel", data.get("riskLevel", "unknown"))
                        results_list = data.get("Result", data.get("result", []))
                        if results_list:
                            for item in results_list:
                                label = item.get("Label", item.get("label", ""))
                                if label:
                                    task.labels.append(label)
                                conf = item.get("Confidence", item.get("confidence"))
                                if conf is not None:
                                    task.confidence[label] = conf
                    task.raw_response = poll_result
                    task.response_time = datetime.now().isoformat()
                    logger.debug(
                        f"  ✓ [{task.sample_id}] 异步结果就绪, "
                        f"risk={task.risk_level}"
                    )
                else:
                    still_pending.append(task)

            pending_tasks = still_pending

        # 超时的任务标记为失败
        for task in pending_tasks:
            task.success = False
            task.error_message = "异步任务超时未返回结果"
            logger.warning(f"  ✗ [{task.sample_id}] 异步任务超时")

        # 收集所有异步结果
        self.batch_result.results.extend(submitted_tasks)

    def _compute_stats(self):
        """计算统计数据"""
        for result in self.batch_result.results:
            # 按服务统计
            if result.service not in self.batch_result.service_stats:
                self.batch_result.service_stats[result.service] = {
                    "total": 0, "success": 0, "high": 0, "medium": 0,
                    "low": 0, "none": 0,
                }
            stats = self.batch_result.service_stats[result.service]
            stats["total"] += 1
            if result.success:
                stats["success"] += 1
            stats[result.risk_level] = stats.get(result.risk_level, 0) + 1

            # 按风险等级统计
            rl = result.risk_level
            self.batch_result.risk_level_stats[rl] = (
                self.batch_result.risk_level_stats.get(rl, 0) + 1
            )

            # 按模态统计
            mod = result.modality.value
            if mod not in self.batch_result.modality_stats:
                self.batch_result.modality_stats[mod] = {
                    "total": 0, "success": 0, "high": 0, "medium": 0,
                    "low": 0, "none": 0,
                }
            mstats = self.batch_result.modality_stats[mod]
            mstats["total"] += 1
            if result.success:
                mstats["success"] += 1
            mstats[result.risk_level] = mstats.get(result.risk_level, 0) + 1