Add shared config file parsing to the utils package.

Signed-off-by: Ying Li <ying.li@docker.com>
This commit is contained in:
Ying Li 2015-11-19 16:19:24 -08:00
parent c4636411bc
commit b1fdea5b56
5 changed files with 394 additions and 59 deletions

View File

@ -73,24 +73,17 @@ func passphraseRetriever(keyName, alias string, createNew bool, attempts int) (p
return passphrase, false, nil
}
// parses and sets up the TLS for the signer http + grpc server
func signerTLS(configuration *viper.Viper, printUsage bool) (*tls.Config, error) {
certFile := configuration.GetString("server.cert_file")
keyFile := configuration.GetString("server.key_file")
if certFile == "" || keyFile == "" {
// validates TLS configuration options and sets up the TLS for the signer
// http + grpc server
func signerTLS(tlsConfig *utils.ServerTLSOpts, printUsage bool) (*tls.Config, error) {
if tlsConfig.ServerCertFile == "" || tlsConfig.ServerKeyFile == "" {
if printUsage {
usage()
}
return nil, fmt.Errorf("Certificate and key are mandatory")
}
clientCAFile := configuration.GetString("server.client_ca_file")
tlsConfig, err := utils.ConfigureServerTLS(&utils.ServerTLSOpts{
ServerCertFile: certFile,
ServerKeyFile: keyFile,
RequireClientAuth: clientCAFile != "",
ClientCAFile: clientCAFile,
})
tlsConfig, err := utils.ConfigureServerTLS(config)
if err != nil {
return nil, fmt.Errorf("Unable to set up TLS: %s", err.Error())
}
@ -115,6 +108,8 @@ func main() {
mainViper.SetConfigType(strings.TrimPrefix(ext, "."))
mainViper.SetConfigName(strings.TrimSuffix(filename, ext))
mainViper.AddConfigPath(configPath)
// set default log level to Error
mainViper.SetDefault("logging.level", "error")
err := mainViper.ReadInConfig()
if err != nil {
logrus.Error("Viper Error: ", err.Error())
@ -122,7 +117,16 @@ func main() {
os.Exit(1)
}
logrus.SetLevel(logrus.Level(mainViper.GetInt("logging.level")))
var config util.ServerConfiguration
err = mainViper.Unmarshal(&config)
if err != nil {
logrus.Fatalf(err.Error())
}
if config.Logging.Level != nil {
fmt.Println("LOGGING level", config.Logging.Level)
logrus.SetLevel(config.Logging.Level.ToLogrus())
}
tlsConfig, err := signerTLS(mainViper, true)
if err != nil {
@ -131,15 +135,15 @@ func main() {
cryptoServices := make(signer.CryptoServiceIndex)
configDBType := strings.ToLower(mainViper.GetString("storage.backend"))
dbURL := mainViper.GetString("storage.db_url")
if configDBType != dbType || dbURL == "" {
emptyStorage := utils.Storage{}
if config.Storage == emptyStorage {
usage()
log.Fatalf("Currently only a MySQL database backend is supported.")
log.Fatalf("Must specify a MySQL backend.")
}
dbSQL, err := sql.Open(configDBType, dbURL)
dbSQL, err := sql.Open(config.Storage.Backend, config.Storage.URL)
if err != nil {
log.Fatalf("failed to open the database: %s, %v", dbURL, err)
log.Fatalf("failed to open the database: %s, %v", config.Storage.URL, err)
}
defaultAlias := mainViper.GetString(defaultAliasEnv)

135
utils/configuration.go Normal file
View File

@ -0,0 +1,135 @@
// Common configuration elements that may be resused
package utils
import (
"fmt"
"strings"
"github.com/Sirupsen/logrus"
bugsnag_hook "github.com/Sirupsen/logrus/hooks/bugsnag"
"github.com/bugsnag/bugsnag-go"
"github.com/spf13/viper"
)
// Server is a configuration about what addresses a server should listen on
type Server struct {
*ServerTLSOpts
HTTPAddr string
GRPCAddr string
}
// Storage is a configuration about what storage backend a server should use
type Storage struct {
Backend string `mapstructure:"backend"`
URL string `mapstructure:"db_url"`
}
// ParseServer tries to parse out a valid Server from a Viper:
// - Either or both of HTTP and GRPC address must be provided
// - If TLS is required, both the cert and key must be provided
// - If TLS is not requried, either both the cert and key must be provided or
// neither must be provided
func ParseServer(configuration *viper.Viper, tlsRequired bool) (*Server, error) {
// mapstructure does not support unmarshalling into a pointer
var tlsOpts ServerTLSOpts
err := configuration.UnmarshalKey("server", &tlsOpts)
if err != nil {
return nil, err
}
cert, key := tlsOpts.ServerCertFile, tlsOpts.ServerKeyFile
if tlsRequired {
if cert == "" || key == "" {
return nil, fmt.Errorf("both the TLS certificate and key are mandatory")
}
} else {
if (cert == "" && key != "") || (cert != "" && key == "") {
return nil, fmt.Errorf(
"either include both a cert and key file, or neither to disable TLS")
}
}
server := Server{
HTTPAddr: configuration.GetString("server.http_addr"),
GRPCAddr: configuration.GetString("server.grpc_addr"),
ServerTLSOpts: &tlsOpts,
}
if cert == "" && key == "" && tlsOpts.ClientCAFile == "" {
server.ServerTLSOpts = nil
}
if server.HTTPAddr == "" && server.GRPCAddr == "" {
return nil, fmt.Errorf("server must have an HTTP and/or GRPC address")
}
return &server, nil
}
// ParseLogLevel tries to parse out a log level from a Viper. If there is no
// configuration, defaults to the provided error level
func ParseLogLevel(configuration *viper.Viper, defaultLevel logrus.Level) (
logrus.Level, error) {
logStr := configuration.GetString("logging.level")
if logStr == "" {
return defaultLevel, nil
}
return logrus.ParseLevel(logStr)
}
// ParseStorage tries to parse out Storage from a Viper. If backend and
// URL are not provided, returns a nil pointer.
func ParseStorage(configuration *viper.Viper) (*Storage, error) {
var store Storage
err := configuration.UnmarshalKey("storage", &store)
if err != nil {
return nil, err
}
if store.Backend == "" && store.URL == "" {
return nil, nil
}
store.Backend = strings.ToLower(store.Backend)
if store.Backend != "mysql" {
return nil, fmt.Errorf(
"must specify one of these supported backends: mysql")
}
if store.URL == "" {
return nil, fmt.Errorf("must provide a non-empty database URL")
}
return &store, nil
}
// ParseBugsnag tries to parse out a Bugsnag Configuration from a Viper.
// If no values are provided, returns a nil pointer.
func ParseBugsnag(configuration *viper.Viper) (*bugsnag.Configuration, error) {
// can't unmarshal because we can't add tags to the bugsnag.Configuration
// struct
bugconf := bugsnag.Configuration{
APIKey: configuration.GetString("reporting.bugsnag.api_key"),
ReleaseStage: configuration.GetString("reporting.bugsnag.release_stage"),
Endpoint: configuration.GetString("reporting.bugsnag.endpoint"),
}
if bugconf.APIKey == "" && bugconf.ReleaseStage == "" && bugconf.Endpoint == "" {
return nil, nil
}
if bugconf.APIKey == "" {
return nil, fmt.Errorf("must provide an API key for bugsnag")
}
return &bugconf, nil
}
// utilities for handling common configurations
// SetUpBugsnag configures bugsnag and sets up a logrus hook
func SetUpBugsnag(config *bugsnag.Configuration) error {
if config != nil {
bugsnag.Configure(*config)
hook, err := bugsnag_hook.NewBugsnagHook()
if err != nil {
return err
}
logrus.AddHook(hook)
logrus.Debug("Adding logrus hook for Bugsnag")
}
return nil
}

220
utils/configuration_test.go Normal file
View File

@ -0,0 +1,220 @@
package utils
import (
"bytes"
"fmt"
"strings"
"testing"
"github.com/Sirupsen/logrus"
"github.com/bugsnag/bugsnag-go"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
)
// initializes a viper object with test configuration
func configure(jsonConfig string) *viper.Viper {
config := viper.New()
config.SetConfigType("json")
config.ReadConfig(bytes.NewBuffer([]byte(jsonConfig)))
return config
}
// An error is returned if the log level is not parsable
func TestParseInvalidLogLevel(t *testing.T) {
_, err := ParseLogLevel(configure(`{"logging": {"level": "horatio"}}`),
logrus.DebugLevel)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not a valid logrus Level")
}
// If there is no logging level configured it is set to the default level
func TestParseNoLogLevel(t *testing.T) {
empties := []string{`{}`, `{"logging": {}}`}
for _, configJSON := range empties {
lvl, err := ParseLogLevel(configure(configJSON), logrus.DebugLevel)
assert.NoError(t, err)
assert.Equal(t, logrus.DebugLevel, lvl)
}
}
// If there is logging level configured, it is set to the configured one
func TestParseLogLevel(t *testing.T) {
lvl, err := ParseLogLevel(configure(`{"logging": {"level": "error"}}`),
logrus.DebugLevel)
assert.NoError(t, err)
assert.Equal(t, logrus.ErrorLevel, lvl)
}
// An error is returned if there's no API key
func TestParseInvalidBugsnag(t *testing.T) {
_, err := ParseBugsnag(configure(
`{"reporting": {"bugsnag": {"endpoint": "http://12345"}}}`))
assert.Error(t, err)
assert.Contains(t, err.Error(), "must provide an API key")
}
// If there's no bugsnag, a nil pointer is returned
func TestParseNoBugsnag(t *testing.T) {
empties := []string{`{}`, `{"reporting": {}}`}
for _, configJSON := range empties {
bugconf, err := ParseBugsnag(configure(configJSON))
assert.NoError(t, err)
assert.Nil(t, bugconf)
}
}
func TestParseBugsnag(t *testing.T) {
config := configure(`{
"reporting": {
"bugsnag": {
"api_key": "12345",
"release_stage": "production",
"endpoint": "http://1234.com"
}
}
}`)
expected := bugsnag.Configuration{
APIKey: "12345",
ReleaseStage: "production",
Endpoint: "http://1234.com",
}
bugconf, err := ParseBugsnag(config)
assert.NoError(t, err)
assert.Equal(t, expected, *bugconf)
}
// If the storage parameters are invalid, an error is returned
func TestParseInvalidStorage(t *testing.T) {
invalids := []string{
`{"storage": {"backend": "memow", "db_url": "1234"}}`,
`{"storage": {"db_url": "12345"}}`,
`{"storage": {"backend": "mysql"}}`,
`{"storage": {"backend": "mysql", "db_url": ""}}`,
}
for _, configJSON := range invalids {
_, err := ParseStorage(configure(configJSON))
assert.Error(t, err, fmt.Sprintf("'%s' should be an error", configJSON))
if strings.Contains(configJSON, "mysql") {
assert.Contains(t, err.Error(),
"must provide a non-empty database URL")
} else {
assert.Contains(t, err.Error(),
"must specify one of these supported backends: mysql")
}
}
}
// If there is no storage, a nil pointer is returned
func TestParseNoStorage(t *testing.T) {
empties := []string{`{}`, `{"storage": {}}`}
for _, configJSON := range empties {
store, err := ParseStorage(configure(configJSON))
assert.NoError(t, err)
assert.Nil(t, store)
}
}
func TestParseStorage(t *testing.T) {
config := configure(`{
"storage": {
"backend": "MySQL",
"db_url": "username:passord@tcp(hostname:1234)/dbname"
}
}`)
expected := Storage{
Backend: "mysql",
URL: "username:passord@tcp(hostname:1234)/dbname",
}
store, err := ParseStorage(config)
assert.NoError(t, err)
assert.Equal(t, expected, *store)
}
// If the server section is missing or missing HTTP/GRPC addresses, an error is
// returned
func TestParseInvalidOrNoServer(t *testing.T) {
invalids := []string{`{}`, `{"server": {}}`}
for _, configJSON := range invalids {
_, err := ParseServer(configure(configJSON), false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "must have an HTTP and/or GRPC address")
}
}
// If TLS is required and the parameters are missing, an error is returned
func TestParseInvalidServerNoTLSWhenRequired(t *testing.T) {
invalids := []string{
`{"server": {"http_addr": ":443", "tls_cert_file": "path/to/cert"}}`,
`{"server": {"http_addr": ":443", "tls_key_file": "path/to/key"}}`,
}
for _, configJSON := range invalids {
_, err := ParseServer(configure(configJSON), true)
assert.Error(t, err)
assert.Contains(t, err.Error(),
"both the TLS certificate and key are mandatory")
}
}
// If TLS is not and the cert/key are partially provided, an error is returned
func TestParseInvalidServerPartialTLS(t *testing.T) {
invalids := []string{
`{"server": {"http_addr": ":443", "tls_cert_file": "path/to/cert"}}`,
`{"server": {"http_addr": ":443", "tls_key_file": "path/to/key"}}`,
}
for _, configJSON := range invalids {
_, err := ParseServer(configure(configJSON), false)
assert.Error(t, err)
assert.Contains(t, err.Error(),
"either include both a cert and key file, or neither to disable TLS")
}
}
func TestParseServerNoTLS(t *testing.T) {
config := configure(`{
"server": {
"http_addr": ":4443",
"grpc_addr": ":7899"
}
}`)
expected := Server{
HTTPAddr: ":4443",
GRPCAddr: ":7899",
ServerTLSOpts: nil,
}
server, err := ParseServer(config, false)
assert.NoError(t, err)
assert.Equal(t, expected, *server)
}
func TestUnmarshalConfigServerWithTLS(t *testing.T) {
config := configure(`{
"server": {
"http_addr": ":4443",
"grpc_addr": ":7899",
"tls_cert_file": "path/to/cert",
"tls_key_file": "path/to/key",
"client_ca_file": "path/to/clientca"
}
}`)
expected := Server{
HTTPAddr: ":4443",
GRPCAddr: ":7899",
ServerTLSOpts: &ServerTLSOpts{
ServerCertFile: "path/to/cert",
ServerKeyFile: "path/to/key",
ClientCAFile: "path/to/clientca",
},
}
server, err := ParseServer(config, false)
assert.NoError(t, err)
assert.Equal(t, expected, *server)
}

View File

@ -44,10 +44,9 @@ func poolFromFile(filename string) (*x509.CertPool, error) {
// ServerTLSOpts generates a tls configuration for servers using the
// provided parameters.
type ServerTLSOpts struct {
ServerCertFile string
ServerKeyFile string
RequireClientAuth bool
ClientCAFile string
ServerCertFile string `mapstructure:"tls_cert_file"`
ServerKeyFile string `mapstructure:"tls_key_file"`
ClientCAFile string `mapstructure:"client_ca_file"`
}
// ConfigureServerTLS specifies a set of ciphersuites, the server cert and key,
@ -73,16 +72,13 @@ func ConfigureServerTLS(opts *ServerTLSOpts) (*tls.Config, error) {
Rand: rand.Reader,
}
if opts.RequireClientAuth {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
if opts.ClientCAFile != "" {
pool, err := poolFromFile(opts.ClientCAFile)
if err != nil {
return nil, err
}
tlsConfig.ClientCAs = pool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
return tlsConfig, nil
@ -91,11 +87,11 @@ func ConfigureServerTLS(opts *ServerTLSOpts) (*tls.Config, error) {
// ClientTLSOpts is a struct that contains options to pass to
// ConfigureClientTLS
type ClientTLSOpts struct {
RootCAFile string
ServerName string
InsecureSkipVerify bool
ClientCertFile string
ClientKeyFile string
RootCAFile string `json:"tls_ca_file"`
ServerName string `json:"hostname"`
InsecureSkipVerify bool `json:"-"`
ClientCertFile string `json:"tls_client_cert"`
ClientKeyFile string `json:"tls_client_key"`
}
// ConfigureClientTLS generates a tls configuration for clients using the

View File

@ -60,10 +60,9 @@ func TestConfigServerTLSFailsIfUnableToLoadCerts(t *testing.T) {
files[i] = "not-real-file"
result, err := ConfigureServerTLS(&ServerTLSOpts{
ServerCertFile: files[0],
ServerKeyFile: files[1],
RequireClientAuth: true,
ClientCAFile: files[2],
ServerCertFile: files[0],
ServerKeyFile: files[1],
ClientCAFile: files[2],
})
assert.Nil(t, result)
assert.Error(t, err)
@ -106,7 +105,7 @@ func TestConfigServerTLSWithEmptyCACertFile(t *testing.T) {
// If server cert and key are provided, and client cert file is provided with
// one cert, a valid tls.Config is returned with the clientCAs set to that
// cert.
// cert. ClientAuth is set to RequireAndVerifyClientCert.
func TestConfigServerTLSWithOneCACert(t *testing.T) {
keypair, err := tls.LoadX509KeyPair(ServerCert, ServerKey)
assert.NoError(t, err)
@ -119,13 +118,13 @@ func TestConfigServerTLSWithOneCACert(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, []tls.Certificate{keypair}, tlsConfig.Certificates)
assert.True(t, tlsConfig.PreferServerCipherSuites)
assert.Equal(t, tls.NoClientCert, tlsConfig.ClientAuth)
assert.Equal(t, tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth)
assert.Len(t, tlsConfig.ClientCAs.Subjects(), 1)
}
// If server cert and key are provided, and client cert file is provided with
// multiple certs, a valid tls.Config is returned with the clientCAs set to
// the valid cert.
// the valid cert. ClientAuth is set to RequireAndVerifyClientCert.
func TestConfigServerTLSWithMultipleCACerts(t *testing.T) {
tempFilename := generateMultiCert(t)
defer os.RemoveAll(tempFilename)
@ -141,27 +140,8 @@ func TestConfigServerTLSWithMultipleCACerts(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, []tls.Certificate{keypair}, tlsConfig.Certificates)
assert.True(t, tlsConfig.PreferServerCipherSuites)
assert.Equal(t, tls.NoClientCert, tlsConfig.ClientAuth)
assert.Len(t, tlsConfig.ClientCAs.Subjects(), 2)
}
// If server cert and key are provided, and client auth is disabled, then
// a valid tls.Config is returned with ClientAuth set to
// RequireAndVerifyClientCert
func TestConfigServerTLSClientAuthEnabled(t *testing.T) {
keypair, err := tls.LoadX509KeyPair(ServerCert, ServerKey)
assert.NoError(t, err)
tlsConfig, err := ConfigureServerTLS(&ServerTLSOpts{
ServerCertFile: ServerCert,
ServerKeyFile: ServerKey,
RequireClientAuth: true,
})
assert.NoError(t, err)
assert.Equal(t, []tls.Certificate{keypair}, tlsConfig.Certificates)
assert.True(t, tlsConfig.PreferServerCipherSuites)
assert.Equal(t, tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth)
assert.Nil(t, tlsConfig.ClientCAs)
assert.Len(t, tlsConfig.ClientCAs.Subjects(), 2)
}
// The skipVerify boolean gets set on the tls.Config's InsecureSkipBoolean