From 30ff0afd5d76a3a5aa48058210ae411253574ada Mon Sep 17 00:00:00 2001 From: zhaoqingang <zhaoqg0118@163.com> Date: 星期四, 13 三月 2025 14:55:30 +0800 Subject: [PATCH] 增加文件多轮问答 --- app/service/v2/chat.py | 36 ++++++++++++++++++++++++------------ 1 files changed, 24 insertions(+), 12 deletions(-) diff --git a/app/service/v2/chat.py b/app/service/v2/chat.py index 0942681..96672a3 100644 --- a/app/service/v2/chat.py +++ b/app/service/v2/chat.py @@ -475,7 +475,12 @@ logger.error(e) return conversation_id, False - +async def add_query_files(db, message_id): + query = {} + complex_log = await ComplexChatSessionDao(db).get_session_by_id(message_id) + if complex_log: + query = json.loads(complex_log.query) + return query.get("files", []) async def service_complex_chat(db, chat_id, mode, user_id, chat_request: ChatDataRequest): answer_event = "" @@ -485,28 +490,34 @@ message_id = "" task_id = "" error = "" - files = [] node_list = [] conversation_id = "" - token = await get_chat_token(db, chat_id) - chat, url = await get_chat_object(mode) + query_data = chat_request.to_dict() + new_message_id = str(uuid.uuid4()) + inputs = {"is_deep": chat_request.isDeep} + files = chat_request.files + if chat_request.chatMode == complex_knowledge_chat: + inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId}) + elif chat_request.chatMode == complex_content_optimization_chat: + inputs["type"] = chat_request.optimizeType + elif chat_request.chatMode == complex_dialog_chat: + if not files and chat_request.parentId: + files = await add_query_files(db, chat_request.parentId) if chat_request.chatMode != complex_content_optimization_chat: await add_session_log(db, chat_request.sessionId, chat_request.query if chat_request.query else "鏈懡鍚嶄細璇�", chat_id, user_id, mode, "", DF_TYPE) - conversation_id, message = await add_complex_log(db, str(uuid.uuid4()),chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=chat_request.to_dict()) + conversation_id, message = await add_complex_log(db, new_message_id, chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=query_data) if not message: yield "data: " + json.dumps({"message": smart_message_error, "error": "\n**ERROR**: 鍒涘缓浼氳瘽澶辫触锛�", "status": http_500}, ensure_ascii=False) + "\n\n" return - inputs = {"is_deep": chat_request.isDeep} - if chat_request.chatMode == complex_knowledge_chat: - inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId}) - elif chat_request.chatMode == complex_content_optimization_chat: - inputs["type"] = chat_request.optimizeType + query_data["parentId"] = new_message_id try: + token = await get_chat_token(db, chat_id) + chat, url = await get_chat_object(mode) async for ans in chat.chat_completions(url, - await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=chat_request.files, inputs=inputs), + await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=files, inputs=inputs), await chat.get_headers(token)): # print(ans) data = {} @@ -561,6 +572,7 @@ yield "data: " + json.dumps( {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id, + "parent_id": new_message_id, "session_id": chat_request.sessionId}, ensure_ascii=False) + "\n\n" @@ -579,7 +591,7 @@ # "node_list": node_list, "task_id": task_id, "id": message_id, # "error": error}, conversation_id) if message_id: - await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=chat_request.to_dict()) + await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=query_data) async def service_complex_upload(db, chat_id, file, user_id): files = [] -- Gitblit v1.8.0