zhaoqingang
2024-12-31 6b4093952e555e1eb2713bd85133a5f697cda1e0
app/models/token_model.py
@@ -1,8 +1,11 @@
from datetime import datetime
from typing import Type
from sqlalchemy import Column, Integer, String, DateTime, Text
from sqlalchemy import Column, Integer, DateTime, Text
from sqlalchemy.orm import Session
from Log import logger
from app.config.const import RAGFLOW
from app.models.base_model import Base
@@ -10,9 +13,9 @@
    __tablename__ = "token"
    id = Column(Integer, primary_key=True, index=True)
    user_id = Column(Integer, index=True)
    token = Column(Text(10000), unique=True, index=True)
    bisheng_token = Column(Text(10000), unique=True, index=True)
    ragflow_token = Column(Text(10000), unique=True, index=True)
    token = Column(Text(10000))
    bisheng_token = Column(Text(10000))
    ragflow_token = Column(Text(10000))
    created_at = Column(DateTime, default=datetime.utcnow)
@@ -20,8 +23,8 @@
    # 参数验证
    if not isinstance(user_id, int) or user_id <= 0:
        return
    if not access_token or not bisheng_token or not ragflow_token:
        return
    # if not access_token or not bisheng_token or not ragflow_token:
    #     return
    db_token = None
    try:
        # 查询现有记录
@@ -49,3 +52,41 @@
    except Exception as e:
        # 异常处理
        db.rollback()  # 回滚事务
async def update_token(db: Session, user_id: int, access_token: str, token: dict):
    # 参数验证
    if not isinstance(user_id, int) or user_id <= 0:
        return
    db_token = None
    print(token)
    try:
        # 查询现有记录
        db_token = db.query(TokenModel).filter_by(user_id=user_id).first()
        if db_token:
            # 记录存在,进行更新
            db_token.token = access_token
            for k, v in token.items():
                setattr(db_token, k.replace("app", "token"), v)
        else:
            # 记录不存在,进行插入
            db_token = TokenModel(
                user_id=user_id,
                token=access_token,
            )
            for k, v in token.items():
                setattr(db_token, k.replace("app", "token"), v)
            db.add(db_token)
        # 提交事务
        db.commit()
        db.refresh(db_token)
    except Exception as e:
        logger.error(e)
        # 异常处理
        db.rollback()  # 回滚事务
def get_token(db: Session, user_id: int) -> Type[TokenModel] | None:
    return db.query(TokenModel).filter_by(user_id=user_id).first()