feat: manager adds trainer config (#2494)

Signed-off-by: huangmin <2107139596@qq.com>
This commit is contained in:
Min 2023-06-29 18:51:24 +08:00 committed by GitHub
parent b58879a0c3
commit 07366174b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 0 deletions

View File

@ -57,6 +57,9 @@ type Config struct {
// Network configuration.
Network NetworkConfig `yaml:"network" mapstructure:"network"`
// Trainer configuration.
Trainer TrainerConfig `yaml:"trainer" mapstructure:"trainer"`
}
type ServerConfig struct {
@ -331,6 +334,14 @@ type NetworkConfig struct {
EnableIPv6 bool `mapstructure:"enableIPv6" yaml:"enableIPv6"`
}
type TrainerConfig struct {
// Enable trainer service.
Enable bool `yaml:"enable" mapstructure:"enable"`
// BucketName is the object storage bucket name of model.
BucketName string `yaml:"bucketName" mapstructure:"bucketName"`
}
// New config instance.
func New() *Config {
return &Config{
@ -404,6 +415,10 @@ func New() *Config {
Network: NetworkConfig{
EnableIPv6: DefaultNetworkEnableIPv6,
},
Trainer: TrainerConfig{
Enable: false,
BucketName: DefaultTrainerBucketName,
},
}
}
@ -589,6 +604,11 @@ func (cfg *Config) Validate() error {
}
}
if cfg.Trainer.Enable {
if cfg.Trainer.BucketName == "" {
return errors.New("trainer requires parameter bucketName")
}
}
return nil
}

View File

@ -103,6 +103,11 @@ var (
ValidityPeriod: DefaultCertValidityPeriod,
},
}
mockTrainerConfig = TrainerConfig{
Enable: true,
BucketName: DefaultTrainerBucketName,
}
)
func TestConfig_Load(t *testing.T) {
@ -207,6 +212,10 @@ func TestConfig_Load(t *testing.T) {
Network: NetworkConfig{
EnableIPv6: true,
},
Trainer: TrainerConfig{
Enable: true,
BucketName: "models",
},
}
managerConfigYAML := &Config{}
@ -838,6 +847,23 @@ func TestConfig_Validate(t *testing.T) {
assert.EqualError(err, "certSpec requires parameter validityPeriod")
},
},
{
name: "trainer requires parameter bucketName",
config: New(),
mock: func(cfg *Config) {
cfg.Auth.JWT = mockJWTConfig
cfg.Database.Type = DatabaseTypeMysql
cfg.Database.Mysql = mockMysqlConfig
cfg.Database.Redis = mockRedisConfig
cfg.Security = mockSecurityConfig
cfg.Trainer = mockTrainerConfig
cfg.Trainer.BucketName = ""
},
expect: func(t *testing.T, err error) {
assert := assert.New(t)
assert.EqualError(err, "trainer requires parameter bucketName")
},
},
}
for _, tc := range tests {

View File

@ -123,3 +123,8 @@ var (
// DefaultNetworkEnableIPv6 is default value of enableIPv6.
DefaultNetworkEnableIPv6 = false
)
var (
// DefaultTrainerBucketName is default object storage bucket name of model.
DefaultTrainerBucketName = "models"
)

View File

@ -88,3 +88,7 @@ metrics:
network:
enableIPv6: true
trainer:
enable: true
bucketName: models