| | |
| | | import json |
| | | import re |
| | | import uuid |
| | | |
| | | from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends |
| | | import asyncio |
| | | import websockets |
| | | from sqlalchemy.orm import Session |
| | | |
| | | from Log import logger |
| | | from app.api import get_current_user_websocket |
| | | from app.config.config import settings |
| | | from app.config.const import IMAGE_TO_TEXT, DOCUMENT_TO_REPORT, DOCUMENT_TO_CLEANING, DOCUMENT_IA_QUESTIONS, \ |
| | | DOCUMENT_TO_REPORT_TITLE, DOCUMENT_TO_TITLE, DOCUMENT_TO_PAPER |
| | | from app.models import MenuCapacityModel |
| | | from app.models.agent_model import AgentModel, AgentType |
| | | from app.models.base_model import get_db |
| | | from app.models.user_model import UserModel |
| | | from app.service.v2.api_token import DfTokenDao |
| | | from app.service.dialog import update_session_history |
| | | from app.service.basic import BasicService |
| | | from app.service.difyService import DifyService |
| | | from app.service.ragflow import RagflowService |
| | | from app.service.token import get_bisheng_token, get_ragflow_token |
| | | from app.service.service_token import get_bisheng_token, get_ragflow_token |
| | | from app.service.session import SessionService |
| | | |
| | | router = APIRouter() |
| | | |
| | |
| | | chat_id: str, |
| | | current_user: UserModel = Depends(get_current_user_websocket), |
| | | db: Session = Depends(get_db)): |
| | | tasks = [] |
| | | await websocket.accept() |
| | | print(f"Client {agent_id} connected") |
| | | |
| | | agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() |
| | | agent = db.query(MenuCapacityModel).filter(MenuCapacityModel.chat_id == agent_id).first() |
| | | if not agent: |
| | | agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() |
| | | agent_type = agent.agent_type |
| | | chat_type = agent.type |
| | | else: |
| | | agent_type = agent.capacity_type |
| | | chat_type = agent.chat_type |
| | | # print(agent_type) |
| | | # print(chat_type) |
| | | if not agent: |
| | | ret = {"message": "Agent not found", "type": "close"} |
| | | await websocket.send_json(ret) |
| | | return |
| | | agent_type = agent.agent_type |
| | | |
| | | if chat_id == "" or chat_id == "0": |
| | | ret = {"message": "Chat ID not found", "type": "close"} |
| | | await websocket.send_json(ret) |
| | | return |
| | | |
| | | # print(agent_type) |
| | | # print(chat_type) |
| | | if agent_type == AgentType.RAGFLOW: |
| | | ragflow_service = RagflowService(settings.ragflow_base_url) |
| | | token = get_ragflow_token(db, current_user.id) |
| | | ragflow_service = RagflowService(settings.fwr_base_url) |
| | | token = await get_ragflow_token(db, current_user.id) |
| | | try: |
| | | async def forward_to_ragflow(): |
| | | while True: |
| | | message = await websocket.receive_json() |
| | | print(f"Received from client {chat_id}: {message}") |
| | | chat_history = message.get('chatHistory', []) |
| | | message["role"] = "user" |
| | | if len(chat_history) == 0: |
| | | chat_history = await ragflow_service.get_session_history(token, chat_id) |
| | | if len(chat_history) == 0: |
| | | chat_history = await ragflow_service.set_session(token, agent_id, |
| | | message["message"], chat_id, True) |
| | | message, chat_id, True) |
| | | # print("chat_history------------------------", chat_history) |
| | | if len(chat_history) == 0: |
| | | result = {"message": "内部错误:创建会话失败", "type": "close"} |
| | | await websocket.send_json(result) |
| | | await websocket.close() |
| | | return |
| | | else: |
| | | chat_history.append({ |
| | | "content": message["message"], |
| | | "doc_ids": message.get("doc_ids", []), |
| | | "role": "user" |
| | | }) |
| | | complete_response = "" |
| | | async for rag_response in ragflow_service.chat(token, chat_id, chat_history): |
| | | try: |
| | | print(f"Received from ragflow: {rag_response}") |
| | | if rag_response[:5] == "data:": |
| | | # 如果是,则截取掉前5个字符,并去除首尾空白符 |
| | | text = rag_response[5:].strip() |
| | | else: |
| | | # 否则,保持原样 |
| | | text = rag_response |
| | | complete_response += text |
| | | try: |
| | | json_data = json.loads(text) |
| | | json_data = json.loads(complete_response) |
| | | data = json_data.get("data") |
| | | if data is True: # 完成输出 |
| | | result = {"message": "", "type": "close"} |
| | | elif data is None: # 发生错误 |
| | | answer = json_data.get("retmsg", json_data.get("retcode")) |
| | | result = {"message": "内部错误:" + answer, "type": "stream"} |
| | | result = {"message": "内部错误:" + answer, "type": "message"} |
| | | else: # 正常输出 |
| | | answer = data.get("answer", "") |
| | | result = {"message": answer, "type": "stream"} |
| | | except json.JSONDecodeError: |
| | | result = {"message": text, "type": "stream"} |
| | | reference = data.get("reference", {}) |
| | | result = {"message": answer, "type": "message", "reference": reference} |
| | | await websocket.send_json(result) |
| | | complete_response = "" |
| | | except json.JSONDecodeError as e: |
| | | print(f"Error decoding JSON: {e}") |
| | | # print(f"Response text: {text}") |
| | | except Exception as e2: |
| | | result = {"message": f"内部错误: {e2}", "type": "close"} |
| | | await websocket.send_json(result) |
| | | print(f"Forwarded to client {chat_id}: {result}") |
| | | except Exception as e: |
| | | result = {"message": f"内部错误: {e}", "type": "close"} |
| | | await websocket.send_json(result) |
| | | print(f"Error process message of ragflow: {e}") |
| | | print(f"Error process message of ragflow: {e2}") |
| | | try: |
| | | dialog_chat_history = await ragflow_service.get_session_history(token, chat_id, 1) |
| | | await update_session_history(db, dialog_chat_history, current_user.id) |
| | | except Exception as e: |
| | | logger.error(e) |
| | | logger.error("-----------------保存ragflow的历史会话异常-----------------") |
| | | |
| | | # 启动任务处理客户端消息 |
| | | tasks = [ |
| | | asyncio.create_task(forward_to_ragflow()) |
| | | ] |
| | | await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) |
| | | except WebSocketDisconnect: |
| | | print(f"Client {chat_id} disconnected") |
| | | except WebSocketDisconnect as e1: |
| | | print(f"Client {chat_id} disconnected: {e1}") |
| | | await websocket.close() |
| | | except Exception as e: |
| | | print(f"Exception occurred: {e}") |
| | | |
| | | finally: |
| | | print("Cleaning up resources of ragflow") |
| | | # 取消所有任务 |
| | | for task in tasks: |
| | | if not task.done(): |
| | | task.cancel() |
| | | try: |
| | | await task |
| | | except asyncio.CancelledError: |
| | | pass |
| | | |
| | | elif agent_type == AgentType.BISHENG: |
| | | token = get_bisheng_token(db, current_user.id) |
| | | service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}" |
| | | token = await get_bisheng_token(db, current_user.id) |
| | | service_uri = f"{settings.sgb_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}" |
| | | headers = {'cookie': f"access_token_cookie={token};"} |
| | | |
| | | async with websockets.connect(service_uri, extra_headers=headers) as service_websocket: |
| | |
| | | except asyncio.CancelledError: |
| | | pass |
| | | |
| | | except WebSocketDisconnect: |
| | | print(f"Client {chat_id} disconnected") |
| | | except WebSocketDisconnect as e: |
| | | print(f"WebSocket connection closed with code {e.code}: {e.reason}") |
| | | await websocket.close() |
| | | await service_websocket.close() |
| | | except Exception as e: |
| | | print(f"Exception occurred: {e}") |
| | | |
| | | finally: |
| | | print("Cleaning up resources of bisheng") |
| | | # 取消所有任务 |
| | | for task in tasks: |
| | | if not task.done(): |
| | | task.cancel() |
| | | try: |
| | | await task |
| | | except asyncio.CancelledError: |
| | | pass |
| | | elif agent_type == AgentType.BASIC: |
| | | try: |
| | | service = BasicService(base_url=settings.basic_base_url) |
| | | while True: |
| | | # 接收前端消息 |
| | | message = await websocket.receive_json() |
| | | question = message.get("message") |
| | | try: |
| | | SessionService(db).create_session( |
| | | chat_id, |
| | | question, |
| | | agent_id, |
| | | AgentType.BASIC, |
| | | current_user.id |
| | | ) |
| | | except Exception as e: |
| | | logger.error(e) |
| | | if not question: |
| | | await websocket.send_json({"message": "Invalid request", "type": "error"}) |
| | | continue |
| | | # logger.error(agent.type) |
| | | if chat_type == "questionTalk": |
| | | |
| | | try: |
| | | data = await service.questions_talk(question, chat_id) |
| | | output = data.get("output", "") |
| | | file_name = data.get("filename", "") |
| | | |
| | | excel_url = None |
| | | if file_name: |
| | | excel_url = f"/api/files/download/?agent_id=basic_question_talk&file_id={file_name}&file_type=word" |
| | | result = {"message": output, "type": "message", "file_url": excel_url, "file_name": file_name} |
| | | try: |
| | | SessionService(db).update_session(chat_id, |
| | | message={"role": "assistant", "content": result}) |
| | | except Exception as e: |
| | | logger.error(e) |
| | | logger.error("-----------------返回数据--------------------") |
| | | await websocket.send_json(result) |
| | | except Exception as e2: |
| | | |
| | | result = {"message": f"内部错误: {e2}", "type": "close"} |
| | | logger.error(str(e2)) |
| | | logger.error(f"Error process message of basic chuti agent: {e2}") |
| | | await websocket.send_json(result) |
| | | |
| | | else: |
| | | message_data = {} |
| | | logger.error("---------------------excel_talk-----------------------------") |
| | | excel_url = "" |
| | | image_url = "" |
| | | image_name = "" |
| | | excel_name = "" |
| | | async for data in service.excel_talk(question, chat_id): |
| | | # logger.error(data) |
| | | output = data.get("output", "") |
| | | e_name = data.get("excel_name", "") |
| | | i_name = data.get("image_name", "") |
| | | |
| | | def build_file_url(name, file_type): |
| | | if not name: |
| | | return None |
| | | return (f"/api/files/download/?agent_id={agent_id}&file_id={name}" |
| | | f"&file_type={file_type}") |
| | | |
| | | if e_name: |
| | | excel_url = build_file_url(e_name, 'excel') |
| | | excel_name = e_name |
| | | if i_name: |
| | | image_url = build_file_url(i_name, 'image') |
| | | image_name = i_name |
| | | if data["type"] == "message": |
| | | message_data = { |
| | | "content": output, |
| | | "excel_url": excel_url, |
| | | "image_url": image_url, |
| | | "image_name": image_name, |
| | | "excel_name": excel_name, |
| | | "sql": data.get("sql", ""), |
| | | "code": data.get("code", ""), |
| | | "e": data.get("e", ""), |
| | | "role": "assistant"} |
| | | |
| | | # 发送结果给客户端 |
| | | # data["type"] = "message" |
| | | data["message"] = output |
| | | data["excel_url"] = excel_url |
| | | data["image_url"] = image_url |
| | | await websocket.send_json(data) |
| | | if message_data: |
| | | try: |
| | | SessionService(db).update_session(chat_id, message=message_data) |
| | | except Exception as e: |
| | | logger.error(f"Unexpected error when update_session: {e}") |
| | | except Exception as e: |
| | | logger.error(e) |
| | | await websocket.send_json({"message": "出现错误!", "type": "error"}) |
| | | finally: |
| | | await websocket.close() |
| | | print(f"Client {agent_id} disconnected") |
| | | if agent_type == AgentType.DIFY: |
| | | dify_service = DifyService(settings.dify_base_url) |
| | | # token = get_dify_token(db, current_user.id) |
| | | try: |
| | | async def forward_to_dify(): |
| | | if chat_type == "imageTalk": |
| | | token = DfTokenDao(db).get_token_by_id(IMAGE_TO_TEXT) |
| | | if not token: |
| | | await websocket.send_json({"message": "Invalid token", "type": "error"}) |
| | | |
| | | while True: |
| | | image_list = [] |
| | | is_image = False |
| | | conversation_id = "" |
| | | receive_message = await websocket.receive_json() |
| | | print(f"Received from client {chat_id}: {receive_message}") |
| | | upload_file_id = receive_message.get('upload_file_id', "") |
| | | question = receive_message.get('message', "") |
| | | if not question and not image_url: |
| | | await websocket.send_json({"message": "Invalid request", "type": "error"}) |
| | | continue |
| | | try: |
| | | session = SessionService(db).create_session( |
| | | chat_id, |
| | | question, |
| | | agent_id, |
| | | AgentType.DIFY, |
| | | current_user.id |
| | | ) |
| | | conversation_id = session.conversation_id |
| | | except Exception as e: |
| | | logger.error(e) |
| | | # complete_response = "" |
| | | |
| | | answer_str = "" |
| | | files = [] |
| | | if upload_file_id: |
| | | files.append({ |
| | | "type": "image", |
| | | "transfer_method": "local_file", |
| | | "url": "", |
| | | "upload_file_id": upload_file_id |
| | | }) |
| | | async for rag_response in dify_service.chat(token, current_user.id, question, files, |
| | | conversation_id, {}): |
| | | # print(rag_response) |
| | | try: |
| | | if rag_response[:5] == "data:": |
| | | # 如果是,则截取掉前5个字符,并去除首尾空白符 |
| | | complete_response = rag_response[5:].strip() |
| | | else: |
| | | # 否则,保持原样 |
| | | complete_response = rag_response |
| | | try: |
| | | data = json.loads(complete_response) |
| | | if data.get("event") == "agent_message": # "event": "message_end" |
| | | if "answer" not in data or not data["answer"]: # 信息过滤 |
| | | logger.error("非法数据--------------------") |
| | | # logger.error(data) |
| | | |
| | | continue |
| | | else: # 正常输出 |
| | | answer = data.get("answer", "") |
| | | if isinstance(answer, str): |
| | | if "]+\)' |
| | | url_image = image_list.pop() |
| | | new_answer = re.sub(pattern, url_image, answer) |
| | | answer_str += new_answer |
| | | else: |
| | | answer_str += answer |
| | | |
| | | elif isinstance(answer, dict): |
| | | logger.error("未知数据体:0---------------------------------") |
| | | logger.error(answer) |
| | | answer_str += answer.get("action_input", "") |
| | | |
| | | result = {"message": answer_str, "type": "message"} |
| | | elif data.get("event") == "message_end": |
| | | images_url = [] |
| | | if image_list and not is_image: |
| | | answer_str += image_list[-1] |
| | | result = {"message": answer_str, |
| | | "type": "close"} # , "message_files": images_url |
| | | try: |
| | | SessionService(db).update_session(chat_id, |
| | | message={"role": "assistant", |
| | | "content": {"answer": answer_str, |
| | | "images": images_url}}, |
| | | conversation_id=data.get( |
| | | "conversation_id")) |
| | | except Exception as e: |
| | | logger.error("保存dify的会话异常!") |
| | | logger.error(e) |
| | | elif data.get("event") == "message_file": |
| | | await dify_service.save_images(data.get("url"), data.get("id") + ".png") |
| | | image_list.append(f"})") |
| | | # result = {"message": answer_str, "type": "message"} |
| | | continue |
| | | else: |
| | | continue |
| | | await websocket.send_json(result) |
| | | complete_response = "" |
| | | except json.JSONDecodeError as e: |
| | | print(f"Error decoding JSON: {e}") |
| | | # print(f"Response text: {text}") |
| | | except Exception as e2: |
| | | result = {"message": f"内部错误: {e2}", "type": "close"} |
| | | await websocket.send_json(result) |
| | | print(f"Error process message of ragflow: {e2}") |
| | | elif chat_type == "reportWorkflow": |
| | | |
| | | token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_CLEANING) |
| | | if not token: |
| | | await websocket.send_json({"message": "Invalid token document_to_cleaning", "type": "error"}) |
| | | while True: |
| | | receive_message = await websocket.receive_json() |
| | | print(f"Received from client {chat_id}: {receive_message}") |
| | | upload_files = receive_message.get('upload_files', []) |
| | | title = receive_message.get('title', "") |
| | | workflow_type = receive_message.get('workflow', 1) |
| | | sub_titles = receive_message.get('sub_titles', "") |
| | | title_number = receive_message.get('title_number', 8) |
| | | title_style = receive_message.get('title_style', "") |
| | | title_query = receive_message.get('title_query', "") |
| | | if upload_files: |
| | | title_query = "start" |
| | | try: |
| | | session = SessionService(db).create_session( |
| | | chat_id, |
| | | title, |
| | | agent_id, |
| | | AgentType.DIFY, |
| | | current_user.id |
| | | ) |
| | | conversation_id = session.conversation_id |
| | | except Exception as e: |
| | | logger.error(e) |
| | | inputs = { |
| | | } |
| | | files = [] |
| | | for file in upload_files: |
| | | files.append({ |
| | | "type": "document", |
| | | "transfer_method": "local_file", |
| | | "url": "", |
| | | "upload_file_id": file |
| | | }) |
| | | if workflow_type == 1: |
| | | inputs["input_files"] = files |
| | | if workflow_type == 2: |
| | | inputs["file_list"] = files |
| | | inputs["Completion_of_main_indicators"] = title |
| | | inputs["sub_titles"] = sub_titles |
| | | token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_REPORT_TITLE) |
| | | if not token: |
| | | await websocket.send_json( |
| | | {"message": "Invalid token document_to_cleaning", "type": "error"}) |
| | | elif workflow_type == 3: |
| | | inputs["file_list"] = files |
| | | inputs["number_of_title"] = title_number |
| | | inputs["title_style"] = title_style |
| | | token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_TITLE) |
| | | if not token: |
| | | await websocket.send_json( |
| | | {"message": "Invalid token document_to_title", "type": "error"}) |
| | | |
| | | complete_response = "" |
| | | if workflow_type == 1 or workflow_type == 2: |
| | | async for rag_response in dify_service.workflow(token, current_user.id, inputs): |
| | | # print(rag_response) |
| | | try: |
| | | if rag_response[:5] == "data:": |
| | | # 如果是,则截取掉前5个字符,并去除首尾空白符 |
| | | complete_response = rag_response[5:].strip() |
| | | elif "event: ping" in rag_response: |
| | | continue |
| | | else: |
| | | # 否则,保持原样 |
| | | complete_response += rag_response |
| | | try: |
| | | data = json.loads(complete_response) |
| | | complete_response = "" |
| | | if data.get("event") == "node_started" or data.get("event") == "node_finished": # "event": "message_end" |
| | | if "data" not in data or not data["data"]: # 信息过滤 |
| | | logger.error("非法数据--------------------") |
| | | logger.error(data) |
| | | continue |
| | | else: # 正常输出 |
| | | answer = data.get("data", "") |
| | | if isinstance(answer, str): |
| | | logger.error("----------------未知数据--------------------") |
| | | logger.error(data) |
| | | continue |
| | | elif isinstance(answer, dict): |
| | | |
| | | message = answer.get("title", "") |
| | | |
| | | result = {"message": message, "type": "system"} |
| | | elif data.get("event") == "workflow_finished": |
| | | answer = data.get("data", "") |
| | | if isinstance(answer, str): |
| | | logger.error("----------------未知数据--------------------") |
| | | logger.error(data) |
| | | result = {"message": "", "type": "close", "download_url": ""} |
| | | elif isinstance(answer, dict): |
| | | download_url = "" |
| | | outputs = answer.get("outputs", {}) |
| | | if outputs: |
| | | message = outputs.get("output", "") |
| | | download_url = outputs.get("download_url", "") |
| | | else: |
| | | message = answer.get("error", "") |
| | | |
| | | result = {"message": message, "type": "message", "download_url": download_url} |
| | | try: |
| | | SessionService(db).update_session(chat_id, |
| | | message={"role": "assistant", |
| | | "content": { |
| | | "answer": message, |
| | | "download_url": download_url}}, |
| | | conversation_id=data.get( |
| | | "conversation_id")) |
| | | except Exception as e: |
| | | logger.error("保存dify的会话异常!") |
| | | logger.error(e) |
| | | await websocket.send_json(result) |
| | | result = {"message": "", "type": "close", "download_url": ""} |
| | | |
| | | |
| | | else: |
| | | continue |
| | | try: |
| | | await websocket.send_json(result) |
| | | except Exception as e: |
| | | logger.error(e) |
| | | logger.error("返回客户端消息异常!") |
| | | complete_response = "" |
| | | except json.JSONDecodeError as e: |
| | | print(f"Error decoding JSON: {e}") |
| | | # print(f"Response text: {text}") |
| | | except Exception as e2: |
| | | result = {"message": f"内部错误: {e2}", "type": "close"} |
| | | await websocket.send_json(result) |
| | | print(f"Error process message of ragflow: {e2}") |
| | | elif workflow_type == 3: |
| | | image_list = [] |
| | | # print(inputs) |
| | | complete_response = "" |
| | | async for rag_response in dify_service.chat(token, current_user.id, title_query, [], |
| | | conversation_id, inputs): |
| | | print(rag_response) |
| | | try: |
| | | if rag_response[:5] == "data:": |
| | | # 如果是,则截取掉前5个字符,并去除首尾空白符 |
| | | complete_response = rag_response[5:].strip() |
| | | elif "event: ping" in rag_response: |
| | | continue |
| | | else: |
| | | # 否则,保持原样 |
| | | complete_response += rag_response |
| | | try: |
| | | data = json.loads(complete_response) |
| | | complete_response = "" |
| | | if data.get("event") == "node_started" or data.get( |
| | | "event") == "node_finished": # "event": "message_end" |
| | | if "data" not in data or not data["data"]: # 信息过滤 |
| | | logger.error("非法数据--------------------") |
| | | logger.error(data) |
| | | continue |
| | | else: # 正常输出 |
| | | answer = data.get("data", "") |
| | | if isinstance(answer, str): |
| | | logger.error("----------------未知数据--------------------") |
| | | logger.error(data) |
| | | continue |
| | | elif isinstance(answer, dict): |
| | | |
| | | message = answer.get("title", "") |
| | | |
| | | result = {"message": message, "type": "system"} |
| | | elif data.get("event") == "message": |
| | | message = data.get("answer", "") |
| | | # try: |
| | | # msg_dict = json.loads(answer) |
| | | # message = msg_dict.get("output", "") |
| | | # except Exception as e: |
| | | # print(e) |
| | | # continue |
| | | result = {"message": message, "type": "message", |
| | | "download_url": ""} |
| | | try: |
| | | SessionService(db).update_session(chat_id, |
| | | message={"role": "assistant", |
| | | "content": { |
| | | "answer": message, |
| | | "download_url": ""}}, |
| | | conversation_id=data.get( |
| | | "conversation_id")) |
| | | except Exception as e: |
| | | logger.error("保存dify的会话异常!") |
| | | logger.error(e) |
| | | # try: |
| | | # await websocket.send_json(result) |
| | | # except Exception as e: |
| | | # logger.error(e) |
| | | # logger.error("返回客户端消息异常!") |
| | | |
| | | elif data.get("event") == "message_end": |
| | | result = {"message": "", "type": "close", "download_url": ""} |
| | | else: |
| | | continue |
| | | try: |
| | | await websocket.send_json(result) |
| | | except Exception as e: |
| | | logger.error(e) |
| | | logger.error("dify返回客户端消息异常!") |
| | | complete_response = "" |
| | | except json.JSONDecodeError as e: |
| | | print(f"Error decoding JSON: {e}") |
| | | # print(f"Response text: {text}") |
| | | except Exception as e2: |
| | | result = {"message": f"内部错误: {e2}", "type": "close"} |
| | | await websocket.send_json(result) |
| | | print(f"Error process message of ragflow: {e2}") |
| | | elif chat_type == "documentIa": |
| | | # print(122112) |
| | | token = DfTokenDao(db).get_token_by_id(DOCUMENT_IA_QUESTIONS) |
| | | # print(token) |
| | | if not token: |
| | | await websocket.send_json({"message": "Invalid token", "type": "error"}) |
| | | |
| | | while True: |
| | | conversation_id = "" |
| | | # print(4343) |
| | | receive_message = await websocket.receive_json() |
| | | print(f"Received from client {chat_id}: {receive_message}") |
| | | upload_file_id = receive_message.get('upload_file_id', []) |
| | | question = receive_message.get('message', "") |
| | | if not question and not image_url: |
| | | await websocket.send_json({"message": "Invalid request", "type": "error"}) |
| | | continue |
| | | try: |
| | | session = SessionService(db).create_session( |
| | | chat_id, |
| | | question, |
| | | agent_id, |
| | | AgentType.DIFY, |
| | | current_user.id |
| | | ) |
| | | conversation_id = session.conversation_id |
| | | except Exception as e: |
| | | logger.error(e) |
| | | # complete_response = "" |
| | | files = [] |
| | | for fileId in upload_file_id: |
| | | files.append({ |
| | | "type": "document", |
| | | "transfer_method": "local_file", |
| | | "url": "", |
| | | "upload_file_id": fileId |
| | | }) |
| | | |
| | | answer_str = "" |
| | | complete_response = "" |
| | | async for rag_response in dify_service.chat(token, current_user.id, question, files, |
| | | conversation_id, {}): |
| | | print(rag_response) |
| | | try: |
| | | if rag_response[:5] == "data:": |
| | | # 如果是,则截取掉前5个字符,并去除首尾空白符 |
| | | complete_response = rag_response[5:].strip() |
| | | elif "event: ping" in rag_response: |
| | | continue |
| | | else: |
| | | # 否则,保持原样 |
| | | complete_response += rag_response |
| | | try: |
| | | data = json.loads(complete_response) |
| | | if data.get("event") == "node_started" or data.get( |
| | | "event") == "node_finished": # "event": "message_end" |
| | | if "data" not in data or not data["data"]: # 信息过滤 |
| | | logger.error("非法数据--------------------") |
| | | logger.error(data) |
| | | continue |
| | | else: # 正常输出 |
| | | answer = data.get("data", "") |
| | | if isinstance(answer, str): |
| | | logger.error("----------------未知数据--------------------") |
| | | logger.error(data) |
| | | continue |
| | | elif isinstance(answer, dict): |
| | | |
| | | message = answer.get("title", "") |
| | | if answer.get("status") == "failed": |
| | | message = answer.get("error") |
| | | |
| | | result = {"message": message, "type": "system"} |
| | | # continue |
| | | elif data.get("event") == "message": # "event": "message_end" |
| | | # 正常输出 |
| | | answer = data.get("answer", "") |
| | | result = {"message": answer, "type": "stream"} |
| | | elif data.get("event") == "error": |
| | | answer = data.get("message", "") |
| | | result = {"message": answer, "type": "system"} |
| | | elif data.get("event") == "workflow_finished": |
| | | answer = data.get("data", "") |
| | | if isinstance(answer, str): |
| | | logger.error("----------------未知数据--------------------") |
| | | logger.error(data) |
| | | # result = {"message": "", "type": "close", "download_url": ""} |
| | | elif isinstance(answer, dict): |
| | | download_url = "" |
| | | outputs = answer.get("outputs", {}) |
| | | if outputs: |
| | | message = outputs.get("answer", "") |
| | | # download_url = outputs.get("download_url", "") |
| | | else: |
| | | message = answer.get("error", "") |
| | | |
| | | result = {"message": message, "type": "system", |
| | | "download_url": download_url} |
| | | try: |
| | | SessionService(db).update_session(chat_id, |
| | | message={"role": "assistant", |
| | | "content": { |
| | | "answer": message, |
| | | "download_url": download_url}}, |
| | | conversation_id=data.get( |
| | | "conversation_id")) |
| | | except Exception as e: |
| | | logger.error("保存dify的会话异常!") |
| | | logger.error(e) |
| | | # await websocket.send_json(result) |
| | | # continue |
| | | elif data.get("event") == "message_end": |
| | | result = {"message": "", "type": "close"} |
| | | |
| | | else: |
| | | continue |
| | | try: |
| | | await websocket.send_json(result) |
| | | except Exception as e: |
| | | logger.error(e) |
| | | logger.error("返回客户端消息异常!") |
| | | complete_response = "" |
| | | except json.JSONDecodeError as e: |
| | | print(f"Error decoding JSON: {e}") |
| | | # print(f"Response text: {text}") |
| | | except Exception as e2: |
| | | result = {"message": f"内部错误: {e2}", "type": "close"} |
| | | await websocket.send_json(result) |
| | | print(f"Error process message of ragflow: {e2}") |
| | | elif chat_type == "paperTalk": |
| | | token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_PAPER) |
| | | # print(token) |
| | | if not token: |
| | | await websocket.send_json({"message": "Invalid token", "type": "error"}) |
| | | |
| | | while True: |
| | | conversation_id = "" |
| | | inputs = {} |
| | | # print(4343) |
| | | receive_message = await websocket.receive_json() |
| | | print(f"Received from client {chat_id}: {receive_message}") |
| | | if "difficulty" in receive_message: |
| | | inputs["Question_Difficulty"] = receive_message["difficulty"] |
| | | if "is_paper" in receive_message: |
| | | inputs["Generate_test_paper"] = receive_message["is_paper"] |
| | | if "single_choice" in receive_message: |
| | | inputs["Multiple_choice_questions"] = receive_message["single_choice"] |
| | | if "gap_filling" in receive_message: |
| | | inputs["Fill_in_blank"] = receive_message["gap_filling"] |
| | | if "true_or_false" in receive_message: |
| | | inputs["true_or_false"] = receive_message["true_or_false"] |
| | | if "multiple_choice" in receive_message: |
| | | inputs["Multiple_Choice"] = receive_message["multiple_choice"] |
| | | if "easy_question" in receive_message: |
| | | inputs["Short_Answer_Questions"] = receive_message["easy_question"] |
| | | if "case_questions" in receive_message: |
| | | inputs["Case_Questions"] = receive_message["case_questions"] |
| | | if "key_words" in receive_message: |
| | | inputs["key_words"] = receive_message["key_words"] |
| | | upload_files = receive_message.get('upload_files', []) |
| | | question = receive_message.get('message', "") |
| | | session_log = SessionService(db).get_session_by_id(chat_id) |
| | | if not session_log and not upload_files: |
| | | await websocket.send_json({"message": "需要上传文档!", "type": "error"}) |
| | | continue |
| | | try: |
| | | session = SessionService(db).create_session( |
| | | chat_id, |
| | | question if question else "开始出题", |
| | | agent_id, |
| | | AgentType.DIFY, |
| | | current_user.id |
| | | ) |
| | | conversation_id = session.conversation_id |
| | | except Exception as e: |
| | | logger.error(e) |
| | | # complete_response = "" |
| | | |
| | | files = [] |
| | | for fileId in upload_files: |
| | | files.append({ |
| | | "type": "document", |
| | | "transfer_method": "local_file", |
| | | "url": "", |
| | | "upload_file_id": fileId |
| | | }) |
| | | if files: |
| | | inputs["upload_files"] = files |
| | | # print(inputs) |
| | | if not question and not inputs: |
| | | await websocket.send_json({"message": "Invalid request", "type": "error"}) |
| | | continue |
| | | |
| | | if not question: |
| | | question = "开始出题" |
| | | complete_response = "" |
| | | async for rag_response in dify_service.chat(token, current_user.id, question, files, |
| | | conversation_id, inputs): |
| | | # print(rag_response) |
| | | try: |
| | | if rag_response[:5] == "data:": |
| | | # 如果是,则截取掉前5个字符,并去除首尾空白符 |
| | | complete_response = rag_response[5:].strip() |
| | | elif "event: ping" in rag_response: |
| | | continue |
| | | else: |
| | | # 否则,保持原样 |
| | | complete_response += rag_response |
| | | try: |
| | | data = json.loads(complete_response) |
| | | # print(data) |
| | | if data.get("event") == "node_started" or data.get( |
| | | "event") == "node_finished": # "event": "message_end" |
| | | if "data" not in data or not data["data"]: # 信息过滤 |
| | | logger.error("非法数据--------------------") |
| | | logger.error(data) |
| | | continue |
| | | else: # 正常输出 |
| | | answer = data.get("data", "") |
| | | if isinstance(answer, str): |
| | | logger.error("----------------未知数据--------------------") |
| | | logger.error(data) |
| | | continue |
| | | elif isinstance(answer, dict): |
| | | |
| | | message = answer.get("title", "") |
| | | |
| | | result = {"message": message, "type": "system"} |
| | | # continue |
| | | elif data.get("event") == "message": # "event": "message_end" |
| | | # 正常输出 |
| | | answer = data.get("answer", "") |
| | | result = {"message": answer, "type": "stream"} |
| | | elif data.get("event") == "error": |
| | | answer = data.get("message", "") |
| | | result = {"message": answer, "type": "system"} |
| | | elif data.get("event") == "workflow_finished": |
| | | answer = data.get("data", "") |
| | | if isinstance(answer, str): |
| | | logger.error("----------------未知数据--------------------") |
| | | logger.error(data) |
| | | result = {"message": "", "type": "close", "download_url": ""} |
| | | elif isinstance(answer, dict): |
| | | download_url = "" |
| | | outputs = answer.get("outputs", {}) |
| | | if outputs: |
| | | message = outputs.get("answer", "") |
| | | download_url = outputs.get("download_url", "") |
| | | else: |
| | | message = answer.get("error", "") |
| | | |
| | | result = {"message": message, "type": "system", |
| | | "download_url": download_url} |
| | | try: |
| | | SessionService(db).update_session(chat_id, |
| | | message={"role": "assistant", |
| | | "content": { |
| | | "answer": message, |
| | | "download_url": download_url}}, |
| | | conversation_id=data.get( |
| | | "conversation_id")) |
| | | except Exception as e: |
| | | logger.error("保存dify的会话异常!") |
| | | logger.error(e) |
| | | # await websocket.send_json(result) |
| | | # continue |
| | | elif data.get("event") == "message_end": |
| | | result = {"message": "", "type": "close"} |
| | | |
| | | else: |
| | | continue |
| | | try: |
| | | await websocket.send_json(result) |
| | | except Exception as e: |
| | | logger.error(e) |
| | | logger.error("返回客户端消息异常!") |
| | | complete_response = "" |
| | | except json.JSONDecodeError as e: |
| | | print(f"Error decoding JSON: {e}") |
| | | # print(f"Response text: {text}") |
| | | except Exception as e2: |
| | | result = {"message": f"内部错误: {e2}", "type": "close"} |
| | | await websocket.send_json(result) |
| | | print(f"Error process message of ragflow: {e2}") |
| | | |
| | | |
| | | # 启动任务处理客户端消息 |
| | | tasks = [ |
| | | asyncio.create_task(forward_to_dify()) |
| | | ] |
| | | await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) |
| | | except WebSocketDisconnect as e1: |
| | | print(f"Client {chat_id} disconnected: {e1}") |
| | | await websocket.close() |
| | | except Exception as e: |
| | | print(f"Exception occurred: {e}") |
| | | |
| | | finally: |
| | | print("Cleaning up resources of ragflow") |
| | | # 取消所有任务 |
| | | for task in tasks: |
| | | if not task.done(): |
| | | task.cancel() |
| | | try: |
| | | await task |
| | | except asyncio.CancelledError: |
| | | pass |
| | | else: |
| | | ret = {"message": "Agent not found", "type": "close"} |
| | | await websocket.send_json(ret) |
| | | |