From 6202db458678153934fb4a31a041c58764a69138 Mon Sep 17 00:00:00 2001
From: zhangqian <zhangqian@123.com>
Date: 星期五, 18 十月 2024 22:59:51 +0800
Subject: [PATCH] 增加文件下载转发接口,把毕昇返回的文件地址改成我们的下载地址
---
app/api/chat.py | 184 +++++++++++++++++++++++++++++++--------------
1 files changed, 127 insertions(+), 57 deletions(-)
diff --git a/app/api/chat.py b/app/api/chat.py
index c7aa2da..2fdbd44 100644
--- a/app/api/chat.py
+++ b/app/api/chat.py
@@ -1,20 +1,19 @@
import json
import uuid
-from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Request, Depends
+from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends
import asyncio
import websockets
from sqlalchemy.orm import Session
from app.api import get_current_user_websocket
from app.config.config import settings
+from app.models.agent_model import AgentModel, AgentType
from app.models.base_model import get_db
from app.models.user_model import UserModel
-from app.service.token import get_bisheng_token
+from app.service.ragflow import RagflowService
+from app.service.token import get_bisheng_token, get_ragflow_token
router = APIRouter()
-
-# 瀛樺偍瀹㈡埛绔� WebSocket 杩炴帴
-client_websockets = {}
# 涓棿灞俉ebSocket 鏈嶅姟鍣紝鎺ユ敹瀹㈡埛绔殑杩炴帴
@@ -27,66 +26,137 @@
await websocket.accept()
print(f"Client {agent_id} connected")
- token = get_bisheng_token(db, current_user.id)
+ agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
+ if not agent:
+ ret = {"message": "Agent not found", "type": "close"}
+ await websocket.send_json(ret)
+ return
+ agent_type = agent.agent_type
+ if chat_id == "" or chat_id == "0":
+ ret = {"message": "Chat ID not found", "type": "close"}
+ await websocket.send_json(ret)
+ return
- if agent_id == "0":
- agent_id = settings.bisheng_agent_id
- if chat_id == "0":
- chat_id = uuid.uuid4().hex
-
- # 杩炴帴鍒版湇鍔$
- service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
- headers = {
- 'cookie': f"access_token_cookie={token};"
- }
-
- async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
- client_websockets[chat_id] = websocket
-
+ if agent_type == AgentType.RAGFLOW:
+ ragflow_service = RagflowService(settings.ragflow_base_url)
+ token = get_ragflow_token(db, current_user.id)
try:
- # 澶勭悊瀹㈡埛绔彂鏉ョ殑娑堟伅
- async def forward_to_service():
+ async def forward_to_ragflow():
while True:
message = await websocket.receive_json()
print(f"Received from client {chat_id}: {message}")
- # 娣诲姞 'agent_id' 鍜� 'chat_id' 瀛楁
- message['flow_id'] = agent_id
- message['chat_id'] = chat_id
- msg = message["message"]
- del message["message"]
- message['inputs'] = {
- "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
- "input": msg
- }
- await service_websocket.send(json.dumps(message))
- print(f"Forwarded to bisheng: {message}")
+ chat_history = message.get('chatHistory', [])
+ message["role"] = "user"
+ if len(chat_history) == 0:
+ chat_history = await ragflow_service.get_session_history(token, chat_id)
+ if len(chat_history) == 0:
+ chat_history = await ragflow_service.set_session(token, agent_id,
+ message, chat_id, True)
+ if len(chat_history) == 0:
+ result = {"message": "鍐呴儴閿欒锛氬垱寤轰細璇濆け璐�", "type": "close"}
+ await websocket.send_json(result)
+ await websocket.close()
+ return
+ else:
+ chat_history.append({
+ "content": message["message"],
+ "doc_ids": message.get("doc_ids", []),
+ "role": "user"
+ })
+ async for rag_response in ragflow_service.chat(token, chat_id, chat_history):
+ try:
+ if rag_response[:5] == "data:":
+ # 濡傛灉鏄紝鍒欐埅鍙栨帀鍓�5涓瓧绗︼紝骞跺幓闄ら灏剧┖鐧界
+ text = rag_response[5:].strip()
+ else:
+ # 鍚﹀垯锛屼繚鎸佸師鏍�
+ text = rag_response
+ try:
+ json_data = json.loads(text)
+ data = json_data.get("data")
+ if data is True: # 瀹屾垚杈撳嚭
+ result = {"message": "", "type": "close"}
+ elif data is None: # 鍙戠敓閿欒
+ answer = json_data.get("retmsg", json_data.get("retcode"))
+ result = {"message": "鍐呴儴閿欒锛�" + answer, "type": "message"}
+ else: # 姝e父杈撳嚭
+ answer = data.get("answer", "")
+ result = {"message": answer, "type": "message"}
+ await websocket.send_json(result)
+ except json.JSONDecodeError:
+ print(f"Error decode ragflow response: {text}")
+ pass
+ except Exception as e:
+ result = {"message": f"鍐呴儴閿欒锛� {e}", "type": "close"}
+ await websocket.send_json(result)
+ print(f"Error process message of ragflow: {e}")
- # 鐩戝惉姣曟槆鍙戞潵鐨勬秷鎭苟杞彂缁欏鎴风
- async def forward_to_client():
- while True:
- message = await service_websocket.recv()
- print(f"Received from service S: {message}")
- await websocket.send_text(message)
- print(f"Forwarded to client {chat_id}: {message}")
-
- # 鍚姩涓や釜浠诲姟锛屽垎鍒鐞嗗鎴风鍜屾湇鍔$鐨勬秷鎭�
+ # 鍚姩浠诲姟澶勭悊瀹㈡埛绔秷鎭�
tasks = [
- asyncio.create_task(forward_to_service()),
- asyncio.create_task(forward_to_client())
+ asyncio.create_task(forward_to_ragflow())
]
-
- done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
-
- # 鍙栨秷鏈畬鎴愮殑浠诲姟
- for task in pending:
- task.cancel()
- try:
- await task
- except asyncio.CancelledError:
- pass
-
+ await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
except WebSocketDisconnect:
print(f"Client {chat_id} disconnected")
- finally:
- del client_websockets[chat_id]
+
+ elif agent_type == AgentType.BISHENG:
+ token = get_bisheng_token(db, current_user.id)
+ service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
+ headers = {'cookie': f"access_token_cookie={token};"}
+
+ async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
+
+ try:
+ # 澶勭悊瀹㈡埛绔彂鏉ョ殑娑堟伅
+ async def forward_to_service():
+ while True:
+ message = await websocket.receive_json()
+ print(f"Received from client, {chat_id}: {message}")
+ # 娣诲姞 'agent_id' 鍜� 'chat_id' 瀛楁
+ message['flow_id'] = agent_id
+ message['chat_id'] = chat_id
+ msg = message["message"]
+ del message["message"]
+ message['inputs'] = {
+ "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
+ "input": msg
+ }
+ await service_websocket.send(json.dumps(message))
+ print(f"Forwarded to bisheng: {message}")
+
+ # 鐩戝惉姣曟槆鍙戞潵鐨勬秷鎭苟杞彂缁欏鎴风
+ async def forward_to_client():
+ while True:
+ message = await service_websocket.recv()
+ print(f"Received from bisheng: {message}")
+ data = json.loads(message)
+ if data["type"] == "close" or data["type"] == "stream" or data["type"] == "end_cover":
+ if data["type"] == "close":
+ t = "close"
+ else:
+ t = "stream"
+ result = {"message": data["message"], "type": t}
+ await websocket.send_json(result)
+ print(f"Forwarded to client, {chat_id}: {result}")
+
+ # 鍚姩涓や釜浠诲姟锛屽垎鍒鐞嗗鎴风鍜屾湇鍔$鐨勬秷鎭�
+ tasks = [
+ asyncio.create_task(forward_to_service()),
+ asyncio.create_task(forward_to_client())
+ ]
+ done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
+
+ # 鍙栨秷鏈畬鎴愮殑浠诲姟
+ for task in pending:
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+ except WebSocketDisconnect:
+ print(f"Client {chat_id} disconnected")
+ else:
+ ret = {"message": "Agent not found", "type": "close"}
+ await websocket.send_json(ret)
--
Gitblit v1.8.0