From 13c3fdf08558b6ce01dcbdc7716bd77dc9b2e88c Mon Sep 17 00:00:00 2001
From: zhaoqingang <zhaoqg0118@163.com>
Date: 星期二, 19 十一月 2024 17:13:19 +0800
Subject: [PATCH] Merge branch 'master' of http://192.168.5.5:10010/r/rag-gateway

---
 app/models/session_model.py |   28 +++++++
 app/config/config.py        |    1 
 app/api/chat.py             |   40 ++++++++++
 app/api/excel_talk.py       |   69 +++++++++++++++++
 app/config/config.yaml      |    3 
 app/service/basic.py        |   68 +++++++++++++++++
 app/api/agent.py            |    7 +
 app/api/files.py            |    9 ++
 8 files changed, 223 insertions(+), 2 deletions(-)

diff --git a/app/api/agent.py b/app/api/agent.py
index 3178144..4e410d4 100644
--- a/app/api/agent.py
+++ b/app/api/agent.py
@@ -10,6 +10,7 @@
 from app.config.config import settings
 from app.models.agent_model import AgentType, AgentModel
 from app.models.base_model import get_db
+from app.models.session_model import SessionModel
 from app.models.user_model import UserModel
 from app.service.bisheng import BishengService
 from app.service.dialog import get_session_history
@@ -57,6 +58,12 @@
             raise HTTPException(status_code=500, detail=str(e))
         return ResponseList(code=200, msg="", data=result)
 
+    elif agent.agent_type == AgentType.BASIC:
+        offset = (page - 1) * limit
+        records = db.query(SessionModel).filter(SessionModel.agent_id == agent_id).offset(offset).limit(limit).all()
+        result = [item.to_dict() for item in records]
+        return ResponseList(code=200, msg="", data=result)
+
     else:
         return ResponseList(code=200, msg="Unsupported agent type")
 
diff --git a/app/api/chat.py b/app/api/chat.py
index ea1be48..b5bfd6a 100644
--- a/app/api/chat.py
+++ b/app/api/chat.py
@@ -11,6 +11,7 @@
 from app.models.base_model import get_db
 from app.models.user_model import UserModel
 from app.service.dialog import update_session_history
+from app.service.basic import BasicService
 from app.service.ragflow import RagflowService
 from app.service.service_token import get_bisheng_token, get_ragflow_token
 
@@ -196,6 +197,45 @@
                             await task
                         except asyncio.CancelledError:
                             pass
+    elif agent_type == AgentType.BASIC:
+        try:
+            while True:
+                # 鎺ユ敹鍓嶇娑堟伅
+                message = await websocket.receive_json()
+                question = message.get("message")
+                if not question:
+                    await websocket.send_json({"message": "Invalid request", "type": "error"})
+                    continue
+
+                service = BasicService(base_url=settings.basic_base_url)
+                complete_response = ""
+                async for result in service.excel_talk(question, chat_id):
+                    try:
+                        if result[:5] == "data:":
+                            # 濡傛灉鏄紝鍒欐埅鍙栨帀鍓�5涓瓧绗︼紝骞跺幓闄ら灏剧┖鐧界
+                            text = result[5:].strip()
+                        else:
+                            # 鍚﹀垯锛屼繚鎸佸師鏍�
+                            text = result
+                        complete_response += text
+                        try:
+                            json_data = json.loads(complete_response)
+                            output = json_data.get("output", "")
+                            result = {"message": output, "type": "message"}
+                            await websocket.send_json(result | json_data)
+                            complete_response = ""
+                        except json.JSONDecodeError as e:
+                            print(f"Error decoding JSON: {e}")
+                            print(f"Response text: {text}")
+                    except Exception as e2:
+                        result = {"message": f"鍐呴儴閿欒锛� {e2}", "type": "close"}
+                        await websocket.send_json(result)
+                        print(f"Error process message of basic agent: {e2}")
+        except Exception as e:
+            await websocket.send_json({"message": str(e), "type": "error"})
+        finally:
+            await websocket.close()
+            print(f"Client {agent_id} disconnected")
     else:
         ret = {"message": "Agent not found", "type": "close"}
         await websocket.send_json(ret)
diff --git a/app/api/excel_talk.py b/app/api/excel_talk.py
new file mode 100644
index 0000000..ad4a416
--- /dev/null
+++ b/app/api/excel_talk.py
@@ -0,0 +1,69 @@
+import asyncio
+import json
+from enum import Enum
+
+from fastapi import APIRouter, Depends
+from sqlalchemy.orm import Session
+from starlette.websockets import WebSocket, WebSocketDisconnect
+
+from app.api import get_current_user_websocket
+from app.config.config import settings
+from app.models import UserModel, AgentModel
+from app.models.base_model import get_db
+from app.service.basic import BasicService
+
+router = APIRouter()
+
+# class CompletionRequest(BaseModel):
+#     id: Optional[str] = None
+#     app_id: str
+#     message: str
+#
+# class DownloadRequest(BaseModel):
+#     file_id: str
+#     app_id: str
+#     file_type: Optional[str] = None
+
+
+class AdvancedAgentID(Enum):
+    EXCEL_TALK = "excel_talk"
+    QUESTIONS_TALK = "questions_talk"
+
+@router.websocket("/ws/{agent_id}/{chat_id}")
+async def handle_client(websocket: WebSocket,
+                        agent_id: str,
+                        chat_id: str,
+                        current_user: UserModel = Depends(get_current_user_websocket),
+                        db: Session = Depends(get_db)):
+    await websocket.accept()
+    print(f"Client {agent_id} connected")
+
+    service = BasicService(base_url=settings.basic_base_url)
+
+    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
+    if not agent:
+        ret = {"message": "Agent not found", "type": "close"}
+        await websocket.send_json(ret)
+        return
+    try:
+        while True:
+            # 鎺ユ敹鍓嶇娑堟伅
+            message = await websocket.receive_json()
+            question = message.get("message")
+            if not question:
+                await websocket.send_json({"message": "Invalid request", "type": "error"})
+                continue
+
+            # 璋冪敤 excel_talk 鏂规硶
+            result = await service.excel_talk(question, chat_id)
+
+            # 灏嗙粨鏋滃彂閫佸洖鍓嶇
+            await websocket.send_json({"message": result, "type": "response"})
+    except Exception as e:
+        await websocket.send_json({"message": str(e), "type": "error"})
+    finally:
+        await websocket.close()
+        print(f"Client {agent_id} disconnected")
+
+
+
diff --git a/app/api/files.py b/app/api/files.py
index eed49d8..fe80f4a 100644
--- a/app/api/files.py
+++ b/app/api/files.py
@@ -1,7 +1,7 @@
 from typing import Optional
 
 import requests
-from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, Query
+from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, Query, Form
 from pydantic import BaseModel
 from sqlalchemy.orm import Session
 from starlette.responses import StreamingResponse
@@ -11,6 +11,7 @@
 from app.models.agent_model import AgentType, AgentModel
 from app.models.base_model import get_db
 from app.models.user_model import UserModel
+from app.service.basic import BasicService
 from app.service.bisheng import BishengService
 from app.service.ragflow import RagflowService
 from app.service.service_token import get_ragflow_token, get_bisheng_token
@@ -58,6 +59,12 @@
             raise HTTPException(status_code=500, detail=str(e))
         result["file_name"] = file.filename
         return Response(code=200, msg="", data=result)
+    elif agent.agent_type == AgentType.BASIC:
+        if agent_id == "basic_excel_talk":
+            service = BasicService(base_url=settings.basic_base_url)
+            result = await service.excel_talk_upload(chat_id, file.filename, file_content)
+
+            return Response(code=200, msg="", data=result)
 
     else:
         return Response(code=200, msg="Unsupported agent type")
diff --git a/app/config/config.py b/app/config/config.py
index 7d6c676..435ee86 100644
--- a/app/config/config.py
+++ b/app/config/config.py
@@ -15,6 +15,7 @@
     PUBLIC_KEY: str
     PRIVATE_KEY: str
     PASSWORD_KEY: str
+    basic_base_url: str = ''
     def __init__(self, **kwargs):
         # Check if all required fields are provided and set them
         for field in self.__annotations__.keys():
diff --git a/app/config/config.yaml b/app/config/config.yaml
index d260302..e1f16c5 100644
--- a/app/config/config.yaml
+++ b/app/config/config.yaml
@@ -12,4 +12,5 @@
 PRIVATE_KEY: str
 fetch_sgb_agent: 鎶ュ憡鐢熸垚,鏂囨。鏅鸿兘
 fetch_fwr_agent: 鐭ヨ瘑闂瓟,鏅鸿兘闂瓟
-PASSWORD_KEY: VKinqB-8XMrwCLLrcf_PyHyo12_4PVKvWzaHjNFions=
\ No newline at end of file
+PASSWORD_KEY: VKinqB-8XMrwCLLrcf_PyHyo12_4PVKvWzaHjNFions=
+basic_base_url: http://192.168.20.231:8000
diff --git a/app/models/session_model.py b/app/models/session_model.py
new file mode 100644
index 0000000..21bfb7e
--- /dev/null
+++ b/app/models/session_model.py
@@ -0,0 +1,28 @@
+import json
+from datetime import datetime
+from enum import IntEnum
+from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime
+
+from app.models import AgentType
+from app.models.base_model import Base
+
+
+class SessionModel(Base):
+    __tablename__ = "sessions"
+    id = Column(String(255), primary_key=True)
+    name = Column(String(255))
+    agent_id = Column(String(255))
+    agent_type = Column(SQLAlchemyEnum(AgentType), nullable=False)  # 鐩墠鍙瓨basic鐨勶紝ragflow鍜宐isheng鐨勮皟鎺ュ彛鑾峰彇
+    create_date = Column(DateTime)  # 鍒涘缓鏃堕棿
+    update_date = Column(DateTime)  # 鏇存柊鏃堕棿
+
+    # to_dict 鏂规硶
+    def to_dict(self):
+        return {
+            'id': self.id,
+            'name': self.name,
+            'agent_type': self.agent_type,
+            'agent_id': self.agent_id,
+            'create_date': self.create_date,
+            'update_date': self.update_date,
+        }
diff --git a/app/service/basic.py b/app/service/basic.py
new file mode 100644
index 0000000..30ac727
--- /dev/null
+++ b/app/service/basic.py
@@ -0,0 +1,68 @@
+import httpx
+
+
+class BasicService:
+    def __init__(self, base_url: str):
+        self.base_url = base_url
+
+    def _check_response(self, response: httpx.Response):
+        """妫�鏌ュ搷搴斿苟澶勭悊閿欒"""
+        if response.status_code not in [200, 201]:
+            raise Exception(f"Failed to fetch data from API: {response.text}")
+        response_data = response.json()
+        status_code = response_data.get("status_code", 0)
+        if status_code != 200:
+            raise Exception(f"Failed to fetch data from API: {response.text}")
+        return response_data.get("data", {})
+
+    async def download_from_url(self, url: str, params: dict):
+        async with httpx.AsyncClient() as client:
+            response = await client.get(url, params=params, stream=True)
+            if response.status_code == 200:
+                content_disposition = response.headers.get('Content-Disposition')
+                filename = content_disposition.split('filename=')[-1].strip(
+                    '"') if content_disposition else 'unknown_filename'
+                return response.content, filename, response.headers.get('Content-Type')
+            else:
+                return None, None, None
+
+    async def excel_talk_image_download(self, file_id: str):
+        url = f"{self.base_url}/exceltalk/download/image"
+        return await self.download_from_url(url, params={'images_name': file_id})
+
+    async def excel_talk_excel_download(self, file_id: str):
+        url = f"{self.base_url}/exceltalk/download/excel"
+        return await self.download_from_url(url, params={'excel_name': file_id})
+
+    async def excel_talk_upload(self, chat_id: str, filename: str, file_content: bytes):
+        url = f"{self.base_url}/exceltalk/upload/files"
+        params = {'chat_id': chat_id, 'is_col': '0'}
+
+        # 鍒涘缓 FormData 瀵硅薄
+        files = [('files', (filename, file_content, 'application/octet-stream'))]
+
+        async with httpx.AsyncClient() as client:
+            response = await client.post(
+                url,
+                files=files,
+                params=params
+            )
+            return await self._check_response(response)
+
+    async def excel_talk(self, question: str, chat_id: str):
+        url = f"{self.base_url}/exceltalk/talk"
+        params = {'chat_id': chat_id}
+        data = {"query": question}
+        headers = {'Content-Type': 'application/json'}
+        async with httpx.AsyncClient(timeout=300.0) as client:
+            async with client.stream("POST", url, params=params, json=data, headers=headers) as response:
+                if response.status_code == 200:
+                    try:
+                        async for answer in response.aiter_text():
+                            print(f"response of ragflow chat: {answer}")
+                            yield answer
+                    except GeneratorExit as e:
+                        print(e)
+                        return
+                else:
+                    yield f"Error: {response.status_code}"

--
Gitblit v1.8.0