mirror of https://github.com/grpc/grpc-go.git
pemfile: Move file watcher plugin from advancedtls to gRPC (#3981)
This commit is contained in:
parent
fe9c99ff4c
commit
4e179b8d3e
|
@ -0,0 +1,252 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2020 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package pemfile provides a file watching certificate provider plugin
|
||||||
|
// implementation which works for files with PEM contents.
|
||||||
|
//
|
||||||
|
// Experimental
|
||||||
|
//
|
||||||
|
// Notice: All APIs in this package are experimental and may be removed in a
|
||||||
|
// later release.
|
||||||
|
package pemfile
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/credentials/tls/certprovider"
|
||||||
|
"google.golang.org/grpc/grpclog"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultCertRefreshDuration = 1 * time.Hour
|
||||||
|
defaultRootRefreshDuration = 2 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// For overriding from unit tests.
|
||||||
|
newDistributor = func() distributor { return certprovider.NewDistributor() }
|
||||||
|
|
||||||
|
logger = grpclog.Component("pemfile")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Options configures a certificate provider plugin that watches a specified set
|
||||||
|
// of files that contain certificates and keys in PEM format.
|
||||||
|
type Options struct {
|
||||||
|
// CertFile is the file that holds the identity certificate.
|
||||||
|
// Optional. If this is set, KeyFile must also be set.
|
||||||
|
CertFile string
|
||||||
|
// KeyFile is the file that holds identity private key.
|
||||||
|
// Optional. If this is set, CertFile must also be set.
|
||||||
|
KeyFile string
|
||||||
|
// RootFile is the file that holds trusted root certificate(s).
|
||||||
|
// Optional.
|
||||||
|
RootFile string
|
||||||
|
// CertRefreshDuration is the amount of time the plugin waits before
|
||||||
|
// checking for updates in the specified identity certificate and key file.
|
||||||
|
// Optional. If not set, a default value (1 hour) will be used.
|
||||||
|
CertRefreshDuration time.Duration
|
||||||
|
// RootRefreshDuration is the amount of time the plugin waits before
|
||||||
|
// checking for updates in the specified root file.
|
||||||
|
// Optional. If not set, a default value (2 hour) will be used.
|
||||||
|
RootRefreshDuration time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProvider returns a new certificate provider plugin that is configured to
|
||||||
|
// watch the PEM files specified in the passed in options.
|
||||||
|
func NewProvider(o Options) (certprovider.Provider, error) {
|
||||||
|
if o.CertFile == "" && o.KeyFile == "" && o.RootFile == "" {
|
||||||
|
return nil, fmt.Errorf("pemfile: at least one credential file needs to be specified")
|
||||||
|
}
|
||||||
|
if keySpecified, certSpecified := o.KeyFile != "", o.CertFile != ""; keySpecified != certSpecified {
|
||||||
|
return nil, fmt.Errorf("pemfile: private key file and identity cert file should be both specified or not specified")
|
||||||
|
}
|
||||||
|
if o.CertRefreshDuration == 0 {
|
||||||
|
o.CertRefreshDuration = defaultCertRefreshDuration
|
||||||
|
}
|
||||||
|
if o.RootRefreshDuration == 0 {
|
||||||
|
o.RootRefreshDuration = defaultRootRefreshDuration
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := &watcher{opts: o}
|
||||||
|
if o.CertFile != "" && o.KeyFile != "" {
|
||||||
|
provider.identityDistributor = newDistributor()
|
||||||
|
}
|
||||||
|
if o.RootFile != "" {
|
||||||
|
provider.rootDistributor = newDistributor()
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
provider.cancel = cancel
|
||||||
|
go provider.run(ctx)
|
||||||
|
|
||||||
|
return provider, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// watcher is a certificate provider plugin that implements the
|
||||||
|
// certprovider.Provider interface. It watches a set of certificate and key
|
||||||
|
// files and provides the most up-to-date key material for consumption by
|
||||||
|
// credentials implementation.
|
||||||
|
type watcher struct {
|
||||||
|
identityDistributor distributor
|
||||||
|
rootDistributor distributor
|
||||||
|
opts Options
|
||||||
|
certFileContents []byte
|
||||||
|
keyFileContents []byte
|
||||||
|
rootFileContents []byte
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// distributor wraps the methods on certprovider.Distributor which are used by
|
||||||
|
// the plugin. This is very useful in tests which need to know exactly when the
|
||||||
|
// plugin updates its key material.
|
||||||
|
type distributor interface {
|
||||||
|
KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error)
|
||||||
|
Set(km *certprovider.KeyMaterial, err error)
|
||||||
|
Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateIdentityDistributor checks if the cert/key files that the plugin is
|
||||||
|
// watching have changed, and if so, reads the new contents and updates the
|
||||||
|
// identityDistributor with the new key material.
|
||||||
|
//
|
||||||
|
// Skips updates when file reading or parsing fails.
|
||||||
|
// TODO(easwars): Retry with limit (on the number of retries or the amount of
|
||||||
|
// time) upon failures.
|
||||||
|
func (w *watcher) updateIdentityDistributor() {
|
||||||
|
if w.identityDistributor == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
certFileContents, err := ioutil.ReadFile(w.opts.CertFile)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warningf("certFile (%s) read failed: %v", w.opts.CertFile, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keyFileContents, err := ioutil.ReadFile(w.opts.KeyFile)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warningf("keyFile (%s) read failed: %v", w.opts.KeyFile, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// If the file contents have not changed, skip updating the distributor.
|
||||||
|
if bytes.Equal(w.certFileContents, certFileContents) && bytes.Equal(w.keyFileContents, keyFileContents) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := tls.X509KeyPair(certFileContents, keyFileContents)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warningf("tls.X509KeyPair(%q, %q) failed: %v", certFileContents, keyFileContents, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.certFileContents = certFileContents
|
||||||
|
w.keyFileContents = keyFileContents
|
||||||
|
w.identityDistributor.Set(&certprovider.KeyMaterial{Certs: []tls.Certificate{cert}}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateRootDistributor checks if the root cert file that the plugin is
|
||||||
|
// watching hs changed, and if so, updates the rootDistributor with the new key
|
||||||
|
// material.
|
||||||
|
//
|
||||||
|
// Skips updates when root cert reading or parsing fails.
|
||||||
|
// TODO(easwars): Retry with limit (on the number of retries or the amount of
|
||||||
|
// time) upon failures.
|
||||||
|
func (w *watcher) updateRootDistributor() {
|
||||||
|
if w.rootDistributor == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rootFileContents, err := ioutil.ReadFile(w.opts.RootFile)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warningf("rootFile (%s) read failed: %v", w.opts.RootFile, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
trustPool := x509.NewCertPool()
|
||||||
|
if !trustPool.AppendCertsFromPEM(rootFileContents) {
|
||||||
|
logger.Warning("failed to parse root certificate")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// If the file contents have not changed, skip updating the distributor.
|
||||||
|
if bytes.Equal(w.rootFileContents, rootFileContents) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.rootFileContents = rootFileContents
|
||||||
|
w.rootDistributor.Set(&certprovider.KeyMaterial{Roots: trustPool}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// run is a long running goroutine which watches the configured files for
|
||||||
|
// changes, and pushes new key material into the appropriate distributors which
|
||||||
|
// is returned from calls to KeyMaterial().
|
||||||
|
func (w *watcher) run(ctx context.Context) {
|
||||||
|
// Update both root and identity certs at the beginning. Subsequently,
|
||||||
|
// update only the appropriate file whose ticker has fired.
|
||||||
|
w.updateIdentityDistributor()
|
||||||
|
w.updateRootDistributor()
|
||||||
|
|
||||||
|
identityTicker := time.NewTicker(w.opts.CertRefreshDuration)
|
||||||
|
rootTicker := time.NewTicker(w.opts.RootRefreshDuration)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
identityTicker.Stop()
|
||||||
|
rootTicker.Stop()
|
||||||
|
if w.identityDistributor != nil {
|
||||||
|
w.identityDistributor.Stop()
|
||||||
|
}
|
||||||
|
if w.rootDistributor != nil {
|
||||||
|
w.rootDistributor.Stop()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case <-identityTicker.C:
|
||||||
|
w.updateIdentityDistributor()
|
||||||
|
case <-rootTicker.C:
|
||||||
|
w.updateRootDistributor()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyMaterial returns the key material sourced by the watcher.
|
||||||
|
// Callers are expected to use the returned value as read-only.
|
||||||
|
func (w *watcher) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
|
||||||
|
km := &certprovider.KeyMaterial{}
|
||||||
|
if w.identityDistributor != nil {
|
||||||
|
identityKM, err := w.identityDistributor.KeyMaterial(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
km.Certs = identityKM.Certs
|
||||||
|
}
|
||||||
|
if w.rootDistributor != nil {
|
||||||
|
rootKM, err := w.rootDistributor.KeyMaterial(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
km.Roots = rootKM.Roots
|
||||||
|
}
|
||||||
|
return km, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close cleans up resources allocated by the watcher.
|
||||||
|
func (w *watcher) Close() {
|
||||||
|
w.cancel()
|
||||||
|
}
|
|
@ -0,0 +1,426 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2020 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package pemfile
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/big"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/credentials/tls/certprovider"
|
||||||
|
"google.golang.org/grpc/internal/grpctest"
|
||||||
|
"google.golang.org/grpc/internal/testutils"
|
||||||
|
"google.golang.org/grpc/testdata"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// These are the names of files inside temporary directories, which the
|
||||||
|
// plugin is asked to watch.
|
||||||
|
certFile = "cert.pem"
|
||||||
|
keyFile = "key.pem"
|
||||||
|
rootFile = "ca.pem"
|
||||||
|
|
||||||
|
defaultTestRefreshDuration = 100 * time.Millisecond
|
||||||
|
defaultTestTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type s struct {
|
||||||
|
grpctest.Tester
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test(t *testing.T) {
|
||||||
|
grpctest.RunSubTests(t, s{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewProvider tests the NewProvider() function with different inputs.
|
||||||
|
func (s) TestNewProvider(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
options Options
|
||||||
|
wantError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "No credential files specified",
|
||||||
|
options: Options{},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Only identity cert is specified",
|
||||||
|
options: Options{
|
||||||
|
CertFile: testdata.Path("x509/client1_cert.pem"),
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Only identity key is specified",
|
||||||
|
options: Options{
|
||||||
|
KeyFile: testdata.Path("x509/client1_key.pem"),
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Identity cert/key pair is specified",
|
||||||
|
options: Options{
|
||||||
|
KeyFile: testdata.Path("x509/client1_key.pem"),
|
||||||
|
CertFile: testdata.Path("x509/client1_cert.pem"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Only root certs are specified",
|
||||||
|
options: Options{
|
||||||
|
RootFile: testdata.Path("x509/client_ca_cert.pem"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Everything is specified",
|
||||||
|
options: Options{
|
||||||
|
KeyFile: testdata.Path("x509/client1_key.pem"),
|
||||||
|
CertFile: testdata.Path("x509/client1_cert.pem"),
|
||||||
|
RootFile: testdata.Path("x509/client_ca_cert.pem"),
|
||||||
|
},
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
provider, err := NewProvider(test.options)
|
||||||
|
if (err != nil) != test.wantError {
|
||||||
|
t.Fatalf("NewProvider(%v) = %v, want %v", test.options, err, test.wantError)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
provider.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrappedDistributor wraps a distributor and pushes on a channel whenever new
|
||||||
|
// key material is pushed to the distributor.
|
||||||
|
type wrappedDistributor struct {
|
||||||
|
*certprovider.Distributor
|
||||||
|
distCh *testutils.Channel
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWrappedDistributor(distCh *testutils.Channel) *wrappedDistributor {
|
||||||
|
return &wrappedDistributor{
|
||||||
|
distCh: distCh,
|
||||||
|
Distributor: certprovider.NewDistributor(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wd *wrappedDistributor) Set(km *certprovider.KeyMaterial, err error) {
|
||||||
|
wd.Distributor.Set(km, err)
|
||||||
|
wd.distCh.Send(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTmpFile(t *testing.T, src, dst string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
data, err := ioutil.ReadFile(src)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ioutil.ReadFile(%q) failed: %v", src, err)
|
||||||
|
}
|
||||||
|
if err := ioutil.WriteFile(dst, data, os.ModePerm); err != nil {
|
||||||
|
t.Fatalf("ioutil.WriteFile(%q) failed: %v", dst, err)
|
||||||
|
}
|
||||||
|
t.Logf("Wrote file at: %s", dst)
|
||||||
|
t.Logf("%s", string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// createTempDirWithFiles creates a temporary directory under the system default
|
||||||
|
// tempDir with the given dirSuffix. It also reads from certSrc, keySrc and
|
||||||
|
// rootSrc files are creates appropriate files under the newly create tempDir.
|
||||||
|
// Returns the name of the created tempDir.
|
||||||
|
func createTmpDirWithFiles(t *testing.T, dirSuffix, certSrc, keySrc, rootSrc string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Create a temp directory. Passing an empty string for the first argument
|
||||||
|
// uses the system temp directory.
|
||||||
|
dir, err := ioutil.TempDir("", dirSuffix)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ioutil.TempDir() failed: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("Using tmpdir: %s", dir)
|
||||||
|
|
||||||
|
createTmpFile(t, testdata.Path(certSrc), path.Join(dir, certFile))
|
||||||
|
createTmpFile(t, testdata.Path(keySrc), path.Join(dir, keyFile))
|
||||||
|
createTmpFile(t, testdata.Path(rootSrc), path.Join(dir, rootFile))
|
||||||
|
return dir
|
||||||
|
}
|
||||||
|
|
||||||
|
// initializeProvider performs setup steps common to all tests (except the one
|
||||||
|
// which uses symlinks).
|
||||||
|
func initializeProvider(t *testing.T, testName string) (string, certprovider.Provider, *testutils.Channel, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Override the newDistributor to one which pushes on a channel that we
|
||||||
|
// can block on.
|
||||||
|
origDistributorFunc := newDistributor
|
||||||
|
distCh := testutils.NewChannel()
|
||||||
|
d := newWrappedDistributor(distCh)
|
||||||
|
newDistributor = func() distributor { return d }
|
||||||
|
|
||||||
|
// Create a new provider to watch the files in tmpdir.
|
||||||
|
dir := createTmpDirWithFiles(t, testName+"*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem")
|
||||||
|
opts := Options{
|
||||||
|
CertFile: path.Join(dir, certFile),
|
||||||
|
KeyFile: path.Join(dir, keyFile),
|
||||||
|
RootFile: path.Join(dir, rootFile),
|
||||||
|
CertRefreshDuration: defaultTestRefreshDuration,
|
||||||
|
RootRefreshDuration: defaultTestRefreshDuration,
|
||||||
|
}
|
||||||
|
prov, err := NewProvider(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProvider(%+v) failed: %v", opts, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the provider picks up the files and pushes the key material on
|
||||||
|
// to the distributors.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
// Since we have root and identity certs, we need to make sure the
|
||||||
|
// update is pushed on both of them.
|
||||||
|
if _, err := distCh.Receive(ctx); err != nil {
|
||||||
|
t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dir, prov, distCh, func() {
|
||||||
|
newDistributor = origDistributorFunc
|
||||||
|
prov.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProvider_NoUpdate tests the case where a file watcher plugin is created
|
||||||
|
// successfully, and the underlying files do not change. Verifies that the
|
||||||
|
// plugin does not push new updates to the distributor in this case.
|
||||||
|
func (s) TestProvider_NoUpdate(t *testing.T) {
|
||||||
|
_, prov, distCh, cancel := initializeProvider(t, "no_update")
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Make sure the provider is healthy and returns key material.
|
||||||
|
ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cc()
|
||||||
|
if _, err := prov.KeyMaterial(ctx); err != nil {
|
||||||
|
t.Fatalf("provider.KeyMaterial() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Files haven't change. Make sure no updates are pushed by the provider.
|
||||||
|
sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration)
|
||||||
|
defer sc()
|
||||||
|
if _, err := distCh.Receive(sCtx); err == nil {
|
||||||
|
t.Fatal("new key material pushed to distributor when underlying files did not change")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProvider_UpdateSuccess tests the case where a file watcher plugin is
|
||||||
|
// created successfully and the underlying files change. Verifies that the
|
||||||
|
// changes are picked up by the provider.
|
||||||
|
func (s) TestProvider_UpdateSuccess(t *testing.T) {
|
||||||
|
dir, prov, distCh, cancel := initializeProvider(t, "update_success")
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Make sure the provider is healthy and returns key material.
|
||||||
|
ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cc()
|
||||||
|
km1, err := prov.KeyMaterial(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("provider.KeyMaterial() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change only the root file.
|
||||||
|
createTmpFile(t, testdata.Path("x509/server_ca_cert.pem"), path.Join(dir, rootFile))
|
||||||
|
if _, err := distCh.Receive(ctx); err != nil {
|
||||||
|
t.Fatal("timeout waiting for new key material to be pushed to the distributor")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure update is picked up.
|
||||||
|
km2, err := prov.KeyMaterial(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("provider.KeyMaterial() failed: %v", err)
|
||||||
|
}
|
||||||
|
if cmp.Equal(km1, km2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
|
||||||
|
t.Fatal("expected provider to return new key material after update to underlying file")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change only cert/key files.
|
||||||
|
createTmpFile(t, testdata.Path("x509/client2_cert.pem"), path.Join(dir, certFile))
|
||||||
|
createTmpFile(t, testdata.Path("x509/client2_key.pem"), path.Join(dir, keyFile))
|
||||||
|
if _, err := distCh.Receive(ctx); err != nil {
|
||||||
|
t.Fatal("timeout waiting for new key material to be pushed to the distributor")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure update is picked up.
|
||||||
|
km3, err := prov.KeyMaterial(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("provider.KeyMaterial() failed: %v", err)
|
||||||
|
}
|
||||||
|
if cmp.Equal(km2, km3, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
|
||||||
|
t.Fatal("expected provider to return new key material after update to underlying file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProvider_UpdateSuccessWithSymlink tests the case where a file watcher
|
||||||
|
// plugin is created successfully to watch files through a symlink and the
|
||||||
|
// symlink is updates to point to new files. Verifies that the changes are
|
||||||
|
// picked up by the provider.
|
||||||
|
func (s) TestProvider_UpdateSuccessWithSymlink(t *testing.T) {
|
||||||
|
// Override the newDistributor to one which pushes on a channel that we
|
||||||
|
// can block on.
|
||||||
|
origDistributorFunc := newDistributor
|
||||||
|
distCh := testutils.NewChannel()
|
||||||
|
d := newWrappedDistributor(distCh)
|
||||||
|
newDistributor = func() distributor { return d }
|
||||||
|
defer func() { newDistributor = origDistributorFunc }()
|
||||||
|
|
||||||
|
// Create two tempDirs with different files.
|
||||||
|
dir1 := createTmpDirWithFiles(t, "update_with_symlink1_*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem")
|
||||||
|
dir2 := createTmpDirWithFiles(t, "update_with_symlink2_*", "x509/server1_cert.pem", "x509/server1_key.pem", "x509/server_ca_cert.pem")
|
||||||
|
|
||||||
|
// Create a symlink under a new tempdir, and make it point to dir1.
|
||||||
|
tmpdir, err := ioutil.TempDir("", "test_symlink_*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ioutil.TempDir() failed: %v", err)
|
||||||
|
}
|
||||||
|
symLinkName := path.Join(tmpdir, "test_symlink")
|
||||||
|
if err := os.Symlink(dir1, symLinkName); err != nil {
|
||||||
|
t.Fatalf("failed to create symlink to %q: %v", dir1, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a provider which watches the files pointed to by the symlink.
|
||||||
|
opts := Options{
|
||||||
|
CertFile: path.Join(symLinkName, certFile),
|
||||||
|
KeyFile: path.Join(symLinkName, keyFile),
|
||||||
|
RootFile: path.Join(symLinkName, rootFile),
|
||||||
|
CertRefreshDuration: defaultTestRefreshDuration,
|
||||||
|
RootRefreshDuration: defaultTestRefreshDuration,
|
||||||
|
}
|
||||||
|
prov, err := NewProvider(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProvider(%+v) failed: %v", opts, err)
|
||||||
|
}
|
||||||
|
defer prov.Close()
|
||||||
|
|
||||||
|
// Make sure the provider picks up the files and pushes the key material on
|
||||||
|
// to the distributors.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
// Since we have root and identity certs, we need to make sure the
|
||||||
|
// update is pushed on both of them.
|
||||||
|
if _, err := distCh.Receive(ctx); err != nil {
|
||||||
|
t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
km1, err := prov.KeyMaterial(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("provider.KeyMaterial() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the symlink to point to dir2.
|
||||||
|
symLinkTmpName := path.Join(tmpdir, "test_symlink.tmp")
|
||||||
|
if err := os.Symlink(dir2, symLinkTmpName); err != nil {
|
||||||
|
t.Fatalf("failed to create symlink to %q: %v", dir2, err)
|
||||||
|
}
|
||||||
|
if err := os.Rename(symLinkTmpName, symLinkName); err != nil {
|
||||||
|
t.Fatalf("failed to update symlink: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the provider picks up the new files and pushes the key material
|
||||||
|
// on to the distributors.
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
// Since we have root and identity certs, we need to make sure the
|
||||||
|
// update is pushed on both of them.
|
||||||
|
if _, err := distCh.Receive(ctx); err != nil {
|
||||||
|
t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
km2, err := prov.KeyMaterial(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("provider.KeyMaterial() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmp.Equal(km1, km2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
|
||||||
|
t.Fatal("expected provider to return new key material after symlink update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProvider_UpdateFailure_ThenSuccess tests the case where updating cert/key
|
||||||
|
// files fail. Verifies that the failed update does not push anything on the
|
||||||
|
// distributor. Then the update succeeds, and the test verifies that the key
|
||||||
|
// material is updated.
|
||||||
|
func (s) TestProvider_UpdateFailure_ThenSuccess(t *testing.T) {
|
||||||
|
dir, prov, distCh, cancel := initializeProvider(t, "update_failure")
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Make sure the provider is healthy and returns key material.
|
||||||
|
ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cc()
|
||||||
|
km1, err := prov.KeyMaterial(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("provider.KeyMaterial() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update only the cert file. The key file is left unchanged. This should
|
||||||
|
// lead to these two files being not compatible with each other. This
|
||||||
|
// simulates the case where the watching goroutine might catch the files in
|
||||||
|
// the midst of an update.
|
||||||
|
createTmpFile(t, testdata.Path("x509/server1_cert.pem"), path.Join(dir, certFile))
|
||||||
|
|
||||||
|
// Since the last update left the files in an incompatible state, the update
|
||||||
|
// should not be picked up by our provider.
|
||||||
|
sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration)
|
||||||
|
defer sc()
|
||||||
|
if _, err := distCh.Receive(sCtx); err == nil {
|
||||||
|
t.Fatal("new key material pushed to distributor when underlying files did not change")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The provider should return key material corresponding to the old state.
|
||||||
|
km2, err := prov.KeyMaterial(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("provider.KeyMaterial() failed: %v", err)
|
||||||
|
}
|
||||||
|
if !cmp.Equal(km1, km2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
|
||||||
|
t.Fatal("expected provider to not update key material")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the key file to match the cert file.
|
||||||
|
createTmpFile(t, testdata.Path("x509/server1_key.pem"), path.Join(dir, keyFile))
|
||||||
|
|
||||||
|
// Make sure update is picked up.
|
||||||
|
if _, err := distCh.Receive(ctx); err != nil {
|
||||||
|
t.Fatal("timeout waiting for new key material to be pushed to the distributor")
|
||||||
|
}
|
||||||
|
km3, err := prov.KeyMaterial(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("provider.KeyMaterial() failed: %v", err)
|
||||||
|
}
|
||||||
|
if cmp.Equal(km2, km3, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
|
||||||
|
t.Fatal("expected provider to return new key material after update to underlying file")
|
||||||
|
}
|
||||||
|
}
|
|
@ -32,6 +32,8 @@ import (
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/credentials/tls/certprovider"
|
||||||
|
"google.golang.org/grpc/credentials/tls/certprovider/pemfile"
|
||||||
pb "google.golang.org/grpc/examples/helloworld/helloworld"
|
pb "google.golang.org/grpc/examples/helloworld/helloworld"
|
||||||
"google.golang.org/grpc/security/advancedtls/internal/testutils"
|
"google.golang.org/grpc/security/advancedtls/internal/testutils"
|
||||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
"google.golang.org/grpc/security/advancedtls/testdata"
|
||||||
|
@ -511,38 +513,38 @@ func copyFileContents(sourceFile, destinationFile string) error {
|
||||||
|
|
||||||
// Create PEMFileProvider(s) watching the content changes of temporary
|
// Create PEMFileProvider(s) watching the content changes of temporary
|
||||||
// files.
|
// files.
|
||||||
func createProviders(tmpFiles *tmpCredsFiles) (*PEMFileProvider, *PEMFileProvider, *PEMFileProvider, *PEMFileProvider, error) {
|
func createProviders(tmpFiles *tmpCredsFiles) (certprovider.Provider, certprovider.Provider, certprovider.Provider, certprovider.Provider, error) {
|
||||||
clientIdentityOptions := PEMFileProviderOptions{
|
clientIdentityOptions := pemfile.Options{
|
||||||
CertFile: tmpFiles.clientCertTmp.Name(),
|
CertFile: tmpFiles.clientCertTmp.Name(),
|
||||||
KeyFile: tmpFiles.clientKeyTmp.Name(),
|
KeyFile: tmpFiles.clientKeyTmp.Name(),
|
||||||
IdentityInterval: credRefreshingInterval,
|
CertRefreshDuration: credRefreshingInterval,
|
||||||
}
|
}
|
||||||
clientIdentityProvider, err := NewPEMFileProvider(clientIdentityOptions)
|
clientIdentityProvider, err := pemfile.NewProvider(clientIdentityOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
clientRootOptions := PEMFileProviderOptions{
|
clientRootOptions := pemfile.Options{
|
||||||
TrustFile: tmpFiles.clientTrustTmp.Name(),
|
RootFile: tmpFiles.clientTrustTmp.Name(),
|
||||||
RootInterval: credRefreshingInterval,
|
RootRefreshDuration: credRefreshingInterval,
|
||||||
}
|
}
|
||||||
clientRootProvider, err := NewPEMFileProvider(clientRootOptions)
|
clientRootProvider, err := pemfile.NewProvider(clientRootOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
serverIdentityOptions := PEMFileProviderOptions{
|
serverIdentityOptions := pemfile.Options{
|
||||||
CertFile: tmpFiles.serverCertTmp.Name(),
|
CertFile: tmpFiles.serverCertTmp.Name(),
|
||||||
KeyFile: tmpFiles.serverKeyTmp.Name(),
|
KeyFile: tmpFiles.serverKeyTmp.Name(),
|
||||||
IdentityInterval: credRefreshingInterval,
|
CertRefreshDuration: credRefreshingInterval,
|
||||||
}
|
}
|
||||||
serverIdentityProvider, err := NewPEMFileProvider(serverIdentityOptions)
|
serverIdentityProvider, err := pemfile.NewProvider(serverIdentityOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
serverRootOptions := PEMFileProviderOptions{
|
serverRootOptions := pemfile.Options{
|
||||||
TrustFile: tmpFiles.serverTrustTmp.Name(),
|
RootFile: tmpFiles.serverTrustTmp.Name(),
|
||||||
RootInterval: credRefreshingInterval,
|
RootRefreshDuration: credRefreshingInterval,
|
||||||
}
|
}
|
||||||
serverRootProvider, err := NewPEMFileProvider(serverRootOptions)
|
serverRootProvider, err := pemfile.NewProvider(serverRootOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,197 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* Copyright 2020 gRPC authors.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
package advancedtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"google.golang.org/grpc/credentials/tls/certprovider"
|
|
||||||
"google.golang.org/grpc/grpclog"
|
|
||||||
)
|
|
||||||
|
|
||||||
const defaultIdentityInterval = 1 * time.Hour
|
|
||||||
const defaultRootInterval = 2 * time.Hour
|
|
||||||
|
|
||||||
// readKeyCertPairFunc will be overridden from unit tests.
|
|
||||||
var readKeyCertPairFunc = tls.LoadX509KeyPair
|
|
||||||
|
|
||||||
// readTrustCertFunc will be overridden from unit tests.
|
|
||||||
var readTrustCertFunc = func(trustFile string) (*x509.CertPool, error) {
|
|
||||||
trustData, err := ioutil.ReadFile(trustFile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
trustPool := x509.NewCertPool()
|
|
||||||
if !trustPool.AppendCertsFromPEM(trustData) {
|
|
||||||
return nil, fmt.Errorf("AppendCertsFromPEM failed to parse certificates")
|
|
||||||
}
|
|
||||||
return trustPool, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var logger = grpclog.Component("advancedtls")
|
|
||||||
|
|
||||||
// PEMFileProviderOptions contains options to configure a PEMFileProvider.
|
|
||||||
// Note that these fields will only take effect during construction. Once the
|
|
||||||
// PEMFileProvider starts, changing fields in PEMFileProviderOptions will have
|
|
||||||
// no effect.
|
|
||||||
type PEMFileProviderOptions struct {
|
|
||||||
// CertFile is the file path that holds identity certificate whose updates
|
|
||||||
// will be captured by a watching goroutine.
|
|
||||||
// Optional. If this is set, KeyFile must also be set.
|
|
||||||
CertFile string
|
|
||||||
// KeyFile is the file path that holds identity private key whose updates
|
|
||||||
// will be captured by a watching goroutine.
|
|
||||||
// Optional. If this is set, CertFile must also be set.
|
|
||||||
KeyFile string
|
|
||||||
// TrustFile is the file path that holds trust certificate whose updates will
|
|
||||||
// be captured by a watching goroutine.
|
|
||||||
// Optional.
|
|
||||||
TrustFile string
|
|
||||||
// IdentityInterval is the time duration between two credential update checks
|
|
||||||
// for identity certs.
|
|
||||||
// Optional. If not set, we will use the default interval(1 hour).
|
|
||||||
IdentityInterval time.Duration
|
|
||||||
// RootInterval is the time duration between two credential update checks
|
|
||||||
// for root certs.
|
|
||||||
// Optional. If not set, we will use the default interval(2 hours).
|
|
||||||
RootInterval time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// PEMFileProvider implements certprovider.Provider.
|
|
||||||
// It provides the most up-to-date identity private key-cert pairs and/or
|
|
||||||
// root certificates.
|
|
||||||
type PEMFileProvider struct {
|
|
||||||
identityDistributor *certprovider.Distributor
|
|
||||||
rootDistributor *certprovider.Distributor
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateIdentityDistributor(distributor *certprovider.Distributor, certFile, keyFile string) {
|
|
||||||
if distributor == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Read identity certs from PEM files.
|
|
||||||
identityCert, err := readKeyCertPairFunc(certFile, keyFile)
|
|
||||||
if err != nil {
|
|
||||||
// If the reading produces an error, we will skip the update for this
|
|
||||||
// round and log the error.
|
|
||||||
logger.Warningf("tls.LoadX509KeyPair reads %s and %s failed: %v", certFile, keyFile, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
distributor.Set(&certprovider.KeyMaterial{Certs: []tls.Certificate{identityCert}}, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateRootDistributor(distributor *certprovider.Distributor, trustFile string) {
|
|
||||||
if distributor == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Read root certs from PEM files.
|
|
||||||
trustPool, err := readTrustCertFunc(trustFile)
|
|
||||||
if err != nil {
|
|
||||||
// If the reading produces an error, we will skip the update for this
|
|
||||||
// round and log the error.
|
|
||||||
logger.Warningf("readTrustCertFunc reads %v failed: %v", trustFile, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
distributor.Set(&certprovider.KeyMaterial{Roots: trustPool}, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPEMFileProvider returns a new PEMFileProvider constructed using the
|
|
||||||
// provided options.
|
|
||||||
func NewPEMFileProvider(o PEMFileProviderOptions) (*PEMFileProvider, error) {
|
|
||||||
if o.CertFile == "" && o.KeyFile == "" && o.TrustFile == "" {
|
|
||||||
return nil, fmt.Errorf("at least one credential file needs to be specified")
|
|
||||||
}
|
|
||||||
if keySpecified, certSpecified := o.KeyFile != "", o.CertFile != ""; keySpecified != certSpecified {
|
|
||||||
return nil, fmt.Errorf("private key file and identity cert file should be both specified or not specified")
|
|
||||||
}
|
|
||||||
if o.IdentityInterval == 0 {
|
|
||||||
o.IdentityInterval = defaultIdentityInterval
|
|
||||||
}
|
|
||||||
if o.RootInterval == 0 {
|
|
||||||
o.RootInterval = defaultRootInterval
|
|
||||||
}
|
|
||||||
provider := &PEMFileProvider{}
|
|
||||||
if o.CertFile != "" && o.KeyFile != "" {
|
|
||||||
provider.identityDistributor = certprovider.NewDistributor()
|
|
||||||
}
|
|
||||||
if o.TrustFile != "" {
|
|
||||||
provider.rootDistributor = certprovider.NewDistributor()
|
|
||||||
}
|
|
||||||
// A goroutine to pull file changes.
|
|
||||||
identityTicker := time.NewTicker(o.IdentityInterval)
|
|
||||||
rootTicker := time.NewTicker(o.RootInterval)
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
updateIdentityDistributor(provider.identityDistributor, o.CertFile, o.KeyFile)
|
|
||||||
updateRootDistributor(provider.rootDistributor, o.TrustFile)
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
identityTicker.Stop()
|
|
||||||
rootTicker.Stop()
|
|
||||||
return
|
|
||||||
case <-identityTicker.C:
|
|
||||||
break
|
|
||||||
case <-rootTicker.C:
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
provider.cancel = cancel
|
|
||||||
return provider, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeyMaterial returns the key material sourced by the PEMFileProvider.
|
|
||||||
// Callers are expected to use the returned value as read-only.
|
|
||||||
func (p *PEMFileProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
|
|
||||||
km := &certprovider.KeyMaterial{}
|
|
||||||
if p.identityDistributor != nil {
|
|
||||||
identityKM, err := p.identityDistributor.KeyMaterial(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
km.Certs = identityKM.Certs
|
|
||||||
}
|
|
||||||
if p.rootDistributor != nil {
|
|
||||||
rootKM, err := p.rootDistributor.KeyMaterial(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
km.Roots = rootKM.Roots
|
|
||||||
}
|
|
||||||
return km, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close cleans up resources allocated by the PEMFileProvider.
|
|
||||||
func (p *PEMFileProvider) Close() {
|
|
||||||
p.cancel()
|
|
||||||
if p.identityDistributor != nil {
|
|
||||||
p.identityDistributor.Stop()
|
|
||||||
}
|
|
||||||
if p.rootDistributor != nil {
|
|
||||||
p.rootDistributor.Stop()
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,220 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* Copyright 2020 gRPC authors.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
package advancedtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"fmt"
|
|
||||||
"math/big"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"google.golang.org/grpc/credentials/tls/certprovider"
|
|
||||||
"google.golang.org/grpc/security/advancedtls/internal/testutils"
|
|
||||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s) TestNewPEMFileProvider(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
desc string
|
|
||||||
options PEMFileProviderOptions
|
|
||||||
certFile string
|
|
||||||
keyFile string
|
|
||||||
trustFile string
|
|
||||||
wantError bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "Expect error if no credential files specified",
|
|
||||||
options: PEMFileProviderOptions{},
|
|
||||||
wantError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Expect error if only certFile is specified",
|
|
||||||
options: PEMFileProviderOptions{
|
|
||||||
CertFile: testdata.Path("client_cert_1.pem"),
|
|
||||||
},
|
|
||||||
wantError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Should be good if only identity key cert pairs are specified",
|
|
||||||
options: PEMFileProviderOptions{
|
|
||||||
KeyFile: testdata.Path("client_key_1.pem"),
|
|
||||||
CertFile: testdata.Path("client_cert_1.pem"),
|
|
||||||
},
|
|
||||||
wantError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Should be good if only root certs are specified",
|
|
||||||
options: PEMFileProviderOptions{
|
|
||||||
TrustFile: testdata.Path("client_trust_cert_1.pem"),
|
|
||||||
},
|
|
||||||
wantError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Should be good if both identity pairs and root certs are specified",
|
|
||||||
options: PEMFileProviderOptions{
|
|
||||||
KeyFile: testdata.Path("client_key_1.pem"),
|
|
||||||
CertFile: testdata.Path("client_cert_1.pem"),
|
|
||||||
TrustFile: testdata.Path("client_trust_cert_1.pem"),
|
|
||||||
},
|
|
||||||
wantError: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
|
||||||
provider, err := NewPEMFileProvider(test.options)
|
|
||||||
if (err != nil) != test.wantError {
|
|
||||||
t.Fatalf("NewPEMFileProvider(%v) = %v, want %v", test.options, err, test.wantError)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
provider.Close()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// This test overwrites the credential reading function used by the watching
|
|
||||||
// goroutine. It is tested under different stages:
|
|
||||||
// At stage 0, we force reading function to load ClientCert1 and ServerTrust1,
|
|
||||||
// and see if the credentials are picked up by the watching go routine.
|
|
||||||
// At stage 1, we force reading function to cause an error. The watching go
|
|
||||||
// routine should log the error while leaving the credentials unchanged.
|
|
||||||
// At stage 2, we force reading function to load ClientCert2 and ServerTrust2,
|
|
||||||
// and see if the new credentials are picked up.
|
|
||||||
func (s) TestWatchingRoutineUpdates(t *testing.T) {
|
|
||||||
// Load certificates.
|
|
||||||
cs := &testutils.CertStore{}
|
|
||||||
if err := cs.LoadCerts(); err != nil {
|
|
||||||
t.Fatalf("cs.LoadCerts() failed, err: %v", err)
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
desc string
|
|
||||||
options PEMFileProviderOptions
|
|
||||||
wantKmStage0 certprovider.KeyMaterial
|
|
||||||
wantKmStage1 certprovider.KeyMaterial
|
|
||||||
wantKmStage2 certprovider.KeyMaterial
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "use identity certs and root certs",
|
|
||||||
options: PEMFileProviderOptions{
|
|
||||||
CertFile: "not_empty_cert_file",
|
|
||||||
KeyFile: "not_empty_key_file",
|
|
||||||
TrustFile: "not_empty_trust_file",
|
|
||||||
},
|
|
||||||
wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1},
|
|
||||||
wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1},
|
|
||||||
wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}, Roots: cs.ServerTrust2},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "use identity certs only",
|
|
||||||
options: PEMFileProviderOptions{
|
|
||||||
CertFile: "not_empty_cert_file",
|
|
||||||
KeyFile: "not_empty_key_file",
|
|
||||||
},
|
|
||||||
wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}},
|
|
||||||
wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}},
|
|
||||||
wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "use trust certs only",
|
|
||||||
options: PEMFileProviderOptions{
|
|
||||||
TrustFile: "not_empty_trust_file",
|
|
||||||
},
|
|
||||||
wantKmStage0: certprovider.KeyMaterial{Roots: cs.ServerTrust1},
|
|
||||||
wantKmStage1: certprovider.KeyMaterial{Roots: cs.ServerTrust1},
|
|
||||||
wantKmStage2: certprovider.KeyMaterial{Roots: cs.ServerTrust2},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, test := range tests {
|
|
||||||
testInterval := 200 * time.Millisecond
|
|
||||||
test.options.IdentityInterval = testInterval
|
|
||||||
test.options.RootInterval = testInterval
|
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
|
||||||
stage := &stageInfo{}
|
|
||||||
oldReadKeyCertPairFunc := readKeyCertPairFunc
|
|
||||||
readKeyCertPairFunc = func(certFile, keyFile string) (tls.Certificate, error) {
|
|
||||||
switch stage.read() {
|
|
||||||
case 0:
|
|
||||||
return cs.ClientCert1, nil
|
|
||||||
case 1:
|
|
||||||
return tls.Certificate{}, fmt.Errorf("error occurred while reloading")
|
|
||||||
case 2:
|
|
||||||
return cs.ClientCert2, nil
|
|
||||||
default:
|
|
||||||
return tls.Certificate{}, fmt.Errorf("test stage not supported")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
readKeyCertPairFunc = oldReadKeyCertPairFunc
|
|
||||||
}()
|
|
||||||
oldReadTrustCertFunc := readTrustCertFunc
|
|
||||||
readTrustCertFunc = func(trustFile string) (*x509.CertPool, error) {
|
|
||||||
switch stage.read() {
|
|
||||||
case 0:
|
|
||||||
return cs.ServerTrust1, nil
|
|
||||||
case 1:
|
|
||||||
return nil, fmt.Errorf("error occurred while reloading")
|
|
||||||
case 2:
|
|
||||||
return cs.ServerTrust2, nil
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("test stage not supported")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
readTrustCertFunc = oldReadTrustCertFunc
|
|
||||||
}()
|
|
||||||
provider, err := NewPEMFileProvider(test.options)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewPEMFileProvider failed: %v", err)
|
|
||||||
}
|
|
||||||
defer provider.Close()
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
//// ------------------------Stage 0------------------------------------
|
|
||||||
// Wait for the refreshing go-routine to pick up the changes.
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
gotKM, err := provider.KeyMaterial(ctx)
|
|
||||||
if !cmp.Equal(*gotKM, test.wantKmStage0, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
|
|
||||||
t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage0)
|
|
||||||
}
|
|
||||||
// ------------------------Stage 1------------------------------------
|
|
||||||
stage.increase()
|
|
||||||
// Wait for the refreshing go-routine to pick up the changes.
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
gotKM, err = provider.KeyMaterial(ctx)
|
|
||||||
if !cmp.Equal(*gotKM, test.wantKmStage1, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
|
|
||||||
t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage1)
|
|
||||||
}
|
|
||||||
//// ------------------------Stage 2------------------------------------
|
|
||||||
// Wait for the refreshing go-routine to pick up the changes.
|
|
||||||
stage.increase()
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
gotKM, err = provider.KeyMaterial(ctx)
|
|
||||||
if !cmp.Equal(*gotKM, test.wantKmStage2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
|
|
||||||
t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage2)
|
|
||||||
}
|
|
||||||
stage.reset()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue