From 50f9b062456bd595d4fee86e7c90e0cac8904960 Mon Sep 17 00:00:00 2001
From: zhaoqingang <zhaoqg0118@163.com>
Date: 星期四, 07 十一月 2024 18:22:48 +0800
Subject: [PATCH] 用户组接口

---
 app/service/bisheng.py    |   10 +
 app/models/user_model.py  |    8 +
 app/service/ragflow.py    |   14 ++
 app/service/group.py      |  119 +++++++++++++++++++
 main.py                   |    4 
 app/models/group_model.py |   52 ++++++++
 app/api/group.py          |   93 +++++++++++++++
 app/api/user.py           |   35 +++++
 app/api/auth.py           |    8 
 app/models/user.py        |    7 +
 10 files changed, 345 insertions(+), 5 deletions(-)

diff --git a/app/api/auth.py b/app/api/auth.py
index 14f2c06..860fca7 100644
--- a/app/api/auth.py
+++ b/app/api/auth.py
@@ -1,3 +1,5 @@
+import json
+
 from fastapi import APIRouter, Depends
 from sqlalchemy.orm import Session
 
@@ -25,19 +27,19 @@
 
     # 娉ㄥ唽鍒版瘯鏄�
     try:
-        await bisheng_service.register(user.username, user.password)
+        bisheng_info = await bisheng_service.register(user.username, user.password)
     except Exception as e:
         return Response(code=500, msg=f"Failed to register with Bisheng: {str(e)}")
 
     # 娉ㄥ唽鍒皉agflow
     try:
-        await ragflow_service.register(user.username, user.password)
+        ragflow_info = await ragflow_service.register(user.username, user.password)
     except Exception as e:
         return Response(code=500, msg=f"Failed to register with Ragflow: {str(e)}")
 
     # 瀛樺偍鐢ㄦ埛淇℃伅
     hashed_password = pwd_context.hash(user.password)
-    db_user = UserModel(username=user.username, hashed_password=hashed_password)
+    db_user = UserModel(username=user.username, hashed_password=hashed_password, email=ragflow_info.get("email",  f"{user.username}@example.com"),ragflow_id=ragflow_info.get("id"),bisheng_id=bisheng_info.get("user_id"))
     db.add(db_user)
     db.commit()
     db.refresh(db_user)
diff --git a/app/api/group.py b/app/api/group.py
new file mode 100644
index 0000000..4485d31
--- /dev/null
+++ b/app/api/group.py
@@ -0,0 +1,93 @@
+from ast import parse
+
+from fastapi import APIRouter, Depends
+from app.api import Response, pwd_context, get_current_user, ResponseList
+from app.config.config import settings
+from app.models.base_model import get_db
+from app.models.group_model import GroupInfoModel, UserGroupModel, GroupData, GroupUsers
+from app.models.user import PageParameter
+from app.models.user_model import UserModel
+from app.service.bisheng import BishengService
+from app.service.group import create_group, group_list, edit_group_data, delete_group_data, get_group_users, \
+    save_user_to_group
+from app.service.token import get_bisheng_token
+
+group_router = APIRouter()
+
+
+@group_router.post("/group_list", response_model=Response)
+async def user_group_list(paras: PageParameter, current_user: UserModel = Depends(get_current_user),
+                          db=Depends(get_db)):
+    return Response(code=200, msg="", data=await group_list(db, paras.page_size, paras.page_index, paras.keyword))
+
+
+@group_router.post("/add_group", response_model=Response)
+async def add_group(group: GroupData, current_user: UserModel = Depends(get_current_user), db=Depends(get_db)):
+    if not group.group_name:
+        return Response(code=400, msg="The group_name cannot be empty!")
+    db_group = db.query(GroupInfoModel).filter(GroupInfoModel.group_name == group.group_name).first()
+    if db_group:
+        return Response(code=200, msg="group already created")
+    is_create = await create_group(db, group.group_name, group.group_description)
+    if not is_create:
+        return Response(code=200, msg="group create failure", data={})
+    return Response(code=200, msg="group create successfully", data={"group_name": group.group_name})
+
+
+@group_router.post("/edit_group", response_model=Response)
+async def edit_group(group: GroupData, current_user: UserModel = Depends(get_current_user), db=Depends(get_db)):
+    if not group.group_name:
+        return Response(code=400, msg="The group_name cannot be empty!")
+    db_group = db.query(GroupInfoModel).filter(GroupInfoModel.group_name == group.group_name).first()
+    if db_group:
+        return Response(code=200, msg="group_name already created")
+    is_edit = await edit_group_data(db, group.id,
+                                    {"group_name": group.group_name, "group_description": group.group_description})
+    if not is_edit:
+        return Response(code=200, msg="group edit failure", data={})
+    return Response(code=200, msg="group edit successfully", data={"group_name": group.group_name})
+
+
+@group_router.post("/edit_group_status", response_model=Response)
+async def edit_group_status(group: GroupData, current_user: UserModel = Depends(get_current_user), db=Depends(get_db)):
+    if group.group_status not in [0, 1]:
+        return Response(code=400, msg="The status cannot be {}!".format(group.group_status))
+    db_group = db.query(GroupInfoModel).filter(GroupInfoModel.group_id == group.id).first()
+    if not db_group:
+        return Response(code=200, msg="group does not exist")
+    is_edit = await edit_group_data(db, group.id,
+                                    {"group_status": group.group_status})
+    if not is_edit:
+        return Response(code=200, msg="group status edit failure", data={})
+    return Response(code=200, msg="group status edit successfully", data={"group_name": group.group_name})
+
+
+@group_router.post("/delete_group", response_model=Response)
+async def delete_group(group: GroupData, current_user: UserModel = Depends(get_current_user), db=Depends(get_db)):
+    db_group = db.query(GroupInfoModel).filter(GroupInfoModel.group_id == group.id).first()
+    if not db_group:
+        return Response(code=200, msg="group does not exist")
+    is_edit = await delete_group_data(db, group.id)
+    if not is_edit:
+        return Response(code=200, msg="group delete failure", data={})
+    return Response(code=200, msg="group delete successfully", data={})
+
+
+@group_router.post("/group_users", response_model=Response)
+async def group_users(group: GroupData, current_user: UserModel = Depends(get_current_user), db=Depends(get_db)):
+    db_group = db.query(GroupInfoModel).filter(GroupInfoModel.group_id == group.id).first()
+    if not db_group:
+        return Response(code=200, data={})
+    return Response(code=200, msg="success", data=await get_group_users(db, group.id))
+
+
+@group_router.post("/save_group_user", response_model=Response)
+async def save_group_user(group_user: GroupUsers, current_user: UserModel = Depends(get_current_user),
+                          db=Depends(get_db)):
+    db_group = db.query(GroupInfoModel).filter(GroupInfoModel.group_id == group_user.id).first()
+    if not db_group:
+        return Response(code=200, msg="group does not exist")
+    is_success = await save_user_to_group(db, current_user.id, group_user.id, group_user.user_list)
+    if not is_success:
+        return Response(code=500, msg="save user to group failure", data={})
+    return Response(code=200, msg="success", data={})
diff --git a/app/api/user.py b/app/api/user.py
new file mode 100644
index 0000000..ca725f4
--- /dev/null
+++ b/app/api/user.py
@@ -0,0 +1,35 @@
+from fastapi import APIRouter, Depends
+from app.api import Response, pwd_context, get_current_user, ResponseList
+from app.config.config import settings
+from app.models.base_model import get_db
+from app.models.group_model import UserGroupModel
+from app.models.user_model import UserModel
+from app.service.bisheng import BishengService
+from app.service.ragflow import RagflowService
+from app.service.token import get_bisheng_token
+
+user_router = APIRouter()
+
+
+@user_router.post("/list", response_model=Response)
+async def user_list(current_user: UserModel = Depends(get_current_user), db=Depends(get_db)):
+
+    bisheng_service = BishengService(settings.sgb_base_url)
+    ragflow_service = RagflowService(settings.fwr_base_url)
+    db_user = db.query(UserModel).filter(UserGroupModel.group_name == UserModel.username).first()
+    if db_user:
+        return Response(code=200, msg="Username already registered")
+    # 娉ㄥ唽鍒版瘯鏄�
+    try:
+        token = get_bisheng_token(db, current_user.id)
+        print(token)
+        result = await bisheng_service.user_list(token)
+        print(result)
+    except Exception as e:
+        return Response(code=500, msg=f"Failed to register with Bisheng: {str(e)}")
+
+
+    return ResponseList(code=200, msg="", data=result)
+
+
+
diff --git a/app/models/group_model.py b/app/models/group_model.py
new file mode 100644
index 0000000..ec03d7d
--- /dev/null
+++ b/app/models/group_model.py
@@ -0,0 +1,52 @@
+from datetime import datetime
+from enum import IntEnum
+from typing import Optional
+
+from sqlalchemy import Column, Integer, String, DateTime, Enum, Index
+from pydantic import BaseModel
+from app.models.base_model import Base
+
+class GroupStatus(IntEnum):
+    NO = 1
+    OFF = 0
+
+
+
+class GroupInfoModel(Base):
+    __tablename__ = "group_info"
+    group_id = Column(Integer, primary_key=True, index=True)
+    group_name = Column(String(255), unique=True, nullable=False, index=True)
+    group_description = Column(String(255))
+    group_status = Column(Integer, nullable=False, default=1)
+    created_at = Column(DateTime, default=datetime.now())
+    updated_at = Column(DateTime, default=datetime.now(), onupdate=datetime.now())
+
+
+    def to_dict(self):
+        return {
+            'id': self.group_id,
+            'name': self.group_name,
+            'group_description': self.group_description,
+            'group_status': self.group_status,
+            'created_at': self.created_at.strftime("%Y.%m.%d %H:%M")
+        }
+
+
+class UserGroupModel(Base):
+    __tablename__ = "user_group"
+    id = Column(Integer, primary_key=True)
+    group_id = Column(Integer, nullable=False)
+    user_id = Column(Integer, nullable=False)
+    Index('ix_user_group_id', group_id, user_id, unique=True)
+
+
+class GroupData(BaseModel):
+    id: Optional[int] = None
+    group_name: Optional[str] = ""
+    group_description: Optional[str] = ""
+    group_status: Optional[int] = None
+
+class GroupUsers(BaseModel):
+    id: int
+    user_list: list
+
diff --git a/app/models/user.py b/app/models/user.py
index 3335170..45874f7 100644
--- a/app/models/user.py
+++ b/app/models/user.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
 from pydantic import BaseModel
 
 
@@ -21,3 +23,8 @@
     token_type: str
     bisheng_token: str
     ragflow_token: str
+
+class PageParameter(BaseModel):
+    page_index: int
+    page_size: int
+    keyword: Optional[str] = ""
\ No newline at end of file
diff --git a/app/models/user_model.py b/app/models/user_model.py
index 3a07ae2..5d5d51b 100644
--- a/app/models/user_model.py
+++ b/app/models/user_model.py
@@ -7,4 +7,10 @@
     __tablename__ = "user"
     id = Column(Integer, primary_key=True, index=True)
     username = Column(String(255), unique=True, index=True)
-    hashed_password = Column(String(255))
\ No newline at end of file
+    hashed_password = Column(String(255))
+    compellation = Column(String(255), nullable=False, default="")
+    phone = Column(String(255), nullable=False, default="")
+    email = Column(String(255), nullable=False, default="")
+    description = Column(String(255), nullable=False, default="")
+    ragflow_id = Column(String(32), unique=True, index=True)
+    bisheng_id = Column(Integer, unique=True, index=True)
\ No newline at end of file
diff --git a/app/service/bisheng.py b/app/service/bisheng.py
index e7c92fa..ad5d084 100644
--- a/app/service/bisheng.py
+++ b/app/service/bisheng.py
@@ -33,7 +33,7 @@
                 json={"user_name": username, "password": password},
                 headers={'Content-Type': 'application/json'}
             )
-            self._check_response(response)
+            return self._check_response(response)
 
     async def login(self, username: str, password: str) -> str:
         public_key = await self.get_public_key_api()
@@ -96,3 +96,11 @@
             }
 
             return result
+
+    async def user_list(self, token: str) -> list:
+        url = f"{self.base_url}/api/v1/user/list"
+        headers = {'cookie': f"access_token_cookie={token};"}
+        async with httpx.AsyncClient() as client:
+            response = await client.get(url, headers=headers)
+            data = self._check_response(response)
+            return data
diff --git a/app/service/group.py b/app/service/group.py
new file mode 100644
index 0000000..dbe7752
--- /dev/null
+++ b/app/service/group.py
@@ -0,0 +1,119 @@
+from sqlalchemy.testing.pickleable import Order
+
+from app.config.config import settings
+from app.models.group_model import GroupInfoModel, UserGroupModel
+from app.models.user_model import UserModel
+from app.service.ragflow import RagflowService
+from app.service.token import get_ragflow_token
+
+
+async def group_list(db, page_size: int, page_index: int, keyword: str):
+    query = db.query(GroupInfoModel)
+    if keyword:
+        query = query.filter(GroupInfoModel.group_name.like('%{}%'.format(keyword)))
+    items = query.order_by(GroupInfoModel.group_id.desc()).limit(page_size).offset(
+        (page_index - 1) * page_size).all()
+    items_list = [item.to_dict() for item in items]
+    groups = [i["id"] for i in items_list]
+    group_dict = {}
+    for group_user in db.query(UserGroupModel.group_id, UserModel.id, UserModel.username).outerjoin(UserModel,
+                                                                                                    UserModel.id == UserGroupModel.user_id).filter(
+        UserGroupModel.group_id.in_(groups)).all():
+        if group_user.group_id in group_dict:
+            group_dict[group_user.group_id].append({"user_id": group_user.id, "user_name": group_user.username})
+        else:
+            group_dict[group_user.group_id] = [{"user_id": group_user.id, "user_name": group_user.username}]
+    for item in items_list:
+        item["users"] = group_dict.get(item["id"], [])
+    return {"total": query.count(), "items": items_list}
+
+
+async def create_group(db, group_name: str, group_description: str):
+    try:
+        group_model = GroupInfoModel(group_name=group_name, group_description=group_description)
+        db.add(group_model)
+        db.commit()
+        db.refresh(group_model)
+    except Exception as e:
+        print(e)
+        db.rollback()
+        return False
+    return True
+
+
+async def edit_group_data(db, group_id: int, data):
+    try:
+        db.query(GroupInfoModel).filter(GroupInfoModel.group_id == group_id).update(data)
+        db.commit()
+    except Exception as e:
+        print(e)
+        db.rollback()
+        return False
+    return True
+
+
+async def delete_group_data(db, group_id: int):
+    try:
+        db.query(GroupInfoModel).filter(GroupInfoModel.group_id == group_id).delete()
+        db.commit()
+    except Exception as e:
+        print(e)
+        db.rollback()
+        return False
+    return True
+
+
+async def get_group_users(db, group_id):
+    not_group_user = []
+    in_group_user = []
+    user_list = [i.user_id for i in
+                 db.query(UserGroupModel.user_id).filter(UserGroupModel.group_id.__eq__(group_id)).all()]
+    for u in db.query(UserModel.id, UserModel.username).order_by(UserModel.id.desc()).all():
+        if u.id in user_list:
+            in_group_user.append({"user_id": u.id, "user_name": u.username})
+        else:
+            not_group_user.append({"user_id": u.id, "user_name": u.username})
+    return {"in_group": in_group_user, "not_in_group": not_group_user}
+
+
+async def save_user_to_group(db, user_id, group_id, user_list):
+    group_user_list = [i.user_id for i in
+                       db.query(UserGroupModel.user_id).filter(UserGroupModel.group_id.__eq__(group_id)).all()]
+    new_users = set([i for i in user_list if i not in group_user_list])
+    delete_user = [i for i in group_user_list if i not in user_list]
+    if new_users:
+
+        user_dict = {i.id: {"rg_id": i.ragflow_id, "email": i.email} for i in
+                     db.query(UserModel.id, UserModel.email, UserModel.ragflow_id).filter(
+                         UserModel.id.in_(user_list)).all()}
+        ragflow_service = RagflowService(settings.fwr_base_url)
+        token = get_ragflow_token(db, user_id)
+
+        try:
+            for old_user in group_user_list:
+                if old_user in delete_user:
+                    continue
+                for new_user in new_users:
+                    await ragflow_service.add_user_tenant(token, user_dict[old_user]["rg_id"], user_dict[new_user]["email"],
+                                                          user_dict[new_user]["rg_id"])
+                    await ragflow_service.add_user_tenant(token, user_dict[new_user]["rg_id"], user_dict[old_user]["email"],
+                                                  user_dict[old_user]["rg_id"])
+            for user1 in new_users:
+                for user2 in new_users:
+                    if user1 != user2:
+                        await ragflow_service.add_user_tenant(token, user_dict[user1]["rg_id"],
+                                                              user_dict[user2]["email"],
+                                                              user_dict[user2]["rg_id"])
+        except Exception as e:
+            print(e)
+            return False
+    try:
+        for user in new_users:
+            db_user = UserGroupModel(group_id=group_id, user_id=user)
+            db.add(db_user)
+        db.query(UserGroupModel).filter(UserGroupModel.group_id.__eq__(group_id), UserGroupModel.user_id.in_(delete_user)).delete()
+        db.commit()
+    except Exception as e:
+        print(e)
+        return False
+    return True
\ No newline at end of file
diff --git a/app/service/ragflow.py b/app/service/ragflow.py
index 7ce287d..94bd9ac 100644
--- a/app/service/ragflow.py
+++ b/app/service/ragflow.py
@@ -1,6 +1,7 @@
 import httpx
 from typing import Union, Dict, List
 
+from Tools.scripts.objgraph import ignore
 from fastapi import HTTPException
 from starlette import status
 
@@ -44,6 +45,7 @@
             )
             if response.status_code != 200:
                 raise Exception(f"Ragflow registration failed: {response.text}")
+            return self._handle_response(response)
 
     async def login(self, username: str, password: str) -> str:
         password = RagflowCrypto(settings.PUBLIC_KEY, settings.PRIVATE_KEY).encrypt(password)
@@ -145,3 +147,15 @@
             response = await client.post(url, headers=headers, files=files, data=data)
             data = self._handle_response(response)
             return data
+
+    async def add_user_tenant(self, token: str, tenant_id: str, email: str, user_id: str) -> str:
+        url = f"{self.base_url}/v1/tenant/{tenant_id}/user"
+        headers = {"Authorization": token}
+        data = {"email": email, "user_id": user_id}
+        print(url)
+        print(data)
+        async with httpx.AsyncClient(timeout=60) as client:
+            response = await client.post(url, headers=headers, json=data)
+            print(response.text)
+            if response.status_code != 200:
+                raise Exception(f"Ragflow add user to tenant failed: {response.text}")
diff --git a/main.py b/main.py
index e629cbf..2c8564f 100644
--- a/main.py
+++ b/main.py
@@ -7,6 +7,8 @@
 from app.api.excel import router as excel_router
 from app.api.files import router as files_router
 from app.api.report import router as report_router
+from app.api.user import user_router
+from app.api.group import group_router
 from app.models.base_model import init_db
 from app.task.fetch_agent import sync_agents, initialize_agents
 
@@ -36,6 +38,8 @@
 app.include_router(excel_router, prefix='/api/document', tags=["document"])
 app.include_router(files_router, prefix='/api/files', tags=["files"])
 app.include_router(report_router, prefix='/api/report', tags=["report"])
+app.include_router(user_router, prefix='/api/user', tags=["user"])
+app.include_router(group_router, prefix='/api/group', tags=["group"])
 
 if __name__ == "__main__":
     import uvicorn

--
Gitblit v1.8.0