文件预览

data_pipeline.py

查看 Cny Rmb China A Shares Stock 技能包中的文件内容。

文件内容

scripts/data_pipeline.py

#!/usr/bin/env python3
"""
A股分析 v3.0 — 数据管道
会话级缓存 + HTTP 连接复用 + 速率限制 + 重试机制
"""

import json
import time
import logging
import urllib.request
import urllib.error
from functools import wraps
from datetime import datetime, timedelta
from collections import OrderedDict
from threading import Lock

from config import (
    CACHE_TTL_SECONDS, HTTP_TIMEOUT, HTTP_RETRIES, HTTP_RETRY_DELAY,
    AK_RATE_LIMIT_CALLS, AK_RATE_LIMIT_WINDOW,
)

logger = logging.getLogger("astock.pipeline")


# ============================================================
# 会话级 LRU 缓存
# ============================================================
class SessionCache:
    """线程安全的会话级缓存,带 TTL 过期"""

    def __init__(self, max_size=128, ttl=CACHE_TTL_SECONDS):
        self._cache = OrderedDict()
        self._ttl = ttl
        self._max_size = max_size
        self._lock = Lock()

    def get(self, key):
        with self._lock:
            if key in self._cache:
                entry = self._cache[key]
                if time.time() - entry["ts"] < self._ttl:
                    self._cache.move_to_end(key)
                    return entry["value"]
                else:
                    del self._cache[key]
        return None

    def set(self, key, value):
        with self._lock:
            if key in self._cache:
                del self._cache[key]
            elif len(self._cache) >= self._max_size:
                self._cache.popitem(last=False)
            self._cache[key] = {"value": value, "ts": time.time()}

    def invalidate(self, key):
        with self._lock:
            self._cache.pop(key, None)

    def clear(self):
        with self._lock:
            self._cache.clear()

    def stats(self):
        with self._lock:
            return {
                "size": len(self._cache),
                "max_size": self._max_size,
                "ttl": self._ttl,
            }


# 全局缓存实例
_cache = SessionCache()


def cached(key_func):
    """装饰器:自动缓存函数返回值"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            cache_key = key_func(*args, **kwargs)
            result = _cache.get(cache_key)
            if result is not None:
                logger.debug(f"Cache hit: {cache_key}")
                return result
            result = func(*args, **kwargs)
            if result is not None:
                _cache.set(cache_key, result)
            return result
        return wrapper
    return decorator


# ============================================================
# HTTP 客户端(连接复用 + 重试)
# ============================================================
class HttpClient:
    """带重试和超时的 HTTP 客户端"""

    DEFAULT_HEADERS = {
        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
        "Accept": "application/json, text/plain, */*",
        "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
    }

    @classmethod
    def get_json(cls, url, headers=None, timeout=HTTP_TIMEOUT, retries=HTTP_RETRIES):
        """HTTP GET → JSON,带重试"""
        return cls._request(url, "json", headers, timeout, retries)

    @classmethod
    def get_text(cls, url, headers=None, timeout=HTTP_TIMEOUT, retries=HTTP_RETRIES):
        """HTTP GET → text,带重试"""
        return cls._request(url, "text", headers, timeout, retries)

    @classmethod
    def _request(cls, url, mode, headers, timeout, retries):
        merged_headers = {**cls.DEFAULT_HEADERS}
        if headers:
            merged_headers.update(headers)

        last_error = None
        for attempt in range(retries + 1):
            try:
                req = urllib.request.Request(url, headers=merged_headers)
                with urllib.request.urlopen(req, timeout=timeout) as resp:
                    raw = resp.read().decode("utf-8")
                    if mode == "json":
                        return json.loads(raw)
                    return raw
            except urllib.error.HTTPError as e:
                last_error = e
                if e.code == 429:
                    # 限流,等待更长时间
                    wait = HTTP_RETRY_DELAY * (attempt + 2)
                    logger.warning(f"HTTP 429 rate limited, waiting {wait:.1f}s: {url}")
                    time.sleep(wait)
                    continue
                elif e.code >= 500:
                    logger.warning(f"HTTP {e.code} server error (attempt {attempt+1}): {url}")
                    time.sleep(HTTP_RETRY_DELAY)
                    continue
                else:
                    logger.error(f"HTTP {e.code} client error: {url}")
                    return None
            except Exception as e:
                last_error = e
                if attempt < retries:
                    logger.debug(f"Request failed (attempt {attempt+1}): {e}")
                    time.sleep(HTTP_RETRY_DELAY)
                    continue

        logger.error(f"Request failed after {retries+1} attempts: {last_error}")
        return None


# ============================================================
# AKShare 速率限制器
# ============================================================
class RateLimiter:
    """滑动窗口速率限制器"""

    def __init__(self, max_calls=AK_RATE_LIMIT_CALLS, window=AK_RATE_LIMIT_WINDOW):
        self._max_calls = max_calls
        self._window = window
        self._timestamps = []
        self._lock = Lock()

    def acquire(self):
        """获取调用许可,必要时等待"""
        with self._lock:
            now = time.time()
            # 清理过期时间戳
            self._timestamps = [t for t in self._timestamps if now - t < self._window]
            if len(self._timestamps) >= self._max_calls:
                # 需要等待
                oldest = self._timestamps[0]
                wait = self._window - (now - oldest) + 0.05
                if wait > 0:
                    logger.debug(f"Rate limit: waiting {wait:.2f}s")
                    time.sleep(wait)
            self._timestamps.append(time.time())

    def stats(self):
        with self._lock:
            now = time.time()
            active = [t for t in self._timestamps if now - t < self._window]
            return {
                "active_calls": len(active),
                "max_calls": self._max_calls,
                "window": self._window,
            }


# 全局速率限制器
_rate_limiter = RateLimiter()


def rate_limited(func):
    """装饰器:AKShare 调用速率限制"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        _rate_limiter.acquire()
        return func(*args, **kwargs)
    return wrapper


# ============================================================
# AKShare 缓存包装器
# ============================================================
def ak_call_with_cache(func_name, *args, **kwargs):
    """
    AKShare 调用包装:缓存 + 速率限制 + 降级
    替代直接调用 ak.xxx(),避免重复请求
    """
    cache_key = f"ak:{func_name}:{args}:{sorted(kwargs.items())}"
    result = _cache.get(cache_key)
    if result is not None:
        return result

    try:
        import akshare as ak
        _rate_limiter.acquire()
        func = getattr(ak, func_name, None)
        if func is None:
            logger.error(f"AKShare function not found: {func_name}")
            return None
        result = func(*args, **kwargs)
        if result is not None:
            _cache.set(cache_key, result)
        return result
    except Exception as e:
        error_msg = str(e).lower()
        if "rate" in error_msg or "limit" in error_msg:
            logger.warning(f"AKShare rate limited: {func_name}")
        elif "timeout" in error_msg:
            logger.warning(f"AKShare timeout: {func_name}")
        else:
            logger.error(f"AKShare error in {func_name}: {e}")
        return None


def get_cache():
    """获取全局缓存实例"""
    return _cache


def get_rate_limiter():
    """获取全局速率限制器"""
    return _rate_limiter