zhaoqingang
2024-12-11 a791022ff1311e1fb76930c398d6ff91036d0456
新增加标签功能
13个文件已修改
5个文件已添加
405 ■■■■ 已修改文件
alembic/versions/b2f03e852b6e_agent_type_add.py 30 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
alembic/versions/c437168c1da4_label_tabel_add.py 53 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/auth.py 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/chat.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/dialog.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/label.py 45 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/user.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/__init__.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/dialog_model.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/label_model.py 46 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/token_model.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/auth.py 36 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/dialog.py 46 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/knowledge.py 21 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/label.py 64 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/service_token.py 23 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/v2/app_register.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
main.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
alembic/versions/b2f03e852b6e_agent_type_add.py
New file
@@ -0,0 +1,30 @@
"""agent type add
Revision ID: b2f03e852b6e
Revises: c437168c1da4
Create Date: 2024-12-11 17:56:12.125274
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b2f03e852b6e'
down_revision: Union[str, None] = 'c437168c1da4'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    op.add_column('dialogs', sa.Column('dialog_type', sa.String(length=1), nullable=True))
    # ### end Alembic commands ###
def downgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    op.drop_column('dialogs', 'dialog_type')
    # ### end Alembic commands ###
alembic/versions/c437168c1da4_label_tabel_add.py
New file
@@ -0,0 +1,53 @@
"""label tabel add
Revision ID: c437168c1da4
Revises: 2f304d60542b
Create Date: 2024-12-11 15:01:45.049315
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'c437168c1da4'
down_revision: Union[str, None] = '2f304d60542b'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    op.create_table('label',
    sa.Column('id', sa.Integer(), nullable=False),
    sa.Column('created_at', sa.DateTime(), nullable=True),
    sa.Column('updated_at', sa.DateTime(), nullable=True),
    sa.Column('name', sa.String(length=128), nullable=False),
    sa.Column('status', sa.String(length=10), nullable=True),
    sa.Column('creator', sa.Integer(), nullable=True),
    sa.Column('label_type', sa.Integer(), nullable=True),
    sa.PrimaryKeyConstraint('id')
    )
    op.create_index(op.f('ix_label_id'), 'label', ['id'], unique=False)
    op.create_index(op.f('ix_label_name'), 'label', ['name'], unique=True)
    op.create_table('label_worker',
    sa.Column('id', sa.Integer(), nullable=False),
    sa.Column('label_id', sa.Integer(), nullable=True),
    sa.Column('object_id', sa.String(length=36), nullable=True),
    sa.PrimaryKeyConstraint('id'),
    sa.UniqueConstraint('label_id', 'object_id', name='label_object_id_ix')
    )
    op.create_index(op.f('ix_label_worker_id'), 'label_worker', ['id'], unique=False)
    # ### end Alembic commands ###
def downgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    op.drop_index(op.f('ix_label_worker_id'), table_name='label_worker')
    op.drop_table('label_worker')
    op.drop_index(op.f('ix_label_name'), table_name='label')
    op.drop_index(op.f('ix_label_id'), table_name='label')
    op.drop_table('label')
    # ### end Alembic commands ###
app/api/auth.py
@@ -16,7 +16,7 @@
from app.models.user import UserCreate, LoginData
from app.models.user_model import UserModel
from app.service.auth import authenticate_user, create_access_token, is_valid_password, save_register_user, \
    update_user_token, UserAppDao
    update_user_token, UserAppDao, update_user_info
from app.service.bisheng import BishengService
from app.service.v2.app_register import AppRegisterDao
from app.service.difyService import DifyService
@@ -83,13 +83,13 @@
    access_token = create_access_token(data={"sub": user.username, "user_id": user.id})
    upsert_token(db, user.id, access_token, bisheng_token, ragflow_token)
    # print(111)
    return Response(code=200, msg="Login successful", data={
        "access_token": access_token,
        "token_type": "bearer",
        "username": user.username,
        "nickname": "",
        "user": user.to_login_json()
        # "user": user.to_login_json()
    })
@@ -120,9 +120,12 @@
            logger.error("未知注册应用---")
            continue
        try:
            name = login_data.username
            user_app = await UserAppDao(db).get_data_by_id(user.id, app["id"])
            if user_app:
                name  = user_app.username
            else:
                await update_user_info(db, user.id)
            token = await service.login(name, login_data.password)
            token_dict[app["id"]] = token
        except Exception as e:
app/api/chat.py
@@ -41,13 +41,16 @@
        ret = {"message": "Agent not found", "type": "close"}
        await websocket.send_json(ret)
        return
    print(1111)
    agent_type = agent.agent_type
    print(agent_type)
    if chat_id == "" or chat_id == "0":
        ret = {"message": "Chat ID not found", "type": "close"}
        await websocket.send_json(ret)
        return
    if agent_type == AgentType.RAGFLOW:
        print(222)
        ragflow_service = RagflowService(settings.fwr_base_url)
        token = await get_ragflow_token(db, current_user.id)
        try:
app/api/dialog.py
@@ -14,8 +14,10 @@
async def dialog_list(current: int,
                      pageSize: int,
                      keyword: str = "",
                      label: int =0,
                      status: str ="",
                      current_user: UserModel = Depends(get_current_user),
                      db=Depends(get_db)):
    if current and not pageSize:
        return ResponseList(code=400, msg="缺少参数")
    return Response(code=200, msg="", data=await get_dialog_list(db, current_user.id, keyword, pageSize, current))
    return Response(code=200, msg="", data=await get_dialog_list(db, current_user.id, keyword, label, status, pageSize, current))
app/api/label.py
New file
@@ -0,0 +1,45 @@
# coding:utf-8
from fastapi import APIRouter, Depends
from app.api import Response, get_current_user
from app.models.base_model import get_db
from app.models.label_model import LabelData, LabelModel, SignLabelData
from app.models.user_model import UserModel
from app.service.label import create_label_service, label_list_service, delete_role_service, sign_label_service
label_router = APIRouter()
@label_router.get("/list", response_model=Response)
async def get_label_list(keyword="", labelType=1,current_user: UserModel = Depends(get_current_user),
                          db=Depends(get_db)):
    return Response(code=200, msg="", data=await label_list_service(db, keyword,labelType))
@label_router.post("/add_label", response_model=Response)
async def add_label_api(label: LabelData, current_user: UserModel = Depends(get_current_user), db=Depends(get_db)):
    if not label.labelName:
        return Response(code=400, msg="The labelName cannot be empty!")
    db_role = db.query(LabelModel).filter(LabelModel.name == label.labelName).first()
    if db_role:
        return Response(code=400, msg="label already created")
    is_create = await create_label_service(db, label.labelName, label.labelType, current_user.id)
    if not is_create:
        return Response(code=500, msg="label create failure", data={})
    return Response(code=200, msg="label create successfully", data={"roleName": label.labelName})
@label_router.delete("/delete_label", response_model=Response)
async def delete_label_api(labelId: int, current_user: UserModel = Depends(get_current_user), db=Depends(get_db)):
    is_delete = await delete_role_service(db, labelId)
    if not is_delete:
        return Response(code=500, msg="label delete failure", data={})
    return Response(code=200, msg="label delete successfully", data={})
@label_router.post("/sign_label", response_model=Response)
async def sign_label_api(sign: SignLabelData, current_user: UserModel = Depends(get_current_user), db=Depends(get_db)):
    is_add = await sign_label_service(db, sign.objectId, sign.labelIdList)
    if not is_add:
        return Response(code=500, msg="label add failure", data={})
    return Response(code=200, msg="label sign add successfully", data={})
app/api/user.py
@@ -99,7 +99,7 @@
@user_router.get("/menus", response_model=ResponseList)
async def user_menus(current_user: UserModel = Depends(get_current_user),db=Depends(get_db)):
async def user_menus(keyword="", current_user: UserModel = Depends(get_current_user),db=Depends(get_db)):
    menus = await get_user_menus(db,  current_user.id)
    # return Response(code=200, msg="successfully", data=menus)
    # result = [item.to_dict() for item in agents]
app/models/__init__.py
@@ -15,6 +15,7 @@
from .session_model import SessionModel
from .public_api_model import *
from .menu_model import *
from .label_model import *
# 获取当前时区的时间
app/models/dialog_model.py
@@ -18,7 +18,7 @@
    description = Column(Text)                 # 说明
    icon = Column(Text)                         # 图标
    status = Column(String(1))                 # 状态
    # dialog_type = Column(String(1))            #    # 平台
    dialog_type = Column(String(1))            #    # 平台
    def get_id(self):
        return str(self.id)
@@ -32,7 +32,8 @@
            'name': self.name,
            'description': self.description,
            'icon': self.icon,
            'status': self.status
            'status': self.status,
            'agentType': self.dialog_type,
        }
app/models/label_model.py
New file
@@ -0,0 +1,46 @@
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, constr
from sqlalchemy import Column, Integer, String, DateTime, Table, ForeignKey, UniqueConstraint
from sqlalchemy.orm import relationship, backref
from app.models.base_model import Base
class LabelModel(Base):
    __tablename__ = 'label'
    id = Column(Integer, primary_key=True, index=True)
    created_at = Column(DateTime, default=datetime.now())
    updated_at = Column(DateTime, default=datetime.now(), onupdate=datetime.now())
    name = Column(String(128), unique=True, nullable=False, index=True)
    status = Column(String(10), default="1")
    creator = Column(Integer)
    label_type = Column(Integer, default=1)
    def to_json(self):
        return {
            'labelId': self.id,
            'labelName': self.name,
            'labelType': self.label_type
        }
class LabelWorkerModel(Base):
    __tablename__ = 'label_worker'
    __table_args__ = (UniqueConstraint('label_id', 'object_id', name='label_object_id_ix'),)
    id = Column(Integer, primary_key=True, index=True)
    label_id = Column(Integer)
    object_id = Column(String(36))
class LabelData(BaseModel):
    labelName: str
    labelType: Optional[int] = 1
class SignLabelData(BaseModel):
    labelIdList: list
    objectId:str
app/models/token_model.py
@@ -90,6 +90,8 @@
async def get_token(db: Session, user_id: int):
    # return db.query(TokenModel).filter_by(user_id=user_id).first()
    return {i.app_type.replace("app", "token"): i.access_token for i in await UserAppDao(db).get_user_datas(user_id)}
    res = {i.app_type.replace("app", "token"): i.access_token for i in await UserAppDao(db).get_user_datas(user_id)}
    if not res:
        token = db.query(TokenModel).filter_by(user_id=user_id).first()
        res = {"ragflow_token": token.ragflow_token, "bisheng_token": token.bisheng_token}
    return res
app/service/auth.py
@@ -9,8 +9,10 @@
from Log import logger
from app.config.config import settings
from app.config.const import RAGFLOW, BISHENG, DIFY
from app.models import RoleModel, GroupModel
from app.models.user_model import UserModel, UserAppModel
from app.service.v2.app_register import AppRegisterDao
SECRET_KEY = settings.secret_key
ALGORITHM = "HS256"
@@ -103,6 +105,40 @@
    return True
async def update_user_info(db, user_id):
    app_register = AppRegisterDao(db).get_apps()
    register_dict = {}
    user = db.query(UserModel).filter(UserModel.id==user_id).first()
    for app in app_register:
        if app["id"] == RAGFLOW:
            register_dict[app['id']] = {"id": user.ragflow_id, "name": user.username, "email": f"{user.username}@example.com"}
        elif app["id"] == BISHENG:
            register_dict[app['id']] = {"id": user.bisheng_id, "name": user.username, "email": ""}
        elif app["id"] == DIFY:
            register_dict[app['id']] = {"id": "", "name": user.username, "email": ""}
        else:
            logger.error("未知注册应用---")
            continue
        try:
            for k, v in register_dict.items():
                await UserAppDao(db).update_and_insert_data(v.get("name"), user.password, v.get("email"), user_id,
                                                            str(v.get("id")), k)
        except Exception as e:
            logger.error(e)
    # 存储用户信息
    # hashed_password = pwd_context.hash(user.password)
    # db_user = UserModel(username=user.username, hashed_password=hashed_password, email=user.email)
    # db_user.password = db_user.encrypted_password(user.password)
    # for k, v in register_dict.items():
    #     setattr(db_user, k.replace("app", "id"), v)
    # db.add(db_user)
    # db.commit()
    # db.refresh(db_user)
    is_sava = await save_register_user(db, user.username, user.password, user.email, register_dict)
class UserAppDao:
    def __init__(self, db: Session):
        self.db = db
app/service/dialog.py
@@ -1,43 +1,55 @@
from datetime import datetime
from app.models import KnowledgeModel, GroupModel, DialogModel, ConversationModel, group_dialog_table
from sqlalchemy import or_
from app.models import KnowledgeModel, GroupModel, DialogModel, ConversationModel, group_dialog_table, LabelWorkerModel, \
    LabelModel
from app.models.user_model import UserModel
from Log import logger
async def get_dialog_list(db, user_id, keyword, page_size, page_index):
async def get_dialog_list(db, user_id, keyword, label, status, page_size, page_index):
    user = db.query(UserModel).filter(UserModel.id == user_id).first()
    if user is None:
        return {"rows": []}
    if user.permission == "admin":
        query = db.query(DialogModel)
    query = db.query(DialogModel)
    id_list = []
    if label:
        id_list = [i.object_id for i in db.query(LabelWorkerModel).filter(LabelWorkerModel.label_id==label).all()]
    if user.permission != "admin":
        dia_list = [j.id for i in user.groups for j in i.dialogs if not label or j.id in id_list]
        query = query.filter(or_(DialogModel.tenant_id == user_id, DialogModel.id.in_(dia_list)))
    else:
        group_list = [i.id for i in user.groups]
        query = db.query(DialogModel)
        query = query.filter(DialogModel.tenant_id == user_id)
        query = query.union(
            db.query(DialogModel).join(
                group_dialog_table,
                DialogModel.id == group_dialog_table.c.dialog_id
            ).filter(
                group_dialog_table.c.group_id.in_(group_list)
            )
        )
        if label:
            query = query.filter(or_(DialogModel.id.in_(id_list)))
    if keyword:
        query = query.filter(DialogModel.name.like('%{}%'.format(keyword)))
    if status:
        print(status)
        query = query.filter(DialogModel.status == status)
    total = query.count()
    if page_size:
        query = query.limit(page_size).offset((page_index - 1) * page_size)
    rows = []
    user_id_set = set()
    dialog_id_set = set()
    label_dict = {}
    for kld in query.all():
        user_id_set.add(kld.tenant_id)
        dialog_id_set.add(kld.id)
        rows.append(kld.to_json())
    print(rows)
    user_dict = {i.id: i.to_dict() for i in db.query(UserModel).filter(UserModel.id.in_(user_id_set)).all()}
    user_dict = {str(i.id): i.to_dict() for i in db.query(UserModel).filter(UserModel.id.in_(user_id_set)).all()}
    for i in db.query(LabelModel.id, LabelModel.name, LabelWorkerModel.object_id).outerjoin(LabelWorkerModel,
                                                                LabelModel.id == LabelWorkerModel.label_id).filter(
            LabelWorkerModel.object_id.in_(dialog_id_set)).all():
        label_dict[i.object_id] = label_dict.get(i.object_id, []) +[{"labelId": i.id, "labelName": i.name}]
    for r in rows:
        r["user"] = user_dict.get(r["user_id"], {})
        r["label"] = label_dict.get(r["id"], [])
    return {"total":  total, "rows": rows}
app/service/knowledge.py
@@ -1,3 +1,5 @@
from sqlalchemy import or_
from app.models import KnowledgeModel, group_knowledge_table
from app.models.user_model import UserModel
from Log import logger
@@ -7,21 +9,10 @@
    user = db.query(UserModel).filter(UserModel.id == user_id).first()
    if user is None:
        return {"rows": []}
    if user.permission == "admin":
        query = db.query(KnowledgeModel)
    else:
        group_list = [i.id for i in user.groups]
        query = db.query(KnowledgeModel)
        query = query.filter(KnowledgeModel.tenant_id == user_id)
        query = query.union(
            db.query(KnowledgeModel).join(
                group_knowledge_table,
                KnowledgeModel.id == group_knowledge_table.c.knowledge_id
            ).filter(
                group_knowledge_table.c.group_id.in_(group_list)
            )
        )
    query = db.query(KnowledgeModel)
    if user.permission != "admin":
        klg_list = [j.id for i in user.groups for j in i.knowledges]
        query = query.filter(or_(KnowledgeModel.tenant_id == user_id, KnowledgeModel.id.in_(klg_list)))
    if keyword:
        query = query.filter(KnowledgeModel.name.like('%{}%'.format(keyword)))
    total = query.count()
app/service/label.py
New file
@@ -0,0 +1,64 @@
import uuid
from streamlit.time_util import adjust_years
from Log import logger
from app.models.label_model import LabelModel, LabelWorkerModel
from app.models.role_model import RoleModel
async def label_list_service(db, keyword: str, label_type):
    query = db.query(LabelModel)
    if keyword:
        query = query.filter(LabelModel.name.like('%{}%'.format(keyword)))
    if label_type:
        query = query.filter(LabelModel.label_type==label_type)
    labels = query.order_by(LabelModel.id.desc()).all()
    return {"total": query.count(), "rows":  [label.to_json() for label in labels]}
async def create_label_service(db, label_name, label_type, user_id):
    try:
        label_model = LabelModel(name=label_name,creator=user_id, label_type=label_type)
        db.add(label_model)
        db.commit()
        db.refresh(label_model)
    except Exception as e:
        logger.error(e)
        db.rollback()
        return False
    return True
async def delete_role_service(db, label_id: int):
    try:
        db.query(LabelModel).filter(LabelModel.id == label_id).delete()
        db.commit()
    except Exception as e:
        logger.error(e)
        db.rollback()
        return False
    return True
async def sign_label_service(db, object_id, label_list):
    delete_list = []
    has_list = []
    for i in db.query(LabelWorkerModel).filter(LabelWorkerModel.object_id == object_id).all():
        if i.label_id not in label_list:
            delete_list.append(i.id)
        else:
            has_list.append(i.label_id)
    for label_id in label_list:
        if label_id in has_list:
            continue
        try:
            label = LabelWorkerModel(label_id=label_id, object_id=object_id)
            db.add(label)
            db.commit()
        except Exception as e:
            logger.error(e)
            db.rollback()
            # return False
    return True
app/service/service_token.py
@@ -9,19 +9,24 @@
async def get_bisheng_token(db, user_id: int):
    # token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first()
    token = await UserAppDao.get_data_by_id(user_id, BISHENG)
    token = await UserAppDao(db).get_data_by_id(user_id, BISHENG)
    if not token:
        return None
    return token.access_token
        token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first()
        if not token:
            return None
    else:
        return token.access_token
    return token.bisheng_token
async def get_ragflow_token(db, user_id: int):
    token = await UserAppDao.get_data_by_id(user_id, RAGFLOW)
    token = await UserAppDao(db).get_data_by_id(user_id, RAGFLOW)
    if not token:
        return None
    return token.access_token
        token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first()
        if not token:
            return None
    else:
        return token.access_token
    return token.ragflow_token
async def get_ragflow_new_token(db, user_id: int, app_type):
    user = db.query(UserModel).filter(UserModel.id == user_id).first()
app/service/v2/app_register.py
@@ -1,6 +1,8 @@
from datetime import datetime
from app.models.public_api_model import AppRegisterModel
from Log import logger
from app.models import current_time
# from app.models import current_time
from sqlalchemy.orm import Session
from typing import Type
@@ -18,7 +20,7 @@
        logger.error("更新数据: app register---------------------------")
        try:
            self.db.query(AppRegisterModel).filter(AppRegisterModel.id==app_id).update({"status":status, "updated_at": current_time()})
            self.db.query(AppRegisterModel).filter(AppRegisterModel.id==app_id).update({"status":status, "updated_at": datetime.now()})
            self.db.commit()
        except Exception as e:
            logger.error(e)
main.py
@@ -12,6 +12,7 @@
from app.api.excel import router as excel_router
from app.api.files import router as files_router
from app.api.knowledge import knowledge_router
from app.api.label import label_router
from app.api.llm import llm_router
from app.api.organization import dept_router
from app.api.v2.public_api import public_api
@@ -74,7 +75,7 @@
app.include_router(llm_router, prefix='/api/llm', tags=["llm"])
app.include_router(dialog_router, prefix='/api/dialog', tags=["dialog"])
app.include_router(canvas_router, prefix='/api/canvas', tags=["canvas"])
# app.include_router(sync_router, prefix='/api/sync', tags=["sync"])
app.include_router(label_router, prefix='/api/label', tags=["label"])
app.include_router(public_api, prefix='/v1/api', tags=["public_api"])
app.mount("/static", StaticFiles(directory="app/images"), name="static")