zhangqian
2024-11-19 7305b7b9c88be497452e4dcf8b70decef0353bad
发送问答消息时创建会话记录
3个文件已修改
1个文件已添加
106 ■■■■■ 已修改文件
app/api/chat.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/__init__.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/session_model.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/session.py 80 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/chat.py
@@ -14,6 +14,7 @@
from app.service.basic import BasicService
from app.service.ragflow import RagflowService
from app.service.service_token import get_bisheng_token, get_ragflow_token
from app.service.session import SessionService
router = APIRouter()
@@ -203,6 +204,12 @@
                # 接收前端消息
                message = await websocket.receive_json()
                question = message.get("message")
                SessionService(db).create_session(
                    session_id=chat_id,
                    name=question,
                    agent_id=agent_id,
                    agent_type=AgentType.BASIC
                )
                if not question:
                    await websocket.send_json({"message": "Invalid request", "type": "error"})
                    continue
app/models/__init__.py
@@ -1,3 +1,7 @@
from zoneinfo import ZoneInfo
import pytz
from .agent_model import *
from .dialog_model import *
from .group_model import *
@@ -6,4 +10,10 @@
from .organization_model import *
from .resource_model import *
from .role_model import *
from .user_model import *
from .user_model import *
# 获取当前时区的时间
def current_time():
    tz = pytz.timezone('Asia/Shanghai')
    return datetime.now(tz)
app/models/session_model.py
@@ -3,7 +3,7 @@
from enum import IntEnum
from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime
from app.models import AgentType
from app.models import AgentType, current_time
from app.models.base_model import Base
@@ -13,9 +13,8 @@
    name = Column(String(255))
    agent_id = Column(String(255))
    agent_type = Column(SQLAlchemyEnum(AgentType), nullable=False)  # 目前只存basic的,ragflow和bisheng的调接口获取
    create_date = Column(DateTime)  # 创建时间
    update_date = Column(DateTime)  # 更新时间
    create_date = Column(DateTime, default=current_time)  # 创建时间,默认值为当前时区时间
    update_date = Column(DateTime, default=current_time, onupdate=current_time)  # 更新时间,默认值为当前时区时间,更新时自动更新
    # to_dict 方法
    def to_dict(self):
        return {
app/service/session.py
New file
@@ -0,0 +1,80 @@
from sqlalchemy.orm import Session
from app.models import AgentType
from app.models.session_model import SessionModel
class SessionService:
    def __init__(self, db: Session):
        self.db = db
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType) -> SessionModel:
        """
        创建一个新的会话记录。
        参数:
            session_id (str): 会话ID。
            name (str): 会话名称。
            agent_id (str): 代理ID。
            agent_type (AgentType): 代理类型。
        返回:
            SessionModel: 新创建的会话模型实例,如果会话ID已存在则返回None。
        """
        existing_session = self.get_session_by_id(session_id)
        if existing_session:
            return None  # 如果会话ID已存在,不进行任何操作
        new_session = SessionModel(
            id=session_id,
            name=name,
            agent_id=agent_id,
            agent_type=agent_type
        )
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
    def get_session_by_id(self, session_id: str) -> SessionModel:
        """
        根据会话ID获取会话记录。
        参数:
            session_id (str): 会话ID。
        返回:
            SessionModel: 查找到的会话模型实例,如果未找到则返回None。
        """
        return self.db.query(SessionModel).filter_by(id=session_id).first()
    def update_session(self, session_id: str, **kwargs) -> SessionModel:
        """
        更新会话记录。
        参数:
            session_id (str): 会话ID。
            kwargs: 需要更新的字段及其值。
        返回:
            SessionModel: 更新后的会话模型实例。
        """
        session = self.get_session_by_id(session_id)
        if session:
            for key, value in kwargs.items():
                setattr(session, key, value)
            self.db.commit()
            self.db.refresh(session)
        return session
    def delete_session(self, session_id: str) -> None:
        """
        删除会话记录。
        参数:
            session_id (str): 会话ID。
        """
        session = self.get_session_by_id(session_id)
        if session:
            self.db.delete(session)
            self.db.commit()