| | |
| | | from app.service.basic import BasicService |
| | | from app.service.ragflow import RagflowService |
| | | from app.service.service_token import get_bisheng_token, get_ragflow_token |
| | | from app.service.session import SessionService |
| | | |
| | | router = APIRouter() |
| | | |
| | |
| | | # 接收前端消息 |
| | | message = await websocket.receive_json() |
| | | question = message.get("message") |
| | | SessionService(db).create_session( |
| | | session_id=chat_id, |
| | | name=question, |
| | | agent_id=agent_id, |
| | | agent_type=AgentType.BASIC |
| | | ) |
| | | if not question: |
| | | await websocket.send_json({"message": "Invalid request", "type": "error"}) |
| | | continue |
| | |
| | | from zoneinfo import ZoneInfo |
| | | |
| | | import pytz |
| | | |
| | | from .agent_model import * |
| | | from .dialog_model import * |
| | | from .group_model import * |
| | |
| | | from .organization_model import * |
| | | from .resource_model import * |
| | | from .role_model import * |
| | | from .user_model import * |
| | | from .user_model import * |
| | | |
| | | |
| | | # 获取当前时区的时间 |
| | | def current_time(): |
| | | tz = pytz.timezone('Asia/Shanghai') |
| | | return datetime.now(tz) |
| | |
| | | from enum import IntEnum |
| | | from sqlalchemy import Column, String, Enum as SQLAlchemyEnum, Integer, DateTime |
| | | |
| | | from app.models import AgentType |
| | | from app.models import AgentType, current_time |
| | | from app.models.base_model import Base |
| | | |
| | | |
| | |
| | | name = Column(String(255)) |
| | | agent_id = Column(String(255)) |
| | | agent_type = Column(SQLAlchemyEnum(AgentType), nullable=False) # 目前只存basic的,ragflow和bisheng的调接口获取 |
| | | create_date = Column(DateTime) # 创建时间 |
| | | update_date = Column(DateTime) # 更新时间 |
| | | |
| | | create_date = Column(DateTime, default=current_time) # 创建时间,默认值为当前时区时间 |
| | | update_date = Column(DateTime, default=current_time, onupdate=current_time) # 更新时间,默认值为当前时区时间,更新时自动更新 |
| | | # to_dict 方法 |
| | | def to_dict(self): |
| | | return { |
New file |
| | |
| | | from sqlalchemy.orm import Session |
| | | |
| | | from app.models import AgentType |
| | | from app.models.session_model import SessionModel |
| | | |
| | | |
| | | class SessionService: |
| | | def __init__(self, db: Session): |
| | | self.db = db |
| | | |
| | | def create_session(self, session_id: str, name: str, agent_id: str, agent_type: AgentType) -> SessionModel: |
| | | """ |
| | | 创建一个新的会话记录。 |
| | | |
| | | 参数: |
| | | session_id (str): 会话ID。 |
| | | name (str): 会话名称。 |
| | | agent_id (str): 代理ID。 |
| | | agent_type (AgentType): 代理类型。 |
| | | |
| | | 返回: |
| | | SessionModel: 新创建的会话模型实例,如果会话ID已存在则返回None。 |
| | | """ |
| | | existing_session = self.get_session_by_id(session_id) |
| | | if existing_session: |
| | | return None # 如果会话ID已存在,不进行任何操作 |
| | | |
| | | new_session = SessionModel( |
| | | id=session_id, |
| | | name=name, |
| | | agent_id=agent_id, |
| | | agent_type=agent_type |
| | | ) |
| | | self.db.add(new_session) |
| | | self.db.commit() |
| | | self.db.refresh(new_session) |
| | | return new_session |
| | | |
| | | def get_session_by_id(self, session_id: str) -> SessionModel: |
| | | """ |
| | | 根据会话ID获取会话记录。 |
| | | |
| | | 参数: |
| | | session_id (str): 会话ID。 |
| | | |
| | | 返回: |
| | | SessionModel: 查找到的会话模型实例,如果未找到则返回None。 |
| | | """ |
| | | return self.db.query(SessionModel).filter_by(id=session_id).first() |
| | | |
| | | def update_session(self, session_id: str, **kwargs) -> SessionModel: |
| | | """ |
| | | 更新会话记录。 |
| | | |
| | | 参数: |
| | | session_id (str): 会话ID。 |
| | | kwargs: 需要更新的字段及其值。 |
| | | |
| | | 返回: |
| | | SessionModel: 更新后的会话模型实例。 |
| | | """ |
| | | session = self.get_session_by_id(session_id) |
| | | if session: |
| | | for key, value in kwargs.items(): |
| | | setattr(session, key, value) |
| | | self.db.commit() |
| | | self.db.refresh(session) |
| | | return session |
| | | |
| | | def delete_session(self, session_id: str) -> None: |
| | | """ |
| | | 删除会话记录。 |
| | | |
| | | 参数: |
| | | session_id (str): 会话ID。 |
| | | """ |
| | | session = self.get_session_by_id(session_id) |
| | | if session: |
| | | self.db.delete(session) |
| | | self.db.commit() |