From 30ff0afd5d76a3a5aa48058210ae411253574ada Mon Sep 17 00:00:00 2001 From: zhaoqingang <zhaoqg0118@163.com> Date: 星期四, 13 三月 2025 14:55:30 +0800 Subject: [PATCH] 增加文件多轮问答 --- app/service/v2/chat.py | 36 ++++++++---- app/models/user_model.py | 23 +++++++ app/api/v2/public_api.py | 76 +++++++++++++++++++++++- main.py | 2 app/api/__init__.py | 34 ++++++++++ app/models/v2/chat.py | 7 ++ 6 files changed, 158 insertions(+), 20 deletions(-) diff --git a/app/api/__init__.py b/app/api/__init__.py index 2a679e8..6cb4b05 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -1,9 +1,11 @@ import urllib +from datetime import datetime +from typing import Callable, Any from urllib.parse import urlencode import jwt # from cryptography.fernet import Fernet -from fastapi import FastAPI, Depends, HTTPException, Header +from fastapi import FastAPI, Depends, HTTPException, Header, Request from fastapi.security import OAuth2PasswordBearer from passlib.context import CryptContext from pydantic import BaseModel @@ -11,8 +13,9 @@ from starlette.websockets import WebSocket, WebSocketDisconnect from Log import logger +from app.models.base_model import SessionLocal # from app.models.app_model import AppRegisterModel -from app.models.user_model import UserModel +from app.models.user_model import UserModel, UserApiTokenModel from app.service.auth import SECRET_KEY, ALGORITHM from app.config.config import settings @@ -35,6 +38,33 @@ data: list[dict] = [] +def verify_token(token: str) -> Any: + """ + 楠岃瘉 Token 鏄惁鏈夋晥 + """ + db = SessionLocal() + try: + db_token = db.query(UserApiTokenModel).filter(UserApiTokenModel.token == token, UserApiTokenModel.is_active == 1).first() + return db_token is not None and (db_token.expires_at is None or db_token.expires_at > datetime.now()) + finally: + db.close() + +def token_required()-> Callable: + def decorated_function(request: Request)-> Any: + authorization_str = request.headers.get("Authorization") + if not authorization_str: + raise HTTPException(status_code=401, detail="Authorization` can't be empty") + authorization_list = authorization_str.split() + if len(authorization_list) < 2: + raise HTTPException(status_code=401, detail="Invalid token") + token = authorization_list[1] + objs = verify_token(token) + if not objs: + raise HTTPException(status_code=401, detail="Invalid token") + user = UserModel(username="", id=objs.user_id) + return user + return decorated_function + def get_current_user(token: str = Depends(oauth2_scheme)): try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) diff --git a/app/api/v2/public_api.py b/app/api/v2/public_api.py index 98aba70..c86b787 100644 --- a/app/api/v2/public_api.py +++ b/app/api/v2/public_api.py @@ -1,19 +1,26 @@ import json +import uuid -from fastapi import APIRouter, Depends from fastapi.responses import JSONResponse +from starlette.responses import StreamingResponse + from Log import logger -from app.api import Response +from app.api import Response, token_required from app.config.const import IMAGE_TO_TEXT, DOCUMENT_TO_CLEANING, DOCUMENT_TO_REPORT, DIFY, BISHENG, RAGFLOW, \ DOCUMENT_IA_QUESTIONS, DOCUMENT_TO_REPORT_TITLE, DOCUMENT_TO_TITLE, DOCUMENT_TO_PAPER, \ - DOCUMENT_IA_QUESTIONS_EQUIPMENT - -from app.models.base_model import get_db + DOCUMENT_IA_QUESTIONS_EQUIPMENT, dialog_chat, workflow_chat, advanced_chat, agent_chat, base_chat from app.models.public_api_model import DfToken from app.service.v2.api_token import DfTokenDao from app.service.v2.initialize_data import dialog_menu_sync, create_menu_sync, user_update_app from app.task.sync_resources import sync_knowledge, sync_dialog, sync_agent, sync_llm, sync_resource +from fastapi import Depends, APIRouter, File, UploadFile +from sqlalchemy.orm import Session +from app.config.const import smart_message_error, http_400, http_500, http_200, complex_dialog_chat +from app.models import UserModel +from app.models.base_model import get_db +from app.models.v2.session_model import ChatData +from app.service.v2.chat import service_chat_dialog, get_chat_info, service_chat_sessions, service_chat_workflow public_api = APIRouter() @@ -94,3 +101,62 @@ return Response(code=500, msg=str(e), data={}) return Response(code=200, msg="success", data={}) + + +@public_api.post("/chat/{chatId}/completions") +async def api_chat_dialog(chatId:str, dialog: ChatData, current_user: UserModel = Depends(token_required),db: Session = Depends(get_db)): # current_user: UserModel = Depends(get_current_user) + chat_info = await get_chat_info(db, chatId) + if not chat_info: + error_msg = json.dumps( + {"message": smart_message_error, "error": "\n**ERROR**: parameter exception", "status": http_400}) + return StreamingResponse(f"data: {error_msg}\n\n", + media_type="text/event-stream") + if chat_info.mode == dialog_chat: + + session_id = dialog.sessionId + if not dialog.query: + error_msg = json.dumps( + {"message": smart_message_error, "error": "\n**ERROR**: question cannot be empty.", "status": http_400}) + return StreamingResponse(f"data: {error_msg}\n\n", + media_type="text/event-stream") + if not session_id: + session = await service_chat_sessions(db, chatId, dialog.query) + # print(session) + if not session or session.get("code") != 0: + error_msg = json.dumps( + {"message": smart_message_error, "error": "\n**ERROR**: chat agent error", "status": http_500}) + return StreamingResponse(f"data: {error_msg}\n\n", + media_type="text/event-stream") + session_id = session.get("data", {}).get("id") + return StreamingResponse(service_chat_dialog(db, chatId, dialog.query, session_id, current_user.id, chat_info.mode), + media_type="text/event-stream") + elif chat_info.mode == workflow_chat: + chat_info = await get_chat_info(db, chatId) + if not chat_info: + error_msg = json.dumps( + {"message": smart_message_error, "error": "\n**ERROR**: parameter exception", "status": http_400}) + return StreamingResponse(f"data: {error_msg}\n\n", + media_type="text/event-stream") + session_id = dialog.sessionId + if not session_id: + session_id = str(uuid.uuid4()).replace("-", "") + return StreamingResponse(service_chat_workflow(db, chatId, dialog, session_id, current_user.id, chat_info.mode), + media_type="text/event-stream") + + elif chat_info.mode == advanced_chat or chat_info.mode == agent_chat or chat_info.mode == base_chat: + chat_info = await get_chat_info(db, chatId) + if not chat_info: + error_msg = json.dumps( + {"message": smart_message_error, "error": "\n**ERROR**: parameter exception", "status": http_400}) + return StreamingResponse(f"data: {error_msg}\n\n", + media_type="text/event-stream") + session_id = dialog.sessionId + if not session_id: + session_id = str(uuid.uuid4()).replace("-", "") + return StreamingResponse(service_chat_workflow(db, chatId, dialog, session_id, current_user.id, chat_info.mode), + media_type="text/event-stream") + else: + error_msg = json.dumps( + {"message": smart_message_error, "error": "\n**ERROR**: unknown chat", "status": http_400}) + return StreamingResponse(f"data: {error_msg}\n\n", + media_type="text/event-stream") diff --git a/app/models/user_model.py b/app/models/user_model.py index bb3a382..9a23226 100644 --- a/app/models/user_model.py +++ b/app/models/user_model.py @@ -254,4 +254,27 @@ 'password': self.password, 'access_token': self.access_token, 'refresh_token': self.refresh_token, + } + + + +class UserApiTokenModel(Base): + __tablename__ = "user_api_token" + id = Column(Integer, primary_key=True) + user_id = Column(Integer) + token = Column(String(40), index=True) + created_at = Column(DateTime, default=datetime.now()) + updated_at = Column(DateTime, default=datetime.now()) + expires_at = Column(DateTime) + is_active = Column(Integer, default=1) + + def to_json(self): + return { + 'id': self.id, + 'account': self.username, + 'createTime': self.created_at, + 'updateTime': self.updated_at, + 'password': self.password, + 'access_token': self.access_token, + 'refresh_token': self.refresh_token, } \ No newline at end of file diff --git a/app/models/v2/chat.py b/app/models/v2/chat.py index e41edcf..7aed562 100644 --- a/app/models/v2/chat.py +++ b/app/models/v2/chat.py @@ -23,6 +23,7 @@ class ChatDataRequest(BaseModel): sessionId: str + parentId: Optional[str] = "" query: str chatMode: Optional[int] = 1 # 1= 鏅�氬璇濓紝2=鑱旂綉锛�3=鐭ヨ瘑搴�,4=娣卞害 isDeep: Optional[int] = 1 # 1= 鏅��, 2=娣卞害 @@ -40,6 +41,7 @@ "files": self.files, "isDeep": self.isDeep, "optimizeType": self.optimizeType, + "parentId": self.parentId, } @@ -182,11 +184,16 @@ 'content': self.content, } else: + query = {} + if self.query: + query = json.loads(self.query) return { 'id': self.id, 'role': "assistant", 'answer': self.content, + 'chat_mode': self.chat_mode, 'node_list': json.loads(self.node_data) if self.node_data else [], + "parentId": query.get("parentId") } diff --git a/app/service/v2/chat.py b/app/service/v2/chat.py index 0942681..96672a3 100644 --- a/app/service/v2/chat.py +++ b/app/service/v2/chat.py @@ -475,7 +475,12 @@ logger.error(e) return conversation_id, False - +async def add_query_files(db, message_id): + query = {} + complex_log = await ComplexChatSessionDao(db).get_session_by_id(message_id) + if complex_log: + query = json.loads(complex_log.query) + return query.get("files", []) async def service_complex_chat(db, chat_id, mode, user_id, chat_request: ChatDataRequest): answer_event = "" @@ -485,28 +490,34 @@ message_id = "" task_id = "" error = "" - files = [] node_list = [] conversation_id = "" - token = await get_chat_token(db, chat_id) - chat, url = await get_chat_object(mode) + query_data = chat_request.to_dict() + new_message_id = str(uuid.uuid4()) + inputs = {"is_deep": chat_request.isDeep} + files = chat_request.files + if chat_request.chatMode == complex_knowledge_chat: + inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId}) + elif chat_request.chatMode == complex_content_optimization_chat: + inputs["type"] = chat_request.optimizeType + elif chat_request.chatMode == complex_dialog_chat: + if not files and chat_request.parentId: + files = await add_query_files(db, chat_request.parentId) if chat_request.chatMode != complex_content_optimization_chat: await add_session_log(db, chat_request.sessionId, chat_request.query if chat_request.query else "鏈懡鍚嶄細璇�", chat_id, user_id, mode, "", DF_TYPE) - conversation_id, message = await add_complex_log(db, str(uuid.uuid4()),chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=chat_request.to_dict()) + conversation_id, message = await add_complex_log(db, new_message_id, chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=query_data) if not message: yield "data: " + json.dumps({"message": smart_message_error, "error": "\n**ERROR**: 鍒涘缓浼氳瘽澶辫触锛�", "status": http_500}, ensure_ascii=False) + "\n\n" return - inputs = {"is_deep": chat_request.isDeep} - if chat_request.chatMode == complex_knowledge_chat: - inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId}) - elif chat_request.chatMode == complex_content_optimization_chat: - inputs["type"] = chat_request.optimizeType + query_data["parentId"] = new_message_id try: + token = await get_chat_token(db, chat_id) + chat, url = await get_chat_object(mode) async for ans in chat.chat_completions(url, - await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=chat_request.files, inputs=inputs), + await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=files, inputs=inputs), await chat.get_headers(token)): # print(ans) data = {} @@ -561,6 +572,7 @@ yield "data: " + json.dumps( {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id, + "parent_id": new_message_id, "session_id": chat_request.sessionId}, ensure_ascii=False) + "\n\n" @@ -579,7 +591,7 @@ # "node_list": node_list, "task_id": task_id, "id": message_id, # "error": error}, conversation_id) if message_id: - await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=chat_request.to_dict()) + await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=query_data) async def service_complex_upload(db, chat_id, file, user_id): files = [] diff --git a/main.py b/main.py index 87597e7..3ec7a9d 100644 --- a/main.py +++ b/main.py @@ -75,7 +75,7 @@ app.include_router(dialog_router, prefix='/api/dialog', tags=["dialog"]) app.include_router(canvas_router, prefix='/api/canvas', tags=["canvas"]) app.include_router(label_router, prefix='/api/label', tags=["label"]) -app.include_router(public_api, prefix='/v1/api', tags=["public_api"]) +app.include_router(public_api, prefix='/v1', tags=["public_api"]) app.include_router(chat_router_v2, prefix='/api/v1', tags=["chat1"]) app.include_router(system_router, prefix='/api/system', tags=["system"]) app.include_router(mind_map_router, prefix='/api/mindmap', tags=["mindmap"]) -- Gitblit v1.8.0