From 67626f1c71d76c5e3d1646259024cb4a452c2890 Mon Sep 17 00:00:00 2001 From: zhaoqingang <zhaoqg0118@163.com> Date: 星期三, 20 十一月 2024 15:59:11 +0800 Subject: [PATCH] Merge branch 'master' of http://192.168.5.5:10010/r/rag-gateway --- app/models/__init__.py | 12 ++ app/models/session_model.py | 7 - app/api/chat.py | 26 ++++- app/service/basic.py | 36 +++++--- app/service/session.py | 80 ++++++++++++++++++++ app/api/files.py | 30 +++++++ 6 files changed, 164 insertions(+), 27 deletions(-) diff --git a/app/api/chat.py b/app/api/chat.py index b5bfd6a..e0abd8d 100644 --- a/app/api/chat.py +++ b/app/api/chat.py @@ -14,6 +14,7 @@ from app.service.basic import BasicService from app.service.ragflow import RagflowService from app.service.service_token import get_bisheng_token, get_ragflow_token +from app.service.session import SessionService router = APIRouter() @@ -203,12 +204,17 @@ # 鎺ユ敹鍓嶇娑堟伅 message = await websocket.receive_json() question = message.get("message") + SessionService(db).create_session( + session_id=chat_id, + name=question, + agent_id=agent_id, + agent_type=AgentType.BASIC + ) if not question: await websocket.send_json({"message": "Invalid request", "type": "error"}) continue service = BasicService(base_url=settings.basic_base_url) - complete_response = "" async for result in service.excel_talk(question, chat_id): try: if result[:5] == "data:": @@ -217,13 +223,19 @@ else: # 鍚﹀垯锛屼繚鎸佸師鏍� text = result - complete_response += text try: - json_data = json.loads(complete_response) - output = json_data.get("output", "") - result = {"message": output, "type": "message"} - await websocket.send_json(result | json_data) - complete_response = "" + data = json.loads(text) + output = data.get("output", "") + excel_name = data.get("excel_name", "") + image_name = data.get("excel_name", "") + excel_url = None + image_url = None + if excel_name: + excel_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={excel_name}&file_type=excel" + if image_name: + image_url = f"/api/files/download/?agent_id=basic_excel_talk&file_id={image_name}&file_type=image" + result = {"message": output, "type": "message", "excel_url": excel_url, "image_url": image_url} + await websocket.send_json(result | data) except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") print(f"Response text: {text}") diff --git a/app/api/files.py b/app/api/files.py index fe80f4a..968b6b3 100644 --- a/app/api/files.py +++ b/app/api/files.py @@ -1,3 +1,4 @@ +import io from typing import Optional import requests @@ -5,6 +6,7 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from starlette.responses import StreamingResponse +from werkzeug.utils import send_file from app.api import Response, get_current_user, ResponseList from app.config.config import settings @@ -76,6 +78,8 @@ agent_id: str = Query(..., description="Agent ID"), doc_id: Optional[str] = Query(None, description="Optional doc id for ragflow agents"), doc_name: Optional[str] = Query(None, description="Optional doc name for ragflow agents"), + file_id: Optional[str] = Query(None, description="Optional file id for basic agents"), + file_type: Optional[str] = Query(None, description="Optional file type for basic agents"), db: Session = Depends(get_db) ): agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() @@ -93,6 +97,10 @@ return Response(code=400, msg="doc_id is required") url = f"{settings.fwr_base_url}/v1/document/get/{doc_id}" filename = doc_name + elif agent.agent_type == AgentType.BASIC: + if agent_id == "basic_excel_talk": + return await download_basic_file(file_id, file_type) + else: return Response(code=400, msg="Unsupported agent type") @@ -109,3 +117,25 @@ ) except Exception as e: raise HTTPException(status_code=400, detail=f"Error downloading file: {e}") + + +async def download_basic_file(file_id: str, file_type: str): + service = BasicService(base_url=settings.basic_base_url) + if not file_type or not file_id: + return Response(code=400, msg="file_type and file_id is required") + if file_type == "image": + content, filename, mimetype = await service.excel_talk_image_download(file_id) + return StreamingResponse( + io.BytesIO(content), + media_type=mimetype, + headers={"Content-Disposition": f"attachment; filename={filename}"} + ) + elif file_type == "excel": + content, filename, mimetype = await service.excel_talk_excel_download(file_id) + return StreamingResponse( + io.BytesIO(content), + media_type=mimetype, + headers={"Content-Disposition": f"attachment; filename={filename}"} + ) + else: + return Response(code=400, msg="Unsupported file type") \ No newline at end of file diff --git a/app/models/__init__.py b/app/models/__init__.py index 1ec93e6..008613d 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -1,3 +1,7 @@ +from zoneinfo import ZoneInfo + +import pytz + from .agent_model import * from .dialog_model import * from .group_model import * @@ -6,4 +10,10 @@ from .organization_model import * from .resource_model import * from .role_model import * -from .user_model import * \ No newline at end of file +from .user_model import * + + +# 鑾峰彇褰撳墠鏃跺尯鐨勬椂闂� +def current_time(): + tz = pytz.timezone('Asia/Shanghai') + return datetime.now(tz) diff --git a/app/models/session_model.py b/app/models/session_model.py index 21bfb7e..44d0b74 100644 --- a/app/models/session_model.py +++ b/app/models/session_model.py @@ -3,7 +3,7 @@ from enum import IntEnum from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime -from app.models import AgentType +from app.models import AgentType, current_time from app.models.base_model import Base @@ -13,9 +13,8 @@ name = Column(String(255)) agent_id = Column(String(255)) agent_type = Column(SQLAlchemyEnum(AgentType), nullable=False) # 鐩墠鍙瓨basic鐨勶紝ragflow鍜宐isheng鐨勮皟鎺ュ彛鑾峰彇 - create_date = Column(DateTime) # 鍒涘缓鏃堕棿 - update_date = Column(DateTime) # 鏇存柊鏃堕棿 - + create_date = Column(DateTime, default=current_time) # 鍒涘缓鏃堕棿锛岄粯璁ゅ�间负褰撳墠鏃跺尯鏃堕棿 + update_date = Column(DateTime, default=current_time, onupdate=current_time) # 鏇存柊鏃堕棿锛岄粯璁ゅ�间负褰撳墠鏃跺尯鏃堕棿锛屾洿鏂版椂鑷姩鏇存柊 # to_dict 鏂规硶 def to_dict(self): return { diff --git a/app/service/basic.py b/app/service/basic.py index 30ac727..93adf6c 100644 --- a/app/service/basic.py +++ b/app/service/basic.py @@ -10,21 +10,27 @@ if response.status_code not in [200, 201]: raise Exception(f"Failed to fetch data from API: {response.text}") response_data = response.json() - status_code = response_data.get("status_code", 0) - if status_code != 200: - raise Exception(f"Failed to fetch data from API: {response.text}") - return response_data.get("data", {}) + return response_data - async def download_from_url(self, url: str, params: dict): + async def download_from_url(self, url, params=None): async with httpx.AsyncClient() as client: - response = await client.get(url, params=params, stream=True) - if response.status_code == 200: - content_disposition = response.headers.get('Content-Disposition') - filename = content_disposition.split('filename=')[-1].strip( - '"') if content_disposition else 'unknown_filename' - return response.content, filename, response.headers.get('Content-Type') - else: - return None, None, None + async with client.stream('GET', url, params=params) as response: + if response.status_code == 200: + # 鑾峰彇鏂囦欢鍚� + content_disposition = response.headers.get('Content-Disposition') + if content_disposition: + filename = content_disposition.split('filename=')[1].strip('"') + else: + filename = 'unknown_filename' + + # 鑾峰彇鍐呭绫诲瀷 + content_type = response.headers.get('Content-Type') + + # 璇诲彇鏂囦欢鍐呭 + content = await response.aread() + return content, filename, content_type + else: + raise Exception(f"Failed to download: {response.status_code}") async def excel_talk_image_download(self, file_id: str): url = f"{self.base_url}/exceltalk/download/image" @@ -47,7 +53,7 @@ files=files, params=params ) - return await self._check_response(response) + return self._check_response(response) async def excel_talk(self, question: str, chat_id: str): url = f"{self.base_url}/exceltalk/talk" @@ -65,4 +71,4 @@ print(e) return else: - yield f"Error: {response.status_code}" + yield f"Error: {response.status_code}" \ No newline at end of file diff --git a/app/service/session.py b/app/service/session.py new file mode 100644 index 0000000..b3b698f --- /dev/null +++ b/app/service/session.py @@ -0,0 +1,80 @@ +from sqlalchemy.orm import Session + +from app.models import AgentType +from app.models.session_model import SessionModel + + +class SessionService: + def __init__(self, db: Session): + self.db = db + + def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType) -> SessionModel: + """ + 鍒涘缓涓�涓柊鐨勪細璇濊褰曘�� + + 鍙傛暟: + session_id (str): 浼氳瘽ID銆� + name (str): 浼氳瘽鍚嶇О銆� + agent_id (str): 浠g悊ID銆� + agent_type (AgentType): 浠g悊绫诲瀷銆� + + 杩斿洖: + SessionModel: 鏂板垱寤虹殑浼氳瘽妯″瀷瀹炰緥锛屽鏋滀細璇滻D宸插瓨鍦ㄥ垯杩斿洖None銆� + """ + existing_session = self.get_session_by_id(session_id) + if existing_session: + return None # 濡傛灉浼氳瘽ID宸插瓨鍦紝涓嶈繘琛屼换浣曟搷浣� + + new_session = SessionModel( + id=session_id, + name=name, + agent_id=agent_id, + agent_type=agent_type + ) + self.db.add(new_session) + self.db.commit() + self.db.refresh(new_session) + return new_session + + def get_session_by_id(self, session_id: str) -> SessionModel: + """ + 鏍规嵁浼氳瘽ID鑾峰彇浼氳瘽璁板綍銆� + + 鍙傛暟: + session_id (str): 浼氳瘽ID銆� + + 杩斿洖: + SessionModel: 鏌ユ壘鍒扮殑浼氳瘽妯″瀷瀹炰緥锛屽鏋滄湭鎵惧埌鍒欒繑鍥濶one銆� + """ + return self.db.query(SessionModel).filter_by(id=session_id).first() + + def update_session(self, session_id: str, **kwargs) -> SessionModel: + """ + 鏇存柊浼氳瘽璁板綍銆� + + 鍙傛暟: + session_id (str): 浼氳瘽ID銆� + kwargs: 闇�瑕佹洿鏂扮殑瀛楁鍙婂叾鍊笺�� + + 杩斿洖: + SessionModel: 鏇存柊鍚庣殑浼氳瘽妯″瀷瀹炰緥銆� + """ + session = self.get_session_by_id(session_id) + if session: + for key, value in kwargs.items(): + setattr(session, key, value) + self.db.commit() + self.db.refresh(session) + return session + + def delete_session(self, session_id: str) -> None: + """ + 鍒犻櫎浼氳瘽璁板綍銆� + + 鍙傛暟: + session_id (str): 浼氳瘽ID銆� + """ + session = self.get_session_by_id(session_id) + if session: + self.db.delete(session) + self.db.commit() \ No newline at end of file -- Gitblit v1.8.0