From e26a7859a8900b152e10961d91fa6ad19a8deb9c Mon Sep 17 00:00:00 2001
From: zhaoqingang <zhaoqg0118@163.com>
Date: 星期四, 06 三月 2025 14:41:27 +0800
Subject: [PATCH] 首页通用对话增加
---
app/service/v2/app_driver/chat_agent.py | 16 +
app/models/v2/mindmap.py | 18
app/api/v2/chat.py | 55 ++-
app/models/v2/chat.py | 234 ++++++++++++++++
app/models/__init__.py | 2
app/service/v2/chat.py | 204 +++++++++++++-
requirements.txt | 0
app/service/v2/mindmap.py | 198 ++++++++++++++
app/config/const.py | 7
app/task/fetch_agent.py | 18
app/utils/common.py | 6
main.py | 2
app/config/env_conf/admin.yaml | 10
app/api/v2/mindmap.py | 31 ++
14 files changed, 744 insertions(+), 57 deletions(-)
diff --git a/app/api/v2/chat.py b/app/api/v2/chat.py
index b65541e..207d967 100644
--- a/app/api/v2/chat.py
+++ b/app/api/v2/chat.py
@@ -1,22 +1,20 @@
import json
import uuid
-from typing import List
+from typing import List
from fastapi import Depends, APIRouter, File, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse, Response
-from werkzeug.http import HTTP_STATUS_CODES
-
from app.api import get_current_user, get_api_key
-from app.config.const import dialog_chat, advanced_chat, base_chat, agent_chat, workflow_chat, basic_chat, \
- smart_message_error, http_400, http_500, http_200
+from app.config.const import smart_message_error, http_400, http_500, http_200, complex_dialog_chat
from app.models import UserModel
from app.models.base_model import get_db
-from app.models.v2.chat import RetrievalRequest
+from app.models.v2.chat import RetrievalRequest, ChatDataRequest, ComplexChatDao
from app.models.v2.session_model import ChatData
from app.service.v2.chat import service_chat_dialog, get_chat_info, service_chat_basic, \
service_chat_workflow, service_chat_parameters, service_chat_sessions, service_chat_upload, \
- service_chat_sessions_list, service_chat_session_log, service_chunk_retrieval, service_base_chunk_retrieval
+ service_chat_sessions_list, service_chat_session_log, service_chunk_retrieval, service_complex_chat, \
+ service_complex_upload
chat_router_v2 = APIRouter()
@@ -37,7 +35,7 @@
media_type="text/event-stream")
if not session_id:
session = await service_chat_sessions(db, chatId, dialog.query)
- print(session)
+ # print(session)
if not session or session.get("code") != 0:
error_msg = json.dumps(
{"message": smart_message_error, "error": "\n**ERROR**: chat agent error", "status": http_500})
@@ -77,7 +75,7 @@
media_type="text/event-stream")
-@chat_router_v2.post("/complex/{chatId}/completions")
+@chat_router_v2.post("/develop/{chatId}/completions")
async def api_chat_dialog(chatId:str, dialog: ChatData, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): # current_user: UserModel = Depends(get_current_user)
chat_info = await get_chat_info(db, chatId)
if not chat_info:
@@ -128,16 +126,39 @@
return Response(data, media_type="application/json", status_code=http_200)
-
-# @chat_router_v2.post("/conversation/mindmap")
-# async def api_conversation_mindmap(chatId:str, current:int=1, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): # current_user: UserModel = Depends(get_current_user)
-# data = await service_chat_sessions_list(db, chatId, current, pageSize, current_user.id, keyword)
-# return Response(data, media_type="application/json", status_code=http_200)
-
-
-
@chat_router_v2.post("/retrieval")
async def retrieve_chunks(request_data: RetrievalRequest, api_key: str = Depends(get_api_key)):
records = await service_chunk_retrieval(request_data.query, request_data.knowledge_id, request_data.retrieval_setting.top_k, request_data.retrieval_setting.score_threshold, api_key)
return {"records": records}
+
+@chat_router_v2.post("/complex/chat/completions")
+async def api_complex_chat_completions(chat: ChatDataRequest, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): # current_user: UserModel = Depends(get_current_user)
+ complex_chat = await ComplexChatDao(db).get_complex_chat_by_mode(chat.chatMode)
+ if complex_chat:
+ if not chat.sessionId:
+ chat.sessionId = str(uuid.uuid4()).replace("-", "")
+ return StreamingResponse(service_complex_chat(db, complex_chat.id, complex_chat.mode, current_user.id, chat),
+ media_type="text/event-stream")
+ else:
+ error_msg = json.dumps(
+ {"message": smart_message_error, "error": "\n**ERROR**: 缃戠粶寮傚父锛屾棤娉曠敓鎴愬璇濈粨鏋滐紒", "status": http_500})
+ return StreamingResponse(f"data: {error_msg}\n\n",
+ media_type="text/event-stream")
+
+
+@chat_router_v2.post("/complex/upload/{chatMode}")
+async def api_complex_upload(chatMode:int, file: List[UploadFile] = File(...), current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): # current_user: UserModel = Depends(get_current_user)
+ status_code = http_200
+ complex_chat = await ComplexChatDao(db).get_complex_chat_by_mode(chatMode)
+ if complex_chat:
+ data = await service_complex_upload(db, complex_chat.id, file, current_user.id)
+ if not data:
+ status_code = http_400
+ data = "{}"
+ else:
+ status_code = http_500
+ data = "{}"
+ return Response(data, media_type="application/json", status_code=status_code)
+
+
diff --git a/app/api/v2/mindmap.py b/app/api/v2/mindmap.py
new file mode 100644
index 0000000..9f9892b
--- /dev/null
+++ b/app/api/v2/mindmap.py
@@ -0,0 +1,31 @@
+import json
+import uuid
+from typing import List
+
+from fastapi import Depends, APIRouter, File, UploadFile
+from sqlalchemy.orm import Session
+from werkzeug.http import HTTP_STATUS_CODES
+
+from app.api import get_current_user, get_api_key, Response
+from app.config.const import dialog_chat, advanced_chat, base_chat, agent_chat, workflow_chat, basic_chat, \
+ smart_message_error, http_400, http_500, http_200, complex_mindmap_chat
+from app.models import UserModel
+from app.models.base_model import get_db
+from app.models.v2.chat import RetrievalRequest, ComplexChatDao
+from app.models.v2.mindmap import MindmapRequest
+from app.models.v2.session_model import ChatData
+from app.service.v2.mindmap import service_chat_mindmap
+
+mind_map_router = APIRouter()
+
+
+@mind_map_router.post("/create", response_model=Response)
+async def api_chat_mindmap(mindmap: MindmapRequest, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): # current_user: UserModel = Depends(get_current_user)
+ complex_chat = await ComplexChatDao(db).get_complex_chat_by_mode(complex_mindmap_chat)
+ if complex_chat:
+ data = await service_chat_mindmap(db, mindmap.messageId, mindmap.query, complex_chat.id,current_user.id)
+ if not data:
+ return Response(code=500, msg="create failure", data={})
+ else:
+ return Response(code=500, msg="缃戠粶寮傚父锛乫ailure", data={})
+ return Response(code=200, msg="create success", data=data)
\ No newline at end of file
diff --git a/app/config/const.py b/app/config/const.py
index a5d31c5..a173fe6 100644
--- a/app/config/const.py
+++ b/app/config/const.py
@@ -113,3 +113,10 @@
###-------------------------------system-------------------------------------------------
SYSTEM_ID = 1
+
+### --------------------------------complex mode----------------------------------------------
+complex_dialog_chat = 1 # 鏂囨。鍜屽熀纭�瀵硅瘽
+complex_network_chat = 2 # 鑱旂綉瀵硅瘽
+complex_knowledge_chat = 3 # 鐭ヨ瘑搴撳璇�
+# complex_deep_chat = 4
+complex_mindmap_chat = 5
\ No newline at end of file
diff --git a/app/config/env_conf/admin.yaml b/app/config/env_conf/admin.yaml
index 30ad3bf..7c12711 100644
--- a/app/config/env_conf/admin.yaml
+++ b/app/config/env_conf/admin.yaml
@@ -3,10 +3,10 @@
password: gAAAAABnvAq8bErFiR9x_ZcODjUeOdrDo8Z5UVOzyqo6SxIhAvLpw81kciQN0frwIFVfY9wrxH1WqrpTICpEwfH7r2SkLjS7SQ==
chat_server:
- id: fe24dd2c9be611ef92880242ac160006
- account: user@example.com
- password: gAAAAABnvs3e3fZOYfUUAJ6uT80dkhNeN7rhylzZErTWRZThNSLzMbZGetPCe9A2BJ86V0nZBLMNNu8w6rWp4dC7JxYxByJcow==
+ id: 2c039666c29d11efa4670242ac1b0006
+ account: zhao1@example.com
+ password: gAAAAABnpFLtotY2OIRH12BJh4MzMgn5Zil7-DVpIeuqlFwvr0g6g_n4ULogn-LNhCbtk6cCDkzZlqAHvBSX2e_zf7AsoyzbiQ==
workflow_server:
- account: basic@mail.com
- password: gAAAAABnvs5i7xUn9pb2szCozJciGSiWPGv80PH_2HFFzNM2r1ZLTOQqftnUso_bvchtmwAmccfNrf53sf9_WMFVTc0hjTKRRQ==
\ No newline at end of file
+ account: admin@basic.com
+ password: gAAAAABnpFLtotY2OIRH12BJh4MzMgn5Zil7-DVpIeuqlFwvr0g6g_n4ULogn-LNhCbtk6cCDkzZlqAHvBSX2e_zf7AsoyzbiQ==
\ No newline at end of file
diff --git a/app/models/__init__.py b/app/models/__init__.py
index 2f90c68..3047b96 100644
--- a/app/models/__init__.py
+++ b/app/models/__init__.py
@@ -17,6 +17,8 @@
from .menu_model import *
from .label_model import *
from .v2.session_model import *
+from .v2.chat import *
+from .v2.mindmap import *
from .system import *
diff --git a/app/models/v2/chat.py b/app/models/v2/chat.py
index a4818e9..1240c85 100644
--- a/app/models/v2/chat.py
+++ b/app/models/v2/chat.py
@@ -1,5 +1,14 @@
-from pydantic import BaseModel
+import json
+from datetime import datetime
+from typing import List, Optional
+from pydantic import BaseModel
+from sqlalchemy import Column, Integer, String, BigInteger, ForeignKey, DateTime, Text, TEXT
+from sqlalchemy.orm import Session
+
+from app.config.const import Dialog_STATSU_DELETE
+from app.models.base_model import Base
+from app.utils.common import current_time
class RetrievalSetting(BaseModel):
@@ -12,5 +21,228 @@
query: str
retrieval_setting: RetrievalSetting
+class ChatDataRequest(BaseModel):
+ sessionId: str
+ query: str
+ chatMode: Optional[int] = 1 # 1= 鏅�氬璇濓紝2=鑱旂綉锛�3=鐭ヨ瘑搴�,4=娣卞害
+ isDeep: Optional[int] = 1 # 1= 鏅��, 2=娣卞害
+ knowledgeId: Optional[list] = []
+ files: Optional[list] = []
+ def to_dict(self):
+ return {
+ "sessionId": self.sessionId,
+ "query": self.query,
+ "chatMode": self.chatMode,
+ "knowledgeId": self.knowledgeId,
+ "files": self.files,
+ }
+
+
+
+
+class ComplexChatModel(Base):
+ __tablename__ = 'complex_chat'
+ __mapper_args__ = {
+ # "order_by": 'SEQ'
+ }
+ id = Column(String(36), primary_key=True) # id
+ create_date = Column(DateTime, default=datetime.now()) # 鍒涘缓鏃堕棿
+ update_date = Column(DateTime, default=datetime.now(), onupdate=datetime.now()) # 鏇存柊鏃堕棿
+ tenant_id = Column(String(36)) # 鍒涘缓浜�
+ name = Column(String(255)) # 鍚嶇О
+ description = Column(Text) # 璇存槑
+ icon = Column(Text, default="intelligentFrame1") # 鍥炬爣
+ status = Column(String(1), default="1") # 鐘舵��
+ dialog_type = Column(String(1)) # 骞冲彴
+ mode = Column(String(36))
+ parameters = Column(Text)
+ chat_mode = Column(Integer) #1= 鏅�氬璇濓紝2=鑱旂綉锛�3=鐭ヨ瘑搴�,4=娣卞害
+
+ def to_json(self):
+ return {
+ 'id': self.id,
+ 'create_date': self.create_date.strftime('%Y-%m-%d %H:%M:%S'),
+ 'update_date': self.update_date.strftime('%Y-%m-%d %H:%M:%S'),
+ 'user_id': self.tenant_id,
+ 'name': self.name,
+ 'description': self.description,
+ 'icon': self.icon,
+ 'status': self.status,
+ 'agentType': self.dialog_type,
+ 'mode': self.mode,
+ }
+
+class ComplexChatDao:
+ def __init__(self, db: Session):
+ self.db = db
+
+ async def create_complex_chat(self, chat_id: str, **kwargs) -> ComplexChatModel:
+ new_session = ComplexChatModel(
+ id=chat_id,
+ create_date=current_time(),
+ update_date=current_time(),
+ **kwargs
+ )
+ self.db.add(new_session)
+ self.db.commit()
+ self.db.refresh(new_session)
+ return new_session
+
+ async def get_complex_chat_by_id(self, chat_id: str) -> ComplexChatModel | None:
+ session = self.db.query(ComplexChatModel).filter_by(id=chat_id).first()
+ return session
+
+ async def update_complex_chat_by_id(self, chat_id: str, session, message: dict, conversation_id=None) -> ComplexChatModel | None:
+ if not session:
+ session = await self.get_complex_chat_by_id(chat_id)
+ if session:
+ try:
+ # TODO
+ session.update_date = current_time()
+ self.db.commit()
+ self.db.refresh(session)
+ except Exception as e:
+ # logger.error(e)
+ self.db.rollback()
+ return session
+
+ async def update_or_insert_by_id(self, chat_id: str, **kwargs) -> ComplexChatModel:
+ existing_session = await self.get_complex_chat_by_id(chat_id)
+ if existing_session:
+ return await self.update_complex_chat_by_id(chat_id, existing_session, kwargs.get("message"))
+
+ existing_session = await self.create_complex_chat(chat_id, **kwargs)
+ return existing_session
+
+ async def delete_complex_chat(self, chat_id: str) -> None:
+ session = await self.get_complex_chat_by_id(chat_id)
+ if session:
+ self.db.delete(session)
+ self.db.commit()
+
+ async def aget_complex_chat_ids(self) -> List:
+ session_list = self.db.query(ComplexChatModel).filter(ComplexChatModel.status!=Dialog_STATSU_DELETE).all()
+
+ return [i.id for i in session_list]
+
+ def get_complex_chat_ids(self) -> List:
+ session_list = self.db.query(ComplexChatModel).filter(ComplexChatModel.status!=Dialog_STATSU_DELETE).all()
+
+ return [i.id for i in session_list]
+
+ async def get_complex_chat_by_mode(self, chat_mode: int) -> ComplexChatModel | None:
+ session = self.db.query(ComplexChatModel).filter(ComplexChatModel.chat_mode==chat_mode, ComplexChatModel.status!=Dialog_STATSU_DELETE).first()
+ return session
+
+
+
+class ComplexChatSessionModel(Base):
+ __tablename__ = "complex_chat_sessions"
+
+
+ id = Column(String(36), primary_key=True)
+ chat_id = Column(String(36))
+ session_id = Column(String(36), index=True)
+ create_date = Column(DateTime, default=current_time, index=True) # 鍒涘缓鏃堕棿锛岄粯璁ゅ�间负褰撳墠鏃跺尯鏃堕棿
+ update_date = Column(DateTime, default=current_time, onupdate=current_time) # 鏇存柊鏃堕棿锛岄粯璁ゅ�间负褰撳墠鏃跺尯鏃堕棿锛屾洿鏂版椂鑷姩鏇存柊
+ tenant_id = Column(Integer, index=True) # 鍒涘缓浜�
+ agent_type = Column(Integer) # 1=rg锛� 3=basic锛�4=df
+ message_type = Column(Integer) # 1=鐢ㄦ埛锛�2=鏈哄櫒浜猴紝3=绯荤粺
+ content = Column(TEXT)
+ mindmap = Column(TEXT)
+ query = Column(TEXT)
+ node_data = Column(TEXT)
+ event_type = Column(String(16))
+ conversation_id = Column(String(36))
+ chat_mode = Column(Integer) # 1= 鏅�氬璇濓紝2=鑱旂綉锛�3=鐭ヨ瘑搴�,4=娣卞害
+
+ # to_dict 鏂规硶
+ def to_dict(self):
+ return {
+ 'session_id': self.id,
+ 'name': self.name,
+ 'agent_type': self.agent_type,
+ 'chat_id': self.agent_id,
+ 'event_type': self.event_type,
+ 'session_type': self.session_type if self.session_type else 0,
+ 'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
+ 'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
+ }
+
+ def log_to_json(self):
+ return {
+ 'id': self.id,
+ 'name': self.name,
+ 'agent_type': self.agent_type,
+ 'chat_id': self.agent_id,
+ 'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
+ 'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
+ 'message': json.loads(self.message)
+ }
+
+
+class ComplexChatSessionDao:
+ def __init__(self, db: Session):
+ self.db = db
+
+ async def get_session_by_session_id(self, session_id: str, chat_id:str) -> ComplexChatSessionModel | None:
+ session = self.db.query(ComplexChatSessionModel).filter_by(chat_id=chat_id, session_id=session_id, message_type=2).first()
+ return session
+
+ async def create_session(self, message_id: str, **kwargs) -> ComplexChatSessionModel:
+ new_session = ComplexChatSessionModel(
+ id=message_id,
+ create_date=current_time(),
+ update_date=current_time(),
+ **kwargs
+ )
+ self.db.add(new_session)
+ self.db.commit()
+ self.db.refresh(new_session)
+ return new_session
+
+ async def get_session_by_id(self, message_id: str) -> ComplexChatSessionModel | None:
+ session = self.db.query(ComplexChatSessionModel).filter_by(id=message_id).first()
+ return session
+
+ async def update_mindmap_by_id(self, message_id: str, mindmap:str) -> ComplexChatSessionModel | None:
+ # print(message)
+
+ session = await self.get_session_by_id(message_id)
+ if session:
+ try:
+
+ session.mindmap = mindmap
+ session.update_date = current_time()
+ self.db.commit()
+ self.db.refresh(session)
+ except Exception as e:
+ # logger.error(e)
+ self.db.rollback()
+ return session
+
+ async def update_or_insert_by_id(self, session_id: str, **kwargs) -> ComplexChatSessionModel:
+ existing_session = await self.get_session_by_id(session_id)
+ if existing_session:
+ return await self.update_session_by_id(session_id, existing_session, kwargs.get("message"))
+
+ existing_session = await self.create_session(session_id, **kwargs)
+ return existing_session
+
+ async def delete_session(self, session_id: str) -> None:
+ session = await self.get_session_by_id(session_id)
+ if session:
+ self.db.delete(session)
+ self.db.commit()
+
+ async def get_session_list(self, user_id: int, agent_id: str, keyword:str, page: int, page_size: int) -> any:
+ query = self.db.query(ComplexChatSessionModel).filter(ComplexChatSessionModel.tenant_id==user_id)
+ if agent_id:
+ query = query.filter(ComplexChatSessionModel.agent_id==agent_id)
+ if keyword:
+ query = query.filter(ComplexChatSessionModel.name.like('%{}%'.format(keyword)))
+ total = query.count()
+ session_list = query.order_by(ComplexChatSessionModel.update_date.desc()).offset((page-1)*page_size).limit(page_size).all()
+ return total, session_list
\ No newline at end of file
diff --git a/app/models/v2/mindmap.py b/app/models/v2/mindmap.py
index cead556..743bb5f 100644
--- a/app/models/v2/mindmap.py
+++ b/app/models/v2/mindmap.py
@@ -1,12 +1,16 @@
import json
-from typing import Optional, Type, List
+from datetime import datetime
+from typing import List
+
from pydantic import BaseModel
+from sqlalchemy import Column, Integer, String, BigInteger, ForeignKey, DateTime, Text
+from sqlalchemy.orm import Session
+from app.config.const import Dialog_STATSU_DELETE
+from app.models.base_model import Base
+from app.utils.common import current_time
+class MindmapRequest(BaseModel):
+ messageId: str
+ query:str
-
-class ChatData(BaseModel):
- sessionId: Optional[str] = ""
-
- class Config:
- extra = 'allow' # 鍏佽鍏朵粬鍔ㄦ�佸瓧娈�
\ No newline at end of file
diff --git a/app/service/v2/app_driver/chat_agent.py b/app/service/v2/app_driver/chat_agent.py
index 45f804e..5fa0bfa 100644
--- a/app/service/v2/app_driver/chat_agent.py
+++ b/app/service/v2/app_driver/chat_agent.py
@@ -9,6 +9,7 @@
async def chat_completions(self, url, data, headers):
complete_response = ""
+ # print(data)
async for line in self.http_stream(url, data, headers):
# logger.error(line)
if line.startswith("data:"):
@@ -46,6 +47,21 @@
"files": files
}
+ @staticmethod
+ async def complex_request_data(query: str, conversation_id: str, user: str, files: list=None, inputs: dict=None) -> dict:
+ if not files:
+ files = []
+ if not inputs:
+ inputs = {}
+ return {
+ "inputs": inputs,
+ "query": query,
+ "response_mode": "streaming",
+ "conversation_id": conversation_id,
+ "user": user,
+ "files": files
+ }
+
if __name__ == "__main__":
async def aa():
diff --git a/app/service/v2/chat.py b/app/service/v2/chat.py
index 68a7cd3..05a37be 100644
--- a/app/service/v2/chat.py
+++ b/app/service/v2/chat.py
@@ -1,6 +1,7 @@
import asyncio
import io
import json
+import uuid
import fitz
from fastapi import HTTPException
@@ -10,7 +11,7 @@
DF_CHAT_WORKFLOW, DF_UPLOAD_FILE, RG_ORIGINAL_URL
from app.config.config import settings
from app.config.const import *
-from app.models import DialogModel, ApiTokenModel, UserTokenModel
+from app.models import DialogModel, ApiTokenModel, UserTokenModel, ComplexChatSessionDao, ChatDataRequest
from app.models.v2.session_model import ChatSessionDao, ChatData
from app.service.v2.app_driver.chat_agent import ChatAgent
from app.service.v2.app_driver.chat_data import ChatBaseApply
@@ -47,12 +48,12 @@
logger.error(e)
return None
+
async def get_app_token(db, app_id):
app_token = db.query(UserTokenModel).filter_by(id=app_id).first()
if app_token:
return app_token.access_token
return ""
-
async def get_chat_token(db, app_id):
@@ -69,7 +70,6 @@
db.commit()
except Exception as e:
logger.error(e)
-
async def get_chat_info(db, chat_id: str):
@@ -134,6 +134,7 @@
message["role"] = "assistant"
await update_session_log(db, session_id, message, conversation_id)
+
async def data_process(data):
if isinstance(data, str):
return data.replace("dify", "smart")
@@ -170,8 +171,9 @@
if hasattr(chat_data, "query"):
query = chat_data.query
else:
- query = "start new workflow"
- session = await add_session_log(db, session_id,query if query else "start new conversation", chat_id, user_id, mode, conversation_id, 3)
+ query = "start new conversation"
+ session = await add_session_log(db, session_id, query if query else "start new conversation", chat_id, user_id,
+ mode, conversation_id, 3)
if session:
conversation_id = session.conversation_id
try:
@@ -215,7 +217,7 @@
[workflow_started, node_started, node_finished].index(ans.get("event"))]
elif ans.get("event") == workflow_finished:
data = ans.get("data", {})
- answer_workflow = data.get("outputs", {}).get("output")
+ answer_workflow = data.get("outputs", {}).get("output", data.get("outputs", {}).get("answer"))
download_url = data.get("outputs", {}).get("download_url")
event = smart_workflow_finished
if data.get("status") == "failed":
@@ -242,8 +244,9 @@
except:
...
finally:
- await update_session_log(db, session_id, {"role": "assistant", "answer": answer_event or answer_agent or answer_workflow or error,
- "download_url":download_url,
+ await update_session_log(db, session_id, {"role": "assistant",
+ "answer": answer_event or answer_agent or answer_workflow or error,
+ "download_url": download_url,
"node_list": node_list, "task_id": task_id, "id": message_id,
"error": error}, conversation_id)
@@ -257,6 +260,7 @@
if not chat_info:
return {}
return chat_info.parameters
+
async def service_chat_sessions(db, chat_id, name):
token = await get_chat_token(db, rg_api_token)
@@ -276,14 +280,12 @@
page=current,
page_size=page_size
)
- return json.dumps({"total":total, "rows": [session.to_dict() for session in session_list]})
-
+ return json.dumps({"total": total, "rows": [session.to_dict() for session in session_list]})
async def service_chat_session_log(db, session_id):
session_log = await ChatSessionDao(db).get_session_by_id(session_id)
- return json.dumps(session_log.log_to_json())
-
+ return json.dumps(session_log.log_to_json() if session_log else {})
async def service_chat_upload(db, chat_id, file, user_id):
@@ -316,6 +318,7 @@
tokens = tokenizer.encode(input_str)
return len(tokens)
+
async def read_pdf(pdf_stream):
text = ""
with fitz.open(stream=pdf_stream, filetype="pdf") as pdf_document:
@@ -335,6 +338,7 @@
return text
+
async def read_file(file, filename, content_type):
text = ""
if content_type == "application/pdf" or filename.endswith('.pdf'):
@@ -349,7 +353,7 @@
async def service_chunk_retrieval(query, knowledge_id, top_k, similarity_threshold, api_key):
- print(query)
+ # print(query)
try:
request_data = json.loads(query)
@@ -357,7 +361,7 @@
"question": request_data.get("query", ""),
"dataset_ids": request_data.get("dataset_ids", []),
"page_size": top_k,
- "similarity_threshold": similarity_threshold
+ "similarity_threshold": similarity_threshold if similarity_threshold else 0.2
}
except json.JSONDecodeError as e:
fixed_json = query.replace("'", '"')
@@ -367,15 +371,16 @@
"question": request_data.get("query", ""),
"dataset_ids": request_data.get("dataset_ids", []),
"page_size": top_k,
- "similarity_threshold": similarity_threshold
+ "similarity_threshold": similarity_threshold if similarity_threshold else 0.2
}
except Exception:
payload = {
- "question":query,
- "dataset_ids":[knowledge_id],
+ "question": query,
+ "dataset_ids": [knowledge_id],
"page_size": top_k,
- "similarity_threshold": similarity_threshold
+ "similarity_threshold": similarity_threshold if similarity_threshold else 0.2
}
+ # print(payload)
url = settings.fwr_base_url + RG_ORIGINAL_URL
chat = ChatBaseApply()
response = await chat.chat_post(url, payload, await chat.get_headers(api_key))
@@ -388,11 +393,15 @@
"title": chunk.get("document_keyword", "Unknown Document"),
"metadata": {"document_id": chunk["document_id"],
"path": f"{settings.fwr_base_url}/document/{chunk['document_id']}?ext={chunk.get('document_keyword').split('.')[-1]}&prefix=document",
- 'highlight': chunk.get("highlight") , "image_id": chunk.get("image_id"), "positions": chunk.get("positions"),}
+ 'highlight': chunk.get("highlight"), "image_id": chunk.get("image_id"),
+ "positions": chunk.get("positions"), }
}
for chunk in response.get("data", {}).get("chunks", [])
]
+ # print(len(records))
+ # print(records)
return records
+
async def service_base_chunk_retrieval(query, knowledge_id, top_k, similarity_threshold, api_key):
# request_data = json.loads(query)
@@ -420,15 +429,170 @@
return records
+async def add_complex_log(db, message_id, chat_id, session_id, chat_mode, query, user_id, mode, agent_type, message_type, conversation_id="", node_data=None, query_data=None):
+ if not node_data:
+ node_data = []
+ if not query_data:
+ query_data = {}
+ try:
+ complex_log = ComplexChatSessionDao(db)
+ if not conversation_id:
+ session = await complex_log.get_session_by_session_id(session_id, chat_id)
+ if session:
+ conversation_id = session.conversation_id
+ await complex_log.create_session(message_id,
+ chat_id=chat_id,
+ session_id=session_id,
+ chat_mode=chat_mode,
+ message_type=message_type,
+ content=query,
+ event_type=mode,
+ tenant_id=user_id,
+ conversation_id=conversation_id,
+ node_data=json.dumps(node_data),
+ query=json.dumps(query_data),
+ agent_type=agent_type)
+ return conversation_id, True
+
+ except Exception as e:
+ logger.error(e)
+ return conversation_id, False
+
+
+
+async def service_complex_chat(db, chat_id, mode, user_id, chat_request: ChatDataRequest):
+ answer_event = ""
+ answer_agent = ""
+ answer_workflow = ""
+ download_url = ""
+ message_id = ""
+ task_id = ""
+ error = ""
+ files = []
+ node_list = []
+ token = await get_chat_token(db, chat_id)
+ chat, url = await get_chat_object(mode)
+ conversation_id, message = await add_complex_log(db, str(uuid.uuid4()),chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=chat_request.to_dict())
+ if not message:
+ yield "data: " + json.dumps({"message": smart_message_error,
+ "error": "\n**ERROR**: 鍒涘缓浼氳瘽澶辫触锛�", "status": http_500},
+ ensure_ascii=False) + "\n\n"
+ return
+ inputs = {"is_deep": chat_request.isDeep}
+ if chat_request.chatMode == complex_knowledge_chat:
+ inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId})
+
+ try:
+ async for ans in chat.chat_completions(url,
+ await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=chat_request.files, inputs=inputs),
+ await chat.get_headers(token)):
+ # print(ans)
+ data = {}
+ status = http_200
+ conversation_id = ans.get("conversation_id")
+ task_id = ans.get("task_id")
+ if ans.get("event") == message_error:
+ error = ans.get("message", "鍙傛暟寮傚父锛�")
+ status = http_400
+ event = smart_message_error
+ elif ans.get("event") == message_agent:
+ data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")}
+ answer_agent += ans.get("answer", "")
+ message_id = ans.get("message_id", "")
+ event = smart_message_stream
+ elif ans.get("event") == message_event:
+ data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")}
+ answer_event += ans.get("answer", "")
+ message_id = ans.get("message_id", "")
+ event = smart_message_stream
+ elif ans.get("event") == message_file:
+ data = {"url": ans.get("url", ""), "id": ans.get("id", ""),
+ "type": ans.get("type", "")}
+ files.append(data)
+ event = smart_message_file
+ elif ans.get("event") in [workflow_started, node_started, node_finished]:
+ data = ans.get("data", {})
+ data["inputs"] = await data_process(data.get("inputs", {}))
+ data["outputs"] = await data_process(data.get("outputs", {}))
+ data["files"] = await data_process(data.get("files", []))
+ data["process_data"] = ""
+ if data.get("status") == "failed":
+ status = http_500
+ error = data.get("error", "")
+ node_list.append(ans)
+ event = [smart_workflow_started, smart_node_started, smart_node_finished][
+ [workflow_started, node_started, node_finished].index(ans.get("event"))]
+ elif ans.get("event") == workflow_finished:
+ data = ans.get("data", {})
+ answer_workflow = data.get("outputs", {}).get("output", data.get("outputs", {}).get("answer"))
+ download_url = data.get("outputs", {}).get("download_url")
+ event = smart_workflow_finished
+ if data.get("status") == "failed":
+ status = http_500
+ error = data.get("error", "")
+ node_list.append(ans)
+
+ elif ans.get("event") == message_end:
+ event = smart_message_end
+ else:
+ continue
+
+ yield "data: " + json.dumps(
+ {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id,
+ "session_id": chat_request.sessionId},
+ ensure_ascii=False) + "\n\n"
+
+ except Exception as e:
+ logger.error(e)
+ try:
+ yield "data: " + json.dumps({"message": smart_message_error,
+ "error": "\n**ERROR**: " + str(e), "status": http_500},
+ ensure_ascii=False) + "\n\n"
+ except:
+ ...
+ finally:
+ # await update_session_log(db, session_id, {"role": "assistant",
+ # "answer": answer_event or answer_agent or answer_workflow or error,
+ # "download_url": download_url,
+ # "node_list": node_list, "task_id": task_id, "id": message_id,
+ # "error": error}, conversation_id)
+ if message_id:
+ await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=chat_request.to_dict())
+
+async def service_complex_upload(db, chat_id, file, user_id):
+ files = []
+ token = await get_chat_token(db, chat_id)
+ if not token:
+ return files
+ url = settings.dify_base_url + DF_UPLOAD_FILE
+ chat = ChatBaseApply()
+ for f in file:
+ try:
+ file_content = await f.read()
+ file_upload = await chat.chat_upload(url, {"file": (f.filename, file_content)}, {"user": str(user_id)},
+ {'Authorization': f'Bearer {token}'})
+ # try:
+ # tokens = await read_file(file_content, f.filename, f.content_type)
+ # file_upload["tokens"] = tokens
+ # except:
+ # ...
+ files.append(file_upload)
+ except Exception as e:
+ logger.error(e)
+ return json.dumps(files) if files else ""
if __name__ == "__main__":
q = json.dumps({"query": "璁惧", "dataset_ids": ["fc68db52f43111efb94a0242ac120004"]})
top_k = 2
similarity_threshold = 0.5
api_key = "ragflow-Y4MGYwY2JlZjM2YjExZWY4ZWU5MDI0Mm"
+
+
# a = service_chunk_retrieval(q, top_k, similarity_threshold, api_key)
# print(a)
async def a():
b = await service_chunk_retrieval(q, top_k, similarity_threshold, api_key)
print(b)
- asyncio.run(a())
\ No newline at end of file
+
+
+ asyncio.run(a())
diff --git a/app/service/v2/mindmap.py b/app/service/v2/mindmap.py
new file mode 100644
index 0000000..dd5b994
--- /dev/null
+++ b/app/service/v2/mindmap.py
@@ -0,0 +1,198 @@
+import json
+from Log import logger
+from app.config.agent_base_url import DF_CHAT_AGENT
+from app.config.config import settings
+from app.config.const import message_error, message_event, complex_knowledge_chat
+from app.models import ComplexChatSessionDao, ChatData
+from app.service.v2.app_driver.chat_agent import ChatAgent
+from app.service.v2.chat import get_chat_token
+
+
+async def service_chat_mindmap_v1(db, message_id, message, mindmap_chat_id, user_id):
+ res = {}
+ mindmap_query = ""
+ complex_log = ComplexChatSessionDao(db)
+ session = await complex_log.get_session_by_id(message_id)
+ if session:
+ token = await get_chat_token(db, session.chat_id)
+ chat = ChatAgent()
+ url = settings.dify_base_url + DF_CHAT_AGENT
+ if session.mindmap:
+ chat_request = json.loads(session.query)
+ try:
+ async for ans in chat.chat_completions(url,
+ await chat.request_data(message, session.conversation_id,
+ str(user_id), ChatData(), chat_request.get("files", [])),
+ await chat.get_headers(token)):
+ if ans.get("event") == message_error:
+ return res
+ elif ans.get("event") == message_event:
+ mindmap_query += ans.get("answer", "")
+ else:
+ continue
+
+ except Exception as e:
+ logger.error(e)
+ return res
+ else:
+ mindmap_query = session.content
+
+ try:
+ mindmap_str = ""
+ token = await get_chat_token(db, mindmap_chat_id)
+ async for ans in chat.chat_completions(url,
+ await chat.request_data(mindmap_query, "",
+ str(user_id), ChatData()),
+ await chat.get_headers(token)):
+ if ans.get("event") == message_error:
+ return res
+ elif ans.get("event") == message_event:
+ mindmap_str += ans.get("answer", "")
+ else:
+ continue
+
+ except Exception as e:
+ logger.error(e)
+ return res
+ mindmap_list = mindmap_str.split("```")
+ mindmap_str = mindmap_list[1].lstrip("markdown\n")
+ if session.mindmap:
+ node_list = await mindmap_to_merge(session.mindmap, mindmap_str, f"- {message}")
+ mindmap_str = "\n".join(node_list)
+ res["mindmap"] = mindmap_str
+ await complex_log.update_mindmap_by_id(message_id, mindmap_str)
+ return res
+
+
+async def service_chat_mindmap(db, message_id, message, mindmap_chat_id, user_id):
+ res = {}
+ mindmap_query = ""
+ complex_log = ComplexChatSessionDao(db)
+ session = await complex_log.get_session_by_id(message_id)
+ if session:
+ token = await get_chat_token(db, session.chat_id)
+ chat = ChatAgent()
+ url = settings.dify_base_url + DF_CHAT_AGENT
+ if session.mindmap:
+ chat_request = json.loads(session.query)
+ inputs = {"is_deep": chat_request.get("isDeep", 1)}
+ if session.chat_mode == complex_knowledge_chat:
+ inputs["query_json"] = json.dumps(
+ {"query": chat_request.get("query", ""), "dataset_ids": chat_request.get("knowledgeId", [])})
+ try:
+ async for ans in chat.chat_completions(url,
+ await chat.complex_request_data(message, session.conversation_id,
+ str(user_id), files=chat_request.get("files", []), inputs=inputs),
+ await chat.get_headers(token)):
+ if ans.get("event") == message_error:
+ return res
+ elif ans.get("event") == message_event:
+ mindmap_query += ans.get("answer", "")
+ else:
+ continue
+
+ except Exception as e:
+ logger.error(e)
+ return res
+ else:
+ mindmap_query = session.content
+
+ try:
+ mindmap_str = ""
+ token = await get_chat_token(db, mindmap_chat_id)
+ async for ans in chat.chat_completions(url,
+ await chat.complex_request_data(mindmap_query, "",
+ str(user_id)),
+ await chat.get_headers(token)):
+ if ans.get("event") == message_error:
+ return res
+ elif ans.get("event") == message_event:
+ mindmap_str += ans.get("answer", "")
+ else:
+ continue
+
+ except Exception as e:
+ logger.error(e)
+ return res
+ if "```json" in mindmap_str:
+ mindmap_list = mindmap_str.split("```")
+ mindmap_str = mindmap_list[1].lstrip("json")
+ mindmap_str = mindmap_str.replace("\n", "")
+ if session.mindmap:
+ mindmap_str = await mindmap_merge_dict(session.mindmap, mindmap_str, message)
+
+ try:
+ res_str = await mindmap_join_str(mindmap_str)
+ res["mindmap"] = res_str
+ except Exception as e:
+ logger.error(e)
+ return res
+ await complex_log.update_mindmap_by_id(message_id, mindmap_str)
+ return res
+
+async def mindmap_merge_dict(parent, child, target_node):
+ parent_dict = json.loads(parent)
+ if child:
+ child_dict = json.loads(child)
+ def iter_dict(node):
+ if "items" not in node:
+ if node["title"] == target_node:
+ node["items"] = child_dict["items"]
+ return
+ else:
+ for i in node["items"]:
+ iter_dict(i)
+ iter_dict(parent_dict)
+
+ return json.dumps(parent_dict)
+
+
+async def mindmap_join_str(mindmap_json):
+ try:
+ parent_dict = json.loads(mindmap_json)
+ except Exception as e:
+ logger.error(e)
+ return ""
+ def join_node(node, level):
+ mindmap_str = ""
+ if level <= 2:
+ mindmap_str += f"{'#'*level} {node['title']}\n"
+ else:
+ mindmap_str += f"{' '*(level-3)*2}- {node['title']}\n"
+ for i in node.get("items", []):
+ mindmap_str += join_node(i, level+1)
+ return mindmap_str
+
+ return join_node(parent_dict, 1)
+
+async def mindmap_to_merge(parent, child, target_node):
+ level = 0
+ index = 0
+ new_node_list = []
+ parent_list= parent.split("\n")
+ child_list= child.split("\n")
+ child_list[0] = target_node
+ for i, node in enumerate(parent_list):
+ if node.endswith(target_node):
+ level = len(node) - len(target_node)
+ index = i
+ break
+ tmp_level = 0
+ for child in child_list:
+ if "#" in child:
+ childs = child.split("#")
+ tmp_level = len(childs) - 2
+ new_node_list.append(" "*(level+tmp_level)+ "-"+childs[-1])
+ elif len(child) == 0:
+ continue
+ else:
+ new_node_list.append(" "*(level+tmp_level)+child)
+
+ return parent_list[:index]+new_node_list+parent_list[index+1:]
+
+
+
+if __name__ == '__main__':
+ a = '{ "title": "鍏ㄧ敓鍛藉懆鏈熺鐞�", "items": [ { "title": "璁惧瑙勫垝涓庨噰璐�", "items": [ { "title": "闇�姹傚垎鏋愪笌閫夊瀷" ,"items": [{"title": "rererer"}, {"title": "trtrtrtrt"}] }, { "title": "渚涘簲鍟嗛�夋嫨涓庡悎鍚岀鐞�" } ] }, { "title": "璁惧瀹夎涓庤皟璇�", "items": [ { "title": "瀹夎瑙勮寖" }, { "title": "璋冭瘯娴嬭瘯" } ] }, { "title": "璁惧浣跨敤", "items": [ { "title": "鎿嶄綔鍩硅" }, { "title": "鎿嶄綔瑙勭▼涓庤褰�" } ] }, { "title": "璁惧缁存姢涓庣淮淇�", "items": [ { "title": "瀹氭湡缁存姢" }, { "title": "鏁呴殰璇婃柇" }, { "title": "澶囦欢绠$悊" } ] }, { "title": "璁惧鏇存柊涓庢敼閫�", "items": [ { "title": "鎶�鏈瘎浼�" }, { "title": "鏇存柊璁″垝" }, { "title": "鏀归�犳柟妗�" } ] }, { "title": "璁惧鎶ュ簾", "items": [ { "title": "鎶ュ簾璇勪及" }, { "title": "鎶ュ簾澶勭悊" } ] }, { "title": "淇℃伅鍖栫鐞�", "items": [ { "title": "璁惧绠$悊绯荤粺" }, { "title": "鏁版嵁鍒嗘瀽" }, { "title": "杩滅▼鐩戞帶" } ] }, { "title": "瀹夊叏绠$悊", "items": [ { "title": "瀹夊叏鍩硅" }, { "title": "瀹夊叏妫�鏌�" }, { "title": "搴旀�ラ妗�" } ] }, { "title": "鐜淇濇姢", "items": [ { "title": "鐜繚璁惧" }, { "title": "搴熺墿澶勭悊" }, { "title": "鑺傝兘鍑忔帓" } ] }, { "title": "鍏蜂綋瀹炶返妗堜緥", "items": [ { "title": "楂樺帇寮�鍏宠澶囨鼎婊戣剛閫夌敤鐮旂┒" }, { "title": "鐜繚鍨� C4 娣锋皵 GIS 璁惧杩愮淮鎶�鏈爺绌�" } ] }, { "title": "鎬荤粨", "items": [ { "title": "鎻愰珮杩愯惀鏁堢巼鍜岀珵浜夊姏" } ] } ]}'
+ b = mindmap_merge_dict(a, {}, "璁惧瑙勫垝涓庨噰璐�")
+ print(b)
\ No newline at end of file
diff --git a/app/task/fetch_agent.py b/app/task/fetch_agent.py
index a5a7bfb..8ad5215 100644
--- a/app/task/fetch_agent.py
+++ b/app/task/fetch_agent.py
@@ -9,7 +9,7 @@
from app.config.config import settings
from app.config.const import RAGFLOW, BISHENG, DIFY, ENV_CONF_PATH, Dialog_STATSU_DELETE, Dialog_STATSU_ON
-from app.models import KnowledgeModel
+from app.models import KnowledgeModel, ComplexChatDao
from app.models.dialog_model import DialogModel
from app.models.user_model import UserAppModel
from app.models.agent_model import AgentModel
@@ -239,7 +239,7 @@
db.close()
-def get_data_from_ragflow_v2(names: List[str], tenant_id) -> List[Dict]:
+def get_data_from_ragflow_v2(base_db, names: List[str], tenant_id) -> List[Dict]:
db = SessionRagflow()
para = {
"user_input_form": [],
@@ -251,6 +251,8 @@
}
}
try:
+ chat_ids = ComplexChatDao(base_db).get_complex_chat_ids()
+ # print(chat_ids)
if names:
query = db.query(Dialog.id, Dialog.name, Dialog.description, Dialog.status, Dialog.tenant_id) \
.filter(Dialog.name.in_(names), Dialog.status == "1")
@@ -261,15 +263,17 @@
results = query.all()
formatted_results = [
{"id": row[0], "name": row[1], "description": row[2], "status": "1" if row[3] == "1" else "2",
- "user_id": str(row[4]), "mode": "agent-dialog", "parameters": para} for row in results]
+ "user_id": str(row[4]), "mode": "agent-dialog", "parameters": para} for row in results if row[0] not in chat_ids]
return formatted_results
finally:
db.close()
-def get_data_from_dy_v2(names: List[str]) -> List[Dict]:
+def get_data_from_dy_v2(base_db, names: List[str]) -> List[Dict]:
db = SessionDify()
try:
+ chat_ids = ComplexChatDao(base_db).get_complex_chat_ids()
+ # print(chat_ids)
if names:
query = db.query(DfApps.id, DfApps.name, DfApps.description, DfApps.status, DfApps.tenant_id, DfApps.mode) \
.filter(DfApps.name.in_(names))
@@ -279,7 +283,7 @@
results = query.all()
formatted_results = [
{"id": str(row[0]), "name": row[1], "description": row[2], "status": "1",
- "user_id": str(row[4]), "mode": row[5], "parameters": {}} for row in results]
+ "user_id": str(row[4]), "mode": row[5], "parameters": {}} for row in results if str(row[0]) not in chat_ids]
return formatted_results
finally:
db.close()
@@ -342,11 +346,11 @@
for app in app_register:
try:
if app["id"] == RAGFLOW:
- ragflow_data = get_data_from_ragflow_v2([], app["name"])
+ ragflow_data = get_data_from_ragflow_v2(db, [], app["name"])
if ragflow_data:
update_ids_in_local_v2(ragflow_data, "1")
elif app["id"] == DIFY:
- dify_data = get_data_from_dy_v2([])
+ dify_data = get_data_from_dy_v2(db, [])
if dify_data:
update_ids_in_local_v2(dify_data, "4")
except Exception as e:
diff --git a/app/utils/common.py b/app/utils/common.py
new file mode 100644
index 0000000..e3ac4a6
--- /dev/null
+++ b/app/utils/common.py
@@ -0,0 +1,6 @@
+import pytz
+from datetime import datetime
+
+def current_time():
+ tz = pytz.timezone('Asia/Shanghai')
+ return datetime.now(tz)
diff --git a/main.py b/main.py
index d0ba8d6..87597e7 100644
--- a/main.py
+++ b/main.py
@@ -15,6 +15,7 @@
from app.api.organization import dept_router
from app.api.system import system_router
from app.api.v2.chat import chat_router_v2
+from app.api.v2.mindmap import mind_map_router
from app.api.v2.public_api import public_api
from app.api.report import router as report_router
from app.api.resource import menu_router
@@ -77,6 +78,7 @@
app.include_router(public_api, prefix='/v1/api', tags=["public_api"])
app.include_router(chat_router_v2, prefix='/api/v1', tags=["chat1"])
app.include_router(system_router, prefix='/api/system', tags=["system"])
+app.include_router(mind_map_router, prefix='/api/mindmap', tags=["mindmap"])
app.mount("/static", StaticFiles(directory="app/images"), name="static")
if __name__ == "__main__":
diff --git a/requirements.txt b/requirements.txt
index 4d645ea..5ba14b3 100644
--- a/requirements.txt
+++ b/requirements.txt
Binary files differ
--
Gitblit v1.8.0