From aa99acacfe3c21fbd638652f2fba1c1c62e3c414 Mon Sep 17 00:00:00 2001 From: zhangqian <zhangqian@123.com> Date: 星期五, 11 十月 2024 21:38:55 +0800 Subject: [PATCH] websocket接口,转发毕昇对话 --- app/api/__init__.py | 62 ++++++++++++++++++++++++++++++ 1 files changed, 61 insertions(+), 1 deletions(-) diff --git a/app/api/__init__.py b/app/api/__init__.py index bcd5c2a..5ddd5aa 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -1,10 +1,70 @@ -from fastapi import FastAPI +import jwt +from fastapi import FastAPI, Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer +from passlib.context import CryptContext from pydantic import BaseModel +from starlette import status +from starlette.websockets import WebSocket, WebSocketDisconnect + +from app.models.user_model import UserModel +from app.service.auth import SECRET_KEY, ALGORITHM app = FastAPI() + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") class Response(BaseModel): code: int = 200 msg: str = "" data: dict = {} + + +def get_current_user(token: str = Depends(oauth2_scheme)): + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="鏃犳硶楠岃瘉鍑瘉", + headers={"WWW-Authenticate": "Bearer"}, + ) + user = UserModel(username=username, id=payload.get("user_id")) + if user.id == 0: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="鐢ㄦ埛涓嶅瓨鍦�", + headers={"WWW-Authenticate": "Bearer"}, + ) + return user + except jwt.PyJWTError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="浠ょ墝鏃犳晥鎴栧凡杩囨湡", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +async def get_current_user_websocket(websocket: WebSocket): + auth_header = websocket.headers.get('Authorization') + if auth_header is None or not auth_header.startswith('Bearer '): + await websocket.close(code=1008) + raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) + token = auth_header[len('Bearer '):] + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + await websocket.close(code=1008) + raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) + user = UserModel(username=username, id=payload.get("user_id")) + if user is None: + await websocket.close(code=1008) + raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) + return user + except jwt.PyJWTError as e: + print(e) + await websocket.close(code=1008) + raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) -- Gitblit v1.8.0