advancedtls: add revocation support to client/server options (#4781)

This commit is contained in:
ZhenLian 2021-09-27 16:42:32 -07:00 committed by GitHub
parent 4555155af2
commit 710419d32b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 29 deletions

View File

@ -181,6 +181,9 @@ type ClientOptions struct {
RootOptions RootCertificateOptions RootOptions RootCertificateOptions
// VType is the verification type on the client side. // VType is the verification type on the client side.
VType VerificationType VType VerificationType
// RevocationConfig is the configurations for certificate revocation checks.
// It could be nil if such checks are not needed.
RevocationConfig *RevocationConfig
} }
// ServerOptions contains the fields needed to be filled by the server. // ServerOptions contains the fields needed to be filled by the server.
@ -199,6 +202,9 @@ type ServerOptions struct {
RequireClientCert bool RequireClientCert bool
// VType is the verification type on the server side. // VType is the verification type on the server side.
VType VerificationType VType VerificationType
// RevocationConfig is the configurations for certificate revocation checks.
// It could be nil if such checks are not needed.
RevocationConfig *RevocationConfig
} }
func (o *ClientOptions) config() (*tls.Config, error) { func (o *ClientOptions) config() (*tls.Config, error) {
@ -356,11 +362,12 @@ func (o *ServerOptions) config() (*tls.Config, error) {
// advancedTLSCreds is the credentials required for authenticating a connection // advancedTLSCreds is the credentials required for authenticating a connection
// using TLS. // using TLS.
type advancedTLSCreds struct { type advancedTLSCreds struct {
config *tls.Config config *tls.Config
verifyFunc CustomVerificationFunc verifyFunc CustomVerificationFunc
getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error) getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error)
isClient bool isClient bool
vType VerificationType vType VerificationType
revocationConfig *RevocationConfig
} }
func (c advancedTLSCreds) Info() credentials.ProtocolInfo { func (c advancedTLSCreds) Info() credentials.ProtocolInfo {
@ -451,6 +458,14 @@ func buildVerifyFunc(c *advancedTLSCreds,
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
chains := verifiedChains chains := verifiedChains
var leafCert *x509.Certificate var leafCert *x509.Certificate
rawCertList := make([]*x509.Certificate, len(rawCerts))
for i, asn1Data := range rawCerts {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return err
}
rawCertList[i] = cert
}
if c.vType == CertAndHostVerification || c.vType == CertVerification { if c.vType == CertAndHostVerification || c.vType == CertVerification {
// perform possible trust credential reloading and certificate check // perform possible trust credential reloading and certificate check
rootCAs := c.config.RootCAs rootCAs := c.config.RootCAs
@ -469,14 +484,6 @@ func buildVerifyFunc(c *advancedTLSCreds,
rootCAs = results.TrustCerts rootCAs = results.TrustCerts
} }
// Verify peers' certificates against RootCAs and get verifiedChains. // Verify peers' certificates against RootCAs and get verifiedChains.
certs := make([]*x509.Certificate, len(rawCerts))
for i, asn1Data := range rawCerts {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return err
}
certs[i] = cert
}
keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
if !c.isClient { if !c.isClient {
keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
@ -487,7 +494,7 @@ func buildVerifyFunc(c *advancedTLSCreds,
Intermediates: x509.NewCertPool(), Intermediates: x509.NewCertPool(),
KeyUsages: keyUsages, KeyUsages: keyUsages,
} }
for _, cert := range certs[1:] { for _, cert := range rawCertList[1:] {
opts.Intermediates.AddCert(cert) opts.Intermediates.AddCert(cert)
} }
// Perform default hostname check if specified. // Perform default hostname check if specified.
@ -501,11 +508,21 @@ func buildVerifyFunc(c *advancedTLSCreds,
opts.DNSName = parsedName opts.DNSName = parsedName
} }
var err error var err error
chains, err = certs[0].Verify(opts) chains, err = rawCertList[0].Verify(opts)
if err != nil { if err != nil {
return err return err
} }
leafCert = certs[0] leafCert = rawCertList[0]
}
// Perform certificate revocation check if specified.
if c.revocationConfig != nil {
verifiedChains := chains
if verifiedChains == nil {
verifiedChains = [][]*x509.Certificate{rawCertList}
}
if err := CheckChainRevocation(verifiedChains, *c.revocationConfig); err != nil {
return err
}
} }
// Perform custom verification check if specified. // Perform custom verification check if specified.
if c.verifyFunc != nil { if c.verifyFunc != nil {
@ -529,11 +546,12 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error)
return nil, err return nil, err
} }
tc := &advancedTLSCreds{ tc := &advancedTLSCreds{
config: conf, config: conf,
isClient: true, isClient: true,
getRootCAs: o.RootOptions.GetRootCertificates, getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.VerifyPeer, verifyFunc: o.VerifyPeer,
vType: o.VType, vType: o.VType,
revocationConfig: o.RevocationConfig,
} }
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
return tc, nil return tc, nil
@ -547,11 +565,12 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error)
return nil, err return nil, err
} }
tc := &advancedTLSCreds{ tc := &advancedTLSCreds{
config: conf, config: conf,
isClient: false, isClient: false,
getRootCAs: o.RootOptions.GetRootCertificates, getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.VerifyPeer, verifyFunc: o.VerifyPeer,
vType: o.VType, vType: o.VType,
revocationConfig: o.RevocationConfig,
} }
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
return tc, nil return tc, nil

View File

@ -380,7 +380,7 @@ func (s) TestEnd2End(t *testing.T) {
} }
clientTLSCreds, err := NewClientCreds(clientOptions) clientTLSCreds, err := NewClientCreds(clientOptions)
if err != nil { if err != nil {
t.Fatalf("clientTLSCreds failed to create") t.Fatalf("clientTLSCreds failed to create: %v", err)
} }
// ------------------------Scenario 1------------------------------------ // ------------------------Scenario 1------------------------------------
// stage = 0, initial connection should succeed // stage = 0, initial connection should succeed
@ -796,7 +796,7 @@ func (s) TestDefaultHostNameCheck(t *testing.T) {
} }
clientTLSCreds, err := NewClientCreds(clientOptions) clientTLSCreds, err := NewClientCreds(clientOptions)
if err != nil { if err != nil {
t.Fatalf("clientTLSCreds failed to create") t.Fatalf("clientTLSCreds failed to create: %v", err)
} }
shouldFail := false shouldFail := false
if test.expectError { if test.expectError {

View File

@ -27,10 +27,12 @@ import (
"net" "net"
"testing" "testing"
lru "github.com/hashicorp/golang-lru"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/credentials/tls/certprovider"
"google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/security/advancedtls/internal/testutils" "google.golang.org/grpc/security/advancedtls/internal/testutils"
"google.golang.org/grpc/security/advancedtls/testdata"
) )
type s struct { type s struct {
@ -339,6 +341,10 @@ func (s) TestClientServerHandshake(t *testing.T) {
getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return nil, fmt.Errorf("bad root certificate reloading") return nil, fmt.Errorf("bad root certificate reloading")
} }
cache, err := lru.New(5)
if err != nil {
t.Fatalf("lru.New: err = %v", err)
}
for _, test := range []struct { for _, test := range []struct {
desc string desc string
clientCert []tls.Certificate clientCert []tls.Certificate
@ -349,6 +355,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
clientVType VerificationType clientVType VerificationType
clientRootProvider certprovider.Provider clientRootProvider certprovider.Provider
clientIdentityProvider certprovider.Provider clientIdentityProvider certprovider.Provider
clientRevocationConfig *RevocationConfig
clientExpectHandshakeError bool clientExpectHandshakeError bool
serverMutualTLS bool serverMutualTLS bool
serverCert []tls.Certificate serverCert []tls.Certificate
@ -359,6 +366,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
serverVType VerificationType serverVType VerificationType
serverRootProvider certprovider.Provider serverRootProvider certprovider.Provider
serverIdentityProvider certprovider.Provider serverIdentityProvider certprovider.Provider
serverRevocationConfig *RevocationConfig
serverExpectError bool serverExpectError bool
}{ }{
// Client: nil setting except verifyFuncGood // Client: nil setting except verifyFuncGood
@ -642,6 +650,30 @@ func (s) TestClientServerHandshake(t *testing.T) {
serverRootProvider: fakeProvider{isClient: false}, serverRootProvider: fakeProvider{isClient: false},
serverVType: CertVerification, serverVType: CertVerification,
}, },
// Client: set valid credentials with the revocation config
// Server: set valid credentials with the revocation config
// Expected Behavior: success, because non of the certificate chains sent in the connection are revoked
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS",
clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
clientRevocationConfig: &RevocationConfig{
RootDir: testdata.Path("crl"),
AllowUndetermined: true,
Cache: cache,
},
serverMutualTLS: true,
serverCert: []tls.Certificate{cs.ServerCert1},
serverGetRoot: getRootCAsForServer,
serverVType: CertVerification,
serverRevocationConfig: &RevocationConfig{
RootDir: testdata.Path("crl"),
AllowUndetermined: true,
Cache: cache,
},
},
} { } {
test := test test := test
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
@ -665,6 +697,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
RequireClientCert: test.serverMutualTLS, RequireClientCert: test.serverMutualTLS,
VerifyPeer: test.serverVerifyFunc, VerifyPeer: test.serverVerifyFunc,
VType: test.serverVType, VType: test.serverVType,
RevocationConfig: test.serverRevocationConfig,
} }
go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) { go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) {
serverRawConn, err := lis.Accept() serverRawConn, err := lis.Accept()
@ -706,7 +739,8 @@ func (s) TestClientServerHandshake(t *testing.T) {
GetRootCertificates: test.clientGetRoot, GetRootCertificates: test.clientGetRoot,
RootProvider: test.clientRootProvider, RootProvider: test.clientRootProvider,
}, },
VType: test.clientVType, VType: test.clientVType,
RevocationConfig: test.clientRevocationConfig,
} }
clientTLS, err := NewClientCreds(clientOptions) clientTLS, err := NewClientCreds(clientOptions)
if err != nil { if err != nil {