import json import uuid from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends import asyncio import websockets from sqlalchemy.orm import Session from app.api import get_current_user_websocket from app.config.config import settings from app.models.agent_model import AgentModel, AgentType from app.models.base_model import get_db from app.models.user_model import UserModel from app.service.ragflow import RagflowService from app.service.token import get_bisheng_token, get_ragflow_token router = APIRouter() # 中间层WebSocket 服务器,接收客户端的连接 @router.websocket("/ws/{agent_id}/{chat_id}") async def handle_client(websocket: WebSocket, agent_id: str, chat_id: str, current_user: UserModel = Depends(get_current_user_websocket), db: Session = Depends(get_db)): await websocket.accept() print(f"Client {agent_id} connected") agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() if not agent: ret = {"message": "Agent not found", "type": "close"} await websocket.send_json(ret) return agent_type = agent.agent_type if chat_id == "" or chat_id == "0": ret = {"message": "Chat ID not found", "type": "close"} await websocket.send_json(ret) return if agent_type == AgentType.RAGFLOW: ragflow_service = RagflowService(settings.ragflow_base_url) token = get_ragflow_token(db, current_user.id) try: async def forward_to_ragflow(): while True: message = await websocket.receive_json() print(f"Received from client {chat_id}: {message}") chat_history = message.get('chatHistory', []) if len(chat_history) == 0: chat_history = await ragflow_service.set_session(token, agent_id, message["message"], chat_id, True) if len(chat_history) == 0: result = {"message": "内部错误:创建会话失败", "type": "close"} await websocket.send_json(result) async for rag_response in ragflow_service.chat(token, chat_id, chat_history): try: print(f"Received from ragflow: {rag_response}") if rag_response[:5] == "data:": # 如果是,则截取掉前5个字符,并去除首尾空白符 json_str = rag_response[5:].strip() else: # 否则,保持原样 json_str = rag_response json_data = json.loads(json_str) data = json_data.get("data") if data is True: # 完成输出 result = {"message": "", "type": "close"} elif data is None: # 发生错误 answer = json_data.get("retmsg", json_data.get("retcode")) result = {"message": "内部错误:" + answer, "type": "stream"} else: # 正常输出 answer = data.get("answer", "") result = {"message": answer, "type": "stream"} await websocket.send_json(result) print(f"Forwarded to client {chat_id}: {result}") except Exception as e: result = {"message": f"内部错误: {e}", "type": "close"} await websocket.send_json(result) print(f"Error process message of ragflow: {e}") # 启动任务处理客户端消息 tasks = [ asyncio.create_task(forward_to_ragflow()) ] await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) except WebSocketDisconnect: print(f"Client {chat_id} disconnected") elif agent_type == AgentType.BISHENG: token = get_bisheng_token(db, current_user.id) service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}" headers = {'cookie': f"access_token_cookie={token};"} async with websockets.connect(service_uri, extra_headers=headers) as service_websocket: try: # 处理客户端发来的消息 async def forward_to_service(): while True: message = await websocket.receive_json() print(f"Received from client, {chat_id}: {message}") # 添加 'agent_id' 和 'chat_id' 字段 message['flow_id'] = agent_id message['chat_id'] = chat_id msg = message["message"] del message["message"] message['inputs'] = { "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"}, "input": msg } await service_websocket.send(json.dumps(message)) print(f"Forwarded to bisheng: {message}") # 监听毕昇发来的消息并转发给客户端 async def forward_to_client(): while True: message = await service_websocket.recv() print(f"Received from bisheng: {message}") data = json.loads(message) if data["type"] == "close" or data["type"] == "stream" or data["type"] == "end_cover": if data["type"] == "close": t = "close" else: t = "stream" result = {"message": data["message"], "type": t} await websocket.send_json(result) print(f"Forwarded to client, {chat_id}: {result}") # 启动两个任务,分别处理客户端和服务端的消息 tasks = [ asyncio.create_task(forward_to_service()), asyncio.create_task(forward_to_client()) ] done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) # 取消未完成的任务 for task in pending: task.cancel() try: await task except asyncio.CancelledError: pass except WebSocketDisconnect: print(f"Client {chat_id} disconnected") else: ret = {"message": "Agent not found", "type": "close"} await websocket.send_json(ret)