zhaoqingang
2025-04-03 9683aeeafa2f1067ef061b34124a1c362df07e5e
app/models/v2/session_model.py
@@ -1,47 +1,51 @@
import json
from datetime import datetime
from enum import IntEnum
from typing import Optional
import pytz
from datetime import datetime
from sqlalchemy.orm import Session
from typing import Optional, Type, List
from pydantic import BaseModel
from sqlalchemy import Column, String, Integer, DateTime, JSON, TEXT, Index
# from Log import logger
from app.models.agent_model import AgentType
# from app.models import current_time
from app.models.base_model import Base
def current_time():
    tz = pytz.timezone('Asia/Shanghai')
    return datetime.now(tz)
class SessionModel(Base):
class ChatSessionModel(Base):
    __tablename__ = "chat_sessions"
    __table_args__ = (
        Index('idx_username', 'username'),
    )
    # __table_args__ = (
    #     Index('idx_username', 'username'),
    # )
    id = Column(Integer, primary_key=True)
    id = Column(String(36), primary_key=True)
    name = Column(String(255))
    agent_id = Column(String(255))
    agent_type = Column(Integer)  # 目前只存basic的,ragflow和bisheng的调接口获取
    create_date = Column(DateTime, default=current_time)  # 创建时间,默认值为当前时区时间
    update_date = Column(DateTime, default=current_time, onupdate=current_time, index=True)  # 更新时间,默认值为当前时区时间,更新时自动更新
    tenant_id = Column(Integer)  # 创建人
    message = Column(TEXT)  # 说明
    reference = Column(TEXT)  # 说明
    conversation_id = Column(String(64))
    session_id = Column(String(36), index=True)
    chat_mode = Column(Integer)
    tenant_id = Column(Integer, index=True)  # 创建人
    message = Column(TEXT)
    reference = Column(TEXT)
    conversation_id = Column(String(36), index=True)
    event_type = Column(String(16))
    session_type = Column(String(16))
    # to_dict 方法
    def to_dict(self):
        return {
            'id': self.id,
            'session_id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'chat_id': self.agent_id,
            'event_type': self.event_type,
            'session_type': self.session_type if self.session_type else 0,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
        }
@@ -51,7 +55,7 @@
            'id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'chat_id': self.agent_id,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
            'message': json.loads(self.message)
@@ -64,11 +68,91 @@
            msg = json.loads(self.message)
            msg.append(message)
        except Exception as e:
            print(e)
            return
        self.message = json.dumps(msg)
class ChatDialogData(BaseModel):
class ChatData(BaseModel):
    sessionId: Optional[str] = ""
    question: str
    chatId: str
    class Config:
        extra = 'allow'  # 允许其他动态字段
    def to_dict(self):
        res = {"files": [], "inputs": {}}
        if hasattr(self, 'files'):
            res['files'] = self.files
        if hasattr(self, 'inputs'):
            res['inputs'] = self.inputs
        return res
class ChatSessionDao:
    def __init__(self, db: Session):
        self.db = db
    async def create_session(self, session_id: str, **kwargs) -> ChatSessionModel:
        new_session = ChatSessionModel(
            id=session_id,
            create_date=current_time(),
            update_date=current_time(),
            **kwargs
        )
        new_session.message = json.dumps([new_session.message])
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
    async def get_session_by_id(self, session_id: str) -> ChatSessionModel | None:
        session = self.db.query(ChatSessionModel).filter_by(id=session_id).first()
        return session
    async def update_session_by_id(self, session_id: str, session, message: dict, conversation_id=None) -> ChatSessionModel | None:
        # print(message)
        if not session:
            session = await self.get_session_by_id(session_id)
        if session:
            try:
                if conversation_id:
                    session.conversation_id=conversation_id
                session.add_message(message)
                session.update_date = current_time()
                self.db.commit()
                self.db.refresh(session)
            except Exception as e:
                # logger.error(e)
                self.db.rollback()
        return session
    async def update_or_insert_by_id(self, session_id: str, **kwargs) -> ChatSessionModel:
        existing_session = await self.get_session_by_id(session_id)
        if existing_session:
            return await self.update_session_by_id(session_id, existing_session, kwargs.get("message"))
        existing_session = await self.create_session(session_id, **kwargs)
        return existing_session
    async def delete_session(self, session_id: str) -> None:
        session = await self.get_session_by_id(session_id)
        if session:
            self.db.delete(session)
            self.db.commit()
    async def get_session_list(self, user_id: int, agent_id: str, keyword:str, page: int, page_size: int) -> any:
        query = self.db.query(ChatSessionModel).filter(ChatSessionModel.tenant_id==user_id)
        if agent_id:
            query = query.filter(ChatSessionModel.agent_id==agent_id)
        if keyword:
            query = query.filter(ChatSessionModel.name.like('%{}%'.format(keyword)))
        total = query.count()
        session_list = query.order_by(ChatSessionModel.update_date.desc()).offset((page-1)*page_size).limit(page_size).all()
        return total, session_list