zhaoqingang
2024-11-12 f9a307e86b771f20bd2dc043a875b2ee86cc5d50
Merge branch 'master' of http://192.168.5.5:10010/r/rag-gateway
3个文件已修改
45 ■■■■ 已修改文件
app/api/agent.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/chat.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/ragflow.py 21 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/agent.py
@@ -1,6 +1,7 @@
import uuid
from fastapi import Depends, APIRouter, Query, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
@@ -51,6 +52,17 @@
        return ResponseList(code=200, msg="Unsupported agent type")
@router.get("/{conversation_id}/session_log")
async def session_log(conversation_id: str, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_user)):
    ragflow_service = RagflowService(base_url=settings.fwr_base_url)
    try:
        token = get_ragflow_token(db, current_user.id)
        result = await ragflow_service.get_session_log(token, conversation_id)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    return JSONResponse(status_code=200, content={"code": 200, "log": result})
@router.get("/get-chat-id/{agent_id}", response_model=Response)
async def get_chat_id(agent_id: str, db: Session = Depends(get_db)):
    agent = db.query(AgentModel).filter(AgentModel.id == agent_id).first()
app/api/chat.py
@@ -64,6 +64,7 @@
                                "doc_ids": message.get("doc_ids", []),
                                "role": "user"
                            })
                    complete_response = ""
                    async for rag_response in ragflow_service.chat(token, chat_id, chat_history):
                        try:
                            if rag_response[:5] == "data:":
@@ -72,8 +73,9 @@
                            else:
                                # 否则,保持原样
                                text = rag_response
                            complete_response += text
                            try:
                                json_data = json.loads(text)
                                json_data = json.loads(complete_response)
                                data = json_data.get("data")
                                if data is True:  # 完成输出
                                    result = {"message": "", "type": "close"}
@@ -85,10 +87,10 @@
                                    reference = data.get("reference", {})
                                    result = {"message": answer, "type": "message", "reference": reference}
                                await websocket.send_json(result)
                            except json.JSONDecodeError:
                                print(f"Error decode ragflow response: {text}")
                                pass
                                complete_response = ""
                            except json.JSONDecodeError as e:
                                print(f"Error decoding JSON: {e}")
                                print(f"Response text: {text}")
                        except Exception as e2:
                            result = {"message": f"内部错误: {e2}", "type": "close"}
                            await websocket.send_json(result)
app/service/ragflow.py
@@ -102,6 +102,27 @@
            ]
            return result
    async def get_session_log(self, token: str, conversation_id: str) -> dict:
        url = f"{self.base_url}/v1/conversation/get?conversation_id={conversation_id}"
        headers = {"Authorization": token}
        async with httpx.AsyncClient() as client:
            response = await client.get(url, headers=headers)
            data = self._handle_response(response)
            session_log = {
                "dialog_id": data.get("dialog_id"),
                "id": data.get("id"),
                "message": [
                    {
                        "content": message.get("content"),
                        "role": message.get("role"),
                    }
                    for message in data.get("message", [])
                ],
                "name": data.get("name"),
                "reference": data.get("reference"),
            }
        return session_log
    async def set_session(self, token: str, dialog_id: str, message: dict, chat_id: str, is_new: bool) -> list:
        url = f"{self.base_url}/v1/conversation/set?dialog_id={dialog_id}"
        headers = {"Authorization": token}