From 06ca927a649c84fb3fc4b6c82aa7695bf3edf728 Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Wed, 29 Jul 2020 09:10:02 -0700 Subject: [PATCH] xds: first part of MeshCaCertificateProvider (#7247) --- .../certprovider/CertificateProvider.java | 3 + .../CertificateProviderStore.java | 6 +- .../MeshCaCertificateProvider.java | 134 +++++++++++ .../MeshCaCertificateProviderProvider.java | 197 ++++++++++++++++ .../grpc/xds/internal/sts/StsCredentials.java | 79 ++++--- .../CertificateProviderStoreTest.java | 9 + ...MeshCaCertificateProviderProviderTest.java | 212 ++++++++++++++++++ .../xds/internal/sts/StsCredentialsTest.java | 17 +- 8 files changed, 618 insertions(+), 39 deletions(-) create mode 100644 xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java index 89d6954898..9b96f9957f 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java @@ -124,6 +124,9 @@ public abstract class CertificateProvider implements Closeable { @Override public abstract void close(); + /** Starts the cert refresh and watcher update cycle. */ + public abstract void start(); + private final DistributorWatcher watcher; private final boolean notifyCertUpdates; diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java index 2a1452fbe9..5b09a58127 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java @@ -131,8 +131,10 @@ public final class CertificateProviderStore { if (certProviderProvider == null) { throw new IllegalArgumentException("Provider not found."); } - return certProviderProvider.createCertificateProvider( - key.config, new CertificateProvider.DistributorWatcher(), key.notifyCertUpdates); + CertificateProvider certProvider = certProviderProvider.createCertificateProvider( + key.config, new CertificateProvider.DistributorWatcher(), key.notifyCertUpdates); + certProvider.start(); + return certProvider; } } diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java new file mode 100644 index 0000000000..a4d22faa6a --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java @@ -0,0 +1,134 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.auth.oauth2.GoogleCredentials; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.internal.BackoffPolicy; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** Implementation of {@link CertificateProvider} for the Google Mesh CA. */ +final class MeshCaCertificateProvider extends CertificateProvider { + private static final Logger logger = Logger.getLogger(MeshCaCertificateProvider.class.getName()); + + protected MeshCaCertificateProvider(DistributorWatcher watcher, boolean notifyCertUpdates, + String meshCaUrl, String zone, long validitySeconds, + int keySize, String alg, String signatureAlg, MeshCaChannelFactory meshCaChannelFactory, + BackoffPolicy.Provider backoffPolicyProvider, long renewalGracePeriodSeconds, + int maxRetryAttempts, GoogleCredentials oauth2Creds) { + super(watcher, notifyCertUpdates); + } + + @Override + public void start() { + // TODO implement + } + + @Override + public void close() { + // TODO implement + } + + /** Factory for creating channels to MeshCA sever. */ + abstract static class MeshCaChannelFactory { + + private static final MeshCaChannelFactory DEFAULT_INSTANCE = + new MeshCaChannelFactory() { + + /** Creates a channel to the URL in the given list. */ + @Override + ManagedChannel createChannel(String serverUri) { + checkArgument(serverUri != null && !serverUri.isEmpty(), "serverUri is null/empty!"); + logger.log(Level.INFO, "Creating channel to {0}", serverUri); + + ManagedChannelBuilder channelBuilder = ManagedChannelBuilder.forTarget(serverUri); + return channelBuilder.keepAliveTime(1, TimeUnit.MINUTES).build(); + } + }; + + static MeshCaChannelFactory getInstance() { + return DEFAULT_INSTANCE; + } + + /** + * Creates a channel to the server. + */ + abstract ManagedChannel createChannel(String serverUri); + } + + /** Factory for creating channels to MeshCA sever. */ + abstract static class Factory { + private static final Factory DEFAULT_INSTANCE = + new Factory() { + + @Override + MeshCaCertificateProvider create( + DistributorWatcher watcher, + boolean notifyCertUpdates, + String meshCaUrl, + String zone, + long validitySeconds, + int keySize, + String alg, + String signatureAlg, + MeshCaChannelFactory meshCaChannelFactory, + BackoffPolicy.Provider backoffPolicyProvider, + long renewalGracePeriodSeconds, + int maxRetryAttempts, + GoogleCredentials oauth2Creds) { + return new MeshCaCertificateProvider( + watcher, + notifyCertUpdates, + meshCaUrl, + zone, + validitySeconds, + keySize, + alg, + signatureAlg, + meshCaChannelFactory, + backoffPolicyProvider, + renewalGracePeriodSeconds, + maxRetryAttempts, + oauth2Creds); + } + }; + + static Factory getInstance() { + return DEFAULT_INSTANCE; + } + + abstract MeshCaCertificateProvider create( + DistributorWatcher watcher, + boolean notifyCertUpdates, + String meshCaUrl, + String zone, + long validitySeconds, + int keySize, + String alg, + String signatureAlg, + MeshCaChannelFactory meshCaChannelFactory, + BackoffPolicy.Provider backoffPolicyProvider, + long renewalGracePeriodSeconds, + int maxRetryAttempts, + GoogleCredentials oauth2Creds); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java new file mode 100644 index 0000000000..a9c1b01ba0 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java @@ -0,0 +1,197 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.internal.BackoffPolicy; +import io.grpc.internal.ExponentialBackoffPolicy; +import io.grpc.xds.internal.sts.StsCredentials; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Provider of {@link CertificateProvider}s. Implemented by the implementer of the plugin. We may + * move this out of the internal package and make this an official API in the future. + */ +final class MeshCaCertificateProviderProvider implements CertificateProviderProvider { + + private static final String MESHCA_URL_KEY = "meshCaUrl"; + private static final String RPC_TIMEOUT_SECONDS_KEY = "rpcTimeoutSeconds"; + private static final String GKECLUSTER_URL_KEY = "gkeClusterUrl"; + private static final String CERT_VALIDITY_SECONDS_KEY = "certValiditySeconds"; + private static final String RENEWAL_GRACE_PERIOD_SECONDS_KEY = "renewalGracePeriodSeconds"; + private static final String KEY_ALGO_KEY = "keyAlgo"; // aka keyType + private static final String KEY_SIZE_KEY = "keySize"; + private static final String SIGNATURE_ALGO_KEY = "signatureAlgo"; + private static final String MAX_RETRY_ATTEMPTS_KEY = "maxRetryAttempts"; + private static final String STS_URL_KEY = "stsUrl"; + private static final String GKE_SA_JWT_LOCATION_KEY = "gkeSaJwtLocation"; + + static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com"; + static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L; + static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L; // 9 hours + static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L; // 1 hour + static final String KEY_ALGO_DEFAULT = "RSA"; // aka keyType + static final int KEY_SIZE_DEFAULT = 2048; + static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA"; + static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3; + static final String STS_URL_DEFAULT = "https://securetoken.googleapis.com/v1/identitybindingtoken"; + + private static final Pattern CLUSTER_URL_PATTERN = Pattern + .compile(".*/projects/(.*)/locations/(.*)/clusters/.*"); + + private static final String TRUST_DOMAIN_SUFFIX = ".svc.id.goog"; + private static final String AUDIENCE_PREFIX = "identitynamespace:"; + static final String MESH_CA_NAME = "meshCA"; + + static { + CertificateProviderRegistry.getInstance() + .register( + new MeshCaCertificateProviderProvider( + StsCredentials.Factory.getInstance(), + MeshCaCertificateProvider.MeshCaChannelFactory.getInstance(), + new ExponentialBackoffPolicy.Provider(), + MeshCaCertificateProvider.Factory.getInstance())); + } + + final StsCredentials.Factory stsCredentialsFactory; + final MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory; + final BackoffPolicy.Provider backoffPolicyProvider; + final MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory; + + @VisibleForTesting + MeshCaCertificateProviderProvider(StsCredentials.Factory stsCredentialsFactory, + MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory, + BackoffPolicy.Provider backoffPolicyProvider, + MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory) { + this.stsCredentialsFactory = stsCredentialsFactory; + this.meshCaChannelFactory = meshCaChannelFactory; + this.backoffPolicyProvider = backoffPolicyProvider; + this.meshCaCertificateProviderFactory = meshCaCertificateProviderFactory; + } + + @Override + public String getName() { + return MESH_CA_NAME; + } + + @Override + public CertificateProvider createCertificateProvider( + Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates) { + + Config configObj = validateAndTranslateConfig(config); + + // Construct audience from project and gkeClusterUrl + String audience = + AUDIENCE_PREFIX + configObj.project + TRUST_DOMAIN_SUFFIX + ":" + configObj.gkeClusterUrl; + StsCredentials stsCredentials = stsCredentialsFactory + .create(configObj.stsUrl, audience, configObj.gkeSaJwtLocation); + + return meshCaCertificateProviderFactory.create(watcher, notifyCertUpdates, configObj.meshCaUrl, + configObj.zone, + configObj.certValiditySeconds, configObj.keySize, configObj.keyAlgo, + configObj.signatureAlgo, + meshCaChannelFactory, backoffPolicyProvider, + configObj.renewalGracePeriodSeconds, configObj.maxRetryAttempts, stsCredentials); + } + + private static Config validateAndTranslateConfig(Object config) { + // TODO(sanjaypujare): add support for string, struct proto etc + checkArgument(config instanceof Map, "Only Map supported for config"); + @SuppressWarnings("unchecked") Map map = (Map)config; + + Config configObj = new Config(); + configObj.meshCaUrl = mapGetOrDefault(map, MESHCA_URL_KEY, MESHCA_URL_DEFAULT); + configObj.rpcTimeoutSeconds = + mapGetOrDefault(map, RPC_TIMEOUT_SECONDS_KEY, RPC_TIMEOUT_SECONDS_DEFAULT); + configObj.gkeClusterUrl = + checkNotNull( + map.get(GKECLUSTER_URL_KEY), GKECLUSTER_URL_KEY + " is required in the config"); + configObj.certValiditySeconds = + mapGetOrDefault(map, CERT_VALIDITY_SECONDS_KEY, CERT_VALIDITY_SECONDS_DEFAULT); + configObj.renewalGracePeriodSeconds = + mapGetOrDefault( + map, RENEWAL_GRACE_PERIOD_SECONDS_KEY, RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT); + configObj.keyAlgo = mapGetOrDefault(map, KEY_ALGO_KEY, KEY_ALGO_DEFAULT); + configObj.keySize = mapGetOrDefault(map, KEY_SIZE_KEY, KEY_SIZE_DEFAULT); + configObj.signatureAlgo = mapGetOrDefault(map, SIGNATURE_ALGO_KEY, SIGNATURE_ALGO_DEFAULT); + configObj.maxRetryAttempts = + mapGetOrDefault(map, MAX_RETRY_ATTEMPTS_KEY, MAX_RETRY_ATTEMPTS_DEFAULT); + configObj.stsUrl = mapGetOrDefault(map, STS_URL_KEY, STS_URL_DEFAULT); + configObj.gkeSaJwtLocation = + checkNotNull( + map.get(GKE_SA_JWT_LOCATION_KEY), + GKE_SA_JWT_LOCATION_KEY + " is required in the config"); + parseProjectAndZone(configObj.gkeClusterUrl, configObj); + return configObj; + } + + private static String mapGetOrDefault(Map map, String key, String defaultVal) { + String value = map.get(key); + if (value == null) { + return defaultVal; + } + return value; + } + + private static Long mapGetOrDefault(Map map, String key, long defaultVal) { + String value = map.get(key); + if (value == null) { + return defaultVal; + } + return Long.parseLong(value); + } + + private static Integer mapGetOrDefault(Map map, String key, int defaultVal) { + String value = map.get(key); + if (value == null) { + return defaultVal; + } + return Integer.parseInt(value); + } + + private static void parseProjectAndZone(String gkeClusterUrl, Config configObj) { + Matcher matcher = CLUSTER_URL_PATTERN.matcher(gkeClusterUrl); + checkState(matcher.find(), "gkeClusterUrl does not have correct format"); + checkState(matcher.groupCount() == 2, "gkeClusterUrl does not have project and location parts"); + configObj.project = matcher.group(1); + configObj.zone = matcher.group(2); + } + + /** POJO class for storing various config values. */ + @VisibleForTesting + static class Config { + String meshCaUrl; + Long rpcTimeoutSeconds; + String gkeClusterUrl; + Long certValiditySeconds; + Long renewalGracePeriodSeconds; + String keyAlgo; // aka keyType + Integer keySize; + String signatureAlgo; + Integer maxRetryAttempts; + String stsUrl; + String gkeSaJwtLocation; + String zone; + String project; + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sts/StsCredentials.java b/xds/src/main/java/io/grpc/xds/internal/sts/StsCredentials.java index 6581650249..fde5f1f0a9 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sts/StsCredentials.java +++ b/xds/src/main/java/io/grpc/xds/internal/sts/StsCredentials.java @@ -47,14 +47,14 @@ import java.util.Map; public final class StsCredentials extends GoogleCredentials { private static final long serialVersionUID = 6647041424685484932L; - private static final HttpTransportFactory defaultHttpTransportFactory = + @VisibleForTesting static final HttpTransportFactory defaultHttpTransportFactory = new DefaultHttpTransportFactory(); private static final String CLOUD_PLATFORM_SCOPE = "https://www.googleapis.com/auth/cloud-platform"; - private final String sourceCredentialsFileLocation; - private final String identityTokenEndpoint; - private final String audience; - private transient HttpTransportFactory transportFactory; + @VisibleForTesting final String sourceCredentialsFileLocation; + @VisibleForTesting final String identityTokenEndpoint; + @VisibleForTesting final String audience; + @VisibleForTesting transient HttpTransportFactory transportFactory; private StsCredentials( String identityTokenEndpoint, @@ -67,33 +67,6 @@ public final class StsCredentials extends GoogleCredentials { this.transportFactory = transportFactory; } - /** - * Creates an StsCredentials. - * - * @param identityTokenEndpoint URL of the token exchange service to use. - * @param audience Audience to use in the STS request. - * @param sourceCredentialsFileLocation file-system location that contains the - * source creds e.g. JWT contents. - */ - public static StsCredentials create( - String identityTokenEndpoint, String audience, String sourceCredentialsFileLocation) { - return create( - identityTokenEndpoint, - audience, - sourceCredentialsFileLocation, - getFromServiceLoader(HttpTransportFactory.class, defaultHttpTransportFactory)); - } - - @VisibleForTesting - static StsCredentials create( - String identityTokenEndpoint, - String audience, - String sourceCredentialsFileLocation, - HttpTransportFactory transportFactory) { - return new StsCredentials( - identityTokenEndpoint, audience, sourceCredentialsFileLocation, transportFactory); - } - @Override public AccessToken refreshAccessToken() throws IOException { AccessToken tok = getSourceAccessTokenFromFileLocation(); @@ -157,6 +130,48 @@ public final class StsCredentials extends GoogleCredentials { throw new UnsupportedOperationException("toBuilder not supported"); } + /** Factory for creating StsCredentials. */ + public abstract static class Factory { + private static final Factory DEFAULT_INSTANCE = + new Factory() { + + @Override + public StsCredentials create( + String identityTokenEndpoint, String audience, String sourceCredentialsFileLocation) { + return create( + identityTokenEndpoint, + audience, + sourceCredentialsFileLocation, + getFromServiceLoader(HttpTransportFactory.class, defaultHttpTransportFactory)); + } + }; + + public static Factory getInstance() { + return DEFAULT_INSTANCE; + } + + /** + * Creates an StsCredentials. + * + * @param identityTokenEndpoint URL of the token exchange service to use. + * @param audience Audience to use in the STS request. + * @param sourceCredentialsFileLocation file-system location that contains the + * source creds e.g. JWT contents. + */ + public abstract StsCredentials create( + String identityTokenEndpoint, String audience, String sourceCredentialsFileLocation); + + @VisibleForTesting + static StsCredentials create( + String identityTokenEndpoint, + String audience, + String sourceCredentialsFileLocation, + HttpTransportFactory transportFactory) { + return new StsCredentials( + identityTokenEndpoint, audience, sourceCredentialsFileLocation, transportFactory); + } + } + private static class DefaultHttpTransportFactory implements HttpTransportFactory { private static final HttpTransport netHttpTransport = new NetHttpTransport(); diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java index 521de29d09..8a0dfeb1d7 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java @@ -53,6 +53,7 @@ public class CertificateProviderStoreTest { Object config; CertificateProviderProvider certProviderProvider; int closeCalled = 0; + int startCalled = 0; protected TestCertificateProvider( CertificateProvider.DistributorWatcher watcher, @@ -71,6 +72,11 @@ public class CertificateProviderStoreTest { public void close() { closeCalled++; } + + @Override + public void start() { + startCalled++; + } } @Before @@ -161,6 +167,7 @@ public class CertificateProviderStoreTest { assertThat(handle1.certProvider).isInstanceOf(TestCertificateProvider.class); TestCertificateProvider testCertificateProvider = (TestCertificateProvider) handle1.certProvider; + assertThat(testCertificateProvider.startCalled).isEqualTo(1); CertificateProvider.DistributorWatcher distWatcher = testCertificateProvider.getWatcher(); assertThat(distWatcher.downsstreamWatchers).hasSize(2); PrivateKey testKey = mock(PrivateKey.class); @@ -335,6 +342,8 @@ public class CertificateProviderStoreTest { verify(mockWatcher2, times(1)).updateCertificate(eq(testKey2), eq(testList2)); verify(mockWatcher1, never()) .updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class)); + assertThat(testCertificateProvider1.startCalled).isEqualTo(1); + assertThat(testCertificateProvider2.startCalled).isEqualTo(1); handle2.close(); assertThat(testCertificateProvider2.closeCalled).isEqualTo(1); handle1.close(); diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java new file mode 100644 index 0000000000..d9d4da9350 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java @@ -0,0 +1,212 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.certprovider; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.google.auth.oauth2.GoogleCredentials; +import io.grpc.internal.BackoffPolicy; +import io.grpc.internal.ExponentialBackoffPolicy; +import io.grpc.xds.internal.sts.StsCredentials; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Unit tests for {@link MeshCaCertificateProviderProvider}. */ +@RunWith(JUnit4.class) +public class MeshCaCertificateProviderProviderTest { + + public static final String EXPECTED_AUDIENCE = + "identitynamespace:test-project1.svc.id.goog:https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3"; + public static final String TMP_PATH_4 = "/tmp/path4"; + public static final String NON_DEFAULT_MESH_CA_URL = "nonDefaultMeshCaUrl"; + public static final String GKE_CLUSTER_URL = + "https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3"; + + @Mock + StsCredentials.Factory stsCredentialsFactory; + + @Mock + MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory; + + @Mock + BackoffPolicy.Provider backoffPolicyProvider; + + @Mock + MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory; + private MeshCaCertificateProviderProvider provider; + + @Before + public void setUp() throws IOException { + MockitoAnnotations.initMocks(this); + provider = + new MeshCaCertificateProviderProvider( + stsCredentialsFactory, + meshCaChannelFactory, + backoffPolicyProvider, + meshCaCertificateProviderFactory); + } + + @Test + public void providerRegisteredName() { + CertificateProviderProvider certProviderProvider = CertificateProviderRegistry.getInstance() + .getProvider(MeshCaCertificateProviderProvider.MESH_CA_NAME); + assertThat(certProviderProvider).isInstanceOf(MeshCaCertificateProviderProvider.class); + MeshCaCertificateProviderProvider meshCaCertificateProviderProvider = + (MeshCaCertificateProviderProvider) certProviderProvider; + assertThat(meshCaCertificateProviderProvider.stsCredentialsFactory) + .isSameInstanceAs(StsCredentials.Factory.getInstance()); + assertThat(meshCaCertificateProviderProvider.meshCaChannelFactory) + .isSameInstanceAs(MeshCaCertificateProvider.MeshCaChannelFactory.getInstance()); + assertThat(meshCaCertificateProviderProvider.backoffPolicyProvider) + .isInstanceOf(ExponentialBackoffPolicy.Provider.class); + assertThat(meshCaCertificateProviderProvider.meshCaCertificateProviderFactory) + .isSameInstanceAs(MeshCaCertificateProvider.Factory.getInstance()); + } + + @Test + public void createProvider_minimalConfig() { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + Map map = buildMinimalMap(); + provider.createCertificateProvider(map, distWatcher, true); + verify(stsCredentialsFactory, times(1)) + .create( + eq(MeshCaCertificateProviderProvider.STS_URL_DEFAULT), + eq(EXPECTED_AUDIENCE), + eq(TMP_PATH_4)); + verify(meshCaCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(true), + eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT), + eq("test-zone2"), + eq(MeshCaCertificateProviderProvider.CERT_VALIDITY_SECONDS_DEFAULT), + eq(MeshCaCertificateProviderProvider.KEY_SIZE_DEFAULT), + eq(MeshCaCertificateProviderProvider.KEY_ALGO_DEFAULT), + eq(MeshCaCertificateProviderProvider.SIGNATURE_ALGO_DEFAULT), + eq(meshCaChannelFactory), + eq(backoffPolicyProvider), + eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT), + eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT), + (GoogleCredentials) isNull()); + } + + @Test + public void createProvider_missingGkeUrl_expectException() { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + Map map = buildMinimalMap(); + map.remove("gkeClusterUrl"); + try { + provider.createCertificateProvider(map, distWatcher, true); + fail("exception expected"); + } catch (NullPointerException npe) { + assertThat(npe).hasMessageThat().isEqualTo("gkeClusterUrl is required in the config"); + } + } + + @Test + public void createProvider_missingGkeSaJwtLocation_expectException() { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + Map map = buildMinimalMap(); + map.remove("gkeSaJwtLocation"); + try { + provider.createCertificateProvider(map, distWatcher, true); + fail("exception expected"); + } catch (NullPointerException npe) { + assertThat(npe).hasMessageThat().isEqualTo("gkeSaJwtLocation is required in the config"); + } + } + + @Test + public void createProvider_missingProject_expectException() { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + Map map = buildMinimalMap(); + map.put("gkeClusterUrl", "https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3"); + try { + provider.createCertificateProvider(map, distWatcher, true); + fail("exception expected"); + } catch (IllegalStateException ex) { + assertThat(ex).hasMessageThat().isEqualTo("gkeClusterUrl does not have correct format"); + } + } + + @Test + public void createProvider_nonDefaultFullConfig() { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + Map map = buildFullMap(); + provider.createCertificateProvider(map, distWatcher, true); + verify(stsCredentialsFactory, times(1)) + .create( + eq("nonDefaultStsUrl"), + eq(EXPECTED_AUDIENCE), + eq(TMP_PATH_4)); + verify(meshCaCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(true), + eq(NON_DEFAULT_MESH_CA_URL), + eq("test-zone2"), + eq(234567L), + eq(4096), + eq("KEY-ALGO1"), + eq("SIG-ALGO2"), + eq(meshCaChannelFactory), + eq(backoffPolicyProvider), + eq(4321L), + eq(9), + (GoogleCredentials) isNull()); + } + + private Map buildFullMap() { + Map map = new HashMap<>(); + map.put("gkeClusterUrl", GKE_CLUSTER_URL); + map.put("gkeSaJwtLocation", TMP_PATH_4); + map.put("meshCaUrl", NON_DEFAULT_MESH_CA_URL); + map.put("rpcTimeoutSeconds", "123"); + map.put("certValiditySeconds", "234567"); + map.put("renewalGracePeriodSeconds", "4321"); + map.put("keyAlgo", "KEY-ALGO1"); + map.put("keySize", "4096"); + map.put("signatureAlgo", "SIG-ALGO2"); + map.put("maxRetryAttempts", "9"); + map.put("stsUrl", "nonDefaultStsUrl"); + return map; + } + + private Map buildMinimalMap() { + Map map = new HashMap<>(); + map.put("gkeClusterUrl", GKE_CLUSTER_URL); + map.put("gkeSaJwtLocation", TMP_PATH_4); + return map; + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/sts/StsCredentialsTest.java b/xds/src/test/java/io/grpc/xds/internal/sts/StsCredentialsTest.java index 4fa13fde14..d6adedeca5 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sts/StsCredentialsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sts/StsCredentialsTest.java @@ -83,7 +83,7 @@ public class StsCredentialsTest { HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class); when(httpTransportFactory.create()).thenReturn(httpTransport); StsCredentials stsCredentials = - StsCredentials.create( + StsCredentials.Factory.create( STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory); AccessToken token = stsCredentials.refreshAccessToken(); assertThat(token).isNotNull(); @@ -115,7 +115,7 @@ public class StsCredentialsTest { HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class); when(httpTransportFactory.create()).thenReturn(httpTransport); StsCredentials stsCredentials = - StsCredentials.create( + StsCredentials.Factory.create( STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory); CallCredentials callCreds = MoreCallCredentials.from(stsCredentials); CallCredentials.RequestInfo requestInfo = mock(CallCredentials.RequestInfo.class); @@ -150,7 +150,7 @@ public class StsCredentialsTest { HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class); when(httpTransportFactory.create()).thenReturn(httpTransport); StsCredentials stsCredentials = - StsCredentials.create( + StsCredentials.Factory.create( STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory); try { stsCredentials.refreshAccessToken(); @@ -171,7 +171,7 @@ public class StsCredentialsTest { HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class); when(httpTransportFactory.create()).thenReturn(httpTransport); StsCredentials stsCredentials = - StsCredentials.create( + StsCredentials.Factory.create( STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory); try { stsCredentials.refreshAccessToken(); @@ -185,7 +185,7 @@ public class StsCredentialsTest { public void toBuilder_unsupportedException() { HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class); StsCredentials stsCredentials = - StsCredentials.create( + StsCredentials.Factory.create( STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory); try { stsCredentials.toBuilder(); @@ -195,6 +195,13 @@ public class StsCredentialsTest { } } + @Test + public void defaultFactory() { + StsCredentials stsCreds = StsCredentials.Factory.getInstance() + .create(STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath()); + assertThat(stsCreds.transportFactory).isEqualTo(StsCredentials.defaultHttpTransportFactory); + } + private static final String ACCESS_TOKEN = "eyJhbGciOiJSU"; private static final String MOCK_RESPONSE = "{\"access_token\": \""