diff --git a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java index b0a626d47c..6b93653414 100644 --- a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java +++ b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java @@ -17,12 +17,14 @@ package io.grpc.util; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; +import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.ExperimentalApi; @@ -36,6 +38,7 @@ import io.grpc.Status; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -119,12 +122,12 @@ public class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory { subchannel.shutdown(); } - updatePicker(getAggregatedError()); + updateBalancingState(getAggregatedState(), getAggregatedError()); } @Override public void handleNameResolutionError(Status error) { - updatePicker(error); + updateBalancingState(TRANSIENT_FAILURE, error); } @Override @@ -136,7 +139,7 @@ public class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory { subchannel.requestConnection(); } getSubchannelStateInfoRef(subchannel).set(stateInfo); - updatePicker(getAggregatedError()); + updateBalancingState(getAggregatedState(), getAggregatedError()); } @Override @@ -149,9 +152,9 @@ public class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory { /** * Updates picker with the list of active subchannels (state == READY). */ - private void updatePicker(@Nullable Status error) { + private void updateBalancingState(ConnectivityState state, Status error) { List activeList = filterNonFailingSubchannels(getSubchannels()); - helper.updatePicker(new Picker(activeList, error)); + helper.updateBalancingState(state, new Picker(activeList, error)); } /** @@ -197,6 +200,26 @@ public class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory { return status; } + private ConnectivityState getAggregatedState() { + Set states = EnumSet.noneOf(ConnectivityState.class); + for (Subchannel subchannel : getSubchannels()) { + states.add(getSubchannelStateInfoRef(subchannel).get().getState()); + } + if (states.contains(READY)) { + return READY; + } + if (states.contains(CONNECTING)) { + return CONNECTING; + } + if (states.contains(IDLE)) { + // This subchannel IDLE is not because of channel IDLE_TIMEOUT, in which case LB is already + // shutdown. + // RRLB will request connection immediately on subchannel IDLE. + return CONNECTING; + } + return TRANSIENT_FAILURE; + } + @VisibleForTesting Collection getSubchannels() { return subchannels.values(); diff --git a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index d22e914168..86e233670c 100644 --- a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -17,12 +17,15 @@ package io.grpc.util; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer.STATE_INFO; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; import static org.mockito.Matchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.inOrder; @@ -77,6 +80,8 @@ public class RoundRobinLoadBalancerTest { @Captor private ArgumentCaptor pickerCaptor; @Captor + private ArgumentCaptor stateCaptor; + @Captor private ArgumentCaptor eagCaptor; @Mock private Helper mockHelper; @@ -131,8 +136,11 @@ public class RoundRobinLoadBalancerTest { verify(subchannel, never()).shutdown(); } - verify(mockHelper, times(2)).updatePicker(pickerCaptor.capture()); + verify(mockHelper, times(2)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + assertEquals(CONNECTING, stateCaptor.getAllValues().get(0)); + assertEquals(READY, stateCaptor.getAllValues().get(1)); assertThat(pickerCaptor.getValue().getList()).containsExactly(readySubchannel); verifyNoMoreInteractions(mockHelper); @@ -176,7 +184,7 @@ public class RoundRobinLoadBalancerTest { InOrder inOrder = inOrder(mockHelper); - inOrder.verify(mockHelper).updatePicker(pickerCaptor.capture()); + inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); Picker picker = pickerCaptor.getValue(); assertNull(picker.getStatus()); assertThat(picker.getList()).containsExactly(removedSubchannel, oldSubchannel); @@ -206,7 +214,7 @@ public class RoundRobinLoadBalancerTest { verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)); - inOrder.verify(mockHelper).updatePicker(pickerCaptor.capture()); + inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); picker = pickerCaptor.getValue(); assertNull(picker.getStatus()); @@ -239,12 +247,12 @@ public class RoundRobinLoadBalancerTest { AtomicReference subchannelStateInfo = subchannel.getAttributes().get( STATE_INFO); - inOrder.verify(mockHelper).updatePicker(isA(Picker.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(Picker.class)); assertThat(subchannelStateInfo.get()).isEqualTo(ConnectivityStateInfo.forNonError(IDLE)); loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(ConnectivityState.READY)); - inOrder.verify(mockHelper).updatePicker(pickerCaptor.capture()); + inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertNull(pickerCaptor.getValue().getStatus()); assertThat(subchannelStateInfo.get()).isEqualTo( ConnectivityStateInfo.forNonError(ConnectivityState.READY)); @@ -254,12 +262,12 @@ public class RoundRobinLoadBalancerTest { ConnectivityStateInfo.forTransientFailure(error)); assertThat(subchannelStateInfo.get()).isEqualTo( ConnectivityStateInfo.forTransientFailure(error)); - inOrder.verify(mockHelper).updatePicker(pickerCaptor.capture()); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); assertNull(pickerCaptor.getValue().getStatus()); loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(ConnectivityState.IDLE)); - inOrder.verify(mockHelper).updatePicker(pickerCaptor.capture()); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); assertNull(pickerCaptor.getValue().getStatus()); assertThat(subchannelStateInfo.get()).isEqualTo( ConnectivityStateInfo.forNonError(ConnectivityState.IDLE)); @@ -300,7 +308,7 @@ public class RoundRobinLoadBalancerTest { public void nameResolutionErrorWithNoChannels() throws Exception { Status error = Status.NOT_FOUND.withDescription("nameResolutionError"); loadBalancer.handleNameResolutionError(error); - verify(mockHelper).updatePicker(pickerCaptor.capture()); + verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs); assertNull(pickResult.getSubchannel()); assertEquals(error, pickResult.getStatus()); @@ -316,7 +324,13 @@ public class RoundRobinLoadBalancerTest { verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)); - verify(mockHelper, times(3)).updatePicker(pickerCaptor.capture()); + verify(mockHelper, times(3)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + + Iterator stateIterator = stateCaptor.getAllValues().iterator(); + assertEquals(CONNECTING, stateIterator.next()); + assertEquals(READY, stateIterator.next()); + assertEquals(TRANSIENT_FAILURE, stateIterator.next()); LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs); assertEquals(readySubchannel, pickResult.getSubchannel()); @@ -346,19 +360,28 @@ public class RoundRobinLoadBalancerTest { loadBalancer .handleSubchannelState(sc3, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); - verify(mockHelper, times(6)).updatePicker(pickerCaptor.capture()); + verify(mockHelper, times(6)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + Iterator stateIterator = stateCaptor.getAllValues().iterator(); Iterator pickers = pickerCaptor.getAllValues().iterator(); // The picker is incrementally updated as subchannels become READY + assertEquals(CONNECTING, stateIterator.next()); assertThat(pickers.next().getList()).isEmpty(); + assertEquals(READY, stateIterator.next()); assertThat(pickers.next().getList()).containsExactly(sc1); + assertEquals(READY, stateIterator.next()); assertThat(pickers.next().getList()).containsExactly(sc1, sc2); + assertEquals(READY, stateIterator.next()); assertThat(pickers.next().getList()).containsExactly(sc1, sc2, sc3); // The IDLE subchannel is dropped from the picker, but a reconnection is requested + assertEquals(READY, stateIterator.next()); assertThat(pickers.next().getList()).containsExactly(sc1, sc3); verify(sc2, times(2)).requestConnection(); // The failing subchannel is dropped from the picker, with no requested reconnect + assertEquals(READY, stateIterator.next()); assertThat(pickers.next().getList()).containsExactly(sc1); verify(sc3, times(1)).requestConnection(); + assertThat(stateIterator.hasNext()).isFalse(); assertThat(pickers.hasNext()).isFalse(); }