diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java index 87e2b6bc61..65f4832f54 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java @@ -16,6 +16,8 @@ package io.grpc.grpclb; +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; import io.grpc.ExperimentalApi; import io.grpc.Metadata; @@ -32,5 +34,12 @@ public final class GrpclbConstants { public static final Metadata.Key TOKEN_METADATA_KEY = Metadata.Key.of("lb-token", Metadata.ASCII_STRING_MARSHALLER); + /** + * For passing LB tokens via the EAG attributes. + */ + @EquivalentAddressGroup.Attr + static final Attributes.Key TOKEN_ATTRIBUTE_KEY = + Attributes.Key.create("lb-token"); + private GrpclbConstants() { } } diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java index 09ed32e399..b0719f0625 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java @@ -17,6 +17,7 @@ package io.grpc.grpclb; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; @@ -51,7 +52,11 @@ class GrpclbLoadBalancer extends LoadBalancer { private static final Logger logger = Logger.getLogger(GrpclbLoadBalancer.class.getName()); private final Helper helper; + private final TimeProvider time; private final SubchannelPool subchannelPool; + private final BackoffPolicy.Provider backoffPolicyProvider; + + private Mode mode = Mode.ROUND_ROBIN; // All mutable states in this class are mutated ONLY from Channel Executor @Nullable @@ -63,12 +68,12 @@ class GrpclbLoadBalancer extends LoadBalancer { TimeProvider time, BackoffPolicy.Provider backoffPolicyProvider) { this.helper = checkNotNull(helper, "helper"); - checkNotNull(time, "time provider"); - checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); + this.time = checkNotNull(time, "time provider"); + this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool"); this.subchannelPool.init(helper); - grpclbState = - new GrpclbState(helper, subchannelPool, time, backoffPolicyProvider); + recreateStates(); + checkNotNull(grpclbState, "grpclbState"); } @Override @@ -97,7 +102,12 @@ class GrpclbLoadBalancer extends LoadBalancer { newBackendServers = Collections.unmodifiableList(newBackendServers); Map rawLbConfigValue = attributes.get(ATTR_LOAD_BALANCING_CONFIG); Mode newMode = retrieveModeFromLbConfig(rawLbConfigValue, helper.getChannelLogger()); - grpclbState.handleAddresses(newLbAddressGroups, newBackendServers, newMode); + if (!mode.equals(newMode)) { + mode = newMode; + helper.getChannelLogger().log(ChannelLogLevel.INFO, "Mode: " + newMode); + recreateStates(); + } + grpclbState.handleAddresses(newLbAddressGroups, newBackendServers); } @VisibleForTesting @@ -141,6 +151,12 @@ class GrpclbLoadBalancer extends LoadBalancer { } } + private void recreateStates() { + resetStates(); + checkState(grpclbState == null, "Should've been cleared"); + grpclbState = new GrpclbState(mode, helper, subchannelPool, time, backoffPolicyProvider); + } + @Override public void shutdown() { resetStates(); diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java index e4203da778..4d87f32d24 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java @@ -138,8 +138,8 @@ final class GrpclbState { @Nullable private LbStream lbStream; - private Map subchannels = Collections.emptyMap(); - private Mode mode; + private Map, Subchannel> subchannels = Collections.emptyMap(); + private final Mode mode; // Has the same size as the round-robin list from the balancer. // A drop entry from the round-robin list becomes a DropEntry here. @@ -151,10 +151,12 @@ final class GrpclbState { new RoundRobinPicker(Collections.emptyList(), Arrays.asList(BUFFER_ENTRY)); GrpclbState( + Mode mode, Helper helper, SubchannelPool subchannelPool, TimeProvider time, BackoffPolicy.Provider backoffPolicyProvider) { + this.mode = checkNotNull(mode, "mode"); this.helper = checkNotNull(helper, "helper"); this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool"); @@ -169,7 +171,7 @@ final class GrpclbState { if (newState.getState() == SHUTDOWN || !subchannels.values().contains(subchannel)) { return; } - if (newState.getState() == IDLE) { + if (mode == Mode.ROUND_ROBIN && newState.getState() == IDLE) { subchannel.requestConnection(); } subchannel.getAttributes().get(STATE_INFO).set(newState); @@ -182,8 +184,7 @@ final class GrpclbState { * not yet connected. */ void handleAddresses( - List newLbAddressGroups, List newBackendServers, - Mode mode) { + List newLbAddressGroups, List newBackendServers) { if (newLbAddressGroups.isEmpty()) { propagateError(Status.UNAVAILABLE.withDescription( "NameResolver returned no LB address while asking for GRPCLB")); @@ -305,10 +306,20 @@ final class GrpclbState { void shutdown() { shutdownLbComm(); - // We close the subchannels through subchannelPool instead of helper just for convenience of - // testing. - for (Subchannel subchannel : subchannels.values()) { - subchannelPool.returnSubchannel(subchannel); + switch (mode) { + case ROUND_ROBIN: + // We close the subchannels through subchannelPool instead of helper just for convenience of + // testing. + for (Subchannel subchannel : subchannels.values()) { + subchannelPool.returnSubchannel(subchannel); + } + break; + case PICK_FIRST: + checkState(subchannels.size() == 1, "Excessive Subchannels: %s", subchannels); + subchannels.values().iterator().next().shutdown(); + break; + default: + throw new AssertionError("Missing case for " + mode); } subchannels = Collections.emptyMap(); subchannelPool.clear(); @@ -341,45 +352,74 @@ final class GrpclbState { @Nullable GrpclbClientLoadRecorder loadRecorder) { logger.log( ChannelLogLevel.INFO, "Using RR list={0}, drop={1}", newBackendAddrList, newDropList); - HashMap newSubchannelMap = + HashMap, Subchannel> newSubchannelMap = new HashMap<>(); List newBackendList = new ArrayList<>(); - for (BackendAddressGroup backendAddr : newBackendAddrList) { - EquivalentAddressGroup eag = backendAddr.getAddresses(); - Subchannel subchannel = newSubchannelMap.get(eag); - if (subchannel == null) { - subchannel = subchannels.get(eag); - if (subchannel == null) { - Attributes subchannelAttrs = Attributes.newBuilder() - .set(STATE_INFO, - new AtomicReference<>( - ConnectivityStateInfo.forNonError(IDLE))) - .build(); - subchannel = subchannelPool.takeOrCreateSubchannel(eag, subchannelAttrs); - subchannel.requestConnection(); + switch (mode) { + case ROUND_ROBIN: + for (BackendAddressGroup backendAddr : newBackendAddrList) { + EquivalentAddressGroup eag = backendAddr.getAddresses(); + List eagAsList = Collections.singletonList(eag); + Subchannel subchannel = newSubchannelMap.get(eagAsList); + if (subchannel == null) { + subchannel = subchannels.get(eagAsList); + if (subchannel == null) { + subchannel = subchannelPool.takeOrCreateSubchannel(eag, createSubchannelAttrs()); + subchannel.requestConnection(); + } + newSubchannelMap.put(eagAsList, subchannel); + } + BackendEntry entry; + // Only picks with tokens are reported to LoadRecorder + if (backendAddr.getToken() == null) { + entry = new BackendEntry(subchannel); + } else { + entry = new BackendEntry(subchannel, loadRecorder, backendAddr.getToken()); + } + newBackendList.add(entry); } - newSubchannelMap.put(eag, subchannel); - } - BackendEntry entry; - // Only picks with tokens are reported to LoadRecorder - if (backendAddr.getToken() == null) { - entry = new BackendEntry(subchannel); - } else { - entry = new BackendEntry(subchannel, loadRecorder, backendAddr.getToken()); - } - newBackendList.add(entry); + // Close Subchannels whose addresses have been delisted + for (Entry, Subchannel> entry : subchannels.entrySet()) { + List eagList = entry.getKey(); + if (!newSubchannelMap.containsKey(eagList)) { + subchannelPool.returnSubchannel(entry.getValue()); + } + } + subchannels = Collections.unmodifiableMap(newSubchannelMap); + break; + case PICK_FIRST: + List eagList = new ArrayList<>(); + // Because for PICK_FIRST, we create a single Subchannel for all addresses, we have to + // attach the tokens to the EAG attributes and use TokenAttachingLoadRecorder to put them on + // headers. + // + // The PICK_FIRST code path doesn't cache Subchannels. + for (BackendAddressGroup bag : newBackendAddrList) { + EquivalentAddressGroup origEag = bag.getAddresses(); + Attributes eagAttrs = origEag.getAttributes(); + if (bag.getToken() != null) { + eagAttrs = eagAttrs.toBuilder() + .set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, bag.getToken()).build(); + } + eagList.add(new EquivalentAddressGroup(origEag.getAddresses(), eagAttrs)); + } + Subchannel subchannel; + if (subchannels.isEmpty()) { + subchannel = helper.createSubchannel(eagList, createSubchannelAttrs()); + } else { + checkState(subchannels.size() == 1, "Unexpected Subchannel count: %s", subchannels); + subchannel = subchannels.values().iterator().next(); + helper.updateSubchannelAddresses(subchannel, eagList); + } + subchannels = Collections.singletonMap(eagList, subchannel); + newBackendList.add( + new BackendEntry(subchannel, new TokenAttachingTracerFactory(loadRecorder))); + break; + default: + throw new AssertionError("Missing case for " + mode); } - // Close Subchannels whose addresses have been delisted - for (Entry entry : subchannels.entrySet()) { - EquivalentAddressGroup eag = entry.getKey(); - if (!newSubchannelMap.containsKey(eag)) { - subchannelPool.returnSubchannel(entry.getValue()); - } - } - - subchannels = Collections.unmodifiableMap(newSubchannelMap); dropList = Collections.unmodifiableList(newDropList); backendList = Collections.unmodifiableList(newBackendList); } @@ -619,32 +659,67 @@ final class GrpclbState { * changed since the last picker created. */ private void maybeUpdatePicker() { - List pickList = new ArrayList<>(backendList.size()); - Status error = null; - boolean hasIdle = false; - for (BackendEntry entry : backendList) { - Subchannel subchannel = entry.result.getSubchannel(); - Attributes attrs = subchannel.getAttributes(); - ConnectivityStateInfo stateInfo = attrs.get(STATE_INFO).get(); - if (stateInfo.getState() == READY) { - pickList.add(entry); - } else if (stateInfo.getState() == TRANSIENT_FAILURE) { - error = stateInfo.getStatus(); - } else if (stateInfo.getState() == IDLE) { - hasIdle = true; - } - } + List pickList; ConnectivityState state; - if (pickList.isEmpty()) { - if (error != null && !hasIdle) { - pickList.add(new ErrorEntry(error)); - state = TRANSIENT_FAILURE; - } else { - pickList.add(BUFFER_ENTRY); - state = CONNECTING; - } - } else { - state = READY; + switch (mode) { + case ROUND_ROBIN: + pickList = new ArrayList<>(backendList.size()); + Status error = null; + boolean hasIdle = false; + for (BackendEntry entry : backendList) { + Subchannel subchannel = entry.subchannel; + Attributes attrs = subchannel.getAttributes(); + ConnectivityStateInfo stateInfo = attrs.get(STATE_INFO).get(); + if (stateInfo.getState() == READY) { + pickList.add(entry); + } else if (stateInfo.getState() == TRANSIENT_FAILURE) { + error = stateInfo.getStatus(); + } else if (stateInfo.getState() == IDLE) { + hasIdle = true; + } + } + if (pickList.isEmpty()) { + if (error != null && !hasIdle) { + pickList.add(new ErrorEntry(error)); + state = TRANSIENT_FAILURE; + } else { + pickList.add(BUFFER_ENTRY); + state = CONNECTING; + } + } else { + state = READY; + } + break; + case PICK_FIRST: + if (backendList.isEmpty()) { + pickList = Collections.singletonList(BUFFER_ENTRY); + // Have not received server addresses + state = CONNECTING; + } else { + checkState(backendList.size() == 1, "Excessive backend entries: %s", backendList); + BackendEntry onlyEntry = backendList.get(0); + ConnectivityStateInfo stateInfo = + onlyEntry.subchannel.getAttributes().get(STATE_INFO).get(); + state = stateInfo.getState(); + switch (state) { + case READY: + pickList = Collections.singletonList(onlyEntry); + break; + case TRANSIENT_FAILURE: + pickList = + Collections.singletonList(new ErrorEntry(stateInfo.getStatus())); + break; + case CONNECTING: + pickList = Collections.singletonList(BUFFER_ENTRY); + break; + default: + pickList = Collections.singletonList( + new IdleSubchannelEntry(onlyEntry.subchannel)); + } + } + break; + default: + throw new AssertionError("Missing case for " + mode); } maybeUpdatePicker(state, new RoundRobinPicker(dropList, pickList)); } @@ -704,6 +779,14 @@ final class GrpclbState { return new EquivalentAddressGroup(addrs, attrs); } + private static Attributes createSubchannelAttrs() { + return Attributes.newBuilder() + .set(STATE_INFO, + new AtomicReference<>( + ConnectivityStateInfo.forNonError(IDLE))) + .build(); + } + @VisibleForTesting static final class DropEntry { private final GrpclbClientLoadRecorder loadRecorder; @@ -740,34 +823,45 @@ final class GrpclbState { } } - private interface RoundRobinEntry { + @VisibleForTesting + interface RoundRobinEntry { PickResult picked(Metadata headers); } @VisibleForTesting static final class BackendEntry implements RoundRobinEntry { + final Subchannel subchannel; @VisibleForTesting final PickResult result; @Nullable - private final GrpclbClientLoadRecorder loadRecorder; - @Nullable private final String token; /** - * Creates a BackendEntry whose usage will be reported to load recorder. + * For ROUND_ROBIN: creates a BackendEntry whose usage will be reported to load recorder. */ BackendEntry(Subchannel subchannel, GrpclbClientLoadRecorder loadRecorder, String token) { - this.result = PickResult.withSubchannel(subchannel, loadRecorder); - this.loadRecorder = checkNotNull(loadRecorder, "loadRecorder"); + this.subchannel = checkNotNull(subchannel, "subchannel"); + this.result = + PickResult.withSubchannel(subchannel, checkNotNull(loadRecorder, "loadRecorder")); this.token = checkNotNull(token, "token"); } /** - * Creates a BackendEntry whose usage will not be reported. + * For ROUND_ROBIN/PICK_FIRST: creates a BackendEntry whose usage will not be reported. */ BackendEntry(Subchannel subchannel) { + this.subchannel = checkNotNull(subchannel, "subchannel"); this.result = PickResult.withSubchannel(subchannel); - this.loadRecorder = null; + this.token = null; + } + + /** + * For PICK_FIRST: creates a BackendEntry that includes all addresses. + */ + BackendEntry(Subchannel subchannel, TokenAttachingTracerFactory tracerFactory) { + this.subchannel = checkNotNull(subchannel, "subchannel"); + this.result = + PickResult.withSubchannel(subchannel, checkNotNull(tracerFactory, "tracerFactory")); this.token = null; } @@ -783,12 +877,12 @@ final class GrpclbState { @Override public String toString() { // This is printed in logs. Only give out useful information. - return "[" + result.getSubchannel().getAllAddresses().toString() + "(" + token + ")]"; + return "[" + subchannel.getAllAddresses().toString() + "(" + token + ")]"; } @Override public int hashCode() { - return Objects.hashCode(loadRecorder, result, token); + return Objects.hashCode(result, token); } @Override @@ -797,8 +891,42 @@ final class GrpclbState { return false; } BackendEntry that = (BackendEntry) other; - return Objects.equal(result, that.result) && Objects.equal(token, that.token) - && Objects.equal(loadRecorder, that.loadRecorder); + return Objects.equal(result, that.result) && Objects.equal(token, that.token); + } + } + + @VisibleForTesting + static final class IdleSubchannelEntry implements RoundRobinEntry { + private final Subchannel subchannel; + + IdleSubchannelEntry(Subchannel subchannel) { + this.subchannel = checkNotNull(subchannel, "subchannel"); + } + + @Override + public PickResult picked(Metadata headers) { + subchannel.requestConnection(); + return PickResult.withNoResult(); + } + + @Override + public String toString() { + // This is printed in logs. Only give out useful information. + return "(idle)[" + subchannel.getAllAddresses().toString() + "]"; + } + + @Override + public int hashCode() { + return Objects.hashCode(subchannel); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof IdleSubchannelEntry)) { + return false; + } + IdleSubchannelEntry that = (IdleSubchannelEntry) other; + return Objects.equal(subchannel, that.subchannel); } } @@ -860,8 +988,7 @@ final class GrpclbState { // First round-robin on dropList. If a drop entry is selected, request will be dropped. If // a non-drop entry is selected, then round-robin on pickList. This makes sure requests are // dropped at the same proportion as the drop entries appear on the round-robin list from - // the balancer, while only READY backends (that make up pickList) are selected for the - // non-drop cases. + // the balancer, while only backends from pickList are selected for the non-drop cases. if (!dropList.isEmpty()) { DropEntry drop = dropList.get(dropIndex); dropIndex++; @@ -881,5 +1008,14 @@ final class GrpclbState { return pick.picked(args.getHeaders()); } } + + @Override + public void requestConnection() { + for (RoundRobinEntry entry : pickList) { + if (entry instanceof IdleSubchannelEntry) { + ((IdleSubchannelEntry) entry).subchannel.requestConnection(); + } + } + } } } diff --git a/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java b/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java new file mode 100644 index 0000000000..03b9bdf7f1 --- /dev/null +++ b/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java @@ -0,0 +1,72 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.grpclb; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.base.Objects; +import io.grpc.Attributes; +import io.grpc.ClientStreamTracer; +import io.grpc.Metadata; +import io.grpc.internal.GrpcAttributes; +import javax.annotation.Nullable; + +/** + * Wraps a {@link ClientStreamTracer.Factory}, retrieves tokens from transport attributes and + * attaches them to headers. This is only used in the PICK_FIRST mode. + */ +final class TokenAttachingTracerFactory extends ClientStreamTracer.Factory { + private static final ClientStreamTracer NOOP_TRACER = new ClientStreamTracer() {}; + + @Nullable + private final ClientStreamTracer.Factory delegate; + + TokenAttachingTracerFactory(@Nullable ClientStreamTracer.Factory delegate) { + this.delegate = delegate; + } + + @Override + public ClientStreamTracer newClientStreamTracer( + ClientStreamTracer.StreamInfo info, Metadata headers) { + Attributes transportAttrs = checkNotNull(info.getTransportAttrs(), "transportAttrs"); + Attributes eagAttrs = + checkNotNull(transportAttrs.get(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS), "eagAttrs"); + String token = eagAttrs.get(GrpclbConstants.TOKEN_ATTRIBUTE_KEY); + headers.discardAll(GrpclbConstants.TOKEN_METADATA_KEY); + if (token != null) { + headers.put(GrpclbConstants.TOKEN_METADATA_KEY, token); + } + if (delegate != null) { + return delegate.newClientStreamTracer(info, headers); + } else { + return NOOP_TRACER; + } + } + + @Override + public int hashCode() { + return Objects.hashCode(delegate); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof TokenAttachingTracerFactory)) { + return false; + } + return Objects.equal(delegate, ((TokenAttachingTracerFactory) other).delegate); + } +} diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java index e4bb5f3599..ee6e185003 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java @@ -56,6 +56,7 @@ import io.grpc.ClientStreamTracer; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -68,7 +69,9 @@ import io.grpc.SynchronizationContext; import io.grpc.grpclb.GrpclbState.BackendEntry; import io.grpc.grpclb.GrpclbState.DropEntry; import io.grpc.grpclb.GrpclbState.ErrorEntry; +import io.grpc.grpclb.GrpclbState.IdleSubchannelEntry; import io.grpc.grpclb.GrpclbState.Mode; +import io.grpc.grpclb.GrpclbState.RoundRobinEntry; import io.grpc.grpclb.GrpclbState.RoundRobinPicker; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; @@ -166,7 +169,8 @@ public class GrpclbLoadBalancerTest { new LinkedList<>(); private final LinkedList mockSubchannels = new LinkedList<>(); private final LinkedList fakeOobChannels = new LinkedList<>(); - private final ArrayList subchannelTracker = new ArrayList<>(); + private final ArrayList pooledSubchannelTracker = new ArrayList<>(); + private final ArrayList unpooledSubchannelTracker = new ArrayList<>(); private final ArrayList oobChannelTracker = new ArrayList<>(); private final ArrayList failingLbAuthorities = new ArrayList<>(); private final SynchronizationContext syncContext = new SynchronizationContext( @@ -251,11 +255,25 @@ public class GrpclbLoadBalancerTest { when(subchannel.getAllAddresses()).thenReturn(Arrays.asList(eag)); when(subchannel.getAttributes()).thenReturn(attrs); mockSubchannels.add(subchannel); - subchannelTracker.add(subchannel); + pooledSubchannelTracker.add(subchannel); return subchannel; } }).when(subchannelPool).takeOrCreateSubchannel( any(EquivalentAddressGroup.class), any(Attributes.class)); + doAnswer(new Answer() { + @Override + public Subchannel answer(InvocationOnMock invocation) throws Throwable { + Subchannel subchannel = mock(Subchannel.class); + List eagList = + (List) invocation.getArguments()[0]; + Attributes attrs = (Attributes) invocation.getArguments()[1]; + when(subchannel.getAllAddresses()).thenReturn(eagList); + when(subchannel.getAttributes()).thenReturn(attrs); + mockSubchannels.add(subchannel); + unpooledSubchannelTracker.add(subchannel); + return subchannel; + } + }).when(helper).createSubchannel(any(List.class), any(Attributes.class)); when(helper.getSynchronizationContext()).thenReturn(syncContext); when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService()); when(helper.getChannelLogger()).thenReturn(channelLogger); @@ -294,14 +312,15 @@ public class GrpclbLoadBalancerTest { assertTrue(channel + " is terminated", channel.isTerminated()); } // GRPCLB manages subchannels only through subchannelPool - for (Subchannel subchannel: subchannelTracker) { + for (Subchannel subchannel : pooledSubchannelTracker) { verify(subchannelPool).returnSubchannel(same(subchannel)); // Our mock subchannelPool never calls Subchannel.shutdown(), thus we can tell if // LoadBalancer has called it expectedly. verify(subchannel, never()).shutdown(); } - verify(helper, never()) - .createSubchannel(any(List.class), any(Attributes.class)); + for (Subchannel subchannel : unpooledSubchannelTracker) { + verify(subchannel).shutdown(); + } // No timer should linger after shutdown assertThat(fakeClock.getPendingTasks()).isEmpty(); } finally { @@ -406,6 +425,65 @@ public class GrpclbLoadBalancerTest { verify(subchannel, never()).getAttributes(); } + @Test + public void roundRobinPickerWithIdleEntry_noDrop() { + Subchannel subchannel = mock(Subchannel.class); + IdleSubchannelEntry entry = new IdleSubchannelEntry(subchannel); + + RoundRobinPicker picker = + new RoundRobinPicker(Collections.emptyList(), Collections.singletonList(entry)); + PickSubchannelArgs args = mock(PickSubchannelArgs.class); + + verify(subchannel, never()).requestConnection(); + assertThat(picker.pickSubchannel(args)).isSameAs(PickResult.withNoResult()); + verify(subchannel).requestConnection(); + } + + @Test + public void roundRobinPickerWithIdleEntry_andDrop() { + GrpclbClientLoadRecorder loadRecorder = + new GrpclbClientLoadRecorder(fakeClock.getTimeProvider()); + // 1 out of 2 requests are to be dropped + DropEntry d = new DropEntry(loadRecorder, "LBTOKEN0003"); + List dropList = Arrays.asList(null, d); + + Subchannel subchannel = mock(Subchannel.class); + IdleSubchannelEntry entry = new IdleSubchannelEntry(subchannel); + + RoundRobinPicker picker = new RoundRobinPicker(dropList, Collections.singletonList(entry)); + PickSubchannelArgs args = mock(PickSubchannelArgs.class); + + verify(subchannel, never()).requestConnection(); + assertThat(picker.pickSubchannel(args)).isSameAs(PickResult.withNoResult()); + verify(subchannel).requestConnection(); + + assertThat(picker.pickSubchannel(args)).isSameAs(DROP_PICK_RESULT); + + verify(subchannel).requestConnection(); + assertThat(picker.pickSubchannel(args)).isSameAs(PickResult.withNoResult()); + verify(subchannel, times(2)).requestConnection(); + } + + @Test + public void roundRobinPicker_requestConnection() { + // requestConnection() on RoundRobinPicker is only passed to IdleSubchannelEntry + + Subchannel subchannel1 = mock(Subchannel.class); + Subchannel subchannel2 = mock(Subchannel.class); + + RoundRobinPicker picker = new RoundRobinPicker( + Collections.emptyList(), + Arrays.asList( + new BackendEntry(subchannel1), new IdleSubchannelEntry(subchannel2), + new ErrorEntry(Status.UNAVAILABLE))); + + verify(subchannel2, never()).requestConnection(); + + picker.requestConnection(); + verify(subchannel2).requestConnection(); + verify(subchannel1, never()).requestConnection(); + } + @Test public void loadReporting() { Metadata headers = new Metadata(); @@ -1591,6 +1669,297 @@ public class GrpclbLoadBalancerTest { verify(helper, times(4)).refreshNameResolution(); } + @SuppressWarnings("unchecked") + @Test + public void grpclbWorking_pickFirstMode() throws Exception { + InOrder inOrder = inOrder(helper); + + String lbConfig = "{\"childPolicy\" : [ {\"pick_first\" : {}} ]}"; + List grpclbResolutionList = createResolvedServerAddresses(true); + Attributes grpclbResolutionAttrs = Attributes.newBuilder().set( + LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build(); + + deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs); + + assertEquals(1, fakeOobChannels.size()); + ManagedChannel oobChannel = fakeOobChannels.poll(); + verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture()); + StreamObserver lbResponseObserver = lbResponseObserverCaptor.getValue(); + assertEquals(1, lbRequestObservers.size()); + StreamObserver lbRequestObserver = lbRequestObservers.poll(); + verify(lbRequestObserver).onNext( + eq(LoadBalanceRequest.newBuilder().setInitialRequest( + InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build()) + .build())); + + // Simulate receiving LB response + List backends1 = Arrays.asList( + new ServerEntry("127.0.0.1", 2000, "token0001"), + new ServerEntry("127.0.0.1", 2010, "token0002")); + inOrder.verify(helper, never()) + .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); + lbResponseObserver.onNext(buildInitialResponse()); + lbResponseObserver.onNext(buildLbResponse(backends1)); + + inOrder.verify(helper).createSubchannel( + eq(Arrays.asList( + new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), + new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))), + any(Attributes.class)); + + // Initially IDLE + inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue(); + + // Only one subchannel is created + assertThat(mockSubchannels).hasSize(1); + Subchannel subchannel = mockSubchannels.poll(); + assertThat(picker0.dropList).containsExactly(null, null); + assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel)); + + // PICK_FIRST doesn't eagerly connect + verify(subchannel, never()).requestConnection(); + + // CONNECTING + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); + + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker1.dropList).containsExactly(null, null); + assertThat(picker1.pickList).containsExactly(BUFFER_ENTRY); + + // TRANSIENT_FAILURE + Status error = Status.UNAVAILABLE.withDescription("Simulated connection error"); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker2.dropList).containsExactly(null, null); + assertThat(picker2.pickList).containsExactly(new ErrorEntry(error)); + + // READY + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + RoundRobinPicker picker3 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker3.dropList).containsExactly(null, null); + assertThat(picker3.pickList).containsExactly( + new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); + + + // New server list with drops + List backends2 = Arrays.asList( + new ServerEntry("127.0.0.1", 2000, "token0001"), + new ServerEntry("token0003"), // drop + new ServerEntry("127.0.0.1", 2020, "token0004")); + inOrder.verify(helper, never()) + .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); + lbResponseObserver.onNext(buildLbResponse(backends2)); + + // new addresses will be updated to the existing subchannel + // createSubchannel() has ever been called only once + verify(helper, times(1)).createSubchannel(any(List.class), any(Attributes.class)); + assertThat(mockSubchannels).isEmpty(); + inOrder.verify(helper).updateSubchannelAddresses( + same(subchannel), + eq(Arrays.asList( + new EquivalentAddressGroup(backends2.get(0).addr, eagAttrsWithToken("token0001")), + new EquivalentAddressGroup(backends2.get(2).addr, + eagAttrsWithToken("token0004"))))); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + RoundRobinPicker picker4 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker4.dropList).containsExactly( + null, new DropEntry(getLoadRecorder(), "token0003"), null); + assertThat(picker4.pickList).containsExactly( + new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); + + // Subchannel goes IDLE, but PICK_FIRST will not try to reconnect + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); + inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + RoundRobinPicker picker5 = (RoundRobinPicker) pickerCaptor.getValue(); + verify(subchannel, never()).requestConnection(); + + // ... until it's selected + PickSubchannelArgs args = mock(PickSubchannelArgs.class); + PickResult pick = picker5.pickSubchannel(args); + assertThat(pick).isSameAs(PickResult.withNoResult()); + verify(subchannel).requestConnection(); + + // ... or requested by application + picker5.requestConnection(); + verify(subchannel, times(2)).requestConnection(); + + // PICK_FIRST doesn't use subchannelPool + verify(subchannelPool, never()) + .takeOrCreateSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)); + verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void pickFirstMode_fallback() throws Exception { + InOrder inOrder = inOrder(helper); + + String lbConfig = "{\"childPolicy\" : [ {\"pick_first\" : {}} ]}"; + + // Name resolver returns a mix of balancer and backend addresses + List grpclbResolutionList = + createResolvedServerAddresses(false, true, false); + Attributes grpclbResolutionAttrs = Attributes.newBuilder().set( + LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build(); + deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs); + + // Attempted to connect to balancer + assertEquals(1, fakeOobChannels.size()); + ManagedChannel oobChannel = fakeOobChannels.poll(); + verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture()); + StreamObserver lbResponseObserver = lbResponseObserverCaptor.getValue(); + assertEquals(1, lbRequestObservers.size()); + + // Fallback timer expires with no response + fakeClock.forwardTime(GrpclbState.FALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS); + + // Entering fallback mode + inOrder.verify(helper).createSubchannel( + eq(Arrays.asList(grpclbResolutionList.get(0), grpclbResolutionList.get(2))), + any(Attributes.class)); + + assertThat(mockSubchannels).hasSize(1); + Subchannel subchannel = mockSubchannels.poll(); + + // Initially IDLE + inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue(); + + // READY + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker1.dropList).containsExactly(null, null); + assertThat(picker1.pickList).containsExactly( + new BackendEntry(subchannel, new TokenAttachingTracerFactory(null))); + + assertThat(picker0.dropList).containsExactly(null, null); + assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel)); + + + // Finally, an LB response, which brings us out of fallback + List backends1 = Arrays.asList( + new ServerEntry("127.0.0.1", 2000, "token0001"), + new ServerEntry("127.0.0.1", 2010, "token0002")); + inOrder.verify(helper, never()) + .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); + lbResponseObserver.onNext(buildInitialResponse()); + lbResponseObserver.onNext(buildLbResponse(backends1)); + + // new addresses will be updated to the existing subchannel + // createSubchannel() has ever been called only once + verify(helper, times(1)).createSubchannel(any(List.class), any(Attributes.class)); + assertThat(mockSubchannels).isEmpty(); + inOrder.verify(helper).updateSubchannelAddresses( + same(subchannel), + eq(Arrays.asList( + new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), + new EquivalentAddressGroup(backends1.get(1).addr, + eagAttrsWithToken("token0002"))))); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker2.dropList).containsExactly(null, null); + assertThat(picker2.pickList).containsExactly( + new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); + + // PICK_FIRST doesn't use subchannelPool + verify(subchannelPool, never()) + .takeOrCreateSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)); + verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class)); + } + + @Test + public void switchMode() throws Exception { + InOrder inOrder = inOrder(helper); + + String lbConfig = "{\"childPolicy\" : [ {\"round_robin\" : {}} ]}"; + List grpclbResolutionList = createResolvedServerAddresses(true); + Attributes grpclbResolutionAttrs = Attributes.newBuilder().set( + LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build(); + + deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs); + + assertEquals(1, fakeOobChannels.size()); + ManagedChannel oobChannel = fakeOobChannels.poll(); + verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture()); + StreamObserver lbResponseObserver = lbResponseObserverCaptor.getValue(); + assertEquals(1, lbRequestObservers.size()); + StreamObserver lbRequestObserver = lbRequestObservers.poll(); + verify(lbRequestObserver).onNext( + eq(LoadBalanceRequest.newBuilder().setInitialRequest( + InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build()) + .build())); + + // Simulate receiving LB response + List backends1 = Arrays.asList( + new ServerEntry("127.0.0.1", 2000, "token0001"), + new ServerEntry("127.0.0.1", 2010, "token0002")); + inOrder.verify(helper, never()) + .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); + lbResponseObserver.onNext(buildInitialResponse()); + lbResponseObserver.onNext(buildLbResponse(backends1)); + + // ROUND_ROBIN: create one subchannel per server + verify(subchannelPool).takeOrCreateSubchannel( + eq(new EquivalentAddressGroup(backends1.get(0).addr, LB_BACKEND_ATTRS)), + any(Attributes.class)); + verify(subchannelPool).takeOrCreateSubchannel( + eq(new EquivalentAddressGroup(backends1.get(1).addr, LB_BACKEND_ATTRS)), + any(Attributes.class)); + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + assertEquals(2, mockSubchannels.size()); + Subchannel subchannel1 = mockSubchannels.poll(); + Subchannel subchannel2 = mockSubchannels.poll(); + verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class)); + + // Switch to PICK_FIRST + lbConfig = "{\"childPolicy\" : [ {\"pick_first\" : {}} ]}"; + grpclbResolutionAttrs = Attributes.newBuilder().set( + LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build(); + deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs); + + + // GrpclbState will be shutdown, and a new one will be created + assertThat(oobChannel.isShutdown()).isTrue(); + verify(subchannelPool).returnSubchannel(same(subchannel1)); + verify(subchannelPool).returnSubchannel(same(subchannel2)); + + // A new LB stream is created + assertEquals(1, fakeOobChannels.size()); + oobChannel = fakeOobChannels.poll(); + verify(mockLbService, times(2)).balanceLoad(lbResponseObserverCaptor.capture()); + lbResponseObserver = lbResponseObserverCaptor.getValue(); + assertEquals(1, lbRequestObservers.size()); + lbRequestObserver = lbRequestObservers.poll(); + verify(lbRequestObserver).onNext( + eq(LoadBalanceRequest.newBuilder().setInitialRequest( + InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build()) + .build())); + + // Simulate receiving LB response + inOrder.verify(helper, never()) + .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); + lbResponseObserver.onNext(buildInitialResponse()); + lbResponseObserver.onNext(buildLbResponse(backends1)); + + // PICK_FIRST Subchannel + inOrder.verify(helper).createSubchannel( + eq(Arrays.asList( + new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), + new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))), + any(Attributes.class)); + + inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + } + + private static Attributes eagAttrsWithToken(String token) { + return LB_BACKEND_ATTRS.toBuilder().set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, token).build(); + } + @Test public void retrieveModeFromLbConfig_pickFirst() throws Exception { String lbConfig = "{\"childPolicy\" : [{\"pick_first\" : {}}, {\"round_robin\" : {}}]}"; diff --git a/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java b/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java new file mode 100644 index 0000000000..469372bc88 --- /dev/null +++ b/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java @@ -0,0 +1,124 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.grpclb; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.Matchers.same; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; +import io.grpc.Metadata; +import io.grpc.internal.GrpcAttributes; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link TokenAttachingTracerFactory}. */ +@RunWith(JUnit4.class) +public class TokenAttachingTracerFactoryTest { + private static final ClientStreamTracer fakeTracer = new ClientStreamTracer() {}; + + private final ClientStreamTracer.Factory delegate = mock( + ClientStreamTracer.Factory.class, + delegatesTo( + new ClientStreamTracer.Factory() { + @Override + public ClientStreamTracer newClientStreamTracer( + ClientStreamTracer.StreamInfo info, Metadata headers) { + return fakeTracer; + } + })); + + @Test + public void hasToken() { + TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate); + ClientStreamTracer.StreamInfo info = new ClientStreamTracer.StreamInfo() { + @Override + public Attributes getTransportAttrs() { + Attributes eagAttrs = Attributes.newBuilder() + .set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, "token0001").build(); + return Attributes.newBuilder() + .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build(); + } + + @Override + public CallOptions getCallOptions() { + return CallOptions.DEFAULT; + } + }; + Metadata headers = new Metadata(); + // Preexisting token should be replaced + headers.put(GrpclbConstants.TOKEN_METADATA_KEY, "preexisting-token"); + + ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers); + verify(delegate).newClientStreamTracer(same(info), same(headers)); + assertThat(tracer).isSameAs(fakeTracer); + assertThat(headers.getAll(GrpclbConstants.TOKEN_METADATA_KEY)).containsExactly("token0001"); + } + + @Test + public void noToken() { + TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate); + ClientStreamTracer.StreamInfo info = new ClientStreamTracer.StreamInfo() { + @Override + public Attributes getTransportAttrs() { + return Attributes.newBuilder() + .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build(); + } + + @Override + public CallOptions getCallOptions() { + return CallOptions.DEFAULT; + } + }; + + Metadata headers = new Metadata(); + // Preexisting token should be removed + headers.put(GrpclbConstants.TOKEN_METADATA_KEY, "preexisting-token"); + + ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers); + verify(delegate).newClientStreamTracer(same(info), same(headers)); + assertThat(tracer).isSameAs(fakeTracer); + assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull(); + } + + @Test + public void nullDelegate() { + TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(null); + ClientStreamTracer.StreamInfo info = new ClientStreamTracer.StreamInfo() { + @Override + public Attributes getTransportAttrs() { + return Attributes.newBuilder() + .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build(); + } + + @Override + public CallOptions getCallOptions() { + return CallOptions.DEFAULT; + } + }; + Metadata headers = new Metadata(); + + ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers); + assertThat(tracer).isNotNull(); + assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull(); + } +}