mirror of https://github.com/grpc/grpc-java.git
xds: first part of MeshCaCertificateProvider (#7247)
This commit is contained in:
parent
d2182fe197
commit
06ca927a64
|
|
@ -124,6 +124,9 @@ public abstract class CertificateProvider implements Closeable {
|
||||||
@Override
|
@Override
|
||||||
public abstract void close();
|
public abstract void close();
|
||||||
|
|
||||||
|
/** Starts the cert refresh and watcher update cycle. */
|
||||||
|
public abstract void start();
|
||||||
|
|
||||||
private final DistributorWatcher watcher;
|
private final DistributorWatcher watcher;
|
||||||
private final boolean notifyCertUpdates;
|
private final boolean notifyCertUpdates;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -131,8 +131,10 @@ public final class CertificateProviderStore {
|
||||||
if (certProviderProvider == null) {
|
if (certProviderProvider == null) {
|
||||||
throw new IllegalArgumentException("Provider not found.");
|
throw new IllegalArgumentException("Provider not found.");
|
||||||
}
|
}
|
||||||
return certProviderProvider.createCertificateProvider(
|
CertificateProvider certProvider = certProviderProvider.createCertificateProvider(
|
||||||
key.config, new CertificateProvider.DistributorWatcher(), key.notifyCertUpdates);
|
key.config, new CertificateProvider.DistributorWatcher(), key.notifyCertUpdates);
|
||||||
|
certProvider.start();
|
||||||
|
return certProvider;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,134 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2020 The gRPC Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package io.grpc.xds.internal.certprovider;
|
||||||
|
|
||||||
|
import static com.google.common.base.Preconditions.checkArgument;
|
||||||
|
|
||||||
|
import com.google.auth.oauth2.GoogleCredentials;
|
||||||
|
import io.grpc.ManagedChannel;
|
||||||
|
import io.grpc.ManagedChannelBuilder;
|
||||||
|
import io.grpc.internal.BackoffPolicy;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.logging.Level;
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/** Implementation of {@link CertificateProvider} for the Google Mesh CA. */
|
||||||
|
final class MeshCaCertificateProvider extends CertificateProvider {
|
||||||
|
private static final Logger logger = Logger.getLogger(MeshCaCertificateProvider.class.getName());
|
||||||
|
|
||||||
|
protected MeshCaCertificateProvider(DistributorWatcher watcher, boolean notifyCertUpdates,
|
||||||
|
String meshCaUrl, String zone, long validitySeconds,
|
||||||
|
int keySize, String alg, String signatureAlg, MeshCaChannelFactory meshCaChannelFactory,
|
||||||
|
BackoffPolicy.Provider backoffPolicyProvider, long renewalGracePeriodSeconds,
|
||||||
|
int maxRetryAttempts, GoogleCredentials oauth2Creds) {
|
||||||
|
super(watcher, notifyCertUpdates);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void start() {
|
||||||
|
// TODO implement
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
// TODO implement
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Factory for creating channels to MeshCA sever. */
|
||||||
|
abstract static class MeshCaChannelFactory {
|
||||||
|
|
||||||
|
private static final MeshCaChannelFactory DEFAULT_INSTANCE =
|
||||||
|
new MeshCaChannelFactory() {
|
||||||
|
|
||||||
|
/** Creates a channel to the URL in the given list. */
|
||||||
|
@Override
|
||||||
|
ManagedChannel createChannel(String serverUri) {
|
||||||
|
checkArgument(serverUri != null && !serverUri.isEmpty(), "serverUri is null/empty!");
|
||||||
|
logger.log(Level.INFO, "Creating channel to {0}", serverUri);
|
||||||
|
|
||||||
|
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(serverUri);
|
||||||
|
return channelBuilder.keepAliveTime(1, TimeUnit.MINUTES).build();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static MeshCaChannelFactory getInstance() {
|
||||||
|
return DEFAULT_INSTANCE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a channel to the server.
|
||||||
|
*/
|
||||||
|
abstract ManagedChannel createChannel(String serverUri);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Factory for creating channels to MeshCA sever. */
|
||||||
|
abstract static class Factory {
|
||||||
|
private static final Factory DEFAULT_INSTANCE =
|
||||||
|
new Factory() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
MeshCaCertificateProvider create(
|
||||||
|
DistributorWatcher watcher,
|
||||||
|
boolean notifyCertUpdates,
|
||||||
|
String meshCaUrl,
|
||||||
|
String zone,
|
||||||
|
long validitySeconds,
|
||||||
|
int keySize,
|
||||||
|
String alg,
|
||||||
|
String signatureAlg,
|
||||||
|
MeshCaChannelFactory meshCaChannelFactory,
|
||||||
|
BackoffPolicy.Provider backoffPolicyProvider,
|
||||||
|
long renewalGracePeriodSeconds,
|
||||||
|
int maxRetryAttempts,
|
||||||
|
GoogleCredentials oauth2Creds) {
|
||||||
|
return new MeshCaCertificateProvider(
|
||||||
|
watcher,
|
||||||
|
notifyCertUpdates,
|
||||||
|
meshCaUrl,
|
||||||
|
zone,
|
||||||
|
validitySeconds,
|
||||||
|
keySize,
|
||||||
|
alg,
|
||||||
|
signatureAlg,
|
||||||
|
meshCaChannelFactory,
|
||||||
|
backoffPolicyProvider,
|
||||||
|
renewalGracePeriodSeconds,
|
||||||
|
maxRetryAttempts,
|
||||||
|
oauth2Creds);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static Factory getInstance() {
|
||||||
|
return DEFAULT_INSTANCE;
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract MeshCaCertificateProvider create(
|
||||||
|
DistributorWatcher watcher,
|
||||||
|
boolean notifyCertUpdates,
|
||||||
|
String meshCaUrl,
|
||||||
|
String zone,
|
||||||
|
long validitySeconds,
|
||||||
|
int keySize,
|
||||||
|
String alg,
|
||||||
|
String signatureAlg,
|
||||||
|
MeshCaChannelFactory meshCaChannelFactory,
|
||||||
|
BackoffPolicy.Provider backoffPolicyProvider,
|
||||||
|
long renewalGracePeriodSeconds,
|
||||||
|
int maxRetryAttempts,
|
||||||
|
GoogleCredentials oauth2Creds);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,197 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2020 The gRPC Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package io.grpc.xds.internal.certprovider;
|
||||||
|
|
||||||
|
import static com.google.common.base.Preconditions.checkArgument;
|
||||||
|
import static com.google.common.base.Preconditions.checkNotNull;
|
||||||
|
import static com.google.common.base.Preconditions.checkState;
|
||||||
|
|
||||||
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
|
import io.grpc.internal.BackoffPolicy;
|
||||||
|
import io.grpc.internal.ExponentialBackoffPolicy;
|
||||||
|
import io.grpc.xds.internal.sts.StsCredentials;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.regex.Matcher;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provider of {@link CertificateProvider}s. Implemented by the implementer of the plugin. We may
|
||||||
|
* move this out of the internal package and make this an official API in the future.
|
||||||
|
*/
|
||||||
|
final class MeshCaCertificateProviderProvider implements CertificateProviderProvider {
|
||||||
|
|
||||||
|
private static final String MESHCA_URL_KEY = "meshCaUrl";
|
||||||
|
private static final String RPC_TIMEOUT_SECONDS_KEY = "rpcTimeoutSeconds";
|
||||||
|
private static final String GKECLUSTER_URL_KEY = "gkeClusterUrl";
|
||||||
|
private static final String CERT_VALIDITY_SECONDS_KEY = "certValiditySeconds";
|
||||||
|
private static final String RENEWAL_GRACE_PERIOD_SECONDS_KEY = "renewalGracePeriodSeconds";
|
||||||
|
private static final String KEY_ALGO_KEY = "keyAlgo"; // aka keyType
|
||||||
|
private static final String KEY_SIZE_KEY = "keySize";
|
||||||
|
private static final String SIGNATURE_ALGO_KEY = "signatureAlgo";
|
||||||
|
private static final String MAX_RETRY_ATTEMPTS_KEY = "maxRetryAttempts";
|
||||||
|
private static final String STS_URL_KEY = "stsUrl";
|
||||||
|
private static final String GKE_SA_JWT_LOCATION_KEY = "gkeSaJwtLocation";
|
||||||
|
|
||||||
|
static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com";
|
||||||
|
static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L;
|
||||||
|
static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L; // 9 hours
|
||||||
|
static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L; // 1 hour
|
||||||
|
static final String KEY_ALGO_DEFAULT = "RSA"; // aka keyType
|
||||||
|
static final int KEY_SIZE_DEFAULT = 2048;
|
||||||
|
static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA";
|
||||||
|
static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3;
|
||||||
|
static final String STS_URL_DEFAULT = "https://securetoken.googleapis.com/v1/identitybindingtoken";
|
||||||
|
|
||||||
|
private static final Pattern CLUSTER_URL_PATTERN = Pattern
|
||||||
|
.compile(".*/projects/(.*)/locations/(.*)/clusters/.*");
|
||||||
|
|
||||||
|
private static final String TRUST_DOMAIN_SUFFIX = ".svc.id.goog";
|
||||||
|
private static final String AUDIENCE_PREFIX = "identitynamespace:";
|
||||||
|
static final String MESH_CA_NAME = "meshCA";
|
||||||
|
|
||||||
|
static {
|
||||||
|
CertificateProviderRegistry.getInstance()
|
||||||
|
.register(
|
||||||
|
new MeshCaCertificateProviderProvider(
|
||||||
|
StsCredentials.Factory.getInstance(),
|
||||||
|
MeshCaCertificateProvider.MeshCaChannelFactory.getInstance(),
|
||||||
|
new ExponentialBackoffPolicy.Provider(),
|
||||||
|
MeshCaCertificateProvider.Factory.getInstance()));
|
||||||
|
}
|
||||||
|
|
||||||
|
final StsCredentials.Factory stsCredentialsFactory;
|
||||||
|
final MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory;
|
||||||
|
final BackoffPolicy.Provider backoffPolicyProvider;
|
||||||
|
final MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
MeshCaCertificateProviderProvider(StsCredentials.Factory stsCredentialsFactory,
|
||||||
|
MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory,
|
||||||
|
BackoffPolicy.Provider backoffPolicyProvider,
|
||||||
|
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory) {
|
||||||
|
this.stsCredentialsFactory = stsCredentialsFactory;
|
||||||
|
this.meshCaChannelFactory = meshCaChannelFactory;
|
||||||
|
this.backoffPolicyProvider = backoffPolicyProvider;
|
||||||
|
this.meshCaCertificateProviderFactory = meshCaCertificateProviderFactory;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return MESH_CA_NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CertificateProvider createCertificateProvider(
|
||||||
|
Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates) {
|
||||||
|
|
||||||
|
Config configObj = validateAndTranslateConfig(config);
|
||||||
|
|
||||||
|
// Construct audience from project and gkeClusterUrl
|
||||||
|
String audience =
|
||||||
|
AUDIENCE_PREFIX + configObj.project + TRUST_DOMAIN_SUFFIX + ":" + configObj.gkeClusterUrl;
|
||||||
|
StsCredentials stsCredentials = stsCredentialsFactory
|
||||||
|
.create(configObj.stsUrl, audience, configObj.gkeSaJwtLocation);
|
||||||
|
|
||||||
|
return meshCaCertificateProviderFactory.create(watcher, notifyCertUpdates, configObj.meshCaUrl,
|
||||||
|
configObj.zone,
|
||||||
|
configObj.certValiditySeconds, configObj.keySize, configObj.keyAlgo,
|
||||||
|
configObj.signatureAlgo,
|
||||||
|
meshCaChannelFactory, backoffPolicyProvider,
|
||||||
|
configObj.renewalGracePeriodSeconds, configObj.maxRetryAttempts, stsCredentials);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Config validateAndTranslateConfig(Object config) {
|
||||||
|
// TODO(sanjaypujare): add support for string, struct proto etc
|
||||||
|
checkArgument(config instanceof Map, "Only Map supported for config");
|
||||||
|
@SuppressWarnings("unchecked") Map<String, String> map = (Map<String, String>)config;
|
||||||
|
|
||||||
|
Config configObj = new Config();
|
||||||
|
configObj.meshCaUrl = mapGetOrDefault(map, MESHCA_URL_KEY, MESHCA_URL_DEFAULT);
|
||||||
|
configObj.rpcTimeoutSeconds =
|
||||||
|
mapGetOrDefault(map, RPC_TIMEOUT_SECONDS_KEY, RPC_TIMEOUT_SECONDS_DEFAULT);
|
||||||
|
configObj.gkeClusterUrl =
|
||||||
|
checkNotNull(
|
||||||
|
map.get(GKECLUSTER_URL_KEY), GKECLUSTER_URL_KEY + " is required in the config");
|
||||||
|
configObj.certValiditySeconds =
|
||||||
|
mapGetOrDefault(map, CERT_VALIDITY_SECONDS_KEY, CERT_VALIDITY_SECONDS_DEFAULT);
|
||||||
|
configObj.renewalGracePeriodSeconds =
|
||||||
|
mapGetOrDefault(
|
||||||
|
map, RENEWAL_GRACE_PERIOD_SECONDS_KEY, RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT);
|
||||||
|
configObj.keyAlgo = mapGetOrDefault(map, KEY_ALGO_KEY, KEY_ALGO_DEFAULT);
|
||||||
|
configObj.keySize = mapGetOrDefault(map, KEY_SIZE_KEY, KEY_SIZE_DEFAULT);
|
||||||
|
configObj.signatureAlgo = mapGetOrDefault(map, SIGNATURE_ALGO_KEY, SIGNATURE_ALGO_DEFAULT);
|
||||||
|
configObj.maxRetryAttempts =
|
||||||
|
mapGetOrDefault(map, MAX_RETRY_ATTEMPTS_KEY, MAX_RETRY_ATTEMPTS_DEFAULT);
|
||||||
|
configObj.stsUrl = mapGetOrDefault(map, STS_URL_KEY, STS_URL_DEFAULT);
|
||||||
|
configObj.gkeSaJwtLocation =
|
||||||
|
checkNotNull(
|
||||||
|
map.get(GKE_SA_JWT_LOCATION_KEY),
|
||||||
|
GKE_SA_JWT_LOCATION_KEY + " is required in the config");
|
||||||
|
parseProjectAndZone(configObj.gkeClusterUrl, configObj);
|
||||||
|
return configObj;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String mapGetOrDefault(Map<String, String> map, String key, String defaultVal) {
|
||||||
|
String value = map.get(key);
|
||||||
|
if (value == null) {
|
||||||
|
return defaultVal;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Long mapGetOrDefault(Map<String, String> map, String key, long defaultVal) {
|
||||||
|
String value = map.get(key);
|
||||||
|
if (value == null) {
|
||||||
|
return defaultVal;
|
||||||
|
}
|
||||||
|
return Long.parseLong(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Integer mapGetOrDefault(Map<String, String> map, String key, int defaultVal) {
|
||||||
|
String value = map.get(key);
|
||||||
|
if (value == null) {
|
||||||
|
return defaultVal;
|
||||||
|
}
|
||||||
|
return Integer.parseInt(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void parseProjectAndZone(String gkeClusterUrl, Config configObj) {
|
||||||
|
Matcher matcher = CLUSTER_URL_PATTERN.matcher(gkeClusterUrl);
|
||||||
|
checkState(matcher.find(), "gkeClusterUrl does not have correct format");
|
||||||
|
checkState(matcher.groupCount() == 2, "gkeClusterUrl does not have project and location parts");
|
||||||
|
configObj.project = matcher.group(1);
|
||||||
|
configObj.zone = matcher.group(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** POJO class for storing various config values. */
|
||||||
|
@VisibleForTesting
|
||||||
|
static class Config {
|
||||||
|
String meshCaUrl;
|
||||||
|
Long rpcTimeoutSeconds;
|
||||||
|
String gkeClusterUrl;
|
||||||
|
Long certValiditySeconds;
|
||||||
|
Long renewalGracePeriodSeconds;
|
||||||
|
String keyAlgo; // aka keyType
|
||||||
|
Integer keySize;
|
||||||
|
String signatureAlgo;
|
||||||
|
Integer maxRetryAttempts;
|
||||||
|
String stsUrl;
|
||||||
|
String gkeSaJwtLocation;
|
||||||
|
String zone;
|
||||||
|
String project;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -47,14 +47,14 @@ import java.util.Map;
|
||||||
public final class StsCredentials extends GoogleCredentials {
|
public final class StsCredentials extends GoogleCredentials {
|
||||||
private static final long serialVersionUID = 6647041424685484932L;
|
private static final long serialVersionUID = 6647041424685484932L;
|
||||||
|
|
||||||
private static final HttpTransportFactory defaultHttpTransportFactory =
|
@VisibleForTesting static final HttpTransportFactory defaultHttpTransportFactory =
|
||||||
new DefaultHttpTransportFactory();
|
new DefaultHttpTransportFactory();
|
||||||
private static final String CLOUD_PLATFORM_SCOPE =
|
private static final String CLOUD_PLATFORM_SCOPE =
|
||||||
"https://www.googleapis.com/auth/cloud-platform";
|
"https://www.googleapis.com/auth/cloud-platform";
|
||||||
private final String sourceCredentialsFileLocation;
|
@VisibleForTesting final String sourceCredentialsFileLocation;
|
||||||
private final String identityTokenEndpoint;
|
@VisibleForTesting final String identityTokenEndpoint;
|
||||||
private final String audience;
|
@VisibleForTesting final String audience;
|
||||||
private transient HttpTransportFactory transportFactory;
|
@VisibleForTesting transient HttpTransportFactory transportFactory;
|
||||||
|
|
||||||
private StsCredentials(
|
private StsCredentials(
|
||||||
String identityTokenEndpoint,
|
String identityTokenEndpoint,
|
||||||
|
|
@ -67,33 +67,6 @@ public final class StsCredentials extends GoogleCredentials {
|
||||||
this.transportFactory = transportFactory;
|
this.transportFactory = transportFactory;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates an StsCredentials.
|
|
||||||
*
|
|
||||||
* @param identityTokenEndpoint URL of the token exchange service to use.
|
|
||||||
* @param audience Audience to use in the STS request.
|
|
||||||
* @param sourceCredentialsFileLocation file-system location that contains the
|
|
||||||
* source creds e.g. JWT contents.
|
|
||||||
*/
|
|
||||||
public static StsCredentials create(
|
|
||||||
String identityTokenEndpoint, String audience, String sourceCredentialsFileLocation) {
|
|
||||||
return create(
|
|
||||||
identityTokenEndpoint,
|
|
||||||
audience,
|
|
||||||
sourceCredentialsFileLocation,
|
|
||||||
getFromServiceLoader(HttpTransportFactory.class, defaultHttpTransportFactory));
|
|
||||||
}
|
|
||||||
|
|
||||||
@VisibleForTesting
|
|
||||||
static StsCredentials create(
|
|
||||||
String identityTokenEndpoint,
|
|
||||||
String audience,
|
|
||||||
String sourceCredentialsFileLocation,
|
|
||||||
HttpTransportFactory transportFactory) {
|
|
||||||
return new StsCredentials(
|
|
||||||
identityTokenEndpoint, audience, sourceCredentialsFileLocation, transportFactory);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public AccessToken refreshAccessToken() throws IOException {
|
public AccessToken refreshAccessToken() throws IOException {
|
||||||
AccessToken tok = getSourceAccessTokenFromFileLocation();
|
AccessToken tok = getSourceAccessTokenFromFileLocation();
|
||||||
|
|
@ -157,6 +130,48 @@ public final class StsCredentials extends GoogleCredentials {
|
||||||
throw new UnsupportedOperationException("toBuilder not supported");
|
throw new UnsupportedOperationException("toBuilder not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Factory for creating StsCredentials. */
|
||||||
|
public abstract static class Factory {
|
||||||
|
private static final Factory DEFAULT_INSTANCE =
|
||||||
|
new Factory() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public StsCredentials create(
|
||||||
|
String identityTokenEndpoint, String audience, String sourceCredentialsFileLocation) {
|
||||||
|
return create(
|
||||||
|
identityTokenEndpoint,
|
||||||
|
audience,
|
||||||
|
sourceCredentialsFileLocation,
|
||||||
|
getFromServiceLoader(HttpTransportFactory.class, defaultHttpTransportFactory));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static Factory getInstance() {
|
||||||
|
return DEFAULT_INSTANCE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an StsCredentials.
|
||||||
|
*
|
||||||
|
* @param identityTokenEndpoint URL of the token exchange service to use.
|
||||||
|
* @param audience Audience to use in the STS request.
|
||||||
|
* @param sourceCredentialsFileLocation file-system location that contains the
|
||||||
|
* source creds e.g. JWT contents.
|
||||||
|
*/
|
||||||
|
public abstract StsCredentials create(
|
||||||
|
String identityTokenEndpoint, String audience, String sourceCredentialsFileLocation);
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
static StsCredentials create(
|
||||||
|
String identityTokenEndpoint,
|
||||||
|
String audience,
|
||||||
|
String sourceCredentialsFileLocation,
|
||||||
|
HttpTransportFactory transportFactory) {
|
||||||
|
return new StsCredentials(
|
||||||
|
identityTokenEndpoint, audience, sourceCredentialsFileLocation, transportFactory);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static class DefaultHttpTransportFactory implements HttpTransportFactory {
|
private static class DefaultHttpTransportFactory implements HttpTransportFactory {
|
||||||
|
|
||||||
private static final HttpTransport netHttpTransport = new NetHttpTransport();
|
private static final HttpTransport netHttpTransport = new NetHttpTransport();
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,7 @@ public class CertificateProviderStoreTest {
|
||||||
Object config;
|
Object config;
|
||||||
CertificateProviderProvider certProviderProvider;
|
CertificateProviderProvider certProviderProvider;
|
||||||
int closeCalled = 0;
|
int closeCalled = 0;
|
||||||
|
int startCalled = 0;
|
||||||
|
|
||||||
protected TestCertificateProvider(
|
protected TestCertificateProvider(
|
||||||
CertificateProvider.DistributorWatcher watcher,
|
CertificateProvider.DistributorWatcher watcher,
|
||||||
|
|
@ -71,6 +72,11 @@ public class CertificateProviderStoreTest {
|
||||||
public void close() {
|
public void close() {
|
||||||
closeCalled++;
|
closeCalled++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void start() {
|
||||||
|
startCalled++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
|
|
@ -161,6 +167,7 @@ public class CertificateProviderStoreTest {
|
||||||
assertThat(handle1.certProvider).isInstanceOf(TestCertificateProvider.class);
|
assertThat(handle1.certProvider).isInstanceOf(TestCertificateProvider.class);
|
||||||
TestCertificateProvider testCertificateProvider =
|
TestCertificateProvider testCertificateProvider =
|
||||||
(TestCertificateProvider) handle1.certProvider;
|
(TestCertificateProvider) handle1.certProvider;
|
||||||
|
assertThat(testCertificateProvider.startCalled).isEqualTo(1);
|
||||||
CertificateProvider.DistributorWatcher distWatcher = testCertificateProvider.getWatcher();
|
CertificateProvider.DistributorWatcher distWatcher = testCertificateProvider.getWatcher();
|
||||||
assertThat(distWatcher.downsstreamWatchers).hasSize(2);
|
assertThat(distWatcher.downsstreamWatchers).hasSize(2);
|
||||||
PrivateKey testKey = mock(PrivateKey.class);
|
PrivateKey testKey = mock(PrivateKey.class);
|
||||||
|
|
@ -335,6 +342,8 @@ public class CertificateProviderStoreTest {
|
||||||
verify(mockWatcher2, times(1)).updateCertificate(eq(testKey2), eq(testList2));
|
verify(mockWatcher2, times(1)).updateCertificate(eq(testKey2), eq(testList2));
|
||||||
verify(mockWatcher1, never())
|
verify(mockWatcher1, never())
|
||||||
.updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class));
|
.updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class));
|
||||||
|
assertThat(testCertificateProvider1.startCalled).isEqualTo(1);
|
||||||
|
assertThat(testCertificateProvider2.startCalled).isEqualTo(1);
|
||||||
handle2.close();
|
handle2.close();
|
||||||
assertThat(testCertificateProvider2.closeCalled).isEqualTo(1);
|
assertThat(testCertificateProvider2.closeCalled).isEqualTo(1);
|
||||||
handle1.close();
|
handle1.close();
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,212 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2020 The gRPC Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package io.grpc.xds.internal.certprovider;
|
||||||
|
|
||||||
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.ArgumentMatchers.isNull;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
|
||||||
|
import com.google.auth.oauth2.GoogleCredentials;
|
||||||
|
import io.grpc.internal.BackoffPolicy;
|
||||||
|
import io.grpc.internal.ExponentialBackoffPolicy;
|
||||||
|
import io.grpc.xds.internal.sts.StsCredentials;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.JUnit4;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.MockitoAnnotations;
|
||||||
|
|
||||||
|
/** Unit tests for {@link MeshCaCertificateProviderProvider}. */
|
||||||
|
@RunWith(JUnit4.class)
|
||||||
|
public class MeshCaCertificateProviderProviderTest {
|
||||||
|
|
||||||
|
public static final String EXPECTED_AUDIENCE =
|
||||||
|
"identitynamespace:test-project1.svc.id.goog:https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3";
|
||||||
|
public static final String TMP_PATH_4 = "/tmp/path4";
|
||||||
|
public static final String NON_DEFAULT_MESH_CA_URL = "nonDefaultMeshCaUrl";
|
||||||
|
public static final String GKE_CLUSTER_URL =
|
||||||
|
"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3";
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
StsCredentials.Factory stsCredentialsFactory;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
BackoffPolicy.Provider backoffPolicyProvider;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
|
||||||
|
private MeshCaCertificateProviderProvider provider;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() throws IOException {
|
||||||
|
MockitoAnnotations.initMocks(this);
|
||||||
|
provider =
|
||||||
|
new MeshCaCertificateProviderProvider(
|
||||||
|
stsCredentialsFactory,
|
||||||
|
meshCaChannelFactory,
|
||||||
|
backoffPolicyProvider,
|
||||||
|
meshCaCertificateProviderFactory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void providerRegisteredName() {
|
||||||
|
CertificateProviderProvider certProviderProvider = CertificateProviderRegistry.getInstance()
|
||||||
|
.getProvider(MeshCaCertificateProviderProvider.MESH_CA_NAME);
|
||||||
|
assertThat(certProviderProvider).isInstanceOf(MeshCaCertificateProviderProvider.class);
|
||||||
|
MeshCaCertificateProviderProvider meshCaCertificateProviderProvider =
|
||||||
|
(MeshCaCertificateProviderProvider) certProviderProvider;
|
||||||
|
assertThat(meshCaCertificateProviderProvider.stsCredentialsFactory)
|
||||||
|
.isSameInstanceAs(StsCredentials.Factory.getInstance());
|
||||||
|
assertThat(meshCaCertificateProviderProvider.meshCaChannelFactory)
|
||||||
|
.isSameInstanceAs(MeshCaCertificateProvider.MeshCaChannelFactory.getInstance());
|
||||||
|
assertThat(meshCaCertificateProviderProvider.backoffPolicyProvider)
|
||||||
|
.isInstanceOf(ExponentialBackoffPolicy.Provider.class);
|
||||||
|
assertThat(meshCaCertificateProviderProvider.meshCaCertificateProviderFactory)
|
||||||
|
.isSameInstanceAs(MeshCaCertificateProvider.Factory.getInstance());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void createProvider_minimalConfig() {
|
||||||
|
CertificateProvider.DistributorWatcher distWatcher =
|
||||||
|
new CertificateProvider.DistributorWatcher();
|
||||||
|
Map<String, String> map = buildMinimalMap();
|
||||||
|
provider.createCertificateProvider(map, distWatcher, true);
|
||||||
|
verify(stsCredentialsFactory, times(1))
|
||||||
|
.create(
|
||||||
|
eq(MeshCaCertificateProviderProvider.STS_URL_DEFAULT),
|
||||||
|
eq(EXPECTED_AUDIENCE),
|
||||||
|
eq(TMP_PATH_4));
|
||||||
|
verify(meshCaCertificateProviderFactory, times(1))
|
||||||
|
.create(
|
||||||
|
eq(distWatcher),
|
||||||
|
eq(true),
|
||||||
|
eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT),
|
||||||
|
eq("test-zone2"),
|
||||||
|
eq(MeshCaCertificateProviderProvider.CERT_VALIDITY_SECONDS_DEFAULT),
|
||||||
|
eq(MeshCaCertificateProviderProvider.KEY_SIZE_DEFAULT),
|
||||||
|
eq(MeshCaCertificateProviderProvider.KEY_ALGO_DEFAULT),
|
||||||
|
eq(MeshCaCertificateProviderProvider.SIGNATURE_ALGO_DEFAULT),
|
||||||
|
eq(meshCaChannelFactory),
|
||||||
|
eq(backoffPolicyProvider),
|
||||||
|
eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT),
|
||||||
|
eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT),
|
||||||
|
(GoogleCredentials) isNull());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void createProvider_missingGkeUrl_expectException() {
|
||||||
|
CertificateProvider.DistributorWatcher distWatcher =
|
||||||
|
new CertificateProvider.DistributorWatcher();
|
||||||
|
Map<String, String> map = buildMinimalMap();
|
||||||
|
map.remove("gkeClusterUrl");
|
||||||
|
try {
|
||||||
|
provider.createCertificateProvider(map, distWatcher, true);
|
||||||
|
fail("exception expected");
|
||||||
|
} catch (NullPointerException npe) {
|
||||||
|
assertThat(npe).hasMessageThat().isEqualTo("gkeClusterUrl is required in the config");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void createProvider_missingGkeSaJwtLocation_expectException() {
|
||||||
|
CertificateProvider.DistributorWatcher distWatcher =
|
||||||
|
new CertificateProvider.DistributorWatcher();
|
||||||
|
Map<String, String> map = buildMinimalMap();
|
||||||
|
map.remove("gkeSaJwtLocation");
|
||||||
|
try {
|
||||||
|
provider.createCertificateProvider(map, distWatcher, true);
|
||||||
|
fail("exception expected");
|
||||||
|
} catch (NullPointerException npe) {
|
||||||
|
assertThat(npe).hasMessageThat().isEqualTo("gkeSaJwtLocation is required in the config");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void createProvider_missingProject_expectException() {
|
||||||
|
CertificateProvider.DistributorWatcher distWatcher =
|
||||||
|
new CertificateProvider.DistributorWatcher();
|
||||||
|
Map<String, String> map = buildMinimalMap();
|
||||||
|
map.put("gkeClusterUrl", "https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3");
|
||||||
|
try {
|
||||||
|
provider.createCertificateProvider(map, distWatcher, true);
|
||||||
|
fail("exception expected");
|
||||||
|
} catch (IllegalStateException ex) {
|
||||||
|
assertThat(ex).hasMessageThat().isEqualTo("gkeClusterUrl does not have correct format");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void createProvider_nonDefaultFullConfig() {
|
||||||
|
CertificateProvider.DistributorWatcher distWatcher =
|
||||||
|
new CertificateProvider.DistributorWatcher();
|
||||||
|
Map<String, String> map = buildFullMap();
|
||||||
|
provider.createCertificateProvider(map, distWatcher, true);
|
||||||
|
verify(stsCredentialsFactory, times(1))
|
||||||
|
.create(
|
||||||
|
eq("nonDefaultStsUrl"),
|
||||||
|
eq(EXPECTED_AUDIENCE),
|
||||||
|
eq(TMP_PATH_4));
|
||||||
|
verify(meshCaCertificateProviderFactory, times(1))
|
||||||
|
.create(
|
||||||
|
eq(distWatcher),
|
||||||
|
eq(true),
|
||||||
|
eq(NON_DEFAULT_MESH_CA_URL),
|
||||||
|
eq("test-zone2"),
|
||||||
|
eq(234567L),
|
||||||
|
eq(4096),
|
||||||
|
eq("KEY-ALGO1"),
|
||||||
|
eq("SIG-ALGO2"),
|
||||||
|
eq(meshCaChannelFactory),
|
||||||
|
eq(backoffPolicyProvider),
|
||||||
|
eq(4321L),
|
||||||
|
eq(9),
|
||||||
|
(GoogleCredentials) isNull());
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, String> buildFullMap() {
|
||||||
|
Map<String, String> map = new HashMap<>();
|
||||||
|
map.put("gkeClusterUrl", GKE_CLUSTER_URL);
|
||||||
|
map.put("gkeSaJwtLocation", TMP_PATH_4);
|
||||||
|
map.put("meshCaUrl", NON_DEFAULT_MESH_CA_URL);
|
||||||
|
map.put("rpcTimeoutSeconds", "123");
|
||||||
|
map.put("certValiditySeconds", "234567");
|
||||||
|
map.put("renewalGracePeriodSeconds", "4321");
|
||||||
|
map.put("keyAlgo", "KEY-ALGO1");
|
||||||
|
map.put("keySize", "4096");
|
||||||
|
map.put("signatureAlgo", "SIG-ALGO2");
|
||||||
|
map.put("maxRetryAttempts", "9");
|
||||||
|
map.put("stsUrl", "nonDefaultStsUrl");
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, String> buildMinimalMap() {
|
||||||
|
Map<String, String> map = new HashMap<>();
|
||||||
|
map.put("gkeClusterUrl", GKE_CLUSTER_URL);
|
||||||
|
map.put("gkeSaJwtLocation", TMP_PATH_4);
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -83,7 +83,7 @@ public class StsCredentialsTest {
|
||||||
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
||||||
when(httpTransportFactory.create()).thenReturn(httpTransport);
|
when(httpTransportFactory.create()).thenReturn(httpTransport);
|
||||||
StsCredentials stsCredentials =
|
StsCredentials stsCredentials =
|
||||||
StsCredentials.create(
|
StsCredentials.Factory.create(
|
||||||
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
||||||
AccessToken token = stsCredentials.refreshAccessToken();
|
AccessToken token = stsCredentials.refreshAccessToken();
|
||||||
assertThat(token).isNotNull();
|
assertThat(token).isNotNull();
|
||||||
|
|
@ -115,7 +115,7 @@ public class StsCredentialsTest {
|
||||||
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
||||||
when(httpTransportFactory.create()).thenReturn(httpTransport);
|
when(httpTransportFactory.create()).thenReturn(httpTransport);
|
||||||
StsCredentials stsCredentials =
|
StsCredentials stsCredentials =
|
||||||
StsCredentials.create(
|
StsCredentials.Factory.create(
|
||||||
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
||||||
CallCredentials callCreds = MoreCallCredentials.from(stsCredentials);
|
CallCredentials callCreds = MoreCallCredentials.from(stsCredentials);
|
||||||
CallCredentials.RequestInfo requestInfo = mock(CallCredentials.RequestInfo.class);
|
CallCredentials.RequestInfo requestInfo = mock(CallCredentials.RequestInfo.class);
|
||||||
|
|
@ -150,7 +150,7 @@ public class StsCredentialsTest {
|
||||||
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
||||||
when(httpTransportFactory.create()).thenReturn(httpTransport);
|
when(httpTransportFactory.create()).thenReturn(httpTransport);
|
||||||
StsCredentials stsCredentials =
|
StsCredentials stsCredentials =
|
||||||
StsCredentials.create(
|
StsCredentials.Factory.create(
|
||||||
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
||||||
try {
|
try {
|
||||||
stsCredentials.refreshAccessToken();
|
stsCredentials.refreshAccessToken();
|
||||||
|
|
@ -171,7 +171,7 @@ public class StsCredentialsTest {
|
||||||
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
||||||
when(httpTransportFactory.create()).thenReturn(httpTransport);
|
when(httpTransportFactory.create()).thenReturn(httpTransport);
|
||||||
StsCredentials stsCredentials =
|
StsCredentials stsCredentials =
|
||||||
StsCredentials.create(
|
StsCredentials.Factory.create(
|
||||||
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
||||||
try {
|
try {
|
||||||
stsCredentials.refreshAccessToken();
|
stsCredentials.refreshAccessToken();
|
||||||
|
|
@ -185,7 +185,7 @@ public class StsCredentialsTest {
|
||||||
public void toBuilder_unsupportedException() {
|
public void toBuilder_unsupportedException() {
|
||||||
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
|
||||||
StsCredentials stsCredentials =
|
StsCredentials stsCredentials =
|
||||||
StsCredentials.create(
|
StsCredentials.Factory.create(
|
||||||
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
|
||||||
try {
|
try {
|
||||||
stsCredentials.toBuilder();
|
stsCredentials.toBuilder();
|
||||||
|
|
@ -195,6 +195,13 @@ public class StsCredentialsTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void defaultFactory() {
|
||||||
|
StsCredentials stsCreds = StsCredentials.Factory.getInstance()
|
||||||
|
.create(STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath());
|
||||||
|
assertThat(stsCreds.transportFactory).isEqualTo(StsCredentials.defaultHttpTransportFactory);
|
||||||
|
}
|
||||||
|
|
||||||
private static final String ACCESS_TOKEN = "eyJhbGciOiJSU";
|
private static final String ACCESS_TOKEN = "eyJhbGciOiJSU";
|
||||||
private static final String MOCK_RESPONSE =
|
private static final String MOCK_RESPONSE =
|
||||||
"{\"access_token\": \""
|
"{\"access_token\": \""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue