import json
|
from typing import Type
|
|
from sqlalchemy.orm import Session
|
|
from Log import logger
|
from app.models import AgentType, current_time
|
from app.models.session_model import SessionModel
|
|
|
class SessionService:
|
def __init__(self, db: Session):
|
self.db = db
|
|
def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int, message: dict = None, workflow_type: int = 0) -> Type[
|
SessionModel] | SessionModel:
|
"""
|
创建一个新的会话记录。
|
|
参数:
|
session_id (str): 会话ID。
|
name (str): 会话名称。
|
agent_id (str): 代理ID。
|
agent_type (AgentType): 代理类型。
|
|
返回:
|
SessionModel: 新创建的会话模型实例,如果会话ID已存在则返回None。
|
"""
|
existing_session = self.get_session_by_id(session_id)
|
if existing_session:
|
print("update success")
|
# existing_session.add_message({"role": "user", "content": name})
|
existing_session.add_message(message)
|
existing_session.update_date = current_time()
|
self.db.commit()
|
self.db.refresh(existing_session)
|
return existing_session
|
|
new_session = SessionModel(
|
id=session_id,
|
name=name[0:50],
|
agent_id=agent_id,
|
agent_type=agent_type,
|
tenant_id=user_id,
|
# message=json.dumps([{"role": "user", "content": name}])
|
workflow = workflow_type,
|
message = json.dumps([message])
|
)
|
self.db.add(new_session)
|
self.db.commit()
|
self.db.refresh(new_session)
|
return new_session
|
|
def get_session_by_id(self, session_id: str) -> Type[SessionModel] | None:
|
"""
|
根据会话ID获取会话记录。
|
|
参数:
|
session_id (str): 会话ID。
|
|
返回:
|
SessionModel: 查找到的会话模型实例,如果未找到则返回None。
|
"""
|
session = self.db.query(SessionModel).filter_by(id=session_id).first()
|
if session and session.message is None:
|
session.message = '[]'
|
return session
|
|
def update_session(self, session_id: str, **kwargs) -> Type[SessionModel] | None:
|
"""
|
更新会话记录。
|
|
参数:
|
session_id (str): 会话ID。
|
kwargs: 需要更新的字段及其值。
|
|
返回:
|
SessionModel: 更新后的会话模型实例。
|
"""
|
logger.error("更新数据---------------------------")
|
self.db.commit()
|
session = self.get_session_by_id(session_id)
|
if session:
|
if "message" in kwargs:
|
session.add_message(kwargs["message"])
|
# 替换其他字段
|
for key, value in kwargs.items():
|
if key != "message":
|
setattr(session, key, value)
|
session.update_date = current_time()
|
try:
|
self.db.commit()
|
self.db.refresh(session)
|
except Exception as e:
|
logger.error(e)
|
self.db.rollback()
|
return session
|
|
def delete_session(self, session_id: str) -> None:
|
"""
|
删除会话记录。
|
|
参数:
|
session_id (str): 会话ID。
|
"""
|
session = self.get_session_by_id(session_id)
|
if session:
|
self.db.delete(session)
|
self.db.commit()
|