opentelemetry-collector/config/configtls/configtls_test.go

926 lines
26 KiB
Go

// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0
package configtls
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/collector/config/configopaque"
)
func TestNewDefaultConfig(t *testing.T) {
expectedConfig := Config{}
config := NewDefaultConfig()
require.Equal(t, expectedConfig, config)
}
func TestNewDefaultClientConfig(t *testing.T) {
expectedConfig := ClientConfig{
Config: NewDefaultConfig(),
}
config := NewDefaultClientConfig()
require.Equal(t, expectedConfig, config)
}
func TestNewDefaultServerConfig(t *testing.T) {
expectedConfig := ServerConfig{
Config: NewDefaultConfig(),
}
config := NewDefaultServerConfig()
require.Equal(t, expectedConfig, config)
}
func TestOptionsToConfig(t *testing.T) {
tests := []struct {
name string
options Config
expectError string
}{
{
name: "should load system CA",
options: Config{CAFile: ""},
},
{
name: "should load custom CA",
options: Config{CAFile: filepath.Join("testdata", "ca-1.crt")},
},
{
name: "should load system CA and custom CA",
options: Config{IncludeSystemCACertsPool: true, CAFile: filepath.Join("testdata", "ca-1.crt")},
},
{
name: "should fail with invalid CA file path",
options: Config{CAFile: filepath.Join("testdata", "not/valid")},
expectError: "failed to load CA",
},
{
name: "should fail with invalid CA file content",
options: Config{CAFile: filepath.Join("testdata", "testCA-bad.txt")},
expectError: "failed to parse cert",
},
{
name: "should load valid TLS settings",
options: Config{
CAFile: filepath.Join("testdata", "ca-1.crt"),
CertFile: filepath.Join("testdata", "server-1.crt"),
KeyFile: filepath.Join("testdata", "server-1.key"),
},
},
{
name: "should fail with missing TLS KeyFile",
options: Config{
CAFile: filepath.Join("testdata", "ca-1.crt"),
CertFile: filepath.Join("testdata", "server-1.crt"),
},
expectError: "provide both certificate and key, or neither",
},
{
name: "should fail with invalid TLS KeyFile",
options: Config{
CAFile: filepath.Join("testdata", "ca-1.crt"),
CertFile: filepath.Join("testdata", "server-1.crt"),
KeyFile: filepath.Join("testdata", "not/valid"),
},
expectError: "failed to load TLS cert and key",
},
{
name: "should fail with missing TLS Cert",
options: Config{
CAFile: filepath.Join("testdata", "ca-1.crt"),
KeyFile: filepath.Join("testdata", "server-1.key"),
},
expectError: "provide both certificate and key, or neither",
},
{
name: "should fail with invalid TLS Cert",
options: Config{
CAFile: filepath.Join("testdata", "ca-1.crt"),
CertFile: filepath.Join("testdata", "not/valid"),
KeyFile: filepath.Join("testdata", "server-1.key"),
},
expectError: "failed to load TLS cert and key",
},
{
name: "should fail with invalid TLS CA",
options: Config{
CAFile: filepath.Join("testdata", "not/valid"),
},
expectError: "failed to load CA",
},
{
name: "should fail with invalid CA pool",
options: Config{
CAFile: filepath.Join("testdata", "testCA-bad.txt"),
},
expectError: "failed to parse cert",
},
{
name: "should pass with valid CA pool",
options: Config{
CAFile: filepath.Join("testdata", "ca-1.crt"),
},
},
{
name: "should pass with valid min and max version",
options: Config{
MinVersion: "1.1",
MaxVersion: "1.2",
},
},
{
name: "should pass with invalid min",
options: Config{
MinVersion: "1.7",
},
expectError: "invalid TLS min_",
},
{
name: "should pass with invalid max",
options: Config{
MaxVersion: "1.7",
},
expectError: "invalid TLS max_",
},
{
name: "should load custom CA PEM",
options: Config{CAPem: readFilePanics("testdata/ca-1.crt")},
},
{
name: "should load valid TLS settings with PEMs",
options: Config{
CAPem: readFilePanics("testdata/ca-1.crt"),
CertPem: readFilePanics("testdata/server-1.crt"),
KeyPem: readFilePanics("testdata/server-1.key"),
},
},
{
name: "mix Cert file and Key PEM provided",
options: Config{
CertFile: "testdata/server-1.crt",
KeyPem: readFilePanics("testdata/server-1.key"),
},
},
{
name: "mix Cert PEM and Key File provided",
options: Config{
CertPem: readFilePanics("testdata/server-1.crt"),
KeyFile: "testdata/server-1.key",
},
},
{
name: "should fail with invalid CA PEM",
options: Config{CAPem: readFilePanics("testdata/testCA-bad.txt")},
expectError: "failed to parse cert",
},
{
name: "should fail CA file and PEM both provided",
options: Config{
CAFile: "testdata/ca-1.crt",
CAPem: readFilePanics("testdata/ca-1.crt"),
},
expectError: "provide either a CA file or the PEM-encoded string, but not both",
},
{
name: "should fail Cert file and PEM both provided",
options: Config{
CertFile: "testdata/server-1.crt",
CertPem: readFilePanics("testdata/server-1.crt"),
KeyFile: "testdata/server-1.key",
},
expectError: "provide either a certificate or the PEM-encoded string, but not both",
},
{
name: "should fail Key file and PEM both provided",
options: Config{
CertFile: "testdata/server-1.crt",
KeyFile: "testdata/ca-1.crt",
KeyPem: readFilePanics("testdata/server-1.key"),
},
expectError: "provide either a key or the PEM-encoded string, but not both",
},
{
name: "should fail to load valid TLS settings with bad Cert PEM",
options: Config{
CAPem: readFilePanics("testdata/ca-1.crt"),
CertPem: readFilePanics("testdata/testCA-bad.txt"),
KeyPem: readFilePanics("testdata/server-1.key"),
},
expectError: "failed to load TLS cert and key PEMs",
},
{
name: "should fail to load valid TLS settings with bad Key PEM",
options: Config{
CAPem: readFilePanics("testdata/ca-1.crt"),
CertPem: readFilePanics("testdata/server-1.crt"),
KeyPem: readFilePanics("testdata/testCA-bad.txt"),
},
expectError: "failed to load TLS cert and key PEMs",
},
{
name: "should fail with missing TLS KeyPem",
options: Config{
CAPem: readFilePanics("testdata/ca-1.crt"),
CertPem: readFilePanics("testdata/server-1.crt"),
},
expectError: "provide both certificate and key, or neither",
},
{
name: "should fail with missing TLS Cert PEM",
options: Config{
CAPem: readFilePanics("testdata/ca-1.crt"),
KeyPem: readFilePanics("testdata/server-1.key"),
},
expectError: "provide both certificate and key, or neither",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
cfg, err := test.options.loadTLSConfig()
if test.expectError != "" {
assert.ErrorContains(t, err, test.expectError)
} else {
require.NoError(t, err)
assert.NotNil(t, cfg)
}
})
}
}
func readFilePanics(filePath string) configopaque.String {
fileContents, err := os.ReadFile(filepath.Clean(filePath))
if err != nil {
panic(fmt.Sprintf("failed to read file %s: %v", filePath, err))
}
return configopaque.String(fileContents)
}
func TestLoadTLSClientConfigError(t *testing.T) {
tlsSetting := ClientConfig{
Config: Config{
CertFile: "doesnt/exist",
KeyFile: "doesnt/exist",
},
}
_, err := tlsSetting.LoadTLSConfig(context.Background())
assert.Error(t, err)
}
func TestLoadTLSClientConfig(t *testing.T) {
tlsSetting := ClientConfig{
Insecure: true,
}
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
require.NoError(t, err)
assert.Nil(t, tlsCfg)
tlsSetting = ClientConfig{}
tlsCfg, err = tlsSetting.LoadTLSConfig(context.Background())
require.NoError(t, err)
assert.NotNil(t, tlsCfg)
tlsSetting = ClientConfig{
InsecureSkipVerify: true,
}
tlsCfg, err = tlsSetting.LoadTLSConfig(context.Background())
require.NoError(t, err)
assert.NotNil(t, tlsCfg)
assert.True(t, tlsCfg.InsecureSkipVerify)
}
func TestLoadTLSServerConfigError(t *testing.T) {
tlsSetting := ServerConfig{
Config: Config{
CertFile: "doesnt/exist",
KeyFile: "doesnt/exist",
},
}
_, err := tlsSetting.LoadTLSConfig(context.Background())
require.Error(t, err)
tlsSetting = ServerConfig{
ClientCAFile: "doesnt/exist",
}
_, err = tlsSetting.LoadTLSConfig(context.Background())
assert.Error(t, err)
}
func TestLoadTLSServerConfig(t *testing.T) {
tlsSetting := ServerConfig{}
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
require.NoError(t, err)
assert.NotNil(t, tlsCfg)
}
func TestLoadTLSServerConfigReload(t *testing.T) {
tmpCaPath := createTempClientCaFile(t)
overwriteClientCA(t, tmpCaPath, "ca-1.crt")
tlsSetting := ServerConfig{
ClientCAFile: tmpCaPath,
ReloadClientCAFile: true,
}
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
require.NoError(t, err)
assert.NotNil(t, tlsCfg)
firstClient, err := tlsCfg.GetConfigForClient(nil)
require.NoError(t, err)
overwriteClientCA(t, tmpCaPath, "ca-2.crt")
assert.Eventually(t, func() bool {
_, loadError := tlsCfg.GetConfigForClient(nil)
return loadError == nil
}, 5*time.Second, 10*time.Millisecond)
secondClient, err := tlsCfg.GetConfigForClient(nil)
require.NoError(t, err)
assert.NotEqual(t, firstClient.ClientCAs, secondClient.ClientCAs)
}
func TestLoadTLSServerConfigFailingReload(t *testing.T) {
tmpCaPath := createTempClientCaFile(t)
overwriteClientCA(t, tmpCaPath, "ca-1.crt")
tlsSetting := ServerConfig{
ClientCAFile: tmpCaPath,
ReloadClientCAFile: true,
}
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
require.NoError(t, err)
assert.NotNil(t, tlsCfg)
firstClient, err := tlsCfg.GetConfigForClient(nil)
require.NoError(t, err)
overwriteClientCA(t, tmpCaPath, "testCA-bad.txt")
assert.Eventually(t, func() bool {
_, loadError := tlsCfg.GetConfigForClient(nil)
return loadError == nil
}, 5*time.Second, 10*time.Millisecond)
secondClient, err := tlsCfg.GetConfigForClient(nil)
require.NoError(t, err)
assert.Equal(t, firstClient.ClientCAs, secondClient.ClientCAs)
}
func TestLoadTLSServerConfigFailingInitialLoad(t *testing.T) {
tmpCaPath := createTempClientCaFile(t)
overwriteClientCA(t, tmpCaPath, "testCA-bad.txt")
tlsSetting := ServerConfig{
ClientCAFile: tmpCaPath,
ReloadClientCAFile: true,
}
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
require.Error(t, err)
assert.Nil(t, tlsCfg)
}
func TestLoadTLSServerConfigWrongPath(t *testing.T) {
tmpCaPath := createTempClientCaFile(t)
tlsSetting := ServerConfig{
ClientCAFile: tmpCaPath + "wrong-path",
ReloadClientCAFile: true,
}
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
require.Error(t, err)
assert.Nil(t, tlsCfg)
}
func TestLoadTLSServerConfigFailing(t *testing.T) {
tmpCaPath := createTempClientCaFile(t)
overwriteClientCA(t, tmpCaPath, "ca-1.crt")
tlsSetting := ServerConfig{
ClientCAFile: tmpCaPath,
ReloadClientCAFile: true,
}
tlsCfg, err := tlsSetting.LoadTLSConfig(context.Background())
require.NoError(t, err)
assert.NotNil(t, tlsCfg)
firstClient, err := tlsCfg.GetConfigForClient(nil)
require.NoError(t, err)
assert.NotNil(t, firstClient)
err = os.Remove(tmpCaPath)
require.NoError(t, err)
firstClient, err = tlsCfg.GetConfigForClient(nil)
require.NoError(t, err)
assert.NotNil(t, firstClient)
}
func overwriteClientCA(t *testing.T, targetFilePath, testdataFileName string) {
targetFile, err := os.OpenFile(filepath.Clean(targetFilePath), os.O_RDWR, 0o600)
require.NoError(t, err)
testdataFilePath := filepath.Join("testdata", testdataFileName)
testdataFile, err := os.OpenFile(filepath.Clean(testdataFilePath), os.O_RDONLY, 0o200)
require.NoError(t, err)
_, err = io.Copy(targetFile, testdataFile)
assert.NoError(t, err)
assert.NoError(t, targetFile.Close())
assert.NoError(t, testdataFile.Close())
}
func createTempClientCaFile(t *testing.T) string {
tmpCa, err := os.CreateTemp(t.TempDir(), "ca-tmp.crt")
require.NoError(t, err)
tmpCaPath, err := filepath.Abs(tmpCa.Name())
assert.NoError(t, err)
assert.NoError(t, tmpCa.Close())
return tmpCaPath
}
func TestEagerlyLoadCertificate(t *testing.T) {
options := Config{
CertFile: filepath.Join("testdata", "client-1.crt"),
KeyFile: filepath.Join("testdata", "client-1.key"),
}
cfg, err := options.loadTLSConfig()
require.NoError(t, err)
assert.NotNil(t, cfg)
cert, err := cfg.GetCertificate(&tls.ClientHelloInfo{})
require.NoError(t, err)
assert.NotNil(t, cert)
pCert, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
assert.NotNil(t, pCert)
assert.ElementsMatch(t, []string{"example1"}, pCert.DNSNames)
}
func TestCertificateReload(t *testing.T) {
tests := []struct {
name string
reloadInterval time.Duration
wait time.Duration
cert2 string
key2 string
dns1 string
dns2 string
errText string
}{
{
name: "Should reload the certificate after reload-interval",
reloadInterval: 100 * time.Microsecond,
wait: 100 * time.Microsecond,
cert2: "client-2.crt",
key2: "client-2.key",
dns1: "example1",
dns2: "example2",
},
{
name: "Should return same cert if called before reload-interval",
reloadInterval: 100 * time.Millisecond,
wait: 100 * time.Microsecond,
cert2: "client-2.crt",
key2: "client-2.key",
dns1: "example1",
dns2: "example1",
},
{
name: "Should always return same cert if reload-interval is 0",
reloadInterval: 0,
wait: 100 * time.Microsecond,
cert2: "client-2.crt",
key2: "client-2.key",
dns1: "example1",
dns2: "example1",
},
{
name: "Should return an error if reloading fails",
reloadInterval: 100 * time.Microsecond,
wait: 100 * time.Microsecond,
cert2: "testCA-bad.txt",
key2: "client-2.key",
dns1: "example1",
errText: "failed to load TLS cert and key: failed to load TLS cert and key PEMs: tls: failed to find any PEM data in certificate input",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Copy certs into a temp dir so we can safely modify them
tempDir := t.TempDir()
certFile, err := os.CreateTemp(tempDir, "cert")
require.NoError(t, err)
defer certFile.Close()
keyFile, err := os.CreateTemp(tempDir, "key")
require.NoError(t, err)
defer keyFile.Close()
fdc, err := os.Open(filepath.Join("testdata", "client-1.crt"))
require.NoError(t, err)
_, err = io.Copy(certFile, fdc)
require.NoError(t, err)
require.NoError(t, fdc.Close())
fdk, err := os.Open(filepath.Join("testdata", "client-1.key"))
require.NoError(t, err)
_, err = io.Copy(keyFile, fdk)
assert.NoError(t, err)
assert.NoError(t, fdk.Close())
options := Config{
CertFile: certFile.Name(),
KeyFile: keyFile.Name(),
ReloadInterval: test.reloadInterval,
}
cfg, err := options.loadTLSConfig()
require.NoError(t, err)
assert.NotNil(t, cfg)
// Assert that we loaded the original certificate
cert, err := cfg.GetCertificate(&tls.ClientHelloInfo{})
require.NoError(t, err)
assert.NotNil(t, cert)
pCert, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
assert.NotNil(t, pCert)
assert.Equal(t, test.dns1, pCert.DNSNames[0])
// Change the certificate
assert.NoError(t, certFile.Truncate(0))
assert.NoError(t, keyFile.Truncate(0))
_, err = certFile.Seek(0, 0)
require.NoError(t, err)
_, err = keyFile.Seek(0, 0)
require.NoError(t, err)
fdc2, err := os.Open(filepath.Join("testdata", test.cert2))
require.NoError(t, err)
_, err = io.Copy(certFile, fdc2)
assert.NoError(t, err)
assert.NoError(t, fdc2.Close())
fdk2, err := os.Open(filepath.Join("testdata", test.key2))
require.NoError(t, err)
_, err = io.Copy(keyFile, fdk2)
assert.NoError(t, err)
assert.NoError(t, fdk2.Close())
// Wait ReloadInterval to ensure a reload will happen
time.Sleep(test.wait)
// Assert that we loaded the new certificate
cert, err = cfg.GetCertificate(&tls.ClientHelloInfo{})
if test.errText == "" {
require.NoError(t, err)
assert.NotNil(t, cert)
pCert, err = x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
assert.NotNil(t, pCert)
assert.Equal(t, test.dns2, pCert.DNSNames[0])
} else {
assert.EqualError(t, err, test.errText)
}
})
}
}
func TestMinMaxTLSVersions(t *testing.T) {
tests := []struct {
name string
minVersion string
maxVersion string
outMinVersion uint16
outMaxVersion uint16
errorTxt string
}{
{name: `TLS Config ["", ""] to give [TLS1.2, 0]`, minVersion: "", maxVersion: "", outMinVersion: tls.VersionTLS12, outMaxVersion: 0},
{name: `TLS Config ["", "1.3"] to give [TLS1.2, TLS1.3]`, minVersion: "", maxVersion: "1.3", outMinVersion: tls.VersionTLS12, outMaxVersion: tls.VersionTLS13},
{name: `TLS Config ["1.2", ""] to give [TLS1.2, 0]`, minVersion: "1.2", maxVersion: "", outMinVersion: tls.VersionTLS12, outMaxVersion: 0},
{name: `TLS Config ["1.3", "1.3"] to give [TLS1.3, TLS1.3]`, minVersion: "1.3", maxVersion: "1.3", outMinVersion: tls.VersionTLS13, outMaxVersion: tls.VersionTLS13},
{name: `TLS Config ["1.0", "1.1"] to give [TLS1.0, TLS1.1]`, minVersion: "1.0", maxVersion: "1.1", outMinVersion: tls.VersionTLS10, outMaxVersion: tls.VersionTLS11},
{name: `TLS Config ["asd", ""] to give [Error]`, minVersion: "asd", maxVersion: "", errorTxt: `invalid TLS min_version: unsupported TLS version: "asd"`},
{name: `TLS Config ["", "asd"] to give [Error]`, minVersion: "", maxVersion: "asd", errorTxt: `invalid TLS max_version: unsupported TLS version: "asd"`},
{name: `TLS Config ["0.4", ""] to give [Error]`, minVersion: "0.4", maxVersion: "", errorTxt: `invalid TLS min_version: unsupported TLS version: "0.4"`},
// Allowing this, however, expecting downstream TLS handshake will throw an error
{name: `TLS Config ["1.2", "1.1"] to give [TLS1.2, TLS1.1]`, minVersion: "1.2", maxVersion: "1.1", outMinVersion: tls.VersionTLS12, outMaxVersion: tls.VersionTLS11},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
setting := Config{
MinVersion: test.minVersion,
MaxVersion: test.maxVersion,
}
config, err := setting.loadTLSConfig()
if test.errorTxt == "" {
assert.Equal(t, config.MinVersion, test.outMinVersion)
assert.Equal(t, config.MaxVersion, test.outMaxVersion)
} else {
assert.EqualError(t, err, test.errorTxt)
}
})
}
}
func TestConfigValidate(t *testing.T) {
tests := []struct {
name string
tlsConfig Config
errorTxt string
}{
{name: `TLS Config ["", ""] to be valid`, tlsConfig: Config{MinVersion: "", MaxVersion: ""}},
{name: `TLS Config ["", "1.3"] to be valid`, tlsConfig: Config{MinVersion: "", MaxVersion: "1.3"}},
{name: `TLS Config ["1.2", ""] to be valid`, tlsConfig: Config{MinVersion: "1.2", MaxVersion: ""}},
{name: `TLS Config ["1.3", "1.3"] to be valid`, tlsConfig: Config{MinVersion: "1.3", MaxVersion: "1.3"}},
{name: `TLS Config ["1.0", "1.1"] to be valid`, tlsConfig: Config{MinVersion: "1.0", MaxVersion: "1.1"}},
{name: `TLS Config ["asd", ""] to give [Error]`, tlsConfig: Config{MinVersion: "asd", MaxVersion: ""}, errorTxt: `invalid TLS min_version: unsupported TLS version: "asd"`},
{name: `TLS Config ["", "asd"] to give [Error]`, tlsConfig: Config{MinVersion: "", MaxVersion: "asd"}, errorTxt: `invalid TLS max_version: unsupported TLS version: "asd"`},
{name: `TLS Config ["0.4", ""] to give [Error]`, tlsConfig: Config{MinVersion: "0.4", MaxVersion: ""}, errorTxt: `invalid TLS min_version: unsupported TLS version: "0.4"`},
{name: `TLS Config ["1.2", "1.1"] to give [Error]`, tlsConfig: Config{MinVersion: "1.2", MaxVersion: "1.1"}, errorTxt: `invalid TLS configuration: min_version cannot be greater than max_version`},
{name: `TLS Config with both CA File and PEM`, tlsConfig: Config{CAFile: "test", CAPem: "test"}, errorTxt: `provide either a CA file or the PEM-encoded string, but not both`},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
err := test.tlsConfig.Validate()
if test.errorTxt == "" {
assert.NoError(t, err)
} else {
assert.EqualError(t, err, test.errorTxt)
}
})
}
}
func TestCipherSuites(t *testing.T) {
tests := []struct {
name string
tlsSetting Config
wantErr string
result []uint16
}{
{
name: "no suites set",
tlsSetting: Config{},
result: nil,
},
{
name: "one cipher suite set",
tlsSetting: Config{
CipherSuites: []string{"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA"},
},
result: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
},
{
name: "invalid cipher suite set",
tlsSetting: Config{
CipherSuites: []string{"FOO"},
},
wantErr: `invalid TLS cipher suite: "FOO"`,
},
{
name: "multiple invalid cipher suites set",
tlsSetting: Config{
CipherSuites: []string{"FOO", "BAR"},
},
wantErr: `invalid TLS cipher suite: "FOO"
invalid TLS cipher suite: "BAR"`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
config, err := test.tlsSetting.loadTLSConfig()
if test.wantErr != "" {
assert.EqualError(t, err, test.wantErr)
} else {
require.NoError(t, err)
assert.Equal(t, test.result, config.CipherSuites)
}
})
}
}
func TestSystemCertPool(t *testing.T) {
anError := errors.New("my error")
tests := []struct {
name string
tlsConfig Config
wantErr error
systemCertFn func() (*x509.CertPool, error)
}{
{
name: "not using system cert pool",
tlsConfig: Config{
IncludeSystemCACertsPool: false,
CAFile: filepath.Join("testdata", "ca-1.crt"),
},
wantErr: nil,
systemCertFn: x509.SystemCertPool,
},
{
name: "using system cert pool",
tlsConfig: Config{
IncludeSystemCACertsPool: true,
CAFile: filepath.Join("testdata", "ca-1.crt"),
},
wantErr: nil,
systemCertFn: x509.SystemCertPool,
},
{
name: "error loading system cert pool",
tlsConfig: Config{
IncludeSystemCACertsPool: true,
CAFile: filepath.Join("testdata", "ca-1.crt"),
},
wantErr: anError,
systemCertFn: func() (*x509.CertPool, error) {
return nil, anError
},
},
{
name: "nil system cert pool",
tlsConfig: Config{
IncludeSystemCACertsPool: true,
CAFile: filepath.Join("testdata", "ca-1.crt"),
},
wantErr: nil,
systemCertFn: func() (*x509.CertPool, error) {
return nil, nil
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
oldSystemCertPool := systemCertPool
systemCertPool = test.systemCertFn
defer func() {
systemCertPool = oldSystemCertPool
}()
serverConfig := ServerConfig{
Config: test.tlsConfig,
}
c, err := serverConfig.LoadTLSConfig(context.Background())
if test.wantErr != nil {
require.ErrorContains(t, err, test.wantErr.Error())
} else {
assert.NotNil(t, c.RootCAs)
}
clientConfig := ClientConfig{
Config: test.tlsConfig,
}
c, err = clientConfig.LoadTLSConfig(context.Background())
if test.wantErr != nil {
assert.ErrorContains(t, err, test.wantErr.Error())
} else {
assert.NotNil(t, c.RootCAs)
}
})
}
}
func TestSystemCertPool_loadCert(t *testing.T) {
anError := errors.New("my error")
tests := []struct {
name string
tlsConfig Config
wantErr error
systemCertFn func() (*x509.CertPool, error)
}{
{
name: "not using system cert pool",
tlsConfig: Config{
IncludeSystemCACertsPool: false,
},
wantErr: nil,
systemCertFn: x509.SystemCertPool,
},
{
name: "using system cert pool",
tlsConfig: Config{
IncludeSystemCACertsPool: true,
},
wantErr: nil,
systemCertFn: x509.SystemCertPool,
},
{
name: "error loading system cert pool",
tlsConfig: Config{
IncludeSystemCACertsPool: true,
},
wantErr: anError,
systemCertFn: func() (*x509.CertPool, error) {
return nil, anError
},
},
{
name: "nil system cert pool",
tlsConfig: Config{
IncludeSystemCACertsPool: true,
},
wantErr: nil,
systemCertFn: func() (*x509.CertPool, error) {
return nil, nil
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
oldSystemCertPool := systemCertPool
systemCertPool = test.systemCertFn
defer func() {
systemCertPool = oldSystemCertPool
}()
certPool, err := test.tlsConfig.loadCert(filepath.Join("testdata", "ca-1.crt"))
if test.wantErr != nil {
assert.Equal(t, test.wantErr, err)
} else {
assert.NotNil(t, certPool)
}
})
}
}
func TestCurvePreferences(t *testing.T) {
tests := []struct {
name string
preferences []string
expectedCurveIDs []tls.CurveID
expectedErr string
}{
{
name: "X25519",
preferences: []string{"X25519"},
expectedCurveIDs: []tls.CurveID{tls.X25519},
},
{
name: "P521",
preferences: []string{"P521"},
expectedCurveIDs: []tls.CurveID{tls.CurveP521},
},
{
name: "P-256",
preferences: []string{"P256"},
expectedCurveIDs: []tls.CurveID{tls.CurveP256},
},
{
name: "multiple",
preferences: []string{"P256", "P521", "X25519"},
expectedCurveIDs: []tls.CurveID{tls.CurveP256, tls.CurveP521, tls.X25519},
},
{
name: "invalid-curve",
preferences: []string{"P25223236"},
expectedCurveIDs: []tls.CurveID{},
expectedErr: "invalid curve type",
},
}
for _, test := range tests {
tlsSetting := ClientConfig{
Config: Config{
CurvePreferences: test.preferences,
},
}
config, err := tlsSetting.LoadTLSConfig(context.Background())
if test.expectedErr == "" {
require.NoError(t, err)
require.ElementsMatchf(t, test.expectedCurveIDs, config.CurvePreferences, "expected %v, got %v", test.expectedCurveIDs, config.CurvePreferences)
} else {
require.ErrorContains(t, err, test.expectedErr)
}
}
}