feat: manager adds trainer config (#2494)
Signed-off-by: huangmin <2107139596@qq.com>
This commit is contained in:
parent
b58879a0c3
commit
07366174b5
|
|
@ -57,6 +57,9 @@ type Config struct {
|
|||
|
||||
// Network configuration.
|
||||
Network NetworkConfig `yaml:"network" mapstructure:"network"`
|
||||
|
||||
// Trainer configuration.
|
||||
Trainer TrainerConfig `yaml:"trainer" mapstructure:"trainer"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
|
|
@ -331,6 +334,14 @@ type NetworkConfig struct {
|
|||
EnableIPv6 bool `mapstructure:"enableIPv6" yaml:"enableIPv6"`
|
||||
}
|
||||
|
||||
type TrainerConfig struct {
|
||||
// Enable trainer service.
|
||||
Enable bool `yaml:"enable" mapstructure:"enable"`
|
||||
|
||||
// BucketName is the object storage bucket name of model.
|
||||
BucketName string `yaml:"bucketName" mapstructure:"bucketName"`
|
||||
}
|
||||
|
||||
// New config instance.
|
||||
func New() *Config {
|
||||
return &Config{
|
||||
|
|
@ -404,6 +415,10 @@ func New() *Config {
|
|||
Network: NetworkConfig{
|
||||
EnableIPv6: DefaultNetworkEnableIPv6,
|
||||
},
|
||||
Trainer: TrainerConfig{
|
||||
Enable: false,
|
||||
BucketName: DefaultTrainerBucketName,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -589,6 +604,11 @@ func (cfg *Config) Validate() error {
|
|||
}
|
||||
}
|
||||
|
||||
if cfg.Trainer.Enable {
|
||||
if cfg.Trainer.BucketName == "" {
|
||||
return errors.New("trainer requires parameter bucketName")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -103,6 +103,11 @@ var (
|
|||
ValidityPeriod: DefaultCertValidityPeriod,
|
||||
},
|
||||
}
|
||||
|
||||
mockTrainerConfig = TrainerConfig{
|
||||
Enable: true,
|
||||
BucketName: DefaultTrainerBucketName,
|
||||
}
|
||||
)
|
||||
|
||||
func TestConfig_Load(t *testing.T) {
|
||||
|
|
@ -207,6 +212,10 @@ func TestConfig_Load(t *testing.T) {
|
|||
Network: NetworkConfig{
|
||||
EnableIPv6: true,
|
||||
},
|
||||
Trainer: TrainerConfig{
|
||||
Enable: true,
|
||||
BucketName: "models",
|
||||
},
|
||||
}
|
||||
|
||||
managerConfigYAML := &Config{}
|
||||
|
|
@ -838,6 +847,23 @@ func TestConfig_Validate(t *testing.T) {
|
|||
assert.EqualError(err, "certSpec requires parameter validityPeriod")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "trainer requires parameter bucketName",
|
||||
config: New(),
|
||||
mock: func(cfg *Config) {
|
||||
cfg.Auth.JWT = mockJWTConfig
|
||||
cfg.Database.Type = DatabaseTypeMysql
|
||||
cfg.Database.Mysql = mockMysqlConfig
|
||||
cfg.Database.Redis = mockRedisConfig
|
||||
cfg.Security = mockSecurityConfig
|
||||
cfg.Trainer = mockTrainerConfig
|
||||
cfg.Trainer.BucketName = ""
|
||||
},
|
||||
expect: func(t *testing.T, err error) {
|
||||
assert := assert.New(t)
|
||||
assert.EqualError(err, "trainer requires parameter bucketName")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
|
|
|||
|
|
@ -123,3 +123,8 @@ var (
|
|||
// DefaultNetworkEnableIPv6 is default value of enableIPv6.
|
||||
DefaultNetworkEnableIPv6 = false
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultTrainerBucketName is default object storage bucket name of model.
|
||||
DefaultTrainerBucketName = "models"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -88,3 +88,7 @@ metrics:
|
|||
|
||||
network:
|
||||
enableIPv6: true
|
||||
|
||||
trainer:
|
||||
enable: true
|
||||
bucketName: models
|
||||
|
|
|
|||
Loading…
Reference in New Issue