zhangxiao
2024-10-23 81d420d88e87ccbbe3b0e7681bea17f31239fcdb
Merge remote-tracking branch 'origin/master'
6个文件已修改
125 ■■■■■ 已修改文件
app/api/__init__.py 46 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/auth.py 15 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/files.py 49 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/report.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/token_model.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
pip_install.sh 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/__init__.py
@@ -1,3 +1,6 @@
import urllib
from urllib.parse import urlencode
import jwt
from fastapi import FastAPI, Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
@@ -72,4 +75,45 @@
    except jwt.PyJWTError as e:
        print(e)
        await websocket.close(code=1008)
        raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
        raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
def format_file_url(agent_id: str, file_url: str, doc_id: str = None, doc_name: str = None) -> str:
    if file_url:
        # 对 file_url 进行 URL 编码
        encoded_file_url = urllib.parse.quote(file_url, safe=':/')
        return f"./api/files/download/?url={encoded_file_url}&agent_id={agent_id}"
    if doc_id:
        # 对 doc_id 和 doc_name 进行 URL 编码
        encoded_doc_id = urllib.parse.quote(doc_id, safe='')
        encoded_doc_name = urllib.parse.quote(doc_name, safe='')
        return f"./api/files/download/?doc_id={encoded_doc_id}&doc_name={encoded_doc_name}&agent_id={agent_id}"
    return file_url
def process_files(files, agent_id):
    """
    处理文件列表,格式化每个文件的 URL。
    :param files: 文件列表,每个文件是一个字典
    :param agent_id: 代理 ID
    """
    if not files:
        return  # 如果文件列表为空,直接返回
    for file in files:
        if "file_url" in file and file["file_url"]:
            try:
                file["file_url"] = format_file_url(agent_id, file["file_url"])
            except Exception as e:
                # 记录异常信息,但继续处理其他文件
                print(f"Error processing file URL: {e}")
if __name__=="__main__":
    files1 = [{"file_url": "aaa.com"}, {"file_url":"bbb.com"}]
    print(files1)
    process_files(files1,11111)
    print(files1)
app/api/auth.py
@@ -1,10 +1,10 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.api import Response, pwd_context
from app.api import Response, pwd_context, get_current_user
from app.config.config import settings
from app.models.base_model import get_db
from app.models.token_model import upsert_token
from app.models.token_model import upsert_token, get_token
from app.models.user import UserCreate, LoginData
from app.models.user_model import UserModel
from app.service.auth import authenticate_user, create_access_token
@@ -76,3 +76,14 @@
        "username": user.username,
        "nickname": "",
    })
@router.get("/token", response_model=Response)
async def token_api(db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_user)):
    # 查询现有记录
    token = get_token(db, current_user.id)
    if token is None:
        return Response(code=400, msg="token not found")
    return Response(code=200, msg="success", data={
        "ragflow_token": token.ragflow_token,
    })
app/api/files.py
@@ -1,5 +1,10 @@
from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, requests, Query
from typing import Optional
import requests
from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.api import Response, get_current_user, ResponseList
from app.config.config import settings
@@ -9,6 +14,7 @@
from app.service.bisheng import BishengService
from app.service.ragflow import RagflowService
from app.service.token import get_ragflow_token, get_bisheng_token
import urllib.parse
router = APIRouter()
@@ -55,3 +61,44 @@
    else:
        return Response(code=200, msg="Unsupported agent type")
@router.get("/download/", response_model=Response)
async def download_file(
        url: Optional[str] = Query(None, description="URL of the file to download for bisheng"),
        agent_id: str = Query(..., description="Agent ID"),
        doc_id: Optional[str] = Query(None, description="Optional doc id for ragflow agents"),
        doc_name: Optional[str] = Query(None, description="Optional doc name for ragflow agents"),
        db: Session = Depends(get_db)
):
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    if not agent:
        return Response(code=404, msg="Agent not found")
    if agent.agent_type == AgentType.BISHENG:
        url = urllib.parse.unquote(url)
        # 从 URL 中提取文件名
        parsed_url = urllib.parse.urlparse(url)
        filename = urllib.parse.unquote(parsed_url.path.split('/')[-1])
        url = url.replace("http://minio:9000", settings.bisheng_base_url)
    elif agent.agent_type == AgentType.RAGFLOW:
        if not doc_id:
            return Response(code=400, msg="doc_id is required")
        url = f"{settings.ragflow_base_url}/v1/document/get/{doc_id}"
        filename = doc_name
    else:
        return Response(code=400, msg="Unsupported agent type")
    try:
        # 发送GET请求获取文件内容
        response = requests.get(url, stream=True)
        response.raise_for_status()  # 检查请求是否成功
        # 返回流式响应
        return StreamingResponse(
            response.iter_content(chunk_size=1024),
            media_type="application/octet-stream",
            headers={"Content-Disposition": f"attachment; filename*=utf-8''{urllib.parse.quote(filename)}"}
        )
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Error downloading file: {e}")
app/api/report.py
@@ -4,7 +4,7 @@
import asyncio
import websockets
from sqlalchemy.orm import Session
from app.api import get_current_user_websocket, ResponseList, get_current_user
from app.api import get_current_user_websocket, ResponseList, get_current_user, format_file_url, process_files
from app.config.config import settings
from app.models.agent_model import AgentModel, AgentType
from app.models.base_model import get_db
@@ -68,6 +68,7 @@
                            t = "close"
                        else:
                            t = "stream"
                        process_files(files, agent_id)
                        result = {"step_message": steps, "type": t, "files": files}
                        await websocket.send_json(result)
                        print(f"Forwarded to client, {chat_id}: {result}")
app/models/token_model.py
@@ -1,6 +1,7 @@
from datetime import datetime
from typing import Type
from sqlalchemy import Column, Integer, String, DateTime, Text
from sqlalchemy import Column, Integer, DateTime, Text
from sqlalchemy.orm import Session
from app.models.base_model import Base
@@ -49,3 +50,7 @@
    except Exception as e:
        # 异常处理
        db.rollback()  # 回滚事务
def get_token(db: Session, user_id: int) -> Type[TokenModel] | None:
    return db.query(TokenModel).filter_by(user_id=user_id).first()
pip_install.sh
@@ -1,4 +1,5 @@
pip install PyMySQL & pip install fastapi & pip install sqlalchemy & pip install PyJWT & pip install rsa & pip install httpx & pip install uvicorn & pip install bcrypt & pip install PyYAML & pip install pycryptodomex & pip install passlib
pip install werkzeug
pip install xlwings
pip install python-multipart
pip install openpyxl
pip install python-multipart
pip install requests