From 244c884d0a7c54c4a37de18c1c2c8ff15a506ff7 Mon Sep 17 00:00:00 2001
From: zhangqian <zhangqian@123.com>
Date: 星期二, 15 十月 2024 00:11:31 +0800
Subject: [PATCH] 智能体列表接口,智能体会话记录接口

---
 app/models/agent_model.py |   15 +++++
 app/service/bisheng.py    |   21 +++++++
 requirements.txt          |    0 
 app/service/ragflow.py    |   21 +++++++
 app/api/chat.py           |    2 
 main.py                   |    2 
 app/api/agent.py          |   67 ++++++++++++++++++++++
 app/api/__init__.py       |    6 ++
 8 files changed, 132 insertions(+), 2 deletions(-)

diff --git a/app/api/__init__.py b/app/api/__init__.py
index 51c31da..8bd4579 100644
--- a/app/api/__init__.py
+++ b/app/api/__init__.py
@@ -21,6 +21,12 @@
     data: dict = {}
 
 
+class ResponseList(BaseModel):
+    code: int = 200
+    msg: str = ""
+    data: list[dict] = []
+
+
 def get_current_user(token: str = Depends(oauth2_scheme)):
     try:
         payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
diff --git a/app/api/agent.py b/app/api/agent.py
new file mode 100644
index 0000000..cded0f8
--- /dev/null
+++ b/app/api/agent.py
@@ -0,0 +1,67 @@
+from fastapi import Depends, APIRouter, Query, HTTPException
+from pydantic import BaseModel
+from sqlalchemy.orm import Session
+
+from app.api import Response, get_current_user, ResponseList
+from app.config.config import settings
+from app.models.agent_model import AgentType, AgentModel
+from app.models.base_model import get_db
+from app.models.user_model import UserModel
+from app.service.bisheng import BishengService
+from app.service.ragflow import RagflowService
+from app.service.token import get_ragflow_token, get_bisheng_token
+
+router = APIRouter()
+
+
+# Pydantic 妯″瀷鐢ㄤ簬鍝嶅簲
+class AgentResponse(BaseModel):
+    id: str
+    name: str
+    agent_type: AgentType
+
+    class Config:
+        orm_mode = True
+
+
+@router.get("/list", response_model=ResponseList)
+async def agent_list(db: Session = Depends(get_db)):
+    agents = db.query(AgentModel).all()
+    result = [
+        {
+            "id": item.id,
+            "name": item.name,
+            "agent_type": item.agent_type
+        }
+        for item in agents
+    ]
+    return ResponseList(code=200, msg="", data=result)
+
+
+@router.get("/{agent_id}/sessions", response_model=ResponseList)
+async def chat_list(agent_id: str, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_user)):
+    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
+    if not agent:
+        return ResponseList(code=404, msg="Agent not found")
+
+    if agent.agent_type == AgentType.RAGFLOW:
+        ragflow_service = RagflowService(base_url=settings.ragflow_base_url)
+        try:
+            token = get_ragflow_token(db, current_user.id)
+            result = await ragflow_service.get_chat_sessions(token, agent_id)
+        except Exception as e:
+            raise HTTPException(status_code=500, detail=str(e))
+        return ResponseList(code=200, msg="", data=result)
+
+    elif agent.agent_type == AgentType.BISHENG:
+        bisheng_service = BishengService(base_url=settings.bisheng_base_url)
+        try:
+            token = get_bisheng_token(db, current_user.id)
+            result = await bisheng_service.get_chat_sessions(token)
+        except Exception as e:
+            raise HTTPException(status_code=500, detail=str(e))
+        return ResponseList(code=200, msg="", data=result)
+
+    else:
+        return ResponseList(code=200, msg="Unsupported agent type")
+
diff --git a/app/api/chat.py b/app/api/chat.py
index 7667736..d5ab09d 100644
--- a/app/api/chat.py
+++ b/app/api/chat.py
@@ -135,5 +135,3 @@
                 print(f"Client {chat_id} disconnected")
             finally:
                 del client_websockets[chat_id]
-
-
diff --git a/app/models/agent_model.py b/app/models/agent_model.py
new file mode 100644
index 0000000..1b2bdb2
--- /dev/null
+++ b/app/models/agent_model.py
@@ -0,0 +1,15 @@
+from enum import IntEnum
+from sqlalchemy import Column, String, Enum as SQLAlchemyEnum
+from app.models.base_model import Base
+
+
+class AgentType(IntEnum):
+    RAGFLOW = 1
+    BISHENG = 2
+
+
+class AgentModel(Base):
+    __tablename__ = "agent"
+    id = Column(String(255), primary_key=True, index=True)
+    name = Column(String(255), index=True)
+    agent_type = Column(SQLAlchemyEnum(AgentType), nullable=False)  # 1 ragflow 2 bisheng
diff --git a/app/service/bisheng.py b/app/service/bisheng.py
index b71a932..3eb0dfd 100644
--- a/app/service/bisheng.py
+++ b/app/service/bisheng.py
@@ -1,3 +1,5 @@
+from datetime import datetime
+
 import httpx
 
 from app.config.config import settings
@@ -42,3 +44,22 @@
             if response.status_code != 200:
                 raise Exception(f"Failed to get public key: {response.text}")
             return response.json().get('data', {}).get('public_key')
+
+    async def get_chat_sessions(self, token: str) -> list:
+        url = f"{self.base_url}/api/v1/chat/list?page=1&limit=40"
+        headers = {'cookie': f"access_token_cookie={token};"}
+        async with httpx.AsyncClient() as client:
+            response = await client.get(url, headers=headers)
+            if response.status_code != 200:
+                raise Exception(f"Failed to fetch data from Bisheng API: {response.text}")
+
+            data = response.json().get("data", [])
+            result = [
+                {
+                    "id": item["chat_id"],
+                    "name": item["latest_message"]["message"],
+                    "updated_time": int(datetime.strptime(item["update_time"], "%Y-%m-%dT%H:%M:%S").timestamp() * 1000)
+                }
+                for item in data
+            ]
+            return result
diff --git a/app/service/ragflow.py b/app/service/ragflow.py
index df131f1..934af1a 100644
--- a/app/service/ragflow.py
+++ b/app/service/ragflow.py
@@ -59,3 +59,24 @@
                         return
                 else:
                     yield f"Error: {response.status_code}"
+
+    async def get_chat_sessions(self, token: str, dialog_id: str) -> list:
+        url = f"{self.base_url}/v1/conversation/list?dialog_id={dialog_id}"
+        headers = {
+            "Authorization": token
+        }
+        async with httpx.AsyncClient() as client:
+            response = await client.get(url, headers=headers)
+            if response.status_code != 200:
+                raise Exception(f"Failed to fetch data from Ragflow API: {response.text}")
+
+            data = response.json().get("data", [])
+            result = [
+                {
+                    "id": item["id"],
+                    "name": item["name"],
+                    "updated_time": item["update_time"]
+                }
+                for item in data
+            ]
+            return result
diff --git a/main.py b/main.py
index 95945d6..d5a8c5d 100644
--- a/main.py
+++ b/main.py
@@ -1,6 +1,7 @@
 from fastapi import FastAPI
 from app.api.auth import router as auth_router
 from app.api.chat import router as chat_router
+from app.api.agent import router as agent_router
 from app.models.base_model import init_db
 
 init_db()
@@ -12,6 +13,7 @@
 
 app.include_router(auth_router, prefix='/auth', tags=["auth"])
 app.include_router(chat_router, prefix='/chat', tags=["chat"])
+app.include_router(agent_router, prefix='/agent', tags=["agent"])
 
 if __name__ == "__main__":
     import uvicorn
diff --git a/requirements.txt b/requirements.txt
index da3411d..a1e863b 100644
--- a/requirements.txt
+++ b/requirements.txt
Binary files differ

--
Gitblit v1.8.0