zhaoqingang
2025-03-13 30ff0afd5d76a3a5aa48058210ae411253574ada
增加文件多轮问答
6个文件已修改
178 ■■■■ 已修改文件
app/api/__init__.py 34 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/v2/public_api.py 76 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/user_model.py 23 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/v2/chat.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/v2/chat.py 36 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
main.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/__init__.py
@@ -1,9 +1,11 @@
import urllib
from datetime import datetime
from typing import Callable, Any
from urllib.parse import urlencode
import jwt
# from cryptography.fernet import Fernet
from fastapi import FastAPI, Depends, HTTPException, Header
from fastapi import FastAPI, Depends, HTTPException, Header, Request
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from pydantic import BaseModel
@@ -11,8 +13,9 @@
from starlette.websockets import WebSocket, WebSocketDisconnect
from Log import logger
from app.models.base_model import SessionLocal
# from app.models.app_model import AppRegisterModel
from app.models.user_model import UserModel
from app.models.user_model import UserModel, UserApiTokenModel
from app.service.auth import SECRET_KEY, ALGORITHM
from app.config.config import settings
@@ -35,6 +38,33 @@
    data: list[dict] = []
def verify_token(token: str) -> Any:
    """
    验证 Token 是否有效
    """
    db = SessionLocal()
    try:
        db_token = db.query(UserApiTokenModel).filter(UserApiTokenModel.token == token, UserApiTokenModel.is_active == 1).first()
        return db_token is not None and (db_token.expires_at is None or db_token.expires_at > datetime.now())
    finally:
        db.close()
def token_required()-> Callable:
    def decorated_function(request: Request)-> Any:
        authorization_str = request.headers.get("Authorization")
        if not authorization_str:
            raise HTTPException(status_code=401, detail="Authorization` can't be empty")
        authorization_list = authorization_str.split()
        if len(authorization_list) < 2:
            raise HTTPException(status_code=401, detail="Invalid token")
        token = authorization_list[1]
        objs = verify_token(token)
        if not objs:
            raise HTTPException(status_code=401, detail="Invalid token")
        user = UserModel(username="", id=objs.user_id)
        return user
    return decorated_function
def get_current_user(token: str = Depends(oauth2_scheme)):
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
app/api/v2/public_api.py
@@ -1,19 +1,26 @@
import json
import uuid
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from starlette.responses import StreamingResponse
from Log import logger
from app.api import Response
from app.api import Response, token_required
from app.config.const import IMAGE_TO_TEXT, DOCUMENT_TO_CLEANING, DOCUMENT_TO_REPORT, DIFY, BISHENG, RAGFLOW, \
    DOCUMENT_IA_QUESTIONS, DOCUMENT_TO_REPORT_TITLE, DOCUMENT_TO_TITLE, DOCUMENT_TO_PAPER, \
    DOCUMENT_IA_QUESTIONS_EQUIPMENT
from app.models.base_model import get_db
    DOCUMENT_IA_QUESTIONS_EQUIPMENT, dialog_chat, workflow_chat, advanced_chat, agent_chat, base_chat
from app.models.public_api_model import DfToken
from app.service.v2.api_token import DfTokenDao
from app.service.v2.initialize_data import dialog_menu_sync, create_menu_sync, user_update_app
from app.task.sync_resources import sync_knowledge, sync_dialog, sync_agent, sync_llm, sync_resource
from fastapi import Depends, APIRouter, File, UploadFile
from sqlalchemy.orm import Session
from app.config.const import smart_message_error, http_400, http_500, http_200, complex_dialog_chat
from app.models import UserModel
from app.models.base_model import get_db
from app.models.v2.session_model import ChatData
from app.service.v2.chat import service_chat_dialog, get_chat_info, service_chat_sessions, service_chat_workflow
public_api = APIRouter()
@@ -94,3 +101,62 @@
        return Response(code=500, msg=str(e), data={})
    return Response(code=200, msg="success", data={})
@public_api.post("/chat/{chatId}/completions")
async def api_chat_dialog(chatId:str, dialog: ChatData, current_user: UserModel = Depends(token_required),db: Session = Depends(get_db)): #  current_user: UserModel = Depends(get_current_user)
    chat_info = await get_chat_info(db, chatId)
    if not chat_info:
        error_msg = json.dumps(
            {"message": smart_message_error, "error": "\n**ERROR**: parameter exception", "status": http_400})
        return StreamingResponse(f"data: {error_msg}\n\n",
                                 media_type="text/event-stream")
    if chat_info.mode == dialog_chat:
        session_id = dialog.sessionId
        if not dialog.query:
            error_msg = json.dumps(
                {"message": smart_message_error, "error": "\n**ERROR**: question cannot be empty.", "status": http_400})
            return StreamingResponse(f"data: {error_msg}\n\n",
                                     media_type="text/event-stream")
        if not session_id:
            session = await service_chat_sessions(db, chatId, dialog.query)
            # print(session)
            if not session or session.get("code") != 0:
                error_msg = json.dumps(
                    {"message": smart_message_error, "error": "\n**ERROR**: chat agent error", "status": http_500})
                return StreamingResponse(f"data: {error_msg}\n\n",
                                         media_type="text/event-stream")
            session_id = session.get("data", {}).get("id")
        return StreamingResponse(service_chat_dialog(db, chatId, dialog.query, session_id, current_user.id, chat_info.mode),
                                 media_type="text/event-stream")
    elif chat_info.mode == workflow_chat:
        chat_info = await get_chat_info(db, chatId)
        if not chat_info:
            error_msg = json.dumps(
                {"message": smart_message_error, "error": "\n**ERROR**: parameter exception", "status": http_400})
            return StreamingResponse(f"data: {error_msg}\n\n",
                                     media_type="text/event-stream")
        session_id = dialog.sessionId
        if not session_id:
            session_id = str(uuid.uuid4()).replace("-", "")
        return StreamingResponse(service_chat_workflow(db, chatId, dialog, session_id, current_user.id, chat_info.mode),
                                 media_type="text/event-stream")
    elif chat_info.mode == advanced_chat or chat_info.mode == agent_chat or chat_info.mode == base_chat:
        chat_info = await get_chat_info(db, chatId)
        if not chat_info:
            error_msg = json.dumps(
                {"message": smart_message_error, "error": "\n**ERROR**: parameter exception", "status": http_400})
            return StreamingResponse(f"data: {error_msg}\n\n",
                                     media_type="text/event-stream")
        session_id = dialog.sessionId
        if not session_id:
            session_id = str(uuid.uuid4()).replace("-", "")
        return StreamingResponse(service_chat_workflow(db, chatId, dialog, session_id, current_user.id, chat_info.mode),
                                 media_type="text/event-stream")
    else:
        error_msg = json.dumps(
            {"message": smart_message_error, "error": "\n**ERROR**: unknown chat", "status": http_400})
        return StreamingResponse(f"data: {error_msg}\n\n",
                                 media_type="text/event-stream")
app/models/user_model.py
@@ -254,4 +254,27 @@
            'password': self.password,
            'access_token': self.access_token,
            'refresh_token': self.refresh_token,
        }
class UserApiTokenModel(Base):
    __tablename__ = "user_api_token"
    id = Column(Integer, primary_key=True)
    user_id = Column(Integer)
    token = Column(String(40), index=True)
    created_at = Column(DateTime, default=datetime.now())
    updated_at = Column(DateTime, default=datetime.now())
    expires_at = Column(DateTime)
    is_active = Column(Integer, default=1)
    def to_json(self):
        return {
            'id': self.id,
            'account': self.username,
            'createTime': self.created_at,
            'updateTime': self.updated_at,
            'password': self.password,
            'access_token': self.access_token,
            'refresh_token': self.refresh_token,
        }
app/models/v2/chat.py
@@ -23,6 +23,7 @@
class ChatDataRequest(BaseModel):
    sessionId: str
    parentId: Optional[str] = ""
    query: str
    chatMode: Optional[int] = 1  # 1= 普通对话,2=联网,3=知识库,4=深度
    isDeep: Optional[int] = 1  # 1= 普通, 2=深度
@@ -40,6 +41,7 @@
            "files": self.files,
            "isDeep": self.isDeep,
            "optimizeType": self.optimizeType,
            "parentId": self.parentId,
        }
@@ -182,11 +184,16 @@
                'content': self.content,
            }
        else:
            query = {}
            if self.query:
                query = json.loads(self.query)
            return {
                'id': self.id,
                'role': "assistant",
                'answer': self.content,
                'chat_mode': self.chat_mode,
                'node_list': json.loads(self.node_data) if self.node_data else [],
                "parentId": query.get("parentId")
            }
app/service/v2/chat.py
@@ -475,7 +475,12 @@
        logger.error(e)
        return conversation_id, False
async def add_query_files(db, message_id):
    query = {}
    complex_log = await ComplexChatSessionDao(db).get_session_by_id(message_id)
    if complex_log:
        query = json.loads(complex_log.query)
    return query.get("files", [])
async def service_complex_chat(db, chat_id, mode, user_id, chat_request: ChatDataRequest):
    answer_event = ""
@@ -485,28 +490,34 @@
    message_id = ""
    task_id = ""
    error = ""
    files = []
    node_list = []
    conversation_id = ""
    token = await get_chat_token(db, chat_id)
    chat, url = await get_chat_object(mode)
    query_data = chat_request.to_dict()
    new_message_id = str(uuid.uuid4())
    inputs = {"is_deep": chat_request.isDeep}
    files = chat_request.files
    if chat_request.chatMode == complex_knowledge_chat:
        inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId})
    elif chat_request.chatMode == complex_content_optimization_chat:
        inputs["type"] = chat_request.optimizeType
    elif chat_request.chatMode == complex_dialog_chat:
        if not files and chat_request.parentId:
            files = await add_query_files(db, chat_request.parentId)
    if chat_request.chatMode != complex_content_optimization_chat:
        await add_session_log(db, chat_request.sessionId, chat_request.query if chat_request.query else "未命名会话", chat_id, user_id,
                                mode, "", DF_TYPE)
        conversation_id, message = await add_complex_log(db, str(uuid.uuid4()),chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=chat_request.to_dict())
        conversation_id, message = await add_complex_log(db, new_message_id, chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=query_data)
        if not message:
            yield "data: " + json.dumps({"message": smart_message_error,
                                         "error": "\n**ERROR**: 创建会话失败!", "status": http_500},
                                        ensure_ascii=False) + "\n\n"
            return
    inputs = {"is_deep": chat_request.isDeep}
    if chat_request.chatMode == complex_knowledge_chat:
        inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId})
    elif chat_request.chatMode == complex_content_optimization_chat:
        inputs["type"] = chat_request.optimizeType
    query_data["parentId"] = new_message_id
    try:
        token = await get_chat_token(db, chat_id)
        chat, url = await get_chat_object(mode)
        async for ans in chat.chat_completions(url,
                                               await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=chat_request.files, inputs=inputs),
                                               await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=files, inputs=inputs),
                                               await chat.get_headers(token)):
            # print(ans)
            data = {}
@@ -561,6 +572,7 @@
            yield "data: " + json.dumps(
                {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id,
                 "parent_id": new_message_id,
                 "session_id": chat_request.sessionId},
                ensure_ascii=False) + "\n\n"
@@ -579,7 +591,7 @@
        #                                           "node_list": node_list, "task_id": task_id, "id": message_id,
        #                                           "error": error}, conversation_id)
        if message_id:
            await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=chat_request.to_dict())
            await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=query_data)
async def service_complex_upload(db, chat_id, file, user_id):
    files = []
main.py
@@ -75,7 +75,7 @@
app.include_router(dialog_router, prefix='/api/dialog', tags=["dialog"])
app.include_router(canvas_router, prefix='/api/canvas', tags=["canvas"])
app.include_router(label_router, prefix='/api/label', tags=["label"])
app.include_router(public_api, prefix='/v1/api', tags=["public_api"])
app.include_router(public_api, prefix='/v1', tags=["public_api"])
app.include_router(chat_router_v2, prefix='/api/v1', tags=["chat1"])
app.include_router(system_router, prefix='/api/system', tags=["system"])
app.include_router(mind_map_router, prefix='/api/mindmap', tags=["mindmap"])