xuyonghao
2025-02-10 2ab8a0e98c782a55c69a22d4b49bf294b8cfc2d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import json
from typing import Type
 
from sqlalchemy.orm import Session
 
from Log import logger
from app.models import AgentType, current_time
from app.models.session_model import SessionModel
 
 
class SessionService:
    def __init__(self, db: Session):
        self.db = db
 
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType, user_id: int,
                       message: dict = None, workflow_type: int = 0) -> Type[SessionModel] | SessionModel:
        """
        创建一个新的会话记录。
 
        参数:
            session_id (str): 会话ID。
            name (str): 会话名称。
            agent_id (str): 代理ID。
            agent_type (AgentType): 代理类型。
 
        返回:
            SessionModel: 新创建的会话模型实例,如果会话ID已存在则返回None。
        """
        if not message:
            message = {"role": "user", "content": name}
        existing_session = self.get_session_by_id(session_id)
        if existing_session:
            existing_session.add_message(message)
            existing_session.update_date = current_time()
            self.db.commit()
            self.db.refresh(existing_session)
            return existing_session
 
        new_session = SessionModel(
            id=session_id,
            name=name[0:50],
            agent_id=agent_id,
            agent_type=agent_type,
            tenant_id=user_id,
            workflow=workflow_type,
            message=json.dumps([message])
        )
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
 
    def get_session_by_id(self, session_id: str) -> Type[SessionModel] | None:
        """
        根据会话ID获取会话记录。
 
        参数:
            session_id (str): 会话ID。
 
        返回:
            SessionModel: 查找到的会话模型实例,如果未找到则返回None。
        """
        session = self.db.query(SessionModel).filter_by(id=session_id).first()
        if  session and session.message is None:
            session.message = '[]'
        return session
 
    def update_session(self, session_id: str, **kwargs) -> Type[SessionModel] | None:
        """
        更新会话记录。
 
        参数:
            session_id (str): 会话ID。
            kwargs: 需要更新的字段及其值。
 
        返回:
            SessionModel: 更新后的会话模型实例。
        """
        logger.error("更新数据---------------------------")
        self.db.commit()
        session = self.get_session_by_id(session_id)
        if session:
            if "message" in kwargs:
                session.add_message(kwargs["message"])
            # 替换其他字段
            for key, value in kwargs.items():
                if key != "message":
                    setattr(session, key, value)
            session.update_date = current_time()
            try:
                self.db.commit()
                self.db.refresh(session)
            except Exception as e:
                logger.error(e)
                self.db.rollback()
        return session
 
    def delete_session(self, session_id: str) -> None:
        """
        删除会话记录。
 
        参数:
            session_id (str): 会话ID。
        """
        session = self.get_session_by_id(session_id)
        if session:
            self.db.delete(session)
            self.db.commit()