zhangqian
2024-11-19 f37670f13f8faf018a87d5b73b662bb1909ebe87
app/api/chat.py
@@ -1,20 +1,21 @@
import json
import uuid
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Request, Depends
from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends
import asyncio
import websockets
from sqlalchemy.orm import Session
from app.api import get_current_user_websocket
from app.config.config import settings
from app.models.agent_model import AgentModel, AgentType
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.token import get_bisheng_token
from app.service.dialog import update_session_history
from app.service.basic import BasicService
from app.service.ragflow import RagflowService
from app.service.service_token import get_bisheng_token, get_ragflow_token
router = APIRouter()
# 存储客户端 WebSocket 连接
client_websockets = {}
# 中间层WebSocket 服务器,接收客户端的连接
@@ -24,69 +25,217 @@
                        chat_id: str,
                        current_user: UserModel = Depends(get_current_user_websocket),
                        db: Session = Depends(get_db)):
    tasks = []
    await websocket.accept()
    print(f"Client {agent_id} connected")
    token = get_bisheng_token(db, current_user.id)
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    if not agent:
        ret = {"message": "Agent not found", "type": "close"}
        await websocket.send_json(ret)
        return
    agent_type = agent.agent_type
    if chat_id == "" or chat_id == "0":
        ret = {"message": "Chat ID not found", "type": "close"}
        await websocket.send_json(ret)
        return
    if agent_id == "0":
        agent_id = settings.bisheng_agent_id
    if chat_id == "0":
        chat_id = uuid.uuid4().hex
    # 连接到服务端
    service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
    headers = {
        'cookie':  f"access_token_cookie={token};"
    }
    async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
        client_websockets[chat_id] = websocket
    if agent_type == AgentType.RAGFLOW:
        ragflow_service = RagflowService(settings.fwr_base_url)
        token = get_ragflow_token(db, current_user.id)
        try:
            # 处理客户端发来的消息
            async def forward_to_service():
            async def forward_to_ragflow():
                while True:
                    is_new = False
                    message = await websocket.receive_json()
                    print(f"Received from client {chat_id}: {message}")
                    # 添加 'agent_id' 和 'chat_id' 字段
                    message['flow_id'] = agent_id
                    message['chat_id'] = chat_id
                    msg = message["message"]
                    del message["message"]
                    message['inputs'] = {
                        "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
                        "input": msg
                    }
                    await service_websocket.send(json.dumps(message))
                    print(f"Forwarded to bisheng: {message}")
                    chat_history = message.get('chatHistory', [])
                    message["role"] = "user"
                    if len(chat_history) == 0:
                        chat_history = await ragflow_service.get_session_history(token, chat_id)
                        if len(chat_history) == 0:
                            is_new = True
                            chat_history = await ragflow_service.set_session(token, agent_id,
                                                                             message, chat_id, True)
                            # print("chat_history------------------------", chat_history)
                            if len(chat_history) == 0:
                                result = {"message": "内部错误:创建会话失败", "type": "close"}
                                await websocket.send_json(result)
                                await websocket.close()
                                return
                        else:
                            chat_history.append({
                                "content": message["message"],
                                "doc_ids": message.get("doc_ids", []),
                                "role": "user"
                            })
                    complete_response = ""
                    async for rag_response in ragflow_service.chat(token, chat_id, chat_history):
                        try:
                            if rag_response[:5] == "data:":
                                # 如果是,则截取掉前5个字符,并去除首尾空白符
                                text = rag_response[5:].strip()
                            else:
                                # 否则,保持原样
                                text = rag_response
                            complete_response += text
                            try:
                                json_data = json.loads(complete_response)
                                data = json_data.get("data")
                                if data is True:  # 完成输出
                                    result = {"message": "", "type": "close"}
                                elif data is None:  # 发生错误
                                    answer = json_data.get("retmsg", json_data.get("retcode"))
                                    result = {"message": "内部错误:" + answer, "type": "message"}
                                else:  # 正常输出
                                    answer = data.get("answer", "")
                                    reference = data.get("reference", {})
                                    result = {"message": answer, "type": "message", "reference": reference}
                                await websocket.send_json(result)
                                complete_response = ""
                            except json.JSONDecodeError as e:
                                print(f"Error decoding JSON: {e}")
                                # print(f"Response text: {text}")
                        except Exception as e2:
                            result = {"message": f"内部错误: {e2}", "type": "close"}
                            await websocket.send_json(result)
                            print(f"Error process message of ragflow: {e2}")
                    dialog_chat_history = await ragflow_service.get_session_history(token, chat_id, 1)
                    await update_session_history(db, dialog_chat_history, current_user.id, is_new)
            # 监听毕昇发来的消息并转发给客户端
            async def forward_to_client():
                while True:
                    message = await service_websocket.recv()
                    print(f"Received from service S: {message}")
                    await websocket.send_text(message)
                    print(f"Forwarded to client {chat_id}: {message}")
            # 启动两个任务,分别处理客户端和服务端的消息
            # 启动任务处理客户端消息
            tasks = [
                asyncio.create_task(forward_to_service()),
                asyncio.create_task(forward_to_client())
                asyncio.create_task(forward_to_ragflow())
            ]
            await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
        except WebSocketDisconnect as e1:
            print(f"Client {chat_id} disconnected: {e1}")
            await websocket.close()
        except Exception as e:
            print(f"Exception occurred: {e}")
            done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
            # 取消未完成的任务
            for task in pending:
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass
        except WebSocketDisconnect:
            print(f"Client {chat_id} disconnected")
        finally:
            del client_websockets[chat_id]
            print("Cleaning up resources of ragflow")
            # 取消所有任务
            for task in tasks:
                if not task.done():
                    task.cancel()
                    try:
                        await task
                    except asyncio.CancelledError:
                        pass
    elif agent_type == AgentType.BISHENG:
        token = get_bisheng_token(db, current_user.id)
        service_uri = f"{settings.sgb_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
        headers = {'cookie': f"access_token_cookie={token};"}
        async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
            try:
                # 处理客户端发来的消息
                async def forward_to_service():
                    while True:
                        message = await websocket.receive_json()
                        print(f"Received from client, {chat_id}: {message}")
                        # 添加 'agent_id' 和 'chat_id' 字段
                        message['flow_id'] = agent_id
                        message['chat_id'] = chat_id
                        msg = message["message"]
                        del message["message"]
                        message['inputs'] = {
                            "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
                            "input": msg
                        }
                        await service_websocket.send(json.dumps(message))
                        print(f"Forwarded to bisheng: {message}")
                # 监听毕昇发来的消息并转发给客户端
                async def forward_to_client():
                    while True:
                        message = await service_websocket.recv()
                        print(f"Received from bisheng: {message}")
                        data = json.loads(message)
                        if data["type"] == "close" or data["type"] == "stream" or data["type"] == "end_cover":
                            if data["type"] == "close":
                                t = "close"
                            else:
                                t = "stream"
                            result = {"message": data["message"], "type": t}
                            await websocket.send_json(result)
                            print(f"Forwarded to client, {chat_id}: {result}")
                # 启动两个任务,分别处理客户端和服务端的消息
                tasks = [
                    asyncio.create_task(forward_to_service()),
                    asyncio.create_task(forward_to_client())
                ]
                done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
                # 取消未完成的任务
                for task in pending:
                    task.cancel()
                    try:
                        await task
                    except asyncio.CancelledError:
                        pass
            except WebSocketDisconnect as e:
                print(f"WebSocket connection closed with code {e.code}: {e.reason}")
                await websocket.close()
                await service_websocket.close()
            except Exception as e:
                print(f"Exception occurred: {e}")
            finally:
                print("Cleaning up resources of bisheng")
                # 取消所有任务
                for task in tasks:
                    if not task.done():
                        task.cancel()
                        try:
                            await task
                        except asyncio.CancelledError:
                            pass
    elif agent_type == AgentType.BASIC:
        try:
            while True:
                # 接收前端消息
                message = await websocket.receive_json()
                question = message.get("message")
                if not question:
                    await websocket.send_json({"message": "Invalid request", "type": "error"})
                    continue
                service = BasicService(base_url=settings.basic_base_url)
                complete_response = ""
                async for result in service.excel_talk(question, chat_id):
                    try:
                        if result[:5] == "data:":
                            # 如果是,则截取掉前5个字符,并去除首尾空白符
                            text = result[5:].strip()
                        else:
                            # 否则,保持原样
                            text = result
                        complete_response += text
                        try:
                            json_data = json.loads(complete_response)
                            output = json_data.get("output", "")
                            result = {"message": output, "type": "message"}
                            await websocket.send_json(result | json_data)
                            complete_response = ""
                        except json.JSONDecodeError as e:
                            print(f"Error decoding JSON: {e}")
                            print(f"Response text: {text}")
                    except Exception as e2:
                        result = {"message": f"内部错误: {e2}", "type": "close"}
                        await websocket.send_json(result)
                        print(f"Error process message of basic agent: {e2}")
        except Exception as e:
            await websocket.send_json({"message": str(e), "type": "error"})
        finally:
            await websocket.close()
            print(f"Client {agent_id} disconnected")
    else:
        ret = {"message": "Agent not found", "type": "close"}
        await websocket.send_json(ret)