diff --git a/host_test.go b/host_test.go index 3f05937ae2..3b2fdcb81f 100644 --- a/host_test.go +++ b/host_test.go @@ -21,11 +21,7 @@ const ( hostTestPrivateKey = "test-key" ) -var ( - tmpDir string -) - -func init() { +func getTestStore() (*Store, error) { tmpDir, err := ioutil.TempDir("", "machine-test-") if err != nil { fmt.Println(err) @@ -33,9 +29,6 @@ func init() { } os.Setenv("MACHINE_DIR", tmpDir) -} - -func getTestStore() (*Store, error) { return NewStore(tmpDir, hostTestCaCert, hostTestPrivateKey), nil } @@ -123,6 +116,14 @@ func TestValidateHostnameInvalid(t *testing.T) { } func TestGenerateClientCertificate(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "machine-test-") + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + os.Setenv("MACHINE_DIR", tmpDir) + caCertPath := filepath.Join(tmpDir, "ca.pem") caKeyPath := filepath.Join(tmpDir, "key.pem") testOrg := "test-org" diff --git a/utils/certs_test.go b/utils/certs_test.go index d4b585bf78..042f47e035 100644 --- a/utils/certs_test.go +++ b/utils/certs_test.go @@ -1 +1,76 @@ package utils + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" +) + +func TestGenerateCACertificate(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "machine-test-") + if err != nil { + t.Fatal(err) + } + + os.Setenv("MACHINE_DIR", tmpDir) + caCertPath := filepath.Join(tmpDir, "ca.pem") + caKeyPath := filepath.Join(tmpDir, "key.pem") + testOrg := "test-org" + bits := 2048 + if err := GenerateCACertificate(caCertPath, caKeyPath, testOrg, bits); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(caCertPath); err != nil { + t.Fatal(err) + } + if _, err := os.Stat(caKeyPath); err != nil { + t.Fatal(err) + } + os.Setenv("MACHINE_DIR", "") + + // cleanup + _ = os.RemoveAll(tmpDir) +} + +func TestGenerateCert(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "machine-test-") + if err != nil { + t.Fatal(err) + } + + os.Setenv("MACHINE_DIR", tmpDir) + caCertPath := filepath.Join(tmpDir, "ca.pem") + caKeyPath := filepath.Join(tmpDir, "key.pem") + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "cert-key.pem") + testOrg := "test-org" + bits := 2048 + if err := GenerateCACertificate(caCertPath, caKeyPath, testOrg, bits); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(caCertPath); err != nil { + t.Fatal(err) + } + if _, err := os.Stat(caKeyPath); err != nil { + t.Fatal(err) + } + os.Setenv("MACHINE_DIR", "") + + if err := GenerateCert([]string{}, certPath, keyPath, caCertPath, caKeyPath, testOrg, bits); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(certPath); err != nil { + t.Fatalf("certificate not created at %s", certPath) + } + + if _, err := os.Stat(keyPath); err != nil { + t.Fatalf("key not created at %s", keyPath) + } + + // cleanup + _ = os.RemoveAll(tmpDir) +} diff --git a/utils/utils_test.go b/utils/utils_test.go index e94fc223a7..e313d51470 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -11,6 +11,7 @@ import ( ) func TestGetBaseDir(t *testing.T) { + // reset any override env var homeDir := GetHomeDir() baseDir := GetBaseDir() @@ -27,6 +28,7 @@ func TestGetCustomBaseDir(t *testing.T) { if strings.Index(root, baseDir) != 0 { t.Fatalf("expected base dir with prefix %s; received %s", root, baseDir) } + os.Setenv("MACHINE_DIR", "") } func TestGetDockerDir(t *testing.T) { @@ -45,6 +47,7 @@ func TestGetDockerDir(t *testing.T) { if filename != ".docker" { t.Fatalf("expected docker dir \".docker\"; received %s", filename) } + os.Setenv("MACHINE_DIR", "") } func TestGetMachineDir(t *testing.T) { @@ -63,6 +66,7 @@ func TestGetMachineDir(t *testing.T) { if filename != "machines" { t.Fatalf("expected machine dir \"machines\"; received %s", filename) } + os.Setenv("MACHINE_DIR", "") } func TestGetMachineClientCertDir(t *testing.T) { @@ -81,6 +85,7 @@ func TestGetMachineClientCertDir(t *testing.T) { if filename != ".client" { t.Fatalf("expected machine client dir \".client\"; received %s", filename) } + os.Setenv("MACHINE_DIR", "") } func TestCopyFile(t *testing.T) {