tnp
zhaoqingang
2025-01-07 51433cba2f35b9a2571023236006ebc69d1d4d2d
app/models/v2/session_model.py
@@ -1,26 +1,26 @@
import json
from datetime import datetime
from enum import IntEnum
from typing import Optional
import pytz
from datetime import datetime
from sqlalchemy.orm import Session
from typing import Optional, Type
from pydantic import BaseModel
from sqlalchemy import Column, String, Integer, DateTime, JSON, TEXT, Index
from Log import logger
from app.models.agent_model import AgentType
# from app.models import current_time
from app.models.base_model import Base
def current_time():
    tz = pytz.timezone('Asia/Shanghai')
    return datetime.now(tz)
class SessionModel(Base):
class ChatSessionModel(Base):
    __tablename__ = "chat_sessions"
    __table_args__ = (
        Index('idx_username', 'username'),
    )
    # __table_args__ = (
    #     Index('idx_username', 'username'),
    # )
    id = Column(Integer, primary_key=True)
    name = Column(String(255))
@@ -72,3 +72,84 @@
    sessionId: Optional[str] = ""
    question: str
    chatId: str
class ChatSessionDao:
    def __init__(self, db: Session):
        self.db = db
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: int, user_id: int, message: str,reference:str) -> ChatSessionModel:
        new_session = ChatSessionModel(
            id=session_id,
            name=name[0:255],
            agent_id=agent_id,
            agent_type=agent_type,
            create_date=current_time(),
            update_date=current_time(),
            tenant_id=user_id,
            message=message,
            reference=reference,
        )
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
    def get_session_by_id(self, session_id: str) -> Type[ChatSessionModel] | None:
        session = self.db.query(ChatSessionModel).filter_by(id=session_id).first()
        if  session and session.message is None:
            session.message = '[]'
        return session
    def update_session_by_id(self, session_id: str, **kwargs) -> Type[ChatSessionModel] | None:
        session = self.get_session_by_id(session_id)
        if session:
            if "message" in kwargs:
                session.add_message(kwargs["message"])
            # 替换其他字段
            for key, value in kwargs.items():
                if key != "message":
                    setattr(session, key, value)
            session.update_date = current_time()
            try:
                self.db.commit()
                self.db.refresh(session)
            except Exception as e:
                logger.error(e)
                self.db.rollback()
        return session
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> ChatSessionModel:
        existing_session = self.get_session_by_id(session_id)
        if existing_session:
            existing_session.add_message({"role": "user", "content": name})
            existing_session.update_date = current_time()
            self.db.commit()
            self.db.refresh(existing_session)
            return existing_session
        new_session = ChatSessionModel(
            id=session_id,
            name=name[0:50],
            agent_id=agent_id,
            agent_type=agent_type,
            tenant_id=user_id,
            message=json.dumps([{"role": "user", "content": name}])
        )
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
    def delete_session(self, session_id: str) -> None:
        """
        删除会话记录。
        参数:
            session_id (str): 会话ID。
        """
        session = self.get_session_by_id(session_id)
        if session:
            self.db.delete(session)
            self.db.commit()