zhaoqingang
2024-11-19 5ef590b70cc8e2de16083af2ee2d977daae5587c
会话缓存本地数据库
6个文件已修改
83 ■■■■■ 已修改文件
app/api/agent.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/chat.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/dialog_model.py 27 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/dialog.py 36 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/group.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/ragflow.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/agent.py
@@ -12,6 +12,7 @@
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.bisheng import BishengService
from app.service.dialog import get_session_history
from app.service.ragflow import RagflowService
from app.service.service_token import get_ragflow_token, get_bisheng_token
@@ -41,6 +42,8 @@
        try:
            token = get_ragflow_token(db, current_user.id)
            result = await ragflow_service.get_chat_sessions(token, agent_id)
            if not result:
                result = await get_session_history(db, current_user.id, agent_id)
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
        return ResponseList(code=200, msg="", data=result)
app/api/chat.py
@@ -10,6 +10,7 @@
from app.models.agent_model import AgentModel, AgentType
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.dialog import update_session_history
from app.service.ragflow import RagflowService
from app.service.service_token import get_bisheng_token, get_ragflow_token
@@ -44,6 +45,7 @@
        try:
            async def forward_to_ragflow():
                while True:
                    is_new = False
                    message = await websocket.receive_json()
                    print(f"Received from client {chat_id}: {message}")
                    chat_history = message.get('chatHistory', [])
@@ -51,9 +53,10 @@
                    if len(chat_history) == 0:
                        chat_history = await ragflow_service.get_session_history(token, chat_id)
                        if len(chat_history) == 0:
                            is_new = True
                            chat_history = await ragflow_service.set_session(token, agent_id,
                                                                             message, chat_id, True)
                            print("chat_history------------------------", chat_history)
                            # print("chat_history------------------------", chat_history)
                            if len(chat_history) == 0:
                                result = {"message": "内部错误:创建会话失败", "type": "close"}
                                await websocket.send_json(result)
@@ -91,11 +94,14 @@
                                complete_response = ""
                            except json.JSONDecodeError as e:
                                print(f"Error decoding JSON: {e}")
                                print(f"Response text: {text}")
                                # print(f"Response text: {text}")
                        except Exception as e2:
                            result = {"message": f"内部错误: {e2}", "type": "close"}
                            await websocket.send_json(result)
                            print(f"Error process message of ragflow: {e2}")
                    dialog_chat_history = await ragflow_service.get_session_history(token, chat_id, 1)
                    await update_session_history(db, dialog_chat_history, current_user.id, is_new)
            # 启动任务处理客户端消息
            tasks = [
                asyncio.create_task(forward_to_ragflow())
app/models/dialog_model.py
@@ -1,6 +1,6 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Table, ForeignKey, DateTime, BigInteger, Text, Float, Boolean
from sqlalchemy import Column, Integer, String, BigInteger, ForeignKey, DateTime, Text, JSON
from sqlalchemy.orm import relationship, backref
from app.models.base_model import Base
@@ -33,4 +33,29 @@
            'description': self.description,
            'icon': self.icon,
            'status': self.status
        }
class ConversationModel(Base):
    __tablename__ = 'conversation'
    id = Column(String(32), primary_key=True)  #  id
    create_date = Column(DateTime)             # 创建时间
    create_time = Column(BigInteger)
    update_date = Column(DateTime)             # 更新时间
    update_time = Column(BigInteger)
    tenant_id = Column(Integer)              # 创建人
    dialog_id = Column(String(32))
    name = Column(String(255))                 # 名称
    message = Column(JSON)                 # 说明
    reference = Column(JSON)                         # 图标
    def get_id(self):
        return str(self.id)
    def to_json(self):
        return {
            'id': self.id,
            'updated_time': self.update_time,
            'name': self.name,
        }
app/service/dialog.py
@@ -1,4 +1,5 @@
from app.models import KnowledgeModel, GroupModel, DialogModel
from app.api.user import user_list
from app.models import KnowledgeModel, GroupModel, DialogModel, ConversationModel
from app.models.user_model import UserModel
from Log import logger
@@ -19,4 +20,35 @@
                    dialog_list.append(k)
                    kld_set.add(k.id)
    return {"rows": [kld.to_json() for kld in dialog_list]}
    return {"rows": [kld.to_json() for kld in dialog_list]}
async def update_session_history(db, data: dict, user_id, is_new):
    session_id = data.get("id")
    if not session_id:
        logger.error("更新回话记录失败!{}".format(data))
    if is_new:
        try:
            data["tenant_id"] = user_id
            conversation_model = ConversationModel(**data)
            db.add(conversation_model)
            db.commit()
        except Exception as e:
            logger.error(e)
            db.rollback()
    else:
        try:
            data["tenant_id"] = user_id
            del data["id"]
            db.query(ConversationModel).filter(ConversationModel.id == session_id).update(data)
            db.commit()
        except Exception as e:
            logger.error(e)
            db.rollback()
async def get_session_history(db, user_id, dialog_id):
    session_list = db.query(ConversationModel).filter(ConversationModel.tenant_id.__eq__(user_id),
                                                      ConversationModel.dialog_id.__eq__(dialog_id)).order_by(
        ConversationModel.update_time.desc()).all()
    return [i.to_json() for i in session_list]
app/service/group.py
@@ -95,6 +95,7 @@
            for user1 in new_users:
                for user2 in new_users:
                    if user1 != user2:
                        print(user1, user2)
                        await ragflow_service.add_user_tenant(token, user_dict[user1]["rg_id"],
                                                              user_dict[user2]["email"],
                                                              user_dict[user2]["rg_id"])
app/service/ragflow.py
@@ -146,12 +146,15 @@
                }
            ] if data else []
    async def get_session_history(self, token: str, chat_id: str) -> list:
    async def get_session_history(self, token: str, chat_id: str, is_all: int=0):
        url = f"{self.base_url}/v1/conversation/get?conversation_id={chat_id}"
        headers = {"Authorization": token}
        async with httpx.AsyncClient() as client:
            response = await client.get(url, headers=headers)
            data = self._handle_response(response)
            # print("----------------data----------------------:", data)
            if is_all:
                return data
            return data.get("message", [])
    async def upload_and_parse(self, token: str, chat_id: str, filename: str, file: bytes) -> str:
@@ -172,6 +175,7 @@
        data = {"email": email, "user_id": user_id}
        async with httpx.AsyncClient(timeout=60) as client:
            response = await client.post(url, headers=headers, json=data)
            print(response)
            if response.status_code != 200:
                raise Exception(f"Ragflow add user to tenant failed: {response.text}")