mirror of https://github.com/grpc/grpc-go.git
advancedtls: add fields for root and identity providers in API (#3863)
* add provider in advancedtls API for pem file reloading
This commit is contained in:
parent
4270c3cfce
commit
0f7e218c2c
|
@ -27,10 +27,12 @@ import (
|
|||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/tls/certprovider"
|
||||
credinternal "google.golang.org/grpc/internal/credentials"
|
||||
)
|
||||
|
||||
|
@ -79,21 +81,67 @@ type GetRootCAsResults struct {
|
|||
TrustCerts *x509.CertPool
|
||||
}
|
||||
|
||||
// RootCertificateOptions contains a field and a function for obtaining root
|
||||
// trust certificates.
|
||||
// It is used by both ClientOptions and ServerOptions.
|
||||
// If users want to use default verification, but did not provide a valid
|
||||
// RootCertificateOptions, we use the system default trust certificates.
|
||||
// RootCertificateOptions contains options to obtain root trust certificates
|
||||
// for both the client and the server.
|
||||
// At most one option could be set. If none of them are set, we
|
||||
// use the system default trust certificates.
|
||||
type RootCertificateOptions struct {
|
||||
// If field RootCACerts is set, field GetRootCAs will be ignored. RootCACerts
|
||||
// will be used every time when verifying the peer certificates, without
|
||||
// performing root certificate reloading.
|
||||
// If RootCACerts is set, it will be used every time when verifying
|
||||
// the peer certificates, without performing root certificate reloading.
|
||||
RootCACerts *x509.CertPool
|
||||
// If GetRootCAs is set and RootCACerts is nil, GetRootCAs will be invoked
|
||||
// every time asked to check certificates sent from the server when a new
|
||||
// connection is established.
|
||||
// This is known as root CA certificate reloading.
|
||||
GetRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error)
|
||||
// If GetRootCertificates is set, it will be invoked to obtain root certs for
|
||||
// every new connection.
|
||||
GetRootCertificates func(params *GetRootCAsParams) (*GetRootCAsResults, error)
|
||||
// If RootProvider is set, we will use the root certs from the Provider's
|
||||
// KeyMaterial() call in the new connections. The Provider must have initial
|
||||
// credentials if specified. Otherwise, KeyMaterial() will block forever.
|
||||
RootProvider certprovider.Provider
|
||||
}
|
||||
|
||||
// nonNilFieldCount returns the number of set fields in RootCertificateOptions.
|
||||
func (o RootCertificateOptions) nonNilFieldCount() int {
|
||||
cnt := 0
|
||||
rv := reflect.ValueOf(o)
|
||||
for i := 0; i < rv.NumField(); i++ {
|
||||
if !rv.Field(i).IsNil() {
|
||||
cnt++
|
||||
}
|
||||
}
|
||||
return cnt
|
||||
}
|
||||
|
||||
// IdentityCertificateOptions contains options to obtain identity certificates
|
||||
// for both the client and the server.
|
||||
// At most one option could be set.
|
||||
type IdentityCertificateOptions struct {
|
||||
// If Certificates is set, it will be used every time when needed to present
|
||||
//identity certificates, without performing identity certificate reloading.
|
||||
Certificates []tls.Certificate
|
||||
// If GetIdentityCertificatesForClient is set, it will be invoked to obtain
|
||||
// identity certs for every new connection.
|
||||
// This field MUST be set on client side.
|
||||
GetIdentityCertificatesForClient func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
|
||||
// If GetIdentityCertificatesForServer is set, it will be invoked to obtain
|
||||
// identity certs for every new connection.
|
||||
// This field MUST be set on server side.
|
||||
GetIdentityCertificatesForServer func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
|
||||
// If IdentityProvider is set, we will use the identity certs from the
|
||||
// Provider's KeyMaterial() call in the new connections. The Provider must
|
||||
// have initial credentials if specified. Otherwise, KeyMaterial() will block
|
||||
// forever.
|
||||
IdentityProvider certprovider.Provider
|
||||
}
|
||||
|
||||
// nonNilFieldCount returns the number of set fields in IdentityCertificateOptions.
|
||||
func (o IdentityCertificateOptions) nonNilFieldCount() int {
|
||||
cnt := 0
|
||||
rv := reflect.ValueOf(o)
|
||||
for i := 0; i < rv.NumField(); i++ {
|
||||
if !rv.Field(i).IsNil() {
|
||||
cnt++
|
||||
}
|
||||
}
|
||||
return cnt
|
||||
}
|
||||
|
||||
// VerificationType is the enum type that represents different levels of
|
||||
|
@ -115,27 +163,11 @@ const (
|
|||
SkipVerification
|
||||
)
|
||||
|
||||
// ClientOptions contains all the fields and functions needed to be filled by
|
||||
// the client.
|
||||
// General rules for certificate setting on client side:
|
||||
// Certificates or GetClientCertificate indicates the certificates sent from
|
||||
// the client to the server to prove client's identities. The rules for setting
|
||||
// these two fields are:
|
||||
// If requiring mutual authentication on server side:
|
||||
// Either Certificates or GetClientCertificate must be set; the other will
|
||||
// be ignored.
|
||||
// Otherwise:
|
||||
// Nothing needed(the two fields will be ignored).
|
||||
// ClientOptions contains the fields needed to be filled by the client.
|
||||
type ClientOptions struct {
|
||||
// If field Certificates is set, field GetClientCertificate will be ignored.
|
||||
// The client will use Certificates every time when asked for a certificate,
|
||||
// without performing certificate reloading.
|
||||
Certificates []tls.Certificate
|
||||
// If GetClientCertificate is set and Certificates is nil, the client will
|
||||
// invoke this function every time asked to present certificates to the
|
||||
// server when a new connection is established. This is known as peer
|
||||
// certificate reloading.
|
||||
GetClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
|
||||
// IdentityOptions is OPTIONAL on client side. This field only needs to be
|
||||
// set if mutual authentication is required on server side.
|
||||
IdentityOptions IdentityCertificateOptions
|
||||
// VerifyPeer is a custom verification check after certificate signature
|
||||
// check.
|
||||
// If this is set, we will perform this customized check after doing the
|
||||
|
@ -145,37 +177,25 @@ type ClientOptions struct {
|
|||
// it will override the virtual host name of authority (e.g. :authority
|
||||
// header field) in requests.
|
||||
ServerNameOverride string
|
||||
// RootCertificateOptions is REQUIRED to be correctly set on client side.
|
||||
RootCertificateOptions
|
||||
// RootOptions is OPTIONAL on client side. If not set, we will try to use the
|
||||
// default trust certificates in users' OS system.
|
||||
RootOptions RootCertificateOptions
|
||||
// VType is the verification type on the client side.
|
||||
VType VerificationType
|
||||
}
|
||||
|
||||
// ServerOptions contains all the fields and functions needed to be filled by
|
||||
// the client.
|
||||
// General rules for certificate setting on server side:
|
||||
// Certificates or GetClientCertificate indicates the certificates sent from
|
||||
// the server to the client to prove server's identities. The rules for setting
|
||||
// these two fields are:
|
||||
// Either Certificates or GetCertificates must be set; the other will be ignored.
|
||||
// ServerOptions contains the fields needed to be filled by the server.
|
||||
type ServerOptions struct {
|
||||
// If field Certificates is set, field GetClientCertificate will be ignored.
|
||||
// The server will use Certificates every time when asked for a certificate,
|
||||
// without performing certificate reloading.
|
||||
Certificates []tls.Certificate
|
||||
// If GetClientCertificate is set and Certificates is nil, the server will
|
||||
// invoke this function every time asked to present certificates to the
|
||||
// client when a new connection is established. This is known as peer
|
||||
// certificate reloading.
|
||||
GetCertificates func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
|
||||
// IdentityOptions is REQUIRED on server side.
|
||||
IdentityOptions IdentityCertificateOptions
|
||||
// VerifyPeer is a custom verification check after certificate signature
|
||||
// check.
|
||||
// If this is set, we will perform this customized check after doing the
|
||||
// normal check(s) indicated by setting VType.
|
||||
VerifyPeer CustomVerificationFunc
|
||||
// RootCertificateOptions is only required when mutual TLS is
|
||||
// enabled(RequireClientCert is true).
|
||||
RootCertificateOptions
|
||||
// RootOptions is OPTIONAL on server side. This field only needs to be set if
|
||||
// mutual authentication is required(RequireClientCert is true).
|
||||
RootOptions RootCertificateOptions
|
||||
// If the server want the client to send certificates.
|
||||
RequireClientCert bool
|
||||
// VType is the verification type on the server side.
|
||||
|
@ -184,48 +204,89 @@ type ServerOptions struct {
|
|||
|
||||
func (o *ClientOptions) config() (*tls.Config, error) {
|
||||
if o.VType == SkipVerification && o.VerifyPeer == nil {
|
||||
return nil, fmt.Errorf(
|
||||
"client needs to provide custom verification mechanism if choose to skip default verification")
|
||||
return nil, fmt.Errorf("client needs to provide custom verification mechanism if choose to skip default verification")
|
||||
}
|
||||
rootCAs := o.RootCACerts
|
||||
if o.VType != SkipVerification && o.RootCACerts == nil && o.GetRootCAs == nil {
|
||||
// Set rootCAs to system default.
|
||||
systemRootCAs, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rootCAs = systemRootCAs
|
||||
// Make sure users didn't specify more than one fields in
|
||||
// RootCertificateOptions and IdentityCertificateOptions.
|
||||
if num := o.RootOptions.nonNilFieldCount(); num > 1 {
|
||||
return nil, fmt.Errorf("at most one field in RootCertificateOptions could be specified")
|
||||
}
|
||||
if num := o.IdentityOptions.nonNilFieldCount(); num > 1 {
|
||||
return nil, fmt.Errorf("at most one field in IdentityCertificateOptions could be specified")
|
||||
}
|
||||
if o.IdentityOptions.GetIdentityCertificatesForServer != nil {
|
||||
return nil, fmt.Errorf("GetIdentityCertificatesForServer cannot be specified on the client side")
|
||||
}
|
||||
// We have to set InsecureSkipVerify to true to skip the default checks and
|
||||
// use the verification function we built from buildVerifyFunc.
|
||||
config := &tls.Config{
|
||||
ServerName: o.ServerNameOverride,
|
||||
Certificates: o.Certificates,
|
||||
GetClientCertificate: o.GetClientCertificate,
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: o.ServerNameOverride,
|
||||
// We have to set InsecureSkipVerify to true to skip the default checks and
|
||||
// use the verification function we built from buildVerifyFunc.
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
if rootCAs != nil {
|
||||
config.RootCAs = rootCAs
|
||||
// Propagate root-certificate-related fields in tls.Config.
|
||||
switch {
|
||||
case o.RootOptions.RootCACerts != nil:
|
||||
config.RootCAs = o.RootOptions.RootCACerts
|
||||
case o.RootOptions.GetRootCertificates != nil:
|
||||
// In cases when users provide GetRootCertificates callback, since this
|
||||
// callback is not contained in tls.Config, we have nothing to set here.
|
||||
// We will invoke the callback in ClientHandshake.
|
||||
case o.RootOptions.RootProvider != nil:
|
||||
o.RootOptions.GetRootCertificates = func(*GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
km, err := o.RootOptions.RootProvider.KeyMaterial(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &GetRootCAsResults{TrustCerts: km.Roots}, nil
|
||||
}
|
||||
default:
|
||||
// No root certificate options specified by user. Use the certificates
|
||||
// stored in system default path as the last resort.
|
||||
if o.VType != SkipVerification {
|
||||
systemRootCAs, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.RootCAs = systemRootCAs
|
||||
}
|
||||
}
|
||||
// Propagate identity-certificate-related fields in tls.Config.
|
||||
switch {
|
||||
case o.IdentityOptions.Certificates != nil:
|
||||
config.Certificates = o.IdentityOptions.Certificates
|
||||
case o.IdentityOptions.GetIdentityCertificatesForClient != nil:
|
||||
config.GetClientCertificate = o.IdentityOptions.GetIdentityCertificatesForClient
|
||||
case o.IdentityOptions.IdentityProvider != nil:
|
||||
config.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
km, err := o.IdentityOptions.IdentityProvider.KeyMaterial(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(km.Certs) != 1 {
|
||||
return nil, fmt.Errorf("there should always be only one identity cert chain on the client side in IdentityProvider")
|
||||
}
|
||||
return &km.Certs[0], nil
|
||||
}
|
||||
default:
|
||||
// It's fine for users to not specify identity certificate options here.
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (o *ServerOptions) config() (*tls.Config, error) {
|
||||
if o.Certificates == nil && o.GetCertificates == nil {
|
||||
return nil, fmt.Errorf("either Certificates or GetCertificates must be specified")
|
||||
}
|
||||
if o.RequireClientCert && o.VType == SkipVerification && o.VerifyPeer == nil {
|
||||
return nil, fmt.Errorf(
|
||||
"server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)")
|
||||
return nil, fmt.Errorf("server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)")
|
||||
}
|
||||
clientCAs := o.RootCACerts
|
||||
if o.VType != SkipVerification && o.RootCACerts == nil && o.GetRootCAs == nil && o.RequireClientCert {
|
||||
// Set clientCAs to system default.
|
||||
systemRootCAs, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientCAs = systemRootCAs
|
||||
// Make sure users didn't specify more than one fields in
|
||||
// RootCertificateOptions and IdentityCertificateOptions.
|
||||
if num := o.RootOptions.nonNilFieldCount(); num > 1 {
|
||||
return nil, fmt.Errorf("at most one field in RootCertificateOptions could be specified")
|
||||
}
|
||||
if num := o.IdentityOptions.nonNilFieldCount(); num > 1 {
|
||||
return nil, fmt.Errorf("at most one field in IdentityCertificateOptions could be specified")
|
||||
}
|
||||
if o.IdentityOptions.GetIdentityCertificatesForClient != nil {
|
||||
return nil, fmt.Errorf("GetIdentityCertificatesForClient cannot be specified on the server side")
|
||||
}
|
||||
clientAuth := tls.NoClientCert
|
||||
if o.RequireClientCert {
|
||||
|
@ -235,18 +296,60 @@ func (o *ServerOptions) config() (*tls.Config, error) {
|
|||
clientAuth = tls.RequireAnyClientCert
|
||||
}
|
||||
config := &tls.Config{
|
||||
ClientAuth: clientAuth,
|
||||
Certificates: o.Certificates,
|
||||
ClientAuth: clientAuth,
|
||||
}
|
||||
if o.GetCertificates != nil {
|
||||
// GetCertificate is only able to perform SNI logic for go1.10 and above.
|
||||
// It will return the first certificate in o.GetCertificates for go1.9.
|
||||
// Propagate root-certificate-related fields in tls.Config.
|
||||
switch {
|
||||
case o.RootOptions.RootCACerts != nil:
|
||||
config.ClientCAs = o.RootOptions.RootCACerts
|
||||
case o.RootOptions.GetRootCertificates != nil:
|
||||
// In cases when users provide GetRootCertificates callback, since this
|
||||
// callback is not contained in tls.Config, we have nothing to set here.
|
||||
// We will invoke the callback in ServerHandshake.
|
||||
case o.RootOptions.RootProvider != nil:
|
||||
o.RootOptions.GetRootCertificates = func(*GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
km, err := o.RootOptions.RootProvider.KeyMaterial(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &GetRootCAsResults{TrustCerts: km.Roots}, nil
|
||||
}
|
||||
default:
|
||||
// No root certificate options specified by user. Use the certificates
|
||||
// stored in system default path as the last resort.
|
||||
if o.VType != SkipVerification && o.RequireClientCert {
|
||||
systemRootCAs, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.ClientCAs = systemRootCAs
|
||||
}
|
||||
}
|
||||
// Propagate identity-certificate-related fields in tls.Config.
|
||||
switch {
|
||||
case o.IdentityOptions.Certificates != nil:
|
||||
config.Certificates = o.IdentityOptions.Certificates
|
||||
case o.IdentityOptions.GetIdentityCertificatesForServer != nil:
|
||||
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return buildGetCertificates(clientHello, o)
|
||||
}
|
||||
}
|
||||
if clientCAs != nil {
|
||||
config.ClientCAs = clientCAs
|
||||
case o.IdentityOptions.IdentityProvider != nil:
|
||||
o.IdentityOptions.GetIdentityCertificatesForServer = func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
km, err := o.IdentityOptions.IdentityProvider.KeyMaterial(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var certChains []*tls.Certificate
|
||||
for i := 0; i < len(km.Certs); i++ {
|
||||
certChains = append(certChains, &km.Certs[i])
|
||||
}
|
||||
return certChains, nil
|
||||
}
|
||||
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return buildGetCertificates(clientHello, o)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("needs to specify at least one field in IdentityCertificateOptions")
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
@ -423,7 +526,7 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error)
|
|||
tc := &advancedTLSCreds{
|
||||
config: conf,
|
||||
isClient: true,
|
||||
getRootCAs: o.GetRootCAs,
|
||||
getRootCAs: o.RootOptions.GetRootCertificates,
|
||||
verifyFunc: o.VerifyPeer,
|
||||
vType: o.VType,
|
||||
}
|
||||
|
@ -441,7 +544,7 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error)
|
|||
tc := &advancedTLSCreds{
|
||||
config: conf,
|
||||
isClient: false,
|
||||
getRootCAs: o.GetRootCAs,
|
||||
getRootCAs: o.RootOptions.GetRootCertificates,
|
||||
verifyFunc: o.VerifyPeer,
|
||||
vType: o.VType,
|
||||
}
|
||||
|
|
|
@ -385,11 +385,13 @@ func (s) TestEnd2End(t *testing.T) {
|
|||
t.Run(test.desc, func(t *testing.T) {
|
||||
// Start a server using ServerOptions in another goroutine.
|
||||
serverOptions := &ServerOptions{
|
||||
Certificates: test.serverCert,
|
||||
GetCertificates: test.serverGetCert,
|
||||
RootCertificateOptions: RootCertificateOptions{
|
||||
RootCACerts: test.serverRoot,
|
||||
GetRootCAs: test.serverGetRoot,
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
Certificates: test.serverCert,
|
||||
GetIdentityCertificatesForServer: test.serverGetCert,
|
||||
},
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: test.serverRoot,
|
||||
GetRootCertificates: test.serverGetRoot,
|
||||
},
|
||||
RequireClientCert: true,
|
||||
VerifyPeer: test.serverVerifyFunc,
|
||||
|
@ -409,12 +411,14 @@ func (s) TestEnd2End(t *testing.T) {
|
|||
pb.RegisterGreeterService(s, &pb.GreeterService{SayHello: sayHello})
|
||||
go s.Serve(lis)
|
||||
clientOptions := &ClientOptions{
|
||||
Certificates: test.clientCert,
|
||||
GetClientCertificate: test.clientGetCert,
|
||||
VerifyPeer: test.clientVerifyFunc,
|
||||
RootCertificateOptions: RootCertificateOptions{
|
||||
RootCACerts: test.clientRoot,
|
||||
GetRootCAs: test.clientGetRoot,
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
Certificates: test.clientCert,
|
||||
GetIdentityCertificatesForClient: test.clientGetCert,
|
||||
},
|
||||
VerifyPeer: test.clientVerifyFunc,
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: test.clientRoot,
|
||||
GetRootCertificates: test.clientGetRoot,
|
||||
},
|
||||
VType: test.clientVType,
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ import (
|
|||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/tls/certprovider"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
||||
)
|
||||
|
@ -46,14 +47,274 @@ func Test(t *testing.T) {
|
|||
grpctest.RunSubTests(t, s{})
|
||||
}
|
||||
|
||||
func (s) TestClientServerHandshake(t *testing.T) {
|
||||
// ------------------Load Client Trust Cert and Peer Cert-------------------
|
||||
clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem"))
|
||||
type provType int
|
||||
|
||||
const (
|
||||
provTypeRoot provType = iota
|
||||
provTypeIdentity
|
||||
)
|
||||
|
||||
type fakeProvider struct {
|
||||
pt provType
|
||||
isClient bool
|
||||
wantMultiCert bool
|
||||
wantError bool
|
||||
}
|
||||
|
||||
func (f fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
|
||||
if f.wantError {
|
||||
return nil, fmt.Errorf("bad fakeProvider")
|
||||
}
|
||||
cs := &certStore{}
|
||||
err := cs.loadCerts()
|
||||
if err != nil {
|
||||
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
|
||||
return nil, fmt.Errorf("failed to load certs: %v", err)
|
||||
}
|
||||
if f.pt == provTypeRoot && f.isClient {
|
||||
return &certprovider.KeyMaterial{Roots: cs.clientTrust1}, nil
|
||||
}
|
||||
if f.pt == provTypeRoot && !f.isClient {
|
||||
return &certprovider.KeyMaterial{Roots: cs.serverTrust1}, nil
|
||||
}
|
||||
if f.pt == provTypeIdentity && f.isClient {
|
||||
if f.wantMultiCert {
|
||||
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1, cs.clientPeer2}}, nil
|
||||
}
|
||||
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}}, nil
|
||||
}
|
||||
if f.wantMultiCert {
|
||||
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1, cs.serverPeer2}}, nil
|
||||
}
|
||||
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1}}, nil
|
||||
}
|
||||
|
||||
func (f fakeProvider) Close() {}
|
||||
|
||||
func (s) TestClientOptionsConfigErrorCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
clientVType VerificationType
|
||||
IdentityOptions IdentityCertificateOptions
|
||||
RootOptions RootCertificateOptions
|
||||
}{
|
||||
{
|
||||
desc: "Skip default verification and provide no root credentials",
|
||||
clientVType: SkipVerification,
|
||||
},
|
||||
{
|
||||
desc: "More than one fields in RootCertificateOptions is specified",
|
||||
clientVType: CertVerification,
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: x509.NewCertPool(),
|
||||
RootProvider: fakeProvider{},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "More than one fields in IdentityCertificateOptions is specified",
|
||||
clientVType: CertVerification,
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
IdentityProvider: fakeProvider{pt: provTypeIdentity},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Specify GetIdentityCertificatesForServer",
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
clientOptions := &ClientOptions{
|
||||
VType: test.clientVType,
|
||||
IdentityOptions: test.IdentityOptions,
|
||||
RootOptions: test.RootOptions,
|
||||
}
|
||||
_, err := clientOptions.config()
|
||||
if err == nil {
|
||||
t.Fatalf("ClientOptions{%v}.config() returns no err, wantErr != nil", clientOptions)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestClientOptionsConfigSuccessCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
clientVType VerificationType
|
||||
IdentityOptions IdentityCertificateOptions
|
||||
RootOptions RootCertificateOptions
|
||||
}{
|
||||
{
|
||||
desc: "Use system default if no fields in RootCertificateOptions is specified",
|
||||
clientVType: CertVerification,
|
||||
},
|
||||
{
|
||||
desc: "Good case with mutual TLS",
|
||||
clientVType: CertVerification,
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootProvider: fakeProvider{},
|
||||
},
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
IdentityProvider: fakeProvider{pt: provTypeIdentity},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
clientOptions := &ClientOptions{
|
||||
VType: test.clientVType,
|
||||
IdentityOptions: test.IdentityOptions,
|
||||
RootOptions: test.RootOptions,
|
||||
}
|
||||
clientConfig, err := clientOptions.config()
|
||||
if err != nil {
|
||||
t.Fatalf("ClientOptions{%v}.config() = %v, wantErr == nil", clientOptions, err)
|
||||
}
|
||||
// Verify that the system-provided certificates would be used
|
||||
// when no verification method was set in clientOptions.
|
||||
if clientOptions.RootOptions.RootCACerts == nil &&
|
||||
clientOptions.RootOptions.GetRootCertificates == nil && clientOptions.RootOptions.RootProvider == nil {
|
||||
if clientConfig.RootCAs == nil {
|
||||
t.Fatalf("Failed to assign system-provided certificates on the client side.")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestServerOptionsConfigErrorCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
requireClientCert bool
|
||||
serverVType VerificationType
|
||||
IdentityOptions IdentityCertificateOptions
|
||||
RootOptions RootCertificateOptions
|
||||
}{
|
||||
{
|
||||
desc: "Skip default verification and provide no root credentials",
|
||||
requireClientCert: true,
|
||||
serverVType: SkipVerification,
|
||||
},
|
||||
{
|
||||
desc: "More than one fields in RootCertificateOptions is specified",
|
||||
requireClientCert: true,
|
||||
serverVType: CertVerification,
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: x509.NewCertPool(),
|
||||
GetRootCertificates: func(*GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "More than one fields in IdentityCertificateOptions is specified",
|
||||
serverVType: CertVerification,
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
Certificates: []tls.Certificate{},
|
||||
IdentityProvider: fakeProvider{pt: provTypeIdentity},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no field in IdentityCertificateOptions is specified",
|
||||
serverVType: CertVerification,
|
||||
},
|
||||
{
|
||||
desc: "Specify GetIdentityCertificatesForClient",
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
serverOptions := &ServerOptions{
|
||||
VType: test.serverVType,
|
||||
RequireClientCert: test.requireClientCert,
|
||||
IdentityOptions: test.IdentityOptions,
|
||||
RootOptions: test.RootOptions,
|
||||
}
|
||||
_, err := serverOptions.config()
|
||||
if err == nil {
|
||||
t.Fatalf("ServerOptions{%v}.config() returns no err, wantErr != nil", serverOptions)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestServerOptionsConfigSuccessCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
requireClientCert bool
|
||||
serverVType VerificationType
|
||||
IdentityOptions IdentityCertificateOptions
|
||||
RootOptions RootCertificateOptions
|
||||
}{
|
||||
{
|
||||
desc: "Use system default if no fields in RootCertificateOptions is specified",
|
||||
requireClientCert: true,
|
||||
serverVType: CertVerification,
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
Certificates: []tls.Certificate{},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Good case with mutual TLS",
|
||||
requireClientCert: true,
|
||||
serverVType: CertVerification,
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootProvider: fakeProvider{},
|
||||
},
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
serverOptions := &ServerOptions{
|
||||
VType: test.serverVType,
|
||||
RequireClientCert: test.requireClientCert,
|
||||
IdentityOptions: test.IdentityOptions,
|
||||
RootOptions: test.RootOptions,
|
||||
}
|
||||
serverConfig, err := serverOptions.config()
|
||||
if err != nil {
|
||||
t.Fatalf("ServerOptions{%v}.config() = %v, wantErr == nil", serverOptions, err)
|
||||
}
|
||||
// Verify that the system-provided certificates would be used
|
||||
// when no verification method was set in serverOptions.
|
||||
if serverOptions.RootOptions.RootCACerts == nil &&
|
||||
serverOptions.RootOptions.GetRootCertificates == nil && serverOptions.RootOptions.RootProvider == nil {
|
||||
if serverConfig.ClientCAs == nil {
|
||||
t.Fatalf("Failed to assign system-provided certificates on the server side.")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestClientServerHandshake(t *testing.T) {
|
||||
cs := &certStore{}
|
||||
err := cs.loadCerts()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load certs: %v", err)
|
||||
}
|
||||
getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
return &GetRootCAsResults{TrustCerts: clientTrustPool}, nil
|
||||
return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil
|
||||
}
|
||||
clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||
if params.ServerName == "" {
|
||||
|
@ -69,18 +330,8 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
verifyFuncBad := func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||
return nil, fmt.Errorf("custom verification function failed")
|
||||
}
|
||||
clientPeerCert, err := tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"),
|
||||
testdata.Path("client_key_1.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("Client is unable to parse peer certificates. Error: %v", err)
|
||||
}
|
||||
// ------------------Load Server Trust Cert and Peer Cert-------------------
|
||||
serverTrustPool, err := readTrustCert(testdata.Path("server_trust_cert_1.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("Server is unable to load trust certs. Error: %v", err)
|
||||
}
|
||||
getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
return &GetRootCAsResults{TrustCerts: serverTrustPool}, nil
|
||||
return &GetRootCAsResults{TrustCerts: cs.serverTrust1}, nil
|
||||
}
|
||||
serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||
if params.ServerName != "" {
|
||||
|
@ -93,11 +344,6 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
|
||||
return &VerificationResults{}, nil
|
||||
}
|
||||
serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
|
||||
testdata.Path("server_key_1.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
|
||||
}
|
||||
getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
return nil, fmt.Errorf("bad root certificate reloading")
|
||||
}
|
||||
|
@ -109,7 +355,8 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
|
||||
clientVerifyFunc CustomVerificationFunc
|
||||
clientVType VerificationType
|
||||
clientExpectCreateError bool
|
||||
clientRootProvider certprovider.Provider
|
||||
clientIdentityProvider certprovider.Provider
|
||||
clientExpectHandshakeError bool
|
||||
serverMutualTLS bool
|
||||
serverCert []tls.Certificate
|
||||
|
@ -118,23 +365,10 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
|
||||
serverVerifyFunc CustomVerificationFunc
|
||||
serverVType VerificationType
|
||||
serverRootProvider certprovider.Provider
|
||||
serverIdentityProvider certprovider.Provider
|
||||
serverExpectError bool
|
||||
}{
|
||||
// Client: nil setting
|
||||
// Server: only set serverCert with mutual TLS off
|
||||
// Expected Behavior: server side failure
|
||||
// Reason: if clientRoot, clientGetRoot and verifyFunc is not set, client
|
||||
// side doesn't provide any verification mechanism. We don't allow this
|
||||
// even setting vType to SkipVerification. Clients should at least provide
|
||||
// their own verification logic.
|
||||
{
|
||||
desc: "Client has no trust cert; server sends peer cert",
|
||||
clientVType: SkipVerification,
|
||||
clientExpectCreateError: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverVType: CertAndHostVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: nil setting except verifyFuncGood
|
||||
// Server: only set serverCert with mutual TLS off
|
||||
// Expected Behavior: success
|
||||
|
@ -144,7 +378,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
desc: "Client has no trust cert with verifyFuncGood; server sends peer cert",
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: SkipVerification,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverVType: CertAndHostVerification,
|
||||
},
|
||||
// Client: only set clientRoot
|
||||
|
@ -155,10 +389,10 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
// this test suites.
|
||||
{
|
||||
desc: "Client has root cert; server sends peer cert",
|
||||
clientRoot: clientTrustPool,
|
||||
clientRoot: cs.clientTrust1,
|
||||
clientVType: CertAndHostVerification,
|
||||
clientExpectHandshakeError: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverVType: CertAndHostVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
|
@ -173,7 +407,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
clientGetRoot: getRootCAsForClient,
|
||||
clientVType: CertAndHostVerification,
|
||||
clientExpectHandshakeError: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverVType: CertAndHostVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
|
@ -185,7 +419,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverVType: CertAndHostVerification,
|
||||
},
|
||||
// Client: set clientGetRoot and bad clientVerifyFunc function
|
||||
|
@ -198,66 +432,35 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
clientVerifyFunc: verifyFuncBad,
|
||||
clientVType: CertVerification,
|
||||
clientExpectHandshakeError: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set clientGetRoot and clientVerifyFunc
|
||||
// Server: nil setting
|
||||
// Expected Behavior: server side failure
|
||||
// Reason: server side must either set serverCert or serverGetCert
|
||||
{
|
||||
desc: "Client sets reload root function with verifyFuncGood; server sets nil",
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set clientGetRoot, clientVerifyFunc and clientCert
|
||||
// Server: set serverRoot and serverCert with mutual TLS on
|
||||
// Expected Behavior: success
|
||||
{
|
||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets peer cert and root cert; mutualTLS",
|
||||
clientCert: []tls.Certificate{clientPeerCert},
|
||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverRoot: serverTrustPool,
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverRoot: cs.serverTrust1,
|
||||
serverVType: CertVerification,
|
||||
},
|
||||
// Client: set clientGetRoot, clientVerifyFunc and clientCert
|
||||
// Server: set serverCert, but not setting any of serverRoot, serverGetRoot
|
||||
// or serverVerifyFunc, with mutual TLS on
|
||||
// Expected Behavior: server side failure
|
||||
// Reason: server side needs to provide any verification mechanism when
|
||||
// mTLS in on, even setting vType to SkipVerification. Servers should at
|
||||
// least provide their own verification logic.
|
||||
{
|
||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets no verification; mutualTLS",
|
||||
clientCert: []tls.Certificate{clientPeerCert},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
clientExpectHandshakeError: true,
|
||||
serverMutualTLS: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverVType: SkipVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set clientGetRoot, clientVerifyFunc and clientCert
|
||||
// Server: set serverGetRoot and serverCert with mutual TLS on
|
||||
// Expected Behavior: success
|
||||
{
|
||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS",
|
||||
clientCert: []tls.Certificate{clientPeerCert},
|
||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverGetRoot: getRootCAsForServer,
|
||||
serverVType: CertVerification,
|
||||
},
|
||||
|
@ -268,12 +471,12 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
// Reason: server side reloading returns failure
|
||||
{
|
||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, bad reload root function; mutualTLS",
|
||||
clientCert: []tls.Certificate{clientPeerCert},
|
||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverGetRoot: getRootCAsForServerBad,
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
|
@ -284,14 +487,14 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
{
|
||||
desc: "Client sets reload peer/root function with verifyFuncGood; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
|
||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return &clientPeerCert, nil
|
||||
return &cs.clientPeer1, nil
|
||||
},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&serverPeerCert}, nil
|
||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
||||
},
|
||||
serverGetRoot: getRootCAsForServer,
|
||||
serverVerifyFunc: serverVerifyFunc,
|
||||
|
@ -305,14 +508,14 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
{
|
||||
desc: "Client sends wrong peer cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
|
||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return &serverPeerCert, nil
|
||||
return &cs.serverPeer1, nil
|
||||
},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&serverPeerCert}, nil
|
||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
||||
},
|
||||
serverGetRoot: getRootCAsForServer,
|
||||
serverVerifyFunc: serverVerifyFunc,
|
||||
|
@ -326,7 +529,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
{
|
||||
desc: "Client has wrong trust cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
|
||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return &clientPeerCert, nil
|
||||
return &cs.clientPeer1, nil
|
||||
},
|
||||
clientGetRoot: getRootCAsForServer,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
|
@ -334,7 +537,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
clientExpectHandshakeError: true,
|
||||
serverMutualTLS: true,
|
||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&serverPeerCert}, nil
|
||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
||||
},
|
||||
serverGetRoot: getRootCAsForServer,
|
||||
serverVerifyFunc: serverVerifyFunc,
|
||||
|
@ -349,14 +552,14 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
{
|
||||
desc: "Client sets reload peer/root function with verifyFuncGood; Server sends wrong peer cert; mutualTLS",
|
||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return &clientPeerCert, nil
|
||||
return &cs.clientPeer1, nil
|
||||
},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&clientPeerCert}, nil
|
||||
return []*tls.Certificate{&cs.clientPeer1}, nil
|
||||
},
|
||||
serverGetRoot: getRootCAsForServer,
|
||||
serverVerifyFunc: serverVerifyFunc,
|
||||
|
@ -370,7 +573,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
{
|
||||
desc: "Client sets reload peer/root function with verifyFuncGood; Server has wrong trust cert; mutualTLS",
|
||||
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return &clientPeerCert, nil
|
||||
return &cs.clientPeer1, nil
|
||||
},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
|
@ -378,7 +581,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
clientExpectHandshakeError: true,
|
||||
serverMutualTLS: true,
|
||||
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&serverPeerCert}, nil
|
||||
return []*tls.Certificate{&cs.serverPeer1}, nil
|
||||
},
|
||||
serverGetRoot: getRootCAsForClient,
|
||||
serverVerifyFunc: serverVerifyFunc,
|
||||
|
@ -391,18 +594,92 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
// server custom check fails
|
||||
{
|
||||
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets bad custom check; mutualTLS",
|
||||
clientCert: []tls.Certificate{clientPeerCert},
|
||||
clientCert: []tls.Certificate{cs.clientPeer1},
|
||||
clientGetRoot: getRootCAsForClient,
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
clientExpectHandshakeError: true,
|
||||
serverMutualTLS: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverCert: []tls.Certificate{cs.serverPeer1},
|
||||
serverGetRoot: getRootCAsForServer,
|
||||
serverVerifyFunc: verifyFuncBad,
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set a clientIdentityProvider which will get multiple cert chains
|
||||
// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
|
||||
// Expected Behavior: server side failure due to multiple cert chains in
|
||||
// clientIdentityProvider
|
||||
{
|
||||
desc: "Client sets multiple certs in clientIdentityProvider; Server sets root and identity provider; mutualTLS",
|
||||
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantMultiCert: true},
|
||||
clientRootProvider: fakeProvider{isClient: true},
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
|
||||
serverRootProvider: fakeProvider{isClient: false},
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set a bad clientIdentityProvider
|
||||
// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
|
||||
// Expected Behavior: server side failure due to bad clientIdentityProvider
|
||||
{
|
||||
desc: "Client sets bad clientIdentityProvider; Server sets root and identity provider; mutualTLS",
|
||||
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantError: true},
|
||||
clientRootProvider: fakeProvider{isClient: true},
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
|
||||
serverRootProvider: fakeProvider{isClient: false},
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set clientIdentityProvider and clientRootProvider
|
||||
// Server: set bad serverRootProvider with mutual TLS on
|
||||
// Expected Behavior: server side failure due to bad serverRootProvider
|
||||
{
|
||||
desc: "Client sets root and identity provider; Server sets bad root provider; mutualTLS",
|
||||
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
|
||||
clientRootProvider: fakeProvider{isClient: true},
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
|
||||
serverRootProvider: fakeProvider{isClient: false, wantError: true},
|
||||
serverVType: CertVerification,
|
||||
serverExpectError: true,
|
||||
},
|
||||
// Client: set clientIdentityProvider and clientRootProvider
|
||||
// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
|
||||
// Expected Behavior: success
|
||||
{
|
||||
desc: "Client sets root and identity provider; Server sets root and identity provider; mutualTLS",
|
||||
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
|
||||
clientRootProvider: fakeProvider{isClient: true},
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
|
||||
serverRootProvider: fakeProvider{isClient: false},
|
||||
serverVType: CertVerification,
|
||||
},
|
||||
// Client: set clientIdentityProvider and clientRootProvider
|
||||
// Server: set serverIdentityProvider getting multiple cert chains and serverRootProvider with mutual TLS on
|
||||
// Expected Behavior: success, because server side has SNI
|
||||
{
|
||||
desc: "Client sets root and identity provider; Server sets multiple certs in serverIdentityProvider; mutualTLS",
|
||||
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
|
||||
clientRootProvider: fakeProvider{isClient: true},
|
||||
clientVerifyFunc: clientVerifyFuncGood,
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false, wantMultiCert: true},
|
||||
serverRootProvider: fakeProvider{isClient: false},
|
||||
serverVType: CertVerification,
|
||||
},
|
||||
} {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
@ -413,11 +690,15 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
}
|
||||
// Start a server using ServerOptions in another goroutine.
|
||||
serverOptions := &ServerOptions{
|
||||
Certificates: test.serverCert,
|
||||
GetCertificates: test.serverGetCert,
|
||||
RootCertificateOptions: RootCertificateOptions{
|
||||
RootCACerts: test.serverRoot,
|
||||
GetRootCAs: test.serverGetRoot,
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
Certificates: test.serverCert,
|
||||
GetIdentityCertificatesForServer: test.serverGetCert,
|
||||
IdentityProvider: test.serverIdentityProvider,
|
||||
},
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: test.serverRoot,
|
||||
GetRootCertificates: test.serverGetRoot,
|
||||
RootProvider: test.serverRootProvider,
|
||||
},
|
||||
RequireClientCert: test.serverMutualTLS,
|
||||
VerifyPeer: test.serverVerifyFunc,
|
||||
|
@ -452,23 +733,22 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
}
|
||||
defer conn.Close()
|
||||
clientOptions := &ClientOptions{
|
||||
Certificates: test.clientCert,
|
||||
GetClientCertificate: test.clientGetCert,
|
||||
VerifyPeer: test.clientVerifyFunc,
|
||||
RootCertificateOptions: RootCertificateOptions{
|
||||
RootCACerts: test.clientRoot,
|
||||
GetRootCAs: test.clientGetRoot,
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
Certificates: test.clientCert,
|
||||
GetIdentityCertificatesForClient: test.clientGetCert,
|
||||
IdentityProvider: test.clientIdentityProvider,
|
||||
},
|
||||
VerifyPeer: test.clientVerifyFunc,
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: test.clientRoot,
|
||||
GetRootCertificates: test.clientGetRoot,
|
||||
RootProvider: test.clientRootProvider,
|
||||
},
|
||||
VType: test.clientVType,
|
||||
}
|
||||
clientTLS, newClientErr := NewClientCreds(clientOptions)
|
||||
if newClientErr != nil && test.clientExpectCreateError {
|
||||
return
|
||||
}
|
||||
if newClientErr != nil && !test.clientExpectCreateError ||
|
||||
newClientErr == nil && test.clientExpectCreateError {
|
||||
t.Fatalf("Expect error: %v, but err is %v",
|
||||
test.clientExpectCreateError, newClientErr)
|
||||
clientTLS, err := NewClientCreds(clientOptions)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientCreds failed: %v", err)
|
||||
}
|
||||
_, clientAuthInfo, handshakeErr := clientTLS.ClientHandshake(context.Background(),
|
||||
lisAddr, conn)
|
||||
|
@ -541,7 +821,7 @@ func (s) TestAdvancedTLSOverrideServerName(t *testing.T) {
|
|||
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
|
||||
}
|
||||
clientOptions := &ClientOptions{
|
||||
RootCertificateOptions: RootCertificateOptions{
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: clientTrustPool,
|
||||
},
|
||||
ServerNameOverride: expectedServerName,
|
||||
|
@ -563,7 +843,7 @@ func (s) TestTLSClone(t *testing.T) {
|
|||
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
|
||||
}
|
||||
clientOptions := &ClientOptions{
|
||||
RootCertificateOptions: RootCertificateOptions{
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootCACerts: clientTrustPool,
|
||||
},
|
||||
ServerNameOverride: expectedServerName,
|
||||
|
@ -635,62 +915,6 @@ func (s) TestWrapSyscallConn(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s) TestOptionsConfig(t *testing.T) {
|
||||
serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
|
||||
testdata.Path("server_key_1.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
|
||||
}
|
||||
tests := []struct {
|
||||
desc string
|
||||
clientVType VerificationType
|
||||
serverMutualTLS bool
|
||||
serverCert []tls.Certificate
|
||||
serverVType VerificationType
|
||||
}{
|
||||
{
|
||||
desc: "Client uses system-provided RootCAs; server uses system-provided ClientCAs",
|
||||
clientVType: CertVerification,
|
||||
serverMutualTLS: true,
|
||||
serverCert: []tls.Certificate{serverPeerCert},
|
||||
serverVType: CertAndHostVerification,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
serverOptions := &ServerOptions{
|
||||
Certificates: test.serverCert,
|
||||
RequireClientCert: test.serverMutualTLS,
|
||||
VType: test.serverVType,
|
||||
}
|
||||
serverConfig, err := serverOptions.config()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to generate serverConfig. Error: %v", err)
|
||||
}
|
||||
// Verify that the system-provided certificates would be used
|
||||
// when no verification method was set in serverOptions.
|
||||
if serverOptions.RootCACerts == nil && serverOptions.GetRootCAs == nil &&
|
||||
serverOptions.RequireClientCert && serverConfig.ClientCAs == nil {
|
||||
t.Fatalf("Failed to assign system-provided certificates on the server side.")
|
||||
}
|
||||
clientOptions := &ClientOptions{
|
||||
VType: test.clientVType,
|
||||
}
|
||||
clientConfig, err := clientOptions.config()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to generate clientConfig. Error: %v", err)
|
||||
}
|
||||
// Verify that the system-provided certificates would be used
|
||||
// when no verification method was set in clientOptions.
|
||||
if clientOptions.RootCACerts == nil && clientOptions.GetRootCAs == nil &&
|
||||
clientConfig.RootCAs == nil {
|
||||
t.Fatalf("Failed to assign system-provided certificates on the client side.")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestGetCertificatesSNI(t *testing.T) {
|
||||
// Load server certificates for setting the serverGetCert callback function.
|
||||
serverCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem"))
|
||||
|
@ -734,8 +958,10 @@ func (s) TestGetCertificatesSNI(t *testing.T) {
|
|||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
serverOptions := &ServerOptions{
|
||||
GetCertificates: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
GetIdentityCertificatesForServer: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
|
||||
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
serverConfig, err := serverOptions.config()
|
||||
|
|
|
@ -97,7 +97,6 @@ func NewPEMFileProvider(o PEMFileProviderOptions) (*PEMFileProvider, error) {
|
|||
return nil, fmt.Errorf("private key file and identity cert file should be both specified or not specified")
|
||||
}
|
||||
if o.IdentityInterval == 0 {
|
||||
logger.Warningf("heyheyhey")
|
||||
o.IdentityInterval = defaultIdentityInterval
|
||||
}
|
||||
if o.RootInterval == 0 {
|
||||
|
|
|
@ -28,10 +28,10 @@ import (
|
|||
// buildGetCertificates returns the certificate that matches the SNI field
|
||||
// for the given ClientHelloInfo, defaulting to the first element of o.GetCertificates.
|
||||
func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) {
|
||||
if o.GetCertificates == nil {
|
||||
if o.IdentityOptions.GetIdentityCertificatesForServer == nil {
|
||||
return nil, fmt.Errorf("function GetCertificates must be specified")
|
||||
}
|
||||
certificates, err := o.GetCertificates(clientHello)
|
||||
certificates, err := o.IdentityOptions.GetIdentityCertificatesForServer(clientHello)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue