From 6bac1630e5af5890a6922bdc624e591eb19a12eb Mon Sep 17 00:00:00 2001 From: zhaoqingang <zhaoqg0118@163.com> Date: 星期四, 13 三月 2025 18:36:07 +0800 Subject: [PATCH] 知识库对接rg --- app/service/v2/chat.py | 241 +++++++++++++++++++++++++++++++++-------------- 1 files changed, 167 insertions(+), 74 deletions(-) diff --git a/app/service/v2/chat.py b/app/service/v2/chat.py index 83ea02a..38683a8 100644 --- a/app/service/v2/chat.py +++ b/app/service/v2/chat.py @@ -1,6 +1,7 @@ import asyncio import io import json +import time import uuid import fitz @@ -11,7 +12,8 @@ DF_CHAT_WORKFLOW, DF_UPLOAD_FILE, RG_ORIGINAL_URL from app.config.config import settings from app.config.const import * -from app.models import DialogModel, ApiTokenModel, UserTokenModel, ComplexChatSessionDao, ChatDataRequest +from app.models import DialogModel, ApiTokenModel, UserTokenModel, ComplexChatSessionDao, ChatDataRequest, \ + ComplexChatDao from app.models.v2.session_model import ChatSessionDao, ChatData from app.service.v2.app_driver.chat_agent import ChatAgent from app.service.v2.app_driver.chat_data import ChatBaseApply @@ -90,7 +92,7 @@ token = await get_chat_token(db, rg_api_token) url = settings.fwr_base_url + RG_CHAT_DIALOG.format(chat_id) chat = ChatDialog() - session = await add_session_log(db, session_id, question, chat_id, user_id, mode, session_id, 1) + session = await add_session_log(db, session_id, question, chat_id, user_id, mode, session_id, RG_TYPE) if session: conversation_id = session.conversation_id message = {"role": "assistant", "answer": "", "reference": {}} @@ -173,7 +175,7 @@ else: query = "start new conversation" session = await add_session_log(db, session_id, query if query else "start new conversation", chat_id, user_id, - mode, conversation_id, 3) + mode, conversation_id, DF_TYPE) if session: conversation_id = session.conversation_id try: @@ -251,8 +253,15 @@ "error": error}, conversation_id) + + async def service_chat_basic(db, chat_id: str, chat_data: ChatData, session_id: str, user_id, mode: str): - ... + + if chat_id == basic_report_talk: + complex_chat = await ComplexChatDao(db).get_complex_chat_by_mode(chat_data.report_mode) + if complex_chat: + ... + async def service_chat_parameters(db, chat_id, user_id): @@ -285,7 +294,15 @@ async def service_chat_session_log(db, session_id): session_log = await ChatSessionDao(db).get_session_by_id(session_id) - return json.dumps(session_log.log_to_json() if session_log else {}) + if not session_log: + return {} + log_info =session_log.log_to_json() + if session_log.event_type == complex_chat: + + total, message_list = await ComplexChatSessionDao(db).get_session_list(session_id) + log_info["message"] = [message.log_to_json() for message in message_list[::-1]] + + return json.dumps(log_info) async def service_chat_upload(db, chat_id, file, user_id): @@ -434,6 +451,9 @@ node_data = [] if not query_data: query_data = {} + # print(node_data) + # print("--------------------------------------------------------") + # print(query_data) try: complex_log = ComplexChatSessionDao(db) if not conversation_id: @@ -458,89 +478,162 @@ logger.error(e) return conversation_id, False - +async def add_query_files(db, message_id): + query = {} + complex_log = await ComplexChatSessionDao(db).get_session_by_id(message_id) + if complex_log: + query = json.loads(complex_log.query) + return query.get("files", []) async def service_complex_chat(db, chat_id, mode, user_id, chat_request: ChatDataRequest): answer_event = "" answer_agent = "" + answer_dialog = "" answer_workflow = "" download_url = "" message_id = "" task_id = "" error = "" - files = [] node_list = [] - token = await get_chat_token(db, chat_id) - chat, url = await get_chat_object(mode) - conversation_id, message = await add_complex_log(db, str(uuid.uuid4()),chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=chat_request.to_dict()) - if not message: - yield "data: " + json.dumps({"message": smart_message_error, - "error": "\n**ERROR**: 鍒涘缓浼氳瘽澶辫触锛�", "status": http_500}, - ensure_ascii=False) + "\n\n" - return + reference= {} + conversation_id = "" + query_data = chat_request.to_dict() + new_message_id = str(uuid.uuid4()) inputs = {"is_deep": chat_request.isDeep} - if chat_request.chatMode == complex_knowledge_chat: - inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId}) + files = chat_request.files + if chat_request.chatMode == complex_content_optimization_chat: + inputs["type"] = chat_request.optimizeType + elif chat_request.chatMode == complex_dialog_chat: + if not files and chat_request.parentId: + files = await add_query_files(db, chat_request.parentId) + if chat_request.chatMode != complex_content_optimization_chat: + await add_session_log(db, chat_request.sessionId, chat_request.query if chat_request.query else "鏈懡鍚嶄細璇�", chat_id, user_id, + mode, "", DF_TYPE) + conversation_id, message = await add_complex_log(db, new_message_id, chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=query_data) + if not message: + yield "data: " + json.dumps({"message": smart_message_error, + "error": "\n**ERROR**: 鍒涘缓浼氳瘽澶辫触锛�", "status": http_500}, + ensure_ascii=False) + "\n\n" + return + query_data["parentId"] = new_message_id try: - async for ans in chat.chat_completions(url, - await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=chat_request.files, inputs=inputs), - await chat.get_headers(token)): - print(ans) - data = {} - status = http_200 - conversation_id = ans.get("conversation_id") - task_id = ans.get("task_id") - if ans.get("event") == message_error: - error = ans.get("message", "鍙傛暟寮傚父锛�") - status = http_400 - event = smart_message_error - elif ans.get("event") == message_agent: - data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")} - answer_agent += ans.get("answer", "") - message_id = ans.get("message_id", "") - event = smart_message_stream - elif ans.get("event") == message_event: - data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")} - answer_event += ans.get("answer", "") - message_id = ans.get("message_id", "") - event = smart_message_stream - elif ans.get("event") == message_file: - data = {"url": ans.get("url", ""), "id": ans.get("id", ""), - "type": ans.get("type", "")} - files.append(data) - event = smart_message_file - elif ans.get("event") in [workflow_started, node_started, node_finished]: - data = ans.get("data", {}) - data["inputs"] = await data_process(data.get("inputs", {})) - data["outputs"] = await data_process(data.get("outputs", {})) - data["files"] = await data_process(data.get("files", [])) - data["process_data"] = "" - if data.get("status") == "failed": - status = http_500 - error = data.get("error", "") - node_list.append(ans) - event = [smart_workflow_started, smart_node_started, smart_node_finished][ - [workflow_started, node_started, node_finished].index(ans.get("event"))] - elif ans.get("event") == workflow_finished: - data = ans.get("data", {}) - answer_workflow = data.get("outputs", {}).get("output", data.get("outputs", {}).get("answer")) - download_url = data.get("outputs", {}).get("download_url") - event = smart_workflow_finished - if data.get("status") == "failed": - status = http_500 - error = data.get("error", "") - node_list.append(ans) + if chat_request.chatMode == complex_knowledge_chat: + if not conversation_id: + session = await service_chat_sessions(db, chat_id, chat_request.query) + # print(session) + if not session or session.get("code") != 0: + yield "data: " + json.dumps( + {"message": smart_message_error, "error": "\n**ERROR**: chat agent error", "status": http_500}) + return + conversation_id = session.get("data", {}).get("id") + token = await get_chat_token(db, rg_api_token) + url = settings.fwr_base_url + RG_CHAT_DIALOG.format(chat_id) + chat = ChatDialog() + try: + async for ans in chat.chat_completions(url, await chat.complex_request_data(chat_request.query, chat_request.knowledgeId, conversation_id), + await chat.get_headers(token)): + data = {} + error = "" + status = http_200 + if ans.get("code", None) == 102: + error = ans.get("message", "error锛�") + status = http_400 + event = smart_message_error + else: + if isinstance(ans.get("data"), bool) and ans.get("data") is True: + event = smart_message_end + else: + data = ans.get("data", {}) + # conversation_id = data.get("session_id", "") + if "session_id" in data: + del data["session_id"] + data["prompt"] = "" + if not message_id: + message_id = data.get("id", "") + answer_dialog = data.get("answer", "") + reference = data.get("reference", {}) + event = smart_message_cover + message_str = "data: " + json.dumps( + {"event": event, "data": data, "error": error, "status": status, "message_id":message_id, + "parent_id": new_message_id, + "session_id": chat_request.sessionId}, + ensure_ascii=False) + "\n\n" + for i in range(0, len(message_str), max_chunk_size): + chunk = message_str[i:i + max_chunk_size] + # print(chunk) + yield chunk # 鍙戦�佸垎鍧楁秷鎭� + except Exception as e: - elif ans.get("event") == message_end: - event = smart_message_end - else: - continue + logger.error(e) + try: + yield "data: " + json.dumps({"message": smart_message_error, + "error": "\n**ERROR**: " + str(e), "status": http_500}, + ensure_ascii=False) + "\n\n" + except: + ... + else: + token = await get_chat_token(db, chat_id) + chat, url = await get_chat_object(mode) + async for ans in chat.chat_completions(url, + await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=files, inputs=inputs), + await chat.get_headers(token)): + # print(ans) + data = {} + status = http_200 + conversation_id = ans.get("conversation_id") + task_id = ans.get("task_id") + if ans.get("event") == message_error: + error = ans.get("message", "鍙傛暟寮傚父锛�") + status = http_400 + event = smart_message_error + elif ans.get("event") == message_agent: + data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")} + answer_agent += ans.get("answer", "") + message_id = ans.get("message_id", "") + event = smart_message_stream + elif ans.get("event") == message_event: + data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")} + answer_event += ans.get("answer", "") + message_id = ans.get("message_id", "") + event = smart_message_stream + elif ans.get("event") == message_file: + data = {"url": ans.get("url", ""), "id": ans.get("id", ""), + "type": ans.get("type", "")} + files.append(data) + event = smart_message_file + elif ans.get("event") in [workflow_started, node_started, node_finished]: + data = ans.get("data", {}) + data["inputs"] = await data_process(data.get("inputs", {})) + data["outputs"] = await data_process(data.get("outputs", {})) + data["files"] = await data_process(data.get("files", [])) + data["process_data"] = "" + if data.get("status") == "failed": + status = http_500 + error = data.get("error", "") + node_list.append(ans) + event = [smart_workflow_started, smart_node_started, smart_node_finished][ + [workflow_started, node_started, node_finished].index(ans.get("event"))] + elif ans.get("event") == workflow_finished: + data = ans.get("data", {}) + answer_workflow = data.get("outputs", {}).get("output", data.get("outputs", {}).get("answer")) + download_url = data.get("outputs", {}).get("download_url") + event = smart_workflow_finished + if data.get("status") == "failed": + status = http_500 + error = data.get("error", "") + node_list.append(ans) - yield "data: " + json.dumps( - {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id, - "session_id": chat_request.sessionId}, - ensure_ascii=False) + "\n\n" + elif ans.get("event") == message_end: + event = smart_message_end + else: + continue + + yield "data: " + json.dumps( + {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id, + "parent_id": new_message_id, + "session_id": chat_request.sessionId}, + ensure_ascii=False) + "\n\n" except Exception as e: logger.error(e) @@ -557,7 +650,7 @@ # "node_list": node_list, "task_id": task_id, "id": message_id, # "error": error}, conversation_id) if message_id: - await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=chat_request.to_dict()) + await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or answer_dialog or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list or reference, query_data=query_data) async def service_complex_upload(db, chat_id, file, user_id): files = [] -- Gitblit v1.8.0