diff --git a/cmd/notary-server/main.go b/cmd/notary-server/main.go index c206607b8d..d581c9988e 100644 --- a/cmd/notary-server/main.go +++ b/cmd/notary-server/main.go @@ -1,6 +1,7 @@ package main import ( + "crypto/tls" _ "expvar" "flag" "fmt" @@ -24,6 +25,7 @@ import ( "github.com/docker/notary/server" "github.com/docker/notary/server/storage" "github.com/docker/notary/signer" + "github.com/docker/notary/utils" "github.com/docker/notary/version" "github.com/spf13/viper" ) @@ -46,6 +48,28 @@ func init() { flag.BoolVar(&debug, "debug", false, "Enable the debugging server on localhost:8080") } +// optionally sets up TLS for the server - if no TLS configuration is +// specified, TLS is not enabled. +func serverTLS(configuration *viper.Viper) (*tls.Config, error) { + tlsCertFile := configuration.GetString("server.tls_cert_file") + tlsKeyFile := configuration.GetString("server.tls_key_file") + + if tlsCertFile == "" && tlsKeyFile == "" { + return nil, nil + } else if tlsCertFile == "" || tlsKeyFile == "" { + return nil, fmt.Errorf("Partial TLS configuration found. Either include both a cert and key file in the configuration, or include neither to disable TLS.") + } + + tlsConfig, err := utils.ConfigureServerTLS(&utils.ServerTLSOpts{ + ServerCertFile: tlsCertFile, + ServerKeyFile: tlsKeyFile, + }) + if err != nil { + return nil, fmt.Errorf("Unable to set up TLS: %s", err.Error()) + } + return tlsConfig, nil +} + func main() { flag.Usage = usage flag.Parse() @@ -151,12 +175,17 @@ func main() { logrus.Debug("Using memory backend") ctx = context.WithValue(ctx, "metaStore", storage.NewMemStorage()) } + + tlsConfig, err := serverTLS(mainViper) + if err != nil { + logrus.Fatal(err.Error()) + } + logrus.Info("Starting Server") err = server.Run( ctx, mainViper.GetString("server.addr"), - mainViper.GetString("server.tls_cert_file"), - mainViper.GetString("server.tls_key_file"), + tlsConfig, trust, mainViper.GetString("auth.type"), mainViper.Get("auth.options"), diff --git a/cmd/notary-server/main_test.go b/cmd/notary-server/main_test.go new file mode 100644 index 0000000000..bd04a125dc --- /dev/null +++ b/cmd/notary-server/main_test.go @@ -0,0 +1,77 @@ +package main + +import ( + "bytes" + "crypto/tls" + "fmt" + "strings" + "testing" + + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" +) + +const ( + Cert = "../../fixtures/notary-server.crt" + Key = "../../fixtures/notary-server.key" + Root = "../../fixtures/root-ca.crt" +) + +// initializes a viper object with test configuration +func configure(jsonConfig []byte) *viper.Viper { + config := viper.New() + config.SetConfigType("json") + config.ReadConfig(bytes.NewBuffer(jsonConfig)) + return config +} + +// If neither the cert nor the key are provided, a nil tls config is returned. +func TestServerTLSMissingCertAndKey(t *testing.T) { + tlsConfig, err := serverTLS(configure([]byte(`{"server": {}}`))) + assert.NoError(t, err) + assert.Nil(t, tlsConfig) +} + +func TestServerTLSMissingCertAndOrKey(t *testing.T) { + configs := []string{ + fmt.Sprintf(`{"tls_cert_file": "%s"}`, Cert), + fmt.Sprintf(`{"tls_key_file": "%s"}`, Key), + } + for _, serverConfig := range configs { + config := configure( + []byte(fmt.Sprintf(`{"server": %s}`, serverConfig))) + tlsConfig, err := serverTLS(config) + assert.Error(t, err) + assert.Nil(t, tlsConfig) + assert.True(t, + strings.Contains(err.Error(), "Partial TLS configuration found.")) + } +} + +// The rest of the functionality of serverTLS depends upon +// utils.ConfigureServerTLS, so this test just asserts that if successful, +// the correct tls.Config is returned based on all the configuration parameters +func TestServerTLSSuccess(t *testing.T) { + keypair, err := tls.LoadX509KeyPair(Cert, Key) + assert.NoError(t, err, "Unable to load cert and key for testing") + + config := fmt.Sprintf( + `{"server": {"tls_cert_file": "%s", "tls_key_file": "%s"}}`, + Cert, Key) + tlsConfig, err := serverTLS(configure([]byte(config))) + assert.NoError(t, err) + assert.Equal(t, []tls.Certificate{keypair}, tlsConfig.Certificates) +} + +// The rest of the functionality of singerTLS depends upon +// utils.ConfigureServerTLS, so this test just asserts that if it fails, +// the error is propogated. +func TestServerTLSFailure(t *testing.T) { + config := fmt.Sprintf( + `{"server": {"tls_cert_file": "non-exist", "tls_key_file": "%s"}}`, + Key) + tlsConfig, err := serverTLS(configure([]byte(config))) + assert.Error(t, err) + assert.Nil(t, tlsConfig) + assert.True(t, strings.Contains(err.Error(), "Unable to set up TLS")) +} diff --git a/server/server.go b/server/server.go index 670b2d153a..c6a17d8797 100644 --- a/server/server.go +++ b/server/server.go @@ -29,7 +29,7 @@ func init() { // Run sets up and starts a TLS server that can be cancelled using the // given configuration. The context it is passed is the context it should // use directly for the TLS server, and generate children off for requests -func Run(ctx context.Context, addr, tlsCertFile, tlsKeyFile string, trust signed.CryptoService, authMethod string, authOpts interface{}) error { +func Run(ctx context.Context, addr string, tlsConfig *tls.Config, trust signed.CryptoService, authMethod string, authOpts interface{}) error { tcpAddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { @@ -41,18 +41,9 @@ func Run(ctx context.Context, addr, tlsCertFile, tlsKeyFile string, trust signed return err } - if tlsCertFile != "" && tlsKeyFile != "" { - tlsConfig, err := utils.ConfigureServerTLS(&utils.ServerTLSOpts{ - ServerCertFile: tlsCertFile, - ServerKeyFile: tlsKeyFile, - }) - if err != nil { - return err - } + if tlsConfig != nil { logrus.Info("Enabling TLS") lsnr = tls.NewListener(lsnr, tlsConfig) - } else if tlsCertFile != "" || tlsKeyFile != "" { - return fmt.Errorf("Partial TLS configuration found. Either include both a cert and key file in the configuration, or include neither to disable TLS.") } var ac auth.AccessController diff --git a/server/server_test.go b/server/server_test.go index 759b5ad49e..a90cbcf8eb 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -14,8 +14,7 @@ func TestRunBadAddr(t *testing.T) { err := Run( context.Background(), "testAddr", - "../fixtures/notary-server.crt", - "../fixtures/notary-server.crt", + nil, signed.NewEd25519(), "", nil, @@ -31,8 +30,7 @@ func TestRunReservedPort(t *testing.T) { err := Run( ctx, "localhost:80", - "../fixtures/notary-server.crt", - "../fixtures/notary-server.crt", + nil, signed.NewEd25519(), "", nil,