diff --git a/core/src/main/java/io/grpc/internal/ServiceConfigUtil.java b/core/src/main/java/io/grpc/internal/ServiceConfigUtil.java index e784bace33..7e8f4610b0 100644 --- a/core/src/main/java/io/grpc/internal/ServiceConfigUtil.java +++ b/core/src/main/java/io/grpc/internal/ServiceConfigUtil.java @@ -36,6 +36,7 @@ public final class ServiceConfigUtil { private static final String SERVICE_CONFIG_METHOD_CONFIG_KEY = "methodConfig"; private static final String SERVICE_CONFIG_LOAD_BALANCING_POLICY_KEY = "loadBalancingPolicy"; + private static final String SERVICE_CONFIG_STICKINESS_METADATA_KEY = "stickinessMetadataKey"; private static final String METHOD_CONFIG_NAME_KEY = "name"; private static final String METHOD_CONFIG_TIMEOUT_KEY = "timeout"; private static final String METHOD_CONFIG_WAIT_FOR_READY_KEY = "waitForReady"; @@ -237,6 +238,18 @@ public final class ServiceConfigUtil { return getString(serviceConfig, SERVICE_CONFIG_LOAD_BALANCING_POLICY_KEY); } + /** + * Extracts the stickiness metadata key from a service config, or {@code null}. + */ + @Nullable + public static String getStickinessMetadataKeyFromServiceConfig( + Map serviceConfig) { + if (!serviceConfig.containsKey(SERVICE_CONFIG_STICKINESS_METADATA_KEY)) { + return null; + } + return getString(serviceConfig, SERVICE_CONFIG_STICKINESS_METADATA_KEY); + } + /** * Gets a list from an object for the given key. */ diff --git a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java index b0c7cd5b2b..19e608fc03 100644 --- a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java +++ b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancerFactory.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.annotations.VisibleForTesting; @@ -33,18 +34,26 @@ import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.Metadata; +import io.grpc.Metadata.Key; import io.grpc.NameResolver; import io.grpc.Status; +import io.grpc.internal.GrpcAttributes; +import io.grpc.internal.ServiceConfigUtil; import java.util.ArrayList; import java.util.Collection; import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Set; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nonnull; import javax.annotation.Nullable; /** @@ -91,10 +100,15 @@ public final class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory { static final Attributes.Key> STATE_INFO = Attributes.Key.of("state-info"); + private static final Logger logger = Logger.getLogger(RoundRobinLoadBalancer.class.getName()); + private final Helper helper; private final Map subchannels = new HashMap(); + @Nullable + private StickinessState stickinessState; + RoundRobinLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); } @@ -107,6 +121,24 @@ public final class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory { Set addedAddrs = setsDifference(latestAddrs, currentAddrs); Set removedAddrs = setsDifference(currentAddrs, latestAddrs); + Map serviceConfig = + attributes.get(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG); + if (serviceConfig != null) { + String stickinessMetadataKey = + ServiceConfigUtil.getStickinessMetadataKeyFromServiceConfig(serviceConfig); + if (stickinessMetadataKey != null) { + if (stickinessMetadataKey.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + logger.log( + Level.FINE, + "Binary stickiness header is not supported. The header '{0}' will be ignored", + stickinessMetadataKey); + } else if (stickinessState == null + || !stickinessState.key.name().equals(stickinessMetadataKey)) { + stickinessState = new StickinessState(stickinessMetadataKey); + } + } + } + // Create new subchannels for new addresses. for (EquivalentAddressGroup addressGroup : addedAddrs) { // NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel @@ -142,6 +174,9 @@ public final class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory { @Override public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { + if (stateInfo.getState() == SHUTDOWN && stickinessState != null) { + stickinessState.remove(subchannel); + } if (subchannels.get(subchannel.getAddresses()) != subchannel) { return; } @@ -164,7 +199,7 @@ public final class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory { */ private void updateBalancingState(ConnectivityState state, Status error) { List activeList = filterNonFailingSubchannels(getSubchannels()); - helper.updateBalancingState(state, new Picker(activeList, error)); + helper.updateBalancingState(state, new Picker(activeList, error, stickinessState)); } /** @@ -245,6 +280,81 @@ public final class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory { aCopy.removeAll(b); return aCopy; } + + Map> getStickinessMapForTest() { + if (stickinessState == null) { + return null; + } + return stickinessState.stickinessMap; + } + + /** + * Holds stickiness related states: The stickiness key, a registry mapping stickiness values to + * the associated Subchannel Ref, and a map from Subchannel to Subchannel Ref. + */ + private static final class StickinessState { + static final int MAX_ENTRIES = 1000; + + final Key key; + final Map> stickinessMap = + new LinkedHashMap>() { + @Override + protected boolean removeEldestEntry(Map.Entry> eldest) { + return size() > MAX_ENTRIES; + } + }; + + final Map> subchannelRefs = + new HashMap>(); + + StickinessState(@Nonnull String stickinessKey) { + this.key = Key.of(stickinessKey, Metadata.ASCII_STRING_MARSHALLER); + } + + /** + * Returns the subchannel asscoicated to the stickiness value if available in both the + * registry and the round robin list, otherwise associates the given subchannel with the + * stickiness key in the registry and returns the given subchannel. + */ + @Nonnull + synchronized Subchannel maybeRegister( + String stickinessValue, @Nonnull Subchannel subchannel, List rrList) { + Subchannel existingSubchannel = getSubchannel(stickinessValue); + if (existingSubchannel != null && rrList.contains(existingSubchannel)) { + return existingSubchannel; + } + + Ref subchannelRef = subchannelRefs.get(subchannel); + if (subchannelRef == null) { + subchannelRef = new Ref(subchannel); + subchannelRefs.put(subchannel, subchannelRef); + } + stickinessMap.put(stickinessValue, subchannelRef); + return subchannel; + } + + /** + * Unregister the subchannel from StickinessState. + */ + synchronized void remove(Subchannel subchannel) { + if (subchannelRefs.containsKey(subchannel)) { + subchannelRefs.get(subchannel).value = null; + subchannelRefs.remove(subchannel); + } + } + + /** + * Gets the subchannel associated with the stickiness value if there is. + */ + @Nullable + synchronized Subchannel getSubchannel(String stickinessValue) { + Ref subchannelRef = stickinessMap.get(stickinessValue); + if (subchannelRef != null) { + return subchannelRef.value; + } + return null; + } + } } @VisibleForTesting @@ -255,17 +365,31 @@ public final class RoundRobinLoadBalancerFactory extends LoadBalancer.Factory { @Nullable private final Status status; private final List list; + @Nullable + private final RoundRobinLoadBalancer.StickinessState stickinessState; @SuppressWarnings("unused") private volatile int index = -1; // start off at -1 so the address on first use is 0. - Picker(List list, @Nullable Status status) { + Picker( + List list, @Nullable Status status, + @Nullable RoundRobinLoadBalancer.StickinessState stickinessState) { this.list = list; this.status = status; + this.stickinessState = stickinessState; } @Override public PickResult pickSubchannel(PickSubchannelArgs args) { if (list.size() > 0) { + if (stickinessState != null && args.getHeaders().containsKey(stickinessState.key)) { + String stickinessValue = args.getHeaders().get(stickinessState.key); + Subchannel subchannel = stickinessState.getSubchannel(stickinessValue); + if (subchannel == null || !list.contains(subchannel)) { + subchannel = stickinessState.maybeRegister(stickinessValue, nextSubchannel(), list); + } + return PickResult.withSubchannel(subchannel); + } + return PickResult.withSubchannel(nextSubchannel()); } diff --git a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index 0f26c76363..6f3e5d3ed3 100644 --- a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -23,11 +23,16 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer.STATE_INFO; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.isA; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -46,12 +51,16 @@ import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.Subchannel; +import io.grpc.Metadata; +import io.grpc.Metadata.Key; import io.grpc.Status; +import io.grpc.internal.GrpcAttributes; import io.grpc.util.RoundRobinLoadBalancerFactory.Picker; import io.grpc.util.RoundRobinLoadBalancerFactory.Ref; import io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer; import java.net.SocketAddress; import java.util.Collections; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -85,8 +94,7 @@ public class RoundRobinLoadBalancerTest { private ArgumentCaptor eagCaptor; @Mock private Helper mockHelper; - @Mock - private Subchannel mockSubchannel; + @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown(). private PickSubchannelArgs mockArgs; @@ -270,8 +278,8 @@ public class RoundRobinLoadBalancerTest { Subchannel subchannel1 = mock(Subchannel.class); Subchannel subchannel2 = mock(Subchannel.class); - Picker picker = new Picker(Collections.unmodifiableList( - Lists.newArrayList(subchannel, subchannel1, subchannel2)), null); + Picker picker = new Picker(Collections.unmodifiableList(Lists.newArrayList( + subchannel, subchannel1, subchannel2)), null /* status */, null /* stickinessState */); assertThat(picker.getList()).containsExactly(subchannel, subchannel1, subchannel2); @@ -283,7 +291,8 @@ public class RoundRobinLoadBalancerTest { @Test public void pickerEmptyList() throws Exception { - Picker picker = new Picker(Lists.newArrayList(), Status.UNKNOWN); + Picker picker = + new Picker(Lists.newArrayList(), Status.UNKNOWN, null /* stickinessState */); assertEquals(null, picker.pickSubchannel(mockArgs).getSubchannel()); assertEquals(Status.UNKNOWN, @@ -371,6 +380,308 @@ public class RoundRobinLoadBalancerTest { assertThat(pickers.hasNext()).isFalse(); } + @Test + public void noStickinessEnabled_withStickyHeader() { + loadBalancer.handleResolvedAddressGroups(servers, Attributes.EMPTY); + for (Subchannel subchannel : subchannels.values()) { + loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + } + verify(mockHelper, times(4)) + .updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture()); + Picker picker = pickerCaptor.getValue(); + + Key stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER); + Metadata headerWithStickinessValue = new Metadata(); + headerWithStickinessValue.put(stickinessKey, "my-sticky-value"); + doReturn(headerWithStickinessValue).when(mockArgs).getHeaders(); + + Iterator subchannelIterator = loadBalancer.getSubchannels().iterator(); + Subchannel sc1 = subchannelIterator.next(); + Subchannel sc2 = subchannelIterator.next(); + Subchannel sc3 = subchannelIterator.next(); + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + assertEquals(sc3, picker.pickSubchannel(mockArgs).getSubchannel()); + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + + assertNull(loadBalancer.getStickinessMapForTest()); + } + + @Test + public void stickinessEnabled_withoutStickyHeader() { + Map serviceConfig = new HashMap(); + serviceConfig.put("stickinessMetadataKey", "my-sticky-key"); + Attributes attributes = Attributes.newBuilder() + .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build(); + loadBalancer.handleResolvedAddressGroups(servers, attributes); + for (Subchannel subchannel : subchannels.values()) { + loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + } + verify(mockHelper, times(4)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + Picker picker = pickerCaptor.getValue(); + + doReturn(new Metadata()).when(mockArgs).getHeaders(); + + Iterator subchannelIterator = loadBalancer.getSubchannels().iterator(); + Subchannel sc1 = subchannelIterator.next(); + Subchannel sc2 = subchannelIterator.next(); + Subchannel sc3 = subchannelIterator.next(); + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + assertEquals(sc3, picker.pickSubchannel(mockArgs).getSubchannel()); + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + + verify(mockArgs, times(4)).getHeaders(); + assertNotNull(loadBalancer.getStickinessMapForTest()); + assertThat(loadBalancer.getStickinessMapForTest()).isEmpty(); + } + + @Test + public void stickinessEnabled_withStickyHeader() { + Map serviceConfig = new HashMap(); + serviceConfig.put("stickinessMetadataKey", "my-sticky-key"); + Attributes attributes = Attributes.newBuilder() + .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build(); + loadBalancer.handleResolvedAddressGroups(servers, attributes); + for (Subchannel subchannel : subchannels.values()) { + loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + } + verify(mockHelper, times(4)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + Picker picker = pickerCaptor.getValue(); + + Key stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER); + Metadata headerWithStickinessValue = new Metadata(); + headerWithStickinessValue.put(stickinessKey, "my-sticky-value"); + doReturn(headerWithStickinessValue).when(mockArgs).getHeaders(); + + Subchannel sc1 = loadBalancer.getSubchannels().iterator().next(); + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + + verify(mockArgs, atLeast(4)).getHeaders(); + assertNotNull(loadBalancer.getStickinessMapForTest()); + assertThat(loadBalancer.getStickinessMapForTest().size()).isEqualTo(1); + } + + @Test + public void stickinessEnabled_withDifferentStickyHeaders() { + Map serviceConfig = new HashMap(); + serviceConfig.put("stickinessMetadataKey", "my-sticky-key"); + Attributes attributes = Attributes.newBuilder() + .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build(); + loadBalancer.handleResolvedAddressGroups(servers, attributes); + for (Subchannel subchannel : subchannels.values()) { + loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + } + verify(mockHelper, times(4)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + Picker picker = pickerCaptor.getValue(); + + Key stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER); + Metadata headerWithStickinessValue1 = new Metadata(); + headerWithStickinessValue1.put(stickinessKey, "my-sticky-value"); + + Metadata headerWithStickinessValue2 = new Metadata(); + headerWithStickinessValue2.put(stickinessKey, "my-sticky-value2"); + + Iterator subchannelIterator = loadBalancer.getSubchannels().iterator(); + Subchannel sc1 = subchannelIterator.next(); + Subchannel sc2 = subchannelIterator.next(); + + doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders(); + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + + doReturn(headerWithStickinessValue2).when(mockArgs).getHeaders(); + assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + + doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders(); + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + + doReturn(headerWithStickinessValue2).when(mockArgs).getHeaders(); + assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + + verify(mockArgs, atLeast(4)).getHeaders(); + assertNotNull(loadBalancer.getStickinessMapForTest()); + assertThat(loadBalancer.getStickinessMapForTest().size()).isEqualTo(2); + } + + @Test + public void stickiness_goToTransientFailure_pick_backToReady() { + Map serviceConfig = new HashMap(); + serviceConfig.put("stickinessMetadataKey", "my-sticky-key"); + Attributes attributes = Attributes.newBuilder() + .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build(); + loadBalancer.handleResolvedAddressGroups(servers, attributes); + for (Subchannel subchannel : subchannels.values()) { + loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + } + verify(mockHelper, times(4)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + Picker picker = pickerCaptor.getValue(); + + Key stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER); + Metadata headerWithStickinessValue = new Metadata(); + headerWithStickinessValue.put(stickinessKey, "my-sticky-value"); + doReturn(headerWithStickinessValue).when(mockArgs).getHeaders(); + + Iterator subchannelIterator = loadBalancer.getSubchannels().iterator(); + Subchannel sc1 = subchannelIterator.next(); + Subchannel sc2 = subchannelIterator.next(); + + // first pick + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + + // go to transient failure + loadBalancer + .handleSubchannelState(sc1, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + + verify(mockHelper, times(5)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + picker = pickerCaptor.getValue(); + // second pick + assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + + // go back to ready + loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY)); + + verify(mockHelper, times(6)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + picker = pickerCaptor.getValue(); + // third pick + assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + + verify(mockArgs, atLeast(3)).getHeaders(); + assertNotNull(loadBalancer.getStickinessMapForTest()); + assertThat(loadBalancer.getStickinessMapForTest().size()).isEqualTo(1); + } + + @Test + public void stickiness_goToTransientFailure_backToReady_pick() { + Map serviceConfig = new HashMap(); + serviceConfig.put("stickinessMetadataKey", "my-sticky-key"); + Attributes attributes = Attributes.newBuilder() + .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build(); + loadBalancer.handleResolvedAddressGroups(servers, attributes); + for (Subchannel subchannel : subchannels.values()) { + loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + } + verify(mockHelper, times(4)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + Picker picker = pickerCaptor.getValue(); + + Key stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER); + Metadata headerWithStickinessValue1 = new Metadata(); + headerWithStickinessValue1.put(stickinessKey, "my-sticky-value"); + doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders(); + + Iterator subchannelIterator = loadBalancer.getSubchannels().iterator(); + Subchannel sc1 = subchannelIterator.next(); + Subchannel sc2 = subchannelIterator.next(); + + // first pick + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + + // go to transient failure + loadBalancer + .handleSubchannelState(sc1, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + + Metadata headerWithStickinessValue2 = new Metadata(); + headerWithStickinessValue2.put(stickinessKey, "my-sticky-value2"); + doReturn(headerWithStickinessValue2).when(mockArgs).getHeaders(); + verify(mockHelper, times(5)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + picker = pickerCaptor.getValue(); + // second pick with a different stickiness value + assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + + // go back to ready + loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY)); + + doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders(); + verify(mockHelper, times(6)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + picker = pickerCaptor.getValue(); + // third pick with my-sticky-value1 + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + + verify(mockArgs, atLeast(3)).getHeaders(); + assertNotNull(loadBalancer.getStickinessMapForTest()); + assertThat(loadBalancer.getStickinessMapForTest().size()).isEqualTo(2); + } + + @Test + public void stickiness_oneSubchannelShutdown() { + Map serviceConfig = new HashMap(); + serviceConfig.put("stickinessMetadataKey", "my-sticky-key"); + Attributes attributes = Attributes.newBuilder() + .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build(); + loadBalancer.handleResolvedAddressGroups(servers, attributes); + for (Subchannel subchannel : subchannels.values()) { + loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + } + verify(mockHelper, times(4)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + Picker picker = pickerCaptor.getValue(); + + Key stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER); + Metadata headerWithStickinessValue = new Metadata(); + headerWithStickinessValue.put(stickinessKey, "my-sticky-value"); + doReturn(headerWithStickinessValue).when(mockArgs).getHeaders(); + + Iterator subchannelIterator = loadBalancer.getSubchannels().iterator(); + Subchannel sc1 = subchannelIterator.next(); + Subchannel sc2 = subchannelIterator.next(); + + assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel()); + + loadBalancer + .handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(ConnectivityState.SHUTDOWN)); + + assertNull(loadBalancer.getStickinessMapForTest().get("my-sticky-value").value); + + assertEquals(sc2, picker.pickSubchannel(mockArgs).getSubchannel()); + assertThat(loadBalancer.getStickinessMapForTest().size()).isEqualTo(1); + verify(mockArgs, atLeast(2)).getHeaders(); + } + + @Test + public void stickiness_resolveTwice_metadataKeyChanged() { + Map serviceConfig1 = new HashMap(); + serviceConfig1.put("stickinessMetadataKey", "my-sticky-key1"); + Attributes attributes1 = Attributes.newBuilder() + .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig1).build(); + loadBalancer.handleResolvedAddressGroups(servers, attributes1); + Map stickinessMap1 = loadBalancer.getStickinessMapForTest(); + + Map serviceConfig2 = new HashMap(); + serviceConfig2.put("stickinessMetadataKey", "my-sticky-key2"); + Attributes attributes2 = Attributes.newBuilder() + .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig2).build(); + loadBalancer.handleResolvedAddressGroups(servers, attributes2); + Map stickinessMap2 = loadBalancer.getStickinessMapForTest(); + + assertNotSame(stickinessMap1, stickinessMap2); + } + + @Test + public void stickiness_resolveTwice_metadataKeyUnChanged() { + Map serviceConfig1 = new HashMap(); + serviceConfig1.put("stickinessMetadataKey", "my-sticky-key1"); + Attributes attributes1 = Attributes.newBuilder() + .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig1).build(); + loadBalancer.handleResolvedAddressGroups(servers, attributes1); + Map stickinessMap1 = loadBalancer.getStickinessMapForTest(); + + loadBalancer.handleResolvedAddressGroups(servers, attributes1); + Map stickinessMap2 = loadBalancer.getStickinessMapForTest(); + + assertSame(stickinessMap1, stickinessMap2); + } + private static class FakeSocketAddress extends SocketAddress { final String name;