zhaoqingang
2024-12-31 6b4093952e555e1eb2713bd85133a5f697cda1e0
app/api/chat.py
@@ -1,17 +1,22 @@
import json
import re
import uuid
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends
import asyncio
import websockets
from sqlalchemy.orm import Session
from starlette.responses import PlainTextResponse
from Log import logger
from app.api import get_current_user_websocket
from app.config.config import settings
from app.config.const import IMAGE_TO_TEXT, DOCUMENT_TO_REPORT, DOCUMENT_TO_CLEANING, DOCUMENT_TO_REPORT_TITLE, \
    DOCUMENT_TO_TITLE, DOCUMENT_IA_QUESTIONS
from app.models.agent_model import AgentModel, AgentType
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.common.api_token import DfTokenDao
from app.service.dialog import update_session_history
from app.service.basic import BasicService
from app.service.difyService import DifyService
@@ -34,11 +39,14 @@
    print(f"Client {agent_id} connected")
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    print(agent_id)
    if not agent:
        ret = {"message": "Agent not found", "type": "close"}
        await websocket.send_json(ret)
        return
    agent_type = agent.agent_type
    print(agent_type)
    if chat_id == "" or chat_id == "0":
        ret = {"message": "Chat ID not found", "type": "close"}
        await websocket.send_json(ret)
@@ -234,7 +242,7 @@
                        excel_url = None
                        if file_name:
                            excel_url = f"/api/files/download/?agent_id=basic_question_talk&file_id={file_name}&file_type=word"
                        result = {"message": output, "type": "message", "file_url": excel_url, "file_name":file_name}
                        result = {"message": output, "type": "message", "file_url": excel_url, "file_name": file_name}
                        try:
                            SessionService(db).update_session(chat_id,
                                                              message={"role": "assistant", "content": result})
@@ -250,37 +258,53 @@
                        await websocket.send_json(result)
                else:
                    message_data = {}
                    logger.error("---------------------excel_talk-----------------------------")
                    excel_url = ""
                    image_url = ""
                    image_name = ""
                    excel_name = ""
                    async for data in service.excel_talk(question, chat_id):
                        # logger.error(data)
                        output = data.get("output", "")
                        excel_name = data.get("excel_name", "")
                        image_name = data.get("image_name", "")
                        e_name = data.get("excel_name", "")
                        i_name = data.get("image_name", "")
                        def build_file_url(name, file_type):
                            if not name:
                                return None
                            return (f"/api/files/download/?agent_id={agent_id}&file_id={name}"
                                    f"&file_type={file_type}")
                        excel_url = build_file_url(excel_name, 'excel')
                        image_url = build_file_url(image_name, 'image')
                        if excel_url or data.get("e", ""):
                            try:
                                SessionService(db).update_session(chat_id,
                                                                  message={
                                                                      "content": output,
                                                                      "excel_url": excel_url,
                                                                      "image_url": image_url,
                                                                      "sql": data.get("sql", ""),
                                                                      "code": data.get("code", ""),
                                                                      "e": data.get("e", ""),
                                                                      "role": "assistant"})
                            except Exception as e:
                                logger.error(f"Unexpected error when update_session: {e}")
                        if e_name:
                            excel_url = build_file_url(e_name, 'excel')
                            excel_name = e_name
                        if i_name:
                            image_url = build_file_url(i_name, 'image')
                            image_name = i_name
                        if data["type"] == "message":
                            message_data = {
                                "content": output,
                                "excel_url": excel_url,
                                "image_url": image_url,
                                "image_name": image_name,
                                "excel_name": excel_name,
                                "sql": data.get("sql", ""),
                                "code": data.get("code", ""),
                                "e": data.get("e", ""),
                                "role": "assistant"}
                        # 发送结果给客户端
                        data["type"] = "message"
                        # data["type"] = "message"
                        data["message"] = output
                        data["excel_url"] = excel_url
                        data["image_url"] = image_url
                        await websocket.send_json(data)
                    if message_data:
                        try:
                            SessionService(db).update_session(chat_id, message=message_data)
                        except Exception as e:
                            logger.error(f"Unexpected error when update_session: {e}")
        except Exception as e:
            logger.error(e)
            await websocket.send_json({"message": "出现错误!", "type": "error"})
@@ -290,93 +314,459 @@
    if agent_type == AgentType.DIFY:
        dify_service = DifyService(settings.dify_base_url)
        # token = get_dify_token(db, current_user.id)
        token = settings.dify_api_token
        try:
            async def forward_to_dify():
                while True:
                    conversation_id = ""
                    receive_message = await websocket.receive_json()
                    print(f"Received from client {chat_id}: {receive_message}")
                    upload_file_id = receive_message.get('upload_file_id', "")
                    question = receive_message.get('message', "")
                    if not question and not image_url:
                        await websocket.send_json({"message": "Invalid request", "type": "error"})
                        continue
                    try:
                        session = SessionService(db).create_session(
                            chat_id,
                            question,
                            agent_id,
                            AgentType.DIFY,
                            current_user.id
                        )
                        conversation_id = session.conversation_id
                    except Exception as e:
                        logger.error(e)
                    # complete_response = ""
                    answer_str = ""
                    async for rag_response in dify_service.chat(token, current_user.id, question, upload_file_id, conversation_id):
                        # print("=============================================")
                        # print(rag_response)
                if agent.type == "imageTalk":
                    token = DfTokenDao(db).get_token_by_id(IMAGE_TO_TEXT)
                    if not token:
                        await websocket.send_json({"message": "Invalid token", "type": "error"})
                    while True:
                        image_list = []
                        is_image = False
                        conversation_id = ""
                        receive_message = await websocket.receive_json()
                        print(f"Received from client {chat_id}: {receive_message}")
                        upload_file_id = receive_message.get('upload_file_id', "")
                        question = receive_message.get('message', "")
                        if not question and not image_url:
                            await websocket.send_json({"message": "Invalid request", "type": "error"})
                            continue
                        try:
                            if rag_response[:5] == "data:":
                                # 如果是,则截取掉前5个字符,并去除首尾空白符
                                complete_response = rag_response[5:].strip()
                            else:
                                # 否则,保持原样
                                complete_response = rag_response
                            # complete_response += text
                            session = SessionService(db).create_session(
                                chat_id,
                                question,
                                agent_id,
                                AgentType.DIFY,
                                current_user.id
                            )
                            conversation_id = session.conversation_id
                        except Exception as e:
                            logger.error(e)
                        # complete_response = ""
                        files = []
                        if upload_file_id:
                            files.append({
                                "type": "image",
                                "transfer_method": "local_file",
                                "url": "",
                                "upload_file_id": upload_file_id
                            })
                        answer_str = ""
                        async for rag_response in dify_service.chat(token, current_user.id, question, files,
                                                                    conversation_id, {}):
                            # print(rag_response)
                            try:
                                data = json.loads(complete_response)
                                complete_response = ""
                                # data = json_data.get("data")
                                if data.get("event") == "agent_message":# "event": "message_end"
                                    if "answer" not in  data or not data["answer"]:  # 信息过滤
                                        logger.error("非法数据--------------------")
                                        # logger.error(data)
                                        continue
                                    else:  # 正常输出
                                        answer = data.get("answer", "")
                                        if isinstance(answer, str):
                                            answer_str += answer
                                        elif isinstance(answer, dict):
                                            logger.error("未知数据体:0---------------------------------")
                                            logger.error(answer)
                                            answer_str += answer.get("action_input", "")
                                        result = {"message": answer_str, "type": "message"}
                                elif data.get("event") == "message_end":
                                    message_files = []
                                    res_msg = await dify_service.get_session_history(token, data.get("conversation_id"), str(current_user.id))
                                    if len(res_msg) > 0:
                                        message_files = res_msg[0].get("message_files")
                                    result = {"message": answer_str, "type": "close", "message_files": message_files}
                                    try:
                                        SessionService(db).update_session(chat_id,
                                                                          message={"role": "assistant", "content": {"answer":answer_str, "images":[i.get("url") for i in message_files]}},conversation_id=data.get("conversation_id"))
                                    except Exception as e:
                                        logger.error("保存dify的会话异常!")
                                        logger.error(e)
                                elif data.get("event") == "message_file":
                                    url = data.get("url", "")
                                    result = {"message": url, "type": "image"}
                                if rag_response[:5] == "data:":
                                    # 如果是,则截取掉前5个字符,并去除首尾空白符
                                    complete_response = rag_response[5:].strip()
                                else:
                                    continue
                                await websocket.send_json(result)
                                complete_response = ""
                            except json.JSONDecodeError as e:
                                print(f"Error decoding JSON: {e}")
                                # print(f"Response text: {text}")
                        except Exception as e2:
                            result = {"message": f"内部错误: {e2}", "type": "close"}
                            await websocket.send_json(result)
                            print(f"Error process message of ragflow: {e2}")
                                    # 否则,保持原样
                                    complete_response = rag_response
                                try:
                                    data = json.loads(complete_response)
                                    if data.get("event") == "agent_message":  # "event": "message_end"
                                        if "answer" not in data or not data["answer"]:  # 信息过滤
                                            logger.error("非法数据--------------------")
                                            # logger.error(data)
                                            continue
                                        else:  # 正常输出
                                            answer = data.get("answer", "")
                                            if isinstance(answer, str):
                                                if "![](https://res.stepfun.com/" in answer and image_list:
                                                    is_image = True
                                                    pattern = r'!\[\] *\(https://res\.stepfun\.com/image_gen/[^)]+\)'
                                                    url_image = image_list.pop()
                                                    new_answer = re.sub(pattern, url_image, answer)
                                                    answer_str += new_answer
                                                else:
                                                    answer_str += answer
                                            elif isinstance(answer, dict):
                                                logger.error("未知数据体:0---------------------------------")
                                                logger.error(answer)
                                                answer_str += answer.get("action_input", "")
                                            result = {"message": answer_str, "type": "message"}
                                    elif data.get("event") == "message_end":
                                        images_url = []
                                        if image_list and not is_image:
                                            answer_str += image_list[-1]
                                        result = {"message": answer_str,
                                                  "type": "close"}  # , "message_files": images_url
                                        try:
                                            SessionService(db).update_session(chat_id,
                                                                              message={"role": "assistant",
                                                                                       "content": {"answer": answer_str,
                                                                                                   "images": images_url}},
                                                                              conversation_id=data.get(
                                                                                  "conversation_id"))
                                        except Exception as e:
                                            logger.error("保存dify的会话异常!")
                                            logger.error(e)
                                    elif data.get("event") == "message_file":
                                        await  dify_service.save_images(data.get("url"), data.get("id") + ".png")
                                        image_list.append(f"![](/api/files/image/{data.get('id')})")
                                        # result = {"message": answer_str, "type": "message"}
                                        continue
                                    else:
                                        continue
                                    await websocket.send_json(result)
                                    complete_response = ""
                                except json.JSONDecodeError as e:
                                    print(f"Error decoding JSON: {e}")
                                    # print(f"Response text: {text}")
                            except Exception as e2:
                                result = {"message": f"内部错误: {e2}", "type": "close"}
                                await websocket.send_json(result)
                                print(f"Error process message of ragflow: {e2}")
                elif agent.type == "reportWorkflow":
                    token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_CLEANING)
                    if not token:
                        await websocket.send_json({"message": "Invalid token document_to_cleaning", "type": "error"})
                    while True:
                        receive_message = await websocket.receive_json()
                        print(f"Received from client {chat_id}: {receive_message}")
                        upload_files = receive_message.get('upload_files', [])
                        title = receive_message.get('title', "")
                        sub_titles = receive_message.get('sub_titles', "")
                        workflow_type = receive_message.get('workflow', 1)
                        title_number = receive_message.get('title_number', 8)
                        title_style = receive_message.get('title_style', "")
                        title_query = receive_message.get('title_query', "")
                        if upload_files:
                            title_query = "start"
                        # if not upload_files:
                            # await websocket.send_json({"message": "Invalid request", "type": "error"})
                            # continue
                        try:
                            session = SessionService(db).create_session(
                                chat_id,
                                title if title else title_query,
                                agent_id,
                                AgentType.DIFY,
                                current_user.id
                            )
                            conversation_id = session.conversation_id
                        except Exception as e:
                            logger.error(e)
                        inputs = {
                        }
                        files = []
                        for file in upload_files:
                            files.append({
                                "type": "document",
                                "transfer_method": "local_file",
                                "url": "",
                                "upload_file_id": file
                            })
                        if workflow_type == 1:
                            inputs["input_files"] = files
                        elif workflow_type == 2:
                            inputs["file_list"] = files
                            inputs["Completion_of_main_indicators"] = title
                            inputs["sub_titles"] = sub_titles
                            token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_REPORT_TITLE)
                            if not token:
                                await websocket.send_json(
                                    {"message": "Invalid token document_to_report", "type": "error"})
                        elif workflow_type == 3:
                            inputs["file_list"] = files
                            inputs["number_of_title"] = title_number
                            inputs["title_style"] = title_style
                            token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_TITLE)
                            if not token:
                                await websocket.send_json(
                                    {"message": "Invalid token document_to_title", "type": "error"})
                        complete_response = ""
                        if workflow_type == 1 or workflow_type == 2:
                            async for rag_response in dify_service.workflow(token, current_user.id, inputs):
                                # print(rag_response)
                                try:
                                    if rag_response[:5] == "data:":
                                        # 如果是,则截取掉前5个字符,并去除首尾空白符
                                        complete_response = rag_response[5:].strip()
                                    elif "event: ping" in rag_response:
                                        continue
                                    else:
                                        # 否则,保持原样
                                        complete_response += rag_response
                                    try:
                                        data = json.loads(complete_response)
                                        complete_response = ""
                                        if data.get("event") == "node_started" or data.get("event") == "node_finished":  # "event": "message_end"
                                            if "data" not in data or not data["data"]:  # 信息过滤
                                                logger.error("非法数据--------------------")
                                                logger.error(data)
                                                continue
                                            else:  # 正常输出
                                                answer = data.get("data", "")
                                                if isinstance(answer, str):
                                                    logger.error("----------------未知数据--------------------")
                                                    logger.error(data)
                                                    continue
                                                elif isinstance(answer, dict):
                                                    message = answer.get("title", "")
                                                result = {"message": message, "type": "system"}
                                        elif data.get("event") == "workflow_finished":
                                            answer = data.get("data", "")
                                            if isinstance(answer, str):
                                                logger.error("----------------未知数据--------------------")
                                                logger.error(data)
                                                result = {"message": "", "type": "close", "download_url": ""}
                                            elif isinstance(answer, dict):
                                                download_url = ""
                                                outputs = answer.get("outputs", {})
                                                if outputs:
                                                    message = outputs.get("output", "")
                                                    download_url = outputs.get("download_url", "")
                                                else:
                                                    message = answer.get("error", "")
                                                result = {"message": message, "type": "message", "download_url": download_url}
                                                try:
                                                    SessionService(db).update_session(chat_id,
                                                                                      message={"role": "assistant",
                                                                                               "content": {
                                                                                                   "answer": message,
                                                                                                   "download_url": download_url}},
                                                                                      conversation_id=data.get(
                                                                                          "conversation_id"))
                                                except Exception as e:
                                                    logger.error("保存dify的会话异常!")
                                                    logger.error(e)
                                                try:
                                                    await websocket.send_json(result)
                                                except Exception as e:
                                                    logger.error(e)
                                                    logger.error("返回客户端消息异常!")
                                                result = {"message": "", "type": "close", "download_url": ""}
                                        else:
                                            continue
                                        try:
                                            await websocket.send_json(result)
                                        except Exception  as e:
                                            logger.error(e)
                                            logger.error("返回客户端消息异常!")
                                        complete_response = ""
                                    except json.JSONDecodeError as e:
                                        print(f"Error decoding JSON: {e}")
                                        # print(f"Response text: {text}")
                                except Exception as e2:
                                    result = {"message": f"内部错误: {e2}", "type": "close"}
                                    await websocket.send_json(result)
                                    print(f"Error process message of ragflow: {e2}")
                        elif workflow_type == 3:
                            image_list = []
                            # print(inputs)
                            complete_response = ""
                            async for rag_response in dify_service.chat(token, current_user.id, title_query, [],
                                                                        conversation_id, inputs):
                                print(rag_response)
                                try:
                                    if rag_response[:5] == "data:":
                                        # 如果是,则截取掉前5个字符,并去除首尾空白符
                                        complete_response = rag_response[5:].strip()
                                    elif "event: ping" in rag_response:
                                        continue
                                    else:
                                        # 否则,保持原样
                                        complete_response += rag_response
                                    try:
                                        data = json.loads(complete_response)
                                        complete_response = ""
                                        if data.get("event") == "node_started" or data.get(
                                                "event") == "node_finished":  # "event": "message_end"
                                            if "data" not in data or not data["data"]:  # 信息过滤
                                                logger.error("非法数据--------------------")
                                                logger.error(data)
                                                continue
                                            else:  # 正常输出
                                                answer = data.get("data", "")
                                                if isinstance(answer, str):
                                                    logger.error("----------------未知数据--------------------")
                                                    logger.error(data)
                                                    continue
                                                elif isinstance(answer, dict):
                                                    message = answer.get("title", "")
                                                result = {"message": message, "type": "system"}
                                        elif data.get("event") == "message":
                                            message = data.get("answer", "")
                                            # try:
                                            #     msg_dict = json.loads(answer)
                                            #     message = msg_dict.get("output",  "")
                                            # except Exception as e:
                                            #     print(e)
                                            #     continue
                                            result = {"message": message, "type": "message",
                                                      "download_url": ""}
                                            try:
                                                SessionService(db).update_session(chat_id,
                                                                                  message={"role": "assistant",
                                                                                           "content": {
                                                                                               "answer": message,
                                                                                               "download_url": ""}},
                                                                                  conversation_id=data.get(
                                                                                      "conversation_id"))
                                            except Exception as e:
                                                logger.error("保存dify的会话异常!")
                                                logger.error(e)
                                            # try:
                                            #     await websocket.send_json(result)
                                            # except Exception as e:
                                            #     logger.error(e)
                                            #     logger.error("返回客户端消息异常!")
                                        elif data.get("event") == "message_end":
                                            result = {"message": "", "type": "close", "download_url": ""}
                                        else:
                                            continue
                                        try:
                                            await websocket.send_json(result)
                                        except Exception as e:
                                            logger.error(e)
                                            logger.error("dify返回客户端消息异常!")
                                        complete_response = ""
                                    except json.JSONDecodeError as e:
                                        print(f"Error decoding JSON: {e}")
                                        # print(f"Response text: {text}")
                                except Exception as e2:
                                    result = {"message": f"内部错误: {e2}", "type": "close"}
                                    await websocket.send_json(result)
                                    print(f"Error process message of ragflow: {e2}")
                elif agent.type == "documentIa":
                    print(122112)
                    token = DfTokenDao(db).get_token_by_id(DOCUMENT_IA_QUESTIONS)
                    # print(token)
                    if not token:
                        await websocket.send_json({"message": "Invalid token", "type": "error"})
                    while True:
                        conversation_id = ""
                        # print(4343)
                        receive_message = await websocket.receive_json()
                        print(f"Received from client {chat_id}: {receive_message}")
                        upload_file_id = receive_message.get('upload_file_id', [])
                        question = receive_message.get('message', "")
                        if not question and not image_url:
                            await websocket.send_json({"message": "Invalid request", "type": "error"})
                            continue
                        try:
                            session = SessionService(db).create_session(
                                chat_id,
                                question,
                                agent_id,
                                AgentType.DIFY,
                                current_user.id
                            )
                            conversation_id = session.conversation_id
                        except Exception as e:
                            logger.error(e)
                        # complete_response = ""
                        files = []
                        for fileId in upload_file_id:
                            files.append({
                                "type": "document",
                                "transfer_method": "local_file",
                                "url": "",
                                "upload_file_id": fileId
                            })
                        answer_str = ""
                        complete_response = ""
                        async for rag_response in dify_service.chat(token, current_user.id, question, files,
                                                                    conversation_id, {}):
                            # print(rag_response)
                            try:
                                if rag_response[:5] == "data:":
                                    # 如果是,则截取掉前5个字符,并去除首尾空白符
                                    complete_response = rag_response[5:].strip()
                                elif "event: ping" in rag_response:
                                    continue
                                else:
                                    # 否则,保持原样
                                    complete_response += rag_response
                                try:
                                    data = json.loads(complete_response)
                                    if data.get("event") == "node_started" or data.get(
                                            "event") == "node_finished":  # "event": "message_end"
                                        if "data" not in data or not data["data"]:  # 信息过滤
                                            logger.error("非法数据--------------------")
                                            logger.error(data)
                                            continue
                                        else:  # 正常输出
                                            answer = data.get("data", "")
                                            if isinstance(answer, str):
                                                logger.error("----------------未知数据--------------------")
                                                logger.error(data)
                                                continue
                                            elif isinstance(answer, dict):
                                                message = answer.get("title", "")
                                            result = {"message": message, "type": "system"}
                                            continue
                                    elif data.get("event") == "message":  # "event": "message_end"
                                        # 正常输出
                                        answer = data.get("answer", "")
                                        result = {"message": answer, "type": "stream"}
                                    elif data.get("event") == "workflow_finished":
                                        answer = data.get("data", "")
                                        if isinstance(answer, str):
                                            logger.error("----------------未知数据--------------------")
                                            logger.error(data)
                                            result = {"message": "", "type": "close", "download_url": ""}
                                        elif isinstance(answer, dict):
                                            download_url = ""
                                            outputs = answer.get("outputs", {})
                                            if outputs:
                                                message = outputs.get("answer", "")
                                                # download_url = outputs.get("download_url", "")
                                            else:
                                                message = answer.get("error", "")
                                            # result = {"message": message, "type": "message",
                                            #           "download_url": download_url}
                                            try:
                                                SessionService(db).update_session(chat_id,
                                                                                  message={"role": "assistant",
                                                                                           "content": {
                                                                                               "answer": message,
                                                                                               "download_url": download_url}},
                                                                                  conversation_id=data.get(
                                                                                      "conversation_id"))
                                            except Exception as e:
                                                logger.error("保存dify的会话异常!")
                                                logger.error(e)
                                            # await websocket.send_json(result)
                                        continue
                                    elif data.get("event") == "message_end":
                                        result = {"message": "", "type": "close"}
                                    else:
                                        continue
                                    try:
                                        await websocket.send_json(result)
                                    except Exception as e:
                                        logger.error(e)
                                        logger.error("返回客户端消息异常!")
                                    complete_response = ""
                                except json.JSONDecodeError as e:
                                    print(f"Error decoding JSON: {e}")
                                    # print(f"Response text: {text}")
                            except Exception as e2:
                                result = {"message": f"内部错误: {e2}", "type": "close"}
                                await websocket.send_json(result)
                                print(f"Error process message of ragflow: {e2}")
            # 启动任务处理客户端消息
            tasks = [