zhaoqingang
2024-11-21 ae30d9a75407c912649f11c4f44ff15c869a4f98
自研agent会话保存和查询
5个文件已修改
41 ■■■■ 已修改文件
app/api/agent.py 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/chat.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/session_model.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/session.py 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/task/fetch_agent.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/agent.py
@@ -105,7 +105,7 @@
                return JSONResponse(status_code=200, content={"code": 400, "message": "Invalid result structure"})
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
    if agent.agent_type == AgentType.BISHENG:
    elif agent.agent_type == AgentType.BISHENG:
        bisheng_service = BishengService(base_url=settings.sgb_base_url)
        try:
            token = get_bisheng_token(db, current_user.id)
@@ -139,6 +139,9 @@
            return JSONResponse(status_code=200, content={"code": 200, "data": combined_logs})
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
    elif agent.agent_type == AgentType.BASIC:
        session = db.query(SessionModel).filter(SessionModel.agent_id == agent_id, SessionModel.tenant_id==current_user.id).first()
        return JSONResponse(status_code=200, content={"code": 200, "data": session.log_to_json()})
    else:
        return JSONResponse(status_code=200, content={"code": 200, "log": "Unsupported agent type"})
app/api/chat.py
@@ -209,7 +209,8 @@
                    session_id=chat_id,
                    name=question,
                    agent_id=agent_id,
                    agent_type=AgentType.BASIC
                    agent_type=AgentType.BASIC,
                    user_id=current_user.id
                )
                if not question:
                    await websocket.send_json({"message": "Invalid request", "type": "error"})
@@ -232,6 +233,7 @@
                                if file_name:
                                    excel_url = f"/api/files/download/?agent_id=basic_question_talk&file_id={file_name}&file_type=word"
                                result = {"message": output, "type": "message", "file_url": excel_url}
                                SessionService(db).update_session(session_id=chat_id, is_incr=1, message={"role":"assistant", "content": result})
                                await websocket.send_json(result | data)
                            except json.JSONDecodeError as e:
                                print(f"Error decoding JSON: {e}")
app/models/session_model.py
@@ -1,7 +1,7 @@
import json
from datetime import datetime
from enum import IntEnum
from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime
from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime, JSON
from app.models import AgentType, current_time
from app.models.base_model import Base
@@ -15,6 +15,9 @@
    agent_type = Column(SQLAlchemyEnum(AgentType), nullable=False)  # 目前只存basic的,ragflow和bisheng的调接口获取
    create_date = Column(DateTime, default=current_time)  # 创建时间,默认值为当前时区时间
    update_date = Column(DateTime, default=current_time, onupdate=current_time)  # 更新时间,默认值为当前时区时间,更新时自动更新
    tenant_id = Column(Integer)  # 创建人
    message = Column(JSON)  # 说明
    # to_dict 方法
    def to_dict(self):
        return {
@@ -25,3 +28,14 @@
            'create_date': self.create_date,
            'update_date': self.update_date,
        }
    def log_to_json(self):
        return {
            'id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'create_date': self.create_date,
            'update_date': self.update_date,
            'message': self.message
        }
app/service/session.py
@@ -8,7 +8,7 @@
    def __init__(self, db: Session):
        self.db = db
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType) -> SessionModel:
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> SessionModel:
        """
        创建一个新的会话记录。
@@ -23,13 +23,17 @@
        """
        existing_session = self.get_session_by_id(session_id)
        if existing_session:
            return None  # 如果会话ID已存在,不进行任何操作
            message=existing_session.message
            message.append({"role": "user", "content": name})
            self.update_session(session_id, message=message)
        new_session = SessionModel(
            id=session_id,
            name=name,
            agent_id=agent_id,
            agent_type=agent_type
            agent_type=agent_type,
            tenant_id = user_id,
            message=[{"role": "user", "content": name}]
        )
        self.db.add(new_session)
        self.db.commit()
@@ -61,6 +65,10 @@
        """
        session = self.get_session_by_id(session_id)
        if session:
            if "message" in kwargs:
                message =  session.message
                message.append(kwargs["message"])
                kwargs["message"] = message
            for key, value in kwargs.items():
                setattr(session, key, value)
            self.db.commit()
app/task/fetch_agent.py
@@ -116,7 +116,7 @@
            ('da3451da89d911efb9490242ac190006', 3, '知识问答', 'RAGFLOW', 'knowledgeQA'),
            ('e96eb7a589db11ef87d20242ac190006', 5, '智能问答', 'RAGFLOW', 'chat'),
            ('basic_excel_talk', 6, '智能数据', 'BASIC', 'excelTalk'),
            ('basic_question_talk', 7, '文档出题', 'BASIC', 'questionTalk')
            ('basic_question_talk', 7, '文档出卷', 'BASIC', 'questionTalk')
        ]
        for agent in initial_agents: