From d21ee583411b8a9c54d16317a384c1cb0677df74 Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Fri, 3 Mar 2017 20:48:05 -0500 Subject: [PATCH] Fix shared subchannel state in RoundRobin LB (#2777) handleResolvedAddresses constructs subchannels for each address in an EquivalentAddressGroup, sharing a reference to an immutable Atrributes. However, as noted, this Attributes instance contains a mutable AtomicReference, eventually causing subchannel state changes to be improperly reflected in *all* subchannels of an EquivalentAddressGroup. --- .../util/RoundRobinLoadBalancerFactory.java | 22 ++--- .../grpc/util/RoundRobinLoadBalancerTest.java | 84 ++++++++++++------- 2 files changed, 63 insertions(+), 43 deletions(-) diff --git a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java index 3bec9c6cd1..6f0890983e 100644 --- a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java +++ b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java @@ -109,19 +109,19 @@ public class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory { Set addedAddrs = setsDifference(latestAddrs, currentAddrs); Set removedAddrs = setsDifference(currentAddrs, latestAddrs); - // NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel doesn't - // need them. They're describing the resolved server list but we're not taking any action - // based on this information. - Attributes subchannelAttrs = Attributes.newBuilder() - // NB(lukaszx0): because attributes are immutable we can't set new value for the key - // after creation but since we can mutate the values we leverge that and set - // AtomicReference which will allow mutating state info for given channel. - .set(STATE_INFO, new AtomicReference( - ConnectivityStateInfo.forNonError(IDLE))) - .build(); - // Create new subchannels for new addresses. for (EquivalentAddressGroup addressGroup : addedAddrs) { + // NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel + // doesn't need them. They're describing the resolved server list but we're not taking + // any action based on this information. + Attributes subchannelAttrs = Attributes.newBuilder() + // NB(lukaszx0): because attributes are immutable we can't set new value for the key + // after creation but since we can mutate the values we leverge that and set + // AtomicReference which will allow mutating state info for given channel. + .set(STATE_INFO, new AtomicReference( + ConnectivityStateInfo.forNonError(IDLE))) + .build(); + Subchannel subchannel = checkNotNull(helper.createSubchannel(addressGroup, subchannelAttrs), "subchannel"); subchannels.put(addressGroup, subchannel); diff --git a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index 58c5fc110b..2414ab20ac 100644 --- a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -39,6 +39,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.mockito.Matchers.any; import static org.mockito.Matchers.isA; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -64,6 +65,7 @@ import io.grpc.util.RoundRobinLoadBalancerFactory.Picker; import io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer; import java.net.SocketAddress; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; @@ -108,7 +110,7 @@ public class RoundRobinLoadBalancerTest { SocketAddress addr = new FakeSocketAddress("server" + i); EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); servers.put(ResolvedServerInfoGroup.builder().add(new ResolvedServerInfo(addr)).build(), eag); - subchannels.put(eag, createMockSubchannel()); + subchannels.put(eag, mock(Subchannel.class)); } when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class))) @@ -116,7 +118,9 @@ public class RoundRobinLoadBalancerTest { @Override public Subchannel answer(InvocationOnMock invocation) throws Throwable { Object[] args = invocation.getArguments(); - return subchannels.get(args[0]); + Subchannel subchannel = subchannels.get(args[0]); + when(subchannel.getAttributes()).thenReturn((Attributes) args[1]); + return subchannel; } }); @@ -131,12 +135,9 @@ public class RoundRobinLoadBalancerTest { @Test public void pickAfterResolved() throws Exception { - Subchannel readySubchannel = subchannels.get(servers.get(servers.keySet().iterator().next())); - when(readySubchannel.getAttributes()).thenReturn(Attributes.newBuilder() - .set(STATE_INFO, new AtomicReference( - ConnectivityStateInfo.forNonError(READY))) - .build()); + final Subchannel readySubchannel = subchannels.values().iterator().next(); loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), affinity); + loadBalancer.handleSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); verify(mockHelper, times(3)).createSubchannel(eagCaptor.capture(), any(Attributes.class)); @@ -147,7 +148,7 @@ public class RoundRobinLoadBalancerTest { verify(subchannel, never()).shutdown(); } - verify(mockHelper, times(1)).updatePicker(pickerCaptor.capture()); + verify(mockHelper, times(2)).updatePicker(pickerCaptor.capture()); assertThat(pickerCaptor.getValue().getList()).containsExactly(readySubchannel); @@ -181,14 +182,13 @@ public class RoundRobinLoadBalancerTest { .add(new ResolvedServerInfo(oldAddr)) .build()); - when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class))) - .then(new Answer() { - @Override - public Subchannel answer(InvocationOnMock invocation) throws Throwable { - Object[] args = invocation.getArguments(); - return subchannels2.get(args[0]); - } - }); + doAnswer(new Answer() { + @Override + public Subchannel answer(InvocationOnMock invocation) throws Throwable { + Object[] args = invocation.getArguments(); + return subchannels2.get(args[0]); + } + }).when(mockHelper).createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)); loadBalancer.handleResolvedAddresses(currentServers, affinity); @@ -253,13 +253,6 @@ public class RoundRobinLoadBalancerTest { @Test public void pickAfterStateChange() throws Exception { InOrder inOrder = inOrder(mockHelper); - when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class))) - .then(new Answer() { - @Override - public Subchannel answer(InvocationOnMock invocation) throws Throwable { - return createMockSubchannel(); - } - }); loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), Attributes.EMPTY); Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); AtomicReference subchannelStateInfo = subchannel.getAttributes().get( @@ -335,14 +328,14 @@ public class RoundRobinLoadBalancerTest { @Test public void nameResolutionErrorWithActiveChannels() throws Exception { - Subchannel readySubchannel = subchannels.values().iterator().next(); - readySubchannel.getAttributes().get(STATE_INFO).set(ConnectivityStateInfo.forNonError(READY)); + final Subchannel readySubchannel = subchannels.values().iterator().next(); loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), affinity); + loadBalancer.handleSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)); - verify(mockHelper, times(2)).updatePicker(pickerCaptor.capture()); + verify(mockHelper, times(3)).updatePicker(pickerCaptor.capture()); LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs); assertEquals(readySubchannel, pickResult.getSubchannel()); @@ -353,12 +346,39 @@ public class RoundRobinLoadBalancerTest { verifyNoMoreInteractions(mockHelper); } - private Subchannel createMockSubchannel() { - Subchannel subchannel = mock(Subchannel.class); - when(subchannel.getAttributes()).thenReturn(Attributes.newBuilder().set(STATE_INFO, - new AtomicReference( - ConnectivityStateInfo.forNonError(IDLE))).build()); - return subchannel; + @Test + public void subchannelStateIsolation() throws Exception { + Iterator subchannelIterator = subchannels.values().iterator(); + Subchannel sc1 = subchannelIterator.next(); + Subchannel sc2 = subchannelIterator.next(); + Subchannel sc3 = subchannelIterator.next(); + + loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), Attributes.EMPTY); + verify(sc1, times(1)).requestConnection(); + verify(sc2, times(1)).requestConnection(); + verify(sc3, times(1)).requestConnection(); + + loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY)); + loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(READY)); + loadBalancer.handleSubchannelState(sc3, ConnectivityStateInfo.forNonError(READY)); + loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(IDLE)); + loadBalancer + .handleSubchannelState(sc3, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + + verify(mockHelper, times(6)).updatePicker(pickerCaptor.capture()); + Iterator pickers = pickerCaptor.getAllValues().iterator(); + // The picker is incrementally updated as subchannels become READY + assertThat(pickers.next().getList()).isEmpty(); + assertThat(pickers.next().getList()).containsExactly(sc1); + assertThat(pickers.next().getList()).containsExactly(sc1, sc2); + assertThat(pickers.next().getList()).containsExactly(sc1, sc2, sc3); + // The IDLE subchannel is dropped from the picker, but a reconnection is requested + assertThat(pickers.next().getList()).containsExactly(sc1, sc3); + verify(sc2, times(2)).requestConnection(); + // The failing subchannel is dropped from the picker, with no requested reconnect + assertThat(pickers.next().getList()).containsExactly(sc1); + verify(sc3, times(1)).requestConnection(); + assertThat(pickers.hasNext()).isFalse(); } private static class FakeSocketAddress extends SocketAddress {