mirror of https://github.com/grpc/grpc-go.git
xds: add support for mTLS Credentials in xDS bootstrap (#6757)
This commit is contained in:
parent
71cc0f1675
commit
6bc19068a7
|
|
@ -29,7 +29,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
pluginName = "file_watcher"
|
||||
PluginName = "file_watcher"
|
||||
defaultRefreshInterval = 10 * time.Minute
|
||||
)
|
||||
|
||||
|
|
@ -48,13 +48,13 @@ func (p *pluginBuilder) ParseConfig(c any) (*certprovider.BuildableConfig, error
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return certprovider.NewBuildableConfig(pluginName, opts.canonical(), func(certprovider.BuildOptions) certprovider.Provider {
|
||||
return certprovider.NewBuildableConfig(PluginName, opts.canonical(), func(certprovider.BuildOptions) certprovider.Provider {
|
||||
return newProvider(opts)
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (p *pluginBuilder) Name() string {
|
||||
return pluginName
|
||||
return PluginName
|
||||
}
|
||||
|
||||
func pluginConfigFromJSON(jd json.RawMessage) (Options, error) {
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ func CreateClientTLSCredentials(t *testing.T) credentials.TransportCredentials {
|
|||
|
||||
// CreateServerTLSCredentials creates server-side TLS transport credentials
|
||||
// using certificate and key files from testdata/x509 directory.
|
||||
func CreateServerTLSCredentials(t *testing.T) credentials.TransportCredentials {
|
||||
func CreateServerTLSCredentials(t *testing.T, clientAuth tls.ClientAuthType) credentials.TransportCredentials {
|
||||
t.Helper()
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
|
||||
|
|
@ -114,7 +114,7 @@ func CreateServerTLSCredentials(t *testing.T) credentials.TransportCredentials {
|
|||
t.Fatal("Failed to append certificates")
|
||||
}
|
||||
return credentials.NewTLS(&tls.Config{
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientAuth: clientAuth,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
ClientCAs: ca,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ package xds_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
|
@ -226,7 +227,7 @@ func (s) TestClientSideXDS_WithValidAndInvalidSecurityConfiguration(t *testing.T
|
|||
// backend1 configured with TLS creds, represents cluster1
|
||||
// backend2 configured with insecure creds, represents cluster2
|
||||
// backend3 configured with insecure creds, represents cluster3
|
||||
creds := e2e.CreateServerTLSCredentials(t)
|
||||
creds := e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert)
|
||||
server1 := stubserver.StartTestService(t, nil, grpc.Creds(creds))
|
||||
defer server1.Stop()
|
||||
server2 := stubserver.StartTestService(t, nil)
|
||||
|
|
|
|||
|
|
@ -37,8 +37,10 @@ var registry = make(map[string]Credentials)
|
|||
// Credentials interface encapsulates a credentials.Bundle builder
|
||||
// that can be used for communicating with the xDS Management server.
|
||||
type Credentials interface {
|
||||
// Build returns a credential bundle associated with this credential.
|
||||
Build(config json.RawMessage) (credentials.Bundle, error)
|
||||
// Build returns a credential bundle associated with this credential, and
|
||||
// a function to cleans up additional resources associated with this bundle
|
||||
// when it is no longer needed.
|
||||
Build(config json.RawMessage) (credentials.Bundle, func(), error)
|
||||
// Name returns the credential name associated with this credential.
|
||||
Name() string
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,9 +36,9 @@ type testCredsBuilder struct {
|
|||
config json.RawMessage
|
||||
}
|
||||
|
||||
func (t *testCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, error) {
|
||||
func (t *testCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, func(), error) {
|
||||
t.config = config
|
||||
return nil, nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (t *testCredsBuilder) Name() string {
|
||||
|
|
@ -53,7 +53,7 @@ func TestRegisterNew(t *testing.T) {
|
|||
|
||||
const sampleConfig = "sample_config"
|
||||
rawMessage := json.RawMessage(sampleConfig)
|
||||
if _, err := c.Build(rawMessage); err != nil {
|
||||
if _, _, err := c.Build(rawMessage); err != nil {
|
||||
t.Errorf("Build(%v) error = %v, want nil", rawMessage, err)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ import (
|
|||
"google.golang.org/grpc/internal/envconfig"
|
||||
"google.golang.org/grpc/internal/pretty"
|
||||
"google.golang.org/grpc/xds/bootstrap"
|
||||
"google.golang.org/grpc/xds/internal/xdsclient/tlscreds"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -60,6 +61,7 @@ const (
|
|||
func init() {
|
||||
bootstrap.RegisterCredentials(&insecureCredsBuilder{})
|
||||
bootstrap.RegisterCredentials(&googleDefaultCredsBuilder{})
|
||||
bootstrap.RegisterCredentials(&tlsCredsBuilder{})
|
||||
}
|
||||
|
||||
// For overriding in unit tests.
|
||||
|
|
@ -69,20 +71,32 @@ var bootstrapFileReadFunc = os.ReadFile
|
|||
// package `xds/bootstrap` and encapsulates an insecure credential.
|
||||
type insecureCredsBuilder struct{}
|
||||
|
||||
func (i *insecureCredsBuilder) Build(json.RawMessage) (credentials.Bundle, error) {
|
||||
return insecure.NewBundle(), nil
|
||||
func (i *insecureCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) {
|
||||
return insecure.NewBundle(), func() {}, nil
|
||||
}
|
||||
|
||||
func (i *insecureCredsBuilder) Name() string {
|
||||
return "insecure"
|
||||
}
|
||||
|
||||
// tlsCredsBuilder implements the `Credentials` interface defined in
|
||||
// package `xds/bootstrap` and encapsulates a TLS credential.
|
||||
type tlsCredsBuilder struct{}
|
||||
|
||||
func (t *tlsCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, func(), error) {
|
||||
return tlscreds.NewBundle(config)
|
||||
}
|
||||
|
||||
func (t *tlsCredsBuilder) Name() string {
|
||||
return "tls"
|
||||
}
|
||||
|
||||
// googleDefaultCredsBuilder implements the `Credentials` interface defined in
|
||||
// package `xds/boostrap` and encapsulates a Google Default credential.
|
||||
type googleDefaultCredsBuilder struct{}
|
||||
|
||||
func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, error) {
|
||||
return google.NewDefaultCredentials(), nil
|
||||
func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) {
|
||||
return google.NewDefaultCredentials(), func() {}, nil
|
||||
}
|
||||
|
||||
func (d *googleDefaultCredsBuilder) Name() string {
|
||||
|
|
@ -151,6 +165,10 @@ type ServerConfig struct {
|
|||
// when a resource is deleted, nor will it remove the existing resource value
|
||||
// from its cache.
|
||||
IgnoreResourceDeletion bool
|
||||
|
||||
// Cleanups are called when the xDS client for this server is closed. Allows
|
||||
// cleaning up resources created specifically for this ServerConfig.
|
||||
Cleanups []func()
|
||||
}
|
||||
|
||||
// CredsDialOption returns the configured credentials as a grpc dial option.
|
||||
|
|
@ -206,12 +224,13 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error {
|
|||
if c == nil {
|
||||
continue
|
||||
}
|
||||
bundle, err := c.Build(cc.Config)
|
||||
bundle, cancel, err := c.Build(cc.Config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build credentials bundle from bootstrap for %q: %v", cc.Type, err)
|
||||
}
|
||||
sc.Creds = ChannelCreds(cc)
|
||||
sc.credsDialOption = grpc.WithCredentialsBundle(bundle)
|
||||
sc.Cleanups = append(sc.Cleanups, cancel)
|
||||
break
|
||||
}
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -1008,30 +1008,53 @@ func TestServerConfigMarshalAndUnmarshal(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDefaultBundles(t *testing.T) {
|
||||
if c := bootstrap.GetCredentials("google_default"); c == nil {
|
||||
t.Errorf(`bootstrap.GetCredentials("google_default") credential is nil, want non-nil`)
|
||||
}
|
||||
tests := []string{"google_default", "insecure", "tls"}
|
||||
|
||||
if c := bootstrap.GetCredentials("insecure"); c == nil {
|
||||
t.Errorf(`bootstrap.GetCredentials("insecure") credential is nil, want non-nil`)
|
||||
for _, typename := range tests {
|
||||
t.Run(typename, func(t *testing.T) {
|
||||
if c := bootstrap.GetCredentials(typename); c == nil {
|
||||
t.Errorf(`bootstrap.GetCredentials(%s) credential is nil, want non-nil`, typename)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredsBuilders(t *testing.T) {
|
||||
b := &googleDefaultCredsBuilder{}
|
||||
if _, err := b.Build(nil); err != nil {
|
||||
t.Errorf("googleDefaultCredsBuilder.Build failed: %v", err)
|
||||
}
|
||||
if got, want := b.Name(), "google_default"; got != want {
|
||||
t.Errorf("googleDefaultCredsBuilder.Name = %v, want %v", got, want)
|
||||
tests := []struct {
|
||||
typename string
|
||||
builder bootstrap.Credentials
|
||||
}{
|
||||
{"google_default", &googleDefaultCredsBuilder{}},
|
||||
{"insecure", &insecureCredsBuilder{}},
|
||||
{"tls", &tlsCredsBuilder{}},
|
||||
}
|
||||
|
||||
i := &insecureCredsBuilder{}
|
||||
if _, err := i.Build(nil); err != nil {
|
||||
t.Errorf("insecureCredsBuilder.Build failed: %v", err)
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.typename, func(t *testing.T) {
|
||||
if got, want := test.builder.Name(), test.typename; got != want {
|
||||
t.Errorf("%T.Name = %v, want %v", test.builder, got, want)
|
||||
}
|
||||
|
||||
if got, want := i.Name(), "insecure"; got != want {
|
||||
t.Errorf("insecureCredsBuilder.Name = %v, want %v", got, want)
|
||||
_, stop, err := test.builder.Build(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%T.Build failed: %v", test.builder, err)
|
||||
}
|
||||
stop()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTlsCredsBuilder(t *testing.T) {
|
||||
tls := &tlsCredsBuilder{}
|
||||
_, stop, err := tls.Build(json.RawMessage(`{}`))
|
||||
if err != nil {
|
||||
t.Fatalf("tls.Build() failed with error %s when expected to succeed", err)
|
||||
}
|
||||
stop()
|
||||
|
||||
if _, stop, err := tls.Build(json.RawMessage(`{"ca_certificate_file":"/ca_certificates.pem","refresh_interval": "asdf"}`)); err == nil {
|
||||
t.Errorf("tls.Build() succeeded with an invalid refresh interval, when expected to fail")
|
||||
stop()
|
||||
}
|
||||
// package internal/xdsclient/tlscreds has tests for config validity.
|
||||
}
|
||||
|
|
|
|||
|
|
@ -85,5 +85,17 @@ func (c *clientImpl) close() {
|
|||
c.authorityMu.Unlock()
|
||||
c.serializerClose()
|
||||
|
||||
for _, f := range c.config.XDSServer.Cleanups {
|
||||
f()
|
||||
}
|
||||
for _, a := range c.config.Authorities {
|
||||
if a.XDSServer == nil {
|
||||
// The server for this authority is the top-level one, cleaned up above.
|
||||
continue
|
||||
}
|
||||
for _, f := range a.XDSServer.Cleanups {
|
||||
f()
|
||||
}
|
||||
}
|
||||
c.logger.Infof("Shutdown")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,138 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2023 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
// Package tlscreds implements mTLS Credentials in xDS Bootstrap File.
|
||||
// See gRFC A65: github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md.
|
||||
package tlscreds
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/tls/certprovider"
|
||||
"google.golang.org/grpc/credentials/tls/certprovider/pemfile"
|
||||
"google.golang.org/grpc/internal/grpcsync"
|
||||
)
|
||||
|
||||
// bundle is an implementation of credentials.Bundle which implements mTLS
|
||||
// Credentials in xDS Bootstrap File.
|
||||
type bundle struct {
|
||||
transportCredentials credentials.TransportCredentials
|
||||
}
|
||||
|
||||
// NewBundle returns a credentials.Bundle which implements mTLS Credentials in xDS
|
||||
// Bootstrap File. It delegates certificate loading to a file_watcher provider
|
||||
// if either client certificates or server root CA is specified. The second
|
||||
// return value is a close func that should be called when the caller no longer
|
||||
// needs this bundle.
|
||||
// See gRFC A65: github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md
|
||||
func NewBundle(jd json.RawMessage) (credentials.Bundle, func(), error) {
|
||||
cfg := &struct {
|
||||
CertificateFile string `json:"certificate_file"`
|
||||
CACertificateFile string `json:"ca_certificate_file"`
|
||||
PrivateKeyFile string `json:"private_key_file"`
|
||||
}{}
|
||||
|
||||
if jd != nil {
|
||||
if err := json.Unmarshal(jd, cfg); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
} // Else the config field is absent. Treat it as an empty config.
|
||||
|
||||
if cfg.CACertificateFile == "" && cfg.CertificateFile == "" && cfg.PrivateKeyFile == "" {
|
||||
// We cannot use (and do not need) a file_watcher provider in this case,
|
||||
// and can simply directly use the TLS transport credentials.
|
||||
// Quoting A65:
|
||||
//
|
||||
// > The only difference between the file-watcher certificate provider
|
||||
// > config and this one is that in the file-watcher certificate
|
||||
// > provider, at least one of the "certificate_file" or
|
||||
// > "ca_certificate_file" fields must be specified, whereas in this
|
||||
// > configuration, it is acceptable to specify neither one.
|
||||
return &bundle{transportCredentials: credentials.NewTLS(&tls.Config{})}, func() {}, nil
|
||||
}
|
||||
// Otherwise we need to use a file_watcher provider to watch the CA,
|
||||
// private and public keys.
|
||||
|
||||
// The pemfile plugin (file_watcher) currently ignores BuildOptions.
|
||||
provider, err := certprovider.GetProvider(pemfile.PluginName, jd, certprovider.BuildOptions{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &bundle{
|
||||
transportCredentials: &reloadingCreds{provider: provider},
|
||||
}, grpcsync.OnceFunc(func() { provider.Close() }), nil
|
||||
}
|
||||
|
||||
func (t *bundle) TransportCredentials() credentials.TransportCredentials {
|
||||
return t.transportCredentials
|
||||
}
|
||||
|
||||
func (t *bundle) PerRPCCredentials() credentials.PerRPCCredentials {
|
||||
// mTLS provides transport credentials only. There are no per-RPC
|
||||
// credentials.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *bundle) NewWithMode(string) (credentials.Bundle, error) {
|
||||
// This bundle has a single mode which only uses TLS transport credentials,
|
||||
// so there is no legitimate case where callers would call NewWithMode.
|
||||
return nil, fmt.Errorf("xDS TLS credentials only support one mode")
|
||||
}
|
||||
|
||||
// reloadingCreds is a credentials.TransportCredentials for client
|
||||
// side mTLS that reloads the server root CA certificate and the client
|
||||
// certificates from the provider on every client handshake. This is necessary
|
||||
// because the standard TLS credentials do not support reloading CA
|
||||
// certificates.
|
||||
type reloadingCreds struct {
|
||||
provider certprovider.Provider
|
||||
}
|
||||
|
||||
func (c *reloadingCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
km, err := c.provider.KeyMaterial(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
config := &tls.Config{
|
||||
RootCAs: km.Roots,
|
||||
Certificates: km.Certs,
|
||||
}
|
||||
return credentials.NewTLS(config).ClientHandshake(ctx, authority, rawConn)
|
||||
}
|
||||
|
||||
func (c *reloadingCreds) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{SecurityProtocol: "tls"}
|
||||
}
|
||||
|
||||
func (c *reloadingCreds) Clone() credentials.TransportCredentials {
|
||||
return &reloadingCreds{provider: c.provider}
|
||||
}
|
||||
|
||||
func (c *reloadingCreds) OverrideServerName(string) error {
|
||||
return errors.New("overriding server name is not supported by xDS client TLS credentials")
|
||||
}
|
||||
|
||||
func (c *reloadingCreds) ServerHandshake(net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return nil, nil, errors.New("server handshake is not supported by xDS client TLS credentials")
|
||||
}
|
||||
|
|
@ -0,0 +1,253 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2023 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package tlscreds_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
"google.golang.org/grpc/internal/stubserver"
|
||||
"google.golang.org/grpc/internal/testutils/xds/e2e"
|
||||
testgrpc "google.golang.org/grpc/interop/grpc_testing"
|
||||
testpb "google.golang.org/grpc/interop/grpc_testing"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/grpc/testdata"
|
||||
"google.golang.org/grpc/xds/internal/xdsclient/tlscreds"
|
||||
)
|
||||
|
||||
const defaultTestTimeout = 5 * time.Second
|
||||
|
||||
type s struct {
|
||||
grpctest.Tester
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
grpctest.RunSubTests(t, s{})
|
||||
}
|
||||
|
||||
type Closable interface {
|
||||
Close()
|
||||
}
|
||||
|
||||
func (s) TestValidTlsBuilder(t *testing.T) {
|
||||
caCert := testdata.Path("x509/server_ca_cert.pem")
|
||||
clientCert := testdata.Path("x509/client1_cert.pem")
|
||||
clientKey := testdata.Path("x509/client1_key.pem")
|
||||
tests := []struct {
|
||||
name string
|
||||
jd string
|
||||
}{
|
||||
{
|
||||
name: "Absent configuration",
|
||||
jd: `null`,
|
||||
},
|
||||
{
|
||||
name: "Empty configuration",
|
||||
jd: `{}`,
|
||||
},
|
||||
{
|
||||
name: "Only CA certificate chain",
|
||||
jd: fmt.Sprintf(`{"ca_certificate_file": "%s"}`, caCert),
|
||||
},
|
||||
{
|
||||
name: "Only private key and certificate chain",
|
||||
jd: fmt.Sprintf(`{"certificate_file":"%s","private_key_file":"%s"}`, clientCert, clientKey),
|
||||
},
|
||||
{
|
||||
name: "CA chain, private key and certificate chain",
|
||||
jd: fmt.Sprintf(`{"ca_certificate_file":"%s","certificate_file":"%s","private_key_file":"%s"}`, caCert, clientCert, clientKey),
|
||||
},
|
||||
{
|
||||
name: "Only refresh interval", jd: `{"refresh_interval": "1s"}`,
|
||||
},
|
||||
{
|
||||
name: "Refresh interval and CA certificate chain",
|
||||
jd: fmt.Sprintf(`{"refresh_interval": "1s","ca_certificate_file": "%s"}`, caCert),
|
||||
},
|
||||
{
|
||||
name: "Refresh interval, private key and certificate chain",
|
||||
jd: fmt.Sprintf(`{"refresh_interval": "1s","certificate_file":"%s","private_key_file":"%s"}`, clientCert, clientKey),
|
||||
},
|
||||
{
|
||||
name: "Refresh interval, CA chain, private key and certificate chain",
|
||||
jd: fmt.Sprintf(`{"refresh_interval": "1s","ca_certificate_file":"%s","certificate_file":"%s","private_key_file":"%s"}`, caCert, clientCert, clientKey),
|
||||
},
|
||||
{
|
||||
name: "Unknown field",
|
||||
jd: `{"unknown_field": "foo"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
msg := json.RawMessage(test.jd)
|
||||
_, stop, err := tlscreds.NewBundle(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err)
|
||||
}
|
||||
stop()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestInvalidTlsBuilder(t *testing.T) {
|
||||
tests := []struct {
|
||||
name, jd, wantErrPrefix string
|
||||
}{
|
||||
{
|
||||
name: "Wrong type in json",
|
||||
jd: `{"ca_certificate_file": 1}`,
|
||||
wantErrPrefix: "failed to unmarshal config:"},
|
||||
{
|
||||
name: "Missing private key",
|
||||
jd: fmt.Sprintf(`{"certificate_file":"%s"}`, testdata.Path("x509/server_cert.pem")),
|
||||
wantErrPrefix: "pemfile: private key file and identity cert file should be both specified or not specified",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
msg := json.RawMessage(test.jd)
|
||||
_, stop, err := tlscreds.NewBundle(msg)
|
||||
if err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) {
|
||||
if stop != nil {
|
||||
stop()
|
||||
}
|
||||
t.Fatalf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestCaReloading(t *testing.T) {
|
||||
serverCa, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read test CA cert: %s", err)
|
||||
}
|
||||
|
||||
// Write CA certs to a temporary file so that we can modify it later.
|
||||
caPath := t.TempDir() + "/ca.pem"
|
||||
if err = os.WriteFile(caPath, serverCa, 0644); err != nil {
|
||||
t.Fatalf("Failed to write test CA cert: %v", err)
|
||||
}
|
||||
cfg := fmt.Sprintf(`{
|
||||
"ca_certificate_file": "%s",
|
||||
"refresh_interval": ".01s"
|
||||
}`, caPath)
|
||||
tlsBundle, stop, err := tlscreds.NewBundle([]byte(cfg))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create TLS bundle: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
serverCredentials := grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.NoClientCert))
|
||||
server := stubserver.StartTestService(t, nil, serverCredentials)
|
||||
|
||||
conn, err := grpc.Dial(
|
||||
server.Address,
|
||||
grpc.WithCredentialsBundle(tlsBundle),
|
||||
grpc.WithAuthority("x.test.example.com"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Error dialing: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
|
||||
client := testgrpc.NewTestServiceClient(conn)
|
||||
if _, err = client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
|
||||
t.Errorf("Error calling EmptyCall: %v", err)
|
||||
}
|
||||
// close the server and create a new one to force client to do a new
|
||||
// handshake.
|
||||
server.Stop()
|
||||
|
||||
invalidCa, err := os.ReadFile(testdata.Path("ca.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read test CA cert: %v", err)
|
||||
}
|
||||
// unload root cert
|
||||
err = os.WriteFile(caPath, invalidCa, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test CA cert: %v", err)
|
||||
}
|
||||
|
||||
for ; ctx.Err() == nil; <-time.After(10 * time.Millisecond) {
|
||||
ss := stubserver.StubServer{
|
||||
Address: server.Address,
|
||||
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil },
|
||||
}
|
||||
server = stubserver.StartTestService(t, &ss, serverCredentials)
|
||||
|
||||
// Client handshake should eventually fail because the client CA was
|
||||
// reloaded, and thus the server cert is signed by an unknown CA.
|
||||
t.Log(server)
|
||||
_, err = client.EmptyCall(ctx, &testpb.Empty{})
|
||||
const wantErr = "certificate signed by unknown authority"
|
||||
if status.Code(err) == codes.Unavailable && strings.Contains(err.Error(), wantErr) {
|
||||
// Certs have reloaded.
|
||||
server.Stop()
|
||||
break
|
||||
}
|
||||
t.Logf("EmptyCall() got err: %s, want code: %s, want err: %s", err, codes.Unavailable, wantErr)
|
||||
server.Stop()
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
t.Errorf("Timed out waiting for CA certs reloading")
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestMTLS(t *testing.T) {
|
||||
s := stubserver.StartTestService(t, nil, grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert)))
|
||||
defer s.Stop()
|
||||
|
||||
cfg := fmt.Sprintf(`{
|
||||
"ca_certificate_file": "%s",
|
||||
"certificate_file": "%s",
|
||||
"private_key_file": "%s"
|
||||
}`,
|
||||
testdata.Path("x509/server_ca_cert.pem"),
|
||||
testdata.Path("x509/client1_cert.pem"),
|
||||
testdata.Path("x509/client1_key.pem"))
|
||||
tlsBundle, stop, err := tlscreds.NewBundle([]byte(cfg))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create TLS bundle: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
conn, err := grpc.Dial(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com"))
|
||||
if err != nil {
|
||||
t.Fatalf("Error dialing: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
client := testgrpc.NewTestServiceClient(conn)
|
||||
if _, err = client.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
|
||||
t.Errorf("EmptyCall(): got error %v when expected to succeed", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2023 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package tlscreds
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/tls/certprovider"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
"google.golang.org/grpc/internal/stubserver"
|
||||
"google.golang.org/grpc/internal/testutils/xds/e2e"
|
||||
testgrpc "google.golang.org/grpc/interop/grpc_testing"
|
||||
testpb "google.golang.org/grpc/interop/grpc_testing"
|
||||
"google.golang.org/grpc/testdata"
|
||||
)
|
||||
|
||||
type s struct {
|
||||
grpctest.Tester
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
grpctest.RunSubTests(t, s{})
|
||||
}
|
||||
|
||||
type failingProvider struct{}
|
||||
|
||||
func (f failingProvider) KeyMaterial(context.Context) (*certprovider.KeyMaterial, error) {
|
||||
return nil, errors.New("test error")
|
||||
}
|
||||
|
||||
func (f failingProvider) Close() {}
|
||||
|
||||
func (s) TestFailingProvider(t *testing.T) {
|
||||
s := stubserver.StartTestService(t, nil, grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert)))
|
||||
defer s.Stop()
|
||||
|
||||
cfg := fmt.Sprintf(`{
|
||||
"ca_certificate_file": "%s",
|
||||
"certificate_file": "%s",
|
||||
"private_key_file": "%s"
|
||||
}`,
|
||||
testdata.Path("x509/server_ca_cert.pem"),
|
||||
testdata.Path("x509/client1_cert.pem"),
|
||||
testdata.Path("x509/client1_key.pem"))
|
||||
tlsBundle, stop, err := NewBundle([]byte(cfg))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create TLS bundle: %v", err)
|
||||
}
|
||||
stop()
|
||||
|
||||
// Force a provider that returns an error, and make sure the client fails
|
||||
// the handshake.
|
||||
creds, ok := tlsBundle.TransportCredentials().(*reloadingCreds)
|
||||
if !ok {
|
||||
t.Fatalf("Got %T, expected reloadingCreds", tlsBundle.TransportCredentials())
|
||||
}
|
||||
creds.provider = &failingProvider{}
|
||||
|
||||
conn, err := grpc.Dial(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com"))
|
||||
if err != nil {
|
||||
t.Fatalf("Error dialing: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := testgrpc.NewTestServiceClient(conn)
|
||||
_, err = client.EmptyCall(context.Background(), &testpb.Empty{})
|
||||
if wantErr := "test error"; err == nil || !strings.Contains(err.Error(), wantErr) {
|
||||
t.Errorf("EmptyCall() got err: %s, want err to contain: %s", err, wantErr)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue