zhaoqingang
2025-03-07 af86455055918d26a0f6eebc270074c4863db0be
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
import asyncio
import io
import json
import uuid
 
import fitz
from fastapi import HTTPException
 
from Log import logger
from app.config.agent_base_url import RG_CHAT_DIALOG, DF_CHAT_AGENT, DF_CHAT_PARAMETERS, RG_CHAT_SESSIONS, \
    DF_CHAT_WORKFLOW, DF_UPLOAD_FILE, RG_ORIGINAL_URL
from app.config.config import settings
from app.config.const import *
from app.models import DialogModel, ApiTokenModel, UserTokenModel, ComplexChatSessionDao, ChatDataRequest
from app.models.v2.session_model import ChatSessionDao, ChatData
from app.service.v2.app_driver.chat_agent import ChatAgent
from app.service.v2.app_driver.chat_data import ChatBaseApply
from app.service.v2.app_driver.chat_dialog import ChatDialog
from app.service.v2.app_driver.chat_workflow import ChatWorkflow
from docx import Document
from dashscope import get_tokenizer  # dashscope版本 >= 1.14.0
 
 
async def update_session_log(db, session_id: str, message: dict, conversation_id: str):
    await ChatSessionDao(db).update_session_by_id(
        session_id=session_id,
        session=None,
        message=message,
        conversation_id=conversation_id
    )
 
 
async def add_session_log(db, session_id: str, question: str, chat_id: str, user_id, event_type: str,
                          conversation_id: str, agent_type):
    try:
        session = await ChatSessionDao(db).update_or_insert_by_id(
            session_id=session_id,
            name=question[:255],
            agent_id=chat_id,
            agent_type=agent_type,
            tenant_id=user_id,
            message={"role": "user", "content": question},
            conversation_id=conversation_id,
            event_type=event_type
        )
        return session
    except Exception as e:
        logger.error(e)
    return None
 
 
async def get_app_token(db, app_id):
    app_token = db.query(UserTokenModel).filter_by(id=app_id).first()
    if app_token:
        return app_token.access_token
    return ""
 
 
async def get_chat_token(db, app_id):
    app_token = db.query(ApiTokenModel).filter_by(app_id=app_id).first()
    if app_token:
        return app_token.token
    return ""
 
 
async def add_chat_token(db, data):
    try:
        api_token = ApiTokenModel(**data)
        db.add(api_token)
        db.commit()
    except Exception as e:
        logger.error(e)
 
 
async def get_chat_info(db, chat_id: str):
    return db.query(DialogModel).filter_by(id=chat_id, status=Dialog_STATSU_ON).first()
 
 
async def get_chat_object(mode):
    if mode == workflow_chat:
        url = settings.dify_base_url + DF_CHAT_WORKFLOW
        return ChatWorkflow(), url
    else:
        url = settings.dify_base_url + DF_CHAT_AGENT
        return ChatAgent(), url
 
 
async def service_chat_dialog(db, chat_id: str, question: str, session_id: str, user_id, mode: str):
    conversation_id = ""
    token = await get_chat_token(db, rg_api_token)
    url = settings.fwr_base_url + RG_CHAT_DIALOG.format(chat_id)
    chat = ChatDialog()
    session = await add_session_log(db, session_id, question, chat_id, user_id, mode, session_id, 1)
    if session:
        conversation_id = session.conversation_id
    message = {"role": "assistant", "answer": "", "reference": {}}
    try:
        async for ans in chat.chat_completions(url, await chat.request_data(question, conversation_id),
                                               await chat.get_headers(token)):
            data = {}
            error = ""
            status = http_200
            if ans.get("code", None) == 102:
                error = ans.get("message", "error!")
                status = http_400
                event = smart_message_error
            else:
                if isinstance(ans.get("data"), bool) and ans.get("data") is True:
                    event = smart_message_end
                else:
                    data = ans.get("data", {})
                    # conversation_id = data.get("session_id", "")
                    if "session_id" in data:
                        del data["session_id"]
                    message = data
                    event = smart_message_cover
            message_str = "data: " + json.dumps(
                {"event": event, "data": data, "error": error, "status": status, "session_id": session_id},
                ensure_ascii=False) + "\n\n"
            for i in range(0, len(message_str), max_chunk_size):
                chunk = message_str[i:i + max_chunk_size]
                # print(chunk)
                yield chunk  # 发送分块消息
    except Exception as e:
 
        logger.error(e)
        try:
            yield "data: " + json.dumps({"message": smart_message_error,
                                         "error": "\n**ERROR**: " + str(e), "status": http_500},
                                        ensure_ascii=False) + "\n\n"
        except:
            ...
    finally:
        message["role"] = "assistant"
        await update_session_log(db, session_id, message, conversation_id)
 
 
async def data_process(data):
    if isinstance(data, str):
        return data.replace("dify", "smart")
    elif isinstance(data, dict):
        for k in list(data.keys()):
            if isinstance(k, str) and "dify" in k:
                new_k = k.replace("dify", "smart")
                data[new_k] = await data_process(data[k])
                del data[k]
            else:
                data[k] = await data_process(data[k])
        return data
    elif isinstance(data, list):
        for i in range(len(data)):
            data[i] = await data_process(data[i])
        return data
    else:
        return data
 
 
async def service_chat_workflow(db, chat_id: str, chat_data: ChatData, session_id: str, user_id, mode: str):
    conversation_id = ""
    answer_event = ""
    answer_agent = ""
    answer_workflow = ""
    download_url = ""
    message_id = ""
    task_id = ""
    error = ""
    files = []
    node_list = []
    token = await get_chat_token(db, chat_id)
    chat, url = await get_chat_object(mode)
    if hasattr(chat_data, "query"):
        query = chat_data.query
    else:
        query = "start new conversation"
    session = await add_session_log(db, session_id, query if query else "start new conversation", chat_id, user_id,
                                    mode, conversation_id, 3)
    if session:
        conversation_id = session.conversation_id
    try:
        async for ans in chat.chat_completions(url,
                                               await chat.request_data(query, conversation_id, str(user_id), chat_data),
                                               await chat.get_headers(token)):
            data = {}
            status = http_200
            conversation_id = ans.get("conversation_id")
            task_id = ans.get("task_id")
            if ans.get("event") == message_error:
                error = ans.get("message", "参数异常!")
                status = http_400
                event = smart_message_error
            elif ans.get("event") == message_agent:
                data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")}
                answer_agent += ans.get("answer", "")
                message_id = ans.get("message_id", "")
                event = smart_message_stream
            elif ans.get("event") == message_event:
                data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")}
                answer_event += ans.get("answer", "")
                message_id = ans.get("message_id", "")
                event = smart_message_stream
            elif ans.get("event") == message_file:
                data = {"url": ans.get("url", ""), "id": ans.get("id", ""),
                        "type": ans.get("type", "")}
                files.append(data)
                event = smart_message_file
            elif ans.get("event") in [workflow_started, node_started, node_finished]:
                data = ans.get("data", {})
                data["inputs"] = await data_process(data.get("inputs", {}))
                data["outputs"] = await data_process(data.get("outputs", {}))
                data["files"] = await data_process(data.get("files", []))
                data["process_data"] = ""
                if data.get("status") == "failed":
                    status = http_500
                    error = data.get("error", "")
                node_list.append(ans)
                event = [smart_workflow_started, smart_node_started, smart_node_finished][
                    [workflow_started, node_started, node_finished].index(ans.get("event"))]
            elif ans.get("event") == workflow_finished:
                data = ans.get("data", {})
                answer_workflow = data.get("outputs", {}).get("output", data.get("outputs", {}).get("answer"))
                download_url = data.get("outputs", {}).get("download_url")
                event = smart_workflow_finished
                if data.get("status") == "failed":
                    status = http_500
                    error = data.get("error", "")
                node_list.append(ans)
 
            elif ans.get("event") == message_end:
                event = smart_message_end
            else:
                continue
 
            yield "data: " + json.dumps(
                {"event": event, "data": data, "error": error, "status": status, "task_id": task_id,
                 "session_id": session_id},
                ensure_ascii=False) + "\n\n"
 
    except Exception as e:
        logger.error(e)
        try:
            yield "data: " + json.dumps({"message": smart_message_error,
                                         "error": "\n**ERROR**: " + str(e), "status": http_500},
                                        ensure_ascii=False) + "\n\n"
        except:
            ...
    finally:
        await update_session_log(db, session_id, {"role": "assistant",
                                                  "answer": answer_event or answer_agent or answer_workflow or error,
                                                  "download_url": download_url,
                                                  "node_list": node_list, "task_id": task_id, "id": message_id,
                                                  "error": error}, conversation_id)
 
 
async def service_chat_basic(db, chat_id: str, chat_data: ChatData, session_id: str, user_id, mode: str):
    ...
 
 
async def service_chat_parameters(db, chat_id, user_id):
    chat_info = db.query(DialogModel).filter_by(id=chat_id).first()
    if not chat_info:
        return {}
    return chat_info.parameters
 
 
async def service_chat_sessions(db, chat_id, name):
    token = await get_chat_token(db, rg_api_token)
    # print(token)
    if not token:
        return {}
    url = settings.fwr_base_url + RG_CHAT_SESSIONS.format(chat_id)
    chat = ChatDialog()
    return await chat.chat_sessions(url, {"name": name}, await chat.get_headers(token))
 
 
async def service_chat_sessions_list(db, chat_id, current, page_size, user_id, keyword):
    total, session_list = await ChatSessionDao(db).get_session_list(
        user_id=user_id,
        agent_id=chat_id,
        keyword=keyword,
        page=current,
        page_size=page_size
    )
    return json.dumps({"total": total, "rows": [session.to_dict() for session in session_list]})
 
 
async def service_chat_session_log(db, session_id):
    session_log = await ChatSessionDao(db).get_session_by_id(session_id)
    return json.dumps(session_log.log_to_json() if session_log else {})
 
 
async def service_chat_upload(db, chat_id, file, user_id):
    files = []
    token = await get_chat_token(db, chat_id)
    if not token:
        return files
    url = settings.dify_base_url + DF_UPLOAD_FILE
    chat = ChatBaseApply()
    for f in file:
        try:
            file_content = await f.read()
            file_upload = await chat.chat_upload(url, {"file": (f.filename, file_content)}, {"user": str(user_id)},
                                                 {'Authorization': f'Bearer {token}'})
            try:
                tokens = await read_file(file_content, f.filename, f.content_type)
                file_upload["tokens"] = tokens
            except:
                ...
            files.append(file_upload)
        except Exception as e:
            logger.error(e)
    return json.dumps(files) if files else ""
 
 
async def get_str_token(input_str):
    # 获取tokenizer对象,目前只支持通义千问系列模型
    tokenizer = get_tokenizer('qwen-turbo')
    # 将字符串切分成token并转换为token id
    tokens = tokenizer.encode(input_str)
    return len(tokens)
 
 
async def read_pdf(pdf_stream):
    text = ""
    with fitz.open(stream=pdf_stream, filetype="pdf") as pdf_document:
        for page in pdf_document:
            text += page.get_text()
    return text
 
 
async def read_word(word_stream):
    # 使用 python-docx 打开 Word 文件流
    doc = Document(io.BytesIO(word_stream))
 
    # 提取每个段落的文本
    text = ""
    for para in doc.paragraphs:
        text += para.text
 
    return text
 
 
async def read_file(file, filename, content_type):
    text = ""
    if content_type == "application/pdf" or filename.endswith('.pdf'):
 
        # 提取 PDF 内容
        text = await read_pdf(file)
    elif content_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" or filename.endswith(
            '.docx'):
        text = await read_word(file)
 
    return await get_str_token(text)
 
 
async def service_chunk_retrieval(query, knowledge_id, top_k, similarity_threshold, api_key):
    # print(query)
 
    try:
        request_data = json.loads(query)
        payload = {
            "question": request_data.get("query", ""),
            "dataset_ids": request_data.get("dataset_ids", []),
            "page_size": top_k,
            "similarity_threshold": similarity_threshold if similarity_threshold else 0.2
        }
    except json.JSONDecodeError as e:
        fixed_json = query.replace("'", '"')
        try:
            request_data = json.loads(fixed_json)
            payload = {
                "question": request_data.get("query", ""),
                "dataset_ids": request_data.get("dataset_ids", []),
                "page_size": top_k,
                "similarity_threshold": similarity_threshold if similarity_threshold else 0.2
            }
        except Exception:
            payload = {
                "question": query,
                "dataset_ids": [knowledge_id],
                "page_size": top_k,
                "similarity_threshold": similarity_threshold if similarity_threshold else 0.2
            }
    # print(payload)
    url = settings.fwr_base_url + RG_ORIGINAL_URL
    chat = ChatBaseApply()
    response = await  chat.chat_post(url, payload, await chat.get_headers(api_key))
    if not response:
        raise HTTPException(status_code=500, detail="服务异常!")
    records = [
        {
            "content": chunk["content"],
            "score": chunk["similarity"],
            "title": chunk.get("document_keyword", "Unknown Document"),
            "metadata": {"document_id": chunk["document_id"],
                         "path": f"{settings.fwr_base_url}/document/{chunk['document_id']}?ext={chunk.get('document_keyword').split('.')[-1]}&prefix=document",
                         'highlight': chunk.get("highlight"), "image_id": chunk.get("image_id"),
                         "positions": chunk.get("positions"), }
        }
        for chunk in response.get("data", {}).get("chunks", [])
    ]
    # print(len(records))
    # print(records)
    return records
 
 
async def service_base_chunk_retrieval(query, knowledge_id, top_k, similarity_threshold, api_key):
    # request_data = json.loads(query)
    payload = {
        "question": query,
        "dataset_ids": [knowledge_id],
        "page_size": top_k,
        "similarity_threshold": similarity_threshold
    }
    url = settings.fwr_base_url + RG_ORIGINAL_URL
    # url = "http://192.168.20.116:11080/" + RG_ORIGINAL_URL
    chat = ChatBaseApply()
    response = await chat.chat_post(url, payload, await chat.get_headers(api_key))
    if not response:
        raise HTTPException(status_code=500, detail="服务异常!")
    records = [
        {
            "content": chunk["content"],
            "score": chunk["similarity"],
            "title": chunk.get("document_keyword", "Unknown Document"),
            "metadata": {"document_id": chunk["document_id"]}
        }
        for chunk in response.get("data", {}).get("chunks", [])
    ]
    return records
 
 
async def add_complex_log(db, message_id, chat_id, session_id, chat_mode, query, user_id, mode, agent_type, message_type, conversation_id="", node_data=None, query_data=None):
    if not node_data:
        node_data = []
    if not query_data:
        query_data = {}
    try:
        complex_log = ComplexChatSessionDao(db)
        if not conversation_id:
            session = await complex_log.get_session_by_session_id(session_id, chat_id)
            if session:
                conversation_id = session.conversation_id
        await complex_log.create_session(message_id,
                                     chat_id=chat_id,
                                     session_id=session_id,
                                     chat_mode=chat_mode,
                                     message_type=message_type,
                                     content=query,
                                     event_type=mode,
                                     tenant_id=user_id,
                                     conversation_id=conversation_id,
                                     node_data=json.dumps(node_data),
                                     query=json.dumps(query_data),
                                     agent_type=agent_type)
        return conversation_id, True
 
    except Exception as e:
        logger.error(e)
        return conversation_id, False
 
 
 
async def service_complex_chat(db, chat_id, mode, user_id, chat_request: ChatDataRequest):
    answer_event = ""
    answer_agent = ""
    answer_workflow = ""
    download_url = ""
    message_id = ""
    task_id = ""
    error = ""
    files = []
    node_list = []
    token = await get_chat_token(db, chat_id)
    chat, url = await get_chat_object(mode)
    conversation_id, message = await add_complex_log(db, str(uuid.uuid4()),chat_id, chat_request.sessionId, chat_request.chatMode, chat_request.query, user_id, mode, DF_TYPE, 1, query_data=chat_request.to_dict())
    if not message:
        yield "data: " + json.dumps({"message": smart_message_error,
                                     "error": "\n**ERROR**: 创建会话失败!", "status": http_500},
                                    ensure_ascii=False) + "\n\n"
        return
    inputs = {"is_deep": chat_request.isDeep}
    if chat_request.chatMode == complex_knowledge_chat:
        inputs["query_json"] = json.dumps({"query": chat_request.query, "dataset_ids": chat_request.knowledgeId})
 
    try:
        async for ans in chat.chat_completions(url,
                                               await chat.complex_request_data(chat_request.query, conversation_id, str(user_id), files=chat_request.files, inputs=inputs),
                                               await chat.get_headers(token)):
            print(ans)
            data = {}
            status = http_200
            conversation_id = ans.get("conversation_id")
            task_id = ans.get("task_id")
            if ans.get("event") == message_error:
                error = ans.get("message", "参数异常!")
                status = http_400
                event = smart_message_error
            elif ans.get("event") == message_agent:
                data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")}
                answer_agent += ans.get("answer", "")
                message_id = ans.get("message_id", "")
                event = smart_message_stream
            elif ans.get("event") == message_event:
                data = {"answer": ans.get("answer", ""), "id": ans.get("message_id", "")}
                answer_event += ans.get("answer", "")
                message_id = ans.get("message_id", "")
                event = smart_message_stream
            elif ans.get("event") == message_file:
                data = {"url": ans.get("url", ""), "id": ans.get("id", ""),
                        "type": ans.get("type", "")}
                files.append(data)
                event = smart_message_file
            elif ans.get("event") in [workflow_started, node_started, node_finished]:
                data = ans.get("data", {})
                data["inputs"] = await data_process(data.get("inputs", {}))
                data["outputs"] = await data_process(data.get("outputs", {}))
                data["files"] = await data_process(data.get("files", []))
                data["process_data"] = ""
                if data.get("status") == "failed":
                    status = http_500
                    error = data.get("error", "")
                node_list.append(ans)
                event = [smart_workflow_started, smart_node_started, smart_node_finished][
                    [workflow_started, node_started, node_finished].index(ans.get("event"))]
            elif ans.get("event") == workflow_finished:
                data = ans.get("data", {})
                answer_workflow = data.get("outputs", {}).get("output", data.get("outputs", {}).get("answer"))
                download_url = data.get("outputs", {}).get("download_url")
                event = smart_workflow_finished
                if data.get("status") == "failed":
                    status = http_500
                    error = data.get("error", "")
                node_list.append(ans)
 
            elif ans.get("event") == message_end:
                event = smart_message_end
            else:
                continue
 
            yield "data: " + json.dumps(
                {"event": event, "data": data, "error": error, "status": status, "task_id": task_id, "message_id":message_id,
                 "session_id": chat_request.sessionId},
                ensure_ascii=False) + "\n\n"
 
    except Exception as e:
        logger.error(e)
        try:
            yield "data: " + json.dumps({"message": smart_message_error,
                                         "error": "\n**ERROR**: " + str(e), "status": http_500},
                                        ensure_ascii=False) + "\n\n"
        except:
            ...
    finally:
        # await update_session_log(db, session_id, {"role": "assistant",
        #                                           "answer": answer_event or answer_agent or answer_workflow or error,
        #                                           "download_url": download_url,
        #                                           "node_list": node_list, "task_id": task_id, "id": message_id,
        #                                           "error": error}, conversation_id)
        if message_id:
            await add_complex_log(db, message_id, chat_id, chat_request.sessionId, chat_request.chatMode, answer_event or answer_agent or answer_workflow or error, user_id, mode, DF_TYPE, 2, conversation_id, node_data=node_list, query_data=chat_request.to_dict())
 
async def service_complex_upload(db, chat_id, file, user_id):
    files = []
    token = await get_chat_token(db, chat_id)
    if not token:
        return files
    url = settings.dify_base_url + DF_UPLOAD_FILE
    chat = ChatBaseApply()
    for f in file:
        try:
            file_content = await f.read()
            file_upload = await chat.chat_upload(url, {"file": (f.filename, file_content)}, {"user": str(user_id)},
                                                 {'Authorization': f'Bearer {token}'})
            # try:
            #     tokens = await read_file(file_content, f.filename, f.content_type)
            #     file_upload["tokens"] = tokens
            # except:
            #     ...
            files.append(file_upload)
        except Exception as e:
            logger.error(e)
    return json.dumps(files) if files else ""
 
if __name__ == "__main__":
    q = json.dumps({"query": "设备", "dataset_ids": ["fc68db52f43111efb94a0242ac120004"]})
    top_k = 2
    similarity_threshold = 0.5
    api_key = "ragflow-Y4MGYwY2JlZjM2YjExZWY4ZWU5MDI0Mm"
 
 
    # a = service_chunk_retrieval(q, top_k, similarity_threshold, api_key)
    # print(a)
    async def a():
        b = await service_chunk_retrieval(q, top_k, similarity_threshold, api_key)
        print(b)
 
 
    asyncio.run(a())