zhaoqingang
2024-11-20 67626f1c71d76c5e3d1646259024cb4a452c2890
Merge branch 'master' of http://192.168.5.5:10010/r/rag-gateway
1个文件已添加
5个文件已修改
191 ■■■■ 已修改文件
app/api/chat.py 26 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/files.py 30 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/__init__.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/session_model.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/basic.py 36 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/session.py 80 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/chat.py
@@ -14,6 +14,7 @@
from app.service.basic import BasicService
from app.service.ragflow import RagflowService
from app.service.service_token import get_bisheng_token, get_ragflow_token
from app.service.session import SessionService
router = APIRouter()
@@ -203,12 +204,17 @@
                # 接收前端消息
                message = await websocket.receive_json()
                question = message.get("message")
                SessionService(db).create_session(
                    session_id=chat_id,
                    name=question,
                    agent_id=agent_id,
                    agent_type=AgentType.BASIC
                )
                if not question:
                    await websocket.send_json({"message": "Invalid request", "type": "error"})
                    continue
                service = BasicService(base_url=settings.basic_base_url)
                complete_response = ""
                async for result in service.excel_talk(question, chat_id):
                    try:
                        if result[:5] == "data:":
@@ -217,13 +223,19 @@
                        else:
                            # 否则,保持原样
                            text = result
                        complete_response += text
                        try:
                            json_data = json.loads(complete_response)
                            output = json_data.get("output", "")
                            result = {"message": output, "type": "message"}
                            await websocket.send_json(result | json_data)
                            complete_response = ""
                            data = json.loads(text)
                            output = data.get("output", "")
                            excel_name = data.get("excel_name", "")
                            image_name = data.get("excel_name", "")
                            excel_url = None
                            image_url = None
                            if excel_name:
                                excel_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={excel_name}&file_type=excel"
                            if image_name:
                                image_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={image_name}&file_type=image"
                            result = {"message": output, "type": "message", "excel_url": excel_url, "image_url": image_url}
                            await websocket.send_json(result | data)
                        except json.JSONDecodeError as e:
                            print(f"Error decoding JSON: {e}")
                            print(f"Response text: {text}")
app/api/files.py
@@ -1,3 +1,4 @@
import io
from typing import Optional
import requests
@@ -5,6 +6,7 @@
from pydantic import BaseModel
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from werkzeug.utils import send_file
from app.api import Response, get_current_user, ResponseList
from app.config.config import settings
@@ -76,6 +78,8 @@
        agent_id: str = Query(..., description="Agent ID"),
        doc_id: Optional[str] = Query(None, description="Optional doc id for ragflow agents"),
        doc_name: Optional[str] = Query(None, description="Optional doc name for ragflow agents"),
        file_id:  Optional[str] = Query(None, description="Optional file id for basic agents"),
        file_type:  Optional[str] = Query(None, description="Optional file type for basic agents"),
        db: Session = Depends(get_db)
):
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
@@ -93,6 +97,10 @@
            return Response(code=400, msg="doc_id is required")
        url = f"{settings.fwr_base_url}/v1/document/get/{doc_id}"
        filename = doc_name
    elif agent.agent_type == AgentType.BASIC:
        if agent_id == "basic_excel_talk":
            return await download_basic_file(file_id, file_type)
    else:
        return Response(code=400, msg="Unsupported agent type")
@@ -109,3 +117,25 @@
        )
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Error downloading file: {e}")
async def download_basic_file(file_id: str, file_type: str):
    service = BasicService(base_url=settings.basic_base_url)
    if not file_type or not file_id:
        return Response(code=400, msg="file_type and file_id is required")
    if file_type == "image":
        content, filename, mimetype = await service.excel_talk_image_download(file_id)
        return StreamingResponse(
                io.BytesIO(content),
                media_type=mimetype,
                headers={"Content-Disposition": f"attachment; filename={filename}"}
            )
    elif file_type == "excel":
        content, filename, mimetype = await service.excel_talk_excel_download(file_id)
        return StreamingResponse(
            io.BytesIO(content),
            media_type=mimetype,
            headers={"Content-Disposition": f"attachment; filename={filename}"}
        )
    else:
        return Response(code=400, msg="Unsupported file type")
app/models/__init__.py
@@ -1,3 +1,7 @@
from zoneinfo import ZoneInfo
import pytz
from .agent_model import *
from .dialog_model import *
from .group_model import *
@@ -6,4 +10,10 @@
from .organization_model import *
from .resource_model import *
from .role_model import *
from .user_model import *
from .user_model import *
# 获取当前时区的时间
def current_time():
    tz = pytz.timezone('Asia/Shanghai')
    return datetime.now(tz)
app/models/session_model.py
@@ -3,7 +3,7 @@
from enum import IntEnum
from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime
from app.models import AgentType
from app.models import AgentType, current_time
from app.models.base_model import Base
@@ -13,9 +13,8 @@
    name = Column(String(255))
    agent_id = Column(String(255))
    agent_type = Column(SQLAlchemyEnum(AgentType), nullable=False)  # 目前只存basic的,ragflow和bisheng的调接口获取
    create_date = Column(DateTime)  # 创建时间
    update_date = Column(DateTime)  # 更新时间
    create_date = Column(DateTime, default=current_time)  # 创建时间,默认值为当前时区时间
    update_date = Column(DateTime, default=current_time, onupdate=current_time)  # 更新时间,默认值为当前时区时间,更新时自动更新
    # to_dict 方法
    def to_dict(self):
        return {
app/service/basic.py
@@ -10,21 +10,27 @@
        if response.status_code not in [200, 201]:
            raise Exception(f"Failed to fetch data from API: {response.text}")
        response_data = response.json()
        status_code = response_data.get("status_code", 0)
        if status_code != 200:
            raise Exception(f"Failed to fetch data from API: {response.text}")
        return response_data.get("data", {})
        return response_data
    async def download_from_url(self, url: str, params: dict):
    async def download_from_url(self, url, params=None):
        async with httpx.AsyncClient() as client:
            response = await client.get(url, params=params, stream=True)
            if response.status_code == 200:
                content_disposition = response.headers.get('Content-Disposition')
                filename = content_disposition.split('filename=')[-1].strip(
                    '"') if content_disposition else 'unknown_filename'
                return response.content, filename, response.headers.get('Content-Type')
            else:
                return None, None, None
            async with client.stream('GET', url, params=params) as response:
                if response.status_code == 200:
                    # 获取文件名
                    content_disposition = response.headers.get('Content-Disposition')
                    if content_disposition:
                        filename = content_disposition.split('filename=')[1].strip('"')
                    else:
                        filename = 'unknown_filename'
                    # 获取内容类型
                    content_type = response.headers.get('Content-Type')
                    # 读取文件内容
                    content = await response.aread()
                    return content, filename, content_type
                else:
                    raise Exception(f"Failed to download: {response.status_code}")
    async def excel_talk_image_download(self, file_id: str):
        url = f"{self.base_url}/exceltalk/download/image"
@@ -47,7 +53,7 @@
                files=files,
                params=params
            )
            return await self._check_response(response)
            return self._check_response(response)
    async def excel_talk(self, question: str, chat_id: str):
        url = f"{self.base_url}/exceltalk/talk"
@@ -65,4 +71,4 @@
                        print(e)
                        return
                else:
                    yield f"Error: {response.status_code}"
                    yield f"Error: {response.status_code}"
app/service/session.py
New file
@@ -0,0 +1,80 @@
from sqlalchemy.orm import Session
from app.models import AgentType
from app.models.session_model import SessionModel
class SessionService:
    def __init__(self, db: Session):
        self.db = db
    def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType) -> SessionModel:
        """
        创建一个新的会话记录。
        参数:
            session_id (str): 会话ID。
            name (str): 会话名称。
            agent_id (str): 代理ID。
            agent_type (AgentType): 代理类型。
        返回:
            SessionModel: 新创建的会话模型实例,如果会话ID已存在则返回None。
        """
        existing_session = self.get_session_by_id(session_id)
        if existing_session:
            return None  # 如果会话ID已存在,不进行任何操作
        new_session = SessionModel(
            id=session_id,
            name=name,
            agent_id=agent_id,
            agent_type=agent_type
        )
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
    def get_session_by_id(self, session_id: str) -> SessionModel:
        """
        根据会话ID获取会话记录。
        参数:
            session_id (str): 会话ID。
        返回:
            SessionModel: 查找到的会话模型实例,如果未找到则返回None。
        """
        return self.db.query(SessionModel).filter_by(id=session_id).first()
    def update_session(self, session_id: str, **kwargs) -> SessionModel:
        """
        更新会话记录。
        参数:
            session_id (str): 会话ID。
            kwargs: 需要更新的字段及其值。
        返回:
            SessionModel: 更新后的会话模型实例。
        """
        session = self.get_session_by_id(session_id)
        if session:
            for key, value in kwargs.items():
                setattr(session, key, value)
            self.db.commit()
            self.db.refresh(session)
        return session
    def delete_session(self, session_id: str) -> None:
        """
        删除会话记录。
        参数:
            session_id (str): 会话ID。
        """
        session = self.get_session_by_id(session_id)
        if session:
            self.db.delete(session)
            self.db.commit()