Fix shared subchannel state in RoundRobin LB (#2777)

handleResolvedAddresses constructs subchannels for each address in
an EquivalentAddressGroup, sharing a reference to an immutable
Atrributes. However, as noted, this Attributes instance contains
a mutable AtomicReference, eventually causing subchannel state
changes to be improperly reflected in *all* subchannels of an
EquivalentAddressGroup.
This commit is contained in:
Jack Amadeo 2017-03-03 20:48:05 -05:00 committed by Kun Zhang
parent acf093dc14
commit d21ee58341
2 changed files with 63 additions and 43 deletions

View File

@ -109,19 +109,19 @@ public class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory {
Set<EquivalentAddressGroup> addedAddrs = setsDifference(latestAddrs, currentAddrs);
Set<EquivalentAddressGroup> removedAddrs = setsDifference(currentAddrs, latestAddrs);
// NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel doesn't
// need them. They're describing the resolved server list but we're not taking any action
// based on this information.
Attributes subchannelAttrs = Attributes.newBuilder()
// NB(lukaszx0): because attributes are immutable we can't set new value for the key
// after creation but since we can mutate the values we leverge that and set
// AtomicReference which will allow mutating state info for given channel.
.set(STATE_INFO, new AtomicReference<ConnectivityStateInfo>(
ConnectivityStateInfo.forNonError(IDLE)))
.build();
// Create new subchannels for new addresses.
for (EquivalentAddressGroup addressGroup : addedAddrs) {
// NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel
// doesn't need them. They're describing the resolved server list but we're not taking
// any action based on this information.
Attributes subchannelAttrs = Attributes.newBuilder()
// NB(lukaszx0): because attributes are immutable we can't set new value for the key
// after creation but since we can mutate the values we leverge that and set
// AtomicReference which will allow mutating state info for given channel.
.set(STATE_INFO, new AtomicReference<ConnectivityStateInfo>(
ConnectivityStateInfo.forNonError(IDLE)))
.build();
Subchannel subchannel = checkNotNull(helper.createSubchannel(addressGroup, subchannelAttrs),
"subchannel");
subchannels.put(addressGroup, subchannel);

View File

@ -39,6 +39,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@ -64,6 +65,7 @@ import io.grpc.util.RoundRobinLoadBalancerFactory.Picker;
import io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer;
import java.net.SocketAddress;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
@ -108,7 +110,7 @@ public class RoundRobinLoadBalancerTest {
SocketAddress addr = new FakeSocketAddress("server" + i);
EquivalentAddressGroup eag = new EquivalentAddressGroup(addr);
servers.put(ResolvedServerInfoGroup.builder().add(new ResolvedServerInfo(addr)).build(), eag);
subchannels.put(eag, createMockSubchannel());
subchannels.put(eag, mock(Subchannel.class));
}
when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)))
@ -116,7 +118,9 @@ public class RoundRobinLoadBalancerTest {
@Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
return subchannels.get(args[0]);
Subchannel subchannel = subchannels.get(args[0]);
when(subchannel.getAttributes()).thenReturn((Attributes) args[1]);
return subchannel;
}
});
@ -131,12 +135,9 @@ public class RoundRobinLoadBalancerTest {
@Test
public void pickAfterResolved() throws Exception {
Subchannel readySubchannel = subchannels.get(servers.get(servers.keySet().iterator().next()));
when(readySubchannel.getAttributes()).thenReturn(Attributes.newBuilder()
.set(STATE_INFO, new AtomicReference<ConnectivityStateInfo>(
ConnectivityStateInfo.forNonError(READY)))
.build());
final Subchannel readySubchannel = subchannels.values().iterator().next();
loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), affinity);
loadBalancer.handleSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
verify(mockHelper, times(3)).createSubchannel(eagCaptor.capture(),
any(Attributes.class));
@ -147,7 +148,7 @@ public class RoundRobinLoadBalancerTest {
verify(subchannel, never()).shutdown();
}
verify(mockHelper, times(1)).updatePicker(pickerCaptor.capture());
verify(mockHelper, times(2)).updatePicker(pickerCaptor.capture());
assertThat(pickerCaptor.getValue().getList()).containsExactly(readySubchannel);
@ -181,14 +182,13 @@ public class RoundRobinLoadBalancerTest {
.add(new ResolvedServerInfo(oldAddr))
.build());
when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)))
.then(new Answer<Subchannel>() {
@Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
return subchannels2.get(args[0]);
}
});
doAnswer(new Answer<Subchannel>() {
@Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
return subchannels2.get(args[0]);
}
}).when(mockHelper).createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class));
loadBalancer.handleResolvedAddresses(currentServers, affinity);
@ -253,13 +253,6 @@ public class RoundRobinLoadBalancerTest {
@Test
public void pickAfterStateChange() throws Exception {
InOrder inOrder = inOrder(mockHelper);
when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)))
.then(new Answer<Subchannel>() {
@Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable {
return createMockSubchannel();
}
});
loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), Attributes.EMPTY);
Subchannel subchannel = loadBalancer.getSubchannels().iterator().next();
AtomicReference<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get(
@ -335,14 +328,14 @@ public class RoundRobinLoadBalancerTest {
@Test
public void nameResolutionErrorWithActiveChannels() throws Exception {
Subchannel readySubchannel = subchannels.values().iterator().next();
readySubchannel.getAttributes().get(STATE_INFO).set(ConnectivityStateInfo.forNonError(READY));
final Subchannel readySubchannel = subchannels.values().iterator().next();
loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), affinity);
loadBalancer.handleSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError"));
verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class),
any(Attributes.class));
verify(mockHelper, times(2)).updatePicker(pickerCaptor.capture());
verify(mockHelper, times(3)).updatePicker(pickerCaptor.capture());
LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs);
assertEquals(readySubchannel, pickResult.getSubchannel());
@ -353,12 +346,39 @@ public class RoundRobinLoadBalancerTest {
verifyNoMoreInteractions(mockHelper);
}
private Subchannel createMockSubchannel() {
Subchannel subchannel = mock(Subchannel.class);
when(subchannel.getAttributes()).thenReturn(Attributes.newBuilder().set(STATE_INFO,
new AtomicReference<ConnectivityStateInfo>(
ConnectivityStateInfo.forNonError(IDLE))).build());
return subchannel;
@Test
public void subchannelStateIsolation() throws Exception {
Iterator<Subchannel> subchannelIterator = subchannels.values().iterator();
Subchannel sc1 = subchannelIterator.next();
Subchannel sc2 = subchannelIterator.next();
Subchannel sc3 = subchannelIterator.next();
loadBalancer.handleResolvedAddresses(Lists.newArrayList(servers.keySet()), Attributes.EMPTY);
verify(sc1, times(1)).requestConnection();
verify(sc2, times(1)).requestConnection();
verify(sc3, times(1)).requestConnection();
loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY));
loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(READY));
loadBalancer.handleSubchannelState(sc3, ConnectivityStateInfo.forNonError(READY));
loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(IDLE));
loadBalancer
.handleSubchannelState(sc3, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));
verify(mockHelper, times(6)).updatePicker(pickerCaptor.capture());
Iterator<Picker> pickers = pickerCaptor.getAllValues().iterator();
// The picker is incrementally updated as subchannels become READY
assertThat(pickers.next().getList()).isEmpty();
assertThat(pickers.next().getList()).containsExactly(sc1);
assertThat(pickers.next().getList()).containsExactly(sc1, sc2);
assertThat(pickers.next().getList()).containsExactly(sc1, sc2, sc3);
// The IDLE subchannel is dropped from the picker, but a reconnection is requested
assertThat(pickers.next().getList()).containsExactly(sc1, sc3);
verify(sc2, times(2)).requestConnection();
// The failing subchannel is dropped from the picker, with no requested reconnect
assertThat(pickers.next().getList()).containsExactly(sc1);
verify(sc3, times(1)).requestConnection();
assertThat(pickers.hasNext()).isFalse();
}
private static class FakeSocketAddress extends SocketAddress {