zhaoqingang
2024-11-19 13c3fdf08558b6ce01dcbdc7716bd77dc9b2e88c
Merge branch 'master' of http://192.168.5.5:10010/r/rag-gateway
5个文件已修改
3个文件已添加
225 ■■■■■ 已修改文件
app/api/agent.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/chat.py 40 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/excel_talk.py 69 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/files.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/config/config.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/config/config.yaml 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/session_model.py 28 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/basic.py 68 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/agent.py
@@ -10,6 +10,7 @@
from app.config.config import settings
from app.models.agent_model import AgentType, AgentModel
from app.models.base_model import get_db
from app.models.session_model import SessionModel
from app.models.user_model import UserModel
from app.service.bisheng import BishengService
from app.service.dialog import get_session_history
@@ -57,6 +58,12 @@
            raise HTTPException(status_code=500, detail=str(e))
        return ResponseList(code=200, msg="", data=result)
    elif agent.agent_type == AgentType.BASIC:
        offset = (page - 1) * limit
        records = db.query(SessionModel).filter(SessionModel.agent_id == agent_id).offset(offset).limit(limit).all()
        result = [item.to_dict() for item in records]
        return ResponseList(code=200, msg="", data=result)
    else:
        return ResponseList(code=200, msg="Unsupported agent type")
app/api/chat.py
@@ -11,6 +11,7 @@
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.dialog import update_session_history
from app.service.basic import BasicService
from app.service.ragflow import RagflowService
from app.service.service_token import get_bisheng_token, get_ragflow_token
@@ -196,6 +197,45 @@
                            await task
                        except asyncio.CancelledError:
                            pass
    elif agent_type == AgentType.BASIC:
        try:
            while True:
                # 接收前端消息
                message = await websocket.receive_json()
                question = message.get("message")
                if not question:
                    await websocket.send_json({"message": "Invalid request", "type": "error"})
                    continue
                service = BasicService(base_url=settings.basic_base_url)
                complete_response = ""
                async for result in service.excel_talk(question, chat_id):
                    try:
                        if result[:5] == "data:":
                            # 如果是,则截取掉前5个字符,并去除首尾空白符
                            text = result[5:].strip()
                        else:
                            # 否则,保持原样
                            text = result
                        complete_response += text
                        try:
                            json_data = json.loads(complete_response)
                            output = json_data.get("output", "")
                            result = {"message": output, "type": "message"}
                            await websocket.send_json(result | json_data)
                            complete_response = ""
                        except json.JSONDecodeError as e:
                            print(f"Error decoding JSON: {e}")
                            print(f"Response text: {text}")
                    except Exception as e2:
                        result = {"message": f"内部错误: {e2}", "type": "close"}
                        await websocket.send_json(result)
                        print(f"Error process message of basic agent: {e2}")
        except Exception as e:
            await websocket.send_json({"message": str(e), "type": "error"})
        finally:
            await websocket.close()
            print(f"Client {agent_id} disconnected")
    else:
        ret = {"message": "Agent not found", "type": "close"}
        await websocket.send_json(ret)
app/api/excel_talk.py
New file
@@ -0,0 +1,69 @@
import asyncio
import json
from enum import Enum
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from starlette.websockets import WebSocket, WebSocketDisconnect
from app.api import get_current_user_websocket
from app.config.config import settings
from app.models import UserModel, AgentModel
from app.models.base_model import get_db
from app.service.basic import BasicService
router = APIRouter()
# class CompletionRequest(BaseModel):
#     id: Optional[str] = None
#     app_id: str
#     message: str
#
# class DownloadRequest(BaseModel):
#     file_id: str
#     app_id: str
#     file_type: Optional[str] = None
class AdvancedAgentID(Enum):
    EXCEL_TALK = "excel_talk"
    QUESTIONS_TALK = "questions_talk"
@router.websocket("/ws/{agent_id}/{chat_id}")
async def handle_client(websocket: WebSocket,
                        agent_id: str,
                        chat_id: str,
                        current_user: UserModel = Depends(get_current_user_websocket),
                        db: Session = Depends(get_db)):
    await websocket.accept()
    print(f"Client {agent_id} connected")
    service = BasicService(base_url=settings.basic_base_url)
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    if not agent:
        ret = {"message": "Agent not found", "type": "close"}
        await websocket.send_json(ret)
        return
    try:
        while True:
            # 接收前端消息
            message = await websocket.receive_json()
            question = message.get("message")
            if not question:
                await websocket.send_json({"message": "Invalid request", "type": "error"})
                continue
            # 调用 excel_talk 方法
            result = await service.excel_talk(question, chat_id)
            # 将结果发送回前端
            await websocket.send_json({"message": result, "type": "response"})
    except Exception as e:
        await websocket.send_json({"message": str(e), "type": "error"})
    finally:
        await websocket.close()
        print(f"Client {agent_id} disconnected")
app/api/files.py
@@ -1,7 +1,7 @@
from typing import Optional
import requests
from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, Query
from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, Query, Form
from pydantic import BaseModel
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
@@ -11,6 +11,7 @@
from app.models.agent_model import AgentType, AgentModel
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.basic import BasicService
from app.service.bisheng import BishengService
from app.service.ragflow import RagflowService
from app.service.service_token import get_ragflow_token, get_bisheng_token
@@ -58,6 +59,12 @@
            raise HTTPException(status_code=500, detail=str(e))
        result["file_name"] = file.filename
        return Response(code=200, msg="", data=result)
    elif agent.agent_type == AgentType.BASIC:
        if agent_id == "basic_excel_talk":
            service = BasicService(base_url=settings.basic_base_url)
            result = await service.excel_talk_upload(chat_id, file.filename, file_content)
            return Response(code=200, msg="", data=result)
    else:
        return Response(code=200, msg="Unsupported agent type")
app/config/config.py
@@ -15,6 +15,7 @@
    PUBLIC_KEY: str
    PRIVATE_KEY: str
    PASSWORD_KEY: str
    basic_base_url: str = ''
    def __init__(self, **kwargs):
        # Check if all required fields are provided and set them
        for field in self.__annotations__.keys():
app/config/config.yaml
@@ -12,4 +12,5 @@
PRIVATE_KEY: str
fetch_sgb_agent: 报告生成,文档智能
fetch_fwr_agent: 知识问答,智能问答
PASSWORD_KEY: VKinqB-8XMrwCLLrcf_PyHyo12_4PVKvWzaHjNFions=
PASSWORD_KEY: VKinqB-8XMrwCLLrcf_PyHyo12_4PVKvWzaHjNFions=
basic_base_url: http://192.168.20.231:8000
app/models/session_model.py
New file
@@ -0,0 +1,28 @@
import json
from datetime import datetime
from enum import IntEnum
from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime
from app.models import AgentType
from app.models.base_model import Base
class SessionModel(Base):
    __tablename__ = "sessions"
    id = Column(String(255), primary_key=True)
    name = Column(String(255))
    agent_id = Column(String(255))
    agent_type = Column(SQLAlchemyEnum(AgentType), nullable=False)  # 目前只存basic的,ragflow和bisheng的调接口获取
    create_date = Column(DateTime)  # 创建时间
    update_date = Column(DateTime)  # 更新时间
    # to_dict 方法
    def to_dict(self):
        return {
            'id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'create_date': self.create_date,
            'update_date': self.update_date,
        }
app/service/basic.py
New file
@@ -0,0 +1,68 @@
import httpx
class BasicService:
    def __init__(self, base_url: str):
        self.base_url = base_url
    def _check_response(self, response: httpx.Response):
        """检查响应并处理错误"""
        if response.status_code not in [200, 201]:
            raise Exception(f"Failed to fetch data from API: {response.text}")
        response_data = response.json()
        status_code = response_data.get("status_code", 0)
        if status_code != 200:
            raise Exception(f"Failed to fetch data from API: {response.text}")
        return response_data.get("data", {})
    async def download_from_url(self, url: str, params: dict):
        async with httpx.AsyncClient() as client:
            response = await client.get(url, params=params, stream=True)
            if response.status_code == 200:
                content_disposition = response.headers.get('Content-Disposition')
                filename = content_disposition.split('filename=')[-1].strip(
                    '"') if content_disposition else 'unknown_filename'
                return response.content, filename, response.headers.get('Content-Type')
            else:
                return None, None, None
    async def excel_talk_image_download(self, file_id: str):
        url = f"{self.base_url}/exceltalk/download/image"
        return await self.download_from_url(url, params={'images_name': file_id})
    async def excel_talk_excel_download(self, file_id: str):
        url = f"{self.base_url}/exceltalk/download/excel"
        return await self.download_from_url(url, params={'excel_name': file_id})
    async def excel_talk_upload(self, chat_id: str, filename: str, file_content: bytes):
        url = f"{self.base_url}/exceltalk/upload/files"
        params = {'chat_id': chat_id, 'is_col': '0'}
        # 创建 FormData 对象
        files = [('files', (filename, file_content, 'application/octet-stream'))]
        async with httpx.AsyncClient() as client:
            response = await client.post(
                url,
                files=files,
                params=params
            )
            return await self._check_response(response)
    async def excel_talk(self, question: str, chat_id: str):
        url = f"{self.base_url}/exceltalk/talk"
        params = {'chat_id': chat_id}
        data = {"query": question}
        headers = {'Content-Type': 'application/json'}
        async with httpx.AsyncClient(timeout=300.0) as client:
            async with client.stream("POST", url, params=params, json=data, headers=headers) as response:
                if response.status_code == 200:
                    try:
                        async for answer in response.aiter_text():
                            print(f"response of ragflow chat: {answer}")
                            yield answer
                    except GeneratorExit as e:
                        print(e)
                        return
                else:
                    yield f"Error: {response.status_code}"