import json import uuid from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Request, Depends import asyncio import websockets from sqlalchemy.orm import Session from app.api import get_current_user_websocket from app.config.config import settings from app.models.base_model import get_db from app.models.user_model import UserModel from app.service.token import get_bisheng_token router = APIRouter() # 存储客户端 WebSocket 连接 client_websockets = {} # 中间层WebSocket 服务器,接收客户端的连接 @router.websocket("/ws/{agent_id}/{chat_id}") async def handle_client(websocket: WebSocket, agent_id: str, chat_id: str, current_user: UserModel = Depends(get_current_user_websocket), db: Session = Depends(get_db)): await websocket.accept() print(f"Client {agent_id} connected") token = get_bisheng_token(db, current_user.id) if agent_id == "0": agent_id = settings.bisheng_agent_id if chat_id == "0": chat_id = uuid.uuid4().hex # 连接到服务端 service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}" headers = { 'cookie': f"access_token_cookie={token};" } async with websockets.connect(service_uri, extra_headers=headers) as service_websocket: client_websockets[chat_id] = websocket try: # 处理客户端发来的消息 async def forward_to_service(): while True: message = await websocket.receive_json() print(f"Received from client {chat_id}: {message}") # 添加 'agent_id' 和 'chat_id' 字段 message['flow_id'] = agent_id message['chat_id'] = chat_id msg = message["message"] del message["message"] message['inputs'] = { "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"}, "input": msg } await service_websocket.send(json.dumps(message)) print(f"Forwarded to bisheng: {message}") # 监听毕昇发来的消息并转发给客户端 async def forward_to_client(): while True: message = await service_websocket.recv() print(f"Received from service S: {message}") await websocket.send_text(message) print(f"Forwarded to client {chat_id}: {message}") # 启动两个任务,分别处理客户端和服务端的消息 tasks = [ asyncio.create_task(forward_to_service()), asyncio.create_task(forward_to_client()) ] done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) # 取消未完成的任务 for task in pending: task.cancel() try: await task except asyncio.CancelledError: pass except WebSocketDisconnect: print(f"Client {chat_id} disconnected") finally: del client_websockets[chat_id]