zhangqian
2024-10-18 6202db458678153934fb4a31a041c58764a69138
app/api/files.py
@@ -1,5 +1,10 @@
from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, requests, Query
from typing import Optional
import requests
from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.api import Response, get_current_user, ResponseList
from app.config.config import settings
@@ -9,6 +14,7 @@
from app.service.bisheng import BishengService
from app.service.ragflow import RagflowService
from app.service.token import get_ragflow_token, get_bisheng_token
import urllib.parse
router = APIRouter()
@@ -16,7 +22,7 @@
@router.post("/upload/{agent_id}", response_model=Response)
async def upload_file(agent_id: str,
                      file: UploadFile = File(...),
                      chat_id: str = Query(..., description="The ID of the chat"),
                      chat_id: str = Query(None, description="The ID of the chat"),
                      db: Session = Depends(get_db),
                      current_user: UserModel = Depends(get_current_user)
                      ):
@@ -55,3 +61,44 @@
    else:
        return Response(code=200, msg="Unsupported agent type")
@router.get("/download/", response_model=Response)
async def download_file(
        url: Optional[str] = Query(None, description="URL of the file to download for bisheng"),
        agent_id: str = Query(..., description="Agent ID"),
        doc_id: Optional[str] = Query(None, description="Optional doc id for ragflow agents"),
        doc_name: Optional[str] = Query(None, description="Optional doc name for ragflow agents"),
        db: Session = Depends(get_db)
):
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    if not agent:
        return Response(code=404, msg="Agent not found")
    if agent.agent_type == AgentType.BISHENG:
        url = urllib.parse.unquote(url)
        # 从 URL 中提取文件名
        parsed_url = urllib.parse.urlparse(url)
        filename = urllib.parse.unquote(parsed_url.path.split('/')[-1])
        url = url.replace("http://minio:9000", settings.bisheng_base_url)
    elif agent.agent_type == AgentType.RAGFLOW:
        if not doc_id:
            return Response(code=400, msg="doc_id is required")
        url = f"{settings.ragflow_base_url}/v1/document/get/{doc_id}"
        filename = doc_name
    else:
        return Response(code=400, msg="Unsupported agent type")
    try:
        # 发送GET请求获取文件内容
        response = requests.get(url, stream=True)
        response.raise_for_status()  # 检查请求是否成功
        # 返回流式响应
        return StreamingResponse(
            response.iter_content(chunk_size=1024),
            media_type="application/octet-stream",
            headers={"Content-Disposition": f"attachment; filename*=utf-8''{urllib.parse.quote(filename)}"}
        )
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Error downloading file: {e}")