diff --git a/manager/config/config.go b/manager/config/config.go index 9dc15286a..dc268a4ba 100644 --- a/manager/config/config.go +++ b/manager/config/config.go @@ -57,6 +57,9 @@ type Config struct { // Network configuration. Network NetworkConfig `yaml:"network" mapstructure:"network"` + + // Trainer configuration. + Trainer TrainerConfig `yaml:"trainer" mapstructure:"trainer"` } type ServerConfig struct { @@ -331,6 +334,14 @@ type NetworkConfig struct { EnableIPv6 bool `mapstructure:"enableIPv6" yaml:"enableIPv6"` } +type TrainerConfig struct { + // Enable trainer service. + Enable bool `yaml:"enable" mapstructure:"enable"` + + // BucketName is the object storage bucket name of model. + BucketName string `yaml:"bucketName" mapstructure:"bucketName"` +} + // New config instance. func New() *Config { return &Config{ @@ -404,6 +415,10 @@ func New() *Config { Network: NetworkConfig{ EnableIPv6: DefaultNetworkEnableIPv6, }, + Trainer: TrainerConfig{ + Enable: false, + BucketName: DefaultTrainerBucketName, + }, } } @@ -589,6 +604,11 @@ func (cfg *Config) Validate() error { } } + if cfg.Trainer.Enable { + if cfg.Trainer.BucketName == "" { + return errors.New("trainer requires parameter bucketName") + } + } return nil } diff --git a/manager/config/config_test.go b/manager/config/config_test.go index 6c4d652c7..63eaa7a0f 100644 --- a/manager/config/config_test.go +++ b/manager/config/config_test.go @@ -103,6 +103,11 @@ var ( ValidityPeriod: DefaultCertValidityPeriod, }, } + + mockTrainerConfig = TrainerConfig{ + Enable: true, + BucketName: DefaultTrainerBucketName, + } ) func TestConfig_Load(t *testing.T) { @@ -207,6 +212,10 @@ func TestConfig_Load(t *testing.T) { Network: NetworkConfig{ EnableIPv6: true, }, + Trainer: TrainerConfig{ + Enable: true, + BucketName: "models", + }, } managerConfigYAML := &Config{} @@ -838,6 +847,23 @@ func TestConfig_Validate(t *testing.T) { assert.EqualError(err, "certSpec requires parameter validityPeriod") }, }, + { + name: "trainer requires parameter bucketName", + config: New(), + mock: func(cfg *Config) { + cfg.Auth.JWT = mockJWTConfig + cfg.Database.Type = DatabaseTypeMysql + cfg.Database.Mysql = mockMysqlConfig + cfg.Database.Redis = mockRedisConfig + cfg.Security = mockSecurityConfig + cfg.Trainer = mockTrainerConfig + cfg.Trainer.BucketName = "" + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "trainer requires parameter bucketName") + }, + }, } for _, tc := range tests { diff --git a/manager/config/constants.go b/manager/config/constants.go index 063dd1df2..96955a7db 100644 --- a/manager/config/constants.go +++ b/manager/config/constants.go @@ -123,3 +123,8 @@ var ( // DefaultNetworkEnableIPv6 is default value of enableIPv6. DefaultNetworkEnableIPv6 = false ) + +var ( + // DefaultTrainerBucketName is default object storage bucket name of model. + DefaultTrainerBucketName = "models" +) diff --git a/manager/config/testdata/manager.yaml b/manager/config/testdata/manager.yaml index a08373e41..edcad8042 100644 --- a/manager/config/testdata/manager.yaml +++ b/manager/config/testdata/manager.yaml @@ -88,3 +88,7 @@ metrics: network: enableIPv6: true + +trainer: + enable: true + bucketName: models