import json import pytz from datetime import datetime from sqlalchemy.orm import Session from typing import Optional, Type, List from pydantic import BaseModel from sqlalchemy import Column, String, Integer, DateTime, JSON, TEXT, Index # from Log import logger from app.models.agent_model import AgentType from app.models.base_model import Base def current_time(): tz = pytz.timezone('Asia/Shanghai') return datetime.now(tz) class ChatSessionModel(Base): __tablename__ = "chat_sessions" # __table_args__ = ( # Index('idx_username', 'username'), # ) id = Column(String(36), primary_key=True) name = Column(String(255)) agent_id = Column(String(255)) agent_type = Column(Integer) # 目前只存basic的,ragflow和bisheng的调接口获取 create_date = Column(DateTime, default=current_time) # 创建时间,默认值为当前时区时间 update_date = Column(DateTime, default=current_time, onupdate=current_time, index=True) # 更新时间,默认值为当前时区时间,更新时自动更新 tenant_id = Column(Integer, index=True) # 创建人 message = Column(TEXT) reference = Column(TEXT) conversation_id = Column(String(36), index=True) event_type = Column(String(16)) session_type = Column(String(16)) # to_dict 方法 def to_dict(self): return { 'session_id': self.id, 'name': self.name, 'agent_type': self.agent_type, 'chat_id': self.agent_id, 'event_type': self.event_type, 'session_type': self.session_type if self.session_type else 0, 'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"), 'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"), } def log_to_json(self): return { 'id': self.id, 'name': self.name, 'agent_type': self.agent_type, 'chat_id': self.agent_id, 'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"), 'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"), 'message': json.loads(self.message) } def add_message(self, message: dict): if self.message is None: self.message = '[]' try: msg = json.loads(self.message) msg.append(message) except Exception as e: print(e) return self.message = json.dumps(msg) class ChatData(BaseModel): sessionId: Optional[str] = "" class Config: extra = 'allow' # 允许其他动态字段 class ChatSessionDao: def __init__(self, db: Session): self.db = db async def create_session(self, session_id: str, **kwargs) -> ChatSessionModel: new_session = ChatSessionModel( id=session_id, create_date=current_time(), update_date=current_time(), **kwargs ) new_session.message = json.dumps([new_session.message]) self.db.add(new_session) self.db.commit() self.db.refresh(new_session) return new_session async def get_session_by_id(self, session_id: str) -> ChatSessionModel | None: session = self.db.query(ChatSessionModel).filter_by(id=session_id).first() return session async def update_session_by_id(self, session_id: str, session, message: dict, conversation_id=None) -> ChatSessionModel | None: # print(message) if not session: session = await self.get_session_by_id(session_id) if session: try: if conversation_id: session.conversation_id=conversation_id session.add_message(message) session.update_date = current_time() self.db.commit() self.db.refresh(session) except Exception as e: # logger.error(e) self.db.rollback() return session async def update_or_insert_by_id(self, session_id: str, **kwargs) -> ChatSessionModel: existing_session = await self.get_session_by_id(session_id) if existing_session: return await self.update_session_by_id(session_id, existing_session, kwargs.get("message")) existing_session = await self.create_session(session_id, **kwargs) return existing_session async def delete_session(self, session_id: str) -> None: session = await self.get_session_by_id(session_id) if session: self.db.delete(session) self.db.commit() async def get_session_list(self, user_id: int, agent_id: str, keyword:str, page: int, page_size: int) -> any: query = self.db.query(ChatSessionModel).filter(ChatSessionModel.tenant_id==user_id) if agent_id: query = query.filter(ChatSessionModel.agent_id==agent_id) if keyword: query = query.filter(ChatSessionModel.name.like('%{}%'.format(keyword))) total = query.count() session_list = query.order_by(ChatSessionModel.update_date.desc()).offset((page-1)*page_size).limit(page_size).all() return total, session_list