Merge branch 'master' of http://192.168.5.5:10010/r/rag-gateway
| | |
| | | from app.config.config import settings |
| | | from app.models.agent_model import AgentType, AgentModel |
| | | from app.models.base_model import get_db |
| | | from app.models.session_model import SessionModel |
| | | from app.models.user_model import UserModel |
| | | from app.service.bisheng import BishengService |
| | | from app.service.dialog import get_session_history |
| | |
| | | raise HTTPException(status_code=500, detail=str(e)) |
| | | return ResponseList(code=200, msg="", data=result) |
| | | |
| | | elif agent.agent_type == AgentType.BASIC: |
| | | offset = (page - 1) * limit |
| | | records = db.query(SessionModel).filter(SessionModel.agent_id == agent_id).offset(offset).limit(limit).all() |
| | | result = [item.to_dict() for item in records] |
| | | return ResponseList(code=200, msg="", data=result) |
| | | |
| | | else: |
| | | return ResponseList(code=200, msg="Unsupported agent type") |
| | | |
| | |
| | | from app.models.base_model import get_db |
| | | from app.models.user_model import UserModel |
| | | from app.service.dialog import update_session_history |
| | | from app.service.basic import BasicService |
| | | from app.service.ragflow import RagflowService |
| | | from app.service.service_token import get_bisheng_token, get_ragflow_token |
| | | |
| | |
| | | await task |
| | | except asyncio.CancelledError: |
| | | pass |
| | | elif agent_type == AgentType.BASIC: |
| | | try: |
| | | while True: |
| | | # 接收前端消息 |
| | | message = await websocket.receive_json() |
| | | question = message.get("message") |
| | | if not question: |
| | | await websocket.send_json({"message": "Invalid request", "type": "error"}) |
| | | continue |
| | | |
| | | service = BasicService(base_url=settings.basic_base_url) |
| | | complete_response = "" |
| | | async for result in service.excel_talk(question, chat_id): |
| | | try: |
| | | if result[:5] == "data:": |
| | | # 如果是,则截取掉前5个字符,并去除首尾空白符 |
| | | text = result[5:].strip() |
| | | else: |
| | | # 否则,保持原样 |
| | | text = result |
| | | complete_response += text |
| | | try: |
| | | json_data = json.loads(complete_response) |
| | | output = json_data.get("output", "") |
| | | result = {"message": output, "type": "message"} |
| | | await websocket.send_json(result | json_data) |
| | | complete_response = "" |
| | | except json.JSONDecodeError as e: |
| | | print(f"Error decoding JSON: {e}") |
| | | print(f"Response text: {text}") |
| | | except Exception as e2: |
| | | result = {"message": f"内部错误: {e2}", "type": "close"} |
| | | await websocket.send_json(result) |
| | | print(f"Error process message of basic agent: {e2}") |
| | | except Exception as e: |
| | | await websocket.send_json({"message": str(e), "type": "error"}) |
| | | finally: |
| | | await websocket.close() |
| | | print(f"Client {agent_id} disconnected") |
| | | else: |
| | | ret = {"message": "Agent not found", "type": "close"} |
| | | await websocket.send_json(ret) |
New file |
| | |
| | | import asyncio |
| | | import json |
| | | from enum import Enum |
| | | |
| | | from fastapi import APIRouter, Depends |
| | | from sqlalchemy.orm import Session |
| | | from starlette.websockets import WebSocket, WebSocketDisconnect |
| | | |
| | | from app.api import get_current_user_websocket |
| | | from app.config.config import settings |
| | | from app.models import UserModel, AgentModel |
| | | from app.models.base_model import get_db |
| | | from app.service.basic import BasicService |
| | | |
| | | router = APIRouter() |
| | | |
| | | # class CompletionRequest(BaseModel): |
| | | # id: Optional[str] = None |
| | | # app_id: str |
| | | # message: str |
| | | # |
| | | # class DownloadRequest(BaseModel): |
| | | # file_id: str |
| | | # app_id: str |
| | | # file_type: Optional[str] = None |
| | | |
| | | |
| | | class AdvancedAgentID(Enum): |
| | | EXCEL_TALK = "excel_talk" |
| | | QUESTIONS_TALK = "questions_talk" |
| | | |
| | | @router.websocket("/ws/{agent_id}/{chat_id}") |
| | | async def handle_client(websocket: WebSocket, |
| | | agent_id: str, |
| | | chat_id: str, |
| | | current_user: UserModel = Depends(get_current_user_websocket), |
| | | db: Session = Depends(get_db)): |
| | | await websocket.accept() |
| | | print(f"Client {agent_id} connected") |
| | | |
| | | service = BasicService(base_url=settings.basic_base_url) |
| | | |
| | | agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() |
| | | if not agent: |
| | | ret = {"message": "Agent not found", "type": "close"} |
| | | await websocket.send_json(ret) |
| | | return |
| | | try: |
| | | while True: |
| | | # 接收前端消息 |
| | | message = await websocket.receive_json() |
| | | question = message.get("message") |
| | | if not question: |
| | | await websocket.send_json({"message": "Invalid request", "type": "error"}) |
| | | continue |
| | | |
| | | # 调用 excel_talk 方法 |
| | | result = await service.excel_talk(question, chat_id) |
| | | |
| | | # 将结果发送回前端 |
| | | await websocket.send_json({"message": result, "type": "response"}) |
| | | except Exception as e: |
| | | await websocket.send_json({"message": str(e), "type": "error"}) |
| | | finally: |
| | | await websocket.close() |
| | | print(f"Client {agent_id} disconnected") |
| | | |
| | | |
| | | |
| | |
| | | from typing import Optional |
| | | |
| | | import requests |
| | | from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, Query |
| | | from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, Query, Form |
| | | from pydantic import BaseModel |
| | | from sqlalchemy.orm import Session |
| | | from starlette.responses import StreamingResponse |
| | |
| | | from app.models.agent_model import AgentType, AgentModel |
| | | from app.models.base_model import get_db |
| | | from app.models.user_model import UserModel |
| | | from app.service.basic import BasicService |
| | | from app.service.bisheng import BishengService |
| | | from app.service.ragflow import RagflowService |
| | | from app.service.service_token import get_ragflow_token, get_bisheng_token |
| | |
| | | raise HTTPException(status_code=500, detail=str(e)) |
| | | result["file_name"] = file.filename |
| | | return Response(code=200, msg="", data=result) |
| | | elif agent.agent_type == AgentType.BASIC: |
| | | if agent_id == "basic_excel_talk": |
| | | service = BasicService(base_url=settings.basic_base_url) |
| | | result = await service.excel_talk_upload(chat_id, file.filename, file_content) |
| | | |
| | | return Response(code=200, msg="", data=result) |
| | | |
| | | else: |
| | | return Response(code=200, msg="Unsupported agent type") |
| | |
| | | PUBLIC_KEY: str |
| | | PRIVATE_KEY: str |
| | | PASSWORD_KEY: str |
| | | basic_base_url: str = '' |
| | | def __init__(self, **kwargs): |
| | | # Check if all required fields are provided and set them |
| | | for field in self.__annotations__.keys(): |
| | |
| | | PRIVATE_KEY: str |
| | | fetch_sgb_agent: 报告生成,文档智能 |
| | | fetch_fwr_agent: 知识问答,智能问答 |
| | | PASSWORD_KEY: VKinqB-8XMrwCLLrcf_PyHyo12_4PVKvWzaHjNFions= |
| | | PASSWORD_KEY: VKinqB-8XMrwCLLrcf_PyHyo12_4PVKvWzaHjNFions= |
| | | basic_base_url: http://192.168.20.231:8000 |
New file |
| | |
| | | import json |
| | | from datetime import datetime |
| | | from enum import IntEnum |
| | | from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime |
| | | |
| | | from app.models import AgentType |
| | | from app.models.base_model import Base |
| | | |
| | | |
| | | class SessionModel(Base): |
| | | __tablename__ = "sessions" |
| | | id = Column(String(255), primary_key=True) |
| | | name = Column(String(255)) |
| | | agent_id = Column(String(255)) |
| | | agent_type = Column(SQLAlchemyEnum(AgentType), nullable=False) # 目前只存basic的,ragflow和bisheng的调接口获取 |
| | | create_date = Column(DateTime) # 创建时间 |
| | | update_date = Column(DateTime) # 更新时间 |
| | | |
| | | # to_dict 方法 |
| | | def to_dict(self): |
| | | return { |
| | | 'id': self.id, |
| | | 'name': self.name, |
| | | 'agent_type': self.agent_type, |
| | | 'agent_id': self.agent_id, |
| | | 'create_date': self.create_date, |
| | | 'update_date': self.update_date, |
| | | } |
New file |
| | |
| | | import httpx |
| | | |
| | | |
| | | class BasicService: |
| | | def __init__(self, base_url: str): |
| | | self.base_url = base_url |
| | | |
| | | def _check_response(self, response: httpx.Response): |
| | | """检查响应并处理错误""" |
| | | if response.status_code not in [200, 201]: |
| | | raise Exception(f"Failed to fetch data from API: {response.text}") |
| | | response_data = response.json() |
| | | status_code = response_data.get("status_code", 0) |
| | | if status_code != 200: |
| | | raise Exception(f"Failed to fetch data from API: {response.text}") |
| | | return response_data.get("data", {}) |
| | | |
| | | async def download_from_url(self, url: str, params: dict): |
| | | async with httpx.AsyncClient() as client: |
| | | response = await client.get(url, params=params, stream=True) |
| | | if response.status_code == 200: |
| | | content_disposition = response.headers.get('Content-Disposition') |
| | | filename = content_disposition.split('filename=')[-1].strip( |
| | | '"') if content_disposition else 'unknown_filename' |
| | | return response.content, filename, response.headers.get('Content-Type') |
| | | else: |
| | | return None, None, None |
| | | |
| | | async def excel_talk_image_download(self, file_id: str): |
| | | url = f"{self.base_url}/exceltalk/download/image" |
| | | return await self.download_from_url(url, params={'images_name': file_id}) |
| | | |
| | | async def excel_talk_excel_download(self, file_id: str): |
| | | url = f"{self.base_url}/exceltalk/download/excel" |
| | | return await self.download_from_url(url, params={'excel_name': file_id}) |
| | | |
| | | async def excel_talk_upload(self, chat_id: str, filename: str, file_content: bytes): |
| | | url = f"{self.base_url}/exceltalk/upload/files" |
| | | params = {'chat_id': chat_id, 'is_col': '0'} |
| | | |
| | | # 创建 FormData 对象 |
| | | files = [('files', (filename, file_content, 'application/octet-stream'))] |
| | | |
| | | async with httpx.AsyncClient() as client: |
| | | response = await client.post( |
| | | url, |
| | | files=files, |
| | | params=params |
| | | ) |
| | | return await self._check_response(response) |
| | | |
| | | async def excel_talk(self, question: str, chat_id: str): |
| | | url = f"{self.base_url}/exceltalk/talk" |
| | | params = {'chat_id': chat_id} |
| | | data = {"query": question} |
| | | headers = {'Content-Type': 'application/json'} |
| | | async with httpx.AsyncClient(timeout=300.0) as client: |
| | | async with client.stream("POST", url, params=params, json=data, headers=headers) as response: |
| | | if response.status_code == 200: |
| | | try: |
| | | async for answer in response.aiter_text(): |
| | | print(f"response of ragflow chat: {answer}") |
| | | yield answer |
| | | except GeneratorExit as e: |
| | | print(e) |
| | | return |
| | | else: |
| | | yield f"Error: {response.status_code}" |