xds: implement requireClientCertificate semantics (#6948)

This commit is contained in:
sanjaypujare 2020-04-20 17:04:38 -07:00 committed by GitHub
parent 54cac75d47
commit a649737e3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 133 additions and 20 deletions

View File

@ -234,9 +234,7 @@ final class SdsSslContextProvider<K> extends SslContextProvider<K>
tlsCertificate.hasPassword() tlsCertificate.hasPassword()
? tlsCertificate.getPassword().getInlineString() ? tlsCertificate.getPassword().getInlineString()
: null); : null);
if (localCertValidationContext != null) { setClientAuthValues(sslContextBuilder, localCertValidationContext);
sslContextBuilder.trustManager(new SdsTrustManagerFactory(localCertValidationContext));
}
} else { } else {
logger.log(Level.FINEST, "for client"); logger.log(Level.FINEST, "for client");
sslContextBuilder = sslContextBuilder =

View File

@ -205,9 +205,7 @@ final class SecretVolumeSslContextProvider<K> extends SslContextProvider<K> {
sslContextBuilder = sslContextBuilder =
GrpcSslContexts.forServer( GrpcSslContexts.forServer(
new File(certificateChain), new File(privateKey), privateKeyPassword); new File(certificateChain), new File(privateKey), privateKeyPassword);
if (certContext != null) { setClientAuthValues(sslContextBuilder, certContext);
sslContextBuilder.trustManager(new SdsTrustManagerFactory(certContext));
}
} else { } else {
sslContextBuilder = sslContextBuilder =
GrpcSslContexts.forClient().trustManager(new SdsTrustManagerFactory(certContext)); GrpcSslContexts.forClient().trustManager(new SdsTrustManagerFactory(certContext));

View File

@ -16,12 +16,21 @@
package io.grpc.xds.internal.sds; package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@ -32,6 +41,8 @@ import java.util.logging.Logger;
* stream that is receiving the requested secret(s) or it could represent file-system based * stream that is receiving the requested secret(s) or it could represent file-system based
* secret(s) that are dynamic. * secret(s) that are dynamic.
*/ */
// TODO(sanjaypujare): replace generic K with DownstreamTlsContext & UpstreamTlsContext in
// separate client&server classes
public abstract class SslContextProvider<K> { public abstract class SslContextProvider<K> {
private static final Logger logger = Logger.getLogger(SslContextProvider.class.getName()); private static final Logger logger = Logger.getLogger(SslContextProvider.class.getName());
@ -48,7 +59,11 @@ public abstract class SslContextProvider<K> {
} }
protected SslContextProvider(K source, boolean server) { protected SslContextProvider(K source, boolean server) {
checkNotNull(source, "source"); if (server) {
checkArgument(source instanceof DownstreamTlsContext, "expecting DownstreamTlsContext");
} else {
checkArgument(source instanceof UpstreamTlsContext, "expecting UpstreamTlsContext");
}
this.source = source; this.source = source;
this.server = server; this.server = server;
} }
@ -66,6 +81,22 @@ public abstract class SslContextProvider<K> {
return null; return null;
} }
protected void setClientAuthValues(
SslContextBuilder sslContextBuilder, CertificateValidationContext localCertValidationContext)
throws CertificateException, IOException, CertStoreException {
checkState(server, "server side SslContextProvider expected");
if (localCertValidationContext != null) {
sslContextBuilder.trustManager(new SdsTrustManagerFactory(localCertValidationContext));
DownstreamTlsContext downstreamTlsContext = (DownstreamTlsContext)getSource();
sslContextBuilder.clientAuth(
downstreamTlsContext.hasRequireClientCertificate()
? ClientAuth.REQUIRE
: ClientAuth.OPTIONAL);
} else {
sslContextBuilder.clientAuth(ClientAuth.NONE);
}
}
/** Closes this provider and releases any resources. */ /** Closes this provider and releases any resources. */
void close() {} void close() {}

View File

@ -18,6 +18,8 @@ package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.XdsClientWrapperForServerSdsTest.buildFilterChainMatch; import static io.grpc.xds.XdsClientWrapperForServerSdsTest.buildFilterChainMatch;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_KEY_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
@ -103,6 +105,60 @@ public class XdsSdsClientServerTest {
assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy");
} }
@Test
public void requireClientAuth_noClientCert_expectException()
throws IOException, URISyntaxException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenamesWithClientCertRequired(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
buildServerWithTlsContext(downstreamTlsContext);
// for TLS, client only uses trustCa
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE);
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr");
try {
unaryRpc(/* requestMessage= */ "buddy", blockingStub);
fail("exception expected");
} catch (StatusRuntimeException sre) {
assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class);
assertThat(sre).hasCauseThat().hasMessageThat().contains("HANDSHAKE_FAILURE");
}
}
@Test
public void noClientAuth_sendBadClientCert_passes() throws IOException, URISyntaxException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null);
buildServerWithTlsContext(downstreamTlsContext);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE);
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr");
assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy");
}
@Test
public void mtls_badClientCert_expectException() throws IOException, URISyntaxException {
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE);
try {
XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext);
fail("exception expected");
} catch (StatusRuntimeException sre) {
assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class);
assertThat(sre).hasCauseThat().hasMessageThat().contains("HANDSHAKE_FAILURE");
}
}
/** mTLS - client auth enabled. */ /** mTLS - client auth enabled. */
@Test @Test
public void mtlsClientServer_withClientAuthentication() throws IOException, URISyntaxException { public void mtlsClientServer_withClientAuthentication() throws IOException, URISyntaxException {
@ -178,7 +234,7 @@ public class XdsSdsClientServerTest {
private XdsClient.ListenerWatcher performMtlsTestAndGetListenerWatcher( private XdsClient.ListenerWatcher performMtlsTestAndGetListenerWatcher(
UpstreamTlsContext upstreamTlsContext) throws IOException, URISyntaxException { UpstreamTlsContext upstreamTlsContext) throws IOException, URISyntaxException {
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenamesWithClientCertRequired(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
final XdsClientWrapperForServerSds xdsClientWrapperForServerSds = final XdsClientWrapperForServerSds xdsClientWrapperForServerSds =

View File

@ -17,6 +17,7 @@
package io.grpc.xds.internal.sds; package io.grpc.xds.internal.sds;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import com.google.protobuf.BoolValue;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValidationContext; import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValidationContext;
@ -134,12 +135,14 @@ public class CommonTlsContextTestsUtil {
return builder.build(); return builder.build();
} }
/** /** Helper method to build DownstreamTlsContext for multiple test classes. */
* Helper method to build DownstreamTlsContext for multiple test classes. static DownstreamTlsContext buildDownstreamTlsContext(
*/ CommonTlsContext commonTlsContext, boolean requireClientCert) {
static DownstreamTlsContext buildDownstreamTlsContext(CommonTlsContext commonTlsContext) {
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
DownstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build(); DownstreamTlsContext.newBuilder()
.setCommonTlsContext(commonTlsContext)
.setRequireClientCertificate(BoolValue.of(requireClientCert))
.build();
return downstreamTlsContext; return downstreamTlsContext;
} }
@ -159,7 +162,8 @@ public class CommonTlsContextTestsUtil {
"unix:/var/run/sds/uds_path", "unix:/var/run/sds/uds_path",
Arrays.asList("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"), Arrays.asList("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"),
Arrays.asList("managed-tls"), Arrays.asList("managed-tls"),
null)); null),
/* requireClientCert= */ false);
} }
static String getTempFileNameForResourcesFile(String resFile) throws IOException { static String getTempFileNameForResourcesFile(String resFile) throws IOException {
@ -171,6 +175,27 @@ public class CommonTlsContextTestsUtil {
*/ */
public static DownstreamTlsContext buildDownstreamTlsContextFromFilenames( public static DownstreamTlsContext buildDownstreamTlsContextFromFilenames(
@Nullable String privateKey, @Nullable String certChain, @Nullable String trustCa) { @Nullable String privateKey, @Nullable String certChain, @Nullable String trustCa) {
return buildDownstreamTlsContextFromFilenamesWithClientAuth(privateKey, certChain, trustCa,
false);
}
/**
* Helper method to build DownstreamTlsContext for above tests. Called from other classes as well.
*/
public static DownstreamTlsContext buildDownstreamTlsContextFromFilenamesWithClientCertRequired(
@Nullable String privateKey,
@Nullable String certChain,
@Nullable String trustCa) {
return buildDownstreamTlsContextFromFilenamesWithClientAuth(privateKey, certChain, trustCa,
true);
}
private static DownstreamTlsContext buildDownstreamTlsContextFromFilenamesWithClientAuth(
@Nullable String privateKey,
@Nullable String certChain,
@Nullable String trustCa,
boolean requireClientCert) {
// get temp file for each file // get temp file for each file
try { try {
if (certChain != null) { if (certChain != null) {
@ -186,7 +211,7 @@ public class CommonTlsContextTestsUtil {
throw new RuntimeException(ioe); throw new RuntimeException(ioe);
} }
return buildDownstreamTlsContext( return buildDownstreamTlsContext(
buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa)); buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa), requireClientCert);
} }
/** /**

View File

@ -79,7 +79,8 @@ public class SdsSslContextProviderTest {
return server return server
? SdsSslContextProvider.getProviderForServer( ? SdsSslContextProvider.getProviderForServer(
CommonTlsContextTestsUtil.buildDownstreamTlsContext(commonTlsContext), CommonTlsContextTestsUtil.buildDownstreamTlsContext(
commonTlsContext, /* requireClientCert= */ false),
node, node,
MoreExecutors.directExecutor(), MoreExecutors.directExecutor(),
MoreExecutors.directExecutor()) MoreExecutors.directExecutor())

View File

@ -274,7 +274,8 @@ public class SecretVolumeSslContextProviderTest {
try { try {
SecretVolumeSslContextProvider.getProviderForServer( SecretVolumeSslContextProvider.getProviderForServer(
CommonTlsContextTestsUtil.buildDownstreamTlsContext( CommonTlsContextTestsUtil.buildDownstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, /* certContext= */ null))); CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, /* certContext= */ null),
/* requireClientCert= */ false));
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected"); assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -295,7 +296,8 @@ public class SecretVolumeSslContextProviderTest {
try { try {
SecretVolumeSslContextProvider.getProviderForServer( SecretVolumeSslContextProvider.getProviderForServer(
CommonTlsContextTestsUtil.buildDownstreamTlsContext( CommonTlsContextTestsUtil.buildDownstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext))); CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext),
/* requireClientCert= */ false));
Assert.fail("no exception thrown"); Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) { } catch (IllegalArgumentException expected) {
assertThat(expected.getMessage()).isEqualTo("filename expected"); assertThat(expected.getMessage()).isEqualTo("filename expected");

View File

@ -52,7 +52,8 @@ public class ServerSslContextProviderFactoryTest {
CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForTlsCertificate( CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForTlsCertificate(
"name", "unix:/tmp/sds/path", CA_PEM_FILE); "name", "unix:/tmp/sds/path", CA_PEM_FILE);
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContext(commonTlsContext); CommonTlsContextTestsUtil.buildDownstreamTlsContext(
commonTlsContext, /* requireClientCert= */ false);
try { try {
SslContextProvider<DownstreamTlsContext> unused = SslContextProvider<DownstreamTlsContext> unused =
@ -71,7 +72,8 @@ public class ServerSslContextProviderFactoryTest {
CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForValidationContext( CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForValidationContext(
"name", "unix:/tmp/sds/path", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE); "name", "unix:/tmp/sds/path", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE);
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContext(commonTlsContext); CommonTlsContextTestsUtil.buildDownstreamTlsContext(
commonTlsContext, /* requireClientCert= */ false);
try { try {
SslContextProvider<DownstreamTlsContext> unused = SslContextProvider<DownstreamTlsContext> unused =