zhangqian
2024-10-17 80978b3aec0e7f7a89d3ad671a9c0869c187be7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import httpx
from typing import Union, Dict, List
from app.config.config import settings
from app.utils.rsa_crypto import RagflowCrypto
 
 
class RagflowService:
    def __init__(self, base_url: str):
        self.base_url = base_url
 
    async def _handle_response(self, response: httpx.Response) -> Union[Dict, List]:
        if response.status_code != 200:
            return {}
 
        data = response.json()
        ret_code = data.get("retcode")
        if ret_code != 0:
            return {}
 
        # 检查返回的数据类型
        if isinstance(data.get("data"), dict):
            return data.get("data", {})
        elif isinstance(data.get("data"), list):
            return data.get("data", [])
        else:
            return {}
 
    async def register(self, username: str, password: str):
        password = RagflowCrypto(settings.PUBLIC_KEY, settings.PRIVATE_KEY).encrypt(password)
        async with httpx.AsyncClient() as client:
            response = await client.post(
                f"{self.base_url}/v1/user/register",
                headers={'Content-Type': 'application/json'},
                json={"nickname": username, "email": f"{username}@example.com", "password": password}
            )
            if response.status_code != 200:
                raise Exception(f"Ragflow registration failed: {response.text}")
 
    async def login(self, username: str, password: str) -> str:
        password = RagflowCrypto(settings.PUBLIC_KEY, settings.PRIVATE_KEY).encrypt(password)
        async with httpx.AsyncClient() as client:
            response = await client.post(
                f"{self.base_url}/v1/user/login",
                headers={'Content-Type': 'application/json'},
                json={"email": f"{username}@example.com", "password": password}
            )
            if response.status_code != 200:
                raise Exception(f"Ragflow login failed: {response.text}")
            authorization = response.headers.get('Authorization')
            if not authorization:
                raise Exception("Authorization header not found in response")
            return authorization
 
    async def chat(self, token: str, chat_id: str, chat_history: list):
        data = {
            "conversation_id": chat_id,
            "messages": chat_history
        }
 
        print(data)
        target_url = f"{self.base_url}/v1/conversation/completion"
        async with httpx.AsyncClient(timeout=10.0) as client:
            headers = {
                'Content-Type': 'application/json',
                'Authorization': token
            }
            async with client.stream("POST", target_url, json=data, headers=headers) as response:
                if response.status_code == 200:
                    try:
                        async for answer in response.aiter_text():
                            yield answer
                    except GeneratorExit as e:
                        print(e)
                        return
                else:
                    yield f"Error: {response.status_code}"
 
    async def get_chat_sessions(self, token: str, dialog_id: str) -> list:
        url = f"{self.base_url}/v1/conversation/list?dialog_id={dialog_id}"
        headers = {"Authorization": token}
        async with httpx.AsyncClient() as client:
            response = await client.get(url, headers=headers)
            data = await self._handle_response(response)
            result = [
                {
                    "id": item["id"],
                    "name": item["name"],
                    "updated_time": item["update_time"]
                }
                for item in data
            ]
            return result
 
    async def set_session(self, token: str, dialog_id: str, name: str, chat_id: str, is_new: bool) -> list:
        url = f"{self.base_url}/v1/conversation/set?dialog_id={dialog_id}"
        headers = {"Authorization": token}
        data = {
            "dialog_id": dialog_id,
            "name": name,
            "is_new": is_new,
            "conversation_id": chat_id,
        }
        async with httpx.AsyncClient() as client:
            response = await client.post(url, headers=headers, json=data)
            data = await self._handle_response(response)
            return [
                {
                    "content": "你好! 我是你的助理,有什么可以帮到你的吗?",
                    "role": "assistant"
                },
                {
                    "content": name,
                    "doc_ids": [],
                    "role": "user"
                }
            ] if data else []
 
    async def get_session_history(self, token: str, chat_id: str) -> list:
        url = f"{self.base_url}/v1/conversation/get?conversation_id={chat_id}"
        headers = {"Authorization": token}
        async with httpx.AsyncClient() as client:
            response = await client.get(url, headers=headers)
            data = await self._handle_response(response)
            return data.get("message", [])