From 3fc9f4f33cf90610c71a1de7b00db0f82b988e98 Mon Sep 17 00:00:00 2001 From: zhangqian <zhangqian@123.com> Date: 星期三, 16 十月 2024 23:22:53 +0800 Subject: [PATCH] 文档上传&获取报告生成变量&报告生成接口 --- app/service/bisheng.py | 52 ++++++++-- app/service/ragflow.py | 21 ++- app/api/chat.py | 9 + main.py | 4 app/api/report.py | 105 +++++++++++++++++++++ app/api/agent.py | 10 -- app/api/files.py | 44 ++++++++ 7 files changed, 212 insertions(+), 33 deletions(-) diff --git a/app/api/agent.py b/app/api/agent.py index 2056aec..7423afd 100644 --- a/app/api/agent.py +++ b/app/api/agent.py @@ -16,16 +16,6 @@ router = APIRouter() -# Pydantic 妯″瀷鐢ㄤ簬鍝嶅簲 -class AgentResponse(BaseModel): - id: str - name: str - agent_type: AgentType - - class Config: - orm_mode = True - - @router.get("/list", response_model=ResponseList) async def agent_list(db: Session = Depends(get_db)): agents = db.query(AgentModel).order_by(AgentModel.sort.asc()).all() diff --git a/app/api/chat.py b/app/api/chat.py index 828b7e8..458581a 100644 --- a/app/api/chat.py +++ b/app/api/chat.py @@ -29,11 +29,13 @@ agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() if not agent: ret = {"message": "Agent not found", "type": "close"} - return websocket.send_json(ret) + await websocket.send_json(ret) + return agent_type = agent.agent_type if chat_id == "" or chat_id == "0": ret = {"message": "Chat ID not found", "type": "close"} - return websocket.send_json(ret) + await websocket.send_json(ret) + return if agent_type == AgentType.RAGFLOW: ragflow_service = RagflowService(settings.ragflow_base_url) @@ -50,7 +52,6 @@ if len(chat_history) == 0: result = {"message": "鍐呴儴閿欒锛氬垱寤轰細璇濆け璐�", "type": "close"} await websocket.send_json(result) - continue async for rag_response in ragflow_service.chat(token, chat_id, chat_history): try: print(f"Received from ragflow: {rag_response}") @@ -143,5 +144,5 @@ print(f"Client {chat_id} disconnected") else: ret = {"message": "Agent not found", "type": "close"} - return websocket.send_json(ret) + await websocket.send_json(ret) diff --git a/app/api/files.py b/app/api/files.py new file mode 100644 index 0000000..401428a --- /dev/null +++ b/app/api/files.py @@ -0,0 +1,44 @@ +from fastapi import Depends, APIRouter, HTTPException, UploadFile, File, requests +from sqlalchemy.orm import Session + +from app.api import Response, get_current_user, ResponseList +from app.config.config import settings +from app.models.agent_model import AgentType, AgentModel +from app.models.base_model import get_db +from app.models.user_model import UserModel +from app.service.bisheng import BishengService +from app.service.ragflow import RagflowService +from app.service.token import get_ragflow_token, get_bisheng_token + +router = APIRouter() + + +@router.post("/upload/{agent_id}", response_model=Response) +async def upload_file(agent_id: str, + file: UploadFile = File(...), + db: Session = Depends(get_db), + current_user: UserModel = Depends(get_current_user) + ): + agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() + if not agent: + return Response(code=404, msg="Agent not found") + # 璇诲彇涓婁紶鐨勬枃浠跺唴瀹� + try: + file_content = await file.read() + except Exception as e: + return Response(code=400, msg=str(e)) + + if agent.agent_type == AgentType.RAGFLOW: + pass + + elif agent.agent_type == AgentType.BISHENG: + bisheng_service = BishengService(base_url=settings.bisheng_base_url) + try: + token = get_bisheng_token(db, current_user.id) + result = await bisheng_service.upload(token, file_content) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + return Response(code=200, msg="", data=result) + + else: + return Response(code=200, msg="Unsupported agent type") diff --git a/app/api/report.py b/app/api/report.py new file mode 100644 index 0000000..386dcf1 --- /dev/null +++ b/app/api/report.py @@ -0,0 +1,105 @@ +import json + +from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends, HTTPException, Query +import asyncio +import websockets +from sqlalchemy.orm import Session +from app.api import get_current_user_websocket, ResponseList, get_current_user +from app.config.config import settings +from app.models.agent_model import AgentModel, AgentType +from app.models.base_model import get_db +from app.models.user_model import UserModel +from app.service.bisheng import BishengService +from app.service.token import get_bisheng_token + +router = APIRouter() + + +@router.websocket("/ws/{agent_id}/{chat_id}") +async def report_chat(websocket: WebSocket, + agent_id: str, + chat_id: str, + current_user: UserModel = Depends(get_current_user_websocket), + db: Session = Depends(get_db)): + agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() + if not agent: + ret = {"message": "Agent not found", "type": "close"} + return websocket.send_json(ret) + agent_type = agent.agent_type + if chat_id == "" or chat_id == "0": + ret = {"message": "Chat ID not found", "type": "close"} + return websocket.send_json(ret) + + if agent_type != AgentType.BISHENG: + ret = {"message": "Agent error", "type": "close"} + return websocket.send_json(ret) + + token = get_bisheng_token(db, current_user.id) + service_uri = f"{settings.bisheng_websocket_url}/api/v1/chat/{agent_id}?type=L1&t=&chat_id={chat_id}" + headers = {'cookie': f"access_token_cookie={token};"} + + await websocket.accept() + print(f"Client {agent_id} connected") + + async with websockets.connect(service_uri, extra_headers=headers) as service_websocket: + + try: + # 澶勭悊瀹㈡埛绔彂鏉ョ殑娑堟伅 + async def forward_to_service(): + while True: + message = await websocket.receive_json() + print(f"Received from client, {chat_id}: {message}") + # 娣诲姞 'agent_id' 鍜� 'chat_id' 瀛楁 + message['flow_id'] = agent_id + message['chat_id'] = chat_id + await service_websocket.send(json.dumps(message)) + print(f"Forwarded to bisheng: {message}") + + # 鐩戝惉姣曟槆鍙戞潵鐨勬秷鎭苟杞彂缁欏鎴风 + async def forward_to_client(): + while True: + message = await service_websocket.recv() + print(f"Received from bisheng: {message}") + data = json.loads(message) + files = data.get("files", []) + steps = data.get("intermediate_steps", "") + if len(files) != 0 or steps != "" or data["type"] == "close": + if data["type"] == "close": + t = "close" + else: + t = "stream" + result = {"step_message": steps, "type": t, "files": files} + await websocket.send_json(result) + print(f"Forwarded to client, {chat_id}: {result}") + + # 鍚姩涓や釜浠诲姟锛屽垎鍒鐞嗗鎴风鍜屾湇鍔$鐨勬秷鎭� + tasks = [ + asyncio.create_task(forward_to_service()), + asyncio.create_task(forward_to_client()) + ] + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + # 鍙栨秷鏈畬鎴愮殑浠诲姟 + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + except WebSocketDisconnect: + print(f"Client {chat_id} disconnected") + + +@router.get("/variables/list", response_model=ResponseList) +async def get_variables(agent_id: str = Query(..., description="The ID of the agent"), db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_user)): + agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first() + if not agent: + return ResponseList(code=404, msg="Agent not found") + bisheng_service = BishengService(base_url=settings.bisheng_base_url) + try: + token = get_bisheng_token(db, current_user.id) + result = await bisheng_service.variable_list(token, agent_id) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + return ResponseList(code=200, msg="", data=result) \ No newline at end of file diff --git a/app/service/bisheng.py b/app/service/bisheng.py index 3eb0dfd..e263c65 100644 --- a/app/service/bisheng.py +++ b/app/service/bisheng.py @@ -1,5 +1,4 @@ from datetime import datetime - import httpx from app.config.config import settings @@ -10,6 +9,15 @@ def __init__(self, base_url: str): self.base_url = base_url + def _check_response(self, response: httpx.Response): + if response.status_code not in [200, 201]: + raise Exception(f"Failed to fetch data from Bisheng API: {response.text}") + response_data = response.json() + status_code = response_data.get("status_code", 0) + if status_code != 200: + raise Exception(f"Failed to fetch data from Bisheng API: {response.text}") + return response_data.get("data", {}) + async def register(self, username: str, password: str): public_key = await self.get_public_key_api() password = BishengCrypto(public_key, settings.PRIVATE_KEY).encrypt(password) @@ -19,8 +27,7 @@ json={"user_name": username, "password": password}, headers={'Content-Type': 'application/json'} ) - if response.status_code != 200 and response.status_code != 201: - raise Exception(f"Bisheng registration failed: {response.text}") + self._check_response(response) async def login(self, username: str, password: str) -> str: public_key = await self.get_public_key_api() @@ -31,9 +38,8 @@ json={"user_name": username, "password": password}, headers={'Content-Type': 'application/json'} ) - if response.status_code != 200 and response.status_code != 201: - raise Exception(f"Bisheng login failed: {response.text}") - return response.json().get('data', {}).get('access_token') + data = self._check_response(response) + return data.get('access_token') async def get_public_key_api(self) -> dict: async with httpx.AsyncClient() as client: @@ -41,19 +47,16 @@ f"{self.base_url}/api/v1/user/public_key", headers={'Content-Type': 'application/json'} ) - if response.status_code != 200: - raise Exception(f"Failed to get public key: {response.text}") - return response.json().get('data', {}).get('public_key') + data = self._check_response(response) + return data.get('public_key') async def get_chat_sessions(self, token: str) -> list: url = f"{self.base_url}/api/v1/chat/list?page=1&limit=40" headers = {'cookie': f"access_token_cookie={token};"} async with httpx.AsyncClient() as client: response = await client.get(url, headers=headers) - if response.status_code != 200: - raise Exception(f"Failed to fetch data from Bisheng API: {response.text}") + data = self._check_response(response) - data = response.json().get("data", []) result = [ { "id": item["chat_id"], @@ -63,3 +66,28 @@ for item in data ] return result + + async def variable_list(self, token: str, agent_id: str) -> list: + url = f"{self.base_url}/api/v1/variable/list?flow_id={agent_id}" + headers = {'cookie': f"access_token_cookie={token};"} + async with httpx.AsyncClient() as client: + response = await client.get(url, headers=headers) + data = self._check_response(response) + return data + + async def upload(self, token: str, file: bytes) -> dict: + url = f"{self.base_url}/api/v1/knowledge/upload" + headers = {'cookie': f"access_token_cookie={token};"} + + # 鍒涘缓琛ㄥ崟鏁版嵁锛屽寘鍚枃浠� + files = {"file": ("file", file)} # 浣跨敤榛樿鏂囦欢鍚� "file" + + async with httpx.AsyncClient() as client: + response = await client.post(url, headers=headers, files=files) + data = self._check_response(response) + file_path = data.get("file_path", "") + result = { + "file_path": file_path + } + + return result diff --git a/app/service/ragflow.py b/app/service/ragflow.py index e4d5657..0fe0c39 100644 --- a/app/service/ragflow.py +++ b/app/service/ragflow.py @@ -87,22 +87,29 @@ "Authorization": token } - data = {"dialog_id": dialog_id, + data = { + "dialog_id": dialog_id, "name": name, "is_new": is_new, "conversation_id": chat_id, - } + } async with httpx.AsyncClient() as client: response = await client.post(url, headers=headers, json=data) if response.status_code != 200: return [] - return [{ - "content": "浣犲ソ锛� 鎴戞槸浣犵殑鍔╃悊锛屾湁浠�涔堝彲浠ュ府鍒颁綘鐨勫悧锛�", - "role": "assistant" - }, + ret_code = response.json().get("retcode") + if ret_code != 0: + return [] + + return [ + { + "content": "浣犲ソ锛� 鎴戞槸浣犵殑鍔╃悊锛屾湁浠�涔堝彲浠ュ府鍒颁綘鐨勫悧锛�", + "role": "assistant" + }, { "content": name, "doc_ids": [], "role": "user" - }] + } + ] diff --git a/main.py b/main.py index d0131b2..2470d4b 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,8 @@ from app.api.chat import router as chat_router from app.api.agent import router as agent_router from app.api.excel import router as excel_router +from app.api.files import router as files_router +from app.api.report import router as report_router from app.models.base_model import init_db init_db() @@ -16,6 +18,8 @@ app.include_router(chat_router, prefix='/api/chat', tags=["chat"]) app.include_router(agent_router, prefix='/api/agent', tags=["agent"]) app.include_router(excel_router, prefix='/api/document', tags=["document"]) +app.include_router(files_router, prefix='/api/files', tags=["files"]) +app.include_router(report_router, prefix='/api/report', tags=["report"]) if __name__ == "__main__": import uvicorn -- Gitblit v1.8.0