Fix shared subchannel state in RoundRobin LB (#2777)

handleResolvedAddresses constructs subchannels for each address in
an EquivalentAddressGroup, sharing a reference to an immutable
Atrributes. However, as noted, this Attributes instance contains
a mutable AtomicReference, eventually causing subchannel state
changes to be improperly reflected in *all* subchannels of an
EquivalentAddressGroup.
This commit is contained in:
Jack Amadeo 2017-03-03 20:48:05 -05:00 committed by Kun Zhang
parent acf093dc14
commit d21ee58341
2 changed files with 63 additions and 43 deletions

View File

@ -109,19 +109,19 @@ public class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory {
Set<EquivalentAddressGroup> addedAddrs = setsDifference(latestAddrs, currentAddrs); Set<EquivalentAddressGroup> addedAddrs = setsDifference(latestAddrs, currentAddrs);
Set<EquivalentAddressGroup> removedAddrs = setsDifference(currentAddrs, latestAddrs); Set<EquivalentAddressGroup> removedAddrs = setsDifference(currentAddrs, latestAddrs);
// NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel doesn't
// need them. They're describing the resolved server list but we're not taking any action
// based on this information.
Attributes subchannelAttrs = Attributes.newBuilder()
// NB(lukaszx0): because attributes are immutable we can't set new value for the key
// after creation but since we can mutate the values we leverge that and set
// AtomicReference which will allow mutating state info for given channel.
.set(STATE_INFO, new AtomicReference<ConnectivityStateInfo>(
ConnectivityStateInfo.forNonError(IDLE)))
.build();
// Create new subchannels for new addresses. // Create new subchannels for new addresses.
for (EquivalentAddressGroup addressGroup : addedAddrs) { for (EquivalentAddressGroup addressGroup : addedAddrs) {
// NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel
// doesn't need them. They're describing the resolved server list but we're not taking
// any action based on this information.
Attributes subchannelAttrs = Attributes.newBuilder()
// NB(lukaszx0): because attributes are immutable we can't set new value for the key
// after creation but since we can mutate the values we leverge that and set
// AtomicReference which will allow mutating state info for given channel.
.set(STATE_INFO, new AtomicReference<ConnectivityStateInfo>(
ConnectivityStateInfo.forNonError(IDLE)))
.build();
Subchannel subchannel = checkNotNull(helper.createSubchannel(addressGroup, subchannelAttrs), Subchannel subchannel = checkNotNull(helper.createSubchannel(addressGroup, subchannelAttrs),
"subchannel"); "subchannel");
subchannels.put(addressGroup, subchannel); subchannels.put(addressGroup, subchannel);

View File

@ -39,6 +39,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isA; import static org.mockito.Matchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
@ -64,6 +65,7 @@ import io.grpc.util.RoundRobinLoadBalancerFactory.Picker;
import io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer; import io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.util.Collections; import java.util.Collections;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -108,7 +110,7 @@ public class RoundRobinLoadBalancerTest {
SocketAddress addr = new FakeSocketAddress("server" + i); SocketAddress addr = new FakeSocketAddress("server" + i);
EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); EquivalentAddressGroup eag = new EquivalentAddressGroup(addr);
servers.put(ResolvedServerInfoGroup.builder().add(new ResolvedServerInfo(addr)).build(), eag); servers.put(ResolvedServerInfoGroup.builder().add(new ResolvedServerInfo(addr)).build(), eag);
subchannels.put(eag, createMockSubchannel()); subchannels.put(eag, mock(Subchannel.class));
} }
when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class))) when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)))
@ -116,7 +118,9 @@ public class RoundRobinLoadBalancerTest {
@Override @Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable { public Subchannel answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments(); Object[] args = invocation.getArguments();
return subchannels.get(args[0]); Subchannel subchannel = subchannels.get(args[0]);
when(subchannel.getAttributes()).thenReturn((Attributes) args[1]);
return subchannel;
} }
}); });
@ -131,12 +135,9 @@ public class RoundRobinLoadBalancerTest {
@Test @Test
public void pickAfterResolved() throws Exception { public void pickAfterResolved() throws Exception {
Subchannel readySubchannel = subchannels.get(servers.get(servers.keySet().iterator().next())); final Subchannel readySubchannel = subchannels.values().iterator().next();
when(readySubchannel.getAttributes()).thenReturn(Attributes.newBuilder()
.set(STATE_INFO, new AtomicReference<ConnectivityStateInfo>(
ConnectivityStateInfo.forNonError(READY)))
.build());
loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), affinity); loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), affinity);
loadBalancer.handleSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
verify(mockHelper, times(3)).createSubchannel(eagCaptor.capture(), verify(mockHelper, times(3)).createSubchannel(eagCaptor.capture(),
any(Attributes.class)); any(Attributes.class));
@ -147,7 +148,7 @@ public class RoundRobinLoadBalancerTest {
verify(subchannel, never()).shutdown(); verify(subchannel, never()).shutdown();
} }
verify(mockHelper, times(1)).updatePicker(pickerCaptor.capture()); verify(mockHelper, times(2)).updatePicker(pickerCaptor.capture());
assertThat(pickerCaptor.getValue().getList()).containsExactly(readySubchannel); assertThat(pickerCaptor.getValue().getList()).containsExactly(readySubchannel);
@ -181,14 +182,13 @@ public class RoundRobinLoadBalancerTest {
.add(new ResolvedServerInfo(oldAddr)) .add(new ResolvedServerInfo(oldAddr))
.build()); .build());
when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class))) doAnswer(new Answer<Subchannel>() {
.then(new Answer<Subchannel>() { @Override
@Override public Subchannel answer(InvocationOnMock invocation) throws Throwable {
public Subchannel answer(InvocationOnMock invocation) throws Throwable { Object[] args = invocation.getArguments();
Object[] args = invocation.getArguments(); return subchannels2.get(args[0]);
return subchannels2.get(args[0]); }
} }).when(mockHelper).createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class));
});
loadBalancer.handleResolvedAddresses(currentServers, affinity); loadBalancer.handleResolvedAddresses(currentServers, affinity);
@ -253,13 +253,6 @@ public class RoundRobinLoadBalancerTest {
@Test @Test
public void pickAfterStateChange() throws Exception { public void pickAfterStateChange() throws Exception {
InOrder inOrder = inOrder(mockHelper); InOrder inOrder = inOrder(mockHelper);
when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)))
.then(new Answer<Subchannel>() {
@Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable {
return createMockSubchannel();
}
});
loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), Attributes.EMPTY); loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), Attributes.EMPTY);
Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); Subchannel subchannel = loadBalancer.getSubchannels().iterator().next();
AtomicReference<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get( AtomicReference<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get(
@ -335,14 +328,14 @@ public class RoundRobinLoadBalancerTest {
@Test @Test
public void nameResolutionErrorWithActiveChannels() throws Exception { public void nameResolutionErrorWithActiveChannels() throws Exception {
Subchannel readySubchannel = subchannels.values().iterator().next(); final Subchannel readySubchannel = subchannels.values().iterator().next();
readySubchannel.getAttributes().get(STATE_INFO).set(ConnectivityStateInfo.forNonError(READY));
loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), affinity); loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), affinity);
loadBalancer.handleSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError"));
verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class), verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class),
any(Attributes.class)); any(Attributes.class));
verify(mockHelper, times(2)).updatePicker(pickerCaptor.capture()); verify(mockHelper, times(3)).updatePicker(pickerCaptor.capture());
LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs); LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs);
assertEquals(readySubchannel, pickResult.getSubchannel()); assertEquals(readySubchannel, pickResult.getSubchannel());
@ -353,12 +346,39 @@ public class RoundRobinLoadBalancerTest {
verifyNoMoreInteractions(mockHelper); verifyNoMoreInteractions(mockHelper);
} }
private Subchannel createMockSubchannel() { @Test
Subchannel subchannel = mock(Subchannel.class); public void subchannelStateIsolation() throws Exception {
when(subchannel.getAttributes()).thenReturn(Attributes.newBuilder().set(STATE_INFO, Iterator<Subchannel> subchannelIterator = subchannels.values().iterator();
new AtomicReference<ConnectivityStateInfo>( Subchannel sc1 = subchannelIterator.next();
ConnectivityStateInfo.forNonError(IDLE))).build()); Subchannel sc2 = subchannelIterator.next();
return subchannel; Subchannel sc3 = subchannelIterator.next();
loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), Attributes.EMPTY);
verify(sc1, times(1)).requestConnection();
verify(sc2, times(1)).requestConnection();
verify(sc3, times(1)).requestConnection();
loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY));
loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(READY));
loadBalancer.handleSubchannelState(sc3, ConnectivityStateInfo.forNonError(READY));
loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(IDLE));
loadBalancer
.handleSubchannelState(sc3, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));
verify(mockHelper, times(6)).updatePicker(pickerCaptor.capture());
Iterator<Picker> pickers = pickerCaptor.getAllValues().iterator();
// The picker is incrementally updated as subchannels become READY
assertThat(pickers.next().getList()).isEmpty();
assertThat(pickers.next().getList()).containsExactly(sc1);
assertThat(pickers.next().getList()).containsExactly(sc1, sc2);
assertThat(pickers.next().getList()).containsExactly(sc1, sc2, sc3);
// The IDLE subchannel is dropped from the picker, but a reconnection is requested
assertThat(pickers.next().getList()).containsExactly(sc1, sc3);
verify(sc2, times(2)).requestConnection();
// The failing subchannel is dropped from the picker, with no requested reconnect
assertThat(pickers.next().getList()).containsExactly(sc1);
verify(sc3, times(1)).requestConnection();
assertThat(pickers.hasNext()).isFalse();
} }
private static class FakeSocketAddress extends SocketAddress { private static class FakeSocketAddress extends SocketAddress {