From aa99acacfe3c21fbd638652f2fba1c1c62e3c414 Mon Sep 17 00:00:00 2001 From: zhangqian <zhangqian@123.com> Date: 星期五, 11 十月 2024 21:38:55 +0800 Subject: [PATCH] websocket接口,转发毕昇对话 --- app/config/config.py | 2 app/api/chat.py | 92 +++++++++++++++++++++++ main.py | 2 app/service/token.py | 15 +++ app/config/config.yaml | 4 app/api/__init__.py | 62 +++++++++++++++ app/api/auth.py | 14 +-- app/service/auth.py | 2 8 files changed, 181 insertions(+), 12 deletions(-) diff --git a/app/api/__init__.py b/app/api/__init__.py index bcd5c2a..5ddd5aa 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -1,10 +1,70 @@ -from fastapi import FastAPI +import jwt +from fastapi import FastAPI, Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer +from passlib.context import CryptContext from pydantic import BaseModel +from starlette import status +from starlette.websockets import WebSocket, WebSocketDisconnect + +from app.models.user_model import UserModel +from app.service.auth import SECRET_KEY, ALGORITHM app = FastAPI() + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") class Response(BaseModel): code: int = 200 msg: str = "" data: dict = {} + + +def get_current_user(token: str = Depends(oauth2_scheme)): + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="鏃犳硶楠岃瘉鍑瘉", + headers={"WWW-Authenticate": "Bearer"}, + ) + user = UserModel(username=username, id=payload.get("user_id")) + if user.id == 0: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="鐢ㄦ埛涓嶅瓨鍦�", + headers={"WWW-Authenticate": "Bearer"}, + ) + return user + except jwt.PyJWTError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="浠ょ墝鏃犳晥鎴栧凡杩囨湡", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +async def get_current_user_websocket(websocket: WebSocket): + auth_header = websocket.headers.get('Authorization') + if auth_header is None or not auth_header.startswith('Bearer '): + await websocket.close(code=1008) + raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) + token = auth_header[len('Bearer '):] + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + await websocket.close(code=1008) + raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) + user = UserModel(username=username, id=payload.get("user_id")) + if user is None: + await websocket.close(code=1008) + raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) + return user + except jwt.PyJWTError as e: + print(e) + await websocket.close(code=1008) + raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) diff --git a/app/api/auth.py b/app/api/auth.py index 86da42f..feef6d9 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -1,16 +1,13 @@ -from typing import Dict -import json - -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, Request from fastapi.security import OAuth2PasswordBearer from passlib.context import CryptContext from sqlalchemy.orm import Session -from app.api import Response +from app.api import Response, pwd_context, oauth2_scheme, get_current_user from app.config.config import settings from app.models.base_model import get_db from app.models.token_model import upsert_token -from app.models.user import User, UserCreate, LoginData +from app.models.user import UserCreate, LoginData from app.models.user_model import UserModel from app.service.auth import authenticate_user, create_access_token from app.service.bisheng import BishengService @@ -18,8 +15,7 @@ router = APIRouter() -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + @router.post("/register", response_model=Response) @@ -74,7 +70,7 @@ return Response(code=500, msg=f"Failed to login with Ragflow: {str(e)}") # 鍒涘缓鏈湴token - access_token = create_access_token(data={"sub": user.username}) + access_token = create_access_token(data={"sub": user.username, "user_id": user.id}) upsert_token(db, user.id, access_token, bisheng_token, ragflow_token) diff --git a/app/api/chat.py b/app/api/chat.py new file mode 100644 index 0000000..c7aa2da --- /dev/null +++ b/app/api/chat.py @@ -0,0 +1,92 @@ +import json +import uuid + +from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Request, Depends +import asyncio +import websockets +from sqlalchemy.orm import Session +from app.api import get_current_user_websocket +from app.config.config import settings +from app.models.base_model import get_db +from app.models.user_model import UserModel +from app.service.token import get_bisheng_token + +router = APIRouter() + +# 瀛樺偍瀹㈡埛绔� WebSocket 杩炴帴 +client_websockets = {} + + +# 涓棿灞俉ebSocket 鏈嶅姟鍣紝鎺ユ敹瀹㈡埛绔殑杩炴帴 +@router.websocket("/ws/{agent_id}/{chat_id}") +async def handle_client(websocket: WebSocket, + agent_id: str, + chat_id: str, + current_user: UserModel = Depends(get_current_user_websocket), + db: Session = Depends(get_db)): + await websocket.accept() + print(f"Client {agent_id} connected") + + token = get_bisheng_token(db, current_user.id) + + if agent_id == "0": + agent_id = settings.bisheng_agent_id + if chat_id == "0": + chat_id = uuid.uuid4().hex + + # 杩炴帴鍒版湇鍔$ + service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}" + headers = { + 'cookie': f"access_token_cookie={token};" + } + + async with websockets.connect(service_uri, extra_headers=headers) as service_websocket: + client_websockets[chat_id] = websocket + + try: + # 澶勭悊瀹㈡埛绔彂鏉ョ殑娑堟伅 + async def forward_to_service(): + while True: + message = await websocket.receive_json() + print(f"Received from client {chat_id}: {message}") + # 娣诲姞 'agent_id' 鍜� 'chat_id' 瀛楁 + message['flow_id'] = agent_id + message['chat_id'] = chat_id + msg = message["message"] + del message["message"] + message['inputs'] = { + "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"}, + "input": msg + } + await service_websocket.send(json.dumps(message)) + print(f"Forwarded to bisheng: {message}") + + + # 鐩戝惉姣曟槆鍙戞潵鐨勬秷鎭苟杞彂缁欏鎴风 + async def forward_to_client(): + while True: + message = await service_websocket.recv() + print(f"Received from service S: {message}") + await websocket.send_text(message) + print(f"Forwarded to client {chat_id}: {message}") + + # 鍚姩涓や釜浠诲姟锛屽垎鍒鐞嗗鎴风鍜屾湇鍔$鐨勬秷鎭� + tasks = [ + asyncio.create_task(forward_to_service()), + asyncio.create_task(forward_to_client()) + ] + + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + # 鍙栨秷鏈畬鎴愮殑浠诲姟 + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + except WebSocketDisconnect: + print(f"Client {chat_id} disconnected") + finally: + del client_websockets[chat_id] diff --git a/app/config/config.py b/app/config/config.py index 8044029..0833ef2 100644 --- a/app/config/config.py +++ b/app/config/config.py @@ -6,10 +6,12 @@ class Settings: secret_key: str = '' bisheng_base_url: str = '' + bisheng_websocket_url: str = '' ragflow_base_url: str = '' database_url: str = '' PUBLIC_KEY: str PRIVATE_KEY: str + bisheng_agent_id: str def __init__(self, **kwargs): # Check if all required fields are provided and set them diff --git a/app/config/config.yaml b/app/config/config.yaml index bb7328d..db7638b 100644 --- a/app/config/config.yaml +++ b/app/config/config.yaml @@ -1,9 +1,11 @@ secret_key: your-secret-key bisheng_base_url: http://192.168.20.119:13001 +bisheng_websocket_url: ws://192.168.20.119:13001 ragflow_base_url: http://192.168.20.119:11080 database_url: mysql+pymysql://root:infini_rag_flow@192.168.20.116:5455/rag_basic PUBLIC_KEY: | -----BEGIN PUBLIC KEY----- MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB -----END PUBLIC KEY----- -PRIVATE_KEY: str \ No newline at end of file +PRIVATE_KEY: str +bisheng_agent_id: 29dd57cf-1bd6-440d-af2c-2aac1c954770 \ No newline at end of file diff --git a/app/service/auth.py b/app/service/auth.py index 3ebccb1..09a4917 100644 --- a/app/service/auth.py +++ b/app/service/auth.py @@ -8,7 +8,7 @@ SECRET_KEY = settings.secret_key ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 +ACCESS_TOKEN_EXPIRE_MINUTES = 3000 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") diff --git a/app/service/token.py b/app/service/token.py new file mode 100644 index 0000000..9c54fed --- /dev/null +++ b/app/service/token.py @@ -0,0 +1,15 @@ +from app.models.token_model import TokenModel + + +def get_bisheng_token(db, user_id: int): + token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first() + if not token: + return None + return token.bisheng_token + + +def get_ragflow_token(db, user_id: int): + token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first() + if not token: + return None + return token.ragflow_token diff --git a/main.py b/main.py index c6fb2cd..95945d6 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ from fastapi import FastAPI from app.api.auth import router as auth_router +from app.api.chat import router as chat_router from app.models.base_model import init_db init_db() @@ -10,6 +11,7 @@ ) app.include_router(auth_router, prefix='/auth', tags=["auth"]) +app.include_router(chat_router, prefix='/chat', tags=["chat"]) if __name__ == "__main__": import uvicorn -- Gitblit v1.8.0