From 282a631b9ceee9a634ee1d93751a5254ed37ccef Mon Sep 17 00:00:00 2001
From: zhaoqingang <zhaoqg0118@163.com>
Date: 星期二, 18 三月 2025 10:10:48 +0800
Subject: [PATCH] 首页知识库对话-rg

---
 app/service/v2/chat.py     |   35 ++++++++++-
 app/service/dialog.py      |    1 
 app/service/v2/mindmap.py  |   69 +++++++++++++++++------
 app/task/fetch_agent.py    |   12 ++-
 app/api/v2/chat.py         |    2 
 app/models/dialog_model.py |    5 +
 app/api/v2/mindmap.py      |   11 +++
 app/service/knowledge.py   |    4 
 app/models/v2/chat.py      |   14 +++-
 9 files changed, 119 insertions(+), 34 deletions(-)

diff --git a/app/api/v2/chat.py b/app/api/v2/chat.py
index 207d967..c10df3e 100644
--- a/app/api/v2/chat.py
+++ b/app/api/v2/chat.py
@@ -42,7 +42,7 @@
             return StreamingResponse(f"data: {error_msg}\n\n",
                                      media_type="text/event-stream")
         session_id = session.get("data", {}).get("id")
-    return StreamingResponse(service_chat_dialog(db, chatId, dialog.query, session_id, current_user.id, chat_info.mode),
+    return StreamingResponse(service_chat_dialog(db, chatId, dialog.query, session_id, current_user.id, chat_info.mode, chat_info.get_kb_ids()),
                              media_type="text/event-stream")
 
 @chat_router_v2.post("/agent/{chatId}/completions")
diff --git a/app/api/v2/mindmap.py b/app/api/v2/mindmap.py
index 9f9892b..d3db944 100644
--- a/app/api/v2/mindmap.py
+++ b/app/api/v2/mindmap.py
@@ -14,7 +14,7 @@
 from app.models.v2.chat import RetrievalRequest, ComplexChatDao
 from app.models.v2.mindmap import MindmapRequest
 from app.models.v2.session_model import ChatData
-from app.service.v2.mindmap import service_chat_mindmap
+from app.service.v2.mindmap import service_chat_mindmap, service_message_mindmap_parse
 
 mind_map_router = APIRouter()
 
@@ -28,4 +28,13 @@
             return Response(code=500, msg="create failure", data={})
     else:
         return Response(code=500, msg="缃戠粶寮傚父锛乫ailure", data={})
+    return Response(code=200, msg="create success", data=data)
+
+
+@mind_map_router.get("/{messageId}/parse", response_model=Response)
+async def api_chat_mindmap(messageId: str, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): #  current_user: UserModel = Depends(get_current_user)
+
+    data = await service_message_mindmap_parse(db, messageId, current_user.id)
+    if not data:
+        return Response(code=500, msg="create failure", data={})
     return Response(code=200, msg="create success", data=data)
\ No newline at end of file
diff --git a/app/models/dialog_model.py b/app/models/dialog_model.py
index 50a8f8e..a9aaf89 100644
--- a/app/models/dialog_model.py
+++ b/app/models/dialog_model.py
@@ -1,3 +1,4 @@
+import json
 from datetime import datetime
 from typing import Optional
 
@@ -24,6 +25,7 @@
     # agent_id = Column(String(36))
     mode = Column(String(36))
     parameters = Column(Text)
+    kb_ids = Column(String(128))
 
     def get_id(self):
         return str(self.id)
@@ -43,6 +45,9 @@
             'mode': self.mode,
         }
 
+    def get_kb_ids(self):
+        return json.loads(self.kb_ids) if self.kb_ids else []
+
 
 class ConversationModel(Base):
     __tablename__ = 'conversation'
diff --git a/app/models/v2/chat.py b/app/models/v2/chat.py
index 2945e87..e7fa060 100644
--- a/app/models/v2/chat.py
+++ b/app/models/v2/chat.py
@@ -6,7 +6,7 @@
 from sqlalchemy import Column, Integer, String, BigInteger, ForeignKey, DateTime, Text, TEXT
 from sqlalchemy.orm import Session
 
-from app.config.const import Dialog_STATSU_DELETE, Dialog_STATSU_ON
+from app.config.const import Dialog_STATSU_DELETE, Dialog_STATSU_ON, complex_knowledge_chat
 from app.models.base_model import Base
 from app.utils.common import current_time
 
@@ -187,14 +187,20 @@
             query = {}
             if self.query:
                 query = json.loads(self.query)
-            return {
+
+            res = {
                 'id': self.id,
                 'role': "assistant",
                 'answer': self.content,
                 'chat_mode': self.chat_mode,
-                'node_list': json.loads(self.node_data) if self.node_data else [],
-                "parentId": query.get("parentId")
+                "parentId": query.get("parentId"),
+                "isDeep": query.get("isDeep", 1),
             }
+            if self.chat_mode == complex_knowledge_chat:
+                res['reference'] = json.loads(self.node_data) if self.node_data else {}
+            else:
+                res['node_list'] = json.loads(self.node_data) if self.node_data else []
+            return res
 
 
 class ComplexChatSessionDao:
diff --git a/app/service/dialog.py b/app/service/dialog.py
index 34e4711..82b2b22 100644
--- a/app/service/dialog.py
+++ b/app/service/dialog.py
@@ -245,6 +245,7 @@
             if app_dialog:
                 dialog.name = app_dialog["name"]
                 dialog.description = app_dialog["description"]
+                dialog.kb_ids = app_dialog["kb_ids"]
                 dialog.update_date = datetime.now()
                 db.add(dialog)
                 db.commit()
diff --git a/app/service/knowledge.py b/app/service/knowledge.py
index 4fb834f..f6250b7 100644
--- a/app/service/knowledge.py
+++ b/app/service/knowledge.py
@@ -17,8 +17,8 @@
         klg_list = [j.id for i in user.groups for j in i.knowledges]
         query = query.filter(or_(KnowledgeModel.id.in_(klg_list), KnowledgeModel.tenant_id == str(user_id)))
 
-    if location:
-        query = query.filter(or_(KnowledgeModel.permission == "team", KnowledgeModel.tenant_id == str(user_id)))
+        if location:
+            query = query.filter(or_(KnowledgeModel.permission == "team", KnowledgeModel.tenant_id == str(user_id)))
 
     if keyword:
         query = query.filter(KnowledgeModel.name.like('%{}%'.format(keyword)))
diff --git a/app/service/v2/chat.py b/app/service/v2/chat.py
index 38683a8..3982bdc 100644
--- a/app/service/v2/chat.py
+++ b/app/service/v2/chat.py
@@ -6,6 +6,7 @@
 
 import fitz
 from fastapi import HTTPException
+from sqlalchemy import or_
 
 from Log import logger
 from app.config.agent_base_url import RG_CHAT_DIALOG, DF_CHAT_AGENT, DF_CHAT_PARAMETERS, RG_CHAT_SESSIONS, \
@@ -13,7 +14,7 @@
 from app.config.config import settings
 from app.config.const import *
 from app.models import DialogModel, ApiTokenModel, UserTokenModel, ComplexChatSessionDao, ChatDataRequest, \
-    ComplexChatDao
+    ComplexChatDao, KnowledgeModel, UserModel
 from app.models.v2.session_model import ChatSessionDao, ChatData
 from app.service.v2.app_driver.chat_agent import ChatAgent
 from app.service.v2.app_driver.chat_data import ChatBaseApply
@@ -87,17 +88,45 @@
         return ChatAgent(), url
 
 
-async def service_chat_dialog(db, chat_id: str, question: str, session_id: str, user_id, mode: str):
+
+async def get_user_kb(db, user_id: int, kb_ids: list) -> list:
+    res = []
+    user = db.query(UserModel).filter(UserModel.id == user_id).first()
+    if user is None:
+        return res
+    query = db.query(KnowledgeModel)
+    if user.permission != "admin":
+        klg_list = [j.id for i in user.groups for j in i.knowledges]
+        query = query.filter(or_(KnowledgeModel.id.in_(klg_list), KnowledgeModel.tenant_id == str(user_id)))
+        kb_list= query.all()
+        for kb in kb_list:
+            if kb.id in kb_ids:
+                if kb.permission == "team":
+                    res.append(kb.id)
+                elif kb.tenant_id == str(user_id):
+                    res.append(kb.id)
+        return res
+    else:
+        return kb_ids
+
+
+async def service_chat_dialog(db, chat_id: str, question: str, session_id: str, user_id: int, mode: str, kb_ids: list):
     conversation_id = ""
     token = await get_chat_token(db, rg_api_token)
     url = settings.fwr_base_url + RG_CHAT_DIALOG.format(chat_id)
+    kb_id = await get_user_kb(db, user_id, kb_ids)
+    if not kb_id:
+        yield "data: " + json.dumps({"message": smart_message_error,
+                                     "error": "\n**ERROR**: The agent has no knowledge base to work with!", "status": http_400},
+                                    ensure_ascii=False) + "\n\n"
+        return
     chat = ChatDialog()
     session = await add_session_log(db, session_id, question, chat_id, user_id, mode, session_id, RG_TYPE)
     if session:
         conversation_id = session.conversation_id
     message = {"role": "assistant", "answer": "", "reference": {}}
     try:
-        async for ans in chat.chat_completions(url, await chat.request_data(question, conversation_id),
+        async for ans in chat.chat_completions(url, await chat.complex_request_data(question, kb_id, conversation_id),
                                                await chat.get_headers(token)):
             data = {}
             error = ""
diff --git a/app/service/v2/mindmap.py b/app/service/v2/mindmap.py
index f6e576c..ff93e47 100644
--- a/app/service/v2/mindmap.py
+++ b/app/service/v2/mindmap.py
@@ -1,10 +1,11 @@
 import json
 from Log import logger
-from app.config.agent_base_url import DF_CHAT_AGENT
+from app.config.agent_base_url import DF_CHAT_AGENT, RG_CHAT_DIALOG
 from app.config.config import settings
-from app.config.const import message_error, message_event, complex_knowledge_chat
+from app.config.const import message_error, message_event, complex_knowledge_chat, rg_api_token, workflow_finished
 from app.models import ComplexChatSessionDao, ChatData
 from app.service.v2.app_driver.chat_agent import ChatAgent
+from app.service.v2.app_driver.chat_dialog import ChatDialog
 from app.service.v2.chat import get_chat_token
 
 
@@ -77,23 +78,41 @@
         if session.mindmap:
             inputs = {"is_deep": chat_request.get("isDeep", 1)}
             if session.chat_mode == complex_knowledge_chat:
-                inputs["query_json"] = json.dumps(
-                    {"query": chat_request.get("query", ""), "dataset_ids": chat_request.get("knowledgeId", [])})
-            try:
-                async for ans in chat.chat_completions(url,
-                                                       await chat.complex_request_data(message, session.conversation_id,
-                                                                               str(user_id), files=chat_request.get("files", []), inputs=inputs),
-                                                       await chat.get_headers(token)):
-                    if ans.get("event") == message_error:
-                        return res
-                    elif ans.get("event") == message_event:
-                        mindmap_query += ans.get("answer", "")
-                    else:
-                        continue
+                token = await get_chat_token(db, rg_api_token)
+                # print(token)
+                dialog_url = settings.fwr_base_url + RG_CHAT_DIALOG.format(session.chat_id)
+                dialog_chat = ChatDialog()
+                try:
+                    async for ans in dialog_chat.chat_completions(dialog_url, await dialog_chat.complex_request_data(f"绠�瑕佹�荤粨锛歿message}",
+                                                                                                chat_request["knowledgeId"],
+                                                                                                session.conversation_id),
+                                                           await dialog_chat.get_headers(token)):
+                        if ans.get("code", None) == 102:
+                            return res
+                        else:
+                            if isinstance(ans.get("data"), bool) and ans.get("data") is True:
+                                break
+                            else:
+                                data = ans.get("data", {})
+                                mindmap_query = data.get("answer", "")
+                except Exception as e:
+                    logger.error(e)
+            else:
+                try:
+                    async for ans in chat.chat_completions(url,
+                                                           await chat.complex_request_data(message, session.conversation_id,
+                                                                                   str(user_id), files=chat_request.get("files", []), inputs=inputs),
+                                                           await chat.get_headers(token)):
+                        if ans.get("event") == message_error:
+                            return res
+                        elif ans.get("event") == workflow_finished:
+                            mindmap_query = ans.get("data", {}).get("outputs", {}).get("answer", "")
+                        else:
+                            continue
 
-            except Exception as e:
-                logger.error(e)
-                return res
+                except Exception as e:
+                    logger.error(e)
+                    return res
         else:
             mindmap_query = session.content
         # print("-----------------", mindmap_query)
@@ -107,6 +126,7 @@
                                                    await chat.complex_request_data(mindmap_query, "",
                                                                            str(user_id)),
                                                    await chat.get_headers(token)):
+                # print(ans)
                 if ans.get("event") == message_error:
                     return res
                 elif ans.get("event") == message_event:
@@ -195,6 +215,19 @@
     return parent_list[:index]+new_node_list+parent_list[index+1:]
 
 
+async def service_message_mindmap_parse(db, message_id, user_id):
+    res = {}
+    complex_log = ComplexChatSessionDao(db)
+    session = await complex_log.get_session_by_id(message_id)
+
+    if session.mindmap:
+        try:
+            res_str = await mindmap_join_str(session.mindmap)
+            res["mindmap"] = res_str
+        except Exception as e:
+            logger.error(e)
+    return res
+
 
 if __name__ == '__main__':
     a = '{  "title": "鍏ㄧ敓鍛藉懆鏈熺鐞�",  "items": [    {      "title": "璁惧瑙勫垝涓庨噰璐�",      "items": [        {          "title": "闇�姹傚垎鏋愪笌閫夊瀷"    ,"items": [{"title": "rererer"}, {"title": "trtrtrtrt"}]    },        {          "title": "渚涘簲鍟嗛�夋嫨涓庡悎鍚岀鐞�"        }      ]    },    {      "title": "璁惧瀹夎涓庤皟璇�",      "items": [        {          "title": "瀹夎瑙勮寖"        },        {          "title": "璋冭瘯娴嬭瘯"        }      ]    },    {      "title": "璁惧浣跨敤",      "items": [        {          "title": "鎿嶄綔鍩硅"        },        {          "title": "鎿嶄綔瑙勭▼涓庤褰�"        }      ]    },    {      "title": "璁惧缁存姢涓庣淮淇�",      "items": [        {          "title": "瀹氭湡缁存姢"        },        {          "title": "鏁呴殰璇婃柇"        },        {          "title": "澶囦欢绠$悊"        }      ]    },    {      "title": "璁惧鏇存柊涓庢敼閫�",      "items": [        {          "title": "鎶�鏈瘎浼�"        },        {          "title": "鏇存柊璁″垝"        },        {          "title": "鏀归�犳柟妗�"        }      ]    },    {      "title": "璁惧鎶ュ簾",      "items": [        {          "title": "鎶ュ簾璇勪及"        },        {          "title": "鎶ュ簾澶勭悊"        }      ]    },    {      "title": "淇℃伅鍖栫鐞�",      "items": [        {          "title": "璁惧绠$悊绯荤粺"        },        {          "title": "鏁版嵁鍒嗘瀽"        },        {          "title": "杩滅▼鐩戞帶"        }      ]    },    {      "title": "瀹夊叏绠$悊",      "items": [        {          "title": "瀹夊叏鍩硅"        },        {          "title": "瀹夊叏妫�鏌�"        },        {          "title": "搴旀�ラ妗�"        }      ]    },    {      "title": "鐜淇濇姢",      "items": [        {          "title": "鐜繚璁惧"        },        {          "title": "搴熺墿澶勭悊"        },        {          "title": "鑺傝兘鍑忔帓"        }      ]    },    {      "title": "鍏蜂綋瀹炶返妗堜緥",      "items": [        {          "title": "楂樺帇寮�鍏宠澶囨鼎婊戣剛閫夌敤鐮旂┒"        },        {          "title": "鐜繚鍨� C4 娣锋皵 GIS 璁惧杩愮淮鎶�鏈爺绌�"        }      ]    },    {      "title": "鎬荤粨",      "items": [        {          "title": "鎻愰珮杩愯惀鏁堢巼鍜岀珵浜夊姏"        }      ]    }  ]}'
diff --git a/app/task/fetch_agent.py b/app/task/fetch_agent.py
index 8ad5215..295b7b0 100644
--- a/app/task/fetch_agent.py
+++ b/app/task/fetch_agent.py
@@ -43,6 +43,7 @@
     status = Column(String(1), nullable=False)
     description = Column(String(255), nullable=False)
     tenant_id = Column(String(36), nullable=False)
+    kb_ids = Column(String(128), nullable=False)
 
 
 class DfApps(Base):
@@ -257,13 +258,13 @@
             query = db.query(Dialog.id, Dialog.name, Dialog.description, Dialog.status, Dialog.tenant_id) \
                 .filter(Dialog.name.in_(names), Dialog.status == "1")
         else:
-            query = db.query(Dialog.id, Dialog.name, Dialog.description, Dialog.status, Dialog.tenant_id).filter(
+            query = db.query(Dialog.id, Dialog.name, Dialog.description, Dialog.status, Dialog.tenant_id, Dialog.kb_ids).filter(
                 Dialog.status == "1", Dialog.tenant_id == tenant_id)
 
         results = query.all()
         formatted_results = [
             {"id": row[0], "name": row[1], "description": row[2], "status": "1" if row[3] == "1" else "2",
-             "user_id": str(row[4]), "mode": "agent-dialog", "parameters": para} for row in results if row[0] not in chat_ids]
+             "user_id": str(row[4]), "mode": "agent-dialog", "parameters": para, "kb_ids": row[5]} for row in results if row[0] not in chat_ids]
         return formatted_results
     finally:
         db.close()
@@ -301,13 +302,14 @@
                 existing_agent.name = row["name"]
                 existing_agent.description = row["description"]
                 existing_agent.mode = row["mode"]
+                existing_agent.kb_ids = row.get("kb_ids", "")
                 if existing_agent.status == Dialog_STATSU_DELETE:
                     existing_agent.status = Dialog_STATSU_ON
                 if row["parameters"]:
                     existing_agent.parameters = json.dumps(row["parameters"])
             else:
                 existing = DialogModel(id=row["id"], status=row["status"], name=row["name"],
-                                       description=row["description"],
+                                       description=row["description"], kb_ids=row.get("kb_ids", ""),
                                        tenant_id=get_rag_user_id(db, row["user_id"], type_dict[dialog_type]),
                                        dialog_type=dialog_type, mode=row["mode"], parameters=json.dumps(row["parameters"]))
                 db.add(existing)
@@ -411,10 +413,10 @@
 def get_one_from_ragflow_dialog(dialog_id):
     db = SessionRagflow()
     try:
-        row = db.query(Dialog.id, Dialog.name, Dialog.description, Dialog.status, Dialog.tenant_id) \
+        row = db.query(Dialog.id, Dialog.name, Dialog.description, Dialog.status, Dialog.tenant_id, Dialog.kb_ids) \
             .filter(Dialog.id==dialog_id).first()
         return {"id": row[0], "name": row[1], "description": row[2], "status": str(row[3]),
-                "user_id": str(row[4])} if row else {}
+                "user_id": str(row[4]), "kb_ids": row[5]} if row else {}
     finally:
         db.close()
 

--
Gitblit v1.8.0