From 9683aeeafa2f1067ef061b34124a1c362df07e5e Mon Sep 17 00:00:00 2001
From: zhaoqingang <zhaoqg0118@163.com>
Date: 星期四, 03 四月 2025 14:10:13 +0800
Subject: [PATCH] rg配置修改

---
 app/service/v2/app_driver/chat_dialog.py |   76 ++++++++++++++++++++++++++++++++++----
 1 files changed, 68 insertions(+), 8 deletions(-)

diff --git a/app/service/v2/app_driver/chat_dialog.py b/app/service/v2/app_driver/chat_dialog.py
index 0ad7c3c..9aa750e 100644
--- a/app/service/v2/app_driver/chat_dialog.py
+++ b/app/service/v2/app_driver/chat_dialog.py
@@ -1,20 +1,80 @@
+import json
+
+from Log import logger
 from app.service.v2.app_driver.chat_base import ChatBase
 
 
 class ChatDialog(ChatBase):
 
+    async def chat_completions(self, url, data, headers):
+        complete_response = ""
+        async for line in self.http_stream(url, data, headers):
+            # print(line)
+            if line.startswith("data:"):
+                complete_response = line.strip("data:").strip()
+            else:
+                complete_response += line.strip()
+            try:
+                json_data = json.loads(complete_response)
+                # 澶勭悊 JSON 鏁版嵁
+                # print(json_data)
+                complete_response = ""
+                yield json_data
 
-    def __init__(self, token):
-        self.token = token
+            except json.JSONDecodeError as e:
+                # print(e)
+                # print(complete_response)
+                logger.info("Invalid JSON data------------------")
+                # print(e)
+
+    async def chat_sessions(self, url, data, headers):
+
+        res = await self.http_post(url, data, headers)
+        if res.status_code == 200:
+            return res.json()
+        else:
+            return {}
 
 
-    async def get_headers(self):
+
+    @staticmethod
+    async def request_data(question, session_id=""):
         return {
-            'Content-Type': 'application/json',
-            'Authorization': f'Bearer {self.token}'
+            "question": question,
+            "stream": True,
+            "session_id": session_id
+        }
+
+    @staticmethod
+    async def complex_request_data(question, dataset_ids, session_id=""):
+        return {
+            "question": question,
+            "stream": True,
+            "session_id": session_id,
+            "kb_ids": dataset_ids
         }
 
 
-    async def chat_completions(self):
-        async for rag_response in self.http_stream(token, chat_id, chat_history):
-            ...
\ No newline at end of file
+if __name__ == "__main__":
+    async def aa():
+        chat_id = "6b8ee426c67511efb1510242ac1b0006"
+        token = "ragflow-YzMzE1NDRjYzMyZjExZWY5ZjkxMDI0Mm"
+        base_url = "http://192.168.20.116:11080"
+        url = f"{base_url}/api/v1/chats/{chat_id}/completions"
+        chat = ChatDialog(token)
+        data = {
+            "question": "鐢电綉鎶�鏈�荤粨300瀛�",
+            "stream": True,
+            "session_id": "9969c152cce411ef8a140242ac1b0002"
+        }
+        headers = {
+            'Content-Type': 'application/json',
+            'Authorization': f"Bearer {token}"
+        }
+        async for ans in chat.chat_completions(url, data, headers):
+            print(ans)
+
+
+    import asyncio
+
+    asyncio.run(aa())

--
Gitblit v1.8.0