mirror of https://github.com/grpc/grpc-go.git
advancedtls: populate verified chains when using custom buildVerifyFunc (#7181)
* populate verified chains when using custom buildVerifyFunc
This commit is contained in:
parent
1db6590e40
commit
5ffe0ef48c
|
|
@ -41,6 +41,8 @@ import (
|
|||
credinternal "google.golang.org/grpc/internal/credentials"
|
||||
)
|
||||
|
||||
type CertificateChains [][]*x509.Certificate
|
||||
|
||||
// HandshakeVerificationInfo contains information about a handshake needed for
|
||||
// verification for use when implementing the `PostHandshakeVerificationFunc`
|
||||
// The fields in this struct are read-only.
|
||||
|
|
@ -53,7 +55,7 @@ type HandshakeVerificationInfo struct {
|
|||
RawCerts [][]byte
|
||||
// The verification chain obtained by checking peer RawCerts against the
|
||||
// trust certificate bundle(s), if applicable.
|
||||
VerifiedChains [][]*x509.Certificate
|
||||
VerifiedChains CertificateChains
|
||||
// The leaf certificate sent from peer, if choosing to verify the peer
|
||||
// certificate(s) and that verification passed. This field would be nil if
|
||||
// either user chose not to verify or the verification failed.
|
||||
|
|
@ -552,7 +554,8 @@ func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string
|
|||
if cfg.ServerName == "" {
|
||||
cfg.ServerName = authority
|
||||
}
|
||||
cfg.VerifyPeerCertificate = buildVerifyFunc(c, cfg.ServerName, rawConn)
|
||||
peerVerifiedChains := CertificateChains{}
|
||||
cfg.VerifyPeerCertificate = buildVerifyFunc(c, cfg.ServerName, rawConn, &peerVerifiedChains)
|
||||
conn := tls.Client(rawConn, cfg)
|
||||
errChannel := make(chan error, 1)
|
||||
go func() {
|
||||
|
|
@ -576,12 +579,14 @@ func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string
|
|||
},
|
||||
}
|
||||
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
|
||||
info.State.VerifiedChains = peerVerifiedChains
|
||||
return credinternal.WrapSyscallConn(rawConn, conn), info, nil
|
||||
}
|
||||
|
||||
func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
cfg := credinternal.CloneTLSConfig(c.config)
|
||||
cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn)
|
||||
peerVerifiedChains := CertificateChains{}
|
||||
cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn, &peerVerifiedChains)
|
||||
conn := tls.Server(rawConn, cfg)
|
||||
if err := conn.Handshake(); err != nil {
|
||||
conn.Close()
|
||||
|
|
@ -594,6 +599,7 @@ func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credenti
|
|||
},
|
||||
}
|
||||
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
|
||||
info.State.VerifiedChains = peerVerifiedChains
|
||||
return credinternal.WrapSyscallConn(rawConn, conn), info, nil
|
||||
}
|
||||
|
||||
|
|
@ -618,9 +624,15 @@ func (c *advancedTLSCreds) OverrideServerName(serverNameOverride string) error {
|
|||
// 1. does not have a good support on root cert reloading.
|
||||
// 2. will ignore basic certificate check when setting InsecureSkipVerify
|
||||
// to true.
|
||||
//
|
||||
// peerVerifiedChains(output param): verified chain of certs from leaf to the
|
||||
// trust cert that the peer trusts.
|
||||
// 1. For server it is, client certs + Root ca that the server trusts
|
||||
// 2. For client it is, server certs + Root ca that the client trusts
|
||||
func buildVerifyFunc(c *advancedTLSCreds,
|
||||
serverName string,
|
||||
rawConn net.Conn) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
rawConn net.Conn,
|
||||
peerVerifiedChains *CertificateChains) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
chains := verifiedChains
|
||||
var leafCert *x509.Certificate
|
||||
|
|
@ -684,7 +696,7 @@ func buildVerifyFunc(c *advancedTLSCreds,
|
|||
if c.revocationOptions != nil {
|
||||
verifiedChains := chains
|
||||
if verifiedChains == nil {
|
||||
verifiedChains = [][]*x509.Certificate{rawCertList}
|
||||
verifiedChains = CertificateChains{rawCertList}
|
||||
}
|
||||
if err := checkChainRevocation(verifiedChains, *c.revocationOptions); err != nil {
|
||||
return err
|
||||
|
|
@ -698,8 +710,11 @@ func buildVerifyFunc(c *advancedTLSCreds,
|
|||
VerifiedChains: chains,
|
||||
Leaf: leafCert,
|
||||
})
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*peerVerifiedChains = chains
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
package advancedtls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
|
|
@ -949,6 +950,76 @@ func (s) TestClientServerHandshake(t *testing.T) {
|
|||
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr,
|
||||
clientAuthInfo, serverAuthInfo)
|
||||
}
|
||||
serverVerifiedChains := serverAuthInfo.(credentials.TLSInfo).State.VerifiedChains
|
||||
if test.serverMutualTLS && !test.serverExpectError {
|
||||
if len(serverVerifiedChains) == 0 {
|
||||
t.Fatalf("server verified chains is empty")
|
||||
}
|
||||
var clientCert *tls.Certificate
|
||||
if len(test.clientCert) > 0 {
|
||||
clientCert = &test.clientCert[0]
|
||||
} else if test.clientGetCert != nil {
|
||||
cert, _ := test.clientGetCert(&tls.CertificateRequestInfo{})
|
||||
clientCert = cert
|
||||
} else if test.clientIdentityProvider != nil {
|
||||
km, _ := test.clientIdentityProvider.KeyMaterial(context.TODO())
|
||||
clientCert = &km.Certs[0]
|
||||
}
|
||||
if !bytes.Equal((*serverVerifiedChains[0][0]).Raw, clientCert.Certificate[0]) {
|
||||
t.Fatal("server verifiedChains leaf cert doesn't match client cert")
|
||||
}
|
||||
|
||||
var serverRoot *x509.CertPool
|
||||
if test.serverRoot != nil {
|
||||
serverRoot = test.serverRoot
|
||||
} else if test.serverGetRoot != nil {
|
||||
result, _ := test.serverGetRoot(&GetRootCAsParams{})
|
||||
serverRoot = result.TrustCerts
|
||||
} else if test.serverRootProvider != nil {
|
||||
km, _ := test.serverRootProvider.KeyMaterial(context.TODO())
|
||||
serverRoot = km.Roots
|
||||
}
|
||||
serverVerifiedChainsCp := x509.NewCertPool()
|
||||
serverVerifiedChainsCp.AddCert(serverVerifiedChains[0][len(serverVerifiedChains[0])-1])
|
||||
if !serverVerifiedChainsCp.Equal(serverRoot) {
|
||||
t.Fatalf("server verified chain hierarchy doesn't match")
|
||||
}
|
||||
}
|
||||
clientVerifiedChains := clientAuthInfo.(credentials.TLSInfo).State.VerifiedChains
|
||||
if test.serverMutualTLS && !test.clientExpectHandshakeError {
|
||||
if len(clientVerifiedChains) == 0 {
|
||||
t.Fatalf("client verified chains is empty")
|
||||
}
|
||||
var serverCert *tls.Certificate
|
||||
if len(test.serverCert) > 0 {
|
||||
serverCert = &test.serverCert[0]
|
||||
} else if test.serverGetCert != nil {
|
||||
cert, _ := test.serverGetCert(&tls.ClientHelloInfo{})
|
||||
serverCert = cert[0]
|
||||
} else if test.serverIdentityProvider != nil {
|
||||
km, _ := test.serverIdentityProvider.KeyMaterial(context.TODO())
|
||||
serverCert = &km.Certs[0]
|
||||
}
|
||||
if !bytes.Equal((*clientVerifiedChains[0][0]).Raw, serverCert.Certificate[0]) {
|
||||
t.Fatal("client verifiedChains leaf cert doesn't match server cert")
|
||||
}
|
||||
|
||||
var clientRoot *x509.CertPool
|
||||
if test.clientRoot != nil {
|
||||
clientRoot = test.clientRoot
|
||||
} else if test.clientGetRoot != nil {
|
||||
result, _ := test.clientGetRoot(&GetRootCAsParams{})
|
||||
clientRoot = result.TrustCerts
|
||||
} else if test.clientRootProvider != nil {
|
||||
km, _ := test.clientRootProvider.KeyMaterial(context.TODO())
|
||||
clientRoot = km.Roots
|
||||
}
|
||||
clientVerifiedChainsCp := x509.NewCertPool()
|
||||
clientVerifiedChainsCp.AddCert(clientVerifiedChains[0][len(clientVerifiedChains[0])-1])
|
||||
if !clientVerifiedChainsCp.Equal(clientRoot) {
|
||||
t.Fatalf("client verified chain hierarchy doesn't match")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue