文件预览

test_upload_files.py

查看 腾讯乐享知识库 Lexiang Knowledge Base 技能包中的文件内容。

文件内容

scripts/test_upload_files.py

"""
upload-files.py 单元测试

覆盖范围:
- Bug 4: asyncio.gather 缺少 return_exceptions=True(部分失败时其他结果不丢失)
- Bug 5: upload_file_async 中同步阻塞读文件(改为 run_in_executor 异步读取)
"""

import asyncio
import time
import pytest
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

# 将脚本目录加入 sys.path
import sys
import os
import importlib.util

# upload-files.py 文件名含连字符,无法直接 import,使用 importlib 动态加载
_script_dir = os.path.dirname(os.path.abspath(__file__))
_spec = importlib.util.spec_from_file_location(
    "upload_files",
    os.path.join(_script_dir, "upload-files.py"),
)
_mod = importlib.util.module_from_spec(_spec)
sys.modules["upload_files"] = _mod
_spec.loader.exec_module(_mod)

from upload_files import LexiangUploader, UploadTask, UploadSession


# ─────────────────────────────────────────────
# 辅助工厂
# ─────────────────────────────────────────────

def make_task(name: str = "test.md") -> UploadTask:
    return UploadTask(
        local_path=f"/tmp/{name}",
        file_name=name,
        parent_entry_id="entry-001",
        mime_type="text/markdown",
        file_size=1024,
        content_hash="abc123",
    )


def make_session(name: str = "test.md", upload_url: str = "http://example.com/upload") -> UploadSession:
    return UploadSession(
        session_id=f"sess-{name}",
        upload_url=upload_url,
        task=make_task(name),
    )


# ─────────────────────────────────────────────
# Bug 4: gather return_exceptions=True
# ─────────────────────────────────────────────

class TestExecuteParallelUploads:
    """测试 execute_parallel_uploads 的异常处理行为(Bug 4)"""

    @pytest.mark.asyncio
    async def test_all_success_returns_all_results(self):
        """全部成功时,返回所有 (session, True) 结果"""
        uploader = LexiangUploader(parallel=3)
        sessions = [make_session(f"file{i}.md", f"http://ok.com/{i}") for i in range(3)]

        async def mock_upload(url, path, mime):
            return True

        with patch.object(uploader, "upload_file_async", side_effect=mock_upload):
            results = await uploader.execute_parallel_uploads(sessions)

        assert len(results) == 3
        for session, success in results:
            assert success is True

    @pytest.mark.asyncio
    async def test_partial_failure_does_not_lose_other_results(self):
        """部分失败时,其他成功结果不丢失(Bug 4 核心验证)"""
        uploader = LexiangUploader(parallel=3)
        sessions = [
            make_session("ok1.md",   "http://ok.com/1"),
            make_session("fail.md",  "http://fail.com/2"),
            make_session("ok2.md",   "http://ok.com/3"),
        ]

        async def mock_upload(url, path, mime):
            if "fail" in url:
                raise RuntimeError("模拟上传失败")
            return True

        with patch.object(uploader, "upload_file_async", side_effect=mock_upload):
            results = await uploader.execute_parallel_uploads(sessions)

        # 三个结果都应该返回,不能因为一个失败而丢失其他
        assert len(results) == 3

        ok1_result   = next(r for r in results if r[0].task.file_name == "ok1.md")
        fail_result  = next(r for r in results if r[0].task.file_name == "fail.md")
        ok2_result   = next(r for r in results if r[0].task.file_name == "ok2.md")

        assert ok1_result[1]  is True,  "ok1.md 应该成功"
        assert ok2_result[1]  is True,  "ok2.md 应该成功"
        assert fail_result[1] is False, "fail.md 应该失败"

    @pytest.mark.asyncio
    async def test_failed_session_error_is_recorded(self):
        """失败的任务,错误信息应记录到 task.error"""
        uploader = LexiangUploader(parallel=2)
        sessions = [make_session("fail.md", "http://fail.com")]

        async def mock_upload(url, path, mime):
            raise ConnectionError("网络超时")

        with patch.object(uploader, "upload_file_async", side_effect=mock_upload):
            results = await uploader.execute_parallel_uploads(sessions)

        session, success = results[0]
        assert success is False
        assert "网络超时" in session.task.error

    @pytest.mark.asyncio
    async def test_all_fail_returns_all_false(self):
        """全部失败时,返回所有 (session, False) 结果,不抛异常"""
        uploader = LexiangUploader(parallel=2)
        sessions = [make_session(f"fail{i}.md") for i in range(3)]

        async def mock_upload(url, path, mime):
            raise RuntimeError("全部失败")

        with patch.object(uploader, "upload_file_async", side_effect=mock_upload):
            # 修复前:这里会直接抛出 RuntimeError
            results = await uploader.execute_parallel_uploads(sessions)

        assert len(results) == 3
        for _, success in results:
            assert success is False


# ─────────────────────────────────────────────
# Bug 5: 异步读文件不阻塞事件循环
# ─────────────────────────────────────────────

class TestUploadFileAsync:
    """测试 upload_file_async 不阻塞事件循环(Bug 5)"""

    @pytest.mark.asyncio
    async def test_parallel_uploads_are_concurrent(self, tmp_path):
        """并行上传时,两个任务应真正并发执行,总耗时接近单个任务耗时"""
        f1 = tmp_path / "a.md"
        f2 = tmp_path / "b.md"
        f1.write_bytes(b"x" * 1024)
        f2.write_bytes(b"x" * 1024)

        DELAY = 0.15  # 模拟每个上传耗时 150ms

        uploader = LexiangUploader(parallel=2)

        # mock_resp 是 async with session.put(...) as response 里的 response
        mock_resp = MagicMock()
        mock_resp.status = 200

        # mock_put_ctx 是 session.put(...) 返回的异步上下文管理器
        mock_put_ctx = MagicMock()
        mock_put_ctx.__aenter__ = AsyncMock(return_value=mock_resp)
        mock_put_ctx.__aexit__ = AsyncMock(return_value=False)

        # session.put 是普通方法(非 coroutine),返回异步上下文管理器
        # 同时在 __aenter__ 里注入延迟,模拟网络耗时
        async def slow_aenter(*args, **kwargs):
            await asyncio.sleep(DELAY)
            return mock_resp
        mock_put_ctx.__aenter__ = slow_aenter

        mock_session = MagicMock()
        mock_session.__aenter__ = AsyncMock(return_value=mock_session)
        mock_session.__aexit__ = AsyncMock(return_value=False)
        mock_session.put = MagicMock(return_value=mock_put_ctx)

        start = time.monotonic()
        with patch("aiohttp.ClientSession", return_value=mock_session):
            await asyncio.gather(
                uploader.upload_file_async("http://x.com/1", str(f1), "text/markdown"),
                uploader.upload_file_async("http://x.com/2", str(f2), "text/markdown"),
            )
        elapsed = time.monotonic() - start

        # 真正并发:耗时应接近 DELAY 而非 2*DELAY
        assert elapsed < DELAY * 1.8, (
            f"上传疑似被串行化,耗时 {elapsed:.2f}s,预期 < {DELAY * 1.8:.2f}s"
        )

    @pytest.mark.asyncio
    async def test_file_read_uses_executor(self, tmp_path):
        """文件读取应通过 run_in_executor 执行,不直接调用同步 open()"""
        test_file = tmp_path / "test.md"
        test_file.write_bytes(b"hello world")

        uploader = LexiangUploader()
        executor_called = []

        original_run_in_executor = asyncio.get_event_loop().run_in_executor

        async def patched_run_in_executor(executor, func, *args):
            executor_called.append(func)
            return await original_run_in_executor(executor, func, *args)

        mock_resp = MagicMock()
        mock_resp.status = 200
        mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
        mock_resp.__aexit__ = AsyncMock(return_value=False)

        mock_session = MagicMock()
        mock_session.__aenter__ = AsyncMock(return_value=mock_session)
        mock_session.__aexit__ = AsyncMock(return_value=False)
        mock_session.put = MagicMock(return_value=mock_resp)

        loop = asyncio.get_event_loop()
        with patch.object(loop, "run_in_executor", side_effect=patched_run_in_executor):
            with patch("aiohttp.ClientSession", return_value=mock_session):
                await uploader.upload_file_async(
                    "http://x.com/upload", str(test_file), "text/markdown"
                )

        # 应该至少有一次 run_in_executor 调用(用于读文件)
        assert len(executor_called) >= 1, "文件读取应通过 run_in_executor 异步执行"

    @pytest.mark.asyncio
    async def test_fallback_sync_uses_executor(self, tmp_path):
        """无 aiohttp 时,降级路径也应通过 run_in_executor 执行,不直接阻塞"""
        test_file = tmp_path / "test.md"
        test_file.write_bytes(b"hello")

        uploader = LexiangUploader()
        executor_called = []

        original_run_in_executor = asyncio.get_event_loop().run_in_executor

        async def patched_run_in_executor(executor, func, *args):
            executor_called.append(True)
            return await original_run_in_executor(executor, func, *args)

        def mock_upload_sync(url, path, mime):
            return True

        loop = asyncio.get_event_loop()
        with patch("upload_files.HAS_AIOHTTP", False):
            with patch.object(uploader, "upload_file_sync", side_effect=mock_upload_sync):
                with patch.object(loop, "run_in_executor", side_effect=patched_run_in_executor):
                    await uploader.upload_file_async(
                        "http://x.com/upload", str(test_file), "text/markdown"
                    )

        assert len(executor_called) >= 1, "降级路径也应通过 run_in_executor 执行"


# ─────────────────────────────────────────────
# 其他基础功能测试
# ─────────────────────────────────────────────

class TestGenerateUploadPlan:
    """测试上传计划生成"""

    def test_plan_contains_all_steps(self):
        """每个文件应生成 3 个步骤:申请、上传、确认"""
        uploader = LexiangUploader()
        tasks = [make_task("a.md"), make_task("b.md")]
        plan = uploader.generate_upload_plan(tasks)

        assert plan["total_files"] == 2
        assert len(plan["steps"]) == 6  # 每个文件 3 步

    def test_plan_total_size(self):
        """计划中的总大小应等于所有文件大小之和"""
        uploader = LexiangUploader()
        tasks = [make_task("a.md"), make_task("b.md")]
        plan = uploader.generate_upload_plan(tasks)

        assert plan["total_size"] == 2048  # 每个 task file_size=1024


class TestScanFolder:
    """测试文件夹扫描"""

    def test_scan_filters_by_extension(self, tmp_path):
        """扫描时应按扩展名过滤"""
        (tmp_path / "doc.md").write_text("# hello")
        (tmp_path / "image.png").write_bytes(b"\x89PNG")
        (tmp_path / "note.txt").write_text("note")

        uploader = LexiangUploader()
        tasks = uploader.scan_folder(str(tmp_path), "entry-001", extensions=[".md"])

        assert len(tasks) == 1
        assert tasks[0].file_name == "doc.md"

    def test_scan_ignores_patterns(self, tmp_path):
        """扫描时应忽略指定模式的路径"""
        node_modules = tmp_path / "node_modules"
        node_modules.mkdir()
        (node_modules / "pkg.md").write_text("pkg")
        (tmp_path / "readme.md").write_text("readme")

        uploader = LexiangUploader()
        tasks = uploader.scan_folder(str(tmp_path), "entry-001", extensions=[".md"])

        assert len(tasks) == 1
        assert tasks[0].file_name == "readme.md"