From 370120fd4154ce6c5f69d16a4a343a016cf2e816 Mon Sep 17 00:00:00 2001
From: zhaoqingang <zhaoqg0118@163.com>
Date: 星期二, 04 三月 2025 09:53:17 +0800
Subject: [PATCH] 完善问题

---
 app/service/v2/chat.py |  101 +++++++++++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 98 insertions(+), 3 deletions(-)

diff --git a/app/service/v2/chat.py b/app/service/v2/chat.py
index a35c775..68a7cd3 100644
--- a/app/service/v2/chat.py
+++ b/app/service/v2/chat.py
@@ -1,11 +1,13 @@
+import asyncio
 import io
 import json
 
 import fitz
+from fastapi import HTTPException
 
 from Log import logger
 from app.config.agent_base_url import RG_CHAT_DIALOG, DF_CHAT_AGENT, DF_CHAT_PARAMETERS, RG_CHAT_SESSIONS, \
-    DF_CHAT_WORKFLOW, DF_UPLOAD_FILE
+    DF_CHAT_WORKFLOW, DF_UPLOAD_FILE, RG_ORIGINAL_URL
 from app.config.config import settings
 from app.config.const import *
 from app.models import DialogModel, ApiTokenModel, UserTokenModel
@@ -169,7 +171,7 @@
         query = chat_data.query
     else:
         query = "start new workflow"
-    session = await add_session_log(db, session_id, query, chat_id, user_id, mode, conversation_id, 3)
+    session = await add_session_log(db, session_id,query if query else "start new conversation", chat_id, user_id, mode, conversation_id, 3)
     if session:
         conversation_id = session.conversation_id
     try:
@@ -205,6 +207,9 @@
                 data["outputs"] = await data_process(data.get("outputs", {}))
                 data["files"] = await data_process(data.get("files", []))
                 data["process_data"] = ""
+                if data.get("status") == "failed":
+                    status = http_500
+                    error = data.get("error", "")
                 node_list.append(ans)
                 event = [smart_workflow_started, smart_node_started, smart_node_finished][
                     [workflow_started, node_started, node_finished].index(ans.get("event"))]
@@ -213,6 +218,9 @@
                 answer_workflow = data.get("outputs", {}).get("output")
                 download_url = data.get("outputs", {}).get("download_url")
                 event = smart_workflow_finished
+                if data.get("status") == "failed":
+                    status = http_500
+                    error = data.get("error", "")
                 node_list.append(ans)
 
             elif ans.get("event") == message_end:
@@ -234,7 +242,7 @@
         except:
             ...
     finally:
-        await update_session_log(db, session_id, {"role": "assistant", "answer": answer_event or answer_agent or answer_workflow,
+        await update_session_log(db, session_id, {"role": "assistant", "answer": answer_event or answer_agent or answer_workflow or error,
                                                   "download_url":download_url,
                                                   "node_list": node_list, "task_id": task_id, "id": message_id,
                                                   "error": error}, conversation_id)
@@ -252,6 +260,7 @@
 
 async def service_chat_sessions(db, chat_id, name):
     token = await get_chat_token(db, rg_api_token)
+    # print(token)
     if not token:
         return {}
     url = settings.fwr_base_url + RG_CHAT_SESSIONS.format(chat_id)
@@ -337,3 +346,89 @@
         text = await read_word(file)
 
     return await get_str_token(text)
+
+
+async def service_chunk_retrieval(query, knowledge_id, top_k, similarity_threshold, api_key):
+    print(query)
+
+    try:
+        request_data = json.loads(query)
+        payload = {
+            "question": request_data.get("query", ""),
+            "dataset_ids": request_data.get("dataset_ids", []),
+            "page_size": top_k,
+            "similarity_threshold": similarity_threshold
+        }
+    except json.JSONDecodeError as e:
+        fixed_json = query.replace("'", '"')
+        try:
+            request_data = json.loads(fixed_json)
+            payload = {
+                "question": request_data.get("query", ""),
+                "dataset_ids": request_data.get("dataset_ids", []),
+                "page_size": top_k,
+                "similarity_threshold": similarity_threshold
+            }
+        except Exception:
+            payload = {
+                "question":query,
+                "dataset_ids":[knowledge_id],
+                "page_size": top_k,
+                "similarity_threshold": similarity_threshold
+            }
+    url = settings.fwr_base_url + RG_ORIGINAL_URL
+    chat = ChatBaseApply()
+    response = await  chat.chat_post(url, payload, await chat.get_headers(api_key))
+    if not response:
+        raise HTTPException(status_code=500, detail="鏈嶅姟寮傚父锛�")
+    records = [
+        {
+            "content": chunk["content"],
+            "score": chunk["similarity"],
+            "title": chunk.get("document_keyword", "Unknown Document"),
+            "metadata": {"document_id": chunk["document_id"],
+                         "path": f"{settings.fwr_base_url}/document/{chunk['document_id']}?ext={chunk.get('document_keyword').split('.')[-1]}&prefix=document",
+                         'highlight': chunk.get("highlight") , "image_id":  chunk.get("image_id"), "positions": chunk.get("positions"),}
+        }
+        for chunk in response.get("data", {}).get("chunks", [])
+    ]
+    return records
+
+async def service_base_chunk_retrieval(query, knowledge_id, top_k, similarity_threshold, api_key):
+    # request_data = json.loads(query)
+    payload = {
+        "question": query,
+        "dataset_ids": [knowledge_id],
+        "page_size": top_k,
+        "similarity_threshold": similarity_threshold
+    }
+    url = settings.fwr_base_url + RG_ORIGINAL_URL
+    # url = "http://192.168.20.116:11080/" + RG_ORIGINAL_URL
+    chat = ChatBaseApply()
+    response = await chat.chat_post(url, payload, await chat.get_headers(api_key))
+    if not response:
+        raise HTTPException(status_code=500, detail="鏈嶅姟寮傚父锛�")
+    records = [
+        {
+            "content": chunk["content"],
+            "score": chunk["similarity"],
+            "title": chunk.get("document_keyword", "Unknown Document"),
+            "metadata": {"document_id": chunk["document_id"]}
+        }
+        for chunk in response.get("data", {}).get("chunks", [])
+    ]
+    return records
+
+
+
+if __name__ == "__main__":
+    q = json.dumps({"query": "璁惧", "dataset_ids": ["fc68db52f43111efb94a0242ac120004"]})
+    top_k = 2
+    similarity_threshold = 0.5
+    api_key = "ragflow-Y4MGYwY2JlZjM2YjExZWY4ZWU5MDI0Mm"
+    # a = service_chunk_retrieval(q, top_k, similarity_threshold, api_key)
+    # print(a)
+    async def a():
+        b = await service_chunk_retrieval(q, top_k, similarity_threshold, api_key)
+        print(b)
+    asyncio.run(a())
\ No newline at end of file

--
Gitblit v1.8.0