Merge remote-tracking branch 'origin/master'
| | |
| | | import urllib |
| | | from urllib.parse import urlencode |
| | | |
| | | import jwt |
| | | from fastapi import FastAPI, Depends, HTTPException |
| | | from fastapi.security import OAuth2PasswordBearer |
| | |
| | | except jwt.PyJWTError as e: |
| | | print(e) |
| | | await websocket.close(code=1008) |
| | | raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) |
| | | raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) |
| | | |
| | | |
| | | def format_file_url(agent_id: str, file_url: str, doc_id: str = None, doc_name: str = None) -> str: |
| | | if file_url: |
| | | # 对 file_url 进行 URL 编码 |
| | | encoded_file_url = urllib.parse.quote(file_url, safe=':/') |
| | | return f"./api/files/download/?url={encoded_file_url}&agent_id={agent_id}" |
| | | |
| | | if doc_id: |
| | | # 对 doc_id 和 doc_name 进行 URL 编码 |
| | | encoded_doc_id = urllib.parse.quote(doc_id, safe='') |
| | | encoded_doc_name = urllib.parse.quote(doc_name, safe='') |
| | | return f"./api/files/download/?doc_id={encoded_doc_id}&doc_name={encoded_doc_name}&agent_id={agent_id}" |
| | | |
| | | return file_url |
| | | |
| | | |
| | | def process_files(files, agent_id): |
| | | """ |
| | | 处理文件列表,格式化每个文件的 URL。 |
| | | |
| | | :param files: 文件列表,每个文件是一个字典 |
| | | :param agent_id: 代理 ID |
| | | """ |
| | | if not files: |
| | | return # 如果文件列表为空,直接返回 |
| | | |
| | | for file in files: |
| | | if "file_url" in file and file["file_url"]: |
| | | try: |
| | | file["file_url"] = format_file_url(agent_id, file["file_url"]) |
| | | except Exception as e: |
| | | # 记录异常信息,但继续处理其他文件 |
| | | print(f"Error processing file URL: {e}") |
| | | if __name__=="__main__": |
| | | |
| | | files1 = [{"file_url": "aaa.com"}, {"file_url":"bbb.com"}] |
| | | print(files1) |
| | | |
| | | process_files(files1,11111) |
| | | print(files1) |
| | |
| | | from fastapi import APIRouter, Depends |
| | | from sqlalchemy.orm import Session |
| | | |
| | | from app.api import Response, pwd_context |
| | | from app.api import Response, pwd_context, get_current_user |
| | | from app.config.config import settings |
| | | from app.models.base_model import get_db |
| | | from app.models.token_model import upsert_token |
| | | from app.models.token_model import upsert_token, get_token |
| | | from app.models.user import UserCreate, LoginData |
| | | from app.models.user_model import UserModel |
| | | from app.service.auth import authenticate_user, create_access_token |
| | |
| | | "username": user.username, |
| | | "nickname": "", |
| | | }) |
| | | |
| | | |
| | | @router.get("/token", response_model=Response) |
| | | async def token_api(db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_user)): |
| | | # 查询现有记录 |
| | | token = get_token(db, current_user.id) |
| | | if token is None: |
| | | return Response(code=400, msg="token not found") |
| | | return Response(code=200, msg="success", data={ |
| | | "ragflow_token": token.ragflow_token, |
| | | }) |
| | |
| | | from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, requests, Query |
| | | from typing import Optional |
| | | |
| | | import requests |
| | | from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, Query |
| | | from pydantic import BaseModel |
| | | from sqlalchemy.orm import Session |
| | | from starlette.responses import StreamingResponse |
| | | |
| | | from app.api import Response, get_current_user, ResponseList |
| | | from app.config.config import settings |
| | |
| | | from app.service.bisheng import BishengService |
| | | from app.service.ragflow import RagflowService |
| | | from app.service.token import get_ragflow_token, get_bisheng_token |
| | | import urllib.parse |
| | | |
| | | router = APIRouter() |
| | | |
| | |
| | | |
| | | else: |
| | | return Response(code=200, msg="Unsupported agent type") |
| | | |
| | | |
| | | @router.get("/download/", response_model=Response) |
| | | async def download_file( |
| | | url: Optional[str] = Query(None, description="URL of the file to download for bisheng"), |
| | | agent_id: str = Query(..., description="Agent ID"), |
| | | doc_id: Optional[str] = Query(None, description="Optional doc id for ragflow agents"), |
| | | doc_name: Optional[str] = Query(None, description="Optional doc name for ragflow agents"), |
| | | db: Session = Depends(get_db) |
| | | ): |
| | | agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() |
| | | if not agent: |
| | | return Response(code=404, msg="Agent not found") |
| | | |
| | | if agent.agent_type == AgentType.BISHENG: |
| | | url = urllib.parse.unquote(url) |
| | | # 从 URL 中提取文件名 |
| | | parsed_url = urllib.parse.urlparse(url) |
| | | filename = urllib.parse.unquote(parsed_url.path.split('/')[-1]) |
| | | url = url.replace("http://minio:9000", settings.bisheng_base_url) |
| | | elif agent.agent_type == AgentType.RAGFLOW: |
| | | if not doc_id: |
| | | return Response(code=400, msg="doc_id is required") |
| | | url = f"{settings.ragflow_base_url}/v1/document/get/{doc_id}" |
| | | filename = doc_name |
| | | else: |
| | | return Response(code=400, msg="Unsupported agent type") |
| | | |
| | | try: |
| | | # 发送GET请求获取文件内容 |
| | | response = requests.get(url, stream=True) |
| | | response.raise_for_status() # 检查请求是否成功 |
| | | |
| | | # 返回流式响应 |
| | | return StreamingResponse( |
| | | response.iter_content(chunk_size=1024), |
| | | media_type="application/octet-stream", |
| | | headers={"Content-Disposition": f"attachment; filename*=utf-8''{urllib.parse.quote(filename)}"} |
| | | ) |
| | | except Exception as e: |
| | | raise HTTPException(status_code=400, detail=f"Error downloading file: {e}") |
| | |
| | | import asyncio |
| | | import websockets |
| | | from sqlalchemy.orm import Session |
| | | from app.api import get_current_user_websocket, ResponseList, get_current_user |
| | | from app.api import get_current_user_websocket, ResponseList, get_current_user, format_file_url, process_files |
| | | from app.config.config import settings |
| | | from app.models.agent_model import AgentModel, AgentType |
| | | from app.models.base_model import get_db |
| | |
| | | t = "close" |
| | | else: |
| | | t = "stream" |
| | | process_files(files, agent_id) |
| | | result = {"step_message": steps, "type": t, "files": files} |
| | | await websocket.send_json(result) |
| | | print(f"Forwarded to client, {chat_id}: {result}") |
| | |
| | | from datetime import datetime |
| | | from typing import Type |
| | | |
| | | from sqlalchemy import Column, Integer, String, DateTime, Text |
| | | from sqlalchemy import Column, Integer, DateTime, Text |
| | | from sqlalchemy.orm import Session |
| | | |
| | | from app.models.base_model import Base |
| | |
| | | except Exception as e: |
| | | # 异常处理 |
| | | db.rollback() # 回滚事务 |
| | | |
| | | |
| | | def get_token(db: Session, user_id: int) -> Type[TokenModel] | None: |
| | | return db.query(TokenModel).filter_by(user_id=user_id).first() |
| | |
| | | pip install PyMySQL & pip install fastapi & pip install sqlalchemy & pip install PyJWT & pip install rsa & pip install httpx & pip install uvicorn & pip install bcrypt & pip install PyYAML & pip install pycryptodomex & pip install passlib |
| | | pip install werkzeug |
| | | pip install xlwings |
| | | pip install python-multipart |
| | | pip install openpyxl |
| | | pip install python-multipart |
| | | pip install requests |