import urllib
|
from urllib.parse import urlencode
|
|
import jwt
|
from cryptography.fernet import Fernet
|
from fastapi import FastAPI, Depends, HTTPException
|
from fastapi.security import OAuth2PasswordBearer
|
from passlib.context import CryptContext
|
from pydantic import BaseModel
|
from starlette import status
|
from starlette.websockets import WebSocket, WebSocketDisconnect
|
|
from Log import logger
|
from app.models.app_model import AppRegisterModel
|
from app.models.user_model import UserModel
|
from app.service.auth import SECRET_KEY, ALGORITHM
|
from app.config.config import settings
|
|
app = FastAPI()
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
# cipher_suite = Fernet(settings.HASH_SUB_KEY)
|
|
|
class Response(BaseModel):
|
code: int = 200
|
msg: str = ""
|
data: dict = {}
|
|
|
class ResponseList(BaseModel):
|
code: int = 200
|
msg: str = ""
|
data: list[dict] = []
|
|
|
def get_current_user(token: str = Depends(oauth2_scheme)):
|
try:
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
username: str = payload.get("sub")
|
if username is None:
|
raise HTTPException(
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
detail="无法验证凭证",
|
headers={"WWW-Authenticate": "Bearer"},
|
)
|
user = UserModel(username=username, id=payload.get("user_id"))
|
if user.id == 0:
|
raise HTTPException(
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
detail="用户不存在",
|
headers={"WWW-Authenticate": "Bearer"},
|
)
|
return user
|
except jwt.PyJWTError:
|
raise HTTPException(
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
detail="令牌无效或已过期",
|
headers={"WWW-Authenticate": "Bearer"},
|
)
|
|
|
async def get_current_user_websocket(websocket: WebSocket):
|
token = websocket.query_params.get('token')
|
if token is None:
|
await websocket.close(code=1008)
|
raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
|
try:
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
username: str = payload.get("sub")
|
if username is None:
|
await websocket.close(code=1008)
|
raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
|
user = UserModel(username=username, id=payload.get("user_id"))
|
if user is None:
|
await websocket.close(code=1008)
|
raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
|
return user
|
except jwt.PyJWTError as e:
|
print(e)
|
await websocket.close(code=1008)
|
raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
|
|
|
def format_file_url(agent_id: str, file_url: str, doc_id: str = None, doc_name: str = None) -> str:
|
if file_url:
|
# 对 file_url 进行 URL 编码
|
encoded_file_url = urllib.parse.quote(file_url, safe=':/')
|
return f"./api/files/download/?url={encoded_file_url}&agent_id={agent_id}"
|
|
if doc_id:
|
# 对 doc_id 和 doc_name 进行 URL 编码
|
encoded_doc_id = urllib.parse.quote(doc_id, safe='')
|
encoded_doc_name = urllib.parse.quote(doc_name, safe='')
|
return f"./api/files/download/?doc_id={encoded_doc_id}&doc_name={encoded_doc_name}&agent_id={agent_id}"
|
|
return file_url
|
|
|
def process_files(files, agent_id):
|
"""
|
处理文件列表,格式化每个文件的 URL。
|
|
:param files: 文件列表,每个文件是一个字典
|
:param agent_id: 代理 ID
|
"""
|
if not files:
|
return # 如果文件列表为空,直接返回
|
|
for file in files:
|
if "file_url" in file and file["file_url"]:
|
try:
|
file["file_url"] = format_file_url(agent_id, file["file_url"])
|
except Exception as e:
|
# 记录异常信息,但继续处理其他文件
|
print(f"Error processing file URL: {e}")
|
|
|
if __name__=="__main__":
|
|
files1 = [{"file_url": "aaa.com"}, {"file_url":"bbb.com"}]
|
print(files1)
|
|
process_files(files1,11111)
|
print(files1)
|