zhangqian
2024-10-16 3fc9f4f33cf90610c71a1de7b00db0f82b988e98
app/api/chat.py
@@ -1,20 +1,19 @@
import json
import uuid
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Request, Depends
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends
import asyncio
import websockets
from sqlalchemy.orm import Session
from app.api import get_current_user_websocket
from app.config.config import settings
from app.models.agent_model import AgentModel, AgentType
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.token import get_bisheng_token
from app.service.ragflow import RagflowService
from app.service.token import get_bisheng_token, get_ragflow_token
router = APIRouter()
# 存储客户端 WebSocket 连接
client_websockets = {}
# 中间层WebSocket 服务器,接收客户端的连接
@@ -27,66 +26,123 @@
    await websocket.accept()
    print(f"Client {agent_id} connected")
    token = get_bisheng_token(db, current_user.id)
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    if not agent:
        ret = {"message": "Agent not found", "type": "close"}
        await websocket.send_json(ret)
        return
    agent_type = agent.agent_type
    if chat_id == "" or chat_id == "0":
        ret = {"message": "Chat ID not found", "type": "close"}
        await websocket.send_json(ret)
        return
    if agent_id == "0":
        agent_id = settings.bisheng_agent_id
    if chat_id == "0":
        chat_id = uuid.uuid4().hex
    # 连接到服务端
    service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
    headers = {
        'cookie':  f"access_token_cookie={token};"
    }
    async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
        client_websockets[chat_id] = websocket
    if agent_type == AgentType.RAGFLOW:
        ragflow_service = RagflowService(settings.ragflow_base_url)
        token = get_ragflow_token(db, current_user.id)
        try:
            # 处理客户端发来的消息
            async def forward_to_service():
            async def forward_to_ragflow():
                while True:
                    message = await websocket.receive_json()
                    print(f"Received from client {chat_id}: {message}")
                    # 添加 'agent_id' 和 'chat_id' 字段
                    message['flow_id'] = agent_id
                    message['chat_id'] = chat_id
                    msg = message["message"]
                    del message["message"]
                    message['inputs'] = {
                        "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
                        "input": msg
                    }
                    await service_websocket.send(json.dumps(message))
                    print(f"Forwarded to bisheng: {message}")
                    chat_history = message.get('chatHistory', [])
                    if len(chat_history) == 0:
            # 监听毕昇发来的消息并转发给客户端
            async def forward_to_client():
                while True:
                    message = await service_websocket.recv()
                    print(f"Received from service S: {message}")
                    await websocket.send_text(message)
                    print(f"Forwarded to client {chat_id}: {message}")
            # 启动两个任务,分别处理客户端和服务端的消息
                        chat_history = await ragflow_service.set_session(token, agent_id, message["message"], chat_id, True)
                        if len(chat_history) == 0:
                            result = {"message": "内部错误:创建会话失败", "type": "close"}
                            await websocket.send_json(result)
                    async for rag_response in ragflow_service.chat(token, chat_id, chat_history):
                        try:
                            print(f"Received from ragflow: {rag_response}")
                            if rag_response[:5] == "data:":
                                # 如果是,则截取掉前5个字符,并去除首尾空白符
                                json_str = rag_response[5:].strip()
                            else:
                                # 否则,保持原样
                                json_str = rag_response
                            json_data = json.loads(json_str)
                            data = json_data.get("data")
                            if data is True:  # 完成输出
                                result = {"message": "", "type": "close"}
                            elif data is None:  # 发生错误
                                answer = json_data.get("retmsg", json_data.get("retcode"))
                                result = {"message": "内部错误:" + answer, "type": "stream"}
                            else:  # 正常输出
                                answer = data.get("answer", "")
                                result = {"message": answer, "type": "stream"}
                            await websocket.send_json(result)
                            print(f"Forwarded to client {chat_id}: {result}")
                        except Exception as e:
                            result = {"message": f"内部错误: {e}", "type": "close"}
                            await websocket.send_json(result)
                            print(f"Error process message of ragflow: {e}")
            # 启动任务处理客户端消息
            tasks = [
                asyncio.create_task(forward_to_service()),
                asyncio.create_task(forward_to_client())
                asyncio.create_task(forward_to_ragflow())
            ]
            done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
            # 取消未完成的任务
            for task in pending:
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass
            await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
        except WebSocketDisconnect:
            print(f"Client {chat_id} disconnected")
        finally:
            del client_websockets[chat_id]
    elif agent_type == AgentType.BISHENG:
        token = get_bisheng_token(db, current_user.id)
        service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
        headers = {'cookie': f"access_token_cookie={token};"}
        async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
            try:
                # 处理客户端发来的消息
                async def forward_to_service():
                    while True:
                        message = await websocket.receive_json()
                        print(f"Received from client, {chat_id}: {message}")
                        # 添加 'agent_id' 和 'chat_id' 字段
                        message['flow_id'] = agent_id
                        message['chat_id'] = chat_id
                        msg = message["message"]
                        del message["message"]
                        message['inputs'] = {
                            "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
                            "input": msg
                        }
                        await service_websocket.send(json.dumps(message))
                        print(f"Forwarded to bisheng: {message}")
                # 监听毕昇发来的消息并转发给客户端
                async def forward_to_client():
                    while True:
                        message = await service_websocket.recv()
                        print(f"Received from bisheng: {message}")
                        data = json.loads(message)
                        if data["type"] == "close" or data["type"] == "stream" or data["type"] == "end_cover":
                            if data["type"] == "close":
                                t = "close"
                            else:
                                t = "stream"
                            result = {"message": data["message"], "type": t}
                            await websocket.send_json(result)
                            print(f"Forwarded to client, {chat_id}: {result}")
                # 启动两个任务,分别处理客户端和服务端的消息
                tasks = [
                    asyncio.create_task(forward_to_service()),
                    asyncio.create_task(forward_to_client())
                ]
                done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
                # 取消未完成的任务
                for task in pending:
                    task.cancel()
                    try:
                        await task
                    except asyncio.CancelledError:
                        pass
            except WebSocketDisconnect:
                print(f"Client {chat_id} disconnected")
    else:
        ret = {"message": "Agent not found", "type": "close"}
        await websocket.send_json(ret)