xuyonghao
2025-02-08 72a8a0a1ad6b79b8e9fb2facef121f9b5d584666
app/api/chat.py
@@ -1,16 +1,19 @@
import json
import re
import uuid
from copy import deepcopy
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends
import asyncio
import websockets
from sqlalchemy.orm import Session
from starlette.responses import PlainTextResponse
from Log import logger
from app.api import get_current_user_websocket
from app.config.config import settings
from app.config.const import IMAGE_TO_TEXT, DOCUMENT_TO_REPORT, DOCUMENT_TO_CLEANING
from app.config.const import IMAGE_TO_TEXT, DOCUMENT_TO_REPORT, DOCUMENT_TO_CLEANING, DOCUMENT_TO_REPORT_TITLE, \
    DOCUMENT_TO_TITLE, DOCUMENT_IA_QUESTIONS
from app.models.agent_model import AgentModel, AgentType
from app.models.base_model import get_db
from app.models.user_model import UserModel
@@ -37,11 +40,14 @@
    print(f"Client {agent_id} connected")
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    print(agent_id)
    if not agent:
        ret = {"message": "Agent not found", "type": "close"}
        await websocket.send_json(ret)
        return
    agent_type = agent.agent_type
    print(agent_type)
    if chat_id == "" or chat_id == "0":
        ret = {"message": "Chat ID not found", "type": "close"}
        await websocket.send_json(ret)
@@ -339,10 +345,17 @@
                        except Exception as e:
                            logger.error(e)
                        # complete_response = ""
                        files = []
                        if upload_file_id:
                            files.append({
                                "type": "image",
                                "transfer_method": "local_file",
                                "url": "",
                                "upload_file_id": upload_file_id
                            })
                        answer_str = ""
                        async for rag_response in dify_service.chat(token, current_user.id, question, upload_file_id,
                                                                    conversation_id):
                        async for rag_response in dify_service.chat(token, current_user.id, question, files,
                                                                    conversation_id, {}):
                            # print(rag_response)
                            try:
                                if rag_response[:5] == "data:":
@@ -411,25 +424,35 @@
                                print(f"Error process message of ragflow: {e2}")
                elif agent.type == "reportWorkflow":
                    token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_CLEANING)
                    if not token:
                        await websocket.send_json({"message": "Invalid token document_to_cleaning", "type": "error"})
                    while True:
                        receive_message = await websocket.receive_json()
                        print(f"Received from client {chat_id}: {receive_message}")
                        upload_files = receive_message.get('upload_files', [])
                        title = receive_message.get('title', "")
                        sub_titles = receive_message.get('sub_titles', "")
                        workflow_type = receive_message.get('workflow', 1)
                        if not upload_files:
                            await websocket.send_json({"message": "Invalid request", "type": "error"})
                            continue
                        title_number = receive_message.get('title_number', 8)
                        title_style = receive_message.get('title_style', "")
                        title_query = receive_message.get('title_query', "")
                        is_clean = receive_message.get('is_clean', 0)
                        file_type = receive_message.get('file_type', 1)
                        max_token = receive_message.get('max_tokens', 100000)
                        tokens = receive_message.get('tokens', 0)
                        if upload_files:
                            title_query = "start"
                        # if not upload_files:
                            # await websocket.send_json({"message": "Invalid request", "type": "error"})
                            # continue
                        try:
                            session = SessionService(db).create_session(
                                chat_id,
                                title,
                                title if title else title_query,
                                agent_id,
                                AgentType.DIFY,
                                current_user.id
                                current_user.id,
                                {"role": "user", "content": title if title else title_query, "type": workflow_type, "is_clean":is_clean},
                                workflow_type
                            )
                            conversation_id = session.conversation_id
                        except Exception as e:
@@ -438,23 +461,378 @@
                        }
                        files = []
                        for file in upload_files:
                            if file_type == 1:
                                files.append({
                                    "type": "document",
                                    "transfer_method": "local_file",
                                    "url": "",
                                    "upload_file_id": file
                                })
                            else:
                                files.append({
                                    "type": "document",
                                    "transfer_method": "remote_url",
                                    "url": file,
                                    "upload_file_id": ""
                                })
                        inputs_list = []
                        is_next = 0
                        if workflow_type == 1:
                            token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_CLEANING)
                            if not token:
                                await websocket.send_json(
                                    {"message": "Invalid token document_to_cleaning", "type": "error"})
                            inputs["input_files"] = files
                            inputs["Completion_of_main_indicators"] = title
                            inputs_list.append({"inputs": inputs, "token": token, "workflow_type": workflow_type})
                        elif workflow_type == 2:
                            inputs["file_list"] = files
                            inputs["Completion_of_main_indicators"] = title
                            inputs["sub_titles"] = sub_titles
                            token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_REPORT_TITLE)
                            if not token:
                                await websocket.send_json(
                                    {"message": "Invalid token document_to_report", "type": "error"})
                            inputs_list.append({"inputs": inputs, "token": token, "workflow_type": workflow_type})
                        elif workflow_type == 3 and is_clean == 0 and tokens < max_token:
                            inputs["file_list"] = files
                            inputs["number_of_title"] = title_number
                            inputs["title_style"] = title_style
                            token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_TITLE)
                            if not token:
                                await websocket.send_json(
                                    {"message": "Invalid token document_to_title", "type": "error"})
                            inputs_list.append({"inputs": inputs, "token": token, "workflow_type": workflow_type})
                        elif workflow_type == 3 and is_clean == 1 or tokens >= max_token:
                            token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_CLEANING)
                            if not token:
                                await websocket.send_json(
                                    {"message": "Invalid token document_to_cleaning", "type": "error"})
                            inputs["input_files"] = files
                            inputs["Completion_of_main_indicators"] = title
                            inputs_list.append({"inputs": inputs, "token": token, "workflow_type": 1})
                            inputs1 = {}
                            inputs1["file_list"] = files
                            inputs1["number_of_title"] = title_number
                            inputs1["title_style"] = title_style
                            token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_TITLE)
                            if not token:
                                await websocket.send_json(
                                    {"message": "Invalid token document_to_report", "type": "error"})
                            inputs_list.append({"inputs": inputs1, "token": token, "workflow_type": 3})
                        # print(inputs_list)
                        for idx, input in enumerate(inputs_list):
                            # print(input)
                            if idx < len(inputs_list)-1:
                                is_next = 1
                            else:
                                is_next = 0
                            i = input["inputs"]
                            if "file_list" in i:
                                i["file_list"] = files
                            # print(i)
                            node_list = []
                            complete_response = ""
                            workflow_list = []
                            workflow_dict = {}
                            if input["workflow_type"] == 1 or input["workflow_type"] == 2:
                                async for rag_response in dify_service.workflow(input["token"], current_user.id, i):
                                    # print(rag_response)
                                    try:
                                        if rag_response[:5] == "data:":
                                            # 如果是,则截取掉前5个字符,并去除首尾空白符
                                            complete_response = rag_response[5:].strip()
                                        elif "event: ping" in rag_response:
                                            continue
                                        else:
                                            # 否则,保持原样
                                            complete_response += rag_response
                                        try:
                                            data = json.loads(complete_response)
                                            # print(data)
                                            node_data = deepcopy(data)
                                            if "data" in node_data:
                                                if "outputs" in node_data["data"]:
                                                    node_data["data"]["outputs"] = {}
                                                if "inputs" in node_data["data"]:
                                                    node_data["data"]["inputs"] = {}
                                            # print(node_data)
                                            node_list.append(node_data)
                                            complete_response = ""
                                            if data.get("event") == "node_started":  # "event": "message_end"
                                                if "data" not in data or not data["data"]:  # 信息过滤
                                                    logger.error("非法数据--------------------")
                                                    logger.error(data)
                                                    continue
                                                else:  # 正常输出
                                                    answer = data.get("data", "")
                                                    if isinstance(answer, str):
                                                        logger.error("----------------未知数据--------------------")
                                                        logger.error(data)
                                                        continue
                                                    elif isinstance(answer, dict):
                                                        message = answer.get("title", "")
                                                    result = {"message": message, "type": "system", "workflow":{"node_data": workflow_list}}
                                            elif data.get("event") == "node_finished":
                                                workflow_list.append({
                                                    "title": data.get("data", {}).get("title", ""),
                                                    "status": data.get("data", {}).get("status", ""),
                                                    "created_at":data.get("data", {}).get("created_at", 0),
                                                    "finished_at":data.get("data", {}).get("finished_at", 0),
                                                    "node_type":data.get("data", {}).get("node_type", 0),
                                                    "elapsed_time":data.get("data", {}).get("elapsed_time", 0),
                                                    "error":data.get("data", {}).get("error", ""),
                                                })
                                                answer = data.get("data", "")
                                                if isinstance(answer, str):
                                                    logger.error("----------------未知数据--------------------")
                                                    logger.error(data)
                                                    continue
                                                elif isinstance(answer, dict):
                                                    message = answer.get("title", "")
                                                    if answer.get("status") == "failed":
                                                        message = answer.get("error", "")
                                                        result = {"message": message, "type": "system", "workflow":{"node_data": workflow_list}}
                                            elif data.get("event") == "workflow_finished":
                                                answer = data.get("data", "")
                                                if isinstance(answer, str):
                                                    logger.error("----------------未知数据--------------------")
                                                    logger.error(data)
                                                    result = {"message": "", "type": "close", "download_url": "", "is_next": is_next}
                                                elif isinstance(answer, dict):
                                                    download_url = ""
                                                    outputs = answer.get("outputs", {})
                                                    if outputs:
                                                        message = outputs.get("output", "")
                                                        download_url = outputs.get("download_url", "")
                                                    else:
                                                        message = answer.get("error", "")
                                                    if download_url:
                                                        files = [{
                                                            "type": "document",
                                                            "transfer_method": "remote_url",
                                                            "url": download_url,
                                                            "upload_file_id": ""
                                                        }]
                                                    workflow_dict = {
                                                        "node_data": workflow_list,
                                                        "total_tokens": answer.get("total_tokens", 0),
                                                        "created_at": answer.get("created_at", 0),
                                                        "finished_at": answer.get("finished_at", 0),
                                                        "status": answer.get("status", ""),
                                                        "error": answer.get("error", ""),
                                                        "elapsed_time": answer.get("elapsed_time", 0)
                                                    }
                                                    result = {"message": message, "type": "message", "download_url": download_url, "workflow":workflow_dict}
                                                    try:
                                                        SessionService(db).update_session(chat_id,
                                                                                          message={"role": "assistant",
                                                                                                   "content": {
                                                                                                       "answer": message,
                                                                                                       "node_list": node_list,
                                                                                                       "download_url": download_url}},
                                                                                          conversation_id=data.get(
                                                                                              "conversation_id"))
                                                        node_list = []
                                                    except Exception as e:
                                                        logger.error("保存dify的会话异常!")
                                                        logger.error(e)
                                                    try:
                                                        await websocket.send_json(result)
                                                    except Exception as e:
                                                        logger.error(e)
                                                        logger.error("返回客户端消息异常!")
                                                    result = {"message": "", "type": "close", "workflow": workflow_dict, "is_next": is_next, "download_url": download_url}
                                            else:
                                                continue
                                            try:
                                                await websocket.send_json(result)
                                            except Exception  as e:
                                                logger.error(e)
                                                logger.error("返回客户端消息异常!")
                                            complete_response = ""
                                        except json.JSONDecodeError as e:
                                            print(f"Error decoding JSON: {e}")
                                            # print(f"Response text: {text}")
                                    except Exception as e2:
                                        result = {"message": f"内部错误: {e2}", "type": "close"}
                                        await websocket.send_json(result)
                                        print(f"Error process message of ragflow: {e2}")
                            elif input["workflow_type"] == 3:
                                image_list = []
                                # print(inputs)
                                complete_response = ""
                                answer_str = ""
                                async for rag_response in dify_service.chat(input["token"], current_user.id, title_query, [],
                                                                            conversation_id, i):
                                    # print(rag_response)
                                    try:
                                        if rag_response[:5] == "data:":
                                            # 如果是,则截取掉前5个字符,并去除首尾空白符
                                            complete_response = rag_response[5:].strip()
                                        elif "event: ping" in rag_response:
                                            continue
                                        else:
                                            # 否则,保持原样
                                            complete_response += rag_response
                                        try:
                                            data = json.loads(complete_response)
                                            node_data = deepcopy(data)
                                            if "data" in node_data:
                                                if "outputs" in node_data["data"]:
                                                    node_data["data"]["outputs"] = {}
                                                if "inputs" in node_data["data"]:
                                                    node_data["data"]["inputs"] = {}
                                            # print(node_data)
                                            node_list.append(node_data)
                                            complete_response = ""
                                            if data.get("event") == "node_started":  # "event": "message_end"
                                                if "data" not in data or not data["data"]:  # 信息过滤
                                                    logger.error("非法数据--------------------")
                                                    logger.error(data)
                                                    continue
                                                else:  # 正常输出
                                                    answer = data.get("data", "")
                                                    if isinstance(answer, str):
                                                        logger.error("----------------未知数据--------------------")
                                                        logger.error(data)
                                                        continue
                                                    elif isinstance(answer, dict):
                                                        message = answer.get("title", "")
                                                    result = {"message": message, "type": "system", "workflow":{"node_data": workflow_list}}
                                            elif data.get("event") == "node_finished":
                                                workflow_list.append({
                                                    "title": data.get("data", {}).get("title", ""),
                                                    "status": data.get("data", {}).get("status", ""),
                                                    "created_at":data.get("data", {}).get("created_at", 0),
                                                    "finished_at":data.get("data", {}).get("finished_at", 0),
                                                    "node_type":data.get("data", {}).get("node_type", 0),
                                                    "elapsed_time":data.get("data", {}).get("elapsed_time", 0),
                                                    "error":data.get("data", {}).get("error", ""),
                                                })
                                                answer = data.get("data", "")
                                                if isinstance(answer, str):
                                                    logger.error("----------------未知数据--------------------")
                                                    logger.error(data)
                                                    continue
                                                elif isinstance(answer, dict):
                                                    message = answer.get("title", "")
                                                    if answer.get("status") == "failed":
                                                        message = answer.get("error", "")
                                                    result = {"message": message, "type": "system", "workflow":{"node_data": workflow_list}}
                                            elif data.get("event") == "message":
                                                answer_str = data.get("answer", "")
                                                # try:
                                                #     msg_dict = json.loads(answer)
                                                #     message = msg_dict.get("output",  "")
                                                # except Exception as e:
                                                #     print(e)
                                                #     continue
                                                result = {"message": answer_str, "type": "message",
                                                          "download_url": "", "workflow": {"node_data": workflow_list}}
                                                # try:
                                                #     await websocket.send_json(result)
                                                # except Exception as e:
                                                #     logger.error(e)
                                                #     logger.error("返回客户端消息异常!")
                                            elif data.get("event") == "workflow_finished":
                                                workflow_dict = {
                                                    "node_data": workflow_list,
                                                    "total_tokens": data.get("data", {}).get("total_tokens", 0),
                                                    "created_at": data.get("data", {}).get("created_at", 0),
                                                    "finished_at": data.get("data", {}).get("finished_at", 0),
                                                    "status": data.get("data", {}).get("status", ""),
                                                    "error": data.get("data", {}).get("error", ""),
                                                    "elapsed_time": data.get("data", {}).get("elapsed_time", 0)
                                                }
                                                try:
                                                    SessionService(db).update_session(chat_id,
                                                                                      message={"role": "assistant",
                                                                                               "content": {
                                                                                                   "answer": answer_str,
                                                                                                   "node_list": node_list,
                                                                                                   "download_url": ""}},
                                                                                      conversation_id=data.get(
                                                                                          "conversation_id"))
                                                    node_list = []
                                                except Exception as e:
                                                    logger.error("保存dify的会话异常!")
                                                    logger.error(e)
                                            elif data.get("event") == "message_end":
                                                result = {"message": "", "type": "close", "workflow": workflow_dict, "is_next": is_next}
                                            else:
                                                continue
                                            try:
                                                await websocket.send_json(result)
                                            except Exception as e:
                                                logger.error(e)
                                                logger.error("dify返回客户端消息异常!")
                                            complete_response = ""
                                        except json.JSONDecodeError as e:
                                            print(f"Error decoding JSON: {e}")
                                            # print(f"Response text: {text}")
                                    except Exception as e2:
                                        result = {"message": f"内部错误: {e2}", "type": "close"}
                                        await websocket.send_json(result)
                                        print(f"Error process message of ragflow: {e2}")
                elif agent.type == "documentIa":
                    token = DfTokenDao(db).get_token_by_id(DOCUMENT_IA_QUESTIONS)
                    # print(token)
                    if not token:
                        await websocket.send_json({"message": "Invalid token", "type": "error"})
                    while True:
                        conversation_id = ""
                        # print(4343)
                        receive_message = await websocket.receive_json()
                        print(f"Received from client {chat_id}: {receive_message}")
                        upload_file_id = receive_message.get('upload_file_id', [])
                        question = receive_message.get('message', "")
                        if not question and not image_url:
                            await websocket.send_json({"message": "Invalid request", "type": "error"})
                            continue
                        try:
                            session = SessionService(db).create_session(
                                chat_id,
                                question,
                                agent_id,
                                AgentType.DIFY,
                                current_user.id
                            )
                            conversation_id = session.conversation_id
                        except Exception as e:
                            logger.error(e)
                        # complete_response = ""
                        files = []
                        for fileId in upload_file_id:
                            files.append({
                                "type": "document",
                                "transfer_method": "local_file",
                                "url": "",
                                "upload_file_id": file
                                "upload_file_id": fileId
                            })
                        if workflow_type == 1:
                            inputs["input_files"] = files
                        if workflow_type == 2:
                            inputs["file_list"] = files
                            inputs["Completion_of_main_indicators"] = title
                            token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_REPORT)
                            if not token:
                                await websocket.send_json(
                                    {"message": "Invalid token document_to_cleaning", "type": "error"})
                        answer_str = ""
                        complete_response = ""
                        async for rag_response in dify_service.workflow(token, current_user.id, inputs):
                        async for rag_response in dify_service.chat(token, current_user.id, question, files,
                                                                    conversation_id, {}):
                            # print(rag_response)
                            try:
                                if rag_response[:5] == "data:":
@@ -467,8 +845,8 @@
                                    complete_response += rag_response
                                try:
                                    data = json.loads(complete_response)
                                    complete_response = ""
                                    if data.get("event") == "node_started" or data.get("event") == "node_finished":  # "event": "message_end"
                                    if data.get("event") == "node_started" or data.get(
                                            "event") == "node_finished":  # "event": "message_end"
                                        if "data" not in data or not data["data"]:  # 信息过滤
                                            logger.error("非法数据--------------------")
                                            logger.error(data)
@@ -484,6 +862,11 @@
                                                message = answer.get("title", "")
                                            result = {"message": message, "type": "system"}
                                            continue
                                    elif data.get("event") == "message":  # "event": "message_end"
                                        # 正常输出
                                        answer = data.get("answer", "")
                                        result = {"message": answer, "type": "stream"}
                                    elif data.get("event") == "workflow_finished":
                                        answer = data.get("data", "")
                                        if isinstance(answer, str):
@@ -494,12 +877,13 @@
                                            download_url = ""
                                            outputs = answer.get("outputs", {})
                                            if outputs:
                                                message = outputs.get("output", "")
                                                download_url = outputs.get("download_url", "")
                                                message = outputs.get("answer", "")
                                                # download_url = outputs.get("download_url", "")
                                            else:
                                                message = answer.get("error", "")
                                            result = {"message": message, "type": "message", "download_url": download_url}
                                            # result = {"message": message, "type": "message",
                                            #           "download_url": download_url}
                                            try:
                                                SessionService(db).update_session(chat_id,
                                                                                  message={"role": "assistant",
@@ -511,15 +895,16 @@
                                            except Exception as e:
                                                logger.error("保存dify的会话异常!")
                                                logger.error(e)
                                            await websocket.send_json(result)
                                            result = {"message": "", "type": "close", "download_url": ""}
                                            # await websocket.send_json(result)
                                        continue
                                    elif data.get("event") == "message_end":
                                        result = {"message": "", "type": "close"}
                                    else:
                                        continue
                                    try:
                                        await websocket.send_json(result)
                                    except Exception  as e:
                                    except Exception as e:
                                        logger.error(e)
                                        logger.error("返回客户端消息异常!")
                                    complete_response = ""