zhaoqingang
2025-03-28 226202d6eee6480f3386c6295be26fad42940cc8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import json
import pytz
 
from datetime import datetime
from sqlalchemy.orm import Session
from typing import Optional, Type, List
from pydantic import BaseModel
from sqlalchemy import Column, String, Integer, DateTime, JSON, TEXT, Index
 
# from Log import logger
from app.models.agent_model import AgentType
from app.models.base_model import Base
 
 
def current_time():
    tz = pytz.timezone('Asia/Shanghai')
    return datetime.now(tz)
 
 
class ChatSessionModel(Base):
    __tablename__ = "chat_sessions"
 
    # __table_args__ = (
    #     Index('idx_username', 'username'),
    # )
 
    id = Column(String(36), primary_key=True)
    name = Column(String(255))
    agent_id = Column(String(255))
    agent_type = Column(Integer)  # 目前只存basic的,ragflow和bisheng的调接口获取
    create_date = Column(DateTime, default=current_time)  # 创建时间,默认值为当前时区时间
    update_date = Column(DateTime, default=current_time, onupdate=current_time, index=True)  # 更新时间,默认值为当前时区时间,更新时自动更新
    tenant_id = Column(Integer, index=True)  # 创建人
    message = Column(TEXT)
    reference = Column(TEXT)
    conversation_id = Column(String(36), index=True)
    event_type = Column(String(16))
    session_type = Column(String(16))
 
    # to_dict 方法
    def to_dict(self):
        return {
            'session_id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'chat_id': self.agent_id,
            'event_type': self.event_type,
            'session_type': self.session_type if self.session_type else 0,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
        }
 
    def log_to_json(self):
        return {
            'id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'chat_id': self.agent_id,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
            'message': json.loads(self.message)
        }
 
    def add_message(self, message: dict):
        if self.message is None:
            self.message = '[]'
        try:
            msg = json.loads(self.message)
            msg.append(message)
        except Exception as e:
            print(e)
            return
        self.message = json.dumps(msg)
 
 
 
class ChatData(BaseModel):
    sessionId: Optional[str] = ""
 
    class Config:
        extra = 'allow'  # 允许其他动态字段
 
 
 
    def to_dict(self):
        res = {"files": [], "inputs": {}}
        if hasattr(self, 'files'):
            res['files'] = self.files
        if hasattr(self, 'inputs'):
            res['inputs'] = self.inputs
        return res
 
 
 
 
class ChatSessionDao:
    def __init__(self, db: Session):
        self.db = db
 
    async def create_session(self, session_id: str, **kwargs) -> ChatSessionModel:
        new_session = ChatSessionModel(
            id=session_id,
            create_date=current_time(),
            update_date=current_time(),
            **kwargs
        )
        new_session.message = json.dumps([new_session.message])
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
 
    async def get_session_by_id(self, session_id: str) -> ChatSessionModel | None:
        session = self.db.query(ChatSessionModel).filter_by(id=session_id).first()
        return session
 
    async def update_session_by_id(self, session_id: str, session, message: dict, conversation_id=None) -> ChatSessionModel | None:
        # print(message)
        if not session:
            session = await self.get_session_by_id(session_id)
        if session:
            try:
                if conversation_id:
                    session.conversation_id=conversation_id
                session.add_message(message)
                session.update_date = current_time()
                self.db.commit()
                self.db.refresh(session)
            except Exception as e:
                # logger.error(e)
                self.db.rollback()
        return session
 
    async def update_or_insert_by_id(self, session_id: str, **kwargs) -> ChatSessionModel:
        existing_session = await self.get_session_by_id(session_id)
        if existing_session:
            return await self.update_session_by_id(session_id, existing_session, kwargs.get("message"))
 
        existing_session = await self.create_session(session_id, **kwargs)
        return existing_session
 
    async def delete_session(self, session_id: str) -> None:
        session = await self.get_session_by_id(session_id)
        if session:
            self.db.delete(session)
            self.db.commit()
 
    async def get_session_list(self, user_id: int, agent_id: str, keyword:str, page: int, page_size: int) -> any:
        query = self.db.query(ChatSessionModel).filter(ChatSessionModel.tenant_id==user_id)
        if agent_id:
            query = query.filter(ChatSessionModel.agent_id==agent_id)
        if keyword:
            query = query.filter(ChatSessionModel.name.like('%{}%'.format(keyword)))
        total = query.count()
        session_list = query.order_by(ChatSessionModel.update_date.desc()).offset((page-1)*page_size).limit(page_size).all()
        return total, session_list