advancedtls: Add system default CAs to config function (#3663)

* Add system default CAs to config function
This commit is contained in:
cindyxue 2020-06-27 16:05:33 -07:00 committed by GitHub
parent c95dc4da23
commit 68098483a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 3 deletions

View File

@ -81,6 +81,8 @@ type GetRootCAsResults struct {
// RootCertificateOptions contains a field and a function for obtaining root // RootCertificateOptions contains a field and a function for obtaining root
// trust certificates. // trust certificates.
// It is used by both ClientOptions and ServerOptions. // It is used by both ClientOptions and ServerOptions.
// If users want to use default verification, but did not provide a valid
// RootCertificateOptions, we use the system default trust certificates.
type RootCertificateOptions struct { type RootCertificateOptions struct {
// If field RootCACerts is set, field GetRootCAs will be ignored. RootCACerts // If field RootCACerts is set, field GetRootCAs will be ignored. RootCACerts
// will be used every time when verifying the peer certificates, without // will be used every time when verifying the peer certificates, without
@ -184,15 +186,26 @@ func (o *ClientOptions) config() (*tls.Config, error) {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"client needs to provide custom verification mechanism if choose to skip default verification") "client needs to provide custom verification mechanism if choose to skip default verification")
} }
rootCAs := o.RootCACerts
if o.VType != SkipVerification && o.RootCACerts == nil && o.GetRootCAs == nil {
// Set rootCAs to system default.
systemRootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
rootCAs = systemRootCAs
}
// We have to set InsecureSkipVerify to true to skip the default checks and // We have to set InsecureSkipVerify to true to skip the default checks and
// use the verification function we built from buildVerifyFunc. // use the verification function we built from buildVerifyFunc.
config := &tls.Config{ config := &tls.Config{
ServerName: o.ServerNameOverride, ServerName: o.ServerNameOverride,
Certificates: o.Certificates, Certificates: o.Certificates,
GetClientCertificate: o.GetClientCertificate, GetClientCertificate: o.GetClientCertificate,
RootCAs: o.RootCACerts,
InsecureSkipVerify: true, InsecureSkipVerify: true,
} }
if rootCAs != nil {
config.RootCAs = rootCAs
}
return config, nil return config, nil
} }
@ -204,6 +217,15 @@ func (o *ServerOptions) config() (*tls.Config, error) {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)") "server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)")
} }
clientCAs := o.RootCACerts
if o.VType != SkipVerification && o.RootCACerts == nil && o.GetRootCAs == nil && o.RequireClientCert {
// Set clientCAs to system default.
systemRootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
clientCAs = systemRootCAs
}
clientAuth := tls.NoClientCert clientAuth := tls.NoClientCert
if o.RequireClientCert { if o.RequireClientCert {
// We have to set clientAuth to RequireAnyClientCert to force underlying // We have to set clientAuth to RequireAnyClientCert to force underlying
@ -216,8 +238,8 @@ func (o *ServerOptions) config() (*tls.Config, error) {
Certificates: o.Certificates, Certificates: o.Certificates,
GetCertificate: o.GetCertificate, GetCertificate: o.GetCertificate,
} }
if o.RootCACerts != nil { if clientCAs != nil {
config.ClientCAs = o.RootCACerts config.ClientCAs = clientCAs
} }
return config, nil return config, nil
} }

View File

@ -623,3 +623,59 @@ func TestWrapSyscallConn(t *testing.T) {
wrapConn) wrapConn)
} }
} }
func TestOptionsConfig(t *testing.T) {
serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
testdata.Path("server_key_1.pem"))
if err != nil {
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
}
tests := []struct {
desc string
clientVType VerificationType
serverMutualTLS bool
serverCert []tls.Certificate
serverVType VerificationType
}{
{
desc: "Client uses system-provided RootCAs; server uses system-provided ClientCAs",
clientVType: CertVerification,
serverMutualTLS: true,
serverCert: []tls.Certificate{serverPeerCert},
serverVType: CertAndHostVerification,
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
serverOptions := &ServerOptions{
Certificates: test.serverCert,
RequireClientCert: test.serverMutualTLS,
VType: test.serverVType,
}
serverConfig, err := serverOptions.config()
if err != nil {
t.Fatalf("Unable to generate serverConfig. Error: %v", err)
}
// Verify that the system-provided certificates would be used
// when no verification method was set in serverOptions.
if serverOptions.RootCACerts == nil && serverOptions.GetRootCAs == nil &&
serverOptions.RequireClientCert && serverConfig.ClientCAs == nil {
t.Fatalf("Failed to assign system-provided certificates on the server side.")
}
clientOptions := &ClientOptions{
VType: test.clientVType,
}
clientConfig, err := clientOptions.config()
if err != nil {
t.Fatalf("Unable to generate clientConfig. Error: %v", err)
}
// Verify that the system-provided certificates would be used
// when no verification method was set in clientOptions.
if clientOptions.RootCACerts == nil && clientOptions.GetRootCAs == nil &&
clientConfig.RootCAs == nil {
t.Fatalf("Failed to assign system-provided certificates on the client side.")
}
})
}
}