From 30ff0afd5d76a3a5aa48058210ae411253574ada Mon Sep 17 00:00:00 2001
From: zhaoqingang <zhaoqg0118@163.com>
Date: 星期四, 13 三月 2025 14:55:30 +0800
Subject: [PATCH] 增加文件多轮问答
---
app/service/v2/chat.py | 70 ++++++++++++++++++++++++++---------
1 files changed, 52 insertions(+), 18 deletions(-)
diff --git a/app/service/v2/chat.py b/app/service/v2/chat.py
index 05a37be..96672a3 100644
--- a/app/service/v2/chat.py
+++ b/app/service/v2/chat.py
@@ -1,6 +1,7 @@
import asyncio
import io
import json
+import time
import uuid
import fitz
@@ -11,7 +12,8 @@
DF_CHAT_WORKFLOW, DF_UPLOAD_FILE, RG_ORIGINAL_URL
from app.config.config import settings
from app.config.const import *
-from app.models import DialogModel, ApiTokenModel, UserTokenModel, ComplexChatSessionDao, ChatDataRequest
+from app.models import DialogModel, ApiTokenModel, UserTokenModel, ComplexChatSessionDao, ChatDataRequest, \
+ ComplexChatDao
from app.models.v2.session_model import ChatSessionDao, ChatData
from app.service.v2.app_driver.chat_agent import ChatAgent
from app.service.v2.app_driver.chat_data import ChatBaseApply
@@ -90,7 +92,7 @@
token = await get_chat_token(db, rg_api_token)
url = settings.fwr_base_url + RG_CHAT_DIALOG.format(chat_id)
chat = ChatDialog()
- session = await add_session_log(db, session_id, question, chat_id, user_id, mode, session_id, 1)
+ session = await add_session_log(db, session_id, question, chat_id, user_id, mode, session_id, RG_TYPE)
if session:
conversation_id = session.conversation_id
message = {"role": "assistant", "answer": "", "reference": {}}
@@ -173,7 +175,7 @@
else:
query = "start new conversation"
session = await add_session_log(db, session_id, query if query else "start new conversation", chat_id, user_id,
- mode, conversation_id, 3)
+ mode, conversation_id, DF_TYPE)
if session:
conversation_id = session.conversation_id
try:
@@ -251,8 +253,15 @@
"error": error}, conversation_id)
+
+
async def service_chat_basic(db, chat_id: str, chat_data: ChatData, session_id: str, user_id, mode: str):
- ...
+
+ if chat_id == basic_report_talk:
+ complex_chat = await ComplexChatDao(db).get_complex_chat_by_mode(chat_data.report_mode)
+ if complex_chat:
+ ...
+
async def service_chat_parameters(db, chat_id, user_id):
@@ -285,7 +294,15 @@
async def service_chat_session_log(db, session_id):
session_log = await ChatSessionDao(db).get_session_by_id(session_id)
- return json.dumps(session_log.log_to_json() if session_log else {})
+ if not session_log:
+ return {}
+ log_info =session_log.log_to_json()
+ if session_log.event_type == complex_chat:
+
+ total, message_list = await ComplexChatSessionDao(db).get_session_list(session_id)
+ log_info["message"] = [message.log_to_json() for message in message_list[::-1]]
+
+ return json.dumps(log_info)
async def service_chat_upload(db, chat_id, file, user_id):
@@ -458,7 +475,12 @@
logger.error(e)
return conversation_id, False
-
+async def add_query_files(db, message_id):
+ query = {}
+ complex_log = await ComplexChatSessionDao(db).get_session_by_id(message_id)
+ if complex_log:
+ query = json.loads(complex_log.query)
+ return query.get("files", [])
async def service_complex_chat(db, chat_id, mode, user_id, chat_request: ChatDataRequest):
answer_event = ""
@@ -468,23 +490,34 @@
message_id = ""
task_id = ""
error = ""
- files = []
node_list = []
- token = await get_chat_token(db, chat_id)
- chat, url = await get_chat_object(mode)
- conversation_id, message = await add_complex_log(db, str(uuid.uuid4()),chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=chat_request.to_dict())
- if not message:
- yield "data: " + json.dumps({"message": smart_message_error,
- "error": "\n**ERROR**: 鍒涘缓浼氳瘽澶辫触锛�", "status": http_500},
- ensure_ascii=False) + "\n\n"
- return
+ conversation_id = ""
+ query_data = chat_request.to_dict()
+ new_message_id = str(uuid.uuid4())
inputs = {"is_deep": chat_request.isDeep}
+ files = chat_request.files
if chat_request.chatMode == complex_knowledge_chat:
inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId})
-
+ elif chat_request.chatMode == complex_content_optimization_chat:
+ inputs["type"] = chat_request.optimizeType
+ elif chat_request.chatMode == complex_dialog_chat:
+ if not files and chat_request.parentId:
+ files = await add_query_files(db, chat_request.parentId)
+ if chat_request.chatMode != complex_content_optimization_chat:
+ await add_session_log(db, chat_request.sessionId, chat_request.query if chat_request.query else "鏈懡鍚嶄細璇�", chat_id, user_id,
+ mode, "", DF_TYPE)
+ conversation_id, message = await add_complex_log(db, new_message_id, chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=query_data)
+ if not message:
+ yield "data: " + json.dumps({"message": smart_message_error,
+ "error": "\n**ERROR**: 鍒涘缓浼氳瘽澶辫触锛�", "status": http_500},
+ ensure_ascii=False) + "\n\n"
+ return
+ query_data["parentId"] = new_message_id
try:
+ token = await get_chat_token(db, chat_id)
+ chat, url = await get_chat_object(mode)
async for ans in chat.chat_completions(url,
- await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=chat_request.files, inputs=inputs),
+ await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=files, inputs=inputs),
await chat.get_headers(token)):
# print(ans)
data = {}
@@ -539,6 +572,7 @@
yield "data: " + json.dumps(
{"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id,
+ "parent_id": new_message_id,
"session_id": chat_request.sessionId},
ensure_ascii=False) + "\n\n"
@@ -557,7 +591,7 @@
# "node_list": node_list, "task_id": task_id, "id": message_id,
# "error": error}, conversation_id)
if message_id:
- await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=chat_request.to_dict())
+ await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=query_data)
async def service_complex_upload(db, chat_id, file, user_id):
files = []
--
Gitblit v1.8.0