From 66f5df3ec8004e91ec2f440d69755caa52ac33bd Mon Sep 17 00:00:00 2001
From: zhangqian <zhangqian@123.com>
Date: 星期五, 22 十一月 2024 01:22:19 +0800
Subject: [PATCH] 接收excel_talk返回的消息和缓冲区的数据拼接,然后解析JSON。如果解析失败,存入缓存区继续累积,如何解析成功,给前端返回并清空缓冲区。 && 修复 保存会话消息历史不成功的bug

---
 app/service/session.py |   42 ++++++++++++++++++++++++++++--------------
 1 files changed, 28 insertions(+), 14 deletions(-)

diff --git a/app/service/session.py b/app/service/session.py
index dd60c26..79c1d3d 100644
--- a/app/service/session.py
+++ b/app/service/session.py
@@ -1,7 +1,9 @@
+from typing import Type
+
 from sqlalchemy.orm import Session
 
 from Log import logger
-from app.models import AgentType
+from app.models import AgentType, current_time
 from app.models.session_model import SessionModel
 
 
@@ -9,7 +11,8 @@
     def __init__(self, db: Session):
         self.db = db
 
-    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> SessionModel:
+    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> Type[
+                                                                                                                    SessionModel] | SessionModel:
         """
         鍒涘缓涓�涓柊鐨勪細璇濊褰曘��
 
@@ -24,13 +27,15 @@
         """
         existing_session = self.get_session_by_id(session_id)
         if existing_session:
-            message=existing_session.message
-            message.append({"role": "user", "content": name})
-            self.update_session(session_id, message=message)
+            existing_session.add_message({"role": "user", "content": name})
+            existing_session.update_date = current_time()
+            self.db.commit()
+            self.db.refresh(existing_session)
+            return existing_session
 
         new_session = SessionModel(
             id=session_id,
-            name=name[0:200],
+            name=name[0:50],
             agent_id=agent_id,
             agent_type=agent_type,
             tenant_id = user_id,
@@ -41,7 +46,7 @@
         self.db.refresh(new_session)
         return new_session
 
-    def get_session_by_id(self, session_id: str) -> SessionModel:
+    def get_session_by_id(self, session_id: str) -> Type[SessionModel] | None:
         """
         鏍规嵁浼氳瘽ID鑾峰彇浼氳瘽璁板綍銆�
 
@@ -51,9 +56,12 @@
         杩斿洖:
             SessionModel: 鏌ユ壘鍒扮殑浼氳瘽妯″瀷瀹炰緥锛屽鏋滄湭鎵惧埌鍒欒繑鍥濶one銆�
         """
-        return self.db.query(SessionModel).filter_by(id=session_id).first()
+        session = self.db.query(SessionModel).filter_by(id=session_id).first()
+        if session.message is None:
+            session.message = '[]'
+        return session
 
-    def update_session(self, session_id: str, **kwargs) -> SessionModel:
+    def update_session(self, session_id: str, **kwargs) -> Type[SessionModel] | None:
         """
         鏇存柊浼氳瘽璁板綍銆�
 
@@ -65,15 +73,21 @@
             SessionModel: 鏇存柊鍚庣殑浼氳瘽妯″瀷瀹炰緥銆�
         """
         logger.error("鏇存柊鏁版嵁---------------------------")
-        session = self.db.query(SessionModel).filter_by(id=session_id).first()
+        self.db.commit()
+        session = self.get_session_by_id(session_id)
         if session:
             if "message" in kwargs:
-
-                message = session.message
-                message.append(kwargs["message"])
-                session = message
+                session.add_message(kwargs["message"])
+            # 鏇挎崲鍏朵粬瀛楁
+            for key, value in kwargs.items():
+                if key != "message":
+                    setattr(session, key, value)
+            session.update_date = current_time()
+            try:
                 self.db.commit()
                 self.db.refresh(session)
+            except Exception as e:
+                self.db.rollback()
         return session
 
     def delete_session(self, session_id: str) -> None:

--
Gitblit v1.8.0