From 2dc4a7392eef26fdadd00fde1baf8b471ab25ca5 Mon Sep 17 00:00:00 2001 From: xuyonghao <898441624@qq.com> Date: 星期二, 17 十二月 2024 15:16:44 +0800 Subject: [PATCH] user_app表app补全注册接口 --- app/api/chat.py | 25 ++++++++++++++++--------- 1 files changed, 16 insertions(+), 9 deletions(-) diff --git a/app/api/chat.py b/app/api/chat.py index 4344b62..ecae273 100644 --- a/app/api/chat.py +++ b/app/api/chat.py @@ -11,10 +11,11 @@ from app.api import get_current_user_websocket from app.config.config import settings from app.config.const import IMAGE_TO_TEXT, DOCUMENT_TO_REPORT, DOCUMENT_TO_CLEANING +from app.models import MenuCapacityModel from app.models.agent_model import AgentModel, AgentType from app.models.base_model import get_db from app.models.user_model import UserModel -from app.service.common.api_token import DfTokenDao +from app.service.v2.api_token import DfTokenDao from app.service.dialog import update_session_history from app.service.basic import BasicService from app.service.difyService import DifyService @@ -35,13 +36,19 @@ tasks = [] await websocket.accept() print(f"Client {agent_id} connected") - - agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() + agent = db.query(MenuCapacityModel).filter(MenuCapacityModel.chat_id == agent_id).first() + if not agent: + agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() + agent_type = agent.agent_type + chat_type = agent.type + else: + agent_type = agent.capacity_type + chat_type = agent.chat_type if not agent: ret = {"message": "Agent not found", "type": "close"} await websocket.send_json(ret) return - agent_type = agent.agent_type + if chat_id == "" or chat_id == "0": ret = {"message": "Chat ID not found", "type": "close"} await websocket.send_json(ret) @@ -49,7 +56,7 @@ if agent_type == AgentType.RAGFLOW: ragflow_service = RagflowService(settings.fwr_base_url) - token = get_ragflow_token(db, current_user.id) + token = await get_ragflow_token(db, current_user.id) try: async def forward_to_ragflow(): while True: @@ -135,7 +142,7 @@ pass elif agent_type == AgentType.BISHENG: - token = get_bisheng_token(db, current_user.id) + token = await get_bisheng_token(db, current_user.id) service_uri = f"{settings.sgb_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}" headers = {'cookie': f"access_token_cookie={token};"} @@ -227,7 +234,7 @@ await websocket.send_json({"message": "Invalid request", "type": "error"}) continue logger.error(agent.type) - if agent.type == "questionTalk": + if chat_type == "questionTalk": try: data = await service.questions_talk(question, chat_id) @@ -311,7 +318,7 @@ # token = get_dify_token(db, current_user.id) try: async def forward_to_dify(): - if agent.type == "imageTalk": + if chat_type == "imageTalk": token = DfTokenDao(db).get_token_by_id(IMAGE_TO_TEXT) if not token: await websocket.send_json({"message": "Invalid token", "type": "error"}) @@ -409,7 +416,7 @@ result = {"message": f"鍐呴儴閿欒锛� {e2}", "type": "close"} await websocket.send_json(result) print(f"Error process message of ragflow: {e2}") - elif agent.type == "reportWorkflow": + elif chat_type == "reportWorkflow": token = DfTokenDao(db).get_token_by_id(DOCUMENT_TO_CLEANING) if not token: -- Gitblit v1.8.0