import asyncio
|
import json
|
from enum import Enum
|
|
from fastapi import APIRouter, Depends
|
from sqlalchemy.orm import Session
|
from starlette.websockets import WebSocket, WebSocketDisconnect
|
|
from app.api import get_current_user_websocket
|
from app.config.config import settings
|
from app.models import UserModel, AgentModel
|
from app.models.base_model import get_db
|
from app.service.basic import BasicService
|
|
router = APIRouter()
|
|
# class CompletionRequest(BaseModel):
|
# id: Optional[str] = None
|
# app_id: str
|
# message: str
|
#
|
# class DownloadRequest(BaseModel):
|
# file_id: str
|
# app_id: str
|
# file_type: Optional[str] = None
|
|
|
class AdvancedAgentID(Enum):
|
EXCEL_TALK = "excel_talk"
|
QUESTIONS_TALK = "questions_talk"
|
|
@router.websocket("/ws/{agent_id}/{chat_id}")
|
async def handle_client(websocket: WebSocket,
|
agent_id: str,
|
chat_id: str,
|
current_user: UserModel = Depends(get_current_user_websocket),
|
db: Session = Depends(get_db)):
|
await websocket.accept()
|
print(f"Client {agent_id} connected")
|
|
service = BasicService(base_url=settings.basic_base_url)
|
|
agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
|
if not agent:
|
ret = {"message": "Agent not found", "type": "close"}
|
await websocket.send_json(ret)
|
return
|
try:
|
while True:
|
# 接收前端消息
|
message = await websocket.receive_json()
|
question = message.get("message")
|
if not question:
|
await websocket.send_json({"message": "Invalid request", "type": "error"})
|
continue
|
|
# 调用 excel_talk 方法
|
result = await service.excel_talk(question, chat_id)
|
|
# 将结果发送回前端
|
await websocket.send_json({"message": result, "type": "response"})
|
except Exception as e:
|
await websocket.send_json({"message": str(e), "type": "error"})
|
finally:
|
await websocket.close()
|
print(f"Client {agent_id} disconnected")
|