import json from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends, HTTPException, Query import asyncio import websockets from sqlalchemy.orm import Session from Log import logger from app.api import get_current_user_websocket, ResponseList, get_current_user, format_file_url, process_files from app.config.config import settings from app.models.agent_model import AgentModel, AgentType from app.models.base_model import get_db from app.models.user_model import UserModel from app.service.bisheng import BishengService from app.service.service_token import get_bisheng_token router = APIRouter() @router.websocket("/ws/{agent_id}/{chat_id}") async def report_chat(websocket: WebSocket, agent_id: str, chat_id: str, current_user: UserModel = Depends(get_current_user_websocket), db: Session = Depends(get_db)): agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() if not agent: ret = {"message": "Agent not found", "type": "close"} return websocket.send_json(ret) agent_type = agent.agent_type if chat_id == "" or chat_id == "0": ret = {"message": "Chat ID not found", "type": "close"} return websocket.send_json(ret) if agent_type != AgentType.BISHENG: ret = {"message": "Agent error", "type": "close"} return websocket.send_json(ret) token = get_bisheng_token(db, current_user.id) service_uri = f"{settings.sgb_websocket_url}/api/v1/chat/{agent_id}?type=L1&t=&chat_id={chat_id}" headers = {'cookie': f"access_token_cookie={token};"} await websocket.accept() print(f"Client {agent_id} connected") async with websockets.connect(service_uri, extra_headers=headers) as service_websocket: try: # 处理客户端发来的消息 async def forward_to_service(): while True: message = await websocket.receive_json() print(f"Received from client, {chat_id}: {message}") # 添加 'agent_id' 和 'chat_id' 字段 message['flow_id'] = agent_id message['chat_id'] = chat_id await service_websocket.send(json.dumps(message)) print(f"Forwarded to bisheng: {message}") # 监听毕昇发来的消息并转发给客户端 async def forward_to_client(): is_answer = False while True: try: message = await service_websocket.recv() # print(f"Received from bisheng: {message}") data = json.loads(message) files = data.get("files", []) steps = data.get("intermediate_steps", "") msg = data.get("message", "") category = data.get("category", "") process_files(files, agent_id) if category == "question" and steps: is_answer = False if not steps: steps = "\n" else: steps = steps + "\n" result = {"message": steps, "type": "stream", "files": files} await websocket.send_json(result) if category == "answer" and not is_answer: if not steps.endswith("\n"): steps += "\n\n" result = {"message": steps, "type": "stream", "files": files} await websocket.send_json(result) if category == "answer" and is_answer: # process_files(files, agent_id) result = {"message": "\n", "type": "stream", "files": files} await websocket.send_json(result) elif data["type"] == "close": # process_files(files, agent_id) result = {"message": "", "type": "close", "files": files} await websocket.send_json(result) elif category == "processing": # process_files(files, agent_id) is_answer = True result = {"message": msg, "type": "stream", "files": files} await websocket.send_json(result) elif files: # process_files(files, agent_id) result = {"message": "", "type": "stream", "files": files} await websocket.send_json(result) elif category == "system" and steps: result = {"message": steps, "type": "stream", "files": files} await websocket.send_json(result) else: logger.error("-------------------11111111111111--------------------------") logger.error(data) except Exception as e: logger.error(e) await websocket.send_json({"message": "连接异常!", "type": "close", "files": []}) # if len(files) != 0 or (msg and category != "answer") or data["type"] == "close": # if data["type"] == "close": # t = "close" # else: # t = "stream" # process_files(files, agent_id) # result = {"message": msg, "type": t, "files": files} # await websocket.send_json(result) # elif steps and last_message == "step": # result = {"step_message": steps, "type": "stream", "files": files} # await websocket.send_json(result) # last_message = "message" if msg else "step" # 启动两个任务,分别处理客户端和服务端的消息 tasks = [ asyncio.create_task(forward_to_service()), asyncio.create_task(forward_to_client()) ] done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) # 取消未完成的任务 for task in pending: task.cancel() try: await task except asyncio.CancelledError as e: print(f"asyncio CancelledError: {e}") pass except WebSocketDisconnect as e: print(f"WebSocket connection closed with code {e.code}: {e.reason}") await websocket.close() await service_websocket.close() except Exception as e: print(f"Exception occurred: {e}") finally: print("Cleaning up resources of bisheng report") # 取消所有任务 for task in tasks: if not task.done(): task.cancel() try: await task except asyncio.CancelledError: pass @router.get("/variables/list", response_model=ResponseList) async def get_variables(agent_id: str = Query(..., description="The ID of the agent"), db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_user)): agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() if not agent: return ResponseList(code=404, msg="Agent not found") bisheng_service = BishengService(base_url=settings.sgb_base_url) try: token = get_bisheng_token(db, current_user.id) result = await bisheng_service.variable_list(token, agent_id) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) return ResponseList(code=200, msg="", data=result)