tnp
zhaoqingang
2025-01-07 51433cba2f35b9a2571023236006ebc69d1d4d2d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import json
import pytz
 
from datetime import datetime
from sqlalchemy.orm import Session
from typing import Optional, Type
from pydantic import BaseModel
from sqlalchemy import Column, String, Integer, DateTime, JSON, TEXT, Index
 
from Log import logger
from app.models.agent_model import AgentType
from app.models.base_model import Base
 
def current_time():
    tz = pytz.timezone('Asia/Shanghai')
    return datetime.now(tz)
 
class ChatSessionModel(Base):
    __tablename__ = "chat_sessions"
 
    # __table_args__ = (
    #     Index('idx_username', 'username'),
    # )
 
    id = Column(Integer, primary_key=True)
    name = Column(String(255))
    agent_id = Column(String(255))
    agent_type = Column(Integer)  # 目前只存basic的,ragflow和bisheng的调接口获取
    create_date = Column(DateTime, default=current_time)  # 创建时间,默认值为当前时区时间
    update_date = Column(DateTime, default=current_time, onupdate=current_time, index=True)  # 更新时间,默认值为当前时区时间,更新时自动更新
    tenant_id = Column(Integer)  # 创建人
    message = Column(TEXT)  # 说明
    reference = Column(TEXT)  # 说明
    conversation_id = Column(String(64))
    session_id = Column(String(36), index=True)
    chat_mode = Column(Integer)
 
    # to_dict 方法
    def to_dict(self):
        return {
            'id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
        }
 
    def log_to_json(self):
        return {
            'id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
            'message': json.loads(self.message)
        }
 
    def add_message(self, message: dict):
        if self.message is None:
            self.message = '[]'
        try:
            msg = json.loads(self.message)
            msg.append(message)
        except Exception as e:
            return
        self.message = json.dumps(msg)
 
 
class ChatDialogData(BaseModel):
    sessionId: Optional[str] = ""
    question: str
    chatId: str
 
 
 
class ChatSessionDao:
    def __init__(self, db: Session):
        self.db = db
 
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: int, user_id: int, message: str,reference:str) -> ChatSessionModel:
        new_session = ChatSessionModel(
            id=session_id,
            name=name[0:255],
            agent_id=agent_id,
            agent_type=agent_type,
            create_date=current_time(),
            update_date=current_time(),
            tenant_id=user_id,
            message=message,
            reference=reference,
        )
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
 
    def get_session_by_id(self, session_id: str) -> Type[ChatSessionModel] | None:
        session = self.db.query(ChatSessionModel).filter_by(id=session_id).first()
        if  session and session.message is None:
            session.message = '[]'
        return session
 
    def update_session_by_id(self, session_id: str, **kwargs) -> Type[ChatSessionModel] | None:
        session = self.get_session_by_id(session_id)
        if session:
            if "message" in kwargs:
                session.add_message(kwargs["message"])
            # 替换其他字段
            for key, value in kwargs.items():
                if key != "message":
                    setattr(session, key, value)
            session.update_date = current_time()
            try:
                self.db.commit()
                self.db.refresh(session)
            except Exception as e:
                logger.error(e)
                self.db.rollback()
        return session
 
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int) -> ChatSessionModel:
        existing_session = self.get_session_by_id(session_id)
        if existing_session:
            existing_session.add_message({"role": "user", "content": name})
            existing_session.update_date = current_time()
            self.db.commit()
            self.db.refresh(existing_session)
            return existing_session
 
        new_session = ChatSessionModel(
            id=session_id,
            name=name[0:50],
            agent_id=agent_id,
            agent_type=agent_type,
            tenant_id=user_id,
            message=json.dumps([{"role": "user", "content": name}])
        )
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
 
    def delete_session(self, session_id: str) -> None:
        """
        删除会话记录。
 
        参数:
            session_id (str): 会话ID。
        """
        session = self.get_session_by_id(session_id)
        if session:
            self.db.delete(session)
            self.db.commit()