| | |
| | | import json |
| | | import re |
| | | import uuid |
| | | from copy import deepcopy |
| | | |
| | | from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends |
| | | import asyncio |
| | |
| | | title_number = receive_message.get('title_number', 8) |
| | | title_style = receive_message.get('title_style', "") |
| | | title_query = receive_message.get('title_query', "") |
| | | is_clean = receive_message.get('is_clean', 0) |
| | | file_type = receive_message.get('file_type', 1) |
| | | max_token = receive_message.get('max_tokens', 100000) |
| | | tokens = receive_message.get('tokens', 0) |
| | | if upload_files: |
| | | title_query = "start" |
| | | try: |
| | |
| | | } |
| | | files = [] |
| | | for file in upload_files: |
| | | if file_type == 1: |
| | | files.append({ |
| | | "type": "document", |
| | | "transfer_method": "local_file", |
| | | "url": "", |
| | | "upload_file_id": file |
| | | }) |
| | | else: |
| | | files.append({ |
| | | "type": "document", |
| | | "transfer_method": "remote_url", |
| | | "url": file, |
| | | "upload_file_id": "" |
| | | }) |
| | | inputs_list = [] |
| | | is_next = 0 |
| | | if workflow_type == 1: |
| | | token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_CLEANING) |
| | | if not token: |
| | | await websocket.send_json( |
| | | {"message": "Invalid token document_to_cleaning", "type": "error"}) |
| | | inputs["input_files"] = files |
| | | inputs["Completion_of_main_indicators"] = title |
| | | inputs_list.append({"inputs": inputs, "token": token, "workflow_type": workflow_type}) |
| | | if workflow_type == 2: |
| | | inputs["file_list"] = files |
| | | inputs["Completion_of_main_indicators"] = title |
| | |
| | | if not token: |
| | | await websocket.send_json( |
| | | {"message": "Invalid token document_to_cleaning", "type": "error"}) |
| | | elif workflow_type == 3: |
| | | inputs_list.append({"inputs": inputs, "token": token, "workflow_type": workflow_type}) |
| | | elif workflow_type == 3 and is_clean == 0 and tokens < max_token: |
| | | inputs["file_list"] = files |
| | | inputs["number_of_title"] = title_number |
| | | inputs["title_style"] = title_style |
| | |
| | | if not token: |
| | | await websocket.send_json( |
| | | {"message": "Invalid token document_to_title", "type": "error"}) |
| | | |
| | | inputs_list.append({"inputs": inputs, "token": token, "workflow_type": workflow_type}) |
| | | elif workflow_type == 3 and is_clean == 1 or tokens >= max_token: |
| | | token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_CLEANING) |
| | | if not token: |
| | | await websocket.send_json( |
| | | {"message": "Invalid token document_to_cleaning", "type": "error"}) |
| | | inputs["input_files"] = files |
| | | inputs["Completion_of_main_indicators"] = title |
| | | inputs_list.append({"inputs": inputs, "token": token, "workflow_type": 1}) |
| | | inputs1 = {} |
| | | inputs1["file_list"] = files |
| | | inputs1["number_of_title"] = title_number |
| | | inputs1["title_style"] = title_style |
| | | token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_TITLE) |
| | | if not token: |
| | | await websocket.send_json( |
| | | {"message": "Invalid token document_to_report", "type": "error"}) |
| | | inputs_list.append({"inputs": inputs1, "token": token, "workflow_type": 3}) |
| | | complete_response = "" |
| | | if workflow_type == 1 or workflow_type == 2: |
| | | async for rag_response in dify_service.workflow(token, current_user.id, inputs): |
| | | for idx, input in enumerate(inputs_list): |
| | | # print(input) |
| | | if idx < len(inputs_list) - 1: |
| | | is_next = 1 |
| | | else: |
| | | is_next = 0 |
| | | i = input["inputs"] |
| | | if "file_list" in i: |
| | | i["file_list"] = files |
| | | # print(i) |
| | | node_list = [] |
| | | complete_response = "" |
| | | workflow_list = [] |
| | | workflow_dict = {} |
| | | if input["workflow_type"] == 1 or input["workflow_type"] == 2: |
| | | async for rag_response in dify_service.workflow(input["token"], current_user.id, i): |
| | | # print(rag_response) |
| | | try: |
| | | if rag_response[:5] == "data:": |
| | |
| | | complete_response += rag_response |
| | | try: |
| | | data = json.loads(complete_response) |
| | | # print(data) |
| | | node_data = deepcopy(data) |
| | | if "data" in node_data: |
| | | if "outputs" in node_data["data"]: |
| | | node_data["data"]["outputs"] = {} |
| | | if "inputs" in node_data["data"]: |
| | | node_data["data"]["inputs"] = {} |
| | | # print(node_data) |
| | | node_list.append(node_data) |
| | | |
| | | complete_response = "" |
| | | if data.get("event") == "node_started" or data.get("event") == "node_finished": # "event": "message_end" |
| | | if data.get("event") == "node_started": # "event": "message_end" |
| | | |
| | | if "data" not in data or not data["data"]: # 信息过滤 |
| | | logger.error("非法数据--------------------") |
| | | logger.error(data) |
| | |
| | | |
| | | message = answer.get("title", "") |
| | | |
| | | result = {"message": message, "type": "system"} |
| | | result = {"message": message, "type": "system", |
| | | "workflow": {"node_data": workflow_list}} |
| | | elif data.get("event") == "node_finished": |
| | | workflow_list.append({ |
| | | "title": data.get("data", {}).get("title", ""), |
| | | "status": data.get("data", {}).get("status", ""), |
| | | "created_at": data.get("data", {}).get("created_at", 0), |
| | | "finished_at": data.get("data", {}).get("finished_at", 0), |
| | | "node_type": data.get("data", {}).get("node_type", 0), |
| | | "elapsed_time": data.get("data", {}).get("elapsed_time", 0), |
| | | "error": data.get("data", {}).get("error", ""), |
| | | }) |
| | | answer = data.get("data", "") |
| | | if isinstance(answer, str): |
| | | logger.error("----------------未知数据--------------------") |
| | | logger.error(data) |
| | | continue |
| | | elif isinstance(answer, dict): |
| | | |
| | | message = answer.get("title", "") |
| | | if answer.get("status") == "failed": |
| | | message = answer.get("error", "") |
| | | result = {"message": message, "type": "system", |
| | | "workflow": {"node_data": workflow_list}} |
| | | |
| | | elif data.get("event") == "workflow_finished": |
| | | answer = data.get("data", "") |
| | | if isinstance(answer, str): |
| | | logger.error("----------------未知数据--------------------") |
| | | logger.error(data) |
| | | result = {"message": "", "type": "close", "download_url": ""} |
| | | result = {"message": "", "type": "close", "download_url": "", |
| | | "is_next": is_next} |
| | | elif isinstance(answer, dict): |
| | | download_url = "" |
| | | outputs = answer.get("outputs", {}) |
| | |
| | | download_url = outputs.get("download_url", "") |
| | | else: |
| | | message = answer.get("error", "") |
| | | |
| | | result = {"message": message, "type": "message", "download_url": download_url} |
| | | if download_url: |
| | | files = [{ |
| | | "type": "document", |
| | | "transfer_method": "remote_url", |
| | | "url": download_url, |
| | | "upload_file_id": "" |
| | | }] |
| | | workflow_dict = { |
| | | "node_data": workflow_list, |
| | | "total_tokens": answer.get("total_tokens", 0), |
| | | "created_at": answer.get("created_at", 0), |
| | | "finished_at": answer.get("finished_at", 0), |
| | | "status": answer.get("status", ""), |
| | | "error": answer.get("error", ""), |
| | | "elapsed_time": answer.get("elapsed_time", 0) |
| | | } |
| | | result = {"message": message, "type": "message", |
| | | "download_url": download_url, "workflow": workflow_dict} |
| | | try: |
| | | SessionService(db).update_session(chat_id, |
| | | message={"role": "assistant", |
| | | "content": { |
| | | "answer": message, |
| | | "node_list": node_list, |
| | | "download_url": download_url}}, |
| | | conversation_id=data.get( |
| | | "conversation_id")) |
| | | node_list = [] |
| | | except Exception as e: |
| | | logger.error("保存dify的会话异常!") |
| | | logger.error(e) |
| | | try: |
| | | await websocket.send_json(result) |
| | | result = {"message": "", "type": "close", "download_url": ""} |
| | | except Exception as e: |
| | | logger.error(e) |
| | | logger.error("返回客户端消息异常!") |
| | | |
| | | result = {"message": "", "type": "close", "workflow": workflow_dict, |
| | | "is_next": is_next, "download_url": download_url} |
| | | |
| | | |
| | | else: |
| | |
| | | logger.error(e) |
| | | logger.error("返回客户端消息异常!") |
| | | complete_response = "" |
| | | |
| | | except json.JSONDecodeError as e: |
| | | print(f"Error decoding JSON: {e}") |
| | | # print(f"Response text: {text}") |
| | |
| | | result = {"message": f"内部错误: {e2}", "type": "close"} |
| | | await websocket.send_json(result) |
| | | print(f"Error process message of ragflow: {e2}") |
| | | elif workflow_type == 3: |
| | | elif input["workflow_type"] == 3: |
| | | image_list = [] |
| | | # print(inputs) |
| | | complete_response = "" |
| | | async for rag_response in dify_service.chat(token, current_user.id, title_query, [], |
| | | conversation_id, inputs): |
| | | print(rag_response) |
| | | answer_str = "" |
| | | async for rag_response in dify_service.chat(input["token"], current_user.id, |
| | | title_query, [], |
| | | conversation_id, i): |
| | | # print(rag_response) |
| | | try: |
| | | if rag_response[:5] == "data:": |
| | | # 如果是,则截取掉前5个字符,并去除首尾空白符 |
| | |
| | | complete_response += rag_response |
| | | try: |
| | | data = json.loads(complete_response) |
| | | node_data = deepcopy(data) |
| | | if "data" in node_data: |
| | | if "outputs" in node_data["data"]: |
| | | node_data["data"]["outputs"] = {} |
| | | if "inputs" in node_data["data"]: |
| | | node_data["data"]["inputs"] = {} |
| | | # print(node_data) |
| | | node_list.append(node_data) |
| | | complete_response = "" |
| | | if data.get("event") == "node_started" or data.get( |
| | | "event") == "node_finished": # "event": "message_end" |
| | | if data.get("event") == "node_started": # "event": "message_end" |
| | | if "data" not in data or not data["data"]: # 信息过滤 |
| | | logger.error("非法数据--------------------") |
| | | logger.error(data) |
| | |
| | | |
| | | message = answer.get("title", "") |
| | | |
| | | result = {"message": message, "type": "system"} |
| | | result = {"message": message, "type": "system", |
| | | "workflow": {"node_data": workflow_list}} |
| | | elif data.get("event") == "node_finished": |
| | | workflow_list.append({ |
| | | "title": data.get("data", {}).get("title", ""), |
| | | "status": data.get("data", {}).get("status", ""), |
| | | "created_at": data.get("data", {}).get("created_at", 0), |
| | | "finished_at": data.get("data", {}).get("finished_at", 0), |
| | | "node_type": data.get("data", {}).get("node_type", 0), |
| | | "elapsed_time": data.get("data", {}).get("elapsed_time", 0), |
| | | "error": data.get("data", {}).get("error", ""), |
| | | }) |
| | | |
| | | answer = data.get("data", "") |
| | | if isinstance(answer, str): |
| | | logger.error("----------------未知数据--------------------") |
| | | logger.error(data) |
| | | continue |
| | | elif isinstance(answer, dict): |
| | | |
| | | message = answer.get("title", "") |
| | | if answer.get("status") == "failed": |
| | | message = answer.get("error", "") |
| | | result = {"message": message, "type": "system", |
| | | "workflow": {"node_data": workflow_list}} |
| | | elif data.get("event") == "message": |
| | | message = data.get("answer", "") |
| | | answer_str = data.get("answer", "") |
| | | # try: |
| | | # msg_dict = json.loads(answer) |
| | | # message = msg_dict.get("output", "") |
| | | # except Exception as e: |
| | | # print(e) |
| | | # continue |
| | | result = {"message": message, "type": "message", |
| | | "download_url": ""} |
| | | try: |
| | | SessionService(db).update_session(chat_id, |
| | | message={"role": "assistant", |
| | | "content": { |
| | | "answer": message, |
| | | "download_url": ""}}, |
| | | conversation_id=data.get( |
| | | "conversation_id")) |
| | | except Exception as e: |
| | | logger.error("保存dify的会话异常!") |
| | | logger.error(e) |
| | | result = {"message": answer_str, "type": "message", |
| | | "download_url": "", "workflow": {"node_data": workflow_list}} |
| | | |
| | | # try: |
| | | # await websocket.send_json(result) |
| | | # except Exception as e: |
| | | # logger.error(e) |
| | | # logger.error("返回客户端消息异常!") |
| | | |
| | | elif data.get("event") == "workflow_finished": |
| | | workflow_dict = { |
| | | "node_data": workflow_list, |
| | | "total_tokens": data.get("data", {}).get("total_tokens", 0), |
| | | "created_at": data.get("data", {}).get("created_at", 0), |
| | | "finished_at": data.get("data", {}).get("finished_at", 0), |
| | | "status": data.get("data", {}).get("status", ""), |
| | | "error": data.get("data", {}).get("error", ""), |
| | | "elapsed_time": data.get("data", {}).get("elapsed_time", 0) |
| | | } |
| | | try: |
| | | SessionService(db).update_session(chat_id, |
| | | message={"role": "assistant", |
| | | "content": { |
| | | "answer": answer_str, |
| | | "node_list": node_list, |
| | | "download_url": ""}}, |
| | | conversation_id=data.get( |
| | | "conversation_id")) |
| | | node_list = [] |
| | | except Exception as e: |
| | | logger.error("保存dify的会话异常!") |
| | | logger.error(e) |
| | | elif data.get("event") == "message_end": |
| | | result = {"message": "", "type": "close", "download_url": ""} |
| | | result = {"message": "", "type": "close", "workflow": workflow_dict, |
| | | "is_next": is_next} |
| | | else: |
| | | continue |
| | | try: |