文件预览

engine.py

查看 Cny Rmb China A Shares Stock 技能包中的文件内容。

文件内容

scripts/engine.py

#!/usr/bin/env python3
"""
A股分析 v3.0 — 数据引擎
核心编排:并行数据采集 + 缓存去重 + 降级策略
"""

import logging
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed

from config import MARKET_INDICES, VERSION
from data_pipeline import cached, ak_call_with_cache, _cache
from holiday import is_trading_day, get_last_trading_day, format_trading_status

logger = logging.getLogger("astock.engine")


# ============================================================
# 大盘数据
# ============================================================
def fetch_market_overview():
    """大盘指数概览(带缓存)"""
    results = []
    for symbol, name in MARKET_INDICES:
        try:
            df = ak_call_with_cache("stock_zh_index_daily", symbol=symbol)
            if df is not None and len(df) >= 2:
                close = float(df.iloc[-1]["close"])
                prev = float(df.iloc[-2]["close"])
                pct = (close - prev) / prev * 100
                results.append({
                    "name": name,
                    "close": round(close, 2),
                    "change_pct": round(pct, 2),
                    "emoji": "🟢" if pct > 0 else "🔴" if pct < 0 else "⚪",
                })
        except Exception as e:
            logger.warning(f"Failed to fetch {name}: {e}")
    return results


def fetch_market_trend(days=5):
    """近 N 日趋势"""
    results = {}
    for symbol, name in MARKET_INDICES:
        try:
            df = ak_call_with_cache("stock_zh_index_daily", symbol=symbol)
            if df is not None and len(df) >= days:
                recent = df.tail(days)
                trend = [
                    {"date": str(r.get("date", "")), "close": round(float(r["close"]), 2)}
                    for _, r in recent.iterrows()
                ]
                if len(trend) >= 2:
                    total = (trend[-1]["close"] - trend[0]["close"]) / trend[0]["close"] * 100
                    results[name] = {
                        "trend": trend,
                        "total_change_pct": round(total, 2),
                        "direction": "📈" if total > 0 else "📉" if total < 0 else "➡️",
                    }
        except Exception:
            continue
    return results


def fetch_zt_dt():
    """涨停/跌停数据"""
    today = datetime.now().strftime("%Y%m%d")
    result = {"涨停": [], "跌停": []}

    try:
        zt_df = ak_call_with_cache("stock_zt_pool_em", date=today)
        if zt_df is not None and not zt_df.empty:
            for _, row in zt_df.head(15).iterrows():
                result["涨停"].append({
                    "code": str(row.get("代码", "")),
                    "name": str(row.get("名称", "")),
                    "change_pct": float(row.get("涨跌幅", 0)),
                    "reason": str(row.get("涨停原因", "")),
                    "turnover": str(row.get("换手率", "")),
                })
    except Exception:
        pass

    try:
        dt_df = ak_call_with_cache("stock_zt_pool_dtgc_em", date=today)
        if dt_df is not None and not dt_df.empty:
            for _, row in dt_df.head(15).iterrows():
                result["跌停"].append({
                    "code": str(row.get("代码", "")),
                    "name": str(row.get("名称", "")),
                    "change_pct": float(row.get("涨跌幅", 0)),
                })
    except Exception:
        pass

    return result


def fetch_hot_sectors():
    """热门板块"""
    try:
        df = ak_call_with_cache("stock_board_industry_name_em")
        if df is not None and not df.empty:
            sectors = []
            for _, row in df.head(15).iterrows():
                try:
                    lc = float(row["领涨股票-涨跌幅"]) if "领涨股票-涨跌幅" in row.index else 0
                except (ValueError, TypeError):
                    lc = 0
                try:
                    uc = int(row["上涨家数"]) if "上涨家数" in row.index else 0
                except (ValueError, TypeError):
                    uc = 0
                try:
                    dc = int(row["下跌家数"]) if "下跌家数" in row.index else 0
                except (ValueError, TypeError):
                    dc = 0
                sectors.append({
                    "name": str(row.get("板块名称", "")),
                    "change_pct": float(row.get("涨跌幅", 0)),
                    "leader": str(row.get("领涨股票", "")),
                    "leader_change": lc,
                    "up_count": uc,
                    "down_count": dc,
                })
            return sectors
    except Exception as e:
        logger.warning(f"Sector data failed: {e}")
    return []


def fetch_single_stock(code):
    """查询单只股票"""
    try:
        df = ak_call_with_cache("stock_zh_a_spot_em")
        if df is not None and not df.empty:
            row = df[df["代码"] == code]
            if row.empty:
                return None
            r = row.iloc[0]
            return {
                "code": str(r.get("代码", "")),
                "name": str(r.get("名称", "")),
                "price": float(r.get("最新价", 0)),
                "change_pct": float(r.get("涨跌幅", 0)),
                "change_amt": float(r.get("涨跌额", 0)),
                "volume": float(r.get("成交量", 0)),
                "turnover": float(r.get("成交额", 0)),
                "high": float(r.get("最高", 0)),
                "low": float(r.get("最低", 0)),
                "open": float(r.get("今开", 0)),
                "prev_close": float(r.get("昨收", 0)),
                "pe": r.get("市盈率-动态", ""),
                "total_mv": r.get("总市值", ""),
                "circ_mv": r.get("流通市值", ""),
                "turnover_rate": r.get("换手率", ""),
            }
    except Exception as e:
        logger.warning(f"Stock query failed for {code}: {e}")
    return None


# ============================================================
# 关联分析
# ============================================================
def analyze_correlation(stock_hot, zt_dt, sectors):
    """热搜 vs 涨跌关联"""
    analysis = {"hot_stock_mentions": [], "hot_and_zt": [], "hot_sectors": [], "insights": []}
    stock_names = [
        i["keyword"].replace("#", "")
        for i in stock_hot
        if 2 <= len(i["keyword"].replace("#", "")) <= 8
    ]
    analysis["hot_stock_mentions"] = stock_names[:10]

    zt_names = [i["name"] for i in zt_dt.get("涨停", [])]
    overlap = [n for n in stock_names if n in zt_names]
    if overlap:
        analysis["hot_and_zt"] = overlap
        analysis["insights"].append(f"🔥 同时出现在热搜和涨停板: {', '.join(overlap)}")

    if zt_dt.get("涨停"):
        reasons = {}
        for item in zt_dt["涨停"]:
            r = item.get("reason", "").strip()
            if r:
                reasons[r] = reasons.get(r, 0) + 1
        for r, c in sorted(reasons.items(), key=lambda x: -x[1])[:3]:
            analysis["insights"].append(f"📈 涨停原因「{r}」: {c} 只")

    return analysis


def analyze_sector_rotation():
    """板块轮动分析"""
    from utils.common import load_recent_snapshots
    snapshots = load_recent_snapshots(3)
    if len(snapshots) < 2:
        return None

    today_sectors = {s["name"]: s["change_pct"] for s in snapshots[-1].get("sectors", [])}
    prev_sectors = {s["name"]: s["change_pct"] for s in snapshots[-2].get("sectors", [])}

    today_set = set(today_sectors.keys())
    prev_set = set(prev_sectors.keys())

    hot_rotation = []
    for name in today_set & prev_set:
        diff = today_sectors[name] - prev_sectors.get(name, 0)
        if abs(diff) > 1:
            hot_rotation.append({"name": name, "today": today_sectors[name], "change": round(diff, 2)})
    hot_rotation.sort(key=lambda x: -x["change"])

    return {
        "new": list(today_set - prev_set)[:5],
        "gone": list(prev_set - today_set)[:5],
        "hot": hot_rotation[:5],
    }


# ============================================================
# 主数据采集(并行)
# ============================================================
def collect_all_data(args):
    """
    并行采集所有数据
    返回完整数据字典
    """
    from sources.multi_platform import fetch_all_hot_sources, filter_stock_keywords, merge_multi_platform_results
    from analysis.capital import CapitalAnalyzer
    from analysis.sentiment import SentimentAnalyzer
    from analysis.sector import SectorAnalyzer

    data = {}
    watchlist = args.watchlist.split(",") if hasattr(args, 'watchlist') and args.watchlist else None

    with ThreadPoolExecutor(max_workers=12) as executor:
        futures = {}

        # 多平台热搜
        if not getattr(args, 'no_weibo', False):
            platforms = args.platforms.split(",") if getattr(args, 'platforms', None) else None
            futures["multi_hot"] = executor.submit(
                fetch_all_hot_sources, platforms=platforms, proxy=getattr(args, 'proxy', None)
            )

        # 行情数据
        if not getattr(args, 'no_market', False):
            futures["market"] = executor.submit(fetch_market_overview)
            futures["trend"] = executor.submit(fetch_market_trend)
            futures["zt_dt"] = executor.submit(fetch_zt_dt)
            futures["sectors"] = executor.submit(fetch_hot_sectors)

            cap = CapitalAnalyzer()
            futures["northbound"] = executor.submit(cap.fetch_northbound_flow)
            futures["northbound_top10"] = executor.submit(cap.fetch_northbound_top10)
            futures["northbound_industry"] = executor.submit(cap.fetch_northbound_industry_flow)
            futures["main_force"] = executor.submit(cap.fetch_main_force_flow)
            futures["dragon_tiger"] = executor.submit(cap.fetch_dragon_tiger)

            sent = SentimentAnalyzer()
            futures["stats"] = executor.submit(sent.fetch_market_stats)
            futures["fear_greed"] = executor.submit(sent.fetch_fear_greed_index)
            futures["margin"] = executor.submit(sent.fetch_margin_trading)

            sec_analyzer = SectorAnalyzer()
            futures["sector_fundamentals"] = executor.submit(sec_analyzer.fetch_sector_fundamentals)
            futures["multi_rotation"] = executor.submit(sec_analyzer.multi_period_rotation)

        # 收集结果
        for name, future in futures.items():
            try:
                data[name] = future.result(timeout=30)
            except Exception as e:
                logger.warning(f"Data fetch timeout for {name}: {e}")
                data[name] = {} if name != "multi_hot" else {}

    # 处理多平台热搜
    multi_hot = data.pop("multi_hot", {})
    if isinstance(multi_hot, dict):
        from sources.multi_platform import PLATFORM_SOURCES
        data["platform_status"] = {
            PLATFORM_SOURCES.get(k, {}).get("name", k): len(v)
            for k, v in multi_hot.items()
        }
        merged = merge_multi_platform_results(multi_hot)
        data["stock_hot"] = filter_stock_keywords(merged, watchlist=watchlist)
    else:
        data["stock_hot"] = []
        data["platform_status"] = {}

    # 板块强度评分
    if data.get("sectors"):
        sec_analyzer = SectorAnalyzer()
        data["sector_strength"] = sec_analyzer.compute_sector_strength(data["sectors"])

    # 关联分析
    if data.get("stock_hot") and data.get("zt_dt"):
        data["correlation"] = analyze_correlation(data["stock_hot"], data["zt_dt"], data.get("sectors", []))
    else:
        data["correlation"] = {}

    # 板块轮动
    if not getattr(args, 'no_market', False):
        data["rotation"] = analyze_sector_rotation()
        data["limit_percentile"] = SentimentAnalyzer().compute_limit_up_percentile()

    # 交易状态
    data["trading_status"] = format_trading_status()
    data["is_trading_day"] = is_trading_day()
    data["version"] = VERSION
    data["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    # 保存快照
    if not getattr(args, 'no_market', False):
        from utils.common import save_snapshot
        save_snapshot(data)

    return data