| | |
| | | from fastapi import FastAPI |
| | | import jwt |
| | | from fastapi import FastAPI, Depends, HTTPException |
| | | from fastapi.security import OAuth2PasswordBearer |
| | | from passlib.context import CryptContext |
| | | from pydantic import BaseModel |
| | | from starlette import status |
| | | from starlette.websockets import WebSocket, WebSocketDisconnect |
| | | |
| | | from app.models.user_model import UserModel |
| | | from app.service.auth import SECRET_KEY, ALGORITHM |
| | | |
| | | app = FastAPI() |
| | | |
| | | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
| | | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
| | | |
| | | |
| | | class Response(BaseModel): |
| | | code: int = 200 |
| | | msg: str = "" |
| | | data: dict = {} |
| | | |
| | | |
| | | class ResponseList(BaseModel): |
| | | code: int = 200 |
| | | msg: str = "" |
| | | data: list[dict] = [] |
| | | |
| | | |
| | | def get_current_user(token: str = Depends(oauth2_scheme)): |
| | | try: |
| | | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
| | | username: str = payload.get("sub") |
| | | if username is None: |
| | | raise HTTPException( |
| | | status_code=status.HTTP_401_UNAUTHORIZED, |
| | | detail="无法验证凭证", |
| | | headers={"WWW-Authenticate": "Bearer"}, |
| | | ) |
| | | user = UserModel(username=username, id=payload.get("user_id")) |
| | | if user.id == 0: |
| | | raise HTTPException( |
| | | status_code=status.HTTP_401_UNAUTHORIZED, |
| | | detail="用户不存在", |
| | | headers={"WWW-Authenticate": "Bearer"}, |
| | | ) |
| | | return user |
| | | except jwt.PyJWTError: |
| | | raise HTTPException( |
| | | status_code=status.HTTP_401_UNAUTHORIZED, |
| | | detail="令牌无效或已过期", |
| | | headers={"WWW-Authenticate": "Bearer"}, |
| | | ) |
| | | |
| | | |
| | | async def get_current_user_websocket(websocket: WebSocket): |
| | | token = websocket.query_params.get('token') |
| | | if token is None: |
| | | await websocket.close(code=1008) |
| | | raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) |
| | | try: |
| | | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
| | | username: str = payload.get("sub") |
| | | if username is None: |
| | | await websocket.close(code=1008) |
| | | raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) |
| | | user = UserModel(username=username, id=payload.get("user_id")) |
| | | if user is None: |
| | | await websocket.close(code=1008) |
| | | raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) |
| | | return user |
| | | except jwt.PyJWTError as e: |
| | | print(e) |
| | | await websocket.close(code=1008) |
| | | raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) |