app/api/auth.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
app/api/chat.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
app/config/config.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
app/config/config.yaml | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
app/service/auth.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
app/service/ragflow.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
app/api/auth.py
@@ -1,9 +1,7 @@ from fastapi import APIRouter, Depends, Request from fastapi.security import OAuth2PasswordBearer from passlib.context import CryptContext from fastapi import APIRouter, Depends from sqlalchemy.orm import Session from app.api import Response, pwd_context, oauth2_scheme, get_current_user from app.api import Response, pwd_context from app.config.config import settings from app.models.base_model import get_db from app.models.token_model import upsert_token @@ -14,8 +12,6 @@ from app.service.ragflow import RagflowService router = APIRouter() @router.post("/register", response_model=Response) app/api/chat.py
@@ -1,7 +1,7 @@ import json import uuid from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Request, Depends from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Depends import asyncio import websockets from sqlalchemy.orm import Session @@ -9,7 +9,8 @@ from app.config.config import settings from app.models.base_model import get_db from app.models.user_model import UserModel from app.service.token import get_bisheng_token from app.service.ragflow import RagflowService from app.service.token import get_bisheng_token, get_ragflow_token router = APIRouter() @@ -27,21 +28,52 @@ await websocket.accept() print(f"Client {agent_id} connected") token = get_bisheng_token(db, current_user.id) if agent_id == "0": agent_id = settings.bisheng_agent_id elif agent_id == "1": agent_id = settings.ragflow_agent_id chat_id = settings.ragflow_chat_id if chat_id == "0": chat_id = uuid.uuid4().hex # 连接到服务端 client_websockets[chat_id] = websocket if agent_id == settings.ragflow_agent_id: ragflow_service = RagflowService(settings.ragflow_base_url) token = get_ragflow_token(db, current_user.id) try: async def forward_to_ragflow(): while True: message = await websocket.receive_json() print(f"Received from client {chat_id}: {message}") async for rag_response in ragflow_service.chat(token, chat_id, message["chatHistory"]): print(f"Received from ragflow: {rag_response}") json_str = rag_response[5:].strip() json_data = json.loads(json_str) if json_data.get("data") is not True: answer = json_data.get("data", {}).get("answer", "") result = {"message": answer, "type": "stream"} else: result = {"message": "", "type": "close"} await websocket.send_json(result) print(f"Forwarded to client {chat_id}: {result}") # 启动任务处理客户端消息 tasks = [ asyncio.create_task(forward_to_ragflow()) ] await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) except WebSocketDisconnect: print(f"Client {chat_id} disconnected") finally: del client_websockets[chat_id] else: token = get_bisheng_token(db, current_user.id) service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}" headers = { 'cookie': f"access_token_cookie={token};" } headers = {'cookie': f"access_token_cookie={token};"} async with websockets.connect(service_uri, extra_headers=headers) as service_websocket: client_websockets[chat_id] = websocket try: # 处理客户端发来的消息 @@ -61,7 +93,6 @@ await service_websocket.send(json.dumps(message)) print(f"Forwarded to bisheng: {message}") # 监听毕昇发来的消息并转发给客户端 async def forward_to_client(): while True: @@ -75,7 +106,6 @@ asyncio.create_task(forward_to_service()), asyncio.create_task(forward_to_client()) ] done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) # 取消未完成的任务 @@ -90,3 +120,5 @@ print(f"Client {chat_id} disconnected") finally: del client_websockets[chat_id] app/config/config.py
@@ -11,7 +11,8 @@ PUBLIC_KEY: str PRIVATE_KEY: str bisheng_agent_id: str ragflow_agent_id: str ragflow_chat_id: str def __init__(self, **kwargs): # Check if all required fields are provided and set them for field in self.__annotations__.keys(): app/config/config.yaml
@@ -9,3 +9,5 @@ -----END PUBLIC KEY----- PRIVATE_KEY: str bisheng_agent_id: 29dd57cf-1bd6-440d-af2c-2aac1c954770 ragflow_agent_id: 690f42554ac84ed7b8bf7605db603b2f ragflow_chat_id: e1d131a1b89b488e97c2194d9e2d345c app/service/auth.py
@@ -35,7 +35,7 @@ if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=15) expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) encoded_jwt = encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt app/service/ragflow.py
@@ -29,4 +29,30 @@ ) if response.status_code != 200: raise Exception(f"Ragflow login failed: {response.text}") return response.json().get('data', {}).get('access_token') # 从响应头中提取 Authorization 字段 authorization = response.headers.get('Authorization') if not authorization: raise Exception("Authorization header not found in response") return authorization async def chat(self, token: str, chat_id: str, chat_history: list): data = { "conversation_id": chat_id, "messages": chat_history } target_url = f"{self.base_url}/v1/conversation/completion" async with httpx.AsyncClient() as client: headers = { 'Content-Type': 'application/json', 'Authorization': token } # 创建流式请求 async with client.stream("POST", target_url, json=data, headers=headers) as response: # 检查响应状态码 if response.status_code == 200: # 流式读取响应 async for answer in response.aiter_text(): yield answer else: yield f"Error: {response.status_code}"