Ensure that environment variables can override config file entries.

Also support parameterized allowed backends when parsing for
storage backends, so that a DB backend can be tested.

Signed-off-by: Ying Li <ying.li@docker.com>
This commit is contained in:
Ying Li 2015-11-19 20:28:13 -08:00
parent b1fdea5b56
commit b25f8546f8
3 changed files with 178 additions and 101 deletions

View File

@ -12,31 +12,24 @@ import (
"github.com/spf13/viper"
)
// Server is a configuration about what addresses a server should listen on
type Server struct {
*ServerTLSOpts
HTTPAddr string
GRPCAddr string
}
// Storage is a configuration about what storage backend a server should use
type Storage struct {
Backend string `mapstructure:"backend"`
URL string `mapstructure:"db_url"`
Backend string
Source string
}
// ParseServer tries to parse out a valid Server from a Viper:
// - Either or both of HTTP and GRPC address must be provided
// ParseServerTLS tries to parse out a valid ServerTLSOpts from a Viper:
// - If TLS is required, both the cert and key must be provided
// - If TLS is not requried, either both the cert and key must be provided or
// neither must be provided
func ParseServer(configuration *viper.Viper, tlsRequired bool) (*Server, error) {
// mapstructure does not support unmarshalling into a pointer
var tlsOpts ServerTLSOpts
err := configuration.UnmarshalKey("server", &tlsOpts)
if err != nil {
return nil, err
func ParseServerTLS(configuration *viper.Viper, tlsRequired bool) (*ServerTLSOpts, error) {
// unmarshalling into objects does not seem to pick up env vars
tlsOpts := ServerTLSOpts{
ServerCertFile: configuration.GetString("server.tls_cert_file"),
ServerKeyFile: configuration.GetString("server.tls_key_file"),
ClientCAFile: configuration.GetString("server.client_ca_file"),
}
cert, key := tlsOpts.ServerCertFile, tlsOpts.ServerKeyFile
if tlsRequired {
if cert == "" || key == "" {
@ -49,20 +42,11 @@ func ParseServer(configuration *viper.Viper, tlsRequired bool) (*Server, error)
}
}
server := Server{
HTTPAddr: configuration.GetString("server.http_addr"),
GRPCAddr: configuration.GetString("server.grpc_addr"),
ServerTLSOpts: &tlsOpts,
}
if cert == "" && key == "" && tlsOpts.ClientCAFile == "" {
server.ServerTLSOpts = nil
return nil, nil
}
if server.HTTPAddr == "" && server.GRPCAddr == "" {
return nil, fmt.Errorf("server must have an HTTP and/or GRPC address")
}
return &server, nil
return &tlsOpts, nil
}
// ParseLogLevel tries to parse out a log level from a Viper. If there is no
@ -79,24 +63,28 @@ func ParseLogLevel(configuration *viper.Viper, defaultLevel logrus.Level) (
// ParseStorage tries to parse out Storage from a Viper. If backend and
// URL are not provided, returns a nil pointer.
func ParseStorage(configuration *viper.Viper) (*Storage, error) {
var store Storage
err := configuration.UnmarshalKey("storage", &store)
if err != nil {
return nil, err
func ParseStorage(configuration *viper.Viper, allowedBackeneds []string) (*Storage, error) {
store := Storage{
Backend: configuration.GetString("storage.backend"),
Source: configuration.GetString("storage.db_url"),
}
if store.Backend == "" && store.URL == "" {
if store.Backend == "" && store.Source == "" {
return nil, nil
}
if store.Source == "" {
return nil, fmt.Errorf("must provide a non-empty database source")
}
store.Backend = strings.ToLower(store.Backend)
if store.Backend != "mysql" {
return nil, fmt.Errorf(
"must specify one of these supported backends: mysql")
for _, backend := range allowedBackeneds {
if backend == store.Backend {
return &store, nil
}
}
if store.URL == "" {
return nil, fmt.Errorf("must provide a non-empty database URL")
}
return &store, nil
return nil, fmt.Errorf(
"must specify one of these supported backends: %s",
strings.Join(allowedBackeneds, ", "))
}
// ParseBugsnag tries to parse out a Bugsnag Configuration from a Viper.
@ -118,7 +106,15 @@ func ParseBugsnag(configuration *viper.Viper) (*bugsnag.Configuration, error) {
return &bugconf, nil
}
// utilities for handling common configurations
// utilities for setting up/acting on common configurations
// SetupViper sets up an instance of viper to also look at environment
// variables
func SetupViper(configuration *viper.Viper, envPrefix string) {
configuration.SetEnvPrefix(envPrefix)
configuration.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
configuration.AutomaticEnv()
}
// SetUpBugsnag configures bugsnag and sets up a logrus hook
func SetUpBugsnag(config *bugsnag.Configuration) error {

View File

@ -3,6 +3,7 @@ package utils
import (
"bytes"
"fmt"
"os"
"strings"
"testing"
@ -12,14 +13,34 @@ import (
"github.com/stretchr/testify/assert"
)
const envPrefix = "NOTARY_TESTING_ENV_PREFIX"
// initializes a viper object with test configuration
func configure(jsonConfig string) *viper.Viper {
config := viper.New()
SetupViper(config, envPrefix)
config.SetConfigType("json")
config.ReadConfig(bytes.NewBuffer([]byte(jsonConfig)))
return config
}
// Sets the environment variables in the given map, prefixed by envPrefix.
func setupEnvironmentVariables(t *testing.T, vars map[string]string) {
for k, v := range vars {
err := os.Setenv(fmt.Sprintf("%s_%s", envPrefix, k), v)
assert.NoError(t, err)
}
}
// Unsets whatever environment variables were set with this map
func cleanupEnvironmentVariables(t *testing.T, vars map[string]string) {
for k := range vars {
err := os.Unsetenv(fmt.Sprintf("%s_%s", envPrefix, k))
assert.NoError(t, err)
}
}
// An error is returned if the log level is not parsable
func TestParseInvalidLogLevel(t *testing.T) {
_, err := ParseLogLevel(configure(`{"logging": {"level": "horatio"}}`),
@ -46,6 +67,17 @@ func TestParseLogLevel(t *testing.T) {
assert.Equal(t, logrus.ErrorLevel, lvl)
}
func TestParseLogLevelWithEnvironmentVariables(t *testing.T) {
vars := map[string]string{"LOGGING_LEVEL": "error"}
setupEnvironmentVariables(t, vars)
defer cleanupEnvironmentVariables(t, vars)
lvl, err := ParseLogLevel(configure(`{}`),
logrus.DebugLevel)
assert.NoError(t, err)
assert.Equal(t, logrus.ErrorLevel, lvl)
}
// An error is returned if there's no API key
func TestParseInvalidBugsnag(t *testing.T) {
_, err := ParseBugsnag(configure(
@ -86,23 +118,51 @@ func TestParseBugsnag(t *testing.T) {
assert.Equal(t, expected, *bugconf)
}
func TestParseBugsnagWithEnvironmentVariables(t *testing.T) {
config := configure(`{
"reporting": {
"bugsnag": {
"api_key": "12345",
"release_stage": "staging"
}
}
}`)
vars := map[string]string{
"REPORTING_BUGSNAG_RELEASE_STAGE": "production",
"REPORTING_BUGSNAG_ENDPOINT": "http://1234.com",
}
setupEnvironmentVariables(t, vars)
defer cleanupEnvironmentVariables(t, vars)
expected := bugsnag.Configuration{
APIKey: "12345",
ReleaseStage: "production",
Endpoint: "http://1234.com",
}
bugconf, err := ParseBugsnag(config)
assert.NoError(t, err)
assert.Equal(t, expected, *bugconf)
}
// If the storage parameters are invalid, an error is returned
func TestParseInvalidStorage(t *testing.T) {
invalids := []string{
`{"storage": {"backend": "memow", "db_url": "1234"}}`,
`{"storage": {"backend": "postgres", "db_url": "1234"}}`,
`{"storage": {"db_url": "12345"}}`,
`{"storage": {"backend": "mysql"}}`,
`{"storage": {"backend": "mysql", "db_url": ""}}`,
`{"storage": {"backend": "sqlite3", "db_url": ""}}`,
}
for _, configJSON := range invalids {
_, err := ParseStorage(configure(configJSON))
_, err := ParseStorage(configure(configJSON), []string{"mysql", "sqlite3"})
assert.Error(t, err, fmt.Sprintf("'%s' should be an error", configJSON))
if strings.Contains(configJSON, "mysql") {
if strings.Contains(configJSON, "mysql") || strings.Contains(configJSON, "sqlite3") {
assert.Contains(t, err.Error(),
"must provide a non-empty database URL")
"must provide a non-empty database source")
} else {
assert.Contains(t, err.Error(),
"must specify one of these supported backends: mysql")
"must specify one of these supported backends: mysql, sqlite3")
}
}
}
@ -111,7 +171,7 @@ func TestParseInvalidStorage(t *testing.T) {
func TestParseNoStorage(t *testing.T) {
empties := []string{`{}`, `{"storage": {}}`}
for _, configJSON := range empties {
store, err := ParseStorage(configure(configJSON))
store, err := ParseStorage(configure(configJSON), []string{"mysql"})
assert.NoError(t, err)
assert.Nil(t, store)
}
@ -127,33 +187,43 @@ func TestParseStorage(t *testing.T) {
expected := Storage{
Backend: "mysql",
URL: "username:passord@tcp(hostname:1234)/dbname",
Source: "username:passord@tcp(hostname:1234)/dbname",
}
store, err := ParseStorage(config)
store, err := ParseStorage(config, []string{"mysql"})
assert.NoError(t, err)
assert.Equal(t, expected, *store)
}
// If the server section is missing or missing HTTP/GRPC addresses, an error is
// returned
func TestParseInvalidOrNoServer(t *testing.T) {
invalids := []string{`{}`, `{"server": {}}`}
for _, configJSON := range invalids {
_, err := ParseServer(configure(configJSON), false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "must have an HTTP and/or GRPC address")
func TestParseStorageWithEnvironmentVariables(t *testing.T) {
config := configure(`{
"storage": {
"db_url": "username:passord@tcp(hostname:1234)/dbname"
}
}`)
vars := map[string]string{"STORAGE_BACKEND": "MySQL"}
setupEnvironmentVariables(t, vars)
defer cleanupEnvironmentVariables(t, vars)
expected := Storage{
Backend: "mysql",
Source: "username:passord@tcp(hostname:1234)/dbname",
}
store, err := ParseStorage(config, []string{"mysql"})
assert.NoError(t, err)
assert.Equal(t, expected, *store)
}
// If TLS is required and the parameters are missing, an error is returned
func TestParseInvalidServerNoTLSWhenRequired(t *testing.T) {
func TestParseTLSNoTLSWhenRequired(t *testing.T) {
invalids := []string{
`{"server": {"http_addr": ":443", "tls_cert_file": "path/to/cert"}}`,
`{"server": {"http_addr": ":443", "tls_key_file": "path/to/key"}}`,
`{"server": {"tls_cert_file": "path/to/cert"}}`,
`{"server": {"tls_key_file": "path/to/key"}}`,
}
for _, configJSON := range invalids {
_, err := ParseServer(configure(configJSON), true)
_, err := ParseServerTLS(configure(configJSON), true)
assert.Error(t, err)
assert.Contains(t, err.Error(),
"both the TLS certificate and key are mandatory")
@ -161,60 +231,71 @@ func TestParseInvalidServerNoTLSWhenRequired(t *testing.T) {
}
// If TLS is not and the cert/key are partially provided, an error is returned
func TestParseInvalidServerPartialTLS(t *testing.T) {
func TestParseTLSPartialTLS(t *testing.T) {
invalids := []string{
`{"server": {"http_addr": ":443", "tls_cert_file": "path/to/cert"}}`,
`{"server": {"http_addr": ":443", "tls_key_file": "path/to/key"}}`,
`{"server": {"tls_cert_file": "path/to/cert"}}`,
`{"server": {"tls_key_file": "path/to/key"}}`,
}
for _, configJSON := range invalids {
_, err := ParseServer(configure(configJSON), false)
_, err := ParseServerTLS(configure(configJSON), false)
assert.Error(t, err)
assert.Contains(t, err.Error(),
"either include both a cert and key file, or neither to disable TLS")
}
}
func TestParseServerNoTLS(t *testing.T) {
func TestParseTLSNoTLSNotRequired(t *testing.T) {
config := configure(`{
"server": {
"http_addr": ":4443",
"grpc_addr": ":7899"
}
"server": {}
}`)
expected := Server{
HTTPAddr: ":4443",
GRPCAddr: ":7899",
ServerTLSOpts: nil,
}
server, err := ParseServer(config, false)
tlsOpts, err := ParseServerTLS(config, false)
assert.NoError(t, err)
assert.Equal(t, expected, *server)
assert.Nil(t, tlsOpts)
}
func TestUnmarshalConfigServerWithTLS(t *testing.T) {
func TestParseTLSWithTLS(t *testing.T) {
config := configure(`{
"server": {
"http_addr": ":4443",
"grpc_addr": ":7899",
"tls_cert_file": "path/to/cert",
"tls_key_file": "path/to/key",
"client_ca_file": "path/to/clientca"
}
}`)
expected := Server{
HTTPAddr: ":4443",
GRPCAddr: ":7899",
ServerTLSOpts: &ServerTLSOpts{
ServerCertFile: "path/to/cert",
ServerKeyFile: "path/to/key",
ClientCAFile: "path/to/clientca",
},
expected := ServerTLSOpts{
ServerCertFile: "path/to/cert",
ServerKeyFile: "path/to/key",
ClientCAFile: "path/to/clientca",
}
server, err := ParseServer(config, false)
tlsOpts, err := ParseServerTLS(config, false)
assert.NoError(t, err)
assert.Equal(t, expected, *server)
assert.Equal(t, expected, *tlsOpts)
}
func TestParseTLSWithEnvironmentVariables(t *testing.T) {
config := configure(`{
"server": {
"tls_cert_file": "path/to/cert",
"client_ca_file": "nosuchfile"
}
}`)
vars := map[string]string{
"SERVER_TLS_KEY_FILE": "path/to/key",
"SERVER_CLIENT_CA_FILE": "path/to/clientca",
}
setupEnvironmentVariables(t, vars)
defer cleanupEnvironmentVariables(t, vars)
expected := ServerTLSOpts{
ServerCertFile: "path/to/cert",
ServerKeyFile: "path/to/key",
ClientCAFile: "path/to/clientca",
}
tlsOpts, err := ParseServerTLS(config, true)
assert.NoError(t, err)
assert.Equal(t, expected, *tlsOpts)
}

View File

@ -44,9 +44,9 @@ func poolFromFile(filename string) (*x509.CertPool, error) {
// ServerTLSOpts generates a tls configuration for servers using the
// provided parameters.
type ServerTLSOpts struct {
ServerCertFile string `mapstructure:"tls_cert_file"`
ServerKeyFile string `mapstructure:"tls_key_file"`
ClientCAFile string `mapstructure:"client_ca_file"`
ServerCertFile string
ServerKeyFile string
ClientCAFile string
}
// ConfigureServerTLS specifies a set of ciphersuites, the server cert and key,
@ -87,11 +87,11 @@ func ConfigureServerTLS(opts *ServerTLSOpts) (*tls.Config, error) {
// ClientTLSOpts is a struct that contains options to pass to
// ConfigureClientTLS
type ClientTLSOpts struct {
RootCAFile string `json:"tls_ca_file"`
ServerName string `json:"hostname"`
InsecureSkipVerify bool `json:"-"`
ClientCertFile string `json:"tls_client_cert"`
ClientKeyFile string `json:"tls_client_key"`
RootCAFile string
ServerName string
InsecureSkipVerify bool
ClientCertFile string
ClientKeyFile string
}
// ConfigureClientTLS generates a tls configuration for clients using the