From f95f801f35aa201cbaffd7d881c07edc9398b570 Mon Sep 17 00:00:00 2001
From: zhaoqingang <zhaoqg0118@163.com>
Date: 星期一, 03 三月 2025 16:03:51 +0800
Subject: [PATCH] 增加外接知识库中转接口

---
 app/service/v2/chat.py |   80 +++++++++++++++++++++++++++++++++++++++
 1 files changed, 79 insertions(+), 1 deletions(-)

diff --git a/app/service/v2/chat.py b/app/service/v2/chat.py
index a24f88d..33634a9 100644
--- a/app/service/v2/chat.py
+++ b/app/service/v2/chat.py
@@ -1,11 +1,13 @@
+import asyncio
 import io
 import json
 
 import fitz
+from fastapi import HTTPException
 
 from Log import logger
 from app.config.agent_base_url import RG_CHAT_DIALOG, DF_CHAT_AGENT, DF_CHAT_PARAMETERS, RG_CHAT_SESSIONS, \
-    DF_CHAT_WORKFLOW, DF_UPLOAD_FILE
+    DF_CHAT_WORKFLOW, DF_UPLOAD_FILE, RG_ORIGINAL_URL
 from app.config.config import settings
 from app.config.const import *
 from app.models import DialogModel, ApiTokenModel, UserTokenModel
@@ -258,6 +260,7 @@
 
 async def service_chat_sessions(db, chat_id, name):
     token = await get_chat_token(db, rg_api_token)
+    # print(token)
     if not token:
         return {}
     url = settings.fwr_base_url + RG_CHAT_SESSIONS.format(chat_id)
@@ -343,3 +346,78 @@
         text = await read_word(file)
 
     return await get_str_token(text)
+
+
+async def service_chunk_retrieval(query, top_k, similarity_threshold, api_key):
+    print(query)
+
+    try:
+        request_data = json.loads(query)
+    except json.JSONDecodeError as e:
+        fixed_json = query.replace("'", '"')
+        print("Fixed JSON:", fixed_json)
+        request_data = json.loads(fixed_json)
+    payload = {
+        "question": request_data.get("query", ""),
+        "dataset_ids": request_data.get("dataset_ids", []),
+        "page_size": top_k,
+        "similarity_threshold": similarity_threshold
+    }
+    url = settings.fwr_base_url + RG_ORIGINAL_URL
+    # url = "http://192.168.20.116:11080/" + RG_ORIGINAL_URL
+    chat = ChatBaseApply()
+    response = await  chat.chat_post(url, payload, await chat.get_headers(api_key))
+    if not response:
+        raise HTTPException(status_code=500, detail="鏈嶅姟寮傚父锛�")
+    print(response)
+    records = [
+        {
+            "content": chunk["content"],
+            "score": chunk["similarity"],
+            "title": chunk.get("document_keyword", "Unknown Document"),
+            "metadata": {"document_id": chunk["document_id"],
+                         "path": f"{settings.fwr_base_url}/document/{chunk['document_id']}?ext={chunk.get('document_keyword').split('.')[-1]}&prefix=document",
+                         'highlight': chunk.get("highlight") , "image_id":  chunk.get("image_id"), "positions": chunk.get("positions"),}
+        }
+        for chunk in response.get("data", {}).get("chunks", [])
+    ]
+    return records
+
+async def service_base_chunk_retrieval(query, knowledge_id, top_k, similarity_threshold, api_key):
+    # request_data = json.loads(query)
+    payload = {
+        "question": query,
+        "dataset_ids": [knowledge_id],
+        "page_size": top_k,
+        "similarity_threshold": similarity_threshold
+    }
+    url = settings.fwr_base_url + RG_ORIGINAL_URL
+    # url = "http://192.168.20.116:11080/" + RG_ORIGINAL_URL
+    chat = ChatBaseApply()
+    response = await chat.chat_post(url, payload, await chat.get_headers(api_key))
+    if not response:
+        raise HTTPException(status_code=500, detail="鏈嶅姟寮傚父锛�")
+    records = [
+        {
+            "content": chunk["content"],
+            "score": chunk["similarity"],
+            "title": chunk.get("document_keyword", "Unknown Document"),
+            "metadata": {"document_id": chunk["document_id"]}
+        }
+        for chunk in response.get("data", {}).get("chunks", [])
+    ]
+    return records
+
+
+
+if __name__ == "__main__":
+    q = json.dumps({"query": "璁惧", "dataset_ids": ["fc68db52f43111efb94a0242ac120004"]})
+    top_k = 2
+    similarity_threshold = 0.5
+    api_key = "ragflow-Y4MGYwY2JlZjM2YjExZWY4ZWU5MDI0Mm"
+    # a = service_chunk_retrieval(q, top_k, similarity_threshold, api_key)
+    # print(a)
+    async def a():
+        b = await service_chunk_retrieval(q, top_k, similarity_threshold, api_key)
+        print(b)
+    asyncio.run(a())
\ No newline at end of file

--
Gitblit v1.8.0