文件预览

gateway.py

查看 OmniSkill 通用技能调度器 技能包中的文件内容。

文件内容

src/gateway/gateway.py

import json
import socket
import sys
import threading
import traceback
import logging
from concurrent.futures import Future
from typing import Any, Dict, Optional

# 引入调度引擎和配置
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dispatcher.dispatcher import DispatcherEngine
import config.settings as settings

logger = logging.getLogger(__name__)

class AdaptiveGateway:
    """自适应触发网关 (Adaptive Trigger Gateway)
    
    负责接收跨语言的 RPC/IPC 请求,
    比如套接字(Socket)或子进程标准输入输出(Stdio),
    并将其路由至调度引擎 (DispatcherEngine)。
    它支持阻塞等待结果(同步调用)和即刻返回(异步调用),
    屏蔽了底层的路由细节。
    """
    def __init__(self, dispatcher: DispatcherEngine):
        self.dispatcher = dispatcher
        self._running = False
        self._socket_thread: Optional[threading.Thread] = None
        self._stdio_thread: Optional[threading.Thread] = None
        self._server_socket: Optional[socket.socket] = None

    def process_payload(self, raw_data: str) -> str:
        """解析请求、调用调度引擎、生成响应"""
        try:
            req = json.loads(raw_data)
        except json.JSONDecodeError:
            return json.dumps({"status": "error", "error": "无效的 JSON 格式"})

        req_id = req.get("id", "unknown_request")
        route_type = req.get("route_type")
        payload = req.get("payload")
        mode = req.get("mode", "sync")  # sync(同步) 或 async(异步)
        args = req.get("args", [])
        kwargs = req.get("kwargs", {})

        if not route_type or not payload:
            return json.dumps({
                "id": req_id,
                "status": "error",
                "error": "缺少 route_type 或 payload 字段"
            })

        try:
            # 连接调度引擎
            future: Future = self.dispatcher.dispatch(route_type, payload, *args, **kwargs)
            
            if mode == "sync":
                # 同步模式:阻塞等待
                result = future.result()
                return json.dumps({
                    "id": req_id,
                    "status": "success",
                    "result": result
                })
            elif mode == "async":
                # 异步模式:即刻返回
                return json.dumps({
                    "id": req_id,
                    "status": "pending",
                    "message": "异步请求已受理"
                })
            else:
                return json.dumps({
                    "id": req_id,
                    "status": "error",
                    "error": f"未知的 mode: {mode}"
                })
        except Exception as e:
            logger.error(f"网关执行异常: {traceback.format_exc()}")
            return json.dumps({
                "id": req_id,
                "status": "error",
                "error": str(e)
            })

    def _socket_worker(self, host: str, port: int):
        """基于 Socket 的跨语言进程间通信"""
        self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        
        try:
            self._server_socket.bind((host, port))
            self._server_socket.listen(5)
            self._server_socket.settimeout(1.0)
            logger.info(f"Socket 监听已开启:{host}:{port}")
        except Exception as e:
            logger.error(f"Socket 绑定失败: {e}")
            self._running = False
            return
            
        while self._running:
            try:
                conn, addr = self._server_socket.accept()
            except socket.timeout:
                continue
            except Exception as e:
                if self._running:
                    logger.error(f"Socket accept 异常: {e}")
                break
                
            # 启动线程处理连接
            def handle_client(client_conn):
                with client_conn:
                    try:
                        # 假设单次请求不超过设定的缓冲区大小
                        data = client_conn.recv(settings.GATEWAY_BUFFER_SIZE)
                        if not data:
                            return
                        response_str = self.process_payload(data.decode('utf-8'))
                        client_conn.sendall(response_str.encode('utf-8'))
                    except Exception as e:
                        logger.error(f"处理客户端连接异常: {e}")
            
            # 启动线程处理请求
            threading.Thread(target=handle_client, args=(conn,), daemon=True).start()

    def start_socket_server(self, host: str = settings.GATEWAY_HOST, port: int = settings.GATEWAY_PORT):
        """启动 Socket 监听服务"""
        if self._running:
            return
        self._running = True
        self._socket_thread = threading.Thread(target=self._socket_worker, args=(host, port), daemon=True)
        self._socket_thread.start()

    def _stdio_worker(self):
        """基于子进程标准输入输出的 IPC"""
        logger.info("标准输入输出监听启动")
        while self._running:
            try:
                # 读取一行作为 JSON 请求
                line = sys.stdin.readline()
                if not line:
                    break
                line = line.strip()
                if not line:
                    continue
                
                # 处理请求
                response_str = self.process_payload(line)
                # 将响应输出
                sys.stdout.write(response_str + '\n')
                sys.stdout.flush()
            except Exception as e:
                logger.error(f"读取标准输入异常: {e}")
                
    def start_stdio_server(self):
        """启动 Stdio 监听服务"""
        if self._running:
            return
        self._running = True
        self._stdio_thread = threading.Thread(target=self._stdio_worker, daemon=True)
        self._stdio_thread.start()

    def stop(self):
        """停止所有服务"""
        self._running = False
        if self._server_socket:
            try:
                self._server_socket.close()
            except:
                pass
        
        # 等待线程结束
        if self._socket_thread and self._socket_thread.is_alive():
            self._socket_thread.join(timeout=2.0)
            
        logger.info("网关已停止")