diff --git a/core/src/main/java/io/grpc/LoadBalancer.java b/core/src/main/java/io/grpc/LoadBalancer.java index e9f0b61844..786322698e 100644 --- a/core/src/main/java/io/grpc/LoadBalancer.java +++ b/core/src/main/java/io/grpc/LoadBalancer.java @@ -16,6 +16,7 @@ package io.grpc; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; @@ -332,9 +333,17 @@ public abstract class LoadBalancer { * @param subchannel the involved Subchannel * @param stateInfo the new state * @since 1.2.0 + * @deprecated This method will be removed. Stop overriding it. Instead, pass {@link + * SubchannelStateListener} to {@link Helper#createSubchannel(List, Attributes, + * SubchannelStateListener)} or {@link Helper#createSubchannel(EquivalentAddressGroup, + * Attributes, SubchannelStateListener)} to receive Subchannel state updates */ - public abstract void handleSubchannelState( - Subchannel subchannel, ConnectivityStateInfo stateInfo); + @Deprecated + public void handleSubchannelState( + Subchannel subchannel, ConnectivityStateInfo stateInfo) { + // Do nothing. If the implemetation doesn't implement this, it will get subchannel states from + // the new API. We don't throw because there may be forwarding LoadBalancers still plumb this. + } /** * The channel asks the load-balancer to shutdown. No more callbacks will be called after this @@ -648,6 +657,149 @@ public abstract class LoadBalancer { } } + /** + * Arguments for {@link Helper#createSubchannel(CreateSubchannelArgs)}. + * + * @since 1.21.0 + */ + public static final class CreateSubchannelArgs { + private final List addrs; + private final Attributes attrs; + private final SubchannelStateListener stateListener; + + private CreateSubchannelArgs( + List addrs, Attributes attrs, + SubchannelStateListener stateListener) { + this.addrs = checkNotNull(addrs, "addresses are not set"); + this.attrs = checkNotNull(attrs, "attrs"); + this.stateListener = checkNotNull(stateListener, "SubchannelStateListener is not set"); + } + + /** + * Returns the addresses, which is an unmodifiable list. + */ + public List getAddresses() { + return addrs; + } + + /** + * Returns the attributes. + */ + public Attributes getAttributes() { + return attrs; + } + + /** + * Returns the state listener. + */ + public SubchannelStateListener getStateListener() { + return stateListener; + } + + /** + * Returns a builder with the same initial values as this object. + */ + public Builder toBuilder() { + return newBuilder().setAddresses(addrs).setAttributes(attrs).setStateListener(stateListener); + } + + /** + * Creates a new builder. + */ + public static Builder newBuilder() { + return new Builder(); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("addrs", addrs) + .add("attrs", attrs) + .add("listener", stateListener) + .toString(); + } + + @Override + public int hashCode() { + return Objects.hashCode(addrs, attrs, stateListener); + } + + /** + * Returns true if the {@link Subchannel}, {@link Status}, and + * {@link ClientStreamTracer.Factory} all match. + */ + @Override + public boolean equals(Object other) { + if (!(other instanceof CreateSubchannelArgs)) { + return false; + } + CreateSubchannelArgs that = (CreateSubchannelArgs) other; + return Objects.equal(addrs, that.addrs) && Objects.equal(attrs, that.attrs) + && Objects.equal(stateListener, that.stateListener); + } + + public static final class Builder { + private List addrs; + private Attributes attrs = Attributes.EMPTY; + private SubchannelStateListener stateListener; + + Builder() { + } + + /** + * The addresses to connect to. All addresses are considered equivalent and will be tried + * in the order they are provided. + */ + public Builder setAddresses(EquivalentAddressGroup addrs) { + this.addrs = Collections.singletonList(addrs); + return this; + } + + /** + * The addresses to connect to. All addresses are considered equivalent and will + * be tried in the order they are provided. + * + *

This is a required property. + * + * @throws IllegalArgumentException if {@code addrs} is empty + */ + public Builder setAddresses(List addrs) { + checkArgument(!addrs.isEmpty(), "addrs is empty"); + this.addrs = Collections.unmodifiableList(new ArrayList<>(addrs)); + return this; + } + + /** + * Attributes provided here will be included in {@link Subchannel#getAttributes}. + * + *

This is an optional property. Default is empty if not set. + */ + public Builder setAttributes(Attributes attrs) { + this.attrs = checkNotNull(attrs, "attrs"); + return this; + } + + /** + * Receives state changes of the created Subchannel. The listener is called from + * the {@link #getSynchronizationContext Synchronization Context}. It's safe to share the + * listener among multiple Subchannels. + * + *

This is a required property. + */ + public Builder setStateListener(SubchannelStateListener listener) { + this.stateListener = checkNotNull(listener, "listener"); + return this; + } + + /** + * Creates a new args object. + */ + public CreateSubchannelArgs build() { + return new CreateSubchannelArgs(addrs, attrs, stateListener); + } + } + } + /** * Provides essentials for LoadBalancer implementations. * @@ -661,7 +813,11 @@ public abstract class LoadBalancer { * EquivalentAddressGroup}. * * @since 1.2.0 + * @deprecated Use {@link #createSubchannel(CreateSubchannelArgs)} instead. Note the new API + * must be called from {@link #getSynchronizationContext the Synchronization + * Context}. */ + @Deprecated public final Subchannel createSubchannel(EquivalentAddressGroup addrs, Attributes attrs) { checkNotNull(addrs, "addrs"); return createSubchannel(Collections.singletonList(addrs), attrs); @@ -682,11 +838,35 @@ public abstract class LoadBalancer { * * @throws IllegalArgumentException if {@code addrs} is empty * @since 1.14.0 + * @deprecated Use {@link #createSubchannel(CreateSubchannelArgs)} instead. Note the new API + * must be called from {@link #getSynchronizationContext the Synchronization + * Context}. */ + @Deprecated public Subchannel createSubchannel(List addrs, Attributes attrs) { throw new UnsupportedOperationException(); } + /** + * Creates a Subchannel, which is a logical connection to the given group of addresses which are + * considered equivalent. The {@code attrs} are custom attributes associated with this + * Subchannel, and can be accessed later through {@link Subchannel#getAttributes + * Subchannel.getAttributes()}. + * + *

This method must be called from the {@link #getSynchronizationContext + * Synchronization Context}, otherwise it may throw. This is to avoid the race between + * the caller and {@link SubchannelStateListener#onSubchannelState}. See #5015 for more discussions. + * + *

The LoadBalancer is responsible for closing unused Subchannels, and closing all + * Subchannels within {@link #shutdown}. + * + * @since 1.21.0 + */ + public Subchannel createSubchannel(CreateSubchannelArgs args) { + throw new UnsupportedOperationException(); + } + /** * Equivalent to {@link #updateSubchannelAddresses(io.grpc.LoadBalancer.Subchannel, List)} with * the given single {@code EquivalentAddressGroup}. @@ -903,7 +1083,7 @@ public abstract class LoadBalancer { */ public final EquivalentAddressGroup getAddresses() { List groups = getAllAddresses(); - Preconditions.checkState(groups.size() == 1, "Does not have exactly one group"); + Preconditions.checkState(groups.size() == 1, "%s does not have exactly one group", groups); return groups.get(0); } @@ -964,6 +1144,39 @@ public abstract class LoadBalancer { } } + /** + * Receives state changes for one or more {@link Subchannel}s. All methods are run under {@link + * Helper#getSynchronizationContext}. + * + * @since 1.21.0 + */ + public interface SubchannelStateListener { + + /** + * Handles a state change on a Subchannel. + * + *

The initial state of a Subchannel is IDLE. You won't get a notification for the initial + * IDLE state. + * + *

If the new state is not SHUTDOWN, this method should create a new picker and call {@link + * Helper#updateBalancingState Helper.updateBalancingState()}. Failing to do so may result in + * unnecessary delays of RPCs. Please refer to {@link PickResult#withSubchannel + * PickResult.withSubchannel()}'s javadoc for more information. + * + *

SHUTDOWN can only happen in two cases. One is that LoadBalancer called {@link + * Subchannel#shutdown} earlier, thus it should have already discarded this Subchannel. The + * other is that Channel is doing a {@link ManagedChannel#shutdownNow forced shutdown} or has + * already terminated, thus there won't be further requests to LoadBalancer. Therefore, + * SHUTDOWN can be safely ignored. + * + * @param subchannel the involved Subchannel + * @param newState the new state + * + * @since 1.21.0 + */ + void onSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState); + } + /** * Factory to create {@link LoadBalancer} instance. * diff --git a/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java b/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java index c0872e823b..52e0ace586 100644 --- a/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java +++ b/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java @@ -77,9 +77,6 @@ public final class AutoConfiguredLoadBalancerFactory extends LoadBalancer.Factor @Override public void handleNameResolutionError(Status error) {} - @Override - public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) {} - @Override public void shutdown() {} } @@ -165,6 +162,7 @@ public final class AutoConfiguredLoadBalancerFactory extends LoadBalancer.Factor getDelegate().handleNameResolutionError(error); } + @Deprecated @Override public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { getDelegate().handleSubchannelState(subchannel, stateInfo); diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index a1109f1a63..3bdefd3938 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -53,6 +53,7 @@ import io.grpc.InternalInstrumented; import io.grpc.InternalLogId; import io.grpc.InternalWithLogId; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; @@ -1050,6 +1051,7 @@ final class ManagedChannelImpl extends ManagedChannel implements } } + @Deprecated @Override public AbstractSubchannel createSubchannel( List addressGroups, Attributes attrs) { @@ -1061,17 +1063,36 @@ final class ManagedChannelImpl extends ManagedChannel implements + " Otherwise, it may race with handleSubchannelState()." + " See https://github.com/grpc/grpc-java/issues/5015", e); } - checkNotNull(addressGroups, "addressGroups"); - checkNotNull(attrs, "attrs"); + return createSubchannelInternal( + CreateSubchannelArgs.newBuilder() + .setAddresses(addressGroups) + .setAttributes(attrs) + .setStateListener(new LoadBalancer.SubchannelStateListener() { + @Override + public void onSubchannelState( + LoadBalancer.Subchannel subchannel, ConnectivityStateInfo newState) { + lb.handleSubchannelState(subchannel, newState); + } + }) + .build()); + } + + @Override + public AbstractSubchannel createSubchannel(CreateSubchannelArgs args) { + syncContext.throwIfNotInThisSynchronizationContext(); + return createSubchannelInternal(args); + } + + private AbstractSubchannel createSubchannelInternal(final CreateSubchannelArgs args) { // TODO(ejona): can we be even stricter? Like loadBalancer == null? checkState(!terminated, "Channel is terminated"); - final SubchannelImpl subchannel = new SubchannelImpl(attrs); + final SubchannelImpl subchannel = new SubchannelImpl(args.getAttributes()); long subchannelCreationTime = timeProvider.currentTimeNanos(); InternalLogId subchannelLogId = InternalLogId.allocate("Subchannel", /*details=*/ null); ChannelTracer subchannelTracer = new ChannelTracer( subchannelLogId, maxTraceEvents, subchannelCreationTime, - "Subchannel for " + addressGroups); + "Subchannel for " + args.getAddresses()); final class ManagedInternalSubchannelCallback extends InternalSubchannel.Callback { // All callbacks are run in syncContext @@ -1087,7 +1108,7 @@ final class ManagedChannelImpl extends ManagedChannel implements handleInternalSubchannelState(newState); // Call LB only if it's not shutdown. If LB is shutdown, lbHelper won't match. if (LbHelperImpl.this == ManagedChannelImpl.this.lbHelper) { - lb.handleSubchannelState(subchannel, newState); + args.getStateListener().onSubchannelState(subchannel, newState); } } @@ -1103,7 +1124,7 @@ final class ManagedChannelImpl extends ManagedChannel implements } final InternalSubchannel internalSubchannel = new InternalSubchannel( - addressGroups, + args.getAddresses(), authority(), userAgent, backoffPolicyProvider, diff --git a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java index 1afcc42815..901126f36f 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java @@ -21,7 +21,6 @@ import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; @@ -31,6 +30,7 @@ import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Status; import java.util.List; @@ -39,7 +39,7 @@ import java.util.List; * NameResolver}. The channel's default behavior is used, which is walking down the address list * and sticking to the first that works. */ -final class PickFirstLoadBalancer extends LoadBalancer { +final class PickFirstLoadBalancer extends LoadBalancer implements SubchannelStateListener { private final Helper helper; private Subchannel subchannel; @@ -51,7 +51,11 @@ final class PickFirstLoadBalancer extends LoadBalancer { public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { List servers = resolvedAddresses.getServers(); if (subchannel == null) { - subchannel = helper.createSubchannel(servers, Attributes.EMPTY); + subchannel = helper.createSubchannel( + CreateSubchannelArgs.newBuilder() + .setAddresses(servers) + .setStateListener(this) + .build()); // The channel state does not get updated when doing name resolving today, so for the moment // let LB report CONNECTION and call subchannel.requestConnection() immediately. @@ -74,9 +78,9 @@ final class PickFirstLoadBalancer extends LoadBalancer { } @Override - public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { + public void onSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { ConnectivityState currentState = stateInfo.getState(); - if (subchannel != this.subchannel || currentState == SHUTDOWN) { + if (subchannel != PickFirstLoadBalancer.this.subchannel || currentState == SHUTDOWN) { return; } @@ -99,7 +103,6 @@ final class PickFirstLoadBalancer extends LoadBalancer { default: throw new IllegalArgumentException("Unsupported state:" + currentState); } - helper.updateBalancingState(currentState, picker); } diff --git a/core/src/main/java/io/grpc/util/ForwardingLoadBalancer.java b/core/src/main/java/io/grpc/util/ForwardingLoadBalancer.java index a5c086fe13..44ca8a68b7 100644 --- a/core/src/main/java/io/grpc/util/ForwardingLoadBalancer.java +++ b/core/src/main/java/io/grpc/util/ForwardingLoadBalancer.java @@ -51,6 +51,7 @@ public abstract class ForwardingLoadBalancer extends LoadBalancer { delegate().handleNameResolutionError(error); } + @Deprecated @Override public void handleSubchannelState( Subchannel subchannel, ConnectivityStateInfo stateInfo) { diff --git a/core/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java b/core/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java index aeac77db00..15a77a007d 100644 --- a/core/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java +++ b/core/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java @@ -22,6 +22,7 @@ import io.grpc.ChannelLogger; import io.grpc.ConnectivityState; import io.grpc.EquivalentAddressGroup; import io.grpc.ExperimentalApi; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer; @@ -38,11 +39,17 @@ public abstract class ForwardingLoadBalancerHelper extends LoadBalancer.Helper { */ protected abstract LoadBalancer.Helper delegate(); + @Deprecated @Override public Subchannel createSubchannel(List addrs, Attributes attrs) { return delegate().createSubchannel(addrs, attrs); } + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + return delegate().createSubchannel(args); + } + @Override public void updateSubchannelAddresses( Subchannel subchannel, List addrs) { diff --git a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java index 76743cbc4a..6db8012d45 100644 --- a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java +++ b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java @@ -37,6 +37,7 @@ import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Metadata; import io.grpc.Metadata.Key; import io.grpc.NameResolver; @@ -63,7 +64,7 @@ import javax.annotation.Nullable; * A {@link LoadBalancer} that provides round-robin load-balancing over the {@link * EquivalentAddressGroup}s from the {@link NameResolver}. */ -final class RoundRobinLoadBalancer extends LoadBalancer { +final class RoundRobinLoadBalancer extends LoadBalancer implements SubchannelStateListener { @VisibleForTesting static final Attributes.Key> STATE_INFO = Attributes.Key.create("state-info"); @@ -130,7 +131,12 @@ final class RoundRobinLoadBalancer extends LoadBalancer { } Subchannel subchannel = checkNotNull( - helper.createSubchannel(addressGroup, subchannelAttrs.build()), "subchannel"); + helper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(addressGroup) + .setAttributes(subchannelAttrs.build()) + .setStateListener(this) + .build()), + "subchannel"); if (stickyRef != null) { stickyRef.value = subchannel; } @@ -161,7 +167,7 @@ final class RoundRobinLoadBalancer extends LoadBalancer { } @Override - public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { + public void onSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { if (subchannels.get(subchannel.getAddresses()) != subchannel) { return; } diff --git a/core/src/test/java/io/grpc/LoadBalancerTest.java b/core/src/test/java/io/grpc/LoadBalancerTest.java index 78e6ec764c..794cbea1c0 100644 --- a/core/src/test/java/io/grpc/LoadBalancerTest.java +++ b/core/src/test/java/io/grpc/LoadBalancerTest.java @@ -17,11 +17,14 @@ package io.grpc; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelStateListener; import java.net.SocketAddress; import java.util.Arrays; import java.util.List; @@ -35,6 +38,8 @@ import org.junit.runners.JUnit4; public class LoadBalancerTest { private final Subchannel subchannel = mock(Subchannel.class); private final Subchannel subchannel2 = mock(Subchannel.class); + private final SubchannelStateListener subchannelStateListener = + mock(SubchannelStateListener.class); private final ClientStreamTracer.Factory tracerFactory = mock(ClientStreamTracer.Factory.class); private final Status status = Status.UNAVAILABLE.withDescription("for test"); private final Status status2 = Status.UNAVAILABLE.withDescription("for test 2"); @@ -120,8 +125,9 @@ public class LoadBalancerTest { assertThat(error1).isNotEqualTo(drop1); } + @Deprecated @Test - public void helper_createSubchannel_delegates() { + public void helper_createSubchannel_old_delegates() { class OverrideCreateSubchannel extends NoopHelper { boolean ran; @@ -140,9 +146,29 @@ public class LoadBalancerTest { assertThat(helper.ran).isTrue(); } - @Test(expected = UnsupportedOperationException.class) + @Test + @SuppressWarnings("deprecation") + public void helper_createSubchannelList_oldApi_throws() { + try { + new NoopHelper().createSubchannel(Arrays.asList(eag), attrs); + fail("Should throw"); + } catch (UnsupportedOperationException e) { + // exepcted + } + } + + @Test public void helper_createSubchannelList_throws() { - new NoopHelper().createSubchannel(Arrays.asList(eag), attrs); + try { + new NoopHelper().createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(eag) + .setAttributes(attrs) + .setStateListener(subchannelStateListener) + .build()); + fail("Should throw"); + } catch (UnsupportedOperationException e) { + // expected + } } @Test diff --git a/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java b/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java index b5e6f804d8..1e3b3e7354 100644 --- a/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java +++ b/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java @@ -41,10 +41,12 @@ import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; @@ -134,6 +136,7 @@ public class AutoConfiguredLoadBalancerFactoryTest { assertThat(lb.getDelegate()).isSameAs(testLbBalancer); } + @SuppressWarnings("deprecation") @Test public void forwardsCalls() { AutoConfiguredLoadBalancer lb = @@ -176,9 +179,9 @@ public class AutoConfiguredLoadBalancerFactoryTest { Collections.singletonList(new EquivalentAddressGroup(new SocketAddress(){})); Helper helper = new TestHelper() { @Override - public Subchannel createSubchannel(List addrs, Attributes attrs) { - assertThat(addrs).isEqualTo(servers); - return new TestSubchannel(addrs, attrs); + public Subchannel createSubchannel(CreateSubchannelArgs args) { + assertThat(args.getAddresses()).isEqualTo(servers); + return new TestSubchannel(args); } }; AutoConfiguredLoadBalancer lb = @@ -206,9 +209,9 @@ public class AutoConfiguredLoadBalancerFactoryTest { Collections.singletonList(new EquivalentAddressGroup(new SocketAddress(){})); Helper helper = new TestHelper() { @Override - public Subchannel createSubchannel(List addrs, Attributes attrs) { - assertThat(addrs).isEqualTo(servers); - return new TestSubchannel(addrs, attrs); + public Subchannel createSubchannel(CreateSubchannelArgs args) { + assertThat(args.getAddresses()).isEqualTo(servers); + return new TestSubchannel(args); } }; AutoConfiguredLoadBalancer lb = @@ -221,11 +224,6 @@ public class AutoConfiguredLoadBalancerFactoryTest { // noop } - @Override - public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { - // noop - } - @Override public void shutdown() { shutdown.set(true); @@ -704,8 +702,18 @@ public class AutoConfiguredLoadBalancerFactoryTest { Collections.singletonList(new EquivalentAddressGroup(new SocketAddress(){})); Helper helper = new TestHelper() { @Override + @Deprecated public Subchannel createSubchannel(List addrs, Attributes attrs) { - return new TestSubchannel(addrs, attrs); + return new TestSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(addrs) + .setAttributes(attrs) + .setStateListener(mock(SubchannelStateListener.class)) + .build()); + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + return new TestSubchannel(args); } @Override @@ -822,11 +830,6 @@ public class AutoConfiguredLoadBalancerFactoryTest { delegate().handleNameResolutionError(error); } - @Override - public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { - delegate().handleSubchannelState(subchannel, stateInfo); - } - @Override public void shutdown() { delegate().shutdown(); @@ -862,9 +865,9 @@ public class AutoConfiguredLoadBalancerFactoryTest { } private static class TestSubchannel extends Subchannel { - TestSubchannel(List addrs, Attributes attrs) { - this.addrs = addrs; - this.attrs = attrs; + TestSubchannel(CreateSubchannelArgs args) { + this.addrs = args.getAddresses(); + this.attrs = args.getAttributes(); } final List addrs; diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java index 1d735aa574..9c25ed7a09 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java @@ -41,12 +41,14 @@ import io.grpc.ClientInterceptor; import io.grpc.EquivalentAddressGroup; import io.grpc.IntegerMarshaller; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; @@ -113,6 +115,7 @@ public class ManagedChannelImplIdlenessTest { @Mock private ClientTransportFactory mockTransportFactory; @Mock private LoadBalancer mockLoadBalancer; + @Mock private SubchannelStateListener subchannelStateListener; private final LoadBalancerProvider mockLoadBalancerProvider = mock(LoadBalancerProvider.class, delegatesTo(new LoadBalancerProvider() { @Override @@ -499,14 +502,19 @@ public class ManagedChannelImplIdlenessTest { } // We need this because createSubchannel() should be called from the SynchronizationContext - private static Subchannel createSubchannelSafely( + private Subchannel createSubchannelSafely( final Helper helper, final EquivalentAddressGroup addressGroup, final Attributes attrs) { final AtomicReference resultCapture = new AtomicReference<>(); helper.getSynchronizationContext().execute( new Runnable() { @Override public void run() { - resultCapture.set(helper.createSubchannel(addressGroup, attrs)); + resultCapture.set( + helper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(addressGroup) + .setAttributes(attrs) + .setStateListener(subchannelStateListener) + .build())); } }); return resultCapture.get(); diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 89ce41881c..1836ab36e1 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -78,12 +78,14 @@ import io.grpc.InternalChannelz.ChannelStats; import io.grpc.InternalChannelz.ChannelTrace; import io.grpc.InternalInstrumented; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; @@ -207,6 +209,8 @@ public class ManagedChannelImplTest { private ArgumentCaptor callOptionsCaptor; @Mock private LoadBalancer mockLoadBalancer; + @Mock + private SubchannelStateListener subchannelStateListener; private final LoadBalancerProvider mockLoadBalancerProvider = mock(LoadBalancerProvider.class, delegatesTo(new LoadBalancerProvider() { @Override @@ -334,8 +338,9 @@ public class ManagedChannelImplTest { LoadBalancerRegistry.getDefaultRegistry().deregister(mockLoadBalancerProvider); } + @Deprecated @Test - public void createSubchannelOutsideSynchronizationContextShouldLogWarning() { + public void createSubchannel_old_outsideSynchronizationContextShouldLogWarning() { createChannel(); final AtomicReference logRef = new AtomicReference<>(); Handler handler = new Handler() { @@ -366,6 +371,48 @@ public class ManagedChannelImplTest { } } + @Deprecated + @Test + public void createSubchannel_old_propagateSubchannelStatesToOldApi() { + createChannel(); + final AtomicReference subchannelCapture = new AtomicReference<>(); + helper.getSynchronizationContext().execute(new Runnable() { + @Override + public void run() { + subchannelCapture.set(helper.createSubchannel(addressGroup, Attributes.EMPTY)); + } + }); + + Subchannel subchannel = subchannelCapture.get(); + subchannel.requestConnection(); + + verify(mockTransportFactory) + .newClientTransport( + any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); + verify(mockLoadBalancer).handleSubchannelState( + same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); + + MockClientTransportInfo transportInfo = transports.poll(); + transportInfo.listener.transportReady(); + + verify(mockLoadBalancer).handleSubchannelState( + same(subchannel), eq(ConnectivityStateInfo.forNonError(READY))); + } + + @Test + public void createSubchannel_outsideSynchronizationContextShouldThrow() { + createChannel(); + try { + helper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(addressGroup) + .setStateListener(subchannelStateListener) + .build()); + fail("Should throw"); + } catch (IllegalStateException e) { + assertThat(e).hasMessageThat().isEqualTo("Not called from the SynchronizationContext"); + } + } + @Test @SuppressWarnings("unchecked") public void idleModeDisabled() { @@ -426,7 +473,8 @@ public class ManagedChannelImplTest { assertNotNull(channelz.getRootChannel(channel.getLogId().getId())); AbstractSubchannel subchannel = - (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + (AbstractSubchannel) createSubchannelSafely( + helper, addressGroup, Attributes.EMPTY, subchannelStateListener); // subchannels are not root channels assertNull(channelz.getRootChannel(subchannel.getInternalSubchannel().getLogId().getId())); assertTrue(channelz.containsSubchannel(subchannel.getInternalSubchannel().getLogId())); @@ -518,7 +566,8 @@ public class ManagedChannelImplTest { // Configure the picker so that first RPC goes to delayed transport, and second RPC goes to // real transport. - Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); subchannel.requestConnection(); verify(mockTransportFactory) .newClientTransport( @@ -657,8 +706,12 @@ public class ManagedChannelImplTest { .setAttributes(Attributes.EMPTY) .build()); - Subchannel subchannel1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - Subchannel subchannel2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + SubchannelStateListener stateListener1 = mock(SubchannelStateListener.class); + SubchannelStateListener stateListener2 = mock(SubchannelStateListener.class); + Subchannel subchannel1 = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, stateListener1); + Subchannel subchannel2 = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, stateListener2); subchannel1.requestConnection(); subchannel2.requestConnection(); verify(mockTransportFactory, times(2)) @@ -669,13 +722,14 @@ public class ManagedChannelImplTest { // LoadBalancer receives all sorts of callbacks transportInfo1.listener.transportReady(); - verify(mockLoadBalancer, times(2)) - .handleSubchannelState(same(subchannel1), stateInfoCaptor.capture()); + + verify(stateListener1, times(2)) + .onSubchannelState(same(subchannel1), stateInfoCaptor.capture()); assertSame(CONNECTING, stateInfoCaptor.getAllValues().get(0).getState()); assertSame(READY, stateInfoCaptor.getAllValues().get(1).getState()); - verify(mockLoadBalancer) - .handleSubchannelState(same(subchannel2), stateInfoCaptor.capture()); + verify(stateListener2) + .onSubchannelState(same(subchannel2), stateInfoCaptor.capture()); assertSame(CONNECTING, stateInfoCaptor.getValue().getState()); resolver.observer.onError(resolutionError); @@ -724,7 +778,8 @@ public class ManagedChannelImplTest { call.start(mockCallListener, headers); // Make the transport available - Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); verify(mockTransportFactory, never()) .newClientTransport( any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); @@ -1029,7 +1084,7 @@ public class ManagedChannelImplTest { return "badAddress"; } }; - InOrder inOrder = inOrder(mockLoadBalancer); + InOrder inOrder = inOrder(mockLoadBalancer, subchannelStateListener); List resolvedAddrs = Arrays.asList(badAddress, goodAddress); FakeNameResolverFactory nameResolverFactory = @@ -1050,14 +1105,14 @@ public class ManagedChannelImplTest { inOrder.verify(mockLoadBalancer).handleResolvedAddresses( ResolvedAddresses.newBuilder() .setServers(Arrays.asList(addressGroup)) - .setAttributes(Attributes.EMPTY) .build()); - Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); subchannel.requestConnection(); - inOrder.verify(mockLoadBalancer).handleSubchannelState( - same(subchannel), stateInfoCaptor.capture()); + inOrder.verify(subchannelStateListener) + .onSubchannelState(same(subchannel), stateInfoCaptor.capture()); assertEquals(CONNECTING, stateInfoCaptor.getValue().getState()); // The channel will starts with the first address (badAddress) @@ -1083,8 +1138,8 @@ public class ManagedChannelImplTest { .thenReturn(mock(ClientStream.class)); goodTransportInfo.listener.transportReady(); - inOrder.verify(mockLoadBalancer).handleSubchannelState( - same(subchannel), stateInfoCaptor.capture()); + inOrder.verify(subchannelStateListener) + .onSubchannelState(same(subchannel), stateInfoCaptor.capture()); assertEquals(READY, stateInfoCaptor.getValue().getState()); // A typical LoadBalancer will call this once the subchannel becomes READY @@ -1174,7 +1229,7 @@ public class ManagedChannelImplTest { return "addr2"; } }; - InOrder inOrder = inOrder(mockLoadBalancer); + InOrder inOrder = inOrder(mockLoadBalancer, subchannelStateListener); List resolvedAddrs = Arrays.asList(addr1, addr2); @@ -1201,15 +1256,15 @@ public class ManagedChannelImplTest { inOrder.verify(mockLoadBalancer).handleResolvedAddresses( ResolvedAddresses.newBuilder() .setServers(Arrays.asList(addressGroup)) - .setAttributes(Attributes.EMPTY) .build()); - Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); subchannel.requestConnection(); - inOrder.verify(mockLoadBalancer).handleSubchannelState( - same(subchannel), stateInfoCaptor.capture()); + inOrder.verify(subchannelStateListener) + .onSubchannelState(same(subchannel), stateInfoCaptor.capture()); assertEquals(CONNECTING, stateInfoCaptor.getValue().getState()); // Connecting to server1, which will fail @@ -1232,8 +1287,8 @@ public class ManagedChannelImplTest { // ... which makes the subchannel enter TRANSIENT_FAILURE. The last error Status is propagated // to LoadBalancer. - inOrder.verify(mockLoadBalancer).handleSubchannelState( - same(subchannel), stateInfoCaptor.capture()); + inOrder.verify(subchannelStateListener) + .onSubchannelState(same(subchannel), stateInfoCaptor.capture()); assertEquals(TRANSIENT_FAILURE, stateInfoCaptor.getValue().getState()); assertSame(server2Error, stateInfoCaptor.getValue().getStatus()); @@ -1262,8 +1317,10 @@ public class ManagedChannelImplTest { // createSubchannel() always return a new Subchannel Attributes attrs1 = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, "attr1").build(); Attributes attrs2 = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, "attr2").build(); - Subchannel sub1 = createSubchannelSafely(helper, addressGroup, attrs1); - Subchannel sub2 = createSubchannelSafely(helper, addressGroup, attrs2); + SubchannelStateListener listener1 = mock(SubchannelStateListener.class); + SubchannelStateListener listener2 = mock(SubchannelStateListener.class); + Subchannel sub1 = createSubchannelSafely(helper, addressGroup, attrs1, listener1); + Subchannel sub2 = createSubchannelSafely(helper, addressGroup, attrs2, listener2); assertNotSame(sub1, sub2); assertNotSame(attrs1, attrs2); assertSame(attrs1, sub1.getAttributes()); @@ -1330,8 +1387,10 @@ public class ManagedChannelImplTest { @Test public void subchannelsWhenChannelShutdownNow() { createChannel(); - Subchannel sub1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - Subchannel sub2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel sub1 = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); + Subchannel sub2 = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); sub1.requestConnection(); sub2.requestConnection(); @@ -1358,8 +1417,10 @@ public class ManagedChannelImplTest { @Test public void subchannelsNoConnectionShutdown() { createChannel(); - Subchannel sub1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - Subchannel sub2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel sub1 = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); + Subchannel sub2 = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); channel.shutdown(); verify(mockLoadBalancer).shutdown(); @@ -1375,8 +1436,8 @@ public class ManagedChannelImplTest { @Test public void subchannelsNoConnectionShutdownNow() { createChannel(); - createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); channel.shutdownNow(); verify(mockLoadBalancer).shutdown(); @@ -1558,7 +1619,8 @@ public class ManagedChannelImplTest { @Test public void subchannelChannel_normalUsage() { createChannel(); - Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); verify(balancerRpcExecutorPool, never()).getObject(); Channel sChannel = subchannel.asChannel(); @@ -1589,7 +1651,8 @@ public class ManagedChannelImplTest { @Test public void subchannelChannel_failWhenNotReady() { createChannel(); - Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); Channel sChannel = subchannel.asChannel(); Metadata headers = new Metadata(); @@ -1617,7 +1680,8 @@ public class ManagedChannelImplTest { @Test public void subchannelChannel_failWaitForReady() { createChannel(); - Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); Channel sChannel = subchannel.asChannel(); Metadata headers = new Metadata(); @@ -1704,7 +1768,8 @@ public class ManagedChannelImplTest { OobChannel oobChannel = (OobChannel) helper.createOobChannel(addressGroup, "oobAuthority"); oobChannel.getSubchannel().requestConnection(); } else { - Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); subchannel.requestConnection(); } @@ -1791,7 +1856,8 @@ public class ManagedChannelImplTest { // Simulate name resolution results EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(socketAddress); - Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); subchannel.requestConnection(); verify(mockTransportFactory) .newClientTransport( @@ -1864,7 +1930,8 @@ public class ManagedChannelImplTest { ClientStreamTracer.Factory factory1 = mock(ClientStreamTracer.Factory.class); ClientStreamTracer.Factory factory2 = mock(ClientStreamTracer.Factory.class); createChannel(); - Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); subchannel.requestConnection(); MockClientTransportInfo transportInfo = transports.poll(); transportInfo.listener.transportReady(); @@ -1902,7 +1969,8 @@ public class ManagedChannelImplTest { ClientCall call = channel.newCall(method, callOptions); call.start(mockCallListener, new Metadata()); - Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); subchannel.requestConnection(); MockClientTransportInfo transportInfo = transports.poll(); transportInfo.listener.transportReady(); @@ -2277,7 +2345,8 @@ public class ManagedChannelImplTest { Helper helper2 = helperCaptor.getValue(); // Establish a connection - Subchannel subchannel = createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY, subchannelStateListener); subchannel.requestConnection(); MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; @@ -2345,7 +2414,8 @@ public class ManagedChannelImplTest { Helper helper2 = helperCaptor.getValue(); // Establish a connection - Subchannel subchannel = createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper2, addressGroup, Attributes.EMPTY, subchannelStateListener); subchannel.requestConnection(); ClientStream mockStream = mock(ClientStream.class); MockClientTransportInfo transportInfo = transports.poll(); @@ -2374,8 +2444,10 @@ public class ManagedChannelImplTest { call.start(mockCallListener, new Metadata()); // Make the transport available with subchannel2 - Subchannel subchannel1 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); - Subchannel subchannel2 = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel1 = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); + Subchannel subchannel2 = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); subchannel2.requestConnection(); MockClientTransportInfo transportInfo = transports.poll(); @@ -2514,7 +2586,8 @@ public class ManagedChannelImplTest { createChannel(); assertEquals(TARGET, getStats(channel).target); - Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); assertEquals(Collections.singletonList(addressGroup).toString(), getStats((AbstractSubchannel) subchannel).target); } @@ -2537,7 +2610,8 @@ public class ManagedChannelImplTest { createChannel(); timer.forwardNanos(1234); AbstractSubchannel subchannel = - (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + (AbstractSubchannel) createSubchannelSafely( + helper, addressGroup, Attributes.EMPTY, subchannelStateListener); assertThat(getStats(channel).channelTrace.events).contains(new ChannelTrace.Event.Builder() .setDescription("Child Subchannel created") .setSeverity(ChannelTrace.Event.Severity.CT_INFO) @@ -2710,7 +2784,8 @@ public class ManagedChannelImplTest { channelBuilder.maxTraceEvents(10); createChannel(); AbstractSubchannel subchannel = - (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + (AbstractSubchannel) createSubchannelSafely( + helper, addressGroup, Attributes.EMPTY, subchannelStateListener); timer.forwardNanos(1234); subchannel.obtainActiveTransport(); assertThat(getStats(subchannel).channelTrace.events).contains(new ChannelTrace.Event.Builder() @@ -2773,7 +2848,8 @@ public class ManagedChannelImplTest { assertEquals(CONNECTING, getStats(channel).state); AbstractSubchannel subchannel = - (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + (AbstractSubchannel) createSubchannelSafely( + helper, addressGroup, Attributes.EMPTY, subchannelStateListener); assertEquals(IDLE, getStats(subchannel).state); subchannel.requestConnection(); @@ -2827,7 +2903,8 @@ public class ManagedChannelImplTest { ClientStream mockStream = mock(ClientStream.class); ClientStreamTracer.Factory factory = mock(ClientStreamTracer.Factory.class); AbstractSubchannel subchannel = - (AbstractSubchannel) createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + (AbstractSubchannel) createSubchannelSafely( + helper, addressGroup, Attributes.EMPTY, subchannelStateListener); subchannel.requestConnection(); MockClientTransportInfo transportInfo = transports.poll(); transportInfo.listener.transportReady(); @@ -3064,7 +3141,8 @@ public class ManagedChannelImplTest { .build()); // simulating request connection and then transport ready after resolved address - Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); subchannel.requestConnection(); @@ -3163,7 +3241,8 @@ public class ManagedChannelImplTest { .build()); // simulating request connection and then transport ready after resolved address - Subchannel subchannel = helper.createSubchannel(addressGroup, Attributes.EMPTY); + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) .thenReturn(PickResult.withSubchannel(subchannel)); subchannel.requestConnection(); @@ -4060,13 +4139,18 @@ public class ManagedChannelImplTest { // We need this because createSubchannel() should be called from the SynchronizationContext private static Subchannel createSubchannelSafely( - final Helper helper, final EquivalentAddressGroup addressGroup, final Attributes attrs) { + final Helper helper, final EquivalentAddressGroup addressGroup, final Attributes attrs, + final SubchannelStateListener stateListener) { final AtomicReference resultCapture = new AtomicReference<>(); helper.getSynchronizationContext().execute( new Runnable() { @Override public void run() { - resultCapture.set(helper.createSubchannel(addressGroup, attrs)); + resultCapture.set(helper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(addressGroup) + .setAttributes(attrs) + .setStateListener(stateListener) + .build())); } }); return resultCapture.get(); diff --git a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java index ea9427f606..bff5f778d1 100644 --- a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import static com.google.common.truth.Truth.assertThat; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; @@ -26,7 +27,6 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -37,12 +37,14 @@ import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Status; import java.net.SocketAddress; import java.util.List; @@ -75,7 +77,7 @@ public class PickFirstLoadBalancerTest { @Captor private ArgumentCaptor pickerCaptor; @Captor - private ArgumentCaptor attrsCaptor; + private ArgumentCaptor createArgsCaptor; @Mock private Helper mockHelper; @Mock @@ -92,16 +94,17 @@ public class PickFirstLoadBalancerTest { } when(mockSubchannel.getAllAddresses()).thenThrow(new UnsupportedOperationException()); - when(mockHelper.createSubchannel( - ArgumentMatchers.anyList(), any(Attributes.class))) - .thenReturn(mockSubchannel); + when(mockHelper.createSubchannel(any(CreateSubchannelArgs.class))).thenReturn(mockSubchannel); loadBalancer = new PickFirstLoadBalancer(mockHelper); } @After + @SuppressWarnings("deprecation") public void tearDown() throws Exception { verifyNoMoreInteractions(mockArgs); + verify(mockHelper, never()).createSubchannel( + ArgumentMatchers.anyList(), any(Attributes.class)); } @Test @@ -109,7 +112,9 @@ public class PickFirstLoadBalancerTest { loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(servers).setAttributes(affinity).build()); - verify(mockHelper).createSubchannel(eq(servers), attrsCaptor.capture()); + verify(mockHelper).createSubchannel(createArgsCaptor.capture()); + CreateSubchannelArgs args = createArgsCaptor.getValue(); + assertThat(args.getAddresses()).isEqualTo(servers); verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); verify(mockSubchannel).requestConnection(); @@ -128,8 +133,8 @@ public class PickFirstLoadBalancerTest { ResolvedAddresses.newBuilder().setServers(servers).setAttributes(affinity).build()); verifyNoMoreInteractions(mockSubchannel); - verify(mockHelper).createSubchannel(ArgumentMatchers.anyList(), - any(Attributes.class)); + verify(mockHelper).createSubchannel(createArgsCaptor.capture()); + assertThat(createArgsCaptor.getValue()).isNotNull(); verify(mockHelper) .updateBalancingState(isA(ConnectivityState.class), isA(SubchannelPicker.class)); // Updating the subchannel addresses is unnecessary, but doesn't hurt anything @@ -149,7 +154,9 @@ public class PickFirstLoadBalancerTest { loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(servers).setAttributes(affinity).build()); - inOrder.verify(mockHelper).createSubchannel(eq(servers), any(Attributes.class)); + inOrder.verify(mockHelper).createSubchannel(createArgsCaptor.capture()); + CreateSubchannelArgs args = createArgsCaptor.getValue(); + assertThat(args.getAddresses()).isEqualTo(servers); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); verify(mockSubchannel).requestConnection(); assertEquals(mockSubchannel, pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); @@ -162,34 +169,30 @@ public class PickFirstLoadBalancerTest { verifyNoMoreInteractions(mockHelper); } - @Test - public void stateChangeBeforeResolution() throws Exception { - loadBalancer.handleSubchannelState(mockSubchannel, ConnectivityStateInfo.forNonError(READY)); - - verifyNoMoreInteractions(mockHelper); - } - @Test public void pickAfterStateChangeAfterResolution() throws Exception { - loadBalancer.handleResolvedAddresses( - ResolvedAddresses.newBuilder().setServers(servers).setAttributes(affinity).build()); - verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - Subchannel subchannel = pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel(); - reset(mockHelper); - InOrder inOrder = inOrder(mockHelper); + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setServers(servers).setAttributes(affinity).build()); + inOrder.verify(mockHelper).createSubchannel(createArgsCaptor.capture()); + CreateSubchannelArgs args = createArgsCaptor.getValue(); + assertThat(args.getAddresses()).isEqualTo(servers); + SubchannelStateListener stateListener = args.getStateListener(); + verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + Subchannel subchannel = pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel(); + Status error = Status.UNAVAILABLE.withDescription("boom!"); - loadBalancer.handleSubchannelState(subchannel, + stateListener.onSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); assertEquals(error, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); - loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); + stateListener.onSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); assertEquals(Status.OK, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); - loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + stateListener.onSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertEquals(subchannel, pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); @@ -219,7 +222,10 @@ public class PickFirstLoadBalancerTest { loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(servers).setAttributes(affinity).build()); - inOrder.verify(mockHelper).createSubchannel(eq(servers), eq(Attributes.EMPTY)); + inOrder.verify(mockHelper).createSubchannel(createArgsCaptor.capture()); + CreateSubchannelArgs args = createArgsCaptor.getValue(); + assertThat(args.getAddresses()).isEqualTo(servers); + assertThat(args.getAttributes()).isEqualTo(Attributes.EMPTY); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); verify(mockSubchannel).requestConnection(); @@ -235,9 +241,21 @@ public class PickFirstLoadBalancerTest { @Test public void nameResolutionErrorWithStateChanges() throws Exception { InOrder inOrder = inOrder(mockHelper); + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setServers(servers).setAttributes(affinity).build()); + inOrder.verify(mockHelper).createSubchannel(createArgsCaptor.capture()); + CreateSubchannelArgs args = createArgsCaptor.getValue(); + assertThat(args.getAddresses()).isEqualTo(servers); - loadBalancer.handleSubchannelState(mockSubchannel, + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + + SubchannelStateListener stateListener = args.getStateListener(); + + stateListener.onSubchannelState(mockSubchannel, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + inOrder.verify(mockHelper).updateBalancingState( + eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + Status error = Status.NOT_FOUND.withDescription("nameResolutionError"); loadBalancer.handleNameResolutionError(error); inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); @@ -246,7 +264,7 @@ public class PickFirstLoadBalancerTest { assertEquals(null, pickResult.getSubchannel()); assertEquals(error, pickResult.getStatus()); - loadBalancer.handleSubchannelState(mockSubchannel, ConnectivityStateInfo.forNonError(READY)); + stateListener.onSubchannelState(mockSubchannel, ConnectivityStateInfo.forNonError(READY)); Status error2 = Status.NOT_FOUND.withDescription("nameResolutionError2"); loadBalancer.handleNameResolutionError(error2); inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); @@ -262,7 +280,12 @@ public class PickFirstLoadBalancerTest { public void requestConnection() { loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(servers).setAttributes(affinity).build()); - loadBalancer.handleSubchannelState(mockSubchannel, ConnectivityStateInfo.forNonError(IDLE)); + verify(mockHelper).createSubchannel(createArgsCaptor.capture()); + CreateSubchannelArgs args = createArgsCaptor.getValue(); + assertThat(args.getAddresses()).isEqualTo(servers); + SubchannelStateListener stateListener = args.getStateListener(); + + stateListener.onSubchannelState(mockSubchannel, ConnectivityStateInfo.forNonError(IDLE)); verify(mockHelper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); SubchannelPicker picker = pickerCaptor.getValue(); diff --git a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index 4255fc4e14..f2a52a949c 100644 --- a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -34,7 +34,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.atLeast; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; @@ -51,12 +50,14 @@ import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Metadata; import io.grpc.Metadata.Key; import io.grpc.Status; @@ -79,6 +80,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; @@ -94,6 +96,8 @@ public class RoundRobinLoadBalancerTest { private RoundRobinLoadBalancer loadBalancer; private final List servers = Lists.newArrayList(); private final Map, Subchannel> subchannels = Maps.newLinkedHashMap(); + private final Map subchannelStateListeners = + Maps.newLinkedHashMap(); private final Attributes affinity = Attributes.newBuilder().set(MAJOR_KEY, "I got the keys").build(); @@ -102,7 +106,7 @@ public class RoundRobinLoadBalancerTest { @Captor private ArgumentCaptor stateCaptor; @Captor - private ArgumentCaptor> eagListCaptor; + private ArgumentCaptor createArgsCaptor; @Mock private Helper mockHelper; @@ -119,17 +123,18 @@ public class RoundRobinLoadBalancerTest { EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); servers.add(eag); Subchannel sc = mock(Subchannel.class); - when(sc.getAllAddresses()).thenReturn(Arrays.asList(eag)); subchannels.put(Arrays.asList(eag), sc); } - when(mockHelper.createSubchannel(any(List.class), any(Attributes.class))) + when(mockHelper.createSubchannel(any(CreateSubchannelArgs.class))) .then(new Answer() { @Override public Subchannel answer(InvocationOnMock invocation) throws Throwable { - Object[] args = invocation.getArguments(); - Subchannel subchannel = subchannels.get(args[0]); - when(subchannel.getAttributes()).thenReturn((Attributes) args[1]); + CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; + Subchannel subchannel = subchannels.get(args.getAddresses()); + when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); + when(subchannel.getAttributes()).thenReturn(args.getAttributes()); + subchannelStateListeners.put(subchannel, args.getStateListener()); return subchannel; } }); @@ -138,8 +143,11 @@ public class RoundRobinLoadBalancerTest { } @After + @SuppressWarnings("deprecation") public void tearDown() throws Exception { verifyNoMoreInteractions(mockArgs); + verify(mockHelper, never()).createSubchannel( + ArgumentMatchers.>any(), any(Attributes.class)); } @Test @@ -147,12 +155,15 @@ public class RoundRobinLoadBalancerTest { final Subchannel readySubchannel = subchannels.values().iterator().next(); loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(servers).setAttributes(affinity).build()); - loadBalancer.handleSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); - verify(mockHelper, times(3)).createSubchannel(eagListCaptor.capture(), - any(Attributes.class)); + verify(mockHelper, times(3)).createSubchannel(createArgsCaptor.capture()); + List> capturedAddrs = new ArrayList<>(); + for (CreateSubchannelArgs arg : createArgsCaptor.getAllValues()) { + capturedAddrs.add(arg.getAddresses()); + } - assertThat(eagListCaptor.getAllValues()).containsAllIn(subchannels.keySet()); + assertThat(capturedAddrs).containsAllIn(subchannels.keySet()); for (Subchannel subchannel : subchannels.values()) { verify(subchannel).requestConnection(); verify(subchannel, never()).shutdown(); @@ -187,35 +198,25 @@ public class RoundRobinLoadBalancerTest { Subchannel subchannel = allSubchannels.get(i); List eagList = Arrays.asList(new EquivalentAddressGroup(allAddrs.get(i))); - when(subchannel.getAttributes()).thenReturn(Attributes.newBuilder().set(STATE_INFO, - new Ref<>( - ConnectivityStateInfo.forNonError(READY))).build()); - when(subchannel.getAllAddresses()).thenReturn(eagList); + subchannels.put(eagList, subchannel); } - final Map, Subchannel> subchannels2 = Maps.newHashMap(); - subchannels2.put(Arrays.asList(new EquivalentAddressGroup(removedAddr)), removedSubchannel); - subchannels2.put(Arrays.asList(new EquivalentAddressGroup(oldAddr)), oldSubchannel); - List currentServers = Lists.newArrayList( new EquivalentAddressGroup(removedAddr), new EquivalentAddressGroup(oldAddr)); - doAnswer(new Answer() { - @Override - public Subchannel answer(InvocationOnMock invocation) throws Throwable { - Object[] args = invocation.getArguments(); - return subchannels2.get(args[0]); - } - }).when(mockHelper).createSubchannel(any(List.class), any(Attributes.class)); + InOrder inOrder = inOrder(mockHelper); loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(currentServers).setAttributes(affinity).build()); - InOrder inOrder = inOrder(mockHelper); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); + deliverSubchannelState(removedSubchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(oldSubchannel, ConnectivityStateInfo.forNonError(READY)); + + inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); SubchannelPicker picker = pickerCaptor.getValue(); assertThat(getList(picker)).containsExactly(removedSubchannel, oldSubchannel); @@ -225,10 +226,6 @@ public class RoundRobinLoadBalancerTest { assertThat(loadBalancer.getSubchannels()).containsExactly(removedSubchannel, oldSubchannel); - subchannels2.clear(); - subchannels2.put(Arrays.asList(new EquivalentAddressGroup(oldAddr)), oldSubchannel); - subchannels2.put(Arrays.asList(new EquivalentAddressGroup(newAddr)), newSubchannel); - List latestServers = Lists.newArrayList( new EquivalentAddressGroup(oldAddr), @@ -240,14 +237,14 @@ public class RoundRobinLoadBalancerTest { verify(newSubchannel, times(1)).requestConnection(); verify(removedSubchannel, times(1)).shutdown(); - loadBalancer.handleSubchannelState(removedSubchannel, - ConnectivityStateInfo.forNonError(SHUTDOWN)); + deliverSubchannelState(removedSubchannel, ConnectivityStateInfo.forNonError(SHUTDOWN)); + deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY)); assertThat(loadBalancer.getSubchannels()).containsExactly(oldSubchannel, newSubchannel); - verify(mockHelper, times(3)).createSubchannel(any(List.class), any(Attributes.class)); - inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); + verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); picker = pickerCaptor.getValue(); assertThat(getList(picker)).containsExactly(oldSubchannel, newSubchannel); @@ -278,7 +275,7 @@ public class RoundRobinLoadBalancerTest { inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(IDLE)); - loadBalancer.handleSubchannelState(subchannel, + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class); @@ -286,20 +283,20 @@ public class RoundRobinLoadBalancerTest { ConnectivityStateInfo.forNonError(READY)); Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"); - loadBalancer.handleSubchannelState(subchannel, + deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); assertThat(subchannelStateInfo.value).isEqualTo( ConnectivityStateInfo.forTransientFailure(error)); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); - loadBalancer.handleSubchannelState(subchannel, + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); assertThat(subchannelStateInfo.value).isEqualTo( ConnectivityStateInfo.forNonError(IDLE)); verify(subchannel, times(2)).requestConnection(); - verify(mockHelper, times(3)).createSubchannel(any(List.class), any(Attributes.class)); + verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verifyNoMoreInteractions(mockHelper); } @@ -351,10 +348,10 @@ public class RoundRobinLoadBalancerTest { final Subchannel readySubchannel = subchannels.values().iterator().next(); loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(servers).setAttributes(affinity).build()); - loadBalancer.handleSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); - verify(mockHelper, times(3)).createSubchannel(any(List.class), any(Attributes.class)); + verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(mockHelper, times(3)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); @@ -385,12 +382,11 @@ public class RoundRobinLoadBalancerTest { verify(sc2, times(1)).requestConnection(); verify(sc3, times(1)).requestConnection(); - loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY)); - loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(READY)); - loadBalancer.handleSubchannelState(sc3, ConnectivityStateInfo.forNonError(READY)); - loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(IDLE)); - loadBalancer - .handleSubchannelState(sc3, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + deliverSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(sc2, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(sc3, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(sc2, ConnectivityStateInfo.forNonError(IDLE)); + deliverSubchannelState(sc3, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); verify(mockHelper, times(6)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); @@ -422,7 +418,7 @@ public class RoundRobinLoadBalancerTest { loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(servers).setAttributes(Attributes.EMPTY).build()); for (Subchannel subchannel : subchannels.values()) { - loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); } verify(mockHelper, times(4)) .updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture()); @@ -456,7 +452,7 @@ public class RoundRobinLoadBalancerTest { loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(servers).setAttributes(attributes).build()); for (Subchannel subchannel : subchannels.values()) { - loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); } verify(mockHelper, times(4)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); @@ -489,7 +485,7 @@ public class RoundRobinLoadBalancerTest { loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(servers).setAttributes(attributes).build()); for (Subchannel subchannel : subchannels.values()) { - loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); } verify(mockHelper, times(4)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); @@ -520,7 +516,7 @@ public class RoundRobinLoadBalancerTest { loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(servers).setAttributes(attributes).build()); for (Subchannel subchannel : subchannels.values()) { - loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); } verify(mockHelper, times(4)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); @@ -566,7 +562,7 @@ public class RoundRobinLoadBalancerTest { loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(servers).setAttributes(attributes).build()); for (Subchannel subchannel : subchannels.values()) { - loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); } verify(mockHelper, times(4)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); @@ -581,8 +577,7 @@ public class RoundRobinLoadBalancerTest { Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel(); // go to transient failure - loadBalancer - .handleSubchannelState(sc1, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + deliverSubchannelState(sc1, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); verify(mockHelper, times(5)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); @@ -592,7 +587,7 @@ public class RoundRobinLoadBalancerTest { Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel(); // go back to ready - loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY)); verify(mockHelper, times(6)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); @@ -615,7 +610,7 @@ public class RoundRobinLoadBalancerTest { loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(servers).setAttributes(attributes).build()); for (Subchannel subchannel : subchannels.values()) { - loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); } verify(mockHelper, times(4)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); @@ -630,8 +625,7 @@ public class RoundRobinLoadBalancerTest { Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel(); // go to transient failure - loadBalancer - .handleSubchannelState(sc1, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + deliverSubchannelState(sc1, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); Metadata headerWithStickinessValue2 = new Metadata(); headerWithStickinessValue2.put(stickinessKey, "my-sticky-value2"); @@ -645,7 +639,7 @@ public class RoundRobinLoadBalancerTest { Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel(); // go back to ready - loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY)); doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders(); verify(mockHelper, times(6)) @@ -670,7 +664,7 @@ public class RoundRobinLoadBalancerTest { loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder().setServers(servers).setAttributes(attributes).build()); for (Subchannel subchannel : subchannels.values()) { - loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); } verify(mockHelper, times(4)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); @@ -686,8 +680,7 @@ public class RoundRobinLoadBalancerTest { Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel(); // shutdown channel directly - loadBalancer - .handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(ConnectivityState.SHUTDOWN)); + deliverSubchannelState(sc1, ConnectivityStateInfo.forNonError(ConnectivityState.SHUTDOWN)); assertNull(loadBalancer.getStickinessMapForTest().get("my-sticky-value").value); @@ -709,7 +702,7 @@ public class RoundRobinLoadBalancerTest { verify(sc2, times(1)).shutdown(); - loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(SHUTDOWN)); + deliverSubchannelState(sc2, ConnectivityStateInfo.forNonError(SHUTDOWN)); assertNull(loadBalancer.getStickinessMapForTest().get("my-sticky-value").value); @@ -795,6 +788,10 @@ public class RoundRobinLoadBalancerTest { Collections.emptyList(); } + private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { + subchannelStateListeners.get(subchannel).onSubchannelState(subchannel, newState); + } + private static class FakeSocketAddress extends SocketAddress { final String name; diff --git a/grpclb/src/main/java/io/grpc/grpclb/CachedSubchannelPool.java b/grpclb/src/main/java/io/grpc/grpclb/CachedSubchannelPool.java index b70b5d5907..a63ce4089b 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/CachedSubchannelPool.java +++ b/grpclb/src/main/java/io/grpc/grpclb/CachedSubchannelPool.java @@ -16,6 +16,7 @@ package io.grpc.grpclb; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; @@ -23,9 +24,10 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; -import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.SynchronizationContext.ScheduledHandle; import java.util.HashMap; import java.util.concurrent.TimeUnit; @@ -39,33 +41,41 @@ final class CachedSubchannelPool implements SubchannelPool { new HashMap<>(); private Helper helper; - private LoadBalancer lb; @VisibleForTesting static final long SHUTDOWN_TIMEOUT_MS = 10000; @Override - public void init(Helper helper, LoadBalancer lb) { + public void init(Helper helper) { this.helper = checkNotNull(helper, "helper"); - this.lb = checkNotNull(lb, "lb"); } @Override public Subchannel takeOrCreateSubchannel( - EquivalentAddressGroup eag, Attributes defaultAttributes) { - final CacheEntry entry = cache.remove(eag); + EquivalentAddressGroup eag, Attributes defaultAttributes, SubchannelStateListener listener) { + final CacheEntry entry = cache.get(eag); final Subchannel subchannel; if (entry == null) { - subchannel = helper.createSubchannel(eag, defaultAttributes); + final CacheEntry newEntry = new CacheEntry(); + subchannel = helper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(eag) + .setAttributes(defaultAttributes) + .setStateListener(new StateListener(newEntry)) + .build()); + newEntry.init(subchannel); + cache.put(eag, newEntry); + newEntry.taken(listener); } else { subchannel = entry.subchannel; - entry.shutdownTimer.cancel(); - // Make the balancer up-to-date with the latest state in case it has changed while it's - // in the cache. + checkState(eag.equals(subchannel.getAddresses()), + "Unexpected address change from %s to %s", eag, subchannel.getAddresses()); + entry.taken(listener); + // Make the listener up-to-date with the latest state in case it has changed while it's in the + // cache. helper.getSynchronizationContext().execute(new Runnable() { @Override public void run() { - lb.handleSubchannelState(subchannel, entry.state); + entry.maybeNotifyStateListener(); } }); } @@ -73,40 +83,19 @@ final class CachedSubchannelPool implements SubchannelPool { } @Override - public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newStateInfo) { - CacheEntry cached = cache.get(subchannel.getAddresses()); - if (cached == null || cached.subchannel != subchannel) { - // Given subchannel is not cached. Not our responsibility. - return; - } - cached.state = newStateInfo; - } - - @Override - public void returnSubchannel(Subchannel subchannel, ConnectivityStateInfo lastKnownState) { - CacheEntry prev = cache.get(subchannel.getAddresses()); - if (prev != null) { - // Returning the same Subchannel twice has no effect. - // Returning a different Subchannel for an already cached EAG will cause the - // latter Subchannel to be shutdown immediately. - if (prev.subchannel != subchannel) { - subchannel.shutdown(); - } - return; - } - final ShutdownSubchannelTask shutdownTask = new ShutdownSubchannelTask(subchannel); - ScheduledHandle shutdownTimer = - helper.getSynchronizationContext().schedule( - shutdownTask, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS, - helper.getScheduledExecutorService()); - CacheEntry entry = new CacheEntry(subchannel, shutdownTimer, lastKnownState); - cache.put(subchannel.getAddresses(), entry); + public void returnSubchannel(Subchannel subchannel) { + CacheEntry entry = cache.get(subchannel.getAddresses()); + checkArgument(entry != null, "Cache record for %s not found", subchannel); + checkArgument(entry.subchannel == subchannel, + "Subchannel being returned (%s) doesn't match the cache (%s)", + subchannel, entry.subchannel); + entry.returned(); } @Override public void clear() { for (CacheEntry entry : cache.values()) { - entry.shutdownTimer.cancel(); + entry.cancelShutdownTimer(); entry.subchannel.shutdown(); } cache.clear(); @@ -125,19 +114,65 @@ final class CachedSubchannelPool implements SubchannelPool { public void run() { CacheEntry entry = cache.remove(subchannel.getAddresses()); checkState(entry.subchannel == subchannel, "Inconsistent state"); + entry.cancelShutdownTimer(); subchannel.shutdown(); } } - private static class CacheEntry { - final Subchannel subchannel; - final ScheduledHandle shutdownTimer; + private class CacheEntry { + Subchannel subchannel; + ScheduledHandle shutdownTimer; ConnectivityStateInfo state; + // Not null if outside of pool + SubchannelStateListener stateListener; - CacheEntry(Subchannel subchannel, ScheduledHandle shutdownTimer, ConnectivityStateInfo state) { + void init(Subchannel subchannel) { this.subchannel = checkNotNull(subchannel, "subchannel"); - this.shutdownTimer = checkNotNull(shutdownTimer, "shutdownTimer"); - this.state = checkNotNull(state, "state"); + } + + void cancelShutdownTimer() { + if (shutdownTimer != null) { + shutdownTimer.cancel(); + shutdownTimer = null; + } + } + + void taken(SubchannelStateListener listener) { + checkState(stateListener == null, "Already out of pool"); + stateListener = checkNotNull(listener, "listener"); + cancelShutdownTimer(); + } + + void returned() { + checkState(stateListener != null, "Already in pool"); + if (shutdownTimer == null) { + shutdownTimer = helper.getSynchronizationContext().schedule( + new ShutdownSubchannelTask(subchannel), SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS, + helper.getScheduledExecutorService()); + } else { + checkState(shutdownTimer.isPending()); + } + stateListener = null; + } + + void maybeNotifyStateListener() { + if (stateListener != null && state != null) { + stateListener.onSubchannelState(subchannel, state); + } + } + } + + private static final class StateListener implements SubchannelStateListener { + private final CacheEntry entry; + + StateListener(CacheEntry entry) { + this.entry = checkNotNull(entry, "entry"); + } + + @Override + public void onSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { + entry.state = newState; + entry.maybeNotifyStateListener(); } } } diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java index 6a93afc401..e4e2bc86ab 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java @@ -24,7 +24,6 @@ import com.google.common.base.Stopwatch; import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; -import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.Status; @@ -75,18 +74,11 @@ class GrpclbLoadBalancer extends LoadBalancer { this.stopwatch = checkNotNull(stopwatch, "stopwatch"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool"); - this.subchannelPool.init(helper, this); + this.subchannelPool.init(helper); recreateStates(); checkNotNull(grpclbState, "grpclbState"); } - @Override - public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { - // grpclbState should never be null here since handleSubchannelState cannot be called while the - // lb is shutdown. - grpclbState.handleSubchannelState(subchannel, newState); - } - @Override public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { List updatedServers = resolvedAddresses.getServers(); diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java index 14295aca0b..a936009931 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java @@ -35,11 +35,13 @@ import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.Status; @@ -81,7 +83,7 @@ import javax.annotation.concurrent.NotThreadSafe; * switches away from GRPCLB mode. */ @NotThreadSafe -final class GrpclbState { +final class GrpclbState implements SubchannelStateListener { static final long FALLBACK_TIMEOUT_MS = TimeUnit.SECONDS.toMillis(10); private static final Attributes LB_PROVIDED_BACKEND_ATTRS = Attributes.newBuilder().set(GrpcAttributes.ATTR_LB_PROVIDED_BACKEND, true).build(); @@ -171,14 +173,12 @@ final class GrpclbState { this.logger = checkNotNull(helper.getChannelLogger(), "logger"); } - void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { + @Override + public void onSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { if (newState.getState() == SHUTDOWN) { return; } if (!subchannels.values().contains(subchannel)) { - if (subchannelPool != null ) { - subchannelPool.handleSubchannelState(subchannel, newState); - } return; } if (mode == Mode.ROUND_ROBIN && newState.getState() == IDLE) { @@ -323,7 +323,7 @@ final class GrpclbState { // We close the subchannels through subchannelPool instead of helper just for convenience of // testing. for (Subchannel subchannel : subchannels.values()) { - returnSubchannelToPool(subchannel); + subchannelPool.returnSubchannel(subchannel); } subchannelPool.clear(); break; @@ -347,10 +347,6 @@ final class GrpclbState { } } - private void returnSubchannelToPool(Subchannel subchannel) { - subchannelPool.returnSubchannel(subchannel, subchannel.getAttributes().get(STATE_INFO).get()); - } - @VisibleForTesting @Nullable GrpclbClientLoadRecorder getLoadRecorder() { @@ -381,7 +377,8 @@ final class GrpclbState { if (subchannel == null) { subchannel = subchannels.get(eagAsList); if (subchannel == null) { - subchannel = subchannelPool.takeOrCreateSubchannel(eag, createSubchannelAttrs()); + subchannel = subchannelPool.takeOrCreateSubchannel( + eag, createSubchannelAttrs(), this); subchannel.requestConnection(); } newSubchannelMap.put(eagAsList, subchannel); @@ -399,7 +396,7 @@ final class GrpclbState { for (Entry, Subchannel> entry : subchannels.entrySet()) { List eagList = entry.getKey(); if (!newSubchannelMap.containsKey(eagList)) { - returnSubchannelToPool(entry.getValue()); + subchannelPool.returnSubchannel(entry.getValue()); } } subchannels = Collections.unmodifiableMap(newSubchannelMap); @@ -422,7 +419,12 @@ final class GrpclbState { } Subchannel subchannel; if (subchannels.isEmpty()) { - subchannel = helper.createSubchannel(eagList, createSubchannelAttrs()); + subchannel = + helper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(eagList) + .setAttributes(createSubchannelAttrs()) + .setStateListener(this) + .build()); } else { checkState(subchannels.size() == 1, "Unexpected Subchannel count: %s", subchannels); subchannel = subchannels.values().iterator().next(); diff --git a/grpclb/src/main/java/io/grpc/grpclb/SubchannelPool.java b/grpclb/src/main/java/io/grpc/grpclb/SubchannelPool.java index 0d328fdb09..f210dfdf5a 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/SubchannelPool.java +++ b/grpclb/src/main/java/io/grpc/grpclb/SubchannelPool.java @@ -17,11 +17,10 @@ package io.grpc.grpclb; import io.grpc.Attributes; -import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; -import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelStateListener; import javax.annotation.concurrent.NotThreadSafe; /** @@ -34,26 +33,33 @@ interface SubchannelPool { /** * Pass essential utilities and the balancer that's using this pool. */ - void init(Helper helper, LoadBalancer lb); + void init(Helper helper); /** * Takes a {@link Subchannel} from the pool for the given {@code eag} if there is one available. * Otherwise, creates and returns a new {@code Subchannel} with the given {@code eag} and {@code * defaultAttributes}. + * + *

There can be at most one Subchannel for each EAG. After a Subchannel is taken out of the + * pool, it must be returned before the same EAG can be used to call this method. + * + * @param defaultAttributes the attributes used to create the Subchannel. Not used if a pooled + * subchannel is returned. + * @param stateListener receives state updates from now on */ - Subchannel takeOrCreateSubchannel(EquivalentAddressGroup eag, Attributes defaultAttributes); - - /** - * Gets notified about a state change of Subchannel that is possibly cached in this pool. Do - * nothing if this pool doesn't own this Subchannel. - */ - void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newStateInfo); + Subchannel takeOrCreateSubchannel( + EquivalentAddressGroup eag, Attributes defaultAttributes, + SubchannelStateListener stateListener); /** * Puts a {@link Subchannel} back to the pool. From this point the Subchannel is owned by the - * pool, and the caller should stop referencing to this Subchannel. + * pool, and the caller should stop referencing to this Subchannel. The {@link + * SubchannelStateListener} will not receive any more updates. + * + *

Can only be called with a Subchannel created by this pool. Must not be called if the + * Subchannel is already in the pool. */ - void returnSubchannel(Subchannel subchannel, ConnectivityStateInfo lastKnownState); + void returnSubchannel(Subchannel subchannel); /** * Shuts down all subchannels in the pool immediately. diff --git a/grpclb/src/test/java/io/grpc/grpclb/CachedSubchannelPoolTest.java b/grpclb/src/test/java/io/grpc/grpclb/CachedSubchannelPoolTest.java index ad2b683f58..4de04ecb83 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/CachedSubchannelPoolTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/CachedSubchannelPoolTest.java @@ -19,9 +19,8 @@ package io.grpc.grpclb; import static com.google.common.truth.Truth.assertThat; import static io.grpc.grpclb.CachedSubchannelPool.SHUTDOWN_TIMEOUT_MS; import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static org.mockito.ArgumentMatchers.eq; +import static org.junit.Assert.fail; import static org.mockito.Mockito.any; -import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -36,21 +35,25 @@ import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; -import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.grpclb.CachedSubchannelPool.ShutdownSubchannelTask; import io.grpc.internal.FakeClock; import java.util.ArrayList; -import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.hamcrest.MockitoHamcrest; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -79,7 +82,7 @@ public class CachedSubchannelPoolTest { }; private final Helper helper = mock(Helper.class); - private final LoadBalancer balancer = mock(LoadBalancer.class); + private final SubchannelStateListener mockListener = mock(SubchannelStateListener.class); private final FakeClock clock = new FakeClock(); private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @@ -90,6 +93,8 @@ public class CachedSubchannelPoolTest { }); private final CachedSubchannelPool pool = new CachedSubchannelPool(); private final ArrayList mockSubchannels = new ArrayList<>(); + // Listeners seen by the Helper + private final Map stateListeners = new HashMap<>(); @Before @SuppressWarnings("unchecked") @@ -98,26 +103,17 @@ public class CachedSubchannelPoolTest { @Override public Subchannel answer(InvocationOnMock invocation) throws Throwable { Subchannel subchannel = mock(Subchannel.class); - List eagList = - (List) invocation.getArguments()[0]; - Attributes attrs = (Attributes) invocation.getArguments()[1]; - when(subchannel.getAllAddresses()).thenReturn(eagList); - when(subchannel.getAttributes()).thenReturn(attrs); + CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; + when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); + when(subchannel.getAttributes()).thenReturn(args.getAttributes()); mockSubchannels.add(subchannel); + stateListeners.put(subchannel, args.getStateListener()); return subchannel; } - }).when(helper).createSubchannel(any(List.class), any(Attributes.class)); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - syncContext.throwIfNotInThisSynchronizationContext(); - return null; - } - }).when(balancer).handleSubchannelState( - any(Subchannel.class), any(ConnectivityStateInfo.class)); + }).when(helper).createSubchannel(any(CreateSubchannelArgs.class)); when(helper.getSynchronizationContext()).thenReturn(syncContext); when(helper.getScheduledExecutorService()).thenReturn(clock.getScheduledExecutorService()); - pool.init(helper, balancer); + pool.init(helper); } @After @@ -126,29 +122,26 @@ public class CachedSubchannelPoolTest { for (Subchannel subchannel : mockSubchannels) { verify(subchannel, atMost(1)).shutdown(); } - verify(balancer, atLeast(0)) - .handleSubchannelState(any(Subchannel.class), any(ConnectivityStateInfo.class)); - verifyNoMoreInteractions(balancer); } @Test public void subchannelExpireAfterReturned() { - Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1); + Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1, mockListener); assertThat(subchannel1).isNotNull(); - verify(helper).createSubchannel(eq(Arrays.asList(EAG1)), same(ATTRS1)); + verify(helper).createSubchannel(argsWith(EAG1, "1")); - Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2); + Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2, mockListener); assertThat(subchannel2).isNotNull(); assertThat(subchannel2).isNotSameAs(subchannel1); - verify(helper).createSubchannel(eq(Arrays.asList(EAG2)), same(ATTRS2)); + verify(helper).createSubchannel(argsWith(EAG2, "2")); - pool.returnSubchannel(subchannel1, READY_STATE); + pool.returnSubchannel(subchannel1); // subchannel1 is 1ms away from expiration. clock.forwardTime(SHUTDOWN_TIMEOUT_MS - 1, MILLISECONDS); verify(subchannel1, never()).shutdown(); - pool.returnSubchannel(subchannel2, READY_STATE); + pool.returnSubchannel(subchannel2); // subchannel1 expires. subchannel2 is (SHUTDOWN_TIMEOUT_MS - 1) away from expiration. clock.forwardTime(1, MILLISECONDS); @@ -163,25 +156,25 @@ public class CachedSubchannelPoolTest { @Test public void subchannelReused() { - Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1); + Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1, mockListener); assertThat(subchannel1).isNotNull(); - verify(helper).createSubchannel(eq(Arrays.asList(EAG1)), same(ATTRS1)); + verify(helper).createSubchannel(argsWith(EAG1, "1")); - Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2); + Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2, mockListener); assertThat(subchannel2).isNotNull(); assertThat(subchannel2).isNotSameAs(subchannel1); - verify(helper).createSubchannel(eq(Arrays.asList(EAG2)), same(ATTRS2)); + verify(helper).createSubchannel(argsWith(EAG2, "2")); - pool.returnSubchannel(subchannel1, READY_STATE); + pool.returnSubchannel(subchannel1); // subchannel1 is 1ms away from expiration. clock.forwardTime(SHUTDOWN_TIMEOUT_MS - 1, MILLISECONDS); // This will cancel the shutdown timer for subchannel1 - Subchannel subchannel1a = pool.takeOrCreateSubchannel(EAG1, ATTRS1); + Subchannel subchannel1a = pool.takeOrCreateSubchannel(EAG1, ATTRS1, mockListener); assertThat(subchannel1a).isSameAs(subchannel1); - pool.returnSubchannel(subchannel2, READY_STATE); + pool.returnSubchannel(subchannel2); // subchannel2 expires SHUTDOWN_TIMEOUT_MS after being returned clock.forwardTime(SHUTDOWN_TIMEOUT_MS - 1, MILLISECONDS); @@ -190,12 +183,12 @@ public class CachedSubchannelPoolTest { verify(subchannel2).shutdown(); // pool will create a new channel for EAG2 when requested - Subchannel subchannel2a = pool.takeOrCreateSubchannel(EAG2, ATTRS2); + Subchannel subchannel2a = pool.takeOrCreateSubchannel(EAG2, ATTRS2, mockListener); assertThat(subchannel2a).isNotSameAs(subchannel2); - verify(helper, times(2)).createSubchannel(eq(Arrays.asList(EAG2)), same(ATTRS2)); + verify(helper, times(2)).createSubchannel(argsWith(EAG2, "2")); // subchannel1 expires SHUTDOWN_TIMEOUT_MS after being returned - pool.returnSubchannel(subchannel1a, READY_STATE); + pool.returnSubchannel(subchannel1a); clock.forwardTime(SHUTDOWN_TIMEOUT_MS - 1, MILLISECONDS); verify(subchannel1a, never()).shutdown(); clock.forwardTime(1, MILLISECONDS); @@ -206,93 +199,136 @@ public class CachedSubchannelPoolTest { @Test public void updateStateWhileInPool() { - Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1); - Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2); - pool.returnSubchannel(subchannel1, READY_STATE); - pool.returnSubchannel(subchannel2, TRANSIENT_FAILURE_STATE); + Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1, mockListener); + Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2, mockListener); + + // Simulate state updates while they are in the pool + stateListeners.get(subchannel1).onSubchannelState(subchannel1, TRANSIENT_FAILURE_STATE); + stateListeners.get(subchannel2).onSubchannelState(subchannel1, TRANSIENT_FAILURE_STATE); + + verify(mockListener).onSubchannelState(same(subchannel1), same(TRANSIENT_FAILURE_STATE)); + verify(mockListener).onSubchannelState(same(subchannel2), same(TRANSIENT_FAILURE_STATE)); + + pool.returnSubchannel(subchannel1); + pool.returnSubchannel(subchannel2); ConnectivityStateInfo anotherFailureState = ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription("Another")); - pool.handleSubchannelState(subchannel1, anotherFailureState); + // Simulate a subchannel state update while it's in the pool + stateListeners.get(subchannel1).onSubchannelState(subchannel1, anotherFailureState); - verify(balancer, never()) - .handleSubchannelState(any(Subchannel.class), any(ConnectivityStateInfo.class)); + SubchannelStateListener mockListener1 = mock(SubchannelStateListener.class); + SubchannelStateListener mockListener2 = mock(SubchannelStateListener.class); - assertThat(pool.takeOrCreateSubchannel(EAG1, ATTRS1)).isSameAs(subchannel1); - verify(balancer).handleSubchannelState(same(subchannel1), same(anotherFailureState)); - verifyNoMoreInteractions(balancer); + // Saved state is populated to new mockListeners + assertThat(pool.takeOrCreateSubchannel(EAG1, ATTRS1, mockListener1)).isSameAs(subchannel1); + verify(mockListener1).onSubchannelState(same(subchannel1), same(anotherFailureState)); + verifyNoMoreInteractions(mockListener1); - assertThat(pool.takeOrCreateSubchannel(EAG2, ATTRS2)).isSameAs(subchannel2); - verify(balancer).handleSubchannelState(same(subchannel2), same(TRANSIENT_FAILURE_STATE)); - verifyNoMoreInteractions(balancer); + assertThat(pool.takeOrCreateSubchannel(EAG2, ATTRS2, mockListener2)).isSameAs(subchannel2); + verify(mockListener2).onSubchannelState(same(subchannel2), same(TRANSIENT_FAILURE_STATE)); + verifyNoMoreInteractions(mockListener2); + + // The old mockListener doesn't receive more updates + verifyNoMoreInteractions(mockListener); } @Test - public void updateStateWhileInPool_notSameObject() { - Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1); - pool.returnSubchannel(subchannel1, READY_STATE); - - Subchannel subchannel2 = helper.createSubchannel(EAG1, ATTRS1); - Subchannel subchannel3 = helper.createSubchannel(EAG2, ATTRS2); - - // subchannel2 is not in the pool, although with the same address - pool.handleSubchannelState(subchannel2, TRANSIENT_FAILURE_STATE); - - // subchannel3 is not in the pool. In fact its address is not in the pool - pool.handleSubchannelState(subchannel3, TRANSIENT_FAILURE_STATE); - - assertThat(pool.takeOrCreateSubchannel(EAG1, ATTRS1)).isSameAs(subchannel1); - - // subchannel1's state is unchanged - verify(balancer).handleSubchannelState(same(subchannel1), same(READY_STATE)); - verifyNoMoreInteractions(balancer); + public void takeTwice_willThrow() { + Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1, mockListener); + try { + pool.takeOrCreateSubchannel(EAG1, ATTRS1, mockListener); + fail("Should throw"); + } catch (IllegalStateException e) { + assertThat(e).hasMessageThat().contains("Already out of pool"); + } } @Test - public void returnDuplicateAddressSubchannel() { - Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1); - Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG1, ATTRS2); - Subchannel subchannel3 = pool.takeOrCreateSubchannel(EAG2, ATTRS1); - assertThat(subchannel1).isNotSameAs(subchannel2); + public void returnTwice_willThrow() { + Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1, mockListener); + pool.returnSubchannel(subchannel1); + try { + pool.returnSubchannel(subchannel1); + fail("Should throw"); + } catch (IllegalStateException e) { + assertThat(e).hasMessageThat().contains("Already in pool"); + } + } - assertThat(clock.getPendingTasks(SHUTDOWN_TASK_FILTER)).isEmpty(); - pool.returnSubchannel(subchannel2, READY_STATE); - assertThat(clock.getPendingTasks(SHUTDOWN_TASK_FILTER)).hasSize(1); + @Test + public void returnNonPoolSubchannelWillThrow_noSuchAddress() { + Subchannel subchannel1 = helper.createSubchannel( + CreateSubchannelArgs.newBuilder() + .setAddresses(EAG1).setStateListener(mockListener) + .build()); + try { + pool.returnSubchannel(subchannel1); + fail("Should throw"); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("not found"); + } + } - // If the subchannel being returned has an address that is the same as a subchannel in the pool, - // the returned subchannel will be shut down. - verify(subchannel1, never()).shutdown(); - pool.returnSubchannel(subchannel1, READY_STATE); - assertThat(clock.getPendingTasks(SHUTDOWN_TASK_FILTER)).hasSize(1); - verify(subchannel1).shutdown(); - - pool.returnSubchannel(subchannel3, READY_STATE); - assertThat(clock.getPendingTasks(SHUTDOWN_TASK_FILTER)).hasSize(2); - // Returning the same subchannel twice has no effect. - pool.returnSubchannel(subchannel3, READY_STATE); - assertThat(clock.getPendingTasks(SHUTDOWN_TASK_FILTER)).hasSize(2); - - verify(subchannel2, never()).shutdown(); - verify(subchannel3, never()).shutdown(); + @Test + public void returnNonPoolSubchannelWillThrow_unmatchedSubchannel() { + Subchannel subchannel1 = helper.createSubchannel( + CreateSubchannelArgs.newBuilder() + .setAddresses(EAG1).setStateListener(mockListener) + .build()); + Subchannel subchannel1c = pool.takeOrCreateSubchannel(EAG1, ATTRS1, mockListener); + assertThat(subchannel1).isNotSameAs(subchannel1c); + try { + pool.returnSubchannel(subchannel1); + fail("Should throw"); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("doesn't match the cache"); + } } @Test public void clear() { - Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1); - Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2); - Subchannel subchannel3 = pool.takeOrCreateSubchannel(EAG2, ATTRS2); + Subchannel subchannel1 = pool.takeOrCreateSubchannel(EAG1, ATTRS1, mockListener); + Subchannel subchannel2 = pool.takeOrCreateSubchannel(EAG2, ATTRS2, mockListener); - pool.returnSubchannel(subchannel1, READY_STATE); - pool.returnSubchannel(subchannel2, READY_STATE); + pool.returnSubchannel(subchannel1); verify(subchannel1, never()).shutdown(); verify(subchannel2, never()).shutdown(); pool.clear(); verify(subchannel1).shutdown(); verify(subchannel2).shutdown(); - - verify(subchannel3, never()).shutdown(); assertThat(clock.numPendingTasks()).isEqualTo(0); } + + private CreateSubchannelArgs argsWith( + final EquivalentAddressGroup expectedEag, final Object expectedValue) { + return MockitoHamcrest.argThat( + new org.hamcrest.BaseMatcher() { + @Override + public boolean matches(Object item) { + if (!(item instanceof CreateSubchannelArgs)) { + return false; + } + CreateSubchannelArgs that = (CreateSubchannelArgs) item; + List expectedEagList = Collections.singletonList(expectedEag); + if (!expectedEagList.equals(that.getAddresses())) { + return false; + } + if (!expectedValue.equals(that.getAttributes().get(ATTR_KEY))) { + return false; + } + return true; + } + + @Override + public void describeTo(org.hamcrest.Description desc) { + desc.appendText( + "Matches Attributes that includes " + expectedEag + " and " + + ATTR_KEY + "=" + expectedValue); + } + }); + } + } diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java index f7424d0ca9..a425b855b9 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java @@ -20,7 +20,6 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; -import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.grpclb.GrpclbLoadBalancer.retrieveModeFromLbConfig; import static io.grpc.grpclb.GrpclbState.BUFFER_ENTRY; @@ -57,12 +56,14 @@ import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.Status; @@ -96,6 +97,7 @@ import java.text.MessageFormat; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -165,10 +167,13 @@ public class GrpclbLoadBalancerTest { private LoadBalancerGrpc.LoadBalancerImplBase mockLbService; @Captor private ArgumentCaptor> lbResponseObserverCaptor; + @Captor + private ArgumentCaptor createArgsCaptor; private final FakeClock fakeClock = new FakeClock(); private final LinkedList> lbRequestObservers = new LinkedList<>(); private final LinkedList mockSubchannels = new LinkedList<>(); + private final Map subchannelStateListeners = new HashMap<>(); private final LinkedList fakeOobChannels = new LinkedList<>(); private final ArrayList pooledSubchannelTracker = new ArrayList<>(); private final ArrayList unpooledSubchannelTracker = new ArrayList<>(); @@ -257,24 +262,26 @@ public class GrpclbLoadBalancerTest { when(subchannel.getAttributes()).thenReturn(attrs); mockSubchannels.add(subchannel); pooledSubchannelTracker.add(subchannel); + subchannelStateListeners.put( + subchannel, (SubchannelStateListener) invocation.getArguments()[2]); return subchannel; } }).when(subchannelPool).takeOrCreateSubchannel( - any(EquivalentAddressGroup.class), any(Attributes.class)); + any(EquivalentAddressGroup.class), any(Attributes.class), + any(SubchannelStateListener.class)); doAnswer(new Answer() { @Override public Subchannel answer(InvocationOnMock invocation) throws Throwable { Subchannel subchannel = mock(Subchannel.class); - List eagList = - (List) invocation.getArguments()[0]; - Attributes attrs = (Attributes) invocation.getArguments()[1]; - when(subchannel.getAllAddresses()).thenReturn(eagList); - when(subchannel.getAttributes()).thenReturn(attrs); + CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; + when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); + when(subchannel.getAttributes()).thenReturn(args.getAttributes()); mockSubchannels.add(subchannel); unpooledSubchannelTracker.add(subchannel); + subchannelStateListeners.put(subchannel, args.getStateListener()); return subchannel; } - }).when(helper).createSubchannel(any(List.class), any(Attributes.class)); + }).when(helper).createSubchannel(any(CreateSubchannelArgs.class)); when(helper.getSynchronizationContext()).thenReturn(syncContext); when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService()); when(helper.getChannelLogger()).thenReturn(channelLogger); @@ -293,7 +300,7 @@ public class GrpclbLoadBalancerTest { balancer = new GrpclbLoadBalancer(helper, subchannelPool, fakeClock.getTimeProvider(), fakeClock.getStopwatchSupplier().get(), backoffPolicyProvider); - verify(subchannelPool).init(same(helper), same(balancer)); + verify(subchannelPool).init(same(helper)); } @After @@ -315,7 +322,7 @@ public class GrpclbLoadBalancerTest { } // GRPCLB manages subchannels only through subchannelPool for (Subchannel subchannel : pooledSubchannelTracker) { - verify(subchannelPool).returnSubchannel(same(subchannel), any(ConnectivityStateInfo.class)); + verify(subchannelPool).returnSubchannel(same(subchannel)); // Our mock subchannelPool never calls Subchannel.shutdown(), thus we can tell if // LoadBalancer has called it expectedly. verify(subchannel, never()).shutdown(); @@ -690,7 +697,8 @@ public class GrpclbLoadBalancerTest { // Same backends, thus no new subchannels helperInOrder.verify(subchannelPool, never()).takeOrCreateSubchannel( - any(EquivalentAddressGroup.class), any(Attributes.class)); + any(EquivalentAddressGroup.class), any(Attributes.class), + any(SubchannelStateListener.class)); // But the new RoundRobinEntries have a new loadRecorder, thus considered different from // the previous list, thus a new picker is created helperInOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); @@ -871,10 +879,12 @@ public class GrpclbLoadBalancerTest { inOrder.verify(subchannelPool).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends.get(0).addr, LB_BACKEND_ATTRS)), - any(Attributes.class)); + any(Attributes.class), + any(SubchannelStateListener.class)); inOrder.verify(subchannelPool).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends.get(1).addr, LB_BACKEND_ATTRS)), - any(Attributes.class)); + any(Attributes.class), + any(SubchannelStateListener.class)); } @Test @@ -957,10 +967,12 @@ public class GrpclbLoadBalancerTest { inOrder.verify(subchannelPool).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends1.get(0).addr, LB_BACKEND_ATTRS)), - any(Attributes.class)); + any(Attributes.class), + any(SubchannelStateListener.class)); inOrder.verify(subchannelPool).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends1.get(1).addr, LB_BACKEND_ATTRS)), - any(Attributes.class)); + any(Attributes.class), + any(SubchannelStateListener.class)); assertEquals(2, mockSubchannels.size()); Subchannel subchannel1 = mockSubchannels.poll(); Subchannel subchannel2 = mockSubchannels.poll(); @@ -1068,8 +1080,7 @@ public class GrpclbLoadBalancerTest { new ServerEntry("127.0.0.1", 2010, "token0004"), // Existing address with token changed new ServerEntry("127.0.0.1", 2030, "token0005"), // New address appearing second time new ServerEntry("token0006")); // drop - verify(subchannelPool, never()) - .returnSubchannel(same(subchannel1), any(ConnectivityStateInfo.class)); + verify(subchannelPool, never()).returnSubchannel(same(subchannel1)); lbResponseObserver.onNext(buildLbResponse(backends2)); assertThat(logs).containsExactly( @@ -1085,23 +1096,23 @@ public class GrpclbLoadBalancerTest { logs.clear(); // not in backends2, closed - verify(subchannelPool).returnSubchannel(same(subchannel1), same(errorState1)); + verify(subchannelPool).returnSubchannel(same(subchannel1)); // backends2[2], will be kept - verify(subchannelPool, never()) - .returnSubchannel(same(subchannel2), any(ConnectivityStateInfo.class)); + verify(subchannelPool, never()).returnSubchannel(same(subchannel2)); inOrder.verify(subchannelPool, never()).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends2.get(2).addr, LB_BACKEND_ATTRS)), - any(Attributes.class)); + any(Attributes.class), + any(SubchannelStateListener.class)); inOrder.verify(subchannelPool).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends2.get(0).addr, LB_BACKEND_ATTRS)), - any(Attributes.class)); + any(Attributes.class), + any(SubchannelStateListener.class)); ConnectivityStateInfo errorOnCachedSubchannel1 = ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("You can get this error even if you are cached")); deliverSubchannelState(subchannel1, errorOnCachedSubchannel1); - verify(subchannelPool).handleSubchannelState(same(subchannel1), same(errorOnCachedSubchannel1)); assertEquals(1, mockSubchannels.size()); Subchannel subchannel3 = mockSubchannels.poll(); @@ -1119,17 +1130,6 @@ public class GrpclbLoadBalancerTest { new DropEntry(getLoadRecorder(), "token0006")).inOrder(); assertThat(picker7.pickList).containsExactly(BUFFER_ENTRY); - // State updates on obsolete subchannel1 will only be passed to the pool - deliverSubchannelState(subchannel1, ConnectivityStateInfo.forNonError(READY)); - deliverSubchannelState( - subchannel1, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); - deliverSubchannelState(subchannel1, ConnectivityStateInfo.forNonError(SHUTDOWN)); - inOrder.verify(subchannelPool) - .handleSubchannelState(same(subchannel1), eq(ConnectivityStateInfo.forNonError(READY))); - inOrder.verify(subchannelPool).handleSubchannelState( - same(subchannel1), eq(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE))); - inOrder.verifyNoMoreInteractions(); - deliverSubchannelState(subchannel3, ConnectivityStateInfo.forNonError(READY)); inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); RoundRobinPicker picker8 = (RoundRobinPicker) pickerCaptor.getValue(); @@ -1157,15 +1157,12 @@ public class GrpclbLoadBalancerTest { new BackendEntry(subchannel3, getLoadRecorder(), "token0003"), new BackendEntry(subchannel2, getLoadRecorder(), "token0004"), new BackendEntry(subchannel3, getLoadRecorder(), "token0005")).inOrder(); - verify(subchannelPool, never()) - .returnSubchannel(same(subchannel3), any(ConnectivityStateInfo.class)); + verify(subchannelPool, never()).returnSubchannel(same(subchannel3)); // Update backends, with no entry lbResponseObserver.onNext(buildLbResponse(Collections.emptyList())); - verify(subchannelPool) - .returnSubchannel(same(subchannel2), eq(ConnectivityStateInfo.forNonError(READY))); - verify(subchannelPool) - .returnSubchannel(same(subchannel3), eq(ConnectivityStateInfo.forNonError(READY))); + verify(subchannelPool).returnSubchannel(same(subchannel2)); + verify(subchannelPool).returnSubchannel(same(subchannel3)); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); RoundRobinPicker picker10 = (RoundRobinPicker) pickerCaptor.getValue(); assertThat(picker10.dropList).isEmpty(); @@ -1511,9 +1508,9 @@ public class GrpclbLoadBalancerTest { if (!(balancerBroken && allSubchannelsBroken)) { verify(subchannelPool, never()).takeOrCreateSubchannel( - eq(resolutionList.get(0)), any(Attributes.class)); + eq(resolutionList.get(0)), any(Attributes.class), any(SubchannelStateListener.class)); verify(subchannelPool, never()).takeOrCreateSubchannel( - eq(resolutionList.get(2)), any(Attributes.class)); + eq(resolutionList.get(2)), any(Attributes.class), any(SubchannelStateListener.class)); } } @@ -1540,7 +1537,8 @@ public class GrpclbLoadBalancerTest { assertEquals(addrs.size(), tokens.size()); } for (EquivalentAddressGroup addr : addrs) { - inOrder.verify(subchannelPool).takeOrCreateSubchannel(eq(addr), any(Attributes.class)); + inOrder.verify(subchannelPool).takeOrCreateSubchannel( + eq(addr), any(Attributes.class), any(SubchannelStateListener.class)); } RoundRobinPicker picker = (RoundRobinPicker) currentPicker; assertThat(picker.dropList).containsExactlyElementsIn(Collections.nCopies(addrs.size(), null)); @@ -1743,11 +1741,11 @@ public class GrpclbLoadBalancerTest { lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(backends1)); - inOrder.verify(helper).createSubchannel( - eq(Arrays.asList( - new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), - new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))), - any(Attributes.class)); + inOrder.verify(helper).createSubchannel(createArgsCaptor.capture()); + assertThat(createArgsCaptor.getValue().getAddresses()).containsExactly( + new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), + new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002"))) + .inOrder(); // Initially IDLE inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); @@ -1798,7 +1796,7 @@ public class GrpclbLoadBalancerTest { // new addresses will be updated to the existing subchannel // createSubchannel() has ever been called only once - verify(helper, times(1)).createSubchannel(any(List.class), any(Attributes.class)); + verify(helper, times(1)).createSubchannel(any(CreateSubchannelArgs.class)); assertThat(mockSubchannels).isEmpty(); inOrder.verify(helper).updateSubchannelAddresses( same(subchannel), @@ -1830,10 +1828,10 @@ public class GrpclbLoadBalancerTest { verify(subchannel, times(2)).requestConnection(); // PICK_FIRST doesn't use subchannelPool - verify(subchannelPool, never()) - .takeOrCreateSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)); - verify(subchannelPool, never()) - .returnSubchannel(any(Subchannel.class), any(ConnectivityStateInfo.class)); + verify(subchannelPool, never()).takeOrCreateSubchannel( + any(EquivalentAddressGroup.class), any(Attributes.class), + any(SubchannelStateListener.class)); + verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class)); } @SuppressWarnings("unchecked") @@ -1860,9 +1858,9 @@ public class GrpclbLoadBalancerTest { fakeClock.forwardTime(GrpclbState.FALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS); // Entering fallback mode - inOrder.verify(helper).createSubchannel( - eq(Arrays.asList(grpclbResolutionList.get(0), grpclbResolutionList.get(2))), - any(Attributes.class)); + inOrder.verify(helper).createSubchannel(createArgsCaptor.capture()); + assertThat(createArgsCaptor.getValue().getAddresses()).containsExactly( + grpclbResolutionList.get(0), grpclbResolutionList.get(2)).inOrder(); assertThat(mockSubchannels).hasSize(1); Subchannel subchannel = mockSubchannels.poll(); @@ -1894,7 +1892,7 @@ public class GrpclbLoadBalancerTest { // new addresses will be updated to the existing subchannel // createSubchannel() has ever been called only once - verify(helper, times(1)).createSubchannel(any(List.class), any(Attributes.class)); + verify(helper, times(1)).createSubchannel(any(CreateSubchannelArgs.class)); assertThat(mockSubchannels).isEmpty(); inOrder.verify(helper).updateSubchannelAddresses( same(subchannel), @@ -1909,10 +1907,10 @@ public class GrpclbLoadBalancerTest { new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); // PICK_FIRST doesn't use subchannelPool - verify(subchannelPool, never()) - .takeOrCreateSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)); - verify(subchannelPool, never()) - .returnSubchannel(any(Subchannel.class), any(ConnectivityStateInfo.class)); + verify(subchannelPool, never()).takeOrCreateSubchannel( + any(EquivalentAddressGroup.class), any(Attributes.class), + any(SubchannelStateListener.class)); + verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class)); } @Test @@ -1949,16 +1947,15 @@ public class GrpclbLoadBalancerTest { // ROUND_ROBIN: create one subchannel per server verify(subchannelPool).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends1.get(0).addr, LB_BACKEND_ATTRS)), - any(Attributes.class)); + any(Attributes.class), any(SubchannelStateListener.class)); verify(subchannelPool).takeOrCreateSubchannel( eq(new EquivalentAddressGroup(backends1.get(1).addr, LB_BACKEND_ATTRS)), - any(Attributes.class)); + any(Attributes.class), any(SubchannelStateListener.class)); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); assertEquals(2, mockSubchannels.size()); Subchannel subchannel1 = mockSubchannels.poll(); Subchannel subchannel2 = mockSubchannels.poll(); - verify(subchannelPool, never()) - .returnSubchannel(any(Subchannel.class), any(ConnectivityStateInfo.class)); + verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class)); // Switch to PICK_FIRST lbConfig = "{\"childPolicy\" : [ {\"pick_first\" : {}} ]}"; @@ -1969,10 +1966,8 @@ public class GrpclbLoadBalancerTest { // GrpclbState will be shutdown, and a new one will be created assertThat(oobChannel.isShutdown()).isTrue(); - verify(subchannelPool) - .returnSubchannel(same(subchannel1), eq(ConnectivityStateInfo.forNonError(IDLE))); - verify(subchannelPool) - .returnSubchannel(same(subchannel2), eq(ConnectivityStateInfo.forNonError(IDLE))); + verify(subchannelPool).returnSubchannel(same(subchannel1)); + verify(subchannelPool).returnSubchannel(same(subchannel2)); // A new LB stream is created assertEquals(1, fakeOobChannels.size()); @@ -1992,11 +1987,11 @@ public class GrpclbLoadBalancerTest { lbResponseObserver.onNext(buildLbResponse(backends1)); // PICK_FIRST Subchannel - inOrder.verify(helper).createSubchannel( - eq(Arrays.asList( - new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), - new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))), - any(Attributes.class)); + inOrder.verify(helper).createSubchannel(createArgsCaptor.capture()); + assertThat(createArgsCaptor.getValue().getAddresses()).containsExactly( + new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), + new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002"))) + .inOrder(); inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); } @@ -2081,7 +2076,7 @@ public class GrpclbLoadBalancerTest { syncContext.execute(new Runnable() { @Override public void run() { - balancer.handleSubchannelState(subchannel, newState); + subchannelStateListeners.get(subchannel).onSubchannelState(subchannel, newState); } }); } diff --git a/services/src/main/java/io/grpc/services/HealthCheckingLoadBalancerFactory.java b/services/src/main/java/io/grpc/services/HealthCheckingLoadBalancerFactory.java index d199002a56..1d680d73f6 100644 --- a/services/src/main/java/io/grpc/services/HealthCheckingLoadBalancerFactory.java +++ b/services/src/main/java/io/grpc/services/HealthCheckingLoadBalancerFactory.java @@ -16,6 +16,7 @@ package io.grpc.services; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static io.grpc.ConnectivityState.CONNECTING; @@ -27,17 +28,17 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Objects; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; -import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ClientCall; import io.grpc.ConnectivityStateInfo; -import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Factory; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.Status.Code; @@ -53,7 +54,6 @@ import io.grpc.internal.ServiceConfigUtil; import io.grpc.util.ForwardingLoadBalancer; import io.grpc.util.ForwardingLoadBalancerHelper; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -70,8 +70,6 @@ import javax.annotation.Nullable; * SynchronizationContext, or it will throw. */ final class HealthCheckingLoadBalancerFactory extends Factory { - private static final Attributes.Key KEY_HEALTH_CHECK_STATE = - Attributes.Key.create("io.grpc.services.HealthCheckingLoadBalancerFactory.healthCheckState"); private static final Logger logger = Logger.getLogger(HealthCheckingLoadBalancerFactory.class.getName()); @@ -91,15 +89,13 @@ final class HealthCheckingLoadBalancerFactory extends Factory { public LoadBalancer newLoadBalancer(Helper helper) { HelperImpl wrappedHelper = new HelperImpl(helper); LoadBalancer delegateBalancer = delegateFactory.newLoadBalancer(wrappedHelper); - wrappedHelper.init(delegateBalancer); return new HealthCheckingLoadBalancer(wrappedHelper, delegateBalancer); } private final class HelperImpl extends ForwardingLoadBalancerHelper { private final Helper delegate; private final SynchronizationContext syncContext; - - private LoadBalancer delegateBalancer; + @Nullable String healthCheckedService; private boolean balancerShutdown; @@ -110,26 +106,21 @@ final class HealthCheckingLoadBalancerFactory extends Factory { this.syncContext = checkNotNull(delegate.getSynchronizationContext(), "syncContext"); } - void init(LoadBalancer delegateBalancer) { - checkState(this.delegateBalancer == null, "init() already called"); - this.delegateBalancer = checkNotNull(delegateBalancer, "delegateBalancer"); - } - @Override protected Helper delegate() { return delegate; } @Override - public Subchannel createSubchannel(List addrs, Attributes attrs) { + public Subchannel createSubchannel(CreateSubchannelArgs args) { // HealthCheckState is not thread-safe, we are requiring the original LoadBalancer calls // createSubchannel() from the SynchronizationContext. syncContext.throwIfNotInThisSynchronizationContext(); HealthCheckState hcState = new HealthCheckState( - this, delegateBalancer, syncContext, delegate.getScheduledExecutorService()); + this, args.getStateListener(), syncContext, delegate.getScheduledExecutorService()); hcStates.add(hcState); - Subchannel subchannel = super.createSubchannel( - addrs, attrs.toBuilder().set(KEY_HEALTH_CHECK_STATE, hcState).build()); + Subchannel subchannel = + super.createSubchannel(args.toBuilder().setStateListener(hcState).build()); hcState.init(subchannel); if (healthCheckedService != null) { hcState.setServiceName(healthCheckedService); @@ -177,27 +168,15 @@ final class HealthCheckingLoadBalancerFactory extends Factory { super.handleResolvedAddresses(resolvedAddresses); } - @Override - public void handleSubchannelState( - Subchannel subchannel, ConnectivityStateInfo stateInfo) { - HealthCheckState hcState = - checkNotNull(subchannel.getAttributes().get(KEY_HEALTH_CHECK_STATE), "hcState"); - hcState.updateRawState(stateInfo); - - if (Objects.equal(stateInfo.getState(), SHUTDOWN)) { - helper.hcStates.remove(hcState); - } - } - @Override public void shutdown() { super.shutdown(); helper.balancerShutdown = true; for (HealthCheckState hcState : helper.hcStates) { - // ManagedChannel will stop calling handleSubchannelState() after shutdown() is called, + // ManagedChannel will stop calling onSubchannelState() after shutdown() is called, // which is required by LoadBalancer API semantics. We need to deliver the final SHUTDOWN // signal to health checkers so that they can cancel the streams. - hcState.updateRawState(ConnectivityStateInfo.forNonError(SHUTDOWN)); + hcState.onSubchannelState(hcState.subchannel, ConnectivityStateInfo.forNonError(SHUTDOWN)); } helper.hcStates.clear(); } @@ -210,7 +189,7 @@ final class HealthCheckingLoadBalancerFactory extends Factory { // All methods are run from syncContext - private final class HealthCheckState { + private final class HealthCheckState implements SubchannelStateListener { private final Runnable retryTask = new Runnable() { @Override public void run() { @@ -218,7 +197,7 @@ final class HealthCheckingLoadBalancerFactory extends Factory { } }; - private final LoadBalancer delegate; + private final SubchannelStateListener stateListener; private final SynchronizationContext syncContext; private final ScheduledExecutorService timerService; private final HelperImpl helperImpl; @@ -246,10 +225,10 @@ final class HealthCheckingLoadBalancerFactory extends Factory { HealthCheckState( HelperImpl helperImpl, - LoadBalancer delegate, SynchronizationContext syncContext, + SubchannelStateListener stateListener, SynchronizationContext syncContext, ScheduledExecutorService timerService) { this.helperImpl = checkNotNull(helperImpl, "helperImpl"); - this.delegate = checkNotNull(delegate, "delegate"); + this.stateListener = checkNotNull(stateListener, "stateListener"); this.syncContext = checkNotNull(syncContext, "syncContext"); this.timerService = checkNotNull(timerService, "timerService"); } @@ -274,13 +253,19 @@ final class HealthCheckingLoadBalancerFactory extends Factory { adjustHealthCheck(); } - void updateRawState(ConnectivityStateInfo rawState) { + @Override + public void onSubchannelState(Subchannel subchannel, ConnectivityStateInfo rawState) { + checkArgument(subchannel == this.subchannel, + "Subchannel mismatch: %s vs %s", subchannel, this.subchannel); if (Objects.equal(this.rawState.getState(), READY) && !Objects.equal(rawState.getState(), READY)) { // A connection was lost. We will reset disabled flag because health check // may be available on the new connection. disabled = false; } + if (Objects.equal(rawState.getState(), SHUTDOWN)) { + helperImpl.hcStates.remove(this); + } this.rawState = rawState; adjustHealthCheck(); } @@ -339,7 +324,7 @@ final class HealthCheckingLoadBalancerFactory extends Factory { checkState(subchannel != null, "init() not called"); if (!helperImpl.balancerShutdown && !Objects.equal(concludedState, newState)) { concludedState = newState; - delegate.handleSubchannelState(subchannel, concludedState); + stateListener.onSubchannelState(subchannel, concludedState); } } diff --git a/services/src/test/java/io/grpc/services/HealthCheckingLoadBalancerFactoryTest.java b/services/src/test/java/io/grpc/services/HealthCheckingLoadBalancerFactoryTest.java index 3576f10c08..331722cf1a 100644 --- a/services/src/test/java/io/grpc/services/HealthCheckingLoadBalancerFactoryTest.java +++ b/services/src/test/java/io/grpc/services/HealthCheckingLoadBalancerFactoryTest.java @@ -24,7 +24,6 @@ import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static org.junit.Assert.fail; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -32,6 +31,7 @@ import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyZeroInteractions; @@ -47,11 +47,13 @@ import io.grpc.Context; import io.grpc.Context.CancellationListener; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Factory; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.ManagedChannel; import io.grpc.NameResolver; import io.grpc.Server; @@ -74,7 +76,6 @@ import java.text.MessageFormat; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Queue; @@ -108,8 +109,13 @@ public class HealthCheckingLoadBalancerFactoryTest { private final EquivalentAddressGroup[] eags = new EquivalentAddressGroup[NUM_SUBCHANNELS]; @SuppressWarnings({"rawtypes", "unchecked"}) private final List[] eagLists = new List[NUM_SUBCHANNELS]; + private final SubchannelStateListener[] mockStateListeners = + new SubchannelStateListener[NUM_SUBCHANNELS]; private List resolvedAddressList; private final FakeSubchannel[] subchannels = new FakeSubchannel[NUM_SUBCHANNELS]; + // State listeners seen by the real Helper. Use them to simulate raw Subchannel updates. + private final SubchannelStateListener[] stateListeners = + new SubchannelStateListener[NUM_SUBCHANNELS]; private final ManagedChannel[] channels = new ManagedChannel[NUM_SUBCHANNELS]; private final Server[] servers = new Server[NUM_SUBCHANNELS]; private final HealthImpl[] healthImpls = new HealthImpl[NUM_SUBCHANNELS]; @@ -139,7 +145,7 @@ public class HealthCheckingLoadBalancerFactoryTest { private LoadBalancer origLb; private LoadBalancer hcLb; @Captor - ArgumentCaptor attrsCaptor; + ArgumentCaptor createArgsCaptor; @Mock private BackoffPolicy.Provider backoffPolicyProvider; @Mock @@ -171,6 +177,7 @@ public class HealthCheckingLoadBalancerFactoryTest { eags[i] = eag; List eagList = Arrays.asList(eag); eagLists[i] = eagList; + mockStateListeners[i] = mock(SubchannelStateListener.class); } resolvedAddressList = Arrays.asList(eags); @@ -199,19 +206,6 @@ public class HealthCheckingLoadBalancerFactoryTest { }); } - @Override - public void handleSubchannelState( - final Subchannel subchannel, final ConnectivityStateInfo stateInfo) { - syncContext.execute(new Runnable() { - @Override - public void run() { - if (!shutdown) { - hcLb.handleSubchannelState(subchannel, stateInfo); - } - } - }); - } - @Override public void handleNameResolutionError(Status error) { throw new AssertionError("Not supposed to be called"); @@ -237,7 +231,7 @@ public class HealthCheckingLoadBalancerFactoryTest { public void teardown() throws Exception { // All scheduled tasks have been accounted for assertThat(clock.getPendingTasks()).isEmpty(); - // Health-check streams are usually not closed in the tests because handleSubchannelState() is + // Health-check streams are usually not closed in the tests because onSubchannelState() is // faked. Force closing for clean up. for (Server server : servers) { server.shutdownNow(); @@ -252,16 +246,6 @@ public class HealthCheckingLoadBalancerFactoryTest { } } - @Test - public void createSubchannelThrowsIfCalledOutsideSynchronizationContext() { - try { - wrappedHelper.createSubchannel(eagLists[0], Attributes.EMPTY); - fail("Should throw"); - } catch (IllegalStateException e) { - assertThat(e.getMessage()).isEqualTo("Not called from the SynchronizationContext"); - } - } - @Test public void typicalWorkflow() { Attributes resolutionAttrs = attrsWithHealthCheckService("FooService"); @@ -285,41 +269,41 @@ public class HealthCheckingLoadBalancerFactoryTest { .set(SUBCHANNEL_ATTR_KEY, subchannelAttrValue).build(); // We don't wrap Subchannels, thus origLb gets the original Subchannels. assertThat(createSubchannel(i, attrs)).isSameAs(subchannels[i]); - verify(origHelper).createSubchannel(same(eagLists[i]), attrsCaptor.capture()); - assertThat(attrsCaptor.getValue().get(SUBCHANNEL_ATTR_KEY)).isEqualTo(subchannelAttrValue); + verify(origHelper, times(i + 1)).createSubchannel(createArgsCaptor.capture()); + assertThat(createArgsCaptor.getValue().getAddresses()).isEqualTo(eagLists[i]); + assertThat(createArgsCaptor.getValue().getAttributes().get(SUBCHANNEL_ATTR_KEY)) + .isEqualTo(subchannelAttrValue); } for (int i = NUM_SUBCHANNELS - 1; i >= 0; i--) { // Not starting health check until underlying Subchannel is READY FakeSubchannel subchannel = subchannels[i]; HealthImpl healthImpl = healthImpls[i]; - InOrder inOrder = inOrder(origLb); - hcLbEventDelivery.handleSubchannelState( - subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); - hcLbEventDelivery.handleSubchannelState( - subchannel, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); - hcLbEventDelivery.handleSubchannelState( - subchannel, ConnectivityStateInfo.forNonError(IDLE)); + SubchannelStateListener mockStateListener = mockStateListeners[i]; + InOrder inOrder = inOrder(mockStateListener); + deliverSubchannelState(i, ConnectivityStateInfo.forNonError(CONNECTING)); + deliverSubchannelState(i, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + deliverSubchannelState(i, ConnectivityStateInfo.forNonError(IDLE)); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE))); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(IDLE))); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(mockStateListener); assertThat(subchannel.logs).isEmpty(); assertThat(healthImpl.calls).isEmpty(); - hcLbEventDelivery.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(i, ConnectivityStateInfo.forNonError(READY)); assertThat(healthImpl.calls).hasSize(1); ServerSideCall serverCall = healthImpl.calls.peek(); assertThat(serverCall.request).isEqualTo(makeRequest("FooService")); // Starting the health check will make the Subchannel appear CONNECTING to the origLb. - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(mockStateListener); assertThat(subchannel.logs).containsExactly( "INFO: CONNECTING: Starting health-check for \"FooService\""); @@ -333,35 +317,36 @@ public class HealthCheckingLoadBalancerFactoryTest { serverCall.responseObserver.onNext(makeResponse(servingStatus)); // SERVING is mapped to READY, while other statuses are mapped to TRANSIENT_FAILURE if (servingStatus == ServingStatus.SERVING) { - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(READY))); assertThat(subchannel.logs).containsExactly( "INFO: READY: health-check responded SERVING"); } else { - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListener).onSubchannelState( same(subchannel),unavailableStateWithMsg( "Health-check service responded " + servingStatus + " for 'FooService'")); assertThat(subchannel.logs).containsExactly( "INFO: TRANSIENT_FAILURE: health-check responded " + servingStatus); } subchannel.logs.clear(); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(mockStateListener); } } // origLb shuts down Subchannels for (int i = 0; i < NUM_SUBCHANNELS; i++) { FakeSubchannel subchannel = subchannels[i]; + SubchannelStateListener mockStateListener = mockStateListeners[i]; ServerSideCall serverCall = healthImpls[i].calls.peek(); assertThat(serverCall.cancelled).isFalse(); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(mockStateListener); // Subchannel enters SHUTDOWN state as a response to shutdown(), and that will cancel the // health check RPC subchannel.shutdown(); assertThat(serverCall.cancelled).isTrue(); - verify(origLb).handleSubchannelState( + verify(mockStateListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(SHUTDOWN))); assertThat(subchannel.logs).isEmpty(); } @@ -390,13 +375,12 @@ public class HealthCheckingLoadBalancerFactoryTest { createSubchannel(i, Attributes.EMPTY); } - InOrder inOrder = inOrder(origLb); + InOrder inOrder = inOrder(mockStateListeners[0], mockStateListeners[1]); for (int i = 0; i < 2; i++) { - hcLbEventDelivery.handleSubchannelState( - subchannels[i], ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(i, ConnectivityStateInfo.forNonError(READY)); assertThat(healthImpls[i].calls).hasSize(1); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListeners[i]).onSubchannelState( same(subchannels[i]), eq(ConnectivityStateInfo.forNonError(CONNECTING))); } @@ -409,7 +393,7 @@ public class HealthCheckingLoadBalancerFactoryTest { // In reality UNIMPLEMENTED is generated by GRPC server library, but the client can't tell // whether it's the server library or the service implementation that returned this status. serverCall0.responseObserver.onError(Status.UNIMPLEMENTED.asException()); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListeners[0]).onSubchannelState( same(subchannels[0]), eq(ConnectivityStateInfo.forNonError(READY))); assertThat(subchannels[0].logs).containsExactly( "ERROR: Health-check disabled: " + Status.UNIMPLEMENTED, @@ -417,32 +401,31 @@ public class HealthCheckingLoadBalancerFactoryTest { // subchannels[1] has normal health checking serverCall1.responseObserver.onNext(makeResponse(ServingStatus.NOT_SERVING)); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListeners[1]).onSubchannelState( same(subchannels[1]), unavailableStateWithMsg("Health-check service responded NOT_SERVING for 'BarService'")); - // Without health checking, states from underlying Subchannel are delivered directly to origLb - hcLbEventDelivery.handleSubchannelState( - subchannels[0], ConnectivityStateInfo.forNonError(IDLE)); - inOrder.verify(origLb).handleSubchannelState( + // Without health checking, states from underlying Subchannel are delivered directly to the mock + // listeners. + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(IDLE)); + inOrder.verify(mockStateListeners[0]).onSubchannelState( same(subchannels[0]), eq(ConnectivityStateInfo.forNonError(IDLE))); // Re-connecting on a Subchannel will reset the "disabled" flag. assertThat(healthImpls[0].calls).hasSize(0); - hcLbEventDelivery.handleSubchannelState( - subchannels[0], ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); assertThat(healthImpls[0].calls).hasSize(1); serverCall0 = healthImpls[0].calls.poll(); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListeners[0]).onSubchannelState( same(subchannels[0]), eq(ConnectivityStateInfo.forNonError(CONNECTING))); // Health check now works as normal serverCall0.responseObserver.onNext(makeResponse(ServingStatus.SERVICE_UNKNOWN)); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListeners[0]).onSubchannelState( same(subchannels[0]), unavailableStateWithMsg("Health-check service responded SERVICE_UNKNOWN for 'BarService'")); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(origLb, mockStateListeners[0], mockStateListeners[1]); verifyZeroInteractions(backoffPolicyProvider); } @@ -460,10 +443,11 @@ public class HealthCheckingLoadBalancerFactoryTest { FakeSubchannel subchannel = (FakeSubchannel) createSubchannel(0, Attributes.EMPTY); assertThat(subchannel).isSameAs(subchannels[0]); - InOrder inOrder = inOrder(origLb, backoffPolicyProvider, backoffPolicy1, backoffPolicy2); + SubchannelStateListener mockListener = mockStateListeners[0]; + InOrder inOrder = inOrder(mockListener, backoffPolicyProvider, backoffPolicy1, backoffPolicy2); - hcLbEventDelivery.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(origLb).handleSubchannelState( + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); HealthImpl healthImpl = healthImpls[0]; assertThat(healthImpl.calls).hasSize(1); @@ -474,7 +458,7 @@ public class HealthCheckingLoadBalancerFactoryTest { healthImpl.calls.poll().responseObserver.onCompleted(); // which results in TRANSIENT_FAILURE - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockListener).onSubchannelState( same(subchannel), unavailableStateWithMsg( "Health-check stream unexpectedly closed with " + Status.OK + " for 'TeeService'")); @@ -487,7 +471,7 @@ public class HealthCheckingLoadBalancerFactoryTest { inOrder.verify(backoffPolicy1).nextBackoffNanos(); assertThat(clock.getPendingTasks()).hasSize(1); - verifyRetryAfterNanos(inOrder, subchannel, healthImpl, 11); + verifyRetryAfterNanos(inOrder, mockListener, subchannel, healthImpl, 11); assertThat(clock.getPendingTasks()).isEmpty(); subchannel.logs.clear(); @@ -495,7 +479,7 @@ public class HealthCheckingLoadBalancerFactoryTest { healthImpl.calls.poll().responseObserver.onError(Status.CANCELLED.asException()); // which also results in TRANSIENT_FAILURE, with a different description - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockListener).onSubchannelState( same(subchannel), unavailableStateWithMsg( "Health-check stream unexpectedly closed with " @@ -507,15 +491,15 @@ public class HealthCheckingLoadBalancerFactoryTest { // Retry with backoff inOrder.verify(backoffPolicy1).nextBackoffNanos(); - verifyRetryAfterNanos(inOrder, subchannel, healthImpl, 21); + verifyRetryAfterNanos(inOrder, mockListener, subchannel, healthImpl, 21); // Server responds this time healthImpl.calls.poll().responseObserver.onNext(makeResponse(ServingStatus.SERVING)); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(READY))); - verifyNoMoreInteractions(origLb, backoffPolicyProvider, backoffPolicy1); + verifyNoMoreInteractions(origLb, mockListener, backoffPolicyProvider, backoffPolicy1); } @Test @@ -530,12 +514,14 @@ public class HealthCheckingLoadBalancerFactoryTest { verify(origLb).handleResolvedAddresses(result); verifyNoMoreInteractions(origLb); + SubchannelStateListener mockStateListener = mockStateListeners[0]; Subchannel subchannel = createSubchannel(0, Attributes.EMPTY); assertThat(subchannel).isSameAs(subchannels[0]); - InOrder inOrder = inOrder(origLb, backoffPolicyProvider, backoffPolicy1, backoffPolicy2); + InOrder inOrder = + inOrder(mockStateListener, backoffPolicyProvider, backoffPolicy1, backoffPolicy2); - hcLbEventDelivery.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(origLb).handleSubchannelState( + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockStateListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); HealthImpl healthImpl = healthImpls[0]; assertThat(healthImpl.calls).hasSize(1); @@ -545,7 +531,7 @@ public class HealthCheckingLoadBalancerFactoryTest { healthImpl.calls.poll().responseObserver.onError(Status.CANCELLED.asException()); // which results in TRANSIENT_FAILURE - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListener).onSubchannelState( same(subchannel), unavailableStateWithMsg( "Health-check stream unexpectedly closed with " @@ -556,19 +542,19 @@ public class HealthCheckingLoadBalancerFactoryTest { inOrder.verify(backoffPolicy1).nextBackoffNanos(); assertThat(clock.getPendingTasks()).hasSize(1); - verifyRetryAfterNanos(inOrder, subchannel, healthImpl, 11); + verifyRetryAfterNanos(inOrder, mockStateListener, subchannel, healthImpl, 11); assertThat(clock.getPendingTasks()).isEmpty(); // Server responds healthImpl.calls.peek().responseObserver.onNext(makeResponse(ServingStatus.SERVING)); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(READY))); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(mockStateListener); // then closes the stream healthImpl.calls.poll().responseObserver.onError(Status.UNAVAILABLE.asException()); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListener).onSubchannelState( same(subchannel), unavailableStateWithMsg( "Health-check stream unexpectedly closed with " @@ -577,7 +563,7 @@ public class HealthCheckingLoadBalancerFactoryTest { // Because server has responded, the first retry is not subject to backoff. // But the backoff policy has been reset. A new backoff policy will be used for // the next backed-off retry. - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); assertThat(healthImpl.calls).hasSize(1); assertThat(clock.getPendingTasks()).isEmpty(); @@ -585,7 +571,7 @@ public class HealthCheckingLoadBalancerFactoryTest { // then closes the stream for this retry healthImpl.calls.poll().responseObserver.onError(Status.UNAVAILABLE.asException()); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListener).onSubchannelState( same(subchannel), unavailableStateWithMsg( "Health-check stream unexpectedly closed with " @@ -596,18 +582,19 @@ public class HealthCheckingLoadBalancerFactoryTest { // Retry with a new backoff policy inOrder.verify(backoffPolicy2).nextBackoffNanos(); - verifyRetryAfterNanos(inOrder, subchannel, healthImpl, 12); + verifyRetryAfterNanos(inOrder, mockStateListener, subchannel, healthImpl, 12); } private void verifyRetryAfterNanos( - InOrder inOrder, Subchannel subchannel, HealthImpl impl, long nanos) { + InOrder inOrder, SubchannelStateListener listener, Subchannel subchannel, HealthImpl impl, + long nanos) { assertThat(impl.calls).isEmpty(); clock.forwardNanos(nanos - 1); assertThat(impl.calls).isEmpty(); inOrder.verifyNoMoreInteractions(); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(listener); clock.forwardNanos(1); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(listener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); assertThat(impl.calls).hasSize(1); } @@ -628,13 +615,12 @@ public class HealthCheckingLoadBalancerFactoryTest { createSubchannel(0, Attributes.EMPTY); // No health check activity. Underlying Subchannel states are directly propagated - hcLbEventDelivery.handleSubchannelState( - subchannels[0], ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); assertThat(healthImpls[0].calls).isEmpty(); - verify(origLb).handleSubchannelState( + verify(mockStateListeners[0]).onSubchannelState( same(subchannels[0]), eq(ConnectivityStateInfo.forNonError(READY))); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(mockStateListeners[0]); // Service config enables health check Attributes resolutionAttrs = attrsWithHealthCheckService("FooService"); @@ -649,13 +635,12 @@ public class HealthCheckingLoadBalancerFactoryTest { assertThat(healthImpls[0].calls).hasSize(1); // State stays in READY, instead of switching to CONNECTING. - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(mockStateListeners[0]); // Start Subchannel 1, which will have health check createSubchannel(1, Attributes.EMPTY); assertThat(healthImpls[1].calls).isEmpty(); - hcLbEventDelivery.handleSubchannelState( - subchannels[1], ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(1, ConnectivityStateInfo.forNonError(READY)); assertThat(healthImpls[1].calls).hasSize(1); } @@ -673,10 +658,10 @@ public class HealthCheckingLoadBalancerFactoryTest { Subchannel subchannel = createSubchannel(0, Attributes.EMPTY); assertThat(subchannel).isSameAs(subchannels[0]); - InOrder inOrder = inOrder(origLb); + InOrder inOrder = inOrder(origLb, mockStateListeners[0]); - hcLbEventDelivery.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(origLb).handleSubchannelState( + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockStateListeners[0]).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); inOrder.verifyNoMoreInteractions(); HealthImpl healthImpl = healthImpls[0]; @@ -694,12 +679,12 @@ public class HealthCheckingLoadBalancerFactoryTest { // Health check RPC cancelled. assertThat(serverCall.cancelled).isTrue(); // Subchannel uses original state - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListeners[0]).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(READY))); inOrder.verify(origLb).handleResolvedAddresses(result2); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(origLb, mockStateListeners[0]); assertThat(healthImpl.calls).isEmpty(); } @@ -717,10 +702,10 @@ public class HealthCheckingLoadBalancerFactoryTest { Subchannel subchannel = createSubchannel(0, Attributes.EMPTY); assertThat(subchannel).isSameAs(subchannels[0]); - InOrder inOrder = inOrder(origLb); + InOrder inOrder = inOrder(origLb, mockStateListeners[0]); - hcLbEventDelivery.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(origLb).handleSubchannelState( + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockStateListeners[0]).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); inOrder.verifyNoMoreInteractions(); HealthImpl healthImpl = healthImpls[0]; @@ -730,7 +715,7 @@ public class HealthCheckingLoadBalancerFactoryTest { assertThat(clock.getPendingTasks()).isEmpty(); healthImpl.calls.poll().responseObserver.onCompleted(); assertThat(clock.getPendingTasks()).hasSize(1); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListeners[0]).onSubchannelState( same(subchannel), unavailableStateWithMsg( "Health-check stream unexpectedly closed with " + Status.OK + " for 'TeeService'")); @@ -749,12 +734,12 @@ public class HealthCheckingLoadBalancerFactoryTest { assertThat(healthImpl.calls).isEmpty(); // Subchannel uses original state - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListeners[0]).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(READY))); inOrder.verify(origLb).handleResolvedAddresses(result2); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(origLb, mockStateListeners[0]); } @Test @@ -771,14 +756,15 @@ public class HealthCheckingLoadBalancerFactoryTest { Subchannel subchannel = createSubchannel(0, Attributes.EMPTY); assertThat(subchannel).isSameAs(subchannels[0]); - InOrder inOrder = inOrder(origLb); + InOrder inOrder = inOrder(origLb, mockStateListeners[0]); // Underlying subchannel is not READY initially ConnectivityStateInfo underlyingErrorState = ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("connection refused")); - hcLbEventDelivery.handleSubchannelState(subchannel, underlyingErrorState); - inOrder.verify(origLb).handleSubchannelState(same(subchannel), same(underlyingErrorState)); + deliverSubchannelState(0, underlyingErrorState); + inOrder.verify(mockStateListeners[0]) + .onSubchannelState(same(subchannel), same(underlyingErrorState)); inOrder.verifyNoMoreInteractions(); // NameResolver gives an update without service config, thus health check will be disabled @@ -791,16 +777,16 @@ public class HealthCheckingLoadBalancerFactoryTest { inOrder.verify(origLb).handleResolvedAddresses(result2); // Underlying subchannel is now ready - hcLbEventDelivery.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); // Since health check is disabled, READY state is propagated directly. - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockStateListeners[0]).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(READY))); // and there is no health check activity. assertThat(healthImpls[0].calls).isEmpty(); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(origLb, mockStateListeners[0]); } @Test @@ -816,11 +802,12 @@ public class HealthCheckingLoadBalancerFactoryTest { verifyNoMoreInteractions(origLb); Subchannel subchannel = createSubchannel(0, Attributes.EMPTY); + SubchannelStateListener mockListener = mockStateListeners[0]; assertThat(subchannel).isSameAs(subchannels[0]); - InOrder inOrder = inOrder(origLb); + InOrder inOrder = inOrder(origLb, mockListener); - hcLbEventDelivery.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(origLb).handleSubchannelState( + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); HealthImpl healthImpl = healthImpls[0]; @@ -831,14 +818,14 @@ public class HealthCheckingLoadBalancerFactoryTest { // Health check responded serverCall.responseObserver.onNext(makeResponse(ServingStatus.SERVING)); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(READY))); // Service config returns with the same health check name. hcLbEventDelivery.handleResolvedAddresses(result1); // It's delivered to origLb, but nothing else happens inOrder.verify(origLb).handleResolvedAddresses(result1); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(origLb, mockListener); // Service config returns a different health check name. resolutionAttrs = attrsWithHealthCheckService("FooService"); @@ -859,7 +846,7 @@ public class HealthCheckingLoadBalancerFactoryTest { assertThat(serverCall.request).isEqualTo(makeRequest("FooService")); // State stays in READY, instead of switching to CONNECTING. - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(origLb, mockListener); } @Test @@ -875,11 +862,12 @@ public class HealthCheckingLoadBalancerFactoryTest { verifyNoMoreInteractions(origLb); Subchannel subchannel = createSubchannel(0, Attributes.EMPTY); + SubchannelStateListener mockListener = mockStateListeners[0]; assertThat(subchannel).isSameAs(subchannels[0]); - InOrder inOrder = inOrder(origLb); + InOrder inOrder = inOrder(origLb, mockListener); - hcLbEventDelivery.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(origLb).handleSubchannelState( + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); HealthImpl healthImpl = healthImpls[0]; @@ -893,7 +881,7 @@ public class HealthCheckingLoadBalancerFactoryTest { serverCall.responseObserver.onCompleted(); assertThat(clock.getPendingTasks()).hasSize(1); assertThat(healthImpl.calls).isEmpty(); - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockListener).onSubchannelState( same(subchannel), unavailableStateWithMsg( "Health-check stream unexpectedly closed with " + Status.OK + " for 'TeeService'")); @@ -903,7 +891,7 @@ public class HealthCheckingLoadBalancerFactoryTest { hcLbEventDelivery.handleResolvedAddresses(result1); // It's delivered to origLb, but nothing else happens inOrder.verify(origLb).handleResolvedAddresses(result1); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(origLb, mockListener); assertThat(clock.getPendingTasks()).hasSize(1); assertThat(healthImpl.calls).isEmpty(); @@ -916,7 +904,7 @@ public class HealthCheckingLoadBalancerFactoryTest { hcLbEventDelivery.handleResolvedAddresses(result2); // Concluded CONNECTING state - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); inOrder.verify(origLb).handleResolvedAddresses(result2); @@ -930,7 +918,7 @@ public class HealthCheckingLoadBalancerFactoryTest { // with the new service name assertThat(serverCall.request).isEqualTo(makeRequest("FooService")); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(origLb, mockListener); } @Test @@ -946,16 +934,17 @@ public class HealthCheckingLoadBalancerFactoryTest { verifyNoMoreInteractions(origLb); Subchannel subchannel = createSubchannel(0, Attributes.EMPTY); + SubchannelStateListener mockListener = mockStateListeners[0]; assertThat(subchannel).isSameAs(subchannels[0]); - InOrder inOrder = inOrder(origLb); + InOrder inOrder = inOrder(origLb, mockListener); HealthImpl healthImpl = healthImpls[0]; // Underlying subchannel is not READY initially ConnectivityStateInfo underlyingErrorState = ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("connection refused")); - hcLbEventDelivery.handleSubchannelState(subchannel, underlyingErrorState); - inOrder.verify(origLb).handleSubchannelState(same(subchannel), same(underlyingErrorState)); + deliverSubchannelState(0, underlyingErrorState); + inOrder.verify(mockListener).onSubchannelState(same(subchannel), same(underlyingErrorState)); inOrder.verifyNoMoreInteractions(); // Service config returns with the same health check name. @@ -976,10 +965,10 @@ public class HealthCheckingLoadBalancerFactoryTest { inOrder.verify(origLb).handleResolvedAddresses(result2); // Underlying subchannel is now ready - hcLbEventDelivery.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); // Concluded CONNECTING state - inOrder.verify(origLb).handleSubchannelState( + inOrder.verify(mockListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); // Health check RPC is started @@ -987,7 +976,7 @@ public class HealthCheckingLoadBalancerFactoryTest { // with the new service name assertThat(healthImpl.calls.poll().request).isEqualTo(makeRequest("FooService")); - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(origLb, mockListener); } @Test @@ -1031,17 +1020,18 @@ public class HealthCheckingLoadBalancerFactoryTest { verifyNoMoreInteractions(origLb); Subchannel subchannel = createSubchannel(0, Attributes.EMPTY); + SubchannelStateListener mockListener = mockStateListeners[0]; assertThat(subchannel).isSameAs(subchannels[0]); // Trigger the health check - hcLbEventDelivery.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); HealthImpl healthImpl = healthImpls[0]; assertThat(healthImpl.calls).hasSize(1); ServerSideCall serverCall = healthImpl.calls.poll(); assertThat(serverCall.cancelled).isFalse(); - verify(origLb).handleSubchannelState( + verify(mockListener).onSubchannelState( same(subchannel), eq(ConnectivityStateInfo.forNonError(CONNECTING))); // Shut down the balancer @@ -1052,7 +1042,7 @@ public class HealthCheckingLoadBalancerFactoryTest { assertThat(serverCall.cancelled).isTrue(); // LoadBalancer API requires no more callbacks on LoadBalancer after shutdown() is called. - verifyNoMoreInteractions(origLb); + verifyNoMoreInteractions(origLb, mockListener); // No more health check call is made or scheduled assertThat(healthImpl.calls).isEmpty(); @@ -1085,8 +1075,7 @@ public class HealthCheckingLoadBalancerFactoryTest { verify(origLb).handleResolvedAddresses(result); createSubchannel(0, Attributes.EMPTY); assertThat(healthImpls[0].calls).isEmpty(); - hcLbEventDelivery.handleSubchannelState( - subchannels[0], ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); assertThat(healthImpls[0].calls).hasSize(1); } @@ -1179,6 +1168,7 @@ public class HealthCheckingLoadBalancerFactoryTest { final Attributes attrs; final Channel channel; final ArrayList logs = new ArrayList<>(); + final int index; private final ChannelLogger logger = new ChannelLogger() { @Override public void log(ChannelLogLevel level, String msg) { @@ -1191,15 +1181,16 @@ public class HealthCheckingLoadBalancerFactoryTest { } }; - FakeSubchannel(List eagList, Attributes attrs, Channel channel) { - this.eagList = Collections.unmodifiableList(eagList); - this.attrs = checkNotNull(attrs); + FakeSubchannel(int index, CreateSubchannelArgs args, Channel channel) { + this.index = index; + this.eagList = args.getAddresses(); + this.attrs = args.getAttributes(); this.channel = checkNotNull(channel); } @Override public void shutdown() { - hcLbEventDelivery.handleSubchannelState(this, ConnectivityStateInfo.forNonError(SHUTDOWN)); + deliverSubchannelState(index, ConnectivityStateInfo.forNonError(SHUTDOWN)); } @Override @@ -1230,18 +1221,19 @@ public class HealthCheckingLoadBalancerFactoryTest { private class FakeHelper extends Helper { @Override - public Subchannel createSubchannel(List addrs, Attributes attrs) { + public Subchannel createSubchannel(CreateSubchannelArgs args) { int index = -1; for (int i = 0; i < NUM_SUBCHANNELS; i++) { - if (eagLists[i] == addrs) { + if (eagLists[i].equals(args.getAddresses())) { index = i; break; } } - checkState(index >= 0, "addrs " + addrs + " not found"); - FakeSubchannel subchannel = new FakeSubchannel(addrs, attrs, channels[index]); + checkState(index >= 0, "addrs " + args.getAddresses() + " not found"); + FakeSubchannel subchannel = new FakeSubchannel(index, args, channels[index]); checkState(subchannels[index] == null, "subchannels[" + index + "] already created"); subchannels[index] = subchannel; + stateListeners[index] = args.getStateListener(); return subchannel; } @@ -1297,9 +1289,23 @@ public class HealthCheckingLoadBalancerFactoryTest { syncContext.execute(new Runnable() { @Override public void run() { - returnedSubchannel.set(wrappedHelper.createSubchannel(eagLists[index], attrs)); + returnedSubchannel.set( + wrappedHelper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(eagLists[index]) + .setAttributes(attrs) + .setStateListener(mockStateListeners[index]) + .build())); } }); return returnedSubchannel.get(); } + + private void deliverSubchannelState(final int index, final ConnectivityStateInfo newState) { + syncContext.execute(new Runnable() { + @Override + public void run() { + stateListeners[index].onSubchannelState(subchannels[index], newState); + } + }); + } }