from sqlalchemy.orm import Session
|
|
from app.models import AgentType
|
from app.models.session_model import SessionModel
|
|
|
class SessionService:
|
def __init__(self, db: Session):
|
self.db = db
|
|
def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> SessionModel:
|
"""
|
创建一个新的会话记录。
|
|
参数:
|
session_id (str): 会话ID。
|
name (str): 会话名称。
|
agent_id (str): 代理ID。
|
agent_type (AgentType): 代理类型。
|
|
返回:
|
SessionModel: 新创建的会话模型实例,如果会话ID已存在则返回None。
|
"""
|
existing_session = self.get_session_by_id(session_id)
|
if existing_session:
|
message=existing_session.message
|
message.append({"role": "user", "content": name})
|
self.update_session(session_id, message=message)
|
|
new_session = SessionModel(
|
id=session_id,
|
name=name,
|
agent_id=agent_id,
|
agent_type=agent_type,
|
tenant_id = user_id,
|
message=[{"role": "user", "content": name}]
|
)
|
self.db.add(new_session)
|
self.db.commit()
|
self.db.refresh(new_session)
|
return new_session
|
|
def get_session_by_id(self, session_id: str) -> SessionModel:
|
"""
|
根据会话ID获取会话记录。
|
|
参数:
|
session_id (str): 会话ID。
|
|
返回:
|
SessionModel: 查找到的会话模型实例,如果未找到则返回None。
|
"""
|
return self.db.query(SessionModel).filter_by(id=session_id).first()
|
|
def update_session(self, session_id: str, **kwargs) -> SessionModel:
|
"""
|
更新会话记录。
|
|
参数:
|
session_id (str): 会话ID。
|
kwargs: 需要更新的字段及其值。
|
|
返回:
|
SessionModel: 更新后的会话模型实例。
|
"""
|
session = self.get_session_by_id(session_id)
|
if session:
|
if "message" in kwargs:
|
message = session.message
|
message.append(kwargs["message"])
|
kwargs["message"] = message
|
for key, value in kwargs.items():
|
setattr(session, key, value)
|
self.db.commit()
|
self.db.refresh(session)
|
return session
|
|
def delete_session(self, session_id: str) -> None:
|
"""
|
删除会话记录。
|
|
参数:
|
session_id (str): 会话ID。
|
"""
|
session = self.get_session_by_id(session_id)
|
if session:
|
self.db.delete(session)
|
self.db.commit()
|