diff --git a/credentials/google/google.go b/credentials/google/google.go index 7f3e240e4..265d193c7 100644 --- a/credentials/google/google.go +++ b/credentials/google/google.go @@ -99,6 +99,15 @@ func (c *creds) PerRPCCredentials() credentials.PerRPCCredentials { return c.perRPCCreds } +var ( + newTLS = func() credentials.TransportCredentials { + return credentials.NewTLS(nil) + } + newALTS = func() credentials.TransportCredentials { + return alts.NewClientCreds(alts.DefaultClientOptions()) + } +) + // NewWithMode should make a copy of Bundle, and switch mode. Modifying the // existing Bundle may cause races. func (c *creds) NewWithMode(mode string) (credentials.Bundle, error) { @@ -110,11 +119,11 @@ func (c *creds) NewWithMode(mode string) (credentials.Bundle, error) { // Create transport credentials. switch mode { case internal.CredsBundleModeFallback: - newCreds.transportCreds = credentials.NewTLS(nil) + newCreds.transportCreds = newClusterTransportCreds(newTLS(), newALTS()) case internal.CredsBundleModeBackendFromBalancer, internal.CredsBundleModeBalancer: // Only the clients can use google default credentials, so we only need // to create new ALTS client creds here. - newCreds.transportCreds = alts.NewClientCreds(alts.DefaultClientOptions()) + newCreds.transportCreds = newALTS() default: return nil, fmt.Errorf("unsupported mode: %v", mode) } diff --git a/credentials/google/google_test.go b/credentials/google/google_test.go new file mode 100644 index 000000000..c20445811 --- /dev/null +++ b/credentials/google/google_test.go @@ -0,0 +1,132 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package google + +import ( + "context" + "net" + "testing" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" + xdsinternal "google.golang.org/grpc/internal/credentials/xds" + "google.golang.org/grpc/resolver" +) + +type testCreds struct { + credentials.TransportCredentials + typ string +} + +func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, &testAuthInfo{typ: c.typ}, nil +} + +func (c *testCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, &testAuthInfo{typ: c.typ}, nil +} + +type testAuthInfo struct { + typ string +} + +func (t *testAuthInfo) AuthType() string { + return t.typ +} + +var ( + testTLS = &testCreds{typ: "tls"} + testALTS = &testCreds{typ: "alts"} + + contextWithHandshakeInfo = internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) +) + +func overrideNewCredsFuncs() func() { + oldNewTLS := newTLS + newTLS = func() credentials.TransportCredentials { + return testTLS + } + oldNewALTS := newALTS + newALTS = func() credentials.TransportCredentials { + return testALTS + } + return func() { + newTLS = oldNewTLS + newALTS = oldNewALTS + } +} + +// TestClientHandshakeBasedOnClusterName that by default (without switching +// modes), ClientHandshake does either tls or alts base on the cluster name in +// attributes. +func TestClientHandshakeBasedOnClusterName(t *testing.T) { + defer overrideNewCredsFuncs()() + for bundleTyp, tc := range map[string]credentials.Bundle{ + "defaultCreds": NewDefaultCredentials(), + "computeCreds": NewComputeEngineCredentials(), + } { + tests := []struct { + name string + ctx context.Context + wantTyp string + }{ + { + name: "no cluster name", + ctx: context.Background(), + wantTyp: "tls", + }, + { + name: "with non-CFE cluster name", + ctx: contextWithHandshakeInfo(context.Background(), credentials.ClientHandshakeInfo{ + Attributes: xdsinternal.SetHandshakeClusterName(resolver.Address{}, "lalala").Attributes, + }), + // non-CFE backends should use alts. + wantTyp: "alts", + }, + { + name: "with CFE cluster name", + ctx: contextWithHandshakeInfo(context.Background(), credentials.ClientHandshakeInfo{ + Attributes: xdsinternal.SetHandshakeClusterName(resolver.Address{}, cfeClusterName).Attributes, + }), + // CFE should use tls. + wantTyp: "tls", + }, + } + for _, tt := range tests { + t.Run(bundleTyp+" "+tt.name, func(t *testing.T) { + _, info, err := tc.TransportCredentials().ClientHandshake(tt.ctx, "", nil) + if err != nil { + t.Fatalf("ClientHandshake failed: %v", err) + } + if gotType := info.AuthType(); gotType != tt.wantTyp { + t.Fatalf("unexpected authtype: %v, want: %v", gotType, tt.wantTyp) + } + + _, infoServer, err := tc.TransportCredentials().ServerHandshake(nil) + if err != nil { + t.Fatalf("ClientHandshake failed: %v", err) + } + // ServerHandshake should always do TLS. + if gotType := infoServer.AuthType(); gotType != "tls" { + t.Fatalf("unexpected server authtype: %v, want: %v", gotType, "tls") + } + }) + } + } +} diff --git a/credentials/google/xds.go b/credentials/google/xds.go new file mode 100644 index 000000000..22997ce25 --- /dev/null +++ b/credentials/google/xds.go @@ -0,0 +1,90 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package google + +import ( + "context" + "net" + + "google.golang.org/grpc/credentials" + xdsinternal "google.golang.org/grpc/internal/credentials/xds" +) + +const cfeClusterName = "google-cfe" + +// clusterTransportCreds is a combo of TLS + ALTS. +// +// On the client, ClientHandshake picks TLS or ALTS based on address attributes. +// - if attributes has cluster name +// - if cluster name is "google_cfe", use TLS +// - otherwise, use ALTS +// - else, do TLS +// +// On the server, ServerHandshake always does TLS. +type clusterTransportCreds struct { + tls credentials.TransportCredentials + alts credentials.TransportCredentials +} + +func newClusterTransportCreds(tls, alts credentials.TransportCredentials) *clusterTransportCreds { + return &clusterTransportCreds{ + tls: tls, + alts: alts, + } +} + +func (c *clusterTransportCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + chi := credentials.ClientHandshakeInfoFromContext(ctx) + if chi.Attributes == nil { + return c.tls.ClientHandshake(ctx, authority, rawConn) + } + cn, ok := xdsinternal.GetHandshakeClusterName(chi.Attributes) + if !ok || cn == cfeClusterName { + return c.tls.ClientHandshake(ctx, authority, rawConn) + } + // If attributes have cluster name, and cluster name is not cfe, it's a + // backend address, use ALTS. + return c.alts.ClientHandshake(ctx, authority, rawConn) +} + +func (c *clusterTransportCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return c.tls.ServerHandshake(conn) +} + +func (c *clusterTransportCreds) Info() credentials.ProtocolInfo { + // TODO: this always returns tls.Info now, because we don't have a cluster + // name to check when this method is called. This method doesn't affect + // anything important now. We may want to revisit this if it becomes more + // important later. + return c.tls.Info() +} + +func (c *clusterTransportCreds) Clone() credentials.TransportCredentials { + return &clusterTransportCreds{ + tls: c.tls.Clone(), + alts: c.alts.Clone(), + } +} + +func (c *clusterTransportCreds) OverrideServerName(s string) error { + if err := c.tls.OverrideServerName(s); err != nil { + return err + } + return c.alts.OverrideServerName(s) +} diff --git a/internal/credentials/xds/handshake_cluster.go b/internal/credentials/xds/handshake_cluster.go new file mode 100644 index 000000000..cb059bd66 --- /dev/null +++ b/internal/credentials/xds/handshake_cluster.go @@ -0,0 +1,42 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds + +import ( + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/resolver" +) + +// handshakeClusterNameKey is the type used as the key to store cluster name in +// the Attributes field of resolver.Address. +type handshakeClusterNameKey struct{} + +// SetHandshakeClusterName returns a copy of addr in which the Attributes field +// is updated with the cluster name. +func SetHandshakeClusterName(addr resolver.Address, clusterName string) resolver.Address { + addr.Attributes = addr.Attributes.WithValues(handshakeClusterNameKey{}, clusterName) + return addr +} + +// GetHandshakeClusterName returns cluster name stored in attr. +func GetHandshakeClusterName(attr *attributes.Attributes) (string, bool) { + v := attr.Value(handshakeClusterNameKey{}) + name, ok := v.(string) + return name, ok +} diff --git a/xds/internal/balancer/clusterimpl/balancer_test.go b/xds/internal/balancer/clusterimpl/balancer_test.go index 6d9b7a508..7fb31ab7a 100644 --- a/xds/internal/balancer/clusterimpl/balancer_test.go +++ b/xds/internal/balancer/clusterimpl/balancer_test.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/connectivity" + xdsinternal "google.golang.org/grpc/internal/credentials/xds" internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/resolver" "google.golang.org/grpc/xds/internal/client/load" @@ -369,3 +370,93 @@ func TestPickerUpdateAfterClose(t *testing.T) { case <-time.After(time.Millisecond * 10): } } + +// TestClusterNameInAddressAttributes covers the case that cluster name is +// attached to the subconn address attributes. +func TestClusterNameInAddressAttributes(t *testing.T) { + xdsC := fakeclient.NewClient() + oldNewXDSClient := newXDSClient + newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } + defer func() { newXDSClient = oldNewXDSClient }() + + builder := balancer.Get(clusterImplName) + cc := testutils.NewTestClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + defer b.Close() + + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: testBackendAddrs, + }, + BalancerConfig: &lbConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + sc1 := <-cc.NewSubConnCh + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + // This should get the connecting picker. + p0 := <-cc.NewPickerCh + for i := 0; i < 10; i++ { + _, err := p0.Pick(balancer.PickInfo{}) + if err != balancer.ErrNoSubConnAvailable { + t.Fatalf("picker.Pick, got _,%v, want Err=%v", err, balancer.ErrNoSubConnAvailable) + } + } + + addrs1 := <-cc.NewSubConnAddrsCh + if got, want := addrs1[0].Addr, testBackendAddrs[0].Addr; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + cn, ok := xdsinternal.GetHandshakeClusterName(addrs1[0].Attributes) + if !ok || cn != testClusterName { + t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn, ok, testClusterName) + } + + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + // Test pick with one backend. + p1 := <-cc.NewPickerCh + const rpcCount = 20 + for i := 0; i < rpcCount; i++ { + gotSCSt, err := p1.Pick(balancer.PickInfo{}) + if err != nil || !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1) + } + if gotSCSt.Done != nil { + gotSCSt.Done(balancer.DoneInfo{}) + } + } + + const testClusterName2 = "test-cluster-2" + var addr2 = resolver.Address{Addr: "2.2.2.2"} + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{addr2}, + }, + BalancerConfig: &lbConfig{ + Cluster: testClusterName2, + EDSServiceName: testServiceName, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + addrs2 := <-cc.NewSubConnAddrsCh + if got, want := addrs2[0].Addr, addr2.Addr; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + // New addresses should have the new cluster name. + cn2, ok := xdsinternal.GetHandshakeClusterName(addrs2[0].Attributes) + if !ok || cn2 != testClusterName2 { + t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn2, ok, testClusterName2) + } +} diff --git a/xds/internal/balancer/clusterimpl/clusterimpl.go b/xds/internal/balancer/clusterimpl/clusterimpl.go index 4435f9e65..0cc8d0d82 100644 --- a/xds/internal/balancer/clusterimpl/clusterimpl.go +++ b/xds/internal/balancer/clusterimpl/clusterimpl.go @@ -30,8 +30,10 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/internal/buffer" + xdsinternal "google.golang.org/grpc/internal/credentials/xds" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/xds/internal/balancer/loadstore" xdsclient "google.golang.org/grpc/xds/internal/client" @@ -110,11 +112,13 @@ type clusterImplBalancer struct { config *lbConfig childLB balancer.Balancer cancelLoadReport func() - clusterName string edsServiceName string lrsServerName string loadWrapper *loadstore.Wrapper + clusterNameMu sync.Mutex + clusterName string + // childState/drops/requestCounter can only be accessed in run(). And run() // is the only goroutine that sends picker to the parent ClientConn. All // requests to update picker need to be sent to pickerUpdateCh. @@ -132,9 +136,11 @@ func (cib *clusterImplBalancer) updateLoadStore(newConfig *lbConfig) error { // ClusterName is different, restart. ClusterName is from ClusterName and // EdsServiceName. - if cib.clusterName != newConfig.Cluster { + clusterName := cib.getClusterName() + if clusterName != newConfig.Cluster { updateLoadClusterAndService = true - cib.clusterName = newConfig.Cluster + cib.setClusterName(newConfig.Cluster) + clusterName = newConfig.Cluster } if cib.edsServiceName != newConfig.EDSServiceName { updateLoadClusterAndService = true @@ -149,7 +155,7 @@ func (cib *clusterImplBalancer) updateLoadStore(newConfig *lbConfig) error { // On the other hand, this will almost never happen. Each LRS policy // shouldn't get updated config. The parent should do a graceful switch // when the clusterName or serviceName is changed. - cib.loadWrapper.UpdateClusterAndService(cib.clusterName, cib.edsServiceName) + cib.loadWrapper.UpdateClusterAndService(clusterName, cib.edsServiceName) } // Check if it's necessary to restart load report. @@ -305,6 +311,36 @@ func (cib *clusterImplBalancer) UpdateState(state balancer.State) { cib.pickerUpdateCh.Put(state) } +func (cib *clusterImplBalancer) setClusterName(n string) { + cib.clusterNameMu.Lock() + defer cib.clusterNameMu.Unlock() + cib.clusterName = n +} + +func (cib *clusterImplBalancer) getClusterName() string { + cib.clusterNameMu.Lock() + defer cib.clusterNameMu.Unlock() + return cib.clusterName +} + +func (cib *clusterImplBalancer) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { + clusterName := cib.getClusterName() + newAddrs := make([]resolver.Address, len(addrs)) + for i, addr := range addrs { + newAddrs[i] = xdsinternal.SetHandshakeClusterName(addr, clusterName) + } + return cib.ClientConn.NewSubConn(newAddrs, opts) +} + +func (cib *clusterImplBalancer) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { + clusterName := cib.getClusterName() + newAddrs := make([]resolver.Address, len(addrs)) + for i, addr := range addrs { + newAddrs[i] = xdsinternal.SetHandshakeClusterName(addr, clusterName) + } + cib.ClientConn.UpdateAddresses(sc, newAddrs) +} + type dropConfigs struct { drops []*dropper requestCounter *xdsclient.ServiceRequestsCounter diff --git a/xds/internal/balancer/edsbalancer/eds.go b/xds/internal/balancer/edsbalancer/eds.go index 423df7aed..de724701d 100644 --- a/xds/internal/balancer/edsbalancer/eds.go +++ b/xds/internal/balancer/edsbalancer/eds.go @@ -116,6 +116,9 @@ type edsBalancerImplInterface interface { // updateServiceRequestsConfig updates the service requests counter to the // one for the given service name. updateServiceRequestsConfig(serviceName string, max *uint32) + // updateClusterName updates the cluster name that will be attached to the + // address attributes. + updateClusterName(name string) // close closes the eds balancer. close() } @@ -250,6 +253,7 @@ func (x *edsBalancer) handleServiceConfigUpdate(config *EDSConfig) error { // This is OK for now, because we don't actually expect edsServiceName // to change. Fix this (a bigger change) will happen later. x.lsw.updateServiceName(x.edsServiceName) + x.edsImpl.updateClusterName(x.edsServiceName) } // Restart load reporting when the loadReportServer name has changed. diff --git a/xds/internal/balancer/edsbalancer/eds_impl.go b/xds/internal/balancer/edsbalancer/eds_impl.go index 5318a5342..94f643d33 100644 --- a/xds/internal/balancer/edsbalancer/eds_impl.go +++ b/xds/internal/balancer/edsbalancer/eds_impl.go @@ -23,6 +23,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + xdsinternal "google.golang.org/grpc/internal/credentials/xds" "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/base" @@ -104,6 +105,9 @@ type edsBalancerImpl struct { innerState balancer.State // The state of the picker without drop support. serviceRequestsCounter *client.ServiceRequestsCounter serviceRequestCountMax uint32 + + clusterNameMu sync.Mutex + clusterName string } // newEDSBalancerImpl create a new edsBalancerImpl. @@ -444,6 +448,18 @@ func (edsImpl *edsBalancerImpl) updateServiceRequestsConfig(serviceName string, edsImpl.pickerMu.Unlock() } +func (edsImpl *edsBalancerImpl) updateClusterName(name string) { + edsImpl.clusterNameMu.Lock() + defer edsImpl.clusterNameMu.Unlock() + edsImpl.clusterName = name +} + +func (edsImpl *edsBalancerImpl) getClusterName() string { + edsImpl.clusterNameMu.Lock() + defer edsImpl.clusterNameMu.Unlock() + return edsImpl.clusterName +} + // updateState first handles priority, and then wraps picker in a drop picker // before forwarding the update. func (edsImpl *edsBalancerImpl) updateState(priority priorityType, s balancer.State) { @@ -479,8 +495,23 @@ type edsBalancerWrapperCC struct { } func (ebwcc *edsBalancerWrapperCC) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { - return ebwcc.parent.newSubConn(ebwcc.priority, addrs, opts) + clusterName := ebwcc.parent.getClusterName() + newAddrs := make([]resolver.Address, len(addrs)) + for i, addr := range addrs { + newAddrs[i] = xdsinternal.SetHandshakeClusterName(addr, clusterName) + } + return ebwcc.parent.newSubConn(ebwcc.priority, newAddrs, opts) } + +func (ebwcc *edsBalancerWrapperCC) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { + clusterName := ebwcc.parent.getClusterName() + newAddrs := make([]resolver.Address, len(addrs)) + for i, addr := range addrs { + newAddrs[i] = xdsinternal.SetHandshakeClusterName(addr, clusterName) + } + ebwcc.ClientConn.UpdateAddresses(sc, newAddrs) +} + func (ebwcc *edsBalancerWrapperCC) UpdateState(state balancer.State) { ebwcc.parent.enqueueChildBalancerStateUpdate(ebwcc.priority, state) } diff --git a/xds/internal/balancer/edsbalancer/eds_impl_test.go b/xds/internal/balancer/edsbalancer/eds_impl_test.go index ebaea13cc..79332dfe1 100644 --- a/xds/internal/balancer/edsbalancer/eds_impl_test.go +++ b/xds/internal/balancer/edsbalancer/eds_impl_test.go @@ -26,6 +26,7 @@ import ( corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + xdsinternal "google.golang.org/grpc/internal/credentials/xds" "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/roundrobin" @@ -933,3 +934,73 @@ func (s) TestEDS_LoadReportDisabled(t *testing.T) { p1.Pick(balancer.PickInfo{}) } } + +// TestEDS_ClusterNameInAddressAttributes covers the case that cluster name is +// attached to the subconn address attributes. +func (s) TestEDS_ClusterNameInAddressAttributes(t *testing.T) { + cc := testutils.NewTestClientConn(t) + edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) + edsb.enqueueChildBalancerStateUpdate = edsb.updateState + + const clusterName1 = "cluster-name-1" + edsb.updateClusterName(clusterName1) + + // One locality with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) + + addrs1 := <-cc.NewSubConnAddrsCh + if got, want := addrs1[0].Addr, testEndpointAddrs[0]; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + cn, ok := xdsinternal.GetHandshakeClusterName(addrs1[0].Attributes) + if !ok || cn != clusterName1 { + t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn, ok, clusterName1) + } + + sc1 := <-cc.NewSubConnCh + edsb.handleSubConnStateChange(sc1, connectivity.Connecting) + edsb.handleSubConnStateChange(sc1, connectivity.Ready) + + // Pick with only the first backend. + p1 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p1.Pick(balancer.PickInfo{}) + if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) + } + } + + // Change cluster name. + const clusterName2 = "cluster-name-2" + edsb.updateClusterName(clusterName2) + + // Change backend. + clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[1:2], nil) + edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) + + addrs2 := <-cc.NewSubConnAddrsCh + if got, want := addrs2[0].Addr, testEndpointAddrs[1]; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + // New addresses should have the new cluster name. + cn2, ok := xdsinternal.GetHandshakeClusterName(addrs2[0].Attributes) + if !ok || cn2 != clusterName2 { + t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn2, ok, clusterName1) + } + + sc2 := <-cc.NewSubConnCh + edsb.handleSubConnStateChange(sc2, connectivity.Connecting) + edsb.handleSubConnStateChange(sc2, connectivity.Ready) + + // Test roundrobin with two subconns. + p2 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p2.Pick(balancer.PickInfo{}) + if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) + } + } +} diff --git a/xds/internal/balancer/edsbalancer/eds_test.go b/xds/internal/balancer/edsbalancer/eds_test.go index 5fe1f2ef6..65b74a1b8 100644 --- a/xds/internal/balancer/edsbalancer/eds_test.go +++ b/xds/internal/balancer/edsbalancer/eds_test.go @@ -117,6 +117,7 @@ type fakeEDSBalancer struct { edsUpdate *testutils.Channel serviceName *testutils.Channel serviceRequestMax *testutils.Channel + clusterName *testutils.Channel } func (f *fakeEDSBalancer) handleSubConnStateChange(sc balancer.SubConn, state connectivity.State) { @@ -138,6 +139,10 @@ func (f *fakeEDSBalancer) updateServiceRequestsConfig(serviceName string, max *u f.serviceRequestMax.Send(max) } +func (f *fakeEDSBalancer) updateClusterName(name string) { + f.clusterName.Send(name) +} + func (f *fakeEDSBalancer) close() {} func (f *fakeEDSBalancer) waitForChildPolicy(ctx context.Context, wantPolicy *loadBalancingConfig) error { @@ -207,6 +212,18 @@ func (f *fakeEDSBalancer) waitForCountMaxUpdate(ctx context.Context, want *uint3 return fmt.Errorf("got countMax %+v, want %+v", got, want) } +func (f *fakeEDSBalancer) waitForClusterNameUpdate(ctx context.Context, wantClusterName string) error { + val, err := f.clusterName.Receive(ctx) + if err != nil { + return err + } + gotServiceName := val.(string) + if gotServiceName != wantClusterName { + return fmt.Errorf("got clusterName %v, want %v", gotServiceName, wantClusterName) + } + return nil +} + func newFakeEDSBalancer(cc balancer.ClientConn) edsBalancerImplInterface { return &fakeEDSBalancer{ cc: cc, @@ -215,6 +232,7 @@ func newFakeEDSBalancer(cc balancer.ClientConn) edsBalancerImplInterface { edsUpdate: testutils.NewChannelWithSize(10), serviceName: testutils.NewChannelWithSize(10), serviceRequestMax: testutils.NewChannelWithSize(10), + clusterName: testutils.NewChannelWithSize(10), } } @@ -657,6 +675,59 @@ func (s) TestCounterUpdate(t *testing.T) { } } +// TestClusterNameUpdateInAddressAttributes verifies that cluster name update in +// edsImpl is triggered with the update from a new service config. +func (s) TestClusterNameUpdateInAddressAttributes(t *testing.T) { + edsLBCh := testutils.NewChannel() + xdsC, cleanup := setup(edsLBCh) + defer cleanup() + + builder := balancer.Get(edsName) + edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) + if edsB == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", edsName) + } + defer edsB.Close() + + // Update should trigger counter update with provided service name. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: &EDSConfig{ + EDSServiceName: "foobar-1", + }, + }); err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + gotCluster, err := xdsC.WaitForWatchEDS(ctx) + if err != nil || gotCluster != "foobar-1" { + t.Fatalf("unexpected EDS watch: %v, %v", gotCluster, err) + } + edsI := edsB.(*edsBalancer).edsImpl.(*fakeEDSBalancer) + if err := edsI.waitForClusterNameUpdate(ctx, "foobar-1"); err != nil { + t.Fatal(err) + } + + // Update should trigger counter update with provided service name. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: &EDSConfig{ + EDSServiceName: "foobar-2", + }, + }); err != nil { + t.Fatal(err) + } + if err := xdsC.WaitForCancelEDSWatch(ctx); err != nil { + t.Fatalf("failed to wait for EDS cancel: %v", err) + } + gotCluster2, err := xdsC.WaitForWatchEDS(ctx) + if err != nil || gotCluster2 != "foobar-2" { + t.Fatalf("unexpected EDS watch: %v, %v", gotCluster2, err) + } + if err := edsI.waitForClusterNameUpdate(ctx, "foobar-2"); err != nil { + t.Fatal(err) + } +} + func (s) TestBalancerConfigParsing(t *testing.T) { const testEDSName = "eds.service" var testLRSName = "lrs.server"