import asyncio import json from enum import Enum from fastapi import APIRouter, Depends from sqlalchemy.orm import Session from starlette.websockets import WebSocket, WebSocketDisconnect from app.api import get_current_user_websocket from app.config.config import settings from app.models import UserModel, AgentModel from app.models.base_model import get_db from app.service.basic import BasicService router = APIRouter() # class CompletionRequest(BaseModel): # id: Optional[str] = None # app_id: str # message: str # # class DownloadRequest(BaseModel): # file_id: str # app_id: str # file_type: Optional[str] = None class AdvancedAgentID(Enum): EXCEL_TALK = "excel_talk" QUESTIONS_TALK = "questions_talk" @router.websocket("/ws/{agent_id}/{chat_id}") async def handle_client(websocket: WebSocket, agent_id: str, chat_id: str, current_user: UserModel = Depends(get_current_user_websocket), db: Session = Depends(get_db)): await websocket.accept() print(f"Client {agent_id} connected") service = BasicService(base_url=settings.basic_base_url) agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() if not agent: ret = {"message": "Agent not found", "type": "close"} await websocket.send_json(ret) return try: while True: # 接收前端消息 message = await websocket.receive_json() question = message.get("message") if not question: await websocket.send_json({"message": "Invalid request", "type": "error"}) continue # 调用 excel_talk 方法 result = await service.excel_talk(question, chat_id) # 将结果发送回前端 await websocket.send_json({"message": result, "type": "response"}) except Exception as e: await websocket.send_json({"message": str(e), "type": "error"}) finally: await websocket.close() print(f"Client {agent_id} disconnected")