advancedtls: add fields for root and identity providers in API (#3863)

* add provider in advancedtls API for pem file reloading
This commit is contained in:
ZhenLian 2020-09-17 12:08:03 -07:00 committed by GitHub
parent 4270c3cfce
commit 0f7e218c2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 613 additions and 281 deletions

View File

@ -27,10 +27,12 @@ import (
"crypto/x509"
"fmt"
"net"
"reflect"
"syscall"
"time"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider"
credinternal "google.golang.org/grpc/internal/credentials"
)
@ -79,21 +81,67 @@ type GetRootCAsResults struct {
TrustCerts *x509.CertPool
}
// RootCertificateOptions contains a field and a function for obtaining root
// trust certificates.
// It is used by both ClientOptions and ServerOptions.
// If users want to use default verification, but did not provide a valid
// RootCertificateOptions, we use the system default trust certificates.
// RootCertificateOptions contains options to obtain root trust certificates
// for both the client and the server.
// At most one option could be set. If none of them are set, we
// use the system default trust certificates.
type RootCertificateOptions struct {
// If field RootCACerts is set, field GetRootCAs will be ignored. RootCACerts
// will be used every time when verifying the peer certificates, without
// performing root certificate reloading.
// If RootCACerts is set, it will be used every time when verifying
// the peer certificates, without performing root certificate reloading.
RootCACerts *x509.CertPool
// If GetRootCAs is set and RootCACerts is nil, GetRootCAs will be invoked
// every time asked to check certificates sent from the server when a new
// connection is established.
// This is known as root CA certificate reloading.
GetRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error)
// If GetRootCertificates is set, it will be invoked to obtain root certs for
// every new connection.
GetRootCertificates func(params *GetRootCAsParams) (*GetRootCAsResults, error)
// If RootProvider is set, we will use the root certs from the Provider's
// KeyMaterial() call in the new connections. The Provider must have initial
// credentials if specified. Otherwise, KeyMaterial() will block forever.
RootProvider certprovider.Provider
}
// nonNilFieldCount returns the number of set fields in RootCertificateOptions.
func (o RootCertificateOptions) nonNilFieldCount() int {
cnt := 0
rv := reflect.ValueOf(o)
for i := 0; i < rv.NumField(); i++ {
if !rv.Field(i).IsNil() {
cnt++
}
}
return cnt
}
// IdentityCertificateOptions contains options to obtain identity certificates
// for both the client and the server.
// At most one option could be set.
type IdentityCertificateOptions struct {
// If Certificates is set, it will be used every time when needed to present
//identity certificates, without performing identity certificate reloading.
Certificates []tls.Certificate
// If GetIdentityCertificatesForClient is set, it will be invoked to obtain
// identity certs for every new connection.
// This field MUST be set on client side.
GetIdentityCertificatesForClient func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
// If GetIdentityCertificatesForServer is set, it will be invoked to obtain
// identity certs for every new connection.
// This field MUST be set on server side.
GetIdentityCertificatesForServer func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
// If IdentityProvider is set, we will use the identity certs from the
// Provider's KeyMaterial() call in the new connections. The Provider must
// have initial credentials if specified. Otherwise, KeyMaterial() will block
// forever.
IdentityProvider certprovider.Provider
}
// nonNilFieldCount returns the number of set fields in IdentityCertificateOptions.
func (o IdentityCertificateOptions) nonNilFieldCount() int {
cnt := 0
rv := reflect.ValueOf(o)
for i := 0; i < rv.NumField(); i++ {
if !rv.Field(i).IsNil() {
cnt++
}
}
return cnt
}
// VerificationType is the enum type that represents different levels of
@ -115,27 +163,11 @@ const (
SkipVerification
)
// ClientOptions contains all the fields and functions needed to be filled by
// the client.
// General rules for certificate setting on client side:
// Certificates or GetClientCertificate indicates the certificates sent from
// the client to the server to prove client's identities. The rules for setting
// these two fields are:
// If requiring mutual authentication on server side:
// Either Certificates or GetClientCertificate must be set; the other will
// be ignored.
// Otherwise:
// Nothing needed(the two fields will be ignored).
// ClientOptions contains the fields needed to be filled by the client.
type ClientOptions struct {
// If field Certificates is set, field GetClientCertificate will be ignored.
// The client will use Certificates every time when asked for a certificate,
// without performing certificate reloading.
Certificates []tls.Certificate
// If GetClientCertificate is set and Certificates is nil, the client will
// invoke this function every time asked to present certificates to the
// server when a new connection is established. This is known as peer
// certificate reloading.
GetClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
// IdentityOptions is OPTIONAL on client side. This field only needs to be
// set if mutual authentication is required on server side.
IdentityOptions IdentityCertificateOptions
// VerifyPeer is a custom verification check after certificate signature
// check.
// If this is set, we will perform this customized check after doing the
@ -145,37 +177,25 @@ type ClientOptions struct {
// it will override the virtual host name of authority (e.g. :authority
// header field) in requests.
ServerNameOverride string
// RootCertificateOptions is REQUIRED to be correctly set on client side.
RootCertificateOptions
// RootOptions is OPTIONAL on client side. If not set, we will try to use the
// default trust certificates in users' OS system.
RootOptions RootCertificateOptions
// VType is the verification type on the client side.
VType VerificationType
}
// ServerOptions contains all the fields and functions needed to be filled by
// the client.
// General rules for certificate setting on server side:
// Certificates or GetClientCertificate indicates the certificates sent from
// the server to the client to prove server's identities. The rules for setting
// these two fields are:
// Either Certificates or GetCertificates must be set; the other will be ignored.
// ServerOptions contains the fields needed to be filled by the server.
type ServerOptions struct {
// If field Certificates is set, field GetClientCertificate will be ignored.
// The server will use Certificates every time when asked for a certificate,
// without performing certificate reloading.
Certificates []tls.Certificate
// If GetClientCertificate is set and Certificates is nil, the server will
// invoke this function every time asked to present certificates to the
// client when a new connection is established. This is known as peer
// certificate reloading.
GetCertificates func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
// IdentityOptions is REQUIRED on server side.
IdentityOptions IdentityCertificateOptions
// VerifyPeer is a custom verification check after certificate signature
// check.
// If this is set, we will perform this customized check after doing the
// normal check(s) indicated by setting VType.
VerifyPeer CustomVerificationFunc
// RootCertificateOptions is only required when mutual TLS is
// enabled(RequireClientCert is true).
RootCertificateOptions
// RootOptions is OPTIONAL on server side. This field only needs to be set if
// mutual authentication is required(RequireClientCert is true).
RootOptions RootCertificateOptions
// If the server want the client to send certificates.
RequireClientCert bool
// VType is the verification type on the server side.
@ -184,48 +204,89 @@ type ServerOptions struct {
func (o *ClientOptions) config() (*tls.Config, error) {
if o.VType == SkipVerification && o.VerifyPeer == nil {
return nil, fmt.Errorf(
"client needs to provide custom verification mechanism if choose to skip default verification")
return nil, fmt.Errorf("client needs to provide custom verification mechanism if choose to skip default verification")
}
rootCAs := o.RootCACerts
if o.VType != SkipVerification && o.RootCACerts == nil && o.GetRootCAs == nil {
// Set rootCAs to system default.
systemRootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
rootCAs = systemRootCAs
// Make sure users didn't specify more than one fields in
// RootCertificateOptions and IdentityCertificateOptions.
if num := o.RootOptions.nonNilFieldCount(); num > 1 {
return nil, fmt.Errorf("at most one field in RootCertificateOptions could be specified")
}
if num := o.IdentityOptions.nonNilFieldCount(); num > 1 {
return nil, fmt.Errorf("at most one field in IdentityCertificateOptions could be specified")
}
if o.IdentityOptions.GetIdentityCertificatesForServer != nil {
return nil, fmt.Errorf("GetIdentityCertificatesForServer cannot be specified on the client side")
}
// We have to set InsecureSkipVerify to true to skip the default checks and
// use the verification function we built from buildVerifyFunc.
config := &tls.Config{
ServerName: o.ServerNameOverride,
Certificates: o.Certificates,
GetClientCertificate: o.GetClientCertificate,
InsecureSkipVerify: true,
ServerName: o.ServerNameOverride,
// We have to set InsecureSkipVerify to true to skip the default checks and
// use the verification function we built from buildVerifyFunc.
InsecureSkipVerify: true,
}
if rootCAs != nil {
config.RootCAs = rootCAs
// Propagate root-certificate-related fields in tls.Config.
switch {
case o.RootOptions.RootCACerts != nil:
config.RootCAs = o.RootOptions.RootCACerts
case o.RootOptions.GetRootCertificates != nil:
// In cases when users provide GetRootCertificates callback, since this
// callback is not contained in tls.Config, we have nothing to set here.
// We will invoke the callback in ClientHandshake.
case o.RootOptions.RootProvider != nil:
o.RootOptions.GetRootCertificates = func(*GetRootCAsParams) (*GetRootCAsResults, error) {
km, err := o.RootOptions.RootProvider.KeyMaterial(context.Background())
if err != nil {
return nil, err
}
return &GetRootCAsResults{TrustCerts: km.Roots}, nil
}
default:
// No root certificate options specified by user. Use the certificates
// stored in system default path as the last resort.
if o.VType != SkipVerification {
systemRootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
config.RootCAs = systemRootCAs
}
}
// Propagate identity-certificate-related fields in tls.Config.
switch {
case o.IdentityOptions.Certificates != nil:
config.Certificates = o.IdentityOptions.Certificates
case o.IdentityOptions.GetIdentityCertificatesForClient != nil:
config.GetClientCertificate = o.IdentityOptions.GetIdentityCertificatesForClient
case o.IdentityOptions.IdentityProvider != nil:
config.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
km, err := o.IdentityOptions.IdentityProvider.KeyMaterial(context.Background())
if err != nil {
return nil, err
}
if len(km.Certs) != 1 {
return nil, fmt.Errorf("there should always be only one identity cert chain on the client side in IdentityProvider")
}
return &km.Certs[0], nil
}
default:
// It's fine for users to not specify identity certificate options here.
}
return config, nil
}
func (o *ServerOptions) config() (*tls.Config, error) {
if o.Certificates == nil && o.GetCertificates == nil {
return nil, fmt.Errorf("either Certificates or GetCertificates must be specified")
}
if o.RequireClientCert && o.VType == SkipVerification && o.VerifyPeer == nil {
return nil, fmt.Errorf(
"server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)")
return nil, fmt.Errorf("server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)")
}
clientCAs := o.RootCACerts
if o.VType != SkipVerification && o.RootCACerts == nil && o.GetRootCAs == nil && o.RequireClientCert {
// Set clientCAs to system default.
systemRootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
clientCAs = systemRootCAs
// Make sure users didn't specify more than one fields in
// RootCertificateOptions and IdentityCertificateOptions.
if num := o.RootOptions.nonNilFieldCount(); num > 1 {
return nil, fmt.Errorf("at most one field in RootCertificateOptions could be specified")
}
if num := o.IdentityOptions.nonNilFieldCount(); num > 1 {
return nil, fmt.Errorf("at most one field in IdentityCertificateOptions could be specified")
}
if o.IdentityOptions.GetIdentityCertificatesForClient != nil {
return nil, fmt.Errorf("GetIdentityCertificatesForClient cannot be specified on the server side")
}
clientAuth := tls.NoClientCert
if o.RequireClientCert {
@ -235,18 +296,60 @@ func (o *ServerOptions) config() (*tls.Config, error) {
clientAuth = tls.RequireAnyClientCert
}
config := &tls.Config{
ClientAuth: clientAuth,
Certificates: o.Certificates,
ClientAuth: clientAuth,
}
if o.GetCertificates != nil {
// GetCertificate is only able to perform SNI logic for go1.10 and above.
// It will return the first certificate in o.GetCertificates for go1.9.
// Propagate root-certificate-related fields in tls.Config.
switch {
case o.RootOptions.RootCACerts != nil:
config.ClientCAs = o.RootOptions.RootCACerts
case o.RootOptions.GetRootCertificates != nil:
// In cases when users provide GetRootCertificates callback, since this
// callback is not contained in tls.Config, we have nothing to set here.
// We will invoke the callback in ServerHandshake.
case o.RootOptions.RootProvider != nil:
o.RootOptions.GetRootCertificates = func(*GetRootCAsParams) (*GetRootCAsResults, error) {
km, err := o.RootOptions.RootProvider.KeyMaterial(context.Background())
if err != nil {
return nil, err
}
return &GetRootCAsResults{TrustCerts: km.Roots}, nil
}
default:
// No root certificate options specified by user. Use the certificates
// stored in system default path as the last resort.
if o.VType != SkipVerification && o.RequireClientCert {
systemRootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
config.ClientCAs = systemRootCAs
}
}
// Propagate identity-certificate-related fields in tls.Config.
switch {
case o.IdentityOptions.Certificates != nil:
config.Certificates = o.IdentityOptions.Certificates
case o.IdentityOptions.GetIdentityCertificatesForServer != nil:
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return buildGetCertificates(clientHello, o)
}
}
if clientCAs != nil {
config.ClientCAs = clientCAs
case o.IdentityOptions.IdentityProvider != nil:
o.IdentityOptions.GetIdentityCertificatesForServer = func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
km, err := o.IdentityOptions.IdentityProvider.KeyMaterial(context.Background())
if err != nil {
return nil, err
}
var certChains []*tls.Certificate
for i := 0; i < len(km.Certs); i++ {
certChains = append(certChains, &km.Certs[i])
}
return certChains, nil
}
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return buildGetCertificates(clientHello, o)
}
default:
return nil, fmt.Errorf("needs to specify at least one field in IdentityCertificateOptions")
}
return config, nil
}
@ -423,7 +526,7 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error)
tc := &advancedTLSCreds{
config: conf,
isClient: true,
getRootCAs: o.GetRootCAs,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.VerifyPeer,
vType: o.VType,
}
@ -441,7 +544,7 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error)
tc := &advancedTLSCreds{
config: conf,
isClient: false,
getRootCAs: o.GetRootCAs,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.VerifyPeer,
vType: o.VType,
}

View File

@ -385,11 +385,13 @@ func (s) TestEnd2End(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
// Start a server using ServerOptions in another goroutine.
serverOptions := &ServerOptions{
Certificates: test.serverCert,
GetCertificates: test.serverGetCert,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: test.serverRoot,
GetRootCAs: test.serverGetRoot,
IdentityOptions: IdentityCertificateOptions{
Certificates: test.serverCert,
GetIdentityCertificatesForServer: test.serverGetCert,
},
RootOptions: RootCertificateOptions{
RootCACerts: test.serverRoot,
GetRootCertificates: test.serverGetRoot,
},
RequireClientCert: true,
VerifyPeer: test.serverVerifyFunc,
@ -409,12 +411,14 @@ func (s) TestEnd2End(t *testing.T) {
pb.RegisterGreeterService(s, &pb.GreeterService{SayHello: sayHello})
go s.Serve(lis)
clientOptions := &ClientOptions{
Certificates: test.clientCert,
GetClientCertificate: test.clientGetCert,
VerifyPeer: test.clientVerifyFunc,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
GetRootCAs: test.clientGetRoot,
IdentityOptions: IdentityCertificateOptions{
Certificates: test.clientCert,
GetIdentityCertificatesForClient: test.clientGetCert,
},
VerifyPeer: test.clientVerifyFunc,
RootOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
GetRootCertificates: test.clientGetRoot,
},
VType: test.clientVType,
}

View File

@ -34,6 +34,7 @@ import (
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/security/advancedtls/testdata"
)
@ -46,14 +47,274 @@ func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
func (s) TestClientServerHandshake(t *testing.T) {
// ------------------Load Client Trust Cert and Peer Cert-------------------
clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem"))
type provType int
const (
provTypeRoot provType = iota
provTypeIdentity
)
type fakeProvider struct {
pt provType
isClient bool
wantMultiCert bool
wantError bool
}
func (f fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
if f.wantError {
return nil, fmt.Errorf("bad fakeProvider")
}
cs := &certStore{}
err := cs.loadCerts()
if err != nil {
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
return nil, fmt.Errorf("failed to load certs: %v", err)
}
if f.pt == provTypeRoot && f.isClient {
return &certprovider.KeyMaterial{Roots: cs.clientTrust1}, nil
}
if f.pt == provTypeRoot && !f.isClient {
return &certprovider.KeyMaterial{Roots: cs.serverTrust1}, nil
}
if f.pt == provTypeIdentity && f.isClient {
if f.wantMultiCert {
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1, cs.clientPeer2}}, nil
}
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}}, nil
}
if f.wantMultiCert {
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1, cs.serverPeer2}}, nil
}
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1}}, nil
}
func (f fakeProvider) Close() {}
func (s) TestClientOptionsConfigErrorCases(t *testing.T) {
tests := []struct {
desc string
clientVType VerificationType
IdentityOptions IdentityCertificateOptions
RootOptions RootCertificateOptions
}{
{
desc: "Skip default verification and provide no root credentials",
clientVType: SkipVerification,
},
{
desc: "More than one fields in RootCertificateOptions is specified",
clientVType: CertVerification,
RootOptions: RootCertificateOptions{
RootCACerts: x509.NewCertPool(),
RootProvider: fakeProvider{},
},
},
{
desc: "More than one fields in IdentityCertificateOptions is specified",
clientVType: CertVerification,
IdentityOptions: IdentityCertificateOptions{
GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return nil, nil
},
IdentityProvider: fakeProvider{pt: provTypeIdentity},
},
},
{
desc: "Specify GetIdentityCertificatesForServer",
IdentityOptions: IdentityCertificateOptions{
GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return nil, nil
},
},
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
clientOptions := &ClientOptions{
VType: test.clientVType,
IdentityOptions: test.IdentityOptions,
RootOptions: test.RootOptions,
}
_, err := clientOptions.config()
if err == nil {
t.Fatalf("ClientOptions{%v}.config() returns no err, wantErr != nil", clientOptions)
}
})
}
}
func (s) TestClientOptionsConfigSuccessCases(t *testing.T) {
tests := []struct {
desc string
clientVType VerificationType
IdentityOptions IdentityCertificateOptions
RootOptions RootCertificateOptions
}{
{
desc: "Use system default if no fields in RootCertificateOptions is specified",
clientVType: CertVerification,
},
{
desc: "Good case with mutual TLS",
clientVType: CertVerification,
RootOptions: RootCertificateOptions{
RootProvider: fakeProvider{},
},
IdentityOptions: IdentityCertificateOptions{
IdentityProvider: fakeProvider{pt: provTypeIdentity},
},
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
clientOptions := &ClientOptions{
VType: test.clientVType,
IdentityOptions: test.IdentityOptions,
RootOptions: test.RootOptions,
}
clientConfig, err := clientOptions.config()
if err != nil {
t.Fatalf("ClientOptions{%v}.config() = %v, wantErr == nil", clientOptions, err)
}
// Verify that the system-provided certificates would be used
// when no verification method was set in clientOptions.
if clientOptions.RootOptions.RootCACerts == nil &&
clientOptions.RootOptions.GetRootCertificates == nil && clientOptions.RootOptions.RootProvider == nil {
if clientConfig.RootCAs == nil {
t.Fatalf("Failed to assign system-provided certificates on the client side.")
}
}
})
}
}
func (s) TestServerOptionsConfigErrorCases(t *testing.T) {
tests := []struct {
desc string
requireClientCert bool
serverVType VerificationType
IdentityOptions IdentityCertificateOptions
RootOptions RootCertificateOptions
}{
{
desc: "Skip default verification and provide no root credentials",
requireClientCert: true,
serverVType: SkipVerification,
},
{
desc: "More than one fields in RootCertificateOptions is specified",
requireClientCert: true,
serverVType: CertVerification,
RootOptions: RootCertificateOptions{
RootCACerts: x509.NewCertPool(),
GetRootCertificates: func(*GetRootCAsParams) (*GetRootCAsResults, error) {
return nil, nil
},
},
},
{
desc: "More than one fields in IdentityCertificateOptions is specified",
serverVType: CertVerification,
IdentityOptions: IdentityCertificateOptions{
Certificates: []tls.Certificate{},
IdentityProvider: fakeProvider{pt: provTypeIdentity},
},
},
{
desc: "no field in IdentityCertificateOptions is specified",
serverVType: CertVerification,
},
{
desc: "Specify GetIdentityCertificatesForClient",
IdentityOptions: IdentityCertificateOptions{
GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return nil, nil
},
},
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
serverOptions := &ServerOptions{
VType: test.serverVType,
RequireClientCert: test.requireClientCert,
IdentityOptions: test.IdentityOptions,
RootOptions: test.RootOptions,
}
_, err := serverOptions.config()
if err == nil {
t.Fatalf("ServerOptions{%v}.config() returns no err, wantErr != nil", serverOptions)
}
})
}
}
func (s) TestServerOptionsConfigSuccessCases(t *testing.T) {
tests := []struct {
desc string
requireClientCert bool
serverVType VerificationType
IdentityOptions IdentityCertificateOptions
RootOptions RootCertificateOptions
}{
{
desc: "Use system default if no fields in RootCertificateOptions is specified",
requireClientCert: true,
serverVType: CertVerification,
IdentityOptions: IdentityCertificateOptions{
Certificates: []tls.Certificate{},
},
},
{
desc: "Good case with mutual TLS",
requireClientCert: true,
serverVType: CertVerification,
RootOptions: RootCertificateOptions{
RootProvider: fakeProvider{},
},
IdentityOptions: IdentityCertificateOptions{
GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return nil, nil
},
},
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
serverOptions := &ServerOptions{
VType: test.serverVType,
RequireClientCert: test.requireClientCert,
IdentityOptions: test.IdentityOptions,
RootOptions: test.RootOptions,
}
serverConfig, err := serverOptions.config()
if err != nil {
t.Fatalf("ServerOptions{%v}.config() = %v, wantErr == nil", serverOptions, err)
}
// Verify that the system-provided certificates would be used
// when no verification method was set in serverOptions.
if serverOptions.RootOptions.RootCACerts == nil &&
serverOptions.RootOptions.GetRootCertificates == nil && serverOptions.RootOptions.RootProvider == nil {
if serverConfig.ClientCAs == nil {
t.Fatalf("Failed to assign system-provided certificates on the server side.")
}
}
})
}
}
func (s) TestClientServerHandshake(t *testing.T) {
cs := &certStore{}
err := cs.loadCerts()
if err != nil {
t.Fatalf("Failed to load certs: %v", err)
}
getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return &GetRootCAsResults{TrustCerts: clientTrustPool}, nil
return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil
}
clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) {
if params.ServerName == "" {
@ -69,18 +330,8 @@ func (s) TestClientServerHandshake(t *testing.T) {
verifyFuncBad := func(params *VerificationFuncParams) (*VerificationResults, error) {
return nil, fmt.Errorf("custom verification function failed")
}
clientPeerCert, err := tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"),
testdata.Path("client_key_1.pem"))
if err != nil {
t.Fatalf("Client is unable to parse peer certificates. Error: %v", err)
}
// ------------------Load Server Trust Cert and Peer Cert-------------------
serverTrustPool, err := readTrustCert(testdata.Path("server_trust_cert_1.pem"))
if err != nil {
t.Fatalf("Server is unable to load trust certs. Error: %v", err)
}
getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return &GetRootCAsResults{TrustCerts: serverTrustPool}, nil
return &GetRootCAsResults{TrustCerts: cs.serverTrust1}, nil
}
serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) {
if params.ServerName != "" {
@ -93,11 +344,6 @@ func (s) TestClientServerHandshake(t *testing.T) {
return &VerificationResults{}, nil
}
serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
testdata.Path("server_key_1.pem"))
if err != nil {
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
}
getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return nil, fmt.Errorf("bad root certificate reloading")
}
@ -109,7 +355,8 @@ func (s) TestClientServerHandshake(t *testing.T) {
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
clientVerifyFunc CustomVerificationFunc
clientVType VerificationType
clientExpectCreateError bool
clientRootProvider certprovider.Provider
clientIdentityProvider certprovider.Provider
clientExpectHandshakeError bool
serverMutualTLS bool
serverCert []tls.Certificate
@ -118,23 +365,10 @@ func (s) TestClientServerHandshake(t *testing.T) {
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
serverVerifyFunc CustomVerificationFunc
serverVType VerificationType
serverRootProvider certprovider.Provider
serverIdentityProvider certprovider.Provider
serverExpectError bool
}{
// Client: nil setting
// Server: only set serverCert with mutual TLS off
// Expected Behavior: server side failure
// Reason: if clientRoot, clientGetRoot and verifyFunc is not set, client
// side doesn't provide any verification mechanism. We don't allow this
// even setting vType to SkipVerification. Clients should at least provide
// their own verification logic.
{
desc: "Client has no trust cert; server sends peer cert",
clientVType: SkipVerification,
clientExpectCreateError: true,
serverCert: []tls.Certificate{serverPeerCert},
serverVType: CertAndHostVerification,
serverExpectError: true,
},
// Client: nil setting except verifyFuncGood
// Server: only set serverCert with mutual TLS off
// Expected Behavior: success
@ -144,7 +378,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
desc: "Client has no trust cert with verifyFuncGood; server sends peer cert",
clientVerifyFunc: clientVerifyFuncGood,
clientVType: SkipVerification,
serverCert: []tls.Certificate{serverPeerCert},
serverCert: []tls.Certificate{cs.serverPeer1},
serverVType: CertAndHostVerification,
},
// Client: only set clientRoot
@ -155,10 +389,10 @@ func (s) TestClientServerHandshake(t *testing.T) {
// this test suites.
{
desc: "Client has root cert; server sends peer cert",
clientRoot: clientTrustPool,
clientRoot: cs.clientTrust1,
clientVType: CertAndHostVerification,
clientExpectHandshakeError: true,
serverCert: []tls.Certificate{serverPeerCert},
serverCert: []tls.Certificate{cs.serverPeer1},
serverVType: CertAndHostVerification,
serverExpectError: true,
},
@ -173,7 +407,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
clientGetRoot: getRootCAsForClient,
clientVType: CertAndHostVerification,
clientExpectHandshakeError: true,
serverCert: []tls.Certificate{serverPeerCert},
serverCert: []tls.Certificate{cs.serverPeer1},
serverVType: CertAndHostVerification,
serverExpectError: true,
},
@ -185,7 +419,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverCert: []tls.Certificate{serverPeerCert},
serverCert: []tls.Certificate{cs.serverPeer1},
serverVType: CertAndHostVerification,
},
// Client: set clientGetRoot and bad clientVerifyFunc function
@ -198,66 +432,35 @@ func (s) TestClientServerHandshake(t *testing.T) {
clientVerifyFunc: verifyFuncBad,
clientVType: CertVerification,
clientExpectHandshakeError: true,
serverCert: []tls.Certificate{serverPeerCert},
serverCert: []tls.Certificate{cs.serverPeer1},
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set clientGetRoot and clientVerifyFunc
// Server: nil setting
// Expected Behavior: server side failure
// Reason: server side must either set serverCert or serverGetCert
{
desc: "Client sets reload root function with verifyFuncGood; server sets nil",
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverRoot and serverCert with mutual TLS on
// Expected Behavior: success
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets peer cert and root cert; mutualTLS",
clientCert: []tls.Certificate{clientPeerCert},
clientCert: []tls.Certificate{cs.clientPeer1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverCert: []tls.Certificate{serverPeerCert},
serverRoot: serverTrustPool,
serverCert: []tls.Certificate{cs.serverPeer1},
serverRoot: cs.serverTrust1,
serverVType: CertVerification,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverCert, but not setting any of serverRoot, serverGetRoot
// or serverVerifyFunc, with mutual TLS on
// Expected Behavior: server side failure
// Reason: server side needs to provide any verification mechanism when
// mTLS in on, even setting vType to SkipVerification. Servers should at
// least provide their own verification logic.
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets no verification; mutualTLS",
clientCert: []tls.Certificate{clientPeerCert},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
clientExpectHandshakeError: true,
serverMutualTLS: true,
serverCert: []tls.Certificate{serverPeerCert},
serverVType: SkipVerification,
serverExpectError: true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverGetRoot and serverCert with mutual TLS on
// Expected Behavior: success
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS",
clientCert: []tls.Certificate{clientPeerCert},
clientCert: []tls.Certificate{cs.clientPeer1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverCert: []tls.Certificate{serverPeerCert},
serverCert: []tls.Certificate{cs.serverPeer1},
serverGetRoot: getRootCAsForServer,
serverVType: CertVerification,
},
@ -268,12 +471,12 @@ func (s) TestClientServerHandshake(t *testing.T) {
// Reason: server side reloading returns failure
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, bad reload root function; mutualTLS",
clientCert: []tls.Certificate{clientPeerCert},
clientCert: []tls.Certificate{cs.clientPeer1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverCert: []tls.Certificate{serverPeerCert},
serverCert: []tls.Certificate{cs.serverPeer1},
serverGetRoot: getRootCAsForServerBad,
serverVType: CertVerification,
serverExpectError: true,
@ -284,14 +487,14 @@ func (s) TestClientServerHandshake(t *testing.T) {
{
desc: "Client sets reload peer/root function with verifyFuncGood; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
return &cs.clientPeer1, nil
},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverPeerCert}, nil
return []*tls.Certificate{&cs.serverPeer1}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
@ -305,14 +508,14 @@ func (s) TestClientServerHandshake(t *testing.T) {
{
desc: "Client sends wrong peer cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
return &cs.serverPeer1, nil
},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverPeerCert}, nil
return []*tls.Certificate{&cs.serverPeer1}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
@ -326,7 +529,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
{
desc: "Client has wrong trust cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
return &cs.clientPeer1, nil
},
clientGetRoot: getRootCAsForServer,
clientVerifyFunc: clientVerifyFuncGood,
@ -334,7 +537,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
clientExpectHandshakeError: true,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverPeerCert}, nil
return []*tls.Certificate{&cs.serverPeer1}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
@ -349,14 +552,14 @@ func (s) TestClientServerHandshake(t *testing.T) {
{
desc: "Client sets reload peer/root function with verifyFuncGood; Server sends wrong peer cert; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
return &cs.clientPeer1, nil
},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&clientPeerCert}, nil
return []*tls.Certificate{&cs.clientPeer1}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
@ -370,7 +573,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
{
desc: "Client sets reload peer/root function with verifyFuncGood; Server has wrong trust cert; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
return &cs.clientPeer1, nil
},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
@ -378,7 +581,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
clientExpectHandshakeError: true,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverPeerCert}, nil
return []*tls.Certificate{&cs.serverPeer1}, nil
},
serverGetRoot: getRootCAsForClient,
serverVerifyFunc: serverVerifyFunc,
@ -391,18 +594,92 @@ func (s) TestClientServerHandshake(t *testing.T) {
// server custom check fails
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets bad custom check; mutualTLS",
clientCert: []tls.Certificate{clientPeerCert},
clientCert: []tls.Certificate{cs.clientPeer1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
clientExpectHandshakeError: true,
serverMutualTLS: true,
serverCert: []tls.Certificate{serverPeerCert},
serverCert: []tls.Certificate{cs.serverPeer1},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: verifyFuncBad,
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set a clientIdentityProvider which will get multiple cert chains
// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
// Expected Behavior: server side failure due to multiple cert chains in
// clientIdentityProvider
{
desc: "Client sets multiple certs in clientIdentityProvider; Server sets root and identity provider; mutualTLS",
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantMultiCert: true},
clientRootProvider: fakeProvider{isClient: true},
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
serverRootProvider: fakeProvider{isClient: false},
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set a bad clientIdentityProvider
// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
// Expected Behavior: server side failure due to bad clientIdentityProvider
{
desc: "Client sets bad clientIdentityProvider; Server sets root and identity provider; mutualTLS",
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantError: true},
clientRootProvider: fakeProvider{isClient: true},
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
serverRootProvider: fakeProvider{isClient: false},
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set clientIdentityProvider and clientRootProvider
// Server: set bad serverRootProvider with mutual TLS on
// Expected Behavior: server side failure due to bad serverRootProvider
{
desc: "Client sets root and identity provider; Server sets bad root provider; mutualTLS",
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
clientRootProvider: fakeProvider{isClient: true},
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
serverRootProvider: fakeProvider{isClient: false, wantError: true},
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set clientIdentityProvider and clientRootProvider
// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
// Expected Behavior: success
{
desc: "Client sets root and identity provider; Server sets root and identity provider; mutualTLS",
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
clientRootProvider: fakeProvider{isClient: true},
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
serverRootProvider: fakeProvider{isClient: false},
serverVType: CertVerification,
},
// Client: set clientIdentityProvider and clientRootProvider
// Server: set serverIdentityProvider getting multiple cert chains and serverRootProvider with mutual TLS on
// Expected Behavior: success, because server side has SNI
{
desc: "Client sets root and identity provider; Server sets multiple certs in serverIdentityProvider; mutualTLS",
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
clientRootProvider: fakeProvider{isClient: true},
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false, wantMultiCert: true},
serverRootProvider: fakeProvider{isClient: false},
serverVType: CertVerification,
},
} {
test := test
t.Run(test.desc, func(t *testing.T) {
@ -413,11 +690,15 @@ func (s) TestClientServerHandshake(t *testing.T) {
}
// Start a server using ServerOptions in another goroutine.
serverOptions := &ServerOptions{
Certificates: test.serverCert,
GetCertificates: test.serverGetCert,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: test.serverRoot,
GetRootCAs: test.serverGetRoot,
IdentityOptions: IdentityCertificateOptions{
Certificates: test.serverCert,
GetIdentityCertificatesForServer: test.serverGetCert,
IdentityProvider: test.serverIdentityProvider,
},
RootOptions: RootCertificateOptions{
RootCACerts: test.serverRoot,
GetRootCertificates: test.serverGetRoot,
RootProvider: test.serverRootProvider,
},
RequireClientCert: test.serverMutualTLS,
VerifyPeer: test.serverVerifyFunc,
@ -452,23 +733,22 @@ func (s) TestClientServerHandshake(t *testing.T) {
}
defer conn.Close()
clientOptions := &ClientOptions{
Certificates: test.clientCert,
GetClientCertificate: test.clientGetCert,
VerifyPeer: test.clientVerifyFunc,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
GetRootCAs: test.clientGetRoot,
IdentityOptions: IdentityCertificateOptions{
Certificates: test.clientCert,
GetIdentityCertificatesForClient: test.clientGetCert,
IdentityProvider: test.clientIdentityProvider,
},
VerifyPeer: test.clientVerifyFunc,
RootOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
GetRootCertificates: test.clientGetRoot,
RootProvider: test.clientRootProvider,
},
VType: test.clientVType,
}
clientTLS, newClientErr := NewClientCreds(clientOptions)
if newClientErr != nil && test.clientExpectCreateError {
return
}
if newClientErr != nil && !test.clientExpectCreateError ||
newClientErr == nil && test.clientExpectCreateError {
t.Fatalf("Expect error: %v, but err is %v",
test.clientExpectCreateError, newClientErr)
clientTLS, err := NewClientCreds(clientOptions)
if err != nil {
t.Fatalf("NewClientCreds failed: %v", err)
}
_, clientAuthInfo, handshakeErr := clientTLS.ClientHandshake(context.Background(),
lisAddr, conn)
@ -541,7 +821,7 @@ func (s) TestAdvancedTLSOverrideServerName(t *testing.T) {
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
}
clientOptions := &ClientOptions{
RootCertificateOptions: RootCertificateOptions{
RootOptions: RootCertificateOptions{
RootCACerts: clientTrustPool,
},
ServerNameOverride: expectedServerName,
@ -563,7 +843,7 @@ func (s) TestTLSClone(t *testing.T) {
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
}
clientOptions := &ClientOptions{
RootCertificateOptions: RootCertificateOptions{
RootOptions: RootCertificateOptions{
RootCACerts: clientTrustPool,
},
ServerNameOverride: expectedServerName,
@ -635,62 +915,6 @@ func (s) TestWrapSyscallConn(t *testing.T) {
}
}
func (s) TestOptionsConfig(t *testing.T) {
serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
testdata.Path("server_key_1.pem"))
if err != nil {
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
}
tests := []struct {
desc string
clientVType VerificationType
serverMutualTLS bool
serverCert []tls.Certificate
serverVType VerificationType
}{
{
desc: "Client uses system-provided RootCAs; server uses system-provided ClientCAs",
clientVType: CertVerification,
serverMutualTLS: true,
serverCert: []tls.Certificate{serverPeerCert},
serverVType: CertAndHostVerification,
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
serverOptions := &ServerOptions{
Certificates: test.serverCert,
RequireClientCert: test.serverMutualTLS,
VType: test.serverVType,
}
serverConfig, err := serverOptions.config()
if err != nil {
t.Fatalf("Unable to generate serverConfig. Error: %v", err)
}
// Verify that the system-provided certificates would be used
// when no verification method was set in serverOptions.
if serverOptions.RootCACerts == nil && serverOptions.GetRootCAs == nil &&
serverOptions.RequireClientCert && serverConfig.ClientCAs == nil {
t.Fatalf("Failed to assign system-provided certificates on the server side.")
}
clientOptions := &ClientOptions{
VType: test.clientVType,
}
clientConfig, err := clientOptions.config()
if err != nil {
t.Fatalf("Unable to generate clientConfig. Error: %v", err)
}
// Verify that the system-provided certificates would be used
// when no verification method was set in clientOptions.
if clientOptions.RootCACerts == nil && clientOptions.GetRootCAs == nil &&
clientConfig.RootCAs == nil {
t.Fatalf("Failed to assign system-provided certificates on the client side.")
}
})
}
}
func (s) TestGetCertificatesSNI(t *testing.T) {
// Load server certificates for setting the serverGetCert callback function.
serverCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem"))
@ -734,8 +958,10 @@ func (s) TestGetCertificatesSNI(t *testing.T) {
test := test
t.Run(test.desc, func(t *testing.T) {
serverOptions := &ServerOptions{
GetCertificates: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
IdentityOptions: IdentityCertificateOptions{
GetIdentityCertificatesForServer: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
},
},
}
serverConfig, err := serverOptions.config()

View File

@ -97,7 +97,6 @@ func NewPEMFileProvider(o PEMFileProviderOptions) (*PEMFileProvider, error) {
return nil, fmt.Errorf("private key file and identity cert file should be both specified or not specified")
}
if o.IdentityInterval == 0 {
logger.Warningf("heyheyhey")
o.IdentityInterval = defaultIdentityInterval
}
if o.RootInterval == 0 {

View File

@ -28,10 +28,10 @@ import (
// buildGetCertificates returns the certificate that matches the SNI field
// for the given ClientHelloInfo, defaulting to the first element of o.GetCertificates.
func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) {
if o.GetCertificates == nil {
if o.IdentityOptions.GetIdentityCertificatesForServer == nil {
return nil, fmt.Errorf("function GetCertificates must be specified")
}
certificates, err := o.GetCertificates(clientHello)
certificates, err := o.IdentityOptions.GetIdentityCertificatesForServer(clientHello)
if err != nil {
return nil, err
}