zhaoqingang
2025-02-18 c966f133bea46f5d1b804296b270955478fa0ac4
app/api/excel.py
@@ -1,10 +1,10 @@
import random
import string
from fastapi import APIRouter, File, UploadFile, Form, BackgroundTasks, Depends, Request
from fastapi import APIRouter, File, UploadFile, Form, BackgroundTasks, Depends, Request, WebSocket
from fastapi.responses import JSONResponse, FileResponse
from sqlalchemy.orm import Session
from starlette.websockets import WebSocket
# from starlette.websockets import WebSocket
from app.api import get_current_user, get_current_user_websocket, Response
from app.models import UserModel, AgentType
@@ -52,14 +52,15 @@
    return prefix + random_part
def db_create_session(db: Session, user_id: str):
def db_create_session(db: Session, user_id: str, message:str, upload_filenames: list):
    db_id = generate_db_id()
    session = SessionService(db).create_session(
        db_id,
        "合并Excel",
        message,
        "basic_excel_merge",
        AgentType.BASIC,
        int(user_id)
        int(user_id),
        {"role": "user", "content": message, "upload_filenames": upload_filenames}
    )
    return session
@@ -102,11 +103,12 @@
    user_excel = EXCEL_FILES_PATH
    create_dir_if_not_exists(user_source)
    create_dir_if_not_exists(user_excel)
    while True:
        data = await websocket.receive_text()
        # data = await websocket.receive_text()git
        receive_message = await websocket.receive_json()
        try:
            if data == "\"合并Excel\"":
            if receive_message.get("message") == "合并Excel":
                upload_filenames = receive_message.get('upload_filenames', [])
                merge_file = run_conformity(user_source, user_excel)
                if merge_file is not None:
@@ -124,7 +126,7 @@
                        "type": "close",
                    })
                    # 创建会话记录
                    session = db_create_session(db, user_id)
                    session = db_create_session(db, user_id, receive_message.get("message"), upload_filenames)
                    # 更新会话记录
                    if session:
                        session_id = session.id
@@ -143,8 +145,8 @@
                    await websocket.send_json({"error": "合并失败", "type": "stream", "files": []})
                    await websocket.close()
            else:
                print(f"Received data: {data}")
                await websocket.send_json({"error": "未知指令", "data": str(data)})
                print(f"Received data: {receive_message.get('message')}")
                await websocket.send_json({"error": "未知指令", "data": str(receive_message.get('message'))})
                await websocket.close()
        except Exception as e:
            await websocket.send_json({"error": str(e)})