| | |
| | | import urllib |
| | | from datetime import datetime |
| | | from typing import Callable, Any |
| | | from urllib.parse import urlencode |
| | | |
| | | import jwt |
| | | from fastapi import FastAPI, Depends, HTTPException |
| | | # from cryptography.fernet import Fernet |
| | | from fastapi import FastAPI, Depends, HTTPException, Header, Request |
| | | from fastapi.security import OAuth2PasswordBearer |
| | | from passlib.context import CryptContext |
| | | from pydantic import BaseModel |
| | | from starlette import status |
| | | from starlette.websockets import WebSocket, WebSocketDisconnect |
| | | |
| | | from app.models.user_model import UserModel |
| | | from Log import logger |
| | | from app.models.base_model import SessionLocal |
| | | # from app.models.app_model import AppRegisterModel |
| | | from app.models.user_model import UserModel, UserApiTokenModel |
| | | from app.service.auth import SECRET_KEY, ALGORITHM |
| | | from app.config.config import settings |
| | | |
| | | app = FastAPI() |
| | | |
| | | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
| | | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
| | | # cipher_suite = Fernet(settings.HASH_SUB_KEY) |
| | | |
| | | |
| | | class Response(BaseModel): |
| | |
| | | data: dict = {} |
| | | |
| | | |
| | | class ResponseList(BaseModel): |
| | | code: int = 200 |
| | | msg: str = "" |
| | | data: list[dict] = [] |
| | | |
| | | |
| | | def verify_token(token: str) -> Any: |
| | | """ |
| | | 验证 Token 是否有效 |
| | | """ |
| | | db = SessionLocal() |
| | | try: |
| | | db_token = db.query(UserApiTokenModel).filter(UserApiTokenModel.token == token, UserApiTokenModel.is_active == 1).first() |
| | | return db_token is not None and (db_token.expires_at is None or db_token.expires_at > datetime.now()) |
| | | finally: |
| | | db.close() |
| | | |
| | | def token_required()-> Callable: |
| | | def decorated_function(request: Request)-> Any: |
| | | authorization_str = request.headers.get("Authorization") |
| | | if not authorization_str: |
| | | raise HTTPException(status_code=401, detail="Authorization` can't be empty") |
| | | authorization_list = authorization_str.split() |
| | | if len(authorization_list) < 2: |
| | | raise HTTPException(status_code=401, detail="Invalid token") |
| | | token = authorization_list[1] |
| | | objs = verify_token(token) |
| | | if not objs: |
| | | raise HTTPException(status_code=401, detail="Invalid token") |
| | | user = UserModel(username="", id=objs.user_id) |
| | | return user |
| | | return decorated_function |
| | | |
| | | def get_current_user(token: str = Depends(oauth2_scheme)): |
| | | try: |
| | | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
| | | expired_time = payload.get("lex") |
| | | if not expired_time: |
| | | raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="令牌无效或已过期", |
| | | headers={"WWW-Authenticate": "Bearer"}) |
| | | if datetime.strptime(expired_time, "%Y-%m-%d %H:%M:%S") < datetime.now(): |
| | | raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="系统授权已过期!", |
| | | headers={"WWW-Authenticate": "Bearer"}) |
| | | |
| | | username: str = payload.get("sub") |
| | | if username is None: |
| | | raise HTTPException( |
| | |
| | | except jwt.PyJWTError as e: |
| | | print(e) |
| | | await websocket.close(code=1008) |
| | | raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) |
| | | raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) |
| | | |
| | | |
| | | def format_file_url(agent_id: str, file_url: str, doc_id: str = None, doc_name: str = None) -> str: |
| | | if file_url: |
| | | # 对 file_url 进行 URL 编码 |
| | | encoded_file_url = urllib.parse.quote(file_url, safe=':/') |
| | | return f"./api/files/download/?url={encoded_file_url}&agent_id={agent_id}" |
| | | |
| | | if doc_id: |
| | | # 对 doc_id 和 doc_name 进行 URL 编码 |
| | | encoded_doc_id = urllib.parse.quote(doc_id, safe='') |
| | | encoded_doc_name = urllib.parse.quote(doc_name, safe='') |
| | | return f"./api/files/download/?doc_id={encoded_doc_id}&doc_name={encoded_doc_name}&agent_id={agent_id}" |
| | | |
| | | return file_url |
| | | |
| | | |
| | | def process_files(files, agent_id): |
| | | """ |
| | | 处理文件列表,格式化每个文件的 URL。 |
| | | |
| | | :param files: 文件列表,每个文件是一个字典 |
| | | :param agent_id: 代理 ID |
| | | """ |
| | | if not files: |
| | | return # 如果文件列表为空,直接返回 |
| | | |
| | | for file in files: |
| | | if "file_url" in file and file["file_url"]: |
| | | try: |
| | | file["file_url"] = format_file_url(agent_id, file["file_url"]) |
| | | except Exception as e: |
| | | # 记录异常信息,但继续处理其他文件 |
| | | print(f"Error processing file URL: {e}") |
| | | |
| | | def get_api_key(authorization: str = Header(...)): |
| | | if not authorization.startswith("Bearer "): |
| | | raise HTTPException(status_code=401, detail="Invalid Authorization header format.") |
| | | return authorization.split(" ")[1] |
| | | |
| | | |
| | | |
| | | if __name__=="__main__": |
| | | |
| | | files1 = [{"file_url": "aaa.com"}, {"file_url":"bbb.com"}] |
| | | print(files1) |
| | | |
| | | process_files(files1,11111) |
| | | print(files1) |