Merge branch 'master' of http://192.168.5.5:10010/r/rag-gateway
| | |
| | | from app.service.basic import BasicService |
| | | from app.service.ragflow import RagflowService |
| | | from app.service.service_token import get_bisheng_token, get_ragflow_token |
| | | from app.service.session import SessionService |
| | | |
| | | router = APIRouter() |
| | | |
| | |
| | | # 接收前端消息 |
| | | message = await websocket.receive_json() |
| | | question = message.get("message") |
| | | SessionService(db).create_session( |
| | | session_id=chat_id, |
| | | name=question, |
| | | agent_id=agent_id, |
| | | agent_type=AgentType.BASIC |
| | | ) |
| | | if not question: |
| | | await websocket.send_json({"message": "Invalid request", "type": "error"}) |
| | | continue |
| | | |
| | | service = BasicService(base_url=settings.basic_base_url) |
| | | complete_response = "" |
| | | async for result in service.excel_talk(question, chat_id): |
| | | try: |
| | | if result[:5] == "data:": |
| | |
| | | else: |
| | | # 否则,保持原样 |
| | | text = result |
| | | complete_response += text |
| | | try: |
| | | json_data = json.loads(complete_response) |
| | | output = json_data.get("output", "") |
| | | result = {"message": output, "type": "message"} |
| | | await websocket.send_json(result | json_data) |
| | | complete_response = "" |
| | | data = json.loads(text) |
| | | output = data.get("output", "") |
| | | excel_name = data.get("excel_name", "") |
| | | image_name = data.get("excel_name", "") |
| | | excel_url = None |
| | | image_url = None |
| | | if excel_name: |
| | | excel_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={excel_name}&file_type=excel" |
| | | if image_name: |
| | | image_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={image_name}&file_type=image" |
| | | result = {"message": output, "type": "message", "excel_url": excel_url, "image_url": image_url} |
| | | await websocket.send_json(result | data) |
| | | except json.JSONDecodeError as e: |
| | | print(f"Error decoding JSON: {e}") |
| | | print(f"Response text: {text}") |
| | |
| | | import io |
| | | from typing import Optional |
| | | |
| | | import requests |
| | |
| | | from pydantic import BaseModel |
| | | from sqlalchemy.orm import Session |
| | | from starlette.responses import StreamingResponse |
| | | from werkzeug.utils import send_file |
| | | |
| | | from app.api import Response, get_current_user, ResponseList |
| | | from app.config.config import settings |
| | |
| | | agent_id: str = Query(..., description="Agent ID"), |
| | | doc_id: Optional[str] = Query(None, description="Optional doc id for ragflow agents"), |
| | | doc_name: Optional[str] = Query(None, description="Optional doc name for ragflow agents"), |
| | | file_id: Optional[str] = Query(None, description="Optional file id for basic agents"), |
| | | file_type: Optional[str] = Query(None, description="Optional file type for basic agents"), |
| | | db: Session = Depends(get_db) |
| | | ): |
| | | agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() |
| | |
| | | return Response(code=400, msg="doc_id is required") |
| | | url = f"{settings.fwr_base_url}/v1/document/get/{doc_id}" |
| | | filename = doc_name |
| | | elif agent.agent_type == AgentType.BASIC: |
| | | if agent_id == "basic_excel_talk": |
| | | return await download_basic_file(file_id, file_type) |
| | | |
| | | else: |
| | | return Response(code=400, msg="Unsupported agent type") |
| | | |
| | |
| | | ) |
| | | except Exception as e: |
| | | raise HTTPException(status_code=400, detail=f"Error downloading file: {e}") |
| | | |
| | | |
| | | async def download_basic_file(file_id: str, file_type: str): |
| | | service = BasicService(base_url=settings.basic_base_url) |
| | | if not file_type or not file_id: |
| | | return Response(code=400, msg="file_type and file_id is required") |
| | | if file_type == "image": |
| | | content, filename, mimetype = await service.excel_talk_image_download(file_id) |
| | | return StreamingResponse( |
| | | io.BytesIO(content), |
| | | media_type=mimetype, |
| | | headers={"Content-Disposition": f"attachment; filename={filename}"} |
| | | ) |
| | | elif file_type == "excel": |
| | | content, filename, mimetype = await service.excel_talk_excel_download(file_id) |
| | | return StreamingResponse( |
| | | io.BytesIO(content), |
| | | media_type=mimetype, |
| | | headers={"Content-Disposition": f"attachment; filename={filename}"} |
| | | ) |
| | | else: |
| | | return Response(code=400, msg="Unsupported file type") |
| | |
| | | from zoneinfo import ZoneInfo |
| | | |
| | | import pytz |
| | | |
| | | from .agent_model import * |
| | | from .dialog_model import * |
| | | from .group_model import * |
| | |
| | | from .organization_model import * |
| | | from .resource_model import * |
| | | from .role_model import * |
| | | from .user_model import * |
| | | from .user_model import * |
| | | |
| | | |
| | | # 获取当前时区的时间 |
| | | def current_time(): |
| | | tz = pytz.timezone('Asia/Shanghai') |
| | | return datetime.now(tz) |
| | |
| | | from enum import IntEnum |
| | | from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime |
| | | |
| | | from app.models import AgentType |
| | | from app.models import AgentType, current_time |
| | | from app.models.base_model import Base |
| | | |
| | | |
| | |
| | | name = Column(String(255)) |
| | | agent_id = Column(String(255)) |
| | | agent_type = Column(SQLAlchemyEnum(AgentType), nullable=False) # 目前只存basic的,ragflow和bisheng的调接口获取 |
| | | create_date = Column(DateTime) # 创建时间 |
| | | update_date = Column(DateTime) # 更新时间 |
| | | |
| | | create_date = Column(DateTime, default=current_time) # 创建时间,默认值为当前时区时间 |
| | | update_date = Column(DateTime, default=current_time, onupdate=current_time) # 更新时间,默认值为当前时区时间,更新时自动更新 |
| | | # to_dict 方法 |
| | | def to_dict(self): |
| | | return { |
| | |
| | | if response.status_code not in [200, 201]: |
| | | raise Exception(f"Failed to fetch data from API: {response.text}") |
| | | response_data = response.json() |
| | | status_code = response_data.get("status_code", 0) |
| | | if status_code != 200: |
| | | raise Exception(f"Failed to fetch data from API: {response.text}") |
| | | return response_data.get("data", {}) |
| | | return response_data |
| | | |
| | | async def download_from_url(self, url: str, params: dict): |
| | | async def download_from_url(self, url, params=None): |
| | | async with httpx.AsyncClient() as client: |
| | | response = await client.get(url, params=params, stream=True) |
| | | if response.status_code == 200: |
| | | content_disposition = response.headers.get('Content-Disposition') |
| | | filename = content_disposition.split('filename=')[-1].strip( |
| | | '"') if content_disposition else 'unknown_filename' |
| | | return response.content, filename, response.headers.get('Content-Type') |
| | | else: |
| | | return None, None, None |
| | | async with client.stream('GET', url, params=params) as response: |
| | | if response.status_code == 200: |
| | | # 获取文件名 |
| | | content_disposition = response.headers.get('Content-Disposition') |
| | | if content_disposition: |
| | | filename = content_disposition.split('filename=')[1].strip('"') |
| | | else: |
| | | filename = 'unknown_filename' |
| | | |
| | | # 获取内容类型 |
| | | content_type = response.headers.get('Content-Type') |
| | | |
| | | # 读取文件内容 |
| | | content = await response.aread() |
| | | return content, filename, content_type |
| | | else: |
| | | raise Exception(f"Failed to download: {response.status_code}") |
| | | |
| | | async def excel_talk_image_download(self, file_id: str): |
| | | url = f"{self.base_url}/exceltalk/download/image" |
| | |
| | | files=files, |
| | | params=params |
| | | ) |
| | | return await self._check_response(response) |
| | | return self._check_response(response) |
| | | |
| | | async def excel_talk(self, question: str, chat_id: str): |
| | | url = f"{self.base_url}/exceltalk/talk" |
| | |
| | | print(e) |
| | | return |
| | | else: |
| | | yield f"Error: {response.status_code}" |
| | | yield f"Error: {response.status_code}" |
New file |
| | |
| | | from sqlalchemy.orm import Session |
| | | |
| | | from app.models import AgentType |
| | | from app.models.session_model import SessionModel |
| | | |
| | | |
| | | class SessionService: |
| | | def __init__(self, db: Session): |
| | | self.db = db |
| | | |
| | | def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType) -> SessionModel: |
| | | """ |
| | | 创建一个新的会话记录。 |
| | | |
| | | 参数: |
| | | session_id (str): 会话ID。 |
| | | name (str): 会话名称。 |
| | | agent_id (str): 代理ID。 |
| | | agent_type (AgentType): 代理类型。 |
| | | |
| | | 返回: |
| | | SessionModel: 新创建的会话模型实例,如果会话ID已存在则返回None。 |
| | | """ |
| | | existing_session = self.get_session_by_id(session_id) |
| | | if existing_session: |
| | | return None # 如果会话ID已存在,不进行任何操作 |
| | | |
| | | new_session = SessionModel( |
| | | id=session_id, |
| | | name=name, |
| | | agent_id=agent_id, |
| | | agent_type=agent_type |
| | | ) |
| | | self.db.add(new_session) |
| | | self.db.commit() |
| | | self.db.refresh(new_session) |
| | | return new_session |
| | | |
| | | def get_session_by_id(self, session_id: str) -> SessionModel: |
| | | """ |
| | | 根据会话ID获取会话记录。 |
| | | |
| | | 参数: |
| | | session_id (str): 会话ID。 |
| | | |
| | | 返回: |
| | | SessionModel: 查找到的会话模型实例,如果未找到则返回None。 |
| | | """ |
| | | return self.db.query(SessionModel).filter_by(id=session_id).first() |
| | | |
| | | def update_session(self, session_id: str, **kwargs) -> SessionModel: |
| | | """ |
| | | 更新会话记录。 |
| | | |
| | | 参数: |
| | | session_id (str): 会话ID。 |
| | | kwargs: 需要更新的字段及其值。 |
| | | |
| | | 返回: |
| | | SessionModel: 更新后的会话模型实例。 |
| | | """ |
| | | session = self.get_session_by_id(session_id) |
| | | if session: |
| | | for key, value in kwargs.items(): |
| | | setattr(session, key, value) |
| | | self.db.commit() |
| | | self.db.refresh(session) |
| | | return session |
| | | |
| | | def delete_session(self, session_id: str) -> None: |
| | | """ |
| | | 删除会话记录。 |
| | | |
| | | 参数: |
| | | session_id (str): 会话ID。 |
| | | """ |
| | | session = self.get_session_by_id(session_id) |
| | | if session: |
| | | self.db.delete(session) |
| | | self.db.commit() |