zhaoqingang
2024-11-25 dc478b065693dd24e4cae719186d6aafb2d24f6d
difyq 接入
1个文件已添加
8个文件已修改
291 ■■■■■ 已修改文件
app/api/agent.py 24 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/chat.py 88 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/files.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/config/config.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/config/config.yaml 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/agent_model.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/session_model.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/difyService.py 153 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/session.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/agent.py
@@ -161,6 +161,30 @@
                data.append(tmp_data)
        return JSONResponse(status_code=200, content={"code": 200, "data": data})
    elif agent.agent_type == AgentType.DIFY:
        data = []
        session = db.query(SessionModel).filter(SessionModel.id == conversation_id).first()
        if session:
            tmp_data = {}
            for i in session.log_to_json().get("message", []):
                if i.get("role") == "user":
                    tmp_data["question"] = i.get("content")
                elif i.get("role") == "assistant":
                    if isinstance(i.get("content"), dict):
                        tmp_data["answer"] = i.get("content", {}).get("answer")
                        if "file_name" in i.get("content", {}):
                            tmp_data["files"] = [{"file_name": i.get("content", {}).get("file_name"),
                                                  "file_url": i.get("content", {}).get("file_url")}]
                    else:
                        tmp_data["answer"] = i.get("content")
                    data.append(tmp_data)
                    tmp_data = {}
            if tmp_data:
                data.append(tmp_data)
        return JSONResponse(status_code=200, content={"code": 200, "data": data})
    else:
        return JSONResponse(status_code=200, content={"code": 200, "log": "Unsupported agent type"})
app/api/chat.py
@@ -14,6 +14,7 @@
from app.models.user_model import UserModel
from app.service.dialog import update_session_history
from app.service.basic import BasicService
from app.service.difyService import DifyService
from app.service.ragflow import RagflowService
from app.service.service_token import get_bisheng_token, get_ragflow_token
from app.service.session import SessionService
@@ -286,6 +287,93 @@
        finally:
            await websocket.close()
            print(f"Client {agent_id} disconnected")
    if agent_type == AgentType.DIFY:
        dify_service = DifyService(settings.dify_base_url)
        # token = get_dify_token(db, current_user.id)
        token = settings.dify_api_token
        try:
            async def forward_to_dify():
                while True:
                    conversation_id = ""
                    receive_message = await websocket.receive_json()
                    print(f"Received from client {chat_id}: {receive_message}")
                    upload_file_id = receive_message.get('upload_file_id', [])
                    question = receive_message.get('message', "")
                    if not question and not image_url:
                        await websocket.send_json({"message": "Invalid request", "type": "error"})
                        continue
                    try:
                        session = SessionService(db).create_session(
                            chat_id,
                            question,
                            agent_id,
                            AgentType.DIFY,
                            current_user.id
                        )
                        conversation_id = session.conversation_id
                    except Exception as e:
                        logger.error(e)
                    complete_response = ""
                    async for rag_response in dify_service.chat(token, chat_id, question, upload_file_id, conversation_id):
                        try:
                            if rag_response[:5] == "data:":
                                # 如果是,则截取掉前5个字符,并去除首尾空白符
                                text = rag_response[5:].strip()
                            else:
                                # 否则,保持原样
                                text = rag_response
                            complete_response += text
                            try:
                                data = json.loads(complete_response)
                                # data = json_data.get("data")
                                if "answer" not in  data:  # 信息过滤
                                    continue
                                else:  # 正常输出
                                    answer = data.get("answer", "")
                                    result = {"message": answer, "type": "message"}
                                    try:
                                        SessionService(db).update_session(chat_id,
                                                                          message={"role": "assistant", "content": data, "conversation_id": data.get("conversation_id")})
                                    except Exception as e:
                                        logger.error(e)
                                await websocket.send_json(result)
                                complete_response = ""
                            except json.JSONDecodeError as e:
                                print(f"Error decoding JSON: {e}")
                                # print(f"Response text: {text}")
                        except Exception as e2:
                            result = {"message": f"内部错误: {e2}", "type": "close"}
                            await websocket.send_json(result)
                            print(f"Error process message of ragflow: {e2}")
                    try:
                        dialog_chat_history = await ragflow_service.get_session_history(token, chat_id, 1)
                        await update_session_history(db, dialog_chat_history, current_user.id)
                    except Exception as e:
                        logger.error(e)
                        logger.error("-----------------保存ragflow的历史会话异常-----------------")
            # 启动任务处理客户端消息
            tasks = [
                asyncio.create_task(forward_to_dify())
            ]
            await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
        except WebSocketDisconnect as e1:
            print(f"Client {chat_id} disconnected: {e1}")
            await websocket.close()
        except Exception as e:
            print(f"Exception occurred: {e}")
        finally:
            print("Cleaning up resources of ragflow")
            # 取消所有任务
            for task in tasks:
                if not task.done():
                    task.cancel()
                    try:
                        await task
                    except asyncio.CancelledError:
                        pass
    else:
        ret = {"message": "Agent not found", "type": "close"}
        await websocket.send_json(ret)
app/api/files.py
@@ -15,6 +15,7 @@
from app.models.user_model import UserModel
from app.service.basic import BasicService
from app.service.bisheng import BishengService
from app.service.difyService import DifyService
from app.service.ragflow import RagflowService
from app.service.service_token import get_ragflow_token, get_bisheng_token
import urllib.parse
@@ -93,6 +94,22 @@
            service = BasicService(base_url=settings.basic_paper_url)
            result = await service.paper_file_upload(chat_id, file.filename, file_content)
        elif agent.agent_type == AgentType.DIFY:
            file = file[0]
            # 读取上传的文件内容
            try:
                file_content = await file.read()
            except Exception as e:
                return Response(code=400, msg=str(e))
            dify_service = DifyService(base_url=settings.dify_base_url)
            try:
                token = get_bisheng_token(db, current_user.id)
                result = await dify_service.upload(token, file.filename, file_content)
            except Exception as e:
                raise HTTPException(status_code=500, detail=str(e))
            result["file_name"] = file.filename
            return Response(code=200, msg="", data=result)
        return Response(code=200, msg="", data=result)
app/config/config.py
@@ -17,6 +17,8 @@
    PASSWORD_KEY: str
    basic_base_url: str = ''
    basic_paper_url: str = ''
    dify_base_url: str = ''
    dify_api_token: str = ''
    def __init__(self, **kwargs):
        # Check if all required fields are provided and set them
        for field in self.__annotations__.keys():
app/config/config.yaml
@@ -14,4 +14,6 @@
fetch_fwr_agent: 知识问答,智能问答
PASSWORD_KEY: VKinqB-8XMrwCLLrcf_PyHyo12_4PVKvWzaHjNFions=
basic_base_url: http://192.168.20.231:8000
basic_paper_url: http://192.168.20.231:8000
basic_paper_url: http://192.168.20.231:8000
dify_base_url: http://192.168.20.116
dify_api_token: app-YmOAMDsPpDDlqryMHnc9TzTO
app/models/agent_model.py
@@ -9,6 +9,7 @@
    RAGFLOW = 1
    BISHENG = 2
    BASIC = 3
    DIFY = 4
class AgentModel(Base):
app/models/session_model.py
@@ -17,6 +17,7 @@
    update_date = Column(DateTime, default=current_time, onupdate=current_time)  # 更新时间,默认值为当前时区时间,更新时自动更新
    tenant_id = Column(Integer)  # 创建人
    message = Column(TEXT)  # 说明
    conversation_id = Column(String(64))
    # to_dict 方法
    def to_dict(self):
app/service/difyService.py
New file
@@ -0,0 +1,153 @@
import json
from datetime import datetime
import httpx
from typing import Union, Dict, List
from fastapi import HTTPException
from starlette import status
from watchdog.observers.fsevents2 import message
# from Log import logger
from app.config.config import settings
from app.utils.rsa_crypto import RagflowCrypto
class DifyService:
    def __init__(self, base_url: str):
        self.base_url = base_url
    def _handle_response(self, response: httpx.Response) -> Union[Dict, List]:
        if response.status_code != 200:
            return {}
        data = response.json()
        ret_code = data.get("retcode")
        if ret_code == 401:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="登录过期",
            )
        if ret_code != 0:
            return {}
        # 检查返回的数据类型
        if isinstance(data.get("data"), dict):
            return data.get("data", {})
        elif isinstance(data.get("data"), list):
            return data.get("data", [])
        else:
            return {}
    async def register(self, username: str, password: str):
        password = RagflowCrypto(settings.PUBLIC_KEY, settings.PRIVATE_KEY).encrypt(password)
        async with httpx.AsyncClient() as client:
            response = await client.post(
                f"{self.base_url}/v1/user/register",
                headers={'Content-Type': 'application/json'},
                json={"nickname": username, "email":  f"{username}@example.com", "password": password}
            )
            if response.status_code != 200:
                raise Exception(f"Ragflow registration failed: {response.text}")
            return self._handle_response(response)
    async def login(self, username: str, password: str) -> str:
        password = RagflowCrypto(settings.PUBLIC_KEY, settings.PRIVATE_KEY).encrypt(password)
        async with httpx.AsyncClient() as client:
            response = await client.post(
                f"{self.base_url}/v1/user/login",
                headers={'Content-Type': 'application/json'},
                json={"email": f"{username}@example.com", "password": password}
            )
            if response.status_code != 200:
                raise Exception(f"Ragflow login failed: {response.text}")
            authorization = response.headers.get('Authorization')
            if not authorization:
                raise Exception("Authorization header not found in response")
            return authorization
    async def chat(self, token: str, chat_id: str,  message: str, upload_file_id: str, conversation_id: str):
        target_url = f"{self.base_url}/v1/chat-messages"
        files = [
                {
                    "type": "image",
                    "transfer_method": "remote_url",
                    "url": "https://cloud.dify.ai/logo/logo-site.png",
                    "upload_file_id":""
                }
            ]
        if upload_file_id:
            files[0]["transfer_method"] = "local_file"
            files[0]["upload_file_id"] = upload_file_id
        data = {
            "inputs": {},
            "query": message,
            "response_mode": "streaming",
            "conversation_id": conversation_id,
            "user": chat_id,
            "files": files
        }
        async with httpx.AsyncClient(timeout=300.0) as client:
            headers = {
                'Content-Type': 'application/json',
                'Authorization': f'Bearer {token}'
            }
            async with client.stream("POST", target_url, data=json.dumps(data), headers=headers) as response:
                if response.status_code == 200:
                    try:
                        async for answer in response.aiter_text():
                            print(f"response of ragflow chat: {answer}")
                            yield answer
                    except GeneratorExit as e:
                        print(e)
                        return
                else:
                    yield f"Error: {response.status_code}"
    async def get_session_history(self, token: str, chat_id: str, is_all: int=0):
        url = f"{self.base_url}/v1/conversation/get?conversation_id={chat_id}"
        headers = {"Authorization": token}
        async with httpx.AsyncClient() as client:
            response = await client.get(url, headers=headers)
            data = self._handle_response(response)
            # print("----------------data----------------------:", data)
            if is_all:
                return data
            return data.get("message", [])
    async def upload(self, token: str, filename: str, file: bytes) -> dict:
        url = f"{self.base_url}/console/api/files/upload"
        headers = {
            'Content-Type': 'application/json',
            'Authorization': f'Bearer {token}'
        }
        # 创建表单数据,包含文件
        files = {"file": (filename, file)}
        async with httpx.AsyncClient() as client:
            response = await client.post(url, headers=headers, files=files)
            data = self._handle_response(response)
            # file_path = data.get("file_path", "")
            result = {
                "file_path": data
            }
            return result
if __name__ == "__main__":
    async def a():
        a = DifyService("http://192.168.20.119:11080")
        b = await a.get_knowledge_list("ImY3ZTZlZWQwYTY2NTExZWY5ZmFiMDI0MmFjMTMwMDA2Ig.Zzxwmw.uI_HAWzOkipQuga1aeQtoeIc3IM", 1,
                                 10)
        print(b)
    import asyncio
    asyncio.run(a())
app/service/session.py
@@ -88,6 +88,7 @@
                self.db.commit()
                self.db.refresh(session)
            except Exception as e:
                logger.error(e)
                self.db.rollback()
        return session