zhaoqingang
2024-11-21 ae30d9a75407c912649f11c4f44ff15c869a4f98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import asyncio
import json
from enum import Enum
 
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from starlette.websockets import WebSocket, WebSocketDisconnect
 
from app.api import get_current_user_websocket
from app.config.config import settings
from app.models import UserModel, AgentModel
from app.models.base_model import get_db
from app.service.basic import BasicService
 
router = APIRouter()
 
# class CompletionRequest(BaseModel):
#     id: Optional[str] = None
#     app_id: str
#     message: str
#
# class DownloadRequest(BaseModel):
#     file_id: str
#     app_id: str
#     file_type: Optional[str] = None
 
 
class AdvancedAgentID(Enum):
    EXCEL_TALK = "excel_talk"
    QUESTIONS_TALK = "questions_talk"
 
@router.websocket("/ws/{agent_id}/{chat_id}")
async def handle_client(websocket: WebSocket,
                        agent_id: str,
                        chat_id: str,
                        current_user: UserModel = Depends(get_current_user_websocket),
                        db: Session = Depends(get_db)):
    await websocket.accept()
    print(f"Client {agent_id} connected")
 
    service = BasicService(base_url=settings.basic_base_url)
 
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    if not agent:
        ret = {"message": "Agent not found", "type": "close"}
        await websocket.send_json(ret)
        return
    try:
        while True:
            # 接收前端消息
            message = await websocket.receive_json()
            question = message.get("message")
            if not question:
                await websocket.send_json({"message": "Invalid request", "type": "error"})
                continue
 
            # 调用 excel_talk 方法
            result = await service.excel_talk(question, chat_id)
 
            # 将结果发送回前端
            await websocket.send_json({"message": result, "type": "response"})
    except Exception as e:
        await websocket.send_json({"message": str(e), "type": "error"})
    finally:
        await websocket.close()
        print(f"Client {agent_id} disconnected")