diff --git a/security/advancedtls/advancedtls.go b/security/advancedtls/advancedtls.go index 40a9fd5f7..745646326 100644 --- a/security/advancedtls/advancedtls.go +++ b/security/advancedtls/advancedtls.go @@ -27,10 +27,12 @@ import ( "crypto/x509" "fmt" "net" + "reflect" "syscall" "time" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/tls/certprovider" credinternal "google.golang.org/grpc/internal/credentials" ) @@ -79,21 +81,67 @@ type GetRootCAsResults struct { TrustCerts *x509.CertPool } -// RootCertificateOptions contains a field and a function for obtaining root -// trust certificates. -// It is used by both ClientOptions and ServerOptions. -// If users want to use default verification, but did not provide a valid -// RootCertificateOptions, we use the system default trust certificates. +// RootCertificateOptions contains options to obtain root trust certificates +// for both the client and the server. +// At most one option could be set. If none of them are set, we +// use the system default trust certificates. type RootCertificateOptions struct { - // If field RootCACerts is set, field GetRootCAs will be ignored. RootCACerts - // will be used every time when verifying the peer certificates, without - // performing root certificate reloading. + // If RootCACerts is set, it will be used every time when verifying + // the peer certificates, without performing root certificate reloading. RootCACerts *x509.CertPool - // If GetRootCAs is set and RootCACerts is nil, GetRootCAs will be invoked - // every time asked to check certificates sent from the server when a new - // connection is established. - // This is known as root CA certificate reloading. - GetRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error) + // If GetRootCertificates is set, it will be invoked to obtain root certs for + // every new connection. + GetRootCertificates func(params *GetRootCAsParams) (*GetRootCAsResults, error) + // If RootProvider is set, we will use the root certs from the Provider's + // KeyMaterial() call in the new connections. The Provider must have initial + // credentials if specified. Otherwise, KeyMaterial() will block forever. + RootProvider certprovider.Provider +} + +// nonNilFieldCount returns the number of set fields in RootCertificateOptions. +func (o RootCertificateOptions) nonNilFieldCount() int { + cnt := 0 + rv := reflect.ValueOf(o) + for i := 0; i < rv.NumField(); i++ { + if !rv.Field(i).IsNil() { + cnt++ + } + } + return cnt +} + +// IdentityCertificateOptions contains options to obtain identity certificates +// for both the client and the server. +// At most one option could be set. +type IdentityCertificateOptions struct { + // If Certificates is set, it will be used every time when needed to present + //identity certificates, without performing identity certificate reloading. + Certificates []tls.Certificate + // If GetIdentityCertificatesForClient is set, it will be invoked to obtain + // identity certs for every new connection. + // This field MUST be set on client side. + GetIdentityCertificatesForClient func(*tls.CertificateRequestInfo) (*tls.Certificate, error) + // If GetIdentityCertificatesForServer is set, it will be invoked to obtain + // identity certs for every new connection. + // This field MUST be set on server side. + GetIdentityCertificatesForServer func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) + // If IdentityProvider is set, we will use the identity certs from the + // Provider's KeyMaterial() call in the new connections. The Provider must + // have initial credentials if specified. Otherwise, KeyMaterial() will block + // forever. + IdentityProvider certprovider.Provider +} + +// nonNilFieldCount returns the number of set fields in IdentityCertificateOptions. +func (o IdentityCertificateOptions) nonNilFieldCount() int { + cnt := 0 + rv := reflect.ValueOf(o) + for i := 0; i < rv.NumField(); i++ { + if !rv.Field(i).IsNil() { + cnt++ + } + } + return cnt } // VerificationType is the enum type that represents different levels of @@ -115,27 +163,11 @@ const ( SkipVerification ) -// ClientOptions contains all the fields and functions needed to be filled by -// the client. -// General rules for certificate setting on client side: -// Certificates or GetClientCertificate indicates the certificates sent from -// the client to the server to prove client's identities. The rules for setting -// these two fields are: -// If requiring mutual authentication on server side: -// Either Certificates or GetClientCertificate must be set; the other will -// be ignored. -// Otherwise: -// Nothing needed(the two fields will be ignored). +// ClientOptions contains the fields needed to be filled by the client. type ClientOptions struct { - // If field Certificates is set, field GetClientCertificate will be ignored. - // The client will use Certificates every time when asked for a certificate, - // without performing certificate reloading. - Certificates []tls.Certificate - // If GetClientCertificate is set and Certificates is nil, the client will - // invoke this function every time asked to present certificates to the - // server when a new connection is established. This is known as peer - // certificate reloading. - GetClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error) + // IdentityOptions is OPTIONAL on client side. This field only needs to be + // set if mutual authentication is required on server side. + IdentityOptions IdentityCertificateOptions // VerifyPeer is a custom verification check after certificate signature // check. // If this is set, we will perform this customized check after doing the @@ -145,37 +177,25 @@ type ClientOptions struct { // it will override the virtual host name of authority (e.g. :authority // header field) in requests. ServerNameOverride string - // RootCertificateOptions is REQUIRED to be correctly set on client side. - RootCertificateOptions + // RootOptions is OPTIONAL on client side. If not set, we will try to use the + // default trust certificates in users' OS system. + RootOptions RootCertificateOptions // VType is the verification type on the client side. VType VerificationType } -// ServerOptions contains all the fields and functions needed to be filled by -// the client. -// General rules for certificate setting on server side: -// Certificates or GetClientCertificate indicates the certificates sent from -// the server to the client to prove server's identities. The rules for setting -// these two fields are: -// Either Certificates or GetCertificates must be set; the other will be ignored. +// ServerOptions contains the fields needed to be filled by the server. type ServerOptions struct { - // If field Certificates is set, field GetClientCertificate will be ignored. - // The server will use Certificates every time when asked for a certificate, - // without performing certificate reloading. - Certificates []tls.Certificate - // If GetClientCertificate is set and Certificates is nil, the server will - // invoke this function every time asked to present certificates to the - // client when a new connection is established. This is known as peer - // certificate reloading. - GetCertificates func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) + // IdentityOptions is REQUIRED on server side. + IdentityOptions IdentityCertificateOptions // VerifyPeer is a custom verification check after certificate signature // check. // If this is set, we will perform this customized check after doing the // normal check(s) indicated by setting VType. VerifyPeer CustomVerificationFunc - // RootCertificateOptions is only required when mutual TLS is - // enabled(RequireClientCert is true). - RootCertificateOptions + // RootOptions is OPTIONAL on server side. This field only needs to be set if + // mutual authentication is required(RequireClientCert is true). + RootOptions RootCertificateOptions // If the server want the client to send certificates. RequireClientCert bool // VType is the verification type on the server side. @@ -184,48 +204,89 @@ type ServerOptions struct { func (o *ClientOptions) config() (*tls.Config, error) { if o.VType == SkipVerification && o.VerifyPeer == nil { - return nil, fmt.Errorf( - "client needs to provide custom verification mechanism if choose to skip default verification") + return nil, fmt.Errorf("client needs to provide custom verification mechanism if choose to skip default verification") } - rootCAs := o.RootCACerts - if o.VType != SkipVerification && o.RootCACerts == nil && o.GetRootCAs == nil { - // Set rootCAs to system default. - systemRootCAs, err := x509.SystemCertPool() - if err != nil { - return nil, err - } - rootCAs = systemRootCAs + // Make sure users didn't specify more than one fields in + // RootCertificateOptions and IdentityCertificateOptions. + if num := o.RootOptions.nonNilFieldCount(); num > 1 { + return nil, fmt.Errorf("at most one field in RootCertificateOptions could be specified") + } + if num := o.IdentityOptions.nonNilFieldCount(); num > 1 { + return nil, fmt.Errorf("at most one field in IdentityCertificateOptions could be specified") + } + if o.IdentityOptions.GetIdentityCertificatesForServer != nil { + return nil, fmt.Errorf("GetIdentityCertificatesForServer cannot be specified on the client side") } - // We have to set InsecureSkipVerify to true to skip the default checks and - // use the verification function we built from buildVerifyFunc. config := &tls.Config{ - ServerName: o.ServerNameOverride, - Certificates: o.Certificates, - GetClientCertificate: o.GetClientCertificate, - InsecureSkipVerify: true, + ServerName: o.ServerNameOverride, + // We have to set InsecureSkipVerify to true to skip the default checks and + // use the verification function we built from buildVerifyFunc. + InsecureSkipVerify: true, } - if rootCAs != nil { - config.RootCAs = rootCAs + // Propagate root-certificate-related fields in tls.Config. + switch { + case o.RootOptions.RootCACerts != nil: + config.RootCAs = o.RootOptions.RootCACerts + case o.RootOptions.GetRootCertificates != nil: + // In cases when users provide GetRootCertificates callback, since this + // callback is not contained in tls.Config, we have nothing to set here. + // We will invoke the callback in ClientHandshake. + case o.RootOptions.RootProvider != nil: + o.RootOptions.GetRootCertificates = func(*GetRootCAsParams) (*GetRootCAsResults, error) { + km, err := o.RootOptions.RootProvider.KeyMaterial(context.Background()) + if err != nil { + return nil, err + } + return &GetRootCAsResults{TrustCerts: km.Roots}, nil + } + default: + // No root certificate options specified by user. Use the certificates + // stored in system default path as the last resort. + if o.VType != SkipVerification { + systemRootCAs, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + config.RootCAs = systemRootCAs + } + } + // Propagate identity-certificate-related fields in tls.Config. + switch { + case o.IdentityOptions.Certificates != nil: + config.Certificates = o.IdentityOptions.Certificates + case o.IdentityOptions.GetIdentityCertificatesForClient != nil: + config.GetClientCertificate = o.IdentityOptions.GetIdentityCertificatesForClient + case o.IdentityOptions.IdentityProvider != nil: + config.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + km, err := o.IdentityOptions.IdentityProvider.KeyMaterial(context.Background()) + if err != nil { + return nil, err + } + if len(km.Certs) != 1 { + return nil, fmt.Errorf("there should always be only one identity cert chain on the client side in IdentityProvider") + } + return &km.Certs[0], nil + } + default: + // It's fine for users to not specify identity certificate options here. } return config, nil } func (o *ServerOptions) config() (*tls.Config, error) { - if o.Certificates == nil && o.GetCertificates == nil { - return nil, fmt.Errorf("either Certificates or GetCertificates must be specified") - } if o.RequireClientCert && o.VType == SkipVerification && o.VerifyPeer == nil { - return nil, fmt.Errorf( - "server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)") + return nil, fmt.Errorf("server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)") } - clientCAs := o.RootCACerts - if o.VType != SkipVerification && o.RootCACerts == nil && o.GetRootCAs == nil && o.RequireClientCert { - // Set clientCAs to system default. - systemRootCAs, err := x509.SystemCertPool() - if err != nil { - return nil, err - } - clientCAs = systemRootCAs + // Make sure users didn't specify more than one fields in + // RootCertificateOptions and IdentityCertificateOptions. + if num := o.RootOptions.nonNilFieldCount(); num > 1 { + return nil, fmt.Errorf("at most one field in RootCertificateOptions could be specified") + } + if num := o.IdentityOptions.nonNilFieldCount(); num > 1 { + return nil, fmt.Errorf("at most one field in IdentityCertificateOptions could be specified") + } + if o.IdentityOptions.GetIdentityCertificatesForClient != nil { + return nil, fmt.Errorf("GetIdentityCertificatesForClient cannot be specified on the server side") } clientAuth := tls.NoClientCert if o.RequireClientCert { @@ -235,18 +296,60 @@ func (o *ServerOptions) config() (*tls.Config, error) { clientAuth = tls.RequireAnyClientCert } config := &tls.Config{ - ClientAuth: clientAuth, - Certificates: o.Certificates, + ClientAuth: clientAuth, } - if o.GetCertificates != nil { - // GetCertificate is only able to perform SNI logic for go1.10 and above. - // It will return the first certificate in o.GetCertificates for go1.9. + // Propagate root-certificate-related fields in tls.Config. + switch { + case o.RootOptions.RootCACerts != nil: + config.ClientCAs = o.RootOptions.RootCACerts + case o.RootOptions.GetRootCertificates != nil: + // In cases when users provide GetRootCertificates callback, since this + // callback is not contained in tls.Config, we have nothing to set here. + // We will invoke the callback in ServerHandshake. + case o.RootOptions.RootProvider != nil: + o.RootOptions.GetRootCertificates = func(*GetRootCAsParams) (*GetRootCAsResults, error) { + km, err := o.RootOptions.RootProvider.KeyMaterial(context.Background()) + if err != nil { + return nil, err + } + return &GetRootCAsResults{TrustCerts: km.Roots}, nil + } + default: + // No root certificate options specified by user. Use the certificates + // stored in system default path as the last resort. + if o.VType != SkipVerification && o.RequireClientCert { + systemRootCAs, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + config.ClientCAs = systemRootCAs + } + } + // Propagate identity-certificate-related fields in tls.Config. + switch { + case o.IdentityOptions.Certificates != nil: + config.Certificates = o.IdentityOptions.Certificates + case o.IdentityOptions.GetIdentityCertificatesForServer != nil: config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { return buildGetCertificates(clientHello, o) } - } - if clientCAs != nil { - config.ClientCAs = clientCAs + case o.IdentityOptions.IdentityProvider != nil: + o.IdentityOptions.GetIdentityCertificatesForServer = func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) { + km, err := o.IdentityOptions.IdentityProvider.KeyMaterial(context.Background()) + if err != nil { + return nil, err + } + var certChains []*tls.Certificate + for i := 0; i < len(km.Certs); i++ { + certChains = append(certChains, &km.Certs[i]) + } + return certChains, nil + } + config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + return buildGetCertificates(clientHello, o) + } + default: + return nil, fmt.Errorf("needs to specify at least one field in IdentityCertificateOptions") } return config, nil } @@ -423,7 +526,7 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error) tc := &advancedTLSCreds{ config: conf, isClient: true, - getRootCAs: o.GetRootCAs, + getRootCAs: o.RootOptions.GetRootCertificates, verifyFunc: o.VerifyPeer, vType: o.VType, } @@ -441,7 +544,7 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error) tc := &advancedTLSCreds{ config: conf, isClient: false, - getRootCAs: o.GetRootCAs, + getRootCAs: o.RootOptions.GetRootCertificates, verifyFunc: o.VerifyPeer, vType: o.VType, } diff --git a/security/advancedtls/advancedtls_integration_test.go b/security/advancedtls/advancedtls_integration_test.go index dc4f513f2..339864447 100644 --- a/security/advancedtls/advancedtls_integration_test.go +++ b/security/advancedtls/advancedtls_integration_test.go @@ -385,11 +385,13 @@ func (s) TestEnd2End(t *testing.T) { t.Run(test.desc, func(t *testing.T) { // Start a server using ServerOptions in another goroutine. serverOptions := &ServerOptions{ - Certificates: test.serverCert, - GetCertificates: test.serverGetCert, - RootCertificateOptions: RootCertificateOptions{ - RootCACerts: test.serverRoot, - GetRootCAs: test.serverGetRoot, + IdentityOptions: IdentityCertificateOptions{ + Certificates: test.serverCert, + GetIdentityCertificatesForServer: test.serverGetCert, + }, + RootOptions: RootCertificateOptions{ + RootCACerts: test.serverRoot, + GetRootCertificates: test.serverGetRoot, }, RequireClientCert: true, VerifyPeer: test.serverVerifyFunc, @@ -409,12 +411,14 @@ func (s) TestEnd2End(t *testing.T) { pb.RegisterGreeterService(s, &pb.GreeterService{SayHello: sayHello}) go s.Serve(lis) clientOptions := &ClientOptions{ - Certificates: test.clientCert, - GetClientCertificate: test.clientGetCert, - VerifyPeer: test.clientVerifyFunc, - RootCertificateOptions: RootCertificateOptions{ - RootCACerts: test.clientRoot, - GetRootCAs: test.clientGetRoot, + IdentityOptions: IdentityCertificateOptions{ + Certificates: test.clientCert, + GetIdentityCertificatesForClient: test.clientGetCert, + }, + VerifyPeer: test.clientVerifyFunc, + RootOptions: RootCertificateOptions{ + RootCACerts: test.clientRoot, + GetRootCertificates: test.clientGetRoot, }, VType: test.clientVType, } diff --git a/security/advancedtls/advancedtls_test.go b/security/advancedtls/advancedtls_test.go index 8800a51a9..a631ee465 100644 --- a/security/advancedtls/advancedtls_test.go +++ b/security/advancedtls/advancedtls_test.go @@ -34,6 +34,7 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/security/advancedtls/testdata" ) @@ -46,14 +47,274 @@ func Test(t *testing.T) { grpctest.RunSubTests(t, s{}) } -func (s) TestClientServerHandshake(t *testing.T) { - // ------------------Load Client Trust Cert and Peer Cert------------------- - clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem")) +type provType int + +const ( + provTypeRoot provType = iota + provTypeIdentity +) + +type fakeProvider struct { + pt provType + isClient bool + wantMultiCert bool + wantError bool +} + +func (f fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) { + if f.wantError { + return nil, fmt.Errorf("bad fakeProvider") + } + cs := &certStore{} + err := cs.loadCerts() if err != nil { - t.Fatalf("Client is unable to load trust certs. Error: %v", err) + return nil, fmt.Errorf("failed to load certs: %v", err) + } + if f.pt == provTypeRoot && f.isClient { + return &certprovider.KeyMaterial{Roots: cs.clientTrust1}, nil + } + if f.pt == provTypeRoot && !f.isClient { + return &certprovider.KeyMaterial{Roots: cs.serverTrust1}, nil + } + if f.pt == provTypeIdentity && f.isClient { + if f.wantMultiCert { + return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1, cs.clientPeer2}}, nil + } + return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}}, nil + } + if f.wantMultiCert { + return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1, cs.serverPeer2}}, nil + } + return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1}}, nil +} + +func (f fakeProvider) Close() {} + +func (s) TestClientOptionsConfigErrorCases(t *testing.T) { + tests := []struct { + desc string + clientVType VerificationType + IdentityOptions IdentityCertificateOptions + RootOptions RootCertificateOptions + }{ + { + desc: "Skip default verification and provide no root credentials", + clientVType: SkipVerification, + }, + { + desc: "More than one fields in RootCertificateOptions is specified", + clientVType: CertVerification, + RootOptions: RootCertificateOptions{ + RootCACerts: x509.NewCertPool(), + RootProvider: fakeProvider{}, + }, + }, + { + desc: "More than one fields in IdentityCertificateOptions is specified", + clientVType: CertVerification, + IdentityOptions: IdentityCertificateOptions{ + GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return nil, nil + }, + IdentityProvider: fakeProvider{pt: provTypeIdentity}, + }, + }, + { + desc: "Specify GetIdentityCertificatesForServer", + IdentityOptions: IdentityCertificateOptions{ + GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) { + return nil, nil + }, + }, + }, + } + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + clientOptions := &ClientOptions{ + VType: test.clientVType, + IdentityOptions: test.IdentityOptions, + RootOptions: test.RootOptions, + } + _, err := clientOptions.config() + if err == nil { + t.Fatalf("ClientOptions{%v}.config() returns no err, wantErr != nil", clientOptions) + } + }) + } +} + +func (s) TestClientOptionsConfigSuccessCases(t *testing.T) { + tests := []struct { + desc string + clientVType VerificationType + IdentityOptions IdentityCertificateOptions + RootOptions RootCertificateOptions + }{ + { + desc: "Use system default if no fields in RootCertificateOptions is specified", + clientVType: CertVerification, + }, + { + desc: "Good case with mutual TLS", + clientVType: CertVerification, + RootOptions: RootCertificateOptions{ + RootProvider: fakeProvider{}, + }, + IdentityOptions: IdentityCertificateOptions{ + IdentityProvider: fakeProvider{pt: provTypeIdentity}, + }, + }, + } + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + clientOptions := &ClientOptions{ + VType: test.clientVType, + IdentityOptions: test.IdentityOptions, + RootOptions: test.RootOptions, + } + clientConfig, err := clientOptions.config() + if err != nil { + t.Fatalf("ClientOptions{%v}.config() = %v, wantErr == nil", clientOptions, err) + } + // Verify that the system-provided certificates would be used + // when no verification method was set in clientOptions. + if clientOptions.RootOptions.RootCACerts == nil && + clientOptions.RootOptions.GetRootCertificates == nil && clientOptions.RootOptions.RootProvider == nil { + if clientConfig.RootCAs == nil { + t.Fatalf("Failed to assign system-provided certificates on the client side.") + } + } + }) + } +} + +func (s) TestServerOptionsConfigErrorCases(t *testing.T) { + tests := []struct { + desc string + requireClientCert bool + serverVType VerificationType + IdentityOptions IdentityCertificateOptions + RootOptions RootCertificateOptions + }{ + { + desc: "Skip default verification and provide no root credentials", + requireClientCert: true, + serverVType: SkipVerification, + }, + { + desc: "More than one fields in RootCertificateOptions is specified", + requireClientCert: true, + serverVType: CertVerification, + RootOptions: RootCertificateOptions{ + RootCACerts: x509.NewCertPool(), + GetRootCertificates: func(*GetRootCAsParams) (*GetRootCAsResults, error) { + return nil, nil + }, + }, + }, + { + desc: "More than one fields in IdentityCertificateOptions is specified", + serverVType: CertVerification, + IdentityOptions: IdentityCertificateOptions{ + Certificates: []tls.Certificate{}, + IdentityProvider: fakeProvider{pt: provTypeIdentity}, + }, + }, + { + desc: "no field in IdentityCertificateOptions is specified", + serverVType: CertVerification, + }, + { + desc: "Specify GetIdentityCertificatesForClient", + IdentityOptions: IdentityCertificateOptions{ + GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return nil, nil + }, + }, + }, + } + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + serverOptions := &ServerOptions{ + VType: test.serverVType, + RequireClientCert: test.requireClientCert, + IdentityOptions: test.IdentityOptions, + RootOptions: test.RootOptions, + } + _, err := serverOptions.config() + if err == nil { + t.Fatalf("ServerOptions{%v}.config() returns no err, wantErr != nil", serverOptions) + } + }) + } +} + +func (s) TestServerOptionsConfigSuccessCases(t *testing.T) { + tests := []struct { + desc string + requireClientCert bool + serverVType VerificationType + IdentityOptions IdentityCertificateOptions + RootOptions RootCertificateOptions + }{ + { + desc: "Use system default if no fields in RootCertificateOptions is specified", + requireClientCert: true, + serverVType: CertVerification, + IdentityOptions: IdentityCertificateOptions{ + Certificates: []tls.Certificate{}, + }, + }, + { + desc: "Good case with mutual TLS", + requireClientCert: true, + serverVType: CertVerification, + RootOptions: RootCertificateOptions{ + RootProvider: fakeProvider{}, + }, + IdentityOptions: IdentityCertificateOptions{ + GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) { + return nil, nil + }, + }, + }, + } + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + serverOptions := &ServerOptions{ + VType: test.serverVType, + RequireClientCert: test.requireClientCert, + IdentityOptions: test.IdentityOptions, + RootOptions: test.RootOptions, + } + serverConfig, err := serverOptions.config() + if err != nil { + t.Fatalf("ServerOptions{%v}.config() = %v, wantErr == nil", serverOptions, err) + } + // Verify that the system-provided certificates would be used + // when no verification method was set in serverOptions. + if serverOptions.RootOptions.RootCACerts == nil && + serverOptions.RootOptions.GetRootCertificates == nil && serverOptions.RootOptions.RootProvider == nil { + if serverConfig.ClientCAs == nil { + t.Fatalf("Failed to assign system-provided certificates on the server side.") + } + } + }) + } +} + +func (s) TestClientServerHandshake(t *testing.T) { + cs := &certStore{} + err := cs.loadCerts() + if err != nil { + t.Fatalf("Failed to load certs: %v", err) } getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { - return &GetRootCAsResults{TrustCerts: clientTrustPool}, nil + return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil } clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) { if params.ServerName == "" { @@ -69,18 +330,8 @@ func (s) TestClientServerHandshake(t *testing.T) { verifyFuncBad := func(params *VerificationFuncParams) (*VerificationResults, error) { return nil, fmt.Errorf("custom verification function failed") } - clientPeerCert, err := tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"), - testdata.Path("client_key_1.pem")) - if err != nil { - t.Fatalf("Client is unable to parse peer certificates. Error: %v", err) - } - // ------------------Load Server Trust Cert and Peer Cert------------------- - serverTrustPool, err := readTrustCert(testdata.Path("server_trust_cert_1.pem")) - if err != nil { - t.Fatalf("Server is unable to load trust certs. Error: %v", err) - } getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { - return &GetRootCAsResults{TrustCerts: serverTrustPool}, nil + return &GetRootCAsResults{TrustCerts: cs.serverTrust1}, nil } serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) { if params.ServerName != "" { @@ -93,11 +344,6 @@ func (s) TestClientServerHandshake(t *testing.T) { return &VerificationResults{}, nil } - serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), - testdata.Path("server_key_1.pem")) - if err != nil { - t.Fatalf("Server is unable to parse peer certificates. Error: %v", err) - } getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { return nil, fmt.Errorf("bad root certificate reloading") } @@ -109,7 +355,8 @@ func (s) TestClientServerHandshake(t *testing.T) { clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) clientVerifyFunc CustomVerificationFunc clientVType VerificationType - clientExpectCreateError bool + clientRootProvider certprovider.Provider + clientIdentityProvider certprovider.Provider clientExpectHandshakeError bool serverMutualTLS bool serverCert []tls.Certificate @@ -118,23 +365,10 @@ func (s) TestClientServerHandshake(t *testing.T) { serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) serverVerifyFunc CustomVerificationFunc serverVType VerificationType + serverRootProvider certprovider.Provider + serverIdentityProvider certprovider.Provider serverExpectError bool }{ - // Client: nil setting - // Server: only set serverCert with mutual TLS off - // Expected Behavior: server side failure - // Reason: if clientRoot, clientGetRoot and verifyFunc is not set, client - // side doesn't provide any verification mechanism. We don't allow this - // even setting vType to SkipVerification. Clients should at least provide - // their own verification logic. - { - desc: "Client has no trust cert; server sends peer cert", - clientVType: SkipVerification, - clientExpectCreateError: true, - serverCert: []tls.Certificate{serverPeerCert}, - serverVType: CertAndHostVerification, - serverExpectError: true, - }, // Client: nil setting except verifyFuncGood // Server: only set serverCert with mutual TLS off // Expected Behavior: success @@ -144,7 +378,7 @@ func (s) TestClientServerHandshake(t *testing.T) { desc: "Client has no trust cert with verifyFuncGood; server sends peer cert", clientVerifyFunc: clientVerifyFuncGood, clientVType: SkipVerification, - serverCert: []tls.Certificate{serverPeerCert}, + serverCert: []tls.Certificate{cs.serverPeer1}, serverVType: CertAndHostVerification, }, // Client: only set clientRoot @@ -155,10 +389,10 @@ func (s) TestClientServerHandshake(t *testing.T) { // this test suites. { desc: "Client has root cert; server sends peer cert", - clientRoot: clientTrustPool, + clientRoot: cs.clientTrust1, clientVType: CertAndHostVerification, clientExpectHandshakeError: true, - serverCert: []tls.Certificate{serverPeerCert}, + serverCert: []tls.Certificate{cs.serverPeer1}, serverVType: CertAndHostVerification, serverExpectError: true, }, @@ -173,7 +407,7 @@ func (s) TestClientServerHandshake(t *testing.T) { clientGetRoot: getRootCAsForClient, clientVType: CertAndHostVerification, clientExpectHandshakeError: true, - serverCert: []tls.Certificate{serverPeerCert}, + serverCert: []tls.Certificate{cs.serverPeer1}, serverVType: CertAndHostVerification, serverExpectError: true, }, @@ -185,7 +419,7 @@ func (s) TestClientServerHandshake(t *testing.T) { clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, - serverCert: []tls.Certificate{serverPeerCert}, + serverCert: []tls.Certificate{cs.serverPeer1}, serverVType: CertAndHostVerification, }, // Client: set clientGetRoot and bad clientVerifyFunc function @@ -198,66 +432,35 @@ func (s) TestClientServerHandshake(t *testing.T) { clientVerifyFunc: verifyFuncBad, clientVType: CertVerification, clientExpectHandshakeError: true, - serverCert: []tls.Certificate{serverPeerCert}, + serverCert: []tls.Certificate{cs.serverPeer1}, serverVType: CertVerification, serverExpectError: true, }, - // Client: set clientGetRoot and clientVerifyFunc - // Server: nil setting - // Expected Behavior: server side failure - // Reason: server side must either set serverCert or serverGetCert - { - desc: "Client sets reload root function with verifyFuncGood; server sets nil", - clientGetRoot: getRootCAsForClient, - clientVerifyFunc: clientVerifyFuncGood, - clientVType: CertVerification, - serverVType: CertVerification, - serverExpectError: true, - }, // Client: set clientGetRoot, clientVerifyFunc and clientCert // Server: set serverRoot and serverCert with mutual TLS on // Expected Behavior: success { desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets peer cert and root cert; mutualTLS", - clientCert: []tls.Certificate{clientPeerCert}, + clientCert: []tls.Certificate{cs.clientPeer1}, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, - serverCert: []tls.Certificate{serverPeerCert}, - serverRoot: serverTrustPool, + serverCert: []tls.Certificate{cs.serverPeer1}, + serverRoot: cs.serverTrust1, serverVType: CertVerification, }, // Client: set clientGetRoot, clientVerifyFunc and clientCert - // Server: set serverCert, but not setting any of serverRoot, serverGetRoot - // or serverVerifyFunc, with mutual TLS on - // Expected Behavior: server side failure - // Reason: server side needs to provide any verification mechanism when - // mTLS in on, even setting vType to SkipVerification. Servers should at - // least provide their own verification logic. - { - desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets no verification; mutualTLS", - clientCert: []tls.Certificate{clientPeerCert}, - clientGetRoot: getRootCAsForClient, - clientVerifyFunc: clientVerifyFuncGood, - clientVType: CertVerification, - clientExpectHandshakeError: true, - serverMutualTLS: true, - serverCert: []tls.Certificate{serverPeerCert}, - serverVType: SkipVerification, - serverExpectError: true, - }, - // Client: set clientGetRoot, clientVerifyFunc and clientCert // Server: set serverGetRoot and serverCert with mutual TLS on // Expected Behavior: success { desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS", - clientCert: []tls.Certificate{clientPeerCert}, + clientCert: []tls.Certificate{cs.clientPeer1}, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, - serverCert: []tls.Certificate{serverPeerCert}, + serverCert: []tls.Certificate{cs.serverPeer1}, serverGetRoot: getRootCAsForServer, serverVType: CertVerification, }, @@ -268,12 +471,12 @@ func (s) TestClientServerHandshake(t *testing.T) { // Reason: server side reloading returns failure { desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, bad reload root function; mutualTLS", - clientCert: []tls.Certificate{clientPeerCert}, + clientCert: []tls.Certificate{cs.clientPeer1}, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, - serverCert: []tls.Certificate{serverPeerCert}, + serverCert: []tls.Certificate{cs.serverPeer1}, serverGetRoot: getRootCAsForServerBad, serverVType: CertVerification, serverExpectError: true, @@ -284,14 +487,14 @@ func (s) TestClientServerHandshake(t *testing.T) { { desc: "Client sets reload peer/root function with verifyFuncGood; Server sets reload peer/root function with verifyFuncGood; mutualTLS", clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return &clientPeerCert, nil + return &cs.clientPeer1, nil }, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&serverPeerCert}, nil + return []*tls.Certificate{&cs.serverPeer1}, nil }, serverGetRoot: getRootCAsForServer, serverVerifyFunc: serverVerifyFunc, @@ -305,14 +508,14 @@ func (s) TestClientServerHandshake(t *testing.T) { { desc: "Client sends wrong peer cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS", clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return &serverPeerCert, nil + return &cs.serverPeer1, nil }, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&serverPeerCert}, nil + return []*tls.Certificate{&cs.serverPeer1}, nil }, serverGetRoot: getRootCAsForServer, serverVerifyFunc: serverVerifyFunc, @@ -326,7 +529,7 @@ func (s) TestClientServerHandshake(t *testing.T) { { desc: "Client has wrong trust cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS", clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return &clientPeerCert, nil + return &cs.clientPeer1, nil }, clientGetRoot: getRootCAsForServer, clientVerifyFunc: clientVerifyFuncGood, @@ -334,7 +537,7 @@ func (s) TestClientServerHandshake(t *testing.T) { clientExpectHandshakeError: true, serverMutualTLS: true, serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&serverPeerCert}, nil + return []*tls.Certificate{&cs.serverPeer1}, nil }, serverGetRoot: getRootCAsForServer, serverVerifyFunc: serverVerifyFunc, @@ -349,14 +552,14 @@ func (s) TestClientServerHandshake(t *testing.T) { { desc: "Client sets reload peer/root function with verifyFuncGood; Server sends wrong peer cert; mutualTLS", clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return &clientPeerCert, nil + return &cs.clientPeer1, nil }, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, serverMutualTLS: true, serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&clientPeerCert}, nil + return []*tls.Certificate{&cs.clientPeer1}, nil }, serverGetRoot: getRootCAsForServer, serverVerifyFunc: serverVerifyFunc, @@ -370,7 +573,7 @@ func (s) TestClientServerHandshake(t *testing.T) { { desc: "Client sets reload peer/root function with verifyFuncGood; Server has wrong trust cert; mutualTLS", clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return &clientPeerCert, nil + return &cs.clientPeer1, nil }, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, @@ -378,7 +581,7 @@ func (s) TestClientServerHandshake(t *testing.T) { clientExpectHandshakeError: true, serverMutualTLS: true, serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&serverPeerCert}, nil + return []*tls.Certificate{&cs.serverPeer1}, nil }, serverGetRoot: getRootCAsForClient, serverVerifyFunc: serverVerifyFunc, @@ -391,18 +594,92 @@ func (s) TestClientServerHandshake(t *testing.T) { // server custom check fails { desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets bad custom check; mutualTLS", - clientCert: []tls.Certificate{clientPeerCert}, + clientCert: []tls.Certificate{cs.clientPeer1}, clientGetRoot: getRootCAsForClient, clientVerifyFunc: clientVerifyFuncGood, clientVType: CertVerification, clientExpectHandshakeError: true, serverMutualTLS: true, - serverCert: []tls.Certificate{serverPeerCert}, + serverCert: []tls.Certificate{cs.serverPeer1}, serverGetRoot: getRootCAsForServer, serverVerifyFunc: verifyFuncBad, serverVType: CertVerification, serverExpectError: true, }, + // Client: set a clientIdentityProvider which will get multiple cert chains + // Server: set serverIdentityProvider and serverRootProvider with mutual TLS on + // Expected Behavior: server side failure due to multiple cert chains in + // clientIdentityProvider + { + desc: "Client sets multiple certs in clientIdentityProvider; Server sets root and identity provider; mutualTLS", + clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantMultiCert: true}, + clientRootProvider: fakeProvider{isClient: true}, + clientVerifyFunc: clientVerifyFuncGood, + clientVType: CertVerification, + serverMutualTLS: true, + serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false}, + serverRootProvider: fakeProvider{isClient: false}, + serverVType: CertVerification, + serverExpectError: true, + }, + // Client: set a bad clientIdentityProvider + // Server: set serverIdentityProvider and serverRootProvider with mutual TLS on + // Expected Behavior: server side failure due to bad clientIdentityProvider + { + desc: "Client sets bad clientIdentityProvider; Server sets root and identity provider; mutualTLS", + clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantError: true}, + clientRootProvider: fakeProvider{isClient: true}, + clientVerifyFunc: clientVerifyFuncGood, + clientVType: CertVerification, + serverMutualTLS: true, + serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false}, + serverRootProvider: fakeProvider{isClient: false}, + serverVType: CertVerification, + serverExpectError: true, + }, + // Client: set clientIdentityProvider and clientRootProvider + // Server: set bad serverRootProvider with mutual TLS on + // Expected Behavior: server side failure due to bad serverRootProvider + { + desc: "Client sets root and identity provider; Server sets bad root provider; mutualTLS", + clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true}, + clientRootProvider: fakeProvider{isClient: true}, + clientVerifyFunc: clientVerifyFuncGood, + clientVType: CertVerification, + serverMutualTLS: true, + serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false}, + serverRootProvider: fakeProvider{isClient: false, wantError: true}, + serverVType: CertVerification, + serverExpectError: true, + }, + // Client: set clientIdentityProvider and clientRootProvider + // Server: set serverIdentityProvider and serverRootProvider with mutual TLS on + // Expected Behavior: success + { + desc: "Client sets root and identity provider; Server sets root and identity provider; mutualTLS", + clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true}, + clientRootProvider: fakeProvider{isClient: true}, + clientVerifyFunc: clientVerifyFuncGood, + clientVType: CertVerification, + serverMutualTLS: true, + serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false}, + serverRootProvider: fakeProvider{isClient: false}, + serverVType: CertVerification, + }, + // Client: set clientIdentityProvider and clientRootProvider + // Server: set serverIdentityProvider getting multiple cert chains and serverRootProvider with mutual TLS on + // Expected Behavior: success, because server side has SNI + { + desc: "Client sets root and identity provider; Server sets multiple certs in serverIdentityProvider; mutualTLS", + clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true}, + clientRootProvider: fakeProvider{isClient: true}, + clientVerifyFunc: clientVerifyFuncGood, + clientVType: CertVerification, + serverMutualTLS: true, + serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false, wantMultiCert: true}, + serverRootProvider: fakeProvider{isClient: false}, + serverVType: CertVerification, + }, } { test := test t.Run(test.desc, func(t *testing.T) { @@ -413,11 +690,15 @@ func (s) TestClientServerHandshake(t *testing.T) { } // Start a server using ServerOptions in another goroutine. serverOptions := &ServerOptions{ - Certificates: test.serverCert, - GetCertificates: test.serverGetCert, - RootCertificateOptions: RootCertificateOptions{ - RootCACerts: test.serverRoot, - GetRootCAs: test.serverGetRoot, + IdentityOptions: IdentityCertificateOptions{ + Certificates: test.serverCert, + GetIdentityCertificatesForServer: test.serverGetCert, + IdentityProvider: test.serverIdentityProvider, + }, + RootOptions: RootCertificateOptions{ + RootCACerts: test.serverRoot, + GetRootCertificates: test.serverGetRoot, + RootProvider: test.serverRootProvider, }, RequireClientCert: test.serverMutualTLS, VerifyPeer: test.serverVerifyFunc, @@ -452,23 +733,22 @@ func (s) TestClientServerHandshake(t *testing.T) { } defer conn.Close() clientOptions := &ClientOptions{ - Certificates: test.clientCert, - GetClientCertificate: test.clientGetCert, - VerifyPeer: test.clientVerifyFunc, - RootCertificateOptions: RootCertificateOptions{ - RootCACerts: test.clientRoot, - GetRootCAs: test.clientGetRoot, + IdentityOptions: IdentityCertificateOptions{ + Certificates: test.clientCert, + GetIdentityCertificatesForClient: test.clientGetCert, + IdentityProvider: test.clientIdentityProvider, + }, + VerifyPeer: test.clientVerifyFunc, + RootOptions: RootCertificateOptions{ + RootCACerts: test.clientRoot, + GetRootCertificates: test.clientGetRoot, + RootProvider: test.clientRootProvider, }, VType: test.clientVType, } - clientTLS, newClientErr := NewClientCreds(clientOptions) - if newClientErr != nil && test.clientExpectCreateError { - return - } - if newClientErr != nil && !test.clientExpectCreateError || - newClientErr == nil && test.clientExpectCreateError { - t.Fatalf("Expect error: %v, but err is %v", - test.clientExpectCreateError, newClientErr) + clientTLS, err := NewClientCreds(clientOptions) + if err != nil { + t.Fatalf("NewClientCreds failed: %v", err) } _, clientAuthInfo, handshakeErr := clientTLS.ClientHandshake(context.Background(), lisAddr, conn) @@ -541,7 +821,7 @@ func (s) TestAdvancedTLSOverrideServerName(t *testing.T) { t.Fatalf("Client is unable to load trust certs. Error: %v", err) } clientOptions := &ClientOptions{ - RootCertificateOptions: RootCertificateOptions{ + RootOptions: RootCertificateOptions{ RootCACerts: clientTrustPool, }, ServerNameOverride: expectedServerName, @@ -563,7 +843,7 @@ func (s) TestTLSClone(t *testing.T) { t.Fatalf("Client is unable to load trust certs. Error: %v", err) } clientOptions := &ClientOptions{ - RootCertificateOptions: RootCertificateOptions{ + RootOptions: RootCertificateOptions{ RootCACerts: clientTrustPool, }, ServerNameOverride: expectedServerName, @@ -635,62 +915,6 @@ func (s) TestWrapSyscallConn(t *testing.T) { } } -func (s) TestOptionsConfig(t *testing.T) { - serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), - testdata.Path("server_key_1.pem")) - if err != nil { - t.Fatalf("Server is unable to parse peer certificates. Error: %v", err) - } - tests := []struct { - desc string - clientVType VerificationType - serverMutualTLS bool - serverCert []tls.Certificate - serverVType VerificationType - }{ - { - desc: "Client uses system-provided RootCAs; server uses system-provided ClientCAs", - clientVType: CertVerification, - serverMutualTLS: true, - serverCert: []tls.Certificate{serverPeerCert}, - serverVType: CertAndHostVerification, - }, - } - for _, test := range tests { - test := test - t.Run(test.desc, func(t *testing.T) { - serverOptions := &ServerOptions{ - Certificates: test.serverCert, - RequireClientCert: test.serverMutualTLS, - VType: test.serverVType, - } - serverConfig, err := serverOptions.config() - if err != nil { - t.Fatalf("Unable to generate serverConfig. Error: %v", err) - } - // Verify that the system-provided certificates would be used - // when no verification method was set in serverOptions. - if serverOptions.RootCACerts == nil && serverOptions.GetRootCAs == nil && - serverOptions.RequireClientCert && serverConfig.ClientCAs == nil { - t.Fatalf("Failed to assign system-provided certificates on the server side.") - } - clientOptions := &ClientOptions{ - VType: test.clientVType, - } - clientConfig, err := clientOptions.config() - if err != nil { - t.Fatalf("Unable to generate clientConfig. Error: %v", err) - } - // Verify that the system-provided certificates would be used - // when no verification method was set in clientOptions. - if clientOptions.RootCACerts == nil && clientOptions.GetRootCAs == nil && - clientConfig.RootCAs == nil { - t.Fatalf("Failed to assign system-provided certificates on the client side.") - } - }) - } -} - func (s) TestGetCertificatesSNI(t *testing.T) { // Load server certificates for setting the serverGetCert callback function. serverCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem")) @@ -734,8 +958,10 @@ func (s) TestGetCertificatesSNI(t *testing.T) { test := test t.Run(test.desc, func(t *testing.T) { serverOptions := &ServerOptions{ - GetCertificates: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { - return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil + IdentityOptions: IdentityCertificateOptions{ + GetIdentityCertificatesForServer: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { + return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil + }, }, } serverConfig, err := serverOptions.config() diff --git a/security/advancedtls/pemfile_provider.go b/security/advancedtls/pemfile_provider.go index 3dbb52440..7e03bb835 100644 --- a/security/advancedtls/pemfile_provider.go +++ b/security/advancedtls/pemfile_provider.go @@ -97,7 +97,6 @@ func NewPEMFileProvider(o PEMFileProviderOptions) (*PEMFileProvider, error) { return nil, fmt.Errorf("private key file and identity cert file should be both specified or not specified") } if o.IdentityInterval == 0 { - logger.Warningf("heyheyhey") o.IdentityInterval = defaultIdentityInterval } if o.RootInterval == 0 { diff --git a/security/advancedtls/sni.go b/security/advancedtls/sni.go index 6c8623283..7fef1990c 100644 --- a/security/advancedtls/sni.go +++ b/security/advancedtls/sni.go @@ -28,10 +28,10 @@ import ( // buildGetCertificates returns the certificate that matches the SNI field // for the given ClientHelloInfo, defaulting to the first element of o.GetCertificates. func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) { - if o.GetCertificates == nil { + if o.IdentityOptions.GetIdentityCertificatesForServer == nil { return nil, fmt.Errorf("function GetCertificates must be specified") } - certificates, err := o.GetCertificates(clientHello) + certificates, err := o.IdentityOptions.GetIdentityCertificatesForServer(clientHello) if err != nil { return nil, err }