zhangxiao
2024-10-16 30311881800e4840a13f13dd702b093543b2082e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import json
import uuid
 
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends
import asyncio
import websockets
from sqlalchemy.orm import Session
from app.api import get_current_user_websocket
from app.config.config import settings
from app.models.agent_model import AgentModel, AgentType
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.ragflow import RagflowService
from app.service.token import get_bisheng_token, get_ragflow_token
 
router = APIRouter()
 
 
# 中间层WebSocket 服务器,接收客户端的连接
@router.websocket("/ws/{agent_id}/{chat_id}")
async def handle_client(websocket: WebSocket,
                        agent_id: str,
                        chat_id: str,
                        current_user: UserModel = Depends(get_current_user_websocket),
                        db: Session = Depends(get_db)):
    await websocket.accept()
    print(f"Client {agent_id} connected")
 
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    if not agent:
        ret = {"message": "Agent not found", "type": "close"}
        return websocket.send_json(ret)
    agent_type = agent.agent_type
    if chat_id == "" or chat_id == "0":
        ret = {"message": "Chat ID not found", "type": "close"}
        return websocket.send_json(ret)
 
    if agent_type == AgentType.RAGFLOW:
        ragflow_service = RagflowService(settings.ragflow_base_url)
        token = get_ragflow_token(db, current_user.id)
        try:
            async def forward_to_ragflow():
                while True:
                    message = await websocket.receive_json()
                    print(f"Received from client {chat_id}: {message}")
                    async for rag_response in ragflow_service.chat(token, chat_id, message["chatHistory"]):
                        try:
                            print(f"Received from ragflow: {rag_response}")
                            json_str = rag_response[5:].strip()
                            json_data = json.loads(json_str)
                            data = json_data.get("data")
                            if data is True:  # 完成输出
                                result = {"message": "", "type": "close"}
                            elif data is None:  # 发生错误
                                answer = json_data.get("retmsg", json_data.get("retcode"))
                                result = {"message": "内部错误:" + answer, "type": "stream"}
                            else:  # 正常输出
                                answer = json_data.get("data", {}).get("answer", "")
                                result = {"message": answer, "type": "stream"}
                            await websocket.send_json(result)
                            print(f"Forwarded to client {chat_id}: {result}")
                        except Exception as e:
                            result = {"message": f"内部错误: {e}", "type": "close"}
                            await websocket.send_json(result)
                            print(f"Error process message of ragflow: {e}")
            # 启动任务处理客户端消息
            tasks = [
                asyncio.create_task(forward_to_ragflow())
            ]
            await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
        except WebSocketDisconnect:
            print(f"Client {chat_id} disconnected")
 
    elif agent_type == AgentType.BISHENG:
        token = get_bisheng_token(db, current_user.id)
        service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
        headers = {'cookie': f"access_token_cookie={token};"}
 
        async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
 
            try:
                # 处理客户端发来的消息
                async def forward_to_service():
                    while True:
                        message = await websocket.receive_json()
                        print(f"Received from client, {chat_id}: {message}")
                        # 添加 'agent_id' 和 'chat_id' 字段
                        message['flow_id'] = agent_id
                        message['chat_id'] = chat_id
                        msg = message["message"]
                        del message["message"]
                        message['inputs'] = {
                            "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
                            "input": msg
                        }
                        await service_websocket.send(json.dumps(message))
                        print(f"Forwarded to bisheng: {message}")
 
                # 监听毕昇发来的消息并转发给客户端
                async def forward_to_client():
                    while True:
                        message = await service_websocket.recv()
                        print(f"Received from bisheng: {message}")
                        data = json.loads(message)
                        if data["type"] == "close" or data["type"] == "stream" or data["type"] == "end_cover":
                            if data["type"] == "close":
                                t = "close"
                            else:
                                t = "stream"
                            result = {"message": data["message"], "type": t}
                            await websocket.send_json(result)
                            print(f"Forwarded to client, {chat_id}: {result}")
 
                # 启动两个任务,分别处理客户端和服务端的消息
                tasks = [
                    asyncio.create_task(forward_to_service()),
                    asyncio.create_task(forward_to_client())
                ]
                done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
 
                # 取消未完成的任务
                for task in pending:
                    task.cancel()
                    try:
                        await task
                    except asyncio.CancelledError:
                        pass
 
            except WebSocketDisconnect:
                print(f"Client {chat_id} disconnected")
    else:
        ret = {"message": "Agent not found", "type": "close"}
        return websocket.send_json(ret)