From aa99acacfe3c21fbd638652f2fba1c1c62e3c414 Mon Sep 17 00:00:00 2001
From: zhangqian <zhangqian@123.com>
Date: 星期五, 11 十月 2024 21:38:55 +0800
Subject: [PATCH] websocket接口,转发毕昇对话

---
 app/config/config.py   |    2 
 app/api/chat.py        |   92 +++++++++++++++++++++++
 main.py                |    2 
 app/service/token.py   |   15 +++
 app/config/config.yaml |    4 
 app/api/__init__.py    |   62 +++++++++++++++
 app/api/auth.py        |   14 +--
 app/service/auth.py    |    2 
 8 files changed, 181 insertions(+), 12 deletions(-)

diff --git a/app/api/__init__.py b/app/api/__init__.py
index bcd5c2a..5ddd5aa 100644
--- a/app/api/__init__.py
+++ b/app/api/__init__.py
@@ -1,10 +1,70 @@
-from fastapi import FastAPI
+import jwt
+from fastapi import FastAPI, Depends, HTTPException
+from fastapi.security import OAuth2PasswordBearer
+from passlib.context import CryptContext
 from pydantic import BaseModel
+from starlette import status
+from starlette.websockets import WebSocket, WebSocketDisconnect
+
+from app.models.user_model import UserModel
+from app.service.auth import SECRET_KEY, ALGORITHM
 
 app = FastAPI()
+
+pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
 
 
 class Response(BaseModel):
     code: int = 200
     msg: str = ""
     data: dict = {}
+
+
+def get_current_user(token: str = Depends(oauth2_scheme)):
+    try:
+        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
+        username: str = payload.get("sub")
+        if username is None:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail="鏃犳硶楠岃瘉鍑瘉",
+                headers={"WWW-Authenticate": "Bearer"},
+            )
+        user = UserModel(username=username, id=payload.get("user_id"))
+        if user.id == 0:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail="鐢ㄦ埛涓嶅瓨鍦�",
+                headers={"WWW-Authenticate": "Bearer"},
+            )
+        return user
+    except jwt.PyJWTError:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail="浠ょ墝鏃犳晥鎴栧凡杩囨湡",
+            headers={"WWW-Authenticate": "Bearer"},
+        )
+
+
+async def get_current_user_websocket(websocket: WebSocket):
+    auth_header = websocket.headers.get('Authorization')
+    if auth_header is None or not auth_header.startswith('Bearer '):
+        await websocket.close(code=1008)
+        raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
+    token = auth_header[len('Bearer '):]
+    try:
+        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
+        username: str = payload.get("sub")
+        if username is None:
+            await websocket.close(code=1008)
+            raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
+        user = UserModel(username=username, id=payload.get("user_id"))
+        if user is None:
+            await websocket.close(code=1008)
+            raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
+        return user
+    except jwt.PyJWTError as e:
+        print(e)
+        await websocket.close(code=1008)
+        raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
diff --git a/app/api/auth.py b/app/api/auth.py
index 86da42f..feef6d9 100644
--- a/app/api/auth.py
+++ b/app/api/auth.py
@@ -1,16 +1,13 @@
-from typing import Dict
-import json
-
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, Depends, Request
 from fastapi.security import OAuth2PasswordBearer
 from passlib.context import CryptContext
 from sqlalchemy.orm import Session
 
-from app.api import Response
+from app.api import Response, pwd_context, oauth2_scheme, get_current_user
 from app.config.config import settings
 from app.models.base_model import get_db
 from app.models.token_model import upsert_token
-from app.models.user import User, UserCreate, LoginData
+from app.models.user import UserCreate, LoginData
 from app.models.user_model import UserModel
 from app.service.auth import authenticate_user, create_access_token
 from app.service.bisheng import BishengService
@@ -18,8 +15,7 @@
 
 router = APIRouter()
 
-pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
-oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
 
 
 @router.post("/register", response_model=Response)
@@ -74,7 +70,7 @@
         return Response(code=500, msg=f"Failed to login with Ragflow: {str(e)}")
 
     # 鍒涘缓鏈湴token
-    access_token = create_access_token(data={"sub": user.username})
+    access_token = create_access_token(data={"sub": user.username, "user_id": user.id})
 
     upsert_token(db, user.id, access_token, bisheng_token, ragflow_token)
 
diff --git a/app/api/chat.py b/app/api/chat.py
new file mode 100644
index 0000000..c7aa2da
--- /dev/null
+++ b/app/api/chat.py
@@ -0,0 +1,92 @@
+import json
+import uuid
+
+from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Request, Depends
+import asyncio
+import websockets
+from sqlalchemy.orm import Session
+from app.api import get_current_user_websocket
+from app.config.config import settings
+from app.models.base_model import get_db
+from app.models.user_model import UserModel
+from app.service.token import get_bisheng_token
+
+router = APIRouter()
+
+# 瀛樺偍瀹㈡埛绔� WebSocket 杩炴帴
+client_websockets = {}
+
+
+# 涓棿灞俉ebSocket 鏈嶅姟鍣紝鎺ユ敹瀹㈡埛绔殑杩炴帴
+@router.websocket("/ws/{agent_id}/{chat_id}")
+async def handle_client(websocket: WebSocket,
+                        agent_id: str,
+                        chat_id: str,
+                        current_user: UserModel = Depends(get_current_user_websocket),
+                        db: Session = Depends(get_db)):
+    await websocket.accept()
+    print(f"Client {agent_id} connected")
+
+    token = get_bisheng_token(db, current_user.id)
+
+    if agent_id == "0":
+        agent_id = settings.bisheng_agent_id
+    if chat_id == "0":
+        chat_id = uuid.uuid4().hex
+
+    # 杩炴帴鍒版湇鍔$
+    service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}"
+    headers = {
+        'cookie':  f"access_token_cookie={token};"
+    }
+
+    async with websockets.connect(service_uri, extra_headers=headers) as service_websocket:
+        client_websockets[chat_id] = websocket
+
+        try:
+            # 澶勭悊瀹㈡埛绔彂鏉ョ殑娑堟伅
+            async def forward_to_service():
+                while True:
+                    message = await websocket.receive_json()
+                    print(f"Received from client {chat_id}: {message}")
+                    # 娣诲姞 'agent_id' 鍜� 'chat_id' 瀛楁
+                    message['flow_id'] = agent_id
+                    message['chat_id'] = chat_id
+                    msg = message["message"]
+                    del message["message"]
+                    message['inputs'] = {
+                        "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"},
+                        "input": msg
+                    }
+                    await service_websocket.send(json.dumps(message))
+                    print(f"Forwarded to bisheng: {message}")
+
+
+            # 鐩戝惉姣曟槆鍙戞潵鐨勬秷鎭苟杞彂缁欏鎴风
+            async def forward_to_client():
+                while True:
+                    message = await service_websocket.recv()
+                    print(f"Received from service S: {message}")
+                    await websocket.send_text(message)
+                    print(f"Forwarded to client {chat_id}: {message}")
+
+            # 鍚姩涓や釜浠诲姟锛屽垎鍒鐞嗗鎴风鍜屾湇鍔$鐨勬秷鎭�
+            tasks = [
+                asyncio.create_task(forward_to_service()),
+                asyncio.create_task(forward_to_client())
+            ]
+
+            done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
+
+            # 鍙栨秷鏈畬鎴愮殑浠诲姟
+            for task in pending:
+                task.cancel()
+                try:
+                    await task
+                except asyncio.CancelledError:
+                    pass
+
+        except WebSocketDisconnect:
+            print(f"Client {chat_id} disconnected")
+        finally:
+            del client_websockets[chat_id]
diff --git a/app/config/config.py b/app/config/config.py
index 8044029..0833ef2 100644
--- a/app/config/config.py
+++ b/app/config/config.py
@@ -6,10 +6,12 @@
 class Settings:
     secret_key: str = ''
     bisheng_base_url: str = ''
+    bisheng_websocket_url: str = ''
     ragflow_base_url: str = ''
     database_url: str = ''
     PUBLIC_KEY: str
     PRIVATE_KEY: str
+    bisheng_agent_id: str
 
     def __init__(self, **kwargs):
         # Check if all required fields are provided and set them
diff --git a/app/config/config.yaml b/app/config/config.yaml
index bb7328d..db7638b 100644
--- a/app/config/config.yaml
+++ b/app/config/config.yaml
@@ -1,9 +1,11 @@
 secret_key: your-secret-key
 bisheng_base_url: http://192.168.20.119:13001
+bisheng_websocket_url: ws://192.168.20.119:13001
 ragflow_base_url: http://192.168.20.119:11080
 database_url: mysql+pymysql://root:infini_rag_flow@192.168.20.116:5455/rag_basic
 PUBLIC_KEY: |
   -----BEGIN PUBLIC KEY-----
   MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB
   -----END PUBLIC KEY-----
-PRIVATE_KEY: str
\ No newline at end of file
+PRIVATE_KEY: str
+bisheng_agent_id: 29dd57cf-1bd6-440d-af2c-2aac1c954770
\ No newline at end of file
diff --git a/app/service/auth.py b/app/service/auth.py
index 3ebccb1..09a4917 100644
--- a/app/service/auth.py
+++ b/app/service/auth.py
@@ -8,7 +8,7 @@
 
 SECRET_KEY = settings.secret_key
 ALGORITHM = "HS256"
-ACCESS_TOKEN_EXPIRE_MINUTES = 30
+ACCESS_TOKEN_EXPIRE_MINUTES = 3000
 
 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
 
diff --git a/app/service/token.py b/app/service/token.py
new file mode 100644
index 0000000..9c54fed
--- /dev/null
+++ b/app/service/token.py
@@ -0,0 +1,15 @@
+from app.models.token_model import TokenModel
+
+
+def get_bisheng_token(db, user_id: int):
+    token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first()
+    if not token:
+        return None
+    return token.bisheng_token
+
+
+def get_ragflow_token(db, user_id: int):
+    token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first()
+    if not token:
+        return None
+    return token.ragflow_token
diff --git a/main.py b/main.py
index c6fb2cd..95945d6 100644
--- a/main.py
+++ b/main.py
@@ -1,5 +1,6 @@
 from fastapi import FastAPI
 from app.api.auth import router as auth_router
+from app.api.chat import router as chat_router
 from app.models.base_model import init_db
 
 init_db()
@@ -10,6 +11,7 @@
 )
 
 app.include_router(auth_router, prefix='/auth', tags=["auth"])
+app.include_router(chat_router, prefix='/chat', tags=["chat"])
 
 if __name__ == "__main__":
     import uvicorn

--
Gitblit v1.8.0