import json
|
|
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends, HTTPException, Query
|
import asyncio
|
import websockets
|
from sqlalchemy.orm import Session
|
from app.api import get_current_user_websocket, ResponseList, get_current_user, format_file_url, process_files
|
from app.config.config import settings
|
from app.models.agent_model import AgentModel, AgentType
|
from app.models.base_model import get_db
|
from app.models.user_model import UserModel
|
from app.service.bisheng import BishengService
|
from app.service.service_token import get_bisheng_token
|
|
router = APIRouter()
|
|
|
@router.websocket("/ws/{agent_id}/{chat_id}")
|
async def report_chat(websocket: WebSocket,
|
agent_id: str,
|
chat_id: str,
|
current_user: UserModel = Depends(get_current_user_websocket),
|
db: Session = Depends(get_db)):
|
agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
|
if not agent:
|
ret = {"message": "Agent not found", "type": "close"}
|
return websocket.send_json(ret)
|
agent_type = agent.agent_type
|
if chat_id == "" or chat_id == "0":
|
ret = {"message": "Chat ID not found", "type": "close"}
|
return websocket.send_json(ret)
|
|
if agent_type != AgentType.BISHENG:
|
ret = {"message": "Agent error", "type": "close"}
|
return websocket.send_json(ret)
|
|
token = get_bisheng_token(db, current_user.id)
|
service_uri = f"{settings.sgb_websocket_url}/api/v1/chat/{agent_id}?type=L1&t=&chat_id={chat_id}"
|
headers = {'cookie': f"access_token_cookie={token};"}
|
|
await websocket.accept()
|
print(f"Client {agent_id} connected")
|
|
async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
|
|
try:
|
# 处理客户端发来的消息
|
async def forward_to_service():
|
while True:
|
message = await websocket.receive_json()
|
print(f"Received from client, {chat_id}: {message}")
|
# 添加 'agent_id' 和 'chat_id' 字段
|
message['flow_id'] = agent_id
|
message['chat_id'] = chat_id
|
await service_websocket.send(json.dumps(message))
|
print(f"Forwarded to bisheng: {message}")
|
|
# 监听毕昇发来的消息并转发给客户端
|
async def forward_to_client():
|
last_message = "step"
|
while True:
|
message = await service_websocket.recv()
|
print(f"Received from bisheng: {message}")
|
data = json.loads(message)
|
files = data.get("files", [])
|
steps = data.get("intermediate_steps", "")
|
msg = data.get("message", "")
|
category = data.get("category", "")
|
|
if len(files) != 0 or (msg and category != "answer") or data["type"] == "close":
|
if data["type"] == "close":
|
t = "close"
|
else:
|
t = "stream"
|
process_files(files, agent_id)
|
result = {"message": msg, "type": t, "files": files}
|
await websocket.send_json(result)
|
elif steps and last_message == "step":
|
result = {"step_message": steps, "type": "stream", "files": files}
|
await websocket.send_json(result)
|
|
last_message = "message" if msg else "step"
|
|
# 启动两个任务,分别处理客户端和服务端的消息
|
tasks = [
|
asyncio.create_task(forward_to_service()),
|
asyncio.create_task(forward_to_client())
|
]
|
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
# 取消未完成的任务
|
for task in pending:
|
task.cancel()
|
try:
|
await task
|
except asyncio.CancelledError as e:
|
print(f"asyncio CancelledError: {e}")
|
pass
|
|
except WebSocketDisconnect as e:
|
print(f"WebSocket connection closed with code {e.code}: {e.reason}")
|
await websocket.close()
|
await service_websocket.close()
|
except Exception as e:
|
print(f"Exception occurred: {e}")
|
finally:
|
print("Cleaning up resources of bisheng report")
|
# 取消所有任务
|
for task in tasks:
|
if not task.done():
|
task.cancel()
|
try:
|
await task
|
except asyncio.CancelledError:
|
pass
|
|
|
@router.get("/variables/list", response_model=ResponseList)
|
async def get_variables(agent_id: str = Query(..., description="The ID of the agent"), db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_user)):
|
agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
|
if not agent:
|
return ResponseList(code=404, msg="Agent not found")
|
bisheng_service = BishengService(base_url=settings.sgb_base_url)
|
try:
|
token = get_bisheng_token(db, current_user.id)
|
result = await bisheng_service.variable_list(token, agent_id)
|
except Exception as e:
|
raise HTTPException(status_code=500, detail=str(e))
|
return ResponseList(code=200, msg="", data=result)
|