diff --git a/security/advancedtls/advancedtls.go b/security/advancedtls/advancedtls.go index 534a3ed41..1892c5ed7 100644 --- a/security/advancedtls/advancedtls.go +++ b/security/advancedtls/advancedtls.go @@ -181,6 +181,9 @@ type ClientOptions struct { RootOptions RootCertificateOptions // VType is the verification type on the client side. VType VerificationType + // RevocationConfig is the configurations for certificate revocation checks. + // It could be nil if such checks are not needed. + RevocationConfig *RevocationConfig } // ServerOptions contains the fields needed to be filled by the server. @@ -199,6 +202,9 @@ type ServerOptions struct { RequireClientCert bool // VType is the verification type on the server side. VType VerificationType + // RevocationConfig is the configurations for certificate revocation checks. + // It could be nil if such checks are not needed. + RevocationConfig *RevocationConfig } func (o *ClientOptions) config() (*tls.Config, error) { @@ -356,11 +362,12 @@ func (o *ServerOptions) config() (*tls.Config, error) { // advancedTLSCreds is the credentials required for authenticating a connection // using TLS. type advancedTLSCreds struct { - config *tls.Config - verifyFunc CustomVerificationFunc - getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error) - isClient bool - vType VerificationType + config *tls.Config + verifyFunc CustomVerificationFunc + getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error) + isClient bool + vType VerificationType + revocationConfig *RevocationConfig } func (c advancedTLSCreds) Info() credentials.ProtocolInfo { @@ -451,6 +458,14 @@ func buildVerifyFunc(c *advancedTLSCreds, return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { chains := verifiedChains var leafCert *x509.Certificate + rawCertList := make([]*x509.Certificate, len(rawCerts)) + for i, asn1Data := range rawCerts { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return err + } + rawCertList[i] = cert + } if c.vType == CertAndHostVerification || c.vType == CertVerification { // perform possible trust credential reloading and certificate check rootCAs := c.config.RootCAs @@ -469,14 +484,6 @@ func buildVerifyFunc(c *advancedTLSCreds, rootCAs = results.TrustCerts } // Verify peers' certificates against RootCAs and get verifiedChains. - certs := make([]*x509.Certificate, len(rawCerts)) - for i, asn1Data := range rawCerts { - cert, err := x509.ParseCertificate(asn1Data) - if err != nil { - return err - } - certs[i] = cert - } keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} if !c.isClient { keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} @@ -487,7 +494,7 @@ func buildVerifyFunc(c *advancedTLSCreds, Intermediates: x509.NewCertPool(), KeyUsages: keyUsages, } - for _, cert := range certs[1:] { + for _, cert := range rawCertList[1:] { opts.Intermediates.AddCert(cert) } // Perform default hostname check if specified. @@ -501,11 +508,21 @@ func buildVerifyFunc(c *advancedTLSCreds, opts.DNSName = parsedName } var err error - chains, err = certs[0].Verify(opts) + chains, err = rawCertList[0].Verify(opts) if err != nil { return err } - leafCert = certs[0] + leafCert = rawCertList[0] + } + // Perform certificate revocation check if specified. + if c.revocationConfig != nil { + verifiedChains := chains + if verifiedChains == nil { + verifiedChains = [][]*x509.Certificate{rawCertList} + } + if err := CheckChainRevocation(verifiedChains, *c.revocationConfig); err != nil { + return err + } } // Perform custom verification check if specified. if c.verifyFunc != nil { @@ -529,11 +546,12 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error) return nil, err } tc := &advancedTLSCreds{ - config: conf, - isClient: true, - getRootCAs: o.RootOptions.GetRootCertificates, - verifyFunc: o.VerifyPeer, - vType: o.VType, + config: conf, + isClient: true, + getRootCAs: o.RootOptions.GetRootCertificates, + verifyFunc: o.VerifyPeer, + vType: o.VType, + revocationConfig: o.RevocationConfig, } tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) return tc, nil @@ -547,11 +565,12 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error) return nil, err } tc := &advancedTLSCreds{ - config: conf, - isClient: false, - getRootCAs: o.RootOptions.GetRootCertificates, - verifyFunc: o.VerifyPeer, - vType: o.VType, + config: conf, + isClient: false, + getRootCAs: o.RootOptions.GetRootCertificates, + verifyFunc: o.VerifyPeer, + vType: o.VType, + revocationConfig: o.RevocationConfig, } tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) return tc, nil diff --git a/security/advancedtls/advancedtls_integration_test.go b/security/advancedtls/advancedtls_integration_test.go index 4bb9e645b..8cddfc234 100644 --- a/security/advancedtls/advancedtls_integration_test.go +++ b/security/advancedtls/advancedtls_integration_test.go @@ -380,7 +380,7 @@ func (s) TestEnd2End(t *testing.T) { } clientTLSCreds, err := NewClientCreds(clientOptions) if err != nil { - t.Fatalf("clientTLSCreds failed to create") + t.Fatalf("clientTLSCreds failed to create: %v", err) } // ------------------------Scenario 1------------------------------------ // stage = 0, initial connection should succeed @@ -796,7 +796,7 @@ func (s) TestDefaultHostNameCheck(t *testing.T) { } clientTLSCreds, err := NewClientCreds(clientOptions) if err != nil { - t.Fatalf("clientTLSCreds failed to create") + t.Fatalf("clientTLSCreds failed to create: %v", err) } shouldFail := false if test.expectError { diff --git a/security/advancedtls/advancedtls_test.go b/security/advancedtls/advancedtls_test.go index 64da81a17..7092d46e6 100644 --- a/security/advancedtls/advancedtls_test.go +++ b/security/advancedtls/advancedtls_test.go @@ -27,10 +27,12 @@ import ( "net" "testing" + lru "github.com/hashicorp/golang-lru" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/security/advancedtls/internal/testutils" + "google.golang.org/grpc/security/advancedtls/testdata" ) type s struct { @@ -339,6 +341,10 @@ func (s) TestClientServerHandshake(t *testing.T) { getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { return nil, fmt.Errorf("bad root certificate reloading") } + cache, err := lru.New(5) + if err != nil { + t.Fatalf("lru.New: err = %v", err) + } for _, test := range []struct { desc string clientCert []tls.Certificate @@ -349,6 +355,7 @@ func (s) TestClientServerHandshake(t *testing.T) { clientVType VerificationType clientRootProvider certprovider.Provider clientIdentityProvider certprovider.Provider + clientRevocationConfig *RevocationConfig clientExpectHandshakeError bool serverMutualTLS bool serverCert []tls.Certificate @@ -359,6 +366,7 @@ func (s) TestClientServerHandshake(t *testing.T) { serverVType VerificationType serverRootProvider certprovider.Provider serverIdentityProvider certprovider.Provider + serverRevocationConfig *RevocationConfig serverExpectError bool }{ // Client: nil setting except verifyFuncGood @@ -642,6 +650,30 @@ func (s) TestClientServerHandshake(t *testing.T) { serverRootProvider: fakeProvider{isClient: false}, serverVType: CertVerification, }, + // Client: set valid credentials with the revocation config + // Server: set valid credentials with the revocation config + // Expected Behavior: success, because non of the certificate chains sent in the connection are revoked + { + desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS", + clientCert: []tls.Certificate{cs.ClientCert1}, + clientGetRoot: getRootCAsForClient, + clientVerifyFunc: clientVerifyFuncGood, + clientVType: CertVerification, + clientRevocationConfig: &RevocationConfig{ + RootDir: testdata.Path("crl"), + AllowUndetermined: true, + Cache: cache, + }, + serverMutualTLS: true, + serverCert: []tls.Certificate{cs.ServerCert1}, + serverGetRoot: getRootCAsForServer, + serverVType: CertVerification, + serverRevocationConfig: &RevocationConfig{ + RootDir: testdata.Path("crl"), + AllowUndetermined: true, + Cache: cache, + }, + }, } { test := test t.Run(test.desc, func(t *testing.T) { @@ -665,6 +697,7 @@ func (s) TestClientServerHandshake(t *testing.T) { RequireClientCert: test.serverMutualTLS, VerifyPeer: test.serverVerifyFunc, VType: test.serverVType, + RevocationConfig: test.serverRevocationConfig, } go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) { serverRawConn, err := lis.Accept() @@ -706,7 +739,8 @@ func (s) TestClientServerHandshake(t *testing.T) { GetRootCertificates: test.clientGetRoot, RootProvider: test.clientRootProvider, }, - VType: test.clientVType, + VType: test.clientVType, + RevocationConfig: test.clientRevocationConfig, } clientTLS, err := NewClientCreds(clientOptions) if err != nil {