| | |
| | | from fastapi import FastAPI |
| | | import jwt |
| | | from fastapi import FastAPI, Depends, HTTPException |
| | | from fastapi.security import OAuth2PasswordBearer |
| | | from passlib.context import CryptContext |
| | | from pydantic import BaseModel |
| | | from starlette import status |
| | | from starlette.websockets import WebSocket, WebSocketDisconnect |
| | | |
| | | from app.models.user_model import UserModel |
| | | from app.service.auth import SECRET_KEY, ALGORITHM |
| | | |
| | | app = FastAPI() |
| | | |
| | | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
| | | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
| | | |
| | | |
| | | class Response(BaseModel): |
| | | code: int = 200 |
| | | msg: str = "" |
| | | data: dict = {} |
| | | |
| | | |
| | | def get_current_user(token: str = Depends(oauth2_scheme)): |
| | | try: |
| | | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
| | | username: str = payload.get("sub") |
| | | if username is None: |
| | | raise HTTPException( |
| | | status_code=status.HTTP_401_UNAUTHORIZED, |
| | | detail="无法验证凭证", |
| | | headers={"WWW-Authenticate": "Bearer"}, |
| | | ) |
| | | user = UserModel(username=username, id=payload.get("user_id")) |
| | | if user.id == 0: |
| | | raise HTTPException( |
| | | status_code=status.HTTP_401_UNAUTHORIZED, |
| | | detail="用户不存在", |
| | | headers={"WWW-Authenticate": "Bearer"}, |
| | | ) |
| | | return user |
| | | except jwt.PyJWTError: |
| | | raise HTTPException( |
| | | status_code=status.HTTP_401_UNAUTHORIZED, |
| | | detail="令牌无效或已过期", |
| | | headers={"WWW-Authenticate": "Bearer"}, |
| | | ) |
| | | |
| | | |
| | | async def get_current_user_websocket(websocket: WebSocket): |
| | | auth_header = websocket.headers.get('Authorization') |
| | | if auth_header is None or not auth_header.startswith('Bearer '): |
| | | await websocket.close(code=1008) |
| | | raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) |
| | | token = auth_header[len('Bearer '):] |
| | | try: |
| | | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
| | | username: str = payload.get("sub") |
| | | if username is None: |
| | | await websocket.close(code=1008) |
| | | raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) |
| | | user = UserModel(username=username, id=payload.get("user_id")) |
| | | if user is None: |
| | | await websocket.close(code=1008) |
| | | raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) |
| | | return user |
| | | except jwt.PyJWTError as e: |
| | | print(e) |
| | | await websocket.close(code=1008) |
| | | raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION) |
| | |
| | | from typing import Dict |
| | | import json |
| | | |
| | | from fastapi import APIRouter, Depends, HTTPException |
| | | from fastapi import APIRouter, Depends, Request |
| | | from fastapi.security import OAuth2PasswordBearer |
| | | from passlib.context import CryptContext |
| | | from sqlalchemy.orm import Session |
| | | |
| | | from app.api import Response |
| | | from app.api import Response, pwd_context, oauth2_scheme, get_current_user |
| | | from app.config.config import settings |
| | | from app.models.base_model import get_db |
| | | from app.models.token_model import upsert_token |
| | | from app.models.user import User, UserCreate, LoginData |
| | | from app.models.user import UserCreate, LoginData |
| | | from app.models.user_model import UserModel |
| | | from app.service.auth import authenticate_user, create_access_token |
| | | from app.service.bisheng import BishengService |
| | |
| | | |
| | | router = APIRouter() |
| | | |
| | | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
| | | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
| | | |
| | | |
| | | |
| | | @router.post("/register", response_model=Response) |
| | |
| | | return Response(code=500, msg=f"Failed to login with Ragflow: {str(e)}") |
| | | |
| | | # 创建本地token |
| | | access_token = create_access_token(data={"sub": user.username}) |
| | | access_token = create_access_token(data={"sub": user.username, "user_id": user.id}) |
| | | |
| | | upsert_token(db, user.id, access_token, bisheng_token, ragflow_token) |
| | | |
New file |
| | |
| | | import json |
| | | import uuid |
| | | |
| | | from fastapi import WebSocket, WebSocketDisconnect, APIRouter, Request, Depends |
| | | import asyncio |
| | | import websockets |
| | | from sqlalchemy.orm import Session |
| | | from app.api import get_current_user_websocket |
| | | from app.config.config import settings |
| | | from app.models.base_model import get_db |
| | | from app.models.user_model import UserModel |
| | | from app.service.token import get_bisheng_token |
| | | |
| | | router = APIRouter() |
| | | |
| | | # 存储客户端 WebSocket 连接 |
| | | client_websockets = {} |
| | | |
| | | |
| | | # 中间层WebSocket 服务器,接收客户端的连接 |
| | | @router.websocket("/ws/{agent_id}/{chat_id}") |
| | | async def handle_client(websocket: WebSocket, |
| | | agent_id: str, |
| | | chat_id: str, |
| | | current_user: UserModel = Depends(get_current_user_websocket), |
| | | db: Session = Depends(get_db)): |
| | | await websocket.accept() |
| | | print(f"Client {agent_id} connected") |
| | | |
| | | token = get_bisheng_token(db, current_user.id) |
| | | |
| | | if agent_id == "0": |
| | | agent_id = settings.bisheng_agent_id |
| | | if chat_id == "0": |
| | | chat_id = uuid.uuid4().hex |
| | | |
| | | # 连接到服务端 |
| | | service_uri = f"{settings.bisheng_websocket_url}/api/v1/assistant/chat/{agent_id}?t=&chat_id={chat_id}" |
| | | headers = { |
| | | 'cookie': f"access_token_cookie={token};" |
| | | } |
| | | |
| | | async with websockets.connect(service_uri, extra_headers=headers) as service_websocket: |
| | | client_websockets[chat_id] = websocket |
| | | |
| | | try: |
| | | # 处理客户端发来的消息 |
| | | async def forward_to_service(): |
| | | while True: |
| | | message = await websocket.receive_json() |
| | | print(f"Received from client {chat_id}: {message}") |
| | | # 添加 'agent_id' 和 'chat_id' 字段 |
| | | message['flow_id'] = agent_id |
| | | message['chat_id'] = chat_id |
| | | msg = message["message"] |
| | | del message["message"] |
| | | message['inputs'] = { |
| | | "data": {"chatId": chat_id, "id": agent_id, "type": "assistant"}, |
| | | "input": msg |
| | | } |
| | | await service_websocket.send(json.dumps(message)) |
| | | print(f"Forwarded to bisheng: {message}") |
| | | |
| | | |
| | | # 监听毕昇发来的消息并转发给客户端 |
| | | async def forward_to_client(): |
| | | while True: |
| | | message = await service_websocket.recv() |
| | | print(f"Received from service S: {message}") |
| | | await websocket.send_text(message) |
| | | print(f"Forwarded to client {chat_id}: {message}") |
| | | |
| | | # 启动两个任务,分别处理客户端和服务端的消息 |
| | | tasks = [ |
| | | asyncio.create_task(forward_to_service()), |
| | | asyncio.create_task(forward_to_client()) |
| | | ] |
| | | |
| | | done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) |
| | | |
| | | # 取消未完成的任务 |
| | | for task in pending: |
| | | task.cancel() |
| | | try: |
| | | await task |
| | | except asyncio.CancelledError: |
| | | pass |
| | | |
| | | except WebSocketDisconnect: |
| | | print(f"Client {chat_id} disconnected") |
| | | finally: |
| | | del client_websockets[chat_id] |
| | |
| | | class Settings: |
| | | secret_key: str = '' |
| | | bisheng_base_url: str = '' |
| | | bisheng_websocket_url: str = '' |
| | | ragflow_base_url: str = '' |
| | | database_url: str = '' |
| | | PUBLIC_KEY: str |
| | | PRIVATE_KEY: str |
| | | bisheng_agent_id: str |
| | | |
| | | def __init__(self, **kwargs): |
| | | # Check if all required fields are provided and set them |
| | |
| | | secret_key: your-secret-key |
| | | bisheng_base_url: http://192.168.20.119:13001 |
| | | bisheng_websocket_url: ws://192.168.20.119:13001 |
| | | ragflow_base_url: http://192.168.20.119:11080 |
| | | database_url: mysql+pymysql://root:infini_rag_flow@192.168.20.116:5455/rag_basic |
| | | PUBLIC_KEY: | |
| | | -----BEGIN PUBLIC KEY----- |
| | | MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB |
| | | -----END PUBLIC KEY----- |
| | | PRIVATE_KEY: str |
| | | PRIVATE_KEY: str |
| | | bisheng_agent_id: 29dd57cf-1bd6-440d-af2c-2aac1c954770 |
| | |
| | | |
| | | SECRET_KEY = settings.secret_key |
| | | ALGORITHM = "HS256" |
| | | ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
| | | ACCESS_TOKEN_EXPIRE_MINUTES = 3000 |
| | | |
| | | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
| | | |
New file |
| | |
| | | from app.models.token_model import TokenModel |
| | | |
| | | |
| | | def get_bisheng_token(db, user_id: int): |
| | | token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first() |
| | | if not token: |
| | | return None |
| | | return token.bisheng_token |
| | | |
| | | |
| | | def get_ragflow_token(db, user_id: int): |
| | | token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first() |
| | | if not token: |
| | | return None |
| | | return token.ragflow_token |
| | |
| | | from fastapi import FastAPI |
| | | from app.api.auth import router as auth_router |
| | | from app.api.chat import router as chat_router |
| | | from app.models.base_model import init_db |
| | | |
| | | init_db() |
| | |
| | | ) |
| | | |
| | | app.include_router(auth_router, prefix='/auth', tags=["auth"]) |
| | | app.include_router(chat_router, prefix='/chat', tags=["chat"]) |
| | | |
| | | if __name__ == "__main__": |
| | | import uvicorn |