xuyonghao
2024-11-14 b4a0d6ac3982621bf72aec7f0536d0d283446525
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import uuid
 
from fastapi import Depends, APIRouter, Query, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
 
from app.api import Response, get_current_user, ResponseList
from app.config.config import settings
from app.models.agent_model import AgentType, AgentModel
from app.models.base_model import get_db
from app.models.user_model import UserModel
from app.service.bisheng import BishengService
from app.service.ragflow import RagflowService
from app.service.token import get_ragflow_token, get_bisheng_token
 
router = APIRouter()
 
 
@router.get("/list", response_model=ResponseList)
async def agent_list(db: Session = Depends(get_db)):
    agents = db.query(AgentModel).order_by(AgentModel.sort.asc()).all()
    result = [item.to_dict() for item in agents]
    return ResponseList(code=200, msg="", data=result)
 
 
@router.get("/{agent_id}/sessions", response_model=ResponseList)
async def chat_list(agent_id: str, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_user)):
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    if not agent:
        return ResponseList(code=404, msg="Agent not found")
 
    if agent.agent_type == AgentType.RAGFLOW:
        ragflow_service = RagflowService(base_url=settings.fwr_base_url)
        try:
            token = get_ragflow_token(db, current_user.id)
            result = await ragflow_service.get_chat_sessions(token, agent_id)
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
        return ResponseList(code=200, msg="", data=result)
 
    elif agent.agent_type == AgentType.BISHENG:
        bisheng_service = BishengService(base_url=settings.sgb_base_url)
        try:
            token = get_bisheng_token(db, current_user.id)
            result = await bisheng_service.get_chat_sessions(token)
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
        return ResponseList(code=200, msg="", data=result)
 
    else:
        return ResponseList(code=200, msg="Unsupported agent type")
 
 
@router.get("/{agent_id}/{conversation_id}/session_log")
async def session_log(agent_id: str, conversation_id: str, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_user)):
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    if not agent:
        return Response(code=404, msg="Agent not found")
 
    if agent.agent_type == AgentType.RAGFLOW:
        ragflow_service = RagflowService(base_url=settings.fwr_base_url)
        try:
            token = get_ragflow_token(db, current_user.id)
            result = await ragflow_service.get_session_log(token, conversation_id)
            if 'session_log' in result and 'reference' in result:
                combined_logs = []
                last_question = None
                references = result['reference']
                reference_index = 0
                for session in result['session_log']:
                    if session['role'] == 'user':
                        last_question = session['message']
                    elif session['role'] == 'assistant' and last_question:
                        if reference_index < len(references):
                            reference = references[reference_index]
                        else:
                            reference = None
                        combined_logs.append({
                            'question': last_question,
                            'answer': session['message'],
                            'reference': reference
                        })
                        last_question = None
                        reference_index += 1
                return JSONResponse(status_code=200, content={"code": 200, "data": combined_logs})
            else:
 
                return JSONResponse(status_code=200, content={"code": 400, "message": "Invalid result structure"})
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
    if agent.agent_type == AgentType.BISHENG:
        bisheng_service = BishengService(base_url=settings.sgb_base_url)
        try:
            token = get_bisheng_token(db, current_user.id)
            result = await bisheng_service.get_session_log(token, agent_id, conversation_id)
            if 'session_log' in result:
                combined_logs = []
                last_question = None
                for session in result['session_log']:
                    if session['role'] == 'question':
                        last_question = session['message']
                    elif session['role'] == 'answer' and last_question:
                        combined_logs.append({
                            'question': last_question,
                            'answer': session['message']
                        })
                        last_question = None
                return JSONResponse(status_code=200, content={"code": 200, "log": combined_logs})
            else:
                return JSONResponse(status_code=200, content={"code": 400, "message": "Invalid result structure"})
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
    else:
        return JSONResponse(status_code=200, content={"code": 200, "log": "Unsupported agent type"})
 
 
@router.get("/get-chat-id/{agent_id}", response_model=Response)
async def get_chat_id(agent_id: str, db: Session = Depends(get_db)):
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
    if not agent:
        return Response(code=404, msg="Agent not found")
 
    return Response(code=200, msg="", data={"chat_id": uuid.uuid4().hex})