zhangqian
2024-10-12 1963c42487b3980cb8513a2cc7669da0876c3037
app/api/chat.py
@@ -1,7 +1,7 @@
import json
import uuid
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Request, Depends
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends
import asyncio
import websockets
from sqlalchemy.orm import Session
@@ -9,7 +9,8 @@
from app.config.config import settings
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.token import get_bisheng_token
from app.service.ragflow import RagflowService
from app.service.token import get_bisheng_token, get_ragflow_token
router = APIRouter()
@@ -27,21 +28,52 @@
    await websocket.accept()
    print(f"Client {agent_id} connected")
    token = get_bisheng_token(db, current_user.id)
    if agent_id == "0":
        agent_id = settings.bisheng_agent_id
    elif agent_id == "1":
        agent_id = settings.ragflow_agent_id
        chat_id = settings.ragflow_chat_id
    if chat_id == "0":
        chat_id = uuid.uuid4().hex
    # 连接到服务端
    client_websockets[chat_id] = websocket
    if agent_id == settings.ragflow_agent_id:
        ragflow_service = RagflowService(settings.ragflow_base_url)
        token = get_ragflow_token(db, current_user.id)
        try:
            async def forward_to_ragflow():
                while True:
                    message = await websocket.receive_json()
                    print(f"Received from client {chat_id}: {message}")
                    async for rag_response in ragflow_service.chat(token, chat_id, message["chatHistory"]):
                        print(f"Received from ragflow: {rag_response}")
                        json_str = rag_response[5:].strip()
                        json_data = json.loads(json_str)
                        if json_data.get("data") is not True:
                            answer = json_data.get("data", {}).get("answer", "")
                            result = {"message": answer, "type": "stream"}
                        else:
                            result = {"message": "", "type": "close"}
                        await websocket.send_json(result)
                        print(f"Forwarded to client {chat_id}: {result}")
            # 启动任务处理客户端消息
            tasks = [
                asyncio.create_task(forward_to_ragflow())
            ]
            await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
        except WebSocketDisconnect:
            print(f"Client {chat_id} disconnected")
        finally:
            del client_websockets[chat_id]
    else:
        token = get_bisheng_token(db, current_user.id)
    service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
    headers = {
        'cookie':  f"access_token_cookie={token};"
    }
        headers = {'cookie': f"access_token_cookie={token};"}
    async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
        client_websockets[chat_id] = websocket
        try:
            # 处理客户端发来的消息
@@ -61,7 +93,6 @@
                    await service_websocket.send(json.dumps(message))
                    print(f"Forwarded to bisheng: {message}")
            # 监听毕昇发来的消息并转发给客户端
            async def forward_to_client():
                while True:
@@ -75,7 +106,6 @@
                asyncio.create_task(forward_to_service()),
                asyncio.create_task(forward_to_client())
            ]
            done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
            # 取消未完成的任务
@@ -90,3 +120,5 @@
            print(f"Client {chat_id} disconnected")
        finally:
            del client_websockets[chat_id]