| | |
| | | import urllib |
| | | from datetime import datetime |
| | | from typing import Callable, Any |
| | | from urllib.parse import urlencode |
| | | |
| | | import jwt |
| | | # from cryptography.fernet import Fernet |
| | | from fastapi import FastAPI, Depends, HTTPException, Header |
| | | from fastapi import FastAPI, Depends, HTTPException, Header, Request |
| | | from fastapi.security import OAuth2PasswordBearer |
| | | from passlib.context import CryptContext |
| | | from pydantic import BaseModel |
| | |
| | | from starlette.websockets import WebSocket, WebSocketDisconnect |
| | | |
| | | from Log import logger |
| | | from app.models.base_model import SessionLocal |
| | | # from app.models.app_model import AppRegisterModel |
| | | from app.models.user_model import UserModel |
| | | from app.models.user_model import UserModel, UserApiTokenModel |
| | | from app.service.auth import SECRET_KEY, ALGORITHM |
| | | from app.config.config import settings |
| | | |
| | |
| | | data: list[dict] = [] |
| | | |
| | | |
| | | def verify_token(token: str) -> Any: |
| | | """ |
| | | 验证 Token 是否有效 |
| | | """ |
| | | db = SessionLocal() |
| | | try: |
| | | db_token = db.query(UserApiTokenModel).filter(UserApiTokenModel.token == token, UserApiTokenModel.is_active == 1).first() |
| | | return db_token is not None and (db_token.expires_at is None or db_token.expires_at > datetime.now()) |
| | | finally: |
| | | db.close() |
| | | |
| | | def token_required()-> Callable: |
| | | def decorated_function(request: Request)-> Any: |
| | | authorization_str = request.headers.get("Authorization") |
| | | if not authorization_str: |
| | | raise HTTPException(status_code=401, detail="Authorization` can't be empty") |
| | | authorization_list = authorization_str.split() |
| | | if len(authorization_list) < 2: |
| | | raise HTTPException(status_code=401, detail="Invalid token") |
| | | token = authorization_list[1] |
| | | objs = verify_token(token) |
| | | if not objs: |
| | | raise HTTPException(status_code=401, detail="Invalid token") |
| | | user = UserModel(username="", id=objs.user_id) |
| | | return user |
| | | return decorated_function |
| | | |
| | | def get_current_user(token: str = Depends(oauth2_scheme)): |
| | | try: |
| | | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
| | |
| | | import json |
| | | import uuid |
| | | |
| | | from fastapi import APIRouter, Depends |
| | | from fastapi.responses import JSONResponse |
| | | from starlette.responses import StreamingResponse |
| | | |
| | | from Log import logger |
| | | from app.api import Response |
| | | from app.api import Response, token_required |
| | | |
| | | from app.config.const import IMAGE_TO_TEXT, DOCUMENT_TO_CLEANING, DOCUMENT_TO_REPORT, DIFY, BISHENG, RAGFLOW, \ |
| | | DOCUMENT_IA_QUESTIONS, DOCUMENT_TO_REPORT_TITLE, DOCUMENT_TO_TITLE, DOCUMENT_TO_PAPER, \ |
| | | DOCUMENT_IA_QUESTIONS_EQUIPMENT |
| | | |
| | | from app.models.base_model import get_db |
| | | DOCUMENT_IA_QUESTIONS_EQUIPMENT, dialog_chat, workflow_chat, advanced_chat, agent_chat, base_chat |
| | | from app.models.public_api_model import DfToken |
| | | from app.service.v2.api_token import DfTokenDao |
| | | from app.service.v2.initialize_data import dialog_menu_sync, create_menu_sync, user_update_app |
| | | from app.task.sync_resources import sync_knowledge, sync_dialog, sync_agent, sync_llm, sync_resource |
| | | from fastapi import Depends, APIRouter, File, UploadFile |
| | | from sqlalchemy.orm import Session |
| | | from app.config.const import smart_message_error, http_400, http_500, http_200, complex_dialog_chat |
| | | from app.models import UserModel |
| | | from app.models.base_model import get_db |
| | | from app.models.v2.session_model import ChatData |
| | | from app.service.v2.chat import service_chat_dialog, get_chat_info, service_chat_sessions, service_chat_workflow |
| | | |
| | | public_api = APIRouter() |
| | | |
| | |
| | | return Response(code=500, msg=str(e), data={}) |
| | | |
| | | return Response(code=200, msg="success", data={}) |
| | | |
| | | |
| | | @public_api.post("/chat/{chatId}/completions") |
| | | async def api_chat_dialog(chatId:str, dialog: ChatData, current_user: UserModel = Depends(token_required),db: Session = Depends(get_db)): # current_user: UserModel = Depends(get_current_user) |
| | | chat_info = await get_chat_info(db, chatId) |
| | | if not chat_info: |
| | | error_msg = json.dumps( |
| | | {"message": smart_message_error, "error": "\n**ERROR**: parameter exception", "status": http_400}) |
| | | return StreamingResponse(f"data: {error_msg}\n\n", |
| | | media_type="text/event-stream") |
| | | if chat_info.mode == dialog_chat: |
| | | |
| | | session_id = dialog.sessionId |
| | | if not dialog.query: |
| | | error_msg = json.dumps( |
| | | {"message": smart_message_error, "error": "\n**ERROR**: question cannot be empty.", "status": http_400}) |
| | | return StreamingResponse(f"data: {error_msg}\n\n", |
| | | media_type="text/event-stream") |
| | | if not session_id: |
| | | session = await service_chat_sessions(db, chatId, dialog.query) |
| | | # print(session) |
| | | if not session or session.get("code") != 0: |
| | | error_msg = json.dumps( |
| | | {"message": smart_message_error, "error": "\n**ERROR**: chat agent error", "status": http_500}) |
| | | return StreamingResponse(f"data: {error_msg}\n\n", |
| | | media_type="text/event-stream") |
| | | session_id = session.get("data", {}).get("id") |
| | | return StreamingResponse(service_chat_dialog(db, chatId, dialog.query, session_id, current_user.id, chat_info.mode), |
| | | media_type="text/event-stream") |
| | | elif chat_info.mode == workflow_chat: |
| | | chat_info = await get_chat_info(db, chatId) |
| | | if not chat_info: |
| | | error_msg = json.dumps( |
| | | {"message": smart_message_error, "error": "\n**ERROR**: parameter exception", "status": http_400}) |
| | | return StreamingResponse(f"data: {error_msg}\n\n", |
| | | media_type="text/event-stream") |
| | | session_id = dialog.sessionId |
| | | if not session_id: |
| | | session_id = str(uuid.uuid4()).replace("-", "") |
| | | return StreamingResponse(service_chat_workflow(db, chatId, dialog, session_id, current_user.id, chat_info.mode), |
| | | media_type="text/event-stream") |
| | | |
| | | elif chat_info.mode == advanced_chat or chat_info.mode == agent_chat or chat_info.mode == base_chat: |
| | | chat_info = await get_chat_info(db, chatId) |
| | | if not chat_info: |
| | | error_msg = json.dumps( |
| | | {"message": smart_message_error, "error": "\n**ERROR**: parameter exception", "status": http_400}) |
| | | return StreamingResponse(f"data: {error_msg}\n\n", |
| | | media_type="text/event-stream") |
| | | session_id = dialog.sessionId |
| | | if not session_id: |
| | | session_id = str(uuid.uuid4()).replace("-", "") |
| | | return StreamingResponse(service_chat_workflow(db, chatId, dialog, session_id, current_user.id, chat_info.mode), |
| | | media_type="text/event-stream") |
| | | else: |
| | | error_msg = json.dumps( |
| | | {"message": smart_message_error, "error": "\n**ERROR**: unknown chat", "status": http_400}) |
| | | return StreamingResponse(f"data: {error_msg}\n\n", |
| | | media_type="text/event-stream") |
| | |
| | | 'password': self.password, |
| | | 'access_token': self.access_token, |
| | | 'refresh_token': self.refresh_token, |
| | | } |
| | | |
| | | |
| | | |
| | | class UserApiTokenModel(Base): |
| | | __tablename__ = "user_api_token" |
| | | id = Column(Integer, primary_key=True) |
| | | user_id = Column(Integer) |
| | | token = Column(String(40), index=True) |
| | | created_at = Column(DateTime, default=datetime.now()) |
| | | updated_at = Column(DateTime, default=datetime.now()) |
| | | expires_at = Column(DateTime) |
| | | is_active = Column(Integer, default=1) |
| | | |
| | | def to_json(self): |
| | | return { |
| | | 'id': self.id, |
| | | 'account': self.username, |
| | | 'createTime': self.created_at, |
| | | 'updateTime': self.updated_at, |
| | | 'password': self.password, |
| | | 'access_token': self.access_token, |
| | | 'refresh_token': self.refresh_token, |
| | | } |
| | |
| | | |
| | | class ChatDataRequest(BaseModel): |
| | | sessionId: str |
| | | parentId: Optional[str] = "" |
| | | query: str |
| | | chatMode: Optional[int] = 1 # 1= 普通对话,2=联网,3=知识库,4=深度 |
| | | isDeep: Optional[int] = 1 # 1= 普通, 2=深度 |
| | |
| | | "files": self.files, |
| | | "isDeep": self.isDeep, |
| | | "optimizeType": self.optimizeType, |
| | | "parentId": self.parentId, |
| | | } |
| | | |
| | | |
| | |
| | | 'content': self.content, |
| | | } |
| | | else: |
| | | query = {} |
| | | if self.query: |
| | | query = json.loads(self.query) |
| | | return { |
| | | 'id': self.id, |
| | | 'role': "assistant", |
| | | 'answer': self.content, |
| | | 'chat_mode': self.chat_mode, |
| | | 'node_list': json.loads(self.node_data) if self.node_data else [], |
| | | "parentId": query.get("parentId") |
| | | } |
| | | |
| | | |
| | |
| | | logger.error(e) |
| | | return conversation_id, False |
| | | |
| | | |
| | | async def add_query_files(db, message_id): |
| | | query = {} |
| | | complex_log = await ComplexChatSessionDao(db).get_session_by_id(message_id) |
| | | if complex_log: |
| | | query = json.loads(complex_log.query) |
| | | return query.get("files", []) |
| | | |
| | | async def service_complex_chat(db, chat_id, mode, user_id, chat_request: ChatDataRequest): |
| | | answer_event = "" |
| | |
| | | message_id = "" |
| | | task_id = "" |
| | | error = "" |
| | | files = [] |
| | | node_list = [] |
| | | conversation_id = "" |
| | | token = await get_chat_token(db, chat_id) |
| | | chat, url = await get_chat_object(mode) |
| | | query_data = chat_request.to_dict() |
| | | new_message_id = str(uuid.uuid4()) |
| | | inputs = {"is_deep": chat_request.isDeep} |
| | | files = chat_request.files |
| | | if chat_request.chatMode == complex_knowledge_chat: |
| | | inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId}) |
| | | elif chat_request.chatMode == complex_content_optimization_chat: |
| | | inputs["type"] = chat_request.optimizeType |
| | | elif chat_request.chatMode == complex_dialog_chat: |
| | | if not files and chat_request.parentId: |
| | | files = await add_query_files(db, chat_request.parentId) |
| | | if chat_request.chatMode != complex_content_optimization_chat: |
| | | await add_session_log(db, chat_request.sessionId, chat_request.query if chat_request.query else "未命名会话", chat_id, user_id, |
| | | mode, "", DF_TYPE) |
| | | conversation_id, message = await add_complex_log(db, str(uuid.uuid4()),chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=chat_request.to_dict()) |
| | | conversation_id, message = await add_complex_log(db, new_message_id, chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=query_data) |
| | | if not message: |
| | | yield "data: " + json.dumps({"message": smart_message_error, |
| | | "error": "\n**ERROR**: 创建会话失败!", "status": http_500}, |
| | | ensure_ascii=False) + "\n\n" |
| | | return |
| | | inputs = {"is_deep": chat_request.isDeep} |
| | | if chat_request.chatMode == complex_knowledge_chat: |
| | | inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId}) |
| | | elif chat_request.chatMode == complex_content_optimization_chat: |
| | | inputs["type"] = chat_request.optimizeType |
| | | query_data["parentId"] = new_message_id |
| | | try: |
| | | token = await get_chat_token(db, chat_id) |
| | | chat, url = await get_chat_object(mode) |
| | | async for ans in chat.chat_completions(url, |
| | | await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=chat_request.files, inputs=inputs), |
| | | await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=files, inputs=inputs), |
| | | await chat.get_headers(token)): |
| | | # print(ans) |
| | | data = {} |
| | |
| | | |
| | | yield "data: " + json.dumps( |
| | | {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id, |
| | | "parent_id": new_message_id, |
| | | "session_id": chat_request.sessionId}, |
| | | ensure_ascii=False) + "\n\n" |
| | | |
| | |
| | | # "node_list": node_list, "task_id": task_id, "id": message_id, |
| | | # "error": error}, conversation_id) |
| | | if message_id: |
| | | await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=chat_request.to_dict()) |
| | | await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=query_data) |
| | | |
| | | async def service_complex_upload(db, chat_id, file, user_id): |
| | | files = [] |
| | |
| | | app.include_router(dialog_router, prefix='/api/dialog', tags=["dialog"]) |
| | | app.include_router(canvas_router, prefix='/api/canvas', tags=["canvas"]) |
| | | app.include_router(label_router, prefix='/api/label', tags=["label"]) |
| | | app.include_router(public_api, prefix='/v1/api', tags=["public_api"]) |
| | | app.include_router(public_api, prefix='/v1', tags=["public_api"]) |
| | | app.include_router(chat_router_v2, prefix='/api/v1', tags=["chat1"]) |
| | | app.include_router(system_router, prefix='/api/system', tags=["system"]) |
| | | app.include_router(mind_map_router, prefix='/api/mindmap', tags=["mindmap"]) |