grpclb: CachedSubchannelPool use new create subchannel (#6831)

This commit is contained in:
Jihun Cho 2020-03-31 13:31:04 -07:00 committed by GitHub
parent ae211a1ba8
commit 6dbdfcdbbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 299 additions and 283 deletions

View File

@ -23,9 +23,10 @@ import com.google.common.annotations.VisibleForTesting;
import io.grpc.Attributes;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.CreateSubchannelArgs;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.SynchronizationContext.ScheduledHandle;
import java.util.HashMap;
import java.util.concurrent.TimeUnit;
@ -38,28 +39,40 @@ final class CachedSubchannelPool implements SubchannelPool {
private final HashMap<EquivalentAddressGroup, CacheEntry> cache =
new HashMap<>();
private Helper helper;
private LoadBalancer lb;
private final Helper helper;
private PooledSubchannelStateListener listener;
@VisibleForTesting
static final long SHUTDOWN_TIMEOUT_MS = 10000;
@Override
public void init(Helper helper, LoadBalancer lb) {
public CachedSubchannelPool(Helper helper) {
this.helper = checkNotNull(helper, "helper");
this.lb = checkNotNull(lb, "lb");
}
@Override
@SuppressWarnings("deprecation")
public void registerListener(PooledSubchannelStateListener listener) {
this.listener = checkNotNull(listener, "listener");
}
@Override
public Subchannel takeOrCreateSubchannel(
EquivalentAddressGroup eag, Attributes defaultAttributes) {
final CacheEntry entry = cache.remove(eag);
final Subchannel subchannel;
if (entry == null) {
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to the
// new createSubchannel().
subchannel = helper.createSubchannel(eag, defaultAttributes);
subchannel =
helper.createSubchannel(
CreateSubchannelArgs.newBuilder()
.setAddresses(eag)
.setAttributes(defaultAttributes)
.build());
subchannel.start(new SubchannelStateListener() {
@Override
public void onSubchannelState(ConnectivityStateInfo newState) {
updateCachedSubchannelState(subchannel, newState);
listener.onSubchannelState(subchannel, newState);
}
});
} else {
subchannel = entry.subchannel;
entry.shutdownTimer.cancel();
@ -68,15 +81,15 @@ final class CachedSubchannelPool implements SubchannelPool {
helper.getSynchronizationContext().execute(new Runnable() {
@Override
public void run() {
lb.handleSubchannelState(subchannel, entry.state);
listener.onSubchannelState(subchannel, entry.state);
}
});
}
return subchannel;
}
@Override
public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newStateInfo) {
private void updateCachedSubchannelState(
Subchannel subchannel, ConnectivityStateInfo newStateInfo) {
CacheEntry cached = cache.get(subchannel.getAddresses());
if (cached == null || cached.subchannel != subchannel) {
// Given subchannel is not cached. Not our responsibility.

View File

@ -23,7 +23,6 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Stopwatch;
import io.grpc.Attributes;
import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.Status;
@ -68,19 +67,10 @@ class GrpclbLoadBalancer extends LoadBalancer {
this.stopwatch = checkNotNull(stopwatch, "stopwatch");
this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool");
this.subchannelPool.init(helper, this);
recreateStates();
checkNotNull(grpclbState, "grpclbState");
}
@Deprecated
@Override
public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) {
// grpclbState should never be null here since handleSubchannelState cannot be called while the
// lb is shutdown.
grpclbState.handleSubchannelState(subchannel, newState);
}
@Override
public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
Attributes attributes = resolvedAddresses.getAttributes();
@ -137,7 +127,8 @@ class GrpclbLoadBalancer extends LoadBalancer {
resetStates();
checkState(grpclbState == null, "Should've been cleared");
grpclbState =
new GrpclbState(config, helper, subchannelPool, time, stopwatch, backoffPolicyProvider);
new GrpclbState(
config, helper, subchannelPool, time, stopwatch, backoffPolicyProvider);
}
@Override

View File

@ -59,10 +59,13 @@ public final class GrpclbLoadBalancerProvider extends LoadBalancerProvider {
@Override
public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) {
return new GrpclbLoadBalancer(
helper, new CachedSubchannelPool(), TimeProvider.SYSTEM_TIME_PROVIDER,
Stopwatch.createUnstarted(),
new ExponentialBackoffPolicy.Provider());
return
new GrpclbLoadBalancer(
helper,
new CachedSubchannelPool(helper),
TimeProvider.SYSTEM_TIME_PROVIDER,
Stopwatch.createUnstarted(),
new ExponentialBackoffPolicy.Provider());
}
@Override

View File

@ -36,16 +36,19 @@ import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer.CreateSubchannelArgs;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.SynchronizationContext.ScheduledHandle;
import io.grpc.grpclb.SubchannelPool.PooledSubchannelStateListener;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.TimeProvider;
import io.grpc.lb.v1.ClientStats;
@ -107,7 +110,7 @@ final class GrpclbState {
}
};
static enum Mode {
enum Mode {
ROUND_ROBIN,
PICK_FIRST,
}
@ -115,6 +118,7 @@ final class GrpclbState {
private final String serviceName;
private final Helper helper;
private final SynchronizationContext syncContext;
@Nullable
private final SubchannelPool subchannelPool;
private final TimeProvider time;
private final Stopwatch stopwatch;
@ -166,9 +170,19 @@ final class GrpclbState {
this.config = checkNotNull(config, "config");
this.helper = checkNotNull(helper, "helper");
this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
this.subchannelPool =
config.getMode() == Mode.ROUND_ROBIN
? checkNotNull(subchannelPool, "subchannelPool") : null;
if (config.getMode() == Mode.ROUND_ROBIN) {
this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool");
subchannelPool.registerListener(
new PooledSubchannelStateListener() {
@Override
public void onSubchannelState(
Subchannel subchannel, ConnectivityStateInfo newState) {
handleSubchannelState(subchannel, newState);
}
});
} else {
this.subchannelPool = null;
}
this.time = checkNotNull(time, "time provider");
this.stopwatch = checkNotNull(stopwatch, "stopwatch");
this.timerService = checkNotNull(helper.getScheduledExecutorService(), "timerService");
@ -182,13 +196,7 @@ final class GrpclbState {
}
void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) {
if (newState.getState() == SHUTDOWN) {
return;
}
if (!subchannels.values().contains(subchannel)) {
if (subchannelPool != null ) {
subchannelPool.handleSubchannelState(subchannel, newState);
}
if (newState.getState() == SHUTDOWN || !subchannels.values().contains(subchannel)) {
return;
}
if (config.getMode() == Mode.ROUND_ROBIN && newState.getState() == IDLE) {
@ -254,7 +262,7 @@ final class GrpclbState {
return;
}
}
// Fallback contiditions met
// Fallback conditions met
useFallbackBackends();
}
@ -383,7 +391,6 @@ final class GrpclbState {
/**
* Populate the round-robin lists with the given values.
*/
@SuppressWarnings("deprecation")
private void useRoundRobinLists(
List<DropEntry> newDropList, List<BackendAddressGroup> newBackendAddrList,
@Nullable GrpclbClientLoadRecorder loadRecorder) {
@ -427,7 +434,7 @@ final class GrpclbState {
break;
case PICK_FIRST:
checkState(subchannels.size() <= 1, "Unexpected Subchannel count: %s", subchannels);
Subchannel subchannel;
final Subchannel subchannel;
if (newBackendAddrList.isEmpty()) {
if (subchannels.size() == 1) {
cancelFallbackTimer();
@ -453,9 +460,18 @@ final class GrpclbState {
eagList.add(new EquivalentAddressGroup(origEag.getAddresses(), eagAttrs));
}
if (subchannels.isEmpty()) {
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to
// the new createSubchannel().
subchannel = helper.createSubchannel(eagList, createSubchannelAttrs());
subchannel =
helper.createSubchannel(
CreateSubchannelArgs.newBuilder()
.setAddresses(eagList)
.setAttributes(createSubchannelAttrs())
.build());
subchannel.start(new SubchannelStateListener() {
@Override
public void onSubchannelState(ConnectivityStateInfo newState) {
handleSubchannelState(subchannel, newState);
}
});
} else {
subchannel = subchannels.values().iterator().next();
subchannel.updateAddresses(eagList);

View File

@ -19,8 +19,6 @@ package io.grpc.grpclb;
import io.grpc.Attributes;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.Subchannel;
import javax.annotation.concurrent.NotThreadSafe;
@ -31,10 +29,11 @@ import javax.annotation.concurrent.NotThreadSafe;
*/
@NotThreadSafe
interface SubchannelPool {
/**
* Pass essential utilities and the balancer that's using this pool.
* Registers a listener to received Subchannel status updates.
*/
void init(Helper helper, LoadBalancer lb);
void registerListener(PooledSubchannelStateListener listener);
/**
* Takes a {@link Subchannel} from the pool for the given {@code eag} if there is one available.
@ -43,12 +42,6 @@ interface SubchannelPool {
*/
Subchannel takeOrCreateSubchannel(EquivalentAddressGroup eag, Attributes defaultAttributes);
/**
* Gets notified about a state change of Subchannel that is possibly cached in this pool. Do
* nothing if this pool doesn't own this Subchannel.
*/
void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newStateInfo);
/**
* Puts a {@link Subchannel} back to the pool. From this point the Subchannel is owned by the
* pool, and the caller should stop referencing to this Subchannel.
@ -59,4 +52,20 @@ interface SubchannelPool {
* Shuts down all subchannels in the pool immediately.
*/
void clear();
/**
* Receives state changes for a pooled {@link Subchannel}.
*/
interface PooledSubchannelStateListener {
/**
* Handles a state change on a Subchannel. The behavior is similar to {@link
* io.grpc.LoadBalancer.SubchannelStateListener}.
*
* <p>When a subchannel is reused, subchannel state change event will be triggered even if the
* underlying status remains same.
*/
void onSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState);
}
}

View File

@ -20,14 +20,11 @@ import static com.google.common.truth.Truth.assertThat;
import static io.grpc.grpclb.CachedSubchannelPool.SHUTDOWN_TIMEOUT_MS;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.atMost;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.same;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
@ -36,21 +33,23 @@ import io.grpc.Attributes;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.CreateSubchannelArgs;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.grpclb.CachedSubchannelPool.ShutdownSubchannelTask;
import io.grpc.grpclb.SubchannelPool.PooledSubchannelStateListener;
import io.grpc.internal.FakeClock;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.AdditionalAnswers;
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
@ -67,8 +66,6 @@ public class CachedSubchannelPoolTest {
private static final ConnectivityStateInfo READY_STATE =
ConnectivityStateInfo.forNonError(ConnectivityState.READY);
private static final ConnectivityStateInfo TRANSIENT_FAILURE_STATE =
ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription("Simulated"));
private static final FakeClock.TaskFilter SHUTDOWN_TASK_FILTER =
new FakeClock.TaskFilter() {
@Override
@ -80,7 +77,15 @@ public class CachedSubchannelPoolTest {
};
private final Helper helper = mock(Helper.class);
private final LoadBalancer balancer = mock(LoadBalancer.class);
private final PooledSubchannelStateListener listener = mock(
PooledSubchannelStateListener.class,
AdditionalAnswers.delegatesTo(new PooledSubchannelStateListener() {
@Override
public void onSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) {
syncContext.throwIfNotInThisSynchronizationContext();
}
}));
private final FakeClock clock = new FakeClock();
private final SynchronizationContext syncContext = new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
@ -89,67 +94,60 @@ public class CachedSubchannelPoolTest {
throw new AssertionError(e);
}
});
private final CachedSubchannelPool pool = new CachedSubchannelPool();
private final SubchannelPool pool = new CachedSubchannelPool(helper);
private final ArrayList<Subchannel> mockSubchannels = new ArrayList<>();
private final ArgumentCaptor<CreateSubchannelArgs> createSubchannelArgsCaptor
= ArgumentCaptor.forClass(CreateSubchannelArgs.class);
@Before
@SuppressWarnings({"unchecked", "deprecation"})
public void setUp() {
doAnswer(new Answer<Subchannel>() {
@Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable {
Subchannel subchannel = mock(Subchannel.class);
List<EquivalentAddressGroup> eagList =
(List<EquivalentAddressGroup>) invocation.getArguments()[0];
Attributes attrs = (Attributes) invocation.getArguments()[1];
when(subchannel.getAllAddresses()).thenReturn(eagList);
when(subchannel.getAttributes()).thenReturn(attrs);
CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0];
when(subchannel.getAllAddresses()).thenReturn(args.getAddresses());
when(subchannel.getAttributes()).thenReturn(args.getAttributes());
mockSubchannels.add(subchannel);
return subchannel;
}
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to
// the new createSubchannel().
}).when(helper).createSubchannel(any(List.class), any(Attributes.class));
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
syncContext.throwIfNotInThisSynchronizationContext();
return null;
}
}).when(balancer).handleSubchannelState(
any(Subchannel.class), any(ConnectivityStateInfo.class));
}).when(helper).createSubchannel(any(CreateSubchannelArgs.class));
when(helper.getSynchronizationContext()).thenReturn(syncContext);
when(helper.getScheduledExecutorService()).thenReturn(clock.getScheduledExecutorService());
pool.init(helper, balancer);
pool.registerListener(listener);
}
@After
@SuppressWarnings("deprecation")
public void wrapUp() {
if (mockSubchannels.isEmpty()) {
return;
}
// Sanity checks
for (Subchannel subchannel : mockSubchannels) {
verify(subchannel, atMost(1)).shutdown();
}
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to
// the new API.
verify(balancer, atLeast(0))
.handleSubchannelState(any(Subchannel.class), any(ConnectivityStateInfo.class));
verifyNoMoreInteractions(balancer);
verify(listener, atLeast(0))
.onSubchannelState(any(Subchannel.class), any(ConnectivityStateInfo.class));
verifyNoMoreInteractions(listener);
}
@SuppressWarnings("deprecation")
@Test
public void subchannelExpireAfterReturned() {
Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1);
assertThat(subchannel1).isNotNull();
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to the new
// createSubchannel().
verify(helper).createSubchannel(eq(Arrays.asList(EAG1)), same(ATTRS1));
InOrder inOrder = Mockito.inOrder(helper);
inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture());
CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue();
assertThat(createSubchannelArgs.getAddresses()).containsExactly(EAG1);
assertThat(createSubchannelArgs.getAttributes()).isEqualTo(ATTRS1);
Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2);
assertThat(subchannel2).isNotNull();
assertThat(subchannel2).isNotSameInstanceAs(subchannel1);
verify(helper).createSubchannel(eq(Arrays.asList(EAG2)), same(ATTRS2));
inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture());
createSubchannelArgs = createSubchannelArgsCaptor.getValue();
assertThat(createSubchannelArgs.getAddresses()).containsExactly(EAG2);
assertThat(createSubchannelArgs.getAttributes()).isEqualTo(ATTRS2);
pool.returnSubchannel(subchannel1, READY_STATE);
@ -170,19 +168,23 @@ public class CachedSubchannelPoolTest {
assertThat(clock.numPendingTasks()).isEqualTo(0);
}
@SuppressWarnings("deprecation")
@Test
public void subchannelReused() {
Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1);
assertThat(subchannel1).isNotNull();
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to the new
// createSubchannel().
verify(helper).createSubchannel(eq(Arrays.asList(EAG1)), same(ATTRS1));
InOrder inOrder = Mockito.inOrder(helper);
inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture());
CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue();
assertThat(createSubchannelArgs.getAddresses()).containsExactly(EAG1);
assertThat(createSubchannelArgs.getAttributes()).isEqualTo(ATTRS1);
Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2);
assertThat(subchannel2).isNotNull();
assertThat(subchannel2).isNotSameInstanceAs(subchannel1);
verify(helper).createSubchannel(eq(Arrays.asList(EAG2)), same(ATTRS2));
inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture());
createSubchannelArgs = createSubchannelArgsCaptor.getValue();
assertThat(createSubchannelArgs.getAddresses()).containsExactly(EAG2);
assertThat(createSubchannelArgs.getAttributes()).isEqualTo(ATTRS2);
pool.returnSubchannel(subchannel1, READY_STATE);
@ -204,7 +206,10 @@ public class CachedSubchannelPoolTest {
// pool will create a new channel for EAG2 when requested
Subchannel subchannel2a = pool.takeOrCreateSubchannel(EAG2, ATTRS2);
assertThat(subchannel2a).isNotSameInstanceAs(subchannel2);
verify(helper, times(2)).createSubchannel(eq(Arrays.asList(EAG2)), same(ATTRS2));
inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture());
createSubchannelArgs = createSubchannelArgsCaptor.getValue();
assertThat(createSubchannelArgs.getAddresses()).containsExactly(EAG2);
assertThat(createSubchannelArgs.getAttributes()).isEqualTo(ATTRS2);
// subchannel1 expires SHUTDOWN_TIMEOUT_MS after being returned
pool.returnSubchannel(subchannel1a, READY_STATE);
@ -216,57 +221,6 @@ public class CachedSubchannelPoolTest {
assertThat(clock.numPendingTasks()).isEqualTo(0);
}
@SuppressWarnings("deprecation")
@Test
public void updateStateWhileInPool() {
Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1);
Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2);
pool.returnSubchannel(subchannel1, READY_STATE);
pool.returnSubchannel(subchannel2, TRANSIENT_FAILURE_STATE);
ConnectivityStateInfo anotherFailureState =
ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription("Another"));
pool.handleSubchannelState(subchannel1, anotherFailureState);
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to the new
// createSubchannel().
verify(balancer, never())
.handleSubchannelState(any(Subchannel.class), any(ConnectivityStateInfo.class));
assertThat(pool.takeOrCreateSubchannel(EAG1, ATTRS1)).isSameInstanceAs(subchannel1);
verify(balancer).handleSubchannelState(same(subchannel1), same(anotherFailureState));
verifyNoMoreInteractions(balancer);
assertThat(pool.takeOrCreateSubchannel(EAG2, ATTRS2)).isSameInstanceAs(subchannel2);
verify(balancer).handleSubchannelState(same(subchannel2), same(TRANSIENT_FAILURE_STATE));
verifyNoMoreInteractions(balancer);
}
@SuppressWarnings("deprecation")
@Test
public void updateStateWhileInPool_notSameObject() {
Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1);
pool.returnSubchannel(subchannel1, READY_STATE);
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to the new
// createSubchannel().
Subchannel subchannel2 = helper.createSubchannel(EAG1, ATTRS1);
Subchannel subchannel3 = helper.createSubchannel(EAG2, ATTRS2);
// subchannel2 is not in the pool, although with the same address
pool.handleSubchannelState(subchannel2, TRANSIENT_FAILURE_STATE);
// subchannel3 is not in the pool. In fact its address is not in the pool
pool.handleSubchannelState(subchannel3, TRANSIENT_FAILURE_STATE);
assertThat(pool.takeOrCreateSubchannel(EAG1, ATTRS1)).isSameInstanceAs(subchannel1);
// subchannel1's state is unchanged
verify(balancer).handleSubchannelState(same(subchannel1), same(READY_STATE));
verifyNoMoreInteractions(balancer);
}
@Test
public void returnDuplicateAddressSubchannel() {
Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1);

View File

@ -16,6 +16,7 @@
package io.grpc.grpclb;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.IDLE;
@ -54,14 +55,17 @@ import io.grpc.ClientStreamTracer;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer.CreateSubchannelArgs;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.LoadBalancer.ResolvedAddresses;
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.NameResolver.Factory;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.SynchronizationContext;
@ -94,13 +98,16 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.AdditionalAnswers;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder;
@ -140,10 +147,11 @@ public class GrpclbLoadBalancerTest {
private static final Attributes LB_BACKEND_ATTRS =
Attributes.newBuilder().set(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND, true).build();
@Mock
private Helper helper;
@Mock
private SubchannelPool subchannelPool;
private Helper helper = mock(Helper.class, delegatesTo(new FakeHelper()));
private SubchannelPool subchannelPool =
mock(
SubchannelPool.class,
delegatesTo(new CachedSubchannelPool(helper)));
private final ArrayList<String> logs = new ArrayList<>();
private final ChannelLogger channelLogger = new ChannelLogger() {
@Override
@ -165,10 +173,8 @@ public class GrpclbLoadBalancerTest {
new LinkedList<>();
private final LinkedList<Subchannel> mockSubchannels = new LinkedList<>();
private final LinkedList<ManagedChannel> fakeOobChannels = new LinkedList<>();
private final ArrayList<Subchannel> pooledSubchannelTracker = new ArrayList<>();
private final ArrayList<Subchannel> unpooledSubchannelTracker = new ArrayList<>();
private final ArrayList<ManagedChannel> oobChannelTracker = new ArrayList<>();
private final ArrayList<String> failingLbAuthorities = new ArrayList<>();
private final SynchronizationContext syncContext = new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
@Override
@ -189,14 +195,16 @@ public class GrpclbLoadBalancerTest {
@Mock
private BackoffPolicy backoffPolicy2;
private GrpclbLoadBalancer balancer;
private ArgumentCaptor<CreateSubchannelArgs> createSubchannelArgsCaptor =
ArgumentCaptor.forClass(CreateSubchannelArgs.class);
@SuppressWarnings({"unchecked", "deprecation"})
@Before
public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
mockLbService = mock(LoadBalancerGrpc.LoadBalancerImplBase.class, delegatesTo(
new LoadBalancerGrpc.LoadBalancerImplBase() {
@Override
@SuppressWarnings("unchecked")
public StreamObserver<LoadBalanceRequest> balanceLoad(
final StreamObserver<LoadBalanceResponse> responseObserver) {
StreamObserver<LoadBalanceRequest> requestObserver =
@ -215,72 +223,15 @@ public class GrpclbLoadBalancerTest {
}));
fakeLbServer = InProcessServerBuilder.forName("fakeLb")
.directExecutor().addService(mockLbService).build().start();
doAnswer(new Answer<ManagedChannel>() {
@Override
public ManagedChannel answer(InvocationOnMock invocation) throws Throwable {
String authority = (String) invocation.getArguments()[1];
ManagedChannel channel;
if (failingLbAuthorities.contains(authority)) {
channel = InProcessChannelBuilder.forName("nonExistFakeLb").directExecutor()
.overrideAuthority(authority).build();
} else {
channel = InProcessChannelBuilder.forName("fakeLb").directExecutor()
.overrideAuthority(authority).build();
}
fakeOobChannels.add(channel);
oobChannelTracker.add(channel);
return channel;
}
}).when(helper).createOobChannel(any(EquivalentAddressGroup.class), any(String.class));
doAnswer(new Answer<Subchannel>() {
@Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable {
Subchannel subchannel = mock(Subchannel.class);
EquivalentAddressGroup eag = (EquivalentAddressGroup) invocation.getArguments()[0];
Attributes attrs = (Attributes) invocation.getArguments()[1];
when(subchannel.getAllAddresses()).thenReturn(Arrays.asList(eag));
when(subchannel.getAttributes()).thenReturn(attrs);
mockSubchannels.add(subchannel);
pooledSubchannelTracker.add(subchannel);
return subchannel;
}
}).when(subchannelPool).takeOrCreateSubchannel(
any(EquivalentAddressGroup.class), any(Attributes.class));
doAnswer(new Answer<Subchannel>() {
@Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable {
Subchannel subchannel = mock(Subchannel.class);
List<EquivalentAddressGroup> eagList =
(List<EquivalentAddressGroup>) invocation.getArguments()[0];
Attributes attrs = (Attributes) invocation.getArguments()[1];
when(subchannel.getAllAddresses()).thenReturn(eagList);
when(subchannel.getAttributes()).thenReturn(attrs);
mockSubchannels.add(subchannel);
unpooledSubchannelTracker.add(subchannel);
return subchannel;
}
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to
// the new createSubchannel().
}).when(helper).createSubchannel(any(List.class), any(Attributes.class));
when(helper.getSynchronizationContext()).thenReturn(syncContext);
when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService());
when(helper.getChannelLogger()).thenReturn(channelLogger);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
currentPicker = (SubchannelPicker) invocation.getArguments()[1];
return null;
}
}).when(helper).updateBalancingState(
any(ConnectivityState.class), any(SubchannelPicker.class));
when(helper.getAuthority()).thenReturn(SERVICE_AUTHORITY);
when(backoffPolicy1.nextBackoffNanos()).thenReturn(10L, 100L);
when(backoffPolicy2.nextBackoffNanos()).thenReturn(10L, 100L);
when(backoffPolicyProvider.get()).thenReturn(backoffPolicy1, backoffPolicy2);
balancer = new GrpclbLoadBalancer(helper, subchannelPool, fakeClock.getTimeProvider(),
balancer = new GrpclbLoadBalancer(
helper,
subchannelPool,
fakeClock.getTimeProvider(),
fakeClock.getStopwatchSupplier().get(),
backoffPolicyProvider);
verify(subchannelPool).init(same(helper), same(balancer));
}
@After
@ -299,13 +250,6 @@ public class GrpclbLoadBalancerTest {
// balancer should have closed the LB stream, terminating the OOB channel.
assertTrue(channel + " is terminated", channel.isTerminated());
}
// GRPCLB manages subchannels only through subchannelPool
for (Subchannel subchannel : pooledSubchannelTracker) {
verify(subchannelPool).returnSubchannel(same(subchannel), any(ConnectivityStateInfo.class));
// Our mock subchannelPool never calls Subchannel.shutdown(), thus we can tell if
// LoadBalancer has called it expectedly.
verify(subchannel, never()).shutdown();
}
for (Subchannel subchannel : unpooledSubchannelTracker) {
verify(subchannel).shutdown();
}
@ -355,7 +299,6 @@ public class GrpclbLoadBalancerTest {
verify(subchannel, never()).getAttributes();
}
@Test
public void roundRobinPickerWithDrop() {
assertTrue(DROP_PICK_RESULT.isDrop());
@ -1077,7 +1020,6 @@ public class GrpclbLoadBalancerTest {
ConnectivityStateInfo.forTransientFailure(
Status.UNAVAILABLE.withDescription("You can get this error even if you are cached"));
deliverSubchannelState(subchannel1, errorOnCachedSubchannel1);
verify(subchannelPool).handleSubchannelState(same(subchannel1), same(errorOnCachedSubchannel1));
assertEquals(1, mockSubchannels.size());
Subchannel subchannel3 = mockSubchannels.poll();
@ -1100,11 +1042,6 @@ public class GrpclbLoadBalancerTest {
deliverSubchannelState(
subchannel1, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));
deliverSubchannelState(subchannel1, ConnectivityStateInfo.forNonError(SHUTDOWN));
inOrder.verify(subchannelPool)
.handleSubchannelState(same(subchannel1), eq(ConnectivityStateInfo.forNonError(READY)));
inOrder.verify(subchannelPool).handleSubchannelState(
same(subchannel1), eq(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)));
inOrder.verifyNoMoreInteractions();
deliverSubchannelState(subchannel3, ConnectivityStateInfo.forNonError(READY));
inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
@ -1253,6 +1190,8 @@ public class GrpclbLoadBalancerTest {
////////////////////////////////////////////////////////////////
// Name resolver sends new resolution results with backend addrs
////////////////////////////////////////////////////////////////
// prevents the cached subchannel to be used
subchannelPool.clear();
backendList = createResolvedBackendAddresses(2);
grpclbBalancerList = createResolvedBalancerAddresses(1);
deliverResolvedAddresses(backendList, grpclbBalancerList);
@ -1685,7 +1624,6 @@ public class GrpclbLoadBalancerTest {
verify(helper, times(4)).refreshNameResolution();
}
@SuppressWarnings({"unchecked", "deprecation"})
@Test
public void grpclbWorking_pickFirstMode() throws Exception {
InOrder inOrder = inOrder(helper);
@ -1716,13 +1654,12 @@ public class GrpclbLoadBalancerTest {
lbResponseObserver.onNext(buildInitialResponse());
lbResponseObserver.onNext(buildLbResponse(backends1));
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to
// the new createSubchannel().
inOrder.verify(helper).createSubchannel(
eq(Arrays.asList(
inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture());
CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue();
assertThat(createSubchannelArgs.getAddresses())
.containsExactly(
new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")),
new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))),
any(Attributes.class));
new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")));
// Initially IDLE
inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
@ -1739,7 +1676,6 @@ public class GrpclbLoadBalancerTest {
// CONNECTING
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING));
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue();
assertThat(picker1.dropList).containsExactly(null, null);
@ -1773,7 +1709,7 @@ public class GrpclbLoadBalancerTest {
// new addresses will be updated to the existing subchannel
// createSubchannel() has ever been called only once
verify(helper, times(1)).createSubchannel(any(List.class), any(Attributes.class));
verify(helper, times(1)).createSubchannel(any(CreateSubchannelArgs.class));
assertThat(mockSubchannels).isEmpty();
verify(subchannel).updateAddresses(
eq(Arrays.asList(
@ -1810,7 +1746,6 @@ public class GrpclbLoadBalancerTest {
.returnSubchannel(any(Subchannel.class), any(ConnectivityStateInfo.class));
}
@SuppressWarnings({"unchecked", "deprecation"})
@Test
public void grpclbWorking_pickFirstMode_lbSendsEmptyAddress() throws Exception {
InOrder inOrder = inOrder(helper);
@ -1840,13 +1775,12 @@ public class GrpclbLoadBalancerTest {
lbResponseObserver.onNext(buildInitialResponse());
lbResponseObserver.onNext(buildLbResponse(backends1));
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to
// the new createSubchannel().
inOrder.verify(helper).createSubchannel(
eq(Arrays.asList(
inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture());
CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue();
assertThat(createSubchannelArgs.getAddresses())
.containsExactly(
new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")),
new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))),
any(Attributes.class));
new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")));
// Initially IDLE
inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
@ -1863,7 +1797,6 @@ public class GrpclbLoadBalancerTest {
// CONNECTING
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING));
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue();
assertThat(picker1.dropList).containsExactly(null, null);
@ -1893,7 +1826,7 @@ public class GrpclbLoadBalancerTest {
// new addresses will be updated to the existing subchannel
// createSubchannel() has ever been called only once
inOrder.verify(helper, never()).createSubchannel(any(List.class), any(Attributes.class));
inOrder.verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class));
assertThat(mockSubchannels).isEmpty();
verify(subchannel).shutdown();
@ -1915,7 +1848,7 @@ public class GrpclbLoadBalancerTest {
lbResponseObserver.onNext(buildLbResponse(backends2));
// new addresses will be updated to the existing subchannel
inOrder.verify(helper, times(1)).createSubchannel(any(List.class), any(Attributes.class));
inOrder.verify(helper, times(1)).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
subchannel = mockSubchannels.poll();
@ -1956,7 +1889,6 @@ public class GrpclbLoadBalancerTest {
.isEqualTo(Code.CANCELLED);
}
@SuppressWarnings({"unchecked", "deprecation"})
@Test
public void pickFirstMode_fallback() throws Exception {
InOrder inOrder = inOrder(helper);
@ -1979,11 +1911,10 @@ public class GrpclbLoadBalancerTest {
fakeClock.forwardTime(GrpclbState.FALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS);
// Entering fallback mode
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to
// the new createSubchannel().
inOrder.verify(helper).createSubchannel(
eq(Arrays.asList(backendList.get(0), backendList.get(1))),
any(Attributes.class));
inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture());
CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue();
assertThat(createSubchannelArgs.getAddresses())
.containsExactly(backendList.get(0), backendList.get(1));
assertThat(mockSubchannels).hasSize(1);
Subchannel subchannel = mockSubchannels.poll();
@ -2015,7 +1946,7 @@ public class GrpclbLoadBalancerTest {
// new addresses will be updated to the existing subchannel
// createSubchannel() has ever been called only once
verify(helper, times(1)).createSubchannel(any(List.class), any(Attributes.class));
inOrder.verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class));
assertThat(mockSubchannels).isEmpty();
verify(subchannel).updateAddresses(
eq(Arrays.asList(
@ -2035,7 +1966,6 @@ public class GrpclbLoadBalancerTest {
.returnSubchannel(any(Subchannel.class), any(ConnectivityStateInfo.class));
}
@SuppressWarnings("deprecation")
@Test
public void switchMode() throws Exception {
InOrder inOrder = inOrder(helper);
@ -2111,13 +2041,12 @@ public class GrpclbLoadBalancerTest {
lbResponseObserver.onNext(buildLbResponse(backends1));
// PICK_FIRST Subchannel
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to
// the new createSubchannel().
inOrder.verify(helper).createSubchannel(
eq(Arrays.asList(
inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture());
CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue();
assertThat(createSubchannelArgs.getAddresses())
.containsExactly(
new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")),
new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))),
any(Attributes.class));
new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")));
inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
}
@ -2127,7 +2056,6 @@ public class GrpclbLoadBalancerTest {
}
@Test
@SuppressWarnings("deprecation")
public void switchMode_nullLbPolicy() throws Exception {
InOrder inOrder = inOrder(helper);
@ -2201,13 +2129,12 @@ public class GrpclbLoadBalancerTest {
lbResponseObserver.onNext(buildLbResponse(backends1));
// PICK_FIRST Subchannel
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to
// the new createSubchannel().
inOrder.verify(helper).createSubchannel(
eq(Arrays.asList(
inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture());
CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue();
assertThat(createSubchannelArgs.getAddresses())
.containsExactly(
new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")),
new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))),
any(Attributes.class));
new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")));
inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
}
@ -2485,17 +2412,9 @@ public class GrpclbLoadBalancerTest {
.inOrder();
}
@SuppressWarnings("deprecation")
private void deliverSubchannelState(
final Subchannel subchannel, final ConnectivityStateInfo newState) {
syncContext.execute(new Runnable() {
@Override
public void run() {
// TODO(zhangkun83): remove the deprecation suppression on this method once migrated to
// the new API.
balancer.handleSubchannelState(subchannel, newState);
}
});
((FakeSubchannel) subchannel).updateState(newState);
}
private void deliverNameResolutionError(final Status error) {
@ -2619,4 +2538,115 @@ public class GrpclbLoadBalancerTest {
this.token = token;
}
}
private static class FakeSubchannel extends Subchannel {
private List<EquivalentAddressGroup> eags;
private Attributes attributes;
private SubchannelStateListener listener;
public FakeSubchannel(List<EquivalentAddressGroup> eags, Attributes attributes) {
this.eags = Collections.unmodifiableList(eags);
this.attributes = attributes;
}
@Override
public List<EquivalentAddressGroup> getAllAddresses() {
return eags;
}
@Override
public Attributes getAttributes() {
return attributes;
}
@Override
public void start(SubchannelStateListener listener) {
this.listener = checkNotNull(listener, "listener");
}
@Override
public void updateAddresses(List<EquivalentAddressGroup> addrs) {
this.eags = Collections.unmodifiableList(addrs);
}
@Override
public void shutdown() {
}
@Override
public void requestConnection() {
}
public void updateState(ConnectivityStateInfo newState) {
listener.onSubchannelState(newState);
}
}
private class FakeHelper extends Helper {
@Override
public SynchronizationContext getSynchronizationContext() {
return syncContext;
}
@Override
public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) {
ManagedChannel channel =
InProcessChannelBuilder
.forName("fakeLb")
.directExecutor()
.overrideAuthority(authority)
.build();
fakeOobChannels.add(channel);
oobChannelTracker.add(channel);
return channel;
}
@Override
public Subchannel createSubchannel(CreateSubchannelArgs args) {
FakeSubchannel subchannel =
mock(
FakeSubchannel.class,
AdditionalAnswers
.delegatesTo(new FakeSubchannel(args.getAddresses(), args.getAttributes())));
mockSubchannels.add(subchannel);
unpooledSubchannelTracker.add(subchannel);
return subchannel;
}
@Override
public ScheduledExecutorService getScheduledExecutorService() {
return fakeClock.getScheduledExecutorService();
}
@Override
public ChannelLogger getChannelLogger() {
return channelLogger;
}
@Override
public void updateBalancingState(
@Nonnull ConnectivityState newState, @Nonnull SubchannelPicker newPicker) {
currentPicker = newPicker;
}
@Override
@SuppressWarnings("deprecation")
public Factory getNameResolverFactory() {
return mock(Factory.class);
}
@Override
public void refreshNameResolution() {
}
@Override
public String getAuthority() {
return SERVICE_AUTHORITY;
}
@Override
public void updateOobChannelAddresses(ManagedChannel channel, EquivalentAddressGroup eag) {
}
}
}