import urllib from datetime import datetime from typing import Callable, Any from urllib.parse import urlencode import jwt # from cryptography.fernet import Fernet from fastapi import FastAPI, Depends, HTTPException, Header, Request from fastapi.security import OAuth2PasswordBearer from passlib.context import CryptContext from pydantic import BaseModel from starlette import status from starlette.websockets import WebSocket, WebSocketDisconnect from Log import logger from app.models.base_model import SessionLocal # from app.models.app_model import AppRegisterModel from app.models.user_model import UserModel, UserApiTokenModel from app.service.auth import SECRET_KEY, ALGORITHM from app.config.config import settings app = FastAPI() pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # cipher_suite = Fernet(settings.HASH_SUB_KEY) class Response(BaseModel): code: int = 200 msg: str = "" data: dict = {} class ResponseList(BaseModel): code: int = 200 msg: str = "" data: list[dict] = [] def verify_token(token: str) -> Any: """ 验证 Token 是否有效 """ db = SessionLocal() try: db_token = db.query(UserApiTokenModel).filter(UserApiTokenModel.token == token, UserApiTokenModel.is_active == 1).first() return db_token is not None and (db_token.expires_at is None or db_token.expires_at > datetime.now()) finally: db.close() def token_required()-> Callable: def decorated_function(request: Request)-> Any: authorization_str = request.headers.get("Authorization") if not authorization_str: raise HTTPException(status_code=401, detail="Authorization` can't be empty") authorization_list = authorization_str.split() if len(authorization_list) < 2: raise HTTPException(status_code=401, detail="Invalid token") token = authorization_list[1] objs = verify_token(token) if not objs: raise HTTPException(status_code=401, detail="Invalid token") user = UserModel(username="", id=objs.user_id) return user return decorated_function def get_current_user(token: str = Depends(oauth2_scheme)): try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) expired_time = payload.get("lex") if not expired_time: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="令牌无效或已过期", headers={"WWW-Authenticate": "Bearer"}) if datetime.strptime(expired_time, "%Y-%m-%d %H:%M:%S") < datetime.now(): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="系统授权已过期!", headers={"WWW-Authenticate": "Bearer"}) username: str = payload.get("sub") if username is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无法验证凭证", headers={"WWW-Authenticate": "Bearer"}, ) user = UserModel(username=username, id=payload.get("user_id")) if user.id == 0: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在", headers={"WWW-Authenticate": "Bearer"}, ) return user except jwt.PyJWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="令牌无效或已过期", headers={"WWW-Authenticate": "Bearer"}, ) async def get_current_user_websocket(websocket: WebSocket): token = websocket.query_params.get('token') if token is None: await websocket.close(code=1008) raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: await websocket.close(code=1008) raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) user = UserModel(username=username, id=payload.get("user_id")) if user is None: await websocket.close(code=1008) raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) return user except jwt.PyJWTError as e: print(e) await websocket.close(code=1008) raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) def format_file_url(agent_id: str, file_url: str, doc_id: str = None, doc_name: str = None) -> str: if file_url: # 对 file_url 进行 URL 编码 encoded_file_url = urllib.parse.quote(file_url, safe=':/') return f"./api/files/download/?url={encoded_file_url}&agent_id={agent_id}" if doc_id: # 对 doc_id 和 doc_name 进行 URL 编码 encoded_doc_id = urllib.parse.quote(doc_id, safe='') encoded_doc_name = urllib.parse.quote(doc_name, safe='') return f"./api/files/download/?doc_id={encoded_doc_id}&doc_name={encoded_doc_name}&agent_id={agent_id}" return file_url def process_files(files, agent_id): """ 处理文件列表,格式化每个文件的 URL。 :param files: 文件列表,每个文件是一个字典 :param agent_id: 代理 ID """ if not files: return # 如果文件列表为空,直接返回 for file in files: if "file_url" in file and file["file_url"]: try: file["file_url"] = format_file_url(agent_id, file["file_url"]) except Exception as e: # 记录异常信息,但继续处理其他文件 print(f"Error processing file URL: {e}") def get_api_key(authorization: str = Header(...)): if not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid Authorization header format.") return authorization.split(" ")[1] if __name__=="__main__": files1 = [{"file_url": "aaa.com"}, {"file_url":"bbb.com"}] print(files1) process_files(files1,11111) print(files1)