xds: implement requireClientCertificate semantics (#6948)

This commit is contained in:
sanjaypujare 2020-04-20 17:04:38 -07:00 committed by GitHub
parent 54cac75d47
commit a649737e3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 133 additions and 20 deletions

View File

@ -234,9 +234,7 @@ final class SdsSslContextProvider<K> extends SslContextProvider<K>
tlsCertificate.hasPassword()
? tlsCertificate.getPassword().getInlineString()
: null);
if (localCertValidationContext != null) {
sslContextBuilder.trustManager(new SdsTrustManagerFactory(localCertValidationContext));
}
setClientAuthValues(sslContextBuilder, localCertValidationContext);
} else {
logger.log(Level.FINEST, "for client");
sslContextBuilder =

View File

@ -205,9 +205,7 @@ final class SecretVolumeSslContextProvider<K> extends SslContextProvider<K> {
sslContextBuilder =
GrpcSslContexts.forServer(
new File(certificateChain), new File(privateKey), privateKeyPassword);
if (certContext != null) {
sslContextBuilder.trustManager(new SdsTrustManagerFactory(certContext));
}
setClientAuthValues(sslContextBuilder, certContext);
} else {
sslContextBuilder =
GrpcSslContexts.forClient().trustManager(new SdsTrustManagerFactory(certContext));

View File

@ -16,12 +16,21 @@
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import java.util.concurrent.Executor;
import java.util.logging.Level;
import java.util.logging.Logger;
@ -32,6 +41,8 @@ import java.util.logging.Logger;
* stream that is receiving the requested secret(s) or it could represent file-system based
* secret(s) that are dynamic.
*/
// TODO(sanjaypujare): replace generic K with DownstreamTlsContext & UpstreamTlsContext in
// separate client&server classes
public abstract class SslContextProvider<K> {
private static final Logger logger = Logger.getLogger(SslContextProvider.class.getName());
@ -48,7 +59,11 @@ public abstract class SslContextProvider<K> {
}
protected SslContextProvider(K source, boolean server) {
checkNotNull(source, "source");
if (server) {
checkArgument(source instanceof DownstreamTlsContext, "expecting DownstreamTlsContext");
} else {
checkArgument(source instanceof UpstreamTlsContext, "expecting UpstreamTlsContext");
}
this.source = source;
this.server = server;
}
@ -66,6 +81,22 @@ public abstract class SslContextProvider<K> {
return null;
}
protected void setClientAuthValues(
SslContextBuilder sslContextBuilder, CertificateValidationContext localCertValidationContext)
throws CertificateException, IOException, CertStoreException {
checkState(server, "server side SslContextProvider expected");
if (localCertValidationContext != null) {
sslContextBuilder.trustManager(new SdsTrustManagerFactory(localCertValidationContext));
DownstreamTlsContext downstreamTlsContext = (DownstreamTlsContext)getSource();
sslContextBuilder.clientAuth(
downstreamTlsContext.hasRequireClientCertificate()
? ClientAuth.REQUIRE
: ClientAuth.OPTIONAL);
} else {
sslContextBuilder.clientAuth(ClientAuth.NONE);
}
}
/** Closes this provider and releases any resources. */
void close() {}

View File

@ -18,6 +18,8 @@ package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.XdsClientWrapperForServerSdsTest.buildFilterChainMatch;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
@ -103,6 +105,60 @@ public class XdsSdsClientServerTest {
assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy");
}
@Test
public void requireClientAuth_noClientCert_expectException()
throws IOException, URISyntaxException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenamesWithClientCertRequired(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
buildServerWithTlsContext(downstreamTlsContext);
// for TLS, client only uses trustCa
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE);
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr");
try {
unaryRpc(/* requestMessage= */ "buddy", blockingStub);
fail("exception expected");
} catch (StatusRuntimeException sre) {
assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class);
assertThat(sre).hasCauseThat().hasMessageThat().contains("HANDSHAKE_FAILURE");
}
}
@Test
public void noClientAuth_sendBadClientCert_passes() throws IOException, URISyntaxException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null);
buildServerWithTlsContext(downstreamTlsContext);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE);
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr");
assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy");
}
@Test
public void mtls_badClientCert_expectException() throws IOException, URISyntaxException {
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE);
try {
XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext);
fail("exception expected");
} catch (StatusRuntimeException sre) {
assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class);
assertThat(sre).hasCauseThat().hasMessageThat().contains("HANDSHAKE_FAILURE");
}
}
/** mTLS - client auth enabled. */
@Test
public void mtlsClientServer_withClientAuthentication() throws IOException, URISyntaxException {
@ -178,7 +234,7 @@ public class XdsSdsClientServerTest {
private XdsClient.ListenerWatcher performMtlsTestAndGetListenerWatcher(
UpstreamTlsContext upstreamTlsContext) throws IOException, URISyntaxException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenamesWithClientCertRequired(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
final XdsClientWrapperForServerSds xdsClientWrapperForServerSds =

View File

@ -17,6 +17,7 @@
package io.grpc.xds.internal.sds;
import com.google.common.base.Strings;
import com.google.protobuf.BoolValue;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValidationContext;
@ -134,12 +135,14 @@ public class CommonTlsContextTestsUtil {
return builder.build();
}
/**
* Helper method to build DownstreamTlsContext for multiple test classes.
*/
static DownstreamTlsContext buildDownstreamTlsContext(CommonTlsContext commonTlsContext) {
/** Helper method to build DownstreamTlsContext for multiple test classes. */
static DownstreamTlsContext buildDownstreamTlsContext(
CommonTlsContext commonTlsContext, boolean requireClientCert) {
DownstreamTlsContext downstreamTlsContext =
DownstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build();
DownstreamTlsContext.newBuilder()
.setCommonTlsContext(commonTlsContext)
.setRequireClientCertificate(BoolValue.of(requireClientCert))
.build();
return downstreamTlsContext;
}
@ -159,7 +162,8 @@ public class CommonTlsContextTestsUtil {
"unix:/var/run/sds/uds_path",
Arrays.asList("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"),
Arrays.asList("managed-tls"),
null));
null),
/* requireClientCert= */ false);
}
static String getTempFileNameForResourcesFile(String resFile) throws IOException {
@ -171,6 +175,27 @@ public class CommonTlsContextTestsUtil {
*/
public static DownstreamTlsContext buildDownstreamTlsContextFromFilenames(
@Nullable String privateKey, @Nullable String certChain, @Nullable String trustCa) {
return buildDownstreamTlsContextFromFilenamesWithClientAuth(privateKey, certChain, trustCa,
false);
}
/**
* Helper method to build DownstreamTlsContext for above tests. Called from other classes as well.
*/
public static DownstreamTlsContext buildDownstreamTlsContextFromFilenamesWithClientCertRequired(
@Nullable String privateKey,
@Nullable String certChain,
@Nullable String trustCa) {
return buildDownstreamTlsContextFromFilenamesWithClientAuth(privateKey, certChain, trustCa,
true);
}
private static DownstreamTlsContext buildDownstreamTlsContextFromFilenamesWithClientAuth(
@Nullable String privateKey,
@Nullable String certChain,
@Nullable String trustCa,
boolean requireClientCert) {
// get temp file for each file
try {
if (certChain != null) {
@ -186,7 +211,7 @@ public class CommonTlsContextTestsUtil {
throw new RuntimeException(ioe);
}
return buildDownstreamTlsContext(
buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa), requireClientCert);
}
/**

View File

@ -79,7 +79,8 @@ public class SdsSslContextProviderTest {
return server
? SdsSslContextProvider.getProviderForServer(
CommonTlsContextTestsUtil.buildDownstreamTlsContext(commonTlsContext),
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
commonTlsContext, /* requireClientCert= */ false),
node,
MoreExecutors.directExecutor(),
MoreExecutors.directExecutor())

View File

@ -274,7 +274,8 @@ public class SecretVolumeSslContextProviderTest {
try {
SecretVolumeSslContextProvider.getProviderForServer(
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, /* certContext= */ null)));
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, /* certContext= */ null),
/* requireClientCert= */ false));
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -295,7 +296,8 @@ public class SecretVolumeSslContextProviderTest {
try {
SecretVolumeSslContextProvider.getProviderForServer(
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext)));
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext),
/* requireClientCert= */ false));
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected.getMessage()).isEqualTo("filename expected");

View File

@ -52,7 +52,8 @@ public class ServerSslContextProviderFactoryTest {
CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForTlsCertificate(
"name", "unix:/tmp/sds/path", CA_PEM_FILE);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContext(commonTlsContext);
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
commonTlsContext, /* requireClientCert= */ false);
try {
SslContextProvider<DownstreamTlsContext> unused =
@ -71,7 +72,8 @@ public class ServerSslContextProviderFactoryTest {
CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForValidationContext(
"name", "unix:/tmp/sds/path", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContext(commonTlsContext);
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
commonTlsContext, /* requireClientCert= */ false);
try {
SslContextProvider<DownstreamTlsContext> unused =