From 30ff0afd5d76a3a5aa48058210ae411253574ada Mon Sep 17 00:00:00 2001
From: zhaoqingang <zhaoqg0118@163.com>
Date: 星期四, 13 三月 2025 14:55:30 +0800
Subject: [PATCH] 增加文件多轮问答

---
 app/service/v2/chat.py   |   36 ++++++++----
 app/models/user_model.py |   23 +++++++
 app/api/v2/public_api.py |   76 +++++++++++++++++++++++-
 main.py                  |    2 
 app/api/__init__.py      |   34 ++++++++++
 app/models/v2/chat.py    |    7 ++
 6 files changed, 158 insertions(+), 20 deletions(-)

diff --git a/app/api/__init__.py b/app/api/__init__.py
index 2a679e8..6cb4b05 100644
--- a/app/api/__init__.py
+++ b/app/api/__init__.py
@@ -1,9 +1,11 @@
 import urllib
+from datetime import datetime
+from typing import Callable, Any
 from urllib.parse import urlencode
 
 import jwt
 # from cryptography.fernet import Fernet
-from fastapi import FastAPI, Depends, HTTPException, Header
+from fastapi import FastAPI, Depends, HTTPException, Header, Request
 from fastapi.security import OAuth2PasswordBearer
 from passlib.context import CryptContext
 from pydantic import BaseModel
@@ -11,8 +13,9 @@
 from starlette.websockets import WebSocket, WebSocketDisconnect
 
 from Log import logger
+from app.models.base_model import SessionLocal
 # from app.models.app_model import AppRegisterModel
-from app.models.user_model import UserModel
+from app.models.user_model import UserModel, UserApiTokenModel
 from app.service.auth import SECRET_KEY, ALGORITHM
 from app.config.config import settings
 
@@ -35,6 +38,33 @@
     data: list[dict] = []
 
 
+def verify_token(token: str) -> Any:
+    """
+    楠岃瘉 Token 鏄惁鏈夋晥
+    """
+    db = SessionLocal()
+    try:
+        db_token = db.query(UserApiTokenModel).filter(UserApiTokenModel.token == token, UserApiTokenModel.is_active == 1).first()
+        return db_token is not None and (db_token.expires_at is None or db_token.expires_at > datetime.now())
+    finally:
+        db.close()
+
+def token_required()-> Callable:
+    def decorated_function(request: Request)-> Any:
+        authorization_str = request.headers.get("Authorization")
+        if not authorization_str:
+            raise HTTPException(status_code=401, detail="Authorization` can't be empty")
+        authorization_list = authorization_str.split()
+        if len(authorization_list) < 2:
+            raise HTTPException(status_code=401, detail="Invalid token")
+        token = authorization_list[1]
+        objs = verify_token(token)
+        if not objs:
+            raise HTTPException(status_code=401, detail="Invalid token")
+        user = UserModel(username="", id=objs.user_id)
+        return user
+    return decorated_function
+
 def get_current_user(token: str = Depends(oauth2_scheme)):
     try:
         payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
diff --git a/app/api/v2/public_api.py b/app/api/v2/public_api.py
index 98aba70..c86b787 100644
--- a/app/api/v2/public_api.py
+++ b/app/api/v2/public_api.py
@@ -1,19 +1,26 @@
 import json
+import uuid
 
-from fastapi import APIRouter, Depends
 from fastapi.responses import JSONResponse
+from starlette.responses import StreamingResponse
+
 from Log import logger
-from app.api import Response
+from app.api import Response, token_required
 
 from app.config.const import IMAGE_TO_TEXT, DOCUMENT_TO_CLEANING, DOCUMENT_TO_REPORT, DIFY, BISHENG, RAGFLOW, \
     DOCUMENT_IA_QUESTIONS, DOCUMENT_TO_REPORT_TITLE, DOCUMENT_TO_TITLE, DOCUMENT_TO_PAPER, \
-    DOCUMENT_IA_QUESTIONS_EQUIPMENT
-
-from app.models.base_model import get_db
+    DOCUMENT_IA_QUESTIONS_EQUIPMENT, dialog_chat, workflow_chat, advanced_chat, agent_chat, base_chat
 from app.models.public_api_model import DfToken
 from app.service.v2.api_token import DfTokenDao
 from app.service.v2.initialize_data import dialog_menu_sync, create_menu_sync, user_update_app
 from app.task.sync_resources import sync_knowledge, sync_dialog, sync_agent, sync_llm, sync_resource
+from fastapi import Depends, APIRouter, File, UploadFile
+from sqlalchemy.orm import Session
+from app.config.const import smart_message_error, http_400, http_500, http_200, complex_dialog_chat
+from app.models import UserModel
+from app.models.base_model import get_db
+from app.models.v2.session_model import ChatData
+from app.service.v2.chat import service_chat_dialog, get_chat_info, service_chat_sessions, service_chat_workflow
 
 public_api = APIRouter()
 
@@ -94,3 +101,62 @@
         return Response(code=500, msg=str(e), data={})
 
     return Response(code=200, msg="success", data={})
+
+
+@public_api.post("/chat/{chatId}/completions")
+async def api_chat_dialog(chatId:str, dialog: ChatData, current_user: UserModel = Depends(token_required),db: Session = Depends(get_db)): #  current_user: UserModel = Depends(get_current_user)
+    chat_info = await get_chat_info(db, chatId)
+    if not chat_info:
+        error_msg = json.dumps(
+            {"message": smart_message_error, "error": "\n**ERROR**: parameter exception", "status": http_400})
+        return StreamingResponse(f"data: {error_msg}\n\n",
+                                 media_type="text/event-stream")
+    if chat_info.mode == dialog_chat:
+
+        session_id = dialog.sessionId
+        if not dialog.query:
+            error_msg = json.dumps(
+                {"message": smart_message_error, "error": "\n**ERROR**: question cannot be empty.", "status": http_400})
+            return StreamingResponse(f"data: {error_msg}\n\n",
+                                     media_type="text/event-stream")
+        if not session_id:
+            session = await service_chat_sessions(db, chatId, dialog.query)
+            # print(session)
+            if not session or session.get("code") != 0:
+                error_msg = json.dumps(
+                    {"message": smart_message_error, "error": "\n**ERROR**: chat agent error", "status": http_500})
+                return StreamingResponse(f"data: {error_msg}\n\n",
+                                         media_type="text/event-stream")
+            session_id = session.get("data", {}).get("id")
+        return StreamingResponse(service_chat_dialog(db, chatId, dialog.query, session_id, current_user.id, chat_info.mode),
+                                 media_type="text/event-stream")
+    elif chat_info.mode == workflow_chat:
+        chat_info = await get_chat_info(db, chatId)
+        if not chat_info:
+            error_msg = json.dumps(
+                {"message": smart_message_error, "error": "\n**ERROR**: parameter exception", "status": http_400})
+            return StreamingResponse(f"data: {error_msg}\n\n",
+                                     media_type="text/event-stream")
+        session_id = dialog.sessionId
+        if not session_id:
+            session_id = str(uuid.uuid4()).replace("-", "")
+        return StreamingResponse(service_chat_workflow(db, chatId, dialog, session_id, current_user.id, chat_info.mode),
+                                 media_type="text/event-stream")
+
+    elif chat_info.mode == advanced_chat or chat_info.mode == agent_chat or chat_info.mode == base_chat:
+        chat_info = await get_chat_info(db, chatId)
+        if not chat_info:
+            error_msg = json.dumps(
+                {"message": smart_message_error, "error": "\n**ERROR**: parameter exception", "status": http_400})
+            return StreamingResponse(f"data: {error_msg}\n\n",
+                                     media_type="text/event-stream")
+        session_id = dialog.sessionId
+        if not session_id:
+            session_id = str(uuid.uuid4()).replace("-", "")
+        return StreamingResponse(service_chat_workflow(db, chatId, dialog, session_id, current_user.id, chat_info.mode),
+                                 media_type="text/event-stream")
+    else:
+        error_msg = json.dumps(
+            {"message": smart_message_error, "error": "\n**ERROR**: unknown chat", "status": http_400})
+        return StreamingResponse(f"data: {error_msg}\n\n",
+                                 media_type="text/event-stream")
diff --git a/app/models/user_model.py b/app/models/user_model.py
index bb3a382..9a23226 100644
--- a/app/models/user_model.py
+++ b/app/models/user_model.py
@@ -254,4 +254,27 @@
             'password': self.password,
             'access_token': self.access_token,
             'refresh_token': self.refresh_token,
+        }
+
+
+
+class UserApiTokenModel(Base):
+    __tablename__ = "user_api_token"
+    id = Column(Integer, primary_key=True)
+    user_id = Column(Integer)
+    token = Column(String(40), index=True)
+    created_at = Column(DateTime, default=datetime.now())
+    updated_at = Column(DateTime, default=datetime.now())
+    expires_at = Column(DateTime)
+    is_active = Column(Integer, default=1)
+
+    def to_json(self):
+        return {
+            'id': self.id,
+            'account': self.username,
+            'createTime': self.created_at,
+            'updateTime': self.updated_at,
+            'password': self.password,
+            'access_token': self.access_token,
+            'refresh_token': self.refresh_token,
         }
\ No newline at end of file
diff --git a/app/models/v2/chat.py b/app/models/v2/chat.py
index e41edcf..7aed562 100644
--- a/app/models/v2/chat.py
+++ b/app/models/v2/chat.py
@@ -23,6 +23,7 @@
 
 class ChatDataRequest(BaseModel):
     sessionId: str
+    parentId: Optional[str] = ""
     query: str
     chatMode: Optional[int] = 1  # 1= 鏅�氬璇濓紝2=鑱旂綉锛�3=鐭ヨ瘑搴�,4=娣卞害
     isDeep: Optional[int] = 1  # 1= 鏅��, 2=娣卞害
@@ -40,6 +41,7 @@
             "files": self.files,
             "isDeep": self.isDeep,
             "optimizeType": self.optimizeType,
+            "parentId": self.parentId,
         }
 
 
@@ -182,11 +184,16 @@
                 'content': self.content,
             }
         else:
+            query = {}
+            if self.query:
+                query = json.loads(self.query)
             return {
                 'id': self.id,
                 'role': "assistant",
                 'answer': self.content,
+                'chat_mode': self.chat_mode,
                 'node_list': json.loads(self.node_data) if self.node_data else [],
+                "parentId": query.get("parentId")
             }
 
 
diff --git a/app/service/v2/chat.py b/app/service/v2/chat.py
index 0942681..96672a3 100644
--- a/app/service/v2/chat.py
+++ b/app/service/v2/chat.py
@@ -475,7 +475,12 @@
         logger.error(e)
         return conversation_id, False
 
-
+async def add_query_files(db, message_id):
+    query = {}
+    complex_log = await ComplexChatSessionDao(db).get_session_by_id(message_id)
+    if complex_log:
+        query = json.loads(complex_log.query)
+    return query.get("files", [])
 
 async def service_complex_chat(db, chat_id, mode, user_id, chat_request: ChatDataRequest):
     answer_event = ""
@@ -485,28 +490,34 @@
     message_id = ""
     task_id = ""
     error = ""
-    files = []
     node_list = []
     conversation_id = ""
-    token = await get_chat_token(db, chat_id)
-    chat, url = await get_chat_object(mode)
+    query_data = chat_request.to_dict()
+    new_message_id = str(uuid.uuid4())
+    inputs = {"is_deep": chat_request.isDeep}
+    files = chat_request.files
+    if chat_request.chatMode == complex_knowledge_chat:
+        inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId})
+    elif chat_request.chatMode == complex_content_optimization_chat:
+        inputs["type"] = chat_request.optimizeType
+    elif chat_request.chatMode == complex_dialog_chat:
+        if not files and chat_request.parentId:
+            files = await add_query_files(db, chat_request.parentId)
     if chat_request.chatMode != complex_content_optimization_chat:
         await add_session_log(db, chat_request.sessionId, chat_request.query if chat_request.query else "鏈懡鍚嶄細璇�", chat_id, user_id,
                                 mode, "", DF_TYPE)
-        conversation_id, message = await add_complex_log(db, str(uuid.uuid4()),chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=chat_request.to_dict())
+        conversation_id, message = await add_complex_log(db, new_message_id, chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=query_data)
         if not message:
             yield "data: " + json.dumps({"message": smart_message_error,
                                          "error": "\n**ERROR**: 鍒涘缓浼氳瘽澶辫触锛�", "status": http_500},
                                         ensure_ascii=False) + "\n\n"
             return
-    inputs = {"is_deep": chat_request.isDeep}
-    if chat_request.chatMode == complex_knowledge_chat:
-        inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId})
-    elif chat_request.chatMode == complex_content_optimization_chat:
-        inputs["type"] = chat_request.optimizeType
+    query_data["parentId"] = new_message_id
     try:
+        token = await get_chat_token(db, chat_id)
+        chat, url = await get_chat_object(mode)
         async for ans in chat.chat_completions(url,
-                                               await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=chat_request.files, inputs=inputs),
+                                               await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=files, inputs=inputs),
                                                await chat.get_headers(token)):
             # print(ans)
             data = {}
@@ -561,6 +572,7 @@
 
             yield "data: " + json.dumps(
                 {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id,
+                 "parent_id": new_message_id,
                  "session_id": chat_request.sessionId},
                 ensure_ascii=False) + "\n\n"
 
@@ -579,7 +591,7 @@
         #                                           "node_list": node_list, "task_id": task_id, "id": message_id,
         #                                           "error": error}, conversation_id)
         if message_id:
-            await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=chat_request.to_dict())
+            await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=query_data)
 
 async def service_complex_upload(db, chat_id, file, user_id):
     files = []
diff --git a/main.py b/main.py
index 87597e7..3ec7a9d 100644
--- a/main.py
+++ b/main.py
@@ -75,7 +75,7 @@
 app.include_router(dialog_router, prefix='/api/dialog', tags=["dialog"])
 app.include_router(canvas_router, prefix='/api/canvas', tags=["canvas"])
 app.include_router(label_router, prefix='/api/label', tags=["label"])
-app.include_router(public_api, prefix='/v1/api', tags=["public_api"])
+app.include_router(public_api, prefix='/v1', tags=["public_api"])
 app.include_router(chat_router_v2, prefix='/api/v1', tags=["chat1"])
 app.include_router(system_router, prefix='/api/system', tags=["system"])
 app.include_router(mind_map_router, prefix='/api/mindmap', tags=["mindmap"])

--
Gitblit v1.8.0