feat: manager adds trainer config (#2494)

Signed-off-by: huangmin <2107139596@qq.com>
This commit is contained in:
Min 2023-06-29 18:51:24 +08:00 committed by GitHub
parent b58879a0c3
commit 07366174b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 0 deletions

View File

@ -57,6 +57,9 @@ type Config struct {
// Network configuration. // Network configuration.
Network NetworkConfig `yaml:"network" mapstructure:"network"` Network NetworkConfig `yaml:"network" mapstructure:"network"`
// Trainer configuration.
Trainer TrainerConfig `yaml:"trainer" mapstructure:"trainer"`
} }
type ServerConfig struct { type ServerConfig struct {
@ -331,6 +334,14 @@ type NetworkConfig struct {
EnableIPv6 bool `mapstructure:"enableIPv6" yaml:"enableIPv6"` EnableIPv6 bool `mapstructure:"enableIPv6" yaml:"enableIPv6"`
} }
type TrainerConfig struct {
// Enable trainer service.
Enable bool `yaml:"enable" mapstructure:"enable"`
// BucketName is the object storage bucket name of model.
BucketName string `yaml:"bucketName" mapstructure:"bucketName"`
}
// New config instance. // New config instance.
func New() *Config { func New() *Config {
return &Config{ return &Config{
@ -404,6 +415,10 @@ func New() *Config {
Network: NetworkConfig{ Network: NetworkConfig{
EnableIPv6: DefaultNetworkEnableIPv6, EnableIPv6: DefaultNetworkEnableIPv6,
}, },
Trainer: TrainerConfig{
Enable: false,
BucketName: DefaultTrainerBucketName,
},
} }
} }
@ -589,6 +604,11 @@ func (cfg *Config) Validate() error {
} }
} }
if cfg.Trainer.Enable {
if cfg.Trainer.BucketName == "" {
return errors.New("trainer requires parameter bucketName")
}
}
return nil return nil
} }

View File

@ -103,6 +103,11 @@ var (
ValidityPeriod: DefaultCertValidityPeriod, ValidityPeriod: DefaultCertValidityPeriod,
}, },
} }
mockTrainerConfig = TrainerConfig{
Enable: true,
BucketName: DefaultTrainerBucketName,
}
) )
func TestConfig_Load(t *testing.T) { func TestConfig_Load(t *testing.T) {
@ -207,6 +212,10 @@ func TestConfig_Load(t *testing.T) {
Network: NetworkConfig{ Network: NetworkConfig{
EnableIPv6: true, EnableIPv6: true,
}, },
Trainer: TrainerConfig{
Enable: true,
BucketName: "models",
},
} }
managerConfigYAML := &Config{} managerConfigYAML := &Config{}
@ -838,6 +847,23 @@ func TestConfig_Validate(t *testing.T) {
assert.EqualError(err, "certSpec requires parameter validityPeriod") assert.EqualError(err, "certSpec requires parameter validityPeriod")
}, },
}, },
{
name: "trainer requires parameter bucketName",
config: New(),
mock: func(cfg *Config) {
cfg.Auth.JWT = mockJWTConfig
cfg.Database.Type = DatabaseTypeMysql
cfg.Database.Mysql = mockMysqlConfig
cfg.Database.Redis = mockRedisConfig
cfg.Security = mockSecurityConfig
cfg.Trainer = mockTrainerConfig
cfg.Trainer.BucketName = ""
},
expect: func(t *testing.T, err error) {
assert := assert.New(t)
assert.EqualError(err, "trainer requires parameter bucketName")
},
},
} }
for _, tc := range tests { for _, tc := range tests {

View File

@ -123,3 +123,8 @@ var (
// DefaultNetworkEnableIPv6 is default value of enableIPv6. // DefaultNetworkEnableIPv6 is default value of enableIPv6.
DefaultNetworkEnableIPv6 = false DefaultNetworkEnableIPv6 = false
) )
var (
// DefaultTrainerBucketName is default object storage bucket name of model.
DefaultTrainerBucketName = "models"
)

View File

@ -88,3 +88,7 @@ metrics:
network: network:
enableIPv6: true enableIPv6: true
trainer:
enable: true
bucketName: models