zhangqian
2024-10-12 1963c42487b3980cb8513a2cc7669da0876c3037
app/api/chat.py
@@ -1,7 +1,7 @@
import json
import uuid
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Request, Depends
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends
import asyncio
import websockets
from sqlalchemy.orm import Session
@@ -9,7 +9,8 @@
from app.config.config import settings
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.token import get_bisheng_token
from app.service.ragflow import RagflowService
from app.service.token import get_bisheng_token, get_ragflow_token
router = APIRouter()
@@ -27,66 +28,97 @@
    await websocket.accept()
    print(f"Client {agent_id} connected")
    token = get_bisheng_token(db, current_user.id)
    if agent_id == "0":
        agent_id = settings.bisheng_agent_id
    elif agent_id == "1":
        agent_id = settings.ragflow_agent_id
        chat_id = settings.ragflow_chat_id
    if chat_id == "0":
        chat_id = uuid.uuid4().hex
    # 连接到服务端
    service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
    headers = {
        'cookie':  f"access_token_cookie={token};"
    }
    async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
        client_websockets[chat_id] = websocket
    client_websockets[chat_id] = websocket
    if agent_id == settings.ragflow_agent_id:
        ragflow_service = RagflowService(settings.ragflow_base_url)
        token = get_ragflow_token(db, current_user.id)
        try:
            # 处理客户端发来的消息
            async def forward_to_service():
            async def forward_to_ragflow():
                while True:
                    message = await websocket.receive_json()
                    print(f"Received from client {chat_id}: {message}")
                    # 添加 'agent_id' 和 'chat_id' 字段
                    message['flow_id'] = agent_id
                    message['chat_id'] = chat_id
                    msg = message["message"]
                    del message["message"]
                    message['inputs'] = {
                        "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
                        "input": msg
                    }
                    await service_websocket.send(json.dumps(message))
                    print(f"Forwarded to bisheng: {message}")
                    async for rag_response in ragflow_service.chat(token, chat_id, message["chatHistory"]):
                        print(f"Received from ragflow: {rag_response}")
                        json_str = rag_response[5:].strip()
                        json_data = json.loads(json_str)
                        if json_data.get("data") is not True:
                            answer = json_data.get("data", {}).get("answer", "")
                            result = {"message": answer, "type": "stream"}
                        else:
                            result = {"message": "", "type": "close"}
                        await websocket.send_json(result)
                        print(f"Forwarded to client {chat_id}: {result}")
            # 监听毕昇发来的消息并转发给客户端
            async def forward_to_client():
                while True:
                    message = await service_websocket.recv()
                    print(f"Received from service S: {message}")
                    await websocket.send_text(message)
                    print(f"Forwarded to client {chat_id}: {message}")
            # 启动两个任务,分别处理客户端和服务端的消息
            # 启动任务处理客户端消息
            tasks = [
                asyncio.create_task(forward_to_service()),
                asyncio.create_task(forward_to_client())
                asyncio.create_task(forward_to_ragflow())
            ]
            done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
            # 取消未完成的任务
            for task in pending:
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass
            await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
        except WebSocketDisconnect:
            print(f"Client {chat_id} disconnected")
        finally:
            del client_websockets[chat_id]
    else:
        token = get_bisheng_token(db, current_user.id)
        service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
        headers = {'cookie': f"access_token_cookie={token};"}
        async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
            try:
                # 处理客户端发来的消息
                async def forward_to_service():
                    while True:
                        message = await websocket.receive_json()
                        print(f"Received from client {chat_id}: {message}")
                        # 添加 'agent_id' 和 'chat_id' 字段
                        message['flow_id'] = agent_id
                        message['chat_id'] = chat_id
                        msg = message["message"]
                        del message["message"]
                        message['inputs'] = {
                            "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
                            "input": msg
                        }
                        await service_websocket.send(json.dumps(message))
                        print(f"Forwarded to bisheng: {message}")
                # 监听毕昇发来的消息并转发给客户端
                async def forward_to_client():
                    while True:
                        message = await service_websocket.recv()
                        print(f"Received from service S: {message}")
                        await websocket.send_text(message)
                        print(f"Forwarded to client {chat_id}: {message}")
                # 启动两个任务,分别处理客户端和服务端的消息
                tasks = [
                    asyncio.create_task(forward_to_service()),
                    asyncio.create_task(forward_to_client())
                ]
                done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
                # 取消未完成的任务
                for task in pending:
                    task.cancel()
                    try:
                        await task
                    except asyncio.CancelledError:
                        pass
            except WebSocketDisconnect:
                print(f"Client {chat_id} disconnected")
            finally:
                del client_websockets[chat_id]