From e26a7859a8900b152e10961d91fa6ad19a8deb9c Mon Sep 17 00:00:00 2001 From: zhaoqingang <zhaoqg0118@163.com> Date: 星期四, 06 三月 2025 14:41:27 +0800 Subject: [PATCH] 首页通用对话增加 --- app/service/v2/app_driver/chat_agent.py | 16 + app/models/v2/mindmap.py | 18 app/api/v2/chat.py | 55 ++- app/models/v2/chat.py | 234 ++++++++++++++++ app/models/__init__.py | 2 app/service/v2/chat.py | 204 +++++++++++++- requirements.txt | 0 app/service/v2/mindmap.py | 198 ++++++++++++++ app/config/const.py | 7 app/task/fetch_agent.py | 18 app/utils/common.py | 6 main.py | 2 app/config/env_conf/admin.yaml | 10 app/api/v2/mindmap.py | 31 ++ 14 files changed, 744 insertions(+), 57 deletions(-) diff --git a/app/api/v2/chat.py b/app/api/v2/chat.py index b65541e..207d967 100644 --- a/app/api/v2/chat.py +++ b/app/api/v2/chat.py @@ -1,22 +1,20 @@ import json import uuid -from typing import List +from typing import List from fastapi import Depends, APIRouter, File, UploadFile from sqlalchemy.orm import Session from starlette.responses import StreamingResponse, Response -from werkzeug.http import HTTP_STATUS_CODES - from app.api import get_current_user, get_api_key -from app.config.const import dialog_chat, advanced_chat, base_chat, agent_chat, workflow_chat, basic_chat, \ - smart_message_error, http_400, http_500, http_200 +from app.config.const import smart_message_error, http_400, http_500, http_200, complex_dialog_chat from app.models import UserModel from app.models.base_model import get_db -from app.models.v2.chat import RetrievalRequest +from app.models.v2.chat import RetrievalRequest, ChatDataRequest, ComplexChatDao from app.models.v2.session_model import ChatData from app.service.v2.chat import service_chat_dialog, get_chat_info, service_chat_basic, \ service_chat_workflow, service_chat_parameters, service_chat_sessions, service_chat_upload, \ - service_chat_sessions_list, service_chat_session_log, service_chunk_retrieval, service_base_chunk_retrieval + service_chat_sessions_list, service_chat_session_log, service_chunk_retrieval, service_complex_chat, \ + service_complex_upload chat_router_v2 = APIRouter() @@ -37,7 +35,7 @@ media_type="text/event-stream") if not session_id: session = await service_chat_sessions(db, chatId, dialog.query) - print(session) + # print(session) if not session or session.get("code") != 0: error_msg = json.dumps( {"message": smart_message_error, "error": "\n**ERROR**: chat agent error", "status": http_500}) @@ -77,7 +75,7 @@ media_type="text/event-stream") -@chat_router_v2.post("/complex/{chatId}/completions") +@chat_router_v2.post("/develop/{chatId}/completions") async def api_chat_dialog(chatId:str, dialog: ChatData, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): # current_user: UserModel = Depends(get_current_user) chat_info = await get_chat_info(db, chatId) if not chat_info: @@ -128,16 +126,39 @@ return Response(data, media_type="application/json", status_code=http_200) - -# @chat_router_v2.post("/conversation/mindmap") -# async def api_conversation_mindmap(chatId:str, current:int=1, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): # current_user: UserModel = Depends(get_current_user) -# data = await service_chat_sessions_list(db, chatId, current, pageSize, current_user.id, keyword) -# return Response(data, media_type="application/json", status_code=http_200) - - - @chat_router_v2.post("/retrieval") async def retrieve_chunks(request_data: RetrievalRequest, api_key: str = Depends(get_api_key)): records = await service_chunk_retrieval(request_data.query, request_data.knowledge_id, request_data.retrieval_setting.top_k, request_data.retrieval_setting.score_threshold, api_key) return {"records": records} + +@chat_router_v2.post("/complex/chat/completions") +async def api_complex_chat_completions(chat: ChatDataRequest, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): # current_user: UserModel = Depends(get_current_user) + complex_chat = await ComplexChatDao(db).get_complex_chat_by_mode(chat.chatMode) + if complex_chat: + if not chat.sessionId: + chat.sessionId = str(uuid.uuid4()).replace("-", "") + return StreamingResponse(service_complex_chat(db, complex_chat.id, complex_chat.mode, current_user.id, chat), + media_type="text/event-stream") + else: + error_msg = json.dumps( + {"message": smart_message_error, "error": "\n**ERROR**: 缃戠粶寮傚父锛屾棤娉曠敓鎴愬璇濈粨鏋滐紒", "status": http_500}) + return StreamingResponse(f"data: {error_msg}\n\n", + media_type="text/event-stream") + + +@chat_router_v2.post("/complex/upload/{chatMode}") +async def api_complex_upload(chatMode:int, file: List[UploadFile] = File(...), current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): # current_user: UserModel = Depends(get_current_user) + status_code = http_200 + complex_chat = await ComplexChatDao(db).get_complex_chat_by_mode(chatMode) + if complex_chat: + data = await service_complex_upload(db, complex_chat.id, file, current_user.id) + if not data: + status_code = http_400 + data = "{}" + else: + status_code = http_500 + data = "{}" + return Response(data, media_type="application/json", status_code=status_code) + + diff --git a/app/api/v2/mindmap.py b/app/api/v2/mindmap.py new file mode 100644 index 0000000..9f9892b --- /dev/null +++ b/app/api/v2/mindmap.py @@ -0,0 +1,31 @@ +import json +import uuid +from typing import List + +from fastapi import Depends, APIRouter, File, UploadFile +from sqlalchemy.orm import Session +from werkzeug.http import HTTP_STATUS_CODES + +from app.api import get_current_user, get_api_key, Response +from app.config.const import dialog_chat, advanced_chat, base_chat, agent_chat, workflow_chat, basic_chat, \ + smart_message_error, http_400, http_500, http_200, complex_mindmap_chat +from app.models import UserModel +from app.models.base_model import get_db +from app.models.v2.chat import RetrievalRequest, ComplexChatDao +from app.models.v2.mindmap import MindmapRequest +from app.models.v2.session_model import ChatData +from app.service.v2.mindmap import service_chat_mindmap + +mind_map_router = APIRouter() + + +@mind_map_router.post("/create", response_model=Response) +async def api_chat_mindmap(mindmap: MindmapRequest, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): # current_user: UserModel = Depends(get_current_user) + complex_chat = await ComplexChatDao(db).get_complex_chat_by_mode(complex_mindmap_chat) + if complex_chat: + data = await service_chat_mindmap(db, mindmap.messageId, mindmap.query, complex_chat.id,current_user.id) + if not data: + return Response(code=500, msg="create failure", data={}) + else: + return Response(code=500, msg="缃戠粶寮傚父锛乫ailure", data={}) + return Response(code=200, msg="create success", data=data) \ No newline at end of file diff --git a/app/config/const.py b/app/config/const.py index a5d31c5..a173fe6 100644 --- a/app/config/const.py +++ b/app/config/const.py @@ -113,3 +113,10 @@ ###-------------------------------system------------------------------------------------- SYSTEM_ID = 1 + +### --------------------------------complex mode---------------------------------------------- +complex_dialog_chat = 1 # 鏂囨。鍜屽熀纭�瀵硅瘽 +complex_network_chat = 2 # 鑱旂綉瀵硅瘽 +complex_knowledge_chat = 3 # 鐭ヨ瘑搴撳璇� +# complex_deep_chat = 4 +complex_mindmap_chat = 5 \ No newline at end of file diff --git a/app/config/env_conf/admin.yaml b/app/config/env_conf/admin.yaml index 30ad3bf..7c12711 100644 --- a/app/config/env_conf/admin.yaml +++ b/app/config/env_conf/admin.yaml @@ -3,10 +3,10 @@ password: gAAAAABnvAq8bErFiR9x_ZcODjUeOdrDo8Z5UVOzyqo6SxIhAvLpw81kciQN0frwIFVfY9wrxH1WqrpTICpEwfH7r2SkLjS7SQ== chat_server: - id: fe24dd2c9be611ef92880242ac160006 - account: user@example.com - password: gAAAAABnvs3e3fZOYfUUAJ6uT80dkhNeN7rhylzZErTWRZThNSLzMbZGetPCe9A2BJ86V0nZBLMNNu8w6rWp4dC7JxYxByJcow== + id: 2c039666c29d11efa4670242ac1b0006 + account: zhao1@example.com + password: gAAAAABnpFLtotY2OIRH12BJh4MzMgn5Zil7-DVpIeuqlFwvr0g6g_n4ULogn-LNhCbtk6cCDkzZlqAHvBSX2e_zf7AsoyzbiQ== workflow_server: - account: basic@mail.com - password: gAAAAABnvs5i7xUn9pb2szCozJciGSiWPGv80PH_2HFFzNM2r1ZLTOQqftnUso_bvchtmwAmccfNrf53sf9_WMFVTc0hjTKRRQ== \ No newline at end of file + account: admin@basic.com + password: gAAAAABnpFLtotY2OIRH12BJh4MzMgn5Zil7-DVpIeuqlFwvr0g6g_n4ULogn-LNhCbtk6cCDkzZlqAHvBSX2e_zf7AsoyzbiQ== \ No newline at end of file diff --git a/app/models/__init__.py b/app/models/__init__.py index 2f90c68..3047b96 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -17,6 +17,8 @@ from .menu_model import * from .label_model import * from .v2.session_model import * +from .v2.chat import * +from .v2.mindmap import * from .system import * diff --git a/app/models/v2/chat.py b/app/models/v2/chat.py index a4818e9..1240c85 100644 --- a/app/models/v2/chat.py +++ b/app/models/v2/chat.py @@ -1,5 +1,14 @@ -from pydantic import BaseModel +import json +from datetime import datetime +from typing import List, Optional +from pydantic import BaseModel +from sqlalchemy import Column, Integer, String, BigInteger, ForeignKey, DateTime, Text, TEXT +from sqlalchemy.orm import Session + +from app.config.const import Dialog_STATSU_DELETE +from app.models.base_model import Base +from app.utils.common import current_time class RetrievalSetting(BaseModel): @@ -12,5 +21,228 @@ query: str retrieval_setting: RetrievalSetting +class ChatDataRequest(BaseModel): + sessionId: str + query: str + chatMode: Optional[int] = 1 # 1= 鏅�氬璇濓紝2=鑱旂綉锛�3=鐭ヨ瘑搴�,4=娣卞害 + isDeep: Optional[int] = 1 # 1= 鏅��, 2=娣卞害 + knowledgeId: Optional[list] = [] + files: Optional[list] = [] + def to_dict(self): + return { + "sessionId": self.sessionId, + "query": self.query, + "chatMode": self.chatMode, + "knowledgeId": self.knowledgeId, + "files": self.files, + } + + + + +class ComplexChatModel(Base): + __tablename__ = 'complex_chat' + __mapper_args__ = { + # "order_by": 'SEQ' + } + id = Column(String(36), primary_key=True) # id + create_date = Column(DateTime, default=datetime.now()) # 鍒涘缓鏃堕棿 + update_date = Column(DateTime, default=datetime.now(), onupdate=datetime.now()) # 鏇存柊鏃堕棿 + tenant_id = Column(String(36)) # 鍒涘缓浜� + name = Column(String(255)) # 鍚嶇О + description = Column(Text) # 璇存槑 + icon = Column(Text, default="intelligentFrame1") # 鍥炬爣 + status = Column(String(1), default="1") # 鐘舵�� + dialog_type = Column(String(1)) # 骞冲彴 + mode = Column(String(36)) + parameters = Column(Text) + chat_mode = Column(Integer) #1= 鏅�氬璇濓紝2=鑱旂綉锛�3=鐭ヨ瘑搴�,4=娣卞害 + + def to_json(self): + return { + 'id': self.id, + 'create_date': self.create_date.strftime('%Y-%m-%d %H:%M:%S'), + 'update_date': self.update_date.strftime('%Y-%m-%d %H:%M:%S'), + 'user_id': self.tenant_id, + 'name': self.name, + 'description': self.description, + 'icon': self.icon, + 'status': self.status, + 'agentType': self.dialog_type, + 'mode': self.mode, + } + +class ComplexChatDao: + def __init__(self, db: Session): + self.db = db + + async def create_complex_chat(self, chat_id: str, **kwargs) -> ComplexChatModel: + new_session = ComplexChatModel( + id=chat_id, + create_date=current_time(), + update_date=current_time(), + **kwargs + ) + self.db.add(new_session) + self.db.commit() + self.db.refresh(new_session) + return new_session + + async def get_complex_chat_by_id(self, chat_id: str) -> ComplexChatModel | None: + session = self.db.query(ComplexChatModel).filter_by(id=chat_id).first() + return session + + async def update_complex_chat_by_id(self, chat_id: str, session, message: dict, conversation_id=None) -> ComplexChatModel | None: + if not session: + session = await self.get_complex_chat_by_id(chat_id) + if session: + try: + # TODO + session.update_date = current_time() + self.db.commit() + self.db.refresh(session) + except Exception as e: + # logger.error(e) + self.db.rollback() + return session + + async def update_or_insert_by_id(self, chat_id: str, **kwargs) -> ComplexChatModel: + existing_session = await self.get_complex_chat_by_id(chat_id) + if existing_session: + return await self.update_complex_chat_by_id(chat_id, existing_session, kwargs.get("message")) + + existing_session = await self.create_complex_chat(chat_id, **kwargs) + return existing_session + + async def delete_complex_chat(self, chat_id: str) -> None: + session = await self.get_complex_chat_by_id(chat_id) + if session: + self.db.delete(session) + self.db.commit() + + async def aget_complex_chat_ids(self) -> List: + session_list = self.db.query(ComplexChatModel).filter(ComplexChatModel.status!=Dialog_STATSU_DELETE).all() + + return [i.id for i in session_list] + + def get_complex_chat_ids(self) -> List: + session_list = self.db.query(ComplexChatModel).filter(ComplexChatModel.status!=Dialog_STATSU_DELETE).all() + + return [i.id for i in session_list] + + async def get_complex_chat_by_mode(self, chat_mode: int) -> ComplexChatModel | None: + session = self.db.query(ComplexChatModel).filter(ComplexChatModel.chat_mode==chat_mode, ComplexChatModel.status!=Dialog_STATSU_DELETE).first() + return session + + + +class ComplexChatSessionModel(Base): + __tablename__ = "complex_chat_sessions" + + + id = Column(String(36), primary_key=True) + chat_id = Column(String(36)) + session_id = Column(String(36), index=True) + create_date = Column(DateTime, default=current_time, index=True) # 鍒涘缓鏃堕棿锛岄粯璁ゅ�间负褰撳墠鏃跺尯鏃堕棿 + update_date = Column(DateTime, default=current_time, onupdate=current_time) # 鏇存柊鏃堕棿锛岄粯璁ゅ�间负褰撳墠鏃跺尯鏃堕棿锛屾洿鏂版椂鑷姩鏇存柊 + tenant_id = Column(Integer, index=True) # 鍒涘缓浜� + agent_type = Column(Integer) # 1=rg锛� 3=basic锛�4=df + message_type = Column(Integer) # 1=鐢ㄦ埛锛�2=鏈哄櫒浜猴紝3=绯荤粺 + content = Column(TEXT) + mindmap = Column(TEXT) + query = Column(TEXT) + node_data = Column(TEXT) + event_type = Column(String(16)) + conversation_id = Column(String(36)) + chat_mode = Column(Integer) # 1= 鏅�氬璇濓紝2=鑱旂綉锛�3=鐭ヨ瘑搴�,4=娣卞害 + + # to_dict 鏂规硶 + def to_dict(self): + return { + 'session_id': self.id, + 'name': self.name, + 'agent_type': self.agent_type, + 'chat_id': self.agent_id, + 'event_type': self.event_type, + 'session_type': self.session_type if self.session_type else 0, + 'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"), + 'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"), + } + + def log_to_json(self): + return { + 'id': self.id, + 'name': self.name, + 'agent_type': self.agent_type, + 'chat_id': self.agent_id, + 'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"), + 'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"), + 'message': json.loads(self.message) + } + + +class ComplexChatSessionDao: + def __init__(self, db: Session): + self.db = db + + async def get_session_by_session_id(self, session_id: str, chat_id:str) -> ComplexChatSessionModel | None: + session = self.db.query(ComplexChatSessionModel).filter_by(chat_id=chat_id, session_id=session_id, message_type=2).first() + return session + + async def create_session(self, message_id: str, **kwargs) -> ComplexChatSessionModel: + new_session = ComplexChatSessionModel( + id=message_id, + create_date=current_time(), + update_date=current_time(), + **kwargs + ) + self.db.add(new_session) + self.db.commit() + self.db.refresh(new_session) + return new_session + + async def get_session_by_id(self, message_id: str) -> ComplexChatSessionModel | None: + session = self.db.query(ComplexChatSessionModel).filter_by(id=message_id).first() + return session + + async def update_mindmap_by_id(self, message_id: str, mindmap:str) -> ComplexChatSessionModel | None: + # print(message) + + session = await self.get_session_by_id(message_id) + if session: + try: + + session.mindmap = mindmap + session.update_date = current_time() + self.db.commit() + self.db.refresh(session) + except Exception as e: + # logger.error(e) + self.db.rollback() + return session + + async def update_or_insert_by_id(self, session_id: str, **kwargs) -> ComplexChatSessionModel: + existing_session = await self.get_session_by_id(session_id) + if existing_session: + return await self.update_session_by_id(session_id, existing_session, kwargs.get("message")) + + existing_session = await self.create_session(session_id, **kwargs) + return existing_session + + async def delete_session(self, session_id: str) -> None: + session = await self.get_session_by_id(session_id) + if session: + self.db.delete(session) + self.db.commit() + + async def get_session_list(self, user_id: int, agent_id: str, keyword:str, page: int, page_size: int) -> any: + query = self.db.query(ComplexChatSessionModel).filter(ComplexChatSessionModel.tenant_id==user_id) + if agent_id: + query = query.filter(ComplexChatSessionModel.agent_id==agent_id) + if keyword: + query = query.filter(ComplexChatSessionModel.name.like('%{}%'.format(keyword))) + total = query.count() + session_list = query.order_by(ComplexChatSessionModel.update_date.desc()).offset((page-1)*page_size).limit(page_size).all() + return total, session_list \ No newline at end of file diff --git a/app/models/v2/mindmap.py b/app/models/v2/mindmap.py index cead556..743bb5f 100644 --- a/app/models/v2/mindmap.py +++ b/app/models/v2/mindmap.py @@ -1,12 +1,16 @@ import json -from typing import Optional, Type, List +from datetime import datetime +from typing import List + from pydantic import BaseModel +from sqlalchemy import Column, Integer, String, BigInteger, ForeignKey, DateTime, Text +from sqlalchemy.orm import Session +from app.config.const import Dialog_STATSU_DELETE +from app.models.base_model import Base +from app.utils.common import current_time +class MindmapRequest(BaseModel): + messageId: str + query:str - -class ChatData(BaseModel): - sessionId: Optional[str] = "" - - class Config: - extra = 'allow' # 鍏佽鍏朵粬鍔ㄦ�佸瓧娈� \ No newline at end of file diff --git a/app/service/v2/app_driver/chat_agent.py b/app/service/v2/app_driver/chat_agent.py index 45f804e..5fa0bfa 100644 --- a/app/service/v2/app_driver/chat_agent.py +++ b/app/service/v2/app_driver/chat_agent.py @@ -9,6 +9,7 @@ async def chat_completions(self, url, data, headers): complete_response = "" + # print(data) async for line in self.http_stream(url, data, headers): # logger.error(line) if line.startswith("data:"): @@ -46,6 +47,21 @@ "files": files } + @staticmethod + async def complex_request_data(query: str, conversation_id: str, user: str, files: list=None, inputs: dict=None) -> dict: + if not files: + files = [] + if not inputs: + inputs = {} + return { + "inputs": inputs, + "query": query, + "response_mode": "streaming", + "conversation_id": conversation_id, + "user": user, + "files": files + } + if __name__ == "__main__": async def aa(): diff --git a/app/service/v2/chat.py b/app/service/v2/chat.py index 68a7cd3..05a37be 100644 --- a/app/service/v2/chat.py +++ b/app/service/v2/chat.py @@ -1,6 +1,7 @@ import asyncio import io import json +import uuid import fitz from fastapi import HTTPException @@ -10,7 +11,7 @@ DF_CHAT_WORKFLOW, DF_UPLOAD_FILE, RG_ORIGINAL_URL from app.config.config import settings from app.config.const import * -from app.models import DialogModel, ApiTokenModel, UserTokenModel +from app.models import DialogModel, ApiTokenModel, UserTokenModel, ComplexChatSessionDao, ChatDataRequest from app.models.v2.session_model import ChatSessionDao, ChatData from app.service.v2.app_driver.chat_agent import ChatAgent from app.service.v2.app_driver.chat_data import ChatBaseApply @@ -47,12 +48,12 @@ logger.error(e) return None + async def get_app_token(db, app_id): app_token = db.query(UserTokenModel).filter_by(id=app_id).first() if app_token: return app_token.access_token return "" - async def get_chat_token(db, app_id): @@ -69,7 +70,6 @@ db.commit() except Exception as e: logger.error(e) - async def get_chat_info(db, chat_id: str): @@ -134,6 +134,7 @@ message["role"] = "assistant" await update_session_log(db, session_id, message, conversation_id) + async def data_process(data): if isinstance(data, str): return data.replace("dify", "smart") @@ -170,8 +171,9 @@ if hasattr(chat_data, "query"): query = chat_data.query else: - query = "start new workflow" - session = await add_session_log(db, session_id,query if query else "start new conversation", chat_id, user_id, mode, conversation_id, 3) + query = "start new conversation" + session = await add_session_log(db, session_id, query if query else "start new conversation", chat_id, user_id, + mode, conversation_id, 3) if session: conversation_id = session.conversation_id try: @@ -215,7 +217,7 @@ [workflow_started, node_started, node_finished].index(ans.get("event"))] elif ans.get("event") == workflow_finished: data = ans.get("data", {}) - answer_workflow = data.get("outputs", {}).get("output") + answer_workflow = data.get("outputs", {}).get("output", data.get("outputs", {}).get("answer")) download_url = data.get("outputs", {}).get("download_url") event = smart_workflow_finished if data.get("status") == "failed": @@ -242,8 +244,9 @@ except: ... finally: - await update_session_log(db, session_id, {"role": "assistant", "answer": answer_event or answer_agent or answer_workflow or error, - "download_url":download_url, + await update_session_log(db, session_id, {"role": "assistant", + "answer": answer_event or answer_agent or answer_workflow or error, + "download_url": download_url, "node_list": node_list, "task_id": task_id, "id": message_id, "error": error}, conversation_id) @@ -257,6 +260,7 @@ if not chat_info: return {} return chat_info.parameters + async def service_chat_sessions(db, chat_id, name): token = await get_chat_token(db, rg_api_token) @@ -276,14 +280,12 @@ page=current, page_size=page_size ) - return json.dumps({"total":total, "rows": [session.to_dict() for session in session_list]}) - + return json.dumps({"total": total, "rows": [session.to_dict() for session in session_list]}) async def service_chat_session_log(db, session_id): session_log = await ChatSessionDao(db).get_session_by_id(session_id) - return json.dumps(session_log.log_to_json()) - + return json.dumps(session_log.log_to_json() if session_log else {}) async def service_chat_upload(db, chat_id, file, user_id): @@ -316,6 +318,7 @@ tokens = tokenizer.encode(input_str) return len(tokens) + async def read_pdf(pdf_stream): text = "" with fitz.open(stream=pdf_stream, filetype="pdf") as pdf_document: @@ -335,6 +338,7 @@ return text + async def read_file(file, filename, content_type): text = "" if content_type == "application/pdf" or filename.endswith('.pdf'): @@ -349,7 +353,7 @@ async def service_chunk_retrieval(query, knowledge_id, top_k, similarity_threshold, api_key): - print(query) + # print(query) try: request_data = json.loads(query) @@ -357,7 +361,7 @@ "question": request_data.get("query", ""), "dataset_ids": request_data.get("dataset_ids", []), "page_size": top_k, - "similarity_threshold": similarity_threshold + "similarity_threshold": similarity_threshold if similarity_threshold else 0.2 } except json.JSONDecodeError as e: fixed_json = query.replace("'", '"') @@ -367,15 +371,16 @@ "question": request_data.get("query", ""), "dataset_ids": request_data.get("dataset_ids", []), "page_size": top_k, - "similarity_threshold": similarity_threshold + "similarity_threshold": similarity_threshold if similarity_threshold else 0.2 } except Exception: payload = { - "question":query, - "dataset_ids":[knowledge_id], + "question": query, + "dataset_ids": [knowledge_id], "page_size": top_k, - "similarity_threshold": similarity_threshold + "similarity_threshold": similarity_threshold if similarity_threshold else 0.2 } + # print(payload) url = settings.fwr_base_url + RG_ORIGINAL_URL chat = ChatBaseApply() response = await chat.chat_post(url, payload, await chat.get_headers(api_key)) @@ -388,11 +393,15 @@ "title": chunk.get("document_keyword", "Unknown Document"), "metadata": {"document_id": chunk["document_id"], "path": f"{settings.fwr_base_url}/document/{chunk['document_id']}?ext={chunk.get('document_keyword').split('.')[-1]}&prefix=document", - 'highlight': chunk.get("highlight") , "image_id": chunk.get("image_id"), "positions": chunk.get("positions"),} + 'highlight': chunk.get("highlight"), "image_id": chunk.get("image_id"), + "positions": chunk.get("positions"), } } for chunk in response.get("data", {}).get("chunks", []) ] + # print(len(records)) + # print(records) return records + async def service_base_chunk_retrieval(query, knowledge_id, top_k, similarity_threshold, api_key): # request_data = json.loads(query) @@ -420,15 +429,170 @@ return records +async def add_complex_log(db, message_id, chat_id, session_id, chat_mode, query, user_id, mode, agent_type, message_type, conversation_id="", node_data=None, query_data=None): + if not node_data: + node_data = [] + if not query_data: + query_data = {} + try: + complex_log = ComplexChatSessionDao(db) + if not conversation_id: + session = await complex_log.get_session_by_session_id(session_id, chat_id) + if session: + conversation_id = session.conversation_id + await complex_log.create_session(message_id, + chat_id=chat_id, + session_id=session_id, + chat_mode=chat_mode, + message_type=message_type, + content=query, + event_type=mode, + tenant_id=user_id, + conversation_id=conversation_id, + node_data=json.dumps(node_data), + query=json.dumps(query_data), + agent_type=agent_type) + return conversation_id, True + + except Exception as e: + logger.error(e) + return conversation_id, False + + + +async def service_complex_chat(db, chat_id, mode, user_id, chat_request: ChatDataRequest): + answer_event = "" + answer_agent = "" + answer_workflow = "" + download_url = "" + message_id = "" + task_id = "" + error = "" + files = [] + node_list = [] + token = await get_chat_token(db, chat_id) + chat, url = await get_chat_object(mode) + conversation_id, message = await add_complex_log(db, str(uuid.uuid4()),chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=chat_request.to_dict()) + if not message: + yield "data: " + json.dumps({"message": smart_message_error, + "error": "\n**ERROR**: 鍒涘缓浼氳瘽澶辫触锛�", "status": http_500}, + ensure_ascii=False) + "\n\n" + return + inputs = {"is_deep": chat_request.isDeep} + if chat_request.chatMode == complex_knowledge_chat: + inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId}) + + try: + async for ans in chat.chat_completions(url, + await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=chat_request.files, inputs=inputs), + await chat.get_headers(token)): + # print(ans) + data = {} + status = http_200 + conversation_id = ans.get("conversation_id") + task_id = ans.get("task_id") + if ans.get("event") == message_error: + error = ans.get("message", "鍙傛暟寮傚父锛�") + status = http_400 + event = smart_message_error + elif ans.get("event") == message_agent: + data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")} + answer_agent += ans.get("answer", "") + message_id = ans.get("message_id", "") + event = smart_message_stream + elif ans.get("event") == message_event: + data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")} + answer_event += ans.get("answer", "") + message_id = ans.get("message_id", "") + event = smart_message_stream + elif ans.get("event") == message_file: + data = {"url": ans.get("url", ""), "id": ans.get("id", ""), + "type": ans.get("type", "")} + files.append(data) + event = smart_message_file + elif ans.get("event") in [workflow_started, node_started, node_finished]: + data = ans.get("data", {}) + data["inputs"] = await data_process(data.get("inputs", {})) + data["outputs"] = await data_process(data.get("outputs", {})) + data["files"] = await data_process(data.get("files", [])) + data["process_data"] = "" + if data.get("status") == "failed": + status = http_500 + error = data.get("error", "") + node_list.append(ans) + event = [smart_workflow_started, smart_node_started, smart_node_finished][ + [workflow_started, node_started, node_finished].index(ans.get("event"))] + elif ans.get("event") == workflow_finished: + data = ans.get("data", {}) + answer_workflow = data.get("outputs", {}).get("output", data.get("outputs", {}).get("answer")) + download_url = data.get("outputs", {}).get("download_url") + event = smart_workflow_finished + if data.get("status") == "failed": + status = http_500 + error = data.get("error", "") + node_list.append(ans) + + elif ans.get("event") == message_end: + event = smart_message_end + else: + continue + + yield "data: " + json.dumps( + {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id, + "session_id": chat_request.sessionId}, + ensure_ascii=False) + "\n\n" + + except Exception as e: + logger.error(e) + try: + yield "data: " + json.dumps({"message": smart_message_error, + "error": "\n**ERROR**: " + str(e), "status": http_500}, + ensure_ascii=False) + "\n\n" + except: + ... + finally: + # await update_session_log(db, session_id, {"role": "assistant", + # "answer": answer_event or answer_agent or answer_workflow or error, + # "download_url": download_url, + # "node_list": node_list, "task_id": task_id, "id": message_id, + # "error": error}, conversation_id) + if message_id: + await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=chat_request.to_dict()) + +async def service_complex_upload(db, chat_id, file, user_id): + files = [] + token = await get_chat_token(db, chat_id) + if not token: + return files + url = settings.dify_base_url + DF_UPLOAD_FILE + chat = ChatBaseApply() + for f in file: + try: + file_content = await f.read() + file_upload = await chat.chat_upload(url, {"file": (f.filename, file_content)}, {"user": str(user_id)}, + {'Authorization': f'Bearer {token}'}) + # try: + # tokens = await read_file(file_content, f.filename, f.content_type) + # file_upload["tokens"] = tokens + # except: + # ... + files.append(file_upload) + except Exception as e: + logger.error(e) + return json.dumps(files) if files else "" if __name__ == "__main__": q = json.dumps({"query": "璁惧", "dataset_ids": ["fc68db52f43111efb94a0242ac120004"]}) top_k = 2 similarity_threshold = 0.5 api_key = "ragflow-Y4MGYwY2JlZjM2YjExZWY4ZWU5MDI0Mm" + + # a = service_chunk_retrieval(q, top_k, similarity_threshold, api_key) # print(a) async def a(): b = await service_chunk_retrieval(q, top_k, similarity_threshold, api_key) print(b) - asyncio.run(a()) \ No newline at end of file + + + asyncio.run(a()) diff --git a/app/service/v2/mindmap.py b/app/service/v2/mindmap.py new file mode 100644 index 0000000..dd5b994 --- /dev/null +++ b/app/service/v2/mindmap.py @@ -0,0 +1,198 @@ +import json +from Log import logger +from app.config.agent_base_url import DF_CHAT_AGENT +from app.config.config import settings +from app.config.const import message_error, message_event, complex_knowledge_chat +from app.models import ComplexChatSessionDao, ChatData +from app.service.v2.app_driver.chat_agent import ChatAgent +from app.service.v2.chat import get_chat_token + + +async def service_chat_mindmap_v1(db, message_id, message, mindmap_chat_id, user_id): + res = {} + mindmap_query = "" + complex_log = ComplexChatSessionDao(db) + session = await complex_log.get_session_by_id(message_id) + if session: + token = await get_chat_token(db, session.chat_id) + chat = ChatAgent() + url = settings.dify_base_url + DF_CHAT_AGENT + if session.mindmap: + chat_request = json.loads(session.query) + try: + async for ans in chat.chat_completions(url, + await chat.request_data(message, session.conversation_id, + str(user_id), ChatData(), chat_request.get("files", [])), + await chat.get_headers(token)): + if ans.get("event") == message_error: + return res + elif ans.get("event") == message_event: + mindmap_query += ans.get("answer", "") + else: + continue + + except Exception as e: + logger.error(e) + return res + else: + mindmap_query = session.content + + try: + mindmap_str = "" + token = await get_chat_token(db, mindmap_chat_id) + async for ans in chat.chat_completions(url, + await chat.request_data(mindmap_query, "", + str(user_id), ChatData()), + await chat.get_headers(token)): + if ans.get("event") == message_error: + return res + elif ans.get("event") == message_event: + mindmap_str += ans.get("answer", "") + else: + continue + + except Exception as e: + logger.error(e) + return res + mindmap_list = mindmap_str.split("```") + mindmap_str = mindmap_list[1].lstrip("markdown\n") + if session.mindmap: + node_list = await mindmap_to_merge(session.mindmap, mindmap_str, f"- {message}") + mindmap_str = "\n".join(node_list) + res["mindmap"] = mindmap_str + await complex_log.update_mindmap_by_id(message_id, mindmap_str) + return res + + +async def service_chat_mindmap(db, message_id, message, mindmap_chat_id, user_id): + res = {} + mindmap_query = "" + complex_log = ComplexChatSessionDao(db) + session = await complex_log.get_session_by_id(message_id) + if session: + token = await get_chat_token(db, session.chat_id) + chat = ChatAgent() + url = settings.dify_base_url + DF_CHAT_AGENT + if session.mindmap: + chat_request = json.loads(session.query) + inputs = {"is_deep": chat_request.get("isDeep", 1)} + if session.chat_mode == complex_knowledge_chat: + inputs["query_json"] = json.dumps( + {"query": chat_request.get("query", ""), "dataset_ids": chat_request.get("knowledgeId", [])}) + try: + async for ans in chat.chat_completions(url, + await chat.complex_request_data(message, session.conversation_id, + str(user_id), files=chat_request.get("files", []), inputs=inputs), + await chat.get_headers(token)): + if ans.get("event") == message_error: + return res + elif ans.get("event") == message_event: + mindmap_query += ans.get("answer", "") + else: + continue + + except Exception as e: + logger.error(e) + return res + else: + mindmap_query = session.content + + try: + mindmap_str = "" + token = await get_chat_token(db, mindmap_chat_id) + async for ans in chat.chat_completions(url, + await chat.complex_request_data(mindmap_query, "", + str(user_id)), + await chat.get_headers(token)): + if ans.get("event") == message_error: + return res + elif ans.get("event") == message_event: + mindmap_str += ans.get("answer", "") + else: + continue + + except Exception as e: + logger.error(e) + return res + if "```json" in mindmap_str: + mindmap_list = mindmap_str.split("```") + mindmap_str = mindmap_list[1].lstrip("json") + mindmap_str = mindmap_str.replace("\n", "") + if session.mindmap: + mindmap_str = await mindmap_merge_dict(session.mindmap, mindmap_str, message) + + try: + res_str = await mindmap_join_str(mindmap_str) + res["mindmap"] = res_str + except Exception as e: + logger.error(e) + return res + await complex_log.update_mindmap_by_id(message_id, mindmap_str) + return res + +async def mindmap_merge_dict(parent, child, target_node): + parent_dict = json.loads(parent) + if child: + child_dict = json.loads(child) + def iter_dict(node): + if "items" not in node: + if node["title"] == target_node: + node["items"] = child_dict["items"] + return + else: + for i in node["items"]: + iter_dict(i) + iter_dict(parent_dict) + + return json.dumps(parent_dict) + + +async def mindmap_join_str(mindmap_json): + try: + parent_dict = json.loads(mindmap_json) + except Exception as e: + logger.error(e) + return "" + def join_node(node, level): + mindmap_str = "" + if level <= 2: + mindmap_str += f"{'#'*level} {node['title']}\n" + else: + mindmap_str += f"{' '*(level-3)*2}- {node['title']}\n" + for i in node.get("items", []): + mindmap_str += join_node(i, level+1) + return mindmap_str + + return join_node(parent_dict, 1) + +async def mindmap_to_merge(parent, child, target_node): + level = 0 + index = 0 + new_node_list = [] + parent_list= parent.split("\n") + child_list= child.split("\n") + child_list[0] = target_node + for i, node in enumerate(parent_list): + if node.endswith(target_node): + level = len(node) - len(target_node) + index = i + break + tmp_level = 0 + for child in child_list: + if "#" in child: + childs = child.split("#") + tmp_level = len(childs) - 2 + new_node_list.append(" "*(level+tmp_level)+ "-"+childs[-1]) + elif len(child) == 0: + continue + else: + new_node_list.append(" "*(level+tmp_level)+child) + + return parent_list[:index]+new_node_list+parent_list[index+1:] + + + +if __name__ == '__main__': + a = '{ "title": "鍏ㄧ敓鍛藉懆鏈熺鐞�", "items": [ { "title": "璁惧瑙勫垝涓庨噰璐�", "items": [ { "title": "闇�姹傚垎鏋愪笌閫夊瀷" ,"items": [{"title": "rererer"}, {"title": "trtrtrtrt"}] }, { "title": "渚涘簲鍟嗛�夋嫨涓庡悎鍚岀鐞�" } ] }, { "title": "璁惧瀹夎涓庤皟璇�", "items": [ { "title": "瀹夎瑙勮寖" }, { "title": "璋冭瘯娴嬭瘯" } ] }, { "title": "璁惧浣跨敤", "items": [ { "title": "鎿嶄綔鍩硅" }, { "title": "鎿嶄綔瑙勭▼涓庤褰�" } ] }, { "title": "璁惧缁存姢涓庣淮淇�", "items": [ { "title": "瀹氭湡缁存姢" }, { "title": "鏁呴殰璇婃柇" }, { "title": "澶囦欢绠$悊" } ] }, { "title": "璁惧鏇存柊涓庢敼閫�", "items": [ { "title": "鎶�鏈瘎浼�" }, { "title": "鏇存柊璁″垝" }, { "title": "鏀归�犳柟妗�" } ] }, { "title": "璁惧鎶ュ簾", "items": [ { "title": "鎶ュ簾璇勪及" }, { "title": "鎶ュ簾澶勭悊" } ] }, { "title": "淇℃伅鍖栫鐞�", "items": [ { "title": "璁惧绠$悊绯荤粺" }, { "title": "鏁版嵁鍒嗘瀽" }, { "title": "杩滅▼鐩戞帶" } ] }, { "title": "瀹夊叏绠$悊", "items": [ { "title": "瀹夊叏鍩硅" }, { "title": "瀹夊叏妫�鏌�" }, { "title": "搴旀�ラ妗�" } ] }, { "title": "鐜淇濇姢", "items": [ { "title": "鐜繚璁惧" }, { "title": "搴熺墿澶勭悊" }, { "title": "鑺傝兘鍑忔帓" } ] }, { "title": "鍏蜂綋瀹炶返妗堜緥", "items": [ { "title": "楂樺帇寮�鍏宠澶囨鼎婊戣剛閫夌敤鐮旂┒" }, { "title": "鐜繚鍨� C4 娣锋皵 GIS 璁惧杩愮淮鎶�鏈爺绌�" } ] }, { "title": "鎬荤粨", "items": [ { "title": "鎻愰珮杩愯惀鏁堢巼鍜岀珵浜夊姏" } ] } ]}' + b = mindmap_merge_dict(a, {}, "璁惧瑙勫垝涓庨噰璐�") + print(b) \ No newline at end of file diff --git a/app/task/fetch_agent.py b/app/task/fetch_agent.py index a5a7bfb..8ad5215 100644 --- a/app/task/fetch_agent.py +++ b/app/task/fetch_agent.py @@ -9,7 +9,7 @@ from app.config.config import settings from app.config.const import RAGFLOW, BISHENG, DIFY, ENV_CONF_PATH, Dialog_STATSU_DELETE, Dialog_STATSU_ON -from app.models import KnowledgeModel +from app.models import KnowledgeModel, ComplexChatDao from app.models.dialog_model import DialogModel from app.models.user_model import UserAppModel from app.models.agent_model import AgentModel @@ -239,7 +239,7 @@ db.close() -def get_data_from_ragflow_v2(names: List[str], tenant_id) -> List[Dict]: +def get_data_from_ragflow_v2(base_db, names: List[str], tenant_id) -> List[Dict]: db = SessionRagflow() para = { "user_input_form": [], @@ -251,6 +251,8 @@ } } try: + chat_ids = ComplexChatDao(base_db).get_complex_chat_ids() + # print(chat_ids) if names: query = db.query(Dialog.id, Dialog.name, Dialog.description, Dialog.status, Dialog.tenant_id) \ .filter(Dialog.name.in_(names), Dialog.status == "1") @@ -261,15 +263,17 @@ results = query.all() formatted_results = [ {"id": row[0], "name": row[1], "description": row[2], "status": "1" if row[3] == "1" else "2", - "user_id": str(row[4]), "mode": "agent-dialog", "parameters": para} for row in results] + "user_id": str(row[4]), "mode": "agent-dialog", "parameters": para} for row in results if row[0] not in chat_ids] return formatted_results finally: db.close() -def get_data_from_dy_v2(names: List[str]) -> List[Dict]: +def get_data_from_dy_v2(base_db, names: List[str]) -> List[Dict]: db = SessionDify() try: + chat_ids = ComplexChatDao(base_db).get_complex_chat_ids() + # print(chat_ids) if names: query = db.query(DfApps.id, DfApps.name, DfApps.description, DfApps.status, DfApps.tenant_id, DfApps.mode) \ .filter(DfApps.name.in_(names)) @@ -279,7 +283,7 @@ results = query.all() formatted_results = [ {"id": str(row[0]), "name": row[1], "description": row[2], "status": "1", - "user_id": str(row[4]), "mode": row[5], "parameters": {}} for row in results] + "user_id": str(row[4]), "mode": row[5], "parameters": {}} for row in results if str(row[0]) not in chat_ids] return formatted_results finally: db.close() @@ -342,11 +346,11 @@ for app in app_register: try: if app["id"] == RAGFLOW: - ragflow_data = get_data_from_ragflow_v2([], app["name"]) + ragflow_data = get_data_from_ragflow_v2(db, [], app["name"]) if ragflow_data: update_ids_in_local_v2(ragflow_data, "1") elif app["id"] == DIFY: - dify_data = get_data_from_dy_v2([]) + dify_data = get_data_from_dy_v2(db, []) if dify_data: update_ids_in_local_v2(dify_data, "4") except Exception as e: diff --git a/app/utils/common.py b/app/utils/common.py new file mode 100644 index 0000000..e3ac4a6 --- /dev/null +++ b/app/utils/common.py @@ -0,0 +1,6 @@ +import pytz +from datetime import datetime + +def current_time(): + tz = pytz.timezone('Asia/Shanghai') + return datetime.now(tz) diff --git a/main.py b/main.py index d0ba8d6..87597e7 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,7 @@ from app.api.organization import dept_router from app.api.system import system_router from app.api.v2.chat import chat_router_v2 +from app.api.v2.mindmap import mind_map_router from app.api.v2.public_api import public_api from app.api.report import router as report_router from app.api.resource import menu_router @@ -77,6 +78,7 @@ app.include_router(public_api, prefix='/v1/api', tags=["public_api"]) app.include_router(chat_router_v2, prefix='/api/v1', tags=["chat1"]) app.include_router(system_router, prefix='/api/system', tags=["system"]) +app.include_router(mind_map_router, prefix='/api/mindmap', tags=["mindmap"]) app.mount("/static", StaticFiles(directory="app/images"), name="static") if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index 4d645ea..5ba14b3 100644 --- a/requirements.txt +++ b/requirements.txt Binary files differ -- Gitblit v1.8.0