grpclb: support "pick_first" child policy (#5438)

The PICK_FIRST mode puts all backend addresses in a single Subchannel. There are a few points where it's different from the default ROUND_ROBIN mode:

1. PICK_FIRST doesn't eagerly connect to backends like ROUND_ROBIN does. Instead, it requests for connections when the Subchannel is picked.

2. PICK_FIRST adds tokens to the headers via a different code path (`TokenAttachingTracerFactory`) than ROUND_ROBIN

3. For simple implementation, when the mode is changed by service config when the LoadBalancer is working, we will shut down `GrpclbState` and starts a new one with the new mode. All connections will be closed during the transition. We don't expect this to happen in practice given the specific use case of PICK_FIRST.
This commit is contained in:
Kun Zhang 2019-03-06 13:02:32 -08:00 committed by GitHub
parent 128409000a
commit 2f50d88678
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 817 additions and 91 deletions

View File

@ -16,6 +16,8 @@
package io.grpc.grpclb;
import io.grpc.Attributes;
import io.grpc.EquivalentAddressGroup;
import io.grpc.ExperimentalApi;
import io.grpc.Metadata;
@ -32,5 +34,12 @@ public final class GrpclbConstants {
public static final Metadata.Key<String> TOKEN_METADATA_KEY =
Metadata.Key.of("lb-token", Metadata.ASCII_STRING_MARSHALLER);
/**
* For passing LB tokens via the EAG attributes.
*/
@EquivalentAddressGroup.Attr
static final Attributes.Key<String> TOKEN_ATTRIBUTE_KEY =
Attributes.Key.create("lb-token");
private GrpclbConstants() { }
}

View File

@ -17,6 +17,7 @@
package io.grpc.grpclb;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Attributes;
@ -51,7 +52,11 @@ class GrpclbLoadBalancer extends LoadBalancer {
private static final Logger logger = Logger.getLogger(GrpclbLoadBalancer.class.getName());
private final Helper helper;
private final TimeProvider time;
private final SubchannelPool subchannelPool;
private final BackoffPolicy.Provider backoffPolicyProvider;
private Mode mode = Mode.ROUND_ROBIN;
// All mutable states in this class are mutated ONLY from Channel Executor
@Nullable
@ -63,12 +68,12 @@ class GrpclbLoadBalancer extends LoadBalancer {
TimeProvider time,
BackoffPolicy.Provider backoffPolicyProvider) {
this.helper = checkNotNull(helper, "helper");
checkNotNull(time, "time provider");
checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
this.time = checkNotNull(time, "time provider");
this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool");
this.subchannelPool.init(helper);
grpclbState =
new GrpclbState(helper, subchannelPool, time, backoffPolicyProvider);
recreateStates();
checkNotNull(grpclbState, "grpclbState");
}
@Override
@ -97,7 +102,12 @@ class GrpclbLoadBalancer extends LoadBalancer {
newBackendServers = Collections.unmodifiableList(newBackendServers);
Map<String, Object> rawLbConfigValue = attributes.get(ATTR_LOAD_BALANCING_CONFIG);
Mode newMode = retrieveModeFromLbConfig(rawLbConfigValue, helper.getChannelLogger());
grpclbState.handleAddresses(newLbAddressGroups, newBackendServers, newMode);
if (!mode.equals(newMode)) {
mode = newMode;
helper.getChannelLogger().log(ChannelLogLevel.INFO, "Mode: " + newMode);
recreateStates();
}
grpclbState.handleAddresses(newLbAddressGroups, newBackendServers);
}
@VisibleForTesting
@ -141,6 +151,12 @@ class GrpclbLoadBalancer extends LoadBalancer {
}
}
private void recreateStates() {
resetStates();
checkState(grpclbState == null, "Should've been cleared");
grpclbState = new GrpclbState(mode, helper, subchannelPool, time, backoffPolicyProvider);
}
@Override
public void shutdown() {
resetStates();

View File

@ -138,8 +138,8 @@ final class GrpclbState {
@Nullable
private LbStream lbStream;
private Map<EquivalentAddressGroup, Subchannel> subchannels = Collections.emptyMap();
private Mode mode;
private Map<List<EquivalentAddressGroup>, Subchannel> subchannels = Collections.emptyMap();
private final Mode mode;
// Has the same size as the round-robin list from the balancer.
// A drop entry from the round-robin list becomes a DropEntry here.
@ -151,10 +151,12 @@ final class GrpclbState {
new RoundRobinPicker(Collections.<DropEntry>emptyList(), Arrays.asList(BUFFER_ENTRY));
GrpclbState(
Mode mode,
Helper helper,
SubchannelPool subchannelPool,
TimeProvider time,
BackoffPolicy.Provider backoffPolicyProvider) {
this.mode = checkNotNull(mode, "mode");
this.helper = checkNotNull(helper, "helper");
this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool");
@ -169,7 +171,7 @@ final class GrpclbState {
if (newState.getState() == SHUTDOWN || !subchannels.values().contains(subchannel)) {
return;
}
if (newState.getState() == IDLE) {
if (mode == Mode.ROUND_ROBIN && newState.getState() == IDLE) {
subchannel.requestConnection();
}
subchannel.getAttributes().get(STATE_INFO).set(newState);
@ -182,8 +184,7 @@ final class GrpclbState {
* not yet connected.
*/
void handleAddresses(
List<LbAddressGroup> newLbAddressGroups, List<EquivalentAddressGroup> newBackendServers,
Mode mode) {
List<LbAddressGroup> newLbAddressGroups, List<EquivalentAddressGroup> newBackendServers) {
if (newLbAddressGroups.isEmpty()) {
propagateError(Status.UNAVAILABLE.withDescription(
"NameResolver returned no LB address while asking for GRPCLB"));
@ -305,10 +306,20 @@ final class GrpclbState {
void shutdown() {
shutdownLbComm();
// We close the subchannels through subchannelPool instead of helper just for convenience of
// testing.
for (Subchannel subchannel : subchannels.values()) {
subchannelPool.returnSubchannel(subchannel);
switch (mode) {
case ROUND_ROBIN:
// We close the subchannels through subchannelPool instead of helper just for convenience of
// testing.
for (Subchannel subchannel : subchannels.values()) {
subchannelPool.returnSubchannel(subchannel);
}
break;
case PICK_FIRST:
checkState(subchannels.size() == 1, "Excessive Subchannels: %s", subchannels);
subchannels.values().iterator().next().shutdown();
break;
default:
throw new AssertionError("Missing case for " + mode);
}
subchannels = Collections.emptyMap();
subchannelPool.clear();
@ -341,45 +352,74 @@ final class GrpclbState {
@Nullable GrpclbClientLoadRecorder loadRecorder) {
logger.log(
ChannelLogLevel.INFO, "Using RR list={0}, drop={1}", newBackendAddrList, newDropList);
HashMap<EquivalentAddressGroup, Subchannel> newSubchannelMap =
HashMap<List<EquivalentAddressGroup>, Subchannel> newSubchannelMap =
new HashMap<>();
List<BackendEntry> newBackendList = new ArrayList<>();
for (BackendAddressGroup backendAddr : newBackendAddrList) {
EquivalentAddressGroup eag = backendAddr.getAddresses();
Subchannel subchannel = newSubchannelMap.get(eag);
if (subchannel == null) {
subchannel = subchannels.get(eag);
if (subchannel == null) {
Attributes subchannelAttrs = Attributes.newBuilder()
.set(STATE_INFO,
new AtomicReference<>(
ConnectivityStateInfo.forNonError(IDLE)))
.build();
subchannel = subchannelPool.takeOrCreateSubchannel(eag, subchannelAttrs);
subchannel.requestConnection();
switch (mode) {
case ROUND_ROBIN:
for (BackendAddressGroup backendAddr : newBackendAddrList) {
EquivalentAddressGroup eag = backendAddr.getAddresses();
List<EquivalentAddressGroup> eagAsList = Collections.singletonList(eag);
Subchannel subchannel = newSubchannelMap.get(eagAsList);
if (subchannel == null) {
subchannel = subchannels.get(eagAsList);
if (subchannel == null) {
subchannel = subchannelPool.takeOrCreateSubchannel(eag, createSubchannelAttrs());
subchannel.requestConnection();
}
newSubchannelMap.put(eagAsList, subchannel);
}
BackendEntry entry;
// Only picks with tokens are reported to LoadRecorder
if (backendAddr.getToken() == null) {
entry = new BackendEntry(subchannel);
} else {
entry = new BackendEntry(subchannel, loadRecorder, backendAddr.getToken());
}
newBackendList.add(entry);
}
newSubchannelMap.put(eag, subchannel);
}
BackendEntry entry;
// Only picks with tokens are reported to LoadRecorder
if (backendAddr.getToken() == null) {
entry = new BackendEntry(subchannel);
} else {
entry = new BackendEntry(subchannel, loadRecorder, backendAddr.getToken());
}
newBackendList.add(entry);
// Close Subchannels whose addresses have been delisted
for (Entry<List<EquivalentAddressGroup>, Subchannel> entry : subchannels.entrySet()) {
List<EquivalentAddressGroup> eagList = entry.getKey();
if (!newSubchannelMap.containsKey(eagList)) {
subchannelPool.returnSubchannel(entry.getValue());
}
}
subchannels = Collections.unmodifiableMap(newSubchannelMap);
break;
case PICK_FIRST:
List<EquivalentAddressGroup> eagList = new ArrayList<>();
// Because for PICK_FIRST, we create a single Subchannel for all addresses, we have to
// attach the tokens to the EAG attributes and use TokenAttachingLoadRecorder to put them on
// headers.
//
// The PICK_FIRST code path doesn't cache Subchannels.
for (BackendAddressGroup bag : newBackendAddrList) {
EquivalentAddressGroup origEag = bag.getAddresses();
Attributes eagAttrs = origEag.getAttributes();
if (bag.getToken() != null) {
eagAttrs = eagAttrs.toBuilder()
.set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, bag.getToken()).build();
}
eagList.add(new EquivalentAddressGroup(origEag.getAddresses(), eagAttrs));
}
Subchannel subchannel;
if (subchannels.isEmpty()) {
subchannel = helper.createSubchannel(eagList, createSubchannelAttrs());
} else {
checkState(subchannels.size() == 1, "Unexpected Subchannel count: %s", subchannels);
subchannel = subchannels.values().iterator().next();
helper.updateSubchannelAddresses(subchannel, eagList);
}
subchannels = Collections.singletonMap(eagList, subchannel);
newBackendList.add(
new BackendEntry(subchannel, new TokenAttachingTracerFactory(loadRecorder)));
break;
default:
throw new AssertionError("Missing case for " + mode);
}
// Close Subchannels whose addresses have been delisted
for (Entry<EquivalentAddressGroup, Subchannel> entry : subchannels.entrySet()) {
EquivalentAddressGroup eag = entry.getKey();
if (!newSubchannelMap.containsKey(eag)) {
subchannelPool.returnSubchannel(entry.getValue());
}
}
subchannels = Collections.unmodifiableMap(newSubchannelMap);
dropList = Collections.unmodifiableList(newDropList);
backendList = Collections.unmodifiableList(newBackendList);
}
@ -619,32 +659,67 @@ final class GrpclbState {
* changed since the last picker created.
*/
private void maybeUpdatePicker() {
List<RoundRobinEntry> pickList = new ArrayList<>(backendList.size());
Status error = null;
boolean hasIdle = false;
for (BackendEntry entry : backendList) {
Subchannel subchannel = entry.result.getSubchannel();
Attributes attrs = subchannel.getAttributes();
ConnectivityStateInfo stateInfo = attrs.get(STATE_INFO).get();
if (stateInfo.getState() == READY) {
pickList.add(entry);
} else if (stateInfo.getState() == TRANSIENT_FAILURE) {
error = stateInfo.getStatus();
} else if (stateInfo.getState() == IDLE) {
hasIdle = true;
}
}
List<RoundRobinEntry> pickList;
ConnectivityState state;
if (pickList.isEmpty()) {
if (error != null && !hasIdle) {
pickList.add(new ErrorEntry(error));
state = TRANSIENT_FAILURE;
} else {
pickList.add(BUFFER_ENTRY);
state = CONNECTING;
}
} else {
state = READY;
switch (mode) {
case ROUND_ROBIN:
pickList = new ArrayList<>(backendList.size());
Status error = null;
boolean hasIdle = false;
for (BackendEntry entry : backendList) {
Subchannel subchannel = entry.subchannel;
Attributes attrs = subchannel.getAttributes();
ConnectivityStateInfo stateInfo = attrs.get(STATE_INFO).get();
if (stateInfo.getState() == READY) {
pickList.add(entry);
} else if (stateInfo.getState() == TRANSIENT_FAILURE) {
error = stateInfo.getStatus();
} else if (stateInfo.getState() == IDLE) {
hasIdle = true;
}
}
if (pickList.isEmpty()) {
if (error != null && !hasIdle) {
pickList.add(new ErrorEntry(error));
state = TRANSIENT_FAILURE;
} else {
pickList.add(BUFFER_ENTRY);
state = CONNECTING;
}
} else {
state = READY;
}
break;
case PICK_FIRST:
if (backendList.isEmpty()) {
pickList = Collections.singletonList(BUFFER_ENTRY);
// Have not received server addresses
state = CONNECTING;
} else {
checkState(backendList.size() == 1, "Excessive backend entries: %s", backendList);
BackendEntry onlyEntry = backendList.get(0);
ConnectivityStateInfo stateInfo =
onlyEntry.subchannel.getAttributes().get(STATE_INFO).get();
state = stateInfo.getState();
switch (state) {
case READY:
pickList = Collections.<RoundRobinEntry>singletonList(onlyEntry);
break;
case TRANSIENT_FAILURE:
pickList =
Collections.<RoundRobinEntry>singletonList(new ErrorEntry(stateInfo.getStatus()));
break;
case CONNECTING:
pickList = Collections.singletonList(BUFFER_ENTRY);
break;
default:
pickList = Collections.<RoundRobinEntry>singletonList(
new IdleSubchannelEntry(onlyEntry.subchannel));
}
}
break;
default:
throw new AssertionError("Missing case for " + mode);
}
maybeUpdatePicker(state, new RoundRobinPicker(dropList, pickList));
}
@ -704,6 +779,14 @@ final class GrpclbState {
return new EquivalentAddressGroup(addrs, attrs);
}
private static Attributes createSubchannelAttrs() {
return Attributes.newBuilder()
.set(STATE_INFO,
new AtomicReference<>(
ConnectivityStateInfo.forNonError(IDLE)))
.build();
}
@VisibleForTesting
static final class DropEntry {
private final GrpclbClientLoadRecorder loadRecorder;
@ -740,34 +823,45 @@ final class GrpclbState {
}
}
private interface RoundRobinEntry {
@VisibleForTesting
interface RoundRobinEntry {
PickResult picked(Metadata headers);
}
@VisibleForTesting
static final class BackendEntry implements RoundRobinEntry {
final Subchannel subchannel;
@VisibleForTesting
final PickResult result;
@Nullable
private final GrpclbClientLoadRecorder loadRecorder;
@Nullable
private final String token;
/**
* Creates a BackendEntry whose usage will be reported to load recorder.
* For ROUND_ROBIN: creates a BackendEntry whose usage will be reported to load recorder.
*/
BackendEntry(Subchannel subchannel, GrpclbClientLoadRecorder loadRecorder, String token) {
this.result = PickResult.withSubchannel(subchannel, loadRecorder);
this.loadRecorder = checkNotNull(loadRecorder, "loadRecorder");
this.subchannel = checkNotNull(subchannel, "subchannel");
this.result =
PickResult.withSubchannel(subchannel, checkNotNull(loadRecorder, "loadRecorder"));
this.token = checkNotNull(token, "token");
}
/**
* Creates a BackendEntry whose usage will not be reported.
* For ROUND_ROBIN/PICK_FIRST: creates a BackendEntry whose usage will not be reported.
*/
BackendEntry(Subchannel subchannel) {
this.subchannel = checkNotNull(subchannel, "subchannel");
this.result = PickResult.withSubchannel(subchannel);
this.loadRecorder = null;
this.token = null;
}
/**
* For PICK_FIRST: creates a BackendEntry that includes all addresses.
*/
BackendEntry(Subchannel subchannel, TokenAttachingTracerFactory tracerFactory) {
this.subchannel = checkNotNull(subchannel, "subchannel");
this.result =
PickResult.withSubchannel(subchannel, checkNotNull(tracerFactory, "tracerFactory"));
this.token = null;
}
@ -783,12 +877,12 @@ final class GrpclbState {
@Override
public String toString() {
// This is printed in logs. Only give out useful information.
return "[" + result.getSubchannel().getAllAddresses().toString() + "(" + token + ")]";
return "[" + subchannel.getAllAddresses().toString() + "(" + token + ")]";
}
@Override
public int hashCode() {
return Objects.hashCode(loadRecorder, result, token);
return Objects.hashCode(result, token);
}
@Override
@ -797,8 +891,42 @@ final class GrpclbState {
return false;
}
BackendEntry that = (BackendEntry) other;
return Objects.equal(result, that.result) && Objects.equal(token, that.token)
&& Objects.equal(loadRecorder, that.loadRecorder);
return Objects.equal(result, that.result) && Objects.equal(token, that.token);
}
}
@VisibleForTesting
static final class IdleSubchannelEntry implements RoundRobinEntry {
private final Subchannel subchannel;
IdleSubchannelEntry(Subchannel subchannel) {
this.subchannel = checkNotNull(subchannel, "subchannel");
}
@Override
public PickResult picked(Metadata headers) {
subchannel.requestConnection();
return PickResult.withNoResult();
}
@Override
public String toString() {
// This is printed in logs. Only give out useful information.
return "(idle)[" + subchannel.getAllAddresses().toString() + "]";
}
@Override
public int hashCode() {
return Objects.hashCode(subchannel);
}
@Override
public boolean equals(Object other) {
if (!(other instanceof IdleSubchannelEntry)) {
return false;
}
IdleSubchannelEntry that = (IdleSubchannelEntry) other;
return Objects.equal(subchannel, that.subchannel);
}
}
@ -860,8 +988,7 @@ final class GrpclbState {
// First round-robin on dropList. If a drop entry is selected, request will be dropped. If
// a non-drop entry is selected, then round-robin on pickList. This makes sure requests are
// dropped at the same proportion as the drop entries appear on the round-robin list from
// the balancer, while only READY backends (that make up pickList) are selected for the
// non-drop cases.
// the balancer, while only backends from pickList are selected for the non-drop cases.
if (!dropList.isEmpty()) {
DropEntry drop = dropList.get(dropIndex);
dropIndex++;
@ -881,5 +1008,14 @@ final class GrpclbState {
return pick.picked(args.getHeaders());
}
}
@Override
public void requestConnection() {
for (RoundRobinEntry entry : pickList) {
if (entry instanceof IdleSubchannelEntry) {
((IdleSubchannelEntry) entry).subchannel.requestConnection();
}
}
}
}
}

View File

@ -0,0 +1,72 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.grpclb;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.base.Objects;
import io.grpc.Attributes;
import io.grpc.ClientStreamTracer;
import io.grpc.Metadata;
import io.grpc.internal.GrpcAttributes;
import javax.annotation.Nullable;
/**
* Wraps a {@link ClientStreamTracer.Factory}, retrieves tokens from transport attributes and
* attaches them to headers. This is only used in the PICK_FIRST mode.
*/
final class TokenAttachingTracerFactory extends ClientStreamTracer.Factory {
private static final ClientStreamTracer NOOP_TRACER = new ClientStreamTracer() {};
@Nullable
private final ClientStreamTracer.Factory delegate;
TokenAttachingTracerFactory(@Nullable ClientStreamTracer.Factory delegate) {
this.delegate = delegate;
}
@Override
public ClientStreamTracer newClientStreamTracer(
ClientStreamTracer.StreamInfo info, Metadata headers) {
Attributes transportAttrs = checkNotNull(info.getTransportAttrs(), "transportAttrs");
Attributes eagAttrs =
checkNotNull(transportAttrs.get(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS), "eagAttrs");
String token = eagAttrs.get(GrpclbConstants.TOKEN_ATTRIBUTE_KEY);
headers.discardAll(GrpclbConstants.TOKEN_METADATA_KEY);
if (token != null) {
headers.put(GrpclbConstants.TOKEN_METADATA_KEY, token);
}
if (delegate != null) {
return delegate.newClientStreamTracer(info, headers);
} else {
return NOOP_TRACER;
}
}
@Override
public int hashCode() {
return Objects.hashCode(delegate);
}
@Override
public boolean equals(Object other) {
if (!(other instanceof TokenAttachingTracerFactory)) {
return false;
}
return Objects.equal(delegate, ((TokenAttachingTracerFactory) other).delegate);
}
}

View File

@ -56,6 +56,7 @@ import io.grpc.ClientStreamTracer;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs;
@ -68,7 +69,9 @@ import io.grpc.SynchronizationContext;
import io.grpc.grpclb.GrpclbState.BackendEntry;
import io.grpc.grpclb.GrpclbState.DropEntry;
import io.grpc.grpclb.GrpclbState.ErrorEntry;
import io.grpc.grpclb.GrpclbState.IdleSubchannelEntry;
import io.grpc.grpclb.GrpclbState.Mode;
import io.grpc.grpclb.GrpclbState.RoundRobinEntry;
import io.grpc.grpclb.GrpclbState.RoundRobinPicker;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
@ -166,7 +169,8 @@ public class GrpclbLoadBalancerTest {
new LinkedList<>();
private final LinkedList<Subchannel> mockSubchannels = new LinkedList<>();
private final LinkedList<ManagedChannel> fakeOobChannels = new LinkedList<>();
private final ArrayList<Subchannel> subchannelTracker = new ArrayList<>();
private final ArrayList<Subchannel> pooledSubchannelTracker = new ArrayList<>();
private final ArrayList<Subchannel> unpooledSubchannelTracker = new ArrayList<>();
private final ArrayList<ManagedChannel> oobChannelTracker = new ArrayList<>();
private final ArrayList<String> failingLbAuthorities = new ArrayList<>();
private final SynchronizationContext syncContext = new SynchronizationContext(
@ -251,11 +255,25 @@ public class GrpclbLoadBalancerTest {
when(subchannel.getAllAddresses()).thenReturn(Arrays.asList(eag));
when(subchannel.getAttributes()).thenReturn(attrs);
mockSubchannels.add(subchannel);
subchannelTracker.add(subchannel);
pooledSubchannelTracker.add(subchannel);
return subchannel;
}
}).when(subchannelPool).takeOrCreateSubchannel(
any(EquivalentAddressGroup.class), any(Attributes.class));
doAnswer(new Answer<Subchannel>() {
@Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable {
Subchannel subchannel = mock(Subchannel.class);
List<EquivalentAddressGroup> eagList =
(List<EquivalentAddressGroup>) invocation.getArguments()[0];
Attributes attrs = (Attributes) invocation.getArguments()[1];
when(subchannel.getAllAddresses()).thenReturn(eagList);
when(subchannel.getAttributes()).thenReturn(attrs);
mockSubchannels.add(subchannel);
unpooledSubchannelTracker.add(subchannel);
return subchannel;
}
}).when(helper).createSubchannel(any(List.class), any(Attributes.class));
when(helper.getSynchronizationContext()).thenReturn(syncContext);
when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService());
when(helper.getChannelLogger()).thenReturn(channelLogger);
@ -294,14 +312,15 @@ public class GrpclbLoadBalancerTest {
assertTrue(channel + " is terminated", channel.isTerminated());
}
// GRPCLB manages subchannels only through subchannelPool
for (Subchannel subchannel: subchannelTracker) {
for (Subchannel subchannel : pooledSubchannelTracker) {
verify(subchannelPool).returnSubchannel(same(subchannel));
// Our mock subchannelPool never calls Subchannel.shutdown(), thus we can tell if
// LoadBalancer has called it expectedly.
verify(subchannel, never()).shutdown();
}
verify(helper, never())
.createSubchannel(any(List.class), any(Attributes.class));
for (Subchannel subchannel : unpooledSubchannelTracker) {
verify(subchannel).shutdown();
}
// No timer should linger after shutdown
assertThat(fakeClock.getPendingTasks()).isEmpty();
} finally {
@ -406,6 +425,65 @@ public class GrpclbLoadBalancerTest {
verify(subchannel, never()).getAttributes();
}
@Test
public void roundRobinPickerWithIdleEntry_noDrop() {
Subchannel subchannel = mock(Subchannel.class);
IdleSubchannelEntry entry = new IdleSubchannelEntry(subchannel);
RoundRobinPicker picker =
new RoundRobinPicker(Collections.<DropEntry>emptyList(), Collections.singletonList(entry));
PickSubchannelArgs args = mock(PickSubchannelArgs.class);
verify(subchannel, never()).requestConnection();
assertThat(picker.pickSubchannel(args)).isSameAs(PickResult.withNoResult());
verify(subchannel).requestConnection();
}
@Test
public void roundRobinPickerWithIdleEntry_andDrop() {
GrpclbClientLoadRecorder loadRecorder =
new GrpclbClientLoadRecorder(fakeClock.getTimeProvider());
// 1 out of 2 requests are to be dropped
DropEntry d = new DropEntry(loadRecorder, "LBTOKEN0003");
List<DropEntry> dropList = Arrays.asList(null, d);
Subchannel subchannel = mock(Subchannel.class);
IdleSubchannelEntry entry = new IdleSubchannelEntry(subchannel);
RoundRobinPicker picker = new RoundRobinPicker(dropList, Collections.singletonList(entry));
PickSubchannelArgs args = mock(PickSubchannelArgs.class);
verify(subchannel, never()).requestConnection();
assertThat(picker.pickSubchannel(args)).isSameAs(PickResult.withNoResult());
verify(subchannel).requestConnection();
assertThat(picker.pickSubchannel(args)).isSameAs(DROP_PICK_RESULT);
verify(subchannel).requestConnection();
assertThat(picker.pickSubchannel(args)).isSameAs(PickResult.withNoResult());
verify(subchannel, times(2)).requestConnection();
}
@Test
public void roundRobinPicker_requestConnection() {
// requestConnection() on RoundRobinPicker is only passed to IdleSubchannelEntry
Subchannel subchannel1 = mock(Subchannel.class);
Subchannel subchannel2 = mock(Subchannel.class);
RoundRobinPicker picker = new RoundRobinPicker(
Collections.<DropEntry>emptyList(),
Arrays.<RoundRobinEntry>asList(
new BackendEntry(subchannel1), new IdleSubchannelEntry(subchannel2),
new ErrorEntry(Status.UNAVAILABLE)));
verify(subchannel2, never()).requestConnection();
picker.requestConnection();
verify(subchannel2).requestConnection();
verify(subchannel1, never()).requestConnection();
}
@Test
public void loadReporting() {
Metadata headers = new Metadata();
@ -1591,6 +1669,297 @@ public class GrpclbLoadBalancerTest {
verify(helper, times(4)).refreshNameResolution();
}
@SuppressWarnings("unchecked")
@Test
public void grpclbWorking_pickFirstMode() throws Exception {
InOrder inOrder = inOrder(helper);
String lbConfig = "{\"childPolicy\" : [ {\"pick_first\" : {}} ]}";
List<EquivalentAddressGroup> grpclbResolutionList = createResolvedServerAddresses(true);
Attributes grpclbResolutionAttrs = Attributes.newBuilder().set(
LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build();
deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs);
assertEquals(1, fakeOobChannels.size());
ManagedChannel oobChannel = fakeOobChannels.poll();
verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture());
StreamObserver<LoadBalanceResponse> lbResponseObserver = lbResponseObserverCaptor.getValue();
assertEquals(1, lbRequestObservers.size());
StreamObserver<LoadBalanceRequest> lbRequestObserver = lbRequestObservers.poll();
verify(lbRequestObserver).onNext(
eq(LoadBalanceRequest.newBuilder().setInitialRequest(
InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build())
.build()));
// Simulate receiving LB response
List<ServerEntry> backends1 = Arrays.asList(
new ServerEntry("127.0.0.1", 2000, "token0001"),
new ServerEntry("127.0.0.1", 2010, "token0002"));
inOrder.verify(helper, never())
.updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
lbResponseObserver.onNext(buildInitialResponse());
lbResponseObserver.onNext(buildLbResponse(backends1));
inOrder.verify(helper).createSubchannel(
eq(Arrays.asList(
new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")),
new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))),
any(Attributes.class));
// Initially IDLE
inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue();
// Only one subchannel is created
assertThat(mockSubchannels).hasSize(1);
Subchannel subchannel = mockSubchannels.poll();
assertThat(picker0.dropList).containsExactly(null, null);
assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel));
// PICK_FIRST doesn't eagerly connect
verify(subchannel, never()).requestConnection();
// CONNECTING
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING));
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue();
assertThat(picker1.dropList).containsExactly(null, null);
assertThat(picker1.pickList).containsExactly(BUFFER_ENTRY);
// TRANSIENT_FAILURE
Status error = Status.UNAVAILABLE.withDescription("Simulated connection error");
deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error));
inOrder.verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue();
assertThat(picker2.dropList).containsExactly(null, null);
assertThat(picker2.pickList).containsExactly(new ErrorEntry(error));
// READY
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
RoundRobinPicker picker3 = (RoundRobinPicker) pickerCaptor.getValue();
assertThat(picker3.dropList).containsExactly(null, null);
assertThat(picker3.pickList).containsExactly(
new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder())));
// New server list with drops
List<ServerEntry> backends2 = Arrays.asList(
new ServerEntry("127.0.0.1", 2000, "token0001"),
new ServerEntry("token0003"), // drop
new ServerEntry("127.0.0.1", 2020, "token0004"));
inOrder.verify(helper, never())
.updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
lbResponseObserver.onNext(buildLbResponse(backends2));
// new addresses will be updated to the existing subchannel
// createSubchannel() has ever been called only once
verify(helper, times(1)).createSubchannel(any(List.class), any(Attributes.class));
assertThat(mockSubchannels).isEmpty();
inOrder.verify(helper).updateSubchannelAddresses(
same(subchannel),
eq(Arrays.asList(
new EquivalentAddressGroup(backends2.get(0).addr, eagAttrsWithToken("token0001")),
new EquivalentAddressGroup(backends2.get(2).addr,
eagAttrsWithToken("token0004")))));
inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
RoundRobinPicker picker4 = (RoundRobinPicker) pickerCaptor.getValue();
assertThat(picker4.dropList).containsExactly(
null, new DropEntry(getLoadRecorder(), "token0003"), null);
assertThat(picker4.pickList).containsExactly(
new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder())));
// Subchannel goes IDLE, but PICK_FIRST will not try to reconnect
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
RoundRobinPicker picker5 = (RoundRobinPicker) pickerCaptor.getValue();
verify(subchannel, never()).requestConnection();
// ... until it's selected
PickSubchannelArgs args = mock(PickSubchannelArgs.class);
PickResult pick = picker5.pickSubchannel(args);
assertThat(pick).isSameAs(PickResult.withNoResult());
verify(subchannel).requestConnection();
// ... or requested by application
picker5.requestConnection();
verify(subchannel, times(2)).requestConnection();
// PICK_FIRST doesn't use subchannelPool
verify(subchannelPool, never())
.takeOrCreateSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class));
verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class));
}
@SuppressWarnings("unchecked")
@Test
public void pickFirstMode_fallback() throws Exception {
InOrder inOrder = inOrder(helper);
String lbConfig = "{\"childPolicy\" : [ {\"pick_first\" : {}} ]}";
// Name resolver returns a mix of balancer and backend addresses
List<EquivalentAddressGroup> grpclbResolutionList =
createResolvedServerAddresses(false, true, false);
Attributes grpclbResolutionAttrs = Attributes.newBuilder().set(
LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build();
deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs);
// Attempted to connect to balancer
assertEquals(1, fakeOobChannels.size());
ManagedChannel oobChannel = fakeOobChannels.poll();
verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture());
StreamObserver<LoadBalanceResponse> lbResponseObserver = lbResponseObserverCaptor.getValue();
assertEquals(1, lbRequestObservers.size());
// Fallback timer expires with no response
fakeClock.forwardTime(GrpclbState.FALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS);
// Entering fallback mode
inOrder.verify(helper).createSubchannel(
eq(Arrays.asList(grpclbResolutionList.get(0), grpclbResolutionList.get(2))),
any(Attributes.class));
assertThat(mockSubchannels).hasSize(1);
Subchannel subchannel = mockSubchannels.poll();
// Initially IDLE
inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue();
// READY
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue();
assertThat(picker1.dropList).containsExactly(null, null);
assertThat(picker1.pickList).containsExactly(
new BackendEntry(subchannel, new TokenAttachingTracerFactory(null)));
assertThat(picker0.dropList).containsExactly(null, null);
assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel));
// Finally, an LB response, which brings us out of fallback
List<ServerEntry> backends1 = Arrays.asList(
new ServerEntry("127.0.0.1", 2000, "token0001"),
new ServerEntry("127.0.0.1", 2010, "token0002"));
inOrder.verify(helper, never())
.updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
lbResponseObserver.onNext(buildInitialResponse());
lbResponseObserver.onNext(buildLbResponse(backends1));
// new addresses will be updated to the existing subchannel
// createSubchannel() has ever been called only once
verify(helper, times(1)).createSubchannel(any(List.class), any(Attributes.class));
assertThat(mockSubchannels).isEmpty();
inOrder.verify(helper).updateSubchannelAddresses(
same(subchannel),
eq(Arrays.asList(
new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")),
new EquivalentAddressGroup(backends1.get(1).addr,
eagAttrsWithToken("token0002")))));
inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue();
assertThat(picker2.dropList).containsExactly(null, null);
assertThat(picker2.pickList).containsExactly(
new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder())));
// PICK_FIRST doesn't use subchannelPool
verify(subchannelPool, never())
.takeOrCreateSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class));
verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class));
}
@Test
public void switchMode() throws Exception {
InOrder inOrder = inOrder(helper);
String lbConfig = "{\"childPolicy\" : [ {\"round_robin\" : {}} ]}";
List<EquivalentAddressGroup> grpclbResolutionList = createResolvedServerAddresses(true);
Attributes grpclbResolutionAttrs = Attributes.newBuilder().set(
LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build();
deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs);
assertEquals(1, fakeOobChannels.size());
ManagedChannel oobChannel = fakeOobChannels.poll();
verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture());
StreamObserver<LoadBalanceResponse> lbResponseObserver = lbResponseObserverCaptor.getValue();
assertEquals(1, lbRequestObservers.size());
StreamObserver<LoadBalanceRequest> lbRequestObserver = lbRequestObservers.poll();
verify(lbRequestObserver).onNext(
eq(LoadBalanceRequest.newBuilder().setInitialRequest(
InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build())
.build()));
// Simulate receiving LB response
List<ServerEntry> backends1 = Arrays.asList(
new ServerEntry("127.0.0.1", 2000, "token0001"),
new ServerEntry("127.0.0.1", 2010, "token0002"));
inOrder.verify(helper, never())
.updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
lbResponseObserver.onNext(buildInitialResponse());
lbResponseObserver.onNext(buildLbResponse(backends1));
// ROUND_ROBIN: create one subchannel per server
verify(subchannelPool).takeOrCreateSubchannel(
eq(new EquivalentAddressGroup(backends1.get(0).addr, LB_BACKEND_ATTRS)),
any(Attributes.class));
verify(subchannelPool).takeOrCreateSubchannel(
eq(new EquivalentAddressGroup(backends1.get(1).addr, LB_BACKEND_ATTRS)),
any(Attributes.class));
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
assertEquals(2, mockSubchannels.size());
Subchannel subchannel1 = mockSubchannels.poll();
Subchannel subchannel2 = mockSubchannels.poll();
verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class));
// Switch to PICK_FIRST
lbConfig = "{\"childPolicy\" : [ {\"pick_first\" : {}} ]}";
grpclbResolutionAttrs = Attributes.newBuilder().set(
LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build();
deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs);
// GrpclbState will be shutdown, and a new one will be created
assertThat(oobChannel.isShutdown()).isTrue();
verify(subchannelPool).returnSubchannel(same(subchannel1));
verify(subchannelPool).returnSubchannel(same(subchannel2));
// A new LB stream is created
assertEquals(1, fakeOobChannels.size());
oobChannel = fakeOobChannels.poll();
verify(mockLbService, times(2)).balanceLoad(lbResponseObserverCaptor.capture());
lbResponseObserver = lbResponseObserverCaptor.getValue();
assertEquals(1, lbRequestObservers.size());
lbRequestObserver = lbRequestObservers.poll();
verify(lbRequestObserver).onNext(
eq(LoadBalanceRequest.newBuilder().setInitialRequest(
InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build())
.build()));
// Simulate receiving LB response
inOrder.verify(helper, never())
.updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
lbResponseObserver.onNext(buildInitialResponse());
lbResponseObserver.onNext(buildLbResponse(backends1));
// PICK_FIRST Subchannel
inOrder.verify(helper).createSubchannel(
eq(Arrays.asList(
new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")),
new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))),
any(Attributes.class));
inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
}
private static Attributes eagAttrsWithToken(String token) {
return LB_BACKEND_ATTRS.toBuilder().set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, token).build();
}
@Test
public void retrieveModeFromLbConfig_pickFirst() throws Exception {
String lbConfig = "{\"childPolicy\" : [{\"pick_first\" : {}}, {\"round_robin\" : {}}]}";

View File

@ -0,0 +1,124 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.grpclb;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import io.grpc.Attributes;
import io.grpc.CallOptions;
import io.grpc.ClientStreamTracer;
import io.grpc.Metadata;
import io.grpc.internal.GrpcAttributes;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link TokenAttachingTracerFactory}. */
@RunWith(JUnit4.class)
public class TokenAttachingTracerFactoryTest {
private static final ClientStreamTracer fakeTracer = new ClientStreamTracer() {};
private final ClientStreamTracer.Factory delegate = mock(
ClientStreamTracer.Factory.class,
delegatesTo(
new ClientStreamTracer.Factory() {
@Override
public ClientStreamTracer newClientStreamTracer(
ClientStreamTracer.StreamInfo info, Metadata headers) {
return fakeTracer;
}
}));
@Test
public void hasToken() {
TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate);
ClientStreamTracer.StreamInfo info = new ClientStreamTracer.StreamInfo() {
@Override
public Attributes getTransportAttrs() {
Attributes eagAttrs = Attributes.newBuilder()
.set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, "token0001").build();
return Attributes.newBuilder()
.set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build();
}
@Override
public CallOptions getCallOptions() {
return CallOptions.DEFAULT;
}
};
Metadata headers = new Metadata();
// Preexisting token should be replaced
headers.put(GrpclbConstants.TOKEN_METADATA_KEY, "preexisting-token");
ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers);
verify(delegate).newClientStreamTracer(same(info), same(headers));
assertThat(tracer).isSameAs(fakeTracer);
assertThat(headers.getAll(GrpclbConstants.TOKEN_METADATA_KEY)).containsExactly("token0001");
}
@Test
public void noToken() {
TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate);
ClientStreamTracer.StreamInfo info = new ClientStreamTracer.StreamInfo() {
@Override
public Attributes getTransportAttrs() {
return Attributes.newBuilder()
.set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build();
}
@Override
public CallOptions getCallOptions() {
return CallOptions.DEFAULT;
}
};
Metadata headers = new Metadata();
// Preexisting token should be removed
headers.put(GrpclbConstants.TOKEN_METADATA_KEY, "preexisting-token");
ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers);
verify(delegate).newClientStreamTracer(same(info), same(headers));
assertThat(tracer).isSameAs(fakeTracer);
assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull();
}
@Test
public void nullDelegate() {
TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(null);
ClientStreamTracer.StreamInfo info = new ClientStreamTracer.StreamInfo() {
@Override
public Attributes getTransportAttrs() {
return Attributes.newBuilder()
.set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build();
}
@Override
public CallOptions getCallOptions() {
return CallOptions.DEFAULT;
}
};
Metadata headers = new Metadata();
ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers);
assertThat(tracer).isNotNull();
assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull();
}
}