mirror of https://github.com/grpc/grpc-java.git
xds: first part of MeshCaCertificateProvider (#7247)
This commit is contained in:
parent
d2182fe197
commit
06ca927a64
|
|
@ -124,6 +124,9 @@ public abstract class CertificateProvider implements Closeable {
|
|||
@Override
|
||||
public abstract void close();
|
||||
|
||||
/** Starts the cert refresh and watcher update cycle. */
|
||||
public abstract void start();
|
||||
|
||||
private final DistributorWatcher watcher;
|
||||
private final boolean notifyCertUpdates;
|
||||
|
||||
|
|
|
|||
|
|
@ -131,8 +131,10 @@ public final class CertificateProviderStore {
|
|||
if (certProviderProvider == null) {
|
||||
throw new IllegalArgumentException("Provider not found.");
|
||||
}
|
||||
return certProviderProvider.createCertificateProvider(
|
||||
CertificateProvider certProvider = certProviderProvider.createCertificateProvider(
|
||||
key.config, new CertificateProvider.DistributorWatcher(), key.notifyCertUpdates);
|
||||
certProvider.start();
|
||||
return certProvider;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,134 @@
|
|||
/*
|
||||
* Copyright 2020 The gRPC Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.xds.internal.certprovider;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkArgument;
|
||||
|
||||
import com.google.auth.oauth2.GoogleCredentials;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.ManagedChannelBuilder;
|
||||
import io.grpc.internal.BackoffPolicy;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/** Implementation of {@link CertificateProvider} for the Google Mesh CA. */
|
||||
final class MeshCaCertificateProvider extends CertificateProvider {
|
||||
private static final Logger logger = Logger.getLogger(MeshCaCertificateProvider.class.getName());
|
||||
|
||||
protected MeshCaCertificateProvider(DistributorWatcher watcher, boolean notifyCertUpdates,
|
||||
String meshCaUrl, String zone, long validitySeconds,
|
||||
int keySize, String alg, String signatureAlg, MeshCaChannelFactory meshCaChannelFactory,
|
||||
BackoffPolicy.Provider backoffPolicyProvider, long renewalGracePeriodSeconds,
|
||||
int maxRetryAttempts, GoogleCredentials oauth2Creds) {
|
||||
super(watcher, notifyCertUpdates);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
// TODO implement
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
// TODO implement
|
||||
}
|
||||
|
||||
/** Factory for creating channels to MeshCA sever. */
|
||||
abstract static class MeshCaChannelFactory {
|
||||
|
||||
private static final MeshCaChannelFactory DEFAULT_INSTANCE =
|
||||
new MeshCaChannelFactory() {
|
||||
|
||||
/** Creates a channel to the URL in the given list. */
|
||||
@Override
|
||||
ManagedChannel createChannel(String serverUri) {
|
||||
checkArgument(serverUri != null && !serverUri.isEmpty(), "serverUri is null/empty!");
|
||||
logger.log(Level.INFO, "Creating channel to {0}", serverUri);
|
||||
|
||||
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(serverUri);
|
||||
return channelBuilder.keepAliveTime(1, TimeUnit.MINUTES).build();
|
||||
}
|
||||
};
|
||||
|
||||
static MeshCaChannelFactory getInstance() {
|
||||
return DEFAULT_INSTANCE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a channel to the server.
|
||||
*/
|
||||
abstract ManagedChannel createChannel(String serverUri);
|
||||
}
|
||||
|
||||
/** Factory for creating channels to MeshCA sever. */
|
||||
abstract static class Factory {
|
||||
private static final Factory DEFAULT_INSTANCE =
|
||||
new Factory() {
|
||||
|
||||
@Override
|
||||
MeshCaCertificateProvider create(
|
||||
DistributorWatcher watcher,
|
||||
boolean notifyCertUpdates,
|
||||
String meshCaUrl,
|
||||
String zone,
|
||||
long validitySeconds,
|
||||
int keySize,
|
||||
String alg,
|
||||
String signatureAlg,
|
||||
MeshCaChannelFactory meshCaChannelFactory,
|
||||
BackoffPolicy.Provider backoffPolicyProvider,
|
||||
long renewalGracePeriodSeconds,
|
||||
int maxRetryAttempts,
|
||||
GoogleCredentials oauth2Creds) {
|
||||
return new MeshCaCertificateProvider(
|
||||
watcher,
|
||||
notifyCertUpdates,
|
||||
meshCaUrl,
|
||||
zone,
|
||||
validitySeconds,
|
||||
keySize,
|
||||
alg,
|
||||
signatureAlg,
|
||||
meshCaChannelFactory,
|
||||
backoffPolicyProvider,
|
||||
renewalGracePeriodSeconds,
|
||||
maxRetryAttempts,
|
||||
oauth2Creds);
|
||||
}
|
||||
};
|
||||
|
||||
static Factory getInstance() {
|
||||
return DEFAULT_INSTANCE;
|
||||
}
|
||||
|
||||
abstract MeshCaCertificateProvider create(
|
||||
DistributorWatcher watcher,
|
||||
boolean notifyCertUpdates,
|
||||
String meshCaUrl,
|
||||
String zone,
|
||||
long validitySeconds,
|
||||
int keySize,
|
||||
String alg,
|
||||
String signatureAlg,
|
||||
MeshCaChannelFactory meshCaChannelFactory,
|
||||
BackoffPolicy.Provider backoffPolicyProvider,
|
||||
long renewalGracePeriodSeconds,
|
||||
int maxRetryAttempts,
|
||||
GoogleCredentials oauth2Creds);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,197 @@
|
|||
/*
|
||||
* Copyright 2020 The gRPC Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.xds.internal.certprovider;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkArgument;
|
||||
import static com.google.common.base.Preconditions.checkNotNull;
|
||||
import static com.google.common.base.Preconditions.checkState;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.grpc.internal.BackoffPolicy;
|
||||
import io.grpc.internal.ExponentialBackoffPolicy;
|
||||
import io.grpc.xds.internal.sts.StsCredentials;
|
||||
import java.util.Map;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
/**
|
||||
* Provider of {@link CertificateProvider}s. Implemented by the implementer of the plugin. We may
|
||||
* move this out of the internal package and make this an official API in the future.
|
||||
*/
|
||||
final class MeshCaCertificateProviderProvider implements CertificateProviderProvider {
|
||||
|
||||
private static final String MESHCA_URL_KEY = "meshCaUrl";
|
||||
private static final String RPC_TIMEOUT_SECONDS_KEY = "rpcTimeoutSeconds";
|
||||
private static final String GKECLUSTER_URL_KEY = "gkeClusterUrl";
|
||||
private static final String CERT_VALIDITY_SECONDS_KEY = "certValiditySeconds";
|
||||
private static final String RENEWAL_GRACE_PERIOD_SECONDS_KEY = "renewalGracePeriodSeconds";
|
||||
private static final String KEY_ALGO_KEY = "keyAlgo"; // aka keyType
|
||||
private static final String KEY_SIZE_KEY = "keySize";
|
||||
private static final String SIGNATURE_ALGO_KEY = "signatureAlgo";
|
||||
private static final String MAX_RETRY_ATTEMPTS_KEY = "maxRetryAttempts";
|
||||
private static final String STS_URL_KEY = "stsUrl";
|
||||
private static final String GKE_SA_JWT_LOCATION_KEY = "gkeSaJwtLocation";
|
||||
|
||||
static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com";
|
||||
static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L;
|
||||
static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L; // 9 hours
|
||||
static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L; // 1 hour
|
||||
static final String KEY_ALGO_DEFAULT = "RSA"; // aka keyType
|
||||
static final int KEY_SIZE_DEFAULT = 2048;
|
||||
static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA";
|
||||
static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3;
|
||||
static final String STS_URL_DEFAULT = "https://securetoken.googleapis.com/v1/identitybindingtoken";
|
||||
|
||||
private static final Pattern CLUSTER_URL_PATTERN = Pattern
|
||||
.compile(".*/projects/(.*)/locations/(.*)/clusters/.*");
|
||||
|
||||
private static final String TRUST_DOMAIN_SUFFIX = ".svc.id.goog";
|
||||
private static final String AUDIENCE_PREFIX = "identitynamespace:";
|
||||
static final String MESH_CA_NAME = "meshCA";
|
||||
|
||||
static {
|
||||
CertificateProviderRegistry.getInstance()
|
||||
.register(
|
||||
new MeshCaCertificateProviderProvider(
|
||||
StsCredentials.Factory.getInstance(),
|
||||
MeshCaCertificateProvider.MeshCaChannelFactory.getInstance(),
|
||||
new ExponentialBackoffPolicy.Provider(),
|
||||
MeshCaCertificateProvider.Factory.getInstance()));
|
||||
}
|
||||
|
||||
final StsCredentials.Factory stsCredentialsFactory;
|
||||
final MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory;
|
||||
final BackoffPolicy.Provider backoffPolicyProvider;
|
||||
final MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
|
||||
|
||||
@VisibleForTesting
|
||||
MeshCaCertificateProviderProvider(StsCredentials.Factory stsCredentialsFactory,
|
||||
MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory,
|
||||
BackoffPolicy.Provider backoffPolicyProvider,
|
||||
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory) {
|
||||
this.stsCredentialsFactory = stsCredentialsFactory;
|
||||
this.meshCaChannelFactory = meshCaChannelFactory;
|
||||
this.backoffPolicyProvider = backoffPolicyProvider;
|
||||
this.meshCaCertificateProviderFactory = meshCaCertificateProviderFactory;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return MESH_CA_NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CertificateProvider createCertificateProvider(
|
||||
Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates) {
|
||||
|
||||
Config configObj = validateAndTranslateConfig(config);
|
||||
|
||||
// Construct audience from project and gkeClusterUrl
|
||||
String audience =
|
||||
AUDIENCE_PREFIX + configObj.project + TRUST_DOMAIN_SUFFIX + ":" + configObj.gkeClusterUrl;
|
||||
StsCredentials stsCredentials = stsCredentialsFactory
|
||||
.create(configObj.stsUrl, audience, configObj.gkeSaJwtLocation);
|
||||
|
||||
return meshCaCertificateProviderFactory.create(watcher, notifyCertUpdates, configObj.meshCaUrl,
|
||||
configObj.zone,
|
||||
configObj.certValiditySeconds, configObj.keySize, configObj.keyAlgo,
|
||||
configObj.signatureAlgo,
|
||||
meshCaChannelFactory, backoffPolicyProvider,
|
||||
configObj.renewalGracePeriodSeconds, configObj.maxRetryAttempts, stsCredentials);
|
||||
}
|
||||
|
||||
private static Config validateAndTranslateConfig(Object config) {
|
||||
// TODO(sanjaypujare): add support for string, struct proto etc
|
||||
checkArgument(config instanceof Map, "Only Map supported for config");
|
||||
@SuppressWarnings("unchecked") Map<String, String> map = (Map<String, String>)config;
|
||||
|
||||
Config configObj = new Config();
|
||||
configObj.meshCaUrl = mapGetOrDefault(map, MESHCA_URL_KEY, MESHCA_URL_DEFAULT);
|
||||
configObj.rpcTimeoutSeconds =
|
||||
mapGetOrDefault(map, RPC_TIMEOUT_SECONDS_KEY, RPC_TIMEOUT_SECONDS_DEFAULT);
|
||||
configObj.gkeClusterUrl =
|
||||
checkNotNull(
|
||||
map.get(GKECLUSTER_URL_KEY), GKECLUSTER_URL_KEY + " is required in the config");
|
||||
configObj.certValiditySeconds =
|
||||
mapGetOrDefault(map, CERT_VALIDITY_SECONDS_KEY, CERT_VALIDITY_SECONDS_DEFAULT);
|
||||
configObj.renewalGracePeriodSeconds =
|
||||
mapGetOrDefault(
|
||||
map, RENEWAL_GRACE_PERIOD_SECONDS_KEY, RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT);
|
||||
configObj.keyAlgo = mapGetOrDefault(map, KEY_ALGO_KEY, KEY_ALGO_DEFAULT);
|
||||
configObj.keySize = mapGetOrDefault(map, KEY_SIZE_KEY, KEY_SIZE_DEFAULT);
|
||||
configObj.signatureAlgo = mapGetOrDefault(map, SIGNATURE_ALGO_KEY, SIGNATURE_ALGO_DEFAULT);
|
||||
configObj.maxRetryAttempts =
|
||||
mapGetOrDefault(map, MAX_RETRY_ATTEMPTS_KEY, MAX_RETRY_ATTEMPTS_DEFAULT);
|
||||
configObj.stsUrl = mapGetOrDefault(map, STS_URL_KEY, STS_URL_DEFAULT);
|
||||
configObj.gkeSaJwtLocation =
|
||||
checkNotNull(
|
||||
map.get(GKE_SA_JWT_LOCATION_KEY),
|
||||
GKE_SA_JWT_LOCATION_KEY + " is required in the config");
|
||||
parseProjectAndZone(configObj.gkeClusterUrl, configObj);
|
||||
return configObj;
|
||||
}
|
||||
|
||||
private static String mapGetOrDefault(Map<String, String> map, String key, String defaultVal) {
|
||||
String value = map.get(key);
|
||||
if (value == null) {
|
||||
return defaultVal;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
private static Long mapGetOrDefault(Map<String, String> map, String key, long defaultVal) {
|
||||
String value = map.get(key);
|
||||
if (value == null) {
|
||||
return defaultVal;
|
||||
}
|
||||
return Long.parseLong(value);
|
||||
}
|
||||
|
||||
private static Integer mapGetOrDefault(Map<String, String> map, String key, int defaultVal) {
|
||||
String value = map.get(key);
|
||||
if (value == null) {
|
||||
return defaultVal;
|
||||
}
|
||||
return Integer.parseInt(value);
|
||||
}
|
||||
|
||||
private static void parseProjectAndZone(String gkeClusterUrl, Config configObj) {
|
||||
Matcher matcher = CLUSTER_URL_PATTERN.matcher(gkeClusterUrl);
|
||||
checkState(matcher.find(), "gkeClusterUrl does not have correct format");
|
||||
checkState(matcher.groupCount() == 2, "gkeClusterUrl does not have project and location parts");
|
||||
configObj.project = matcher.group(1);
|
||||
configObj.zone = matcher.group(2);
|
||||
}
|
||||
|
||||
/** POJO class for storing various config values. */
|
||||
@VisibleForTesting
|
||||
static class Config {
|
||||
String meshCaUrl;
|
||||
Long rpcTimeoutSeconds;
|
||||
String gkeClusterUrl;
|
||||
Long certValiditySeconds;
|
||||
Long renewalGracePeriodSeconds;
|
||||
String keyAlgo; // aka keyType
|
||||
Integer keySize;
|
||||
String signatureAlgo;
|
||||
Integer maxRetryAttempts;
|
||||
String stsUrl;
|
||||
String gkeSaJwtLocation;
|
||||
String zone;
|
||||
String project;
|
||||
}
|
||||
}
|
||||
|
|
@ -47,14 +47,14 @@ import java.util.Map;
|
|||
public final class StsCredentials extends GoogleCredentials {
|
||||
private static final long serialVersionUID = 6647041424685484932L;
|
||||
|
||||
private static final HttpTransportFactory defaultHttpTransportFactory =
|
||||
@VisibleForTesting static final HttpTransportFactory defaultHttpTransportFactory =
|
||||
new DefaultHttpTransportFactory();
|
||||
private static final String CLOUD_PLATFORM_SCOPE =
|
||||
"https://www.googleapis.com/auth/cloud-platform";
|
||||
private final String sourceCredentialsFileLocation;
|
||||
private final String identityTokenEndpoint;
|
||||
private final String audience;
|
||||
private transient HttpTransportFactory transportFactory;
|
||||
@VisibleForTesting final String sourceCredentialsFileLocation;
|
||||
@VisibleForTesting final String identityTokenEndpoint;
|
||||
@VisibleForTesting final String audience;
|
||||
@VisibleForTesting transient HttpTransportFactory transportFactory;
|
||||
|
||||
private StsCredentials(
|
||||
String identityTokenEndpoint,
|
||||
|
|
@ -67,33 +67,6 @@ public final class StsCredentials extends GoogleCredentials {
|
|||
this.transportFactory = transportFactory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an StsCredentials.
|
||||
*
|
||||
* @param identityTokenEndpoint URL of the token exchange service to use.
|
||||
* @param audience Audience to use in the STS request.
|
||||
* @param sourceCredentialsFileLocation file-system location that contains the
|
||||
* source creds e.g. JWT contents.
|
||||
*/
|
||||
public static StsCredentials create(
|
||||
String identityTokenEndpoint, String audience, String sourceCredentialsFileLocation) {
|
||||
return create(
|
||||
identityTokenEndpoint,
|
||||
audience,
|
||||
sourceCredentialsFileLocation,
|
||||
getFromServiceLoader(HttpTransportFactory.class, defaultHttpTransportFactory));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static StsCredentials create(
|
||||
String identityTokenEndpoint,
|
||||
String audience,
|
||||
String sourceCredentialsFileLocation,
|
||||
HttpTransportFactory transportFactory) {
|
||||
return new StsCredentials(
|
||||
identityTokenEndpoint, audience, sourceCredentialsFileLocation, transportFactory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AccessToken refreshAccessToken() throws IOException {
|
||||
AccessToken tok = getSourceAccessTokenFromFileLocation();
|
||||
|
|
@ -157,6 +130,48 @@ public final class StsCredentials extends GoogleCredentials {
|
|||
throw new UnsupportedOperationException("toBuilder not supported");
|
||||
}
|
||||
|
||||
/** Factory for creating StsCredentials. */
|
||||
public abstract static class Factory {
|
||||
private static final Factory DEFAULT_INSTANCE =
|
||||
new Factory() {
|
||||
|
||||
@Override
|
||||
public StsCredentials create(
|
||||
String identityTokenEndpoint, String audience, String sourceCredentialsFileLocation) {
|
||||
return create(
|
||||
identityTokenEndpoint,
|
||||
audience,
|
||||
sourceCredentialsFileLocation,
|
||||
getFromServiceLoader(HttpTransportFactory.class, defaultHttpTransportFactory));
|
||||
}
|
||||
};
|
||||
|
||||
public static Factory getInstance() {
|
||||
return DEFAULT_INSTANCE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an StsCredentials.
|
||||
*
|
||||
* @param identityTokenEndpoint URL of the token exchange service to use.
|
||||
* @param audience Audience to use in the STS request.
|
||||
* @param sourceCredentialsFileLocation file-system location that contains the
|
||||
* source creds e.g. JWT contents.
|
||||
*/
|
||||
public abstract StsCredentials create(
|
||||
String identityTokenEndpoint, String audience, String sourceCredentialsFileLocation);
|
||||
|
||||
@VisibleForTesting
|
||||
static StsCredentials create(
|
||||
String identityTokenEndpoint,
|
||||
String audience,
|
||||
String sourceCredentialsFileLocation,
|
||||
HttpTransportFactory transportFactory) {
|
||||
return new StsCredentials(
|
||||
identityTokenEndpoint, audience, sourceCredentialsFileLocation, transportFactory);
|
||||
}
|
||||
}
|
||||
|
||||
private static class DefaultHttpTransportFactory implements HttpTransportFactory {
|
||||
|
||||
private static final HttpTransport netHttpTransport = new NetHttpTransport();
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ public class CertificateProviderStoreTest {
|
|||
Object config;
|
||||
CertificateProviderProvider certProviderProvider;
|
||||
int closeCalled = 0;
|
||||
int startCalled = 0;
|
||||
|
||||
protected TestCertificateProvider(
|
||||
CertificateProvider.DistributorWatcher watcher,
|
||||
|
|
@ -71,6 +72,11 @@ public class CertificateProviderStoreTest {
|
|||
public void close() {
|
||||
closeCalled++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
startCalled++;
|
||||
}
|
||||
}
|
||||
|
||||
@Before
|
||||
|
|
@ -161,6 +167,7 @@ public class CertificateProviderStoreTest {
|
|||
assertThat(handle1.certProvider).isInstanceOf(TestCertificateProvider.class);
|
||||
TestCertificateProvider testCertificateProvider =
|
||||
(TestCertificateProvider) handle1.certProvider;
|
||||
assertThat(testCertificateProvider.startCalled).isEqualTo(1);
|
||||
CertificateProvider.DistributorWatcher distWatcher = testCertificateProvider.getWatcher();
|
||||
assertThat(distWatcher.downsstreamWatchers).hasSize(2);
|
||||
PrivateKey testKey = mock(PrivateKey.class);
|
||||
|
|
@ -335,6 +342,8 @@ public class CertificateProviderStoreTest {
|
|||
verify(mockWatcher2, times(1)).updateCertificate(eq(testKey2), eq(testList2));
|
||||
verify(mockWatcher1, never())
|
||||
.updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class));
|
||||
assertThat(testCertificateProvider1.startCalled).isEqualTo(1);
|
||||
assertThat(testCertificateProvider2.startCalled).isEqualTo(1);
|
||||
handle2.close();
|
||||
assertThat(testCertificateProvider2.closeCalled).isEqualTo(1);
|
||||
handle1.close();
|
||||
|
|
|
|||
|
|
@ -0,0 +1,212 @@
|
|||
/*
|
||||
* Copyright 2020 The gRPC Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.xds.internal.certprovider;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.fail;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.ArgumentMatchers.isNull;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
import com.google.auth.oauth2.GoogleCredentials;
|
||||
import io.grpc.internal.BackoffPolicy;
|
||||
import io.grpc.internal.ExponentialBackoffPolicy;
|
||||
import io.grpc.xds.internal.sts.StsCredentials;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.MockitoAnnotations;
|
||||
|
||||
/** Unit tests for {@link MeshCaCertificateProviderProvider}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class MeshCaCertificateProviderProviderTest {
|
||||
|
||||
public static final String EXPECTED_AUDIENCE =
|
||||
"identitynamespace:test-project1.svc.id.goog:https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3";
|
||||
public static final String TMP_PATH_4 = "/tmp/path4";
|
||||
public static final String NON_DEFAULT_MESH_CA_URL = "nonDefaultMeshCaUrl";
|
||||
public static final String GKE_CLUSTER_URL =
|
||||
"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3";
|
||||
|
||||
@Mock
|
||||
StsCredentials.Factory stsCredentialsFactory;
|
||||
|
||||
@Mock
|
||||
MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory;
|
||||
|
||||
@Mock
|
||||
BackoffPolicy.Provider backoffPolicyProvider;
|
||||
|
||||
@Mock
|
||||
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
|
||||
private MeshCaCertificateProviderProvider provider;
|
||||
|
||||
@Before
|
||||
public void setUp() throws IOException {
|
||||
MockitoAnnotations.initMocks(this);
|
||||
provider =
|
||||
new MeshCaCertificateProviderProvider(
|
||||
stsCredentialsFactory,
|
||||
meshCaChannelFactory,
|
||||
backoffPolicyProvider,
|
||||
meshCaCertificateProviderFactory);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void providerRegisteredName() {
|
||||
CertificateProviderProvider certProviderProvider = CertificateProviderRegistry.getInstance()
|
||||
.getProvider(MeshCaCertificateProviderProvider.MESH_CA_NAME);
|
||||
assertThat(certProviderProvider).isInstanceOf(MeshCaCertificateProviderProvider.class);
|
||||
MeshCaCertificateProviderProvider meshCaCertificateProviderProvider =
|
||||
(MeshCaCertificateProviderProvider) certProviderProvider;
|
||||
assertThat(meshCaCertificateProviderProvider.stsCredentialsFactory)
|
||||
.isSameInstanceAs(StsCredentials.Factory.getInstance());
|
||||
assertThat(meshCaCertificateProviderProvider.meshCaChannelFactory)
|
||||
.isSameInstanceAs(MeshCaCertificateProvider.MeshCaChannelFactory.getInstance());
|
||||
assertThat(meshCaCertificateProviderProvider.backoffPolicyProvider)
|
||||
.isInstanceOf(ExponentialBackoffPolicy.Provider.class);
|
||||
assertThat(meshCaCertificateProviderProvider.meshCaCertificateProviderFactory)
|
||||
.isSameInstanceAs(MeshCaCertificateProvider.Factory.getInstance());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createProvider_minimalConfig() {
|
||||
CertificateProvider.DistributorWatcher distWatcher =
|
||||
new CertificateProvider.DistributorWatcher();
|
||||
Map<String, String> map = buildMinimalMap();
|
||||
provider.createCertificateProvider(map, distWatcher, true);
|
||||
verify(stsCredentialsFactory, times(1))
|
||||
.create(
|
||||
eq(MeshCaCertificateProviderProvider.STS_URL_DEFAULT),
|
||||
eq(EXPECTED_AUDIENCE),
|
||||
eq(TMP_PATH_4));
|
||||
verify(meshCaCertificateProviderFactory, times(1))
|
||||
.create(
|
||||
eq(distWatcher),
|
||||
eq(true),
|
||||
eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT),
|
||||
eq("test-zone2"),
|
||||
eq(MeshCaCertificateProviderProvider.CERT_VALIDITY_SECONDS_DEFAULT),
|
||||
eq(MeshCaCertificateProviderProvider.KEY_SIZE_DEFAULT),
|
||||
eq(MeshCaCertificateProviderProvider.KEY_ALGO_DEFAULT),
|
||||
eq(MeshCaCertificateProviderProvider.SIGNATURE_ALGO_DEFAULT),
|
||||
eq(meshCaChannelFactory),
|
||||
eq(backoffPolicyProvider),
|
||||
eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT),
|
||||
eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT),
|
||||
(GoogleCredentials) isNull());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createProvider_missingGkeUrl_expectException() {
|
||||
CertificateProvider.DistributorWatcher distWatcher =
|
||||
new CertificateProvider.DistributorWatcher();
|
||||
Map<String, String> map = buildMinimalMap();
|
||||
map.remove("gkeClusterUrl");
|
||||
try {
|
||||
provider.createCertificateProvider(map, distWatcher, true);
|
||||
fail("exception expected");
|
||||
} catch (NullPointerException npe) {
|
||||
assertThat(npe).hasMessageThat().isEqualTo("gkeClusterUrl is required in the config");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createProvider_missingGkeSaJwtLocation_expectException() {
|
||||
CertificateProvider.DistributorWatcher distWatcher =
|
||||
new CertificateProvider.DistributorWatcher();
|
||||
Map<String, String> map = buildMinimalMap();
|
||||
map.remove("gkeSaJwtLocation");
|
||||
try {
|
||||
provider.createCertificateProvider(map, distWatcher, true);
|
||||
fail("exception expected");
|
||||
} catch (NullPointerException npe) {
|
||||
assertThat(npe).hasMessageThat().isEqualTo("gkeSaJwtLocation is required in the config");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createProvider_missingProject_expectException() {
|
||||
CertificateProvider.DistributorWatcher distWatcher =
|
||||
new CertificateProvider.DistributorWatcher();
|
||||
Map<String, String> map = buildMinimalMap();
|
||||
map.put("gkeClusterUrl", "https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3");
|
||||
try {
|
||||
provider.createCertificateProvider(map, distWatcher, true);
|
||||
fail("exception expected");
|
||||
} catch (IllegalStateException ex) {
|
||||
assertThat(ex).hasMessageThat().isEqualTo("gkeClusterUrl does not have correct format");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createProvider_nonDefaultFullConfig() {
|
||||
CertificateProvider.DistributorWatcher distWatcher =
|
||||
new CertificateProvider.DistributorWatcher();
|
||||
Map<String, String> map = buildFullMap();
|
||||
provider.createCertificateProvider(map, distWatcher, true);
|
||||
verify(stsCredentialsFactory, times(1))
|
||||
.create(
|
||||
eq("nonDefaultStsUrl"),
|
||||
eq(EXPECTED_AUDIENCE),
|
||||
eq(TMP_PATH_4));
|
||||
verify(meshCaCertificateProviderFactory, times(1))
|
||||
.create(
|
||||
eq(distWatcher),
|
||||
eq(true),
|
||||
eq(NON_DEFAULT_MESH_CA_URL),
|
||||
eq("test-zone2"),
|
||||
eq(234567L),
|
||||
eq(4096),
|
||||
eq("KEY-ALGO1"),
|
||||
eq("SIG-ALGO2"),
|
||||
eq(meshCaChannelFactory),
|
||||
eq(backoffPolicyProvider),
|
||||
eq(4321L),
|
||||
eq(9),
|
||||
(GoogleCredentials) isNull());
|
||||
}
|
||||
|
||||
private Map<String, String> buildFullMap() {
|
||||
Map<String, String> map = new HashMap<>();
|
||||
map.put("gkeClusterUrl", GKE_CLUSTER_URL);
|
||||
map.put("gkeSaJwtLocation", TMP_PATH_4);
|
||||
map.put("meshCaUrl", NON_DEFAULT_MESH_CA_URL);
|
||||
map.put("rpcTimeoutSeconds", "123");
|
||||
map.put("certValiditySeconds", "234567");
|
||||
map.put("renewalGracePeriodSeconds", "4321");
|
||||
map.put("keyAlgo", "KEY-ALGO1");
|
||||
map.put("keySize", "4096");
|
||||
map.put("signatureAlgo", "SIG-ALGO2");
|
||||
map.put("maxRetryAttempts", "9");
|
||||
map.put("stsUrl", "nonDefaultStsUrl");
|
||||
return map;
|
||||
}
|
||||
|
||||
private Map<String, String> buildMinimalMap() {
|
||||
Map<String, String> map = new HashMap<>();
|
||||
map.put("gkeClusterUrl", GKE_CLUSTER_URL);
|
||||
map.put("gkeSaJwtLocation", TMP_PATH_4);
|
||||
return map;
|
||||
}
|
||||
}
|
||||
|
|
@ -83,7 +83,7 @@ public class StsCredentialsTest {
|
|||
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
||||
when(httpTransportFactory.create()).thenReturn(httpTransport);
|
||||
StsCredentials stsCredentials =
|
||||
StsCredentials.create(
|
||||
StsCredentials.Factory.create(
|
||||
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
||||
AccessToken token = stsCredentials.refreshAccessToken();
|
||||
assertThat(token).isNotNull();
|
||||
|
|
@ -115,7 +115,7 @@ public class StsCredentialsTest {
|
|||
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
||||
when(httpTransportFactory.create()).thenReturn(httpTransport);
|
||||
StsCredentials stsCredentials =
|
||||
StsCredentials.create(
|
||||
StsCredentials.Factory.create(
|
||||
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
||||
CallCredentials callCreds = MoreCallCredentials.from(stsCredentials);
|
||||
CallCredentials.RequestInfo requestInfo = mock(CallCredentials.RequestInfo.class);
|
||||
|
|
@ -150,7 +150,7 @@ public class StsCredentialsTest {
|
|||
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
||||
when(httpTransportFactory.create()).thenReturn(httpTransport);
|
||||
StsCredentials stsCredentials =
|
||||
StsCredentials.create(
|
||||
StsCredentials.Factory.create(
|
||||
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
||||
try {
|
||||
stsCredentials.refreshAccessToken();
|
||||
|
|
@ -171,7 +171,7 @@ public class StsCredentialsTest {
|
|||
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
||||
when(httpTransportFactory.create()).thenReturn(httpTransport);
|
||||
StsCredentials stsCredentials =
|
||||
StsCredentials.create(
|
||||
StsCredentials.Factory.create(
|
||||
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
||||
try {
|
||||
stsCredentials.refreshAccessToken();
|
||||
|
|
@ -185,7 +185,7 @@ public class StsCredentialsTest {
|
|||
public void toBuilder_unsupportedException() {
|
||||
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
||||
StsCredentials stsCredentials =
|
||||
StsCredentials.create(
|
||||
StsCredentials.Factory.create(
|
||||
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
||||
try {
|
||||
stsCredentials.toBuilder();
|
||||
|
|
@ -195,6 +195,13 @@ public class StsCredentialsTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void defaultFactory() {
|
||||
StsCredentials stsCreds = StsCredentials.Factory.getInstance()
|
||||
.create(STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath());
|
||||
assertThat(stsCreds.transportFactory).isEqualTo(StsCredentials.defaultHttpTransportFactory);
|
||||
}
|
||||
|
||||
private static final String ACCESS_TOKEN = "eyJhbGciOiJSU";
|
||||
private static final String MOCK_RESPONSE =
|
||||
"{\"access_token\": \""
|
||||
|
|
|
|||
Loading…
Reference in New Issue