xds: get rid of legacy SDS and file watching code (#8276)

This commit is contained in:
sanjaypujare 2021-06-23 11:13:19 -07:00 committed by GitHub
parent c540229d79
commit e4ab8287d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 313 additions and 1382 deletions

View File

@ -77,10 +77,7 @@ public class CdsLoadBalancer2Test {
private static final String DNS_HOST_NAME = "backend-service-dns.googleapis.com:443"; private static final String DNS_HOST_NAME = "backend-service-dns.googleapis.com:443";
private static final String LRS_SERVER_NAME = "lrs.googleapis.com"; private static final String LRS_SERVER_NAME = "lrs.googleapis.com";
private final UpstreamTlsContext upstreamTlsContext = private final UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true);
CommonTlsContextTestsUtil.CLIENT_KEY_FILE,
CommonTlsContextTestsUtil.CLIENT_PEM_FILE,
CommonTlsContextTestsUtil.CA_PEM_FILE);
private final SynchronizationContext syncContext = new SynchronizationContext( private final SynchronizationContext syncContext = new SynchronizationContext(
new Thread.UncaughtExceptionHandler() { new Thread.UncaughtExceptionHandler() {

View File

@ -494,10 +494,7 @@ public class ClusterImplLoadBalancerTest {
private void subtest_endpointAddressesAttachedWithTlsConfig(boolean enableSecurity) { private void subtest_endpointAddressesAttachedWithTlsConfig(boolean enableSecurity) {
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true);
CommonTlsContextTestsUtil.CLIENT_KEY_FILE,
CommonTlsContextTestsUtil.CLIENT_PEM_FILE,
CommonTlsContextTestsUtil.CA_PEM_FILE);
LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider();
WeightedTargetConfig weightedTargetConfig = WeightedTargetConfig weightedTargetConfig =
buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); buildWeightedTargetConfig(ImmutableMap.of(locality, 10));
@ -541,10 +538,7 @@ public class ClusterImplLoadBalancerTest {
// Config with a new UpstreamTlsContext. // Config with a new UpstreamTlsContext.
upstreamTlsContext = upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe1", true);
CommonTlsContextTestsUtil.BAD_CLIENT_KEY_FILE,
CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE,
CommonTlsContextTestsUtil.CA_PEM_FILE);
config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_NAME, config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_NAME,
null, Collections.<DropOverload>emptyList(), null, Collections.<DropOverload>emptyList(),
new PolicySelection(weightedTargetProvider, weightedTargetConfig), upstreamTlsContext); new PolicySelection(weightedTargetProvider, weightedTargetConfig), upstreamTlsContext);

View File

@ -111,10 +111,7 @@ public class ClusterResolverLoadBalancerTest {
private final Locality locality3 = private final Locality locality3 =
Locality.create("test-region-3", "test-zone-3", "test-subzone-3"); Locality.create("test-region-3", "test-zone-3", "test-subzone-3");
private final UpstreamTlsContext tlsContext = private final UpstreamTlsContext tlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true);
CommonTlsContextTestsUtil.CLIENT_KEY_FILE,
CommonTlsContextTestsUtil.CLIENT_PEM_FILE,
CommonTlsContextTestsUtil.CA_PEM_FILE);
private final DiscoveryMechanism edsDiscoveryMechanism1 = private final DiscoveryMechanism edsDiscoveryMechanism1 =
DiscoveryMechanism.forEds(CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_NAME, 100L, tlsContext); DiscoveryMechanism.forEds(CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_NAME, 100L, tlsContext);
private final DiscoveryMechanism edsDiscoveryMechanism2 = private final DiscoveryMechanism edsDiscoveryMechanism2 =

View File

@ -19,8 +19,11 @@ package io.grpc.xds;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import io.grpc.internal.JsonParser; import io.grpc.internal.JsonParser;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import javax.annotation.Nullable;
public class CommonBootstrapperTestUtils { public class CommonBootstrapperTestUtils {
private static final String FILE_WATCHER_CONFIG = "{\"path\": \"/etc/secret/certs\"}"; private static final String FILE_WATCHER_CONFIG = "{\"path\": \"/etc/secret/certs\"}";
@ -72,4 +75,58 @@ public class CommonBootstrapperTestUtils {
throw new AssertionError(e); throw new AssertionError(e);
} }
} }
/**
* Build {@link Bootstrapper.BootstrapInfo} for certProviderInstance tests.
* Populates with temp file paths.
*/
public static Bootstrapper.BootstrapInfo buildBootstrapInfo(
String certInstanceName1, @Nullable String privateKey1,
@Nullable String cert1,
@Nullable String trustCa1, String certInstanceName2, String privateKey2, String cert2,
String trustCa2) {
// get temp file for each file
try {
if (privateKey1 != null) {
privateKey1 = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(privateKey1);
}
if (cert1 != null) {
cert1 = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(cert1);
}
if (trustCa1 != null) {
trustCa1 = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(trustCa1);
}
if (privateKey2 != null) {
privateKey2 = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(privateKey2);
}
if (cert2 != null) {
cert2 = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(cert2);
}
if (trustCa2 != null) {
trustCa2 = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(trustCa2);
}
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
HashMap<String, String> config = new HashMap<>();
config.put("certificate_file", cert1);
config.put("private_key_file", privateKey1);
config.put("ca_certificate_file", trustCa1);
Bootstrapper.CertificateProviderInfo certificateProviderInfo =
new Bootstrapper.CertificateProviderInfo("file_watcher", config);
HashMap<String, Bootstrapper.CertificateProviderInfo> certProviders =
new HashMap<>();
certProviders.put(certInstanceName1, certificateProviderInfo);
if (certInstanceName2 != null) {
config = new HashMap<>();
config.put("certificate_file", cert2);
config.put("private_key_file", privateKey2);
config.put("ca_certificate_file", trustCa2);
certificateProviderInfo =
new Bootstrapper.CertificateProviderInfo("file_watcher", config);
certProviders.put(certInstanceName2, certificateProviderInfo);
}
return new Bootstrapper.BootstrapInfo(null, EnvoyProtoData.Node.newBuilder().build(),
certProviders, null);
}
} }

View File

@ -927,8 +927,8 @@ public class FilterChainMatchTest {
registeredWatcher.onChanged(listenerUpdate); registeredWatcher.onChanged(listenerUpdate);
EnvoyServerProtoData.DownstreamTlsContext tlsContextPicked = getDownstreamTlsContext(); EnvoyServerProtoData.DownstreamTlsContext tlsContextPicked = getDownstreamTlsContext();
// assert defaultFilterChain match // assert defaultFilterChain match
assertThat(tlsContextPicked.getCommonTlsContext().getTlsCertificateSdsSecretConfigsList() assertThat(tlsContextPicked.getCommonTlsContext().getTlsCertificateCertificateProviderInstance()
.get(0).getName()).isEqualTo("CERT3"); .getCertificateName()).isEqualTo("CERT3");
} }
private void setupChannel(String localIp, String remoteIp, int remotePort) private void setupChannel(String localIp, String remoteIp, int remotePort)

View File

@ -82,7 +82,10 @@ public class XdsSdsClientServerTest {
@Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
private int port; private int port;
private FakeNameResolverFactory fakeNameResolverFactory; private FakeNameResolverFactory fakeNameResolverFactory;
private final TlsContextManagerImpl tlsContextManager = new TlsContextManagerImpl(null); private Bootstrapper.BootstrapInfo bootstrapInfoForClient = null;
private Bootstrapper.BootstrapInfo bootstrapInfoForServer = null;
private TlsContextManagerImpl tlsContextManagerForClient;
private TlsContextManagerImpl tlsContextManagerForServer;
@Before @Before
public void setUp() throws IOException { public void setUp() throws IOException {
@ -119,14 +122,13 @@ public class XdsSdsClientServerTest {
@Test @Test
public void tlsClientServer_noClientAuthentication() throws IOException, URISyntaxException { public void tlsClientServer_noClientAuthentication() throws IOException, URISyntaxException {
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false);
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, null);
buildServerWithTlsContext(downstreamTlsContext); buildServerWithTlsContext(downstreamTlsContext);
// for TLS, client only needs trustCa // for TLS, client only needs trustCa
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE,
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); CLIENT_PEM_FILE, false);
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr");
@ -137,14 +139,13 @@ public class XdsSdsClientServerTest {
public void requireClientAuth_noClientCert_expectException() public void requireClientAuth_noClientCert_expectException()
throws IOException, URISyntaxException { throws IOException, URISyntaxException {
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenamesWithClientCertRequired( setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, true, true);
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
buildServerWithTlsContext(downstreamTlsContext); buildServerWithTlsContext(downstreamTlsContext);
// for TLS, client only uses trustCa // for TLS, client only uses trustCa
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE,
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); CLIENT_PEM_FILE, false);
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr");
@ -166,13 +167,12 @@ public class XdsSdsClientServerTest {
@Test @Test
public void noClientAuth_sendBadClientCert_passes() throws IOException, URISyntaxException { public void noClientAuth_sendBadClientCert_passes() throws IOException, URISyntaxException {
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false);
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null);
buildServerWithTlsContext(downstreamTlsContext); buildServerWithTlsContext(downstreamTlsContext);
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( BAD_CLIENT_KEY_FILE,
BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE); BAD_CLIENT_PEM_FILE, true);
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr");
@ -181,11 +181,11 @@ public class XdsSdsClientServerTest {
@Test @Test
public void mtls_badClientCert_expectException() throws IOException, URISyntaxException { public void mtls_badClientCert_expectException() throws IOException, URISyntaxException {
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( BAD_CLIENT_KEY_FILE,
BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE); BAD_CLIENT_PEM_FILE, true);
try { try {
performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false); performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false, null, null, null, null);
fail("exception expected"); fail("exception expected");
} catch (StatusRuntimeException sre) { } catch (StatusRuntimeException sre) {
if (sre.getCause() instanceof SSLHandshakeException) { if (sre.getCause() instanceof SSLHandshakeException) {
@ -202,27 +202,26 @@ public class XdsSdsClientServerTest {
/** mTLS - client auth enabled. */ /** mTLS - client auth enabled. */
@Test @Test
public void mtlsClientServer_withClientAuthentication() throws IOException, URISyntaxException { public void mtlsClientServer_withClientAuthentication() throws IOException, URISyntaxException {
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE,
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); CLIENT_PEM_FILE, true);
performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false); performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false, null, null, null, null);
} }
/** mTLS - client auth enabled - using {@link XdsChannelCredentials} API. */ /** mTLS - client auth enabled - using {@link XdsChannelCredentials} API. */
@Test @Test
public void mtlsClientServer_withClientAuthentication_withXdsChannelCreds() public void mtlsClientServer_withClientAuthentication_withXdsChannelCreds()
throws IOException, URISyntaxException { throws IOException, URISyntaxException {
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE,
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); CLIENT_PEM_FILE, true);
performMtlsTestAndGetListenerWatcher(upstreamTlsContext, true); performMtlsTestAndGetListenerWatcher(upstreamTlsContext, true, null, null, null, null);
} }
@Test @Test
public void tlsServer_plaintextClient_expectException() throws IOException, URISyntaxException { public void tlsServer_plaintextClient_expectException() throws IOException, URISyntaxException {
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false);
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, null);
buildServerWithTlsContext(downstreamTlsContext); buildServerWithTlsContext(downstreamTlsContext);
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
@ -241,9 +240,9 @@ public class XdsSdsClientServerTest {
buildServerWithTlsContext(/* downstreamTlsContext= */ null); buildServerWithTlsContext(/* downstreamTlsContext= */ null);
// for TLS, client only needs trustCa // for TLS, client only needs trustCa
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE,
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); CLIENT_PEM_FILE, false);
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr");
@ -260,15 +259,18 @@ public class XdsSdsClientServerTest {
@Test @Test
public void mtlsClientServer_changeServerContext_expectException() public void mtlsClientServer_changeServerContext_expectException()
throws IOException, URISyntaxException { throws IOException, URISyntaxException {
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE,
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); CLIENT_PEM_FILE, true);
XdsClient.LdsResourceWatcher listenerWatcher = XdsClient.LdsResourceWatcher listenerWatcher =
performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false); performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false, "cert-instance-name2",
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE); BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE);
generateListenerUpdateToWatcher(downstreamTlsContext, listenerWatcher, tlsContextManager); DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
"cert-instance-name2", true, true);
generateListenerUpdateToWatcher(downstreamTlsContext, listenerWatcher,
tlsContextManagerForServer);
try { try {
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(upstreamTlsContext, "foo.test.google.fr"); getBlockingStub(upstreamTlsContext, "foo.test.google.fr");
@ -281,11 +283,12 @@ public class XdsSdsClientServerTest {
} }
private XdsClient.LdsResourceWatcher performMtlsTestAndGetListenerWatcher( private XdsClient.LdsResourceWatcher performMtlsTestAndGetListenerWatcher(
UpstreamTlsContext upstreamTlsContext, boolean newApi) UpstreamTlsContext upstreamTlsContext, boolean newApi, String certInstanceName2,
String privateKey2, String cert2, String trustCa2)
throws IOException, URISyntaxException { throws IOException, URISyntaxException {
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenamesWithClientCertRequired( setBootstrapInfoAndBuildDownstreamTlsContext(certInstanceName2, privateKey2, cert2,
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); trustCa2, true, true);
final XdsClientWrapperForServerSds xdsClientWrapperForServerSds = final XdsClientWrapperForServerSds xdsClientWrapperForServerSds =
createXdsClientWrapperForServerSds(port); createXdsClientWrapperForServerSds(port);
@ -302,6 +305,27 @@ public class XdsSdsClientServerTest {
return listenerWatcher; return listenerWatcher;
} }
private DownstreamTlsContext setBootstrapInfoAndBuildDownstreamTlsContext(
String certInstanceName2,
String privateKey2,
String cert2, String trustCa2, boolean hasRootCert, boolean requireClientCertificate) {
bootstrapInfoForServer = CommonBootstrapperTestUtils
.buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE,
SERVER_1_PEM_FILE, CA_PEM_FILE, certInstanceName2, privateKey2, cert2, trustCa2);
return CommonTlsContextTestsUtil.buildDownstreamTlsContext(
"google_cloud_private_spiffe-server", hasRootCert, requireClientCertificate);
}
private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContext(String clientKeyFile,
String clientPemFile,
boolean hasIdentityCert) {
bootstrapInfoForClient = CommonBootstrapperTestUtils
.buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile,
CA_PEM_FILE, null, null, null, null);
return CommonTlsContextTestsUtil
.buildUpstreamTlsContext("google_cloud_private_spiffe-client", hasIdentityCert);
}
private void buildServerWithTlsContext(DownstreamTlsContext downstreamTlsContext) private void buildServerWithTlsContext(DownstreamTlsContext downstreamTlsContext)
throws IOException { throws IOException {
buildServerWithTlsContext(downstreamTlsContext, InsecureServerCredentials.create()); buildServerWithTlsContext(downstreamTlsContext, InsecureServerCredentials.create());
@ -328,8 +352,9 @@ public class XdsSdsClientServerTest {
/** Creates XdsClientWrapperForServerSds. */ /** Creates XdsClientWrapperForServerSds. */
private XdsClientWrapperForServerSds createXdsClientWrapperForServerSds(int port) { private XdsClientWrapperForServerSds createXdsClientWrapperForServerSds(int port) {
tlsContextManagerForServer = new TlsContextManagerImpl(bootstrapInfoForServer);
XdsClientWrapperForServerSds xdsClientWrapperForServerSds = XdsClientWrapperForServerSds xdsClientWrapperForServerSds =
XdsServerTestHelper.createXdsClientWrapperForServerSds(port, tlsContextManager); XdsServerTestHelper.createXdsClientWrapperForServerSds(port, tlsContextManagerForServer);
xdsClientWrapperForServerSds.start(); xdsClientWrapperForServerSds.start();
return xdsClientWrapperForServerSds; return xdsClientWrapperForServerSds;
} }
@ -351,8 +376,10 @@ public class XdsSdsClientServerTest {
throws IOException { throws IOException {
XdsServerBuilder builder = XdsServerBuilder.forPort(port, serverCredentials) XdsServerBuilder builder = XdsServerBuilder.forPort(port, serverCredentials)
.addService(new SimpleServiceImpl()); .addService(new SimpleServiceImpl());
tlsContextManagerForServer = new TlsContextManagerImpl(bootstrapInfoForServer);
XdsServerTestHelper.generateListenerUpdate( XdsServerTestHelper.generateListenerUpdate(
xdsClientWrapperForServerSds.getListenerWatcher(), downstreamTlsContext, tlsContextManager); xdsClientWrapperForServerSds.getListenerWatcher(), downstreamTlsContext,
tlsContextManagerForServer);
cleanupRule.register(builder.buildServer(xdsClientWrapperForServerSds)).start(); cleanupRule.register(builder.buildServer(xdsClientWrapperForServerSds)).start();
} }
@ -396,12 +423,13 @@ public class XdsSdsClientServerTest {
} }
InetSocketAddress socketAddress = InetSocketAddress socketAddress =
new InetSocketAddress(Inet4Address.getLoopbackAddress(), port); new InetSocketAddress(Inet4Address.getLoopbackAddress(), port);
tlsContextManagerForClient = new TlsContextManagerImpl(bootstrapInfoForClient);
Attributes attrs = Attributes attrs =
(upstreamTlsContext != null) (upstreamTlsContext != null)
? Attributes.newBuilder() ? Attributes.newBuilder()
.set(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, .set(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER,
new SslContextProviderSupplier( new SslContextProviderSupplier(
upstreamTlsContext, tlsContextManager)) upstreamTlsContext, tlsContextManagerForClient))
.build() .build()
: Attributes.EMPTY; : Attributes.EMPTY;
fakeNameResolverFactory.setServers( fakeNameResolverFactory.setServers(

View File

@ -17,9 +17,6 @@
package io.grpc.xds.internal.sds; package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -67,62 +64,6 @@ public class ClientSslContextProviderFactoryTest {
new CertProviderClientSslContextProvider.Factory(certificateProviderStore); new CertProviderClientSslContextProvider.Factory(certificateProviderStore);
} }
@Test
public void createSslContextProvider_allFilenames() {
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
null, certProviderClientSslContextProviderFactory);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isNotNull();
}
@Test
public void createSslContextProvider_sdsConfigForTlsCert_expectException() {
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
null, certProviderClientSslContextProviderFactory);
CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForTlsCertificate(
/* name= */ "name", /* targetUri= */ "unix:/tmp/sds/path", CA_PEM_FILE);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContext(commonTlsContext);
try {
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("unexpected TlsCertificateSdsSecretConfigs");
}
}
@Test
public void createSslContextProvider_sdsConfigForCertValidationContext_expectException() {
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
null, certProviderClientSslContextProviderFactory);
CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForValidationContext(
/* name= */ "name",
/* targetUri= */ "unix:/tmp/sds/path",
CLIENT_KEY_FILE,
CLIENT_PEM_FILE);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContext(commonTlsContext);
try {
SslContextProvider unused =
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (IllegalStateException expected) {
assertThat(expected).hasMessageThat().isEqualTo("incorrect ValidationContextTypeCase");
}
}
@Test @Test
public void createCertProviderClientSslContextProvider() throws XdsInitializationException { public void createCertProviderClientSslContextProvider() throws XdsInitializationException {
final CertificateProvider.DistributorWatcher[] watcherCaptor = final CertificateProvider.DistributorWatcher[] watcherCaptor =
@ -267,23 +208,6 @@ public class ClientSslContextProviderFactoryTest {
verifyWatcher(sslContextProvider, watcherCaptor[1]); verifyWatcher(sslContextProvider, watcherCaptor[1]);
} }
@Test
public void createEmptyCommonTlsContext_exception() throws IOException {
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
null, certProviderClientSslContextProviderFactory);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(null, null, null);
try {
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("Unsupported configurations in UpstreamTlsContext!");
}
}
@Test @Test
public void createNullCommonTlsContext_exception() throws IOException { public void createNullCommonTlsContext_exception() throws IOException {
clientSslContextProviderFactory = clientSslContextProviderFactory =

View File

@ -19,25 +19,14 @@ package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.base.Strings;
import com.google.common.io.CharStreams; import com.google.common.io.CharStreams;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.BoolValue; import com.google.protobuf.BoolValue;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import io.envoyproxy.envoy.api.v2.core.ApiConfigSource.ApiType;
import io.envoyproxy.envoy.api.v2.core.GrpcService.GoogleGrpc;
import io.envoyproxy.envoy.config.core.v3.ApiConfigSource;
import io.envoyproxy.envoy.config.core.v3.ConfigSource;
import io.envoyproxy.envoy.config.core.v3.DataSource;
import io.envoyproxy.envoy.config.core.v3.GrpcService;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.SdsSecretConfig;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.internal.testing.TestUtils; import io.grpc.internal.testing.TestUtils;
@ -72,193 +61,39 @@ public class CommonTlsContextTestsUtil {
public static final String BAD_CLIENT_PEM_FILE = "badclient.pem"; public static final String BAD_CLIENT_PEM_FILE = "badclient.pem";
public static final String BAD_CLIENT_KEY_FILE = "badclient.key"; public static final String BAD_CLIENT_KEY_FILE = "badclient.key";
static io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig buildSdsSecretConfigV2(
String name, String targetUri, String channelType) {
io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig sdsSecretConfig = null;
if (!Strings.isNullOrEmpty(name) && !Strings.isNullOrEmpty(targetUri)) {
sdsSecretConfig =
io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig.newBuilder()
.setName(name)
.setSdsConfig(buildConfigSourceV2(targetUri, channelType))
.build();
}
return sdsSecretConfig;
}
private static SdsSecretConfig
buildSdsSecretConfig(String name, String targetUri, String channelType) {
SdsSecretConfig sdsSecretConfig = null;
if (!Strings.isNullOrEmpty(name) && !Strings.isNullOrEmpty(targetUri)) {
sdsSecretConfig =
SdsSecretConfig.newBuilder()
.setName(name)
.setSdsConfig(buildConfigSource(targetUri, channelType))
.build();
}
return sdsSecretConfig;
}
/**
* Builds a {@link io.envoyproxy.envoy.api.v2.core.ConfigSource} for the given targetUri.
*
* @param channelType specifying "inproc" creates an Inprocess channel for testing.
*/
private static io.envoyproxy.envoy.api.v2.core.ConfigSource buildConfigSourceV2(
String targetUri, String channelType) {
GoogleGrpc.Builder googleGrpcBuilder = GoogleGrpc.newBuilder().setTargetUri(targetUri);
if (channelType != null) {
Struct.Builder structBuilder = Struct.newBuilder()
.putFields("channelType", Value.newBuilder().setStringValue(channelType).build());
googleGrpcBuilder.setConfig(structBuilder.build());
}
return io.envoyproxy.envoy.api.v2.core.ConfigSource.newBuilder()
.setApiConfigSource(
io.envoyproxy.envoy.api.v2.core.ApiConfigSource.newBuilder()
.setApiType(ApiType.GRPC)
.addGrpcServices(
io.envoyproxy.envoy.api.v2.core.GrpcService.newBuilder()
.setGoogleGrpc(googleGrpcBuilder.build())
.build())
.build())
.build();
}
/**
* Builds a {@link ConfigSource} for the given targetUri.
*
* @param channelType specifying "inproc" creates an Inprocess channel for testing.
*/
private static ConfigSource buildConfigSource(String targetUri, String channelType) {
GrpcService.GoogleGrpc.Builder googleGrpcBuilder =
GrpcService.GoogleGrpc.newBuilder().setTargetUri(targetUri);
if (channelType != null) {
Struct.Builder structBuilder = Struct.newBuilder()
.putFields("channelType", Value.newBuilder().setStringValue(channelType).build());
googleGrpcBuilder.setConfig(structBuilder.build());
}
return ConfigSource.newBuilder()
.setApiConfigSource(
ApiConfigSource.newBuilder()
.setApiType(ApiConfigSource.ApiType.GRPC)
.addGrpcServices(GrpcService.newBuilder().setGoogleGrpc(googleGrpcBuilder))
.build())
.build();
}
static CommonTlsContext buildCommonTlsContextFromSdsConfigForValidationContext(
String name, String targetUri, String privateKey, String certChain) {
SdsSecretConfig sdsSecretConfig =
buildSdsSecretConfig(name, targetUri, /* channelType= */ null);
CommonTlsContext.Builder builder =
CommonTlsContext.newBuilder().setValidationContextSdsSecretConfig(sdsSecretConfig);
if (!Strings.isNullOrEmpty(privateKey) && !Strings.isNullOrEmpty(certChain)) {
builder.addTlsCertificates(
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setFilename(certChain))
.setPrivateKey(DataSource.newBuilder().setFilename(privateKey))
.build());
}
return builder.build();
}
static CommonTlsContext buildCommonTlsContextFromSdsConfigForTlsCertificate(
String name, String targetUri, String trustCa) {
SdsSecretConfig sdsSecretConfig =
buildSdsSecretConfig(name, targetUri, /* channelType= */ null);
CommonTlsContext.Builder builder =
CommonTlsContext.newBuilder().addTlsCertificateSdsSecretConfigs(sdsSecretConfig);
if (!Strings.isNullOrEmpty(trustCa)) {
builder.setValidationContext(
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename(trustCa))
.build());
}
return builder.build();
}
/** takes additional values and creates CombinedCertificateValidationContext as needed. */
@SuppressWarnings("deprecation")
static io.envoyproxy.envoy.api.v2.auth.CommonTlsContext
buildCommonTlsContextWithAdditionalValuesV2(
String certName,
String certTargetUri,
String validationContextName,
String validationContextTargetUri,
Iterable<String> verifySubjectAltNames,
Iterable<String> alpnNames,
String channelType) {
io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.Builder builder =
io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.newBuilder();
io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig sdsSecretConfig =
buildSdsSecretConfigV2(certName, certTargetUri, channelType);
if (sdsSecretConfig != null) {
builder.addTlsCertificateSdsSecretConfigs(sdsSecretConfig);
}
sdsSecretConfig =
buildSdsSecretConfigV2(validationContextName, validationContextTargetUri, channelType);
io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext certValidationContext =
verifySubjectAltNames == null ? null
: io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext.newBuilder()
.addAllVerifySubjectAltName(verifySubjectAltNames).build();
if (sdsSecretConfig != null && certValidationContext != null) {
io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValidationContext
combined =
io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValidationContext
.newBuilder()
.setDefaultValidationContext(certValidationContext)
.setValidationContextSdsSecretConfig(sdsSecretConfig)
.build();
builder.setCombinedValidationContext(combined);
} else if (sdsSecretConfig != null) {
builder.setValidationContextSdsSecretConfig(sdsSecretConfig);
} else if (certValidationContext != null) {
builder.setValidationContext(certValidationContext);
}
if (alpnNames != null) {
builder.addAllAlpnProtocols(alpnNames);
}
return builder.build();
}
/** takes additional values and creates CombinedCertificateValidationContext as needed. */ /** takes additional values and creates CombinedCertificateValidationContext as needed. */
static CommonTlsContext buildCommonTlsContextWithAdditionalValues( static CommonTlsContext buildCommonTlsContextWithAdditionalValues(
String certName, String certInstanceName, String certName,
String certTargetUri, String validationContextCertInstanceName, String validationContextCertName,
String validationContextName,
String validationContextTargetUri,
Iterable<StringMatcher> matchSubjectAltNames, Iterable<StringMatcher> matchSubjectAltNames,
Iterable<String> alpnNames, Iterable<String> alpnNames) {
String channelType) {
CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
SdsSecretConfig sdsSecretConfig = buildSdsSecretConfig(certName, certTargetUri, channelType); CertificateProviderInstance certificateProviderInstance = CertificateProviderInstance
if (sdsSecretConfig != null) { .newBuilder().setInstanceName(certInstanceName).setCertificateName(certName).build();
builder.addTlsCertificateSdsSecretConfigs(sdsSecretConfig); if (certificateProviderInstance != null) {
builder.setTlsCertificateCertificateProviderInstance(certificateProviderInstance);
} }
sdsSecretConfig = CertificateProviderInstance validationCertificateProviderInstance =
buildSdsSecretConfig(validationContextName, validationContextTargetUri, channelType); CertificateProviderInstance.newBuilder().setInstanceName(validationContextCertInstanceName)
.setCertificateName(validationContextCertName).build();
CertificateValidationContext certValidationContext = CertificateValidationContext certValidationContext =
matchSubjectAltNames == null matchSubjectAltNames == null
? null ? null
: CertificateValidationContext.newBuilder() : CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(matchSubjectAltNames) .addAllMatchSubjectAltNames(matchSubjectAltNames)
.build(); .build();
if (sdsSecretConfig != null && certValidationContext != null) { if (validationCertificateProviderInstance != null && certValidationContext != null) {
CombinedCertificateValidationContext.Builder combinedBuilder = CombinedCertificateValidationContext.Builder combinedBuilder =
CombinedCertificateValidationContext.newBuilder() CombinedCertificateValidationContext.newBuilder()
.setDefaultValidationContext(certValidationContext) .setDefaultValidationContext(certValidationContext)
.setValidationContextSdsSecretConfig(sdsSecretConfig); .setValidationContextCertificateProviderInstance(
validationCertificateProviderInstance);
builder.setCombinedValidationContext(combinedBuilder); builder.setCombinedValidationContext(combinedBuilder);
} else if (sdsSecretConfig != null) { } else if (validationCertificateProviderInstance != null) {
builder.setValidationContextSdsSecretConfig(sdsSecretConfig); builder
.setValidationContextCertificateProviderInstance(validationCertificateProviderInstance);
} else if (certValidationContext != null) { } else if (certValidationContext != null) {
builder.setValidationContext(certValidationContext); builder.setValidationContext(certValidationContext);
} }
@ -268,18 +103,6 @@ public class CommonTlsContextTestsUtil {
return builder.build(); return builder.build();
} }
/** Helper method to build DownstreamTlsContext for multiple test classes. */
static io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext buildDownstreamTlsContextV2(
io.envoyproxy.envoy.api.v2.auth.CommonTlsContext commonTlsContext,
boolean requireClientCert) {
io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext downstreamTlsContext =
io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext.newBuilder()
.setCommonTlsContext(commonTlsContext)
.setRequireClientCertificate(BoolValue.of(requireClientCert))
.build();
return downstreamTlsContext;
}
/** Helper method to build DownstreamTlsContext for multiple test classes. */ /** Helper method to build DownstreamTlsContext for multiple test classes. */
static DownstreamTlsContext buildDownstreamTlsContext( static DownstreamTlsContext buildDownstreamTlsContext(
CommonTlsContext commonTlsContext, boolean requireClientCert) { CommonTlsContext commonTlsContext, boolean requireClientCert) {
@ -291,6 +114,20 @@ public class CommonTlsContextTestsUtil {
return downstreamTlsContext; return downstreamTlsContext;
} }
/** Helper method to build DownstreamTlsContext for multiple test classes. */
public static EnvoyServerProtoData.DownstreamTlsContext buildDownstreamTlsContext(
String commonInstanceName, boolean hasRootCert,
boolean requireClientCertificate) {
return buildDownstreamTlsContextForCertProviderInstance(
commonInstanceName,
"default",
hasRootCert ? commonInstanceName : null,
hasRootCert ? "ROOT" : null,
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ requireClientCertificate);
}
/** Helper method to build internal DownstreamTlsContext for multiple test classes. */ /** Helper method to build internal DownstreamTlsContext for multiple test classes. */
static EnvoyServerProtoData.DownstreamTlsContext buildInternalDownstreamTlsContext( static EnvoyServerProtoData.DownstreamTlsContext buildInternalDownstreamTlsContext(
CommonTlsContext commonTlsContext, boolean requireClientCert) { CommonTlsContext commonTlsContext, boolean requireClientCert) {
@ -298,36 +135,18 @@ public class CommonTlsContextTestsUtil {
buildDownstreamTlsContext(commonTlsContext, requireClientCert)); buildDownstreamTlsContext(commonTlsContext, requireClientCert));
} }
/** Helper method for creating DownstreamTlsContext values with names. */
public static io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext
buildTestDownstreamTlsContextV2(String certName, String validationContextName) {
return buildDownstreamTlsContextV2(
buildCommonTlsContextWithAdditionalValuesV2(
certName,
"unix:/var/run/sds/uds_path",
validationContextName,
"unix:/var/run/sds/uds_path",
Arrays.asList("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"),
Arrays.asList("managed-tls"),
null),
/* requireClientCert= */ false);
}
/** Helper method for creating DownstreamTlsContext values with names. */ /** Helper method for creating DownstreamTlsContext values with names. */
public static DownstreamTlsContext buildTestDownstreamTlsContext( public static DownstreamTlsContext buildTestDownstreamTlsContext(
String certName, String validationContextName) { String certName, String validationContextCertName) {
return buildDownstreamTlsContext( return buildDownstreamTlsContext(
buildCommonTlsContextWithAdditionalValues( buildCommonTlsContextWithAdditionalValues(
certName, "cert-instance-name", certName,
"unix:/var/run/sds/uds_path", "val-cert-instance-name", validationContextCertName,
validationContextName,
"unix:/var/run/sds/uds_path",
Arrays.asList( Arrays.asList(
StringMatcher.newBuilder() StringMatcher.newBuilder()
.setExact("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob") .setExact("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob")
.build()), .build()),
Arrays.asList("managed-tls"), Arrays.asList("managed-tls")),
null),
/* requireClientCert= */ false); /* requireClientCert= */ false);
} }
@ -341,103 +160,6 @@ public class CommonTlsContextTestsUtil {
return TestUtils.loadCert(resFile).getAbsolutePath(); return TestUtils.loadCert(resFile).getAbsolutePath();
} }
/**
* Helper method to build DownstreamTlsContext for above tests. Called from other classes as well.
*/
public static EnvoyServerProtoData.DownstreamTlsContext buildDownstreamTlsContextFromFilenames(
@Nullable String privateKey, @Nullable String certChain, @Nullable String trustCa) {
return buildDownstreamTlsContextFromFilenamesWithClientAuth(privateKey, certChain, trustCa,
false);
}
/**
* Helper method to build DownstreamTlsContext for above tests. Called from other classes as well.
*/
public static EnvoyServerProtoData.DownstreamTlsContext
buildDownstreamTlsContextFromFilenamesWithClientCertRequired(
@Nullable String privateKey, @Nullable String certChain, @Nullable String trustCa) {
return buildDownstreamTlsContextFromFilenamesWithClientAuth(privateKey, certChain, trustCa,
true);
}
private static EnvoyServerProtoData.DownstreamTlsContext
buildDownstreamTlsContextFromFilenamesWithClientAuth(
@Nullable String privateKey,
@Nullable String certChain,
@Nullable String trustCa,
boolean requireClientCert) {
// get temp file for each file
try {
if (certChain != null) {
certChain = getTempFileNameForResourcesFile(certChain);
}
if (privateKey != null) {
privateKey = getTempFileNameForResourcesFile(privateKey);
}
if (trustCa != null) {
trustCa = getTempFileNameForResourcesFile(trustCa);
}
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
return buildInternalDownstreamTlsContext(
buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa), requireClientCert);
}
/**
* Helper method to build UpstreamTlsContext for above tests. Called from other classes as well.
*/
public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContextFromFilenames(
@Nullable String privateKey, @Nullable String certChain, @Nullable String trustCa) {
try {
if (certChain != null) {
certChain = getTempFileNameForResourcesFile(certChain);
}
if (privateKey != null) {
privateKey = getTempFileNameForResourcesFile(privateKey);
}
if (trustCa != null) {
trustCa = getTempFileNameForResourcesFile(trustCa);
}
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
return buildUpstreamTlsContext(
buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
}
private static CommonTlsContext buildCommonTlsContextFromFilenames(
String privateKey, String certChain, String trustCa) {
TlsCertificate tlsCert = null;
if (!Strings.isNullOrEmpty(privateKey) && !Strings.isNullOrEmpty(certChain)) {
tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setFilename(certChain))
.setPrivateKey(DataSource.newBuilder().setFilename(privateKey))
.build();
}
CertificateValidationContext certContext = null;
if (!Strings.isNullOrEmpty(trustCa)) {
certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename(trustCa))
.build();
}
return getCommonTlsContext(tlsCert, certContext);
}
static CommonTlsContext getCommonTlsContext(
TlsCertificate tlsCertificate, CertificateValidationContext certContext) {
CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
if (tlsCertificate != null) {
builder = builder.addTlsCertificates(tlsCertificate);
}
if (certContext != null) {
builder = builder.setValidationContext(certContext);
}
return builder.build();
}
/** /**
* Helper method to build UpstreamTlsContext for above tests. Called from other classes as well. * Helper method to build UpstreamTlsContext for above tests. Called from other classes as well.
*/ */
@ -449,6 +171,18 @@ public class CommonTlsContextTestsUtil {
upstreamTlsContext); upstreamTlsContext);
} }
/** Helper method to build UpstreamTlsContext for multiple test classes. */
public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext(
String commonInstanceName, boolean hasIdentityCert) {
return buildUpstreamTlsContextForCertProviderInstance(
hasIdentityCert ? commonInstanceName : null,
hasIdentityCert ? "default" : null,
commonInstanceName,
"ROOT",
null,
null);
}
/** Gets a cert from contents of a resource. */ /** Gets a cert from contents of a resource. */
public static X509Certificate getCertFromResourceName(String resourceName) public static X509Certificate getCertFromResourceName(String resourceName)
throws IOException, CertificateException { throws IOException, CertificateException {
@ -516,22 +250,6 @@ public class CommonTlsContextTestsUtil {
return builder; return builder;
} }
static CommonTlsContext.Builder addCertificateValidationContext(
CommonTlsContext.Builder builder,
String name,
String targetUri,
String channelType,
CertificateValidationContext staticCertValidationContext) {
SdsSecretConfig sdsSecretConfig = buildSdsSecretConfig(name, targetUri, channelType);
CombinedCertificateValidationContext combined =
CombinedCertificateValidationContext.newBuilder()
.setDefaultValidationContext(staticCertValidationContext)
.setValidationContextSdsSecretConfig(sdsSecretConfig)
.build();
return builder.setCombinedValidationContext(combined);
}
/** Helper method to build UpstreamTlsContext for CertProvider tests. */ /** Helper method to build UpstreamTlsContext for CertProvider tests. */
public static EnvoyServerProtoData.UpstreamTlsContext public static EnvoyServerProtoData.UpstreamTlsContext
buildUpstreamTlsContextForCertProviderInstance( buildUpstreamTlsContextForCertProviderInstance(

View File

@ -31,20 +31,19 @@ import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.base.Strings; import com.google.common.util.concurrent.MoreExecutors;
import io.envoyproxy.envoy.config.core.v3.DataSource; import com.google.common.util.concurrent.SettableFuture;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.ChannelLogger; import io.grpc.ChannelLogger;
import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.internal.TestUtils.NoopChannelLogger; import io.grpc.internal.TestUtils.NoopChannelLogger;
import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.InternalProtocolNegotiationEvent;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.InternalProtocolNegotiators;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.CommonBootstrapperTestUtils;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.InternalXdsAttributes; import io.grpc.xds.InternalXdsAttributes;
@ -66,6 +65,7 @@ import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
import io.netty.handler.codec.http2.Http2ConnectionDecoder; import io.netty.handler.codec.http2.Http2ConnectionDecoder;
import io.netty.handler.codec.http2.Http2ConnectionEncoder; import io.netty.handler.codec.http2.Http2ConnectionEncoder;
import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent; import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import java.io.IOException; import java.io.IOException;
@ -74,6 +74,9 @@ import java.net.SocketAddress;
import java.security.cert.CertStoreException; import java.security.cert.CertStoreException;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
@ -89,70 +92,6 @@ public class SdsProtocolNegotiatorsTest {
private ChannelPipeline pipeline = channel.pipeline(); private ChannelPipeline pipeline = channel.pipeline();
private ChannelHandlerContext channelHandlerCtx; private ChannelHandlerContext channelHandlerCtx;
private static String getTempFileNameForResourcesFile(String resFile) throws IOException {
return Strings.isNullOrEmpty(resFile) ? null : TestUtils.loadCert(resFile).getAbsolutePath();
}
/** Builds DownstreamTlsContext from file-names. */
private static DownstreamTlsContext buildDownstreamTlsContextFromFilenames(
String privateKey, String certChain, String trustCa) throws IOException {
return buildDownstreamTlsContext(
buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
}
/** Builds UpstreamTlsContext from file-names. */
private static UpstreamTlsContext buildUpstreamTlsContextFromFilenames(
String privateKey, String certChain, String trustCa) throws IOException {
return CommonTlsContextTestsUtil.buildUpstreamTlsContext(
buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
}
/** Builds DownstreamTlsContext from commonTlsContext. */
private static DownstreamTlsContext buildDownstreamTlsContext(CommonTlsContext commonTlsContext) {
io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext
downstreamTlsContext =
io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext
.newBuilder()
.setCommonTlsContext(commonTlsContext)
.build();
return DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext(downstreamTlsContext);
}
private static CommonTlsContext buildCommonTlsContextFromFilenames(
String privateKey, String certChain, String trustCa) throws IOException {
TlsCertificate tlsCert = null;
privateKey = getTempFileNameForResourcesFile(privateKey);
certChain = getTempFileNameForResourcesFile(certChain);
trustCa = getTempFileNameForResourcesFile(trustCa);
if (!Strings.isNullOrEmpty(privateKey) && !Strings.isNullOrEmpty(certChain)) {
tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setFilename(certChain))
.setPrivateKey(DataSource.newBuilder().setFilename(privateKey))
.build();
}
CertificateValidationContext certContext = null;
if (!Strings.isNullOrEmpty(trustCa)) {
certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename(trustCa))
.build();
}
return getCommonTlsContext(tlsCert, certContext);
}
private static CommonTlsContext getCommonTlsContext(
TlsCertificate tlsCertificate, CertificateValidationContext certContext) {
CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
if (tlsCertificate != null) {
builder = builder.addTlsCertificates(tlsCertificate);
}
if (certContext != null) {
builder = builder.setValidationContext(certContext);
}
return builder.build();
}
@Test @Test
public void clientSdsProtocolNegotiatorNewHandler_noTlsContextAttribute() { public void clientSdsProtocolNegotiatorNewHandler_noTlsContextAttribute() {
ChannelHandler mockChannelHandler = mock(ChannelHandler.class); ChannelHandler mockChannelHandler = mock(ChannelHandler.class);
@ -181,8 +120,7 @@ public class SdsProtocolNegotiatorsTest {
@Test @Test
public void clientSdsProtocolNegotiatorNewHandler_withTlsContextAttribute() { public void clientSdsProtocolNegotiatorNewHandler_withTlsContextAttribute() {
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContext( CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build());
getCommonTlsContext(/* tlsCertificate= */ null, /* certContext= */ null));
ClientSdsProtocolNegotiator pn = ClientSdsProtocolNegotiator pn =
new ClientSdsProtocolNegotiator(InternalProtocolNegotiators.plaintext()); new ClientSdsProtocolNegotiator(InternalProtocolNegotiators.plaintext());
GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class);
@ -202,12 +140,18 @@ public class SdsProtocolNegotiatorsTest {
} }
@Test @Test
public void clientSdsHandler_addLast() throws IOException { public void clientSdsHandler_addLast()
throws InterruptedException, TimeoutException, ExecutionException {
Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils
.buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE,
CA_PEM_FILE, null, null, null, null);
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext =
buildUpstreamTlsContextFromFilenames(CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); CommonTlsContextTestsUtil
.buildUpstreamTlsContext("google_cloud_private_spiffe-client", true);
SslContextProviderSupplier sslContextProviderSupplier = SslContextProviderSupplier sslContextProviderSupplier =
new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(null)); new SslContextProviderSupplier(upstreamTlsContext,
new TlsContextManagerImpl(bootstrapInfoForClient));
SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler = SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler =
new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier); new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier);
pipeline.addLast(clientSdsHandler); pipeline.addLast(clientSdsHandler);
@ -216,7 +160,23 @@ public class SdsProtocolNegotiatorsTest {
// kick off protocol negotiation. // kick off protocol negotiation.
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
channel.runPendingTasks(); // need this for tasks to execute on eventLoop final SettableFuture<Object> future = SettableFuture.create();
sslContextProviderSupplier
.updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
@Override
public void updateSecret(SslContext sslContext) {
future.set(sslContext);
}
@Override
protected void onException(Throwable throwable) {
future.set(throwable);
}
});
channel.runPendingTasks();
Object fromFuture = future.get(2, TimeUnit.SECONDS);
assertThat(fromFuture).isInstanceOf(SslContext.class);
channel.runPendingTasks();
channelHandlerCtx = pipeline.context(clientSdsHandler); channelHandlerCtx = pipeline.context(clientSdsHandler);
assertThat(channelHandlerCtx).isNull(); assertThat(channelHandlerCtx).isNull();
@ -229,7 +189,8 @@ public class SdsProtocolNegotiatorsTest {
} }
@Test @Test
public void serverSdsHandler_addLast() throws IOException { public void serverSdsHandler_addLast()
throws InterruptedException, TimeoutException, ExecutionException {
// we need InetSocketAddress instead of EmbeddedSocketAddress as localAddress for this test // we need InetSocketAddress instead of EmbeddedSocketAddress as localAddress for this test
channel = channel =
new EmbeddedChannel() { new EmbeddedChannel() {
@ -244,12 +205,17 @@ public class SdsProtocolNegotiatorsTest {
} }
}; };
pipeline = channel.pipeline(); pipeline = channel.pipeline();
Bootstrapper.BootstrapInfo bootstrapInfoForServer = CommonBootstrapperTestUtils
.buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE,
SERVER_1_PEM_FILE, CA_PEM_FILE, null, null, null, null);
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
buildDownstreamTlsContextFromFilenames(SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); CommonTlsContextTestsUtil.buildDownstreamTlsContext(
"google_cloud_private_spiffe-server", true, true);
TlsContextManagerImpl tlsContextManager = new TlsContextManagerImpl(bootstrapInfoForServer);
XdsClientWrapperForServerSds xdsClientWrapperForServerSds = XdsClientWrapperForServerSds xdsClientWrapperForServerSds =
XdsClientWrapperForServerSdsTestMisc.createXdsClientWrapperForServerSds( XdsClientWrapperForServerSdsTestMisc.createXdsClientWrapperForServerSds(
80, downstreamTlsContext, new TlsContextManagerImpl(null)); 80, downstreamTlsContext, tlsContextManager);
SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler =
new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, xdsClientWrapperForServerSds, new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, xdsClientWrapperForServerSds,
InternalProtocolNegotiators.serverPlaintext()); InternalProtocolNegotiators.serverPlaintext());
@ -263,7 +229,26 @@ public class SdsProtocolNegotiatorsTest {
assertThat(channelHandlerCtx).isNull(); assertThat(channelHandlerCtx).isNull();
channelHandlerCtx = pipeline.context(SdsProtocolNegotiators.ServerSdsHandler.class); channelHandlerCtx = pipeline.context(SdsProtocolNegotiators.ServerSdsHandler.class);
assertThat(channelHandlerCtx).isNotNull(); assertThat(channelHandlerCtx).isNotNull();
SslContextProviderSupplier sslContextProviderSupplier =
new SslContextProviderSupplier(downstreamTlsContext, tlsContextManager);
final SettableFuture<Object> future = SettableFuture.create();
sslContextProviderSupplier
.updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
@Override
public void updateSecret(SslContext sslContext) {
future.set(sslContext);
}
@Override
protected void onException(Throwable throwable) {
future.set(throwable);
}
});
channel.runPendingTasks(); // need this for tasks to execute on eventLoop channel.runPendingTasks(); // need this for tasks to execute on eventLoop
Object fromFuture = future.get(2, TimeUnit.SECONDS);
assertThat(fromFuture).isInstanceOf(SslContext.class);
channel.runPendingTasks();
channelHandlerCtx = pipeline.context(SdsProtocolNegotiators.ServerSdsHandler.class); channelHandlerCtx = pipeline.context(SdsProtocolNegotiators.ServerSdsHandler.class);
assertThat(channelHandlerCtx).isNull(); assertThat(channelHandlerCtx).isNull();
@ -365,12 +350,17 @@ public class SdsProtocolNegotiatorsTest {
@Test @Test
public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent()
throws IOException, InterruptedException { throws InterruptedException, TimeoutException, ExecutionException {
Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils
.buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE,
CA_PEM_FILE, null, null, null, null);
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext =
buildUpstreamTlsContextFromFilenames(CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); CommonTlsContextTestsUtil
.buildUpstreamTlsContext("google_cloud_private_spiffe-client", true);
SslContextProviderSupplier sslContextProviderSupplier = SslContextProviderSupplier sslContextProviderSupplier =
new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(null)); new SslContextProviderSupplier(upstreamTlsContext,
new TlsContextManagerImpl(bootstrapInfoForClient));
SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler = SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler =
new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier); new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier);
@ -380,7 +370,23 @@ public class SdsProtocolNegotiatorsTest {
// kick off protocol negotiation. // kick off protocol negotiation.
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
final SettableFuture<Object> future = SettableFuture.create();
sslContextProviderSupplier
.updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
@Override
public void updateSecret(SslContext sslContext) {
future.set(sslContext);
}
@Override
protected void onException(Throwable throwable) {
future.set(throwable);
}
});
channel.runPendingTasks(); // need this for tasks to execute on eventLoop channel.runPendingTasks(); // need this for tasks to execute on eventLoop
Object fromFuture = future.get(5, TimeUnit.SECONDS);
assertThat(fromFuture).isInstanceOf(SslContext.class);
channel.runPendingTasks();
channelHandlerCtx = pipeline.context(clientSdsHandler); channelHandlerCtx = pipeline.context(clientSdsHandler);
assertThat(channelHandlerCtx).isNull(); assertThat(channelHandlerCtx).isNull();
Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;

View File

@ -1,263 +0,0 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.doChecksOnSslContext;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.getValueThruCallback;
import static io.grpc.xds.internal.sds.SdsClientTest.getOneCertificateValidationContextSecret;
import static io.grpc.xds.internal.sds.SdsClientTest.getOneTlsCertSecret;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.util.concurrent.MoreExecutors;
import io.envoyproxy.envoy.api.v2.core.Node;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.Status.Code;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback;
import java.io.IOException;
import java.util.Arrays;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link SdsClientSslContextProvider}. */
@RunWith(JUnit4.class)
public class SdsSslContextProviderTest {
private TestSdsServer.ServerMock serverMock;
private TestSdsServer server;
private Node node;
@Before
public void setUp() throws Exception {
serverMock = mock(TestSdsServer.ServerMock.class);
server = new TestSdsServer(serverMock);
server.startServer(/* name= */ "inproc", /* useUds= */ false, /* useInterceptor= */ false);
node = Node.newBuilder().setId("sds-client-temp-test1").build();
}
@After
public void teardown() throws InterruptedException {
server.shutdown();
}
/** Helper method to build SdsClientSslContextProvider from given names. */
private SdsClientSslContextProvider getSdsClientSslContextProvider(
String certName,
String validationContextName,
Iterable<StringMatcher> matchSubjectAltNames,
Iterable<String> alpnProtocols)
throws IOException {
CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextWithAdditionalValues(
certName,
/* certTargetUri= */ "inproc",
validationContextName,
/* validationContextTargetUri= */ "inproc",
matchSubjectAltNames,
alpnProtocols,
/* channelType= */ "inproc");
return SdsClientSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildUpstreamTlsContext(commonTlsContext),
node,
MoreExecutors.directExecutor(),
MoreExecutors.directExecutor());
}
/** Helper method to build SdsServerSslContextProvider from given names. */
private SdsServerSslContextProvider getSdsServerSslContextProvider(
String certName,
String validationContextName,
Iterable<StringMatcher> matchSubjectAltNames,
Iterable<String> alpnProtocols)
throws IOException {
CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextWithAdditionalValues(
certName,
/* certTargetUri= */ "inproc",
validationContextName,
/* validationContextTargetUri= */ "inproc",
matchSubjectAltNames,
alpnProtocols,
/* channelType= */ "inproc");
return SdsServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildInternalDownstreamTlsContext(
commonTlsContext, /* requireClientCert= */ false),
node,
MoreExecutors.directExecutor(),
MoreExecutors.directExecutor());
}
@Test
public void testProviderForServer() throws IOException {
when(serverMock.getSecretFor(/* name= */ "cert1"))
.thenReturn(getOneTlsCertSecret(/* name= */ "cert1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE));
when(serverMock.getSecretFor(/* name= */ "valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsServerSslContextProvider provider =
getSdsServerSslContextProvider("cert1", "valid1", null, null);
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
public void testProviderForClient() throws IOException {
when(serverMock.getSecretFor(/* name= */ "cert1"))
.thenReturn(getOneTlsCertSecret(/* name= */ "cert1", CLIENT_KEY_FILE, CLIENT_PEM_FILE));
when(serverMock.getSecretFor("valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsClientSslContextProvider provider =
getSdsClientSslContextProvider(
/* certName= */ "cert1",
/* validationContextName= */ "valid1",
/* matchSubjectAltNames= */ null,
/* alpnProtocols= */ null);
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
public void testProviderForServer_onlyCert() throws IOException {
when(serverMock.getSecretFor(/* name= */ "cert1"))
.thenReturn(getOneTlsCertSecret(/* name= */ "cert1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE));
SdsServerSslContextProvider provider =
getSdsServerSslContextProvider(
/* certName= */ "cert1",
/* validationContextName= */ null,
/* matchSubjectAltNames= */ null,
/* alpnProtocols= */ null);
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
public void getProviderForClient_onlyTrust() throws IOException {
when(serverMock.getSecretFor(/* name= */ "valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsClientSslContextProvider provider =
getSdsClientSslContextProvider(
/* certName= */ null,
/* validationContextName= */ "valid1",
/* matchSubjectAltNames= */ null,
null);
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
public void getProviderForServer_noCert_throwsException() throws IOException {
when(serverMock.getSecretFor(/* name= */ "valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsServerSslContextProvider provider =
getSdsServerSslContextProvider(
/* certName= */ null,
/* validationContextName= */ "valid1",
/* matchSubjectAltNames= */ null,
/* alpnProtocols= */ null);
TestCallback testCallback = getValueThruCallback(provider);
assertThat(server.lastNack).isNotNull();
assertThat(server.lastNack.getVersionInfo()).isEmpty();
assertThat(server.lastNack.getResponseNonce()).isEmpty();
com.google.rpc.Status errorDetail = server.lastNack.getErrorDetail();
assertThat(errorDetail.getCode()).isEqualTo(Code.UNKNOWN.value());
assertThat(errorDetail.getMessage()).isEqualTo("Secret not updated");
assertThat(testCallback.updatedSslContext).isNull();
}
@Test
public void testProviderForClient_withSubjectAltNames() throws IOException {
when(serverMock.getSecretFor(/* name= */ "cert1"))
.thenReturn(getOneTlsCertSecret(/* name= */ "cert1", CLIENT_KEY_FILE, CLIENT_PEM_FILE));
when(serverMock.getSecretFor("valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsClientSslContextProvider provider =
getSdsClientSslContextProvider(
/* certName= */ "cert1",
/* validationContextName= */ "valid1",
Arrays.asList(
StringMatcher.newBuilder()
.setExact("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob")
.build()),
/* alpnProtocols= */ null);
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
public void testProviderForClient_withAlpnProtocols() throws IOException {
when(serverMock.getSecretFor(/* name= */ "cert1"))
.thenReturn(getOneTlsCertSecret(/* name= */ "cert1", CLIENT_KEY_FILE, CLIENT_PEM_FILE));
when(serverMock.getSecretFor("valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsClientSslContextProvider provider =
getSdsClientSslContextProvider(
/* certName= */ "cert1",
/* validationContextName= */ "valid1",
/* matchSubjectAltNames= */ null,
/* alpnProtocols= */ Arrays.asList("managed-mtls", "h2"));
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(
false, testCallback.updatedSslContext, Arrays.asList("managed-mtls", "h2"));
}
@Test
public void testProviderForServer_withAlpnProtocols() throws IOException {
when(serverMock.getSecretFor(/* name= */ "cert1"))
.thenReturn(getOneTlsCertSecret(/* name= */ "cert1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE));
when(serverMock.getSecretFor(/* name= */ "valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsServerSslContextProvider provider =
getSdsServerSslContextProvider(
/* certName= */ "cert1",
/* validationContextName= */ "valid1",
/* matchSubjectAltNames= */ null,
/* alpnProtocols= */ Arrays.asList("managed-mtls", "h2"));
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(
true, testCallback.updatedSslContext, Arrays.asList("managed-mtls", "h2"));
}
}

View File

@ -1,444 +0,0 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.doChecksOnSslContext;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.getValueThruCallback;
import io.envoyproxy.envoy.config.core.v3.DataSource;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback;
import io.netty.handler.ssl.SslContext;
import java.io.IOException;
import java.security.cert.CertStoreException;
import java.security.cert.CertificateException;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link SecretVolumeClientSslContextProvider}. */
@RunWith(JUnit4.class)
public class SecretVolumeSslContextProviderTest {
@Rule public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Test
public void validateCertificateContext_nullAndNotOptional_throwsException() {
// expect exception when certContext is null and not optional
try {
CommonTlsContextUtil.validateCertificateContext(
/* certContext= */ null, /* optional= */ false);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("certContext is required");
}
}
@Test
public void validateCertificateContext_missingTrustCa_throwsException() {
// expect exception when certContext has no CA and not optional
CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance();
try {
CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ false);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("certContext is required");
}
}
@Test
public void validateCertificateContext_nullAndOptional() {
// certContext argument can be null when optional
CertificateValidationContext certContext =
CommonTlsContextUtil.validateCertificateContext(
/* certContext= */ null, /* optional= */ true);
assertThat(certContext).isNull();
}
@Test
public void validateCertificateContext_missingTrustCaOptional() {
// certContext argument can have missing CA when optional
CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance();
assertThat(CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ true))
.isNull();
}
@Test
public void validateCertificateContext_inlineString_throwsException() {
// expect exception when certContext doesn't use filename (inline string)
CertificateValidationContext certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setInlineString("foo"))
.build();
try {
CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ false);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
}
}
@Test
public void validateCertificateContext_filename() {
// validation succeeds and returns same instance when filename provided
CertificateValidationContext certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename("bar"))
.build();
assertThat(CommonTlsContextUtil.validateCertificateContext(certContext, /* optional= */ false))
.isSameInstanceAs(certContext);
}
@Test
public void validateTlsCertificate_nullAndNotOptional_throwsException() {
// expect exception when tlsCertificate is null and not optional
try {
CommonTlsContextUtil.validateTlsCertificate(
/* tlsCertificate= */ null, /* optional= */ false);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("tlsCertificate is required");
}
}
@Test
public void validateTlsCertificate_nullOptional() {
assertThat(
CommonTlsContextUtil.validateTlsCertificate(
/* tlsCertificate= */ null, /* optional= */ true))
.isNull();
}
@Test
public void validateTlsCertificate_defaultInstance_returnsNull() {
// tlsCertificate is not null but has no value (default instance): expect null
TlsCertificate tlsCert = TlsCertificate.getDefaultInstance();
assertThat(CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true)).isNull();
}
@Test
public void validateTlsCertificate_missingCertChainNotOptional_throwsException() {
// expect exception when tlsCertificate has missing certChain and not optional
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setPrivateKey(DataSource.newBuilder().setInlineString("foo"))
.build();
try {
CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ false);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
}
}
@Test
public void validateTlsCertificate_missingCertChainOptional_throwsException() {
// expect exception when tlsCertificate has missing certChain even if optional
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setPrivateKey(DataSource.newBuilder().setInlineString("foo"))
.build();
try {
CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
}
}
@Test
public void validateTlsCertificate_missingPrivateKeyNotOptional_throwsException() {
// expect exception when tlsCertificate has missing private key and not optional
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setInlineString("foo"))
.build();
try {
CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ false);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
}
}
@Test
public void validateTlsCertificate_missingPrivateKeyOptional_throwsException() {
// expect exception when tlsCertificate has missing private key even if optional
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setInlineString("foo"))
.build();
try {
CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
}
}
@Test
public void validateTlsCertificate_optional_returnsSameInstance() {
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setFilename("foo"))
.setPrivateKey(DataSource.newBuilder().setFilename("bar"))
.build();
assertThat(CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true))
.isSameInstanceAs(tlsCert);
}
@Test
public void validateTlsCertificate_notOptional_returnsSameInstance() {
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setFilename("foo"))
.setPrivateKey(DataSource.newBuilder().setFilename("bar"))
.build();
assertThat(CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ false))
.isSameInstanceAs(tlsCert);
}
@Test
public void validateTlsCertificate_certChainInlineString_throwsException() {
// expect exception when tlsCertificate has certChain as inline string
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setInlineString("foo"))
.setPrivateKey(DataSource.newBuilder().setFilename("bar"))
.build();
try {
CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
}
}
@Test
public void validateTlsCertificate_privateKeyInlineString_throwsException() {
// expect exception when tlsCertificate has private key as inline string
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setPrivateKey(DataSource.newBuilder().setInlineString("foo"))
.setCertificateChain(DataSource.newBuilder().setFilename("bar"))
.build();
try {
CommonTlsContextUtil.validateTlsCertificate(tlsCert, /* optional= */ true);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
}
}
@Test
public void getProviderForServer_defaultTlsCertificate_throwsException() {
TlsCertificate tlsCert = TlsCertificate.getDefaultInstance();
try {
SecretVolumeServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildInternalDownstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, /* certContext= */ null),
/* requireClientCert= */ false));
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
}
}
@Test
public void getProviderForServer_certContextWithInlineString_throwsException() {
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setFilename("foo"))
.setPrivateKey(DataSource.newBuilder().setFilename("bar"))
.build();
CertificateValidationContext certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setInlineString("foo"))
.build();
try {
SecretVolumeServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildInternalDownstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext),
/* requireClientCert= */ false));
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected.getMessage()).isEqualTo("filename expected");
}
}
@Test
public void getProviderForClient_defaultCertContext_throwsException() {
CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance();
try {
SecretVolumeClientSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildUpstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(
/* tlsCertificate= */ null, certContext)));
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("certContext is required");
}
}
@Test
public void getProviderForClient_certWithPrivateKeyInlineString_throwsException() {
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setFilename("foo"))
.setPrivateKey(DataSource.newBuilder().setInlineString("bar"))
.build();
CertificateValidationContext certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename("foo"))
.build();
try {
SecretVolumeClientSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildUpstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext)));
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
}
}
@Test
public void getProviderForClient_certWithCertChainInlineString_throwsException() {
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setInlineString("foo"))
.setPrivateKey(DataSource.newBuilder().setFilename("bar"))
.build();
CertificateValidationContext certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename("foo"))
.build();
try {
SecretVolumeClientSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildUpstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext)));
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
}
}
/**
* Helper method to build SecretVolumeSslContextProvider, call buildSslContext on it and
* check returned SslContext.
*/
private static void sslContextForEitherWithBothCertAndTrust(
boolean server, String pemFile, String keyFile, String caFile)
throws IOException, CertificateException, CertStoreException {
SslContext sslContext = null;
if (server) {
SecretVolumeServerSslContextProvider provider =
SecretVolumeServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
keyFile, pemFile, caFile));
sslContext = provider.buildSslContextFromSecrets();
} else {
SecretVolumeClientSslContextProvider provider =
SecretVolumeClientSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
keyFile, pemFile, caFile));
sslContext = provider.buildSslContextFromSecrets();
}
doChecksOnSslContext(server, sslContext, /* expectedApnProtos= */ null);
}
@Test
public void getProviderForServer() throws IOException, CertificateException, CertStoreException {
sslContextForEitherWithBothCertAndTrust(
true, SERVER_1_PEM_FILE, SERVER_1_KEY_FILE, CA_PEM_FILE);
}
@Test
public void getProviderForClient() throws IOException, CertificateException, CertStoreException {
sslContextForEitherWithBothCertAndTrust(false, CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE);
}
@Test
public void getProviderForServer_onlyCert()
throws IOException, CertificateException, CertStoreException {
sslContextForEitherWithBothCertAndTrust(true, SERVER_1_PEM_FILE, SERVER_1_KEY_FILE, null);
}
@Test
public void getProviderForClient_onlyTrust()
throws IOException, CertificateException, CertStoreException {
sslContextForEitherWithBothCertAndTrust(false, null, null, CA_PEM_FILE);
}
@Test
public void getProviderForServer_badFile_throwsException()
throws IOException, CertificateException, CertStoreException {
try {
sslContextForEitherWithBothCertAndTrust(true, SERVER_1_PEM_FILE, SERVER_1_PEM_FILE, null);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().contains("File does not contain valid private key");
}
}
@Test
public void getProviderForServer_both_callsback() throws IOException {
SecretVolumeServerSslContextProvider provider =
SecretVolumeServerSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE));
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
public void getProviderForClient_both_callsback() throws IOException {
SecretVolumeClientSslContextProvider provider =
SecretVolumeClientSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE));
TestCallback testCallback = getValueThruCallback(provider);
doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
// note this test generates stack-trace but can be safely ignored
@Test
public void getProviderForClient_both_callsback_setException() throws IOException {
SecretVolumeClientSslContextProvider provider =
SecretVolumeClientSslContextProvider.getProvider(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_PEM_FILE, CLIENT_PEM_FILE, CA_PEM_FILE));
TestCallback testCallback = getValueThruCallback(provider);
assertThat(testCallback.updatedSslContext).isNull();
assertThat(testCallback.updatedThrowable).isInstanceOf(IllegalArgumentException.class);
assertThat(testCallback.updatedThrowable).hasMessageThat()
.contains("File does not contain valid private key");
}
}

View File

@ -19,9 +19,6 @@ package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.ClientSslContextProviderFactoryTest.createAndRegisterProviderProvider; import static io.grpc.xds.internal.sds.ClientSslContextProviderFactoryTest.createAndRegisterProviderProvider;
import static io.grpc.xds.internal.sds.ClientSslContextProviderFactoryTest.verifyWatcher; import static io.grpc.xds.internal.sds.ClientSslContextProviderFactoryTest.verifyWatcher;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
@ -36,8 +33,6 @@ import io.grpc.xds.internal.certprovider.CertProviderServerSslContextProvider;
import io.grpc.xds.internal.certprovider.CertificateProvider; import io.grpc.xds.internal.certprovider.CertificateProvider;
import io.grpc.xds.internal.certprovider.CertificateProviderRegistry; import io.grpc.xds.internal.certprovider.CertificateProviderRegistry;
import io.grpc.xds.internal.certprovider.CertificateProviderStore; import io.grpc.xds.internal.certprovider.CertificateProviderStore;
import java.io.IOException;
import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -60,62 +55,6 @@ public class ServerSslContextProviderFactoryTest {
new CertProviderServerSslContextProvider.Factory(certificateProviderStore); new CertProviderServerSslContextProvider.Factory(certificateProviderStore);
} }
@Test
public void createSslContextProvider_allFilenames() {
serverSslContextProviderFactory =
new ServerSslContextProviderFactory(
null, certProviderServerSslContextProviderFactory);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isNotNull();
}
@Test
public void createSslContextProvider_sdsConfigForTlsCert_expectException() {
serverSslContextProviderFactory =
new ServerSslContextProviderFactory(
null, certProviderServerSslContextProviderFactory);
CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForTlsCertificate(
"name", "unix:/tmp/sds/path", CA_PEM_FILE);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildInternalDownstreamTlsContext(
commonTlsContext, /* requireClientCert= */ false);
try {
SslContextProvider unused =
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("unexpected TlsCertificateSdsSecretConfigs");
}
}
@Test
public void createSslContextProvider_sdsConfigForCertValidationContext_expectException() {
serverSslContextProviderFactory =
new ServerSslContextProviderFactory(
null, certProviderServerSslContextProviderFactory);
CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForValidationContext(
"name", "unix:/tmp/sds/path", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildInternalDownstreamTlsContext(
commonTlsContext, /* requireClientCert= */ false);
try {
SslContextProvider unused =
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (IllegalStateException expected) {
assertThat(expected).hasMessageThat().isEqualTo("incorrect ValidationContextTypeCase");
}
}
@Test @Test
public void createCertProviderServerSslContextProvider() throws XdsInitializationException { public void createCertProviderServerSslContextProvider() throws XdsInitializationException {
final CertificateProvider.DistributorWatcher[] watcherCaptor = final CertificateProvider.DistributorWatcher[] watcherCaptor =
@ -267,37 +206,4 @@ public class ServerSslContextProviderFactoryTest {
verifyWatcher(sslContextProvider, watcherCaptor[0]); verifyWatcher(sslContextProvider, watcherCaptor[0]);
verifyWatcher(sslContextProvider, watcherCaptor[1]); verifyWatcher(sslContextProvider, watcherCaptor[1]);
} }
@Test
public void createEmptyCommonTlsContext_exception() throws IOException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(null, null, null);
serverSslContextProviderFactory =
new ServerSslContextProviderFactory(
null, certProviderServerSslContextProviderFactory);
try {
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (UnsupportedOperationException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("Unsupported configurations in DownstreamTlsContext!");
}
}
@Test
public void createNullCommonTlsContext_exception() throws IOException {
serverSslContextProviderFactory =
new ServerSslContextProviderFactory(
null, certProviderServerSslContextProviderFactory);
DownstreamTlsContext downstreamTlsContext = new DownstreamTlsContext(null, true);
try {
serverSslContextProviderFactory.create(downstreamTlsContext);
Assert.fail("no exception thrown");
} catch (NullPointerException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("downstreamTlsContext should have CommonTlsContext");
}
}
} }

View File

@ -17,9 +17,6 @@
package io.grpc.xds.internal.sds; package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.any; import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
@ -63,8 +60,7 @@ public class SslContextProviderSupplierTest {
private void prepareSupplier() { private void prepareSupplier() {
upstreamTlsContext = upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true);
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
mockSslContextProvider = mock(SslContextProvider.class); mockSslContextProvider = mock(SslContextProvider.class);
doReturn(mockSslContextProvider) doReturn(mockSslContextProvider)
.when(mockTlsContextManager) .when(mockTlsContextManager)

View File

@ -30,6 +30,8 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.CommonBootstrapperTestUtils;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory;
@ -53,11 +55,14 @@ public class TlsContextManagerTest {
@Test @Test
public void createServerSslContextProvider() { public void createServerSslContextProvider() {
Bootstrapper.BootstrapInfo bootstrapInfoForServer = CommonBootstrapperTestUtils
.buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE,
SERVER_1_PEM_FILE, CA_PEM_FILE, null, null, null, null);
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildDownstreamTlsContext(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); "google_cloud_private_spiffe-server", false, false);
TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(null); TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(bootstrapInfoForServer);
SslContextProvider serverSecretProvider = SslContextProvider serverSecretProvider =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext);
assertThat(serverSecretProvider).isNotNull(); assertThat(serverSecretProvider).isNotNull();
@ -69,11 +74,14 @@ public class TlsContextManagerTest {
@Test @Test
public void createClientSslContextProvider() { public void createClientSslContextProvider() {
Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils
.buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE,
CA_PEM_FILE, null, null, null, null);
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false);
TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(null); TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(bootstrapInfoForClient);
SslContextProvider clientSecretProvider = SslContextProvider clientSecretProvider =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext);
assertThat(clientSecretProvider).isNotNull(); assertThat(clientSecretProvider).isNotNull();
@ -85,18 +93,23 @@ public class TlsContextManagerTest {
@Test @Test
public void createServerSslContextProvider_differentInstance() { public void createServerSslContextProvider_differentInstance() {
Bootstrapper.BootstrapInfo bootstrapInfoForServer = CommonBootstrapperTestUtils
.buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE,
SERVER_1_PEM_FILE, CA_PEM_FILE, "cert-instance2", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE,
CA_PEM_FILE);
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildDownstreamTlsContext(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); "google_cloud_private_spiffe-server", false, false);
TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(null); TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(bootstrapInfoForServer);
SslContextProvider serverSecretProvider = SslContextProvider serverSecretProvider =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext);
assertThat(serverSecretProvider).isNotNull(); assertThat(serverSecretProvider).isNotNull();
DownstreamTlsContext downstreamTlsContext1 = DownstreamTlsContext downstreamTlsContext1 =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildDownstreamTlsContext(
SERVER_0_KEY_FILE, SERVER_0_PEM_FILE, CA_PEM_FILE); "cert-instance2", true, true);
SslContextProvider serverSecretProvider1 = SslContextProvider serverSecretProvider1 =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext1); tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext1);
assertThat(serverSecretProvider1).isNotNull(); assertThat(serverSecretProvider1).isNotNull();
@ -105,18 +118,20 @@ public class TlsContextManagerTest {
@Test @Test
public void createClientSslContextProvider_differentInstance() { public void createClientSslContextProvider_differentInstance() {
Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils
.buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE,
CA_PEM_FILE, "cert-instance-2", CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false);
TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(null); TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(bootstrapInfoForClient);
SslContextProvider clientSecretProvider = SslContextProvider clientSecretProvider =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext);
assertThat(clientSecretProvider).isNotNull(); assertThat(clientSecretProvider).isNotNull();
UpstreamTlsContext upstreamTlsContext1 = UpstreamTlsContext upstreamTlsContext1 =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildUpstreamTlsContext("cert-instance-2", true);
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider clientSecretProvider1 = SslContextProvider clientSecretProvider1 =
tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext1); tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext1);
@ -126,8 +141,8 @@ public class TlsContextManagerTest {
@Test @Test
public void createServerSslContextProvider_releaseInstance() { public void createServerSslContextProvider_releaseInstance() {
DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( CommonTlsContextTestsUtil.buildDownstreamTlsContext(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); "google_cloud_private_spiffe-server", false, false);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory); new TlsContextManagerImpl(mockClientFactory, mockServerFactory);
@ -145,8 +160,8 @@ public class TlsContextManagerTest {
@Test @Test
public void createClientSslContextProvider_releaseInstance() { public void createClientSslContextProvider_releaseInstance() {
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CommonTlsContextTestsUtil
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl tlsContextManagerImpl =
new TlsContextManagerImpl(mockClientFactory, mockServerFactory); new TlsContextManagerImpl(mockClientFactory, mockServerFactory);