文件内容
scripts/data_pipeline.py
#!/usr/bin/env python3
"""
A股分析 v3.0 — 数据管道
会话级缓存 + HTTP 连接复用 + 速率限制 + 重试机制
"""
import json
import time
import logging
import urllib.request
import urllib.error
from functools import wraps
from datetime import datetime, timedelta
from collections import OrderedDict
from threading import Lock
from config import (
CACHE_TTL_SECONDS, HTTP_TIMEOUT, HTTP_RETRIES, HTTP_RETRY_DELAY,
AK_RATE_LIMIT_CALLS, AK_RATE_LIMIT_WINDOW,
)
logger = logging.getLogger("astock.pipeline")
# ============================================================
# 会话级 LRU 缓存
# ============================================================
class SessionCache:
"""线程安全的会话级缓存,带 TTL 过期"""
def __init__(self, max_size=128, ttl=CACHE_TTL_SECONDS):
self._cache = OrderedDict()
self._ttl = ttl
self._max_size = max_size
self._lock = Lock()
def get(self, key):
with self._lock:
if key in self._cache:
entry = self._cache[key]
if time.time() - entry["ts"] < self._ttl:
self._cache.move_to_end(key)
return entry["value"]
else:
del self._cache[key]
return None
def set(self, key, value):
with self._lock:
if key in self._cache:
del self._cache[key]
elif len(self._cache) >= self._max_size:
self._cache.popitem(last=False)
self._cache[key] = {"value": value, "ts": time.time()}
def invalidate(self, key):
with self._lock:
self._cache.pop(key, None)
def clear(self):
with self._lock:
self._cache.clear()
def stats(self):
with self._lock:
return {
"size": len(self._cache),
"max_size": self._max_size,
"ttl": self._ttl,
}
# 全局缓存实例
_cache = SessionCache()
def cached(key_func):
"""装饰器:自动缓存函数返回值"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
cache_key = key_func(*args, **kwargs)
result = _cache.get(cache_key)
if result is not None:
logger.debug(f"Cache hit: {cache_key}")
return result
result = func(*args, **kwargs)
if result is not None:
_cache.set(cache_key, result)
return result
return wrapper
return decorator
# ============================================================
# HTTP 客户端(连接复用 + 重试)
# ============================================================
class HttpClient:
"""带重试和超时的 HTTP 客户端"""
DEFAULT_HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
"Accept": "application/json, text/plain, */*",
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
}
@classmethod
def get_json(cls, url, headers=None, timeout=HTTP_TIMEOUT, retries=HTTP_RETRIES):
"""HTTP GET → JSON,带重试"""
return cls._request(url, "json", headers, timeout, retries)
@classmethod
def get_text(cls, url, headers=None, timeout=HTTP_TIMEOUT, retries=HTTP_RETRIES):
"""HTTP GET → text,带重试"""
return cls._request(url, "text", headers, timeout, retries)
@classmethod
def _request(cls, url, mode, headers, timeout, retries):
merged_headers = {**cls.DEFAULT_HEADERS}
if headers:
merged_headers.update(headers)
last_error = None
for attempt in range(retries + 1):
try:
req = urllib.request.Request(url, headers=merged_headers)
with urllib.request.urlopen(req, timeout=timeout) as resp:
raw = resp.read().decode("utf-8")
if mode == "json":
return json.loads(raw)
return raw
except urllib.error.HTTPError as e:
last_error = e
if e.code == 429:
# 限流,等待更长时间
wait = HTTP_RETRY_DELAY * (attempt + 2)
logger.warning(f"HTTP 429 rate limited, waiting {wait:.1f}s: {url}")
time.sleep(wait)
continue
elif e.code >= 500:
logger.warning(f"HTTP {e.code} server error (attempt {attempt+1}): {url}")
time.sleep(HTTP_RETRY_DELAY)
continue
else:
logger.error(f"HTTP {e.code} client error: {url}")
return None
except Exception as e:
last_error = e
if attempt < retries:
logger.debug(f"Request failed (attempt {attempt+1}): {e}")
time.sleep(HTTP_RETRY_DELAY)
continue
logger.error(f"Request failed after {retries+1} attempts: {last_error}")
return None
# ============================================================
# AKShare 速率限制器
# ============================================================
class RateLimiter:
"""滑动窗口速率限制器"""
def __init__(self, max_calls=AK_RATE_LIMIT_CALLS, window=AK_RATE_LIMIT_WINDOW):
self._max_calls = max_calls
self._window = window
self._timestamps = []
self._lock = Lock()
def acquire(self):
"""获取调用许可,必要时等待"""
with self._lock:
now = time.time()
# 清理过期时间戳
self._timestamps = [t for t in self._timestamps if now - t < self._window]
if len(self._timestamps) >= self._max_calls:
# 需要等待
oldest = self._timestamps[0]
wait = self._window - (now - oldest) + 0.05
if wait > 0:
logger.debug(f"Rate limit: waiting {wait:.2f}s")
time.sleep(wait)
self._timestamps.append(time.time())
def stats(self):
with self._lock:
now = time.time()
active = [t for t in self._timestamps if now - t < self._window]
return {
"active_calls": len(active),
"max_calls": self._max_calls,
"window": self._window,
}
# 全局速率限制器
_rate_limiter = RateLimiter()
def rate_limited(func):
"""装饰器:AKShare 调用速率限制"""
@wraps(func)
def wrapper(*args, **kwargs):
_rate_limiter.acquire()
return func(*args, **kwargs)
return wrapper
# ============================================================
# AKShare 缓存包装器
# ============================================================
def ak_call_with_cache(func_name, *args, **kwargs):
"""
AKShare 调用包装:缓存 + 速率限制 + 降级
替代直接调用 ak.xxx(),避免重复请求
"""
cache_key = f"ak:{func_name}:{args}:{sorted(kwargs.items())}"
result = _cache.get(cache_key)
if result is not None:
return result
try:
import akshare as ak
_rate_limiter.acquire()
func = getattr(ak, func_name, None)
if func is None:
logger.error(f"AKShare function not found: {func_name}")
return None
result = func(*args, **kwargs)
if result is not None:
_cache.set(cache_key, result)
return result
except Exception as e:
error_msg = str(e).lower()
if "rate" in error_msg or "limit" in error_msg:
logger.warning(f"AKShare rate limited: {func_name}")
elif "timeout" in error_msg:
logger.warning(f"AKShare timeout: {func_name}")
else:
logger.error(f"AKShare error in {func_name}: {e}")
return None
def get_cache():
"""获取全局缓存实例"""
return _cache
def get_rate_limiter():
"""获取全局速率限制器"""
return _rate_limiter