from typing import Dict, List, Tuple from sqlalchemy import create_engine, Column, String, Integer from sqlalchemy.dialects.postgresql import array from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from app.config.config import settings from app.models.agent_model import AgentModel from app.models.base_model import SessionLocal, Base from app.service.v2.initialize_data import dialog_menu_sync # 创建数据库引擎和会话工厂 engine_bisheng = create_engine(settings.sgb_db_url) engine_ragflow = create_engine(settings.fwr_db_url) SessionBisheng = sessionmaker(autocommit=False, autoflush=False, bind=engine_bisheng) SessionRagflow = sessionmaker(autocommit=False, autoflush=False, bind=engine_ragflow) class Flow(Base): __tablename__ = 'flow' id = Column(String(255), primary_key=True) name = Column(String(255), nullable=False) status = Column(Integer, nullable=False) class Dialog(Base): __tablename__ = 'dialog' id = Column(String(255), primary_key=True) name = Column(String(255), nullable=False) status = Column(String(1), nullable=False) # 解析名字 def parse_names(names_str: str) -> List[str]: return [name.strip() for name in names_str.split(',')] BISHENG_NAMES_TO_SYNC = parse_names(settings.fetch_sgb_agent) RAGFLOW_NAMES_TO_SYNC = parse_names(settings.fetch_fwr_agent) def get_data_from_bisheng(names: List[str]) -> List[Tuple]: db = SessionBisheng() try: if names: query = db.query(Flow.id, Flow.name) \ .filter(Flow.status == 2, Flow.name.in_(names)) else: query = db.query(Flow.id, Flow.name) \ .filter(Flow.status == 2) results = query.all() print(f"Executing query: {query}") # 格式化id为UUID formatted_results = [(format_uuid(row[0]), row[1]) for row in results] return formatted_results finally: db.close() def format_uuid(uuid_str: str) -> str: # 确保输入字符串长度为32 if len(uuid_str) != 32: raise ValueError("Input string must be 32 characters long") # 插入连字符 formatted_uuid = f"{uuid_str[:8]}-{uuid_str[8:12]}-{uuid_str[12:16]}-{uuid_str[16:20]}-{uuid_str[20:]}" return formatted_uuid def get_data_from_ragflow(names: List[str]) -> List[Tuple]: db = SessionRagflow() try: if names: query = db.query(Dialog.id, Dialog.name) \ .filter(Dialog.status == 1, Dialog.name.in_(names)) else: query = db.query(Dialog.id, Dialog.name) \ .filter(Dialog.status == 1) results = query.all() print(f"Executing query: {query}") return results finally: db.close() def update_ids_in_local(data: List[Tuple]): db = SessionLocal() try: for row in data: name = row[1] new_id = row[0] existing_agent = db.query(AgentModel).filter_by(name=name).first() if existing_agent: existing_agent.id = new_id db.add(existing_agent) db.commit() except IntegrityError: db.rollback() raise finally: db.close() def initialize_agents(): db = SessionLocal() try: count = db.query(AgentModel).count() if count > 0: result = db.query(AgentModel).delete() db.commit() # 提交事务 initial_agents = [ ('80ee430a-e396-48c4-a12c-7c7cdf5eda51', 1, '报告生成', 'BISHENG', 'report'), ('basic_excel_merge', 2, '报表合并', 'BASIC', 'excelMerge'), ('bfd090d589d811efb3630242ac190006', 4, '文档智能', 'BISHENG', 'report'), ('da3451da89d911efb9490242ac190006', 3, '知识问答', 'RAGFLOW', 'knowledgeQA'), ('e96eb7a589db11ef87d20242ac190006', 5, '智能问答', 'RAGFLOW', 'chat'), ('basic_excel_talk', 6, '智能数据', 'BASIC', 'excelTalk'), ('basic_question_talk', 7, '出题组卷', 'BASIC', 'questionTalk'), ('9d75142a-66eb-4e23-b7d4-03efe4584915', 8, '小数绘图', 'DIFY', 'imageTalk'), ('basic_paper_talk', 8, '文档出卷', 'BASIC', 'paperTalk'), ('basic_report_clean', 10, '文档报告', 'DIFY', 'reportWorkflow') ] for agent in initial_agents: agent_id = format_uuid(agent[0]) if len(agent[0]) == 32 else agent[0] db.add(AgentModel(id=agent_id, sort=agent[1], name=agent[2], agent_type=agent[3], type=agent[4])) db.commit() print("Initial agents inserted successfully") except IntegrityError: db.rollback() raise finally: db.close() def sync_agents(): try: bisheng_data = get_data_from_bisheng(BISHENG_NAMES_TO_SYNC) ragflow_data = get_data_from_ragflow(RAGFLOW_NAMES_TO_SYNC) update_ids_in_local(bisheng_data) update_ids_in_local(ragflow_data) print("Agents synchronized successfully") except Exception as e: print(f"Failed to sync agents: {str(e)}") async def sync_web_menu(): db = SessionLocal() await dialog_menu_sync(db) async def sync_default_group(): db = SessionLocal() await dialog_menu_sync(db)