From 66f5df3ec8004e91ec2f440d69755caa52ac33bd Mon Sep 17 00:00:00 2001 From: zhangqian <zhangqian@123.com> Date: 星期五, 22 十一月 2024 01:22:19 +0800 Subject: [PATCH] 接收excel_talk返回的消息和缓冲区的数据拼接,然后解析JSON。如果解析失败,存入缓存区继续累积,如何解析成功,给前端返回并清空缓冲区。 && 修复 保存会话消息历史不成功的bug --- app/models/session_model.py | 16 ++++- app/api/chat.py | 55 +++++++----------- app/service/basic.py | 47 +++++++++++++-- app/service/session.py | 42 +++++++++---- 4 files changed, 103 insertions(+), 57 deletions(-) diff --git a/app/api/chat.py b/app/api/chat.py index 9784eb9..1d12e91 100644 --- a/app/api/chat.py +++ b/app/api/chat.py @@ -247,43 +247,30 @@ await websocket.send_json(result) else: - async for result in service.excel_talk(question, chat_id): + async for data in service.excel_talk(question, chat_id): + output = data.get("output", "") + excel_name = data.get("excel_name", "") + image_name = data.get("image_name", "") + + def build_file_url(name, file_type): + if not name: + return None + return (f"/api/files/download/?agent_id={agent_id}&file_id={name}" + f"&file_type={file_type}") + excel_url = build_file_url(excel_name, 'excel') + image_url = build_file_url(image_name, 'image') try: - if result[:5] == "data:": - # 濡傛灉鏄紝鍒欐埅鍙栨帀鍓�5涓瓧绗︼紝骞跺幓闄ら灏剧┖鐧界 - text = result[5:].strip() - else: - # 鍚﹀垯锛屼繚鎸佸師鏍� - text = result - try: - data = json.loads(text) - output = data.get("output", "") - excel_name = data.get("excel_name", "") - image_name = data.get("image_name", "") - excel_url = None - image_url = None - if excel_name: - excel_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={excel_name}&file_type=excel" - if image_name: - image_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={image_name}&file_type=image" - result = {"message": output, "type": "message", "excel_url": excel_url, "image_url": image_url} - try: - SessionService(db).update_session(chat_id, - message={"role": "assistant", "content": result}) - except Exception as e: - logger.error(e) - await websocket.send_json(result | data) - except json.JSONDecodeError as e: - print(f"Error decoding JSON: {e}") - # print(f"Response text: {text}") - except Exception as e2: - result = {"message": f"鍐呴儴閿欒锛� {e2}", "type": "close"} - await websocket.send_json(result) - print(f"Error process message of basic agent: {e2}") + SessionService(db).update_session(chat_id, message={"content": output, "role": "assistant"}) + except Exception as e: + logger.error(f"Unexpected error when update_session: {e}") + # 鍙戦�佺粨鏋滅粰瀹㈡埛绔� + data["type"] = "message" + data["message"] = output + data["excel_url"] = excel_url + data["image_url"] = image_url + await websocket.send_json(data) except Exception as e: - logger.error("----------------------------------------------fffffff") logger.error(e) - print(e) await websocket.send_json({"message": "鍑虹幇閿欒锛�", "type": "error"}) finally: await websocket.close() diff --git a/app/models/session_model.py b/app/models/session_model.py index 6aed237..9536471 100644 --- a/app/models/session_model.py +++ b/app/models/session_model.py @@ -1,7 +1,7 @@ import json from datetime import datetime from enum import IntEnum -from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime, JSON +from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime, JSON, TEXT from app.models import AgentType, current_time from app.models.base_model import Base @@ -16,7 +16,7 @@ create_date = Column(DateTime, default=current_time) # 鍒涘缓鏃堕棿锛岄粯璁ゅ�间负褰撳墠鏃跺尯鏃堕棿 update_date = Column(DateTime, default=current_time, onupdate=current_time) # 鏇存柊鏃堕棿锛岄粯璁ゅ�间负褰撳墠鏃跺尯鏃堕棿锛屾洿鏂版椂鑷姩鏇存柊 tenant_id = Column(Integer) # 鍒涘缓浜� - message = Column(JSON) # 璇存槑 + message = Column(TEXT) # 璇存槑 # to_dict 鏂规硶 def to_dict(self): @@ -37,5 +37,15 @@ 'agent_id': self.agent_id, 'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"), 'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"), - 'message': self.message + 'message': json.loads(self.message) } + + def add_message(self, message: dict): + if self.message is None: + self.message = '[]' + try: + msg = json.loads(self.message) + msg.append(message) + except Exception as e: + return + self.message = json.dumps(msg) diff --git a/app/service/basic.py b/app/service/basic.py index 33e5f86..9c22206 100644 --- a/app/service/basic.py +++ b/app/service/basic.py @@ -1,3 +1,5 @@ +import json + import httpx from Log import logger @@ -62,20 +64,30 @@ params = {'chat_id': chat_id} data = {"query": question} headers = {'Content-Type': 'application/json'} + buffer = bytearray() async with httpx.AsyncClient(timeout=300.0) as client: async with client.stream("POST", url, params=params, json=data, headers=headers) as response: if response.status_code == 200: try: - async for answer in response.aiter_text(): - print(f"response of excel_talk chat: {answer}") - yield answer + async for chunk in response.aiter_bytes(): + json_data = process_buffer(chunk, buffer) + if json_data: + yield json_data + buffer.clear() except GeneratorExit as e: print(e) - return + yield {"message": "鍐呴儴閿欒", "type": "close"} + finally: + # 鍦ㄦ墍鏈夋暟鎹帴鏀跺畬姣曞悗璁板綍鏃ュ織 + logger.info("All messages received and processed - over") + yield {"message": "", "type": "close"} + else: yield f"Error: {response.status_code}" - async def questions_talk(self,question, chat_id: str): + + + async def questions_talk(self, question, chat_id: str): logger.error("---------------questions_talk--------------------------") url = f"{self.base_url}/questions/talk" params = {'chat_id': chat_id} @@ -91,4 +103,27 @@ async def questions_talk_word_download(self, file_id: str): url = f"{self.base_url}/questions/download/word" - return await self.download_from_url(url, params={'excel_name': file_id}) \ No newline at end of file + return await self.download_from_url(url, params={'excel_name': file_id}) + + +def process_buffer(data, buffer): + def try_parse_json(data1): + try: + return True, json.loads(data1) + except json.JSONDecodeError: + return False, None + + if data.startswith(b'data:'): + # 鍒犻櫎 'data:' 澶� + data = data[5:].strip() + else: + pass + + # 鐩存帴鎷兼帴鍒扮紦鍐插尯灏濊瘯瑙f瀽JSON + buffer.extend(data.strip()) + success, parsed_data = try_parse_json(buffer) + if success: + return parsed_data + else: + # 瑙f瀽澶辫触锛岀户缁嫾鎺� + return None diff --git a/app/service/session.py b/app/service/session.py index dd60c26..79c1d3d 100644 --- a/app/service/session.py +++ b/app/service/session.py @@ -1,7 +1,9 @@ +from typing import Type + from sqlalchemy.orm import Session from Log import logger -from app.models import AgentType +from app.models import AgentType, current_time from app.models.session_model import SessionModel @@ -9,7 +11,8 @@ def __init__(self, db: Session): self.db = db - def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> SessionModel: + def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> Type[ + SessionModel] | SessionModel: """ 鍒涘缓涓�涓柊鐨勪細璇濊褰曘�� @@ -24,13 +27,15 @@ """ existing_session = self.get_session_by_id(session_id) if existing_session: - message=existing_session.message - message.append({"role": "user", "content": name}) - self.update_session(session_id, message=message) + existing_session.add_message({"role": "user", "content": name}) + existing_session.update_date = current_time() + self.db.commit() + self.db.refresh(existing_session) + return existing_session new_session = SessionModel( id=session_id, - name=name[0:200], + name=name[0:50], agent_id=agent_id, agent_type=agent_type, tenant_id = user_id, @@ -41,7 +46,7 @@ self.db.refresh(new_session) return new_session - def get_session_by_id(self, session_id: str) -> SessionModel: + def get_session_by_id(self, session_id: str) -> Type[SessionModel] | None: """ 鏍规嵁浼氳瘽ID鑾峰彇浼氳瘽璁板綍銆� @@ -51,9 +56,12 @@ 杩斿洖: SessionModel: 鏌ユ壘鍒扮殑浼氳瘽妯″瀷瀹炰緥锛屽鏋滄湭鎵惧埌鍒欒繑鍥濶one銆� """ - return self.db.query(SessionModel).filter_by(id=session_id).first() + session = self.db.query(SessionModel).filter_by(id=session_id).first() + if session.message is None: + session.message = '[]' + return session - def update_session(self, session_id: str, **kwargs) -> SessionModel: + def update_session(self, session_id: str, **kwargs) -> Type[SessionModel] | None: """ 鏇存柊浼氳瘽璁板綍銆� @@ -65,15 +73,21 @@ SessionModel: 鏇存柊鍚庣殑浼氳瘽妯″瀷瀹炰緥銆� """ logger.error("鏇存柊鏁版嵁---------------------------") - session = self.db.query(SessionModel).filter_by(id=session_id).first() + self.db.commit() + session = self.get_session_by_id(session_id) if session: if "message" in kwargs: - - message = session.message - message.append(kwargs["message"]) - session = message + session.add_message(kwargs["message"]) + # 鏇挎崲鍏朵粬瀛楁 + for key, value in kwargs.items(): + if key != "message": + setattr(session, key, value) + session.update_date = current_time() + try: self.db.commit() self.db.refresh(session) + except Exception as e: + self.db.rollback() return session def delete_session(self, session_id: str) -> None: -- Gitblit v1.8.0