from typing import Dict, List, Tuple
|
|
from sqlalchemy import create_engine, Column, String, Integer
|
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.orm import sessionmaker
|
|
from app.config.config import settings
|
from app.models.agent_model import AgentModel
|
from app.models.base_model import SessionLocal, Base
|
|
# 创建数据库引擎和会话工厂
|
engine_bisheng = create_engine(settings.sgb_db_url)
|
engine_ragflow = create_engine(settings.fwr_db_url)
|
|
SessionBisheng = sessionmaker(autocommit=False, autoflush=False, bind=engine_bisheng)
|
SessionRagflow = sessionmaker(autocommit=False, autoflush=False, bind=engine_ragflow)
|
|
|
class Flow(Base):
|
__tablename__ = 'flow'
|
id = Column(String(255), primary_key=True)
|
name = Column(String(255), nullable=False)
|
status = Column(Integer, nullable=False)
|
|
|
class Dialog(Base):
|
__tablename__ = 'dialog'
|
id = Column(String(255), primary_key=True)
|
name = Column(String(255), nullable=False)
|
status = Column(String(1), nullable=False)
|
|
|
# 解析名字
|
def parse_names(names_str: str) -> List[str]:
|
return [name.strip() for name in names_str.split(',')]
|
|
|
BISHENG_NAMES_TO_SYNC = parse_names(settings.fetch_sgb_agent)
|
RAGFLOW_NAMES_TO_SYNC = parse_names(settings.fetch_fwr_agent)
|
|
|
def get_data_from_bisheng(names: List[str]) -> List[Tuple]:
|
db = SessionBisheng()
|
try:
|
if names:
|
query = db.query(Flow.id, Flow.name) \
|
.filter(Flow.status == 2, Flow.name.in_(names))
|
else:
|
query = db.query(Flow.id, Flow.name) \
|
.filter(Flow.status == 2)
|
|
results = query.all()
|
print(f"Executing query: {query}")
|
# 格式化id为UUID
|
formatted_results = [(format_uuid(row[0]), row[1]) for row in results]
|
return formatted_results
|
finally:
|
db.close()
|
|
|
def format_uuid(uuid_str: str) -> str:
|
# 确保输入字符串长度为32
|
if len(uuid_str) != 32:
|
raise ValueError("Input string must be 32 characters long")
|
|
# 插入连字符
|
formatted_uuid = f"{uuid_str[:8]}-{uuid_str[8:12]}-{uuid_str[12:16]}-{uuid_str[16:20]}-{uuid_str[20:]}"
|
return formatted_uuid
|
|
|
def get_data_from_ragflow(names: List[str]) -> List[Tuple]:
|
db = SessionRagflow()
|
try:
|
if names:
|
query = db.query(Dialog.id, Dialog.name) \
|
.filter(Dialog.status == 1, Dialog.name.in_(names))
|
else:
|
query = db.query(Dialog.id, Dialog.name) \
|
.filter(Dialog.status == 1)
|
|
results = query.all()
|
print(f"Executing query: {query}")
|
return results
|
finally:
|
db.close()
|
|
|
def update_ids_in_local(data: List[Tuple]):
|
db = SessionLocal()
|
try:
|
for row in data:
|
name = row[1]
|
new_id = row[0]
|
existing_agent = db.query(AgentModel).filter_by(name=name).first()
|
if existing_agent:
|
existing_agent.id = new_id
|
db.add(existing_agent)
|
db.commit()
|
except IntegrityError:
|
db.rollback()
|
raise
|
finally:
|
db.close()
|
|
|
def initialize_agents():
|
db = SessionLocal()
|
try:
|
count = db.query(AgentModel).count()
|
if count > 0:
|
return
|
initial_agents = [
|
('80ee430a-e396-48c4-a12c-7c7cdf5eda51', 1, '报告生成', 'BISHENG', 'report'),
|
('basic_excel_merge', 2, '报表合并', 'BASIC', 'excelMerge'),
|
('bfd090d589d811efb3630242ac190006', 4, '文档智能', 'BISHENG', 'report'),
|
('da3451da89d911efb9490242ac190006', 3, '知识问答', 'RAGFLOW', 'knowledgeQA'),
|
('e96eb7a589db11ef87d20242ac190006', 5, '智能问答', 'RAGFLOW', 'chat'),
|
('basic_excel_talk', 6, '智能数据', 'BASIC', 'excelTalk'),
|
('basic_question_talk', 7, '文档出卷', 'BASIC', 'questionTalk')
|
]
|
|
for agent in initial_agents:
|
agent_id = format_uuid(agent[0]) if len(agent[0]) == 32 else agent[0]
|
db.add(AgentModel(id=agent_id, sort=agent[1], name=agent[2], agent_type=agent[3], type=agent[4]))
|
|
db.commit()
|
print("Initial agents inserted successfully")
|
except IntegrityError:
|
db.rollback()
|
raise
|
finally:
|
db.close()
|
|
|
def sync_agents():
|
try:
|
bisheng_data = get_data_from_bisheng(BISHENG_NAMES_TO_SYNC)
|
ragflow_data = get_data_from_ragflow(RAGFLOW_NAMES_TO_SYNC)
|
|
update_ids_in_local(bisheng_data)
|
update_ids_in_local(ragflow_data)
|
|
print("Agents synchronized successfully")
|
except Exception as e:
|
print(f"Failed to sync agents: {str(e)}")
|