From 17e2cbe88713844032125f97f1174fcb2fe7faf0 Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Mon, 14 Dec 2020 09:00:45 -0800 Subject: [PATCH] credentials/xds: ServerHandshake() implementation (#4089) --- credentials/xds/xds.go | 167 +++++++--- credentials/xds/xds_client_test.go | 87 +++-- credentials/xds/xds_server_test.go | 492 +++++++++++++++++++++++++++++ go.sum | 1 + 4 files changed, 680 insertions(+), 67 deletions(-) create mode 100644 credentials/xds/xds_server_test.go diff --git a/credentials/xds/xds.go b/credentials/xds/xds.go index 2cc52ce6b..666da9918 100644 --- a/credentials/xds/xds.go +++ b/credentials/xds/xds.go @@ -33,6 +33,7 @@ import ( "fmt" "net" "sync" + "time" "google.golang.org/grpc/attributes" "google.golang.org/grpc/credentials" @@ -50,9 +51,9 @@ func init() { // credentials implementation. type ClientOptions struct { // FallbackCreds specifies the fallback credentials to be used when either - // the `xds` scheme is not used in the user's dial target or when the xDS - // server does not return any security configuration. Attempts to create - // client credentials without a fallback credentials will fail. + // the `xds` scheme is not used in the user's dial target or when the + // management server does not return any security configuration. Attempts to + // create client credentials without fallback credentials will fail. FallbackCreds credentials.TransportCredentials } @@ -68,6 +69,27 @@ func NewClientCredentials(opts ClientOptions) (credentials.TransportCredentials, }, nil } +// ServerOptions contains parameters to configure a new server-side xDS +// credentials implementation. +type ServerOptions struct { + // FallbackCreds specifies the fallback credentials to be used when the + // management server does not return any security configuration. Attempts to + // create server credentials without fallback credentials will fail. + FallbackCreds credentials.TransportCredentials +} + +// NewServerCredentials returns a new server-side transport credentials +// implementation which uses xDS APIs to fetch its security configuration. +func NewServerCredentials(opts ServerOptions) (credentials.TransportCredentials, error) { + if opts.FallbackCreds == nil { + return nil, errors.New("missing fallback credentials") + } + return &credsImpl{ + isClient: false, + fallback: opts.FallbackCreds, + }, nil +} + // credsImpl is an implementation of the credentials.TransportCredentials // interface which uses xDS APIs to fetch its security configuration. type credsImpl struct { @@ -98,11 +120,15 @@ func getHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo { // responsible for populating these fields. // // Safe for concurrent access. +// +// TODO(easwars): Move this type and any other non-user functionality to an +// internal package. type HandshakeInfo struct { - mu sync.Mutex - rootProvider certprovider.Provider - identityProvider certprovider.Provider - acceptedSANs map[string]bool // Only on the client side. + mu sync.Mutex + rootProvider certprovider.Provider + identityProvider certprovider.Provider + acceptedSANs map[string]bool // Only on the client side. + requireClientCert bool // Only on server side. } // SetRootCertProvider updates the root certificate provider. @@ -129,6 +155,14 @@ func (hi *HandshakeInfo) SetAcceptedSANs(sans []string) { hi.mu.Unlock() } +// SetRequireClientCert updates whether a client cert is required during the +// ServerHandshake(). A value of true indicates that we are performing mTLS. +func (hi *HandshakeInfo) SetRequireClientCert(require bool) { + hi.mu.Lock() + hi.requireClientCert = require + hi.mu.Unlock() +} + // UseFallbackCreds returns true when fallback credentials are to be used based // on the contents of the HandshakeInfo. func (hi *HandshakeInfo) UseFallbackCreds() bool { @@ -141,27 +175,13 @@ func (hi *HandshakeInfo) UseFallbackCreds() bool { return hi.identityProvider == nil && hi.rootProvider == nil } -func (hi *HandshakeInfo) validate(isClient bool) error { +func (hi *HandshakeInfo) makeClientSideTLSConfig(ctx context.Context) (*tls.Config, error) { hi.mu.Lock() - defer hi.mu.Unlock() - // On the client side, rootProvider is mandatory. IdentityProvider is // optional based on whether the client is doing TLS or mTLS. - if isClient && hi.rootProvider == nil { - return errors.New("xds: CertificateProvider to fetch trusted roots is missing, cannot perform TLS handshake. Please check configuration on the management server") + if hi.rootProvider == nil { + return nil, errors.New("xds: CertificateProvider to fetch trusted roots is missing, cannot perform TLS handshake. Please check configuration on the management server") } - - // On the server side, identityProvider is mandatory. RootProvider is - // optional based on whether the server is doing TLS or mTLS. - if !isClient && hi.identityProvider == nil { - return errors.New("xds: CertificateProvider to fetch identity certificate is missing, cannot perform TLS handshake. Please check configuration on the management server") - } - - return nil -} - -func (hi *HandshakeInfo) makeTLSConfig(ctx context.Context) (*tls.Config, error) { - hi.mu.Lock() // Since the call to KeyMaterial() can block, we read the providers under // the lock but call the actual function after releasing the lock. rootProv, idProv := hi.rootProvider, hi.identityProvider @@ -173,13 +193,13 @@ func (hi *HandshakeInfo) makeTLSConfig(ctx context.Context) (*tls.Config, error) // includes hostname verification) or none. We are forced to go with the // latter and perform the normal cert validation ourselves. cfg := &tls.Config{InsecureSkipVerify: true} - if rootProv != nil { - km, err := rootProv.KeyMaterial(ctx) - if err != nil { - return nil, fmt.Errorf("xds: fetching trusted roots from CertificateProvider failed: %v", err) - } - cfg.RootCAs = km.Roots + + km, err := rootProv.KeyMaterial(ctx) + if err != nil { + return nil, fmt.Errorf("xds: fetching trusted roots from CertificateProvider failed: %v", err) } + cfg.RootCAs = km.Roots + if idProv != nil { km, err := idProv.KeyMaterial(ctx) if err != nil { @@ -190,6 +210,39 @@ func (hi *HandshakeInfo) makeTLSConfig(ctx context.Context) (*tls.Config, error) return cfg, nil } +func (hi *HandshakeInfo) makeServerSideTLSConfig(ctx context.Context) (*tls.Config, error) { + cfg := &tls.Config{ClientAuth: tls.NoClientCert} + hi.mu.Lock() + // On the server side, identityProvider is mandatory. RootProvider is + // optional based on whether the server is doing TLS or mTLS. + if hi.identityProvider == nil { + return nil, errors.New("xds: CertificateProvider to fetch identity certificate is missing, cannot perform TLS handshake. Please check configuration on the management server") + } + // Since the call to KeyMaterial() can block, we read the providers under + // the lock but call the actual function after releasing the lock. + rootProv, idProv := hi.rootProvider, hi.identityProvider + if hi.requireClientCert { + cfg.ClientAuth = tls.RequireAndVerifyClientCert + } + hi.mu.Unlock() + + // identityProvider is mandatory on the server side. + km, err := idProv.KeyMaterial(ctx) + if err != nil { + return nil, fmt.Errorf("xds: fetching identity certificates from CertificateProvider failed: %v", err) + } + cfg.Certificates = km.Certs + + if rootProv != nil { + km, err := rootProv.KeyMaterial(ctx) + if err != nil { + return nil, fmt.Errorf("xds: fetching trusted roots from CertificateProvider failed: %v", err) + } + cfg.ClientCAs = km.Roots + } + return cfg, nil +} + func (hi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool { if len(hi.acceptedSANs) == 0 { // An empty list of acceptedSANs means "accept everything". @@ -265,9 +318,6 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo if hi.UseFallbackCreds() { return c.fallback.ClientHandshake(ctx, authority, rawConn) } - if err := hi.validate(c.isClient); err != nil { - return nil, nil, err - } // We build the tls.Config with the following values // 1. Root certificate as returned by the root provider. @@ -281,7 +331,7 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo // 4. Key usage to match whether client/server usage. // 5. A `VerifyPeerCertificate` function which performs normal peer // cert verification using configured roots, and the custom SAN checks. - cfg, err := hi.makeTLSConfig(ctx) + cfg, err := hi.makeClientSideTLSConfig(ctx) if err != nil { return nil, nil, err } @@ -349,12 +399,55 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo } // ServerHandshake performs the TLS handshake on the server-side. -func (c *credsImpl) ServerHandshake(net.Conn) (net.Conn, credentials.AuthInfo, error) { +func (c *credsImpl) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { if c.isClient { return nil, nil, errors.New("ServerHandshake is not supported for client credentials") } - // TODO(easwars): Implement along with server side xDS implementation. - return nil, nil, errors.New("not implemented") + + // An xds-enabled gRPC server wraps the underlying raw net.Conn in a type + // that provides a way to retrieve `HandshakeInfo`, which contains the + // certificate providers to be used during the handshake. If the net.Conn + // passed to this function does not implement this interface, or if the + // `HandshakeInfo` does not contain the information we are looking for, we + // delegate the handshake to the fallback credentials. + hiConn, ok := rawConn.(interface{ XDSHandshakeInfo() *HandshakeInfo }) + if !ok { + return c.fallback.ServerHandshake(rawConn) + } + hi := hiConn.XDSHandshakeInfo() + if hi.UseFallbackCreds() { + return c.fallback.ServerHandshake(rawConn) + } + + // An xds-enabled gRPC server is expected to wrap the underlying raw + // net.Conn in a type which provides a way to retrieve the deadline set on + // it. If we cannot retrieve the deadline here, we fail (by setting deadline + // to time.Now()), instead of using a default deadline and possibly taking + // longer to eventually fail. + deadline := time.Now() + if dConn, ok := rawConn.(interface{ GetDeadline() time.Time }); ok { + deadline = dConn.GetDeadline() + } + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + cfg, err := hi.makeServerSideTLSConfig(ctx) + if err != nil { + return nil, nil, err + } + + conn := tls.Server(rawConn, cfg) + if err := conn.Handshake(); err != nil { + conn.Close() + return nil, nil, err + } + info := credentials.TLSInfo{ + State: conn.ConnectionState(), + CommonAuthInfo: credentials.CommonAuthInfo{ + SecurityLevel: credentials.PrivacyAndIntegrity, + }, + } + info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState()) + return credinternal.WrapSyscallConn(rawConn, conn), info, nil } // Info provides the ProtocolInfo of this TransportCredentials. diff --git a/credentials/xds/xds_client_test.go b/credentials/xds/xds_client_test.go index 07bc48b3f..f22579c6b 100644 --- a/credentials/xds/xds_client_test.go +++ b/credentials/xds/xds_client_test.go @@ -40,9 +40,10 @@ import ( ) const ( - defaultTestTimeout = 1 * time.Second - defaultTestCertSAN = "*.test.example.com" - authority = "authority" + defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond + defaultTestCertSAN = "*.test.example.com" + authority = "authority" ) type s struct { @@ -133,17 +134,6 @@ func (ts *testServer) stop() { ts.lis.Close() } -// A handshake function which simulates a handshake timeout. Tests usually pass -// `defaultTestTimeout` to the ClientHandshake() method. This function just -// hangs around for twice that duration, thus making sure that the context -// passes to the credentials code times out. -func testServerTLSHandshakeTimeout(_ net.Conn) handshakeResult { - ctx, cancel := context.WithTimeout(context.Background(), 2*defaultTestTimeout) - <-ctx.Done() - cancel() - return handshakeResult{err: ctx.Err()} -} - // A handshake function which simulates a successful handshake without client // authentication (server does not request for client certificate during the // handshake here). @@ -239,7 +229,7 @@ func newTestContextWithHandshakeInfo(parent context.Context, root, identity cert // compareAuthInfo compares the AuthInfo received on the client side after a // successful handshake with the authInfo available on the testServer. -func compareAuthInfo(ts *testServer, ai credentials.AuthInfo) error { +func compareAuthInfo(ctx context.Context, ts *testServer, ai credentials.AuthInfo) error { if ai.AuthType() != "tls" { return fmt.Errorf("ClientHandshake returned authType %q, want %q", ai.AuthType(), "tls") } @@ -251,8 +241,6 @@ func compareAuthInfo(ts *testServer, ai credentials.AuthInfo) error { // Read the handshake result from the testServer which contains the TLS // connection state and compare it with the one received on the client-side. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() val, err := ts.hsResult.Receive(ctx) if err != nil { return fmt.Errorf("testServer failed to return handshake result: %v", err) @@ -341,7 +329,7 @@ func (s) TestClientCredsProviderFailure(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider) - if _, _, err := creds.ClientHandshake(ctx, authority, nil); !strings.Contains(err.Error(), test.wantErr) { + if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) { t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr) } }) @@ -410,13 +398,59 @@ func (s) TestClientCredsSuccess(t *testing.T) { if err != nil { t.Fatalf("ClientHandshake() returned failed: %q", err) } - if err := compareAuthInfo(ts, ai); err != nil { + if err := compareAuthInfo(ctx, ts, ai); err != nil { t.Fatal(err) } }) } } +func (s) TestClientCredsHandshakeTimeout(t *testing.T) { + clientDone := make(chan struct{}) + // A handshake function which simulates a handshake timeout from the + // server-side by simply blocking on the client-side handshake to timeout + // and not writing any handshake data. + hErr := errors.New("server handshake error") + ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { + <-clientDone + return handshakeResult{err: hErr} + }) + defer ts.stop() + + opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} + creds, err := NewClientCredentials(opts) + if err != nil { + t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) + } + + conn, err := net.Dial("tcp", ts.address) + if err != nil { + t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) + } + defer conn.Close() + + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + ctx := newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN) + if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil { + t.Fatal("ClientHandshake() succeeded when expected to timeout") + } + close(clientDone) + + // Read the handshake result from the testServer and make sure the expected + // error is returned. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + val, err := ts.hsResult.Receive(ctx) + if err != nil { + t.Fatalf("testServer failed to return handshake result: %v", err) + } + hsr := val.(handshakeResult) + if hsr.err != hErr { + t.Fatalf("testServer handshake returned error: %v, want: %v", hsr.err, hErr) + } +} + // TestClientCredsHandshakeFailure verifies different handshake failure cases. func (s) TestClientCredsHandshakeFailure(t *testing.T) { tests := []struct { @@ -433,13 +467,6 @@ func (s) TestClientCredsHandshakeFailure(t *testing.T) { san: defaultTestCertSAN, wantErr: "x509: certificate signed by unknown authority", }, - { - desc: "handshake times out", - handshakeFunc: testServerTLSHandshakeTimeout, - rootProvider: makeRootProvider(t, "x509/server_ca_cert.pem"), - san: defaultTestCertSAN, - wantErr: "context deadline exceeded", - }, { desc: "SAN mismatch", handshakeFunc: testServerTLSHandshake, @@ -534,13 +561,13 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) { if err != nil { t.Fatalf("ClientHandshake() returned failed: %q", err) } - if err := compareAuthInfo(ts, ai); err != nil { + if err := compareAuthInfo(ctx, ts, ai); err != nil { t.Fatal(err) } } -// TestClone verifies the Clone() method. -func (s) TestClone(t *testing.T) { +// TestClientClone verifies the Clone() method on client credentials. +func (s) TestClientClone(t *testing.T) { opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} orig, err := NewClientCredentials(opts) if err != nil { @@ -549,7 +576,7 @@ func (s) TestClone(t *testing.T) { // The credsImpl does not have any exported fields, and it does not make // sense to use any cmp options to look deep into. So, all we make sure here - // is that the cloned object points to a different locaiton in memory. + // is that the cloned object points to a different location in memory. if clone := orig.Clone(); clone == orig { t.Fatal("return value from Clone() doesn't point to new credentials instance") } diff --git a/credentials/xds/xds_server_test.go b/credentials/xds/xds_server_test.go new file mode 100644 index 000000000..b9c62dbd5 --- /dev/null +++ b/credentials/xds/xds_server_test.go @@ -0,0 +1,492 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io/ioutil" + "net" + "strings" + "testing" + "time" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/tls/certprovider" + "google.golang.org/grpc/testdata" +) + +func makeClientTLSConfig(t *testing.T, mTLS bool) *tls.Config { + t.Helper() + + pemData, err := ioutil.ReadFile(testdata.Path("x509/server_ca_cert.pem")) + if err != nil { + t.Fatal(err) + } + roots := x509.NewCertPool() + roots.AppendCertsFromPEM(pemData) + + var certs []tls.Certificate + if mTLS { + cert, err := tls.LoadX509KeyPair(testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) + if err != nil { + t.Fatal(err) + } + certs = append(certs, cert) + } + + return &tls.Config{ + Certificates: certs, + RootCAs: roots, + ServerName: "*.test.example.com", + // Setting this to true completely turns off the certificate validation + // on the client side. So, the client side handshake always seems to + // succeed. But if we want to turn this ON, we will need to generate + // certificates which work with localhost, or supply a custom + // verification function. So, the server credentials tests will rely + // solely on the success/failure of the server-side handshake. + InsecureSkipVerify: true, + } +} + +// Helper function to create a real TLS server credentials which is used as +// fallback credentials from multiple tests. +func makeFallbackServerCreds(t *testing.T) credentials.TransportCredentials { + t.Helper() + + creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) + if err != nil { + t.Fatal(err) + } + return creds +} + +type errorCreds struct { + credentials.TransportCredentials +} + +// TestServerCredsWithoutFallback verifies that the call to +// NewServerCredentials() fails when no fallback is specified. +func (s) TestServerCredsWithoutFallback(t *testing.T) { + if _, err := NewServerCredentials(ServerOptions{}); err == nil { + t.Fatal("NewServerCredentials() succeeded without specifying fallback") + } +} + +type wrapperConn struct { + net.Conn + xdsHI *HandshakeInfo + deadline time.Time +} + +func (wc *wrapperConn) XDSHandshakeInfo() *HandshakeInfo { + return wc.xdsHI +} + +func (wc *wrapperConn) GetDeadline() time.Time { + return wc.deadline +} + +func newWrappedConn(conn net.Conn, xdsHI *HandshakeInfo, deadline time.Time) *wrapperConn { + return &wrapperConn{Conn: conn, xdsHI: xdsHI, deadline: deadline} +} + +// TestServerCredsInvalidHandshakeInfo verifies scenarios where the passed in +// HandshakeInfo is invalid because it does not contain the expected certificate +// providers. +func (s) TestServerCredsInvalidHandshakeInfo(t *testing.T) { + opts := ServerOptions{FallbackCreds: &errorCreds{}} + creds, err := NewServerCredentials(opts) + if err != nil { + t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) + } + + info := NewHandshakeInfo(&fakeProvider{}, nil) + conn := newWrappedConn(nil, info, time.Time{}) + if _, _, err := creds.ServerHandshake(conn); err == nil { + t.Fatal("ServerHandshake succeeded without identity certificate provider in HandshakeInfo") + } +} + +// TestServerCredsProviderFailure verifies the cases where an expected +// certificate provider is missing in the HandshakeInfo value in the context. +func (s) TestServerCredsProviderFailure(t *testing.T) { + opts := ServerOptions{FallbackCreds: &errorCreds{}} + creds, err := NewServerCredentials(opts) + if err != nil { + t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) + } + + tests := []struct { + desc string + rootProvider certprovider.Provider + identityProvider certprovider.Provider + wantErr string + }{ + { + desc: "erroring identity provider", + identityProvider: &fakeProvider{err: errors.New("identity provider error")}, + wantErr: "identity provider error", + }, + { + desc: "erroring root provider", + identityProvider: &fakeProvider{km: &certprovider.KeyMaterial{}}, + rootProvider: &fakeProvider{err: errors.New("root provider error")}, + wantErr: "root provider error", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + info := NewHandshakeInfo(test.rootProvider, test.identityProvider) + conn := newWrappedConn(nil, info, time.Time{}) + if _, _, err := creds.ServerHandshake(conn); err == nil || !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("ServerHandshake() returned error: %q, wantErr: %q", err, test.wantErr) + } + }) + } +} + +// TestServerCredsHandshakeTimeout verifies the case where the client does not +// send required handshake data before the deadline set on the net.Conn passed +// to ServerHandshake(). +func (s) TestServerCredsHandshakeTimeout(t *testing.T) { + opts := ServerOptions{FallbackCreds: &errorCreds{}} + creds, err := NewServerCredentials(opts) + if err != nil { + t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) + } + + // Create a test server which uses the xDS server credentials created above + // to perform TLS handshake on incoming connections. + ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { + hi := NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem")) + hi.SetRequireClientCert(true) + + // Create a wrapped conn which can return the HandshakeInfo created + // above with a very small deadline. + d := time.Now().Add(defaultTestShortTimeout) + rawConn.SetDeadline(d) + conn := newWrappedConn(rawConn, hi, d) + + // ServerHandshake() on the xDS credentials is expected to fail. + if _, _, err := creds.ServerHandshake(conn); err == nil { + return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to timeout")} + } + return handshakeResult{} + }) + defer ts.stop() + + // Dial the test server, but don't trigger the TLS handshake. This will + // cause ServerHandshake() to fail. + rawConn, err := net.Dial("tcp", ts.address) + if err != nil { + t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) + } + defer rawConn.Close() + + // Read handshake result from the testServer and expect a failure result. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + val, err := ts.hsResult.Receive(ctx) + if err != nil { + t.Fatalf("testServer failed to return handshake result: %v", err) + } + hsr := val.(handshakeResult) + if hsr.err != nil { + t.Fatalf("testServer handshake failure: %v", hsr.err) + } +} + +// TestServerCredsHandshakeFailure verifies the case where the server-side +// credentials uses a root certificate which does not match the certificate +// presented by the client, and hence the handshake must fail. +func (s) TestServerCredsHandshakeFailure(t *testing.T) { + opts := ServerOptions{FallbackCreds: &errorCreds{}} + creds, err := NewServerCredentials(opts) + if err != nil { + t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) + } + + // Create a test server which uses the xDS server credentials created above + // to perform TLS handshake on incoming connections. + ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { + // Create a HandshakeInfo which has a root provider which does not match + // the certificate sent by the client. + hi := NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem")) + hi.SetRequireClientCert(true) + + // Create a wrapped conn which can return the HandshakeInfo and + // configured deadline to the xDS credentials' ServerHandshake() + // method. + conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout)) + + // ServerHandshake() on the xDS credentials is expected to fail. + if _, _, err := creds.ServerHandshake(conn); err == nil { + return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to fail")} + } + return handshakeResult{} + }) + defer ts.stop() + + // Dial the test server, and trigger the TLS handshake. + rawConn, err := net.Dial("tcp", ts.address) + if err != nil { + t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) + } + defer rawConn.Close() + tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, true)) + tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout)) + if err := tlsConn.Handshake(); err != nil { + t.Fatal(err) + } + + // Read handshake result from the testServer which will return an error if + // the handshake succeeded. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + val, err := ts.hsResult.Receive(ctx) + if err != nil { + t.Fatalf("testServer failed to return handshake result: %v", err) + } + hsr := val.(handshakeResult) + if hsr.err != nil { + t.Fatalf("testServer handshake failure: %v", hsr.err) + } +} + +// TestServerCredsHandshakeSuccess verifies success handshake cases. +func (s) TestServerCredsHandshakeSuccess(t *testing.T) { + tests := []struct { + desc string + fallbackCreds credentials.TransportCredentials + rootProvider certprovider.Provider + identityProvider certprovider.Provider + requireClientCert bool + }{ + { + desc: "fallback", + fallbackCreds: makeFallbackServerCreds(t), + }, + { + desc: "TLS", + fallbackCreds: &errorCreds{}, + identityProvider: makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), + }, + { + desc: "mTLS", + fallbackCreds: &errorCreds{}, + identityProvider: makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), + rootProvider: makeRootProvider(t, "x509/client_ca_cert.pem"), + requireClientCert: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + // Create an xDS server credentials. + opts := ServerOptions{FallbackCreds: test.fallbackCreds} + creds, err := NewServerCredentials(opts) + if err != nil { + t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) + } + + // Create a test server which uses the xDS server credentials + // created above to perform TLS handshake on incoming connections. + ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { + // Create a HandshakeInfo with information from the test table. + hi := NewHandshakeInfo(test.rootProvider, test.identityProvider) + hi.SetRequireClientCert(test.requireClientCert) + + // Create a wrapped conn which can return the HandshakeInfo and + // configured deadline to the xDS credentials' ServerHandshake() + // method. + conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout)) + + // Invoke the ServerHandshake() method on the xDS credentials + // and make some sanity checks before pushing the result for + // inspection by the main test body. + _, ai, err := creds.ServerHandshake(conn) + if err != nil { + return handshakeResult{err: fmt.Errorf("ServerHandshake() failed: %v", err)} + } + if ai.AuthType() != "tls" { + return handshakeResult{err: fmt.Errorf("ServerHandshake returned authType %q, want %q", ai.AuthType(), "tls")} + } + info, ok := ai.(credentials.TLSInfo) + if !ok { + return handshakeResult{err: fmt.Errorf("ServerHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})} + } + return handshakeResult{connState: info.State} + }) + defer ts.stop() + + // Dial the test server, and trigger the TLS handshake. + rawConn, err := net.Dial("tcp", ts.address) + if err != nil { + t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) + } + defer rawConn.Close() + tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, test.requireClientCert)) + tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout)) + if err := tlsConn.Handshake(); err != nil { + t.Fatal(err) + } + + // Read the handshake result from the testServer which contains the + // TLS connection state on the server-side and compare it with the + // one received on the client-side. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + val, err := ts.hsResult.Receive(ctx) + if err != nil { + t.Fatalf("testServer failed to return handshake result: %v", err) + } + hsr := val.(handshakeResult) + if hsr.err != nil { + t.Fatalf("testServer handshake failure: %v", hsr.err) + } + + // AuthInfo contains a variety of information. We only verify a + // subset here. This is the same subset which is verified in TLS + // credentials tests. + if err := compareConnState(tlsConn.ConnectionState(), hsr.connState); err != nil { + t.Fatal(err) + } + }) + } +} + +func (s) TestServerCredsProviderSwitch(t *testing.T) { + opts := ServerOptions{FallbackCreds: &errorCreds{}} + creds, err := NewServerCredentials(opts) + if err != nil { + t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) + } + + // The first time the handshake function is invoked, it returns a + // HandshakeInfo which is expected to fail. Further invocations return a + // HandshakeInfo which is expected to succeed. + cnt := 0 + // Create a test server which uses the xDS server credentials created above + // to perform TLS handshake on incoming connections. + ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { + cnt++ + var hi *HandshakeInfo + if cnt == 1 { + // Create a HandshakeInfo which has a root provider which does not match + // the certificate sent by the client. + hi = NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem")) + hi.SetRequireClientCert(true) + + // Create a wrapped conn which can return the HandshakeInfo and + // configured deadline to the xDS credentials' ServerHandshake() + // method. + conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout)) + + // ServerHandshake() on the xDS credentials is expected to fail. + if _, _, err := creds.ServerHandshake(conn); err == nil { + return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to fail")} + } + return handshakeResult{} + } + + hi = NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem")) + hi.SetRequireClientCert(true) + + // Create a wrapped conn which can return the HandshakeInfo and + // configured deadline to the xDS credentials' ServerHandshake() + // method. + conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout)) + + // Invoke the ServerHandshake() method on the xDS credentials + // and make some sanity checks before pushing the result for + // inspection by the main test body. + _, ai, err := creds.ServerHandshake(conn) + if err != nil { + return handshakeResult{err: fmt.Errorf("ServerHandshake() failed: %v", err)} + } + if ai.AuthType() != "tls" { + return handshakeResult{err: fmt.Errorf("ServerHandshake returned authType %q, want %q", ai.AuthType(), "tls")} + } + info, ok := ai.(credentials.TLSInfo) + if !ok { + return handshakeResult{err: fmt.Errorf("ServerHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})} + } + return handshakeResult{connState: info.State} + }) + defer ts.stop() + + for i := 0; i < 5; i++ { + // Dial the test server, and trigger the TLS handshake. + rawConn, err := net.Dial("tcp", ts.address) + if err != nil { + t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) + } + defer rawConn.Close() + tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, true)) + tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout)) + if err := tlsConn.Handshake(); err != nil { + t.Fatal(err) + } + + // Read the handshake result from the testServer which contains the + // TLS connection state on the server-side and compare it with the + // one received on the client-side. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + val, err := ts.hsResult.Receive(ctx) + if err != nil { + t.Fatalf("testServer failed to return handshake result: %v", err) + } + hsr := val.(handshakeResult) + if hsr.err != nil { + t.Fatalf("testServer handshake failure: %v", hsr.err) + } + if i == 0 { + // We expect the first handshake to fail. So, we skip checks which + // compare connection state. + continue + } + // AuthInfo contains a variety of information. We only verify a + // subset here. This is the same subset which is verified in TLS + // credentials tests. + if err := compareConnState(tlsConn.ConnectionState(), hsr.connState); err != nil { + t.Fatal(err) + } + } +} + +// TestServerClone verifies the Clone() method on client credentials. +func (s) TestServerClone(t *testing.T) { + opts := ServerOptions{FallbackCreds: makeFallbackServerCreds(t)} + orig, err := NewServerCredentials(opts) + if err != nil { + t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) + } + + // The credsImpl does not have any exported fields, and it does not make + // sense to use any cmp options to look deep into. So, all we make sure here + // is that the cloned object points to a different location in memory. + if clone := orig.Clone(); clone == orig { + t.Fatal("return value from Clone() doesn't point to new credentials instance") + } +} diff --git a/go.sum b/go.sum index e437e1300..77ee70b44 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,7 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1: github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=