From 30ff0afd5d76a3a5aa48058210ae411253574ada Mon Sep 17 00:00:00 2001
From: zhaoqingang <zhaoqg0118@163.com>
Date: 星期四, 13 三月 2025 14:55:30 +0800
Subject: [PATCH] 增加文件多轮问答

---
 app/service/v2/chat.py |   36 ++++++++++++++++++++++++------------
 1 files changed, 24 insertions(+), 12 deletions(-)

diff --git a/app/service/v2/chat.py b/app/service/v2/chat.py
index 0942681..96672a3 100644
--- a/app/service/v2/chat.py
+++ b/app/service/v2/chat.py
@@ -475,7 +475,12 @@
         logger.error(e)
         return conversation_id, False
 
-
+async def add_query_files(db, message_id):
+    query = {}
+    complex_log = await ComplexChatSessionDao(db).get_session_by_id(message_id)
+    if complex_log:
+        query = json.loads(complex_log.query)
+    return query.get("files", [])
 
 async def service_complex_chat(db, chat_id, mode, user_id, chat_request: ChatDataRequest):
     answer_event = ""
@@ -485,28 +490,34 @@
     message_id = ""
     task_id = ""
     error = ""
-    files = []
     node_list = []
     conversation_id = ""
-    token = await get_chat_token(db, chat_id)
-    chat, url = await get_chat_object(mode)
+    query_data = chat_request.to_dict()
+    new_message_id = str(uuid.uuid4())
+    inputs = {"is_deep": chat_request.isDeep}
+    files = chat_request.files
+    if chat_request.chatMode == complex_knowledge_chat:
+        inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId})
+    elif chat_request.chatMode == complex_content_optimization_chat:
+        inputs["type"] = chat_request.optimizeType
+    elif chat_request.chatMode == complex_dialog_chat:
+        if not files and chat_request.parentId:
+            files = await add_query_files(db, chat_request.parentId)
     if chat_request.chatMode != complex_content_optimization_chat:
         await add_session_log(db, chat_request.sessionId, chat_request.query if chat_request.query else "鏈懡鍚嶄細璇�", chat_id, user_id,
                                 mode, "", DF_TYPE)
-        conversation_id, message = await add_complex_log(db, str(uuid.uuid4()),chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=chat_request.to_dict())
+        conversation_id, message = await add_complex_log(db, new_message_id, chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=query_data)
         if not message:
             yield "data: " + json.dumps({"message": smart_message_error,
                                          "error": "\n**ERROR**: 鍒涘缓浼氳瘽澶辫触锛�", "status": http_500},
                                         ensure_ascii=False) + "\n\n"
             return
-    inputs = {"is_deep": chat_request.isDeep}
-    if chat_request.chatMode == complex_knowledge_chat:
-        inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId})
-    elif chat_request.chatMode == complex_content_optimization_chat:
-        inputs["type"] = chat_request.optimizeType
+    query_data["parentId"] = new_message_id
     try:
+        token = await get_chat_token(db, chat_id)
+        chat, url = await get_chat_object(mode)
         async for ans in chat.chat_completions(url,
-                                               await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=chat_request.files, inputs=inputs),
+                                               await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=files, inputs=inputs),
                                                await chat.get_headers(token)):
             # print(ans)
             data = {}
@@ -561,6 +572,7 @@
 
             yield "data: " + json.dumps(
                 {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id,
+                 "parent_id": new_message_id,
                  "session_id": chat_request.sessionId},
                 ensure_ascii=False) + "\n\n"
 
@@ -579,7 +591,7 @@
         #                                           "node_list": node_list, "task_id": task_id, "id": message_id,
         #                                           "error": error}, conversation_id)
         if message_id:
-            await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=chat_request.to_dict())
+            await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=query_data)
 
 async def service_complex_upload(db, chat_id, file, user_id):
     files = []

--
Gitblit v1.8.0