接收excel_talk返回的消息和缓冲区的数据拼接,然后解析JSON。如果解析失败,存入缓存区继续累积,如何解析成功,给前端返回并清空缓冲区。
&& 修复 保存会话消息历史不成功的bug
4个文件已修改
160 ■■■■■ 已修改文件
app/api/chat.py 55 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/session_model.py 16 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/basic.py 47 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/session.py 42 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/chat.py
@@ -247,43 +247,30 @@
                        await websocket.send_json(result)
                else:
                    async for result in service.excel_talk(question, chat_id):
                    async for data in service.excel_talk(question, chat_id):
                        output = data.get("output", "")
                        excel_name = data.get("excel_name", "")
                        image_name = data.get("image_name", "")
                        def build_file_url(name, file_type):
                            if not name:
                                return None
                            return (f"/api/files/download/?agent_id={agent_id}&file_id={name}"
                                    f"&file_type={file_type}")
                        excel_url = build_file_url(excel_name, 'excel')
                        image_url = build_file_url(image_name, 'image')
                        try:
                            if result[:5] == "data:":
                                # 如果是,则截取掉前5个字符,并去除首尾空白符
                                text = result[5:].strip()
                            else:
                                # 否则,保持原样
                                text = result
                            try:
                                data = json.loads(text)
                                output = data.get("output", "")
                                excel_name = data.get("excel_name", "")
                                image_name = data.get("image_name", "")
                                excel_url = None
                                image_url = None
                                if excel_name:
                                    excel_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={excel_name}&file_type=excel"
                                if image_name:
                                    image_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={image_name}&file_type=image"
                                result = {"message": output, "type": "message", "excel_url": excel_url, "image_url": image_url}
                                try:
                                    SessionService(db).update_session(chat_id,
                                                                      message={"role": "assistant", "content": result})
                                except Exception as e:
                                    logger.error(e)
                                await websocket.send_json(result | data)
                            except json.JSONDecodeError as e:
                                print(f"Error decoding JSON: {e}")
                                # print(f"Response text: {text}")
                        except Exception as e2:
                            result = {"message": f"内部错误: {e2}", "type": "close"}
                            await websocket.send_json(result)
                            print(f"Error process message of basic agent: {e2}")
                            SessionService(db).update_session(chat_id, message={"content": output, "role": "assistant"})
                        except Exception as e:
                            logger.error(f"Unexpected error when update_session: {e}")
                        # 发送结果给客户端
                        data["type"] = "message"
                        data["message"] = output
                        data["excel_url"] = excel_url
                        data["image_url"] = image_url
                        await websocket.send_json(data)
        except Exception as e:
            logger.error("----------------------------------------------fffffff")
            logger.error(e)
            print(e)
            await websocket.send_json({"message": "出现错误!", "type": "error"})
        finally:
            await websocket.close()
app/models/session_model.py
@@ -1,7 +1,7 @@
import json
from datetime import datetime
from enum import IntEnum
from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime, JSON
from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime, JSON, TEXT
from app.models import AgentType, current_time
from app.models.base_model import Base
@@ -16,7 +16,7 @@
    create_date = Column(DateTime, default=current_time)  # 创建时间,默认值为当前时区时间
    update_date = Column(DateTime, default=current_time, onupdate=current_time)  # 更新时间,默认值为当前时区时间,更新时自动更新
    tenant_id = Column(Integer)  # 创建人
    message = Column(JSON)  # 说明
    message = Column(TEXT)  # 说明
    # to_dict 方法
    def to_dict(self):
@@ -37,5 +37,15 @@
            'agent_id': self.agent_id,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
            'message': self.message
            'message': json.loads(self.message)
        }
    def add_message(self, message: dict):
        if self.message is None:
            self.message = '[]'
        try:
            msg = json.loads(self.message)
            msg.append(message)
        except Exception as e:
            return
        self.message = json.dumps(msg)
app/service/basic.py
@@ -1,3 +1,5 @@
import json
import httpx
from Log import logger
@@ -62,20 +64,30 @@
        params = {'chat_id': chat_id}
        data = {"query": question}
        headers = {'Content-Type': 'application/json'}
        buffer = bytearray()
        async with httpx.AsyncClient(timeout=300.0) as client:
            async with client.stream("POST", url, params=params, json=data, headers=headers) as response:
                if response.status_code == 200:
                    try:
                        async for answer in response.aiter_text():
                            print(f"response of excel_talk chat: {answer}")
                            yield answer
                        async for chunk in response.aiter_bytes():
                            json_data = process_buffer(chunk, buffer)
                            if json_data:
                                yield json_data
                                buffer.clear()
                    except GeneratorExit as e:
                        print(e)
                        return
                        yield {"message": "内部错误", "type": "close"}
                    finally:
                        # 在所有数据接收完毕后记录日志
                        logger.info("All messages received and processed - over")
                        yield {"message": "", "type": "close"}
                else:
                    yield f"Error: {response.status_code}"
    async def questions_talk(self,question, chat_id: str):
    async def questions_talk(self, question, chat_id: str):
        logger.error("---------------questions_talk--------------------------")
        url = f"{self.base_url}/questions/talk"
        params = {'chat_id': chat_id}
@@ -91,4 +103,27 @@
    async def questions_talk_word_download(self, file_id: str):
        url = f"{self.base_url}/questions/download/word"
        return await self.download_from_url(url, params={'excel_name': file_id})
        return await self.download_from_url(url, params={'excel_name': file_id})
def process_buffer(data, buffer):
    def try_parse_json(data1):
        try:
            return True, json.loads(data1)
        except json.JSONDecodeError:
            return False, None
    if data.startswith(b'data:'):
        # 删除 'data:' 头
        data = data[5:].strip()
    else:
        pass
    # 直接拼接到缓冲区尝试解析JSON
    buffer.extend(data.strip())
    success, parsed_data = try_parse_json(buffer)
    if success:
        return parsed_data
    else:
        # 解析失败,继续拼接
        return None
app/service/session.py
@@ -1,7 +1,9 @@
from typing import Type
from sqlalchemy.orm import Session
from Log import logger
from app.models import AgentType
from app.models import AgentType, current_time
from app.models.session_model import SessionModel
@@ -9,7 +11,8 @@
    def __init__(self, db: Session):
        self.db = db
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> SessionModel:
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> Type[
                                                                                                                    SessionModel] | SessionModel:
        """
        创建一个新的会话记录。
@@ -24,13 +27,15 @@
        """
        existing_session = self.get_session_by_id(session_id)
        if existing_session:
            message=existing_session.message
            message.append({"role": "user", "content": name})
            self.update_session(session_id, message=message)
            existing_session.add_message({"role": "user", "content": name})
            existing_session.update_date = current_time()
            self.db.commit()
            self.db.refresh(existing_session)
            return existing_session
        new_session = SessionModel(
            id=session_id,
            name=name[0:200],
            name=name[0:50],
            agent_id=agent_id,
            agent_type=agent_type,
            tenant_id = user_id,
@@ -41,7 +46,7 @@
        self.db.refresh(new_session)
        return new_session
    def get_session_by_id(self, session_id: str) -> SessionModel:
    def get_session_by_id(self, session_id: str) -> Type[SessionModel] | None:
        """
        根据会话ID获取会话记录。
@@ -51,9 +56,12 @@
        返回:
            SessionModel: 查找到的会话模型实例,如果未找到则返回None。
        """
        return self.db.query(SessionModel).filter_by(id=session_id).first()
        session = self.db.query(SessionModel).filter_by(id=session_id).first()
        if session.message is None:
            session.message = '[]'
        return session
    def update_session(self, session_id: str, **kwargs) -> SessionModel:
    def update_session(self, session_id: str, **kwargs) -> Type[SessionModel] | None:
        """
        更新会话记录。
@@ -65,15 +73,21 @@
            SessionModel: 更新后的会话模型实例。
        """
        logger.error("更新数据---------------------------")
        session = self.db.query(SessionModel).filter_by(id=session_id).first()
        self.db.commit()
        session = self.get_session_by_id(session_id)
        if session:
            if "message" in kwargs:
                message = session.message
                message.append(kwargs["message"])
                session = message
                session.add_message(kwargs["message"])
            # 替换其他字段
            for key, value in kwargs.items():
                if key != "message":
                    setattr(session, key, value)
            session.update_date = current_time()
            try:
                self.db.commit()
                self.db.refresh(session)
            except Exception as e:
                self.db.rollback()
        return session
    def delete_session(self, session_id: str) -> None: