From 29753f2009822f806f48771272a920c128f1e8fb Mon Sep 17 00:00:00 2001 From: Chengyuan Zhang Date: Tue, 19 Jan 2021 10:16:06 -0800 Subject: [PATCH] xds: google_default should use TLS if address contains no cluster name (#7818) Fixes bug introduced by 4130c5a1b8016656081258f1ceff393b59b9c7b2. TLS should be selected for addresses without cluster name attributes, even if grpc-xds is in classpath. --- .../alts/internal/AltsProtocolNegotiator.java | 9 +- .../GoogleDefaultProtocolNegotiatorTest.java | 202 ++++++++++-------- 2 files changed, 122 insertions(+), 89 deletions(-) diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java index 0b4a90748d..7c4502e3d3 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java @@ -247,8 +247,13 @@ public final class AltsProtocolNegotiator { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler); ChannelHandler securityHandler; - boolean isXdsDirectPath = clusterNameAttrKey != null - && !"google_cfe".equals(grpcHandler.getEagAttributes().get(clusterNameAttrKey)); + boolean isXdsDirectPath = false; + if (clusterNameAttrKey != null) { + String clusterName = grpcHandler.getEagAttributes().get(clusterNameAttrKey); + if (clusterName != null && !clusterName.equals("google_cfe")) { + isXdsDirectPath = true; + } + } if (grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_ADDR_AUTHORITY) != null || grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND) != null || isXdsDirectPath) { diff --git a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java index bc01f83102..5ac2669b3a 100644 --- a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java @@ -36,110 +36,138 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.ssl.SslContext; -import java.util.Arrays; import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nullable; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.junit.experimental.runners.Enclosed; import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.junit.runners.Parameterized.Parameters; +import org.junit.runners.JUnit4; -@RunWith(Parameterized.class) +@RunWith(Enclosed.class) public final class GoogleDefaultProtocolNegotiatorTest { - @Parameterized.Parameter - public boolean withXds; - private ProtocolNegotiator googleProtocolNegotiator; + @RunWith(JUnit4.class) + public abstract static class HandlerSelectionTest { + private ProtocolNegotiator googleProtocolNegotiator; + private final ObjectPool handshakerChannelPool = new ObjectPool() { - // Same as io.grpc.xds.InternalXdsAttributes.ATTR_CLUSTER_NAME - private final Attributes.Key clusterNameAttrKey = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.clusterName"); - private final ObjectPool handshakerChannelPool = new ObjectPool() { - - @Override - public Channel getObject() { - return InProcessChannelBuilder.forName("test").build(); - } - - @Override - public Channel returnObject(Object object) { - ((ManagedChannel) object).shutdownNow(); - return null; - } - }; - - @Parameters(name = "Run with xDS : {0}") - public static Iterable data() { - return Arrays.asList(true, false); - } - - @Before - public void setUp() throws Exception { - SslContext sslContext = GrpcSslContexts.forClient().build(); - - googleProtocolNegotiator = new AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory( - ImmutableList.of(), - handshakerChannelPool, - sslContext, - withXds ? clusterNameAttrKey : null) - .newNegotiator(); - } - - @After - public void tearDown() { - googleProtocolNegotiator.close(); - } - - @Test - public void altsHandler() { - Attributes eagAttributes; - if (withXds) { - eagAttributes = - Attributes.newBuilder().set(clusterNameAttrKey, "api.googleapis.com").build(); - } else { - eagAttributes = - Attributes.newBuilder().set(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND, true).build(); - } - GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); - when(mockHandler.getEagAttributes()).thenReturn(eagAttributes); - - final AtomicReference failure = new AtomicReference<>(); - ChannelHandler exceptionCaught = new ChannelInboundHandlerAdapter() { @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - failure.set(cause); - super.exceptionCaught(ctx, cause); + public Channel getObject() { + return InProcessChannelBuilder.forName("test").build(); + } + + @Override + public Channel returnObject(Object object) { + ((ManagedChannel) object).shutdownNow(); + return null; } }; - ChannelHandler h = googleProtocolNegotiator.newHandler(mockHandler); - EmbeddedChannel chan = new EmbeddedChannel(exceptionCaught); - // Add the negotiator handler last, but to the front. Putting this in ctor above would make it - // throw early. - chan.pipeline().addFirst(h); - chan.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); - // Check that the message complained about the ALTS code, rather than SSL. ALTS throws on - // being added, so it's hard to catch it at the right time to make this assertion. - assertThat(failure.get()).hasMessageThat().contains("TsiHandshakeHandler"); + @Before + public void setUp() throws Exception { + SslContext sslContext = GrpcSslContexts.forClient().build(); + + googleProtocolNegotiator = new AltsProtocolNegotiator.GoogleDefaultProtocolNegotiatorFactory( + ImmutableList.of(), + handshakerChannelPool, + sslContext, + getClusterNameAttrKey()) + .newNegotiator(); + } + + @After + public void tearDown() { + googleProtocolNegotiator.close(); + } + + @Nullable + abstract Attributes.Key getClusterNameAttrKey(); + + @Test + public void altsHandler_lbProvidedBackend() { + Attributes attrs = + Attributes.newBuilder().set(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND, true).build(); + subtest_altsHandler(attrs); + } + + @Test + public void tlsHandler_emptyAttributes() { + subtest_tlsHandler(Attributes.EMPTY); + } + + void subtest_altsHandler(Attributes eagAttributes) { + GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); + when(mockHandler.getEagAttributes()).thenReturn(eagAttributes); + + final AtomicReference failure = new AtomicReference<>(); + ChannelHandler exceptionCaught = new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + failure.set(cause); + super.exceptionCaught(ctx, cause); + } + }; + ChannelHandler h = googleProtocolNegotiator.newHandler(mockHandler); + EmbeddedChannel chan = new EmbeddedChannel(exceptionCaught); + // Add the negotiator handler last, but to the front. Putting this in ctor above would make + // it throw early. + chan.pipeline().addFirst(h); + chan.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + + // Check that the message complained about the ALTS code, rather than SSL. ALTS throws on + // being added, so it's hard to catch it at the right time to make this assertion. + assertThat(failure.get()).hasMessageThat().contains("TsiHandshakeHandler"); + } + + void subtest_tlsHandler(Attributes eagAttributes) { + GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); + when(mockHandler.getEagAttributes()).thenReturn(eagAttributes); + when(mockHandler.getAuthority()).thenReturn("authority"); + + ChannelHandler h = googleProtocolNegotiator.newHandler(mockHandler); + EmbeddedChannel chan = new EmbeddedChannel(h); + chan.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + + assertThat(chan.pipeline().first().getClass().getSimpleName()).isEqualTo("SslHandler"); + } } - @Test - public void tlsHandler() { - Attributes eagAttributes; - if (withXds) { - eagAttributes = Attributes.newBuilder().set(clusterNameAttrKey, "google_cfe").build(); - } else { - eagAttributes = Attributes.EMPTY; + @RunWith(JUnit4.class) + public static class WithoutXdsInClasspath extends HandlerSelectionTest { + + @Nullable + @Override + Attributes.Key getClusterNameAttrKey() { + return null; } - GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); - when(mockHandler.getEagAttributes()).thenReturn(eagAttributes); - when(mockHandler.getAuthority()).thenReturn("authority"); + } - ChannelHandler h = googleProtocolNegotiator.newHandler(mockHandler); - EmbeddedChannel chan = new EmbeddedChannel(h); - chan.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + @RunWith(JUnit4.class) + public static class WithXdsInClasspath extends HandlerSelectionTest { + // Same as io.grpc.xds.InternalXdsAttributes.ATTR_CLUSTER_NAME + private static final Attributes.Key XDS_CLUSTER_NAME_ATTR_KEY = + Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.clusterName"); - assertThat(chan.pipeline().first().getClass().getSimpleName()).isEqualTo("SslHandler"); + @Nullable + @Override + Attributes.Key getClusterNameAttrKey() { + return XDS_CLUSTER_NAME_ATTR_KEY; + } + + @Test + public void altsHandler_xdsCluster() { + Attributes attrs = + Attributes.newBuilder().set(XDS_CLUSTER_NAME_ATTR_KEY, "api.googleapis.com").build(); + subtest_altsHandler(attrs); + } + + @Test + public void tlsHandler_googleCfe() { + Attributes attrs = + Attributes.newBuilder().set(XDS_CLUSTER_NAME_ATTR_KEY, "google_cfe").build(); + subtest_tlsHandler(attrs); + } } }