zhangxiao
2024-10-16 30311881800e4840a13f13dd702b093543b2082e
app/api/__init__.py
@@ -21,6 +21,12 @@
    data: dict = {}
class ResponseList(BaseModel):
    code: int = 200
    msg: str = ""
    data: list[dict] = []
def get_current_user(token: str = Depends(oauth2_scheme)):
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
@@ -48,11 +54,10 @@
async def get_current_user_websocket(websocket: WebSocket):
    auth_header = websocket.headers.get('Authorization')
    if auth_header is None or not auth_header.startswith('Bearer '):
    token = websocket.query_params.get('token')
    if token is None:
        await websocket.close(code=1008)
        raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
    token = auth_header[len('Bearer '):]
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
@@ -67,4 +72,4 @@
    except jwt.PyJWTError as e:
        print(e)
        await websocket.close(code=1008)
        raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
        raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)