certs: check if remote is valid and regenerate if not

Signed-off-by: Evan Hazlett <ejhazlett@gmail.com>
This commit is contained in:
Evan Hazlett 2015-03-12 00:25:43 -04:00
parent 89ea1ed4b4
commit d221d8ee97
5 changed files with 177 additions and 15 deletions

View File

@ -36,8 +36,11 @@ type machineConfig struct {
machineName string machineName string
machineDir string machineDir string
caCertPath string caCertPath string
caKeyPath string
clientCertPath string clientCertPath string
clientKeyPath string clientKeyPath string
serverCertPath string
serverKeyPath string
machineUrl string machineUrl string
swarmMaster bool swarmMaster bool
swarmHost string swarmHost string
@ -68,6 +71,25 @@ func (h hostListItemByName) Less(i, j int) bool {
return strings.ToLower(h[i].Name) < strings.ToLower(h[j].Name) return strings.ToLower(h[i].Name) < strings.ToLower(h[j].Name)
} }
func confirmInput(msg string) bool {
fmt.Printf("%s (y/n): ", msg)
var resp string
_, err := fmt.Scanln(&resp)
if err != nil {
log.Fatal(err)
}
if strings.Index(strings.ToLower(resp), "y") == 0 {
return true
}
return false
}
func setupCertificates(caCertPath, caKeyPath, clientCertPath, clientKeyPath string) error { func setupCertificates(caCertPath, caKeyPath, clientCertPath, clientKeyPath string) error {
org := utils.GetUsername() org := utils.GetUsername()
bits := 2048 bits := 2048
@ -207,6 +229,18 @@ var Commands = []cli.Command{
Usage: "List machines", Usage: "List machines",
Action: cmdLs, Action: cmdLs,
}, },
{
Name: "regenerate-certs",
Usage: "Regenerate TLS Certificates for a machine",
Description: "Argument(s) are one or more machine names. Will use the active machine if none is provided.",
Action: cmdRegenerateCerts,
Flags: []cli.Flag{
cli.BoolFlag{
Name: "force, f",
Usage: "Force rebuild and do not prompt",
},
},
},
{ {
Name: "restart", Name: "restart",
Usage: "Restart a machine", Usage: "Restart a machine",
@ -369,6 +403,33 @@ func cmdConfig(c *cli.Context) {
dockerHost = fmt.Sprintf("tcp://%s:%s", machineIp, swarmPort) dockerHost = fmt.Sprintf("tcp://%s:%s", machineIp, swarmPort)
} }
u, err := url.Parse(cfg.machineUrl)
if err != nil {
log.Fatal(err)
}
if u.Scheme != "unix" {
// validate cert and regenerate if needed
valid, err := utils.ValidateCertificate(
u.Host,
cfg.caCertPath,
cfg.serverCertPath,
cfg.serverKeyPath,
)
if err != nil {
log.Fatal(err)
}
if !valid {
log.Debugf("invalid certs detected; regenerating for %s", u.Host)
if err := runActionWithContext("configureAuth", c); err != nil {
log.Fatal(err)
}
}
}
fmt.Printf("--tlsverify --tlscacert=%s --tlscert=%s --tlskey=%s -H=%s", fmt.Printf("--tlsverify --tlscacert=%s --tlscert=%s --tlskey=%s -H=%s",
cfg.caCertPath, cfg.clientCertPath, cfg.clientKeyPath, dockerHost) cfg.caCertPath, cfg.clientCertPath, cfg.clientKeyPath, dockerHost)
} }
@ -459,6 +520,16 @@ func cmdLs(c *cli.Context) {
w.Flush() w.Flush()
} }
func cmdRegenerateCerts(c *cli.Context) {
force := c.Bool("force")
if force || confirmInput("Regenerate TLS machine certs? Warning: this is irreversible.") {
log.Infof("Regenerating TLS certificates")
if err := runActionWithContext("configureAuth", c); err != nil {
log.Fatal(err)
}
}
}
func cmdRm(c *cli.Context) { func cmdRm(c *cli.Context) {
if len(c.Args()) == 0 { if len(c.Args()) == 0 {
cli.ShowCommandHelp(c, "rm") cli.ShowCommandHelp(c, "rm")
@ -521,6 +592,32 @@ func cmdEnv(c *cli.Context) {
dockerHost = fmt.Sprintf("tcp://%s:%s", machineIp, swarmPort) dockerHost = fmt.Sprintf("tcp://%s:%s", machineIp, swarmPort)
} }
u, err := url.Parse(cfg.machineUrl)
if err != nil {
log.Fatal(err)
}
if u.Scheme != "unix" {
// validate cert and regenerate if needed
valid, err := utils.ValidateCertificate(
u.Host,
cfg.caCertPath,
cfg.serverCertPath,
cfg.serverKeyPath,
)
if err != nil {
log.Fatal(err)
}
if !valid {
log.Debugf("invalid certs detected; regenerating for %s", u.Host)
if err := runActionWithContext("configureAuth", c); err != nil {
log.Fatal(err)
}
}
}
switch userShell { switch userShell {
case "fish": case "fish":
fmt.Printf("set -x DOCKER_TLS_VERIFY 1;\nset -x DOCKER_CERT_PATH %s;\nset -x DOCKER_HOST %s;\n", fmt.Printf("set -x DOCKER_TLS_VERIFY 1;\nset -x DOCKER_CERT_PATH %s;\nset -x DOCKER_HOST %s;\n",
@ -574,11 +671,12 @@ func cmdSsh(c *cli.Context) {
// We run commands concurrently and communicate back an error if there was one. // We run commands concurrently and communicate back an error if there was one.
func machineCommand(actionName string, machine *Host, errorChan chan<- error) { func machineCommand(actionName string, machine *Host, errorChan chan<- error) {
commands := map[string](func() error){ commands := map[string](func() error){
"start": machine.Start, "configureAuth": machine.ConfigureAuth,
"stop": machine.Stop, "start": machine.Start,
"restart": machine.Restart, "stop": machine.Stop,
"kill": machine.Kill, "restart": machine.Restart,
"upgrade": machine.Upgrade, "kill": machine.Kill,
"upgrade": machine.Upgrade,
} }
log.Debugf("command=%s machine=%s", actionName, machine.Name) log.Debugf("command=%s machine=%s", actionName, machine.Name)
@ -811,8 +909,11 @@ func getMachineConfig(c *cli.Context) (*machineConfig, error) {
machineDir := filepath.Join(utils.GetMachineDir(), machine.Name) machineDir := filepath.Join(utils.GetMachineDir(), machine.Name)
caCert := filepath.Join(machineDir, "ca.pem") caCert := filepath.Join(machineDir, "ca.pem")
caKey := filepath.Join(utils.GetMachineCertDir(), "ca-key.pem")
clientCert := filepath.Join(machineDir, "cert.pem") clientCert := filepath.Join(machineDir, "cert.pem")
clientKey := filepath.Join(machineDir, "key.pem") clientKey := filepath.Join(machineDir, "key.pem")
serverCert := filepath.Join(machineDir, "server.pem")
serverKey := filepath.Join(machineDir, "server-key.pem")
machineUrl, err := machine.GetURL() machineUrl, err := machine.GetURL()
if err != nil { if err != nil {
if err == drivers.ErrHostIsNotRunning { if err == drivers.ErrHostIsNotRunning {
@ -825,8 +926,11 @@ func getMachineConfig(c *cli.Context) (*machineConfig, error) {
machineName: name, machineName: name,
machineDir: machineDir, machineDir: machineDir,
caCertPath: caCert, caCertPath: caCert,
caKeyPath: caKey,
clientCertPath: clientCert, clientCertPath: clientCert,
clientKeyPath: clientKey, clientKeyPath: clientKey,
serverCertPath: serverCert,
serverKeyPath: serverKey,
machineUrl: machineUrl, machineUrl: machineUrl,
swarmMaster: machine.SwarmMaster, swarmMaster: machine.SwarmMaster,
swarmHost: machine.SwarmHost, swarmHost: machine.SwarmHost,

View File

@ -460,6 +460,16 @@ foo3 virtualbox Running tcp://192.168.99.108:2376
foo4 * virtualbox Running tcp://192.168.99.109:2376 foo4 * virtualbox Running tcp://192.168.99.109:2376
``` ```
#### regenerate-certs
Regenerate TLS certificates and update the machine with new certs.
```
$ docker-machine regenerate-certs
Regenerate TLS machine certs? Warning: this is irreversible. (y/n): y
INFO[0013] Regenerating TLS certificates
```
#### restart #### restart
Restart a machine. Oftentimes this is equivalent to Restart a machine. Oftentimes this is equivalent to

View File

@ -245,7 +245,7 @@ func (h *Host) StopDocker() error {
switch h.Driver.GetProviderType() { switch h.Driver.GetProviderType() {
case provider.Local: case provider.Local:
cmd, err = h.GetSSHCommand("if [ -e /var/run/docker.pid ]; then sudo /etc/init.d/docker stop ; fi") cmd, err = h.GetSSHCommand("if [ -e /var/run/docker.pid ] && [ -d /proc/$(cat /var/run/docker.pid) ]; then sudo /etc/init.d/docker stop ; exit 0; fi")
case provider.Remote: case provider.Remote:
cmd, err = h.GetSSHCommand("sudo service docker stop") cmd, err = h.GetSSHCommand("sudo service docker stop")
default: default:
@ -364,7 +364,7 @@ func (h *Host) ConfigureAuth() error {
} }
machineServerKeyPath := path.Join(dockerDir, "server-key.pem") machineServerKeyPath := path.Join(dockerDir, "server-key.pem")
cmd, err = h.GetSSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee -a %s", string(caCert), machineCaCertPath)) cmd, err = h.GetSSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee %s", string(caCert), machineCaCertPath))
if err != nil { if err != nil {
return err return err
} }
@ -372,7 +372,7 @@ func (h *Host) ConfigureAuth() error {
return err return err
} }
cmd, err = h.GetSSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee -a %s", string(serverKey), machineServerKeyPath)) cmd, err = h.GetSSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee %s", string(serverKey), machineServerKeyPath))
if err != nil { if err != nil {
return err return err
} }
@ -380,7 +380,7 @@ func (h *Host) ConfigureAuth() error {
return err return err
} }
cmd, err = h.GetSSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee -a %s", string(serverCert), machineServerCertPath)) cmd, err = h.GetSSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee %s", string(serverCert), machineServerCertPath))
if err != nil { if err != nil {
return err return err
} }

View File

@ -7,12 +7,33 @@ import (
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/pem" "encoding/pem"
"io/ioutil"
"math/big" "math/big"
"net" "net"
"os" "os"
"time" "time"
) )
func getTLSConfig(caCert, cert, key []byte, allowInsecure bool) (*tls.Config, error) {
// TLS config
var tlsConfig tls.Config
tlsConfig.InsecureSkipVerify = allowInsecure
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(caCert)
tlsConfig.RootCAs = certPool
keypair, err := tls.X509KeyPair(cert, key)
if err != nil {
return &tlsConfig, err
}
tlsConfig.Certificates = []tls.Certificate{keypair}
if allowInsecure {
tlsConfig.InsecureSkipVerify = true
}
return &tlsConfig, nil
}
func newCertificate(org string) (*x509.Certificate, error) { func newCertificate(org string) (*x509.Certificate, error) {
now := time.Now() now := time.Now()
// need to set notBefore slightly in the past to account for time // need to set notBefore slightly in the past to account for time
@ -149,3 +170,32 @@ func GenerateCert(hosts []string, certFile, keyFile, caFile, caKeyFile, org stri
return nil return nil
} }
func ValidateCertificate(addr, caCertPath, serverCertPath, serverKeyPath string) (bool, error) {
caCert, err := ioutil.ReadFile(caCertPath)
if err != nil {
return false, err
}
serverCert, err := ioutil.ReadFile(serverCertPath)
if err != nil {
return false, err
}
serverKey, err := ioutil.ReadFile(serverKeyPath)
if err != nil {
return false, err
}
tlsConfig, err := getTLSConfig(caCert, serverCert, serverKey, false)
if err != nil {
return false, err
}
_, err = tls.Dial("tcp", addr, tlsConfig)
if err != nil {
return false, nil
}
return true, nil
}

View File

@ -12,6 +12,8 @@ func TestGenerateCACertificate(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// cleanup
defer os.RemoveAll(tmpDir)
os.Setenv("MACHINE_DIR", tmpDir) os.Setenv("MACHINE_DIR", tmpDir)
caCertPath := filepath.Join(tmpDir, "ca.pem") caCertPath := filepath.Join(tmpDir, "ca.pem")
@ -29,9 +31,6 @@ func TestGenerateCACertificate(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
os.Setenv("MACHINE_DIR", "") os.Setenv("MACHINE_DIR", "")
// cleanup
_ = os.RemoveAll(tmpDir)
} }
func TestGenerateCert(t *testing.T) { func TestGenerateCert(t *testing.T) {
@ -39,6 +38,8 @@ func TestGenerateCert(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// cleanup
defer os.RemoveAll(tmpDir)
os.Setenv("MACHINE_DIR", tmpDir) os.Setenv("MACHINE_DIR", tmpDir)
caCertPath := filepath.Join(tmpDir, "ca.pem") caCertPath := filepath.Join(tmpDir, "ca.pem")
@ -70,7 +71,4 @@ func TestGenerateCert(t *testing.T) {
if _, err := os.Stat(keyPath); err != nil { if _, err := os.Stat(keyPath); err != nil {
t.Fatalf("key not created at %s", keyPath) t.Fatalf("key not created at %s", keyPath)
} }
// cleanup
_ = os.RemoveAll(tmpDir)
} }