926 lines
26 KiB
Go
926 lines
26 KiB
Go
// Copyright The OpenTelemetry Authors
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package configtls
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"go.opentelemetry.io/collector/config/configopaque"
|
|
)
|
|
|
|
func TestNewDefaultConfig(t *testing.T) {
|
|
expectedConfig := Config{}
|
|
config := NewDefaultConfig()
|
|
require.Equal(t, expectedConfig, config)
|
|
}
|
|
|
|
func TestNewDefaultClientConfig(t *testing.T) {
|
|
expectedConfig := ClientConfig{
|
|
Config: NewDefaultConfig(),
|
|
}
|
|
config := NewDefaultClientConfig()
|
|
require.Equal(t, expectedConfig, config)
|
|
}
|
|
|
|
func TestNewDefaultServerConfig(t *testing.T) {
|
|
expectedConfig := ServerConfig{
|
|
Config: NewDefaultConfig(),
|
|
}
|
|
config := NewDefaultServerConfig()
|
|
require.Equal(t, expectedConfig, config)
|
|
}
|
|
|
|
func TestOptionsToConfig(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
options Config
|
|
expectError string
|
|
}{
|
|
{
|
|
name: "should load system CA",
|
|
options: Config{CAFile: ""},
|
|
},
|
|
{
|
|
name: "should load custom CA",
|
|
options: Config{CAFile: filepath.Join("testdata", "ca-1.crt")},
|
|
},
|
|
{
|
|
name: "should load system CA and custom CA",
|
|
options: Config{IncludeSystemCACertsPool: true, CAFile: filepath.Join("testdata", "ca-1.crt")},
|
|
},
|
|
{
|
|
name: "should fail with invalid CA file path",
|
|
options: Config{CAFile: filepath.Join("testdata", "not/valid")},
|
|
expectError: "failed to load CA",
|
|
},
|
|
{
|
|
name: "should fail with invalid CA file content",
|
|
options: Config{CAFile: filepath.Join("testdata", "testCA-bad.txt")},
|
|
expectError: "failed to parse cert",
|
|
},
|
|
{
|
|
name: "should load valid TLS settings",
|
|
options: Config{
|
|
CAFile: filepath.Join("testdata", "ca-1.crt"),
|
|
CertFile: filepath.Join("testdata", "server-1.crt"),
|
|
KeyFile: filepath.Join("testdata", "server-1.key"),
|
|
},
|
|
},
|
|
{
|
|
name: "should fail with missing TLS KeyFile",
|
|
options: Config{
|
|
CAFile: filepath.Join("testdata", "ca-1.crt"),
|
|
CertFile: filepath.Join("testdata", "server-1.crt"),
|
|
},
|
|
expectError: "provide both certificate and key, or neither",
|
|
},
|
|
{
|
|
name: "should fail with invalid TLS KeyFile",
|
|
options: Config{
|
|
CAFile: filepath.Join("testdata", "ca-1.crt"),
|
|
CertFile: filepath.Join("testdata", "server-1.crt"),
|
|
KeyFile: filepath.Join("testdata", "not/valid"),
|
|
},
|
|
expectError: "failed to load TLS cert and key",
|
|
},
|
|
{
|
|
name: "should fail with missing TLS Cert",
|
|
options: Config{
|
|
CAFile: filepath.Join("testdata", "ca-1.crt"),
|
|
KeyFile: filepath.Join("testdata", "server-1.key"),
|
|
},
|
|
expectError: "provide both certificate and key, or neither",
|
|
},
|
|
{
|
|
name: "should fail with invalid TLS Cert",
|
|
options: Config{
|
|
CAFile: filepath.Join("testdata", "ca-1.crt"),
|
|
CertFile: filepath.Join("testdata", "not/valid"),
|
|
KeyFile: filepath.Join("testdata", "server-1.key"),
|
|
},
|
|
expectError: "failed to load TLS cert and key",
|
|
},
|
|
{
|
|
name: "should fail with invalid TLS CA",
|
|
options: Config{
|
|
CAFile: filepath.Join("testdata", "not/valid"),
|
|
},
|
|
expectError: "failed to load CA",
|
|
},
|
|
{
|
|
name: "should fail with invalid CA pool",
|
|
options: Config{
|
|
CAFile: filepath.Join("testdata", "testCA-bad.txt"),
|
|
},
|
|
expectError: "failed to parse cert",
|
|
},
|
|
{
|
|
name: "should pass with valid CA pool",
|
|
options: Config{
|
|
CAFile: filepath.Join("testdata", "ca-1.crt"),
|
|
},
|
|
},
|
|
{
|
|
name: "should pass with valid min and max version",
|
|
options: Config{
|
|
MinVersion: "1.1",
|
|
MaxVersion: "1.2",
|
|
},
|
|
},
|
|
{
|
|
name: "should pass with invalid min",
|
|
options: Config{
|
|
MinVersion: "1.7",
|
|
},
|
|
expectError: "invalid TLS min_",
|
|
},
|
|
{
|
|
name: "should pass with invalid max",
|
|
options: Config{
|
|
MaxVersion: "1.7",
|
|
},
|
|
expectError: "invalid TLS max_",
|
|
},
|
|
{
|
|
name: "should load custom CA PEM",
|
|
options: Config{CAPem: readFilePanics("testdata/ca-1.crt")},
|
|
},
|
|
{
|
|
name: "should load valid TLS settings with PEMs",
|
|
options: Config{
|
|
CAPem: readFilePanics("testdata/ca-1.crt"),
|
|
CertPem: readFilePanics("testdata/server-1.crt"),
|
|
KeyPem: readFilePanics("testdata/server-1.key"),
|
|
},
|
|
},
|
|
{
|
|
name: "mix Cert file and Key PEM provided",
|
|
options: Config{
|
|
CertFile: "testdata/server-1.crt",
|
|
KeyPem: readFilePanics("testdata/server-1.key"),
|
|
},
|
|
},
|
|
{
|
|
name: "mix Cert PEM and Key File provided",
|
|
options: Config{
|
|
CertPem: readFilePanics("testdata/server-1.crt"),
|
|
KeyFile: "testdata/server-1.key",
|
|
},
|
|
},
|
|
{
|
|
name: "should fail with invalid CA PEM",
|
|
options: Config{CAPem: readFilePanics("testdata/testCA-bad.txt")},
|
|
expectError: "failed to parse cert",
|
|
},
|
|
{
|
|
name: "should fail CA file and PEM both provided",
|
|
options: Config{
|
|
CAFile: "testdata/ca-1.crt",
|
|
CAPem: readFilePanics("testdata/ca-1.crt"),
|
|
},
|
|
expectError: "provide either a CA file or the PEM-encoded string, but not both",
|
|
},
|
|
{
|
|
name: "should fail Cert file and PEM both provided",
|
|
options: Config{
|
|
CertFile: "testdata/server-1.crt",
|
|
CertPem: readFilePanics("testdata/server-1.crt"),
|
|
KeyFile: "testdata/server-1.key",
|
|
},
|
|
expectError: "provide either a certificate or the PEM-encoded string, but not both",
|
|
},
|
|
{
|
|
name: "should fail Key file and PEM both provided",
|
|
options: Config{
|
|
CertFile: "testdata/server-1.crt",
|
|
KeyFile: "testdata/ca-1.crt",
|
|
KeyPem: readFilePanics("testdata/server-1.key"),
|
|
},
|
|
expectError: "provide either a key or the PEM-encoded string, but not both",
|
|
},
|
|
{
|
|
name: "should fail to load valid TLS settings with bad Cert PEM",
|
|
options: Config{
|
|
CAPem: readFilePanics("testdata/ca-1.crt"),
|
|
CertPem: readFilePanics("testdata/testCA-bad.txt"),
|
|
KeyPem: readFilePanics("testdata/server-1.key"),
|
|
},
|
|
expectError: "failed to load TLS cert and key PEMs",
|
|
},
|
|
{
|
|
name: "should fail to load valid TLS settings with bad Key PEM",
|
|
options: Config{
|
|
CAPem: readFilePanics("testdata/ca-1.crt"),
|
|
CertPem: readFilePanics("testdata/server-1.crt"),
|
|
KeyPem: readFilePanics("testdata/testCA-bad.txt"),
|
|
},
|
|
expectError: "failed to load TLS cert and key PEMs",
|
|
},
|
|
{
|
|
name: "should fail with missing TLS KeyPem",
|
|
options: Config{
|
|
CAPem: readFilePanics("testdata/ca-1.crt"),
|
|
CertPem: readFilePanics("testdata/server-1.crt"),
|
|
},
|
|
expectError: "provide both certificate and key, or neither",
|
|
},
|
|
{
|
|
name: "should fail with missing TLS Cert PEM",
|
|
options: Config{
|
|
CAPem: readFilePanics("testdata/ca-1.crt"),
|
|
KeyPem: readFilePanics("testdata/server-1.key"),
|
|
},
|
|
expectError: "provide both certificate and key, or neither",
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
cfg, err := test.options.loadTLSConfig()
|
|
if test.expectError != "" {
|
|
assert.ErrorContains(t, err, test.expectError)
|
|
} else {
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, cfg)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func readFilePanics(filePath string) configopaque.String {
|
|
fileContents, err := os.ReadFile(filepath.Clean(filePath))
|
|
if err != nil {
|
|
panic(fmt.Sprintf("failed to read file %s: %v", filePath, err))
|
|
}
|
|
|
|
return configopaque.String(fileContents)
|
|
}
|
|
|
|
func TestLoadTLSClientConfigError(t *testing.T) {
|
|
tlsSetting := ClientConfig{
|
|
Config: Config{
|
|
CertFile: "doesnt/exist",
|
|
KeyFile: "doesnt/exist",
|
|
},
|
|
}
|
|
_, err := tlsSetting.LoadTLSConfig(context.Background())
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestLoadTLSClientConfig(t *testing.T) {
|
|
tlsSetting := ClientConfig{
|
|
Insecure: true,
|
|
}
|
|
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
|
|
require.NoError(t, err)
|
|
assert.Nil(t, tlsCfg)
|
|
|
|
tlsSetting = ClientConfig{}
|
|
tlsCfg, err = tlsSetting.LoadTLSConfig(context.Background())
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, tlsCfg)
|
|
|
|
tlsSetting = ClientConfig{
|
|
InsecureSkipVerify: true,
|
|
}
|
|
tlsCfg, err = tlsSetting.LoadTLSConfig(context.Background())
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, tlsCfg)
|
|
assert.True(t, tlsCfg.InsecureSkipVerify)
|
|
}
|
|
|
|
func TestLoadTLSServerConfigError(t *testing.T) {
|
|
tlsSetting := ServerConfig{
|
|
Config: Config{
|
|
CertFile: "doesnt/exist",
|
|
KeyFile: "doesnt/exist",
|
|
},
|
|
}
|
|
_, err := tlsSetting.LoadTLSConfig(context.Background())
|
|
require.Error(t, err)
|
|
|
|
tlsSetting = ServerConfig{
|
|
ClientCAFile: "doesnt/exist",
|
|
}
|
|
_, err = tlsSetting.LoadTLSConfig(context.Background())
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestLoadTLSServerConfig(t *testing.T) {
|
|
tlsSetting := ServerConfig{}
|
|
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, tlsCfg)
|
|
}
|
|
|
|
func TestLoadTLSServerConfigReload(t *testing.T) {
|
|
tmpCaPath := createTempClientCaFile(t)
|
|
|
|
overwriteClientCA(t, tmpCaPath, "ca-1.crt")
|
|
|
|
tlsSetting := ServerConfig{
|
|
ClientCAFile: tmpCaPath,
|
|
ReloadClientCAFile: true,
|
|
}
|
|
|
|
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, tlsCfg)
|
|
|
|
firstClient, err := tlsCfg.GetConfigForClient(nil)
|
|
require.NoError(t, err)
|
|
|
|
overwriteClientCA(t, tmpCaPath, "ca-2.crt")
|
|
|
|
assert.Eventually(t, func() bool {
|
|
_, loadError := tlsCfg.GetConfigForClient(nil)
|
|
return loadError == nil
|
|
}, 5*time.Second, 10*time.Millisecond)
|
|
|
|
secondClient, err := tlsCfg.GetConfigForClient(nil)
|
|
require.NoError(t, err)
|
|
|
|
assert.NotEqual(t, firstClient.ClientCAs, secondClient.ClientCAs)
|
|
}
|
|
|
|
func TestLoadTLSServerConfigFailingReload(t *testing.T) {
|
|
tmpCaPath := createTempClientCaFile(t)
|
|
|
|
overwriteClientCA(t, tmpCaPath, "ca-1.crt")
|
|
|
|
tlsSetting := ServerConfig{
|
|
ClientCAFile: tmpCaPath,
|
|
ReloadClientCAFile: true,
|
|
}
|
|
|
|
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, tlsCfg)
|
|
|
|
firstClient, err := tlsCfg.GetConfigForClient(nil)
|
|
require.NoError(t, err)
|
|
|
|
overwriteClientCA(t, tmpCaPath, "testCA-bad.txt")
|
|
|
|
assert.Eventually(t, func() bool {
|
|
_, loadError := tlsCfg.GetConfigForClient(nil)
|
|
return loadError == nil
|
|
}, 5*time.Second, 10*time.Millisecond)
|
|
|
|
secondClient, err := tlsCfg.GetConfigForClient(nil)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, firstClient.ClientCAs, secondClient.ClientCAs)
|
|
}
|
|
|
|
func TestLoadTLSServerConfigFailingInitialLoad(t *testing.T) {
|
|
tmpCaPath := createTempClientCaFile(t)
|
|
|
|
overwriteClientCA(t, tmpCaPath, "testCA-bad.txt")
|
|
|
|
tlsSetting := ServerConfig{
|
|
ClientCAFile: tmpCaPath,
|
|
ReloadClientCAFile: true,
|
|
}
|
|
|
|
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
|
|
require.Error(t, err)
|
|
assert.Nil(t, tlsCfg)
|
|
}
|
|
|
|
func TestLoadTLSServerConfigWrongPath(t *testing.T) {
|
|
tmpCaPath := createTempClientCaFile(t)
|
|
|
|
tlsSetting := ServerConfig{
|
|
ClientCAFile: tmpCaPath + "wrong-path",
|
|
ReloadClientCAFile: true,
|
|
}
|
|
|
|
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
|
|
require.Error(t, err)
|
|
assert.Nil(t, tlsCfg)
|
|
}
|
|
|
|
func TestLoadTLSServerConfigFailing(t *testing.T) {
|
|
tmpCaPath := createTempClientCaFile(t)
|
|
|
|
overwriteClientCA(t, tmpCaPath, "ca-1.crt")
|
|
|
|
tlsSetting := ServerConfig{
|
|
ClientCAFile: tmpCaPath,
|
|
ReloadClientCAFile: true,
|
|
}
|
|
|
|
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, tlsCfg)
|
|
|
|
firstClient, err := tlsCfg.GetConfigForClient(nil)
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, firstClient)
|
|
|
|
err = os.Remove(tmpCaPath)
|
|
require.NoError(t, err)
|
|
|
|
firstClient, err = tlsCfg.GetConfigForClient(nil)
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, firstClient)
|
|
}
|
|
|
|
func overwriteClientCA(t *testing.T, targetFilePath, testdataFileName string) {
|
|
targetFile, err := os.OpenFile(filepath.Clean(targetFilePath), os.O_RDWR, 0o600)
|
|
require.NoError(t, err)
|
|
|
|
testdataFilePath := filepath.Join("testdata", testdataFileName)
|
|
testdataFile, err := os.OpenFile(filepath.Clean(testdataFilePath), os.O_RDONLY, 0o200)
|
|
require.NoError(t, err)
|
|
|
|
_, err = io.Copy(targetFile, testdataFile)
|
|
assert.NoError(t, err)
|
|
|
|
assert.NoError(t, targetFile.Close())
|
|
assert.NoError(t, testdataFile.Close())
|
|
}
|
|
|
|
func createTempClientCaFile(t *testing.T) string {
|
|
tmpCa, err := os.CreateTemp(t.TempDir(), "ca-tmp.crt")
|
|
require.NoError(t, err)
|
|
tmpCaPath, err := filepath.Abs(tmpCa.Name())
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, tmpCa.Close())
|
|
return tmpCaPath
|
|
}
|
|
|
|
func TestEagerlyLoadCertificate(t *testing.T) {
|
|
options := Config{
|
|
CertFile: filepath.Join("testdata", "client-1.crt"),
|
|
KeyFile: filepath.Join("testdata", "client-1.key"),
|
|
}
|
|
cfg, err := options.loadTLSConfig()
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, cfg)
|
|
cert, err := cfg.GetCertificate(&tls.ClientHelloInfo{})
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, cert)
|
|
pCert, err := x509.ParseCertificate(cert.Certificate[0])
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, pCert)
|
|
assert.ElementsMatch(t, []string{"example1"}, pCert.DNSNames)
|
|
}
|
|
|
|
func TestCertificateReload(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
reloadInterval time.Duration
|
|
wait time.Duration
|
|
cert2 string
|
|
key2 string
|
|
dns1 string
|
|
dns2 string
|
|
errText string
|
|
}{
|
|
{
|
|
name: "Should reload the certificate after reload-interval",
|
|
reloadInterval: 100 * time.Microsecond,
|
|
wait: 100 * time.Microsecond,
|
|
cert2: "client-2.crt",
|
|
key2: "client-2.key",
|
|
dns1: "example1",
|
|
dns2: "example2",
|
|
},
|
|
{
|
|
name: "Should return same cert if called before reload-interval",
|
|
reloadInterval: 100 * time.Millisecond,
|
|
wait: 100 * time.Microsecond,
|
|
cert2: "client-2.crt",
|
|
key2: "client-2.key",
|
|
dns1: "example1",
|
|
dns2: "example1",
|
|
},
|
|
{
|
|
name: "Should always return same cert if reload-interval is 0",
|
|
reloadInterval: 0,
|
|
wait: 100 * time.Microsecond,
|
|
cert2: "client-2.crt",
|
|
key2: "client-2.key",
|
|
dns1: "example1",
|
|
dns2: "example1",
|
|
},
|
|
{
|
|
name: "Should return an error if reloading fails",
|
|
reloadInterval: 100 * time.Microsecond,
|
|
wait: 100 * time.Microsecond,
|
|
cert2: "testCA-bad.txt",
|
|
key2: "client-2.key",
|
|
dns1: "example1",
|
|
errText: "failed to load TLS cert and key: failed to load TLS cert and key PEMs: tls: failed to find any PEM data in certificate input",
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
// Copy certs into a temp dir so we can safely modify them
|
|
tempDir := t.TempDir()
|
|
certFile, err := os.CreateTemp(tempDir, "cert")
|
|
require.NoError(t, err)
|
|
defer certFile.Close()
|
|
|
|
keyFile, err := os.CreateTemp(tempDir, "key")
|
|
require.NoError(t, err)
|
|
defer keyFile.Close()
|
|
|
|
fdc, err := os.Open(filepath.Join("testdata", "client-1.crt"))
|
|
require.NoError(t, err)
|
|
_, err = io.Copy(certFile, fdc)
|
|
require.NoError(t, err)
|
|
require.NoError(t, fdc.Close())
|
|
|
|
fdk, err := os.Open(filepath.Join("testdata", "client-1.key"))
|
|
require.NoError(t, err)
|
|
_, err = io.Copy(keyFile, fdk)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, fdk.Close())
|
|
|
|
options := Config{
|
|
CertFile: certFile.Name(),
|
|
KeyFile: keyFile.Name(),
|
|
ReloadInterval: test.reloadInterval,
|
|
}
|
|
cfg, err := options.loadTLSConfig()
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, cfg)
|
|
|
|
// Assert that we loaded the original certificate
|
|
cert, err := cfg.GetCertificate(&tls.ClientHelloInfo{})
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, cert)
|
|
pCert, err := x509.ParseCertificate(cert.Certificate[0])
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, pCert)
|
|
assert.Equal(t, test.dns1, pCert.DNSNames[0])
|
|
|
|
// Change the certificate
|
|
assert.NoError(t, certFile.Truncate(0))
|
|
assert.NoError(t, keyFile.Truncate(0))
|
|
_, err = certFile.Seek(0, 0)
|
|
require.NoError(t, err)
|
|
_, err = keyFile.Seek(0, 0)
|
|
require.NoError(t, err)
|
|
|
|
fdc2, err := os.Open(filepath.Join("testdata", test.cert2))
|
|
require.NoError(t, err)
|
|
_, err = io.Copy(certFile, fdc2)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, fdc2.Close())
|
|
|
|
fdk2, err := os.Open(filepath.Join("testdata", test.key2))
|
|
require.NoError(t, err)
|
|
_, err = io.Copy(keyFile, fdk2)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, fdk2.Close())
|
|
|
|
// Wait ReloadInterval to ensure a reload will happen
|
|
time.Sleep(test.wait)
|
|
|
|
// Assert that we loaded the new certificate
|
|
cert, err = cfg.GetCertificate(&tls.ClientHelloInfo{})
|
|
if test.errText == "" {
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, cert)
|
|
pCert, err = x509.ParseCertificate(cert.Certificate[0])
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, pCert)
|
|
assert.Equal(t, test.dns2, pCert.DNSNames[0])
|
|
} else {
|
|
assert.EqualError(t, err, test.errText)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMinMaxTLSVersions(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
minVersion string
|
|
maxVersion string
|
|
outMinVersion uint16
|
|
outMaxVersion uint16
|
|
errorTxt string
|
|
}{
|
|
{name: `TLS Config ["", ""] to give [TLS1.2, 0]`, minVersion: "", maxVersion: "", outMinVersion: tls.VersionTLS12, outMaxVersion: 0},
|
|
{name: `TLS Config ["", "1.3"] to give [TLS1.2, TLS1.3]`, minVersion: "", maxVersion: "1.3", outMinVersion: tls.VersionTLS12, outMaxVersion: tls.VersionTLS13},
|
|
{name: `TLS Config ["1.2", ""] to give [TLS1.2, 0]`, minVersion: "1.2", maxVersion: "", outMinVersion: tls.VersionTLS12, outMaxVersion: 0},
|
|
{name: `TLS Config ["1.3", "1.3"] to give [TLS1.3, TLS1.3]`, minVersion: "1.3", maxVersion: "1.3", outMinVersion: tls.VersionTLS13, outMaxVersion: tls.VersionTLS13},
|
|
{name: `TLS Config ["1.0", "1.1"] to give [TLS1.0, TLS1.1]`, minVersion: "1.0", maxVersion: "1.1", outMinVersion: tls.VersionTLS10, outMaxVersion: tls.VersionTLS11},
|
|
{name: `TLS Config ["asd", ""] to give [Error]`, minVersion: "asd", maxVersion: "", errorTxt: `invalid TLS min_version: unsupported TLS version: "asd"`},
|
|
{name: `TLS Config ["", "asd"] to give [Error]`, minVersion: "", maxVersion: "asd", errorTxt: `invalid TLS max_version: unsupported TLS version: "asd"`},
|
|
{name: `TLS Config ["0.4", ""] to give [Error]`, minVersion: "0.4", maxVersion: "", errorTxt: `invalid TLS min_version: unsupported TLS version: "0.4"`},
|
|
|
|
// Allowing this, however, expecting downstream TLS handshake will throw an error
|
|
{name: `TLS Config ["1.2", "1.1"] to give [TLS1.2, TLS1.1]`, minVersion: "1.2", maxVersion: "1.1", outMinVersion: tls.VersionTLS12, outMaxVersion: tls.VersionTLS11},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
setting := Config{
|
|
MinVersion: test.minVersion,
|
|
MaxVersion: test.maxVersion,
|
|
}
|
|
|
|
config, err := setting.loadTLSConfig()
|
|
|
|
if test.errorTxt == "" {
|
|
assert.Equal(t, config.MinVersion, test.outMinVersion)
|
|
assert.Equal(t, config.MaxVersion, test.outMaxVersion)
|
|
} else {
|
|
assert.EqualError(t, err, test.errorTxt)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConfigValidate(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tlsConfig Config
|
|
errorTxt string
|
|
}{
|
|
{name: `TLS Config ["", ""] to be valid`, tlsConfig: Config{MinVersion: "", MaxVersion: ""}},
|
|
{name: `TLS Config ["", "1.3"] to be valid`, tlsConfig: Config{MinVersion: "", MaxVersion: "1.3"}},
|
|
{name: `TLS Config ["1.2", ""] to be valid`, tlsConfig: Config{MinVersion: "1.2", MaxVersion: ""}},
|
|
{name: `TLS Config ["1.3", "1.3"] to be valid`, tlsConfig: Config{MinVersion: "1.3", MaxVersion: "1.3"}},
|
|
{name: `TLS Config ["1.0", "1.1"] to be valid`, tlsConfig: Config{MinVersion: "1.0", MaxVersion: "1.1"}},
|
|
{name: `TLS Config ["asd", ""] to give [Error]`, tlsConfig: Config{MinVersion: "asd", MaxVersion: ""}, errorTxt: `invalid TLS min_version: unsupported TLS version: "asd"`},
|
|
{name: `TLS Config ["", "asd"] to give [Error]`, tlsConfig: Config{MinVersion: "", MaxVersion: "asd"}, errorTxt: `invalid TLS max_version: unsupported TLS version: "asd"`},
|
|
{name: `TLS Config ["0.4", ""] to give [Error]`, tlsConfig: Config{MinVersion: "0.4", MaxVersion: ""}, errorTxt: `invalid TLS min_version: unsupported TLS version: "0.4"`},
|
|
{name: `TLS Config ["1.2", "1.1"] to give [Error]`, tlsConfig: Config{MinVersion: "1.2", MaxVersion: "1.1"}, errorTxt: `invalid TLS configuration: min_version cannot be greater than max_version`},
|
|
{name: `TLS Config with both CA File and PEM`, tlsConfig: Config{CAFile: "test", CAPem: "test"}, errorTxt: `provide either a CA file or the PEM-encoded string, but not both`},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
err := test.tlsConfig.Validate()
|
|
|
|
if test.errorTxt == "" {
|
|
assert.NoError(t, err)
|
|
} else {
|
|
assert.EqualError(t, err, test.errorTxt)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCipherSuites(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tlsSetting Config
|
|
wantErr string
|
|
result []uint16
|
|
}{
|
|
{
|
|
name: "no suites set",
|
|
tlsSetting: Config{},
|
|
result: nil,
|
|
},
|
|
{
|
|
name: "one cipher suite set",
|
|
tlsSetting: Config{
|
|
CipherSuites: []string{"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA"},
|
|
},
|
|
result: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
|
|
},
|
|
{
|
|
name: "invalid cipher suite set",
|
|
tlsSetting: Config{
|
|
CipherSuites: []string{"FOO"},
|
|
},
|
|
wantErr: `invalid TLS cipher suite: "FOO"`,
|
|
},
|
|
{
|
|
name: "multiple invalid cipher suites set",
|
|
tlsSetting: Config{
|
|
CipherSuites: []string{"FOO", "BAR"},
|
|
},
|
|
wantErr: `invalid TLS cipher suite: "FOO"
|
|
invalid TLS cipher suite: "BAR"`,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
config, err := test.tlsSetting.loadTLSConfig()
|
|
if test.wantErr != "" {
|
|
assert.EqualError(t, err, test.wantErr)
|
|
} else {
|
|
require.NoError(t, err)
|
|
assert.Equal(t, test.result, config.CipherSuites)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSystemCertPool(t *testing.T) {
|
|
anError := errors.New("my error")
|
|
tests := []struct {
|
|
name string
|
|
tlsConfig Config
|
|
wantErr error
|
|
systemCertFn func() (*x509.CertPool, error)
|
|
}{
|
|
{
|
|
name: "not using system cert pool",
|
|
tlsConfig: Config{
|
|
IncludeSystemCACertsPool: false,
|
|
CAFile: filepath.Join("testdata", "ca-1.crt"),
|
|
},
|
|
wantErr: nil,
|
|
systemCertFn: x509.SystemCertPool,
|
|
},
|
|
{
|
|
name: "using system cert pool",
|
|
tlsConfig: Config{
|
|
IncludeSystemCACertsPool: true,
|
|
CAFile: filepath.Join("testdata", "ca-1.crt"),
|
|
},
|
|
wantErr: nil,
|
|
systemCertFn: x509.SystemCertPool,
|
|
},
|
|
{
|
|
name: "error loading system cert pool",
|
|
tlsConfig: Config{
|
|
IncludeSystemCACertsPool: true,
|
|
CAFile: filepath.Join("testdata", "ca-1.crt"),
|
|
},
|
|
wantErr: anError,
|
|
systemCertFn: func() (*x509.CertPool, error) {
|
|
return nil, anError
|
|
},
|
|
},
|
|
{
|
|
name: "nil system cert pool",
|
|
tlsConfig: Config{
|
|
IncludeSystemCACertsPool: true,
|
|
CAFile: filepath.Join("testdata", "ca-1.crt"),
|
|
},
|
|
wantErr: nil,
|
|
systemCertFn: func() (*x509.CertPool, error) {
|
|
return nil, nil
|
|
},
|
|
},
|
|
}
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
oldSystemCertPool := systemCertPool
|
|
systemCertPool = test.systemCertFn
|
|
defer func() {
|
|
systemCertPool = oldSystemCertPool
|
|
}()
|
|
|
|
serverConfig := ServerConfig{
|
|
Config: test.tlsConfig,
|
|
}
|
|
c, err := serverConfig.LoadTLSConfig(context.Background())
|
|
if test.wantErr != nil {
|
|
require.ErrorContains(t, err, test.wantErr.Error())
|
|
} else {
|
|
assert.NotNil(t, c.RootCAs)
|
|
}
|
|
|
|
clientConfig := ClientConfig{
|
|
Config: test.tlsConfig,
|
|
}
|
|
c, err = clientConfig.LoadTLSConfig(context.Background())
|
|
if test.wantErr != nil {
|
|
assert.ErrorContains(t, err, test.wantErr.Error())
|
|
} else {
|
|
assert.NotNil(t, c.RootCAs)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSystemCertPool_loadCert(t *testing.T) {
|
|
anError := errors.New("my error")
|
|
tests := []struct {
|
|
name string
|
|
tlsConfig Config
|
|
wantErr error
|
|
systemCertFn func() (*x509.CertPool, error)
|
|
}{
|
|
{
|
|
name: "not using system cert pool",
|
|
tlsConfig: Config{
|
|
IncludeSystemCACertsPool: false,
|
|
},
|
|
wantErr: nil,
|
|
systemCertFn: x509.SystemCertPool,
|
|
},
|
|
{
|
|
name: "using system cert pool",
|
|
tlsConfig: Config{
|
|
IncludeSystemCACertsPool: true,
|
|
},
|
|
wantErr: nil,
|
|
systemCertFn: x509.SystemCertPool,
|
|
},
|
|
{
|
|
name: "error loading system cert pool",
|
|
tlsConfig: Config{
|
|
IncludeSystemCACertsPool: true,
|
|
},
|
|
wantErr: anError,
|
|
systemCertFn: func() (*x509.CertPool, error) {
|
|
return nil, anError
|
|
},
|
|
},
|
|
{
|
|
name: "nil system cert pool",
|
|
tlsConfig: Config{
|
|
IncludeSystemCACertsPool: true,
|
|
},
|
|
wantErr: nil,
|
|
systemCertFn: func() (*x509.CertPool, error) {
|
|
return nil, nil
|
|
},
|
|
},
|
|
}
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
oldSystemCertPool := systemCertPool
|
|
systemCertPool = test.systemCertFn
|
|
defer func() {
|
|
systemCertPool = oldSystemCertPool
|
|
}()
|
|
certPool, err := test.tlsConfig.loadCert(filepath.Join("testdata", "ca-1.crt"))
|
|
if test.wantErr != nil {
|
|
assert.Equal(t, test.wantErr, err)
|
|
} else {
|
|
assert.NotNil(t, certPool)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCurvePreferences(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
preferences []string
|
|
expectedCurveIDs []tls.CurveID
|
|
expectedErr string
|
|
}{
|
|
{
|
|
name: "X25519",
|
|
preferences: []string{"X25519"},
|
|
expectedCurveIDs: []tls.CurveID{tls.X25519},
|
|
},
|
|
{
|
|
name: "P521",
|
|
preferences: []string{"P521"},
|
|
expectedCurveIDs: []tls.CurveID{tls.CurveP521},
|
|
},
|
|
{
|
|
name: "P-256",
|
|
preferences: []string{"P256"},
|
|
expectedCurveIDs: []tls.CurveID{tls.CurveP256},
|
|
},
|
|
{
|
|
name: "multiple",
|
|
preferences: []string{"P256", "P521", "X25519"},
|
|
expectedCurveIDs: []tls.CurveID{tls.CurveP256, tls.CurveP521, tls.X25519},
|
|
},
|
|
{
|
|
name: "invalid-curve",
|
|
preferences: []string{"P25223236"},
|
|
expectedCurveIDs: []tls.CurveID{},
|
|
expectedErr: "invalid curve type",
|
|
},
|
|
}
|
|
for _, test := range tests {
|
|
tlsSetting := ClientConfig{
|
|
Config: Config{
|
|
CurvePreferences: test.preferences,
|
|
},
|
|
}
|
|
config, err := tlsSetting.LoadTLSConfig(context.Background())
|
|
if test.expectedErr == "" {
|
|
require.NoError(t, err)
|
|
require.ElementsMatchf(t, test.expectedCurveIDs, config.CurvePreferences, "expected %v, got %v", test.expectedCurveIDs, config.CurvePreferences)
|
|
} else {
|
|
require.ErrorContains(t, err, test.expectedErr)
|
|
}
|
|
}
|
|
}
|