zhaoqingang
2025-04-11 e078028f1a8da34f3cc2cb5095b8e103a996f553
app/models/v2/session_model.py
@@ -3,11 +3,11 @@
from datetime import datetime
from sqlalchemy.orm import Session
from typing import Optional, Type
from typing import Optional, Type, List
from pydantic import BaseModel
from sqlalchemy import Column, String, Integer, DateTime, JSON, TEXT, Index
from Log import logger
# from Log import logger
from app.models.agent_model import AgentType
from app.models.base_model import Base
@@ -35,14 +35,17 @@
    reference = Column(TEXT)
    conversation_id = Column(String(36), index=True)
    event_type = Column(String(16))
    session_type = Column(String(16))
    # to_dict 方法
    def to_dict(self):
        return {
            'id': self.id,
            'session_id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'chat_id': self.agent_id,
            'event_type': self.event_type,
            'session_type': self.session_type if self.session_type else 0,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
        }
@@ -52,7 +55,7 @@
            'id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'chat_id': self.agent_id,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
            'message': json.loads(self.message)
@@ -71,10 +74,23 @@
class ChatDialogData(BaseModel):
class ChatData(BaseModel):
    sessionId: Optional[str] = ""
    question: str
    chatId: str
    class Config:
        extra = 'allow'  # 允许其他动态字段
    def to_dict(self):
        res = {"files": [], "inputs": {}}
        if hasattr(self, 'files'):
            res['files'] = self.files
        if hasattr(self, 'inputs'):
            res['inputs'] = self.inputs
        return res
class ChatSessionDao:
@@ -98,17 +114,20 @@
        session = self.db.query(ChatSessionModel).filter_by(id=session_id).first()
        return session
    async def update_session_by_id(self, session_id: str, session, message: dict) -> ChatSessionModel | None:
    async def update_session_by_id(self, session_id: str, session, message: dict, conversation_id=None) -> ChatSessionModel | None:
        # print(message)
        if not session:
            session = await self.get_session_by_id(session_id)
        if session:
            try:
                if conversation_id:
                    session.conversation_id=conversation_id
                session.add_message(message)
                session.update_date = current_time()
                self.db.commit()
                self.db.refresh(session)
            except Exception as e:
                logger.error(e)
                # logger.error(e)
                self.db.rollback()
        return session
@@ -125,3 +144,15 @@
        if session:
            self.db.delete(session)
            self.db.commit()
    async def get_session_list(self, user_id: int, agent_id: str, keyword:str, page: int, page_size: int) -> any:
        query = self.db.query(ChatSessionModel).filter(ChatSessionModel.tenant_id==user_id)
        if agent_id:
            query = query.filter(ChatSessionModel.agent_id==agent_id)
        if keyword:
            query = query.filter(ChatSessionModel.name.like('%{}%'.format(keyword)))
        total = query.count()
        session_list = query.order_by(ChatSessionModel.update_date.desc()).offset((page-1)*page_size).limit(page_size).all()
        return total, session_list