import json
|
import pytz
|
|
from datetime import datetime
|
from sqlalchemy.orm import Session
|
from typing import Optional, Type, List
|
from pydantic import BaseModel
|
from sqlalchemy import Column, String, Integer, DateTime, JSON, TEXT, Index
|
|
# from Log import logger
|
from app.models.agent_model import AgentType
|
from app.models.base_model import Base
|
|
|
def current_time():
|
tz = pytz.timezone('Asia/Shanghai')
|
return datetime.now(tz)
|
|
|
class ChatSessionModel(Base):
|
__tablename__ = "chat_sessions"
|
|
# __table_args__ = (
|
# Index('idx_username', 'username'),
|
# )
|
|
id = Column(String(36), primary_key=True)
|
name = Column(String(255))
|
agent_id = Column(String(255))
|
agent_type = Column(Integer) # 目前只存basic的,ragflow和bisheng的调接口获取
|
create_date = Column(DateTime, default=current_time) # 创建时间,默认值为当前时区时间
|
update_date = Column(DateTime, default=current_time, onupdate=current_time, index=True) # 更新时间,默认值为当前时区时间,更新时自动更新
|
tenant_id = Column(Integer, index=True) # 创建人
|
message = Column(TEXT)
|
reference = Column(TEXT)
|
conversation_id = Column(String(36), index=True)
|
event_type = Column(String(16))
|
session_type = Column(String(16))
|
|
# to_dict 方法
|
def to_dict(self):
|
return {
|
'session_id': self.id,
|
'name': self.name,
|
'agent_type': self.agent_type,
|
'chat_id': self.agent_id,
|
'event_type': self.event_type,
|
'session_type': self.session_type if self.session_type else 0,
|
'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
|
'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
|
}
|
|
def log_to_json(self):
|
return {
|
'id': self.id,
|
'name': self.name,
|
'agent_type': self.agent_type,
|
'chat_id': self.agent_id,
|
'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
|
'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
|
'message': json.loads(self.message)
|
}
|
|
def add_message(self, message: dict):
|
if self.message is None:
|
self.message = '[]'
|
try:
|
msg = json.loads(self.message)
|
msg.append(message)
|
except Exception as e:
|
print(e)
|
return
|
self.message = json.dumps(msg)
|
|
|
|
class ChatData(BaseModel):
|
sessionId: Optional[str] = ""
|
|
class Config:
|
extra = 'allow' # 允许其他动态字段
|
|
|
|
def to_dict(self):
|
res = {"files": [], "inputs": {}}
|
if hasattr(self, 'files'):
|
res['files'] = self.files
|
if hasattr(self, 'inputs'):
|
res['inputs'] = self.inputs
|
return res
|
|
|
|
|
class ChatSessionDao:
|
def __init__(self, db: Session):
|
self.db = db
|
|
async def create_session(self, session_id: str, **kwargs) -> ChatSessionModel:
|
new_session = ChatSessionModel(
|
id=session_id,
|
create_date=current_time(),
|
update_date=current_time(),
|
**kwargs
|
)
|
new_session.message = json.dumps([new_session.message])
|
self.db.add(new_session)
|
self.db.commit()
|
self.db.refresh(new_session)
|
return new_session
|
|
async def get_session_by_id(self, session_id: str) -> ChatSessionModel | None:
|
session = self.db.query(ChatSessionModel).filter_by(id=session_id).first()
|
return session
|
|
async def update_session_by_id(self, session_id: str, session, message: dict, conversation_id=None) -> ChatSessionModel | None:
|
# print(message)
|
if not session:
|
session = await self.get_session_by_id(session_id)
|
if session:
|
try:
|
if conversation_id:
|
session.conversation_id=conversation_id
|
session.add_message(message)
|
session.update_date = current_time()
|
self.db.commit()
|
self.db.refresh(session)
|
except Exception as e:
|
# logger.error(e)
|
self.db.rollback()
|
return session
|
|
async def update_or_insert_by_id(self, session_id: str, **kwargs) -> ChatSessionModel:
|
existing_session = await self.get_session_by_id(session_id)
|
if existing_session:
|
return await self.update_session_by_id(session_id, existing_session, kwargs.get("message"))
|
|
existing_session = await self.create_session(session_id, **kwargs)
|
return existing_session
|
|
async def delete_session(self, session_id: str) -> None:
|
session = await self.get_session_by_id(session_id)
|
if session:
|
self.db.delete(session)
|
self.db.commit()
|
|
async def get_session_list(self, user_id: int, agent_id: str, keyword:str, page: int, page_size: int) -> any:
|
query = self.db.query(ChatSessionModel).filter(ChatSessionModel.tenant_id==user_id)
|
if agent_id:
|
query = query.filter(ChatSessionModel.agent_id==agent_id)
|
if keyword:
|
query = query.filter(ChatSessionModel.name.like('%{}%'.format(keyword)))
|
total = query.count()
|
session_list = query.order_by(ChatSessionModel.update_date.desc()).offset((page-1)*page_size).limit(page_size).all()
|
return total, session_list
|