zhaoqingang
2025-03-04 370120fd4154ce6c5f69d16a4a343a016cf2e816
app/models/v2/session_model.py
@@ -1,39 +1,51 @@
import json
from datetime import datetime
from enum import IntEnum
from typing import Optional
import pytz
from pydantic import BaseModel
from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime, JSON, TEXT
from datetime import datetime
from sqlalchemy.orm import Session
from typing import Optional, Type, List
from pydantic import BaseModel
from sqlalchemy import Column, String, Integer, DateTime, JSON, TEXT, Index
# from Log import logger
from app.models.agent_model import AgentType
# from app.models import current_time
from app.models.base_model import Base
def current_time():
    tz = pytz.timezone('Asia/Shanghai')
    return datetime.now(tz)
class SessionModel(Base):
    __tablename__ = "sessions"
    id = Column(String(255), primary_key=True)
class ChatSessionModel(Base):
    __tablename__ = "chat_sessions"
    # __table_args__ = (
    #     Index('idx_username', 'username'),
    # )
    id = Column(String(36), primary_key=True)
    name = Column(String(255))
    agent_id = Column(String(255))
    agent_type = Column(SQLAlchemyEnum(AgentType), nullable=False)  # 目前只存basic的,ragflow和bisheng的调接口获取
    agent_type = Column(Integer)  # 目前只存basic的,ragflow和bisheng的调接口获取
    create_date = Column(DateTime, default=current_time)  # 创建时间,默认值为当前时区时间
    update_date = Column(DateTime, default=current_time, onupdate=current_time)  # 更新时间,默认值为当前时区时间,更新时自动更新
    tenant_id = Column(Integer)  # 创建人
    message = Column(TEXT)  # 说明
    conversation_id = Column(String(64))
    update_date = Column(DateTime, default=current_time, onupdate=current_time, index=True)  # 更新时间,默认值为当前时区时间,更新时自动更新
    tenant_id = Column(Integer, index=True)  # 创建人
    message = Column(TEXT)
    reference = Column(TEXT)
    conversation_id = Column(String(36), index=True)
    event_type = Column(String(16))
    session_type = Column(String(16))
    # to_dict 方法
    def to_dict(self):
        return {
            'id': self.id,
            'session_id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'chat_id': self.agent_id,
            'event_type': self.event_type,
            'session_type': self.session_type if self.session_type else 0,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
        }
@@ -43,7 +55,7 @@
            'id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'chat_id': self.agent_id,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
            'message': json.loads(self.message)
@@ -56,10 +68,80 @@
            msg = json.loads(self.message)
            msg.append(message)
        except Exception as e:
            print(e)
            return
        self.message = json.dumps(msg)
class ChatDialogData(BaseModel):
class ChatData(BaseModel):
    sessionId: Optional[str] = ""
    question: str
    class Config:
        extra = 'allow'  # 允许其他动态字段
class ChatSessionDao:
    def __init__(self, db: Session):
        self.db = db
    async def create_session(self, session_id: str, **kwargs) -> ChatSessionModel:
        new_session = ChatSessionModel(
            id=session_id,
            create_date=current_time(),
            update_date=current_time(),
            **kwargs
        )
        new_session.message = json.dumps([new_session.message])
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
    async def get_session_by_id(self, session_id: str) -> ChatSessionModel | None:
        session = self.db.query(ChatSessionModel).filter_by(id=session_id).first()
        return session
    async def update_session_by_id(self, session_id: str, session, message: dict, conversation_id=None) -> ChatSessionModel | None:
        # print(message)
        if not session:
            session = await self.get_session_by_id(session_id)
        if session:
            try:
                if conversation_id:
                    session.conversation_id=conversation_id
                session.add_message(message)
                session.update_date = current_time()
                self.db.commit()
                self.db.refresh(session)
            except Exception as e:
                # logger.error(e)
                self.db.rollback()
        return session
    async def update_or_insert_by_id(self, session_id: str, **kwargs) -> ChatSessionModel:
        existing_session = await self.get_session_by_id(session_id)
        if existing_session:
            return await self.update_session_by_id(session_id, existing_session, kwargs.get("message"))
        existing_session = await self.create_session(session_id, **kwargs)
        return existing_session
    async def delete_session(self, session_id: str) -> None:
        session = await self.get_session_by_id(session_id)
        if session:
            self.db.delete(session)
            self.db.commit()
    async def get_session_list(self, user_id: int, agent_id: str, keyword:str, page: int, page_size: int) -> any:
        query = self.db.query(ChatSessionModel).filter(ChatSessionModel.tenant_id==user_id)
        if agent_id:
            query = query.filter(ChatSessionModel.agent_id==agent_id)
        if keyword:
            query = query.filter(ChatSessionModel.name.like('%{}%'.format(keyword)))
        total = query.count()
        session_list = query.order_by(ChatSessionModel.update_date.desc()).offset((page-1)*page_size).limit(page_size).all()
        return total, session_list