import jwt
|
from fastapi import FastAPI, Depends, HTTPException
|
from fastapi.security import OAuth2PasswordBearer
|
from passlib.context import CryptContext
|
from pydantic import BaseModel
|
from starlette import status
|
from starlette.websockets import WebSocket, WebSocketDisconnect
|
|
from app.models.user_model import UserModel
|
from app.service.auth import SECRET_KEY, ALGORITHM
|
|
app = FastAPI()
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
|
class Response(BaseModel):
|
code: int = 200
|
msg: str = ""
|
data: dict = {}
|
|
|
class ResponseList(BaseModel):
|
code: int = 200
|
msg: str = ""
|
data: list[dict] = []
|
|
|
def get_current_user(token: str = Depends(oauth2_scheme)):
|
try:
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
username: str = payload.get("sub")
|
if username is None:
|
raise HTTPException(
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
detail="无法验证凭证",
|
headers={"WWW-Authenticate": "Bearer"},
|
)
|
user = UserModel(username=username, id=payload.get("user_id"))
|
if user.id == 0:
|
raise HTTPException(
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
detail="用户不存在",
|
headers={"WWW-Authenticate": "Bearer"},
|
)
|
return user
|
except jwt.PyJWTError:
|
raise HTTPException(
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
detail="令牌无效或已过期",
|
headers={"WWW-Authenticate": "Bearer"},
|
)
|
|
|
async def get_current_user_websocket(websocket: WebSocket):
|
token = websocket.query_params.get('token')
|
if token is None:
|
await websocket.close(code=1008)
|
raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
|
try:
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
username: str = payload.get("sub")
|
if username is None:
|
await websocket.close(code=1008)
|
raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
|
user = UserModel(username=username, id=payload.get("user_id"))
|
if user is None:
|
await websocket.close(code=1008)
|
raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
|
return user
|
except jwt.PyJWTError as e:
|
print(e)
|
await websocket.close(code=1008)
|
raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
|