From f95f801f35aa201cbaffd7d881c07edc9398b570 Mon Sep 17 00:00:00 2001
From: zhaoqingang <zhaoqg0118@163.com>
Date: 星期一, 03 三月 2025 16:03:51 +0800
Subject: [PATCH] 增加外接知识库中转接口
---
app/service/v2/chat.py | 90 +++++++++++++++++++++++++++++++++++++++++++-
1 files changed, 87 insertions(+), 3 deletions(-)
diff --git a/app/service/v2/chat.py b/app/service/v2/chat.py
index a35c775..33634a9 100644
--- a/app/service/v2/chat.py
+++ b/app/service/v2/chat.py
@@ -1,11 +1,13 @@
+import asyncio
import io
import json
import fitz
+from fastapi import HTTPException
from Log import logger
from app.config.agent_base_url import RG_CHAT_DIALOG, DF_CHAT_AGENT, DF_CHAT_PARAMETERS, RG_CHAT_SESSIONS, \
- DF_CHAT_WORKFLOW, DF_UPLOAD_FILE
+ DF_CHAT_WORKFLOW, DF_UPLOAD_FILE, RG_ORIGINAL_URL
from app.config.config import settings
from app.config.const import *
from app.models import DialogModel, ApiTokenModel, UserTokenModel
@@ -169,7 +171,7 @@
query = chat_data.query
else:
query = "start new workflow"
- session = await add_session_log(db, session_id, query, chat_id, user_id, mode, conversation_id, 3)
+ session = await add_session_log(db, session_id,query if query else "start new conversation", chat_id, user_id, mode, conversation_id, 3)
if session:
conversation_id = session.conversation_id
try:
@@ -205,6 +207,9 @@
data["outputs"] = await data_process(data.get("outputs", {}))
data["files"] = await data_process(data.get("files", []))
data["process_data"] = ""
+ if data.get("status") == "failed":
+ status = http_500
+ error = data.get("error", "")
node_list.append(ans)
event = [smart_workflow_started, smart_node_started, smart_node_finished][
[workflow_started, node_started, node_finished].index(ans.get("event"))]
@@ -213,6 +218,9 @@
answer_workflow = data.get("outputs", {}).get("output")
download_url = data.get("outputs", {}).get("download_url")
event = smart_workflow_finished
+ if data.get("status") == "failed":
+ status = http_500
+ error = data.get("error", "")
node_list.append(ans)
elif ans.get("event") == message_end:
@@ -234,7 +242,7 @@
except:
...
finally:
- await update_session_log(db, session_id, {"role": "assistant", "answer": answer_event or answer_agent or answer_workflow,
+ await update_session_log(db, session_id, {"role": "assistant", "answer": answer_event or answer_agent or answer_workflow or error,
"download_url":download_url,
"node_list": node_list, "task_id": task_id, "id": message_id,
"error": error}, conversation_id)
@@ -252,6 +260,7 @@
async def service_chat_sessions(db, chat_id, name):
token = await get_chat_token(db, rg_api_token)
+ # print(token)
if not token:
return {}
url = settings.fwr_base_url + RG_CHAT_SESSIONS.format(chat_id)
@@ -337,3 +346,78 @@
text = await read_word(file)
return await get_str_token(text)
+
+
+async def service_chunk_retrieval(query, top_k, similarity_threshold, api_key):
+ print(query)
+
+ try:
+ request_data = json.loads(query)
+ except json.JSONDecodeError as e:
+ fixed_json = query.replace("'", '"')
+ print("Fixed JSON:", fixed_json)
+ request_data = json.loads(fixed_json)
+ payload = {
+ "question": request_data.get("query", ""),
+ "dataset_ids": request_data.get("dataset_ids", []),
+ "page_size": top_k,
+ "similarity_threshold": similarity_threshold
+ }
+ url = settings.fwr_base_url + RG_ORIGINAL_URL
+ # url = "http://192.168.20.116:11080/" + RG_ORIGINAL_URL
+ chat = ChatBaseApply()
+ response = await chat.chat_post(url, payload, await chat.get_headers(api_key))
+ if not response:
+ raise HTTPException(status_code=500, detail="鏈嶅姟寮傚父锛�")
+ print(response)
+ records = [
+ {
+ "content": chunk["content"],
+ "score": chunk["similarity"],
+ "title": chunk.get("document_keyword", "Unknown Document"),
+ "metadata": {"document_id": chunk["document_id"],
+ "path": f"{settings.fwr_base_url}/document/{chunk['document_id']}?ext={chunk.get('document_keyword').split('.')[-1]}&prefix=document",
+ 'highlight': chunk.get("highlight") , "image_id": chunk.get("image_id"), "positions": chunk.get("positions"),}
+ }
+ for chunk in response.get("data", {}).get("chunks", [])
+ ]
+ return records
+
+async def service_base_chunk_retrieval(query, knowledge_id, top_k, similarity_threshold, api_key):
+ # request_data = json.loads(query)
+ payload = {
+ "question": query,
+ "dataset_ids": [knowledge_id],
+ "page_size": top_k,
+ "similarity_threshold": similarity_threshold
+ }
+ url = settings.fwr_base_url + RG_ORIGINAL_URL
+ # url = "http://192.168.20.116:11080/" + RG_ORIGINAL_URL
+ chat = ChatBaseApply()
+ response = await chat.chat_post(url, payload, await chat.get_headers(api_key))
+ if not response:
+ raise HTTPException(status_code=500, detail="鏈嶅姟寮傚父锛�")
+ records = [
+ {
+ "content": chunk["content"],
+ "score": chunk["similarity"],
+ "title": chunk.get("document_keyword", "Unknown Document"),
+ "metadata": {"document_id": chunk["document_id"]}
+ }
+ for chunk in response.get("data", {}).get("chunks", [])
+ ]
+ return records
+
+
+
+if __name__ == "__main__":
+ q = json.dumps({"query": "璁惧", "dataset_ids": ["fc68db52f43111efb94a0242ac120004"]})
+ top_k = 2
+ similarity_threshold = 0.5
+ api_key = "ragflow-Y4MGYwY2JlZjM2YjExZWY4ZWU5MDI0Mm"
+ # a = service_chunk_retrieval(q, top_k, similarity_threshold, api_key)
+ # print(a)
+ async def a():
+ b = await service_chunk_retrieval(q, top_k, similarity_threshold, api_key)
+ print(b)
+ asyncio.run(a())
\ No newline at end of file
--
Gitblit v1.8.0