feat: manager adds trainer config (#2494)
Signed-off-by: huangmin <2107139596@qq.com>
This commit is contained in:
parent
b58879a0c3
commit
07366174b5
|
|
@ -57,6 +57,9 @@ type Config struct {
|
||||||
|
|
||||||
// Network configuration.
|
// Network configuration.
|
||||||
Network NetworkConfig `yaml:"network" mapstructure:"network"`
|
Network NetworkConfig `yaml:"network" mapstructure:"network"`
|
||||||
|
|
||||||
|
// Trainer configuration.
|
||||||
|
Trainer TrainerConfig `yaml:"trainer" mapstructure:"trainer"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
|
|
@ -331,6 +334,14 @@ type NetworkConfig struct {
|
||||||
EnableIPv6 bool `mapstructure:"enableIPv6" yaml:"enableIPv6"`
|
EnableIPv6 bool `mapstructure:"enableIPv6" yaml:"enableIPv6"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TrainerConfig struct {
|
||||||
|
// Enable trainer service.
|
||||||
|
Enable bool `yaml:"enable" mapstructure:"enable"`
|
||||||
|
|
||||||
|
// BucketName is the object storage bucket name of model.
|
||||||
|
BucketName string `yaml:"bucketName" mapstructure:"bucketName"`
|
||||||
|
}
|
||||||
|
|
||||||
// New config instance.
|
// New config instance.
|
||||||
func New() *Config {
|
func New() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
|
|
@ -404,6 +415,10 @@ func New() *Config {
|
||||||
Network: NetworkConfig{
|
Network: NetworkConfig{
|
||||||
EnableIPv6: DefaultNetworkEnableIPv6,
|
EnableIPv6: DefaultNetworkEnableIPv6,
|
||||||
},
|
},
|
||||||
|
Trainer: TrainerConfig{
|
||||||
|
Enable: false,
|
||||||
|
BucketName: DefaultTrainerBucketName,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -589,6 +604,11 @@ func (cfg *Config) Validate() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.Trainer.Enable {
|
||||||
|
if cfg.Trainer.BucketName == "" {
|
||||||
|
return errors.New("trainer requires parameter bucketName")
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -103,6 +103,11 @@ var (
|
||||||
ValidityPeriod: DefaultCertValidityPeriod,
|
ValidityPeriod: DefaultCertValidityPeriod,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mockTrainerConfig = TrainerConfig{
|
||||||
|
Enable: true,
|
||||||
|
BucketName: DefaultTrainerBucketName,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig_Load(t *testing.T) {
|
func TestConfig_Load(t *testing.T) {
|
||||||
|
|
@ -207,6 +212,10 @@ func TestConfig_Load(t *testing.T) {
|
||||||
Network: NetworkConfig{
|
Network: NetworkConfig{
|
||||||
EnableIPv6: true,
|
EnableIPv6: true,
|
||||||
},
|
},
|
||||||
|
Trainer: TrainerConfig{
|
||||||
|
Enable: true,
|
||||||
|
BucketName: "models",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
managerConfigYAML := &Config{}
|
managerConfigYAML := &Config{}
|
||||||
|
|
@ -838,6 +847,23 @@ func TestConfig_Validate(t *testing.T) {
|
||||||
assert.EqualError(err, "certSpec requires parameter validityPeriod")
|
assert.EqualError(err, "certSpec requires parameter validityPeriod")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "trainer requires parameter bucketName",
|
||||||
|
config: New(),
|
||||||
|
mock: func(cfg *Config) {
|
||||||
|
cfg.Auth.JWT = mockJWTConfig
|
||||||
|
cfg.Database.Type = DatabaseTypeMysql
|
||||||
|
cfg.Database.Mysql = mockMysqlConfig
|
||||||
|
cfg.Database.Redis = mockRedisConfig
|
||||||
|
cfg.Security = mockSecurityConfig
|
||||||
|
cfg.Trainer = mockTrainerConfig
|
||||||
|
cfg.Trainer.BucketName = ""
|
||||||
|
},
|
||||||
|
expect: func(t *testing.T, err error) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
assert.EqualError(err, "trainer requires parameter bucketName")
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
|
|
|
||||||
|
|
@ -123,3 +123,8 @@ var (
|
||||||
// DefaultNetworkEnableIPv6 is default value of enableIPv6.
|
// DefaultNetworkEnableIPv6 is default value of enableIPv6.
|
||||||
DefaultNetworkEnableIPv6 = false
|
DefaultNetworkEnableIPv6 = false
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// DefaultTrainerBucketName is default object storage bucket name of model.
|
||||||
|
DefaultTrainerBucketName = "models"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -88,3 +88,7 @@ metrics:
|
||||||
|
|
||||||
network:
|
network:
|
||||||
enableIPv6: true
|
enableIPv6: true
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
enable: true
|
||||||
|
bucketName: models
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue