diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index e71a69559b..9f3e5991d6 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -67,6 +67,82 @@ public final class InternalProtocolNegotiators { return new TlsNegotiator(); } + /** + * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will be + * negotiated, the server TLS {@code handler} is added and writes to the {@link + * io.netty.channel.Channel} may happen immediately, even before the TLS Handshake is complete. + */ + public static InternalProtocolNegotiator.ProtocolNegotiator serverTls(SslContext sslContext) { + final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.serverTls(sslContext); + final class ServerTlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { + + @Override + public AsciiString scheme() { + return negotiator.scheme(); + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + return negotiator.newHandler(grpcHandler); + } + + @Override + public void close() { + negotiator.close(); + } + } + + return new ServerTlsNegotiator(); + } + + /** Returns a {@link ProtocolNegotiator} for plaintext client channel. */ + public static InternalProtocolNegotiator.ProtocolNegotiator plaintext() { + final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.plaintext(); + final class PlaintextNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { + + @Override + public AsciiString scheme() { + return negotiator.scheme(); + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + return negotiator.newHandler(grpcHandler); + } + + @Override + public void close() { + negotiator.close(); + } + } + + return new PlaintextNegotiator(); + } + + /** Returns a {@link ProtocolNegotiator} for plaintext server channel. */ + public static InternalProtocolNegotiator.ProtocolNegotiator serverPlaintext() { + final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.serverPlaintext(); + final class ServerPlaintextNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { + + @Override + public AsciiString scheme() { + return negotiator.scheme(); + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + return negotiator.newHandler(grpcHandler); + } + + @Override + public void close() { + negotiator.close(); + } + } + + return new ServerPlaintextNegotiator(); + } + /** * Internal version of {@link WaitUntilActiveHandler}. */ diff --git a/xds/src/main/java/io/grpc/xds/sds/XdsChannelBuilder.java b/xds/src/main/java/io/grpc/xds/sds/XdsChannelBuilder.java index 682b401246..2e4e5d2c05 100644 --- a/xds/src/main/java/io/grpc/xds/sds/XdsChannelBuilder.java +++ b/xds/src/main/java/io/grpc/xds/sds/XdsChannelBuilder.java @@ -16,14 +16,17 @@ package io.grpc.xds.sds; +import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; import io.grpc.ExperimentalApi; import io.grpc.ForwardingChannelBuilder; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; +import io.grpc.netty.InternalNettyChannelBuilder; import io.grpc.netty.NettyChannelBuilder; import io.grpc.xds.sds.internal.SdsProtocolNegotiators; import java.net.SocketAddress; import javax.annotation.CheckReturnValue; +import javax.annotation.Nullable; /** * A version of {@link ManagedChannelBuilder} to create xDS managed channels that will use SDS to @@ -34,9 +37,11 @@ public final class XdsChannelBuilder extends ForwardingChannelBuilder delegate() { return delegate; @@ -73,6 +87,8 @@ public final class XdsChannelBuilder extends ForwardingChannelBuilder { private final NettyServerBuilder delegate; + // TODO (sanjaypujare) integrate with xDS client to get downstreamTlsContext from LDS + @Nullable private DownstreamTlsContext downstreamTlsContext; + private XdsServerBuilder(NettyServerBuilder nettyDelegate) { this.delegate = nettyDelegate; } @@ -119,6 +123,15 @@ public final class XdsServerBuilder extends ServerBuilder { return this; } + /** + * Set the DownstreamTlsContext for the server. This is a temporary workaround until integration + * with xDS client is implemented to get LDS. Passing {@code null} will fall back to plaintext. + */ + public XdsServerBuilder tlsContext(@Nullable DownstreamTlsContext downstreamTlsContext) { + this.downstreamTlsContext = downstreamTlsContext; + return this; + } + /** Creates a gRPC server builder for the given port. */ public static XdsServerBuilder forPort(int port) { NettyServerBuilder nettyDelegate = NettyServerBuilder.forAddress(new InetSocketAddress(port)); @@ -128,7 +141,8 @@ public final class XdsServerBuilder extends ServerBuilder { @Override public Server build() { // note: doing it in build() will overwrite any previously set ProtocolNegotiator - delegate.protocolNegotiator(SdsProtocolNegotiators.serverProtocolNegotiator()); + delegate.protocolNegotiator( + SdsProtocolNegotiators.serverProtocolNegotiator(this.downstreamTlsContext)); return delegate.build(); } } diff --git a/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java index 9bc5129ad1..4a72ea20f0 100644 --- a/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java @@ -16,15 +16,30 @@ package io.grpc.xds.sds.internal; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; +import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; import io.grpc.Internal; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalNettyChannelBuilder; import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory; import io.grpc.netty.InternalProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.NettyChannelBuilder; +import io.grpc.xds.sds.SecretProvider; +import io.grpc.xds.sds.TlsContextManager; import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nullable; /** * Provides client and server side gRPC {@link ProtocolNegotiator}s that use SDS to provide the SSL @@ -35,12 +50,40 @@ public final class SdsProtocolNegotiators { private static final AsciiString SCHEME = AsciiString.of("https"); + /** + * Returns a {@link ProtocolNegotiatorFactory} to be used on {@link NettyChannelBuilder}. Passing + * {@code null} for upstreamTlsContext will fall back to plaintext. + */ + // TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS + public static ProtocolNegotiatorFactory clientProtocolNegotiatorFactory( + @Nullable UpstreamTlsContext upstreamTlsContext) { + return new ClientSdsProtocolNegotiatorFactory(upstreamTlsContext); + } + + /** + * Creates an SDS based {@link ProtocolNegotiator} for a {@link io.grpc.netty.NettyServerBuilder}. + * Passing {@code null} for downstreamTlsContext will fall back to plaintext. + */ + // TODO (sanjaypujare) integrate with xDS client to get LDS + public static ProtocolNegotiator serverProtocolNegotiator( + @Nullable DownstreamTlsContext downstreamTlsContext) { + return new ServerSdsProtocolNegotiator(downstreamTlsContext); + } + private static final class ClientSdsProtocolNegotiatorFactory implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory { + // TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS + private final UpstreamTlsContext upstreamTlsContext; + + ClientSdsProtocolNegotiatorFactory(UpstreamTlsContext upstreamTlsContext) { + this.upstreamTlsContext = upstreamTlsContext; + } + @Override public InternalProtocolNegotiator.ProtocolNegotiator buildProtocolNegotiator() { - final ClientSdsProtocolNegotiator negotiator = new ClientSdsProtocolNegotiator(); + final ClientSdsProtocolNegotiator negotiator = + new ClientSdsProtocolNegotiator(upstreamTlsContext); final class LocalSdsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { @Override @@ -63,7 +106,15 @@ public final class SdsProtocolNegotiators { } } - private static final class ClientSdsProtocolNegotiator implements ProtocolNegotiator { + @VisibleForTesting + static final class ClientSdsProtocolNegotiator implements ProtocolNegotiator { + + // TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS + UpstreamTlsContext upstreamTlsContext; + + ClientSdsProtocolNegotiator(UpstreamTlsContext upstreamTlsContext) { + this.upstreamTlsContext = upstreamTlsContext; + } @Override public AsciiString scheme() { @@ -72,16 +123,111 @@ public final class SdsProtocolNegotiators { @Override public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - // TODO(sanjaypujare): once implemented return ClientSdsHandler - throw new UnsupportedOperationException("Not implemented yet"); + // once CDS is implemented we will retrieve upstreamTlsContext as follows: + // grpcHandler.getEagAttributes().get(XdsAttributes.ATTR_UPSTREAM_TLS_CONTEXT); + if (isTlsContextEmpty(upstreamTlsContext)) { + return InternalProtocolNegotiators.plaintext().newHandler(grpcHandler); + } + return new ClientSdsHandler(grpcHandler, upstreamTlsContext); + } + + private static boolean isTlsContextEmpty(UpstreamTlsContext upstreamTlsContext) { + return upstreamTlsContext == null || !upstreamTlsContext.hasCommonTlsContext(); } @Override public void close() {} } + private static class BufferReadsHandler extends ChannelInboundHandlerAdapter { + private final List reads = new ArrayList<>(); + private boolean readComplete; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + reads.add(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + readComplete = true; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + for (Object msg : reads) { + super.channelRead(ctx, msg); + } + if (readComplete) { + super.channelReadComplete(ctx); + } + } + } + + @VisibleForTesting + static final class ClientSdsHandler + extends InternalProtocolNegotiators.ProtocolNegotiationHandler { + private final GrpcHttp2ConnectionHandler grpcHandler; + private final UpstreamTlsContext upstreamTlsContext; + + ClientSdsHandler( + GrpcHttp2ConnectionHandler grpcHandler, UpstreamTlsContext upstreamTlsContext) { + super( + // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' + // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior + // here and then manually add 'next' when we call fireProtocolNegotiationEvent() + new ChannelHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + ctx.pipeline().remove(this); + } + }); + checkNotNull(grpcHandler, "grpcHandler"); + this.grpcHandler = grpcHandler; + this.upstreamTlsContext = upstreamTlsContext; + } + + @Override + protected void handlerAdded0(final ChannelHandlerContext ctx) { + final BufferReadsHandler bufferReads = new BufferReadsHandler(); + ctx.pipeline().addBefore(ctx.name(), null, bufferReads); + + SecretProvider sslContextProvider = + TlsContextManager.getInstance().findOrCreateClientSslContextProvider(upstreamTlsContext); + + sslContextProvider.addCallback( + new SecretProvider.Callback() { + + @Override + public void updateSecret(SslContext sslContext) { + ChannelHandler handler = + InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); + + // Delegate rest of handshake to TLS handler + ctx.pipeline().addAfter(ctx.name(), null, handler); + fireProtocolNegotiationEvent(ctx); + ctx.pipeline().remove(bufferReads); + } + + @Override + public void onException(Throwable throwable) { + ctx.fireExceptionCaught(throwable); + } + }, + ctx.executor()); + } + } + private static final class ServerSdsProtocolNegotiator implements ProtocolNegotiator { + // TODO (sanjaypujare) integrate with xDS client to get LDS. LDS watcher will + // inject/update the downstreamTlsContext from LDS + private DownstreamTlsContext downstreamTlsContext; + + ServerSdsProtocolNegotiator(DownstreamTlsContext downstreamTlsContext) { + this.downstreamTlsContext = downstreamTlsContext; + } + @Override public AsciiString scheme() { return SCHEME; @@ -89,22 +235,72 @@ public final class SdsProtocolNegotiators { @Override public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - // TODO(sanjaypujare): once implemented return ServerSdsHandler - throw new UnsupportedOperationException("Not implemented yet"); + if (isTlsContextEmpty(downstreamTlsContext)) { + return InternalProtocolNegotiators.serverPlaintext().newHandler(grpcHandler); + } + return new ServerSdsHandler(grpcHandler, downstreamTlsContext); + } + + private static boolean isTlsContextEmpty(DownstreamTlsContext downstreamTlsContext) { + return downstreamTlsContext == null || !downstreamTlsContext.hasCommonTlsContext(); } @Override public void close() {} } - /** Sets the {@link ProtocolNegotiatorFactory} on a NettyChannelBuilder. */ - public static void setProtocolNegotiatorFactory(NettyChannelBuilder builder) { - InternalNettyChannelBuilder.setProtocolNegotiatorFactory( - builder, new ClientSdsProtocolNegotiatorFactory()); - } + @VisibleForTesting + static final class ServerSdsHandler + extends InternalProtocolNegotiators.ProtocolNegotiationHandler { + private final GrpcHttp2ConnectionHandler grpcHandler; + private final DownstreamTlsContext downstreamTlsContext; - /** Creates an SDS based {@link ProtocolNegotiator} for a server. */ - public static ProtocolNegotiator serverProtocolNegotiator() { - return new ServerSdsProtocolNegotiator(); + ServerSdsHandler( + GrpcHttp2ConnectionHandler grpcHandler, DownstreamTlsContext downstreamTlsContext) { + super( + // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' + // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior + // here and then manually add 'next' when we call fireProtocolNegotiationEvent() + new ChannelHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + ctx.pipeline().remove(this); + } + }); + checkNotNull(grpcHandler, "grpcHandler"); + this.grpcHandler = grpcHandler; + this.downstreamTlsContext = downstreamTlsContext; + } + + @Override + protected void handlerAdded0(final ChannelHandlerContext ctx) { + final BufferReadsHandler bufferReads = new BufferReadsHandler(); + ctx.pipeline().addBefore(ctx.name(), null, bufferReads); + + SecretProvider sslContextProvider = + TlsContextManager.getInstance() + .findOrCreateServerSslContextProvider(downstreamTlsContext); + + sslContextProvider.addCallback( + new SecretProvider.Callback() { + + @Override + public void updateSecret(SslContext sslContext) { + ChannelHandler handler = + InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHandler); + + // Delegate rest of handshake to TLS handler + ctx.pipeline().addAfter(ctx.name(), null, handler); + fireProtocolNegotiationEvent(ctx); + ctx.pipeline().remove(bufferReads); + } + + @Override + public void onException(Throwable throwable) { + ctx.fireExceptionCaught(throwable); + } + }, + ctx.executor()); + } } } diff --git a/xds/src/test/java/io/grpc/xds/sds/XdsSdsClientServerTest.java b/xds/src/test/java/io/grpc/xds/sds/XdsSdsClientServerTest.java new file mode 100644 index 0000000000..5a2687ba85 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/sds/XdsSdsClientServerTest.java @@ -0,0 +1,193 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.sds; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.protobuf.BoolValue; +import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; +import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; +import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; +import io.envoyproxy.envoy.api.v2.auth.TlsCertificate; +import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; +import io.envoyproxy.envoy.api.v2.core.DataSource; +import io.grpc.Server; +import io.grpc.internal.testing.TestUtils; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import java.io.IOException; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link XdsChannelBuilder} and {@link XdsServerBuilder} for plaintext/TLS/mTLS + * modes. + */ +@RunWith(JUnit4.class) +public class XdsSdsClientServerTest { + + @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + + @Test + public void plaintextClientServer() throws IOException { + Server server = getXdsServer(/* downstreamTlsContext= */ null); + buildClientAndTest( + /* upstreamTlsContext= */ null, /* overrideAuthority= */ null, "buddy", server.getPort()); + } + + /** TLS channel - no mTLS. */ + @Test + public void tlsClientServer_noClientAuthentication() throws IOException { + String server1Pem = TestUtils.loadCert("server1.pem").getAbsolutePath(); + String server1Key = TestUtils.loadCert("server1.key").getAbsolutePath(); + + TlsCertificate tlsCert = + TlsCertificate.newBuilder() + .setPrivateKey(DataSource.newBuilder().setFilename(server1Key).build()) + .setCertificateChain(DataSource.newBuilder().setFilename(server1Pem).build()) + .build(); + + CommonTlsContext commonTlsContext = + CommonTlsContext.newBuilder().addTlsCertificates(tlsCert).build(); + + DownstreamTlsContext downstreamTlsContext = + DownstreamTlsContext.newBuilder() + .setCommonTlsContext(commonTlsContext) + .setRequireClientCertificate(BoolValue.of(false)) + .build(); + + Server server = getXdsServer(downstreamTlsContext); + + // for TLS client doesn't need cert but needs trustCa + String trustCa = TestUtils.loadCert("ca.pem").getAbsolutePath(); + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder() + .setTrustedCa(DataSource.newBuilder().setFilename(trustCa).build()) + .build(); + + CommonTlsContext commonTlsContext1 = + CommonTlsContext.newBuilder().setValidationContext(certContext).build(); + + UpstreamTlsContext upstreamTlsContext = + UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext1).build(); + buildClientAndTest(upstreamTlsContext, "foo.test.google.fr", "buddy", server.getPort()); + } + + /** mTLS - client auth enabled. */ + @Test + public void mtlsClientServer_withClientAuthentication() throws IOException, InterruptedException { + String server1Pem = TestUtils.loadCert("server1.pem").getAbsolutePath(); + String server1Key = TestUtils.loadCert("server1.key").getAbsolutePath(); + String trustCa = TestUtils.loadCert("ca.pem").getAbsolutePath(); + + TlsCertificate tlsCert = + TlsCertificate.newBuilder() + .setPrivateKey(DataSource.newBuilder().setFilename(server1Key).build()) + .setCertificateChain(DataSource.newBuilder().setFilename(server1Pem).build()) + .build(); + + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder() + .setTrustedCa(DataSource.newBuilder().setFilename(trustCa).build()) + .build(); + + CommonTlsContext commonTlsContext = + CommonTlsContext.newBuilder() + .addTlsCertificates(tlsCert) + .setValidationContext(certContext) + .build(); + + DownstreamTlsContext downstreamTlsContext = + DownstreamTlsContext.newBuilder() + .setCommonTlsContext(commonTlsContext) + .setRequireClientCertificate(BoolValue.of(false)) + .build(); + + Server server = getXdsServer(downstreamTlsContext); + + String clientPem = TestUtils.loadCert("client.pem").getAbsolutePath(); + String clientKey = TestUtils.loadCert("client.key").getAbsolutePath(); + + TlsCertificate tlsCert1 = + TlsCertificate.newBuilder() + .setPrivateKey(DataSource.newBuilder().setFilename(clientKey).build()) + .setCertificateChain(DataSource.newBuilder().setFilename(clientPem).build()) + .build(); + + CommonTlsContext commonTlsContext1 = + CommonTlsContext.newBuilder() + .addTlsCertificates(tlsCert1) + .setValidationContext(certContext) + .build(); + + UpstreamTlsContext upstreamTlsContext = + UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext1).build(); + + buildClientAndTest(upstreamTlsContext, "foo.test.google.fr", "buddy", server.getPort()); + } + + private Server getXdsServer(DownstreamTlsContext downstreamTlsContext) throws IOException { + XdsServerBuilder serverBuilder = + XdsServerBuilder.forPort(0) // get unused port + .addService(new SimpleServiceImpl()) + .tlsContext(downstreamTlsContext); + return cleanupRule.register(serverBuilder.build()).start(); + } + + private void buildClientAndTest( + UpstreamTlsContext upstreamTlsContext, + String overrideAuthority, + String requestMessage, + int serverPort) { + + XdsChannelBuilder builder = + XdsChannelBuilder.forTarget("localhost:" + serverPort).tlsContext(upstreamTlsContext); + if (overrideAuthority != null) { + builder = builder.overrideAuthority(overrideAuthority); + } + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + SimpleServiceGrpc.newBlockingStub(cleanupRule.register(builder.build())); + String resp = unaryRpc(requestMessage, blockingStub); + assertThat(resp).isEqualTo("Hello " + requestMessage); + } + + /** Say hello to server. */ + private static String unaryRpc( + String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) { + SimpleRequest request = SimpleRequest.newBuilder().setRequestMessage(requestMessage).build(); + SimpleResponse response = blockingStub.unaryRpc(request); + return response.getResponseMessage(); + } + + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + + @Override + public void unaryRpc(SimpleRequest req, StreamObserver responseObserver) { + SimpleResponse response = + SimpleResponse.newBuilder() + .setResponseMessage("Hello " + req.getRequestMessage()) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/sds/internal/SdsProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/sds/internal/SdsProtocolNegotiatorsTest.java new file mode 100644 index 0000000000..ea3c65364e --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/sds/internal/SdsProtocolNegotiatorsTest.java @@ -0,0 +1,263 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.sds.internal; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import com.google.common.base.Strings; +import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; +import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; +import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; +import io.envoyproxy.envoy.api.v2.auth.TlsCertificate; +import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; +import io.envoyproxy.envoy.api.v2.core.DataSource; +import io.grpc.internal.testing.TestUtils; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiationEvent; +import io.grpc.xds.sds.internal.SdsProtocolNegotiators.ClientSdsHandler; +import io.grpc.xds.sds.internal.SdsProtocolNegotiators.ClientSdsProtocolNegotiator; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http2.DefaultHttp2Connection; +import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; +import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder; +import io.netty.handler.codec.http2.DefaultHttp2FrameReader; +import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; +import io.netty.handler.codec.http2.Http2ConnectionDecoder; +import io.netty.handler.codec.http2.Http2ConnectionEncoder; +import io.netty.handler.codec.http2.Http2Settings; +import io.netty.handler.ssl.SslHandler; +import io.netty.handler.ssl.SslHandshakeCompletionEvent; +import java.io.IOException; +import java.util.Iterator; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link SdsProtocolNegotiators}. */ +@RunWith(JUnit4.class) +public class SdsProtocolNegotiatorsTest { + + private static final String SERVER_1_PEM_FILE = "server1.pem"; + private static final String SERVER_1_KEY_FILE = "server1.key"; + private static final String CLIENT_PEM_FILE = "client.pem"; + private static final String CLIENT_KEY_FILE = "client.key"; + private static final String CA_PEM_FILE = "ca.pem"; + + private final GrpcHttp2ConnectionHandler grpcHandler = + FakeGrpcHttp2ConnectionHandler.newHandler(); + + private EmbeddedChannel channel = new EmbeddedChannel(); + private ChannelPipeline pipeline = channel.pipeline(); + private ChannelHandlerContext channelHandlerCtx; + + private static String getTempFileNameForResourcesFile(String resFile) throws IOException { + return Strings.isNullOrEmpty(resFile) ? null : TestUtils.loadCert(resFile).getAbsolutePath(); + } + + /** Builds DownstreamTlsContext from file-names. */ + private static DownstreamTlsContext buildDownstreamTlsContextFromFilenames( + String privateKey, String certChain, String trustCa) throws IOException { + return buildDownstreamTlsContext( + buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa)); + } + + /** Builds UpstreamTlsContext from file-names. */ + private static UpstreamTlsContext buildUpstreamTlsContextFromFilenames( + String privateKey, String certChain, String trustCa) throws IOException { + return buildUpstreamTlsContext( + buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa)); + } + + /** Builds UpstreamTlsContext from commonTlsContext. */ + private static UpstreamTlsContext buildUpstreamTlsContext(CommonTlsContext commonTlsContext) { + UpstreamTlsContext upstreamTlsContext = + UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build(); + return upstreamTlsContext; + } + + /** Builds DownstreamTlsContext from commonTlsContext. */ + private static DownstreamTlsContext buildDownstreamTlsContext(CommonTlsContext commonTlsContext) { + DownstreamTlsContext downstreamTlsContext = + DownstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build(); + return downstreamTlsContext; + } + + private static CommonTlsContext buildCommonTlsContextFromFilenames( + String privateKey, String certChain, String trustCa) throws IOException { + TlsCertificate tlsCert = null; + privateKey = getTempFileNameForResourcesFile(privateKey); + certChain = getTempFileNameForResourcesFile(certChain); + trustCa = getTempFileNameForResourcesFile(trustCa); + if (!Strings.isNullOrEmpty(privateKey) && !Strings.isNullOrEmpty(certChain)) { + tlsCert = + TlsCertificate.newBuilder() + .setCertificateChain(DataSource.newBuilder().setFilename(certChain)) + .setPrivateKey(DataSource.newBuilder().setFilename(privateKey)) + .build(); + } + CertificateValidationContext certContext = null; + if (!Strings.isNullOrEmpty(trustCa)) { + certContext = + CertificateValidationContext.newBuilder() + .setTrustedCa(DataSource.newBuilder().setFilename(trustCa)) + .build(); + } + return getCommonTlsContext(tlsCert, certContext); + } + + private static CommonTlsContext getCommonTlsContext( + TlsCertificate tlsCertificate, CertificateValidationContext certContext) { + CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); + if (tlsCertificate != null) { + builder = builder.addTlsCertificates(tlsCertificate); + } + if (certContext != null) { + builder = builder.setValidationContext(certContext); + } + return builder.build(); + } + + @Test + public void clientSdsProtocolNegotiatorNewHandler_nullTlsContext() { + ClientSdsProtocolNegotiator pn = + new ClientSdsProtocolNegotiator(/* upstreamTlsContext= */ null); + ChannelHandler newHandler = pn.newHandler(grpcHandler); + assertThat(newHandler).isNotNull(); + // ProtocolNegotiators.WaitUntilActiveHandler not accessible, get canonical name + assertThat(newHandler.getClass().getCanonicalName()) + .contains("io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler"); + } + + @Test + public void clientSdsProtocolNegotiatorNewHandler_nonNullTlsContext() { + UpstreamTlsContext upstreamTlsContext = + buildUpstreamTlsContext(getCommonTlsContext(null, null)); + ClientSdsProtocolNegotiator pn = new ClientSdsProtocolNegotiator(upstreamTlsContext); + ChannelHandler newHandler = pn.newHandler(grpcHandler); + assertThat(newHandler).isNotNull(); + assertThat(newHandler).isInstanceOf(ClientSdsHandler.class); + } + + @Test + public void clientSdsHandler_addLast() throws IOException { + UpstreamTlsContext upstreamTlsContext = + buildUpstreamTlsContextFromFilenames(CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); + + SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler = + new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, upstreamTlsContext); + pipeline.addLast(clientSdsHandler); + channelHandlerCtx = pipeline.context(clientSdsHandler); + assertNotNull(channelHandlerCtx); // clientSdsHandler ctx is non-null since we just added it + + // kick off protocol negotiation. + pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + channel.runPendingTasks(); // need this for tasks to execute on eventLoop + channelHandlerCtx = pipeline.context(clientSdsHandler); + assertThat(channelHandlerCtx).isNull(); + + // pipeline should have SslHandler and ClientTlsHandler + Iterator> iterator = pipeline.iterator(); + assertThat(iterator.next().getValue()).isInstanceOf(SslHandler.class); + // ProtocolNegotiators.ClientTlsHandler.class not accessible, get canonical name + assertThat(iterator.next().getValue().getClass().getCanonicalName()) + .contains("ProtocolNegotiators.ClientTlsHandler"); + } + + @Test + public void serverSdsHandler_addLast() throws IOException { + DownstreamTlsContext downstreamTlsContext = + buildDownstreamTlsContextFromFilenames(SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); + + SdsProtocolNegotiators.ServerSdsHandler serverSdsHandler = + new SdsProtocolNegotiators.ServerSdsHandler(grpcHandler, downstreamTlsContext); + pipeline.addLast(serverSdsHandler); + channelHandlerCtx = pipeline.context(serverSdsHandler); + assertNotNull(channelHandlerCtx); // serverSdsHandler ctx is non-null since we just added it + + // kick off protocol negotiation + pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + channel.runPendingTasks(); // need this for tasks to execute on eventLoop + channelHandlerCtx = pipeline.context(serverSdsHandler); + assertThat(channelHandlerCtx).isNull(); + + // pipeline should have SslHandler and ServerTlsHandler + Iterator> iterator = pipeline.iterator(); + assertThat(iterator.next().getValue()).isInstanceOf(SslHandler.class); + // ProtocolNegotiators.ServerTlsHandler.class is not accessible, get canonical name + assertThat(iterator.next().getValue().getClass().getCanonicalName()) + .contains("ProtocolNegotiators.ServerTlsHandler"); + } + + @Test + public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() + throws IOException, InterruptedException { + UpstreamTlsContext upstreamTlsContext = + buildUpstreamTlsContextFromFilenames(CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); + + SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler = + new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, upstreamTlsContext); + + pipeline.addLast(clientSdsHandler); + channelHandlerCtx = pipeline.context(clientSdsHandler); + assertNotNull(channelHandlerCtx); // non-null since we just added it + + // kick off protocol negotiation. + pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + channel.runPendingTasks(); // need this for tasks to execute on eventLoop + channelHandlerCtx = pipeline.context(clientSdsHandler); + assertThat(channelHandlerCtx).isNull(); + Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; + + pipeline.fireUserEventTriggered(sslEvent); + channel.runPendingTasks(); // need this for tasks to execute on eventLoop + assertTrue(channel.isOpen()); + } + + private static final class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler { + + FakeGrpcHttp2ConnectionHandler( + ChannelPromise channelUnused, + Http2ConnectionDecoder decoder, + Http2ConnectionEncoder encoder, + Http2Settings initialSettings) { + super(channelUnused, decoder, encoder, initialSettings); + } + + static FakeGrpcHttp2ConnectionHandler newHandler() { + DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false); + DefaultHttp2ConnectionEncoder encoder = + new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter()); + DefaultHttp2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(conn, encoder, new DefaultHttp2FrameReader()); + Http2Settings settings = new Http2Settings(); + return new FakeGrpcHttp2ConnectionHandler( + /*channelUnused=*/ null, decoder, encoder, settings); + } + + @Override + public String getAuthority() { + return "authority"; + } + } +}