From aa99acacfe3c21fbd638652f2fba1c1c62e3c414 Mon Sep 17 00:00:00 2001
From: zhangqian <zhangqian@123.com>
Date: 星期五, 11 十月 2024 21:38:55 +0800
Subject: [PATCH] websocket接口,转发毕昇对话
---
app/config/config.py | 2
app/api/chat.py | 92 +++++++++++++++++++++++
main.py | 2
app/service/token.py | 15 +++
app/config/config.yaml | 4
app/api/__init__.py | 62 +++++++++++++++
app/api/auth.py | 14 +--
app/service/auth.py | 2
8 files changed, 181 insertions(+), 12 deletions(-)
diff --git a/app/api/__init__.py b/app/api/__init__.py
index bcd5c2a..5ddd5aa 100644
--- a/app/api/__init__.py
+++ b/app/api/__init__.py
@@ -1,10 +1,70 @@
-from fastapi import FastAPI
+import jwt
+from fastapi import FastAPI, Depends, HTTPException
+from fastapi.security import OAuth2PasswordBearer
+from passlib.context import CryptContext
from pydantic import BaseModel
+from starlette import status
+from starlette.websockets import WebSocket, WebSocketDisconnect
+
+from app.models.user_model import UserModel
+from app.service.auth import SECRET_KEY, ALGORITHM
app = FastAPI()
+
+pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class Response(BaseModel):
code: int = 200
msg: str = ""
data: dict = {}
+
+
+def get_current_user(token: str = Depends(oauth2_scheme)):
+ try:
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
+ username: str = payload.get("sub")
+ if username is None:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="鏃犳硶楠岃瘉鍑瘉",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ user = UserModel(username=username, id=payload.get("user_id"))
+ if user.id == 0:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="鐢ㄦ埛涓嶅瓨鍦�",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ return user
+ except jwt.PyJWTError:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="浠ょ墝鏃犳晥鎴栧凡杩囨湡",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+
+
+async def get_current_user_websocket(websocket: WebSocket):
+ auth_header = websocket.headers.get('Authorization')
+ if auth_header is None or not auth_header.startswith('Bearer '):
+ await websocket.close(code=1008)
+ raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
+ token = auth_header[len('Bearer '):]
+ try:
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
+ username: str = payload.get("sub")
+ if username is None:
+ await websocket.close(code=1008)
+ raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
+ user = UserModel(username=username, id=payload.get("user_id"))
+ if user is None:
+ await websocket.close(code=1008)
+ raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
+ return user
+ except jwt.PyJWTError as e:
+ print(e)
+ await websocket.close(code=1008)
+ raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
diff --git a/app/api/auth.py b/app/api/auth.py
index 86da42f..feef6d9 100644
--- a/app/api/auth.py
+++ b/app/api/auth.py
@@ -1,16 +1,13 @@
-from typing import Dict
-import json
-
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, Depends, Request
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from sqlalchemy.orm import Session
-from app.api import Response
+from app.api import Response, pwd_context, oauth2_scheme, get_current_user
from app.config.config import settings
from app.models.base_model import get_db
from app.models.token_model import upsert_token
-from app.models.user import User, UserCreate, LoginData
+from app.models.user import UserCreate, LoginData
from app.models.user_model import UserModel
from app.service.auth import authenticate_user, create_access_token
from app.service.bisheng import BishengService
@@ -18,8 +15,7 @@
router = APIRouter()
-pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
-oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
@router.post("/register", response_model=Response)
@@ -74,7 +70,7 @@
return Response(code=500, msg=f"Failed to login with Ragflow: {str(e)}")
# 鍒涘缓鏈湴token
- access_token = create_access_token(data={"sub": user.username})
+ access_token = create_access_token(data={"sub": user.username, "user_id": user.id})
upsert_token(db, user.id, access_token, bisheng_token, ragflow_token)
diff --git a/app/api/chat.py b/app/api/chat.py
new file mode 100644
index 0000000..c7aa2da
--- /dev/null
+++ b/app/api/chat.py
@@ -0,0 +1,92 @@
+import json
+import uuid
+
+from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Request, Depends
+import asyncio
+import websockets
+from sqlalchemy.orm import Session
+from app.api import get_current_user_websocket
+from app.config.config import settings
+from app.models.base_model import get_db
+from app.models.user_model import UserModel
+from app.service.token import get_bisheng_token
+
+router = APIRouter()
+
+# 瀛樺偍瀹㈡埛绔� WebSocket 杩炴帴
+client_websockets = {}
+
+
+# 涓棿灞俉ebSocket 鏈嶅姟鍣紝鎺ユ敹瀹㈡埛绔殑杩炴帴
+@router.websocket("/ws/{agent_id}/{chat_id}")
+async def handle_client(websocket: WebSocket,
+ agent_id: str,
+ chat_id: str,
+ current_user: UserModel = Depends(get_current_user_websocket),
+ db: Session = Depends(get_db)):
+ await websocket.accept()
+ print(f"Client {agent_id} connected")
+
+ token = get_bisheng_token(db, current_user.id)
+
+ if agent_id == "0":
+ agent_id = settings.bisheng_agent_id
+ if chat_id == "0":
+ chat_id = uuid.uuid4().hex
+
+ # 杩炴帴鍒版湇鍔$
+ service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
+ headers = {
+ 'cookie': f"access_token_cookie={token};"
+ }
+
+ async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
+ client_websockets[chat_id] = websocket
+
+ try:
+ # 澶勭悊瀹㈡埛绔彂鏉ョ殑娑堟伅
+ async def forward_to_service():
+ while True:
+ message = await websocket.receive_json()
+ print(f"Received from client {chat_id}: {message}")
+ # 娣诲姞 'agent_id' 鍜� 'chat_id' 瀛楁
+ message['flow_id'] = agent_id
+ message['chat_id'] = chat_id
+ msg = message["message"]
+ del message["message"]
+ message['inputs'] = {
+ "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
+ "input": msg
+ }
+ await service_websocket.send(json.dumps(message))
+ print(f"Forwarded to bisheng: {message}")
+
+
+ # 鐩戝惉姣曟槆鍙戞潵鐨勬秷鎭苟杞彂缁欏鎴风
+ async def forward_to_client():
+ while True:
+ message = await service_websocket.recv()
+ print(f"Received from service S: {message}")
+ await websocket.send_text(message)
+ print(f"Forwarded to client {chat_id}: {message}")
+
+ # 鍚姩涓や釜浠诲姟锛屽垎鍒鐞嗗鎴风鍜屾湇鍔$鐨勬秷鎭�
+ tasks = [
+ asyncio.create_task(forward_to_service()),
+ asyncio.create_task(forward_to_client())
+ ]
+
+ done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
+
+ # 鍙栨秷鏈畬鎴愮殑浠诲姟
+ for task in pending:
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+ except WebSocketDisconnect:
+ print(f"Client {chat_id} disconnected")
+ finally:
+ del client_websockets[chat_id]
diff --git a/app/config/config.py b/app/config/config.py
index 8044029..0833ef2 100644
--- a/app/config/config.py
+++ b/app/config/config.py
@@ -6,10 +6,12 @@
class Settings:
secret_key: str = ''
bisheng_base_url: str = ''
+ bisheng_websocket_url: str = ''
ragflow_base_url: str = ''
database_url: str = ''
PUBLIC_KEY: str
PRIVATE_KEY: str
+ bisheng_agent_id: str
def __init__(self, **kwargs):
# Check if all required fields are provided and set them
diff --git a/app/config/config.yaml b/app/config/config.yaml
index bb7328d..db7638b 100644
--- a/app/config/config.yaml
+++ b/app/config/config.yaml
@@ -1,9 +1,11 @@
secret_key: your-secret-key
bisheng_base_url: http://192.168.20.119:13001
+bisheng_websocket_url: ws://192.168.20.119:13001
ragflow_base_url: http://192.168.20.119:11080
database_url: mysql+pymysql://root:infini_rag_flow@192.168.20.116:5455/rag_basic
PUBLIC_KEY: |
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB
-----END PUBLIC KEY-----
-PRIVATE_KEY: str
\ No newline at end of file
+PRIVATE_KEY: str
+bisheng_agent_id: 29dd57cf-1bd6-440d-af2c-2aac1c954770
\ No newline at end of file
diff --git a/app/service/auth.py b/app/service/auth.py
index 3ebccb1..09a4917 100644
--- a/app/service/auth.py
+++ b/app/service/auth.py
@@ -8,7 +8,7 @@
SECRET_KEY = settings.secret_key
ALGORITHM = "HS256"
-ACCESS_TOKEN_EXPIRE_MINUTES = 30
+ACCESS_TOKEN_EXPIRE_MINUTES = 3000
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
diff --git a/app/service/token.py b/app/service/token.py
new file mode 100644
index 0000000..9c54fed
--- /dev/null
+++ b/app/service/token.py
@@ -0,0 +1,15 @@
+from app.models.token_model import TokenModel
+
+
+def get_bisheng_token(db, user_id: int):
+ token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first()
+ if not token:
+ return None
+ return token.bisheng_token
+
+
+def get_ragflow_token(db, user_id: int):
+ token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first()
+ if not token:
+ return None
+ return token.ragflow_token
diff --git a/main.py b/main.py
index c6fb2cd..95945d6 100644
--- a/main.py
+++ b/main.py
@@ -1,5 +1,6 @@
from fastapi import FastAPI
from app.api.auth import router as auth_router
+from app.api.chat import router as chat_router
from app.models.base_model import init_db
init_db()
@@ -10,6 +11,7 @@
)
app.include_router(auth_router, prefix='/auth', tags=["auth"])
+app.include_router(chat_router, prefix='/chat', tags=["chat"])
if __name__ == "__main__":
import uvicorn
--
Gitblit v1.8.0