zhangqian
2024-12-12 33a468c0bceff7841abe168a6bc825d6ccf96a6f
查询模型任务时只查询启用的模型
1个文件已添加
2个文件已修改
211 ■■■■■ 已修改文件
db/model.go 186 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
db/task.go 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
service/task.go 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
db/model.go
New file
@@ -0,0 +1,186 @@
package db
import (
    "fmt"
    "gorm.io/gorm"
    "model-engine/pkg/mysqlx"
)
type Model struct {
    BaseModel
    Name        string `json:"name" gorm:"type:varchar(255)"`                   //模型名称
    Description string `json:"description,omitempty" gorm:"type:varchar(1000)"` //模型描述
    Version     string `json:"version" gorm:"type:varchar(255)"`                //版本号
    Enabled     bool   `json:"enabled"`                                         //是否开启
}
func (m *Model) TableName() string {
    return "model"
}
type ModelSearch struct {
    Model
    Orm      *gorm.DB
    PageNum  int
    PageSize int
    Order    string
    Keyword  string
}
func NewModelSearch() *ModelSearch {
    return &ModelSearch{
        Orm:      mysqlx.GetDB(),
        PageNum:  1,
        PageSize: 10,
    }
}
func (slf *ModelSearch) SetOrm(tx *gorm.DB) *ModelSearch {
    slf.Orm = tx
    return slf
}
func (slf *ModelSearch) SetPage(page, size int) *ModelSearch {
    slf.PageNum, slf.PageSize = page, size
    return slf
}
func (slf *ModelSearch) SetOrder(order string) *ModelSearch {
    slf.Order = order
    return slf
}
func (slf *ModelSearch) SetID(id string) *ModelSearch {
    slf.ID = id
    return slf
}
func (slf *ModelSearch) SetKeyword(kw string) *ModelSearch {
    slf.Keyword = kw
    return slf
}
func (slf *ModelSearch) SetEnabled(enabled bool) *ModelSearch {
    slf.Enabled = enabled
    return slf
}
func (slf *ModelSearch) build() *gorm.DB {
    var db = slf.Orm.Table(slf.TableName())
    if slf.Order != "" {
        db = db.Order(slf.Order)
    }
    if slf.ID != "" {
        db = db.Where("id = ?", slf.ID)
    }
    if slf.Keyword != "" {
        kw := "%" + slf.Keyword + "%"
        db = db.Where("name like ?", kw)
    }
    if slf.Enabled {
        db = db.Where("enabled = ?", slf.Enabled)
    }
    return db
}
func (slf *ModelSearch) First() (*Model, error) {
    var (
        record = new(Model)
        db     = slf.build()
    )
    if err := db.First(record).Error; err != nil {
        return record, err
    }
    return record, nil
}
func (slf *ModelSearch) Find() ([]*Model, int64, error) {
    var (
        records = make([]*Model, 0)
        total   int64
        db      = slf.build()
    )
    if err := db.Count(&total).Error; err != nil {
        return records, total, fmt.Errorf("find count err: %v", err)
    }
    if slf.PageNum*slf.PageSize > 0 {
        db = db.Offset((slf.PageNum - 1) * slf.PageSize).Limit(slf.PageSize)
    }
    if err := db.Find(&records).Error; err != nil {
        return records, total, fmt.Errorf("find records err: %v", err)
    }
    return records, total, nil
}
func (slf *ModelSearch) FindAll() ([]*Model, error) {
    var (
        records = make([]*Model, 0)
        db      = slf.build()
    )
    if err := db.Find(&records).Error; err != nil {
        return records, fmt.Errorf("find records err: %v", err)
    }
    return records, nil
}
func (slf *ModelSearch) Count() int64 {
    var (
        count int64
        db    = slf.build()
    )
    if err := db.Count(&count).Error; err != nil {
        return count
    }
    return count
}
func (slf *ModelSearch) Create(record *Model) error {
    var db = slf.build()
    if err := db.Create(record).Error; err != nil {
        return fmt.Errorf("create err: %v, record: %+v", err, record)
    }
    return nil
}
func (slf *ModelSearch) Save(record *Model) error {
    var db = slf.build()
    if err := db.Omit("CreatedAt").Save(record).Error; err != nil {
        return fmt.Errorf("save err: %v, record: %+v", err, record)
    }
    return nil
}
func (slf *ModelSearch) Update(record *Model) error {
    var db = slf.build()
    if err := db.Updates(record).Error; err != nil {
        return fmt.Errorf("update err: %v, record: %+v", err, record)
    }
    return nil
}
func (slf *ModelSearch) Delete() error {
    var db = slf.build()
    return db.Delete(&Model{}).Error
}
const (
    ModelIdDrug   = "drug"   //涉毒
    ModelIdGather = "gather" //聚集
)
db/task.go
@@ -32,6 +32,7 @@
        PageNum  int
        PageSize int
        Keyword  string
        ModelIDs []string
    }
)
@@ -102,6 +103,11 @@
    return slf
}
func (slf *ModelTaskSearch) SetModelIDs(ids []string) *ModelTaskSearch {
    slf.ModelIDs = ids
    return slf
}
func (slf *ModelTaskSearch) SetKeyword(kw string) *ModelTaskSearch {
    slf.Keyword = kw
    return slf
@@ -123,6 +129,10 @@
        db = db.Where("model_id = ?", slf.ModelID)
    }
    if len(slf.ModelIDs) != 0 {
        db = db.Where("model_id in ?", slf.ModelIDs)
    }
    if slf.Keyword != "" {
        kw := "%" + slf.Keyword + "%"
        db = db.Where("name like ?", kw)
service/task.go
@@ -3,5 +3,18 @@
import "model-engine/db"
func GetTasks() (tasks []*db.ModelTask, err error) {
    return db.NewModelTaskSearch().FindAll()
    models, err := db.NewModelSearch().SetEnabled(true).FindAll()
    if err != nil {
        return nil, err
    }
    if len(models) == 0 {
        return nil, nil
    }
    modelIds := make([]string, 0, len(models))
    for _, model := range models {
        modelIds = append(modelIds, model.ID)
    }
    return db.NewModelTaskSearch().SetModelIDs(modelIds).FindAll()
}