zhaoqingang
2025-04-03 9683aeeafa2f1067ef061b34124a1c362df07e5e
app/models/v2/session_model.py
@@ -3,17 +3,19 @@
from datetime import datetime
from sqlalchemy.orm import Session
from typing import Optional, Type
from typing import Optional, Type, List
from pydantic import BaseModel
from sqlalchemy import Column, String, Integer, DateTime, JSON, TEXT, Index
from Log import logger
# from Log import logger
from app.models.agent_model import AgentType
from app.models.base_model import Base
def current_time():
    tz = pytz.timezone('Asia/Shanghai')
    return datetime.now(tz)
class ChatSessionModel(Base):
    __tablename__ = "chat_sessions"
@@ -22,26 +24,28 @@
    #     Index('idx_username', 'username'),
    # )
    id = Column(Integer, primary_key=True)
    id = Column(String(36), primary_key=True)
    name = Column(String(255))
    agent_id = Column(String(255))
    agent_type = Column(Integer)  # 目前只存basic的,ragflow和bisheng的调接口获取
    create_date = Column(DateTime, default=current_time)  # 创建时间,默认值为当前时区时间
    update_date = Column(DateTime, default=current_time, onupdate=current_time, index=True)  # 更新时间,默认值为当前时区时间,更新时自动更新
    tenant_id = Column(Integer)  # 创建人
    message = Column(TEXT)  # 说明
    reference = Column(TEXT)  # 说明
    conversation_id = Column(String(64))
    session_id = Column(String(36), index=True)
    chat_mode = Column(Integer)
    tenant_id = Column(Integer, index=True)  # 创建人
    message = Column(TEXT)
    reference = Column(TEXT)
    conversation_id = Column(String(36), index=True)
    event_type = Column(String(16))
    session_type = Column(String(16))
    # to_dict 方法
    def to_dict(self):
        return {
            'id': self.id,
            'session_id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'chat_id': self.agent_id,
            'event_type': self.event_type,
            'session_type': self.session_type if self.session_type else 0,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
        }
@@ -51,7 +55,7 @@
            'id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'chat_id': self.agent_id,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
            'message': json.loads(self.message)
@@ -64,14 +68,28 @@
            msg = json.loads(self.message)
            msg.append(message)
        except Exception as e:
            print(e)
            return
        self.message = json.dumps(msg)
class ChatDialogData(BaseModel):
class ChatData(BaseModel):
    sessionId: Optional[str] = ""
    question: str
    chatId: str
    class Config:
        extra = 'allow'  # 允许其他动态字段
    def to_dict(self):
        res = {"files": [], "inputs": {}}
        if hasattr(self, 'files'):
            res['files'] = self.files
        if hasattr(self, 'inputs'):
            res['inputs'] = self.inputs
        return res
@@ -79,77 +97,62 @@
    def __init__(self, db: Session):
        self.db = db
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: int, user_id: int, message: str,reference:str) -> ChatSessionModel:
    async def create_session(self, session_id: str, **kwargs) -> ChatSessionModel:
        new_session = ChatSessionModel(
            id=session_id,
            name=name[0:255],
            agent_id=agent_id,
            agent_type=agent_type,
            create_date=current_time(),
            update_date=current_time(),
            tenant_id=user_id,
            message=message,
            reference=reference,
            **kwargs
        )
        new_session.message = json.dumps([new_session.message])
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
    def get_session_by_id(self, session_id: str) -> Type[ChatSessionModel] | None:
    async def get_session_by_id(self, session_id: str) -> ChatSessionModel | None:
        session = self.db.query(ChatSessionModel).filter_by(id=session_id).first()
        if  session and session.message is None:
            session.message = '[]'
        return session
    def update_session_by_id(self, session_id: str, **kwargs) -> Type[ChatSessionModel] | None:
        session = self.get_session_by_id(session_id)
    async def update_session_by_id(self, session_id: str, session, message: dict, conversation_id=None) -> ChatSessionModel | None:
        # print(message)
        if not session:
            session = await self.get_session_by_id(session_id)
        if session:
            if "message" in kwargs:
                session.add_message(kwargs["message"])
            # 替换其他字段
            for key, value in kwargs.items():
                if key != "message":
                    setattr(session, key, value)
            session.update_date = current_time()
            try:
                if conversation_id:
                    session.conversation_id=conversation_id
                session.add_message(message)
                session.update_date = current_time()
                self.db.commit()
                self.db.refresh(session)
            except Exception as e:
                logger.error(e)
                # logger.error(e)
                self.db.rollback()
        return session
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> ChatSessionModel:
        existing_session = self.get_session_by_id(session_id)
    async def update_or_insert_by_id(self, session_id: str, **kwargs) -> ChatSessionModel:
        existing_session = await self.get_session_by_id(session_id)
        if existing_session:
            existing_session.add_message({"role": "user", "content": name})
            existing_session.update_date = current_time()
            self.db.commit()
            self.db.refresh(existing_session)
            return existing_session
            return await self.update_session_by_id(session_id, existing_session, kwargs.get("message"))
        new_session = ChatSessionModel(
            id=session_id,
            name=name[0:50],
            agent_id=agent_id,
            agent_type=agent_type,
            tenant_id=user_id,
            message=json.dumps([{"role": "user", "content": name}])
        )
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
        existing_session = await self.create_session(session_id, **kwargs)
        return existing_session
    def delete_session(self, session_id: str) -> None:
        """
        删除会话记录。
        参数:
            session_id (str): 会话ID。
        """
        session = self.get_session_by_id(session_id)
    async def delete_session(self, session_id: str) -> None:
        session = await self.get_session_by_id(session_id)
        if session:
            self.db.delete(session)
            self.db.commit()
            self.db.commit()
    async def get_session_list(self, user_id: int, agent_id: str, keyword:str, page: int, page_size: int) -> any:
        query = self.db.query(ChatSessionModel).filter(ChatSessionModel.tenant_id==user_id)
        if agent_id:
            query = query.filter(ChatSessionModel.agent_id==agent_id)
        if keyword:
            query = query.filter(ChatSessionModel.name.like('%{}%'.format(keyword)))
        total = query.count()
        session_list = query.order_by(ChatSessionModel.update_date.desc()).offset((page-1)*page_size).limit(page_size).all()
        return total, session_list