zhangqian
2024-11-14 a27d23e9d7dde3a220795828971f480850c22b8f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from typing import Dict, List, Tuple
 
from sqlalchemy import create_engine, Column, String, Integer
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import sessionmaker
 
from app.config.config import settings
from app.models.agent_model import AgentModel
from app.models.base_model import SessionLocal, Base
 
# 创建数据库引擎和会话工厂
engine_bisheng = create_engine(settings.sgb_db_url)
engine_ragflow = create_engine(settings.fwr_db_url)
 
SessionBisheng = sessionmaker(autocommit=False, autoflush=False, bind=engine_bisheng)
SessionRagflow = sessionmaker(autocommit=False, autoflush=False, bind=engine_ragflow)
 
 
class Flow(Base):
    __tablename__ = 'flow'
    id = Column(String(255), primary_key=True)
    name = Column(String(255), nullable=False)
    status = Column(Integer, nullable=False)
 
 
class Dialog(Base):
    __tablename__ = 'dialog'
    id = Column(String(255), primary_key=True)
    name = Column(String(255), nullable=False)
    status = Column(String(1), nullable=False)
 
 
# 解析名字
def parse_names(names_str: str) -> List[str]:
    return [name.strip() for name in names_str.split(',')]
 
 
BISHENG_NAMES_TO_SYNC = parse_names(settings.fetch_sgb_agent)
RAGFLOW_NAMES_TO_SYNC = parse_names(settings.fetch_fwr_agent)
 
 
def get_data_from_bisheng(names: List[str]) -> List[Tuple]:
    db = SessionBisheng()
    try:
        if names:
            query = db.query(Flow.id, Flow.name) \
                .filter(Flow.status == 2, Flow.name.in_(names))
        else:
            query = db.query(Flow.id, Flow.name) \
                .filter(Flow.status == 2)
 
        results = query.all()
        print(f"Executing query: {query}")
        # 格式化id为UUID
        formatted_results = [(format_uuid(row[0]), row[1]) for row in results]
        return formatted_results
    finally:
        db.close()
 
 
def format_uuid(uuid_str: str) -> str:
    # 确保输入字符串长度为32
    if len(uuid_str) != 32:
        raise ValueError("Input string must be 32 characters long")
 
    # 插入连字符
    formatted_uuid = f"{uuid_str[:8]}-{uuid_str[8:12]}-{uuid_str[12:16]}-{uuid_str[16:20]}-{uuid_str[20:]}"
    return formatted_uuid
 
 
def get_data_from_ragflow(names: List[str]) -> List[Tuple]:
    db = SessionRagflow()
    try:
        if names:
            query = db.query(Dialog.id, Dialog.name) \
                .filter(Dialog.status == 1, Dialog.name.in_(names))
        else:
            query = db.query(Dialog.id, Dialog.name) \
                .filter(Dialog.status == 1)
 
        results = query.all()
        print(f"Executing query: {query}")
        return results
    finally:
        db.close()
 
 
def update_ids_in_local(data: List[Tuple]):
    db = SessionLocal()
    try:
        for row in data:
            name = row[1]
            new_id = row[0]
            existing_agent = db.query(AgentModel).filter_by(name=name).first()
            if existing_agent:
                existing_agent.id = new_id
                db.add(existing_agent)
        db.commit()
    except IntegrityError:
        db.rollback()
        raise
    finally:
        db.close()
 
 
def initialize_agents():
    db = SessionLocal()
    try:
        count = db.query(AgentModel).count()
        if count > 0:
            return
        initial_agents = [
            ('80ee430a-e396-48c4-a12c-7c7cdf5eda51', 1, '报告生成', 'BISHENG', 'report'),
            ('basic_excel_merge', 2, '报表合并', 'BASIC', 'excelMerge'),
            ('bfd090d589d811efb3630242ac190006', 4, '文档智能', 'BISHENG', 'documentChat'),
            ('da3451da89d911efb9490242ac190006', 3, '知识问答', 'RAGFLOW', 'knowledgeQA'),
            ('e96eb7a589db11ef87d20242ac190006', 5, '智能问答', 'RAGFLOW', 'chat')
        ]
 
        for agent in initial_agents:
            agent_id = format_uuid(agent[0]) if len(agent[0]) == 32 else agent[0]
            db.add(AgentModel(id=agent_id, sort=agent[1], name=agent[2], agent_type=agent[3], type=agent[4]))
 
        db.commit()
        print("Initial agents inserted successfully")
    except IntegrityError:
        db.rollback()
        raise
    finally:
        db.close()
 
 
def sync_agents():
    try:
        bisheng_data = get_data_from_bisheng(BISHENG_NAMES_TO_SYNC)
        ragflow_data = get_data_from_ragflow(RAGFLOW_NAMES_TO_SYNC)
 
        update_ids_in_local(bisheng_data)
        update_ids_in_local(ragflow_data)
 
        print("Agents synchronized successfully")
    except Exception as e:
        print(f"Failed to sync agents: {str(e)}")