api: avoid infinite loop in handleResolvedAddresses

If a `LoadBalancer` implementation does not override `handleResolvedAddressGroups()`, or overrides `handleResolvedAddressGroups()` but calls `super.handleResolvedAddressGroups()` at the beginning or the end, it will be trapped in an infinite loop.
This commit is contained in:
ZHANG Dapeng 2019-10-03 15:58:18 -07:00 committed by GitHub
parent 28323e2fb6
commit 90b3c88fe2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 4 deletions

View File

@ -117,6 +117,8 @@ public abstract class LoadBalancer {
public static final Attributes.Key<Map<String, ?>> ATTR_LOAD_BALANCING_CONFIG = public static final Attributes.Key<Map<String, ?>> ATTR_LOAD_BALANCING_CONFIG =
Attributes.Key.create("io.grpc.LoadBalancer.loadBalancingConfig"); Attributes.Key.create("io.grpc.LoadBalancer.loadBalancingConfig");
private int recursionCount;
/** /**
* Handles newly resolved server groups and metadata attributes from name resolution system. * Handles newly resolved server groups and metadata attributes from name resolution system.
* {@code servers} contained in {@link EquivalentAddressGroup} should be considered equivalent * {@code servers} contained in {@link EquivalentAddressGroup} should be considered equivalent
@ -133,8 +135,11 @@ public abstract class LoadBalancer {
public void handleResolvedAddressGroups( public void handleResolvedAddressGroups(
List<EquivalentAddressGroup> servers, List<EquivalentAddressGroup> servers,
@NameResolver.ResolutionResultAttr Attributes attributes) { @NameResolver.ResolutionResultAttr Attributes attributes) {
handleResolvedAddresses( if (recursionCount++ == 0) {
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(attributes).build()); handleResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(attributes).build());
}
recursionCount = 0;
} }
/** /**
@ -149,8 +154,11 @@ public abstract class LoadBalancer {
*/ */
@SuppressWarnings("deprecation") @SuppressWarnings("deprecation")
public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
handleResolvedAddressGroups( if (recursionCount++ == 0) {
resolvedAddresses.getAddresses(), resolvedAddresses.getAttributes()); handleResolvedAddressGroups(
resolvedAddresses.getAddresses(), resolvedAddresses.getAttributes());
}
recursionCount = 0;
} }
/** /**

View File

@ -25,6 +25,7 @@ import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.ResolvedAddresses;
import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.Subchannel;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -360,6 +361,41 @@ public class LoadBalancerTest {
assertThat(attrsCapture.get()).isEqualTo(attrs); assertThat(attrsCapture.get()).isEqualTo(attrs);
} }
@Deprecated
@Test
public void handleResolvedAddresses_noInfiniteLoop() {
final List<List<EquivalentAddressGroup>> serversCapture = new ArrayList<>();
final List<Attributes> attrsCapture = new ArrayList<>();
LoadBalancer balancer = new LoadBalancer() {
@Override
public void handleResolvedAddressGroups(
List<EquivalentAddressGroup> servers, Attributes attrs) {
serversCapture.add(servers);
attrsCapture.add(attrs);
super.handleResolvedAddressGroups(servers, attrs);
}
@Override
public void handleNameResolutionError(Status error) {
}
@Override
public void shutdown() {
}
};
List<EquivalentAddressGroup> servers = Arrays.asList(
new EquivalentAddressGroup(new SocketAddress(){}),
new EquivalentAddressGroup(new SocketAddress(){}));
balancer.handleResolvedAddresses(
ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(attrs).build());
assertThat(serversCapture).hasSize(1);
assertThat(attrsCapture).hasSize(1);
assertThat(serversCapture.get(0)).isEqualTo(servers);
assertThat(attrsCapture.get(0)).isEqualTo(attrs);
}
private static class NoopHelper extends LoadBalancer.Helper { private static class NoopHelper extends LoadBalancer.Helper {
@Override @Override
public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) { public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) {