import json
|
import pytz
|
|
from datetime import datetime
|
from sqlalchemy.orm import Session
|
from typing import Optional, Type
|
from pydantic import BaseModel
|
from sqlalchemy import Column, String, Integer, DateTime, JSON, TEXT, Index
|
|
from Log import logger
|
from app.models.agent_model import AgentType
|
from app.models.base_model import Base
|
|
def current_time():
|
tz = pytz.timezone('Asia/Shanghai')
|
return datetime.now(tz)
|
|
class ChatSessionModel(Base):
|
__tablename__ = "chat_sessions"
|
|
# __table_args__ = (
|
# Index('idx_username', 'username'),
|
# )
|
|
id = Column(Integer, primary_key=True)
|
name = Column(String(255))
|
agent_id = Column(String(255))
|
agent_type = Column(Integer) # 目前只存basic的,ragflow和bisheng的调接口获取
|
create_date = Column(DateTime, default=current_time) # 创建时间,默认值为当前时区时间
|
update_date = Column(DateTime, default=current_time, onupdate=current_time, index=True) # 更新时间,默认值为当前时区时间,更新时自动更新
|
tenant_id = Column(Integer) # 创建人
|
message = Column(TEXT) # 说明
|
reference = Column(TEXT) # 说明
|
conversation_id = Column(String(64))
|
session_id = Column(String(36), index=True)
|
chat_mode = Column(Integer)
|
|
# to_dict 方法
|
def to_dict(self):
|
return {
|
'id': self.id,
|
'name': self.name,
|
'agent_type': self.agent_type,
|
'agent_id': self.agent_id,
|
'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
|
'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
|
}
|
|
def log_to_json(self):
|
return {
|
'id': self.id,
|
'name': self.name,
|
'agent_type': self.agent_type,
|
'agent_id': self.agent_id,
|
'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
|
'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
|
'message': json.loads(self.message)
|
}
|
|
def add_message(self, message: dict):
|
if self.message is None:
|
self.message = '[]'
|
try:
|
msg = json.loads(self.message)
|
msg.append(message)
|
except Exception as e:
|
return
|
self.message = json.dumps(msg)
|
|
|
class ChatDialogData(BaseModel):
|
sessionId: Optional[str] = ""
|
question: str
|
chatId: str
|
|
|
|
class ChatSessionDao:
|
def __init__(self, db: Session):
|
self.db = db
|
|
def create_session(self, session_id: str, name: str, agent_id: str, agent_type: int, user_id: int, message: str,reference:str) -> ChatSessionModel:
|
new_session = ChatSessionModel(
|
id=session_id,
|
name=name[0:255],
|
agent_id=agent_id,
|
agent_type=agent_type,
|
create_date=current_time(),
|
update_date=current_time(),
|
tenant_id=user_id,
|
message=message,
|
reference=reference,
|
)
|
self.db.add(new_session)
|
self.db.commit()
|
self.db.refresh(new_session)
|
return new_session
|
|
def get_session_by_id(self, session_id: str) -> Type[ChatSessionModel] | None:
|
session = self.db.query(ChatSessionModel).filter_by(id=session_id).first()
|
if session and session.message is None:
|
session.message = '[]'
|
return session
|
|
def update_session_by_id(self, session_id: str, **kwargs) -> Type[ChatSessionModel] | None:
|
session = self.get_session_by_id(session_id)
|
if session:
|
if "message" in kwargs:
|
session.add_message(kwargs["message"])
|
# 替换其他字段
|
for key, value in kwargs.items():
|
if key != "message":
|
setattr(session, key, value)
|
session.update_date = current_time()
|
try:
|
self.db.commit()
|
self.db.refresh(session)
|
except Exception as e:
|
logger.error(e)
|
self.db.rollback()
|
return session
|
|
def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> ChatSessionModel:
|
existing_session = self.get_session_by_id(session_id)
|
if existing_session:
|
existing_session.add_message({"role": "user", "content": name})
|
existing_session.update_date = current_time()
|
self.db.commit()
|
self.db.refresh(existing_session)
|
return existing_session
|
|
new_session = ChatSessionModel(
|
id=session_id,
|
name=name[0:50],
|
agent_id=agent_id,
|
agent_type=agent_type,
|
tenant_id=user_id,
|
message=json.dumps([{"role": "user", "content": name}])
|
)
|
self.db.add(new_session)
|
self.db.commit()
|
self.db.refresh(new_session)
|
return new_session
|
|
def delete_session(self, session_id: str) -> None:
|
"""
|
删除会话记录。
|
|
参数:
|
session_id (str): 会话ID。
|
"""
|
session = self.get_session_by_id(session_id)
|
if session:
|
self.db.delete(session)
|
self.db.commit()
|