mirror of https://github.com/docker/docs.git
Ensure that environment variables can override config file entries.
Also support parameterized allowed backends when parsing for storage backends, so that a DB backend can be tested. Signed-off-by: Ying Li <ying.li@docker.com>
This commit is contained in:
parent
b1fdea5b56
commit
b25f8546f8
|
@ -12,31 +12,24 @@ import (
|
|||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Server is a configuration about what addresses a server should listen on
|
||||
type Server struct {
|
||||
*ServerTLSOpts
|
||||
HTTPAddr string
|
||||
GRPCAddr string
|
||||
}
|
||||
|
||||
// Storage is a configuration about what storage backend a server should use
|
||||
type Storage struct {
|
||||
Backend string `mapstructure:"backend"`
|
||||
URL string `mapstructure:"db_url"`
|
||||
Backend string
|
||||
Source string
|
||||
}
|
||||
|
||||
// ParseServer tries to parse out a valid Server from a Viper:
|
||||
// - Either or both of HTTP and GRPC address must be provided
|
||||
// ParseServerTLS tries to parse out a valid ServerTLSOpts from a Viper:
|
||||
// - If TLS is required, both the cert and key must be provided
|
||||
// - If TLS is not requried, either both the cert and key must be provided or
|
||||
// neither must be provided
|
||||
func ParseServer(configuration *viper.Viper, tlsRequired bool) (*Server, error) {
|
||||
// mapstructure does not support unmarshalling into a pointer
|
||||
var tlsOpts ServerTLSOpts
|
||||
err := configuration.UnmarshalKey("server", &tlsOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func ParseServerTLS(configuration *viper.Viper, tlsRequired bool) (*ServerTLSOpts, error) {
|
||||
// unmarshalling into objects does not seem to pick up env vars
|
||||
tlsOpts := ServerTLSOpts{
|
||||
ServerCertFile: configuration.GetString("server.tls_cert_file"),
|
||||
ServerKeyFile: configuration.GetString("server.tls_key_file"),
|
||||
ClientCAFile: configuration.GetString("server.client_ca_file"),
|
||||
}
|
||||
|
||||
cert, key := tlsOpts.ServerCertFile, tlsOpts.ServerKeyFile
|
||||
if tlsRequired {
|
||||
if cert == "" || key == "" {
|
||||
|
@ -49,20 +42,11 @@ func ParseServer(configuration *viper.Viper, tlsRequired bool) (*Server, error)
|
|||
}
|
||||
}
|
||||
|
||||
server := Server{
|
||||
HTTPAddr: configuration.GetString("server.http_addr"),
|
||||
GRPCAddr: configuration.GetString("server.grpc_addr"),
|
||||
ServerTLSOpts: &tlsOpts,
|
||||
}
|
||||
if cert == "" && key == "" && tlsOpts.ClientCAFile == "" {
|
||||
server.ServerTLSOpts = nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if server.HTTPAddr == "" && server.GRPCAddr == "" {
|
||||
return nil, fmt.Errorf("server must have an HTTP and/or GRPC address")
|
||||
}
|
||||
|
||||
return &server, nil
|
||||
return &tlsOpts, nil
|
||||
}
|
||||
|
||||
// ParseLogLevel tries to parse out a log level from a Viper. If there is no
|
||||
|
@ -79,24 +63,28 @@ func ParseLogLevel(configuration *viper.Viper, defaultLevel logrus.Level) (
|
|||
|
||||
// ParseStorage tries to parse out Storage from a Viper. If backend and
|
||||
// URL are not provided, returns a nil pointer.
|
||||
func ParseStorage(configuration *viper.Viper) (*Storage, error) {
|
||||
var store Storage
|
||||
err := configuration.UnmarshalKey("storage", &store)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func ParseStorage(configuration *viper.Viper, allowedBackeneds []string) (*Storage, error) {
|
||||
store := Storage{
|
||||
Backend: configuration.GetString("storage.backend"),
|
||||
Source: configuration.GetString("storage.db_url"),
|
||||
}
|
||||
if store.Backend == "" && store.URL == "" {
|
||||
if store.Backend == "" && store.Source == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if store.Source == "" {
|
||||
return nil, fmt.Errorf("must provide a non-empty database source")
|
||||
}
|
||||
store.Backend = strings.ToLower(store.Backend)
|
||||
if store.Backend != "mysql" {
|
||||
return nil, fmt.Errorf(
|
||||
"must specify one of these supported backends: mysql")
|
||||
for _, backend := range allowedBackeneds {
|
||||
if backend == store.Backend {
|
||||
return &store, nil
|
||||
}
|
||||
|
||||
}
|
||||
if store.URL == "" {
|
||||
return nil, fmt.Errorf("must provide a non-empty database URL")
|
||||
}
|
||||
return &store, nil
|
||||
return nil, fmt.Errorf(
|
||||
"must specify one of these supported backends: %s",
|
||||
strings.Join(allowedBackeneds, ", "))
|
||||
}
|
||||
|
||||
// ParseBugsnag tries to parse out a Bugsnag Configuration from a Viper.
|
||||
|
@ -118,7 +106,15 @@ func ParseBugsnag(configuration *viper.Viper) (*bugsnag.Configuration, error) {
|
|||
return &bugconf, nil
|
||||
}
|
||||
|
||||
// utilities for handling common configurations
|
||||
// utilities for setting up/acting on common configurations
|
||||
|
||||
// SetupViper sets up an instance of viper to also look at environment
|
||||
// variables
|
||||
func SetupViper(configuration *viper.Viper, envPrefix string) {
|
||||
configuration.SetEnvPrefix(envPrefix)
|
||||
configuration.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
configuration.AutomaticEnv()
|
||||
}
|
||||
|
||||
// SetUpBugsnag configures bugsnag and sets up a logrus hook
|
||||
func SetUpBugsnag(config *bugsnag.Configuration) error {
|
||||
|
|
|
@ -3,6 +3,7 @@ package utils
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -12,14 +13,34 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const envPrefix = "NOTARY_TESTING_ENV_PREFIX"
|
||||
|
||||
// initializes a viper object with test configuration
|
||||
func configure(jsonConfig string) *viper.Viper {
|
||||
config := viper.New()
|
||||
SetupViper(config, envPrefix)
|
||||
config.SetConfigType("json")
|
||||
config.ReadConfig(bytes.NewBuffer([]byte(jsonConfig)))
|
||||
return config
|
||||
}
|
||||
|
||||
// Sets the environment variables in the given map, prefixed by envPrefix.
|
||||
func setupEnvironmentVariables(t *testing.T, vars map[string]string) {
|
||||
for k, v := range vars {
|
||||
err := os.Setenv(fmt.Sprintf("%s_%s", envPrefix, k), v)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Unsets whatever environment variables were set with this map
|
||||
func cleanupEnvironmentVariables(t *testing.T, vars map[string]string) {
|
||||
for k := range vars {
|
||||
err := os.Unsetenv(fmt.Sprintf("%s_%s", envPrefix, k))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// An error is returned if the log level is not parsable
|
||||
func TestParseInvalidLogLevel(t *testing.T) {
|
||||
_, err := ParseLogLevel(configure(`{"logging": {"level": "horatio"}}`),
|
||||
|
@ -46,6 +67,17 @@ func TestParseLogLevel(t *testing.T) {
|
|||
assert.Equal(t, logrus.ErrorLevel, lvl)
|
||||
}
|
||||
|
||||
func TestParseLogLevelWithEnvironmentVariables(t *testing.T) {
|
||||
vars := map[string]string{"LOGGING_LEVEL": "error"}
|
||||
setupEnvironmentVariables(t, vars)
|
||||
defer cleanupEnvironmentVariables(t, vars)
|
||||
|
||||
lvl, err := ParseLogLevel(configure(`{}`),
|
||||
logrus.DebugLevel)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, logrus.ErrorLevel, lvl)
|
||||
}
|
||||
|
||||
// An error is returned if there's no API key
|
||||
func TestParseInvalidBugsnag(t *testing.T) {
|
||||
_, err := ParseBugsnag(configure(
|
||||
|
@ -86,23 +118,51 @@ func TestParseBugsnag(t *testing.T) {
|
|||
assert.Equal(t, expected, *bugconf)
|
||||
}
|
||||
|
||||
func TestParseBugsnagWithEnvironmentVariables(t *testing.T) {
|
||||
config := configure(`{
|
||||
"reporting": {
|
||||
"bugsnag": {
|
||||
"api_key": "12345",
|
||||
"release_stage": "staging"
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
vars := map[string]string{
|
||||
"REPORTING_BUGSNAG_RELEASE_STAGE": "production",
|
||||
"REPORTING_BUGSNAG_ENDPOINT": "http://1234.com",
|
||||
}
|
||||
setupEnvironmentVariables(t, vars)
|
||||
defer cleanupEnvironmentVariables(t, vars)
|
||||
|
||||
expected := bugsnag.Configuration{
|
||||
APIKey: "12345",
|
||||
ReleaseStage: "production",
|
||||
Endpoint: "http://1234.com",
|
||||
}
|
||||
|
||||
bugconf, err := ParseBugsnag(config)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, *bugconf)
|
||||
}
|
||||
|
||||
// If the storage parameters are invalid, an error is returned
|
||||
func TestParseInvalidStorage(t *testing.T) {
|
||||
invalids := []string{
|
||||
`{"storage": {"backend": "memow", "db_url": "1234"}}`,
|
||||
`{"storage": {"backend": "postgres", "db_url": "1234"}}`,
|
||||
`{"storage": {"db_url": "12345"}}`,
|
||||
`{"storage": {"backend": "mysql"}}`,
|
||||
`{"storage": {"backend": "mysql", "db_url": ""}}`,
|
||||
`{"storage": {"backend": "sqlite3", "db_url": ""}}`,
|
||||
}
|
||||
for _, configJSON := range invalids {
|
||||
_, err := ParseStorage(configure(configJSON))
|
||||
_, err := ParseStorage(configure(configJSON), []string{"mysql", "sqlite3"})
|
||||
assert.Error(t, err, fmt.Sprintf("'%s' should be an error", configJSON))
|
||||
if strings.Contains(configJSON, "mysql") {
|
||||
if strings.Contains(configJSON, "mysql") || strings.Contains(configJSON, "sqlite3") {
|
||||
assert.Contains(t, err.Error(),
|
||||
"must provide a non-empty database URL")
|
||||
"must provide a non-empty database source")
|
||||
} else {
|
||||
assert.Contains(t, err.Error(),
|
||||
"must specify one of these supported backends: mysql")
|
||||
"must specify one of these supported backends: mysql, sqlite3")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -111,7 +171,7 @@ func TestParseInvalidStorage(t *testing.T) {
|
|||
func TestParseNoStorage(t *testing.T) {
|
||||
empties := []string{`{}`, `{"storage": {}}`}
|
||||
for _, configJSON := range empties {
|
||||
store, err := ParseStorage(configure(configJSON))
|
||||
store, err := ParseStorage(configure(configJSON), []string{"mysql"})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, store)
|
||||
}
|
||||
|
@ -127,33 +187,43 @@ func TestParseStorage(t *testing.T) {
|
|||
|
||||
expected := Storage{
|
||||
Backend: "mysql",
|
||||
URL: "username:passord@tcp(hostname:1234)/dbname",
|
||||
Source: "username:passord@tcp(hostname:1234)/dbname",
|
||||
}
|
||||
|
||||
store, err := ParseStorage(config)
|
||||
store, err := ParseStorage(config, []string{"mysql"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, *store)
|
||||
}
|
||||
|
||||
// If the server section is missing or missing HTTP/GRPC addresses, an error is
|
||||
// returned
|
||||
func TestParseInvalidOrNoServer(t *testing.T) {
|
||||
invalids := []string{`{}`, `{"server": {}}`}
|
||||
for _, configJSON := range invalids {
|
||||
_, err := ParseServer(configure(configJSON), false)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "must have an HTTP and/or GRPC address")
|
||||
func TestParseStorageWithEnvironmentVariables(t *testing.T) {
|
||||
config := configure(`{
|
||||
"storage": {
|
||||
"db_url": "username:passord@tcp(hostname:1234)/dbname"
|
||||
}
|
||||
}`)
|
||||
|
||||
vars := map[string]string{"STORAGE_BACKEND": "MySQL"}
|
||||
setupEnvironmentVariables(t, vars)
|
||||
defer cleanupEnvironmentVariables(t, vars)
|
||||
|
||||
expected := Storage{
|
||||
Backend: "mysql",
|
||||
Source: "username:passord@tcp(hostname:1234)/dbname",
|
||||
}
|
||||
|
||||
store, err := ParseStorage(config, []string{"mysql"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, *store)
|
||||
}
|
||||
|
||||
// If TLS is required and the parameters are missing, an error is returned
|
||||
func TestParseInvalidServerNoTLSWhenRequired(t *testing.T) {
|
||||
func TestParseTLSNoTLSWhenRequired(t *testing.T) {
|
||||
invalids := []string{
|
||||
`{"server": {"http_addr": ":443", "tls_cert_file": "path/to/cert"}}`,
|
||||
`{"server": {"http_addr": ":443", "tls_key_file": "path/to/key"}}`,
|
||||
`{"server": {"tls_cert_file": "path/to/cert"}}`,
|
||||
`{"server": {"tls_key_file": "path/to/key"}}`,
|
||||
}
|
||||
for _, configJSON := range invalids {
|
||||
_, err := ParseServer(configure(configJSON), true)
|
||||
_, err := ParseServerTLS(configure(configJSON), true)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(),
|
||||
"both the TLS certificate and key are mandatory")
|
||||
|
@ -161,60 +231,71 @@ func TestParseInvalidServerNoTLSWhenRequired(t *testing.T) {
|
|||
}
|
||||
|
||||
// If TLS is not and the cert/key are partially provided, an error is returned
|
||||
func TestParseInvalidServerPartialTLS(t *testing.T) {
|
||||
func TestParseTLSPartialTLS(t *testing.T) {
|
||||
invalids := []string{
|
||||
`{"server": {"http_addr": ":443", "tls_cert_file": "path/to/cert"}}`,
|
||||
`{"server": {"http_addr": ":443", "tls_key_file": "path/to/key"}}`,
|
||||
`{"server": {"tls_cert_file": "path/to/cert"}}`,
|
||||
`{"server": {"tls_key_file": "path/to/key"}}`,
|
||||
}
|
||||
for _, configJSON := range invalids {
|
||||
_, err := ParseServer(configure(configJSON), false)
|
||||
_, err := ParseServerTLS(configure(configJSON), false)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(),
|
||||
"either include both a cert and key file, or neither to disable TLS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseServerNoTLS(t *testing.T) {
|
||||
func TestParseTLSNoTLSNotRequired(t *testing.T) {
|
||||
config := configure(`{
|
||||
"server": {
|
||||
"http_addr": ":4443",
|
||||
"grpc_addr": ":7899"
|
||||
}
|
||||
"server": {}
|
||||
}`)
|
||||
|
||||
expected := Server{
|
||||
HTTPAddr: ":4443",
|
||||
GRPCAddr: ":7899",
|
||||
ServerTLSOpts: nil,
|
||||
}
|
||||
|
||||
server, err := ParseServer(config, false)
|
||||
tlsOpts, err := ParseServerTLS(config, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, *server)
|
||||
assert.Nil(t, tlsOpts)
|
||||
}
|
||||
|
||||
func TestUnmarshalConfigServerWithTLS(t *testing.T) {
|
||||
func TestParseTLSWithTLS(t *testing.T) {
|
||||
config := configure(`{
|
||||
"server": {
|
||||
"http_addr": ":4443",
|
||||
"grpc_addr": ":7899",
|
||||
"tls_cert_file": "path/to/cert",
|
||||
"tls_key_file": "path/to/key",
|
||||
"client_ca_file": "path/to/clientca"
|
||||
}
|
||||
}`)
|
||||
|
||||
expected := Server{
|
||||
HTTPAddr: ":4443",
|
||||
GRPCAddr: ":7899",
|
||||
ServerTLSOpts: &ServerTLSOpts{
|
||||
ServerCertFile: "path/to/cert",
|
||||
ServerKeyFile: "path/to/key",
|
||||
ClientCAFile: "path/to/clientca",
|
||||
},
|
||||
expected := ServerTLSOpts{
|
||||
ServerCertFile: "path/to/cert",
|
||||
ServerKeyFile: "path/to/key",
|
||||
ClientCAFile: "path/to/clientca",
|
||||
}
|
||||
|
||||
server, err := ParseServer(config, false)
|
||||
tlsOpts, err := ParseServerTLS(config, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, *server)
|
||||
assert.Equal(t, expected, *tlsOpts)
|
||||
}
|
||||
|
||||
func TestParseTLSWithEnvironmentVariables(t *testing.T) {
|
||||
config := configure(`{
|
||||
"server": {
|
||||
"tls_cert_file": "path/to/cert",
|
||||
"client_ca_file": "nosuchfile"
|
||||
}
|
||||
}`)
|
||||
|
||||
vars := map[string]string{
|
||||
"SERVER_TLS_KEY_FILE": "path/to/key",
|
||||
"SERVER_CLIENT_CA_FILE": "path/to/clientca",
|
||||
}
|
||||
setupEnvironmentVariables(t, vars)
|
||||
defer cleanupEnvironmentVariables(t, vars)
|
||||
|
||||
expected := ServerTLSOpts{
|
||||
ServerCertFile: "path/to/cert",
|
||||
ServerKeyFile: "path/to/key",
|
||||
ClientCAFile: "path/to/clientca",
|
||||
}
|
||||
|
||||
tlsOpts, err := ParseServerTLS(config, true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, *tlsOpts)
|
||||
}
|
||||
|
|
|
@ -44,9 +44,9 @@ func poolFromFile(filename string) (*x509.CertPool, error) {
|
|||
// ServerTLSOpts generates a tls configuration for servers using the
|
||||
// provided parameters.
|
||||
type ServerTLSOpts struct {
|
||||
ServerCertFile string `mapstructure:"tls_cert_file"`
|
||||
ServerKeyFile string `mapstructure:"tls_key_file"`
|
||||
ClientCAFile string `mapstructure:"client_ca_file"`
|
||||
ServerCertFile string
|
||||
ServerKeyFile string
|
||||
ClientCAFile string
|
||||
}
|
||||
|
||||
// ConfigureServerTLS specifies a set of ciphersuites, the server cert and key,
|
||||
|
@ -87,11 +87,11 @@ func ConfigureServerTLS(opts *ServerTLSOpts) (*tls.Config, error) {
|
|||
// ClientTLSOpts is a struct that contains options to pass to
|
||||
// ConfigureClientTLS
|
||||
type ClientTLSOpts struct {
|
||||
RootCAFile string `json:"tls_ca_file"`
|
||||
ServerName string `json:"hostname"`
|
||||
InsecureSkipVerify bool `json:"-"`
|
||||
ClientCertFile string `json:"tls_client_cert"`
|
||||
ClientKeyFile string `json:"tls_client_key"`
|
||||
RootCAFile string
|
||||
ServerName string
|
||||
InsecureSkipVerify bool
|
||||
ClientCertFile string
|
||||
ClientKeyFile string
|
||||
}
|
||||
|
||||
// ConfigureClientTLS generates a tls configuration for clients using the
|
||||
|
|
Loading…
Reference in New Issue