diff --git a/grpclb/src/main/java/io/grpc/grpclb/CachedSubchannelPool.java b/grpclb/src/main/java/io/grpc/grpclb/CachedSubchannelPool.java new file mode 100644 index 0000000000..55cbef1981 --- /dev/null +++ b/grpclb/src/main/java/io/grpc/grpclb/CachedSubchannelPool.java @@ -0,0 +1,142 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.grpclb; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.Subchannel; +import java.util.HashMap; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +/** + * A {@link SubchannelPool} that keeps returned {@link Subchannel}s for a given time before it's + * shut down by the pool. + */ +final class CachedSubchannelPool implements SubchannelPool { + private final HashMap cache = + new HashMap(); + + private Helper helper; + private ScheduledExecutorService timerService; + + @VisibleForTesting + static final long SHUTDOWN_TIMEOUT_MS = 10000; + + @Override + public void init(Helper helper, ScheduledExecutorService timerService) { + this.helper = checkNotNull(helper, "helper"); + this.timerService = checkNotNull(timerService, "timerService"); + } + + @Override + public Subchannel takeOrCreateSubchannel( + EquivalentAddressGroup eag, Attributes defaultAttributes) { + CacheEntry entry = cache.remove(eag); + Subchannel subchannel; + if (entry == null) { + subchannel = helper.createSubchannel(eag, defaultAttributes); + } else { + subchannel = entry.subchannel; + entry.shutdownTimer.cancel(false); + } + return subchannel; + } + + @Override + public void returnSubchannel(Subchannel subchannel) { + CacheEntry prev = cache.get(subchannel.getAddresses()); + if (prev != null) { + // Returning the same Subchannel twice has no effect. + // Returning a different Subchannel for an already cached EAG will cause the + // latter Subchannel to be shutdown immediately. + if (prev.subchannel != subchannel) { + subchannel.shutdown(); + } + return; + } + final ShutdownSubchannelTask shutdownTask = new ShutdownSubchannelTask(subchannel); + ScheduledFuture shutdownTimer = + timerService.schedule( + new ShutdownSubchannelScheduledTask(shutdownTask), + SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS); + shutdownTask.timer = shutdownTimer; + CacheEntry entry = new CacheEntry(subchannel, shutdownTimer); + cache.put(subchannel.getAddresses(), entry); + } + + @Override + public void clear() { + for (CacheEntry entry : cache.values()) { + entry.shutdownTimer.cancel(false); + entry.subchannel.shutdown(); + } + cache.clear(); + } + + @VisibleForTesting + final class ShutdownSubchannelScheduledTask implements Runnable { + private final ShutdownSubchannelTask task; + + ShutdownSubchannelScheduledTask(ShutdownSubchannelTask task) { + this.task = checkNotNull(task, "task"); + } + + @Override + public void run() { + helper.runSerialized(task); + } + } + + @VisibleForTesting + final class ShutdownSubchannelTask implements Runnable { + private final Subchannel subchannel; + private ScheduledFuture timer; + + private ShutdownSubchannelTask(Subchannel subchannel) { + this.subchannel = checkNotNull(subchannel, "subchannel"); + } + + // This runs in channelExecutor + @Override + public void run() { + // getSubchannel() may have cancelled the timer after the timer has expired but before this + // task is actually run in the channelExecutor. + if (!timer.isCancelled()) { + CacheEntry entry = cache.remove(subchannel.getAddresses()); + checkState(entry.subchannel == subchannel, "Inconsistent state"); + subchannel.shutdown(); + } + } + } + + private static class CacheEntry { + final Subchannel subchannel; + final ScheduledFuture shutdownTimer; + + CacheEntry(Subchannel subchannel, ScheduledFuture shutdownTimer) { + this.subchannel = checkNotNull(subchannel, "subchannel"); + this.shutdownTimer = checkNotNull(shutdownTimer, "shutdownTimer"); + } + } +} diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java index 17e3a50301..5a5a89161a 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java @@ -49,6 +49,7 @@ class GrpclbLoadBalancer extends LoadBalancer implements WithLogId { private final LogId logId = LogId.allocate(getClass().getName()); private final Helper helper; + private final SubchannelPool subchannelPool; private final Factory pickFirstBalancerFactory; private final Factory roundRobinBalancerFactory; private final ObjectPool timerServicePool; @@ -67,7 +68,7 @@ class GrpclbLoadBalancer extends LoadBalancer implements WithLogId { @Nullable private GrpclbState grpclbState; - GrpclbLoadBalancer(Helper helper, Factory pickFirstBalancerFactory, + GrpclbLoadBalancer(Helper helper, SubchannelPool subchannelPool, Factory pickFirstBalancerFactory, Factory roundRobinBalancerFactory, ObjectPool timerServicePool, TimeProvider time) { this.helper = checkNotNull(helper, "helper"); @@ -78,6 +79,8 @@ class GrpclbLoadBalancer extends LoadBalancer implements WithLogId { this.timerServicePool = checkNotNull(timerServicePool, "timerServicePool"); this.timerService = checkNotNull(timerServicePool.getObject(), "timerService"); this.time = checkNotNull(time, "time provider"); + this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool"); + this.subchannelPool.init(helper, timerService); setLbPolicy(LbPolicy.GRPCLB); } @@ -159,7 +162,8 @@ class GrpclbLoadBalancer extends LoadBalancer implements WithLogId { "roundRobinBalancerFactory.newLoadBalancer()"); break; case GRPCLB: - grpclbState = new GrpclbState(helper, time, timerService, logId); + grpclbState = + new GrpclbState(helper, subchannelPool, time, timerService, logId); break; default: // Do nohting diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerFactory.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerFactory.java index f50343a264..a55cb2970e 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerFactory.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerFactory.java @@ -50,7 +50,7 @@ public class GrpclbLoadBalancerFactory extends LoadBalancer.Factory { @Override public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { return new GrpclbLoadBalancer( - helper, PickFirstBalancerFactory.getInstance(), + helper, new CachedSubchannelPool(), PickFirstBalancerFactory.getInstance(), RoundRobinLoadBalancerFactory.getInstance(), // TODO(zhangkun83): balancer sends load reporting RPCs from it, which also involves // channelExecutor thus may also run other tasks queued in the channelExecutor. If such diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java index f295f916d2..9e75dd2e44 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java @@ -95,6 +95,7 @@ final class GrpclbState { private final LogId logId; private final String serviceName; private final Helper helper; + private final SubchannelPool subchannelPool; private final TimeProvider time; private final ScheduledExecutorService timerService; @@ -128,10 +129,12 @@ final class GrpclbState { GrpclbState( Helper helper, + SubchannelPool subchannelPool, TimeProvider time, ScheduledExecutorService timerService, LogId logId) { this.helper = checkNotNull(helper, "helper"); + this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool"); this.time = checkNotNull(time, "time provider"); this.timerService = checkNotNull(timerService, "timerService"); this.serviceName = checkNotNull(helper.getAuthority(), "helper returns null authority"); @@ -278,10 +281,13 @@ final class GrpclbState { void shutdown() { shutdownLbComm(); + // We close the subchannels through subchannelPool instead of helper just for convenience of + // testing. for (Subchannel subchannel : subchannels.values()) { - subchannel.shutdown(); + subchannelPool.returnSubchannel(subchannel); } subchannels = Collections.emptyMap(); + subchannelPool.clear(); cancelFallbackTimer(); } @@ -324,7 +330,7 @@ final class GrpclbState { new AtomicReference( ConnectivityStateInfo.forNonError(IDLE))) .build(); - subchannel = helper.createSubchannel(eag, subchannelAttrs); + subchannel = subchannelPool.takeOrCreateSubchannel(eag, subchannelAttrs); subchannel.requestConnection(); } newSubchannelMap.put(eag, subchannel); @@ -343,7 +349,7 @@ final class GrpclbState { for (Entry entry : subchannels.entrySet()) { EquivalentAddressGroup eag = entry.getKey(); if (!newSubchannelMap.containsKey(eag)) { - entry.getValue().shutdown(); + subchannelPool.returnSubchannel(entry.getValue()); } } diff --git a/grpclb/src/main/java/io/grpc/grpclb/SubchannelPool.java b/grpclb/src/main/java/io/grpc/grpclb/SubchannelPool.java new file mode 100644 index 0000000000..88498e8900 --- /dev/null +++ b/grpclb/src/main/java/io/grpc/grpclb/SubchannelPool.java @@ -0,0 +1,55 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.grpclb; + +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.Subchannel; +import java.util.concurrent.ScheduledExecutorService; +import javax.annotation.concurrent.NotThreadSafe; + +/** + * Manages life-cycle of Subchannels for {@link GrpclbState}. + * + *

All methods are run from the ChannelExecutor that the helper uses. + */ +@NotThreadSafe +interface SubchannelPool { + /** + * Pass essential utilities. + */ + void init(Helper helper, ScheduledExecutorService timerService); + + /** + * Takes a {@link Subchannel} from the pool for the given {@code eag} if there is one available. + * Otherwise, creates and returns a new {@code Subchannel} with the given {@code eag} and {@code + * defaultAttributes}. + */ + Subchannel takeOrCreateSubchannel(EquivalentAddressGroup eag, Attributes defaultAttributes); + + /** + * Puts a {@link Subchannel} back to the pool. From this point the Subchannel is owned by the + * pool. + */ + void returnSubchannel(Subchannel subchannel); + + /** + * Shuts down all subchannels in the pool immediately. + */ + void clear(); +} diff --git a/grpclb/src/test/java/io/grpc/grpclb/CachedSubchannelPoolTest.java b/grpclb/src/test/java/io/grpc/grpclb/CachedSubchannelPoolTest.java new file mode 100644 index 0000000000..7cdb7f7814 --- /dev/null +++ b/grpclb/src/test/java/io/grpc/grpclb/CachedSubchannelPoolTest.java @@ -0,0 +1,231 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.grpclb; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.grpclb.CachedSubchannelPool.SHUTDOWN_TIMEOUT_MS; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.atMost; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.same; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.grpclb.CachedSubchannelPool.ShutdownSubchannelScheduledTask; +import io.grpc.grpclb.CachedSubchannelPool.ShutdownSubchannelTask; +import io.grpc.internal.FakeClock; +import io.grpc.internal.SerializingExecutor; +import java.util.ArrayList; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +/** Unit tests for {@link CachedSubchannelPool}. */ +@RunWith(JUnit4.class) +public class CachedSubchannelPoolTest { + private static final EquivalentAddressGroup EAG1 = + new EquivalentAddressGroup(new FakeSocketAddress("fake-address-1"), Attributes.EMPTY); + private static final EquivalentAddressGroup EAG2 = + new EquivalentAddressGroup(new FakeSocketAddress("fake-address-2"), Attributes.EMPTY); + private static final Attributes.Key ATTR_KEY = Attributes.Key.of("test-attr"); + private static final Attributes ATTRS1 = Attributes.newBuilder().set(ATTR_KEY, "1").build(); + private static final Attributes ATTRS2 = Attributes.newBuilder().set(ATTR_KEY, "2").build(); + private static final FakeClock.TaskFilter SHUTDOWN_SCHEDULED_TASK_FILTER = + new FakeClock.TaskFilter() { + @Override + public boolean shouldAccept(Runnable command) { + return command instanceof ShutdownSubchannelScheduledTask; + } + }; + + private final SerializingExecutor channelExecutor = + new SerializingExecutor(MoreExecutors.directExecutor()); + private final Helper helper = mock(Helper.class); + private final FakeClock clock = new FakeClock(); + private final CachedSubchannelPool pool = new CachedSubchannelPool(); + private final ArrayList mockSubchannels = new ArrayList(); + + @Before + public void setUp() { + doAnswer(new Answer() { + @Override + public Subchannel answer(InvocationOnMock invocation) throws Throwable { + Subchannel subchannel = mock(Subchannel.class); + EquivalentAddressGroup eag = (EquivalentAddressGroup) invocation.getArguments()[0]; + Attributes attrs = (Attributes) invocation.getArguments()[1]; + when(subchannel.getAddresses()).thenReturn(eag); + when(subchannel.getAttributes()).thenReturn(attrs); + mockSubchannels.add(subchannel); + return subchannel; + } + }).when(helper).createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + Runnable task = (Runnable) invocation.getArguments()[0]; + channelExecutor.execute(task); + return null; + } + }).when(helper).runSerialized(any(Runnable.class)); + pool.init(helper, clock.getScheduledExecutorService()); + } + + @After + public void wrapUp() { + // Sanity checks + for (Subchannel subchannel : mockSubchannels) { + verify(subchannel, atMost(1)).shutdown(); + } + } + + @Test + public void subchannelExpireAfterReturned() { + Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1); + assertThat(subchannel1).isNotNull(); + verify(helper).createSubchannel(same(EAG1), same(ATTRS1)); + + Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2); + assertThat(subchannel2).isNotNull(); + assertThat(subchannel2).isNotSameAs(subchannel1); + verify(helper).createSubchannel(same(EAG2), same(ATTRS2)); + + pool.returnSubchannel(subchannel1); + + // subchannel1 is 1ms away from expiration. + clock.forwardTime(SHUTDOWN_TIMEOUT_MS - 1, MILLISECONDS); + verify(subchannel1, never()).shutdown(); + + pool.returnSubchannel(subchannel2); + + // subchannel1 expires. subchannel2 is (SHUTDOWN_TIMEOUT_MS - 1) away from expiration. + clock.forwardTime(1, MILLISECONDS); + verify(subchannel1).shutdown(); + + // subchanne2 expires. + clock.forwardTime(SHUTDOWN_TIMEOUT_MS - 1, MILLISECONDS); + verify(subchannel2).shutdown(); + + assertThat(clock.numPendingTasks()).isEqualTo(0); + verify(helper, times(2)).runSerialized(any(ShutdownSubchannelTask.class)); + } + + @Test + public void subchannelReused() { + Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1); + assertThat(subchannel1).isNotNull(); + verify(helper).createSubchannel(same(EAG1), same(ATTRS1)); + + Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2); + assertThat(subchannel2).isNotNull(); + assertThat(subchannel2).isNotSameAs(subchannel1); + verify(helper).createSubchannel(same(EAG2), same(ATTRS2)); + + pool.returnSubchannel(subchannel1); + + // subchannel1 is 1ms away from expiration. + clock.forwardTime(SHUTDOWN_TIMEOUT_MS - 1, MILLISECONDS); + + // This will cancel the shutdown timer for subchannel1 + Subchannel subchannel1a = pool.takeOrCreateSubchannel(EAG1, ATTRS1); + assertThat(subchannel1a).isSameAs(subchannel1); + + pool.returnSubchannel(subchannel2); + + // subchannel2 expires SHUTDOWN_TIMEOUT_MS after being returned + clock.forwardTime(SHUTDOWN_TIMEOUT_MS - 1, MILLISECONDS); + verify(subchannel2, never()).shutdown(); + clock.forwardTime(1, MILLISECONDS); + verify(subchannel2).shutdown(); + + // pool will create a new channel for EAG2 when requested + Subchannel subchannel2a = pool.takeOrCreateSubchannel(EAG2, ATTRS2); + assertThat(subchannel2a).isNotSameAs(subchannel2); + verify(helper, times(2)).createSubchannel(same(EAG2), same(ATTRS2)); + + // subchannel1 expires SHUTDOWN_TIMEOUT_MS after being returned + pool.returnSubchannel(subchannel1a); + clock.forwardTime(SHUTDOWN_TIMEOUT_MS - 1, MILLISECONDS); + verify(subchannel1a, never()).shutdown(); + clock.forwardTime(1, MILLISECONDS); + verify(subchannel1a).shutdown(); + + assertThat(clock.numPendingTasks()).isEqualTo(0); + verify(helper, times(2)).runSerialized(any(ShutdownSubchannelTask.class)); + } + + @Test + public void returnDuplicateAddressSubchannel() { + Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1); + Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG1, ATTRS2); + Subchannel subchannel3 = pool.takeOrCreateSubchannel(EAG2, ATTRS1); + assertThat(subchannel1).isNotSameAs(subchannel2); + + assertThat(clock.getPendingTasks(SHUTDOWN_SCHEDULED_TASK_FILTER)).isEmpty(); + pool.returnSubchannel(subchannel2); + assertThat(clock.getPendingTasks(SHUTDOWN_SCHEDULED_TASK_FILTER)).hasSize(1); + + // If the subchannel being returned has an address that is the same as a subchannel in the pool, + // the returned subchannel will be shut down. + verify(subchannel1, never()).shutdown(); + pool.returnSubchannel(subchannel1); + assertThat(clock.getPendingTasks(SHUTDOWN_SCHEDULED_TASK_FILTER)).hasSize(1); + verify(subchannel1).shutdown(); + + pool.returnSubchannel(subchannel3); + assertThat(clock.getPendingTasks(SHUTDOWN_SCHEDULED_TASK_FILTER)).hasSize(2); + // Returning the same subchannel twice has no effect. + pool.returnSubchannel(subchannel3); + assertThat(clock.getPendingTasks(SHUTDOWN_SCHEDULED_TASK_FILTER)).hasSize(2); + + verify(subchannel2, never()).shutdown(); + verify(subchannel3, never()).shutdown(); + verify(helper, never()).runSerialized(any(ShutdownSubchannelTask.class)); + } + + @Test + public void clear() { + Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1); + Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2); + Subchannel subchannel3 = pool.takeOrCreateSubchannel(EAG2, ATTRS2); + + pool.returnSubchannel(subchannel1); + pool.returnSubchannel(subchannel2); + + verify(subchannel1, never()).shutdown(); + verify(subchannel2, never()).shutdown(); + pool.clear(); + verify(subchannel1).shutdown(); + verify(subchannel2).shutdown(); + + verify(subchannel3, never()).shutdown(); + assertThat(clock.numPendingTasks()).isEqualTo(0); + verify(helper, never()).runSerialized(any(ShutdownSubchannelTask.class)); + } +} diff --git a/grpclb/src/test/java/io/grpc/grpclb/FakeSocketAddress.java b/grpclb/src/test/java/io/grpc/grpclb/FakeSocketAddress.java new file mode 100644 index 0000000000..4d6848a48a --- /dev/null +++ b/grpclb/src/test/java/io/grpc/grpclb/FakeSocketAddress.java @@ -0,0 +1,46 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.grpclb; + +import java.net.SocketAddress; + +final class FakeSocketAddress extends SocketAddress { + final String name; + + FakeSocketAddress(String name) { + this.name = name; + } + + @Override + public String toString() { + return "FakeSocketAddress-" + name; + } + + @Override + public boolean equals(Object other) { + if (other instanceof FakeSocketAddress) { + FakeSocketAddress otherAddr = (FakeSocketAddress) other; + return name.equals(otherAddr.name); + } + return false; + } + + @Override + public int hashCode() { + return name.hashCode(); + } +} diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java index dc18c1b31b..7379d66bfe 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java @@ -121,6 +121,8 @@ public class GrpclbLoadBalancerTest { @Mock private Helper helper; + @Mock + private SubchannelPool subchannelPool; private SubchannelPicker currentPicker; private LoadBalancerGrpc.LoadBalancerImplBase mockLbService; @Captor @@ -139,7 +141,6 @@ public class GrpclbLoadBalancerTest { return fakeClock.currentTimeMillis(); } }; - private io.grpc.Server fakeLbServer; @Captor private ArgumentCaptor pickerCaptor; @@ -215,7 +216,8 @@ public class GrpclbLoadBalancerTest { subchannelTracker.add(subchannel); return subchannel; } - }).when(helper).createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)); + }).when(subchannelPool).takeOrCreateSubchannel( + any(EquivalentAddressGroup.class), any(Attributes.class)); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocation) throws Throwable { @@ -233,11 +235,13 @@ public class GrpclbLoadBalancerTest { }).when(helper).updateBalancingState( any(ConnectivityState.class), any(SubchannelPicker.class)); when(helper.getAuthority()).thenReturn(SERVICE_AUTHORITY); - when(timerServicePool.getObject()).thenReturn(fakeClock.getScheduledExecutorService()); + ScheduledExecutorService timerService = fakeClock.getScheduledExecutorService(); + when(timerServicePool.getObject()).thenReturn(timerService); balancer = new GrpclbLoadBalancer( - helper, pickFirstBalancerFactory, roundRobinBalancerFactory, + helper, subchannelPool, pickFirstBalancerFactory, roundRobinBalancerFactory, timerServicePool, timeProvider); + verify(subchannelPool).init(same(helper), same(timerService)); } @After @@ -256,11 +260,17 @@ public class GrpclbLoadBalancerTest { // balancer should have closed the LB stream, terminating the OOB channel. assertTrue(channel + " is terminated", channel.isTerminated()); } + // GRPCLB manages subchannels only through subchannelPool for (Subchannel subchannel: subchannelTracker) { - verify(subchannel).shutdown(); + verify(subchannelPool).returnSubchannel(same(subchannel)); + // Our mock subchannelPool never calls Subchannel.shutdown(), thus we can tell if + // LoadBalancer has called it expectedly. + verify(subchannel, never()).shutdown(); } + verify(helper, never()) + .createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)); // No timer should linger after shutdown - assertEquals(0, fakeClock.numPendingTasks()); + assertThat(fakeClock.getPendingTasks()).isEmpty(); } finally { if (fakeLbServer != null) { fakeLbServer.shutdownNow(); @@ -382,7 +392,7 @@ public class GrpclbLoadBalancerTest { assertEquals(1, lbRequestObservers.size()); StreamObserver lbRequestObserver = lbRequestObservers.poll(); InOrder inOrder = inOrder(lbRequestObserver); - InOrder helperInOrder = inOrder(helper); + InOrder helperInOrder = inOrder(helper, subchannelPool); inOrder.verify(lbRequestObserver).onNext( eq(LoadBalanceRequest.newBuilder().setInitialRequest( @@ -565,7 +575,7 @@ public class GrpclbLoadBalancerTest { lbResponseObserver.onNext(buildLbResponse(backends)); // Same backends, thus no new subchannels - helperInOrder.verify(helper, never()).createSubchannel( + helperInOrder.verify(subchannelPool, never()).takeOrCreateSubchannel( any(EquivalentAddressGroup.class), any(Attributes.class)); // But the new RoundRobinEntries have a new loadRecorder, thus considered different from // the previous list, thus a new picker is created @@ -784,7 +794,7 @@ public class GrpclbLoadBalancerTest { @Test public void grpclbThenNameResolutionFails() { - InOrder inOrder = inOrder(helper); + InOrder inOrder = inOrder(helper, subchannelPool); // Go to GRPCLB first List grpclbResolutionList = createResolvedServerAddresses(true); Attributes grpclbResolutionAttrs = Attributes.newBuilder() @@ -818,9 +828,9 @@ public class GrpclbLoadBalancerTest { lbResponseObserver.onNext(buildLbResponse(backends)); verify(helper, times(2)).runSerialized(any(Runnable.class)); - inOrder.verify(helper).createSubchannel( + inOrder.verify(subchannelPool).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends.get(0).addr)), any(Attributes.class)); - inOrder.verify(helper).createSubchannel( + inOrder.verify(subchannelPool).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends.get(1).addr)), any(Attributes.class)); } @@ -863,6 +873,8 @@ public class GrpclbLoadBalancerTest { // GRPCLB connection is closed verify(lbRequestObserver).onCompleted(); assertTrue(oobChannel.isShutdown()); + // Switching away from GRPCLB will clear the subchannelPool + verify(subchannelPool).clear(); // Switch to ROUND_ROBIN List roundRobinResolutionList = @@ -911,11 +923,13 @@ public class GrpclbLoadBalancerTest { assertSame(pickFirstBalancer, balancer.getDelegate()); // GRPCLB connection is closed assertTrue(oobChannel.isShutdown()); + // Switching away from GRPCLB will clear the subchannelPool + verify(subchannelPool, times(2)).clear(); } @Test public void resetGrpclbWhenSwitchingAwayFromGrpclb() { - InOrder inOrder = inOrder(helper); + InOrder inOrder = inOrder(helper, subchannelPool); List grpclbResolutionList = createResolvedServerAddresses(true); Attributes grpclbResolutionAttrs = Attributes.newBuilder() .set(GrpclbConstants.ATTR_LB_POLICY, LbPolicy.GRPCLB).build(); @@ -942,7 +956,7 @@ public class GrpclbLoadBalancerTest { lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(backends)); - inOrder.verify(helper).createSubchannel( + inOrder.verify(subchannelPool).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends.get(0).addr)), any(Attributes.class)); assertEquals(1, mockSubchannels.size()); Subchannel subchannel = mockSubchannels.poll(); @@ -954,12 +968,12 @@ public class GrpclbLoadBalancerTest { Attributes roundRobinResolutionAttrs = Attributes.newBuilder() .set(GrpclbConstants.ATTR_LB_POLICY, LbPolicy.ROUND_ROBIN).build(); verify(lbRequestObserver, never()).onCompleted(); - verify(subchannel, never()).shutdown(); + verify(subchannelPool, never()).returnSubchannel(same(subchannel)); assertFalse(oobChannel.isShutdown()); deliverResolvedAddresses(roundRobinResolutionList, roundRobinResolutionAttrs); verify(lbRequestObserver).onCompleted(); - verify(subchannel).shutdown(); + verify(subchannelPool).returnSubchannel(same(subchannel)); assertTrue(oobChannel.isShutdown()); assertTrue(oobChannel.isTerminated()); assertSame(LbPolicy.ROUND_ROBIN, balancer.getLbPolicy()); @@ -1016,7 +1030,7 @@ public class GrpclbLoadBalancerTest { @Test public void grpclbWorking() { - InOrder inOrder = inOrder(helper); + InOrder inOrder = inOrder(helper, subchannelPool); List grpclbResolutionList = createResolvedServerAddresses(true); Attributes grpclbResolutionAttrs = Attributes.newBuilder() .set(GrpclbConstants.ATTR_LB_POLICY, LbPolicy.GRPCLB).build(); @@ -1048,9 +1062,9 @@ public class GrpclbLoadBalancerTest { lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(backends1)); - inOrder.verify(helper).createSubchannel( + inOrder.verify(subchannelPool).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends1.get(0).addr)), any(Attributes.class)); - inOrder.verify(helper).createSubchannel( + inOrder.verify(subchannelPool).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends1.get(1).addr)), any(Attributes.class)); assertEquals(2, mockSubchannels.size()); Subchannel subchannel1 = mockSubchannels.poll(); @@ -1121,15 +1135,17 @@ public class GrpclbLoadBalancerTest { new ServerEntry("127.0.0.1", 2010, "token0004"), // Existing address with token changed new ServerEntry("127.0.0.1", 2030, "token0005"), // New address appearing second time new ServerEntry("token0006")); // drop - verify(subchannel1, never()).shutdown(); + verify(subchannelPool, never()).returnSubchannel(same(subchannel1)); lbResponseObserver.onNext(buildLbResponse(backends2)); - verify(subchannel1).shutdown(); // not in backends2, closed - verify(subchannel2, never()).shutdown(); // backends2[2], will be kept + // not in backends2, closed + verify(subchannelPool).returnSubchannel(same(subchannel1)); + // backends2[2], will be kept + verify(subchannelPool, never()).returnSubchannel(same(subchannel2)); - inOrder.verify(helper, never()).createSubchannel( + inOrder.verify(subchannelPool, never()).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends2.get(2).addr)), any(Attributes.class)); - inOrder.verify(helper).createSubchannel( + inOrder.verify(subchannelPool).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends2.get(0).addr)), any(Attributes.class)); assertEquals(1, mockSubchannels.size()); Subchannel subchannel3 = mockSubchannels.poll(); @@ -1179,12 +1195,12 @@ public class GrpclbLoadBalancerTest { new BackendEntry(subchannel3, getLoadRecorder(), "token0003"), new BackendEntry(subchannel2, getLoadRecorder(), "token0004"), new BackendEntry(subchannel3, getLoadRecorder(), "token0005")).inOrder(); - verify(subchannel3, never()).shutdown(); + verify(subchannelPool, never()).returnSubchannel(same(subchannel3)); // Update backends, with no entry lbResponseObserver.onNext(buildLbResponse(Collections.emptyList())); - verify(subchannel2).shutdown(); - verify(subchannel3).shutdown(); + verify(subchannelPool).returnSubchannel(same(subchannel2)); + verify(subchannelPool).returnSubchannel(same(subchannel3)); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); RoundRobinPicker picker10 = (RoundRobinPicker) pickerCaptor.getValue(); assertThat(picker10.dropList).isEmpty(); @@ -1197,6 +1213,10 @@ public class GrpclbLoadBalancerTest { // Load reporting was not requested, thus never scheduled assertEquals(0, fakeClock.numPendingTasks(LOAD_REPORTING_TASK_FILTER)); + + verify(subchannelPool, never()).clear(); + balancer.shutdown(); + verify(subchannelPool).clear(); } @Test @@ -1212,7 +1232,7 @@ public class GrpclbLoadBalancerTest { // Fallback or not within the period of the initial timeout. private void subtestGrpclbFallbackInitialTimeout(boolean timerExpires) { long loadReportIntervalMillis = 1983; - InOrder helperInOrder = inOrder(helper); + InOrder inOrder = inOrder(helper, subchannelPool); // Create a resolution list with a mixture of balancer and backend addresses List resolutionList = @@ -1222,7 +1242,7 @@ public class GrpclbLoadBalancerTest { deliverResolvedAddresses(resolutionList, resolutionAttrs); assertSame(LbPolicy.GRPCLB, balancer.getLbPolicy()); - helperInOrder.verify(helper).createOobChannel( + inOrder.verify(helper).createOobChannel( addrsEq(resolutionList.get(1)), eq(lbAuthority(0))); // Attempted to connect to balancer @@ -1239,8 +1259,8 @@ public class GrpclbLoadBalancerTest { .build())); lbResponseObserver.onNext(buildInitialResponse(loadReportIntervalMillis)); // We don't care if runSerialized() has been run. - helperInOrder.verify(helper, atLeast(0)).runSerialized(any(Runnable.class)); - helperInOrder.verifyNoMoreInteractions(); + inOrder.verify(helper, atLeast(0)).runSerialized(any(Runnable.class)); + inOrder.verifyNoMoreInteractions(); assertEquals(1, fakeClock.numPendingTasks(FALLBACK_MODE_TASK_FILTER)); fakeClock.forwardTime(GrpclbState.FALLBACK_TIMEOUT_MS - 1, TimeUnit.MILLISECONDS); @@ -1277,7 +1297,7 @@ public class GrpclbLoadBalancerTest { assertEquals(0, fakeClock.numPendingTasks(FALLBACK_MODE_TASK_FILTER)); // Fall back to the backends from resolver fallbackTestVerifyUseOfFallbackBackendLists( - helperInOrder, helper, Arrays.asList(resolutionList.get(0), resolutionList.get(2))); + inOrder, Arrays.asList(resolutionList.get(0), resolutionList.get(2))); assertNull(balancer.getDelegate()); assertFalse(oobChannel.isShutdown()); @@ -1292,7 +1312,7 @@ public class GrpclbLoadBalancerTest { assertSame(LbPolicy.GRPCLB, balancer.getLbPolicy()); // New addresses are updated to the OobChannel - helperInOrder.verify(helper).updateOobChannelAddresses( + inOrder.verify(helper).updateOobChannelAddresses( same(oobChannel), eq(new EquivalentAddressGroup( Arrays.asList( @@ -1302,7 +1322,7 @@ public class GrpclbLoadBalancerTest { if (timerExpires) { // Still in fallback logic, except that the backend list is empty fallbackTestVerifyUseOfFallbackBackendLists( - helperInOrder, helper, Collections.emptyList()); + inOrder, Collections.emptyList()); } ////////////////////////////////////////////////// @@ -1313,14 +1333,14 @@ public class GrpclbLoadBalancerTest { assertSame(LbPolicy.GRPCLB, balancer.getLbPolicy()); // New LB address is updated to the OobChannel - helperInOrder.verify(helper).updateOobChannelAddresses( + inOrder.verify(helper).updateOobChannelAddresses( same(oobChannel), addrsEq(resolutionList.get(0))); if (timerExpires) { // New backend addresses are used for fallback fallbackTestVerifyUseOfFallbackBackendLists( - helperInOrder, helper, Arrays.asList(resolutionList.get(1), resolutionList.get(2))); + inOrder, Arrays.asList(resolutionList.get(1), resolutionList.get(2))); } //////////////////////////////////////////////// @@ -1330,7 +1350,7 @@ public class GrpclbLoadBalancerTest { lbResponseObserver.onError(streamError.asException()); // The error will NOT propagate to picker because fallback list is in use. - helperInOrder.verify(helper, never()) + inOrder.verify(helper, never()) .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); // A new stream is created verify(mockLbService, times(3)).balanceLoad(lbResponseObserverCaptor.capture()); @@ -1353,7 +1373,7 @@ public class GrpclbLoadBalancerTest { lbResponseObserver.onNext(buildLbResponse(serverList)); // Balancer-provided server list now in effect - fallbackTestVerifyUseOfBalancerBackendLists(helperInOrder, helper, serverList); + fallbackTestVerifyUseOfBalancerBackendLists(inOrder, serverList); /////////////////////////////////////////////////////////////// // New backend addresses from resolver outside of fallback mode @@ -1362,7 +1382,7 @@ public class GrpclbLoadBalancerTest { deliverResolvedAddresses(resolutionList, resolutionAttrs); assertSame(LbPolicy.GRPCLB, balancer.getLbPolicy()); // Will not affect the round robin list at all - helperInOrder.verify(helper, never()) + inOrder.verify(helper, never()) .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); // No fallback timeout timer scheduled. @@ -1388,7 +1408,7 @@ public class GrpclbLoadBalancerTest { private void subtestGrpclbFallbackConnectionLost( boolean balancerBroken, boolean allSubchannelsBroken) { long loadReportIntervalMillis = 1983; - InOrder inOrder = inOrder(helper, mockLbService); + InOrder inOrder = inOrder(helper, mockLbService, subchannelPool); // Create a resolution list with a mixture of balancer and backend addresses List resolutionList = @@ -1425,8 +1445,7 @@ public class GrpclbLoadBalancerTest { lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(serverList)); - List subchannels = - fallbackTestVerifyUseOfBalancerBackendLists(inOrder, helper, serverList); + List subchannels = fallbackTestVerifyUseOfBalancerBackendLists(inOrder, serverList); // Break connections if (balancerBroken) { @@ -1447,7 +1466,7 @@ public class GrpclbLoadBalancerTest { if (balancerBroken && allSubchannelsBroken) { // Going into fallback subchannels = fallbackTestVerifyUseOfFallbackBackendLists( - inOrder, helper, Arrays.asList(resolutionList.get(0), resolutionList.get(2))); + inOrder, Arrays.asList(resolutionList.get(0), resolutionList.get(2))); // When in fallback mode, fallback timer should not be scheduled when all backend // connections are lost @@ -1462,40 +1481,42 @@ public class GrpclbLoadBalancerTest { lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(serverList2)); - fallbackTestVerifyUseOfBalancerBackendLists(inOrder, helper, serverList2); + fallbackTestVerifyUseOfBalancerBackendLists(inOrder, serverList2); } assertEquals(0, fakeClock.numPendingTasks(FALLBACK_MODE_TASK_FILTER)); if (!(balancerBroken && allSubchannelsBroken)) { - verify(helper, never()).createSubchannel(eq(resolutionList.get(0)), any(Attributes.class)); - verify(helper, never()).createSubchannel(eq(resolutionList.get(2)), any(Attributes.class)); + verify(subchannelPool, never()).takeOrCreateSubchannel( + eq(resolutionList.get(0)), any(Attributes.class)); + verify(subchannelPool, never()).takeOrCreateSubchannel( + eq(resolutionList.get(2)), any(Attributes.class)); } } private List fallbackTestVerifyUseOfFallbackBackendLists( - InOrder inOrder, Helper helper, List addrs) { - return fallbackTestVerifyUseOfBackendLists(inOrder, helper, addrs, null); + InOrder inOrder, List addrs) { + return fallbackTestVerifyUseOfBackendLists(inOrder, addrs, null); } private List fallbackTestVerifyUseOfBalancerBackendLists( - InOrder inOrder, Helper helper, List servers) { + InOrder inOrder, List servers) { ArrayList addrs = new ArrayList(); ArrayList tokens = new ArrayList(); for (ServerEntry server : servers) { addrs.add(new EquivalentAddressGroup(server.addr)); tokens.add(server.token); } - return fallbackTestVerifyUseOfBackendLists(inOrder, helper, addrs, tokens); + return fallbackTestVerifyUseOfBackendLists(inOrder, addrs, tokens); } private List fallbackTestVerifyUseOfBackendLists( - InOrder inOrder, Helper helper, List addrs, + InOrder inOrder, List addrs, @Nullable List tokens) { if (tokens != null) { assertEquals(addrs.size(), tokens.size()); } for (EquivalentAddressGroup addr : addrs) { - inOrder.verify(helper).createSubchannel(addrsEq(addr), any(Attributes.class)); + inOrder.verify(subchannelPool).takeOrCreateSubchannel(addrsEq(addr), any(Attributes.class)); } RoundRobinPicker picker = (RoundRobinPicker) currentPicker; assertThat(picker.dropList).containsExactlyElementsIn(Collections.nCopies(addrs.size(), null)); @@ -1670,31 +1691,4 @@ public class GrpclbLoadBalancerTest { this.token = token; } } - - private static class FakeSocketAddress extends SocketAddress { - final String name; - - FakeSocketAddress(String name) { - this.name = name; - } - - @Override - public String toString() { - return "FakeSocketAddress-" + name; - } - - @Override - public boolean equals(Object other) { - if (other instanceof FakeSocketAddress) { - FakeSocketAddress otherAddr = (FakeSocketAddress) other; - return name.equals(otherAddr.name); - } - return false; - } - - @Override - public int hashCode() { - return name.hashCode(); - } - } }