From a791022ff1311e1fb76930c398d6ff91036d0456 Mon Sep 17 00:00:00 2001 From: zhaoqingang <zhaoqg0118@163.com> Date: 星期三, 11 十二月 2024 17:57:52 +0800 Subject: [PATCH] 新增加标签功能 --- app/service/v2/app_register.py | 6 app/service/label.py | 64 +++++++++ app/service/service_token.py | 23 ++- app/api/chat.py | 3 app/service/knowledge.py | 21 -- app/api/user.py | 2 alembic/versions/c437168c1da4_label_tabel_add.py | 53 +++++++ app/api/auth.py | 9 app/service/auth.py | 36 +++++ app/models/__init__.py | 1 alembic/versions/b2f03e852b6e_agent_type_add.py | 30 ++++ app/service/dialog.py | 46 ++++-- app/api/dialog.py | 4 app/api/label.py | 45 ++++++ main.py | 3 app/models/label_model.py | 46 ++++++ app/models/dialog_model.py | 5 app/models/token_model.py | 8 18 files changed, 351 insertions(+), 54 deletions(-) diff --git a/alembic/versions/b2f03e852b6e_agent_type_add.py b/alembic/versions/b2f03e852b6e_agent_type_add.py new file mode 100644 index 0000000..a19c34c --- /dev/null +++ b/alembic/versions/b2f03e852b6e_agent_type_add.py @@ -0,0 +1,30 @@ +"""agent type add + +Revision ID: b2f03e852b6e +Revises: c437168c1da4 +Create Date: 2024-12-11 17:56:12.125274 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b2f03e852b6e' +down_revision: Union[str, None] = 'c437168c1da4' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('dialogs', sa.Column('dialog_type', sa.String(length=1), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('dialogs', 'dialog_type') + # ### end Alembic commands ### diff --git a/alembic/versions/c437168c1da4_label_tabel_add.py b/alembic/versions/c437168c1da4_label_tabel_add.py new file mode 100644 index 0000000..91026a3 --- /dev/null +++ b/alembic/versions/c437168c1da4_label_tabel_add.py @@ -0,0 +1,53 @@ +"""label tabel add + +Revision ID: c437168c1da4 +Revises: 2f304d60542b +Create Date: 2024-12-11 15:01:45.049315 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'c437168c1da4' +down_revision: Union[str, None] = '2f304d60542b' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('label', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('name', sa.String(length=128), nullable=False), + sa.Column('status', sa.String(length=10), nullable=True), + sa.Column('creator', sa.Integer(), nullable=True), + sa.Column('label_type', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_label_id'), 'label', ['id'], unique=False) + op.create_index(op.f('ix_label_name'), 'label', ['name'], unique=True) + op.create_table('label_worker', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('label_id', sa.Integer(), nullable=True), + sa.Column('object_id', sa.String(length=36), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('label_id', 'object_id', name='label_object_id_ix') + ) + op.create_index(op.f('ix_label_worker_id'), 'label_worker', ['id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_label_worker_id'), table_name='label_worker') + op.drop_table('label_worker') + op.drop_index(op.f('ix_label_name'), table_name='label') + op.drop_index(op.f('ix_label_id'), table_name='label') + op.drop_table('label') + # ### end Alembic commands ### diff --git a/app/api/auth.py b/app/api/auth.py index f850882..ba86a89 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -16,7 +16,7 @@ from app.models.user import UserCreate, LoginData from app.models.user_model import UserModel from app.service.auth import authenticate_user, create_access_token, is_valid_password, save_register_user, \ - update_user_token, UserAppDao + update_user_token, UserAppDao, update_user_info from app.service.bisheng import BishengService from app.service.v2.app_register import AppRegisterDao from app.service.difyService import DifyService @@ -83,13 +83,13 @@ access_token = create_access_token(data={"sub": user.username, "user_id": user.id}) upsert_token(db, user.id, access_token, bisheng_token, ragflow_token) - + # print(111) return Response(code=200, msg="Login successful", data={ "access_token": access_token, "token_type": "bearer", "username": user.username, "nickname": "", - "user": user.to_login_json() + # "user": user.to_login_json() }) @@ -120,9 +120,12 @@ logger.error("鏈煡娉ㄥ唽搴旂敤---") continue try: + name = login_data.username user_app = await UserAppDao(db).get_data_by_id(user.id, app["id"]) if user_app: name = user_app.username + else: + await update_user_info(db, user.id) token = await service.login(name, login_data.password) token_dict[app["id"]] = token except Exception as e: diff --git a/app/api/chat.py b/app/api/chat.py index cde74ff..2628bac 100644 --- a/app/api/chat.py +++ b/app/api/chat.py @@ -41,13 +41,16 @@ ret = {"message": "Agent not found", "type": "close"} await websocket.send_json(ret) return + print(1111) agent_type = agent.agent_type + print(agent_type) if chat_id == "" or chat_id == "0": ret = {"message": "Chat ID not found", "type": "close"} await websocket.send_json(ret) return if agent_type == AgentType.RAGFLOW: + print(222) ragflow_service = RagflowService(settings.fwr_base_url) token = await get_ragflow_token(db, current_user.id) try: diff --git a/app/api/dialog.py b/app/api/dialog.py index e50cfd3..cb2b1e4 100644 --- a/app/api/dialog.py +++ b/app/api/dialog.py @@ -14,8 +14,10 @@ async def dialog_list(current: int, pageSize: int, keyword: str = "", + label: int =0, + status: str ="", current_user: UserModel = Depends(get_current_user), db=Depends(get_db)): if current and not pageSize: return ResponseList(code=400, msg="缂哄皯鍙傛暟") - return Response(code=200, msg="", data=await get_dialog_list(db, current_user.id, keyword, pageSize, current)) + return Response(code=200, msg="", data=await get_dialog_list(db, current_user.id, keyword, label, status, pageSize, current)) diff --git a/app/api/label.py b/app/api/label.py new file mode 100644 index 0000000..85744a5 --- /dev/null +++ b/app/api/label.py @@ -0,0 +1,45 @@ +# coding:utf-8 + +from fastapi import APIRouter, Depends +from app.api import Response, get_current_user +from app.models.base_model import get_db +from app.models.label_model import LabelData, LabelModel, SignLabelData +from app.models.user_model import UserModel +from app.service.label import create_label_service, label_list_service, delete_role_service, sign_label_service + +label_router = APIRouter() + +@label_router.get("/list", response_model=Response) +async def get_label_list(keyword="", labelType=1,current_user: UserModel = Depends(get_current_user), + db=Depends(get_db)): + + return Response(code=200, msg="", data=await label_list_service(db, keyword,labelType)) + + +@label_router.post("/add_label", response_model=Response) +async def add_label_api(label: LabelData, current_user: UserModel = Depends(get_current_user), db=Depends(get_db)): + if not label.labelName: + return Response(code=400, msg="The labelName cannot be empty!") + db_role = db.query(LabelModel).filter(LabelModel.name == label.labelName).first() + if db_role: + return Response(code=400, msg="label already created") + is_create = await create_label_service(db, label.labelName, label.labelType, current_user.id) + if not is_create: + return Response(code=500, msg="label create failure", data={}) + return Response(code=200, msg="label create successfully", data={"roleName": label.labelName}) + + +@label_router.delete("/delete_label", response_model=Response) +async def delete_label_api(labelId: int, current_user: UserModel = Depends(get_current_user), db=Depends(get_db)): + is_delete = await delete_role_service(db, labelId) + if not is_delete: + return Response(code=500, msg="label delete failure", data={}) + return Response(code=200, msg="label delete successfully", data={}) + + +@label_router.post("/sign_label", response_model=Response) +async def sign_label_api(sign: SignLabelData, current_user: UserModel = Depends(get_current_user), db=Depends(get_db)): + is_add = await sign_label_service(db, sign.objectId, sign.labelIdList) + if not is_add: + return Response(code=500, msg="label add failure", data={}) + return Response(code=200, msg="label sign add successfully", data={}) \ No newline at end of file diff --git a/app/api/user.py b/app/api/user.py index b5bd520..b442af8 100644 --- a/app/api/user.py +++ b/app/api/user.py @@ -99,7 +99,7 @@ @user_router.get("/menus", response_model=ResponseList) -async def user_menus(current_user: UserModel = Depends(get_current_user),db=Depends(get_db)): +async def user_menus(keyword="", current_user: UserModel = Depends(get_current_user),db=Depends(get_db)): menus = await get_user_menus(db, current_user.id) # return Response(code=200, msg="successfully", data=menus) # result = [item.to_dict() for item in agents] diff --git a/app/models/__init__.py b/app/models/__init__.py index c6370a9..126abc8 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -15,6 +15,7 @@ from .session_model import SessionModel from .public_api_model import * from .menu_model import * +from .label_model import * # 鑾峰彇褰撳墠鏃跺尯鐨勬椂闂� diff --git a/app/models/dialog_model.py b/app/models/dialog_model.py index de1e2b5..f08ca78 100644 --- a/app/models/dialog_model.py +++ b/app/models/dialog_model.py @@ -18,7 +18,7 @@ description = Column(Text) # 璇存槑 icon = Column(Text) # 鍥炬爣 status = Column(String(1)) # 鐘舵�� - # dialog_type = Column(String(1)) # # 骞冲彴 + dialog_type = Column(String(1)) # # 骞冲彴 def get_id(self): return str(self.id) @@ -32,7 +32,8 @@ 'name': self.name, 'description': self.description, 'icon': self.icon, - 'status': self.status + 'status': self.status, + 'agentType': self.dialog_type, } diff --git a/app/models/label_model.py b/app/models/label_model.py new file mode 100644 index 0000000..41a6fe3 --- /dev/null +++ b/app/models/label_model.py @@ -0,0 +1,46 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, constr +from sqlalchemy import Column, Integer, String, DateTime, Table, ForeignKey, UniqueConstraint +from sqlalchemy.orm import relationship, backref + +from app.models.base_model import Base + + +class LabelModel(Base): + __tablename__ = 'label' + id = Column(Integer, primary_key=True, index=True) + created_at = Column(DateTime, default=datetime.now()) + updated_at = Column(DateTime, default=datetime.now(), onupdate=datetime.now()) + name = Column(String(128), unique=True, nullable=False, index=True) + status = Column(String(10), default="1") + creator = Column(Integer) + label_type = Column(Integer, default=1) + + def to_json(self): + return { + 'labelId': self.id, + 'labelName': self.name, + 'labelType': self.label_type + } + + +class LabelWorkerModel(Base): + __tablename__ = 'label_worker' + __table_args__ = (UniqueConstraint('label_id', 'object_id', name='label_object_id_ix'),) + id = Column(Integer, primary_key=True, index=True) + label_id = Column(Integer) + object_id = Column(String(36)) + + + +class LabelData(BaseModel): + labelName: str + labelType: Optional[int] = 1 + + +class SignLabelData(BaseModel): + labelIdList: list + objectId:str + diff --git a/app/models/token_model.py b/app/models/token_model.py index fad57e0..febe111 100644 --- a/app/models/token_model.py +++ b/app/models/token_model.py @@ -90,6 +90,8 @@ async def get_token(db: Session, user_id: int): - # return db.query(TokenModel).filter_by(user_id=user_id).first() - - return {i.app_type.replace("app", "token"): i.access_token for i in await UserAppDao(db).get_user_datas(user_id)} + res = {i.app_type.replace("app", "token"): i.access_token for i in await UserAppDao(db).get_user_datas(user_id)} + if not res: + token = db.query(TokenModel).filter_by(user_id=user_id).first() + res = {"ragflow_token": token.ragflow_token, "bisheng_token": token.bisheng_token} + return res diff --git a/app/service/auth.py b/app/service/auth.py index 46c42dd..a0a1952 100644 --- a/app/service/auth.py +++ b/app/service/auth.py @@ -9,8 +9,10 @@ from Log import logger from app.config.config import settings +from app.config.const import RAGFLOW, BISHENG, DIFY from app.models import RoleModel, GroupModel from app.models.user_model import UserModel, UserAppModel +from app.service.v2.app_register import AppRegisterDao SECRET_KEY = settings.secret_key ALGORITHM = "HS256" @@ -103,6 +105,40 @@ return True +async def update_user_info(db, user_id): + app_register = AppRegisterDao(db).get_apps() + register_dict = {} + user = db.query(UserModel).filter(UserModel.id==user_id).first() + for app in app_register: + if app["id"] == RAGFLOW: + register_dict[app['id']] = {"id": user.ragflow_id, "name": user.username, "email": f"{user.username}@example.com"} + elif app["id"] == BISHENG: + register_dict[app['id']] = {"id": user.bisheng_id, "name": user.username, "email": ""} + elif app["id"] == DIFY: + register_dict[app['id']] = {"id": "", "name": user.username, "email": ""} + else: + logger.error("鏈煡娉ㄥ唽搴旂敤---") + continue + + try: + for k, v in register_dict.items(): + await UserAppDao(db).update_and_insert_data(v.get("name"), user.password, v.get("email"), user_id, + str(v.get("id")), k) + except Exception as e: + logger.error(e) + + # 瀛樺偍鐢ㄦ埛淇℃伅 + # hashed_password = pwd_context.hash(user.password) + # db_user = UserModel(username=user.username, hashed_password=hashed_password, email=user.email) + # db_user.password = db_user.encrypted_password(user.password) + # for k, v in register_dict.items(): + # setattr(db_user, k.replace("app", "id"), v) + # db.add(db_user) + # db.commit() + # db.refresh(db_user) + + is_sava = await save_register_user(db, user.username, user.password, user.email, register_dict) + class UserAppDao: def __init__(self, db: Session): self.db = db diff --git a/app/service/dialog.py b/app/service/dialog.py index d78fbbe..623477a 100644 --- a/app/service/dialog.py +++ b/app/service/dialog.py @@ -1,43 +1,55 @@ from datetime import datetime -from app.models import KnowledgeModel, GroupModel, DialogModel, ConversationModel, group_dialog_table +from sqlalchemy import or_ + +from app.models import KnowledgeModel, GroupModel, DialogModel, ConversationModel, group_dialog_table, LabelWorkerModel, \ + LabelModel from app.models.user_model import UserModel from Log import logger -async def get_dialog_list(db, user_id, keyword, page_size, page_index): +async def get_dialog_list(db, user_id, keyword, label, status, page_size, page_index): user = db.query(UserModel).filter(UserModel.id == user_id).first() if user is None: return {"rows": []} - if user.permission == "admin": - query = db.query(DialogModel) + query = db.query(DialogModel) + id_list = [] + if label: + id_list = [i.object_id for i in db.query(LabelWorkerModel).filter(LabelWorkerModel.label_id==label).all()] + if user.permission != "admin": + dia_list = [j.id for i in user.groups for j in i.dialogs if not label or j.id in id_list] + query = query.filter(or_(DialogModel.tenant_id == user_id, DialogModel.id.in_(dia_list))) else: - group_list = [i.id for i in user.groups] - query = db.query(DialogModel) - query = query.filter(DialogModel.tenant_id == user_id) - query = query.union( - db.query(DialogModel).join( - group_dialog_table, - DialogModel.id == group_dialog_table.c.dialog_id - ).filter( - group_dialog_table.c.group_id.in_(group_list) - ) - ) + if label: + query = query.filter(or_(DialogModel.id.in_(id_list))) + if keyword: query = query.filter(DialogModel.name.like('%{}%'.format(keyword))) + + if status: + print(status) + query = query.filter(DialogModel.status == status) total = query.count() if page_size: query = query.limit(page_size).offset((page_index - 1) * page_size) rows = [] user_id_set = set() + dialog_id_set = set() + label_dict = {} for kld in query.all(): user_id_set.add(kld.tenant_id) + dialog_id_set.add(kld.id) rows.append(kld.to_json()) - print(rows) - user_dict = {i.id: i.to_dict() for i in db.query(UserModel).filter(UserModel.id.in_(user_id_set)).all()} + user_dict = {str(i.id): i.to_dict() for i in db.query(UserModel).filter(UserModel.id.in_(user_id_set)).all()} + for i in db.query(LabelModel.id, LabelModel.name, LabelWorkerModel.object_id).outerjoin(LabelWorkerModel, + LabelModel.id == LabelWorkerModel.label_id).filter( + LabelWorkerModel.object_id.in_(dialog_id_set)).all(): + + label_dict[i.object_id] = label_dict.get(i.object_id, []) +[{"labelId": i.id, "labelName": i.name}] for r in rows: r["user"] = user_dict.get(r["user_id"], {}) + r["label"] = label_dict.get(r["id"], []) return {"total": total, "rows": rows} diff --git a/app/service/knowledge.py b/app/service/knowledge.py index 91022c3..b927feb 100644 --- a/app/service/knowledge.py +++ b/app/service/knowledge.py @@ -1,3 +1,5 @@ +from sqlalchemy import or_ + from app.models import KnowledgeModel, group_knowledge_table from app.models.user_model import UserModel from Log import logger @@ -7,21 +9,10 @@ user = db.query(UserModel).filter(UserModel.id == user_id).first() if user is None: return {"rows": []} - if user.permission == "admin": - query = db.query(KnowledgeModel) - else: - group_list = [i.id for i in user.groups] - query = db.query(KnowledgeModel) - query = query.filter(KnowledgeModel.tenant_id == user_id) - - query = query.union( - db.query(KnowledgeModel).join( - group_knowledge_table, - KnowledgeModel.id == group_knowledge_table.c.knowledge_id - ).filter( - group_knowledge_table.c.group_id.in_(group_list) - ) - ) + query = db.query(KnowledgeModel) + if user.permission != "admin": + klg_list = [j.id for i in user.groups for j in i.knowledges] + query = query.filter(or_(KnowledgeModel.tenant_id == user_id, KnowledgeModel.id.in_(klg_list))) if keyword: query = query.filter(KnowledgeModel.name.like('%{}%'.format(keyword))) total = query.count() diff --git a/app/service/label.py b/app/service/label.py new file mode 100644 index 0000000..54448b5 --- /dev/null +++ b/app/service/label.py @@ -0,0 +1,64 @@ +import uuid + +from streamlit.time_util import adjust_years + +from Log import logger +from app.models.label_model import LabelModel, LabelWorkerModel +from app.models.role_model import RoleModel + + +async def label_list_service(db, keyword: str, label_type): + query = db.query(LabelModel) + if keyword: + query = query.filter(LabelModel.name.like('%{}%'.format(keyword))) + if label_type: + query = query.filter(LabelModel.label_type==label_type) + labels = query.order_by(LabelModel.id.desc()).all() + return {"total": query.count(), "rows": [label.to_json() for label in labels]} + + + +async def create_label_service(db, label_name, label_type, user_id): + try: + label_model = LabelModel(name=label_name,creator=user_id, label_type=label_type) + db.add(label_model) + db.commit() + db.refresh(label_model) + except Exception as e: + logger.error(e) + db.rollback() + return False + return True + + +async def delete_role_service(db, label_id: int): + try: + db.query(LabelModel).filter(LabelModel.id == label_id).delete() + db.commit() + except Exception as e: + logger.error(e) + db.rollback() + return False + return True + + +async def sign_label_service(db, object_id, label_list): + delete_list = [] + has_list = [] + for i in db.query(LabelWorkerModel).filter(LabelWorkerModel.object_id == object_id).all(): + if i.label_id not in label_list: + delete_list.append(i.id) + else: + has_list.append(i.label_id) + for label_id in label_list: + if label_id in has_list: + continue + try: + label = LabelWorkerModel(label_id=label_id, object_id=object_id) + db.add(label) + db.commit() + except Exception as e: + logger.error(e) + db.rollback() + # return False + return True \ No newline at end of file diff --git a/app/service/service_token.py b/app/service/service_token.py index 8fffc99..215353b 100644 --- a/app/service/service_token.py +++ b/app/service/service_token.py @@ -9,19 +9,24 @@ async def get_bisheng_token(db, user_id: int): - # token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first() - token = await UserAppDao.get_data_by_id(user_id, BISHENG) + token = await UserAppDao(db).get_data_by_id(user_id, BISHENG) if not token: - return None - return token.access_token - + token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first() + if not token: + return None + else: + return token.access_token + return token.bisheng_token async def get_ragflow_token(db, user_id: int): - token = await UserAppDao.get_data_by_id(user_id, RAGFLOW) + token = await UserAppDao(db).get_data_by_id(user_id, RAGFLOW) if not token: - return None - return token.access_token - + token = db.query(TokenModel).filter(TokenModel.user_id == user_id).first() + if not token: + return None + else: + return token.access_token + return token.ragflow_token async def get_ragflow_new_token(db, user_id: int, app_type): user = db.query(UserModel).filter(UserModel.id == user_id).first() diff --git a/app/service/v2/app_register.py b/app/service/v2/app_register.py index e30816b..88c5393 100644 --- a/app/service/v2/app_register.py +++ b/app/service/v2/app_register.py @@ -1,6 +1,8 @@ +from datetime import datetime + from app.models.public_api_model import AppRegisterModel from Log import logger -from app.models import current_time +# from app.models import current_time from sqlalchemy.orm import Session from typing import Type @@ -18,7 +20,7 @@ logger.error("鏇存柊鏁版嵁: app register---------------------------") try: - self.db.query(AppRegisterModel).filter(AppRegisterModel.id==app_id).update({"status":status, "updated_at": current_time()}) + self.db.query(AppRegisterModel).filter(AppRegisterModel.id==app_id).update({"status":status, "updated_at": datetime.now()}) self.db.commit() except Exception as e: logger.error(e) diff --git a/main.py b/main.py index 9c00b97..a9e4f0f 100644 --- a/main.py +++ b/main.py @@ -12,6 +12,7 @@ from app.api.excel import router as excel_router from app.api.files import router as files_router from app.api.knowledge import knowledge_router +from app.api.label import label_router from app.api.llm import llm_router from app.api.organization import dept_router from app.api.v2.public_api import public_api @@ -74,7 +75,7 @@ app.include_router(llm_router, prefix='/api/llm', tags=["llm"]) app.include_router(dialog_router, prefix='/api/dialog', tags=["dialog"]) app.include_router(canvas_router, prefix='/api/canvas', tags=["canvas"]) -# app.include_router(sync_router, prefix='/api/sync', tags=["sync"]) +app.include_router(label_router, prefix='/api/label', tags=["label"]) app.include_router(public_api, prefix='/v1/api', tags=["public_api"]) app.mount("/static", StaticFiles(directory="app/images"), name="static") -- Gitblit v1.8.0