From ce5e36655676f214220d2fb6ce831f7ba91897fe Mon Sep 17 00:00:00 2001 From: ZhenLian Date: Mon, 19 Oct 2020 13:54:02 -0700 Subject: [PATCH] advancedtls: add PemFileProvider integration tests (#3934) * advancedtls: add PemFileProvider integration tests --- .../advancedtls_integration_test.go | 342 +++++++++++++++++- 1 file changed, 330 insertions(+), 12 deletions(-) diff --git a/security/advancedtls/advancedtls_integration_test.go b/security/advancedtls/advancedtls_integration_test.go index 3f4e7059a..20a9a5857 100644 --- a/security/advancedtls/advancedtls_integration_test.go +++ b/security/advancedtls/advancedtls_integration_test.go @@ -23,7 +23,9 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io/ioutil" "net" + "os" "sync" "testing" "time" @@ -32,6 +34,7 @@ import ( "google.golang.org/grpc/credentials" pb "google.golang.org/grpc/examples/helloworld/helloworld" "google.golang.org/grpc/security/advancedtls/internal/testutils" + "google.golang.org/grpc/security/advancedtls/testdata" ) var ( @@ -39,6 +42,17 @@ var ( port = ":50051" ) +const ( + // Default timeout for normal connections. + defaultTestTimeout = 5 * time.Second + // Default timeout for failed connections. + defaultTestShortTimeout = 10 * time.Millisecond + // Intervals that set to monitor the credential updates. + credRefreshingInterval = 200 * time.Millisecond + // Time we wait for the credential updates to be picked up. + sleepInterval = 400 * time.Millisecond +) + // stageInfo contains a stage number indicating the current phase of each // integration test, and a mutex. // Based on the stage number of current test, we will use different @@ -76,6 +90,8 @@ func (greeterServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.Hel return &pb.HelloReply{Message: "Hello " + in.Name}, nil } +// TODO(ZhenLian): remove shouldFail to the function signature to provider +// tests. func callAndVerify(msg string, client pb.GreeterClient, shouldFail bool) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -86,6 +102,8 @@ func callAndVerify(msg string, client pb.GreeterClient, shouldFail bool) error { return nil } +// TODO(ZhenLian): remove shouldFail and add ...DialOption to the function +// signature to provider cleaner tests. func callAndVerifyWithClientConn(connCtx context.Context, msg string, creds credentials.TransportCredentials, shouldFail bool) (*grpc.ClientConn, pb.GreeterClient, error) { var conn *grpc.ClientConn var err error @@ -153,7 +171,7 @@ func (s) TestEnd2End(t *testing.T) { // should see it again accepts the connection, since ClientCert2 is trusted // by ServerTrust2. { - desc: "TestClientPeerCertReloadServerTrustCertReload", + desc: "test the reloading feature for client identity callback and server trust callback", clientGetCert: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { switch stage.read() { case 0: @@ -194,7 +212,7 @@ func (s) TestEnd2End(t *testing.T) { // should see it again accepts the connection, since ServerCert2 is trusted // by ClientTrust2. { - desc: "TestServerPeerCertReloadClientTrustCertReload", + desc: "test the reloading feature for server identity callback and client trust callback", clientCert: []tls.Certificate{cs.ClientCert1}, clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) { switch stage.read() { @@ -236,7 +254,7 @@ func (s) TestEnd2End(t *testing.T) { // At stage 2, the client changes authorization check to only accept // ServerCert2. Now we should see the connection becomes normal again. { - desc: "TestClientCustomVerification", + desc: "test client custom verification", clientCert: []tls.Certificate{cs.ClientCert1}, clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) { switch stage.read() { @@ -368,9 +386,9 @@ func (s) TestEnd2End(t *testing.T) { } // ------------------------Scenario 1------------------------------------ // stage = 0, initial connection should succeed - ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel1() - conn, greetClient, err := callAndVerifyWithClientConn(ctx1, "rpc call 1", clientTLSCreds, false) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, greetClient, err := callAndVerifyWithClientConn(ctx, "rpc call 1", clientTLSCreds, false) if err != nil { t.Fatal(err) } @@ -385,9 +403,9 @@ func (s) TestEnd2End(t *testing.T) { } // ------------------------Scenario 3------------------------------------ // stage = 1, new connection should fail - ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel2() - conn2, greetClient, err := callAndVerifyWithClientConn(ctx2, "rpc call 3", clientTLSCreds, true) + shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shortCancel() + conn2, greetClient, err := callAndVerifyWithClientConn(shortCtx, "rpc call 3", clientTLSCreds, true) if err != nil { t.Fatal(err) } @@ -396,9 +414,7 @@ func (s) TestEnd2End(t *testing.T) { stage.increase() // ------------------------Scenario 4------------------------------------ // stage = 2, new connection should succeed - ctx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel3() - conn3, greetClient, err := callAndVerifyWithClientConn(ctx3, "rpc call 4", clientTLSCreds, false) + conn3, greetClient, err := callAndVerifyWithClientConn(ctx, "rpc call 4", clientTLSCreds, false) if err != nil { t.Fatal(err) } @@ -408,3 +424,305 @@ func (s) TestEnd2End(t *testing.T) { }) } } + +type tmpCredsFiles struct { + clientCertTmp *os.File + clientKeyTmp *os.File + clientTrustTmp *os.File + serverCertTmp *os.File + serverKeyTmp *os.File + serverTrustTmp *os.File +} + +// Create temp files that are used to hold credentials. +func createTmpFiles() (*tmpCredsFiles, error) { + tmpFiles := &tmpCredsFiles{} + var err error + tmpFiles.clientCertTmp, err = ioutil.TempFile(os.TempDir(), "pre-") + if err != nil { + return nil, err + } + tmpFiles.clientKeyTmp, err = ioutil.TempFile(os.TempDir(), "pre-") + if err != nil { + return nil, err + } + tmpFiles.clientTrustTmp, err = ioutil.TempFile(os.TempDir(), "pre-") + if err != nil { + return nil, err + } + tmpFiles.serverCertTmp, err = ioutil.TempFile(os.TempDir(), "pre-") + if err != nil { + return nil, err + } + tmpFiles.serverKeyTmp, err = ioutil.TempFile(os.TempDir(), "pre-") + if err != nil { + return nil, err + } + tmpFiles.serverTrustTmp, err = ioutil.TempFile(os.TempDir(), "pre-") + if err != nil { + return nil, err + } + return tmpFiles, nil +} + +// Copy the credential contents to the temporary files. +func (tmpFiles *tmpCredsFiles) copyCredsToTmpFiles() error { + if err := copyFileContents(testdata.Path("client_cert_1.pem"), tmpFiles.clientCertTmp.Name()); err != nil { + return err + } + if err := copyFileContents(testdata.Path("client_key_1.pem"), tmpFiles.clientKeyTmp.Name()); err != nil { + return err + } + if err := copyFileContents(testdata.Path("client_trust_cert_1.pem"), tmpFiles.clientTrustTmp.Name()); err != nil { + return err + } + if err := copyFileContents(testdata.Path("server_cert_1.pem"), tmpFiles.serverCertTmp.Name()); err != nil { + return err + } + if err := copyFileContents(testdata.Path("server_key_1.pem"), tmpFiles.serverKeyTmp.Name()); err != nil { + return err + } + if err := copyFileContents(testdata.Path("server_trust_cert_1.pem"), tmpFiles.serverTrustTmp.Name()); err != nil { + return err + } + return nil +} + +func (tmpFiles *tmpCredsFiles) removeFiles() { + os.Remove(tmpFiles.clientCertTmp.Name()) + os.Remove(tmpFiles.clientKeyTmp.Name()) + os.Remove(tmpFiles.clientTrustTmp.Name()) + os.Remove(tmpFiles.serverCertTmp.Name()) + os.Remove(tmpFiles.serverKeyTmp.Name()) + os.Remove(tmpFiles.serverTrustTmp.Name()) +} + +func copyFileContents(sourceFile, destinationFile string) error { + input, err := ioutil.ReadFile(sourceFile) + if err != nil { + return err + } + err = ioutil.WriteFile(destinationFile, input, 0644) + if err != nil { + return err + } + return nil +} + +// Create PEMFileProvider(s) watching the content changes of temporary +// files. +func createProviders(tmpFiles *tmpCredsFiles) (*PEMFileProvider, *PEMFileProvider, *PEMFileProvider, *PEMFileProvider, error) { + clientIdentityOptions := PEMFileProviderOptions{ + CertFile: tmpFiles.clientCertTmp.Name(), + KeyFile: tmpFiles.clientKeyTmp.Name(), + IdentityInterval: credRefreshingInterval, + } + clientIdentityProvider, err := NewPEMFileProvider(clientIdentityOptions) + if err != nil { + return nil, nil, nil, nil, err + } + clientRootOptions := PEMFileProviderOptions{ + TrustFile: tmpFiles.clientTrustTmp.Name(), + RootInterval: credRefreshingInterval, + } + clientRootProvider, err := NewPEMFileProvider(clientRootOptions) + if err != nil { + return nil, nil, nil, nil, err + } + serverIdentityOptions := PEMFileProviderOptions{ + CertFile: tmpFiles.serverCertTmp.Name(), + KeyFile: tmpFiles.serverKeyTmp.Name(), + IdentityInterval: credRefreshingInterval, + } + serverIdentityProvider, err := NewPEMFileProvider(serverIdentityOptions) + if err != nil { + return nil, nil, nil, nil, err + } + serverRootOptions := PEMFileProviderOptions{ + TrustFile: tmpFiles.serverTrustTmp.Name(), + RootInterval: credRefreshingInterval, + } + serverRootProvider, err := NewPEMFileProvider(serverRootOptions) + if err != nil { + return nil, nil, nil, nil, err + } + return clientIdentityProvider, clientRootProvider, serverIdentityProvider, serverRootProvider, nil +} + +// In order to test advanced TLS provider features, we used temporary files to +// hold credential data, and copy the contents under testdata/ to these tmp +// files. +// Initially, we establish a good connection with providers watching contents +// from tmp files. +// Next, we change the identity certs that IdentityProvider is watching. Since +// the identity key is not changed, the IdentityProvider should ignore the +// update, and the connection should still be good. +// Then the the identity key is changed. This time IdentityProvider should pick +// up the update, and the connection should fail, due to the trust certs on the +// other side is not changed. +// Finally, the trust certs that other-side's RootProvider is watching get +// changed. The connection should go back to normal again. +func (s) TestPEMFileProviderEnd2End(t *testing.T) { + tmpFiles, err := createTmpFiles() + if err != nil { + t.Fatalf("createTmpFiles() failed, error: %v", err) + } + defer tmpFiles.removeFiles() + for _, test := range []struct { + desc string + certUpdateFunc func() + keyUpdateFunc func() + trustCertUpdateFunc func() + }{ + { + desc: "test the reloading feature for clientIdentityProvider and serverTrustProvider", + certUpdateFunc: func() { + err = copyFileContents(testdata.Path("client_cert_2.pem"), tmpFiles.clientCertTmp.Name()) + if err != nil { + t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_cert_2.pem"), tmpFiles.clientCertTmp.Name(), err) + } + }, + keyUpdateFunc: func() { + err = copyFileContents(testdata.Path("client_key_2.pem"), tmpFiles.clientKeyTmp.Name()) + if err != nil { + t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_key_2.pem"), tmpFiles.clientKeyTmp.Name(), err) + } + }, + trustCertUpdateFunc: func() { + err = copyFileContents(testdata.Path("server_trust_cert_2.pem"), tmpFiles.serverTrustTmp.Name()) + if err != nil { + t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_trust_cert_2.pem"), tmpFiles.serverTrustTmp.Name(), err) + } + }, + }, + { + desc: "test the reloading feature for serverIdentityProvider and clientTrustProvider", + certUpdateFunc: func() { + err = copyFileContents(testdata.Path("server_cert_2.pem"), tmpFiles.serverCertTmp.Name()) + if err != nil { + t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_cert_2.pem"), tmpFiles.serverCertTmp.Name(), err) + } + }, + keyUpdateFunc: func() { + err = copyFileContents(testdata.Path("server_key_2.pem"), tmpFiles.serverKeyTmp.Name()) + if err != nil { + t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_key_2.pem"), tmpFiles.serverKeyTmp.Name(), err) + } + }, + trustCertUpdateFunc: func() { + err = copyFileContents(testdata.Path("client_trust_cert_2.pem"), tmpFiles.clientTrustTmp.Name()) + if err != nil { + t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_trust_cert_2.pem"), tmpFiles.clientTrustTmp.Name(), err) + } + }, + }, + } { + test := test + t.Run(test.desc, func(t *testing.T) { + if err := tmpFiles.copyCredsToTmpFiles(); err != nil { + t.Fatalf("tmpFiles.copyCredsToTmpFiles() failed, error: %v", err) + } + clientIdentityProvider, clientRootProvider, serverIdentityProvider, serverRootProvider, err := createProviders(tmpFiles) + if err != nil { + t.Fatalf("createProviders(%v) failed, error: %v", tmpFiles, err) + } + defer clientIdentityProvider.Close() + defer clientRootProvider.Close() + defer serverIdentityProvider.Close() + defer serverRootProvider.Close() + // Start a server and create a client using advancedtls API with Provider. + serverOptions := &ServerOptions{ + IdentityOptions: IdentityCertificateOptions{ + IdentityProvider: serverIdentityProvider, + }, + RootOptions: RootCertificateOptions{ + RootProvider: serverRootProvider, + }, + RequireClientCert: true, + VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) { + return &VerificationResults{}, nil + }, + VType: CertVerification, + } + serverTLSCreds, err := NewServerCreds(serverOptions) + if err != nil { + t.Fatalf("failed to create server creds: %v", err) + } + s := grpc.NewServer(grpc.Creds(serverTLSCreds)) + defer s.Stop() + lis, err := net.Listen("tcp", port) + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + defer lis.Close() + pb.RegisterGreeterServer(s, greeterServer{}) + go s.Serve(lis) + clientOptions := &ClientOptions{ + IdentityOptions: IdentityCertificateOptions{ + IdentityProvider: clientIdentityProvider, + }, + VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) { + return &VerificationResults{}, nil + }, + RootOptions: RootCertificateOptions{ + RootProvider: clientRootProvider, + }, + VType: CertVerification, + } + clientTLSCreds, err := NewClientCreds(clientOptions) + if err != nil { + t.Fatalf("clientTLSCreds failed to create, error: %v", err) + } + + // At initialization, the connection should be good. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, greetClient, err := callAndVerifyWithClientConn(ctx, "rpc call 1", clientTLSCreds, false) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + // Make the identity cert change, and wait 1 second for the provider to + // pick up the change. + test.certUpdateFunc() + time.Sleep(sleepInterval) + // The already-established connection should not be affected. + err = callAndVerify("rpc call 2", greetClient, false) + if err != nil { + t.Fatal(err) + } + // New connections should still be good, because the Provider didn't pick + // up the changes due to key-cert mismatch. + conn2, greetClient, err := callAndVerifyWithClientConn(ctx, "rpc call 3", clientTLSCreds, false) + if err != nil { + t.Fatal(err) + } + defer conn2.Close() + // Make the identity key change, and wait 1 second for the provider to + // pick up the change. + test.keyUpdateFunc() + time.Sleep(sleepInterval) + // New connections should fail now, because the Provider picked the + // change, and *_cert_2.pem is not trusted by *_trust_cert_1.pem on the + // other side. + shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shortCancel() + conn3, greetClient, err := callAndVerifyWithClientConn(shortCtx, "rpc call 4", clientTLSCreds, true) + if err != nil { + t.Fatal(err) + } + defer conn3.Close() + // Make the trust cert change on the other side, and wait 1 second for + // the provider to pick up the change. + test.trustCertUpdateFunc() + time.Sleep(sleepInterval) + // New connections should be good, because the other side is using + // *_trust_cert_2.pem now. + conn4, greetClient, err := callAndVerifyWithClientConn(ctx, "rpc call 5", clientTLSCreds, false) + if err != nil { + t.Fatal(err) + } + defer conn4.Close() + }) + } +}