From 9dacc45447192555e445485f173299cc598ebd7f Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Mon, 13 May 2019 17:31:24 -0700 Subject: [PATCH] xds: implement ADS request and response handling in standard mode (#5532) Summary of PR: - XdsLbState now assumes standard mode only. - Will not send CDS request. A EDS request will be sent at the constructor of `AdsStream`. - Added a method to `LocalityStore` - `void updateLocalityStore(Map localityInfoMap);` - When a EDS response is received. `LocalityStore.updateLocalityStore()` will be called. - `LocalityStoreImpl` maintains a map `Map localityMap`. - `LocalityStoreImpl.updateLocalityStore()` will create a child balancer for each locality, with a `ChildHelper`. Then each child balancer will call `handleResolvedAddresses()`. - `LocalityStoreImpl.updateLocalityStore()` will update `childPickers`. - `ChildHelper.updateBalancingState()` will update `childPickers` and then delegate to parent `helper.updateBalancingState()`. - `XdsLbState.handleSubchannelState()` will delegate to `childBalancer.handleSubchannelState()` where the subchannel belongs to the childBalancer's locality. --- .../main/java/io/grpc/xds/LocalityStore.java | 356 ++++++++++++++++++ xds/src/main/java/io/grpc/xds/XdsComms.java | 243 ++++++++++-- xds/src/main/java/io/grpc/xds/XdsLbState.java | 77 +--- .../java/io/grpc/xds/XdsLoadBalancer.java | 42 +-- .../io/grpc/xds/XdsLoadBalancerProvider.java | 3 +- .../java/io/grpc/xds/FallbackManagerTest.java | 46 ++- .../java/io/grpc/xds/LocalityStoreTest.java | 286 ++++++++++++++ .../test/java/io/grpc/xds/XdsCommsTest.java | 200 +++++++++- .../test/java/io/grpc/xds/XdsLbStateTest.java | 226 +++++++++-- .../java/io/grpc/xds/XdsLoadBalancerTest.java | 37 +- 10 files changed, 1335 insertions(+), 181 deletions(-) create mode 100644 xds/src/main/java/io/grpc/xds/LocalityStore.java create mode 100644 xds/src/test/java/io/grpc/xds/LocalityStoreTest.java diff --git a/xds/src/main/java/io/grpc/xds/LocalityStore.java b/xds/src/main/java/io/grpc/xds/LocalityStore.java new file mode 100644 index 0000000000..8f48d06706 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/LocalityStore.java @@ -0,0 +1,356 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.ConnectivityState.CONNECTING; +import static io.grpc.ConnectivityState.IDLE; +import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; +import io.grpc.ChannelLogger.ChannelLogLevel; +import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.Status; +import io.grpc.util.ForwardingLoadBalancerHelper; +import io.grpc.xds.InterLocalityPicker.WeightedChildPicker; +import io.grpc.xds.XdsComms.Locality; +import io.grpc.xds.XdsComms.LocalityInfo; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Manages EAG and locality info for a collection of subchannels, not including subchannels + * created by the fallback balancer. + */ +// Must be accessed/run in SynchronizedContext. +interface LocalityStore { + + boolean hasReadyBackends(); + + boolean hasNonDropBackends(); + + void reset(); + + void updateLocalityStore(Map localityInfoMap); + + void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState); + + final class LocalityStoreImpl implements LocalityStore { + private static final String ROUND_ROBIN = "round_robin"; + + private final Helper helper; + private final PickerFactory pickerFactory; + + private Map localityMap = new HashMap<>(); + private LoadBalancerProvider loadBalancerProvider; + private ConnectivityState overallState; + + LocalityStoreImpl(Helper helper, LoadBalancerRegistry lbRegistry) { + this(helper, pickerFactoryImpl, lbRegistry); + } + + @VisibleForTesting + LocalityStoreImpl(Helper helper, PickerFactory pickerFactory, LoadBalancerRegistry lbRegistry) { + this.helper = helper; + this.pickerFactory = pickerFactory; + loadBalancerProvider = checkNotNull( + lbRegistry.getProvider(ROUND_ROBIN), + "Unable to find '%s' LoadBalancer", ROUND_ROBIN); + } + + @VisibleForTesting // Introduced for testing only. + interface PickerFactory { + SubchannelPicker picker(List childPickers); + } + + private static final PickerFactory pickerFactoryImpl = + new PickerFactory() { + @Override + public SubchannelPicker picker(List childPickers) { + return new InterLocalityPicker(childPickers); + } + }; + + @Override + public boolean hasReadyBackends() { + return overallState == READY; + } + + @Override + public boolean hasNonDropBackends() { + // TODO: impl + return false; + } + + // This is triggered by xdsLoadbalancer.handleSubchannelState + @Override + public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { + // delegate to the childBalancer who manages this subchannel + for (LocalityLbInfo localityLbInfo : localityMap.values()) { + // This will probably trigger childHelper.updateBalancingState + localityLbInfo.childBalancer.handleSubchannelState(subchannel, newState); + } + } + + @Override + public void reset() { + for (LocalityLbInfo localityLbInfo : localityMap.values()) { + localityLbInfo.shutdown(); + } + localityMap = new HashMap<>(); + } + + // This is triggered by EDS response. + @Override + public void updateLocalityStore(Map localityInfoMap) { + Set oldLocalities = localityMap.keySet(); + Set newLocalities = localityInfoMap.keySet(); + + Iterator iterator = oldLocalities.iterator(); + while (iterator.hasNext()) { + Locality oldLocality = iterator.next(); + if (!newLocalities.contains(oldLocality)) { + // No graceful transition until a high-level lb graceful transition design is available. + localityMap.get(oldLocality).shutdown(); + iterator.remove(); + if (localityMap.isEmpty()) { + // down-size the map + localityMap = new HashMap<>(); + } + } + } + + ConnectivityState newState = null; + List childPickers = new ArrayList<>(newLocalities.size()); + for (Locality newLocality : newLocalities) { + + // Assuming standard mode only (EDS response with a list of endpoints) for now + List newEags = localityInfoMap.get(newLocality).eags; + LocalityLbInfo localityLbInfo; + ChildHelper childHelper; + if (oldLocalities.contains(newLocality)) { + LocalityLbInfo oldLocalityLbInfo + = localityMap.get(newLocality); + childHelper = oldLocalityLbInfo.childHelper; + localityLbInfo = new LocalityLbInfo( + localityInfoMap.get(newLocality).localityWeight, + oldLocalityLbInfo.childBalancer, + childHelper); + } else { + childHelper = new ChildHelper(newLocality); + localityLbInfo = + new LocalityLbInfo( + localityInfoMap.get(newLocality).localityWeight, + loadBalancerProvider.newLoadBalancer(childHelper), + childHelper); + localityMap.put(newLocality, localityLbInfo); + } + // TODO: put endPointWeights into attributes for WRR. + localityLbInfo.childBalancer + .handleResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(newEags).build()); + + if (localityLbInfo.childHelper.currentChildState == READY) { + childPickers.add( + new WeightedChildPicker( + localityInfoMap.get(newLocality).localityWeight, + localityLbInfo.childHelper.currentChildPicker)); + } + newState = aggregateState(newState, childHelper.currentChildState); + } + + updatePicker(newState, childPickers); + + } + + private static final class ErrorPicker extends SubchannelPicker { + + final Status error; + + ErrorPicker(Status error) { + this.error = checkNotNull(error, "error"); + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + return PickResult.withError(error); + } + } + + private static final SubchannelPicker BUFFER_PICKER = new SubchannelPicker() { + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + return PickResult.withNoResult(); + } + + @Override + public String toString() { + return "BUFFER_PICKER"; + } + }; + + private static ConnectivityState aggregateState( + ConnectivityState overallState, ConnectivityState childState) { + if (overallState == null) { + return childState; + } + if (overallState == READY || childState == READY) { + return READY; + } + if (overallState == CONNECTING || childState == CONNECTING) { + return CONNECTING; + } + if (overallState == IDLE || childState == IDLE) { + return IDLE; + } + return overallState; + } + + private void updateChildState( + Locality locality, ConnectivityState newChildState, SubchannelPicker newChildPicker) { + if (!localityMap.containsKey(locality)) { + return; + } + + List childPickers = new ArrayList<>(); + + ConnectivityState overallState = null; + for (Locality l : localityMap.keySet()) { + LocalityLbInfo localityLbInfo = localityMap.get(l); + ConnectivityState childState; + SubchannelPicker childPicker; + if (l.equals(locality)) { + childState = newChildState; + childPicker = newChildPicker; + } else { + childState = localityLbInfo.childHelper.currentChildState; + childPicker = localityLbInfo.childHelper.currentChildPicker; + } + overallState = aggregateState(overallState, childState); + + if (READY == childState) { + childPickers.add( + new WeightedChildPicker(localityLbInfo.localityWeight, childPicker)); + } + } + + updatePicker(overallState, childPickers); + this.overallState = overallState; + } + + private void updatePicker(ConnectivityState state, List childPickers) { + childPickers = Collections.unmodifiableList(childPickers); + SubchannelPicker picker; + if (childPickers.isEmpty()) { + if (state == TRANSIENT_FAILURE) { + picker = new ErrorPicker(Status.UNAVAILABLE); // TODO: more details in status + } else { + picker = BUFFER_PICKER; + } + } else { + picker = pickerFactory.picker(childPickers); + } + if (state != null) { + helper.getChannelLogger().log( + ChannelLogLevel.INFO, "Picker updated - state: {0}, picker: {1}", state, picker); + helper.updateBalancingState(state, picker); + } + } + + /** + * State of a single Locality. + */ + static final class LocalityLbInfo { + + final int localityWeight; + final LoadBalancer childBalancer; + final ChildHelper childHelper; + + LocalityLbInfo( + int localityWeight, LoadBalancer childBalancer, ChildHelper childHelper) { + checkArgument(localityWeight >= 0, "localityWeight must be non-negative"); + this.localityWeight = localityWeight; + this.childBalancer = checkNotNull(childBalancer, "childBalancer"); + this.childHelper = checkNotNull(childHelper, "childHelper"); + } + + void shutdown() { + childBalancer.shutdown(); + } + } + + class ChildHelper extends ForwardingLoadBalancerHelper { + + private final Locality locality; + + private SubchannelPicker currentChildPicker = BUFFER_PICKER; + private ConnectivityState currentChildState = null; + + ChildHelper(Locality locality) { + this.locality = checkNotNull(locality, "locality"); + } + + @Override + protected Helper delegate() { + return helper; + } + + // This is triggered by child balancer + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + checkNotNull(newState, "newState"); + checkNotNull(newPicker, "newPicker"); + + currentChildState = newState; + currentChildPicker = newPicker; + + // delegate to parent helper + updateChildState(locality, newState, newPicker); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("locality", locality).toString(); + } + + @Override + public String getAuthority() { + //FIXME: This should be a new proposed field of Locality, locality_name + return locality.subzone; + } + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsComms.java b/xds/src/main/java/io/grpc/xds/XdsComms.java index 0cb53fc870..18c3115ab3 100644 --- a/xds/src/main/java/io/grpc/xds/XdsComms.java +++ b/xds/src/main/java/io/grpc/xds/XdsComms.java @@ -19,13 +19,32 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; +import com.google.common.base.Objects; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import io.envoyproxy.envoy.api.v2.ClusterLoadAssignment; import io.envoyproxy.envoy.api.v2.DiscoveryRequest; import io.envoyproxy.envoy.api.v2.DiscoveryResponse; +import io.envoyproxy.envoy.api.v2.core.Node; +import io.envoyproxy.envoy.api.v2.core.SocketAddress; +import io.envoyproxy.envoy.api.v2.endpoint.LocalityLbEndpoints; import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc; +import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer.Helper; import io.grpc.ManagedChannel; import io.grpc.Status; import io.grpc.stub.StreamObserver; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; /** * ADS client implementation. @@ -37,7 +56,123 @@ final class XdsComms { // never null private AdsStream adsStream; + static final class Locality { + final String region; + final String zone; + final String subzone; + + Locality(io.envoyproxy.envoy.api.v2.core.Locality locality) { + this( + /* region = */ locality.getRegion(), + /* zone = */ locality.getZone(), + /* subzone = */ locality.getSubZone()); + } + + @VisibleForTesting + Locality(String region, String zone, String subzone) { + this.region = region; + this.zone = zone; + this.subzone = subzone; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Locality locality = (Locality) o; + return Objects.equal(region, locality.region) + && Objects.equal(zone, locality.zone) + && Objects.equal(subzone, locality.subzone); + } + + @Override + public int hashCode() { + return Objects.hashCode(region, zone, subzone); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("region", region) + .add("zone", zone) + .add("subzone", subzone) + .toString(); + } + } + + /** + * Information about the locality from EDS response. + */ + static final class LocalityInfo { + final List eags; + final List endPointWeights; + final int localityWeight; + + LocalityInfo(Collection lbEndPoints, int localityWeight) { + List eags = new ArrayList<>(lbEndPoints.size()); + List endPointWeights = new ArrayList<>(lbEndPoints.size()); + for (LbEndpoint lbEndPoint : lbEndPoints) { + eags.add(lbEndPoint.eag); + endPointWeights.add(lbEndPoint.endPointWeight); + } + this.eags = Collections.unmodifiableList(eags); + this.endPointWeights = Collections.unmodifiableList(new ArrayList<>(endPointWeights)); + this.localityWeight = localityWeight; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + LocalityInfo that = (LocalityInfo) o; + return localityWeight == that.localityWeight + && Objects.equal(eags, that.eags) + && Objects.equal(endPointWeights, that.endPointWeights); + } + + @Override + public int hashCode() { + return Objects.hashCode(eags, endPointWeights, localityWeight); + } + } + + static final class LbEndpoint { + final EquivalentAddressGroup eag; + final int endPointWeight; + + LbEndpoint(io.envoyproxy.envoy.api.v2.endpoint.LbEndpoint lbEndpointProto) { + + this( + new EquivalentAddressGroup(ImmutableList.of(fromEnvoyProtoAddress(lbEndpointProto))), + lbEndpointProto.getLoadBalancingWeight().getValue()); + } + + @VisibleForTesting + LbEndpoint(EquivalentAddressGroup eag, int endPointWeight) { + this.eag = eag; + this.endPointWeight = endPointWeight; + } + + private static java.net.SocketAddress fromEnvoyProtoAddress( + io.envoyproxy.envoy.api.v2.endpoint.LbEndpoint lbEndpointProto) { + SocketAddress socketAddress = lbEndpointProto.getEndpoint().getAddress().getSocketAddress(); + return new InetSocketAddress(socketAddress.getAddress(), socketAddress.getPortValue()); + } + } + private final class AdsStream { + static final String EDS_TYPE_URL = + "type.googleapis.com/envoy.api.v2.ClusterLoadAssignment"; + static final String TRAFFICDIRECTOR_GRPC_HOSTNAME = "TRAFFICDIRECTOR_GRPC_HOSTNAME"; + final LocalityStore localityStore; final AdsStreamCallback adsStreamCallback; @@ -49,18 +184,56 @@ final class XdsComms { boolean firstResponseReceived; @Override - public void onNext(DiscoveryResponse value) { - if (!firstResponseReceived) { - firstResponseReceived = true; - helper.getSynchronizationContext().execute( - new Runnable() { - @Override - public void run() { - adsStreamCallback.onWorking(); + public void onNext(final DiscoveryResponse value) { + + class HandleResponseRunnable implements Runnable { + + @Override + public void run() { + if (!firstResponseReceived) { + firstResponseReceived = true; + adsStreamCallback.onWorking(); + } + String typeUrl = value.getTypeUrl(); + if (EDS_TYPE_URL.equals(typeUrl)) { + // Assuming standard mode. + + ClusterLoadAssignment clusterLoadAssignment; + try { + // maybe better to run this deserialization task out of syncContext? + clusterLoadAssignment = + value.getResources(0).unpack(ClusterLoadAssignment.class); + } catch (InvalidProtocolBufferException | NullPointerException e) { + cancelRpc("Received invalid EDS response", e); + return; + } + + List localities = clusterLoadAssignment.getEndpointsList(); + Map localityEndpointsMapping = new LinkedHashMap<>(); + for (LocalityLbEndpoints localityLbEndpoints : localities) { + io.envoyproxy.envoy.api.v2.core.Locality localityProto = + localityLbEndpoints.getLocality(); + Locality locality = new Locality(localityProto); + List lbEndPoints = new ArrayList<>(); + for (io.envoyproxy.envoy.api.v2.endpoint.LbEndpoint lbEndpoint + : localityLbEndpoints.getLbEndpointsList()) { + lbEndPoints.add(new LbEndpoint(lbEndpoint)); } - }); + int localityWeight = localityLbEndpoints.getLoadBalancingWeight().getValue(); + + localityEndpointsMapping.put( + locality, new LocalityInfo(lbEndPoints, localityWeight)); + } + + localityEndpointsMapping = Collections.unmodifiableMap(localityEndpointsMapping); + + // TODO: parse drop_percentage, and also updateLoacalistyStore with dropPercentage + localityStore.updateLocalityStore(localityEndpointsMapping); + } + } } - // TODO: more impl + + helper.getSynchronizationContext().execute(new HandleResponseRunnable()); } @Override @@ -81,17 +254,47 @@ final class XdsComms { @Override public void onCompleted() { - // TODO: impl + onError(Status.INTERNAL.withDescription("Server closed the ADS streaming RPC") + .asException()); } }; boolean cancelled; boolean closed; - AdsStream(AdsStreamCallback adsStreamCallback) { + AdsStream(AdsStreamCallback adsStreamCallback, LocalityStore localityStore) { this.adsStreamCallback = adsStreamCallback; this.xdsRequestWriter = AggregatedDiscoveryServiceGrpc.newStub(channel).withWaitForReady() .streamAggregatedResources(xdsResponseReader); + this.localityStore = localityStore; + + // Assuming standard mode, and send EDS request only + xdsRequestWriter.onNext( + DiscoveryRequest.newBuilder() + .setNode(Node.newBuilder() + .setMetadata(Struct.newBuilder() + .putFields( + TRAFFICDIRECTOR_GRPC_HOSTNAME, + Value.newBuilder().setStringValue(helper.getAuthority()) + .build()) + .putFields( + "endpoints_required", + Value.newBuilder().setBoolValue(true).build()))) + .addResourceNames(helper.getAuthority()) + .setTypeUrl(EDS_TYPE_URL).build()); + } + + AdsStream(AdsStream adsStream) { + this(adsStream.adsStreamCallback, adsStream.localityStore); + } + + void cancelRpc(String message, Throwable cause) { + if (cancelled) { + return; + } + cancelled = true; + xdsRequestWriter.onError( + Status.CANCELLED.withDescription(message).withCause(cause).asRuntimeException()); } } @@ -99,10 +302,13 @@ final class XdsComms { * Starts a new ADS streaming RPC. */ XdsComms( - ManagedChannel channel, Helper helper, AdsStreamCallback adsStreamCallback) { + ManagedChannel channel, Helper helper, AdsStreamCallback adsStreamCallback, + LocalityStore localityStore) { this.channel = checkNotNull(channel, "channel"); this.helper = checkNotNull(helper, "helper"); - this.adsStream = new AdsStream(checkNotNull(adsStreamCallback, "adsStreamCallback")); + this.adsStream = new AdsStream( + checkNotNull(adsStreamCallback, "adsStreamCallback"), + checkNotNull(localityStore, "localityStore")); } void shutdownChannel() { @@ -114,17 +320,12 @@ final class XdsComms { checkState(!channel.isShutdown(), "channel is alreday shutdown"); if (adsStream.closed || adsStream.cancelled) { - adsStream = new AdsStream(adsStream.adsStreamCallback); + adsStream = new AdsStream(adsStream); } } void shutdownLbRpc(String message) { - if (adsStream.cancelled) { - return; - } - adsStream.cancelled = true; - adsStream.xdsRequestWriter.onError( - Status.CANCELLED.withDescription(message).asRuntimeException()); + adsStream.cancelRpc(message, null); } /** diff --git a/xds/src/main/java/io/grpc/xds/XdsLbState.java b/xds/src/main/java/io/grpc/xds/XdsLbState.java index 0098acd7cf..e500834e05 100644 --- a/xds/src/main/java/io/grpc/xds/XdsLbState.java +++ b/xds/src/main/java/io/grpc/xds/XdsLbState.java @@ -18,7 +18,6 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import com.google.common.collect.ImmutableList; import io.grpc.Attributes; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; @@ -28,9 +27,7 @@ import io.grpc.ManagedChannel; import io.grpc.Status; import io.grpc.internal.ServiceConfigUtil.LbConfig; import io.grpc.xds.XdsComms.AdsStreamCallback; -import java.net.SocketAddress; import java.util.List; -import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; /** @@ -49,33 +46,30 @@ import javax.annotation.Nullable; */ class XdsLbState { - private static final Attributes.Key> STATE_INFO = - Attributes.Key.create("io.grpc.xds.XdsLoadBalancer.stateInfo"); final String balancerName; @Nullable final LbConfig childPolicy; - private final SubchannelStore subchannelStore; + private final LocalityStore localityStore; private final Helper helper; private final AdsStreamCallback adsStreamCallback; @Nullable private XdsComms xdsComms; - XdsLbState( String balancerName, @Nullable LbConfig childPolicy, @Nullable XdsComms xdsComms, Helper helper, - SubchannelStore subchannelStore, + LocalityStore localityStore, AdsStreamCallback adsStreamCallback) { this.balancerName = checkNotNull(balancerName, "balancerName"); this.childPolicy = childPolicy; this.xdsComms = xdsComms; this.helper = checkNotNull(helper, "helper"); - this.subchannelStore = checkNotNull(subchannelStore, "subchannelStore"); + this.localityStore = checkNotNull(localityStore, "localityStore"); this.adsStreamCallback = checkNotNull(adsStreamCallback, "adsStreamCallback"); } @@ -86,30 +80,22 @@ class XdsLbState { if (xdsComms != null) { xdsComms.refreshAdsStream(); } else { - // ** This is wrong ** - // FIXME: use name resolver to resolve addresses for balancerName, and create xdsComms in - // name resolver listener callback. - // TODO: consider pass a fake EAG as a static final field visible to tests and verify - // createOobChannel() with this EAG in tests. - ManagedChannel oobChannel = helper.createOobChannel( - new EquivalentAddressGroup(ImmutableList.of(new SocketAddress() { - })), - balancerName); - xdsComms = new XdsComms(oobChannel, helper, adsStreamCallback); + ManagedChannel oobChannel = helper.createResolvingOobChannel(balancerName); + xdsComms = new XdsComms(oobChannel, helper, adsStreamCallback, localityStore); } // TODO: maybe update picker } - final void handleNameResolutionError(Status error) { - if (!subchannelStore.hasNonDropBackends()) { + if (!localityStore.hasNonDropBackends()) { // TODO: maybe update picker with transient failure } } final void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { // TODO: maybe update picker + localityStore.handleSubchannelState(subchannel, newState); } /** @@ -117,8 +103,7 @@ class XdsLbState { */ void shutdown() { // TODO: cancel retry timer - // TODO: shutdown child balancers - subchannelStore.shutdown(); + localityStore.reset(); } @Nullable @@ -129,50 +114,4 @@ class XdsLbState { return xdsComms; } - /** - * Manages EAG and locality info for a collection of subchannels, not including subchannels - * created by the fallback balancer. - */ - static final class SubchannelStoreImpl implements SubchannelStore { - - SubchannelStoreImpl() {} - - @Override - public boolean hasReadyBackends() { - // TODO: impl - return false; - } - - @Override - public boolean hasNonDropBackends() { - // TODO: impl - return false; - } - - - @Override - public boolean hasSubchannel(Subchannel subchannel) { - // TODO: impl - return false; - } - - @Override - public void shutdown() { - // TODO: impl - } - } - - /** - * The interface of {@link XdsLbState.SubchannelStoreImpl} that is convenient for testing. - */ - public interface SubchannelStore { - - boolean hasReadyBackends(); - - boolean hasNonDropBackends(); - - boolean hasSubchannel(Subchannel subchannel); - - void shutdown(); - } } diff --git a/xds/src/main/java/io/grpc/xds/XdsLoadBalancer.java b/xds/src/main/java/io/grpc/xds/XdsLoadBalancer.java index 75f321fe18..f37e604b07 100644 --- a/xds/src/main/java/io/grpc/xds/XdsLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/XdsLoadBalancer.java @@ -17,7 +17,6 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.SHUTDOWN; import com.google.common.annotations.VisibleForTesting; @@ -34,8 +33,8 @@ import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.ServiceConfigUtil.LbConfig; +import io.grpc.xds.LocalityStore.LocalityStoreImpl; import io.grpc.xds.XdsComms.AdsStreamCallback; -import io.grpc.xds.XdsLbState.SubchannelStore; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -51,7 +50,7 @@ final class XdsLoadBalancer extends LoadBalancer { static final Attributes.Key> STATE_INFO = Attributes.Key.create("io.grpc.xds.XdsLoadBalancer.stateInfo"); - private final SubchannelStore subchannelStore; + private final LocalityStore localityStore; private final Helper helper; private final LoadBalancerRegistry lbRegistry; private final FallbackManager fallbackManager; @@ -66,6 +65,7 @@ final class XdsLoadBalancer extends LoadBalancer { @Override public void onError() { + // TODO: backoff and retry fallbackManager.balancerWorking = false; fallbackManager.maybeUseFallbackPolicy(); } @@ -76,11 +76,16 @@ final class XdsLoadBalancer extends LoadBalancer { private LbConfig fallbackPolicy; - XdsLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry, SubchannelStore subchannelStore) { - this.helper = checkNotNull(helper, "helper"); + @VisibleForTesting + XdsLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry, LocalityStore localityStore) { + this.helper = helper; this.lbRegistry = lbRegistry; - this.subchannelStore = subchannelStore; - fallbackManager = new FallbackManager(helper, subchannelStore, lbRegistry); + this.localityStore = localityStore; + fallbackManager = new FallbackManager(helper, localityStore, lbRegistry); + } + + XdsLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry) { + this(helper, lbRegistry, new LocalityStoreImpl(helper, lbRegistry)); } @Override @@ -131,7 +136,7 @@ final class XdsLoadBalancer extends LoadBalancer { } } xdsLbState = new XdsLbState( - newBalancerName, childPolicy, xdsComms, helper, subchannelStore, adsStreamCallback); + newBalancerName, childPolicy, xdsComms, helper, localityStore, adsStreamCallback); } @Nullable @@ -168,14 +173,9 @@ final class XdsLoadBalancer extends LoadBalancer { if (fallbackManager.fallbackBalancer != null) { fallbackManager.fallbackBalancer.handleSubchannelState(subchannel, newState); } - if (subchannelStore.hasSubchannel(subchannel)) { - if (newState.getState() == IDLE) { - subchannel.requestConnection(); - } - subchannel.getAttributes().get(STATE_INFO).set(newState); - xdsLbState.handleSubchannelState(subchannel, newState); - fallbackManager.maybeUseFallbackPolicy(); - } + + xdsLbState.handleSubchannelState(subchannel, newState); + fallbackManager.maybeUseFallbackPolicy(); } @Override @@ -206,7 +206,7 @@ final class XdsLoadBalancer extends LoadBalancer { private static final long FALLBACK_TIMEOUT_MS = TimeUnit.SECONDS.toMillis(10); // same as grpclb private final Helper helper; - private final SubchannelStore subchannelStore; + private final LocalityStore localityStore; private final LoadBalancerRegistry lbRegistry; private LbConfig fallbackPolicy; @@ -225,9 +225,9 @@ final class XdsLoadBalancer extends LoadBalancer { private boolean balancerWorking; FallbackManager( - Helper helper, SubchannelStore subchannelStore, LoadBalancerRegistry lbRegistry) { + Helper helper, LocalityStore localityStore, LoadBalancerRegistry lbRegistry) { this.helper = helper; - this.subchannelStore = subchannelStore; + this.localityStore = localityStore; this.lbRegistry = lbRegistry; } @@ -245,7 +245,7 @@ final class XdsLoadBalancer extends LoadBalancer { if (fallbackBalancer != null) { return; } - if (balancerWorking || subchannelStore.hasReadyBackends()) { + if (balancerWorking || localityStore.hasReadyBackends()) { return; } @@ -260,7 +260,7 @@ final class XdsLoadBalancer extends LoadBalancer { .setAttributes(fallbackAttributes) .build()); - // TODO: maybe update picker + // TODO: maybe update picker here if still use the old API but not SubchannelStateListener } void updateFallbackServers( diff --git a/xds/src/main/java/io/grpc/xds/XdsLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/XdsLoadBalancerProvider.java index 8177c3222c..b200111db7 100644 --- a/xds/src/main/java/io/grpc/xds/XdsLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/XdsLoadBalancerProvider.java @@ -27,7 +27,6 @@ import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; import io.grpc.internal.ServiceConfigUtil; import io.grpc.internal.ServiceConfigUtil.LbConfig; -import io.grpc.xds.XdsLbState.SubchannelStoreImpl; import io.grpc.xds.XdsLoadBalancer.XdsConfig; import java.util.List; import java.util.Map; @@ -62,7 +61,7 @@ public final class XdsLoadBalancerProvider extends LoadBalancerProvider { @Override public LoadBalancer newLoadBalancer(Helper helper) { - return new XdsLoadBalancer(helper, registry, new SubchannelStoreImpl()); + return new XdsLoadBalancer(helper, registry); } @Override diff --git a/xds/src/test/java/io/grpc/xds/FallbackManagerTest.java b/xds/src/test/java/io/grpc/xds/FallbackManagerTest.java index a3f2a4f38b..ae8544b4ac 100644 --- a/xds/src/test/java/io/grpc/xds/FallbackManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/FallbackManagerTest.java @@ -32,7 +32,7 @@ import io.grpc.LoadBalancerRegistry; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; import io.grpc.internal.ServiceConfigUtil.LbConfig; -import io.grpc.xds.XdsLbState.SubchannelStoreImpl; +import io.grpc.xds.LocalityStore.LocalityStoreImpl; import io.grpc.xds.XdsLoadBalancer.FallbackManager; import java.util.ArrayList; import java.util.HashMap; @@ -58,7 +58,7 @@ public class FallbackManagerTest { private final FakeClock fakeClock = new FakeClock(); private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); - private final LoadBalancerProvider fakeLbProvider = new LoadBalancerProvider() { + private final LoadBalancerProvider fakeFallbackLbProvider = new LoadBalancerProvider() { @Override public boolean isAvailable() { return true; @@ -76,7 +76,29 @@ public class FallbackManagerTest { @Override public LoadBalancer newLoadBalancer(Helper helper) { - return fakeLb; + return fakeFallbackLb; + } + }; + + private final LoadBalancerProvider fakeRoundRonbinLbProvider = new LoadBalancerProvider() { + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "round_robin"; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return fakeRoundRobinLb; } }; @@ -91,7 +113,9 @@ public class FallbackManagerTest { @Mock private Helper helper; @Mock - private LoadBalancer fakeLb; + private LoadBalancer fakeRoundRobinLb; + @Mock + private LoadBalancer fakeFallbackLb; @Mock private ChannelLogger channelLogger; @@ -104,9 +128,11 @@ public class FallbackManagerTest { doReturn(syncContext).when(helper).getSynchronizationContext(); doReturn(fakeClock.getScheduledExecutorService()).when(helper).getScheduledExecutorService(); doReturn(channelLogger).when(helper).getChannelLogger(); - fallbackManager = new FallbackManager(helper, new SubchannelStoreImpl(), lbRegistry); + lbRegistry.register(fakeRoundRonbinLbProvider); + lbRegistry.register(fakeFallbackLbProvider); + fallbackManager = new FallbackManager( + helper, new LocalityStoreImpl(helper, lbRegistry), lbRegistry); fallbackPolicy = new LbConfig("test_policy", new HashMap()); - lbRegistry.register(fakeLbProvider); } @After @@ -121,11 +147,12 @@ public class FallbackManagerTest { fallbackManager.updateFallbackServers( eags, Attributes.EMPTY, fallbackPolicy); - verify(fakeLb, never()).handleResolvedAddresses(ArgumentMatchers.any(ResolvedAddresses.class)); + verify(fakeFallbackLb, never()) + .handleResolvedAddresses(ArgumentMatchers.any(ResolvedAddresses.class)); fakeClock.forwardTime(FALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS); - verify(fakeLb).handleResolvedAddresses( + verify(fakeFallbackLb).handleResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(eags) .setAttributes( @@ -148,6 +175,7 @@ public class FallbackManagerTest { fakeClock.forwardTime(FALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS); - verify(fakeLb, never()).handleResolvedAddresses(ArgumentMatchers.any(ResolvedAddresses.class)); + verify(fakeFallbackLb, never()) + .handleResolvedAddresses(ArgumentMatchers.any(ResolvedAddresses.class)); } } diff --git a/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java b/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java new file mode 100644 index 0000000000..71f1e14c85 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java @@ -0,0 +1,286 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.ConnectivityState.CONNECTING; +import static io.grpc.ConnectivityState.READY; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.grpc.Attributes; +import io.grpc.ChannelLogger; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.xds.InterLocalityPicker.WeightedChildPicker; +import io.grpc.xds.LocalityStore.LocalityStoreImpl; +import io.grpc.xds.LocalityStore.LocalityStoreImpl.PickerFactory; +import io.grpc.xds.XdsComms.LbEndpoint; +import io.grpc.xds.XdsComms.Locality; +import io.grpc.xds.XdsComms.LocalityInfo; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** + * Tests for {@link LocalityStore}. + */ +@RunWith(JUnit4.class) +public class LocalityStoreTest { + @Rule + public final MockitoRule mockitoRule = MockitoJUnit.rule(); + + private static final class FakePickerFactory implements PickerFactory { + int totalReadyLocalities; + int nextIndex; + + @Override + public SubchannelPicker picker(final List childPickers) { + totalReadyLocalities = childPickers.size(); + + return new SubchannelPicker() { + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + return childPickers.get(nextIndex).getPicker().pickSubchannel(args); + } + }; + } + + void setNextIndex(int nextIndex) { + this.nextIndex = nextIndex; + } + } + + private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); + private final List loadBalancers = new ArrayList<>(); + private final List helpers = new ArrayList<>(); + + private final LoadBalancerProvider lbProvider = new LoadBalancerProvider() { + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 0; + } + + @Override + public String getPolicyName() { + return "round_robin"; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + LoadBalancer fakeLb = mock(LoadBalancer.class); + loadBalancers.add(fakeLb); + helpers.add(helper); + return fakeLb; + } + }; + + private final FakePickerFactory pickerFactory = new FakePickerFactory(); + + private final Locality locality1 = new Locality("r1", "z1", "sz1"); + private final Locality locality2 = new Locality("r2", "z2", "sz2"); + private final Locality locality3 = new Locality("r3", "z3", "sz3"); + private final Locality locality4 = new Locality("r4", "z4", "sz4"); + + private final EquivalentAddressGroup eag11 = + new EquivalentAddressGroup(new InetSocketAddress("addr11", 11)); + private final EquivalentAddressGroup eag12 = + new EquivalentAddressGroup(new InetSocketAddress("addr12", 12)); + private final EquivalentAddressGroup eag21 = + new EquivalentAddressGroup(new InetSocketAddress("addr21", 21)); + private final EquivalentAddressGroup eag22 = + new EquivalentAddressGroup(new InetSocketAddress("addr22", 22)); + private final EquivalentAddressGroup eag31 = + new EquivalentAddressGroup(new InetSocketAddress("addr31", 31)); + private final EquivalentAddressGroup eag32 = + new EquivalentAddressGroup(new InetSocketAddress("addr32", 32)); + private final EquivalentAddressGroup eag41 = + new EquivalentAddressGroup(new InetSocketAddress("addr41", 41)); + private final EquivalentAddressGroup eag42 = + new EquivalentAddressGroup(new InetSocketAddress("addr42", 42)); + + private final LbEndpoint lbEndpoint11 = new LbEndpoint(eag11, 11); + private final LbEndpoint lbEndpoint12 = new LbEndpoint(eag12, 12); + private final LbEndpoint lbEndpoint21 = new LbEndpoint(eag21, 21); + private final LbEndpoint lbEndpoint22 = new LbEndpoint(eag22, 22); + private final LbEndpoint lbEndpoint31 = new LbEndpoint(eag31, 31); + private final LbEndpoint lbEndpoint32 = new LbEndpoint(eag32, 32); + private final LbEndpoint lbEndpoint41 = new LbEndpoint(eag41, 41); + private final LbEndpoint lbEndpoint42 = new LbEndpoint(eag42, 42); + + @Mock + private Helper helper; + @Mock + private PickSubchannelArgs pickSubchannelArgs; + + private LocalityStore localityStore; + + @Before + public void setUp() { + doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger(); + doReturn(mock(Subchannel.class)).when(helper).createSubchannel( + ArgumentMatchers.anyList(), any(Attributes.class)); + lbRegistry.register(lbProvider); + localityStore = new LocalityStoreImpl(helper, pickerFactory, lbRegistry); + } + + @Test + public void updateLoaclityStore() { + LocalityInfo localityInfo1 = + new LocalityInfo(ImmutableList.of(lbEndpoint11, lbEndpoint12), 1); + LocalityInfo localityInfo2 = + new LocalityInfo(ImmutableList.of(lbEndpoint21, lbEndpoint22), 2); + LocalityInfo localityInfo3 = + new LocalityInfo(ImmutableList.of(lbEndpoint31, lbEndpoint32), 3); + Map localityInfoMap = ImmutableMap.of( + locality1, localityInfo1, locality2, localityInfo2, locality3, localityInfo3); + localityStore.updateLocalityStore(localityInfoMap); + + assertThat(loadBalancers).hasSize(3); + ArgumentCaptor resolvedAddressesCaptor1 = + ArgumentCaptor.forClass(ResolvedAddresses.class); + verify(loadBalancers.get(0)).handleResolvedAddresses(resolvedAddressesCaptor1.capture()); + assertThat(resolvedAddressesCaptor1.getValue().getAddresses()).containsExactly(eag11, eag12); + ArgumentCaptor resolvedAddressesCaptor2 = + ArgumentCaptor.forClass(ResolvedAddresses.class); + verify(loadBalancers.get(1)).handleResolvedAddresses(resolvedAddressesCaptor2.capture()); + assertThat(resolvedAddressesCaptor2.getValue().getAddresses()).containsExactly(eag21, eag22); + ArgumentCaptor resolvedAddressesCaptor3 = + ArgumentCaptor.forClass(ResolvedAddresses.class); + verify(loadBalancers.get(2)).handleResolvedAddresses(resolvedAddressesCaptor3.capture()); + assertThat(resolvedAddressesCaptor3.getValue().getAddresses()).containsExactly(eag31, eag32); + assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0); + + // subchannel12 goes to CONNECTING + final Subchannel subchannel12 = + helpers.get(0).createSubchannel(ImmutableList.of(eag12), Attributes.EMPTY); + verify(helper).createSubchannel(ImmutableList.of(eag12), Attributes.EMPTY); + SubchannelPicker subchannelPicker12 = new SubchannelPicker() { + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + return PickResult.withSubchannel(subchannel12); + } + }; + helpers.get(0).updateBalancingState(CONNECTING, subchannelPicker12); + ArgumentCaptor subchannelPickerCaptor12 = + ArgumentCaptor.forClass(SubchannelPicker.class); + verify(helper).updateBalancingState(same(CONNECTING), subchannelPickerCaptor12.capture()); + assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0); + assertThat(subchannelPickerCaptor12.getValue().pickSubchannel(pickSubchannelArgs)) + .isEqualTo(PickResult.withNoResult()); + + // subchannel31 goes to READY + final Subchannel subchannel31 = + helpers.get(2).createSubchannel(ImmutableList.of(eag31), Attributes.EMPTY); + verify(helper).createSubchannel(ImmutableList.of(eag31), Attributes.EMPTY); + SubchannelPicker subchannelPicker31 = new SubchannelPicker() { + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + return PickResult.withSubchannel(subchannel31); + } + }; + helpers.get(2).updateBalancingState(READY, subchannelPicker31); + ArgumentCaptor subchannelPickerCaptor31 = + ArgumentCaptor.forClass(SubchannelPicker.class); + verify(helper).updateBalancingState(same(READY), subchannelPickerCaptor31.capture()); + assertThat(pickerFactory.totalReadyLocalities).isEqualTo(1); + assertThat( + subchannelPickerCaptor31.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) + .isEqualTo(subchannel31); + + // subchannel12 goes to READY + helpers.get(0).updateBalancingState(READY, subchannelPicker12); + verify(helper, times(2)).updateBalancingState(same(READY), subchannelPickerCaptor12.capture()); + assertThat(pickerFactory.totalReadyLocalities).isEqualTo(2); + pickerFactory.nextIndex = 0; + assertThat( + subchannelPickerCaptor12.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) + .isEqualTo(subchannel12); + + // update with new addressed + localityInfo1 = + new LocalityInfo(ImmutableList.of(lbEndpoint11), 1); + LocalityInfo localityInfo4 = + new LocalityInfo(ImmutableList.of(lbEndpoint41, lbEndpoint42), 4); + localityInfoMap = ImmutableMap.of( + locality2, localityInfo2, locality4, localityInfo4, locality1, localityInfo1); + localityStore.updateLocalityStore(localityInfoMap); + + assertThat(loadBalancers).hasSize(4); + verify(loadBalancers.get(2)).shutdown(); + verify(loadBalancers.get(1), times(2)) + .handleResolvedAddresses(resolvedAddressesCaptor2.capture()); + assertThat(resolvedAddressesCaptor2.getValue().getAddresses()).containsExactly(eag21, eag22); + ArgumentCaptor resolvedAddressesCaptor4 = + ArgumentCaptor.forClass(ResolvedAddresses.class); + verify(loadBalancers.get(3)).handleResolvedAddresses(resolvedAddressesCaptor4.capture()); + assertThat(resolvedAddressesCaptor4.getValue().getAddresses()).containsExactly(eag41, eag42); + verify(loadBalancers.get(0), times(2)) + .handleResolvedAddresses(resolvedAddressesCaptor1.capture()); + assertThat(resolvedAddressesCaptor1.getValue().getAddresses()).containsExactly(eag11); + assertThat(pickerFactory.totalReadyLocalities).isEqualTo(1); + } + + @Test + public void reset() { + LocalityInfo localityInfo1 = + new LocalityInfo(ImmutableList.of(lbEndpoint11, lbEndpoint12), 1); + LocalityInfo localityInfo2 = + new LocalityInfo(ImmutableList.of(lbEndpoint21, lbEndpoint22), 2); + Map localityInfoMap = ImmutableMap.of( + locality1, localityInfo1, locality2, localityInfo2); + localityStore.updateLocalityStore(localityInfoMap); + + assertThat(loadBalancers).hasSize(2); + + localityStore.reset(); + + verify(loadBalancers.get(0)).shutdown(); + verify(loadBalancers.get(1)).shutdown(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/XdsCommsTest.java b/xds/src/test/java/io/grpc/xds/XdsCommsTest.java index d5dbb52044..b1c05d3e55 100644 --- a/xds/src/test/java/io/grpc/xds/XdsCommsTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsCommsTest.java @@ -16,27 +16,50 @@ package io.grpc.xds; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Any; +import com.google.protobuf.UInt32Value; +import io.envoyproxy.envoy.api.v2.ClusterLoadAssignment; import io.envoyproxy.envoy.api.v2.DiscoveryRequest; import io.envoyproxy.envoy.api.v2.DiscoveryResponse; +import io.envoyproxy.envoy.api.v2.core.Address; +import io.envoyproxy.envoy.api.v2.core.Locality; +import io.envoyproxy.envoy.api.v2.core.SocketAddress; +import io.envoyproxy.envoy.api.v2.endpoint.Endpoint; +import io.envoyproxy.envoy.api.v2.endpoint.LbEndpoint; +import io.envoyproxy.envoy.api.v2.endpoint.LocalityLbEndpoints; import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase; +import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; import io.grpc.Status; +import io.grpc.SynchronizationContext; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.testing.StreamRecorder; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; import io.grpc.xds.XdsComms.AdsStreamCallback; +import io.grpc.xds.XdsComms.LocalityInfo; +import java.util.Map; import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -48,12 +71,27 @@ public class XdsCommsTest { @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); @Mock - Helper helper; + private Helper helper; @Mock - AdsStreamCallback adsStreamCallback; + private AdsStreamCallback adsStreamCallback; + @Mock + private LocalityStore localityStore; + @Captor + private ArgumentCaptor> localityEndpointsMappingCaptor; + + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); private final StreamRecorder streamRecorder = StreamRecorder.create(); + private StreamObserver responseWriter; + private ManagedChannel channel; private XdsComms xdsComms; @Before @@ -66,6 +104,8 @@ public class XdsCommsTest { @Override public StreamObserver streamAggregatedResources( final StreamObserver responseObserver) { + responseWriter = responseObserver; + return new StreamObserver() { @Override @@ -91,11 +131,51 @@ public class XdsCommsTest { InProcessServerBuilder .forName(serverName) .addService(serviceImpl) + .directExecutor() .build() .start()); - ManagedChannel channel = - cleanupRule.register(InProcessChannelBuilder.forName(serverName).build()); - xdsComms = new XdsComms(channel, helper, adsStreamCallback); + channel = + cleanupRule.register(InProcessChannelBuilder.forName(serverName).directExecutor().build()); + doReturn("fake_authority").when(helper).getAuthority(); + doReturn(syncContext).when(helper).getSynchronizationContext(); + lbRegistry.register(new LoadBalancerProvider() { + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 0; + } + + @Override + public String getPolicyName() { + return "round_robin"; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return null; + } + }); + xdsComms = new XdsComms(channel, helper, adsStreamCallback, localityStore); + } + + @Test + public void shutdownLbComm() throws Exception { + xdsComms.shutdownChannel(); + assertTrue(channel.isShutdown()); + assertTrue(streamRecorder.awaitCompletion(1, TimeUnit.SECONDS)); + assertEquals(Status.Code.CANCELLED, Status.fromThrowable(streamRecorder.getError()).getCode()); + } + + @Test + public void shutdownLbRpc_verifyChannelNotShutdown() throws Exception { + xdsComms.shutdownLbRpc("shutdown msg1"); + assertTrue(streamRecorder.awaitCompletion(1, TimeUnit.SECONDS)); + assertEquals(Status.Code.CANCELLED, Status.fromThrowable(streamRecorder.getError()).getCode()); + assertFalse(channel.isShutdown()); } @Test @@ -104,4 +184,114 @@ public class XdsCommsTest { assertTrue(streamRecorder.awaitCompletion(1, TimeUnit.SECONDS)); assertEquals(Status.Code.CANCELLED, Status.fromThrowable(streamRecorder.getError()).getCode()); } + + @Test + public void standardMode_sendEdsRequest_getEdsResponse() { + assertThat(streamRecorder.getValues()).hasSize(1); + DiscoveryRequest request = streamRecorder.getValues().get(0); + assertThat(request.getTypeUrl()) + .isEqualTo("type.googleapis.com/envoy.api.v2.ClusterLoadAssignment"); + assertThat( + request.getNode().getMetadata().getFieldsOrThrow("endpoints_required").getBoolValue()) + .isTrue(); + + Locality localityProto1 = Locality.newBuilder() + .setRegion("region1").setZone("zone1").setSubZone("subzone1").build(); + LbEndpoint endpoint11 = LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("addr11").setPortValue(11)))) + .setLoadBalancingWeight(UInt32Value.of(11)) + .build(); + LbEndpoint endpoint12 = LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("addr12").setPortValue(12)))) + .setLoadBalancingWeight(UInt32Value.of(12)) + .build(); + Locality localityProto2 = Locality.newBuilder() + .setRegion("region2").setZone("zone2").setSubZone("subzone2").build(); + LbEndpoint endpoint21 = LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("addr21").setPortValue(21)))) + .setLoadBalancingWeight(UInt32Value.of(21)) + .build(); + LbEndpoint endpoint22 = LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("addr22").setPortValue(22)))) + .setLoadBalancingWeight(UInt32Value.of(22)) + .build(); + DiscoveryResponse edsResponse = DiscoveryResponse.newBuilder() + .addResources(Any.pack(ClusterLoadAssignment.newBuilder() + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLocality(localityProto1) + .addLbEndpoints(endpoint11) + .addLbEndpoints(endpoint12) + .setLoadBalancingWeight(UInt32Value.of(1))) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLocality(localityProto2) + .addLbEndpoints(endpoint21) + .addLbEndpoints(endpoint22) + .setLoadBalancingWeight(UInt32Value.of(2))) + .build())) + .setTypeUrl("type.googleapis.com/envoy.api.v2.ClusterLoadAssignment") + .build(); + responseWriter.onNext(edsResponse); + + XdsComms.Locality locality1 = new XdsComms.Locality(localityProto1); + LocalityInfo localityInfo1 = new LocalityInfo( + ImmutableList.of( + new XdsComms.LbEndpoint(endpoint11), + new XdsComms.LbEndpoint(endpoint12)), + 1); + LocalityInfo localityInfo2 = new LocalityInfo( + ImmutableList.of( + new XdsComms.LbEndpoint(endpoint21), + new XdsComms.LbEndpoint(endpoint22)), + 2); + XdsComms.Locality locality2 = new XdsComms.Locality(localityProto2); + + verify(localityStore).updateLocalityStore(localityEndpointsMappingCaptor.capture()); + assertThat(localityEndpointsMappingCaptor.getValue()).containsExactly( + locality1, localityInfo1, locality2, localityInfo2).inOrder(); + + + edsResponse = DiscoveryResponse.newBuilder() + .addResources(Any.pack(ClusterLoadAssignment.newBuilder() + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLocality(localityProto2) + .addLbEndpoints(endpoint21) + .addLbEndpoints(endpoint22) + .setLoadBalancingWeight(UInt32Value.of(2))) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLocality(localityProto1) + .addLbEndpoints(endpoint11) + .addLbEndpoints(endpoint12) + .setLoadBalancingWeight(UInt32Value.of(1))) + .build())) + .setTypeUrl("type.googleapis.com/envoy.api.v2.ClusterLoadAssignment") + .build(); + responseWriter.onNext(edsResponse); + + verify(localityStore, times(2)).updateLocalityStore(localityEndpointsMappingCaptor.capture()); + assertThat(localityEndpointsMappingCaptor.getValue()).containsExactly( + locality2, localityInfo2, locality1, localityInfo1).inOrder(); + + xdsComms.shutdownChannel(); + } + + @Test + public void serverOnCompleteShouldFailClient() { + responseWriter.onCompleted(); + + verify(adsStreamCallback).onError(); + + xdsComms.shutdownChannel(); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java b/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java index 6cd3dd29b5..864b5c6e92 100644 --- a/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java @@ -16,19 +16,35 @@ package io.grpc.xds; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.ConnectivityState.CONNECTING; +import static io.grpc.ConnectivityState.READY; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import com.google.common.collect.ImmutableList; import io.envoyproxy.envoy.api.v2.DiscoveryRequest; import io.envoyproxy.envoy.api.v2.DiscoveryResponse; import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase; +import io.grpc.Attributes; +import io.grpc.ChannelLogger; +import io.grpc.ConnectivityState; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; -import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; @@ -36,14 +52,26 @@ import io.grpc.internal.FakeClock; import io.grpc.internal.testing.StreamRecorder; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; +import io.grpc.xds.InterLocalityPicker.WeightedChildPicker; +import io.grpc.xds.LocalityStore.LocalityStoreImpl; +import io.grpc.xds.LocalityStore.LocalityStoreImpl.PickerFactory; import io.grpc.xds.XdsComms.AdsStreamCallback; -import java.util.concurrent.TimeUnit; -import org.junit.After; +import io.grpc.xds.XdsComms.LbEndpoint; +import io.grpc.xds.XdsComms.Locality; +import io.grpc.xds.XdsComms.LocalityInfo; +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -52,15 +80,55 @@ import org.mockito.MockitoAnnotations; */ @RunWith(JUnit4.class) public class XdsLbStateTest { + private static final String BALANCER_NAME = "balancerName"; @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); @Mock private Helper helper; @Mock private AdsStreamCallback adsStreamCallback; + @Mock + private PickSubchannelArgs pickSubchannelArgs; + @Captor + private ArgumentCaptor subchannelPickerCaptor; + @Captor + private ArgumentCaptor resolvedAddressesCaptor; + + private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); + private final Map loadBalancers = new HashMap<>(); + private final Map childHelpers = new HashMap<>(); + + private final LoadBalancerProvider childLbProvider = new LoadBalancerProvider() { + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "round_robin"; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + if (loadBalancers.containsKey(helper.getAuthority())) { + return loadBalancers.get(helper.getAuthority()); + } + LoadBalancer loadBalancer = mock(LoadBalancer.class); + loadBalancers.put(helper.getAuthority(), loadBalancer); + childHelpers.put(helper.getAuthority(), helper); + return loadBalancer; + } + }; + + private LocalityStore localityStore; private final FakeClock fakeClock = new FakeClock(); - private final StreamRecorder streamRecorder = StreamRecorder.create(); private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @@ -70,15 +138,43 @@ public class XdsLbStateTest { } }); - private XdsComms xdsComms; - + private final StreamRecorder streamRecorder = StreamRecorder.create(); + private StreamObserver responseWriter; private ManagedChannel channel; + private static final class FakeInterLocalityPickerFactory implements PickerFactory { + int totalReadyLocalities; + int nextIndex; + + @Override + public SubchannelPicker picker(final List childPickers) { + totalReadyLocalities = childPickers.size(); + + return new SubchannelPicker() { + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + return childPickers.get(nextIndex).getPicker().pickSubchannel(args); + } + }; + } + + void setNextIndex(int nextIndex) { + this.nextIndex = nextIndex; + } + } + + private final FakeInterLocalityPickerFactory interLocalityPickerFactory + = new FakeInterLocalityPickerFactory(); + @Before public void setUp() throws Exception { MockitoAnnotations.initMocks(this); doReturn(syncContext).when(helper).getSynchronizationContext(); doReturn(fakeClock.getScheduledExecutorService()).when(helper).getScheduledExecutorService(); + doReturn("fake_authority").when(helper).getAuthority(); + doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger(); + lbRegistry.register(childLbProvider); + localityStore = new LocalityStoreImpl(helper, interLocalityPickerFactory, lbRegistry); String serverName = InProcessServerBuilder.generateName(); @@ -86,6 +182,8 @@ public class XdsLbStateTest { @Override public StreamObserver streamAggregatedResources( final StreamObserver responseObserver) { + responseWriter = responseObserver; + return new StreamObserver() { @Override @@ -111,34 +209,12 @@ public class XdsLbStateTest { InProcessServerBuilder .forName(serverName) .addService(serviceImpl) + .directExecutor() .build() .start()); channel = - cleanupRule.register(InProcessChannelBuilder.forName(serverName).build()); - xdsComms = new XdsComms(channel, helper, adsStreamCallback); - } - - @After - public void tearDown() { - if (!channel.isShutdown()) { - channel.shutdownNow(); - } - } - - @Test - public void shutdownLbComm() throws Exception { - xdsComms.shutdownChannel(); - assertTrue(channel.isShutdown()); - assertTrue(streamRecorder.awaitCompletion(1, TimeUnit.SECONDS)); - assertEquals(Status.Code.CANCELLED, Status.fromThrowable(streamRecorder.getError()).getCode()); - } - - @Test - public void shutdownLbRpc_verifyChannelNotShutdown() throws Exception { - xdsComms.shutdownLbRpc("shutdown msg1"); - assertTrue(streamRecorder.awaitCompletion(1, TimeUnit.SECONDS)); - assertEquals(Status.Code.CANCELLED, Status.fromThrowable(streamRecorder.getError()).getCode()); - assertFalse(channel.isShutdown()); + cleanupRule.register(InProcessChannelBuilder.forName(serverName).directExecutor().build()); + doReturn(channel).when(helper).createResolvingOobChannel(BALANCER_NAME); } @Test @@ -147,4 +223,86 @@ public class XdsLbStateTest { xdsLbState.shutdownAndReleaseXdsComms(); verify(xdsLbState).shutdown(); } + + @Test + public void handleSubchannelState() { + assertThat(loadBalancers).isEmpty(); + + Locality locality1 = new Locality("r1", "z1", "sz1"); + EquivalentAddressGroup eag11 = new EquivalentAddressGroup(new InetSocketAddress("addr11", 11)); + EquivalentAddressGroup eag12 = new EquivalentAddressGroup(new InetSocketAddress("addr12", 12)); + + LbEndpoint lbEndpoint11 = new LbEndpoint(eag11, 11); + LbEndpoint lbEndpoint12 = new LbEndpoint(eag12, 12); + LocalityInfo localityInfo1 = new LocalityInfo(ImmutableList.of(lbEndpoint11, lbEndpoint12), 1); + + Locality locality2 = new Locality("r2", "z2", "sz2"); + EquivalentAddressGroup eag21 = new EquivalentAddressGroup(new InetSocketAddress("addr21", 21)); + EquivalentAddressGroup eag22 = new EquivalentAddressGroup(new InetSocketAddress("addr22", 22)); + + LbEndpoint lbEndpoint21 = new LbEndpoint(eag21, 21); + LbEndpoint lbEndpoint22 = new LbEndpoint(eag22, 22); + LocalityInfo localityInfo2 = new LocalityInfo(ImmutableList.of(lbEndpoint21, lbEndpoint22), 2); + + Map localityInfoMap = new LinkedHashMap<>(); + localityInfoMap.put(locality1, localityInfo1); + localityInfoMap.put(locality2, localityInfo2); + + verify(helper, never()).updateBalancingState( + any(ConnectivityState.class), any(SubchannelPicker.class)); + + localityStore.updateLocalityStore(localityInfoMap); + + assertThat(loadBalancers).hasSize(2); + assertThat(loadBalancers.keySet()).containsExactly("sz1", "sz2"); + assertThat(childHelpers).hasSize(2); + assertThat(childHelpers.keySet()).containsExactly("sz1", "sz2"); + + verify(loadBalancers.get("sz1")).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + assertThat(resolvedAddressesCaptor.getValue().getAddresses()) + .containsExactly(eag11, eag12).inOrder(); + verify(loadBalancers.get("sz2")).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + assertThat(resolvedAddressesCaptor.getValue().getAddresses()) + .containsExactly(eag21, eag22).inOrder(); + verify(helper, never()).updateBalancingState( + any(ConnectivityState.class), any(SubchannelPicker.class)); + + SubchannelPicker childPicker1 = mock(SubchannelPicker.class); + PickResult pickResult1 = PickResult.withSubchannel(mock(Subchannel.class)); + doReturn(pickResult1).when(childPicker1).pickSubchannel(any(PickSubchannelArgs.class)); + childHelpers.get("sz1").updateBalancingState(READY, childPicker1); + verify(helper).updateBalancingState(eq(READY), subchannelPickerCaptor.capture()); + + assertThat(interLocalityPickerFactory.totalReadyLocalities).isEqualTo(1); + interLocalityPickerFactory.setNextIndex(0); + assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs)) + .isSameInstanceAs(pickResult1); + + SubchannelPicker childPicker2 = mock(SubchannelPicker.class); + PickResult pickResult2 = PickResult.withSubchannel(mock(Subchannel.class)); + doReturn(pickResult2).when(childPicker2).pickSubchannel(any(PickSubchannelArgs.class)); + childHelpers.get("sz2").updateBalancingState(CONNECTING, childPicker2); + verify(helper, times(2)).updateBalancingState(eq(READY), subchannelPickerCaptor.capture()); + + assertThat(interLocalityPickerFactory.totalReadyLocalities).isEqualTo(1); + assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs)) + .isSameInstanceAs(pickResult1); + } + + @Test + public void handleResolvedAddressGroupsThenShutdown() throws Exception { + localityStore = mock(LocalityStore.class); + XdsLbState xdsLbState = + new XdsLbState(BALANCER_NAME, null, null, helper, localityStore, adsStreamCallback); + xdsLbState.handleResolvedAddressGroups( + Collections.emptyList(), Attributes.EMPTY); + + assertThat(streamRecorder.firstValue().get().getTypeUrl()) + .isEqualTo("type.googleapis.com/envoy.api.v2.ClusterLoadAssignment"); + + xdsLbState.shutdown(); + verify(localityStore).reset(); + + channel.shutdownNow(); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/XdsLoadBalancerTest.java index 195edcb9df..80a67aa67f 100644 --- a/xds/src/test/java/io/grpc/xds/XdsLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsLoadBalancerTest.java @@ -55,8 +55,6 @@ import io.grpc.internal.JsonParser; import io.grpc.internal.testing.StreamRecorder; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; -import io.grpc.xds.XdsLbState.SubchannelStore; -import io.grpc.xds.XdsLbState.SubchannelStoreImpl; import java.util.Collections; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -166,8 +164,8 @@ public class XdsLoadBalancerTest { } }); - private final SubchannelStore fakeSubchannelStore = - mock(SubchannelStore.class, delegatesTo(new SubchannelStoreImpl())); + @Mock + private LocalityStore fakeLocalityStore; private ManagedChannel oobChannel1; private ManagedChannel oobChannel2; @@ -181,10 +179,11 @@ public class XdsLoadBalancerTest { lbRegistry.register(lbProvider1); lbRegistry.register(lbProvider2); lbRegistry.register(roundRobin); - lb = new XdsLoadBalancer(helper, lbRegistry, fakeSubchannelStore); + lb = new XdsLoadBalancer(helper, lbRegistry, fakeLocalityStore); doReturn(syncContext).when(helper).getSynchronizationContext(); doReturn(fakeClock.getScheduledExecutorService()).when(helper).getScheduledExecutorService(); doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger(); + doReturn("fake_authority").when(helper).getAuthority(); String serverName = InProcessServerBuilder.generateName(); @@ -236,7 +235,7 @@ public class XdsLoadBalancerTest { delegatesTo(cleanupRule.register(channelBuilder.build()))); doReturn(oobChannel1).doReturn(oobChannel2).doReturn(oobChannel3) - .when(helper).createOobChannel(ArgumentMatchers.any(), anyString()); + .when(helper).createResolvingOobChannel(anyString()); } @After @@ -268,7 +267,7 @@ public class XdsLoadBalancerTest { XdsLbState xdsLbState1 = lb.getXdsLbStateForTest(); assertThat(xdsLbState1.childPolicy).isNull(); - verify(helper).createOobChannel(ArgumentMatchers.any(), anyString()); + verify(helper).createResolvingOobChannel(anyString()); verify(oobChannel1) .newCall(ArgumentMatchers.>any(), ArgumentMatchers.any()); @@ -293,7 +292,7 @@ public class XdsLoadBalancerTest { assertThat(xdsLbState2).isSameInstanceAs(xdsLbState1); // verify oobChannel is unchanged - verify(helper).createOobChannel(ArgumentMatchers.any(), anyString()); + verify(helper).createResolvingOobChannel(anyString()); // verify ADS stream is unchanged verify(oobChannel1) .newCall(ArgumentMatchers.>any(), @@ -316,7 +315,7 @@ public class XdsLoadBalancerTest { .setAddresses(Collections.emptyList()) .setAttributes(attrs) .build()); - verify(helper).createOobChannel(ArgumentMatchers.any(), anyString()); + verify(helper).createResolvingOobChannel(anyString()); verify(oobChannel1) .newCall(ArgumentMatchers.>any(), ArgumentMatchers.any()); @@ -339,7 +338,7 @@ public class XdsLoadBalancerTest { assertThat(lb.getXdsLbStateForTest().childPolicy).isNotNull(); // verify oobChannel is unchanged - verify(helper).createOobChannel(ArgumentMatchers.any(), anyString()); + verify(helper).createResolvingOobChannel(anyString()); // verify ADS stream is reset verify(oobChannel1, times(2)) .newCall(ArgumentMatchers.>any(), @@ -362,7 +361,7 @@ public class XdsLoadBalancerTest { .setAddresses(Collections.emptyList()) .setAttributes(attrs) .build()); - verify(helper).createOobChannel(ArgumentMatchers.any(), anyString()); + verify(helper).createResolvingOobChannel(anyString()); verify(oobChannel1) .newCall(ArgumentMatchers.>any(), ArgumentMatchers.any()); @@ -387,7 +386,7 @@ public class XdsLoadBalancerTest { assertThat(lb.getXdsLbStateForTest().childPolicy).isNull(); // verify oobChannel is unchanged - verify(helper).createOobChannel(ArgumentMatchers.any(), anyString()); + verify(helper).createResolvingOobChannel(anyString()); // verify ADS stream is reset verify(oobChannel1, times(2)) .newCall(ArgumentMatchers.>any(), @@ -411,7 +410,7 @@ public class XdsLoadBalancerTest { .build()); assertThat(lb.getXdsLbStateForTest().childPolicy).isNotNull(); - verify(helper).createOobChannel(ArgumentMatchers.any(), anyString()); + verify(helper).createResolvingOobChannel(anyString()); verify(oobChannel1) .newCall(ArgumentMatchers.>any(), ArgumentMatchers.any()); @@ -433,7 +432,7 @@ public class XdsLoadBalancerTest { assertThat(lb.getXdsLbStateForTest().childPolicy).isNotNull(); // verify oobChannel is unchanged - verify(helper).createOobChannel(ArgumentMatchers.any(), anyString()); + verify(helper).createResolvingOobChannel(anyString()); // verify ADS stream is reset verify(oobChannel1, times(2)) .newCall(ArgumentMatchers.>any(), @@ -456,7 +455,7 @@ public class XdsLoadBalancerTest { .setAddresses(Collections.emptyList()) .setAttributes(attrs) .build()); - verify(helper).createOobChannel(ArgumentMatchers.any(), anyString()); + verify(helper).createResolvingOobChannel(anyString()); verify(oobChannel1) .newCall(ArgumentMatchers.>any(), ArgumentMatchers.any()); @@ -479,8 +478,7 @@ public class XdsLoadBalancerTest { assertThat(lb.getXdsLbStateForTest().childPolicy).isNotNull(); // verify oobChannel is unchanged - verify(helper, times(2)).createOobChannel(ArgumentMatchers.any(), - anyString()); + verify(helper, times(2)).createResolvingOobChannel(anyString()); verify(oobChannel1) .newCall(ArgumentMatchers.>any(), ArgumentMatchers.any()); @@ -552,7 +550,7 @@ public class XdsLoadBalancerTest { .build()); serverResponseWriter.onNext(DiscoveryResponse.getDefaultInstance()); - doReturn(true).when(fakeSubchannelStore).hasReadyBackends(); + doReturn(true).when(fakeLocalityStore).hasReadyBackends(); serverResponseWriter.onError(new Exception("fake error")); verify(fakeBalancer1, never()).handleResolvedAddresses( @@ -575,8 +573,7 @@ public class XdsLoadBalancerTest { } }; - doReturn(true).when(fakeSubchannelStore).hasSubchannel(subchannel); - doReturn(false).when(fakeSubchannelStore).hasReadyBackends(); + doReturn(false).when(fakeLocalityStore).hasReadyBackends(); lb.handleSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE));