zhangqian
2024-10-11 aa99acacfe3c21fbd638652f2fba1c1c62e3c414
websocket接口,转发毕昇对话
6个文件已修改
2个文件已添加
193 ■■■■■ 已修改文件
app/api/__init__.py 62 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/auth.py 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/chat.py 92 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/config/config.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/config/config.yaml 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/auth.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/token.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
main.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/__init__.py
@@ -1,10 +1,70 @@
from fastapi import FastAPI
import jwt
from fastapi import FastAPI, Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from pydantic import BaseModel
from starlette import status
from starlette.websockets import WebSocket, WebSocketDisconnect
from app.models.user_model import UserModel
from app.service.auth import SECRET_KEY, ALGORITHM
app = FastAPI()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class Response(BaseModel):
    code: int = 200
    msg: str = ""
    data: dict = {}
def get_current_user(token: str = Depends(oauth2_scheme)):
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="无法验证凭证",
                headers={"WWW-Authenticate": "Bearer"},
            )
        user = UserModel(username=username, id=payload.get("user_id"))
        if user.id == 0:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="用户不存在",
                headers={"WWW-Authenticate": "Bearer"},
            )
        return user
    except jwt.PyJWTError:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="令牌无效或已过期",
            headers={"WWW-Authenticate": "Bearer"},
        )
async def get_current_user_websocket(websocket: WebSocket):
    auth_header = websocket.headers.get('Authorization')
    if auth_header is None or not auth_header.startswith('Bearer '):
        await websocket.close(code=1008)
        raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
    token = auth_header[len('Bearer '):]
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            await websocket.close(code=1008)
            raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
        user = UserModel(username=username, id=payload.get("user_id"))
        if user is None:
            await websocket.close(code=1008)
            raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
        return user
    except jwt.PyJWTError as e:
        print(e)
        await websocket.close(code=1008)
        raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
app/api/auth.py
@@ -1,16 +1,13 @@
from typing import Dict
import json
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, Request
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from app.api import Response
from app.api import Response, pwd_context, oauth2_scheme, get_current_user
from app.config.config import settings
from app.models.base_model import get_db
from app.models.token_model import upsert_token
from app.models.user import User, UserCreate, LoginData
from app.models.user import UserCreate, LoginData
from app.models.user_model import UserModel
from app.service.auth import authenticate_user, create_access_token
from app.service.bisheng import BishengService
@@ -18,8 +15,7 @@
router = APIRouter()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@router.post("/register", response_model=Response)
@@ -74,7 +70,7 @@
        return Response(code=500, msg=f"Failed to login with Ragflow: {str(e)}")
    # 创建本地token
    access_token = create_access_token(data={"sub": user.username})
    access_token = create_access_token(data={"sub": user.username, "user_id": user.id})
    upsert_token(db, user.id, access_token, bisheng_token, ragflow_token)
app/api/chat.py
New file
@@ -0,0 +1,92 @@
import json
import uuid
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Request, Depends
import asyncio
import websockets
from sqlalchemy.orm import Session
from app.api import get_current_user_websocket
from app.config.config import settings
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.token import get_bisheng_token
router = APIRouter()
# 存储客户端 WebSocket 连接
client_websockets = {}
# 中间层WebSocket 服务器,接收客户端的连接
@router.websocket("/ws/{agent_id}/{chat_id}")
async def handle_client(websocket: WebSocket,
                        agent_id: str,
                        chat_id: str,
                        current_user: UserModel = Depends(get_current_user_websocket),
                        db: Session = Depends(get_db)):
    await websocket.accept()
    print(f"Client {agent_id} connected")
    token = get_bisheng_token(db, current_user.id)
    if agent_id == "0":
        agent_id = settings.bisheng_agent_id
    if chat_id == "0":
        chat_id = uuid.uuid4().hex
    # 连接到服务端
    service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
    headers = {
        'cookie':  f"access_token_cookie={token};"
    }
    async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
        client_websockets[chat_id] = websocket
        try:
            # 处理客户端发来的消息
            async def forward_to_service():
                while True:
                    message = await websocket.receive_json()
                    print(f"Received from client {chat_id}: {message}")
                    # 添加 'agent_id' 和 'chat_id' 字段
                    message['flow_id'] = agent_id
                    message['chat_id'] = chat_id
                    msg = message["message"]
                    del message["message"]
                    message['inputs'] = {
                        "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
                        "input": msg
                    }
                    await service_websocket.send(json.dumps(message))
                    print(f"Forwarded to bisheng: {message}")
            # 监听毕昇发来的消息并转发给客户端
            async def forward_to_client():
                while True:
                    message = await service_websocket.recv()
                    print(f"Received from service S: {message}")
                    await websocket.send_text(message)
                    print(f"Forwarded to client {chat_id}: {message}")
            # 启动两个任务,分别处理客户端和服务端的消息
            tasks = [
                asyncio.create_task(forward_to_service()),
                asyncio.create_task(forward_to_client())
            ]
            done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
            # 取消未完成的任务
            for task in pending:
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass
        except WebSocketDisconnect:
            print(f"Client {chat_id} disconnected")
        finally:
            del client_websockets[chat_id]
app/config/config.py
@@ -6,10 +6,12 @@
class Settings:
    secret_key: str = ''
    bisheng_base_url: str = ''
    bisheng_websocket_url: str = ''
    ragflow_base_url: str = ''
    database_url: str = ''
    PUBLIC_KEY: str
    PRIVATE_KEY: str
    bisheng_agent_id: str
    def __init__(self, **kwargs):
        # Check if all required fields are provided and set them
app/config/config.yaml
@@ -1,9 +1,11 @@
secret_key: your-secret-key
bisheng_base_url: http://192.168.20.119:13001
bisheng_websocket_url: ws://192.168.20.119:13001
ragflow_base_url: http://192.168.20.119:11080
database_url: mysql+pymysql://root:infini_rag_flow@192.168.20.116:5455/rag_basic
PUBLIC_KEY: |
  -----BEGIN PUBLIC KEY-----
  MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB
  -----END PUBLIC KEY-----
PRIVATE_KEY: str
PRIVATE_KEY: str
bisheng_agent_id: 29dd57cf-1bd6-440d-af2c-2aac1c954770
app/service/auth.py
@@ -8,7 +8,7 @@
SECRET_KEY = settings.secret_key
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
ACCESS_TOKEN_EXPIRE_MINUTES = 3000
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
app/service/token.py
New file
@@ -0,0 +1,15 @@
from app.models.token_model import TokenModel
def get_bisheng_token(db, user_id: int):
    token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first()
    if not token:
        return None
    return token.bisheng_token
def get_ragflow_token(db, user_id: int):
    token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first()
    if not token:
        return None
    return token.ragflow_token
main.py
@@ -1,5 +1,6 @@
from fastapi import FastAPI
from app.api.auth import router as auth_router
from app.api.chat import router as chat_router
from app.models.base_model import init_db
init_db()
@@ -10,6 +11,7 @@
)
app.include_router(auth_router, prefix='/auth', tags=["auth"])
app.include_router(chat_router, prefix='/chat', tags=["chat"])
if __name__ == "__main__":
    import uvicorn