from typing import Optional import requests from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, Query from pydantic import BaseModel from sqlalchemy.orm import Session from starlette.responses import StreamingResponse from app.api import Response, get_current_user, ResponseList from app.config.config import settings from app.models.agent_model import AgentType, AgentModel from app.models.base_model import get_db from app.models.user_model import UserModel from app.service.bisheng import BishengService from app.service.ragflow import RagflowService from app.service.service_token import get_ragflow_token, get_bisheng_token import urllib.parse router = APIRouter() @router.post("/upload/{agent_id}", response_model=Response) async def upload_file(agent_id: str, file: UploadFile = File(...), chat_id: str = Query(None, description="The ID of the chat"), db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_user) ): agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() if not agent: return Response(code=404, msg="Agent not found") # 读取上传的文件内容 try: file_content = await file.read() except Exception as e: return Response(code=400, msg=str(e)) if agent.agent_type == AgentType.RAGFLOW: token = get_ragflow_token(db, current_user.id) ragflow_service = RagflowService(base_url=settings.fwr_base_url) # 查询会话是否存在,不存在先创建会话 history = await ragflow_service.get_session_history(token, chat_id) if len(history) == 0: message = {"role": "user", "message": file.filename} await ragflow_service.set_session(token, agent_id, message, chat_id, True) ragflow_service = RagflowService(base_url=settings.fwr_base_url) token = get_ragflow_token(db, current_user.id) doc_ids = await ragflow_service.upload_and_parse(token, chat_id, file.filename, file_content) return Response(code=200, msg="", data={"doc_ids": doc_ids, "file_name": file.filename}) elif agent.agent_type == AgentType.BISHENG: bisheng_service = BishengService(base_url=settings.sgb_base_url) try: token = get_bisheng_token(db, current_user.id) result = await bisheng_service.upload(token, file.filename, file_content) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) result["file_name"] = file.filename return Response(code=200, msg="", data=result) else: return Response(code=200, msg="Unsupported agent type") @router.get("/download/", response_model=Response) async def download_file( url: Optional[str] = Query(None, description="URL of the file to download for bisheng"), agent_id: str = Query(..., description="Agent ID"), doc_id: Optional[str] = Query(None, description="Optional doc id for ragflow agents"), doc_name: Optional[str] = Query(None, description="Optional doc name for ragflow agents"), db: Session = Depends(get_db) ): agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() if not agent: return Response(code=404, msg="Agent not found") if agent.agent_type == AgentType.BISHENG: url = urllib.parse.unquote(url) # 从 URL 中提取文件名 parsed_url = urllib.parse.urlparse(url) filename = urllib.parse.unquote(parsed_url.path.split('/')[-1]) url = url.replace("http://minio:9000", settings.sgb_base_url) elif agent.agent_type == AgentType.RAGFLOW: if not doc_id: return Response(code=400, msg="doc_id is required") url = f"{settings.fwr_base_url}/v1/document/get/{doc_id}" filename = doc_name else: return Response(code=400, msg="Unsupported agent type") try: # 发送GET请求获取文件内容 response = requests.get(url, stream=True) response.raise_for_status() # 检查请求是否成功 # 返回流式响应 return StreamingResponse( response.iter_content(chunk_size=1024), media_type="application/octet-stream", headers={"Content-Disposition": f"attachment; filename*=utf-8''{urllib.parse.quote(filename)}"} ) except Exception as e: raise HTTPException(status_code=400, detail=f"Error downloading file: {e}")