From 355d9dd10db6bd14f71d93ceb36713f52b8e289b Mon Sep 17 00:00:00 2001 From: Gaius Date: Mon, 10 Jul 2023 18:19:14 +0800 Subject: [PATCH] feat: add update model rest api (#2530) Signed-off-by: Gaius --- api/manager/docs.go | 107 +------------------ api/manager/swagger.json | 107 +------------------ api/manager/swagger.yaml | 73 +------------ manager/handlers/model.go | 27 ----- manager/manager.go | 2 +- manager/models/model.go | 1 + manager/router/router.go | 1 - manager/rpcserver/manager_server_v1.go | 1 + manager/service/mocks/service_mock.go | 15 --- manager/service/model.go | 136 ++++++++++++++++++------- manager/service/service.go | 6 +- manager/types/model.go | 15 +-- 12 files changed, 121 insertions(+), 370 deletions(-) diff --git a/api/manager/docs.go b/api/manager/docs.go index d31f75197..9415f44e9 100644 --- a/api/manager/docs.go +++ b/api/manager/docs.go @@ -1116,47 +1116,6 @@ const docTemplate = `{ "description": "Internal Server Error" } } - }, - "post": { - "description": "Create Model by json config", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Model" - ], - "summary": "Create Model", - "parameters": [ - { - "description": "Model", - "name": "Model", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/d7y_io_dragonfly_v2_manager_types.CreateModelRequest" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/d7y_io_dragonfly_v2_manager_models.Model" - } - }, - "400": { - "description": "Bad Request" - }, - "404": { - "description": "Not Found" - }, - "500": { - "description": "Internal Server Error" - } - } } }, "/models/{id}": { @@ -3438,6 +3397,9 @@ const docTemplate = `{ "id": { "type": "integer" }, + "name": { + "type": "string" + }, "scheduler_id": { "type": "integer" }, @@ -3890,32 +3852,6 @@ const docTemplate = `{ } } }, - "d7y_io_dragonfly_v2_manager_types.CreateModelRequest": { - "type": "object", - "required": [ - "evaluation", - "scheduler_id", - "type", - "version" - ], - "properties": { - "BIO": { - "type": "string" - }, - "evaluation": { - "$ref": "#/definitions/d7y_io_dragonfly_v2_manager_types.ModelEvaluation" - }, - "scheduler_id": { - "type": "integer" - }, - "type": { - "type": "string" - }, - "version": { - "type": "string" - } - } - }, "d7y_io_dragonfly_v2_manager_types.CreateOauthRequest": { "type": "object", "required": [ @@ -4205,34 +4141,6 @@ const docTemplate = `{ } } }, - "d7y_io_dragonfly_v2_manager_types.ModelEvaluation": { - "type": "object", - "properties": { - "f1_score": { - "type": "number", - "maximum": 1, - "minimum": 0 - }, - "mae": { - "type": "number", - "minimum": 0 - }, - "mse": { - "type": "number", - "minimum": 0 - }, - "precision": { - "type": "number", - "maximum": 1, - "minimum": 0 - }, - "recall": { - "type": "number", - "maximum": 1, - "minimum": 0 - } - } - }, "d7y_io_dragonfly_v2_manager_types.PriorityConfig": { "type": "object", "required": [ @@ -4504,17 +4412,10 @@ const docTemplate = `{ "BIO": { "type": "string" }, - "evaluation": { - "$ref": "#/definitions/d7y_io_dragonfly_v2_manager_types.ModelEvaluation" - }, - "scheduler_id": { - "type": "integer" - }, "state": { "type": "string", "enum": [ - "active", - "inactive" + "active" ] } } diff --git a/api/manager/swagger.json b/api/manager/swagger.json index 12b5cc45b..678138f75 100644 --- a/api/manager/swagger.json +++ b/api/manager/swagger.json @@ -1110,47 +1110,6 @@ "description": "Internal Server Error" } } - }, - "post": { - "description": "Create Model by json config", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Model" - ], - "summary": "Create Model", - "parameters": [ - { - "description": "Model", - "name": "Model", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/d7y_io_dragonfly_v2_manager_types.CreateModelRequest" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/d7y_io_dragonfly_v2_manager_models.Model" - } - }, - "400": { - "description": "Bad Request" - }, - "404": { - "description": "Not Found" - }, - "500": { - "description": "Internal Server Error" - } - } } }, "/models/{id}": { @@ -3432,6 +3391,9 @@ "id": { "type": "integer" }, + "name": { + "type": "string" + }, "scheduler_id": { "type": "integer" }, @@ -3884,32 +3846,6 @@ } } }, - "d7y_io_dragonfly_v2_manager_types.CreateModelRequest": { - "type": "object", - "required": [ - "evaluation", - "scheduler_id", - "type", - "version" - ], - "properties": { - "BIO": { - "type": "string" - }, - "evaluation": { - "$ref": "#/definitions/d7y_io_dragonfly_v2_manager_types.ModelEvaluation" - }, - "scheduler_id": { - "type": "integer" - }, - "type": { - "type": "string" - }, - "version": { - "type": "string" - } - } - }, "d7y_io_dragonfly_v2_manager_types.CreateOauthRequest": { "type": "object", "required": [ @@ -4199,34 +4135,6 @@ } } }, - "d7y_io_dragonfly_v2_manager_types.ModelEvaluation": { - "type": "object", - "properties": { - "f1_score": { - "type": "number", - "maximum": 1, - "minimum": 0 - }, - "mae": { - "type": "number", - "minimum": 0 - }, - "mse": { - "type": "number", - "minimum": 0 - }, - "precision": { - "type": "number", - "maximum": 1, - "minimum": 0 - }, - "recall": { - "type": "number", - "maximum": 1, - "minimum": 0 - } - } - }, "d7y_io_dragonfly_v2_manager_types.PriorityConfig": { "type": "object", "required": [ @@ -4498,17 +4406,10 @@ "BIO": { "type": "string" }, - "evaluation": { - "$ref": "#/definitions/d7y_io_dragonfly_v2_manager_types.ModelEvaluation" - }, - "scheduler_id": { - "type": "integer" - }, "state": { "type": "string", "enum": [ - "active", - "inactive" + "active" ] } } diff --git a/api/manager/swagger.yaml b/api/manager/swagger.yaml index 7be88fc02..a1aa4865c 100644 --- a/api/manager/swagger.yaml +++ b/api/manager/swagger.yaml @@ -82,6 +82,8 @@ definitions: $ref: '#/definitions/d7y_io_dragonfly_v2_manager_models.JSONMap' id: type: integer + name: + type: string scheduler_id: type: integer state: @@ -383,24 +385,6 @@ definitions: required: - type type: object - d7y_io_dragonfly_v2_manager_types.CreateModelRequest: - properties: - BIO: - type: string - evaluation: - $ref: '#/definitions/d7y_io_dragonfly_v2_manager_types.ModelEvaluation' - scheduler_id: - type: integer - type: - type: string - version: - type: string - required: - - evaluation - - scheduler_id - - type - - version - type: object d7y_io_dragonfly_v2_manager_types.CreateOauthRequest: properties: bio: @@ -597,27 +581,6 @@ definitions: status: type: string type: object - d7y_io_dragonfly_v2_manager_types.ModelEvaluation: - properties: - f1_score: - maximum: 1 - minimum: 0 - type: number - mae: - minimum: 0 - type: number - mse: - minimum: 0 - type: number - precision: - maximum: 1 - minimum: 0 - type: number - recall: - maximum: 1 - minimum: 0 - type: number - type: object d7y_io_dragonfly_v2_manager_types.PriorityConfig: properties: urls: @@ -803,14 +766,9 @@ definitions: properties: BIO: type: string - evaluation: - $ref: '#/definitions/d7y_io_dragonfly_v2_manager_types.ModelEvaluation' - scheduler_id: - type: integer state: enum: - active - - inactive type: string type: object d7y_io_dragonfly_v2_manager_types.UpdateOauthRequest: @@ -1662,33 +1620,6 @@ paths: summary: Get Models tags: - Model - post: - consumes: - - application/json - description: Create Model by json config - parameters: - - description: Model - in: body - name: Model - required: true - schema: - $ref: '#/definitions/d7y_io_dragonfly_v2_manager_types.CreateModelRequest' - produces: - - application/json - responses: - "200": - description: OK - schema: - $ref: '#/definitions/d7y_io_dragonfly_v2_manager_models.Model' - "400": - description: Bad Request - "404": - description: Not Found - "500": - description: Internal Server Error - summary: Create Model - tags: - - Model /models/{id}: delete: consumes: diff --git a/manager/handlers/model.go b/manager/handlers/model.go index 4c3b51859..c80d28f2e 100644 --- a/manager/handlers/model.go +++ b/manager/handlers/model.go @@ -9,33 +9,6 @@ import ( "d7y.io/dragonfly/v2/manager/types" ) -// @Summary Create Model -// @Description Create Model by json config -// @Tags Model -// @Accept json -// @Produce json -// @Param Model body types.CreateModelRequest true "Model" -// @Success 200 {object} models.Model -// @Failure 400 -// @Failure 404 -// @Failure 500 -// @Router /models [post] -func (h *Handlers) CreateModel(ctx *gin.Context) { - var json types.CreateModelRequest - if err := ctx.ShouldBindJSON(&json); err != nil { - ctx.JSON(http.StatusUnprocessableEntity, gin.H{"errors": err.Error()}) - return - } - - model, err := h.service.CreateModel(ctx.Request.Context(), json) - if err != nil { - ctx.Error(err) // nolint: errcheck - return - } - - ctx.JSON(http.StatusOK, model) -} - // @Summary Destroy Model // @Description Destroy by id // @Tags Model diff --git a/manager/manager.go b/manager/manager.go index 7e50b5539..7deabf074 100644 --- a/manager/manager.go +++ b/manager/manager.go @@ -147,7 +147,7 @@ func New(cfg *config.Config, d dfpath.Dfpath) (*Server, error) { } // Initialize REST server - restService := service.New(db, cache, job, enforcer, objectStorage) + restService := service.New(cfg, db, cache, job, enforcer, objectStorage) router, err := router.Init(cfg, d.LogDir(), restService, enforcer, EmbedFolder(assets, assetsTargetPath)) if err != nil { return nil, err diff --git a/manager/models/model.go b/manager/models/model.go index d8afce2ec..1a0d48fa8 100644 --- a/manager/models/model.go +++ b/manager/models/model.go @@ -35,6 +35,7 @@ const ( // TODO(Gaius) Add regression analysis parameters. type Model struct { BaseModel + Name string `gorm:"column:name;type:varchar(256);index:uk_model_name,unique;not null;comment:name" json:"name"` Type string `gorm:"column:type;type:varchar(256);index:uk_model,unique;not null;comment:type" json:"type"` BIO string `gorm:"column:bio;type:varchar(1024);comment:biography" json:"bio"` Version string `gorm:"column:version;type:varchar(256);index:uk_model,unique;not null;comment:model version" json:"version"` diff --git a/manager/router/router.go b/manager/router/router.go index 7f64dcd53..5e42ebe04 100644 --- a/manager/router/router.go +++ b/manager/router/router.go @@ -197,7 +197,6 @@ func Init(cfg *config.Config, logDir string, service service.Service, enforcer * // Model model := apiv1.Group("/models", jwt.MiddlewareFunc(), rbac) - model.POST("", h.CreateModel) model.DELETE(":id", h.DestroyModel) model.PATCH(":id", h.UpdateModel) model.GET(":id", h.GetModel) diff --git a/manager/rpcserver/manager_server_v1.go b/manager/rpcserver/manager_server_v1.go index 21536f476..283828f9e 100644 --- a/manager/rpcserver/manager_server_v1.go +++ b/manager/rpcserver/manager_server_v1.go @@ -835,6 +835,7 @@ func (s *managerServerV1) CreateModel(ctx context.Context, req *managerv1.Create // Create model in database. if err := s.db.WithContext(ctx).Model(&scheduler).Association("Models").Append(&models.Model{ + Name: name, Type: typ, Version: fmt.Sprint(version), State: models.ModelVersionStateInactive, diff --git a/manager/service/mocks/service_mock.go b/manager/service/mocks/service_mock.go index db0e6770a..b5b4733c0 100644 --- a/manager/service/mocks/service_mock.go +++ b/manager/service/mocks/service_mock.go @@ -170,21 +170,6 @@ func (mr *MockServiceMockRecorder) CreateConfig(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateConfig", reflect.TypeOf((*MockService)(nil).CreateConfig), arg0, arg1) } -// CreateModel mocks base method. -func (m *MockService) CreateModel(arg0 context.Context, arg1 types.CreateModelRequest) (*models.Model, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateModel", arg0, arg1) - ret0, _ := ret[0].(*models.Model) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateModel indicates an expected call of CreateModel. -func (mr *MockServiceMockRecorder) CreateModel(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateModel", reflect.TypeOf((*MockService)(nil).CreateModel), arg0, arg1) -} - // CreateOauth mocks base method. func (m *MockService) CreateOauth(arg0 context.Context, arg1 types.CreateOauthRequest) (*models.Oauth, error) { m.ctrl.T.Helper() diff --git a/manager/service/model.go b/manager/service/model.go index 1a4f641df..aed592cc2 100644 --- a/manager/service/model.go +++ b/manager/service/model.go @@ -18,33 +18,20 @@ package service import ( "context" + "encoding/json" + "errors" + "fmt" + "io" + "strconv" + "strings" + + inferencev1 "d7y.io/api/pkg/apis/inference/v1" "d7y.io/dragonfly/v2/manager/models" "d7y.io/dragonfly/v2/manager/types" - "d7y.io/dragonfly/v2/pkg/structure" + "d7y.io/dragonfly/v2/pkg/digest" ) -func (s *service) CreateModel(ctx context.Context, json types.CreateModelRequest) (*models.Model, error) { - evaluation, err := structure.StructToMap(json.Evaluation) - if err != nil { - return nil, err - } - - model := models.Model{ - Type: json.Type, - BIO: json.BIO, - Version: json.Version, - Evaluation: evaluation, - SchedulerID: json.SchedulerID, - } - - if err := s.db.WithContext(ctx).Create(&model).Error; err != nil { - return nil, err - } - - return &model, nil -} - func (s *service) DestroyModel(ctx context.Context, id uint) error { model := models.Model{} if err := s.db.WithContext(ctx).First(&model, id).Error; err != nil { @@ -59,23 +46,22 @@ func (s *service) DestroyModel(ctx context.Context, id uint) error { } func (s *service) UpdateModel(ctx context.Context, id uint, json types.UpdateModelRequest) (*models.Model, error) { - var ( - evaluation map[string]any - err error - ) - if json.Evaluation != nil { - evaluation, err = structure.StructToMap(json.Evaluation) - if err != nil { + model := models.Model{} + if err := s.db.WithContext(ctx).First(&model, id).Error; err != nil { + return nil, err + } + + // If the model is active, update the model config and + // update the model state. + if json.State == models.ModelVersionStateActive { + if err := s.updateModelStateToActive(ctx, &model); err != nil { return nil, err } } - model := models.Model{} - if err := s.db.WithContext(ctx).First(&model, id).Updates(models.Model{ - BIO: json.BIO, - State: json.State, - Evaluation: evaluation, - SchedulerID: json.SchedulerID, + // Update the model. + if err := s.db.WithContext(ctx).Model(&model).Updates(models.Model{ + BIO: json.BIO, }).Error; err != nil { return nil, err } @@ -105,3 +91,83 @@ func (s *service) GetModels(ctx context.Context, q types.GetModelsQuery) ([]mode return model, count, nil } + +func (s *service) updateModelStateToActive(ctx context.Context, model *models.Model) error { + version, err := strconv.ParseInt(model.Version, 10, 64) + if err != nil { + return err + } + + // Update the model config to object storage. + if err := s.updateModelConfig(ctx, model.Name, version); err != nil { + return err + } + + // Create a transaction to ensure that only one + // version is active at a time. + tx := s.db.WithContext(ctx).Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + + if err := tx.Error; err != nil { + return err + } + + if err := tx.Model(&models.Model{}).Where("state = ?", models.ModelVersionStateActive).Update("state", models.ModelVersionStateInactive).Error; err != nil { + tx.Rollback() + return err + } + + if err := tx.Model(model).Update("state", models.ModelVersionStateActive).Error; err != nil { + tx.Rollback() + return err + } + + if tx.Commit().Error != nil { + return err + } + + return nil +} + +func (s *service) updateModelConfig(ctx context.Context, name string, version int64) error { + if !s.config.ObjectStorage.Enable { + return errors.New("object storage is disabled") + } + + objectKey := types.MakeObjectKeyOfModelConfigFile(name) + var pbModelConfig inferencev1.ModelConfig + reader, err := s.objectStorage.GetOject(ctx, s.config.Trainer.BucketName, objectKey) + if err != nil { + return err + } + defer reader.Close() + + data, err := io.ReadAll(reader) + if err != nil { + return err + } + + if err := json.Unmarshal(data, &pbModelConfig); err != nil { + return err + } + + switch policyChoice := pbModelConfig.VersionPolicy.PolicyChoice.(type) { + case *inferencev1.ModelVersionPolicy_Specific_: + // If the version already exists, add the version to the existing version list. + policyChoice.Specific.Versions = []int64{version} + default: + return fmt.Errorf("unknown policy choice: %#v", policyChoice) + } + + dgst := digest.New(digest.AlgorithmSHA256, digest.SHA256FromStrings(pbModelConfig.String())) + if err := s.objectStorage.PutObject(ctx, s.config.Trainer.BucketName, + types.MakeObjectKeyOfModelConfigFile(name), dgst.String(), strings.NewReader(pbModelConfig.String())); err != nil { + return err + } + + return nil +} diff --git a/manager/service/service.go b/manager/service/service.go index 2667c2690..49638dc9e 100644 --- a/manager/service/service.go +++ b/manager/service/service.go @@ -27,6 +27,7 @@ import ( "gorm.io/gorm" "d7y.io/dragonfly/v2/manager/cache" + "d7y.io/dragonfly/v2/manager/config" "d7y.io/dragonfly/v2/manager/database" "d7y.io/dragonfly/v2/manager/job" "d7y.io/dragonfly/v2/manager/models" @@ -124,7 +125,6 @@ type Service interface { GetApplication(context.Context, uint) (*models.Application, error) GetApplications(context.Context, types.GetApplicationsQuery) ([]models.Application, int64, error) - CreateModel(context.Context, types.CreateModelRequest) (*models.Model, error) DestroyModel(context.Context, uint) error UpdateModel(context.Context, uint, types.UpdateModelRequest) (*models.Model, error) GetModel(context.Context, uint) (*models.Model, error) @@ -132,6 +132,7 @@ type Service interface { } type service struct { + config *config.Config db *gorm.DB rdb redis.UniversalClient cache *cache.Cache @@ -141,8 +142,9 @@ type service struct { } // NewREST returns a new REST instence -func New(database *database.Database, cache *cache.Cache, job *job.Job, enforcer *casbin.Enforcer, objectStorage objectstorage.ObjectStorage) Service { +func New(cfg *config.Config, database *database.Database, cache *cache.Cache, job *job.Job, enforcer *casbin.Enforcer, objectStorage objectstorage.ObjectStorage) Service { return &service{ + config: cfg, db: database.DB, rdb: database.RDB, cache: cache, diff --git a/manager/types/model.go b/manager/types/model.go index 3d6af9dd6..7842e6d95 100644 --- a/manager/types/model.go +++ b/manager/types/model.go @@ -41,22 +41,13 @@ type ModelParams struct { ID uint `uri:"id" binding:"required"` } -type CreateModelRequest struct { - Type string `json:"type" binding:"required"` - BIO string `json:"BIO" binding:"omitempty"` - Version string `json:"version" binding:"required"` - Evaluation *ModelEvaluation `json:"evaluation" binding:"required"` - SchedulerID uint `json:"scheduler_id" binding:"required"` -} - type UpdateModelRequest struct { - BIO string `json:"BIO" binding:"omitempty"` - State string `json:"state" binding:"omitempty,oneof=active inactive"` - Evaluation *ModelEvaluation `json:"evaluation" binding:"omitempty"` - SchedulerID uint `json:"scheduler_id" binding:"omitempty"` + BIO string `json:"BIO" binding:"omitempty"` + State string `json:"state" binding:"omitempty,oneof=active"` } type GetModelsQuery struct { + Name string `json:"name" binding:"omitempty"` Type string `json:"type" binding:"omitempty"` Version string `json:"version" binding:"omitempty"` SchedulerID uint `json:"scheduler_id" binding:"omitempty"`