From 66f5df3ec8004e91ec2f440d69755caa52ac33bd Mon Sep 17 00:00:00 2001
From: zhangqian <zhangqian@123.com>
Date: 星期五, 22 十一月 2024 01:22:19 +0800
Subject: [PATCH] 接收excel_talk返回的消息和缓冲区的数据拼接,然后解析JSON。如果解析失败,存入缓存区继续累积,如何解析成功,给前端返回并清空缓冲区。 && 修复 保存会话消息历史不成功的bug

---
 app/models/session_model.py |   16 ++++-
 app/api/chat.py             |   55 +++++++-----------
 app/service/basic.py        |   47 +++++++++++++--
 app/service/session.py      |   42 +++++++++----
 4 files changed, 103 insertions(+), 57 deletions(-)

diff --git a/app/api/chat.py b/app/api/chat.py
index 9784eb9..1d12e91 100644
--- a/app/api/chat.py
+++ b/app/api/chat.py
@@ -247,43 +247,30 @@
                         await websocket.send_json(result)
 
                 else:
-                    async for result in service.excel_talk(question, chat_id):
+                    async for data in service.excel_talk(question, chat_id):
+                        output = data.get("output", "")
+                        excel_name = data.get("excel_name", "")
+                        image_name = data.get("image_name", "")
+
+                        def build_file_url(name, file_type):
+                            if not name:
+                                return None
+                            return (f"/api/files/download/?agent_id={agent_id}&file_id={name}"
+                                    f"&file_type={file_type}")
+                        excel_url = build_file_url(excel_name, 'excel')
+                        image_url = build_file_url(image_name, 'image')
                         try:
-                            if result[:5] == "data:":
-                                # 濡傛灉鏄紝鍒欐埅鍙栨帀鍓�5涓瓧绗︼紝骞跺幓闄ら灏剧┖鐧界
-                                text = result[5:].strip()
-                            else:
-                                # 鍚﹀垯锛屼繚鎸佸師鏍�
-                                text = result
-                            try:
-                                data = json.loads(text)
-                                output = data.get("output", "")
-                                excel_name = data.get("excel_name", "")
-                                image_name = data.get("image_name", "")
-                                excel_url = None
-                                image_url = None
-                                if excel_name:
-                                    excel_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={excel_name}&file_type=excel"
-                                if image_name:
-                                    image_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={image_name}&file_type=image"
-                                result = {"message": output, "type": "message", "excel_url": excel_url, "image_url": image_url}
-                                try:
-                                    SessionService(db).update_session(chat_id,
-                                                                      message={"role": "assistant", "content": result})
-                                except Exception as e:
-                                    logger.error(e)
-                                await websocket.send_json(result | data)
-                            except json.JSONDecodeError as e:
-                                print(f"Error decoding JSON: {e}")
-                                # print(f"Response text: {text}")
-                        except Exception as e2:
-                            result = {"message": f"鍐呴儴閿欒锛� {e2}", "type": "close"}
-                            await websocket.send_json(result)
-                            print(f"Error process message of basic agent: {e2}")
+                            SessionService(db).update_session(chat_id, message={"content": output, "role": "assistant"})
+                        except Exception as e:
+                            logger.error(f"Unexpected error when update_session: {e}")
+                        # 鍙戦�佺粨鏋滅粰瀹㈡埛绔�
+                        data["type"] = "message"
+                        data["message"] = output
+                        data["excel_url"] = excel_url
+                        data["image_url"] = image_url
+                        await websocket.send_json(data)
         except Exception as e:
-            logger.error("----------------------------------------------fffffff")
             logger.error(e)
-            print(e)
             await websocket.send_json({"message": "鍑虹幇閿欒锛�", "type": "error"})
         finally:
             await websocket.close()
diff --git a/app/models/session_model.py b/app/models/session_model.py
index 6aed237..9536471 100644
--- a/app/models/session_model.py
+++ b/app/models/session_model.py
@@ -1,7 +1,7 @@
 import json
 from datetime import datetime
 from enum import IntEnum
-from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime, JSON
+from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime, JSON, TEXT
 
 from app.models import AgentType, current_time
 from app.models.base_model import Base
@@ -16,7 +16,7 @@
     create_date = Column(DateTime, default=current_time)  # 鍒涘缓鏃堕棿锛岄粯璁ゅ�间负褰撳墠鏃跺尯鏃堕棿
     update_date = Column(DateTime, default=current_time, onupdate=current_time)  # 鏇存柊鏃堕棿锛岄粯璁ゅ�间负褰撳墠鏃跺尯鏃堕棿锛屾洿鏂版椂鑷姩鏇存柊
     tenant_id = Column(Integer)  # 鍒涘缓浜�
-    message = Column(JSON)  # 璇存槑
+    message = Column(TEXT)  # 璇存槑
 
     # to_dict 鏂规硶
     def to_dict(self):
@@ -37,5 +37,15 @@
             'agent_id': self.agent_id,
             'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
             'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
-            'message': self.message
+            'message': json.loads(self.message)
         }
+
+    def add_message(self, message: dict):
+        if self.message is None:
+            self.message = '[]'
+        try:
+            msg = json.loads(self.message)
+            msg.append(message)
+        except Exception as e:
+            return
+        self.message = json.dumps(msg)
diff --git a/app/service/basic.py b/app/service/basic.py
index 33e5f86..9c22206 100644
--- a/app/service/basic.py
+++ b/app/service/basic.py
@@ -1,3 +1,5 @@
+import json
+
 import httpx
 
 from Log import logger
@@ -62,20 +64,30 @@
         params = {'chat_id': chat_id}
         data = {"query": question}
         headers = {'Content-Type': 'application/json'}
+        buffer = bytearray()
         async with httpx.AsyncClient(timeout=300.0) as client:
             async with client.stream("POST", url, params=params, json=data, headers=headers) as response:
                 if response.status_code == 200:
                     try:
-                        async for answer in response.aiter_text():
-                            print(f"response of excel_talk chat: {answer}")
-                            yield answer
+                        async for chunk in response.aiter_bytes():
+                            json_data = process_buffer(chunk, buffer)
+                            if json_data:
+                                yield json_data
+                                buffer.clear()
                     except GeneratorExit as e:
                         print(e)
-                        return
+                        yield {"message": "鍐呴儴閿欒", "type": "close"}
+                    finally:
+                        # 鍦ㄦ墍鏈夋暟鎹帴鏀跺畬姣曞悗璁板綍鏃ュ織
+                        logger.info("All messages received and processed - over")
+                        yield {"message": "", "type": "close"}
+
                 else:
                     yield f"Error: {response.status_code}"
 
-    async def questions_talk(self,question, chat_id: str):
+
+
+    async def questions_talk(self, question, chat_id: str):
         logger.error("---------------questions_talk--------------------------")
         url = f"{self.base_url}/questions/talk"
         params = {'chat_id': chat_id}
@@ -91,4 +103,27 @@
 
     async def questions_talk_word_download(self, file_id: str):
         url = f"{self.base_url}/questions/download/word"
-        return await self.download_from_url(url, params={'excel_name': file_id})
\ No newline at end of file
+        return await self.download_from_url(url, params={'excel_name': file_id})
+
+
+def process_buffer(data, buffer):
+    def try_parse_json(data1):
+        try:
+            return True, json.loads(data1)
+        except json.JSONDecodeError:
+            return False, None
+
+    if data.startswith(b'data:'):
+        # 鍒犻櫎 'data:' 澶�
+        data = data[5:].strip()
+    else:
+        pass
+
+    # 鐩存帴鎷兼帴鍒扮紦鍐插尯灏濊瘯瑙f瀽JSON
+    buffer.extend(data.strip())
+    success, parsed_data = try_parse_json(buffer)
+    if success:
+        return parsed_data
+    else:
+        # 瑙f瀽澶辫触锛岀户缁嫾鎺�
+        return None
diff --git a/app/service/session.py b/app/service/session.py
index dd60c26..79c1d3d 100644
--- a/app/service/session.py
+++ b/app/service/session.py
@@ -1,7 +1,9 @@
+from typing import Type
+
 from sqlalchemy.orm import Session
 
 from Log import logger
-from app.models import AgentType
+from app.models import AgentType, current_time
 from app.models.session_model import SessionModel
 
 
@@ -9,7 +11,8 @@
     def __init__(self, db: Session):
         self.db = db
 
-    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> SessionModel:
+    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> Type[
+                                                                                                                    SessionModel] | SessionModel:
         """
         鍒涘缓涓�涓柊鐨勪細璇濊褰曘��
 
@@ -24,13 +27,15 @@
         """
         existing_session = self.get_session_by_id(session_id)
         if existing_session:
-            message=existing_session.message
-            message.append({"role": "user", "content": name})
-            self.update_session(session_id, message=message)
+            existing_session.add_message({"role": "user", "content": name})
+            existing_session.update_date = current_time()
+            self.db.commit()
+            self.db.refresh(existing_session)
+            return existing_session
 
         new_session = SessionModel(
             id=session_id,
-            name=name[0:200],
+            name=name[0:50],
             agent_id=agent_id,
             agent_type=agent_type,
             tenant_id = user_id,
@@ -41,7 +46,7 @@
         self.db.refresh(new_session)
         return new_session
 
-    def get_session_by_id(self, session_id: str) -> SessionModel:
+    def get_session_by_id(self, session_id: str) -> Type[SessionModel] | None:
         """
         鏍规嵁浼氳瘽ID鑾峰彇浼氳瘽璁板綍銆�
 
@@ -51,9 +56,12 @@
         杩斿洖:
             SessionModel: 鏌ユ壘鍒扮殑浼氳瘽妯″瀷瀹炰緥锛屽鏋滄湭鎵惧埌鍒欒繑鍥濶one銆�
         """
-        return self.db.query(SessionModel).filter_by(id=session_id).first()
+        session = self.db.query(SessionModel).filter_by(id=session_id).first()
+        if session.message is None:
+            session.message = '[]'
+        return session
 
-    def update_session(self, session_id: str, **kwargs) -> SessionModel:
+    def update_session(self, session_id: str, **kwargs) -> Type[SessionModel] | None:
         """
         鏇存柊浼氳瘽璁板綍銆�
 
@@ -65,15 +73,21 @@
             SessionModel: 鏇存柊鍚庣殑浼氳瘽妯″瀷瀹炰緥銆�
         """
         logger.error("鏇存柊鏁版嵁---------------------------")
-        session = self.db.query(SessionModel).filter_by(id=session_id).first()
+        self.db.commit()
+        session = self.get_session_by_id(session_id)
         if session:
             if "message" in kwargs:
-
-                message = session.message
-                message.append(kwargs["message"])
-                session = message
+                session.add_message(kwargs["message"])
+            # 鏇挎崲鍏朵粬瀛楁
+            for key, value in kwargs.items():
+                if key != "message":
+                    setattr(session, key, value)
+            session.update_date = current_time()
+            try:
                 self.db.commit()
                 self.db.refresh(session)
+            except Exception as e:
+                self.db.rollback()
         return session
 
     def delete_session(self, session_id: str) -> None:

--
Gitblit v1.8.0