zhaoqingang
2024-12-16 88360b4ac6f051f62a91e93d602fd393935071ab
app/api/chat.py
@@ -10,9 +10,12 @@
from Log import logger
from app.api import get_current_user_websocket
from app.config.config import settings
from app.config.const import IMAGE_TO_TEXT, DOCUMENT_TO_REPORT, DOCUMENT_TO_CLEANING
from app.models import MenuCapacityModel
from app.models.agent_model import AgentModel, AgentType
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.v2.api_token import DfTokenDao
from app.service.dialog import update_session_history
from app.service.basic import BasicService
from app.service.difyService import DifyService
@@ -33,13 +36,19 @@
    tasks = []
    await websocket.accept()
    print(f"Client {agent_id} connected")
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    agent = db.query(MenuCapacityModel).filter(MenuCapacityModel.chat_id == agent_id).first()
    if not agent:
        agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
        agent_type = agent.agent_type
        chat_type = agent.type
    else:
        agent_type = agent.capacity_type
        chat_type = agent.chat_type
    if not agent:
        ret = {"message": "Agent not found", "type": "close"}
        await websocket.send_json(ret)
        return
    agent_type = agent.agent_type
    if chat_id == "" or chat_id == "0":
        ret = {"message": "Chat ID not found", "type": "close"}
        await websocket.send_json(ret)
@@ -47,7 +56,7 @@
    if agent_type == AgentType.RAGFLOW:
        ragflow_service = RagflowService(settings.fwr_base_url)
        token = get_ragflow_token(db, current_user.id)
        token = await get_ragflow_token(db, current_user.id)
        try:
            async def forward_to_ragflow():
                while True:
@@ -133,7 +142,7 @@
                        pass
    elif agent_type == AgentType.BISHENG:
        token = get_bisheng_token(db, current_user.id)
        token = await get_bisheng_token(db, current_user.id)
        service_uri = f"{settings.sgb_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
        headers = {'cookie': f"access_token_cookie={token};"}
@@ -225,7 +234,7 @@
                    await websocket.send_json({"message": "Invalid request", "type": "error"})
                    continue
                logger.error(agent.type)
                if agent.type == "questionTalk":
                if chat_type == "questionTalk":
                    try:
                        data = await service.questions_talk(question, chat_id)
@@ -309,8 +318,11 @@
        # token = get_dify_token(db, current_user.id)
        try:
            async def forward_to_dify():
                if agent.type == "imageTalk":
                    token = settings.dify_api_token
                if chat_type == "imageTalk":
                    token = DfTokenDao(db).get_token_by_id(IMAGE_TO_TEXT)
                    if not token:
                        await websocket.send_json({"message": "Invalid token", "type": "error"})
                    while True:
                        image_list = []
                        is_image = False
@@ -334,9 +346,11 @@
                        except Exception as e:
                            logger.error(e)
                        # complete_response = ""
                        answer_str = ""
                        async for rag_response in dify_service.chat(token, current_user.id, question, upload_file_id,
                                                                    conversation_id):
                            # print(rag_response)
                            try:
                                if rag_response[:5] == "data:":
                                    # 如果是,则截取掉前5个字符,并去除首尾空白符
@@ -402,9 +416,11 @@
                                result = {"message": f"内部错误: {e2}", "type": "close"}
                                await websocket.send_json(result)
                                print(f"Error process message of ragflow: {e2}")
                elif agent.type == "reportWorkflow":
                    print(2323333232)
                    token = settings.dify_workflow_clean
                elif chat_type == "reportWorkflow":
                    token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_CLEANING)
                    if not token:
                        await websocket.send_json({"message": "Invalid token document_to_cleaning", "type": "error"})
                    while True:
                        receive_message = await websocket.receive_json()
                        print(f"Received from client {chat_id}: {receive_message}")
@@ -426,21 +442,27 @@
                        except Exception as e:
                            logger.error(e)
                        inputs = {
                            "input_files": []
                        }
                        files = []
                        for file in upload_files:
                            inputs["input_files"].append({
                            files.append({
                                "type": "document",
                                "transfer_method": "local_file",
                                "url": "",
                                "upload_file_id": file
                            })
                        if workflow_type == 1:
                            inputs["input_files"] = files
                        if workflow_type == 2:
                            inputs["file_list"] = files
                            inputs["Completion_of_main_indicators"] = title
                            token = settings.dify_workflow_report
                            token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_REPORT)
                            if not token:
                                await websocket.send_json(
                                    {"message": "Invalid token document_to_cleaning", "type": "error"})
                        complete_response = ""
                        async for rag_response in dify_service.workflow(token, current_user.id, inputs):
                            print(rag_response)
                            # print(rag_response)
                            try:
                                if rag_response[:5] == "data:":
                                    # 如果是,则截取掉前5个字符,并去除首尾空白符