import jwt from fastapi import FastAPI, Depends, HTTPException from fastapi.security import OAuth2PasswordBearer from passlib.context import CryptContext from pydantic import BaseModel from starlette import status from starlette.websockets import WebSocket, WebSocketDisconnect from app.models.user_model import UserModel from app.service.auth import SECRET_KEY, ALGORITHM app = FastAPI() pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") class Response(BaseModel): code: int = 200 msg: str = "" data: dict = {} def get_current_user(token: str = Depends(oauth2_scheme)): try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无法验证凭证", headers={"WWW-Authenticate": "Bearer"}, ) user = UserModel(username=username, id=payload.get("user_id")) if user.id == 0: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在", headers={"WWW-Authenticate": "Bearer"}, ) return user except jwt.PyJWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="令牌无效或已过期", headers={"WWW-Authenticate": "Bearer"}, ) async def get_current_user_websocket(websocket: WebSocket): auth_header = websocket.headers.get('Authorization') if auth_header is None or not auth_header.startswith('Bearer '): await websocket.close(code=1008) raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) token = auth_header[len('Bearer '):] try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: await websocket.close(code=1008) raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) user = UserModel(username=username, id=payload.get("user_id")) if user is None: await websocket.close(code=1008) raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) return user except jwt.PyJWTError as e: print(e) await websocket.close(code=1008) raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)