zhangqian
2024-10-11 aa99acacfe3c21fbd638652f2fba1c1c62e3c414
app/api/__init__.py
@@ -1,10 +1,70 @@
from fastapi import FastAPI
import jwt
from fastapi import FastAPI, Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from pydantic import BaseModel
from starlette import status
from starlette.websockets import WebSocket, WebSocketDisconnect
from app.models.user_model import UserModel
from app.service.auth import SECRET_KEY, ALGORITHM
app = FastAPI()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class Response(BaseModel):
    code: int = 200
    msg: str = ""
    data: dict = {}
def get_current_user(token: str = Depends(oauth2_scheme)):
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="无法验证凭证",
                headers={"WWW-Authenticate": "Bearer"},
            )
        user = UserModel(username=username, id=payload.get("user_id"))
        if user.id == 0:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="用户不存在",
                headers={"WWW-Authenticate": "Bearer"},
            )
        return user
    except jwt.PyJWTError:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="令牌无效或已过期",
            headers={"WWW-Authenticate": "Bearer"},
        )
async def get_current_user_websocket(websocket: WebSocket):
    auth_header = websocket.headers.get('Authorization')
    if auth_header is None or not auth_header.startswith('Bearer '):
        await websocket.close(code=1008)
        raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
    token = auth_header[len('Bearer '):]
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            await websocket.close(code=1008)
            raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
        user = UserModel(username=username, id=payload.get("user_id"))
        if user is None:
            await websocket.close(code=1008)
            raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
        return user
    except jwt.PyJWTError as e:
        print(e)
        await websocket.close(code=1008)
        raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)