advancedtls: Rename custom verification function APIs (#7140)

* Rename custom verification function APIs
This commit is contained in:
Gregory Cooke 2024-04-23 14:20:28 -04:00 committed by GitHub
parent 34de5cf483
commit d75b5e2f5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 103 additions and 60 deletions

View File

@ -35,10 +35,10 @@ import (
credinternal "google.golang.org/grpc/internal/credentials"
)
// VerificationFuncParams contains parameters available to users when
// implementing CustomVerificationFunc.
// HandshakeVerificationInfo contains information about a handshake needed for
// verification for use when implementing the `PostHandshakeVerificationFunc`
// The fields in this struct are read-only.
type VerificationFuncParams struct {
type HandshakeVerificationInfo struct {
// The target server name that the client connects to when establishing the
// connection. This field is only meaningful for client side. On server side,
// this field would be an empty string.
@ -54,17 +54,36 @@ type VerificationFuncParams struct {
Leaf *x509.Certificate
}
// VerificationResults contains the information about results of
// CustomVerificationFunc.
// VerificationResults is an empty struct for now. It may be extended in the
// VerificationFuncParams contains parameters available to users when
// implementing CustomVerificationFunc.
// The fields in this struct are read-only.
//
// Deprecated: use HandshakeVerificationInfo instead.
type VerificationFuncParams = HandshakeVerificationInfo
// PostHandshakeVerificationResults contains the information about results of
// PostHandshakeVerificationFunc.
// PostHandshakeVerificationResults is an empty struct for now. It may be extended in the
// future to include more information.
type VerificationResults struct{}
type PostHandshakeVerificationResults struct{}
// Deprecated: use PostHandshakeVerificationResults instead.
type VerificationResults = PostHandshakeVerificationResults
// PostHandshakeVerificationFunc is the function defined by users to perform
// custom verification checks after chain building and regular handshake
// verification has been completed.
// PostHandshakeVerificationFunc should return (nil, error) if the authorization
// should fail, with the error containing information on why it failed.
type PostHandshakeVerificationFunc func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error)
// CustomVerificationFunc is the function defined by users to perform custom
// verification check.
// CustomVerificationFunc returns nil if the authorization fails; otherwise
// returns an empty struct.
type CustomVerificationFunc func(params *VerificationFuncParams) (*VerificationResults, error)
//
// Deprecated: use PostHandshakeVerificationFunc instead.
type CustomVerificationFunc = PostHandshakeVerificationFunc
// GetRootCAsParams contains the parameters available to users when
// implementing GetRootCAs.
@ -167,11 +186,18 @@ type ClientOptions struct {
// IdentityOptions is OPTIONAL on client side. This field only needs to be
// set if mutual authentication is required on server side.
IdentityOptions IdentityCertificateOptions
// AdditionalPeerVerification is a custom verification check after certificate signature
// check.
// If this is set, we will perform this customized check after doing the
// normal check(s) indicated by setting VerificationType.
AdditionalPeerVerification PostHandshakeVerificationFunc
// VerifyPeer is a custom verification check after certificate signature
// check.
// If this is set, we will perform this customized check after doing the
// normal check(s) indicated by setting VType.
VerifyPeer CustomVerificationFunc
// normal check(s) indicated by setting VerificationType.
//
// Deprecated: use AdditionalPeerVerification instead.
VerifyPeer PostHandshakeVerificationFunc
// RootOptions is OPTIONAL on client side. If not set, we will try to use the
// default trust certificates in users' OS system.
RootOptions RootCertificateOptions
@ -206,11 +232,18 @@ type ClientOptions struct {
type ServerOptions struct {
// IdentityOptions is REQUIRED on server side.
IdentityOptions IdentityCertificateOptions
// AdditionalPeerVerification is a custom verification check after certificate signature
// check.
// If this is set, we will perform this customized check after doing the
// normal check(s) indicated by setting VerificationType.
AdditionalPeerVerification PostHandshakeVerificationFunc
// VerifyPeer is a custom verification check after certificate signature
// check.
// If this is set, we will perform this customized check after doing the
// normal check(s) indicated by setting VType.
VerifyPeer CustomVerificationFunc
// normal check(s) indicated by setting VerificationType.
//
// Deprecated: use AdditionalPeerVerification instead.
VerifyPeer PostHandshakeVerificationFunc
// RootOptions is OPTIONAL on server side. This field only needs to be set if
// mutual authentication is required(RequireClientCert is true).
RootOptions RootCertificateOptions
@ -239,13 +272,18 @@ type ServerOptions struct {
}
func (o *ClientOptions) config() (*tls.Config, error) {
// TODO(gtcooke94) Remove this block when o.VerifyPeer is remoed.
// VerifyPeer is deprecated, but do this to aid the transitory migration time.
if o.AdditionalPeerVerification == nil {
o.AdditionalPeerVerification = o.VerifyPeer
}
// TODO(gtcooke94). VType is deprecated, eventually remove this block. This
// will ensure that users still explicitly setting `VType` will get the
// setting to the right place.
if o.VType != CertAndHostVerification {
o.VerificationType = o.VType
}
if o.VerificationType == SkipVerification && o.VerifyPeer == nil {
if o.VerificationType == SkipVerification && o.AdditionalPeerVerification == nil {
return nil, fmt.Errorf("client needs to provide custom verification mechanism if choose to skip default verification")
}
// Make sure users didn't specify more than one fields in
@ -321,13 +359,18 @@ func (o *ClientOptions) config() (*tls.Config, error) {
}
func (o *ServerOptions) config() (*tls.Config, error) {
// TODO(gtcooke94) Remove this block when o.VerifyPeer is remoed.
// VerifyPeer is deprecated, but do this to aid the transitory migration time.
if o.AdditionalPeerVerification == nil {
o.AdditionalPeerVerification = o.VerifyPeer
}
// TODO(gtcooke94). VType is deprecated, eventually remove this block. This
// will ensure that users still explicitly setting `VType` will get the
// setting to the right place.
if o.VType != CertAndHostVerification {
o.VerificationType = o.VType
}
if o.RequireClientCert && o.VerificationType == SkipVerification && o.VerifyPeer == nil {
if o.RequireClientCert && o.VerificationType == SkipVerification && o.AdditionalPeerVerification == nil {
return nil, fmt.Errorf("server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)")
}
// Make sure users didn't specify more than one fields in
@ -416,7 +459,7 @@ func (o *ServerOptions) config() (*tls.Config, error) {
// using TLS.
type advancedTLSCreds struct {
config *tls.Config
verifyFunc CustomVerificationFunc
verifyFunc PostHandshakeVerificationFunc
getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error)
isClient bool
verificationType VerificationType
@ -579,7 +622,7 @@ func buildVerifyFunc(c *advancedTLSCreds,
}
// Perform custom verification check if specified.
if c.verifyFunc != nil {
_, err := c.verifyFunc(&VerificationFuncParams{
_, err := c.verifyFunc(&HandshakeVerificationInfo{
ServerName: serverName,
RawCerts: rawCerts,
VerifiedChains: chains,
@ -602,7 +645,7 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error)
config: conf,
isClient: true,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.VerifyPeer,
verifyFunc: o.AdditionalPeerVerification,
verificationType: o.VerificationType,
revocationConfig: o.RevocationConfig,
}
@ -621,7 +664,7 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error)
config: conf,
isClient: false,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.VerifyPeer,
verifyFunc: o.AdditionalPeerVerification,
verificationType: o.VerificationType,
revocationConfig: o.RevocationConfig,
}

View File

@ -143,13 +143,13 @@ func (s) TestEnd2End(t *testing.T) {
clientGetCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
clientRoot *x509.CertPool
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
clientVerifyFunc CustomVerificationFunc
clientVerifyFunc PostHandshakeVerificationFunc
clientVerificationType VerificationType
serverCert []tls.Certificate
serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
serverRoot *x509.CertPool
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
serverVerifyFunc CustomVerificationFunc
serverVerifyFunc PostHandshakeVerificationFunc
serverVerificationType VerificationType
}{
// Test Scenarios:
@ -175,8 +175,8 @@ func (s) TestEnd2End(t *testing.T) {
}
},
clientRoot: cs.ClientTrust1,
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
clientVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
return &PostHandshakeVerificationResults{}, nil
},
clientVerificationType: CertVerification,
serverCert: []tls.Certificate{cs.ServerCert1},
@ -188,8 +188,8 @@ func (s) TestEnd2End(t *testing.T) {
return &GetRootCAsResults{TrustCerts: cs.ServerTrust2}, nil
}
},
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
serverVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
return &PostHandshakeVerificationResults{}, nil
},
serverVerificationType: CertVerification,
},
@ -216,8 +216,8 @@ func (s) TestEnd2End(t *testing.T) {
return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
}
},
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
clientVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
return &PostHandshakeVerificationResults{}, nil
},
clientVerificationType: CertVerification,
serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
@ -229,8 +229,8 @@ func (s) TestEnd2End(t *testing.T) {
}
},
serverRoot: cs.ServerTrust1,
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
serverVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
return &PostHandshakeVerificationResults{}, nil
},
serverVerificationType: CertVerification,
},
@ -258,7 +258,7 @@ func (s) TestEnd2End(t *testing.T) {
return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
}
},
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
clientVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
if len(params.RawCerts) == 0 {
return nil, fmt.Errorf("no peer certs")
}
@ -280,7 +280,7 @@ func (s) TestEnd2End(t *testing.T) {
}
}
if authzCheck {
return &VerificationResults{}, nil
return &PostHandshakeVerificationResults{}, nil
}
return nil, fmt.Errorf("custom authz check fails")
},
@ -294,8 +294,8 @@ func (s) TestEnd2End(t *testing.T) {
}
},
serverRoot: cs.ServerTrust1,
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
serverVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
return &PostHandshakeVerificationResults{}, nil
},
serverVerificationType: CertVerification,
},
@ -314,16 +314,16 @@ func (s) TestEnd2End(t *testing.T) {
desc: "TestServerCustomVerification",
clientCert: []tls.Certificate{cs.ClientCert1},
clientRoot: cs.ClientTrust1,
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
clientVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
return &PostHandshakeVerificationResults{}, nil
},
clientVerificationType: CertVerification,
serverCert: []tls.Certificate{cs.ServerCert1},
serverRoot: cs.ServerTrust1,
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
serverVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
switch stage.read() {
case 0, 2:
return &VerificationResults{}, nil
return &PostHandshakeVerificationResults{}, nil
case 1:
return nil, fmt.Errorf("custom authz check fails")
default:
@ -345,9 +345,9 @@ func (s) TestEnd2End(t *testing.T) {
RootCACerts: test.serverRoot,
GetRootCertificates: test.serverGetRoot,
},
RequireClientCert: true,
VerifyPeer: test.serverVerifyFunc,
VerificationType: test.serverVerificationType,
RequireClientCert: true,
AdditionalPeerVerification: test.serverVerifyFunc,
VerificationType: test.serverVerificationType,
}
serverTLSCreds, err := NewServerCreds(serverOptions)
if err != nil {
@ -368,7 +368,7 @@ func (s) TestEnd2End(t *testing.T) {
Certificates: test.clientCert,
GetIdentityCertificatesForClient: test.clientGetCert,
},
VerifyPeer: test.clientVerifyFunc,
AdditionalPeerVerification: test.clientVerifyFunc,
RootOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
GetRootCertificates: test.clientGetRoot,
@ -635,8 +635,8 @@ func (s) TestPEMFileProviderEnd2End(t *testing.T) {
RootProvider: serverRootProvider,
},
RequireClientCert: true,
VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
AdditionalPeerVerification: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
return &PostHandshakeVerificationResults{}, nil
},
VerificationType: CertVerification,
}
@ -658,8 +658,8 @@ func (s) TestPEMFileProviderEnd2End(t *testing.T) {
IdentityOptions: IdentityCertificateOptions{
IdentityProvider: clientIdentityProvider,
},
VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
AdditionalPeerVerification: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
return &PostHandshakeVerificationResults{}, nil
},
RootOptions: RootCertificateOptions{
RootProvider: clientRootProvider,

View File

@ -369,7 +369,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
}
clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) {
clientVerifyFuncGood := func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
if params.ServerName == "" {
return nil, errors.New("client side server name should have a value")
}
@ -378,15 +378,15 @@ func (s) TestClientServerHandshake(t *testing.T) {
return nil, errors.New("client side params parsing error")
}
return &VerificationResults{}, nil
return &PostHandshakeVerificationResults{}, nil
}
verifyFuncBad := func(params *VerificationFuncParams) (*VerificationResults, error) {
verifyFuncBad := func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
return nil, fmt.Errorf("custom verification function failed")
}
getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil
}
serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) {
serverVerifyFunc := func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
if params.ServerName != "" {
return nil, errors.New("server side server name should not have a value")
}
@ -395,7 +395,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
return nil, errors.New("server side params parsing error")
}
return &VerificationResults{}, nil
return &PostHandshakeVerificationResults{}, nil
}
getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return nil, fmt.Errorf("bad root certificate reloading")
@ -431,7 +431,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
clientGetCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
clientRoot *x509.CertPool
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
clientVerifyFunc CustomVerificationFunc
clientVerifyFunc PostHandshakeVerificationFunc
clientVerificationType VerificationType
clientRootProvider certprovider.Provider
clientIdentityProvider certprovider.Provider
@ -442,7 +442,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
serverRoot *x509.CertPool
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
serverVerifyFunc CustomVerificationFunc
serverVerifyFunc PostHandshakeVerificationFunc
serverVerificationType VerificationType
serverRootProvider certprovider.Provider
serverIdentityProvider certprovider.Provider
@ -822,10 +822,10 @@ func (s) TestClientServerHandshake(t *testing.T) {
GetRootCertificates: test.serverGetRoot,
RootProvider: test.serverRootProvider,
},
RequireClientCert: test.serverMutualTLS,
VerifyPeer: test.serverVerifyFunc,
VerificationType: test.serverVerificationType,
RevocationConfig: test.serverRevocationConfig,
RequireClientCert: test.serverMutualTLS,
AdditionalPeerVerification: test.serverVerifyFunc,
VerificationType: test.serverVerificationType,
RevocationConfig: test.serverRevocationConfig,
}
go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) {
serverRawConn, err := lis.Accept()
@ -861,7 +861,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
GetIdentityCertificatesForClient: test.clientGetCert,
IdentityProvider: test.clientIdentityProvider,
},
VerifyPeer: test.clientVerifyFunc,
AdditionalPeerVerification: test.clientVerifyFunc,
RootOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
GetRootCertificates: test.clientGetRoot,

View File

@ -76,8 +76,8 @@ func main() {
IdentityOptions: advancedtls.IdentityCertificateOptions{
IdentityProvider: identityProvider,
},
VerifyPeer: func(params *advancedtls.VerificationFuncParams) (*advancedtls.VerificationResults, error) {
return &advancedtls.VerificationResults{}, nil
AdditionalPeerVerification: func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) {
return &advancedtls.PostHandshakeVerificationResults{}, nil
},
RootOptions: advancedtls.RootCertificateOptions{
RootProvider: rootProvider,

View File

@ -84,10 +84,10 @@ func main() {
RootProvider: rootProvider,
},
RequireClientCert: true,
VerifyPeer: func(params *advancedtls.VerificationFuncParams) (*advancedtls.VerificationResults, error) {
AdditionalPeerVerification: func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) {
// This message is to show the certificate under the hood is actually reloaded.
fmt.Printf("Client common name: %s.\n", params.Leaf.Subject.CommonName)
return &advancedtls.VerificationResults{}, nil
return &advancedtls.PostHandshakeVerificationResults{}, nil
},
VerificationType: advancedtls.CertVerification,
}