From 5ef590b70cc8e2de16083af2ee2d977daae5587c Mon Sep 17 00:00:00 2001
From: zhaoqingang <zhaoqg0118@163.com>
Date: 星期二, 19 十一月 2024 16:34:29 +0800
Subject: [PATCH] 会话缓存本地数据库

---
 app/service/ragflow.py     |    6 ++
 app/service/dialog.py      |   36 +++++++++++++++++-
 app/service/group.py       |    1 
 app/api/chat.py            |   10 ++++-
 app/models/dialog_model.py |   27 +++++++++++++
 app/api/agent.py           |    3 +
 6 files changed, 77 insertions(+), 6 deletions(-)

diff --git a/app/api/agent.py b/app/api/agent.py
index 8e7f1c3..3178144 100644
--- a/app/api/agent.py
+++ b/app/api/agent.py
@@ -12,6 +12,7 @@
 from app.models.base_model import get_db
 from app.models.user_model import UserModel
 from app.service.bisheng import BishengService
+from app.service.dialog import get_session_history
 from app.service.ragflow import RagflowService
 from app.service.service_token import get_ragflow_token, get_bisheng_token
 
@@ -41,6 +42,8 @@
         try:
             token = get_ragflow_token(db, current_user.id)
             result = await ragflow_service.get_chat_sessions(token, agent_id)
+            if not result:
+                result = await get_session_history(db, current_user.id, agent_id)
         except Exception as e:
             raise HTTPException(status_code=500, detail=str(e))
         return ResponseList(code=200, msg="", data=result)
diff --git a/app/api/chat.py b/app/api/chat.py
index fe86fb5..ea1be48 100644
--- a/app/api/chat.py
+++ b/app/api/chat.py
@@ -10,6 +10,7 @@
 from app.models.agent_model import AgentModel, AgentType
 from app.models.base_model import get_db
 from app.models.user_model import UserModel
+from app.service.dialog import update_session_history
 from app.service.ragflow import RagflowService
 from app.service.service_token import get_bisheng_token, get_ragflow_token
 
@@ -44,6 +45,7 @@
         try:
             async def forward_to_ragflow():
                 while True:
+                    is_new = False
                     message = await websocket.receive_json()
                     print(f"Received from client {chat_id}: {message}")
                     chat_history = message.get('chatHistory', [])
@@ -51,9 +53,10 @@
                     if len(chat_history) == 0:
                         chat_history = await ragflow_service.get_session_history(token, chat_id)
                         if len(chat_history) == 0:
+                            is_new = True
                             chat_history = await ragflow_service.set_session(token, agent_id,
                                                                              message, chat_id, True)
-                            print("chat_history------------------------", chat_history)
+                            # print("chat_history------------------------", chat_history)
                             if len(chat_history) == 0:
                                 result = {"message": "鍐呴儴閿欒锛氬垱寤轰細璇濆け璐�", "type": "close"}
                                 await websocket.send_json(result)
@@ -91,11 +94,14 @@
                                 complete_response = ""
                             except json.JSONDecodeError as e:
                                 print(f"Error decoding JSON: {e}")
-                                print(f"Response text: {text}")
+                                # print(f"Response text: {text}")
                         except Exception as e2:
                             result = {"message": f"鍐呴儴閿欒锛� {e2}", "type": "close"}
                             await websocket.send_json(result)
                             print(f"Error process message of ragflow: {e2}")
+                    dialog_chat_history = await ragflow_service.get_session_history(token, chat_id, 1)
+                    await update_session_history(db, dialog_chat_history, current_user.id, is_new)
+
             # 鍚姩浠诲姟澶勭悊瀹㈡埛绔秷鎭�
             tasks = [
                 asyncio.create_task(forward_to_ragflow())
diff --git a/app/models/dialog_model.py b/app/models/dialog_model.py
index e46a950..7a1848a 100644
--- a/app/models/dialog_model.py
+++ b/app/models/dialog_model.py
@@ -1,6 +1,6 @@
 from datetime import datetime
 
-from sqlalchemy import Column, Integer, String, Table, ForeignKey, DateTime, BigInteger, Text, Float, Boolean
+from sqlalchemy import Column, Integer, String, BigInteger, ForeignKey, DateTime, Text, JSON
 from sqlalchemy.orm import relationship, backref
 
 from app.models.base_model import Base
@@ -33,4 +33,29 @@
             'description': self.description,
             'icon': self.icon,
             'status': self.status
+        }
+
+
+class ConversationModel(Base):
+    __tablename__ = 'conversation'
+    id = Column(String(32), primary_key=True)  #  id
+    create_date = Column(DateTime)             # 鍒涘缓鏃堕棿
+    create_time = Column(BigInteger)
+    update_date = Column(DateTime)             # 鏇存柊鏃堕棿
+    update_time = Column(BigInteger)
+    tenant_id = Column(Integer)              # 鍒涘缓浜�
+    dialog_id = Column(String(32))
+    name = Column(String(255))                 # 鍚嶇О
+    message = Column(JSON)                 # 璇存槑
+    reference = Column(JSON)                         # 鍥炬爣
+
+    def get_id(self):
+        return str(self.id)
+
+
+    def to_json(self):
+        return {
+            'id': self.id,
+            'updated_time': self.update_time,
+            'name': self.name,
         }
\ No newline at end of file
diff --git a/app/service/dialog.py b/app/service/dialog.py
index 4fa7a63..bcca339 100644
--- a/app/service/dialog.py
+++ b/app/service/dialog.py
@@ -1,4 +1,5 @@
-from app.models import KnowledgeModel, GroupModel, DialogModel
+from app.api.user import user_list
+from app.models import KnowledgeModel, GroupModel, DialogModel, ConversationModel
 from app.models.user_model import UserModel
 from Log import logger
 
@@ -19,4 +20,35 @@
                     dialog_list.append(k)
                     kld_set.add(k.id)
 
-    return {"rows": [kld.to_json() for kld in dialog_list]}
\ No newline at end of file
+    return {"rows": [kld.to_json() for kld in dialog_list]}
+
+
+async def update_session_history(db, data: dict, user_id, is_new):
+    session_id = data.get("id")
+    if not session_id:
+        logger.error("鏇存柊鍥炶瘽璁板綍澶辫触锛亄}".format(data))
+    if is_new:
+        try:
+            data["tenant_id"] = user_id
+            conversation_model = ConversationModel(**data)
+            db.add(conversation_model)
+            db.commit()
+        except Exception as e:
+            logger.error(e)
+            db.rollback()
+    else:
+        try:
+            data["tenant_id"] = user_id
+            del data["id"]
+            db.query(ConversationModel).filter(ConversationModel.id == session_id).update(data)
+            db.commit()
+        except Exception as e:
+            logger.error(e)
+            db.rollback()
+
+
+async def get_session_history(db, user_id, dialog_id):
+    session_list = db.query(ConversationModel).filter(ConversationModel.tenant_id.__eq__(user_id),
+                                                      ConversationModel.dialog_id.__eq__(dialog_id)).order_by(
+        ConversationModel.update_time.desc()).all()
+    return [i.to_json() for i in session_list]
diff --git a/app/service/group.py b/app/service/group.py
index 361062d..af95c76 100644
--- a/app/service/group.py
+++ b/app/service/group.py
@@ -95,6 +95,7 @@
             for user1 in new_users:
                 for user2 in new_users:
                     if user1 != user2:
+                        print(user1, user2)
                         await ragflow_service.add_user_tenant(token, user_dict[user1]["rg_id"],
                                                               user_dict[user2]["email"],
                                                               user_dict[user2]["rg_id"])
diff --git a/app/service/ragflow.py b/app/service/ragflow.py
index 6d6012d..769631d 100644
--- a/app/service/ragflow.py
+++ b/app/service/ragflow.py
@@ -146,12 +146,15 @@
                 }
             ] if data else []
 
-    async def get_session_history(self, token: str, chat_id: str) -> list:
+    async def get_session_history(self, token: str, chat_id: str, is_all: int=0):
         url = f"{self.base_url}/v1/conversation/get?conversation_id={chat_id}"
         headers = {"Authorization": token}
         async with httpx.AsyncClient() as client:
             response = await client.get(url, headers=headers)
             data = self._handle_response(response)
+            # print("----------------data----------------------:", data)
+            if is_all:
+                return data
             return data.get("message", [])
 
     async def upload_and_parse(self, token: str, chat_id: str, filename: str, file: bytes) -> str:
@@ -172,6 +175,7 @@
         data = {"email": email, "user_id": user_id}
         async with httpx.AsyncClient(timeout=60) as client:
             response = await client.post(url, headers=headers, json=data)
+            print(response)
             if response.status_code != 200:
                 raise Exception(f"Ragflow add user to tenant failed: {response.text}")
 

--
Gitblit v1.8.0