zhaoqingang
2025-01-02 b991b79b608e3b811399cb59b2776ce23ba6d1e0
tmp test
4个文件已修改
2个文件已添加
104 ■■■■■ 已修改文件
app/api/v2/chat.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/models/v2/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
app/models/v2/session_model.py 65 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/v2/app_driver/chat_dialog.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/service/v2/chat.py 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
main.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/api/v2/chat.py
@@ -0,0 +1,15 @@
from fastapi import Depends, APIRouter
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.api import get_current_user
from app.models import UserModel
from app.models.base_model import get_db
from app.models.v2.session_model import ChatDialogData
from app.service.v2.chat import service_chat_dialog
chat1_router = APIRouter()
@chat1_router.get("/chat_dialog")
async def api_chat_dialog(dialog: ChatDialogData, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_user)):
    return StreamingResponse(await service_chat_dialog(dialog.question, dialog.sessionId), media_type="text/event-stream")
app/models/v2/__init__.py
app/models/v2/session_model.py
New file
@@ -0,0 +1,65 @@
import json
from datetime import datetime
from enum import IntEnum
from typing import Optional
import pytz
from pydantic import BaseModel
from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime, JSON, TEXT
from app.models.agent_model import AgentType
# from app.models import current_time
from app.models.base_model import Base
def current_time():
    tz = pytz.timezone('Asia/Shanghai')
    return datetime.now(tz)
class SessionModel(Base):
    __tablename__ = "sessions"
    id = Column(String(255), primary_key=True)
    name = Column(String(255))
    agent_id = Column(String(255))
    agent_type = Column(SQLAlchemyEnum(AgentType), nullable=False)  # 目前只存basic的,ragflow和bisheng的调接口获取
    create_date = Column(DateTime, default=current_time)  # 创建时间,默认值为当前时区时间
    update_date = Column(DateTime, default=current_time, onupdate=current_time)  # 更新时间,默认值为当前时区时间,更新时自动更新
    tenant_id = Column(Integer)  # 创建人
    message = Column(TEXT)  # 说明
    conversation_id = Column(String(64))
    # to_dict 方法
    def to_dict(self):
        return {
            'id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
        }
    def log_to_json(self):
        return {
            'id': self.id,
            'name': self.name,
            'agent_type': self.agent_type,
            'agent_id': self.agent_id,
            'create_date': self.create_date.strftime("%Y-%m-%d %H:%M:%S"),
            'update_date': self.update_date.strftime("%Y-%m-%d %H:%M:%S"),
            'message': json.loads(self.message)
        }
    def add_message(self, message: dict):
        if self.message is None:
            self.message = '[]'
        try:
            msg = json.loads(self.message)
            msg.append(message)
        except Exception as e:
            return
        self.message = json.dumps(msg)
class ChatDialogData(BaseModel):
    sessionId: Optional[str] = ""
    question: str
app/service/v2/app_driver/chat_dialog.py
@@ -17,4 +17,5 @@
    async def chat_completions(self):
        async for rag_response in self.http_stream(token, chat_id, chat_history):
            ...
            yield rag_response
app/service/v2/chat.py
@@ -0,0 +1,19 @@
async def service_chat_dialog(question: str, session_id: str):
    if session_id:
        ...
    try:
        for ans in chat(dia, msg, True, **req):
            yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
        ConversationService.update_by_id(conv.id, conv.to_dict())
    except Exception as e:
        yield "data:" + json.dumps({"code": 500, "message": str(e),
                                    "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
                                   ensure_ascii=False) + "\n\n"
    yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
main.py
@@ -15,6 +15,7 @@
from app.api.label import label_router
from app.api.llm import llm_router
from app.api.organization import dept_router
from app.api.v2.chat import chat1_router
from app.api.v2.public_api import public_api
from app.api.report import router as report_router
from app.api.resource import menu_router
@@ -85,6 +86,7 @@
app.include_router(canvas_router, prefix='/api/canvas', tags=["canvas"])
app.include_router(label_router, prefix='/api/label', tags=["label"])
app.include_router(public_api, prefix='/v1/api', tags=["public_api"])
app.include_router(chat1_router, prefix='/v1/chat', tags=["chat1"])
app.mount("/static", StaticFiles(directory="app/images"), name="static")
if __name__ == "__main__":