From ae30d9a75407c912649f11c4f44ff15c869a4f98 Mon Sep 17 00:00:00 2001 From: zhaoqingang <zhaoqg0118@163.com> Date: 星期四, 21 十一月 2024 15:42:42 +0800 Subject: [PATCH] 自研agent会话保存和查询 --- app/models/session_model.py | 16 +++++++++++++++- app/api/chat.py | 4 +++- app/task/fetch_agent.py | 2 +- app/service/session.py | 14 +++++++++++--- app/api/agent.py | 5 ++++- 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/app/api/agent.py b/app/api/agent.py index 4e410d4..af36024 100644 --- a/app/api/agent.py +++ b/app/api/agent.py @@ -105,7 +105,7 @@ return JSONResponse(status_code=200, content={"code": 400, "message": "Invalid result structure"}) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - if agent.agent_type == AgentType.BISHENG: + elif agent.agent_type == AgentType.BISHENG: bisheng_service = BishengService(base_url=settings.sgb_base_url) try: token = get_bisheng_token(db, current_user.id) @@ -139,6 +139,9 @@ return JSONResponse(status_code=200, content={"code": 200, "data": combined_logs}) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + elif agent.agent_type == AgentType.BASIC: + session = db.query(SessionModel).filter(SessionModel.agent_id == agent_id, SessionModel.tenant_id==current_user.id).first() + return JSONResponse(status_code=200, content={"code": 200, "data": session.log_to_json()}) else: return JSONResponse(status_code=200, content={"code": 200, "log": "Unsupported agent type"}) diff --git a/app/api/chat.py b/app/api/chat.py index de05de8..60d2fb8 100644 --- a/app/api/chat.py +++ b/app/api/chat.py @@ -209,7 +209,8 @@ session_id=chat_id, name=question, agent_id=agent_id, - agent_type=AgentType.BASIC + agent_type=AgentType.BASIC, + user_id=current_user.id ) if not question: await websocket.send_json({"message": "Invalid request", "type": "error"}) @@ -232,6 +233,7 @@ if file_name: excel_url = f"/api/files/download/?agent_id=basic_question_talk&file_id={file_name}&file_type=word" result = {"message": output, "type": "message", "file_url": excel_url} + SessionService(db).update_session(session_id=chat_id, is_incr=1, message={"role":"assistant", "content": result}) await websocket.send_json(result | data) except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") diff --git a/app/models/session_model.py b/app/models/session_model.py index 44d0b74..11bd33a 100644 --- a/app/models/session_model.py +++ b/app/models/session_model.py @@ -1,7 +1,7 @@ import json from datetime import datetime from enum import IntEnum -from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime +from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime, JSON from app.models import AgentType, current_time from app.models.base_model import Base @@ -15,6 +15,9 @@ agent_type = Column(SQLAlchemyEnum(AgentType), nullable=False) # 鐩墠鍙瓨basic鐨勶紝ragflow鍜宐isheng鐨勮皟鎺ュ彛鑾峰彇 create_date = Column(DateTime, default=current_time) # 鍒涘缓鏃堕棿锛岄粯璁ゅ�间负褰撳墠鏃跺尯鏃堕棿 update_date = Column(DateTime, default=current_time, onupdate=current_time) # 鏇存柊鏃堕棿锛岄粯璁ゅ�间负褰撳墠鏃跺尯鏃堕棿锛屾洿鏂版椂鑷姩鏇存柊 + tenant_id = Column(Integer) # 鍒涘缓浜� + message = Column(JSON) # 璇存槑 + # to_dict 鏂规硶 def to_dict(self): return { @@ -25,3 +28,14 @@ 'create_date': self.create_date, 'update_date': self.update_date, } + + def log_to_json(self): + return { + 'id': self.id, + 'name': self.name, + 'agent_type': self.agent_type, + 'agent_id': self.agent_id, + 'create_date': self.create_date, + 'update_date': self.update_date, + 'message': self.message + } diff --git a/app/service/session.py b/app/service/session.py index b3b698f..78ae31a 100644 --- a/app/service/session.py +++ b/app/service/session.py @@ -8,7 +8,7 @@ def __init__(self, db: Session): self.db = db - def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType) -> SessionModel: + def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> SessionModel: """ 鍒涘缓涓�涓柊鐨勪細璇濊褰曘�� @@ -23,13 +23,17 @@ """ existing_session = self.get_session_by_id(session_id) if existing_session: - return None # 濡傛灉浼氳瘽ID宸插瓨鍦紝涓嶈繘琛屼换浣曟搷浣� + message=existing_session.message + message.append({"role": "user", "content": name}) + self.update_session(session_id, message=message) new_session = SessionModel( id=session_id, name=name, agent_id=agent_id, - agent_type=agent_type + agent_type=agent_type, + tenant_id = user_id, + message=[{"role": "user", "content": name}] ) self.db.add(new_session) self.db.commit() @@ -61,6 +65,10 @@ """ session = self.get_session_by_id(session_id) if session: + if "message" in kwargs: + message = session.message + message.append(kwargs["message"]) + kwargs["message"] = message for key, value in kwargs.items(): setattr(session, key, value) self.db.commit() diff --git a/app/task/fetch_agent.py b/app/task/fetch_agent.py index 6e80963..a39461a 100644 --- a/app/task/fetch_agent.py +++ b/app/task/fetch_agent.py @@ -116,7 +116,7 @@ ('da3451da89d911efb9490242ac190006', 3, '鐭ヨ瘑闂瓟', 'RAGFLOW', 'knowledgeQA'), ('e96eb7a589db11ef87d20242ac190006', 5, '鏅鸿兘闂瓟', 'RAGFLOW', 'chat'), ('basic_excel_talk', 6, '鏅鸿兘鏁版嵁', 'BASIC', 'excelTalk'), - ('basic_question_talk', 7, '鏂囨。鍑洪', 'BASIC', 'questionTalk') + ('basic_question_talk', 7, '鏂囨。鍑哄嵎', 'BASIC', 'questionTalk') ] for agent in initial_agents: -- Gitblit v1.8.0