From a649737e3a238f0bf84b0f9e73d5b06e5ba016e0 Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Mon, 20 Apr 2020 17:04:38 -0700 Subject: [PATCH] xds: implement requireClientCertificate semantics (#6948) --- .../internal/sds/SdsSslContextProvider.java | 4 +- .../sds/SecretVolumeSslContextProvider.java | 4 +- .../xds/internal/sds/SslContextProvider.java | 33 ++++++++++- .../io/grpc/xds/XdsSdsClientServerTest.java | 58 ++++++++++++++++++- .../sds/CommonTlsContextTestsUtil.java | 39 ++++++++++--- .../sds/SdsSslContextProviderTest.java | 3 +- .../SecretVolumeSslContextProviderTest.java | 6 +- .../ServerSslContextProviderFactoryTest.java | 6 +- 8 files changed, 133 insertions(+), 20 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java index 88a0a1dae1..4af26f6ee8 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsSslContextProvider.java @@ -234,9 +234,7 @@ final class SdsSslContextProvider extends SslContextProvider tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null); - if (localCertValidationContext != null) { - sslContextBuilder.trustManager(new SdsTrustManagerFactory(localCertValidationContext)); - } + setClientAuthValues(sslContextBuilder, localCertValidationContext); } else { logger.log(Level.FINEST, "for client"); sslContextBuilder = diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProvider.java index d843b3bf58..c545c81979 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProvider.java @@ -205,9 +205,7 @@ final class SecretVolumeSslContextProvider extends SslContextProvider { sslContextBuilder = GrpcSslContexts.forServer( new File(certificateChain), new File(privateKey), privateKeyPassword); - if (certContext != null) { - sslContextBuilder.trustManager(new SdsTrustManagerFactory(certContext)); - } + setClientAuthValues(sslContextBuilder, certContext); } else { sslContextBuilder = GrpcSslContexts.forClient().trustManager(new SdsTrustManagerFactory(certContext)); diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java index e8b4f22252..4ec4e6cefa 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java @@ -16,12 +16,21 @@ package io.grpc.xds.internal.sds; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; +import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory; +import io.netty.handler.ssl.ClientAuth; import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.IOException; +import java.security.cert.CertStoreException; +import java.security.cert.CertificateException; import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; @@ -32,6 +41,8 @@ import java.util.logging.Logger; * stream that is receiving the requested secret(s) or it could represent file-system based * secret(s) that are dynamic. */ +// TODO(sanjaypujare): replace generic K with DownstreamTlsContext & UpstreamTlsContext in +// separate client&server classes public abstract class SslContextProvider { private static final Logger logger = Logger.getLogger(SslContextProvider.class.getName()); @@ -48,7 +59,11 @@ public abstract class SslContextProvider { } protected SslContextProvider(K source, boolean server) { - checkNotNull(source, "source"); + if (server) { + checkArgument(source instanceof DownstreamTlsContext, "expecting DownstreamTlsContext"); + } else { + checkArgument(source instanceof UpstreamTlsContext, "expecting UpstreamTlsContext"); + } this.source = source; this.server = server; } @@ -66,6 +81,22 @@ public abstract class SslContextProvider { return null; } + protected void setClientAuthValues( + SslContextBuilder sslContextBuilder, CertificateValidationContext localCertValidationContext) + throws CertificateException, IOException, CertStoreException { + checkState(server, "server side SslContextProvider expected"); + if (localCertValidationContext != null) { + sslContextBuilder.trustManager(new SdsTrustManagerFactory(localCertValidationContext)); + DownstreamTlsContext downstreamTlsContext = (DownstreamTlsContext)getSource(); + sslContextBuilder.clientAuth( + downstreamTlsContext.hasRequireClientCertificate() + ? ClientAuth.REQUIRE + : ClientAuth.OPTIONAL); + } else { + sslContextBuilder.clientAuth(ClientAuth.NONE); + } + } + /** Closes this provider and releases any resources. */ void close() {} diff --git a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java index d8de48790e..8156e04954 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java @@ -18,6 +18,8 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; import static io.grpc.xds.XdsClientWrapperForServerSdsTest.buildFilterChainMatch; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_KEY_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; @@ -103,6 +105,60 @@ public class XdsSdsClientServerTest { assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); } + @Test + public void requireClientAuth_noClientCert_expectException() + throws IOException, URISyntaxException { + DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenamesWithClientCertRequired( + SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); + buildServerWithTlsContext(downstreamTlsContext); + + // for TLS, client only uses trustCa + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( + /* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); + try { + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + fail("exception expected"); + } catch (StatusRuntimeException sre) { + assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class); + assertThat(sre).hasCauseThat().hasMessageThat().contains("HANDSHAKE_FAILURE"); + } + } + + @Test + public void noClientAuth_sendBadClientCert_passes() throws IOException, URISyntaxException { + DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( + SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( + BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); + assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); + } + + @Test + public void mtls_badClientCert_expectException() throws IOException, URISyntaxException { + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( + BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE); + try { + XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext); + fail("exception expected"); + } catch (StatusRuntimeException sre) { + assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class); + assertThat(sre).hasCauseThat().hasMessageThat().contains("HANDSHAKE_FAILURE"); + } + } + /** mTLS - client auth enabled. */ @Test public void mtlsClientServer_withClientAuthentication() throws IOException, URISyntaxException { @@ -178,7 +234,7 @@ public class XdsSdsClientServerTest { private XdsClient.ListenerWatcher performMtlsTestAndGetListenerWatcher( UpstreamTlsContext upstreamTlsContext) throws IOException, URISyntaxException { DownstreamTlsContext downstreamTlsContext = - CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenamesWithClientCertRequired( SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); final XdsClientWrapperForServerSds xdsClientWrapperForServerSds = diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java index 0f1bb235d3..e545ea5f43 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java @@ -17,6 +17,7 @@ package io.grpc.xds.internal.sds; import com.google.common.base.Strings; +import com.google.protobuf.BoolValue; import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValidationContext; @@ -134,12 +135,14 @@ public class CommonTlsContextTestsUtil { return builder.build(); } - /** - * Helper method to build DownstreamTlsContext for multiple test classes. - */ - static DownstreamTlsContext buildDownstreamTlsContext(CommonTlsContext commonTlsContext) { + /** Helper method to build DownstreamTlsContext for multiple test classes. */ + static DownstreamTlsContext buildDownstreamTlsContext( + CommonTlsContext commonTlsContext, boolean requireClientCert) { DownstreamTlsContext downstreamTlsContext = - DownstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build(); + DownstreamTlsContext.newBuilder() + .setCommonTlsContext(commonTlsContext) + .setRequireClientCertificate(BoolValue.of(requireClientCert)) + .build(); return downstreamTlsContext; } @@ -159,7 +162,8 @@ public class CommonTlsContextTestsUtil { "unix:/var/run/sds/uds_path", Arrays.asList("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"), Arrays.asList("managed-tls"), - null)); + null), + /* requireClientCert= */ false); } static String getTempFileNameForResourcesFile(String resFile) throws IOException { @@ -171,6 +175,27 @@ public class CommonTlsContextTestsUtil { */ public static DownstreamTlsContext buildDownstreamTlsContextFromFilenames( @Nullable String privateKey, @Nullable String certChain, @Nullable String trustCa) { + return buildDownstreamTlsContextFromFilenamesWithClientAuth(privateKey, certChain, trustCa, + false); + } + + /** + * Helper method to build DownstreamTlsContext for above tests. Called from other classes as well. + */ + public static DownstreamTlsContext buildDownstreamTlsContextFromFilenamesWithClientCertRequired( + @Nullable String privateKey, + @Nullable String certChain, + @Nullable String trustCa) { + + return buildDownstreamTlsContextFromFilenamesWithClientAuth(privateKey, certChain, trustCa, + true); + } + + private static DownstreamTlsContext buildDownstreamTlsContextFromFilenamesWithClientAuth( + @Nullable String privateKey, + @Nullable String certChain, + @Nullable String trustCa, + boolean requireClientCert) { // get temp file for each file try { if (certChain != null) { @@ -186,7 +211,7 @@ public class CommonTlsContextTestsUtil { throw new RuntimeException(ioe); } return buildDownstreamTlsContext( - buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa)); + buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa), requireClientCert); } /** diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java index db8e13d525..18759eb282 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java @@ -79,7 +79,8 @@ public class SdsSslContextProviderTest { return server ? SdsSslContextProvider.getProviderForServer( - CommonTlsContextTestsUtil.buildDownstreamTlsContext(commonTlsContext), + CommonTlsContextTestsUtil.buildDownstreamTlsContext( + commonTlsContext, /* requireClientCert= */ false), node, MoreExecutors.directExecutor(), MoreExecutors.directExecutor()) diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java index 6d5e5bdc83..86c4b816e5 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java @@ -274,7 +274,8 @@ public class SecretVolumeSslContextProviderTest { try { SecretVolumeSslContextProvider.getProviderForServer( CommonTlsContextTestsUtil.buildDownstreamTlsContext( - CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, /* certContext= */ null))); + CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, /* certContext= */ null), + /* requireClientCert= */ false)); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().isEqualTo("filename expected"); @@ -295,7 +296,8 @@ public class SecretVolumeSslContextProviderTest { try { SecretVolumeSslContextProvider.getProviderForServer( CommonTlsContextTestsUtil.buildDownstreamTlsContext( - CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext))); + CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext), + /* requireClientCert= */ false)); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected.getMessage()).isEqualTo("filename expected"); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java index e21bc55f2c..ad8542f9d7 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java @@ -52,7 +52,8 @@ public class ServerSslContextProviderFactoryTest { CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForTlsCertificate( "name", "unix:/tmp/sds/path", CA_PEM_FILE); DownstreamTlsContext downstreamTlsContext = - CommonTlsContextTestsUtil.buildDownstreamTlsContext(commonTlsContext); + CommonTlsContextTestsUtil.buildDownstreamTlsContext( + commonTlsContext, /* requireClientCert= */ false); try { SslContextProvider unused = @@ -71,7 +72,8 @@ public class ServerSslContextProviderFactoryTest { CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForValidationContext( "name", "unix:/tmp/sds/path", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE); DownstreamTlsContext downstreamTlsContext = - CommonTlsContextTestsUtil.buildDownstreamTlsContext(commonTlsContext); + CommonTlsContextTestsUtil.buildDownstreamTlsContext( + commonTlsContext, /* requireClientCert= */ false); try { SslContextProvider unused =