zhangqian
2024-11-19 f37670f13f8faf018a87d5b73b662bb1909ebe87
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import io
from typing import Optional
 
import requests
from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, Query, Form
from pydantic import BaseModel
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from werkzeug.utils import send_file
 
from app.api import Response, get_current_user, ResponseList
from app.config.config import settings
from app.models.agent_model import AgentType, AgentModel
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.basic import BasicService
from app.service.bisheng import BishengService
from app.service.ragflow import RagflowService
from app.service.service_token import get_ragflow_token, get_bisheng_token
import urllib.parse
 
router = APIRouter()
 
 
@router.post("/upload/{agent_id}", response_model=Response)
async def upload_file(agent_id: str,
                      file: UploadFile = File(...),
                      chat_id: str = Query(None, description="The ID of the chat"),
                      db: Session = Depends(get_db),
                      current_user: UserModel = Depends(get_current_user)
                      ):
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    if not agent:
        return Response(code=404, msg="Agent not found")
    # 读取上传的文件内容
    try:
        file_content = await file.read()
    except Exception as e:
        return Response(code=400, msg=str(e))
 
    if agent.agent_type == AgentType.RAGFLOW:
        token = get_ragflow_token(db, current_user.id)
        ragflow_service = RagflowService(base_url=settings.fwr_base_url)
        # 查询会话是否存在,不存在先创建会话
        history = await ragflow_service.get_session_history(token, chat_id)
        if len(history) == 0:
            message = {"role": "user", "message": file.filename}
            await ragflow_service.set_session(token, agent_id, message, chat_id, True)
 
        ragflow_service = RagflowService(base_url=settings.fwr_base_url)
        token = get_ragflow_token(db, current_user.id)
        doc_ids = await ragflow_service.upload_and_parse(token, chat_id, file.filename, file_content)
        return Response(code=200, msg="", data={"doc_ids": doc_ids, "file_name": file.filename})
 
    elif agent.agent_type == AgentType.BISHENG:
        bisheng_service = BishengService(base_url=settings.sgb_base_url)
        try:
            token = get_bisheng_token(db, current_user.id)
            result = await bisheng_service.upload(token, file.filename, file_content)
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
        result["file_name"] = file.filename
        return Response(code=200, msg="", data=result)
    elif agent.agent_type == AgentType.BASIC:
        if agent_id == "basic_excel_talk":
            service = BasicService(base_url=settings.basic_base_url)
            result = await service.excel_talk_upload(chat_id, file.filename, file_content)
 
            return Response(code=200, msg="", data=result)
 
    else:
        return Response(code=200, msg="Unsupported agent type")
 
 
@router.get("/download/", response_model=Response)
async def download_file(
        url: Optional[str] = Query(None, description="URL of the file to download for bisheng"),
        agent_id: str = Query(..., description="Agent ID"),
        doc_id: Optional[str] = Query(None, description="Optional doc id for ragflow agents"),
        doc_name: Optional[str] = Query(None, description="Optional doc name for ragflow agents"),
        file_id:  Optional[str] = Query(None, description="Optional file id for basic agents"),
        file_type:  Optional[str] = Query(None, description="Optional file type for basic agents"),
        db: Session = Depends(get_db)
):
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    if not agent:
        return Response(code=404, msg="Agent not found")
 
    if agent.agent_type == AgentType.BISHENG:
        url = urllib.parse.unquote(url)
        # 从 URL 中提取文件名
        parsed_url = urllib.parse.urlparse(url)
        filename = urllib.parse.unquote(parsed_url.path.split('/')[-1])
        url = url.replace("http://minio:9000", settings.sgb_base_url)
    elif agent.agent_type == AgentType.RAGFLOW:
        if not doc_id:
            return Response(code=400, msg="doc_id is required")
        url = f"{settings.fwr_base_url}/v1/document/get/{doc_id}"
        filename = doc_name
    elif agent.agent_type == AgentType.BASIC:
        if agent_id == "basic_excel_talk":
            return await download_basic_file(file_id, file_type)
 
    else:
        return Response(code=400, msg="Unsupported agent type")
 
    try:
        # 发送GET请求获取文件内容
        response = requests.get(url, stream=True)
        response.raise_for_status()  # 检查请求是否成功
 
        # 返回流式响应
        return StreamingResponse(
            response.iter_content(chunk_size=1024),
            media_type="application/octet-stream",
            headers={"Content-Disposition": f"attachment; filename*=utf-8''{urllib.parse.quote(filename)}"}
        )
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Error downloading file: {e}")
 
 
async def download_basic_file(file_id: str, file_type: str):
    service = BasicService(base_url=settings.basic_base_url)
    if not file_type or not file_id:
        return Response(code=400, msg="file_type and file_id is required")
    if file_type == "image":
        content, filename, mimetype = await service.excel_talk_image_download(file_id)
        return StreamingResponse(
                io.BytesIO(content),
                media_type=mimetype,
                headers={"Content-Disposition": f"attachment; filename={filename}"}
            )
    elif file_type == "excel":
        content, filename, mimetype = await service.excel_talk_excel_download(file_id)
        return StreamingResponse(
            io.BytesIO(content),
            media_type=mimetype,
            headers={"Content-Disposition": f"attachment; filename={filename}"}
        )
    else:
        return Response(code=400, msg="Unsupported file type")