diff --git a/commands/config.go b/commands/config.go index 311b40b796..d6980b2a54 100644 --- a/commands/config.go +++ b/commands/config.go @@ -14,6 +14,19 @@ import ( "github.com/docker/machine/libmachine/state" ) +// For when the cert is computed to be invalid. +type ErrCertInvalid struct { + wrappedErr error + hostUrl string +} + +func (e ErrCertInvalid) Error() string { + return fmt.Sprintf(`There was an error validating certificates for host %q: %s +You can attempt to regenerate them using 'docker-machine regenerate-certs name'. +Be advised that this will trigger a Docker daemon restart which will stop running containers. +`, e.hostUrl, e.wrappedErr) +} + func cmdConfig(c *cli.Context) error { // Ensure that log messages always go to stderr when this command is // being run (it is intended to be run in a subshell) @@ -72,29 +85,19 @@ func runConnectionBoilerplate(h *host.Host, c *cli.Context) (string, *auth.AuthO authOptions := h.HostOptions.AuthOptions - if err := checkCert(u.Host, authOptions, c); err != nil { + if err := checkCert(u.Host, authOptions); err != nil { return "", &auth.AuthOptions{}, fmt.Errorf("Error checking and/or regenerating the certs: %s", err) } return dockerHost, authOptions, nil } -func checkCert(hostUrl string, authOptions *auth.AuthOptions, c *cli.Context) error { - valid, err := cert.ValidateCertificate( - hostUrl, - authOptions.CaCertPath, - authOptions.ServerCertPath, - authOptions.ServerKeyPath, - ) - if err != nil { - return fmt.Errorf("Error attempting to validate the certificates: %s", err) - } - - if !valid { - log.Errorf("Invalid certs detected; regenerating for %s", hostUrl) - - if err := runActionWithContext("configureAuth", c); err != nil { - return fmt.Errorf("Error attempting to regenerate the certs: %s", err) +func checkCert(hostUrl string, authOptions *auth.AuthOptions) error { + valid, err := cert.ValidateCertificate(hostUrl, authOptions) + if !valid || err != nil { + return ErrCertInvalid{ + wrappedErr: err, + hostUrl: hostUrl, } } diff --git a/commands/config_test.go b/commands/config_test.go new file mode 100644 index 0000000000..ae1389dbf8 --- /dev/null +++ b/commands/config_test.go @@ -0,0 +1,54 @@ +package commands + +import ( + "errors" + "testing" + + "github.com/docker/machine/libmachine/auth" + "github.com/docker/machine/libmachine/cert" + "github.com/stretchr/testify/assert" +) + +type FakeValidateCertificate struct { + IsValid bool + Err error +} + +type FakeCertGenerator struct { + fakeValidateCertificate *FakeValidateCertificate +} + +func (fcg FakeCertGenerator) GenerateCACertificate(certFile, keyFile, org string, bits int) error { + return nil +} + +func (fcg FakeCertGenerator) GenerateCert(hosts []string, certFile, keyFile, caFile, caKeyFile, org string, bits int) error { + return nil +} + +func (fcg FakeCertGenerator) ValidateCertificate(addr string, authOptions *auth.AuthOptions) (bool, error) { + return fcg.fakeValidateCertificate.IsValid, fcg.fakeValidateCertificate.Err +} + +func TestCheckCert(t *testing.T) { + errCertsExpired := errors.New("Certs have expired") + + cases := []struct { + hostUrl string + authOptions *auth.AuthOptions + valid bool + checkErr error + expectedErr error + }{ + {"192.168.99.100:2376", &auth.AuthOptions{}, true, nil, nil}, + {"192.168.99.100:2376", &auth.AuthOptions{}, false, nil, ErrCertInvalid{wrappedErr: nil, hostUrl: "192.168.99.100:2376"}}, + {"192.168.99.100:2376", &auth.AuthOptions{}, false, errCertsExpired, ErrCertInvalid{wrappedErr: errCertsExpired, hostUrl: "192.168.99.100:2376"}}, + } + + for _, c := range cases { + fcg := FakeCertGenerator{fakeValidateCertificate: &FakeValidateCertificate{c.valid, c.checkErr}} + cert.SetCertGenerator(fcg) + err := checkCert(c.hostUrl, c.authOptions) + assert.Equal(t, c.expectedErr, err) + } +} diff --git a/libmachine/cert/cert.go b/libmachine/cert/cert.go index f7c8b06902..f3457d4398 100644 --- a/libmachine/cert/cert.go +++ b/libmachine/cert/cert.go @@ -7,7 +7,6 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" - "fmt" "io/ioutil" "math/big" "net" @@ -16,18 +15,41 @@ import ( "errors" + "github.com/docker/machine/libmachine/auth" "github.com/docker/machine/libmachine/log" ) -type ErrValidatingCert struct { - wrappedErr error +var defaultGenerator = NewX509CertGenerator() + +type CertGenerator interface { + GenerateCACertificate(certFile, keyFile, org string, bits int) error + GenerateCert(hosts []string, certFile, keyFile, caFile, caKeyFile, org string, bits int) error + ValidateCertificate(addr string, authOptions *auth.AuthOptions) (bool, error) } -func (e ErrValidatingCert) Error() string { - return fmt.Sprintf("There was an error validating the cert: %s", e.wrappedErr) +type X509CertGenerator struct{} + +func NewX509CertGenerator() CertGenerator { + return &X509CertGenerator{} } -func getTLSConfig(caCert, cert, key []byte, allowInsecure bool) (*tls.Config, error) { +func GenerateCACertificate(certFile, keyFile, org string, bits int) error { + return defaultGenerator.GenerateCACertificate(certFile, keyFile, org, bits) +} + +func GenerateCert(hosts []string, certFile, keyFile, caFile, caKeyFile, org string, bits int) error { + return defaultGenerator.GenerateCert(hosts, certFile, keyFile, caFile, caKeyFile, org, bits) +} + +func ValidateCertificate(addr string, authOptions *auth.AuthOptions) (bool, error) { + return defaultGenerator.ValidateCertificate(addr, authOptions) +} + +func SetCertGenerator(cg CertGenerator) { + defaultGenerator = cg +} + +func (xcg *X509CertGenerator) getTLSConfig(caCert, cert, key []byte, allowInsecure bool) (*tls.Config, error) { // TLS config var tlsConfig tls.Config tlsConfig.InsecureSkipVerify = allowInsecure @@ -48,7 +70,7 @@ func getTLSConfig(caCert, cert, key []byte, allowInsecure bool) (*tls.Config, er return &tlsConfig, nil } -func newCertificate(org string) (*x509.Certificate, error) { +func (xcg *X509CertGenerator) newCertificate(org string) (*x509.Certificate, error) { now := time.Now() // need to set notBefore slightly in the past to account for time // skew in the VMs otherwise the certs sometimes are not yet valid @@ -78,8 +100,8 @@ func newCertificate(org string) (*x509.Certificate, error) { // GenerateCACertificate generates a new certificate authority from the specified org // and bit size and stores the resulting certificate and key file // in the arguments. -func GenerateCACertificate(certFile, keyFile, org string, bits int) error { - template, err := newCertificate(org) +func (xcg *X509CertGenerator) GenerateCACertificate(certFile, keyFile, org string, bits int) error { + template, err := xcg.newCertificate(org) if err != nil { return err } @@ -123,8 +145,8 @@ func GenerateCACertificate(certFile, keyFile, org string, bits int) error { // certificate authority files and stores the result in the certificate // file and key provided. The provided host names are set to the // appropriate certificate fields. -func GenerateCert(hosts []string, certFile, keyFile, caFile, caKeyFile, org string, bits int) error { - template, err := newCertificate(org) +func (xcg *X509CertGenerator) GenerateCert(hosts []string, certFile, keyFile, caFile, caKeyFile, org string, bits int) error { + template, err := xcg.newCertificate(org) if err != nil { return err } @@ -183,28 +205,32 @@ func GenerateCert(hosts []string, certFile, keyFile, caFile, caKeyFile, org stri } // ValidateCertificate validate the certificate installed on the vm. -func ValidateCertificate(addr, caCertPath, serverCertPath, serverKeyPath string) (bool, error) { +func (xcg *X509CertGenerator) ValidateCertificate(addr string, authOptions *auth.AuthOptions) (bool, error) { + caCertPath := authOptions.CaCertPath + serverCertPath := authOptions.ServerCertPath + serverKeyPath := authOptions.ServerKeyPath + log.Debugf("Reading CA certificate from %s", caCertPath) caCert, err := ioutil.ReadFile(caCertPath) if err != nil { - return false, ErrValidatingCert{err} + return false, err } log.Debugf("Reading server certificate from %s", serverCertPath) serverCert, err := ioutil.ReadFile(serverCertPath) if err != nil { - return false, ErrValidatingCert{err} + return false, err } log.Debugf("Reading server key from %s", serverKeyPath) serverKey, err := ioutil.ReadFile(serverKeyPath) if err != nil { - return false, ErrValidatingCert{err} + return false, err } - tlsConfig, err := getTLSConfig(caCert, serverCert, serverKey, false) + tlsConfig, err := xcg.getTLSConfig(caCert, serverCert, serverKey, false) if err != nil { - return false, ErrValidatingCert{err} + return false, err } dialer := &net.Dialer{ @@ -213,8 +239,7 @@ func ValidateCertificate(addr, caCertPath, serverCertPath, serverKeyPath string) _, err = tls.DialWithDialer(dialer, "tcp", addr, tlsConfig) if err != nil { - log.Debugf("Certificates are not valid: %s", err) - return false, nil + return false, err } return true, nil