From 5580958d49e5aab48908000614e47ecb75ff4797 Mon Sep 17 00:00:00 2001
From: zhaoqingang <zhaoqg0118@163.com>
Date: 星期四, 28 十一月 2024 19:14:26 +0800
Subject: [PATCH] 智能数据问题优化

---
 app/models/postgresql_base_model.py |   17 +++++
 requirements.txt                    |    0 
 app/config/config.py                |    1 
 app/api/chat.py                     |   41 ++++++++-----
 app/config/config.yaml              |    3 
 app/service/basic.py                |   15 +++-
 app/models/app_token_model.py       |    9 +++
 app/api/agent.py                    |   12 ++++
 app/api/auth.py                     |   54 +++++++++++++++++
 9 files changed, 128 insertions(+), 24 deletions(-)

diff --git a/app/api/agent.py b/app/api/agent.py
index 2f27f23..698ca0d 100644
--- a/app/api/agent.py
+++ b/app/api/agent.py
@@ -154,12 +154,24 @@
                 if i.get("role") == "user":
                     tmp_data["question"]=i.get("content")
                 elif i.get("role") == "assistant":
+
                     if isinstance(i.get("content"), dict):
                         tmp_data["answer"] = i.get("content", {}).get("message")
                         if "file_name" in i.get("content", {}):
                             tmp_data["files"] = [{"file_name":i.get("content", {}).get("file_name"), "file_url":i.get("content", {}).get("file_url")}]
                     else:
                         tmp_data["answer"] = i.get("content")
+
+                    if "excel_url" in i:
+                        tmp_data["excel_url"] = i.get("excel_url")
+                    if "image_url" in i:
+                        tmp_data["image_url"] = i.get("image_url")
+                    if "sql" in i:
+                        tmp_data["sql"] = i.get("sql")
+                    if "code" in i:
+                        tmp_data["code"] = i.get("code")
+                    if "e" in i:
+                        tmp_data["e"] = i.get("e")
                     data.append(tmp_data)
                     tmp_data = {}
 
diff --git a/app/api/auth.py b/app/api/auth.py
index 72b0bbf..6e74966 100644
--- a/app/api/auth.py
+++ b/app/api/auth.py
@@ -2,16 +2,20 @@
 
 from fastapi import APIRouter, Depends
 from sqlalchemy.orm import Session
-
+from sqlalchemy.ext.asyncio import AsyncSession
 from app.api import Response, pwd_context, get_current_user
 from app.config.config import settings
+from app.models.app_token_model import AppToken
 from app.models.base_model import get_db
+from app.models.postgresql_base_model import get_pdb
 from app.models.token_model import upsert_token, get_token
 from app.models.user import UserCreate, LoginData
 from app.models.user_model import UserModel
 from app.service.auth import authenticate_user, create_access_token
 from app.service.bisheng import BishengService
 from app.service.ragflow import RagflowService
+from sqlalchemy.future import select
+
 
 router = APIRouter()
 
@@ -91,3 +95,51 @@
     return Response(code=200, msg="success", data={
         "ragflow_token": token.ragflow_token,
     })
+
+
+@router.post("/login_test", response_model=Response)
+async def login_test(login_data: LoginData, db: Session = Depends(get_db), pdb: AsyncSession = Depends(get_pdb)):
+    user = authenticate_user(db, login_data.username, login_data.password)
+    if not user:
+        return Response(code=400, msg="Incorrect username or password")
+
+    bisheng_service = BishengService(settings.sgb_base_url)
+    ragflow_service = RagflowService(settings.fwr_base_url)
+
+    # 鐧诲綍鍒版瘯鏄�
+    try:
+        bisheng_token = await bisheng_service.login(login_data.username, login_data.password)
+    except Exception as e:
+        return Response(code=500, msg=f"Failed to login with Bisheng: {str(e)}")
+
+    # 鐧诲綍鍒皉agflow
+    try:
+        ragflow_token = await ragflow_service.login(login_data.username, login_data.password)
+    except Exception as e:
+        return Response(code=500, msg=f"Failed to login with Ragflow: {str(e)}")
+
+    # 鍒涘缓鏈湴token
+    access_token = create_access_token(data={"sub": user.username, "user_id": user.id})
+
+    upsert_token(db, user.id, access_token, bisheng_token, ragflow_token)
+    result = await pdb.execute(select(AppToken).where(AppToken.id == user.id))
+    db_app_token = result.scalars().first()
+    if not db_app_token:
+        app_token_str = json.dumps({"rag_token": ragflow_token, "bs_token":bisheng_token})
+        # print(app_token_str)
+        app_token = AppToken(id=user.id, token=access_token.decode(), app_token=app_token_str)
+        pdb.add(app_token)
+        await pdb.commit()
+        await pdb.refresh(app_token)
+    else:
+        db_app_token.token = access_token.decode()
+        db_app_token.app_token = json.dumps({"rag_token": ragflow_token, "bs_token":bisheng_token})
+        await pdb.commit()
+        await pdb.refresh(db_app_token)
+    return Response(code=200, msg="Login successful", data={
+        "access_token": access_token,
+        "token_type": "bearer",
+        "username": user.username,
+        "nickname": "",
+        # "user": user.to_login_json()
+    })
diff --git a/app/api/chat.py b/app/api/chat.py
index 7e85a8e..f21f07c 100644
--- a/app/api/chat.py
+++ b/app/api/chat.py
@@ -251,9 +251,12 @@
                         await websocket.send_json(result)
 
                 else:
+                    message_data = {}
                     logger.error("---------------------excel_talk-----------------------------")
+                    excel_url = ""
+                    image_url = ""
                     async for data in service.excel_talk(question, chat_id):
-                        logger.error(data)
+                        # logger.error(data)
                         output = data.get("output", "")
                         excel_name = data.get("excel_name", "")
                         image_name = data.get("image_name", "")
@@ -263,27 +266,31 @@
                                 return None
                             return (f"/api/files/download/?agent_id={agent_id}&file_id={name}"
                                     f"&file_type={file_type}")
-                        excel_url = build_file_url(excel_name, 'excel')
-                        image_url = build_file_url(image_name, 'image')
-                        if excel_url or data.get("e", ""):
-                            try:
-                                SessionService(db).update_session(chat_id,
-                                                                  message={
-                                                                      "content": output,
-                                                                      "excel_url": excel_url,
-                                                                      "image_url": image_url,
-                                                                      "sql": data.get("sql", ""),
-                                                                      "code": data.get("code", ""),
-                                                                      "e": data.get("e", ""),
-                                                                      "role": "assistant"})
-                            except Exception as e:
-                                logger.error(f"Unexpected error when update_session: {e}")
+                        if excel_name:
+                            excel_url = build_file_url(excel_name, 'excel')
+                        if image_name:
+                            image_url = build_file_url(image_name, 'image')
+                        if data["type"] == "message":
+                            message_data = {
+                                "content": output,
+                                "excel_url": excel_url,
+                                "image_url": image_url,
+                                "sql": data.get("sql", ""),
+                                "code": data.get("code", ""),
+                                "e": data.get("e", ""),
+                                "role": "assistant"}
+
                         # 鍙戦�佺粨鏋滅粰瀹㈡埛绔�
-                        data["type"] = "message"
+                        # data["type"] = "message"
                         data["message"] = output
                         data["excel_url"] = excel_url
                         data["image_url"] = image_url
                         await websocket.send_json(data)
+                    if message_data:
+                        try:
+                            SessionService(db).update_session(chat_id,message=message_data)
+                        except Exception as e:
+                            logger.error(f"Unexpected error when update_session: {e}")
         except Exception as e:
             logger.error(e)
             await websocket.send_json({"message": "鍑虹幇閿欒锛�", "type": "error"})
diff --git a/app/config/config.py b/app/config/config.py
index 3c97edc..46855be 100644
--- a/app/config/config.py
+++ b/app/config/config.py
@@ -19,6 +19,7 @@
     basic_paper_url: str = ''
     dify_base_url: str = ''
     dify_api_token: str = ''
+    postgresql_database_url: str = ''
     def __init__(self, **kwargs):
         # Check if all required fields are provided and set them
         for field in self.__annotations__.keys():
diff --git a/app/config/config.yaml b/app/config/config.yaml
index f8a53b0..fe9963c 100644
--- a/app/config/config.yaml
+++ b/app/config/config.yaml
@@ -16,4 +16,5 @@
 basic_base_url: http://192.168.20.231:8000
 basic_paper_url: http://192.168.20.231:8000
 dify_base_url: http://192.168.20.116
-dify_api_token: app-YmOAMDsPpDDlqryMHnc9TzTO
\ No newline at end of file
+dify_api_token: app-YmOAMDsPpDDlqryMHnc9TzTO
+postgresql_database_url: postgresql+asyncpg://kong:kongpass@192.168.20.119:5432/kong
\ No newline at end of file
diff --git a/app/models/app_token_model.py b/app/models/app_token_model.py
new file mode 100644
index 0000000..9a0af3d
--- /dev/null
+++ b/app/models/app_token_model.py
@@ -0,0 +1,9 @@
+from sqlalchemy import Column, Integer, String
+from app.models.postgresql_base_model import PostgresqlBase
+
+class AppToken(PostgresqlBase):
+    __tablename__ = "app_service_token"
+
+    id = Column(Integer, primary_key=True, index=True)
+    token = Column(String, unique=True, index=True)
+    app_token = Column(String)
\ No newline at end of file
diff --git a/app/models/postgresql_base_model.py b/app/models/postgresql_base_model.py
new file mode 100644
index 0000000..f139db8
--- /dev/null
+++ b/app/models/postgresql_base_model.py
@@ -0,0 +1,17 @@
+import os
+
+from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import sessionmaker
+from app.config.config import settings
+
+DATABASE_URL = os.getenv('POSTGRESQL_DATABASE_URL') or settings.postgresql_database_url
+
+engine = create_async_engine(DATABASE_URL, echo=True)
+PostgresqlSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, class_=AsyncSession)
+PostgresqlBase = declarative_base()
+
+
+async def get_pdb() -> AsyncSession:
+    async with PostgresqlSessionLocal() as session:
+        yield session
\ No newline at end of file
diff --git a/app/service/basic.py b/app/service/basic.py
index b3ad295..29bb02a 100644
--- a/app/service/basic.py
+++ b/app/service/basic.py
@@ -51,7 +51,7 @@
         url = f"{self.base_url}/exceltalk/upload/files"
         params = {'chat_id': chat_id, 'is_col': '0'}
 
-        async with httpx.AsyncClient() as client:
+        async with httpx.AsyncClient(timeout=300) as client:
             response = await client.post(
                 url,
                 files=files,
@@ -73,18 +73,23 @@
                         if decoded_line.startswith("data:"):
                             decoded_line = decoded_line[5:]
                         answer = json.loads(decoded_line)
+                        answer["type"] = "message"
                         yield answer
                     except GeneratorExit as e:
                         logger.error("------------except GeneratorExit as e:---------------------")
                         logger.error(e)
                         print(e)
                         yield {"message": "鍐呴儴閿欒", "type": "close"}
-                    finally:
-                        # 鍦ㄦ墍鏈夋暟鎹帴鏀跺畬姣曞悗杩斿洖close
-                        yield {"message": "", "type": "close"}
+                    # finally:
+                    #     # 鍦ㄦ墍鏈夋暟鎹帴鏀跺畬姣曞悗杩斿洖close
+                    #     yield {"message": "", "type": "close"}
 
                 else:
-                    yield f"Error: {response.status_code}"
+                    continue
+                    # yield f"Error: {response.status_code}"
+            else:
+            # 鍦ㄦ墍鏈夋暟鎹帴鏀跺畬姣曞悗杩斿洖close
+                yield {"message": "", "type": "close"}
 
     async def questions_talk(self, question, chat_id: str):
         logger.error("---------------questions_talk--------------------------")
diff --git a/requirements.txt b/requirements.txt
index a83ed48..a400033 100644
--- a/requirements.txt
+++ b/requirements.txt
Binary files differ

--
Gitblit v1.8.0