xds: first part of MeshCaCertificateProvider (#7247)

This commit is contained in:
sanjaypujare 2020-07-29 09:10:02 -07:00 committed by GitHub
parent d2182fe197
commit 06ca927a64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 618 additions and 39 deletions

View File

@ -124,6 +124,9 @@ public abstract class CertificateProvider implements Closeable {
@Override @Override
public abstract void close(); public abstract void close();
/** Starts the cert refresh and watcher update cycle. */
public abstract void start();
private final DistributorWatcher watcher; private final DistributorWatcher watcher;
private final boolean notifyCertUpdates; private final boolean notifyCertUpdates;

View File

@ -131,8 +131,10 @@ public final class CertificateProviderStore {
if (certProviderProvider == null) { if (certProviderProvider == null) {
throw new IllegalArgumentException("Provider not found."); throw new IllegalArgumentException("Provider not found.");
} }
return certProviderProvider.createCertificateProvider( CertificateProvider certProvider = certProviderProvider.createCertificateProvider(
key.config, new CertificateProvider.DistributorWatcher(), key.notifyCertUpdates); key.config, new CertificateProvider.DistributorWatcher(), key.notifyCertUpdates);
certProvider.start();
return certProvider;
} }
} }

View File

@ -0,0 +1,134 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.base.Preconditions.checkArgument;
import com.google.auth.oauth2.GoogleCredentials;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.internal.BackoffPolicy;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
/** Implementation of {@link CertificateProvider} for the Google Mesh CA. */
final class MeshCaCertificateProvider extends CertificateProvider {
private static final Logger logger = Logger.getLogger(MeshCaCertificateProvider.class.getName());
protected MeshCaCertificateProvider(DistributorWatcher watcher, boolean notifyCertUpdates,
String meshCaUrl, String zone, long validitySeconds,
int keySize, String alg, String signatureAlg, MeshCaChannelFactory meshCaChannelFactory,
BackoffPolicy.Provider backoffPolicyProvider, long renewalGracePeriodSeconds,
int maxRetryAttempts, GoogleCredentials oauth2Creds) {
super(watcher, notifyCertUpdates);
}
@Override
public void start() {
// TODO implement
}
@Override
public void close() {
// TODO implement
}
/** Factory for creating channels to MeshCA sever. */
abstract static class MeshCaChannelFactory {
private static final MeshCaChannelFactory DEFAULT_INSTANCE =
new MeshCaChannelFactory() {
/** Creates a channel to the URL in the given list. */
@Override
ManagedChannel createChannel(String serverUri) {
checkArgument(serverUri != null && !serverUri.isEmpty(), "serverUri is null/empty!");
logger.log(Level.INFO, "Creating channel to {0}", serverUri);
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(serverUri);
return channelBuilder.keepAliveTime(1, TimeUnit.MINUTES).build();
}
};
static MeshCaChannelFactory getInstance() {
return DEFAULT_INSTANCE;
}
/**
* Creates a channel to the server.
*/
abstract ManagedChannel createChannel(String serverUri);
}
/** Factory for creating channels to MeshCA sever. */
abstract static class Factory {
private static final Factory DEFAULT_INSTANCE =
new Factory() {
@Override
MeshCaCertificateProvider create(
DistributorWatcher watcher,
boolean notifyCertUpdates,
String meshCaUrl,
String zone,
long validitySeconds,
int keySize,
String alg,
String signatureAlg,
MeshCaChannelFactory meshCaChannelFactory,
BackoffPolicy.Provider backoffPolicyProvider,
long renewalGracePeriodSeconds,
int maxRetryAttempts,
GoogleCredentials oauth2Creds) {
return new MeshCaCertificateProvider(
watcher,
notifyCertUpdates,
meshCaUrl,
zone,
validitySeconds,
keySize,
alg,
signatureAlg,
meshCaChannelFactory,
backoffPolicyProvider,
renewalGracePeriodSeconds,
maxRetryAttempts,
oauth2Creds);
}
};
static Factory getInstance() {
return DEFAULT_INSTANCE;
}
abstract MeshCaCertificateProvider create(
DistributorWatcher watcher,
boolean notifyCertUpdates,
String meshCaUrl,
String zone,
long validitySeconds,
int keySize,
String alg,
String signatureAlg,
MeshCaChannelFactory meshCaChannelFactory,
BackoffPolicy.Provider backoffPolicyProvider,
long renewalGracePeriodSeconds,
int maxRetryAttempts,
GoogleCredentials oauth2Creds);
}
}

View File

@ -0,0 +1,197 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.ExponentialBackoffPolicy;
import io.grpc.xds.internal.sts.StsCredentials;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Provider of {@link CertificateProvider}s. Implemented by the implementer of the plugin. We may
* move this out of the internal package and make this an official API in the future.
*/
final class MeshCaCertificateProviderProvider implements CertificateProviderProvider {
private static final String MESHCA_URL_KEY = "meshCaUrl";
private static final String RPC_TIMEOUT_SECONDS_KEY = "rpcTimeoutSeconds";
private static final String GKECLUSTER_URL_KEY = "gkeClusterUrl";
private static final String CERT_VALIDITY_SECONDS_KEY = "certValiditySeconds";
private static final String RENEWAL_GRACE_PERIOD_SECONDS_KEY = "renewalGracePeriodSeconds";
private static final String KEY_ALGO_KEY = "keyAlgo"; // aka keyType
private static final String KEY_SIZE_KEY = "keySize";
private static final String SIGNATURE_ALGO_KEY = "signatureAlgo";
private static final String MAX_RETRY_ATTEMPTS_KEY = "maxRetryAttempts";
private static final String STS_URL_KEY = "stsUrl";
private static final String GKE_SA_JWT_LOCATION_KEY = "gkeSaJwtLocation";
static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com";
static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L;
static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L; // 9 hours
static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L; // 1 hour
static final String KEY_ALGO_DEFAULT = "RSA"; // aka keyType
static final int KEY_SIZE_DEFAULT = 2048;
static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA";
static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3;
static final String STS_URL_DEFAULT = "https://securetoken.googleapis.com/v1/identitybindingtoken";
private static final Pattern CLUSTER_URL_PATTERN = Pattern
.compile(".*/projects/(.*)/locations/(.*)/clusters/.*");
private static final String TRUST_DOMAIN_SUFFIX = ".svc.id.goog";
private static final String AUDIENCE_PREFIX = "identitynamespace:";
static final String MESH_CA_NAME = "meshCA";
static {
CertificateProviderRegistry.getInstance()
.register(
new MeshCaCertificateProviderProvider(
StsCredentials.Factory.getInstance(),
MeshCaCertificateProvider.MeshCaChannelFactory.getInstance(),
new ExponentialBackoffPolicy.Provider(),
MeshCaCertificateProvider.Factory.getInstance()));
}
final StsCredentials.Factory stsCredentialsFactory;
final MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory;
final BackoffPolicy.Provider backoffPolicyProvider;
final MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
@VisibleForTesting
MeshCaCertificateProviderProvider(StsCredentials.Factory stsCredentialsFactory,
MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory,
BackoffPolicy.Provider backoffPolicyProvider,
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory) {
this.stsCredentialsFactory = stsCredentialsFactory;
this.meshCaChannelFactory = meshCaChannelFactory;
this.backoffPolicyProvider = backoffPolicyProvider;
this.meshCaCertificateProviderFactory = meshCaCertificateProviderFactory;
}
@Override
public String getName() {
return MESH_CA_NAME;
}
@Override
public CertificateProvider createCertificateProvider(
Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates) {
Config configObj = validateAndTranslateConfig(config);
// Construct audience from project and gkeClusterUrl
String audience =
AUDIENCE_PREFIX + configObj.project + TRUST_DOMAIN_SUFFIX + ":" + configObj.gkeClusterUrl;
StsCredentials stsCredentials = stsCredentialsFactory
.create(configObj.stsUrl, audience, configObj.gkeSaJwtLocation);
return meshCaCertificateProviderFactory.create(watcher, notifyCertUpdates, configObj.meshCaUrl,
configObj.zone,
configObj.certValiditySeconds, configObj.keySize, configObj.keyAlgo,
configObj.signatureAlgo,
meshCaChannelFactory, backoffPolicyProvider,
configObj.renewalGracePeriodSeconds, configObj.maxRetryAttempts, stsCredentials);
}
private static Config validateAndTranslateConfig(Object config) {
// TODO(sanjaypujare): add support for string, struct proto etc
checkArgument(config instanceof Map, "Only Map supported for config");
@SuppressWarnings("unchecked") Map<String, String> map = (Map<String, String>)config;
Config configObj = new Config();
configObj.meshCaUrl = mapGetOrDefault(map, MESHCA_URL_KEY, MESHCA_URL_DEFAULT);
configObj.rpcTimeoutSeconds =
mapGetOrDefault(map, RPC_TIMEOUT_SECONDS_KEY, RPC_TIMEOUT_SECONDS_DEFAULT);
configObj.gkeClusterUrl =
checkNotNull(
map.get(GKECLUSTER_URL_KEY), GKECLUSTER_URL_KEY + " is required in the config");
configObj.certValiditySeconds =
mapGetOrDefault(map, CERT_VALIDITY_SECONDS_KEY, CERT_VALIDITY_SECONDS_DEFAULT);
configObj.renewalGracePeriodSeconds =
mapGetOrDefault(
map, RENEWAL_GRACE_PERIOD_SECONDS_KEY, RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT);
configObj.keyAlgo = mapGetOrDefault(map, KEY_ALGO_KEY, KEY_ALGO_DEFAULT);
configObj.keySize = mapGetOrDefault(map, KEY_SIZE_KEY, KEY_SIZE_DEFAULT);
configObj.signatureAlgo = mapGetOrDefault(map, SIGNATURE_ALGO_KEY, SIGNATURE_ALGO_DEFAULT);
configObj.maxRetryAttempts =
mapGetOrDefault(map, MAX_RETRY_ATTEMPTS_KEY, MAX_RETRY_ATTEMPTS_DEFAULT);
configObj.stsUrl = mapGetOrDefault(map, STS_URL_KEY, STS_URL_DEFAULT);
configObj.gkeSaJwtLocation =
checkNotNull(
map.get(GKE_SA_JWT_LOCATION_KEY),
GKE_SA_JWT_LOCATION_KEY + " is required in the config");
parseProjectAndZone(configObj.gkeClusterUrl, configObj);
return configObj;
}
private static String mapGetOrDefault(Map<String, String> map, String key, String defaultVal) {
String value = map.get(key);
if (value == null) {
return defaultVal;
}
return value;
}
private static Long mapGetOrDefault(Map<String, String> map, String key, long defaultVal) {
String value = map.get(key);
if (value == null) {
return defaultVal;
}
return Long.parseLong(value);
}
private static Integer mapGetOrDefault(Map<String, String> map, String key, int defaultVal) {
String value = map.get(key);
if (value == null) {
return defaultVal;
}
return Integer.parseInt(value);
}
private static void parseProjectAndZone(String gkeClusterUrl, Config configObj) {
Matcher matcher = CLUSTER_URL_PATTERN.matcher(gkeClusterUrl);
checkState(matcher.find(), "gkeClusterUrl does not have correct format");
checkState(matcher.groupCount() == 2, "gkeClusterUrl does not have project and location parts");
configObj.project = matcher.group(1);
configObj.zone = matcher.group(2);
}
/** POJO class for storing various config values. */
@VisibleForTesting
static class Config {
String meshCaUrl;
Long rpcTimeoutSeconds;
String gkeClusterUrl;
Long certValiditySeconds;
Long renewalGracePeriodSeconds;
String keyAlgo; // aka keyType
Integer keySize;
String signatureAlgo;
Integer maxRetryAttempts;
String stsUrl;
String gkeSaJwtLocation;
String zone;
String project;
}
}

View File

@ -47,14 +47,14 @@ import java.util.Map;
public final class StsCredentials extends GoogleCredentials { public final class StsCredentials extends GoogleCredentials {
private static final long serialVersionUID = 6647041424685484932L; private static final long serialVersionUID = 6647041424685484932L;
private static final HttpTransportFactory defaultHttpTransportFactory = @VisibleForTesting static final HttpTransportFactory defaultHttpTransportFactory =
new DefaultHttpTransportFactory(); new DefaultHttpTransportFactory();
private static final String CLOUD_PLATFORM_SCOPE = private static final String CLOUD_PLATFORM_SCOPE =
"https://www.googleapis.com/auth/cloud-platform"; "https://www.googleapis.com/auth/cloud-platform";
private final String sourceCredentialsFileLocation; @VisibleForTesting final String sourceCredentialsFileLocation;
private final String identityTokenEndpoint; @VisibleForTesting final String identityTokenEndpoint;
private final String audience; @VisibleForTesting final String audience;
private transient HttpTransportFactory transportFactory; @VisibleForTesting transient HttpTransportFactory transportFactory;
private StsCredentials( private StsCredentials(
String identityTokenEndpoint, String identityTokenEndpoint,
@ -67,33 +67,6 @@ public final class StsCredentials extends GoogleCredentials {
this.transportFactory = transportFactory; this.transportFactory = transportFactory;
} }
/**
* Creates an StsCredentials.
*
* @param identityTokenEndpoint URL of the token exchange service to use.
* @param audience Audience to use in the STS request.
* @param sourceCredentialsFileLocation file-system location that contains the
* source creds e.g. JWT contents.
*/
public static StsCredentials create(
String identityTokenEndpoint, String audience, String sourceCredentialsFileLocation) {
return create(
identityTokenEndpoint,
audience,
sourceCredentialsFileLocation,
getFromServiceLoader(HttpTransportFactory.class, defaultHttpTransportFactory));
}
@VisibleForTesting
static StsCredentials create(
String identityTokenEndpoint,
String audience,
String sourceCredentialsFileLocation,
HttpTransportFactory transportFactory) {
return new StsCredentials(
identityTokenEndpoint, audience, sourceCredentialsFileLocation, transportFactory);
}
@Override @Override
public AccessToken refreshAccessToken() throws IOException { public AccessToken refreshAccessToken() throws IOException {
AccessToken tok = getSourceAccessTokenFromFileLocation(); AccessToken tok = getSourceAccessTokenFromFileLocation();
@ -157,6 +130,48 @@ public final class StsCredentials extends GoogleCredentials {
throw new UnsupportedOperationException("toBuilder not supported"); throw new UnsupportedOperationException("toBuilder not supported");
} }
/** Factory for creating StsCredentials. */
public abstract static class Factory {
private static final Factory DEFAULT_INSTANCE =
new Factory() {
@Override
public StsCredentials create(
String identityTokenEndpoint, String audience, String sourceCredentialsFileLocation) {
return create(
identityTokenEndpoint,
audience,
sourceCredentialsFileLocation,
getFromServiceLoader(HttpTransportFactory.class, defaultHttpTransportFactory));
}
};
public static Factory getInstance() {
return DEFAULT_INSTANCE;
}
/**
* Creates an StsCredentials.
*
* @param identityTokenEndpoint URL of the token exchange service to use.
* @param audience Audience to use in the STS request.
* @param sourceCredentialsFileLocation file-system location that contains the
* source creds e.g. JWT contents.
*/
public abstract StsCredentials create(
String identityTokenEndpoint, String audience, String sourceCredentialsFileLocation);
@VisibleForTesting
static StsCredentials create(
String identityTokenEndpoint,
String audience,
String sourceCredentialsFileLocation,
HttpTransportFactory transportFactory) {
return new StsCredentials(
identityTokenEndpoint, audience, sourceCredentialsFileLocation, transportFactory);
}
}
private static class DefaultHttpTransportFactory implements HttpTransportFactory { private static class DefaultHttpTransportFactory implements HttpTransportFactory {
private static final HttpTransport netHttpTransport = new NetHttpTransport(); private static final HttpTransport netHttpTransport = new NetHttpTransport();

View File

@ -53,6 +53,7 @@ public class CertificateProviderStoreTest {
Object config; Object config;
CertificateProviderProvider certProviderProvider; CertificateProviderProvider certProviderProvider;
int closeCalled = 0; int closeCalled = 0;
int startCalled = 0;
protected TestCertificateProvider( protected TestCertificateProvider(
CertificateProvider.DistributorWatcher watcher, CertificateProvider.DistributorWatcher watcher,
@ -71,6 +72,11 @@ public class CertificateProviderStoreTest {
public void close() { public void close() {
closeCalled++; closeCalled++;
} }
@Override
public void start() {
startCalled++;
}
} }
@Before @Before
@ -161,6 +167,7 @@ public class CertificateProviderStoreTest {
assertThat(handle1.certProvider).isInstanceOf(TestCertificateProvider.class); assertThat(handle1.certProvider).isInstanceOf(TestCertificateProvider.class);
TestCertificateProvider testCertificateProvider = TestCertificateProvider testCertificateProvider =
(TestCertificateProvider) handle1.certProvider; (TestCertificateProvider) handle1.certProvider;
assertThat(testCertificateProvider.startCalled).isEqualTo(1);
CertificateProvider.DistributorWatcher distWatcher = testCertificateProvider.getWatcher(); CertificateProvider.DistributorWatcher distWatcher = testCertificateProvider.getWatcher();
assertThat(distWatcher.downsstreamWatchers).hasSize(2); assertThat(distWatcher.downsstreamWatchers).hasSize(2);
PrivateKey testKey = mock(PrivateKey.class); PrivateKey testKey = mock(PrivateKey.class);
@ -335,6 +342,8 @@ public class CertificateProviderStoreTest {
verify(mockWatcher2, times(1)).updateCertificate(eq(testKey2), eq(testList2)); verify(mockWatcher2, times(1)).updateCertificate(eq(testKey2), eq(testList2));
verify(mockWatcher1, never()) verify(mockWatcher1, never())
.updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class)); .updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class));
assertThat(testCertificateProvider1.startCalled).isEqualTo(1);
assertThat(testCertificateProvider2.startCalled).isEqualTo(1);
handle2.close(); handle2.close();
assertThat(testCertificateProvider2.closeCalled).isEqualTo(1); assertThat(testCertificateProvider2.closeCalled).isEqualTo(1);
handle1.close(); handle1.close();

View File

@ -0,0 +1,212 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import com.google.auth.oauth2.GoogleCredentials;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.ExponentialBackoffPolicy;
import io.grpc.xds.internal.sts.StsCredentials;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/** Unit tests for {@link MeshCaCertificateProviderProvider}. */
@RunWith(JUnit4.class)
public class MeshCaCertificateProviderProviderTest {
public static final String EXPECTED_AUDIENCE =
"identitynamespace:test-project1.svc.id.goog:https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3";
public static final String TMP_PATH_4 = "/tmp/path4";
public static final String NON_DEFAULT_MESH_CA_URL = "nonDefaultMeshCaUrl";
public static final String GKE_CLUSTER_URL =
"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3";
@Mock
StsCredentials.Factory stsCredentialsFactory;
@Mock
MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory;
@Mock
BackoffPolicy.Provider backoffPolicyProvider;
@Mock
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
private MeshCaCertificateProviderProvider provider;
@Before
public void setUp() throws IOException {
MockitoAnnotations.initMocks(this);
provider =
new MeshCaCertificateProviderProvider(
stsCredentialsFactory,
meshCaChannelFactory,
backoffPolicyProvider,
meshCaCertificateProviderFactory);
}
@Test
public void providerRegisteredName() {
CertificateProviderProvider certProviderProvider = CertificateProviderRegistry.getInstance()
.getProvider(MeshCaCertificateProviderProvider.MESH_CA_NAME);
assertThat(certProviderProvider).isInstanceOf(MeshCaCertificateProviderProvider.class);
MeshCaCertificateProviderProvider meshCaCertificateProviderProvider =
(MeshCaCertificateProviderProvider) certProviderProvider;
assertThat(meshCaCertificateProviderProvider.stsCredentialsFactory)
.isSameInstanceAs(StsCredentials.Factory.getInstance());
assertThat(meshCaCertificateProviderProvider.meshCaChannelFactory)
.isSameInstanceAs(MeshCaCertificateProvider.MeshCaChannelFactory.getInstance());
assertThat(meshCaCertificateProviderProvider.backoffPolicyProvider)
.isInstanceOf(ExponentialBackoffPolicy.Provider.class);
assertThat(meshCaCertificateProviderProvider.meshCaCertificateProviderFactory)
.isSameInstanceAs(MeshCaCertificateProvider.Factory.getInstance());
}
@Test
public void createProvider_minimalConfig() {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
Map<String, String> map = buildMinimalMap();
provider.createCertificateProvider(map, distWatcher, true);
verify(stsCredentialsFactory, times(1))
.create(
eq(MeshCaCertificateProviderProvider.STS_URL_DEFAULT),
eq(EXPECTED_AUDIENCE),
eq(TMP_PATH_4));
verify(meshCaCertificateProviderFactory, times(1))
.create(
eq(distWatcher),
eq(true),
eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT),
eq("test-zone2"),
eq(MeshCaCertificateProviderProvider.CERT_VALIDITY_SECONDS_DEFAULT),
eq(MeshCaCertificateProviderProvider.KEY_SIZE_DEFAULT),
eq(MeshCaCertificateProviderProvider.KEY_ALGO_DEFAULT),
eq(MeshCaCertificateProviderProvider.SIGNATURE_ALGO_DEFAULT),
eq(meshCaChannelFactory),
eq(backoffPolicyProvider),
eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT),
eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT),
(GoogleCredentials) isNull());
}
@Test
public void createProvider_missingGkeUrl_expectException() {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
Map<String, String> map = buildMinimalMap();
map.remove("gkeClusterUrl");
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (NullPointerException npe) {
assertThat(npe).hasMessageThat().isEqualTo("gkeClusterUrl is required in the config");
}
}
@Test
public void createProvider_missingGkeSaJwtLocation_expectException() {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
Map<String, String> map = buildMinimalMap();
map.remove("gkeSaJwtLocation");
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (NullPointerException npe) {
assertThat(npe).hasMessageThat().isEqualTo("gkeSaJwtLocation is required in the config");
}
}
@Test
public void createProvider_missingProject_expectException() {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
Map<String, String> map = buildMinimalMap();
map.put("gkeClusterUrl", "https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3");
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
} catch (IllegalStateException ex) {
assertThat(ex).hasMessageThat().isEqualTo("gkeClusterUrl does not have correct format");
}
}
@Test
public void createProvider_nonDefaultFullConfig() {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
Map<String, String> map = buildFullMap();
provider.createCertificateProvider(map, distWatcher, true);
verify(stsCredentialsFactory, times(1))
.create(
eq("nonDefaultStsUrl"),
eq(EXPECTED_AUDIENCE),
eq(TMP_PATH_4));
verify(meshCaCertificateProviderFactory, times(1))
.create(
eq(distWatcher),
eq(true),
eq(NON_DEFAULT_MESH_CA_URL),
eq("test-zone2"),
eq(234567L),
eq(4096),
eq("KEY-ALGO1"),
eq("SIG-ALGO2"),
eq(meshCaChannelFactory),
eq(backoffPolicyProvider),
eq(4321L),
eq(9),
(GoogleCredentials) isNull());
}
private Map<String, String> buildFullMap() {
Map<String, String> map = new HashMap<>();
map.put("gkeClusterUrl", GKE_CLUSTER_URL);
map.put("gkeSaJwtLocation", TMP_PATH_4);
map.put("meshCaUrl", NON_DEFAULT_MESH_CA_URL);
map.put("rpcTimeoutSeconds", "123");
map.put("certValiditySeconds", "234567");
map.put("renewalGracePeriodSeconds", "4321");
map.put("keyAlgo", "KEY-ALGO1");
map.put("keySize", "4096");
map.put("signatureAlgo", "SIG-ALGO2");
map.put("maxRetryAttempts", "9");
map.put("stsUrl", "nonDefaultStsUrl");
return map;
}
private Map<String, String> buildMinimalMap() {
Map<String, String> map = new HashMap<>();
map.put("gkeClusterUrl", GKE_CLUSTER_URL);
map.put("gkeSaJwtLocation", TMP_PATH_4);
return map;
}
}

View File

@ -83,7 +83,7 @@ public class StsCredentialsTest {
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class); HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
when(httpTransportFactory.create()).thenReturn(httpTransport); when(httpTransportFactory.create()).thenReturn(httpTransport);
StsCredentials stsCredentials = StsCredentials stsCredentials =
StsCredentials.create( StsCredentials.Factory.create(
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory); STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
AccessToken token = stsCredentials.refreshAccessToken(); AccessToken token = stsCredentials.refreshAccessToken();
assertThat(token).isNotNull(); assertThat(token).isNotNull();
@ -115,7 +115,7 @@ public class StsCredentialsTest {
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class); HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
when(httpTransportFactory.create()).thenReturn(httpTransport); when(httpTransportFactory.create()).thenReturn(httpTransport);
StsCredentials stsCredentials = StsCredentials stsCredentials =
StsCredentials.create( StsCredentials.Factory.create(
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory); STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
CallCredentials callCreds = MoreCallCredentials.from(stsCredentials); CallCredentials callCreds = MoreCallCredentials.from(stsCredentials);
CallCredentials.RequestInfo requestInfo = mock(CallCredentials.RequestInfo.class); CallCredentials.RequestInfo requestInfo = mock(CallCredentials.RequestInfo.class);
@ -150,7 +150,7 @@ public class StsCredentialsTest {
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class); HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
when(httpTransportFactory.create()).thenReturn(httpTransport); when(httpTransportFactory.create()).thenReturn(httpTransport);
StsCredentials stsCredentials = StsCredentials stsCredentials =
StsCredentials.create( StsCredentials.Factory.create(
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory); STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
try { try {
stsCredentials.refreshAccessToken(); stsCredentials.refreshAccessToken();
@ -171,7 +171,7 @@ public class StsCredentialsTest {
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class); HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
when(httpTransportFactory.create()).thenReturn(httpTransport); when(httpTransportFactory.create()).thenReturn(httpTransport);
StsCredentials stsCredentials = StsCredentials stsCredentials =
StsCredentials.create( StsCredentials.Factory.create(
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory); STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
try { try {
stsCredentials.refreshAccessToken(); stsCredentials.refreshAccessToken();
@ -185,7 +185,7 @@ public class StsCredentialsTest {
public void toBuilder_unsupportedException() { public void toBuilder_unsupportedException() {
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class); HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
StsCredentials stsCredentials = StsCredentials stsCredentials =
StsCredentials.create( StsCredentials.Factory.create(
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory); STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
try { try {
stsCredentials.toBuilder(); stsCredentials.toBuilder();
@ -195,6 +195,13 @@ public class StsCredentialsTest {
} }
} }
@Test
public void defaultFactory() {
StsCredentials stsCreds = StsCredentials.Factory.getInstance()
.create(STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath());
assertThat(stsCreds.transportFactory).isEqualTo(StsCredentials.defaultHttpTransportFactory);
}
private static final String ACCESS_TOKEN = "eyJhbGciOiJSU"; private static final String ACCESS_TOKEN = "eyJhbGciOiJSU";
private static final String MOCK_RESPONSE = private static final String MOCK_RESPONSE =
"{\"access_token\": \"" "{\"access_token\": \""