From 5ef590b70cc8e2de16083af2ee2d977daae5587c Mon Sep 17 00:00:00 2001 From: zhaoqingang <zhaoqg0118@163.com> Date: 星期二, 19 十一月 2024 16:34:29 +0800 Subject: [PATCH] 会话缓存本地数据库 --- app/service/ragflow.py | 6 ++ app/service/dialog.py | 36 +++++++++++++++++- app/service/group.py | 1 app/api/chat.py | 10 ++++- app/models/dialog_model.py | 27 +++++++++++++ app/api/agent.py | 3 + 6 files changed, 77 insertions(+), 6 deletions(-) diff --git a/app/api/agent.py b/app/api/agent.py index 8e7f1c3..3178144 100644 --- a/app/api/agent.py +++ b/app/api/agent.py @@ -12,6 +12,7 @@ from app.models.base_model import get_db from app.models.user_model import UserModel from app.service.bisheng import BishengService +from app.service.dialog import get_session_history from app.service.ragflow import RagflowService from app.service.service_token import get_ragflow_token, get_bisheng_token @@ -41,6 +42,8 @@ try: token = get_ragflow_token(db, current_user.id) result = await ragflow_service.get_chat_sessions(token, agent_id) + if not result: + result = await get_session_history(db, current_user.id, agent_id) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) return ResponseList(code=200, msg="", data=result) diff --git a/app/api/chat.py b/app/api/chat.py index fe86fb5..ea1be48 100644 --- a/app/api/chat.py +++ b/app/api/chat.py @@ -10,6 +10,7 @@ from app.models.agent_model import AgentModel, AgentType from app.models.base_model import get_db from app.models.user_model import UserModel +from app.service.dialog import update_session_history from app.service.ragflow import RagflowService from app.service.service_token import get_bisheng_token, get_ragflow_token @@ -44,6 +45,7 @@ try: async def forward_to_ragflow(): while True: + is_new = False message = await websocket.receive_json() print(f"Received from client {chat_id}: {message}") chat_history = message.get('chatHistory', []) @@ -51,9 +53,10 @@ if len(chat_history) == 0: chat_history = await ragflow_service.get_session_history(token, chat_id) if len(chat_history) == 0: + is_new = True chat_history = await ragflow_service.set_session(token, agent_id, message, chat_id, True) - print("chat_history------------------------", chat_history) + # print("chat_history------------------------", chat_history) if len(chat_history) == 0: result = {"message": "鍐呴儴閿欒锛氬垱寤轰細璇濆け璐�", "type": "close"} await websocket.send_json(result) @@ -91,11 +94,14 @@ complete_response = "" except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") - print(f"Response text: {text}") + # print(f"Response text: {text}") except Exception as e2: result = {"message": f"鍐呴儴閿欒锛� {e2}", "type": "close"} await websocket.send_json(result) print(f"Error process message of ragflow: {e2}") + dialog_chat_history = await ragflow_service.get_session_history(token, chat_id, 1) + await update_session_history(db, dialog_chat_history, current_user.id, is_new) + # 鍚姩浠诲姟澶勭悊瀹㈡埛绔秷鎭� tasks = [ asyncio.create_task(forward_to_ragflow()) diff --git a/app/models/dialog_model.py b/app/models/dialog_model.py index e46a950..7a1848a 100644 --- a/app/models/dialog_model.py +++ b/app/models/dialog_model.py @@ -1,6 +1,6 @@ from datetime import datetime -from sqlalchemy import Column, Integer, String, Table, ForeignKey, DateTime, BigInteger, Text, Float, Boolean +from sqlalchemy import Column, Integer, String, BigInteger, ForeignKey, DateTime, Text, JSON from sqlalchemy.orm import relationship, backref from app.models.base_model import Base @@ -33,4 +33,29 @@ 'description': self.description, 'icon': self.icon, 'status': self.status + } + + +class ConversationModel(Base): + __tablename__ = 'conversation' + id = Column(String(32), primary_key=True) # id + create_date = Column(DateTime) # 鍒涘缓鏃堕棿 + create_time = Column(BigInteger) + update_date = Column(DateTime) # 鏇存柊鏃堕棿 + update_time = Column(BigInteger) + tenant_id = Column(Integer) # 鍒涘缓浜� + dialog_id = Column(String(32)) + name = Column(String(255)) # 鍚嶇О + message = Column(JSON) # 璇存槑 + reference = Column(JSON) # 鍥炬爣 + + def get_id(self): + return str(self.id) + + + def to_json(self): + return { + 'id': self.id, + 'updated_time': self.update_time, + 'name': self.name, } \ No newline at end of file diff --git a/app/service/dialog.py b/app/service/dialog.py index 4fa7a63..bcca339 100644 --- a/app/service/dialog.py +++ b/app/service/dialog.py @@ -1,4 +1,5 @@ -from app.models import KnowledgeModel, GroupModel, DialogModel +from app.api.user import user_list +from app.models import KnowledgeModel, GroupModel, DialogModel, ConversationModel from app.models.user_model import UserModel from Log import logger @@ -19,4 +20,35 @@ dialog_list.append(k) kld_set.add(k.id) - return {"rows": [kld.to_json() for kld in dialog_list]} \ No newline at end of file + return {"rows": [kld.to_json() for kld in dialog_list]} + + +async def update_session_history(db, data: dict, user_id, is_new): + session_id = data.get("id") + if not session_id: + logger.error("鏇存柊鍥炶瘽璁板綍澶辫触锛亄}".format(data)) + if is_new: + try: + data["tenant_id"] = user_id + conversation_model = ConversationModel(**data) + db.add(conversation_model) + db.commit() + except Exception as e: + logger.error(e) + db.rollback() + else: + try: + data["tenant_id"] = user_id + del data["id"] + db.query(ConversationModel).filter(ConversationModel.id == session_id).update(data) + db.commit() + except Exception as e: + logger.error(e) + db.rollback() + + +async def get_session_history(db, user_id, dialog_id): + session_list = db.query(ConversationModel).filter(ConversationModel.tenant_id.__eq__(user_id), + ConversationModel.dialog_id.__eq__(dialog_id)).order_by( + ConversationModel.update_time.desc()).all() + return [i.to_json() for i in session_list] diff --git a/app/service/group.py b/app/service/group.py index 361062d..af95c76 100644 --- a/app/service/group.py +++ b/app/service/group.py @@ -95,6 +95,7 @@ for user1 in new_users: for user2 in new_users: if user1 != user2: + print(user1, user2) await ragflow_service.add_user_tenant(token, user_dict[user1]["rg_id"], user_dict[user2]["email"], user_dict[user2]["rg_id"]) diff --git a/app/service/ragflow.py b/app/service/ragflow.py index 6d6012d..769631d 100644 --- a/app/service/ragflow.py +++ b/app/service/ragflow.py @@ -146,12 +146,15 @@ } ] if data else [] - async def get_session_history(self, token: str, chat_id: str) -> list: + async def get_session_history(self, token: str, chat_id: str, is_all: int=0): url = f"{self.base_url}/v1/conversation/get?conversation_id={chat_id}" headers = {"Authorization": token} async with httpx.AsyncClient() as client: response = await client.get(url, headers=headers) data = self._handle_response(response) + # print("----------------data----------------------:", data) + if is_all: + return data return data.get("message", []) async def upload_and_parse(self, token: str, chat_id: str, filename: str, file: bytes) -> str: @@ -172,6 +175,7 @@ data = {"email": email, "user_id": user_id} async with httpx.AsyncClient(timeout=60) as client: response = await client.post(url, headers=headers, json=data) + print(response) if response.status_code != 200: raise Exception(f"Ragflow add user to tenant failed: {response.text}") -- Gitblit v1.8.0