import json import re import uuid from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends import asyncio import websockets from sqlalchemy.orm import Session from Log import logger from app.api import get_current_user_websocket from app.config.config import settings from app.config.const import IMAGE_TO_TEXT, DOCUMENT_TO_REPORT, DOCUMENT_TO_CLEANING from app.models.agent_model import AgentModel, AgentType from app.models.base_model import get_db from app.models.user_model import UserModel from app.service.v2.api_token import DfTokenDao from app.service.dialog import update_session_history from app.service.basic import BasicService from app.service.difyService import DifyService from app.service.ragflow import RagflowService from app.service.service_token import get_bisheng_token, get_ragflow_token from app.service.session import SessionService router = APIRouter() # 中间层WebSocket 服务器,接收客户端的连接 @router.websocket("/ws/{agent_id}/{chat_id}") async def handle_client(websocket: WebSocket, agent_id: str, chat_id: str, current_user: UserModel = Depends(get_current_user_websocket), db: Session = Depends(get_db)): tasks = [] await websocket.accept() print(f"Client {agent_id} connected") agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() if not agent: ret = {"message": "Agent not found", "type": "close"} await websocket.send_json(ret) return print(1111) agent_type = agent.agent_type print(agent_type) if chat_id == "" or chat_id == "0": ret = {"message": "Chat ID not found", "type": "close"} await websocket.send_json(ret) return if agent_type == AgentType.RAGFLOW: print(222) ragflow_service = RagflowService(settings.fwr_base_url) token = await get_ragflow_token(db, current_user.id) try: async def forward_to_ragflow(): while True: message = await websocket.receive_json() print(f"Received from client {chat_id}: {message}") chat_history = message.get('chatHistory', []) message["role"] = "user" if len(chat_history) == 0: chat_history = await ragflow_service.get_session_history(token, chat_id) if len(chat_history) == 0: chat_history = await ragflow_service.set_session(token, agent_id, message, chat_id, True) # print("chat_history------------------------", chat_history) if len(chat_history) == 0: result = {"message": "内部错误:创建会话失败", "type": "close"} await websocket.send_json(result) await websocket.close() return else: chat_history.append({ "content": message["message"], "doc_ids": message.get("doc_ids", []), "role": "user" }) complete_response = "" async for rag_response in ragflow_service.chat(token, chat_id, chat_history): try: if rag_response[:5] == "data:": # 如果是,则截取掉前5个字符,并去除首尾空白符 text = rag_response[5:].strip() else: # 否则,保持原样 text = rag_response complete_response += text try: json_data = json.loads(complete_response) data = json_data.get("data") if data is True: # 完成输出 result = {"message": "", "type": "close"} elif data is None: # 发生错误 answer = json_data.get("retmsg", json_data.get("retcode")) result = {"message": "内部错误:" + answer, "type": "message"} else: # 正常输出 answer = data.get("answer", "") reference = data.get("reference", {}) result = {"message": answer, "type": "message", "reference": reference} await websocket.send_json(result) complete_response = "" except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") # print(f"Response text: {text}") except Exception as e2: result = {"message": f"内部错误: {e2}", "type": "close"} await websocket.send_json(result) print(f"Error process message of ragflow: {e2}") try: dialog_chat_history = await ragflow_service.get_session_history(token, chat_id, 1) await update_session_history(db, dialog_chat_history, current_user.id) except Exception as e: logger.error(e) logger.error("-----------------保存ragflow的历史会话异常-----------------") # 启动任务处理客户端消息 tasks = [ asyncio.create_task(forward_to_ragflow()) ] await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) except WebSocketDisconnect as e1: print(f"Client {chat_id} disconnected: {e1}") await websocket.close() except Exception as e: print(f"Exception occurred: {e}") finally: print("Cleaning up resources of ragflow") # 取消所有任务 for task in tasks: if not task.done(): task.cancel() try: await task except asyncio.CancelledError: pass elif agent_type == AgentType.BISHENG: token = await get_bisheng_token(db, current_user.id) service_uri = f"{settings.sgb_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}" headers = {'cookie': f"access_token_cookie={token};"} async with websockets.connect(service_uri, extra_headers=headers) as service_websocket: try: # 处理客户端发来的消息 async def forward_to_service(): while True: message = await websocket.receive_json() print(f"Received from client, {chat_id}: {message}") # 添加 'agent_id' 和 'chat_id' 字段 message['flow_id'] = agent_id message['chat_id'] = chat_id msg = message["message"] del message["message"] message['inputs'] = { "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"}, "input": msg } await service_websocket.send(json.dumps(message)) print(f"Forwarded to bisheng: {message}") # 监听毕昇发来的消息并转发给客户端 async def forward_to_client(): while True: message = await service_websocket.recv() print(f"Received from bisheng: {message}") data = json.loads(message) if data["type"] == "close" or data["type"] == "stream" or data["type"] == "end_cover": if data["type"] == "close": t = "close" else: t = "stream" result = {"message": data["message"], "type": t} await websocket.send_json(result) print(f"Forwarded to client, {chat_id}: {result}") # 启动两个任务,分别处理客户端和服务端的消息 tasks = [ asyncio.create_task(forward_to_service()), asyncio.create_task(forward_to_client()) ] done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) # 取消未完成的任务 for task in pending: task.cancel() try: await task except asyncio.CancelledError: pass except WebSocketDisconnect as e: print(f"WebSocket connection closed with code {e.code}: {e.reason}") await websocket.close() await service_websocket.close() except Exception as e: print(f"Exception occurred: {e}") finally: print("Cleaning up resources of bisheng") # 取消所有任务 for task in tasks: if not task.done(): task.cancel() try: await task except asyncio.CancelledError: pass elif agent_type == AgentType.BASIC: try: service = BasicService(base_url=settings.basic_base_url) while True: # 接收前端消息 message = await websocket.receive_json() question = message.get("message") try: SessionService(db).create_session( chat_id, question, agent_id, AgentType.BASIC, current_user.id ) except Exception as e: logger.error(e) if not question: await websocket.send_json({"message": "Invalid request", "type": "error"}) continue logger.error(agent.type) if agent.type == "questionTalk": try: data = await service.questions_talk(question, chat_id) output = data.get("output", "") file_name = data.get("filename", "") excel_url = None if file_name: excel_url = f"/api/files/download/?agent_id=basic_question_talk&file_id={file_name}&file_type=word" result = {"message": output, "type": "message", "file_url": excel_url, "file_name": file_name} try: SessionService(db).update_session(chat_id, message={"role": "assistant", "content": result}) except Exception as e: logger.error(e) logger.error("-----------------返回数据--------------------") await websocket.send_json(result) except Exception as e2: result = {"message": f"内部错误: {e2}", "type": "close"} logger.error(str(e2)) logger.error(f"Error process message of basic chuti agent: {e2}") await websocket.send_json(result) else: message_data = {} logger.error("---------------------excel_talk-----------------------------") excel_url = "" image_url = "" image_name = "" excel_name = "" async for data in service.excel_talk(question, chat_id): # logger.error(data) output = data.get("output", "") e_name = data.get("excel_name", "") i_name = data.get("image_name", "") def build_file_url(name, file_type): if not name: return None return (f"/api/files/download/?agent_id={agent_id}&file_id={name}" f"&file_type={file_type}") if e_name: excel_url = build_file_url(e_name, 'excel') excel_name = e_name if i_name: image_url = build_file_url(i_name, 'image') image_name = i_name if data["type"] == "message": message_data = { "content": output, "excel_url": excel_url, "image_url": image_url, "image_name": image_name, "excel_name": excel_name, "sql": data.get("sql", ""), "code": data.get("code", ""), "e": data.get("e", ""), "role": "assistant"} # 发送结果给客户端 # data["type"] = "message" data["message"] = output data["excel_url"] = excel_url data["image_url"] = image_url await websocket.send_json(data) if message_data: try: SessionService(db).update_session(chat_id, message=message_data) except Exception as e: logger.error(f"Unexpected error when update_session: {e}") except Exception as e: logger.error(e) await websocket.send_json({"message": "出现错误!", "type": "error"}) finally: await websocket.close() print(f"Client {agent_id} disconnected") if agent_type == AgentType.DIFY: dify_service = DifyService(settings.dify_base_url) # token = get_dify_token(db, current_user.id) try: async def forward_to_dify(): if agent.type == "imageTalk": token = DfTokenDao(db).get_token_by_id(IMAGE_TO_TEXT) if not token: await websocket.send_json({"message": "Invalid token", "type": "error"}) while True: image_list = [] is_image = False conversation_id = "" receive_message = await websocket.receive_json() print(f"Received from client {chat_id}: {receive_message}") upload_file_id = receive_message.get('upload_file_id', "") question = receive_message.get('message', "") if not question and not image_url: await websocket.send_json({"message": "Invalid request", "type": "error"}) continue try: session = SessionService(db).create_session( chat_id, question, agent_id, AgentType.DIFY, current_user.id ) conversation_id = session.conversation_id except Exception as e: logger.error(e) # complete_response = "" answer_str = "" async for rag_response in dify_service.chat(token, current_user.id, question, upload_file_id, conversation_id): # print(rag_response) try: if rag_response[:5] == "data:": # 如果是,则截取掉前5个字符,并去除首尾空白符 complete_response = rag_response[5:].strip() else: # 否则,保持原样 complete_response = rag_response try: data = json.loads(complete_response) if data.get("event") == "agent_message": # "event": "message_end" if "answer" not in data or not data["answer"]: # 信息过滤 logger.error("非法数据--------------------") # logger.error(data) continue else: # 正常输出 answer = data.get("answer", "") if isinstance(answer, str): if "![](https://res.stepfun.com/" in answer and image_list: is_image = True pattern = r'!\[\] *\(https://res\.stepfun\.com/image_gen/[^)]+\)' url_image = image_list.pop() new_answer = re.sub(pattern, url_image, answer) answer_str += new_answer else: answer_str += answer elif isinstance(answer, dict): logger.error("未知数据体:0---------------------------------") logger.error(answer) answer_str += answer.get("action_input", "") result = {"message": answer_str, "type": "message"} elif data.get("event") == "message_end": images_url = [] if image_list and not is_image: answer_str += image_list[-1] result = {"message": answer_str, "type": "close"} # , "message_files": images_url try: SessionService(db).update_session(chat_id, message={"role": "assistant", "content": {"answer": answer_str, "images": images_url}}, conversation_id=data.get( "conversation_id")) except Exception as e: logger.error("保存dify的会话异常!") logger.error(e) elif data.get("event") == "message_file": await dify_service.save_images(data.get("url"), data.get("id") + ".png") image_list.append(f"![](/api/files/image/{data.get('id')})") # result = {"message": answer_str, "type": "message"} continue else: continue await websocket.send_json(result) complete_response = "" except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") # print(f"Response text: {text}") except Exception as e2: result = {"message": f"内部错误: {e2}", "type": "close"} await websocket.send_json(result) print(f"Error process message of ragflow: {e2}") elif agent.type == "reportWorkflow": token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_CLEANING) if not token: await websocket.send_json({"message": "Invalid token document_to_cleaning", "type": "error"}) while True: receive_message = await websocket.receive_json() print(f"Received from client {chat_id}: {receive_message}") upload_files = receive_message.get('upload_files', []) title = receive_message.get('title', "") workflow_type = receive_message.get('workflow', 1) if not upload_files: await websocket.send_json({"message": "Invalid request", "type": "error"}) continue try: session = SessionService(db).create_session( chat_id, title, agent_id, AgentType.DIFY, current_user.id ) conversation_id = session.conversation_id except Exception as e: logger.error(e) inputs = { } files = [] for file in upload_files: files.append({ "type": "document", "transfer_method": "local_file", "url": "", "upload_file_id": file }) if workflow_type == 1: inputs["input_files"] = files if workflow_type == 2: inputs["file_list"] = files inputs["Completion_of_main_indicators"] = title token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_REPORT) if not token: await websocket.send_json( {"message": "Invalid token document_to_cleaning", "type": "error"}) complete_response = "" async for rag_response in dify_service.workflow(token, current_user.id, inputs): # print(rag_response) try: if rag_response[:5] == "data:": # 如果是,则截取掉前5个字符,并去除首尾空白符 complete_response = rag_response[5:].strip() elif "event: ping" in rag_response: continue else: # 否则,保持原样 complete_response += rag_response try: data = json.loads(complete_response) complete_response = "" if data.get("event") == "node_started" or data.get("event") == "node_finished": # "event": "message_end" if "data" not in data or not data["data"]: # 信息过滤 logger.error("非法数据--------------------") logger.error(data) continue else: # 正常输出 answer = data.get("data", "") if isinstance(answer, str): logger.error("----------------未知数据--------------------") logger.error(data) continue elif isinstance(answer, dict): message = answer.get("title", "") result = {"message": message, "type": "system"} elif data.get("event") == "workflow_finished": answer = data.get("data", "") if isinstance(answer, str): logger.error("----------------未知数据--------------------") logger.error(data) result = {"message": "", "type": "close", "download_url": ""} elif isinstance(answer, dict): download_url = "" outputs = answer.get("outputs", {}) if outputs: message = outputs.get("output", "") download_url = outputs.get("download_url", "") else: message = answer.get("error", "") result = {"message": message, "type": "message", "download_url": download_url} try: SessionService(db).update_session(chat_id, message={"role": "assistant", "content": { "answer": message, "download_url": download_url}}, conversation_id=data.get( "conversation_id")) except Exception as e: logger.error("保存dify的会话异常!") logger.error(e) await websocket.send_json(result) result = {"message": "", "type": "close", "download_url": ""} else: continue try: await websocket.send_json(result) except Exception as e: logger.error(e) logger.error("返回客户端消息异常!") complete_response = "" except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") # print(f"Response text: {text}") except Exception as e2: result = {"message": f"内部错误: {e2}", "type": "close"} await websocket.send_json(result) print(f"Error process message of ragflow: {e2}") # 启动任务处理客户端消息 tasks = [ asyncio.create_task(forward_to_dify()) ] await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) except WebSocketDisconnect as e1: print(f"Client {chat_id} disconnected: {e1}") await websocket.close() except Exception as e: print(f"Exception occurred: {e}") finally: print("Cleaning up resources of ragflow") # 取消所有任务 for task in tasks: if not task.done(): task.cancel() try: await task except asyncio.CancelledError: pass else: ret = {"message": "Agent not found", "type": "close"} await websocket.send_json(ret)