From 282a631b9ceee9a634ee1d93751a5254ed37ccef Mon Sep 17 00:00:00 2001 From: zhaoqingang <zhaoqg0118@163.com> Date: 星期二, 18 三月 2025 10:10:48 +0800 Subject: [PATCH] 首页知识库对话-rg --- app/service/v2/chat.py | 220 ++++++++++++++++++++++++++++++++++++++---------------- 1 files changed, 154 insertions(+), 66 deletions(-) diff --git a/app/service/v2/chat.py b/app/service/v2/chat.py index 96672a3..3982bdc 100644 --- a/app/service/v2/chat.py +++ b/app/service/v2/chat.py @@ -6,6 +6,7 @@ import fitz from fastapi import HTTPException +from sqlalchemy import or_ from Log import logger from app.config.agent_base_url import RG_CHAT_DIALOG, DF_CHAT_AGENT, DF_CHAT_PARAMETERS, RG_CHAT_SESSIONS, \ @@ -13,7 +14,7 @@ from app.config.config import settings from app.config.const import * from app.models import DialogModel, ApiTokenModel, UserTokenModel, ComplexChatSessionDao, ChatDataRequest, \ - ComplexChatDao + ComplexChatDao, KnowledgeModel, UserModel from app.models.v2.session_model import ChatSessionDao, ChatData from app.service.v2.app_driver.chat_agent import ChatAgent from app.service.v2.app_driver.chat_data import ChatBaseApply @@ -87,17 +88,45 @@ return ChatAgent(), url -async def service_chat_dialog(db, chat_id: str, question: str, session_id: str, user_id, mode: str): + +async def get_user_kb(db, user_id: int, kb_ids: list) -> list: + res = [] + user = db.query(UserModel).filter(UserModel.id == user_id).first() + if user is None: + return res + query = db.query(KnowledgeModel) + if user.permission != "admin": + klg_list = [j.id for i in user.groups for j in i.knowledges] + query = query.filter(or_(KnowledgeModel.id.in_(klg_list), KnowledgeModel.tenant_id == str(user_id))) + kb_list= query.all() + for kb in kb_list: + if kb.id in kb_ids: + if kb.permission == "team": + res.append(kb.id) + elif kb.tenant_id == str(user_id): + res.append(kb.id) + return res + else: + return kb_ids + + +async def service_chat_dialog(db, chat_id: str, question: str, session_id: str, user_id: int, mode: str, kb_ids: list): conversation_id = "" token = await get_chat_token(db, rg_api_token) url = settings.fwr_base_url + RG_CHAT_DIALOG.format(chat_id) + kb_id = await get_user_kb(db, user_id, kb_ids) + if not kb_id: + yield "data: " + json.dumps({"message": smart_message_error, + "error": "\n**ERROR**: The agent has no knowledge base to work with!", "status": http_400}, + ensure_ascii=False) + "\n\n" + return chat = ChatDialog() session = await add_session_log(db, session_id, question, chat_id, user_id, mode, session_id, RG_TYPE) if session: conversation_id = session.conversation_id message = {"role": "assistant", "answer": "", "reference": {}} try: - async for ans in chat.chat_completions(url, await chat.request_data(question, conversation_id), + async for ans in chat.chat_completions(url, await chat.complex_request_data(question, kb_id, conversation_id), await chat.get_headers(token)): data = {} error = "" @@ -451,6 +480,9 @@ node_data = [] if not query_data: query_data = {} + # print(node_data) + # print("--------------------------------------------------------") + # print(query_data) try: complex_log = ComplexChatSessionDao(db) if not conversation_id: @@ -485,20 +517,20 @@ async def service_complex_chat(db, chat_id, mode, user_id, chat_request: ChatDataRequest): answer_event = "" answer_agent = "" + answer_dialog = "" answer_workflow = "" download_url = "" message_id = "" task_id = "" error = "" node_list = [] + reference= {} conversation_id = "" query_data = chat_request.to_dict() new_message_id = str(uuid.uuid4()) inputs = {"is_deep": chat_request.isDeep} files = chat_request.files - if chat_request.chatMode == complex_knowledge_chat: - inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId}) - elif chat_request.chatMode == complex_content_optimization_chat: + if chat_request.chatMode == complex_content_optimization_chat: inputs["type"] = chat_request.optimizeType elif chat_request.chatMode == complex_dialog_chat: if not files and chat_request.parentId: @@ -512,69 +544,125 @@ "error": "\n**ERROR**: 鍒涘缓浼氳瘽澶辫触锛�", "status": http_500}, ensure_ascii=False) + "\n\n" return + query_data["parentId"] = new_message_id try: - token = await get_chat_token(db, chat_id) - chat, url = await get_chat_object(mode) - async for ans in chat.chat_completions(url, - await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=files, inputs=inputs), - await chat.get_headers(token)): - # print(ans) - data = {} - status = http_200 - conversation_id = ans.get("conversation_id") - task_id = ans.get("task_id") - if ans.get("event") == message_error: - error = ans.get("message", "鍙傛暟寮傚父锛�") - status = http_400 - event = smart_message_error - elif ans.get("event") == message_agent: - data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")} - answer_agent += ans.get("answer", "") - message_id = ans.get("message_id", "") - event = smart_message_stream - elif ans.get("event") == message_event: - data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")} - answer_event += ans.get("answer", "") - message_id = ans.get("message_id", "") - event = smart_message_stream - elif ans.get("event") == message_file: - data = {"url": ans.get("url", ""), "id": ans.get("id", ""), - "type": ans.get("type", "")} - files.append(data) - event = smart_message_file - elif ans.get("event") in [workflow_started, node_started, node_finished]: - data = ans.get("data", {}) - data["inputs"] = await data_process(data.get("inputs", {})) - data["outputs"] = await data_process(data.get("outputs", {})) - data["files"] = await data_process(data.get("files", [])) - data["process_data"] = "" - if data.get("status") == "failed": - status = http_500 - error = data.get("error", "") - node_list.append(ans) - event = [smart_workflow_started, smart_node_started, smart_node_finished][ - [workflow_started, node_started, node_finished].index(ans.get("event"))] - elif ans.get("event") == workflow_finished: - data = ans.get("data", {}) - answer_workflow = data.get("outputs", {}).get("output", data.get("outputs", {}).get("answer")) - download_url = data.get("outputs", {}).get("download_url") - event = smart_workflow_finished - if data.get("status") == "failed": - status = http_500 - error = data.get("error", "") - node_list.append(ans) + if chat_request.chatMode == complex_knowledge_chat: + if not conversation_id: + session = await service_chat_sessions(db, chat_id, chat_request.query) + # print(session) + if not session or session.get("code") != 0: + yield "data: " + json.dumps( + {"message": smart_message_error, "error": "\n**ERROR**: chat agent error", "status": http_500}) + return + conversation_id = session.get("data", {}).get("id") + token = await get_chat_token(db, rg_api_token) + url = settings.fwr_base_url + RG_CHAT_DIALOG.format(chat_id) + chat = ChatDialog() + try: + async for ans in chat.chat_completions(url, await chat.complex_request_data(chat_request.query, chat_request.knowledgeId, conversation_id), + await chat.get_headers(token)): + data = {} + error = "" + status = http_200 + if ans.get("code", None) == 102: + error = ans.get("message", "error锛�") + status = http_400 + event = smart_message_error + else: + if isinstance(ans.get("data"), bool) and ans.get("data") is True: + event = smart_message_end + else: + data = ans.get("data", {}) + # conversation_id = data.get("session_id", "") + if "session_id" in data: + del data["session_id"] + data["prompt"] = "" + if not message_id: + message_id = data.get("id", "") + answer_dialog = data.get("answer", "") + reference = data.get("reference", {}) + event = smart_message_cover + message_str = "data: " + json.dumps( + {"event": event, "data": data, "error": error, "status": status, "message_id":message_id, + "parent_id": new_message_id, + "session_id": chat_request.sessionId}, + ensure_ascii=False) + "\n\n" + for i in range(0, len(message_str), max_chunk_size): + chunk = message_str[i:i + max_chunk_size] + # print(chunk) + yield chunk # 鍙戦�佸垎鍧楁秷鎭� + except Exception as e: - elif ans.get("event") == message_end: - event = smart_message_end - else: - continue + logger.error(e) + try: + yield "data: " + json.dumps({"message": smart_message_error, + "error": "\n**ERROR**: " + str(e), "status": http_500}, + ensure_ascii=False) + "\n\n" + except: + ... + else: + token = await get_chat_token(db, chat_id) + chat, url = await get_chat_object(mode) + async for ans in chat.chat_completions(url, + await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=files, inputs=inputs), + await chat.get_headers(token)): + # print(ans) + data = {} + status = http_200 + conversation_id = ans.get("conversation_id") + task_id = ans.get("task_id") + if ans.get("event") == message_error: + error = ans.get("message", "鍙傛暟寮傚父锛�") + status = http_400 + event = smart_message_error + elif ans.get("event") == message_agent: + data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")} + answer_agent += ans.get("answer", "") + message_id = ans.get("message_id", "") + event = smart_message_stream + elif ans.get("event") == message_event: + data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")} + answer_event += ans.get("answer", "") + message_id = ans.get("message_id", "") + event = smart_message_stream + elif ans.get("event") == message_file: + data = {"url": ans.get("url", ""), "id": ans.get("id", ""), + "type": ans.get("type", "")} + files.append(data) + event = smart_message_file + elif ans.get("event") in [workflow_started, node_started, node_finished]: + data = ans.get("data", {}) + data["inputs"] = await data_process(data.get("inputs", {})) + data["outputs"] = await data_process(data.get("outputs", {})) + data["files"] = await data_process(data.get("files", [])) + data["process_data"] = "" + if data.get("status") == "failed": + status = http_500 + error = data.get("error", "") + node_list.append(ans) + event = [smart_workflow_started, smart_node_started, smart_node_finished][ + [workflow_started, node_started, node_finished].index(ans.get("event"))] + elif ans.get("event") == workflow_finished: + data = ans.get("data", {}) + answer_workflow = data.get("outputs", {}).get("output", data.get("outputs", {}).get("answer")) + download_url = data.get("outputs", {}).get("download_url") + event = smart_workflow_finished + if data.get("status") == "failed": + status = http_500 + error = data.get("error", "") + node_list.append(ans) - yield "data: " + json.dumps( - {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id, - "parent_id": new_message_id, - "session_id": chat_request.sessionId}, - ensure_ascii=False) + "\n\n" + elif ans.get("event") == message_end: + event = smart_message_end + else: + continue + + yield "data: " + json.dumps( + {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id, + "parent_id": new_message_id, + "session_id": chat_request.sessionId}, + ensure_ascii=False) + "\n\n" except Exception as e: logger.error(e) @@ -591,7 +679,7 @@ # "node_list": node_list, "task_id": task_id, "id": message_id, # "error": error}, conversation_id) if message_id: - await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=query_data) + await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or answer_dialog or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list or reference, query_data=query_data) async def service_complex_upload(db, chat_id, file, user_id): files = [] -- Gitblit v1.8.0