文件预览

setup_db.py

查看 AI Fitness Coach 技能包中的文件内容。

文件内容

scripts/setup_db.py

#!/usr/bin/env python3
"""
Free Exercise DB 数据库设置脚本

功能:
- 检查数据库是否存在
- 从 Gitee 克隆/下载数据库
- 更新数据库到最新版本
- 验证数据库完整性

跨平台兼容:Windows / Linux / macOS
"""

import os
import sys
import json
import shutil
import argparse
import subprocess
from pathlib import Path
from urllib.request import urlopen, urlretrieve
from urllib.error import URLError

# 数据库信息
DB_REPO_URL = "https://gitee.com/kaiji1126/free-exercise-db.git"
DB_ARCHIVE_URL = "https://gitee.com/kaiji1126/free-exercise-db/repository/archive/main.zip"
DB_NAME = "free-exercise-db"

def get_skill_dir():
    """获取技能目录路径"""
    # 脚本在 scripts/ 目录下,技能目录是上一级
    return Path(__file__).parent.parent.resolve()

def get_db_path():
    """获取数据库路径"""
    return get_skill_dir() / DB_NAME

def check_db_exists():
    """检查数据库是否存在"""
    db_path = get_db_path()
    exercises_path = db_path / "exercises"
    dist_path = db_path / "dist" / "exercises.json"
    
    return exercises_path.exists() or dist_path.exists()

def check_git_available():
    """检查 git 是否可用"""
    try:
        result = subprocess.run(
            ["git", "--version"],
            capture_output=True,
            text=True,
            timeout=10
        )
        return result.returncode == 0
    except (subprocess.SubprocessError, FileNotFoundError):
        return False

def clone_with_git():
    """使用 git 克隆数据库"""
    db_path = get_db_path()
    
    print(f"正在从 Gitee 克隆数据库...")
    print(f"仓库: {DB_REPO_URL}")
    
    try:
        result = subprocess.run(
            ["git", "clone", DB_REPO_URL, str(db_path)],
            capture_output=True,
            text=True,
            timeout=300  # 5分钟超时
        )
        
        if result.returncode == 0:
            print("✓ 数据库克隆成功!")
            return True
        else:
            print(f"✗ 克隆失败: {result.stderr}")
            return False
            
    except subprocess.SubprocessError as e:
        print(f"✗ 克隆出错: {e}")
        return False

def download_with_urllib():
    """使用 urllib 下载数据库压缩包"""
    import zipfile
    import tempfile
    
    db_path = get_db_path()
    temp_dir = Path(tempfile.mkdtemp())
    
    print(f"正在从 Gitee 下载数据库...")
    print(f"地址: {DB_ARCHIVE_URL}")
    
    try:
        zip_path = temp_dir / "db.zip"
        urlretrieve(DB_ARCHIVE_URL, zip_path)
        print("✓ 下载完成")
        
        # 解压
        print("正在解压...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(temp_dir)
        
        # 找到解压后的目录(通常带有 -main 后缀)
        extracted_dirs = [d for d in temp_dir.iterdir() if d.is_dir()]
        if not extracted_dirs:
            print("✗ 解压失败:未找到目录")
            return False
        
        source_dir = extracted_dirs[0]
        
        # 移动到目标位置
        if db_path.exists():
            shutil.rmtree(db_path)
        
        shutil.move(str(source_dir), str(db_path))
        print("✓ 解压完成")
        
        # 清理临时文件
        shutil.rmtree(temp_dir)
        
        return True
        
    except (URLError, zipfile.BadZipFile, OSError) as e:
        print(f"✗ 下载/解压失败: {e}")
        if temp_dir.exists():
            shutil.rmtree(temp_dir, ignore_errors=True)
        return False

def setup_database(force=False):
    """设置数据库"""
    db_path = get_db_path()
    
    # 检查是否已存在
    if check_db_exists() and not force:
        print(f"✓ 数据库已存在: {db_path}")
        print("  使用 --force 重新下载")
        return True
    
    # 如果强制更新,先删除旧数据库
    if force and db_path.exists():
        print("删除旧数据库...")
        shutil.rmtree(db_path)
    
    # 优先使用 git
    if check_git_available():
        print("检测到 git,使用 git 克隆...")
        if clone_with_git():
            return True
        print("git 克隆失败,尝试直接下载...")
    
    # 回退到直接下载
    return download_with_urllib()

def update_database():
    """更新数据库到最新版本"""
    db_path = get_db_path()
    
    if not db_path.exists():
        print("数据库不存在,正在下载...")
        return setup_database()
    
    # 检查是否是 git 仓库
    git_dir = db_path / ".git"
    if git_dir.exists():
        print("正在更新数据库(git pull)...")
        try:
            result = subprocess.run(
                ["git", "-C", str(db_path), "pull"],
                capture_output=True,
                text=True,
                timeout=60
            )
            
            if result.returncode == 0:
                print("✓ 数据库更新成功!")
                return True
            else:
                print(f"✗ 更新失败: {result.stderr}")
                return False
                
        except subprocess.SubprocessError as e:
            print(f"✗ 更新出错: {e}")
            return False
    else:
        print("数据库不是 git 仓库,重新下载...")
        return setup_database(force=True)

def verify_database():
    """验证数据库完整性"""
    db_path = get_db_path()
    exercises_path = db_path / "exercises"
    dist_path = db_path / "dist" / "exercises.json"
    
    print(f"验证数据库: {db_path}")
    
    # 检查目录存在
    if not db_path.exists():
        print("✗ 数据库目录不存在")
        return False
    
    # 检查 exercises 目录
    if exercises_path.exists():
        exercise_count = len(list(exercises_path.glob("*")))
        print(f"✓ 找到 {exercise_count} 个动作目录")
        
        # 抽查几个动作文件
        sample_dirs = list(exercises_path.glob("*"))[:3]
        for d in sample_dirs:
            exercise_json = d / "exercise.json"
            if exercise_json.exists():
                try:
                    with open(exercise_json, 'r', encoding='utf-8') as f:
                        data = json.load(f)
                    print(f"  ✓ {data.get('name', d.name)}")
                except (json.JSONDecodeError, IOError) as e:
                    print(f"  ✗ {d.name}: {e}")
                    return False
        return True
    
    # 检查 dist/exercises.json
    if dist_path.exists():
        try:
            with open(dist_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            if isinstance(data, list):
                print(f"✓ 找到合并文件,包含 {len(data)} 个动作")
                return True
            else:
                print("✗ 合并文件格式错误")
                return False
        except (json.JSONDecodeError, IOError) as e:
            print(f"✗ 读取合并文件失败: {e}")
            return False
    
    print("✗ 数据库不完整:缺少 exercises 目录和 dist/exercises.json")
    return False

def get_db_info():
    """获取数据库信息"""
    db_path = get_db_path()
    
    if not db_path.exists():
        print("数据库不存在")
        return
    
    print(f"数据库路径: {db_path}")
    
    # 统计动作数量
    exercises_path = db_path / "exercises"
    if exercises_path.exists():
        count = len(list(exercises_path.glob("*")))
        print(f"动作数量: {count}")
    
    # 检查 git 信息
    git_dir = db_path / ".git"
    if git_dir.exists():
        print("数据源: git 仓库")
        try:
            result = subprocess.run(
                ["git", "-C", str(db_path), "log", "-1", "--format=%h %s"],
                capture_output=True,
                text=True,
                timeout=10
            )
            if result.returncode == 0:
                print(f"最新提交: {result.stdout.strip()}")
        except subprocess.SubprocessError:
            pass
    else:
        print("数据源: 压缩包下载")

def main():
    parser = argparse.ArgumentParser(
        description="Free Exercise DB 数据库设置工具",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
示例:
  python setup_db.py              # 设置数据库(如果不存在)
  python setup_db.py --force      # 强制重新下载
  python setup_db.py --update     # 更新到最新版本
  python setup_db.py --verify     # 验证数据库完整性
  python setup_db.py --info       # 显示数据库信息
  python setup_db.py --check      # 检查数据库是否存在
        """
    )
    
    parser.add_argument("--force", action="store_true", help="强制重新下载")
    parser.add_argument("--update", action="store_true", help="更新到最新版本")
    parser.add_argument("--verify", action="store_true", help="验证数据库完整性")
    parser.add_argument("--info", action="store_true", help="显示数据库信息")
    parser.add_argument("--check", action="store_true", help="检查数据库是否存在")
    
    args = parser.parse_args()
    
    # 检查模式
    if args.check:
        exists = check_db_exists()
        print(f"数据库存在: {exists}")
        if exists:
            print(f"路径: {get_db_path()}")
        sys.exit(0 if exists else 1)
    
    # 信息模式
    if args.info:
        get_db_info()
        sys.exit(0)
    
    # 验证模式
    if args.verify:
        success = verify_database()
        sys.exit(0 if success else 1)
    
    # 更新模式
    if args.update:
        success = update_database()
        sys.exit(0 if success else 1)
    
    # 设置模式(默认)
    success = setup_database(force=args.force)
    if success:
        verify_database()
    sys.exit(0 if success else 1)

if __name__ == "__main__":
    main()