feat: add update model rest api (#2530)

Signed-off-by: Gaius <gaius.qi@gmail.com>
This commit is contained in:
Gaius 2023-07-10 18:19:14 +08:00 committed by GitHub
parent a8f7c56b4a
commit 355d9dd10d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 121 additions and 370 deletions

View File

@ -1116,47 +1116,6 @@ const docTemplate = `{
"description": "Internal Server Error" "description": "Internal Server Error"
} }
} }
},
"post": {
"description": "Create Model by json config",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"Model"
],
"summary": "Create Model",
"parameters": [
{
"description": "Model",
"name": "Model",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/d7y_io_dragonfly_v2_manager_types.CreateModelRequest"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/d7y_io_dragonfly_v2_manager_models.Model"
}
},
"400": {
"description": "Bad Request"
},
"404": {
"description": "Not Found"
},
"500": {
"description": "Internal Server Error"
}
}
} }
}, },
"/models/{id}": { "/models/{id}": {
@ -3438,6 +3397,9 @@ const docTemplate = `{
"id": { "id": {
"type": "integer" "type": "integer"
}, },
"name": {
"type": "string"
},
"scheduler_id": { "scheduler_id": {
"type": "integer" "type": "integer"
}, },
@ -3890,32 +3852,6 @@ const docTemplate = `{
} }
} }
}, },
"d7y_io_dragonfly_v2_manager_types.CreateModelRequest": {
"type": "object",
"required": [
"evaluation",
"scheduler_id",
"type",
"version"
],
"properties": {
"BIO": {
"type": "string"
},
"evaluation": {
"$ref": "#/definitions/d7y_io_dragonfly_v2_manager_types.ModelEvaluation"
},
"scheduler_id": {
"type": "integer"
},
"type": {
"type": "string"
},
"version": {
"type": "string"
}
}
},
"d7y_io_dragonfly_v2_manager_types.CreateOauthRequest": { "d7y_io_dragonfly_v2_manager_types.CreateOauthRequest": {
"type": "object", "type": "object",
"required": [ "required": [
@ -4205,34 +4141,6 @@ const docTemplate = `{
} }
} }
}, },
"d7y_io_dragonfly_v2_manager_types.ModelEvaluation": {
"type": "object",
"properties": {
"f1_score": {
"type": "number",
"maximum": 1,
"minimum": 0
},
"mae": {
"type": "number",
"minimum": 0
},
"mse": {
"type": "number",
"minimum": 0
},
"precision": {
"type": "number",
"maximum": 1,
"minimum": 0
},
"recall": {
"type": "number",
"maximum": 1,
"minimum": 0
}
}
},
"d7y_io_dragonfly_v2_manager_types.PriorityConfig": { "d7y_io_dragonfly_v2_manager_types.PriorityConfig": {
"type": "object", "type": "object",
"required": [ "required": [
@ -4504,17 +4412,10 @@ const docTemplate = `{
"BIO": { "BIO": {
"type": "string" "type": "string"
}, },
"evaluation": {
"$ref": "#/definitions/d7y_io_dragonfly_v2_manager_types.ModelEvaluation"
},
"scheduler_id": {
"type": "integer"
},
"state": { "state": {
"type": "string", "type": "string",
"enum": [ "enum": [
"active", "active"
"inactive"
] ]
} }
} }

View File

@ -1110,47 +1110,6 @@
"description": "Internal Server Error" "description": "Internal Server Error"
} }
} }
},
"post": {
"description": "Create Model by json config",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"Model"
],
"summary": "Create Model",
"parameters": [
{
"description": "Model",
"name": "Model",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/d7y_io_dragonfly_v2_manager_types.CreateModelRequest"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/d7y_io_dragonfly_v2_manager_models.Model"
}
},
"400": {
"description": "Bad Request"
},
"404": {
"description": "Not Found"
},
"500": {
"description": "Internal Server Error"
}
}
} }
}, },
"/models/{id}": { "/models/{id}": {
@ -3432,6 +3391,9 @@
"id": { "id": {
"type": "integer" "type": "integer"
}, },
"name": {
"type": "string"
},
"scheduler_id": { "scheduler_id": {
"type": "integer" "type": "integer"
}, },
@ -3884,32 +3846,6 @@
} }
} }
}, },
"d7y_io_dragonfly_v2_manager_types.CreateModelRequest": {
"type": "object",
"required": [
"evaluation",
"scheduler_id",
"type",
"version"
],
"properties": {
"BIO": {
"type": "string"
},
"evaluation": {
"$ref": "#/definitions/d7y_io_dragonfly_v2_manager_types.ModelEvaluation"
},
"scheduler_id": {
"type": "integer"
},
"type": {
"type": "string"
},
"version": {
"type": "string"
}
}
},
"d7y_io_dragonfly_v2_manager_types.CreateOauthRequest": { "d7y_io_dragonfly_v2_manager_types.CreateOauthRequest": {
"type": "object", "type": "object",
"required": [ "required": [
@ -4199,34 +4135,6 @@
} }
} }
}, },
"d7y_io_dragonfly_v2_manager_types.ModelEvaluation": {
"type": "object",
"properties": {
"f1_score": {
"type": "number",
"maximum": 1,
"minimum": 0
},
"mae": {
"type": "number",
"minimum": 0
},
"mse": {
"type": "number",
"minimum": 0
},
"precision": {
"type": "number",
"maximum": 1,
"minimum": 0
},
"recall": {
"type": "number",
"maximum": 1,
"minimum": 0
}
}
},
"d7y_io_dragonfly_v2_manager_types.PriorityConfig": { "d7y_io_dragonfly_v2_manager_types.PriorityConfig": {
"type": "object", "type": "object",
"required": [ "required": [
@ -4498,17 +4406,10 @@
"BIO": { "BIO": {
"type": "string" "type": "string"
}, },
"evaluation": {
"$ref": "#/definitions/d7y_io_dragonfly_v2_manager_types.ModelEvaluation"
},
"scheduler_id": {
"type": "integer"
},
"state": { "state": {
"type": "string", "type": "string",
"enum": [ "enum": [
"active", "active"
"inactive"
] ]
} }
} }

View File

@ -82,6 +82,8 @@ definitions:
$ref: '#/definitions/d7y_io_dragonfly_v2_manager_models.JSONMap' $ref: '#/definitions/d7y_io_dragonfly_v2_manager_models.JSONMap'
id: id:
type: integer type: integer
name:
type: string
scheduler_id: scheduler_id:
type: integer type: integer
state: state:
@ -383,24 +385,6 @@ definitions:
required: required:
- type - type
type: object type: object
d7y_io_dragonfly_v2_manager_types.CreateModelRequest:
properties:
BIO:
type: string
evaluation:
$ref: '#/definitions/d7y_io_dragonfly_v2_manager_types.ModelEvaluation'
scheduler_id:
type: integer
type:
type: string
version:
type: string
required:
- evaluation
- scheduler_id
- type
- version
type: object
d7y_io_dragonfly_v2_manager_types.CreateOauthRequest: d7y_io_dragonfly_v2_manager_types.CreateOauthRequest:
properties: properties:
bio: bio:
@ -597,27 +581,6 @@ definitions:
status: status:
type: string type: string
type: object type: object
d7y_io_dragonfly_v2_manager_types.ModelEvaluation:
properties:
f1_score:
maximum: 1
minimum: 0
type: number
mae:
minimum: 0
type: number
mse:
minimum: 0
type: number
precision:
maximum: 1
minimum: 0
type: number
recall:
maximum: 1
minimum: 0
type: number
type: object
d7y_io_dragonfly_v2_manager_types.PriorityConfig: d7y_io_dragonfly_v2_manager_types.PriorityConfig:
properties: properties:
urls: urls:
@ -803,14 +766,9 @@ definitions:
properties: properties:
BIO: BIO:
type: string type: string
evaluation:
$ref: '#/definitions/d7y_io_dragonfly_v2_manager_types.ModelEvaluation'
scheduler_id:
type: integer
state: state:
enum: enum:
- active - active
- inactive
type: string type: string
type: object type: object
d7y_io_dragonfly_v2_manager_types.UpdateOauthRequest: d7y_io_dragonfly_v2_manager_types.UpdateOauthRequest:
@ -1662,33 +1620,6 @@ paths:
summary: Get Models summary: Get Models
tags: tags:
- Model - Model
post:
consumes:
- application/json
description: Create Model by json config
parameters:
- description: Model
in: body
name: Model
required: true
schema:
$ref: '#/definitions/d7y_io_dragonfly_v2_manager_types.CreateModelRequest'
produces:
- application/json
responses:
"200":
description: OK
schema:
$ref: '#/definitions/d7y_io_dragonfly_v2_manager_models.Model'
"400":
description: Bad Request
"404":
description: Not Found
"500":
description: Internal Server Error
summary: Create Model
tags:
- Model
/models/{id}: /models/{id}:
delete: delete:
consumes: consumes:

View File

@ -9,33 +9,6 @@ import (
"d7y.io/dragonfly/v2/manager/types" "d7y.io/dragonfly/v2/manager/types"
) )
// @Summary Create Model
// @Description Create Model by json config
// @Tags Model
// @Accept json
// @Produce json
// @Param Model body types.CreateModelRequest true "Model"
// @Success 200 {object} models.Model
// @Failure 400
// @Failure 404
// @Failure 500
// @Router /models [post]
func (h *Handlers) CreateModel(ctx *gin.Context) {
var json types.CreateModelRequest
if err := ctx.ShouldBindJSON(&json); err != nil {
ctx.JSON(http.StatusUnprocessableEntity, gin.H{"errors": err.Error()})
return
}
model, err := h.service.CreateModel(ctx.Request.Context(), json)
if err != nil {
ctx.Error(err) // nolint: errcheck
return
}
ctx.JSON(http.StatusOK, model)
}
// @Summary Destroy Model // @Summary Destroy Model
// @Description Destroy by id // @Description Destroy by id
// @Tags Model // @Tags Model

View File

@ -147,7 +147,7 @@ func New(cfg *config.Config, d dfpath.Dfpath) (*Server, error) {
} }
// Initialize REST server // Initialize REST server
restService := service.New(db, cache, job, enforcer, objectStorage) restService := service.New(cfg, db, cache, job, enforcer, objectStorage)
router, err := router.Init(cfg, d.LogDir(), restService, enforcer, EmbedFolder(assets, assetsTargetPath)) router, err := router.Init(cfg, d.LogDir(), restService, enforcer, EmbedFolder(assets, assetsTargetPath))
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -35,6 +35,7 @@ const (
// TODO(Gaius) Add regression analysis parameters. // TODO(Gaius) Add regression analysis parameters.
type Model struct { type Model struct {
BaseModel BaseModel
Name string `gorm:"column:name;type:varchar(256);index:uk_model_name,unique;not null;comment:name" json:"name"`
Type string `gorm:"column:type;type:varchar(256);index:uk_model,unique;not null;comment:type" json:"type"` Type string `gorm:"column:type;type:varchar(256);index:uk_model,unique;not null;comment:type" json:"type"`
BIO string `gorm:"column:bio;type:varchar(1024);comment:biography" json:"bio"` BIO string `gorm:"column:bio;type:varchar(1024);comment:biography" json:"bio"`
Version string `gorm:"column:version;type:varchar(256);index:uk_model,unique;not null;comment:model version" json:"version"` Version string `gorm:"column:version;type:varchar(256);index:uk_model,unique;not null;comment:model version" json:"version"`

View File

@ -197,7 +197,6 @@ func Init(cfg *config.Config, logDir string, service service.Service, enforcer *
// Model // Model
model := apiv1.Group("/models", jwt.MiddlewareFunc(), rbac) model := apiv1.Group("/models", jwt.MiddlewareFunc(), rbac)
model.POST("", h.CreateModel)
model.DELETE(":id", h.DestroyModel) model.DELETE(":id", h.DestroyModel)
model.PATCH(":id", h.UpdateModel) model.PATCH(":id", h.UpdateModel)
model.GET(":id", h.GetModel) model.GET(":id", h.GetModel)

View File

@ -835,6 +835,7 @@ func (s *managerServerV1) CreateModel(ctx context.Context, req *managerv1.Create
// Create model in database. // Create model in database.
if err := s.db.WithContext(ctx).Model(&scheduler).Association("Models").Append(&models.Model{ if err := s.db.WithContext(ctx).Model(&scheduler).Association("Models").Append(&models.Model{
Name: name,
Type: typ, Type: typ,
Version: fmt.Sprint(version), Version: fmt.Sprint(version),
State: models.ModelVersionStateInactive, State: models.ModelVersionStateInactive,

View File

@ -170,21 +170,6 @@ func (mr *MockServiceMockRecorder) CreateConfig(arg0, arg1 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateConfig", reflect.TypeOf((*MockService)(nil).CreateConfig), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateConfig", reflect.TypeOf((*MockService)(nil).CreateConfig), arg0, arg1)
} }
// CreateModel mocks base method.
func (m *MockService) CreateModel(arg0 context.Context, arg1 types.CreateModelRequest) (*models.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateModel", arg0, arg1)
ret0, _ := ret[0].(*models.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateModel indicates an expected call of CreateModel.
func (mr *MockServiceMockRecorder) CreateModel(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateModel", reflect.TypeOf((*MockService)(nil).CreateModel), arg0, arg1)
}
// CreateOauth mocks base method. // CreateOauth mocks base method.
func (m *MockService) CreateOauth(arg0 context.Context, arg1 types.CreateOauthRequest) (*models.Oauth, error) { func (m *MockService) CreateOauth(arg0 context.Context, arg1 types.CreateOauthRequest) (*models.Oauth, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -18,33 +18,20 @@ package service
import ( import (
"context" "context"
"encoding/json"
"errors"
"fmt"
"io"
"strconv"
"strings"
inferencev1 "d7y.io/api/pkg/apis/inference/v1"
"d7y.io/dragonfly/v2/manager/models" "d7y.io/dragonfly/v2/manager/models"
"d7y.io/dragonfly/v2/manager/types" "d7y.io/dragonfly/v2/manager/types"
"d7y.io/dragonfly/v2/pkg/structure" "d7y.io/dragonfly/v2/pkg/digest"
) )
func (s *service) CreateModel(ctx context.Context, json types.CreateModelRequest) (*models.Model, error) {
evaluation, err := structure.StructToMap(json.Evaluation)
if err != nil {
return nil, err
}
model := models.Model{
Type: json.Type,
BIO: json.BIO,
Version: json.Version,
Evaluation: evaluation,
SchedulerID: json.SchedulerID,
}
if err := s.db.WithContext(ctx).Create(&model).Error; err != nil {
return nil, err
}
return &model, nil
}
func (s *service) DestroyModel(ctx context.Context, id uint) error { func (s *service) DestroyModel(ctx context.Context, id uint) error {
model := models.Model{} model := models.Model{}
if err := s.db.WithContext(ctx).First(&model, id).Error; err != nil { if err := s.db.WithContext(ctx).First(&model, id).Error; err != nil {
@ -59,23 +46,22 @@ func (s *service) DestroyModel(ctx context.Context, id uint) error {
} }
func (s *service) UpdateModel(ctx context.Context, id uint, json types.UpdateModelRequest) (*models.Model, error) { func (s *service) UpdateModel(ctx context.Context, id uint, json types.UpdateModelRequest) (*models.Model, error) {
var ( model := models.Model{}
evaluation map[string]any if err := s.db.WithContext(ctx).First(&model, id).Error; err != nil {
err error return nil, err
) }
if json.Evaluation != nil {
evaluation, err = structure.StructToMap(json.Evaluation) // If the model is active, update the model config and
if err != nil { // update the model state.
if json.State == models.ModelVersionStateActive {
if err := s.updateModelStateToActive(ctx, &model); err != nil {
return nil, err return nil, err
} }
} }
model := models.Model{} // Update the model.
if err := s.db.WithContext(ctx).First(&model, id).Updates(models.Model{ if err := s.db.WithContext(ctx).Model(&model).Updates(models.Model{
BIO: json.BIO, BIO: json.BIO,
State: json.State,
Evaluation: evaluation,
SchedulerID: json.SchedulerID,
}).Error; err != nil { }).Error; err != nil {
return nil, err return nil, err
} }
@ -105,3 +91,83 @@ func (s *service) GetModels(ctx context.Context, q types.GetModelsQuery) ([]mode
return model, count, nil return model, count, nil
} }
func (s *service) updateModelStateToActive(ctx context.Context, model *models.Model) error {
version, err := strconv.ParseInt(model.Version, 10, 64)
if err != nil {
return err
}
// Update the model config to object storage.
if err := s.updateModelConfig(ctx, model.Name, version); err != nil {
return err
}
// Create a transaction to ensure that only one
// version is active at a time.
tx := s.db.WithContext(ctx).Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
if err := tx.Error; err != nil {
return err
}
if err := tx.Model(&models.Model{}).Where("state = ?", models.ModelVersionStateActive).Update("state", models.ModelVersionStateInactive).Error; err != nil {
tx.Rollback()
return err
}
if err := tx.Model(model).Update("state", models.ModelVersionStateActive).Error; err != nil {
tx.Rollback()
return err
}
if tx.Commit().Error != nil {
return err
}
return nil
}
func (s *service) updateModelConfig(ctx context.Context, name string, version int64) error {
if !s.config.ObjectStorage.Enable {
return errors.New("object storage is disabled")
}
objectKey := types.MakeObjectKeyOfModelConfigFile(name)
var pbModelConfig inferencev1.ModelConfig
reader, err := s.objectStorage.GetOject(ctx, s.config.Trainer.BucketName, objectKey)
if err != nil {
return err
}
defer reader.Close()
data, err := io.ReadAll(reader)
if err != nil {
return err
}
if err := json.Unmarshal(data, &pbModelConfig); err != nil {
return err
}
switch policyChoice := pbModelConfig.VersionPolicy.PolicyChoice.(type) {
case *inferencev1.ModelVersionPolicy_Specific_:
// If the version already exists, add the version to the existing version list.
policyChoice.Specific.Versions = []int64{version}
default:
return fmt.Errorf("unknown policy choice: %#v", policyChoice)
}
dgst := digest.New(digest.AlgorithmSHA256, digest.SHA256FromStrings(pbModelConfig.String()))
if err := s.objectStorage.PutObject(ctx, s.config.Trainer.BucketName,
types.MakeObjectKeyOfModelConfigFile(name), dgst.String(), strings.NewReader(pbModelConfig.String())); err != nil {
return err
}
return nil
}

View File

@ -27,6 +27,7 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"d7y.io/dragonfly/v2/manager/cache" "d7y.io/dragonfly/v2/manager/cache"
"d7y.io/dragonfly/v2/manager/config"
"d7y.io/dragonfly/v2/manager/database" "d7y.io/dragonfly/v2/manager/database"
"d7y.io/dragonfly/v2/manager/job" "d7y.io/dragonfly/v2/manager/job"
"d7y.io/dragonfly/v2/manager/models" "d7y.io/dragonfly/v2/manager/models"
@ -124,7 +125,6 @@ type Service interface {
GetApplication(context.Context, uint) (*models.Application, error) GetApplication(context.Context, uint) (*models.Application, error)
GetApplications(context.Context, types.GetApplicationsQuery) ([]models.Application, int64, error) GetApplications(context.Context, types.GetApplicationsQuery) ([]models.Application, int64, error)
CreateModel(context.Context, types.CreateModelRequest) (*models.Model, error)
DestroyModel(context.Context, uint) error DestroyModel(context.Context, uint) error
UpdateModel(context.Context, uint, types.UpdateModelRequest) (*models.Model, error) UpdateModel(context.Context, uint, types.UpdateModelRequest) (*models.Model, error)
GetModel(context.Context, uint) (*models.Model, error) GetModel(context.Context, uint) (*models.Model, error)
@ -132,6 +132,7 @@ type Service interface {
} }
type service struct { type service struct {
config *config.Config
db *gorm.DB db *gorm.DB
rdb redis.UniversalClient rdb redis.UniversalClient
cache *cache.Cache cache *cache.Cache
@ -141,8 +142,9 @@ type service struct {
} }
// NewREST returns a new REST instence // NewREST returns a new REST instence
func New(database *database.Database, cache *cache.Cache, job *job.Job, enforcer *casbin.Enforcer, objectStorage objectstorage.ObjectStorage) Service { func New(cfg *config.Config, database *database.Database, cache *cache.Cache, job *job.Job, enforcer *casbin.Enforcer, objectStorage objectstorage.ObjectStorage) Service {
return &service{ return &service{
config: cfg,
db: database.DB, db: database.DB,
rdb: database.RDB, rdb: database.RDB,
cache: cache, cache: cache,

View File

@ -41,22 +41,13 @@ type ModelParams struct {
ID uint `uri:"id" binding:"required"` ID uint `uri:"id" binding:"required"`
} }
type CreateModelRequest struct {
Type string `json:"type" binding:"required"`
BIO string `json:"BIO" binding:"omitempty"`
Version string `json:"version" binding:"required"`
Evaluation *ModelEvaluation `json:"evaluation" binding:"required"`
SchedulerID uint `json:"scheduler_id" binding:"required"`
}
type UpdateModelRequest struct { type UpdateModelRequest struct {
BIO string `json:"BIO" binding:"omitempty"` BIO string `json:"BIO" binding:"omitempty"`
State string `json:"state" binding:"omitempty,oneof=active inactive"` State string `json:"state" binding:"omitempty,oneof=active"`
Evaluation *ModelEvaluation `json:"evaluation" binding:"omitempty"`
SchedulerID uint `json:"scheduler_id" binding:"omitempty"`
} }
type GetModelsQuery struct { type GetModelsQuery struct {
Name string `json:"name" binding:"omitempty"`
Type string `json:"type" binding:"omitempty"` Type string `json:"type" binding:"omitempty"`
Version string `json:"version" binding:"omitempty"` Version string `json:"version" binding:"omitempty"`
SchedulerID uint `json:"scheduler_id" binding:"omitempty"` SchedulerID uint `json:"scheduler_id" binding:"omitempty"`