zhaoqingang
2024-11-22 95bfccca0260a6ff3e994ebbbbfafb61b7dab442
app/service/session.py
@@ -1,7 +1,9 @@
from typing import Type
from sqlalchemy.orm import Session
from Log import logger
from app.models import AgentType
from app.models import AgentType, current_time
from app.models.session_model import SessionModel
@@ -9,7 +11,8 @@
    def __init__(self, db: Session):
        self.db = db
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> SessionModel:
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> Type[
                                                                                                                    SessionModel] | SessionModel:
        """
        创建一个新的会话记录。
@@ -22,19 +25,20 @@
        返回:
            SessionModel: 新创建的会话模型实例,如果会话ID已存在则返回None。
        """
        logger.error("-------------xieru")
        existing_session = self.get_session_by_id(session_id)
        if existing_session:
            message=existing_session.message
            message.append({"role": "user", "content": name})
            self.update_session(session_id, message=message)
        logger.error("-------------xieru------------1")
            existing_session.add_message({"role": "user", "content": name})
            existing_session.update_date = current_time()
            self.db.commit()
            self.db.refresh(existing_session)
            return existing_session
        new_session = SessionModel(
            id=session_id,
            name=name[0:200],
            name=name[0:50],
            agent_id=agent_id,
            agent_type=agent_type,
            tenant_id = user_id,
            tenant_id=user_id,
            message=[{"role": "user", "content": name}]
        )
        self.db.add(new_session)
@@ -42,7 +46,7 @@
        self.db.refresh(new_session)
        return new_session
    def get_session_by_id(self, session_id: str) -> SessionModel:
    def get_session_by_id(self, session_id: str) -> Type[SessionModel] | None:
        """
        根据会话ID获取会话记录。
@@ -52,9 +56,12 @@
        返回:
            SessionModel: 查找到的会话模型实例,如果未找到则返回None。
        """
        return self.db.query(SessionModel).filter_by(id=session_id).first()
        session = self.db.query(SessionModel).filter_by(id=session_id).first()
        if session.message is None:
            session.message = '[]'
        return session
    def update_session(self, session_id: str, **kwargs) -> SessionModel:
    def update_session(self, session_id: str, **kwargs) -> Type[SessionModel] | None:
        """
        更新会话记录。
@@ -66,18 +73,21 @@
            SessionModel: 更新后的会话模型实例。
        """
        logger.error("更新数据---------------------------")
        session = self.db.query(SessionModel).filter_by(id=session_id).first()
        self.db.commit()
        session = self.get_session_by_id(session_id)
        if session:
            if "message" in kwargs:
                message = session.message
                logger.error(kwargs)
                message.append(kwargs["message"])
                session.message = message
                logger.error("更新数据--------------------------11111111-")
                logger.error(message)
                session.add_message(kwargs["message"])
            # 替换其他字段
            for key, value in kwargs.items():
                if key != "message":
                    setattr(session, key, value)
            session.update_date = current_time()
            try:
                self.db.commit()
                self.db.refresh(session)
                logger.error("更新数据完成--------------------------1111111122222222-")
            except Exception as e:
                self.db.rollback()
        return session
    def delete_session(self, session_id: str) -> None: