zhangqian
2024-10-12 1963c42487b3980cb8513a2cc7669da0876c3037
websocket对话接口兼容ragflow流式对话
6个文件已修改
175 ■■■■■ 已修改文件
app/api/auth.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/chat.py 130 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/config/config.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/config/config.yaml 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/auth.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/ragflow.py 28 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/auth.py
@@ -1,9 +1,7 @@
from fastapi import APIRouter, Depends, Request
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.api import Response, pwd_context, oauth2_scheme, get_current_user
from app.api import Response, pwd_context
from app.config.config import settings
from app.models.base_model import get_db
from app.models.token_model import upsert_token
@@ -14,8 +12,6 @@
from app.service.ragflow import RagflowService
router = APIRouter()
@router.post("/register", response_model=Response)
app/api/chat.py
@@ -1,7 +1,7 @@
import json
import uuid
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Request, Depends
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends
import asyncio
import websockets
from sqlalchemy.orm import Session
@@ -9,7 +9,8 @@
from app.config.config import settings
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.token import get_bisheng_token
from app.service.ragflow import RagflowService
from app.service.token import get_bisheng_token, get_ragflow_token
router = APIRouter()
@@ -27,66 +28,97 @@
    await websocket.accept()
    print(f"Client {agent_id} connected")
    token = get_bisheng_token(db, current_user.id)
    if agent_id == "0":
        agent_id = settings.bisheng_agent_id
    elif agent_id == "1":
        agent_id = settings.ragflow_agent_id
        chat_id = settings.ragflow_chat_id
    if chat_id == "0":
        chat_id = uuid.uuid4().hex
    # 连接到服务端
    service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
    headers = {
        'cookie':  f"access_token_cookie={token};"
    }
    async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
        client_websockets[chat_id] = websocket
    client_websockets[chat_id] = websocket
    if agent_id == settings.ragflow_agent_id:
        ragflow_service = RagflowService(settings.ragflow_base_url)
        token = get_ragflow_token(db, current_user.id)
        try:
            # 处理客户端发来的消息
            async def forward_to_service():
            async def forward_to_ragflow():
                while True:
                    message = await websocket.receive_json()
                    print(f"Received from client {chat_id}: {message}")
                    # 添加 'agent_id' 和 'chat_id' 字段
                    message['flow_id'] = agent_id
                    message['chat_id'] = chat_id
                    msg = message["message"]
                    del message["message"]
                    message['inputs'] = {
                        "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
                        "input": msg
                    }
                    await service_websocket.send(json.dumps(message))
                    print(f"Forwarded to bisheng: {message}")
                    async for rag_response in ragflow_service.chat(token, chat_id, message["chatHistory"]):
                        print(f"Received from ragflow: {rag_response}")
                        json_str = rag_response[5:].strip()
                        json_data = json.loads(json_str)
                        if json_data.get("data") is not True:
                            answer = json_data.get("data", {}).get("answer", "")
                            result = {"message": answer, "type": "stream"}
                        else:
                            result = {"message": "", "type": "close"}
                        await websocket.send_json(result)
                        print(f"Forwarded to client {chat_id}: {result}")
            # 监听毕昇发来的消息并转发给客户端
            async def forward_to_client():
                while True:
                    message = await service_websocket.recv()
                    print(f"Received from service S: {message}")
                    await websocket.send_text(message)
                    print(f"Forwarded to client {chat_id}: {message}")
            # 启动两个任务,分别处理客户端和服务端的消息
            # 启动任务处理客户端消息
            tasks = [
                asyncio.create_task(forward_to_service()),
                asyncio.create_task(forward_to_client())
                asyncio.create_task(forward_to_ragflow())
            ]
            done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
            # 取消未完成的任务
            for task in pending:
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass
            await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
        except WebSocketDisconnect:
            print(f"Client {chat_id} disconnected")
        finally:
            del client_websockets[chat_id]
    else:
        token = get_bisheng_token(db, current_user.id)
        service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
        headers = {'cookie': f"access_token_cookie={token};"}
        async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
            try:
                # 处理客户端发来的消息
                async def forward_to_service():
                    while True:
                        message = await websocket.receive_json()
                        print(f"Received from client {chat_id}: {message}")
                        # 添加 'agent_id' 和 'chat_id' 字段
                        message['flow_id'] = agent_id
                        message['chat_id'] = chat_id
                        msg = message["message"]
                        del message["message"]
                        message['inputs'] = {
                            "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
                            "input": msg
                        }
                        await service_websocket.send(json.dumps(message))
                        print(f"Forwarded to bisheng: {message}")
                # 监听毕昇发来的消息并转发给客户端
                async def forward_to_client():
                    while True:
                        message = await service_websocket.recv()
                        print(f"Received from service S: {message}")
                        await websocket.send_text(message)
                        print(f"Forwarded to client {chat_id}: {message}")
                # 启动两个任务,分别处理客户端和服务端的消息
                tasks = [
                    asyncio.create_task(forward_to_service()),
                    asyncio.create_task(forward_to_client())
                ]
                done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
                # 取消未完成的任务
                for task in pending:
                    task.cancel()
                    try:
                        await task
                    except asyncio.CancelledError:
                        pass
            except WebSocketDisconnect:
                print(f"Client {chat_id} disconnected")
            finally:
                del client_websockets[chat_id]
app/config/config.py
@@ -11,7 +11,8 @@
    PUBLIC_KEY: str
    PRIVATE_KEY: str
    bisheng_agent_id: str
    ragflow_agent_id: str
    ragflow_chat_id: str
    def __init__(self, **kwargs):
        # Check if all required fields are provided and set them
        for field in self.__annotations__.keys():
app/config/config.yaml
@@ -8,4 +8,6 @@
  MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB
  -----END PUBLIC KEY-----
PRIVATE_KEY: str
bisheng_agent_id: 29dd57cf-1bd6-440d-af2c-2aac1c954770
bisheng_agent_id: 29dd57cf-1bd6-440d-af2c-2aac1c954770
ragflow_agent_id: 690f42554ac84ed7b8bf7605db603b2f
ragflow_chat_id: e1d131a1b89b488e97c2194d9e2d345c
app/service/auth.py
@@ -35,7 +35,7 @@
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
        expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    to_encode.update({"exp": expire})
    encoded_jwt = encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt
app/service/ragflow.py
@@ -29,4 +29,30 @@
            )
            if response.status_code != 200:
                raise Exception(f"Ragflow login failed: {response.text}")
            return response.json().get('data', {}).get('access_token')
                # 从响应头中提取 Authorization 字段
            authorization = response.headers.get('Authorization')
            if not authorization:
                raise Exception("Authorization header not found in response")
            return authorization
    async def chat(self, token: str, chat_id: str, chat_history: list):
        data = {
            "conversation_id": chat_id,
            "messages": chat_history
        }
        target_url = f"{self.base_url}/v1/conversation/completion"
        async with httpx.AsyncClient() as client:
            headers = {
                'Content-Type': 'application/json',
                'Authorization': token
            }
            # 创建流式请求
            async with client.stream("POST", target_url, json=data, headers=headers) as response:
                # 检查响应状态码
                if response.status_code == 200:
                    # 流式读取响应
                    async for answer in response.aiter_text():
                        yield answer
                else:
                    yield f"Error: {response.status_code}"