zhaoqingang
2025-03-06 e26a7859a8900b152e10961d91fa6ad19a8deb9c
首页通用对话增加
3个文件已添加
11个文件已修改
801 ■■■■■ 已修改文件
app/api/v2/chat.py 55 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/v2/mindmap.py 31 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/config/const.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/config/env_conf/admin.yaml 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/__init__.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/v2/chat.py 234 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/v2/mindmap.py 18 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/v2/app_driver/chat_agent.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/v2/chat.py 204 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/v2/mindmap.py 198 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/task/fetch_agent.py 18 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/utils/common.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
main.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
requirements.txt 补丁 | 查看 | 原始文档 | blame | 历史
app/api/v2/chat.py
@@ -1,22 +1,20 @@
import json
import uuid
from typing import List
from typing import List
from fastapi import Depends, APIRouter, File, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse, Response
from werkzeug.http import HTTP_STATUS_CODES
from app.api import get_current_user, get_api_key
from app.config.const import dialog_chat, advanced_chat, base_chat, agent_chat, workflow_chat, basic_chat, \
    smart_message_error, http_400, http_500, http_200
from app.config.const import smart_message_error, http_400, http_500, http_200, complex_dialog_chat
from app.models import UserModel
from app.models.base_model import get_db
from app.models.v2.chat import RetrievalRequest
from app.models.v2.chat import RetrievalRequest, ChatDataRequest, ComplexChatDao
from app.models.v2.session_model import ChatData
from app.service.v2.chat import service_chat_dialog, get_chat_info, service_chat_basic, \
    service_chat_workflow, service_chat_parameters, service_chat_sessions, service_chat_upload, \
    service_chat_sessions_list, service_chat_session_log, service_chunk_retrieval, service_base_chunk_retrieval
    service_chat_sessions_list, service_chat_session_log, service_chunk_retrieval, service_complex_chat, \
    service_complex_upload
chat_router_v2 = APIRouter()
@@ -37,7 +35,7 @@
                                 media_type="text/event-stream")
    if not session_id:
        session = await service_chat_sessions(db, chatId, dialog.query)
        print(session)
        # print(session)
        if not session or session.get("code") != 0:
            error_msg = json.dumps(
                {"message": smart_message_error, "error": "\n**ERROR**: chat agent error", "status": http_500})
@@ -77,7 +75,7 @@
                             media_type="text/event-stream")
@chat_router_v2.post("/complex/{chatId}/completions")
@chat_router_v2.post("/develop/{chatId}/completions")
async def api_chat_dialog(chatId:str, dialog: ChatData, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): #  current_user: UserModel = Depends(get_current_user)
    chat_info = await get_chat_info(db, chatId)
    if not chat_info:
@@ -128,16 +126,39 @@
    return Response(data, media_type="application/json", status_code=http_200)
# @chat_router_v2.post("/conversation/mindmap")
# async def api_conversation_mindmap(chatId:str, current:int=1, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): #  current_user: UserModel = Depends(get_current_user)
#     data = await service_chat_sessions_list(db, chatId, current, pageSize, current_user.id, keyword)
#     return Response(data, media_type="application/json", status_code=http_200)
@chat_router_v2.post("/retrieval")
async def retrieve_chunks(request_data: RetrievalRequest, api_key: str = Depends(get_api_key)):
    records = await service_chunk_retrieval(request_data.query, request_data.knowledge_id, request_data.retrieval_setting.top_k, request_data.retrieval_setting.score_threshold, api_key)
    return {"records": records}
@chat_router_v2.post("/complex/chat/completions")
async def api_complex_chat_completions(chat: ChatDataRequest, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): #  current_user: UserModel = Depends(get_current_user)
    complex_chat = await ComplexChatDao(db).get_complex_chat_by_mode(chat.chatMode)
    if complex_chat:
        if not chat.sessionId:
            chat.sessionId = str(uuid.uuid4()).replace("-", "")
        return StreamingResponse(service_complex_chat(db, complex_chat.id, complex_chat.mode, current_user.id, chat),
                                 media_type="text/event-stream")
    else:
        error_msg = json.dumps(
            {"message": smart_message_error, "error": "\n**ERROR**: 网络异常,无法生成对话结果!", "status": http_500})
        return StreamingResponse(f"data: {error_msg}\n\n",
                                 media_type="text/event-stream")
@chat_router_v2.post("/complex/upload/{chatMode}")
async def api_complex_upload(chatMode:int, file: List[UploadFile] = File(...), current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): #  current_user: UserModel = Depends(get_current_user)
    status_code = http_200
    complex_chat = await ComplexChatDao(db).get_complex_chat_by_mode(chatMode)
    if complex_chat:
        data = await service_complex_upload(db, complex_chat.id, file, current_user.id)
        if not data:
            status_code = http_400
            data = "{}"
    else:
        status_code = http_500
        data = "{}"
    return Response(data, media_type="application/json", status_code=status_code)
app/api/v2/mindmap.py
New file
@@ -0,0 +1,31 @@
import json
import uuid
from typing import List
from fastapi import Depends, APIRouter, File, UploadFile
from sqlalchemy.orm import Session
from werkzeug.http import HTTP_STATUS_CODES
from app.api import get_current_user, get_api_key, Response
from app.config.const import dialog_chat, advanced_chat, base_chat, agent_chat, workflow_chat, basic_chat, \
    smart_message_error, http_400, http_500, http_200, complex_mindmap_chat
from app.models import UserModel
from app.models.base_model import get_db
from app.models.v2.chat import RetrievalRequest, ComplexChatDao
from app.models.v2.mindmap import MindmapRequest
from app.models.v2.session_model import ChatData
from app.service.v2.mindmap import service_chat_mindmap
mind_map_router = APIRouter()
@mind_map_router.post("/create", response_model=Response)
async def api_chat_mindmap(mindmap: MindmapRequest, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): #  current_user: UserModel = Depends(get_current_user)
    complex_chat = await ComplexChatDao(db).get_complex_chat_by_mode(complex_mindmap_chat)
    if complex_chat:
        data = await service_chat_mindmap(db, mindmap.messageId, mindmap.query, complex_chat.id,current_user.id)
        if not data:
            return Response(code=500, msg="create failure", data={})
    else:
        return Response(code=500, msg="网络异常!failure", data={})
    return Response(code=200, msg="create success", data=data)
app/config/const.py
@@ -113,3 +113,10 @@
###-------------------------------system-------------------------------------------------
SYSTEM_ID = 1
### --------------------------------complex mode----------------------------------------------
complex_dialog_chat = 1 # 文档和基础对话
complex_network_chat = 2 # 联网对话
complex_knowledge_chat = 3 # 知识库对话
# complex_deep_chat = 4
complex_mindmap_chat = 5
app/config/env_conf/admin.yaml
@@ -3,10 +3,10 @@
  password: gAAAAABnvAq8bErFiR9x_ZcODjUeOdrDo8Z5UVOzyqo6SxIhAvLpw81kciQN0frwIFVfY9wrxH1WqrpTICpEwfH7r2SkLjS7SQ==
chat_server:
  id: fe24dd2c9be611ef92880242ac160006
  account: user@example.com
  password: gAAAAABnvs3e3fZOYfUUAJ6uT80dkhNeN7rhylzZErTWRZThNSLzMbZGetPCe9A2BJ86V0nZBLMNNu8w6rWp4dC7JxYxByJcow==
  id: 2c039666c29d11efa4670242ac1b0006
  account: zhao1@example.com
  password: gAAAAABnpFLtotY2OIRH12BJh4MzMgn5Zil7-DVpIeuqlFwvr0g6g_n4ULogn-LNhCbtk6cCDkzZlqAHvBSX2e_zf7AsoyzbiQ==
workflow_server:
  account: basic@mail.com
  password: gAAAAABnvs5i7xUn9pb2szCozJciGSiWPGv80PH_2HFFzNM2r1ZLTOQqftnUso_bvchtmwAmccfNrf53sf9_WMFVTc0hjTKRRQ==
  account: admin@basic.com
  password: gAAAAABnpFLtotY2OIRH12BJh4MzMgn5Zil7-DVpIeuqlFwvr0g6g_n4ULogn-LNhCbtk6cCDkzZlqAHvBSX2e_zf7AsoyzbiQ==
app/models/__init__.py
@@ -17,6 +17,8 @@
from .menu_model import *
from .label_model import *
from .v2.session_model import *
from .v2.chat import *
from .v2.mindmap import *
from .system import *
app/models/v2/chat.py
@@ -1,5 +1,14 @@
from pydantic import BaseModel
import json
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel
from sqlalchemy import Column, Integer, String, BigInteger, ForeignKey, DateTime, Text, TEXT
from sqlalchemy.orm import Session
from app.config.const import Dialog_STATSU_DELETE
from app.models.base_model import Base
from app.utils.common import current_time
class RetrievalSetting(BaseModel):
@@ -12,5 +21,228 @@
    query: str
    retrieval_setting: RetrievalSetting
class ChatDataRequest(BaseModel):
    sessionId: str
    query: str
    chatMode: Optional[int] = 1  # 1= 普通对话,2=联网,3=知识库,4=深度
    isDeep: Optional[int] = 1  # 1= 普通, 2=深度
    knowledgeId: Optional[list] = []
    files: Optional[list] = []
    def to_dict(self):
        return {
            "sessionId": self.sessionId,
            "query": self.query,
            "chatMode": self.chatMode,
            "knowledgeId": self.knowledgeId,
            "files": self.files,
        }
class ComplexChatModel(Base):
    __tablename__ = 'complex_chat'
    __mapper_args__ = {
        # "order_by": 'SEQ'
    }
    id = Column(String(36), primary_key=True)  #  id
    create_date = Column(DateTime, default=datetime.now())             # 创建时间
    update_date = Column(DateTime, default=datetime.now(), onupdate=datetime.now())             # 更新时间
    tenant_id = Column(String(36))              # 创建人
    name = Column(String(255))                 # 名称
    description = Column(Text)                 # 说明
    icon = Column(Text, default="intelligentFrame1")                         # 图标
    status = Column(String(1), default="1")                 # 状态
    dialog_type = Column(String(1))            #  平台
    mode = Column(String(36))
    parameters = Column(Text)
    chat_mode = Column(Integer) #1= 普通对话,2=联网,3=知识库,4=深度
    def to_json(self):
        return {
            'id': self.id,
            'create_date': self.create_date.strftime('%Y-%m-%d %H:%M:%S'),
            'update_date': self.update_date.strftime('%Y-%m-%d %H:%M:%S'),
            'user_id': self.tenant_id,
            'name': self.name,
            'description': self.description,
            'icon': self.icon,
            'status': self.status,
            'agentType': self.dialog_type,
            'mode': self.mode,
        }
class ComplexChatDao:
    def __init__(self, db: Session):
        self.db = db
    async def create_complex_chat(self, chat_id: str, **kwargs) -> ComplexChatModel:
        new_session = ComplexChatModel(
            id=chat_id,
            create_date=current_time(),
            update_date=current_time(),
            **kwargs
        )
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
    async def get_complex_chat_by_id(self, chat_id: str) -> ComplexChatModel | None:
        session = self.db.query(ComplexChatModel).filter_by(id=chat_id).first()
        return session
    async def update_complex_chat_by_id(self, chat_id: str, session, message: dict, conversation_id=None) -> ComplexChatModel | None:
        if not session:
            session = await self.get_complex_chat_by_id(chat_id)
        if session:
            try:
                # TODO
                session.update_date = current_time()
                self.db.commit()
                self.db.refresh(session)
            except Exception as e:
                # logger.error(e)
                self.db.rollback()
        return session
    async def update_or_insert_by_id(self, chat_id: str, **kwargs) -> ComplexChatModel:
        existing_session = await self.get_complex_chat_by_id(chat_id)
        if existing_session:
            return await self.update_complex_chat_by_id(chat_id, existing_session, kwargs.get("message"))
        existing_session = await self.create_complex_chat(chat_id, **kwargs)
        return existing_session
    async def delete_complex_chat(self, chat_id: str) -> None:
        session = await self.get_complex_chat_by_id(chat_id)
        if session:
            self.db.delete(session)
            self.db.commit()
    async def aget_complex_chat_ids(self) -> List:
        session_list = self.db.query(ComplexChatModel).filter(ComplexChatModel.status!=Dialog_STATSU_DELETE).all()
        return [i.id for i in session_list]
    def get_complex_chat_ids(self) -> List:
        session_list = self.db.query(ComplexChatModel).filter(ComplexChatModel.status!=Dialog_STATSU_DELETE).all()
        return [i.id for i in session_list]
    async def get_complex_chat_by_mode(self, chat_mode: int) -> ComplexChatModel | None:
        session = self.db.query(ComplexChatModel).filter(ComplexChatModel.chat_mode==chat_mode, ComplexChatModel.status!=Dialog_STATSU_DELETE).first()
        return session
class ComplexChatSessionModel(Base):
    __tablename__ = "complex_chat_sessions"
    id = Column(String(36), primary_key=True)
    chat_id = Column(String(36))
    session_id = Column(String(36), index=True)
    create_date = Column(DateTime, default=current_time, index=True)  # 创建时间,默认值为当前时区时间
    update_date = Column(DateTime, default=current_time, onupdate=current_time)  # 更新时间,默认值为当前时区时间,更新时自动更新
    tenant_id = Column(Integer, index=True)  # 创建人
    agent_type = Column(Integer) # 1=rg, 3=basic,4=df
    message_type = Column(Integer)  # 1=用户,2=机器人,3=系统
    content = Column(TEXT)
    mindmap = Column(TEXT)
    query = Column(TEXT)
    node_data = Column(TEXT)
    event_type = Column(String(16))
    conversation_id = Column(String(36))
    chat_mode = Column(Integer) # 1= 普通对话,2=联网,3=知识库,4=深度
    # to_dict 方法
    def to_dict(self):
        return {
            'session_id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'chat_id': self.agent_id,
            'event_type': self.event_type,
            'session_type': self.session_type if self.session_type else 0,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
        }
    def log_to_json(self):
        return {
            'id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'chat_id': self.agent_id,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
            'message': json.loads(self.message)
        }
class ComplexChatSessionDao:
    def __init__(self, db: Session):
        self.db = db
    async def get_session_by_session_id(self, session_id: str, chat_id:str) -> ComplexChatSessionModel | None:
        session = self.db.query(ComplexChatSessionModel).filter_by(chat_id=chat_id, session_id=session_id, message_type=2).first()
        return session
    async def create_session(self, message_id: str, **kwargs) -> ComplexChatSessionModel:
        new_session = ComplexChatSessionModel(
            id=message_id,
            create_date=current_time(),
            update_date=current_time(),
            **kwargs
        )
        self.db.add(new_session)
        self.db.commit()
        self.db.refresh(new_session)
        return new_session
    async def get_session_by_id(self, message_id: str) -> ComplexChatSessionModel | None:
        session = self.db.query(ComplexChatSessionModel).filter_by(id=message_id).first()
        return session
    async def update_mindmap_by_id(self, message_id: str, mindmap:str) -> ComplexChatSessionModel | None:
        # print(message)
        session = await self.get_session_by_id(message_id)
        if session:
            try:
                session.mindmap = mindmap
                session.update_date = current_time()
                self.db.commit()
                self.db.refresh(session)
            except Exception as e:
                # logger.error(e)
                self.db.rollback()
        return session
    async def update_or_insert_by_id(self, session_id: str, **kwargs) -> ComplexChatSessionModel:
        existing_session = await self.get_session_by_id(session_id)
        if existing_session:
            return await self.update_session_by_id(session_id, existing_session, kwargs.get("message"))
        existing_session = await self.create_session(session_id, **kwargs)
        return existing_session
    async def delete_session(self, session_id: str) -> None:
        session = await self.get_session_by_id(session_id)
        if session:
            self.db.delete(session)
            self.db.commit()
    async def get_session_list(self, user_id: int, agent_id: str, keyword:str, page: int, page_size: int) -> any:
        query = self.db.query(ComplexChatSessionModel).filter(ComplexChatSessionModel.tenant_id==user_id)
        if agent_id:
            query = query.filter(ComplexChatSessionModel.agent_id==agent_id)
        if keyword:
            query = query.filter(ComplexChatSessionModel.name.like('%{}%'.format(keyword)))
        total = query.count()
        session_list = query.order_by(ComplexChatSessionModel.update_date.desc()).offset((page-1)*page_size).limit(page_size).all()
        return total, session_list
app/models/v2/mindmap.py
@@ -1,12 +1,16 @@
import json
from typing import Optional, Type, List
from datetime import datetime
from typing import List
from pydantic import BaseModel
from sqlalchemy import Column, Integer, String, BigInteger, ForeignKey, DateTime, Text
from sqlalchemy.orm import Session
from app.config.const import Dialog_STATSU_DELETE
from app.models.base_model import Base
from app.utils.common import current_time
class MindmapRequest(BaseModel):
    messageId: str
    query:str
class ChatData(BaseModel):
    sessionId: Optional[str] = ""
    class Config:
        extra = 'allow'  # 允许其他动态字段
app/service/v2/app_driver/chat_agent.py
@@ -9,6 +9,7 @@
    async def chat_completions(self, url, data, headers):
        complete_response = ""
        # print(data)
        async for line in self.http_stream(url, data, headers):
            # logger.error(line)
            if line.startswith("data:"):
@@ -46,6 +47,21 @@
            "files": files
        }
    @staticmethod
    async def complex_request_data(query: str, conversation_id: str, user: str, files: list=None, inputs: dict=None) -> dict:
        if not files:
            files = []
        if not inputs:
            inputs = {}
        return {
            "inputs": inputs,
            "query": query,
            "response_mode": "streaming",
            "conversation_id": conversation_id,
            "user": user,
            "files": files
        }
if __name__ == "__main__":
    async def aa():
app/service/v2/chat.py
@@ -1,6 +1,7 @@
import asyncio
import io
import json
import uuid
import fitz
from fastapi import HTTPException
@@ -10,7 +11,7 @@
    DF_CHAT_WORKFLOW, DF_UPLOAD_FILE, RG_ORIGINAL_URL
from app.config.config import settings
from app.config.const import *
from app.models import DialogModel, ApiTokenModel, UserTokenModel
from app.models import DialogModel, ApiTokenModel, UserTokenModel, ComplexChatSessionDao, ChatDataRequest
from app.models.v2.session_model import ChatSessionDao, ChatData
from app.service.v2.app_driver.chat_agent import ChatAgent
from app.service.v2.app_driver.chat_data import ChatBaseApply
@@ -47,12 +48,12 @@
        logger.error(e)
    return None
async def get_app_token(db, app_id):
    app_token = db.query(UserTokenModel).filter_by(id=app_id).first()
    if app_token:
        return app_token.access_token
    return ""
async def get_chat_token(db, app_id):
@@ -69,7 +70,6 @@
        db.commit()
    except Exception as e:
        logger.error(e)
async def get_chat_info(db, chat_id: str):
@@ -134,6 +134,7 @@
        message["role"] = "assistant"
        await update_session_log(db, session_id, message, conversation_id)
async def data_process(data):
    if isinstance(data, str):
        return data.replace("dify", "smart")
@@ -170,8 +171,9 @@
    if hasattr(chat_data, "query"):
        query = chat_data.query
    else:
        query = "start new workflow"
    session = await add_session_log(db, session_id,query if query else "start new conversation", chat_id, user_id, mode, conversation_id, 3)
        query = "start new conversation"
    session = await add_session_log(db, session_id, query if query else "start new conversation", chat_id, user_id,
                                    mode, conversation_id, 3)
    if session:
        conversation_id = session.conversation_id
    try:
@@ -215,7 +217,7 @@
                    [workflow_started, node_started, node_finished].index(ans.get("event"))]
            elif ans.get("event") == workflow_finished:
                data = ans.get("data", {})
                answer_workflow = data.get("outputs", {}).get("output")
                answer_workflow = data.get("outputs", {}).get("output", data.get("outputs", {}).get("answer"))
                download_url = data.get("outputs", {}).get("download_url")
                event = smart_workflow_finished
                if data.get("status") == "failed":
@@ -242,8 +244,9 @@
        except:
            ...
    finally:
        await update_session_log(db, session_id, {"role": "assistant", "answer": answer_event or answer_agent or answer_workflow or error,
                                                  "download_url":download_url,
        await update_session_log(db, session_id, {"role": "assistant",
                                                  "answer": answer_event or answer_agent or answer_workflow or error,
                                                  "download_url": download_url,
                                                  "node_list": node_list, "task_id": task_id, "id": message_id,
                                                  "error": error}, conversation_id)
@@ -257,6 +260,7 @@
    if not chat_info:
        return {}
    return chat_info.parameters
async def service_chat_sessions(db, chat_id, name):
    token = await get_chat_token(db, rg_api_token)
@@ -276,14 +280,12 @@
        page=current,
        page_size=page_size
    )
    return json.dumps({"total":total, "rows": [session.to_dict() for session in session_list]})
    return json.dumps({"total": total, "rows": [session.to_dict() for session in session_list]})
async def service_chat_session_log(db, session_id):
    session_log = await ChatSessionDao(db).get_session_by_id(session_id)
    return json.dumps(session_log.log_to_json())
    return json.dumps(session_log.log_to_json() if session_log else {})
async def service_chat_upload(db, chat_id, file, user_id):
@@ -316,6 +318,7 @@
    tokens = tokenizer.encode(input_str)
    return len(tokens)
async def read_pdf(pdf_stream):
    text = ""
    with fitz.open(stream=pdf_stream, filetype="pdf") as pdf_document:
@@ -335,6 +338,7 @@
    return text
async def read_file(file, filename, content_type):
    text = ""
    if content_type == "application/pdf" or filename.endswith('.pdf'):
@@ -349,7 +353,7 @@
async def service_chunk_retrieval(query, knowledge_id, top_k, similarity_threshold, api_key):
    print(query)
    # print(query)
    try:
        request_data = json.loads(query)
@@ -357,7 +361,7 @@
            "question": request_data.get("query", ""),
            "dataset_ids": request_data.get("dataset_ids", []),
            "page_size": top_k,
            "similarity_threshold": similarity_threshold
            "similarity_threshold": similarity_threshold if similarity_threshold else 0.2
        }
    except json.JSONDecodeError as e:
        fixed_json = query.replace("'", '"')
@@ -367,15 +371,16 @@
                "question": request_data.get("query", ""),
                "dataset_ids": request_data.get("dataset_ids", []),
                "page_size": top_k,
                "similarity_threshold": similarity_threshold
                "similarity_threshold": similarity_threshold if similarity_threshold else 0.2
            }
        except Exception:
            payload = {
                "question":query,
                "dataset_ids":[knowledge_id],
                "question": query,
                "dataset_ids": [knowledge_id],
                "page_size": top_k,
                "similarity_threshold": similarity_threshold
                "similarity_threshold": similarity_threshold if similarity_threshold else 0.2
            }
    # print(payload)
    url = settings.fwr_base_url + RG_ORIGINAL_URL
    chat = ChatBaseApply()
    response = await  chat.chat_post(url, payload, await chat.get_headers(api_key))
@@ -388,11 +393,15 @@
            "title": chunk.get("document_keyword", "Unknown Document"),
            "metadata": {"document_id": chunk["document_id"],
                         "path": f"{settings.fwr_base_url}/document/{chunk['document_id']}?ext={chunk.get('document_keyword').split('.')[-1]}&prefix=document",
                         'highlight': chunk.get("highlight") , "image_id":  chunk.get("image_id"), "positions": chunk.get("positions"),}
                         'highlight': chunk.get("highlight"), "image_id": chunk.get("image_id"),
                         "positions": chunk.get("positions"), }
        }
        for chunk in response.get("data", {}).get("chunks", [])
    ]
    # print(len(records))
    # print(records)
    return records
async def service_base_chunk_retrieval(query, knowledge_id, top_k, similarity_threshold, api_key):
    # request_data = json.loads(query)
@@ -420,15 +429,170 @@
    return records
async def add_complex_log(db, message_id, chat_id, session_id, chat_mode, query, user_id, mode, agent_type, message_type, conversation_id="", node_data=None, query_data=None):
    if not node_data:
        node_data = []
    if not query_data:
        query_data = {}
    try:
        complex_log = ComplexChatSessionDao(db)
        if not conversation_id:
            session = await complex_log.get_session_by_session_id(session_id, chat_id)
            if session:
                conversation_id = session.conversation_id
        await complex_log.create_session(message_id,
                                     chat_id=chat_id,
                                     session_id=session_id,
                                     chat_mode=chat_mode,
                                     message_type=message_type,
                                     content=query,
                                     event_type=mode,
                                     tenant_id=user_id,
                                     conversation_id=conversation_id,
                                     node_data=json.dumps(node_data),
                                     query=json.dumps(query_data),
                                     agent_type=agent_type)
        return conversation_id, True
    except Exception as e:
        logger.error(e)
        return conversation_id, False
async def service_complex_chat(db, chat_id, mode, user_id, chat_request: ChatDataRequest):
    answer_event = ""
    answer_agent = ""
    answer_workflow = ""
    download_url = ""
    message_id = ""
    task_id = ""
    error = ""
    files = []
    node_list = []
    token = await get_chat_token(db, chat_id)
    chat, url = await get_chat_object(mode)
    conversation_id, message = await add_complex_log(db, str(uuid.uuid4()),chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=chat_request.to_dict())
    if not message:
        yield "data: " + json.dumps({"message": smart_message_error,
                                     "error": "\n**ERROR**: 创建会话失败!", "status": http_500},
                                    ensure_ascii=False) + "\n\n"
        return
    inputs = {"is_deep": chat_request.isDeep}
    if chat_request.chatMode == complex_knowledge_chat:
        inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId})
    try:
        async for ans in chat.chat_completions(url,
                                               await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=chat_request.files, inputs=inputs),
                                               await chat.get_headers(token)):
            # print(ans)
            data = {}
            status = http_200
            conversation_id = ans.get("conversation_id")
            task_id = ans.get("task_id")
            if ans.get("event") == message_error:
                error = ans.get("message", "参数异常!")
                status = http_400
                event = smart_message_error
            elif ans.get("event") == message_agent:
                data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")}
                answer_agent += ans.get("answer", "")
                message_id = ans.get("message_id", "")
                event = smart_message_stream
            elif ans.get("event") == message_event:
                data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")}
                answer_event += ans.get("answer", "")
                message_id = ans.get("message_id", "")
                event = smart_message_stream
            elif ans.get("event") == message_file:
                data = {"url": ans.get("url", ""), "id": ans.get("id", ""),
                        "type": ans.get("type", "")}
                files.append(data)
                event = smart_message_file
            elif ans.get("event") in [workflow_started, node_started, node_finished]:
                data = ans.get("data", {})
                data["inputs"] = await data_process(data.get("inputs", {}))
                data["outputs"] = await data_process(data.get("outputs", {}))
                data["files"] = await data_process(data.get("files", []))
                data["process_data"] = ""
                if data.get("status") == "failed":
                    status = http_500
                    error = data.get("error", "")
                node_list.append(ans)
                event = [smart_workflow_started, smart_node_started, smart_node_finished][
                    [workflow_started, node_started, node_finished].index(ans.get("event"))]
            elif ans.get("event") == workflow_finished:
                data = ans.get("data", {})
                answer_workflow = data.get("outputs", {}).get("output", data.get("outputs", {}).get("answer"))
                download_url = data.get("outputs", {}).get("download_url")
                event = smart_workflow_finished
                if data.get("status") == "failed":
                    status = http_500
                    error = data.get("error", "")
                node_list.append(ans)
            elif ans.get("event") == message_end:
                event = smart_message_end
            else:
                continue
            yield "data: " + json.dumps(
                {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id,
                 "session_id": chat_request.sessionId},
                ensure_ascii=False) + "\n\n"
    except Exception as e:
        logger.error(e)
        try:
            yield "data: " + json.dumps({"message": smart_message_error,
                                         "error": "\n**ERROR**: " + str(e), "status": http_500},
                                        ensure_ascii=False) + "\n\n"
        except:
            ...
    finally:
        # await update_session_log(db, session_id, {"role": "assistant",
        #                                           "answer": answer_event or answer_agent or answer_workflow or error,
        #                                           "download_url": download_url,
        #                                           "node_list": node_list, "task_id": task_id, "id": message_id,
        #                                           "error": error}, conversation_id)
        if message_id:
            await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=chat_request.to_dict())
async def service_complex_upload(db, chat_id, file, user_id):
    files = []
    token = await get_chat_token(db, chat_id)
    if not token:
        return files
    url = settings.dify_base_url + DF_UPLOAD_FILE
    chat = ChatBaseApply()
    for f in file:
        try:
            file_content = await f.read()
            file_upload = await chat.chat_upload(url, {"file": (f.filename, file_content)}, {"user": str(user_id)},
                                                 {'Authorization': f'Bearer {token}'})
            # try:
            #     tokens = await read_file(file_content, f.filename, f.content_type)
            #     file_upload["tokens"] = tokens
            # except:
            #     ...
            files.append(file_upload)
        except Exception as e:
            logger.error(e)
    return json.dumps(files) if files else ""
if __name__ == "__main__":
    q = json.dumps({"query": "设备", "dataset_ids": ["fc68db52f43111efb94a0242ac120004"]})
    top_k = 2
    similarity_threshold = 0.5
    api_key = "ragflow-Y4MGYwY2JlZjM2YjExZWY4ZWU5MDI0Mm"
    # a = service_chunk_retrieval(q, top_k, similarity_threshold, api_key)
    # print(a)
    async def a():
        b = await service_chunk_retrieval(q, top_k, similarity_threshold, api_key)
        print(b)
    asyncio.run(a())
    asyncio.run(a())
app/service/v2/mindmap.py
New file
@@ -0,0 +1,198 @@
import json
from Log import logger
from app.config.agent_base_url import DF_CHAT_AGENT
from app.config.config import settings
from app.config.const import message_error, message_event, complex_knowledge_chat
from app.models import ComplexChatSessionDao, ChatData
from app.service.v2.app_driver.chat_agent import ChatAgent
from app.service.v2.chat import get_chat_token
async def service_chat_mindmap_v1(db, message_id, message, mindmap_chat_id, user_id):
    res = {}
    mindmap_query = ""
    complex_log = ComplexChatSessionDao(db)
    session = await complex_log.get_session_by_id(message_id)
    if session:
        token = await get_chat_token(db, session.chat_id)
        chat = ChatAgent()
        url = settings.dify_base_url + DF_CHAT_AGENT
        if session.mindmap:
            chat_request = json.loads(session.query)
            try:
                async for ans in chat.chat_completions(url,
                                                       await chat.request_data(message, session.conversation_id,
                                                                               str(user_id), ChatData(), chat_request.get("files", [])),
                                                       await chat.get_headers(token)):
                    if ans.get("event") == message_error:
                        return res
                    elif ans.get("event") == message_event:
                        mindmap_query += ans.get("answer", "")
                    else:
                        continue
            except Exception as e:
                logger.error(e)
                return res
        else:
            mindmap_query = session.content
        try:
            mindmap_str = ""
            token = await get_chat_token(db, mindmap_chat_id)
            async for ans in chat.chat_completions(url,
                                                   await chat.request_data(mindmap_query, "",
                                                                           str(user_id), ChatData()),
                                                   await chat.get_headers(token)):
                if ans.get("event") == message_error:
                    return res
                elif ans.get("event") == message_event:
                    mindmap_str += ans.get("answer", "")
                else:
                    continue
        except Exception as e:
            logger.error(e)
            return res
        mindmap_list = mindmap_str.split("```")
        mindmap_str = mindmap_list[1].lstrip("markdown\n")
        if session.mindmap:
            node_list = await mindmap_to_merge(session.mindmap, mindmap_str, f"- {message}")
            mindmap_str = "\n".join(node_list)
        res["mindmap"] = mindmap_str
        await complex_log.update_mindmap_by_id(message_id, mindmap_str)
    return res
async def service_chat_mindmap(db, message_id, message, mindmap_chat_id, user_id):
    res = {}
    mindmap_query = ""
    complex_log = ComplexChatSessionDao(db)
    session = await complex_log.get_session_by_id(message_id)
    if session:
        token = await get_chat_token(db, session.chat_id)
        chat = ChatAgent()
        url = settings.dify_base_url + DF_CHAT_AGENT
        if session.mindmap:
            chat_request = json.loads(session.query)
            inputs = {"is_deep": chat_request.get("isDeep", 1)}
            if session.chat_mode == complex_knowledge_chat:
                inputs["query_json"] = json.dumps(
                    {"query": chat_request.get("query", ""), "dataset_ids": chat_request.get("knowledgeId", [])})
            try:
                async for ans in chat.chat_completions(url,
                                                       await chat.complex_request_data(message, session.conversation_id,
                                                                               str(user_id), files=chat_request.get("files", []), inputs=inputs),
                                                       await chat.get_headers(token)):
                    if ans.get("event") == message_error:
                        return res
                    elif ans.get("event") == message_event:
                        mindmap_query += ans.get("answer", "")
                    else:
                        continue
            except Exception as e:
                logger.error(e)
                return res
        else:
            mindmap_query = session.content
        try:
            mindmap_str = ""
            token = await get_chat_token(db, mindmap_chat_id)
            async for ans in chat.chat_completions(url,
                                                   await chat.complex_request_data(mindmap_query, "",
                                                                           str(user_id)),
                                                   await chat.get_headers(token)):
                if ans.get("event") == message_error:
                    return res
                elif ans.get("event") == message_event:
                    mindmap_str += ans.get("answer", "")
                else:
                    continue
        except Exception as e:
            logger.error(e)
            return res
        if "```json" in mindmap_str:
            mindmap_list = mindmap_str.split("```")
            mindmap_str = mindmap_list[1].lstrip("json")
        mindmap_str = mindmap_str.replace("\n", "")
        if session.mindmap:
            mindmap_str = await mindmap_merge_dict(session.mindmap, mindmap_str, message)
        try:
            res_str = await mindmap_join_str(mindmap_str)
            res["mindmap"] = res_str
        except Exception as e:
            logger.error(e)
            return res
        await complex_log.update_mindmap_by_id(message_id, mindmap_str)
    return res
async def mindmap_merge_dict(parent, child, target_node):
    parent_dict = json.loads(parent)
    if child:
        child_dict = json.loads(child)
        def iter_dict(node):
            if "items" not in node:
                if node["title"] == target_node:
                    node["items"] = child_dict["items"]
                return
            else:
                for i in node["items"]:
                    iter_dict(i)
        iter_dict(parent_dict)
    return json.dumps(parent_dict)
async def mindmap_join_str(mindmap_json):
    try:
        parent_dict = json.loads(mindmap_json)
    except Exception as e:
        logger.error(e)
        return ""
    def join_node(node, level):
        mindmap_str = ""
        if level <= 2:
            mindmap_str += f"{'#'*level} {node['title']}\n"
        else:
            mindmap_str += f"{' '*(level-3)*2}- {node['title']}\n"
        for i in node.get("items", []):
            mindmap_str += join_node(i, level+1)
        return mindmap_str
    return join_node(parent_dict, 1)
async def mindmap_to_merge(parent, child, target_node):
    level = 0
    index = 0
    new_node_list = []
    parent_list= parent.split("\n")
    child_list= child.split("\n")
    child_list[0] = target_node
    for i, node in enumerate(parent_list):
        if node.endswith(target_node):
            level = len(node) - len(target_node)
            index = i
            break
    tmp_level = 0
    for child in child_list:
        if "#" in child:
            childs = child.split("#")
            tmp_level = len(childs) - 2
            new_node_list.append(" "*(level+tmp_level)+ "-"+childs[-1])
        elif len(child) == 0:
            continue
        else:
            new_node_list.append(" "*(level+tmp_level)+child)
    return parent_list[:index]+new_node_list+parent_list[index+1:]
if __name__ == '__main__':
    a = '{  "title": "全生命周期管理",  "items": [    {      "title": "设备规划与采购",      "items": [        {          "title": "需求分析与选型"    ,"items": [{"title": "rererer"}, {"title": "trtrtrtrt"}]    },        {          "title": "供应商选择与合同管理"        }      ]    },    {      "title": "设备安装与调试",      "items": [        {          "title": "安装规范"        },        {          "title": "调试测试"        }      ]    },    {      "title": "设备使用",      "items": [        {          "title": "操作培训"        },        {          "title": "操作规程与记录"        }      ]    },    {      "title": "设备维护与维修",      "items": [        {          "title": "定期维护"        },        {          "title": "故障诊断"        },        {          "title": "备件管理"        }      ]    },    {      "title": "设备更新与改造",      "items": [        {          "title": "技术评估"        },        {          "title": "更新计划"        },        {          "title": "改造方案"        }      ]    },    {      "title": "设备报废",      "items": [        {          "title": "报废评估"        },        {          "title": "报废处理"        }      ]    },    {      "title": "信息化管理",      "items": [        {          "title": "设备管理系统"        },        {          "title": "数据分析"        },        {          "title": "远程监控"        }      ]    },    {      "title": "安全管理",      "items": [        {          "title": "安全培训"        },        {          "title": "安全检查"        },        {          "title": "应急预案"        }      ]    },    {      "title": "环境保护",      "items": [        {          "title": "环保设备"        },        {          "title": "废物处理"        },        {          "title": "节能减排"        }      ]    },    {      "title": "具体实践案例",      "items": [        {          "title": "高压开关设备润滑脂选用研究"        },        {          "title": "环保型 C4 混气 GIS 设备运维技术研究"        }      ]    },    {      "title": "总结",      "items": [        {          "title": "提高运营效率和竞争力"        }      ]    }  ]}'
    b = mindmap_merge_dict(a, {}, "设备规划与采购")
    print(b)
app/task/fetch_agent.py
@@ -9,7 +9,7 @@
from app.config.config import settings
from app.config.const import RAGFLOW, BISHENG, DIFY, ENV_CONF_PATH, Dialog_STATSU_DELETE, Dialog_STATSU_ON
from app.models import KnowledgeModel
from app.models import KnowledgeModel, ComplexChatDao
from app.models.dialog_model import DialogModel
from app.models.user_model import UserAppModel
from app.models.agent_model import AgentModel
@@ -239,7 +239,7 @@
        db.close()
def get_data_from_ragflow_v2(names: List[str], tenant_id) -> List[Dict]:
def get_data_from_ragflow_v2(base_db, names: List[str], tenant_id) -> List[Dict]:
    db = SessionRagflow()
    para = {
        "user_input_form": [],
@@ -251,6 +251,8 @@
        }
    }
    try:
        chat_ids = ComplexChatDao(base_db).get_complex_chat_ids()
        # print(chat_ids)
        if names:
            query = db.query(Dialog.id, Dialog.name, Dialog.description, Dialog.status, Dialog.tenant_id) \
                .filter(Dialog.name.in_(names), Dialog.status == "1")
@@ -261,15 +263,17 @@
        results = query.all()
        formatted_results = [
            {"id": row[0], "name": row[1], "description": row[2], "status": "1" if row[3] == "1" else "2",
             "user_id": str(row[4]), "mode": "agent-dialog", "parameters": para} for row in results]
             "user_id": str(row[4]), "mode": "agent-dialog", "parameters": para} for row in results if row[0] not in chat_ids]
        return formatted_results
    finally:
        db.close()
def get_data_from_dy_v2(names: List[str]) -> List[Dict]:
def get_data_from_dy_v2(base_db, names: List[str]) -> List[Dict]:
    db = SessionDify()
    try:
        chat_ids = ComplexChatDao(base_db).get_complex_chat_ids()
        # print(chat_ids)
        if names:
            query = db.query(DfApps.id, DfApps.name, DfApps.description, DfApps.status, DfApps.tenant_id, DfApps.mode) \
                .filter(DfApps.name.in_(names))
@@ -279,7 +283,7 @@
        results = query.all()
        formatted_results = [
            {"id": str(row[0]), "name": row[1], "description": row[2], "status": "1",
             "user_id": str(row[4]), "mode": row[5], "parameters": {}} for row in results]
             "user_id": str(row[4]), "mode": row[5], "parameters": {}} for row in results if str(row[0]) not in chat_ids]
        return formatted_results
    finally:
        db.close()
@@ -342,11 +346,11 @@
        for app in app_register:
            try:
                if app["id"] == RAGFLOW:
                    ragflow_data = get_data_from_ragflow_v2([], app["name"])
                    ragflow_data = get_data_from_ragflow_v2(db, [], app["name"])
                    if ragflow_data:
                        update_ids_in_local_v2(ragflow_data, "1")
                elif app["id"] == DIFY:
                    dify_data = get_data_from_dy_v2([])
                    dify_data = get_data_from_dy_v2(db, [])
                    if dify_data:
                        update_ids_in_local_v2(dify_data, "4")
            except Exception as e:
app/utils/common.py
New file
@@ -0,0 +1,6 @@
import pytz
from datetime import datetime
def current_time():
    tz = pytz.timezone('Asia/Shanghai')
    return datetime.now(tz)
main.py
@@ -15,6 +15,7 @@
from app.api.organization import dept_router
from app.api.system import system_router
from app.api.v2.chat import chat_router_v2
from app.api.v2.mindmap import mind_map_router
from app.api.v2.public_api import public_api
from app.api.report import router as report_router
from app.api.resource import menu_router
@@ -77,6 +78,7 @@
app.include_router(public_api, prefix='/v1/api', tags=["public_api"])
app.include_router(chat_router_v2, prefix='/api/v1', tags=["chat1"])
app.include_router(system_router, prefix='/api/system', tags=["system"])
app.include_router(mind_map_router, prefix='/api/mindmap', tags=["mindmap"])
app.mount("/static", StaticFiles(directory="app/images"), name="static")
if __name__ == "__main__":
requirements.txt
Binary files differ