mirror of https://github.com/grpc/grpc-go.git
advancedtls: add PemFileProvider integration tests (#3934)
* advancedtls: add PemFileProvider integration tests
This commit is contained in:
parent
4be647f7f6
commit
ce5e366556
|
|
@ -23,7 +23,9 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -32,6 +34,7 @@ import (
|
|||
"google.golang.org/grpc/credentials"
|
||||
pb "google.golang.org/grpc/examples/helloworld/helloworld"
|
||||
"google.golang.org/grpc/security/advancedtls/internal/testutils"
|
||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -39,6 +42,17 @@ var (
|
|||
port = ":50051"
|
||||
)
|
||||
|
||||
const (
|
||||
// Default timeout for normal connections.
|
||||
defaultTestTimeout = 5 * time.Second
|
||||
// Default timeout for failed connections.
|
||||
defaultTestShortTimeout = 10 * time.Millisecond
|
||||
// Intervals that set to monitor the credential updates.
|
||||
credRefreshingInterval = 200 * time.Millisecond
|
||||
// Time we wait for the credential updates to be picked up.
|
||||
sleepInterval = 400 * time.Millisecond
|
||||
)
|
||||
|
||||
// stageInfo contains a stage number indicating the current phase of each
|
||||
// integration test, and a mutex.
|
||||
// Based on the stage number of current test, we will use different
|
||||
|
|
@ -76,6 +90,8 @@ func (greeterServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.Hel
|
|||
return &pb.HelloReply{Message: "Hello " + in.Name}, nil
|
||||
}
|
||||
|
||||
// TODO(ZhenLian): remove shouldFail to the function signature to provider
|
||||
// tests.
|
||||
func callAndVerify(msg string, client pb.GreeterClient, shouldFail bool) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
|
@ -86,6 +102,8 @@ func callAndVerify(msg string, client pb.GreeterClient, shouldFail bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// TODO(ZhenLian): remove shouldFail and add ...DialOption to the function
|
||||
// signature to provider cleaner tests.
|
||||
func callAndVerifyWithClientConn(connCtx context.Context, msg string, creds credentials.TransportCredentials, shouldFail bool) (*grpc.ClientConn, pb.GreeterClient, error) {
|
||||
var conn *grpc.ClientConn
|
||||
var err error
|
||||
|
|
@ -153,7 +171,7 @@ func (s) TestEnd2End(t *testing.T) {
|
|||
// should see it again accepts the connection, since ClientCert2 is trusted
|
||||
// by ServerTrust2.
|
||||
{
|
||||
desc: "TestClientPeerCertReloadServerTrustCertReload",
|
||||
desc: "test the reloading feature for client identity callback and server trust callback",
|
||||
clientGetCert: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
switch stage.read() {
|
||||
case 0:
|
||||
|
|
@ -194,7 +212,7 @@ func (s) TestEnd2End(t *testing.T) {
|
|||
// should see it again accepts the connection, since ServerCert2 is trusted
|
||||
// by ClientTrust2.
|
||||
{
|
||||
desc: "TestServerPeerCertReloadClientTrustCertReload",
|
||||
desc: "test the reloading feature for server identity callback and client trust callback",
|
||||
clientCert: []tls.Certificate{cs.ClientCert1},
|
||||
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
switch stage.read() {
|
||||
|
|
@ -236,7 +254,7 @@ func (s) TestEnd2End(t *testing.T) {
|
|||
// At stage 2, the client changes authorization check to only accept
|
||||
// ServerCert2. Now we should see the connection becomes normal again.
|
||||
{
|
||||
desc: "TestClientCustomVerification",
|
||||
desc: "test client custom verification",
|
||||
clientCert: []tls.Certificate{cs.ClientCert1},
|
||||
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
|
||||
switch stage.read() {
|
||||
|
|
@ -368,9 +386,9 @@ func (s) TestEnd2End(t *testing.T) {
|
|||
}
|
||||
// ------------------------Scenario 1------------------------------------
|
||||
// stage = 0, initial connection should succeed
|
||||
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel1()
|
||||
conn, greetClient, err := callAndVerifyWithClientConn(ctx1, "rpc call 1", clientTLSCreds, false)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
conn, greetClient, err := callAndVerifyWithClientConn(ctx, "rpc call 1", clientTLSCreds, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -385,9 +403,9 @@ func (s) TestEnd2End(t *testing.T) {
|
|||
}
|
||||
// ------------------------Scenario 3------------------------------------
|
||||
// stage = 1, new connection should fail
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
conn2, greetClient, err := callAndVerifyWithClientConn(ctx2, "rpc call 3", clientTLSCreds, true)
|
||||
shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
|
||||
defer shortCancel()
|
||||
conn2, greetClient, err := callAndVerifyWithClientConn(shortCtx, "rpc call 3", clientTLSCreds, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -396,9 +414,7 @@ func (s) TestEnd2End(t *testing.T) {
|
|||
stage.increase()
|
||||
// ------------------------Scenario 4------------------------------------
|
||||
// stage = 2, new connection should succeed
|
||||
ctx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel3()
|
||||
conn3, greetClient, err := callAndVerifyWithClientConn(ctx3, "rpc call 4", clientTLSCreds, false)
|
||||
conn3, greetClient, err := callAndVerifyWithClientConn(ctx, "rpc call 4", clientTLSCreds, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -408,3 +424,305 @@ func (s) TestEnd2End(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
type tmpCredsFiles struct {
|
||||
clientCertTmp *os.File
|
||||
clientKeyTmp *os.File
|
||||
clientTrustTmp *os.File
|
||||
serverCertTmp *os.File
|
||||
serverKeyTmp *os.File
|
||||
serverTrustTmp *os.File
|
||||
}
|
||||
|
||||
// Create temp files that are used to hold credentials.
|
||||
func createTmpFiles() (*tmpCredsFiles, error) {
|
||||
tmpFiles := &tmpCredsFiles{}
|
||||
var err error
|
||||
tmpFiles.clientCertTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmpFiles.clientKeyTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmpFiles.clientTrustTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmpFiles.serverCertTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmpFiles.serverKeyTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmpFiles.serverTrustTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tmpFiles, nil
|
||||
}
|
||||
|
||||
// Copy the credential contents to the temporary files.
|
||||
func (tmpFiles *tmpCredsFiles) copyCredsToTmpFiles() error {
|
||||
if err := copyFileContents(testdata.Path("client_cert_1.pem"), tmpFiles.clientCertTmp.Name()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := copyFileContents(testdata.Path("client_key_1.pem"), tmpFiles.clientKeyTmp.Name()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := copyFileContents(testdata.Path("client_trust_cert_1.pem"), tmpFiles.clientTrustTmp.Name()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := copyFileContents(testdata.Path("server_cert_1.pem"), tmpFiles.serverCertTmp.Name()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := copyFileContents(testdata.Path("server_key_1.pem"), tmpFiles.serverKeyTmp.Name()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := copyFileContents(testdata.Path("server_trust_cert_1.pem"), tmpFiles.serverTrustTmp.Name()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tmpFiles *tmpCredsFiles) removeFiles() {
|
||||
os.Remove(tmpFiles.clientCertTmp.Name())
|
||||
os.Remove(tmpFiles.clientKeyTmp.Name())
|
||||
os.Remove(tmpFiles.clientTrustTmp.Name())
|
||||
os.Remove(tmpFiles.serverCertTmp.Name())
|
||||
os.Remove(tmpFiles.serverKeyTmp.Name())
|
||||
os.Remove(tmpFiles.serverTrustTmp.Name())
|
||||
}
|
||||
|
||||
func copyFileContents(sourceFile, destinationFile string) error {
|
||||
input, err := ioutil.ReadFile(sourceFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ioutil.WriteFile(destinationFile, input, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create PEMFileProvider(s) watching the content changes of temporary
|
||||
// files.
|
||||
func createProviders(tmpFiles *tmpCredsFiles) (*PEMFileProvider, *PEMFileProvider, *PEMFileProvider, *PEMFileProvider, error) {
|
||||
clientIdentityOptions := PEMFileProviderOptions{
|
||||
CertFile: tmpFiles.clientCertTmp.Name(),
|
||||
KeyFile: tmpFiles.clientKeyTmp.Name(),
|
||||
IdentityInterval: credRefreshingInterval,
|
||||
}
|
||||
clientIdentityProvider, err := NewPEMFileProvider(clientIdentityOptions)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
clientRootOptions := PEMFileProviderOptions{
|
||||
TrustFile: tmpFiles.clientTrustTmp.Name(),
|
||||
RootInterval: credRefreshingInterval,
|
||||
}
|
||||
clientRootProvider, err := NewPEMFileProvider(clientRootOptions)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
serverIdentityOptions := PEMFileProviderOptions{
|
||||
CertFile: tmpFiles.serverCertTmp.Name(),
|
||||
KeyFile: tmpFiles.serverKeyTmp.Name(),
|
||||
IdentityInterval: credRefreshingInterval,
|
||||
}
|
||||
serverIdentityProvider, err := NewPEMFileProvider(serverIdentityOptions)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
serverRootOptions := PEMFileProviderOptions{
|
||||
TrustFile: tmpFiles.serverTrustTmp.Name(),
|
||||
RootInterval: credRefreshingInterval,
|
||||
}
|
||||
serverRootProvider, err := NewPEMFileProvider(serverRootOptions)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
return clientIdentityProvider, clientRootProvider, serverIdentityProvider, serverRootProvider, nil
|
||||
}
|
||||
|
||||
// In order to test advanced TLS provider features, we used temporary files to
|
||||
// hold credential data, and copy the contents under testdata/ to these tmp
|
||||
// files.
|
||||
// Initially, we establish a good connection with providers watching contents
|
||||
// from tmp files.
|
||||
// Next, we change the identity certs that IdentityProvider is watching. Since
|
||||
// the identity key is not changed, the IdentityProvider should ignore the
|
||||
// update, and the connection should still be good.
|
||||
// Then the the identity key is changed. This time IdentityProvider should pick
|
||||
// up the update, and the connection should fail, due to the trust certs on the
|
||||
// other side is not changed.
|
||||
// Finally, the trust certs that other-side's RootProvider is watching get
|
||||
// changed. The connection should go back to normal again.
|
||||
func (s) TestPEMFileProviderEnd2End(t *testing.T) {
|
||||
tmpFiles, err := createTmpFiles()
|
||||
if err != nil {
|
||||
t.Fatalf("createTmpFiles() failed, error: %v", err)
|
||||
}
|
||||
defer tmpFiles.removeFiles()
|
||||
for _, test := range []struct {
|
||||
desc string
|
||||
certUpdateFunc func()
|
||||
keyUpdateFunc func()
|
||||
trustCertUpdateFunc func()
|
||||
}{
|
||||
{
|
||||
desc: "test the reloading feature for clientIdentityProvider and serverTrustProvider",
|
||||
certUpdateFunc: func() {
|
||||
err = copyFileContents(testdata.Path("client_cert_2.pem"), tmpFiles.clientCertTmp.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_cert_2.pem"), tmpFiles.clientCertTmp.Name(), err)
|
||||
}
|
||||
},
|
||||
keyUpdateFunc: func() {
|
||||
err = copyFileContents(testdata.Path("client_key_2.pem"), tmpFiles.clientKeyTmp.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_key_2.pem"), tmpFiles.clientKeyTmp.Name(), err)
|
||||
}
|
||||
},
|
||||
trustCertUpdateFunc: func() {
|
||||
err = copyFileContents(testdata.Path("server_trust_cert_2.pem"), tmpFiles.serverTrustTmp.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_trust_cert_2.pem"), tmpFiles.serverTrustTmp.Name(), err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "test the reloading feature for serverIdentityProvider and clientTrustProvider",
|
||||
certUpdateFunc: func() {
|
||||
err = copyFileContents(testdata.Path("server_cert_2.pem"), tmpFiles.serverCertTmp.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_cert_2.pem"), tmpFiles.serverCertTmp.Name(), err)
|
||||
}
|
||||
},
|
||||
keyUpdateFunc: func() {
|
||||
err = copyFileContents(testdata.Path("server_key_2.pem"), tmpFiles.serverKeyTmp.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_key_2.pem"), tmpFiles.serverKeyTmp.Name(), err)
|
||||
}
|
||||
},
|
||||
trustCertUpdateFunc: func() {
|
||||
err = copyFileContents(testdata.Path("client_trust_cert_2.pem"), tmpFiles.clientTrustTmp.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_trust_cert_2.pem"), tmpFiles.clientTrustTmp.Name(), err)
|
||||
}
|
||||
},
|
||||
},
|
||||
} {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
if err := tmpFiles.copyCredsToTmpFiles(); err != nil {
|
||||
t.Fatalf("tmpFiles.copyCredsToTmpFiles() failed, error: %v", err)
|
||||
}
|
||||
clientIdentityProvider, clientRootProvider, serverIdentityProvider, serverRootProvider, err := createProviders(tmpFiles)
|
||||
if err != nil {
|
||||
t.Fatalf("createProviders(%v) failed, error: %v", tmpFiles, err)
|
||||
}
|
||||
defer clientIdentityProvider.Close()
|
||||
defer clientRootProvider.Close()
|
||||
defer serverIdentityProvider.Close()
|
||||
defer serverRootProvider.Close()
|
||||
// Start a server and create a client using advancedtls API with Provider.
|
||||
serverOptions := &ServerOptions{
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
IdentityProvider: serverIdentityProvider,
|
||||
},
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootProvider: serverRootProvider,
|
||||
},
|
||||
RequireClientCert: true,
|
||||
VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||
return &VerificationResults{}, nil
|
||||
},
|
||||
VType: CertVerification,
|
||||
}
|
||||
serverTLSCreds, err := NewServerCreds(serverOptions)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server creds: %v", err)
|
||||
}
|
||||
s := grpc.NewServer(grpc.Creds(serverTLSCreds))
|
||||
defer s.Stop()
|
||||
lis, err := net.Listen("tcp", port)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
defer lis.Close()
|
||||
pb.RegisterGreeterServer(s, greeterServer{})
|
||||
go s.Serve(lis)
|
||||
clientOptions := &ClientOptions{
|
||||
IdentityOptions: IdentityCertificateOptions{
|
||||
IdentityProvider: clientIdentityProvider,
|
||||
},
|
||||
VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) {
|
||||
return &VerificationResults{}, nil
|
||||
},
|
||||
RootOptions: RootCertificateOptions{
|
||||
RootProvider: clientRootProvider,
|
||||
},
|
||||
VType: CertVerification,
|
||||
}
|
||||
clientTLSCreds, err := NewClientCreds(clientOptions)
|
||||
if err != nil {
|
||||
t.Fatalf("clientTLSCreds failed to create, error: %v", err)
|
||||
}
|
||||
|
||||
// At initialization, the connection should be good.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
conn, greetClient, err := callAndVerifyWithClientConn(ctx, "rpc call 1", clientTLSCreds, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
// Make the identity cert change, and wait 1 second for the provider to
|
||||
// pick up the change.
|
||||
test.certUpdateFunc()
|
||||
time.Sleep(sleepInterval)
|
||||
// The already-established connection should not be affected.
|
||||
err = callAndVerify("rpc call 2", greetClient, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// New connections should still be good, because the Provider didn't pick
|
||||
// up the changes due to key-cert mismatch.
|
||||
conn2, greetClient, err := callAndVerifyWithClientConn(ctx, "rpc call 3", clientTLSCreds, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn2.Close()
|
||||
// Make the identity key change, and wait 1 second for the provider to
|
||||
// pick up the change.
|
||||
test.keyUpdateFunc()
|
||||
time.Sleep(sleepInterval)
|
||||
// New connections should fail now, because the Provider picked the
|
||||
// change, and *_cert_2.pem is not trusted by *_trust_cert_1.pem on the
|
||||
// other side.
|
||||
shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
|
||||
defer shortCancel()
|
||||
conn3, greetClient, err := callAndVerifyWithClientConn(shortCtx, "rpc call 4", clientTLSCreds, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn3.Close()
|
||||
// Make the trust cert change on the other side, and wait 1 second for
|
||||
// the provider to pick up the change.
|
||||
test.trustCertUpdateFunc()
|
||||
time.Sleep(sleepInterval)
|
||||
// New connections should be good, because the other side is using
|
||||
// *_trust_cert_2.pem now.
|
||||
conn4, greetClient, err := callAndVerifyWithClientConn(ctx, "rpc call 5", clientTLSCreds, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn4.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue