接收excel_talk返回的消息和缓冲区的数据拼接,然后解析JSON。如果解析失败,存入缓存区继续累积,如何解析成功,给前端返回并清空缓冲区。
&& 修复 保存会话消息历史不成功的bug
| | |
| | | await websocket.send_json(result) |
| | | |
| | | else: |
| | | async for result in service.excel_talk(question, chat_id): |
| | | async for data in service.excel_talk(question, chat_id): |
| | | output = data.get("output", "") |
| | | excel_name = data.get("excel_name", "") |
| | | image_name = data.get("image_name", "") |
| | | |
| | | def build_file_url(name, file_type): |
| | | if not name: |
| | | return None |
| | | return (f"/api/files/download/?agent_id={agent_id}&file_id={name}" |
| | | f"&file_type={file_type}") |
| | | excel_url = build_file_url(excel_name, 'excel') |
| | | image_url = build_file_url(image_name, 'image') |
| | | try: |
| | | if result[:5] == "data:": |
| | | # 如果是,则截取掉前5个字符,并去除首尾空白符 |
| | | text = result[5:].strip() |
| | | else: |
| | | # 否则,保持原样 |
| | | text = result |
| | | try: |
| | | data = json.loads(text) |
| | | output = data.get("output", "") |
| | | excel_name = data.get("excel_name", "") |
| | | image_name = data.get("image_name", "") |
| | | excel_url = None |
| | | image_url = None |
| | | if excel_name: |
| | | excel_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={excel_name}&file_type=excel" |
| | | if image_name: |
| | | image_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={image_name}&file_type=image" |
| | | result = {"message": output, "type": "message", "excel_url": excel_url, "image_url": image_url} |
| | | try: |
| | | SessionService(db).update_session(chat_id, |
| | | message={"role": "assistant", "content": result}) |
| | | except Exception as e: |
| | | logger.error(e) |
| | | await websocket.send_json(result | data) |
| | | except json.JSONDecodeError as e: |
| | | print(f"Error decoding JSON: {e}") |
| | | # print(f"Response text: {text}") |
| | | except Exception as e2: |
| | | result = {"message": f"内部错误: {e2}", "type": "close"} |
| | | await websocket.send_json(result) |
| | | print(f"Error process message of basic agent: {e2}") |
| | | SessionService(db).update_session(chat_id, message={"content": output, "role": "assistant"}) |
| | | except Exception as e: |
| | | logger.error(f"Unexpected error when update_session: {e}") |
| | | # 发送结果给客户端 |
| | | data["type"] = "message" |
| | | data["message"] = output |
| | | data["excel_url"] = excel_url |
| | | data["image_url"] = image_url |
| | | await websocket.send_json(data) |
| | | except Exception as e: |
| | | logger.error("----------------------------------------------fffffff") |
| | | logger.error(e) |
| | | print(e) |
| | | await websocket.send_json({"message": "出现错误!", "type": "error"}) |
| | | finally: |
| | | await websocket.close() |
| | |
| | | import json |
| | | from datetime import datetime |
| | | from enum import IntEnum |
| | | from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime, JSON |
| | | from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime, JSON, TEXT |
| | | |
| | | from app.models import AgentType, current_time |
| | | from app.models.base_model import Base |
| | |
| | | create_date = Column(DateTime, default=current_time) # 创建时间,默认值为当前时区时间 |
| | | update_date = Column(DateTime, default=current_time, onupdate=current_time) # 更新时间,默认值为当前时区时间,更新时自动更新 |
| | | tenant_id = Column(Integer) # 创建人 |
| | | message = Column(JSON) # 说明 |
| | | message = Column(TEXT) # 说明 |
| | | |
| | | # to_dict 方法 |
| | | def to_dict(self): |
| | |
| | | 'agent_id': self.agent_id, |
| | | 'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"), |
| | | 'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"), |
| | | 'message': self.message |
| | | 'message': json.loads(self.message) |
| | | } |
| | | |
| | | def add_message(self, message: dict): |
| | | if self.message is None: |
| | | self.message = '[]' |
| | | try: |
| | | msg = json.loads(self.message) |
| | | msg.append(message) |
| | | except Exception as e: |
| | | return |
| | | self.message = json.dumps(msg) |
| | |
| | | import json |
| | | |
| | | import httpx |
| | | |
| | | from Log import logger |
| | |
| | | params = {'chat_id': chat_id} |
| | | data = {"query": question} |
| | | headers = {'Content-Type': 'application/json'} |
| | | buffer = bytearray() |
| | | async with httpx.AsyncClient(timeout=300.0) as client: |
| | | async with client.stream("POST", url, params=params, json=data, headers=headers) as response: |
| | | if response.status_code == 200: |
| | | try: |
| | | async for answer in response.aiter_text(): |
| | | print(f"response of excel_talk chat: {answer}") |
| | | yield answer |
| | | async for chunk in response.aiter_bytes(): |
| | | json_data = process_buffer(chunk, buffer) |
| | | if json_data: |
| | | yield json_data |
| | | buffer.clear() |
| | | except GeneratorExit as e: |
| | | print(e) |
| | | return |
| | | yield {"message": "内部错误", "type": "close"} |
| | | finally: |
| | | # 在所有数据接收完毕后记录日志 |
| | | logger.info("All messages received and processed - over") |
| | | yield {"message": "", "type": "close"} |
| | | |
| | | else: |
| | | yield f"Error: {response.status_code}" |
| | | |
| | | async def questions_talk(self,question, chat_id: str): |
| | | |
| | | |
| | | async def questions_talk(self, question, chat_id: str): |
| | | logger.error("---------------questions_talk--------------------------") |
| | | url = f"{self.base_url}/questions/talk" |
| | | params = {'chat_id': chat_id} |
| | |
| | | |
| | | async def questions_talk_word_download(self, file_id: str): |
| | | url = f"{self.base_url}/questions/download/word" |
| | | return await self.download_from_url(url, params={'excel_name': file_id}) |
| | | return await self.download_from_url(url, params={'excel_name': file_id}) |
| | | |
| | | |
| | | def process_buffer(data, buffer): |
| | | def try_parse_json(data1): |
| | | try: |
| | | return True, json.loads(data1) |
| | | except json.JSONDecodeError: |
| | | return False, None |
| | | |
| | | if data.startswith(b'data:'): |
| | | # 删除 'data:' 头 |
| | | data = data[5:].strip() |
| | | else: |
| | | pass |
| | | |
| | | # 直接拼接到缓冲区尝试解析JSON |
| | | buffer.extend(data.strip()) |
| | | success, parsed_data = try_parse_json(buffer) |
| | | if success: |
| | | return parsed_data |
| | | else: |
| | | # 解析失败,继续拼接 |
| | | return None |
| | |
| | | from typing import Type |
| | | |
| | | from sqlalchemy.orm import Session |
| | | |
| | | from Log import logger |
| | | from app.models import AgentType |
| | | from app.models import AgentType, current_time |
| | | from app.models.session_model import SessionModel |
| | | |
| | | |
| | |
| | | def __init__(self, db: Session): |
| | | self.db = db |
| | | |
| | | def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> SessionModel: |
| | | def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> Type[ |
| | | SessionModel] | SessionModel: |
| | | """ |
| | | 创建一个新的会话记录。 |
| | | |
| | |
| | | """ |
| | | existing_session = self.get_session_by_id(session_id) |
| | | if existing_session: |
| | | message=existing_session.message |
| | | message.append({"role": "user", "content": name}) |
| | | self.update_session(session_id, message=message) |
| | | existing_session.add_message({"role": "user", "content": name}) |
| | | existing_session.update_date = current_time() |
| | | self.db.commit() |
| | | self.db.refresh(existing_session) |
| | | return existing_session |
| | | |
| | | new_session = SessionModel( |
| | | id=session_id, |
| | | name=name[0:200], |
| | | name=name[0:50], |
| | | agent_id=agent_id, |
| | | agent_type=agent_type, |
| | | tenant_id = user_id, |
| | |
| | | self.db.refresh(new_session) |
| | | return new_session |
| | | |
| | | def get_session_by_id(self, session_id: str) -> SessionModel: |
| | | def get_session_by_id(self, session_id: str) -> Type[SessionModel] | None: |
| | | """ |
| | | 根据会话ID获取会话记录。 |
| | | |
| | |
| | | 返回: |
| | | SessionModel: 查找到的会话模型实例,如果未找到则返回None。 |
| | | """ |
| | | return self.db.query(SessionModel).filter_by(id=session_id).first() |
| | | session = self.db.query(SessionModel).filter_by(id=session_id).first() |
| | | if session.message is None: |
| | | session.message = '[]' |
| | | return session |
| | | |
| | | def update_session(self, session_id: str, **kwargs) -> SessionModel: |
| | | def update_session(self, session_id: str, **kwargs) -> Type[SessionModel] | None: |
| | | """ |
| | | 更新会话记录。 |
| | | |
| | |
| | | SessionModel: 更新后的会话模型实例。 |
| | | """ |
| | | logger.error("更新数据---------------------------") |
| | | session = self.db.query(SessionModel).filter_by(id=session_id).first() |
| | | self.db.commit() |
| | | session = self.get_session_by_id(session_id) |
| | | if session: |
| | | if "message" in kwargs: |
| | | |
| | | message = session.message |
| | | message.append(kwargs["message"]) |
| | | session = message |
| | | session.add_message(kwargs["message"]) |
| | | # 替换其他字段 |
| | | for key, value in kwargs.items(): |
| | | if key != "message": |
| | | setattr(session, key, value) |
| | | session.update_date = current_time() |
| | | try: |
| | | self.db.commit() |
| | | self.db.refresh(session) |
| | | except Exception as e: |
| | | self.db.rollback() |
| | | return session |
| | | |
| | | def delete_session(self, session_id: str) -> None: |