Add `CachedJwtSource` (#116)
Add CachedJwtSource Signed-off-by: Max Lambrecht <max.lambrecht@hpe.com>
This commit is contained in:
parent
5e16f7a632
commit
3ca77c1de2
|
|
@ -64,11 +64,17 @@ public class JwtSvid {
|
|||
*/
|
||||
String token;
|
||||
|
||||
/**
|
||||
* Issued at time of JWT-SVID as present in 'iat' claim.
|
||||
*/
|
||||
Date issuedAt;
|
||||
|
||||
public static final String HEADER_TYP_JWT = "JWT";
|
||||
public static final String HEADER_TYP_JOSE = "JOSE";
|
||||
|
||||
private JwtSvid(final SpiffeId spiffeId,
|
||||
final Set<String> audience,
|
||||
final Date issuedAt,
|
||||
final Date expiry,
|
||||
final Map<String, Object> claims,
|
||||
final String token) {
|
||||
|
|
@ -77,6 +83,7 @@ public class JwtSvid {
|
|||
this.expiry = expiry;
|
||||
this.claims = claims;
|
||||
this.token = token;
|
||||
this.issuedAt = issuedAt;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -120,6 +127,8 @@ public class JwtSvid {
|
|||
val claimsSet = getJwtClaimsSet(signedJwt);
|
||||
validateAudience(claimsSet.getAudience(), audience);
|
||||
|
||||
val issuedAt = claimsSet.getIssueTime();
|
||||
|
||||
val expirationTime = claimsSet.getExpirationTime();
|
||||
validateExpiration(expirationTime);
|
||||
|
||||
|
|
@ -132,7 +141,8 @@ public class JwtSvid {
|
|||
verifySignature(signedJwt, jwtAuthority, algorithm, keyId);
|
||||
|
||||
val claimAudience = new HashSet<>(claimsSet.getAudience());
|
||||
return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token);
|
||||
|
||||
return new JwtSvid(spiffeId, claimAudience, issuedAt, expirationTime, claimsSet.getClaims(), token);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -163,13 +173,16 @@ public class JwtSvid {
|
|||
val claimsSet = getJwtClaimsSet(signedJwt);
|
||||
validateAudience(claimsSet.getAudience(), audience);
|
||||
|
||||
val issuedAt = claimsSet.getIssueTime();
|
||||
|
||||
val expirationTime = claimsSet.getExpirationTime();
|
||||
validateExpiration(expirationTime);
|
||||
|
||||
val spiffeId = getSpiffeIdOfSubject(claimsSet);
|
||||
|
||||
val claimAudience = new HashSet<>(claimsSet.getAudience());
|
||||
return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token);
|
||||
|
||||
return new JwtSvid(spiffeId, claimAudience, issuedAt, expirationTime, claimsSet.getClaims(), token);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -0,0 +1,338 @@
|
|||
package io.spiffe.workloadapi;
|
||||
|
||||
|
||||
import io.spiffe.bundle.jwtbundle.JwtBundle;
|
||||
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
|
||||
import io.spiffe.bundle.x509bundle.X509Bundle;
|
||||
import io.spiffe.exception.*;
|
||||
import io.spiffe.spiffeid.SpiffeId;
|
||||
import io.spiffe.spiffeid.TrustDomain;
|
||||
import io.spiffe.svid.jwtsvid.JwtSvid;
|
||||
import lombok.NonNull;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.java.Log;
|
||||
import lombok.val;
|
||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Date;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.logging.Level;
|
||||
|
||||
import static io.spiffe.workloadapi.internal.ThreadUtils.await;
|
||||
|
||||
/**
|
||||
* Represents a source of SPIFFE JWT SVIDs and JWT bundles maintained via the Workload API.
|
||||
* The JWT SVIDs are cached and fetchJwtSvid methods return from cache
|
||||
* checking that the JWT SVID has still at least half of its lifetime.
|
||||
*/
|
||||
@Log
|
||||
public class CachedJwtSource implements JwtSource {
|
||||
static final String TIMEOUT_SYSTEM_PROPERTY = "spiffe.newJwtSource.timeout";
|
||||
|
||||
static final Duration DEFAULT_TIMEOUT =
|
||||
Duration.parse(System.getProperty(TIMEOUT_SYSTEM_PROPERTY, "PT0S"));
|
||||
|
||||
// Synchronized map of JWT SVIDs, keyed by a pair of SPIFFE ID and a Set of audiences strings.
|
||||
// This map is used to cache the JWT SVIDs and avoid fetching them from the Workload API.
|
||||
private final
|
||||
Map<ImmutablePair<SpiffeId, Set<String>>, List<JwtSvid>> jwtSvids = new ConcurrentHashMap<>();
|
||||
|
||||
private JwtBundleSet bundles;
|
||||
|
||||
private final WorkloadApiClient workloadApiClient;
|
||||
private volatile boolean closed;
|
||||
private Clock clock;
|
||||
|
||||
// private constructor
|
||||
private CachedJwtSource(final WorkloadApiClient workloadApiClient) {
|
||||
this.clock = Clock.systemDefaultZone();
|
||||
this.workloadApiClient = workloadApiClient;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Cached JWT source. It blocks until the initial update with the JWT bundles
|
||||
* has been received from the Workload API or until the timeout configured
|
||||
* through the system property `spiffe.newJwtSource.timeout` expires.
|
||||
* If no timeout is configured, it blocks until it gets a JWT update from the Workload API.
|
||||
* <p>
|
||||
* It uses the default address socket endpoint from the environment variable to get the Workload API address.
|
||||
*
|
||||
* @return an instance of {@link DefaultJwtSource}, with the JWT bundles initialized
|
||||
* @throws SocketEndpointAddressException if the address to the Workload API is not valid
|
||||
* @throws JwtSourceException if the source could not be initialized
|
||||
*/
|
||||
public static JwtSource newSource() throws JwtSourceException, SocketEndpointAddressException {
|
||||
JwtSourceOptions options = JwtSourceOptions.builder().initTimeout(DEFAULT_TIMEOUT).build();
|
||||
return newSource(options);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new JWT source. It blocks until the initial update with the JWT bundles
|
||||
* has been received from the Workload API, doing retries with an exponential backoff policy,
|
||||
* or until the initTimeout has expired.
|
||||
* <p>
|
||||
* If the timeout is not provided in the options, the default timeout is read from the
|
||||
* system property `spiffe.newJwtSource.timeout`. If none is configured, this method will
|
||||
* block until the JWT bundles can be retrieved from the Workload API.
|
||||
* <p>
|
||||
* The {@link WorkloadApiClient} can be provided in the options, if it is not,
|
||||
* a new client is created.
|
||||
*
|
||||
* @param options {@link JwtSourceOptions}
|
||||
* @return an instance of {@link CachedJwtSource}, with the JWT bundles initialized
|
||||
* @throws SocketEndpointAddressException if the address to the Workload API is not valid
|
||||
* @throws JwtSourceException if the source could not be initialized
|
||||
*/
|
||||
public static JwtSource newSource(@NonNull final JwtSourceOptions options)
|
||||
throws SocketEndpointAddressException, JwtSourceException {
|
||||
if (options.getWorkloadApiClient() == null) {
|
||||
options.setWorkloadApiClient(createClient(options));
|
||||
}
|
||||
|
||||
if (options.getInitTimeout() == null) {
|
||||
options.setInitTimeout(DEFAULT_TIMEOUT);
|
||||
}
|
||||
|
||||
CachedJwtSource jwtSource = new CachedJwtSource(options.getWorkloadApiClient());
|
||||
|
||||
try {
|
||||
jwtSource.init(options.getInitTimeout());
|
||||
} catch (Exception e) {
|
||||
jwtSource.close();
|
||||
throw new JwtSourceException("Error creating JWT source", e);
|
||||
}
|
||||
|
||||
return jwtSource;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches a JWT SVID for the given audiences. The JWT SVID is cached and
|
||||
* returned from the cache if it still has at least half of its lifetime.
|
||||
*
|
||||
* @param audience the audience
|
||||
* @param extraAudiences a list of extra audiences as an array of String
|
||||
* @return a {@link JwtSvid}
|
||||
* @throws JwtSvidException
|
||||
*/
|
||||
@Override
|
||||
public JwtSvid fetchJwtSvid(final String audience, final String... extraAudiences) throws JwtSvidException {
|
||||
if (isClosed()) {
|
||||
throw new IllegalStateException("JWT SVID source is closed");
|
||||
}
|
||||
|
||||
return getJwtSvids(audience, extraAudiences).get(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches a JWT SVID for the given subject and audience. The JWT SVID is cached and
|
||||
* returned from cache if it has still at least half of its lifetime.
|
||||
*
|
||||
* @return a {@link JwtSvid}
|
||||
* @throws IllegalStateException if the source is closed
|
||||
*/
|
||||
@Override
|
||||
public JwtSvid fetchJwtSvid(final SpiffeId subject, final String audience, final String... extraAudiences)
|
||||
throws JwtSvidException {
|
||||
if (isClosed()) {
|
||||
throw new IllegalStateException("JWT SVID source is closed");
|
||||
}
|
||||
|
||||
return getJwtSvids(subject, audience, extraAudiences).get(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches a list of JWT SVIDs for the given audience. The JWT SVIDs are cached and
|
||||
* returned from cache if they have still at least half of their lifetime.
|
||||
*
|
||||
* @return a list of {@link JwtSvid}s
|
||||
* @throws IllegalStateException if the source is closed
|
||||
*/
|
||||
@Override
|
||||
public List<JwtSvid> fetchJwtSvids(final String audience, final String... extraAudiences) throws JwtSvidException {
|
||||
if (isClosed()) {
|
||||
throw new IllegalStateException("JWT SVID source is closed");
|
||||
}
|
||||
|
||||
return getJwtSvids(audience, extraAudiences);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches a list of JWT SVIDs for the given subject and audience. The JWT SVIDs are cached and
|
||||
* returned from cache if they have still at least half of their lifetime.
|
||||
*
|
||||
* @return a list of {@link JwtSvid}s
|
||||
* @throws IllegalStateException if the source is closed
|
||||
*/
|
||||
@Override
|
||||
public List<JwtSvid> fetchJwtSvids(final SpiffeId subject, final String audience, final String... extraAudiences)
|
||||
throws JwtSvidException {
|
||||
if (isClosed()) {
|
||||
throw new IllegalStateException("JWT SVID source is closed");
|
||||
}
|
||||
|
||||
return getJwtSvids(subject, audience, extraAudiences);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the JWT bundle for a given trust domain.
|
||||
*
|
||||
* @return an instance of a {@link X509Bundle}
|
||||
* @throws BundleNotFoundException is there is no bundle for the trust domain provided
|
||||
* @throws IllegalStateException if the source is closed
|
||||
*/
|
||||
@Override
|
||||
public JwtBundle getBundleForTrustDomain(@NonNull final TrustDomain trustDomain) throws BundleNotFoundException {
|
||||
if (isClosed()) {
|
||||
throw new IllegalStateException("JWT bundle source is closed");
|
||||
}
|
||||
return bundles.getBundleForTrustDomain(trustDomain);
|
||||
}
|
||||
|
||||
/**
|
||||
* Closes this source, dropping the connection to the Workload API.
|
||||
* Other source methods will return an error after close has been called.
|
||||
* <p>
|
||||
* It is marked with {@link SneakyThrows} because it is not expected to throw
|
||||
* the checked exception defined on the {@link Closeable} interface.
|
||||
*/
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public void close() {
|
||||
if (!closed) {
|
||||
synchronized (this) {
|
||||
if (!closed) {
|
||||
workloadApiClient.close();
|
||||
closed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the jwtSvids map contains the cacheKey, returns it if it does and the JWT SVID has not passed its half lifetime.
|
||||
// If the cache does not contain the key or the JWT SVID has passed its half lifetime, make a new FetchJWTSVID call to the Workload API,
|
||||
// adds the JWT SVIDs to the cache map and returns them.
|
||||
// Only one thread can fetch new JWT SVIDs and update the cache at a time.
|
||||
private List<JwtSvid> getJwtSvids(SpiffeId subject, String audience, String... extraAudiences) throws JwtSvidException {
|
||||
Set<String> audiencesSet = getAudienceSet(audience, extraAudiences);
|
||||
ImmutablePair<SpiffeId, Set<String>> cacheKey = new ImmutablePair<>(subject, audiencesSet);
|
||||
|
||||
List<JwtSvid> svidList = jwtSvids.get(cacheKey);
|
||||
if (svidList != null && !isTokenPastHalfLifetime(svidList.get(0))) {
|
||||
return svidList;
|
||||
}
|
||||
|
||||
// even using ConcurrentHashMap, there is a possibility of multiple threads trying to fetch new JWT SVIDs at the same time.
|
||||
synchronized (this) {
|
||||
// Check again if the jwtSvids map contains the cacheKey, and return the entry if it exists and the JWT SVID has not passed its half lifetime.
|
||||
// If it does not exist or the JWT-SVID has passed half its lifetime, call the Workload API to fetch new JWT-SVIDs,
|
||||
// add them to the cache map, and return the list of JWT-SVIDs.
|
||||
svidList = jwtSvids.get(cacheKey);
|
||||
if (svidList != null && !isTokenPastHalfLifetime(svidList.get(0))) {
|
||||
return svidList;
|
||||
}
|
||||
|
||||
if (cacheKey.left == null) {
|
||||
svidList = workloadApiClient.fetchJwtSvids(audience, extraAudiences);
|
||||
} else {
|
||||
svidList = workloadApiClient.fetchJwtSvids(cacheKey.left, audience, extraAudiences);
|
||||
}
|
||||
jwtSvids.put(cacheKey, svidList);
|
||||
return svidList;
|
||||
}
|
||||
}
|
||||
|
||||
private List<JwtSvid> getJwtSvids(String audience, String... extraAudiences) throws JwtSvidException {
|
||||
return getJwtSvids(null, audience, extraAudiences);
|
||||
}
|
||||
|
||||
private static Set<String> getAudienceSet(String audience, String[] extraAudiences) {
|
||||
Set<String> audiencesString;
|
||||
if (extraAudiences != null && extraAudiences.length > 0) {
|
||||
audiencesString = new HashSet<>(Arrays.asList(extraAudiences));
|
||||
audiencesString.add(audience);
|
||||
} else {
|
||||
audiencesString = Collections.singleton(audience);
|
||||
}
|
||||
return audiencesString;
|
||||
}
|
||||
|
||||
private boolean isTokenPastHalfLifetime(JwtSvid jwtSvid) {
|
||||
Instant now = clock.instant();
|
||||
val halfLife = new Date(jwtSvid.getExpiry().getTime() - (jwtSvid.getExpiry().getTime() - jwtSvid.getIssuedAt().getTime()) / 2);
|
||||
val halfLifeInstant = Instant.ofEpochMilli(halfLife.getTime());
|
||||
return now.isAfter(halfLifeInstant);
|
||||
}
|
||||
|
||||
|
||||
private void init(final Duration timeout) throws TimeoutException {
|
||||
CountDownLatch done = new CountDownLatch(1);
|
||||
setJwtBundlesWatcher(done);
|
||||
|
||||
boolean success;
|
||||
if (timeout.isZero()) {
|
||||
await(done);
|
||||
success = true;
|
||||
} else {
|
||||
success = await(done, timeout.getSeconds(), TimeUnit.SECONDS);
|
||||
}
|
||||
if (!success) {
|
||||
throw new TimeoutException("Timeout waiting for JWT bundles update");
|
||||
}
|
||||
}
|
||||
|
||||
private void setJwtBundlesWatcher(final CountDownLatch done) {
|
||||
workloadApiClient.watchJwtBundles(new Watcher<JwtBundleSet>() {
|
||||
@Override
|
||||
public void onUpdate(final JwtBundleSet update) {
|
||||
log.log(Level.INFO, "Received JwtBundleSet update");
|
||||
setJwtBundleSet(update);
|
||||
done.countDown();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(final Throwable error) {
|
||||
log.log(Level.SEVERE, "Error in JwtBundleSet watcher", error);
|
||||
done.countDown();
|
||||
throw new WatcherException("Error fetching JwtBundleSet", error);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void setJwtBundleSet(final JwtBundleSet update) {
|
||||
synchronized (this) {
|
||||
this.bundles = update;
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isClosed() {
|
||||
synchronized (this) {
|
||||
return closed;
|
||||
}
|
||||
}
|
||||
|
||||
private static WorkloadApiClient createClient(final JwtSourceOptions options)
|
||||
throws SocketEndpointAddressException {
|
||||
val clientOptions = DefaultWorkloadApiClient.ClientOptions
|
||||
.builder()
|
||||
.spiffeSocketPath(options.getSpiffeSocketPath())
|
||||
.build();
|
||||
return DefaultWorkloadApiClient.newClient(clientOptions);
|
||||
}
|
||||
|
||||
void setClock(Clock clock) {
|
||||
this.clock = clock;
|
||||
}
|
||||
}
|
||||
|
|
@ -3,30 +3,22 @@ package io.spiffe.workloadapi;
|
|||
import io.spiffe.bundle.jwtbundle.JwtBundle;
|
||||
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
|
||||
import io.spiffe.bundle.x509bundle.X509Bundle;
|
||||
import io.spiffe.exception.BundleNotFoundException;
|
||||
import io.spiffe.exception.JwtSourceException;
|
||||
import io.spiffe.exception.JwtSvidException;
|
||||
import io.spiffe.exception.SocketEndpointAddressException;
|
||||
import io.spiffe.exception.WatcherException;
|
||||
import io.spiffe.exception.*;
|
||||
import io.spiffe.spiffeid.SpiffeId;
|
||||
import io.spiffe.spiffeid.TrustDomain;
|
||||
import io.spiffe.svid.jwtsvid.JwtSvid;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.java.Log;
|
||||
import lombok.val;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.logging.Level;
|
||||
import java.util.List;
|
||||
|
||||
import static io.spiffe.workloadapi.internal.ThreadUtils.await;
|
||||
|
||||
|
|
@ -87,18 +79,18 @@ public class DefaultJwtSource implements JwtSource {
|
|||
*/
|
||||
public static JwtSource newSource(@NonNull final JwtSourceOptions options)
|
||||
throws SocketEndpointAddressException, JwtSourceException {
|
||||
if (options.workloadApiClient == null) {
|
||||
options.workloadApiClient = createClient(options);
|
||||
if (options.getWorkloadApiClient()== null) {
|
||||
options.setWorkloadApiClient(createClient(options));
|
||||
}
|
||||
|
||||
if (options.initTimeout == null) {
|
||||
options.initTimeout = DEFAULT_TIMEOUT;
|
||||
if (options.getInitTimeout()== null) {
|
||||
options.setInitTimeout(DEFAULT_TIMEOUT);
|
||||
}
|
||||
|
||||
DefaultJwtSource jwtSource = new DefaultJwtSource(options.workloadApiClient);
|
||||
DefaultJwtSource jwtSource = new DefaultJwtSource(options.getWorkloadApiClient());
|
||||
|
||||
try {
|
||||
jwtSource.init(options.initTimeout);
|
||||
jwtSource.init(options.getInitTimeout());
|
||||
} catch (Exception e) {
|
||||
jwtSource.close();
|
||||
throw new JwtSourceException("Error creating JWT source", e);
|
||||
|
|
@ -241,42 +233,8 @@ public class DefaultJwtSource implements JwtSource {
|
|||
throws SocketEndpointAddressException {
|
||||
val clientOptions = DefaultWorkloadApiClient.ClientOptions
|
||||
.builder()
|
||||
.spiffeSocketPath(options.spiffeSocketPath)
|
||||
.spiffeSocketPath(options.getSpiffeSocketPath())
|
||||
.build();
|
||||
return DefaultWorkloadApiClient.newClient(clientOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Options to configure a {@link DefaultJwtSource}.
|
||||
* <p>
|
||||
* <code>spiffeSocketPath</code> Address to the Workload API, if it is not set, the default address will be used.
|
||||
* <p>
|
||||
* <code>initTimeout</code> Timeout for initializing the instance. If it is not defined, the timeout is read
|
||||
* from the System property `spiffe.newJwtSource.timeout'. If this is also not defined, no default timeout is applied.
|
||||
* <p>
|
||||
* <code>workloadApiClient</code> A custom instance of a {@link WorkloadApiClient}, if it is not set,
|
||||
* a new client will be created.
|
||||
*/
|
||||
@Data
|
||||
public static class JwtSourceOptions {
|
||||
|
||||
@Setter(AccessLevel.NONE)
|
||||
private String spiffeSocketPath;
|
||||
|
||||
@Setter(AccessLevel.NONE)
|
||||
private Duration initTimeout;
|
||||
|
||||
@Setter(AccessLevel.NONE)
|
||||
private WorkloadApiClient workloadApiClient;
|
||||
|
||||
@Builder
|
||||
public JwtSourceOptions(
|
||||
final String spiffeSocketPath,
|
||||
final WorkloadApiClient workloadApiClient,
|
||||
final Duration initTimeout) {
|
||||
this.spiffeSocketPath = spiffeSocketPath;
|
||||
this.workloadApiClient = workloadApiClient;
|
||||
this.initTimeout = initTimeout;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
package io.spiffe.workloadapi;
|
||||
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
/**
|
||||
* Options to configure a {@link JwtSource}.
|
||||
* <p>
|
||||
* <code>spiffeSocketPath</code> Address to the Workload API, if it is not set, the default address will be used.
|
||||
* <p>
|
||||
* <code>initTimeout</code> Timeout for initializing the instance. If it is not defined, the timeout is read
|
||||
* from the System property `spiffe.newJwtSource.timeout'. If this is also not defined, no default timeout is applied.
|
||||
* <p>
|
||||
* <code>workloadApiClient</code> A custom instance of a {@link WorkloadApiClient}, if it is not set,
|
||||
* a new client will be created.
|
||||
*/
|
||||
@Data
|
||||
public class JwtSourceOptions {
|
||||
|
||||
@Setter(AccessLevel.PUBLIC)
|
||||
private String spiffeSocketPath;
|
||||
|
||||
@Setter(AccessLevel.PUBLIC)
|
||||
private Duration initTimeout;
|
||||
|
||||
@Setter(AccessLevel.PUBLIC)
|
||||
private WorkloadApiClient workloadApiClient;
|
||||
|
||||
@Builder
|
||||
public JwtSourceOptions(
|
||||
final String spiffeSocketPath,
|
||||
final WorkloadApiClient workloadApiClient,
|
||||
final Duration initTimeout) {
|
||||
this.spiffeSocketPath = spiffeSocketPath;
|
||||
this.workloadApiClient = workloadApiClient;
|
||||
this.initTimeout = initTimeout;
|
||||
}
|
||||
}
|
||||
|
|
@ -124,6 +124,7 @@ class JwtSvidParseAndValidateTest {
|
|||
jwtBundle.putJwtAuthority("authority3", key3.getPublic());
|
||||
|
||||
SpiffeId spiffeId = trustDomain.newSpiffeId("host");
|
||||
Date issuedAt = new Date();
|
||||
Date expiration = new Date(System.currentTimeMillis() + (60 * 60 * 1000));
|
||||
Set<String> audience = new HashSet<String>() {{add("audience1"); add("audience2");}};
|
||||
|
||||
|
|
@ -139,6 +140,7 @@ class JwtSvidParseAndValidateTest {
|
|||
.expectedJwtSvid(newJwtSvidInstance(
|
||||
trustDomain.newSpiffeId("host"),
|
||||
audience,
|
||||
issuedAt,
|
||||
expiration,
|
||||
claims.getClaims(), TestUtils.generateToken(claims, key1, "authority1", JwtSvid.HEADER_TYP_JOSE) ))
|
||||
.build()),
|
||||
|
|
@ -151,6 +153,7 @@ class JwtSvidParseAndValidateTest {
|
|||
.expectedJwtSvid(newJwtSvidInstance(
|
||||
trustDomain.newSpiffeId("host"),
|
||||
audience,
|
||||
issuedAt,
|
||||
expiration,
|
||||
claims.getClaims(), TestUtils.generateToken(claims, key3, "authority3", JwtSvid.HEADER_TYP_JWT)))
|
||||
.build()),
|
||||
|
|
@ -163,6 +166,7 @@ class JwtSvidParseAndValidateTest {
|
|||
.expectedJwtSvid(newJwtSvidInstance(
|
||||
trustDomain.newSpiffeId("host"),
|
||||
audience,
|
||||
issuedAt,
|
||||
expiration,
|
||||
claims.getClaims(), TestUtils.generateToken(claims, key3, "authority3")))
|
||||
.build())
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ class JwtSvidParseInsecureTest {
|
|||
|
||||
SpiffeId spiffeId = trustDomain.newSpiffeId("host");
|
||||
Date expiration = new Date(System.currentTimeMillis() + 3600000);
|
||||
Date issuedAt = new Date();
|
||||
Set<String> audience = Collections.singleton("audience");
|
||||
|
||||
JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration);
|
||||
|
|
@ -125,6 +126,7 @@ class JwtSvidParseInsecureTest {
|
|||
.expectedJwtSvid(newJwtSvidInstance(
|
||||
trustDomain.newSpiffeId("host"),
|
||||
audience,
|
||||
issuedAt,
|
||||
expiration,
|
||||
claims.getClaims(), TestUtils.generateToken(claims, key1, "authority1", JwtSvid.HEADER_TYP_JWT)))
|
||||
.build()),
|
||||
|
|
@ -136,6 +138,7 @@ class JwtSvidParseInsecureTest {
|
|||
.expectedJwtSvid(newJwtSvidInstance(
|
||||
trustDomain.newSpiffeId("host"),
|
||||
audience,
|
||||
issuedAt,
|
||||
expiration,
|
||||
claims.getClaims(), TestUtils.generateToken(claims, key1, "authority1", JwtSvid.HEADER_TYP_JWT)))
|
||||
.build()),
|
||||
|
|
@ -147,6 +150,7 @@ class JwtSvidParseInsecureTest {
|
|||
.expectedJwtSvid(newJwtSvidInstance(
|
||||
trustDomain.newSpiffeId("host"),
|
||||
audience,
|
||||
issuedAt,
|
||||
expiration,
|
||||
claims.getClaims(), TestUtils.generateToken(claims, key1, "authority1", "")))
|
||||
.build()));
|
||||
|
|
@ -234,13 +238,14 @@ class JwtSvidParseInsecureTest {
|
|||
|
||||
static JwtSvid newJwtSvidInstance(final SpiffeId spiffeId,
|
||||
final Set<String> audience,
|
||||
final Date issuedAt,
|
||||
final Date expiry,
|
||||
final Map<String, Object> claims,
|
||||
final String token) {
|
||||
val constructor = JwtSvid.class.getDeclaredConstructors()[0];
|
||||
constructor.setAccessible(true);
|
||||
try {
|
||||
return (JwtSvid) constructor.newInstance(spiffeId, audience, expiry, claims, token);
|
||||
return (JwtSvid) constructor.newInstance(spiffeId, audience, issuedAt, expiry, claims, token);
|
||||
} catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,522 @@
|
|||
package io.spiffe.workloadapi;
|
||||
|
||||
import com.google.common.collect.Sets;
|
||||
import io.spiffe.bundle.jwtbundle.JwtBundle;
|
||||
import io.spiffe.exception.BundleNotFoundException;
|
||||
import io.spiffe.exception.JwtSourceException;
|
||||
import io.spiffe.exception.JwtSvidException;
|
||||
import io.spiffe.exception.SocketEndpointAddressException;
|
||||
import io.spiffe.spiffeid.SpiffeId;
|
||||
import io.spiffe.spiffeid.TrustDomain;
|
||||
import io.spiffe.svid.jwtsvid.JwtSvid;
|
||||
import io.spiffe.utils.TestUtils;
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.time.ZoneId;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Future;
|
||||
|
||||
import static io.spiffe.workloadapi.WorkloadApiClientStub.JWT_TTL;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class CachedJwtSourceTest {
|
||||
private CachedJwtSource jwtSource;
|
||||
private WorkloadApiClientStub workloadApiClient;
|
||||
private WorkloadApiClientErrorStub workloadApiClientErrorStub;
|
||||
private Clock clock;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws JwtSourceException, SocketEndpointAddressException {
|
||||
workloadApiClient = new WorkloadApiClientStub();
|
||||
JwtSourceOptions options = JwtSourceOptions.builder().workloadApiClient(workloadApiClient).build();
|
||||
System.setProperty(DefaultJwtSource.TIMEOUT_SYSTEM_PROPERTY, "PT1S");
|
||||
jwtSource = (CachedJwtSource) CachedJwtSource.newSource(options);
|
||||
workloadApiClientErrorStub = new WorkloadApiClientErrorStub();
|
||||
|
||||
clock = Clock.fixed(Instant.now(), ZoneId.systemDefault());
|
||||
workloadApiClient.setClock(clock);
|
||||
jwtSource.setClock(clock);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() throws IOException {
|
||||
jwtSource.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetBundleForTrustDomain() {
|
||||
try {
|
||||
JwtBundle bundle = jwtSource.getBundleForTrustDomain(TrustDomain.parse("example.org"));
|
||||
assertNotNull(bundle);
|
||||
assertEquals(TrustDomain.parse("example.org"), bundle.getTrustDomain());
|
||||
} catch (BundleNotFoundException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetBundleForTrustDomain_nullParam() {
|
||||
try {
|
||||
jwtSource.getBundleForTrustDomain(null);
|
||||
fail();
|
||||
} catch (NullPointerException e) {
|
||||
assertEquals("trustDomain is marked non-null but is null", e.getMessage());
|
||||
} catch (BundleNotFoundException e) {
|
||||
fail();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetBundleForTrustDomain_SourceIsClosed_ThrowsIllegalStateException() throws IOException {
|
||||
jwtSource.close();
|
||||
try {
|
||||
jwtSource.getBundleForTrustDomain(TrustDomain.parse("example.org"));
|
||||
fail("expected exception");
|
||||
} catch (IllegalStateException e) {
|
||||
assertEquals("JWT bundle source is closed", e.getMessage());
|
||||
assertTrue(workloadApiClient.closed);
|
||||
} catch (BundleNotFoundException e) {
|
||||
fail("not expected exception", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvidWithSubject() {
|
||||
try {
|
||||
JwtSvid svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvidWithSubject_ReturnFromCache() {
|
||||
try {
|
||||
JwtSvid svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud3", "aud2", "aud1");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// call again to get from cache changing the order of the audiences
|
||||
svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// call again using different subject
|
||||
svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/extra-workload-server"), "aud2", "aud3", "aud1");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// call again using the same audiences
|
||||
svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/extra-workload-server"), "aud1", "aud2", "aud3");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvidWithSubject_JwtSvidExpiredInCache() {
|
||||
try {
|
||||
JwtSvid svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// set clock forwards but not enough to expire the JWT SVID in the cache
|
||||
jwtSource.setClock(clock.offset(clock, JWT_TTL.dividedBy(2).minus(Duration.ofSeconds(1))));
|
||||
|
||||
// call again to get from cache, fetchJwtSvid call count should not change
|
||||
svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// set clock to expire the JWT SVID in the cache
|
||||
jwtSource.setClock(clock.offset(clock, JWT_TTL.dividedBy(2).plus(Duration.ofSeconds(1))));
|
||||
|
||||
// call again, fetchJwtSvid call count should increase
|
||||
svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvidWithSubject_JwtSvidExpiredInCache_MultipleThreads() {
|
||||
// test fetchJwtSvid with several threads trying to read and write the cache
|
||||
// at the same time, the cache should be updated only once
|
||||
try {
|
||||
|
||||
jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
|
||||
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// set clock to expire the JWT SVID in the cache
|
||||
Clock offset = Clock.offset(clock, JWT_TTL.dividedBy(2).plus(Duration.ofSeconds(1)));
|
||||
jwtSource.setClock(offset);
|
||||
workloadApiClient.setClock(offset);
|
||||
|
||||
// create a thread pool with 10 threads
|
||||
ExecutorService executorService = Executors.newFixedThreadPool(10);
|
||||
|
||||
List<Future<JwtSvid>> futures = new ArrayList<>();
|
||||
|
||||
// create 10 tasks to fetch a JWT SVID
|
||||
for (int i = 0; i < 10; i++) {
|
||||
futures.add(executorService.submit(() -> jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3")));
|
||||
}
|
||||
|
||||
// wait for all tasks to finish
|
||||
for (Future<JwtSvid> future : futures) {
|
||||
future.get();
|
||||
}
|
||||
|
||||
// verify that the cache was updated only once after the JWT SVID expired
|
||||
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
} catch (InterruptedException | ExecutionException | JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvidWithoutSubject() {
|
||||
try {
|
||||
JwtSvid svid = jwtSource.fetchJwtSvid("aud1", "aud2", "aud3");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvidWithoutSubject_ReturnFromCache() {
|
||||
try {
|
||||
JwtSvid svid = jwtSource.fetchJwtSvid("aud1", "aud2", "aud3");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// call again to get from cache changing the order of the audiences, the call count should not change
|
||||
svid = jwtSource.fetchJwtSvid("aud3", "aud2", "aud1");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// call again using different audience, the call count should increase
|
||||
svid = jwtSource.fetchJwtSvid("other-audience");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("other-audience"), svid.getAudience());
|
||||
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvidWithoutSubject_JwtSvidExpiredInCache() {
|
||||
try {
|
||||
JwtSvid svid = jwtSource.fetchJwtSvid("aud1", "aud2", "aud3");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// set clock forwards but not enough to expire the JWT SVID in the cache
|
||||
jwtSource.setClock(clock.offset(clock, JWT_TTL.dividedBy(2).minus(Duration.ofSeconds(1))));
|
||||
|
||||
// call again to get from cache, fetchJwtSvid call count should not change
|
||||
svid = jwtSource.fetchJwtSvid("aud3", "aud2", "aud1");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// set clock forwards to expire the JWT SVID in the cache
|
||||
jwtSource.setClock(clock.offset(clock, JWT_TTL.dividedBy(2).plus(Duration.ofSeconds(1))));
|
||||
|
||||
// call again, fetchJwtSvid call count should increase
|
||||
svid = jwtSource.fetchJwtSvid("aud1", "aud2", "aud3");
|
||||
assertNotNull(svid);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
|
||||
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvid_SourceIsClosed_ThrowsIllegalStateException() throws IOException {
|
||||
jwtSource.close();
|
||||
try {
|
||||
jwtSource.fetchJwtSvid("aud1", "aud2", "aud3");
|
||||
fail("expected exception");
|
||||
} catch (IllegalStateException e) {
|
||||
assertEquals("JWT SVID source is closed", e.getMessage());
|
||||
assertTrue(workloadApiClient.closed);
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvidWithSubject_SourceIsClosed_ThrowsIllegalStateException() throws IOException {
|
||||
jwtSource.close();
|
||||
try {
|
||||
jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
|
||||
fail("expected exception");
|
||||
} catch (IllegalStateException e) {
|
||||
assertEquals("JWT SVID source is closed", e.getMessage());
|
||||
assertTrue(workloadApiClient.closed);
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvidsWithSubject() {
|
||||
try {
|
||||
List<JwtSvid> svids = jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
|
||||
assertNotNull(svids);
|
||||
assertEquals(1, svids.size());
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvidsWithSubject_ReturnFromCache() {
|
||||
try {
|
||||
List<JwtSvid> svids = jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
|
||||
assertNotNull(svids);
|
||||
assertEquals(1, svids.size());
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
|
||||
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// call again to get from cache changing the order of the audiences
|
||||
svids = jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
|
||||
assertNotNull(svids);
|
||||
assertEquals(1, svids.size());
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud3", "aud2", "aud1"), svids.get(0).getAudience());
|
||||
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// call again using different audience
|
||||
svids = jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "other-audience");
|
||||
assertNotNull(svids);
|
||||
assertEquals(1, svids.size());
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("other-audience"), svids.get(0).getAudience());
|
||||
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvidsWithoutSubject() {
|
||||
try {
|
||||
List<JwtSvid> svids = jwtSource.fetchJwtSvids("aud1", "aud2", "aud3");
|
||||
assertNotNull(svids);
|
||||
assertEquals(svids.size(), 2);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(1).getAudience());
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvidsWithoutSubject_ReturnFromCache() {
|
||||
try {
|
||||
List<JwtSvid> svids = jwtSource.fetchJwtSvids("aud1", "aud2", "aud3");
|
||||
assertNotNull(svids);
|
||||
assertEquals(svids.size(), 2);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(1).getAudience());
|
||||
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// call again to get from cache changing the order of the audiences
|
||||
svids = jwtSource.fetchJwtSvids("aud2", "aud3", "aud1");
|
||||
assertNotNull(svids);
|
||||
assertEquals(svids.size(), 2);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(1).getAudience());
|
||||
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
|
||||
// call again using different audience
|
||||
svids = jwtSource.fetchJwtSvids("other-audience");
|
||||
assertNotNull(svids);
|
||||
assertEquals(svids.size(), 2);
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("other-audience"), svids.get(0).getAudience());
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("other-audience"), svids.get(1).getAudience());
|
||||
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvids_SourceIsClosed_ThrowsIllegalStateException() throws IOException {
|
||||
jwtSource.close();
|
||||
try {
|
||||
jwtSource.fetchJwtSvids("aud1", "aud2", "aud3");
|
||||
fail("expected exception");
|
||||
} catch (IllegalStateException e) {
|
||||
assertEquals("JWT SVID source is closed", e.getMessage());
|
||||
assertTrue(workloadApiClient.closed);
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFetchJwtSvidsWithSubject_SourceIsClosed_ThrowsIllegalStateException() throws IOException {
|
||||
jwtSource.close();
|
||||
try {
|
||||
jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
|
||||
fail("expected exception");
|
||||
} catch (IllegalStateException e) {
|
||||
assertEquals("JWT SVID source is closed", e.getMessage());
|
||||
assertTrue(workloadApiClient.closed);
|
||||
} catch (JwtSvidException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void newSource_success() {
|
||||
val options = JwtSourceOptions
|
||||
.builder()
|
||||
.workloadApiClient(workloadApiClient)
|
||||
.initTimeout(Duration.ofSeconds(0))
|
||||
.build();
|
||||
try {
|
||||
JwtSource jwtSource = DefaultJwtSource.newSource(options);
|
||||
assertNotNull(jwtSource);
|
||||
} catch (SocketEndpointAddressException | JwtSourceException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void newSource_nullParam() {
|
||||
try {
|
||||
DefaultJwtSource.newSource(null);
|
||||
fail();
|
||||
} catch (NullPointerException e) {
|
||||
assertEquals("options is marked non-null but is null", e.getMessage());
|
||||
} catch (SocketEndpointAddressException | JwtSourceException e) {
|
||||
fail();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void newSource_errorFetchingJwtBundles() {
|
||||
val options = JwtSourceOptions
|
||||
.builder()
|
||||
.workloadApiClient(workloadApiClientErrorStub)
|
||||
.spiffeSocketPath("unix:/tmp/test")
|
||||
.build();
|
||||
try {
|
||||
DefaultJwtSource.newSource(options);
|
||||
fail();
|
||||
} catch (JwtSourceException e) {
|
||||
assertEquals("Error creating JWT source", e.getMessage());
|
||||
assertEquals("Error fetching JwtBundleSet", e.getCause().getMessage());
|
||||
} catch (Exception e) {
|
||||
fail();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void newSource_FailsBecauseOfTimeOut() throws Exception {
|
||||
try {
|
||||
val options = JwtSourceOptions
|
||||
.builder()
|
||||
.spiffeSocketPath("unix:/tmp/test")
|
||||
.build();
|
||||
DefaultJwtSource.newSource(options);
|
||||
fail();
|
||||
} catch (JwtSourceException e) {
|
||||
assertEquals("Error creating JWT source", e.getMessage());
|
||||
assertEquals("Timeout waiting for JWT bundles update", e.getCause().getMessage());
|
||||
} catch (SocketEndpointAddressException e) {
|
||||
fail();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void newSource_DefaultSocketAddress() throws Exception {
|
||||
try {
|
||||
TestUtils.setEnvironmentVariable(Address.SOCKET_ENV_VARIABLE, "unix:/tmp/test");
|
||||
DefaultJwtSource.newSource();
|
||||
fail();
|
||||
} catch (JwtSourceException e) {
|
||||
assertEquals("Error creating JWT source", e.getMessage());
|
||||
} catch (SocketEndpointAddressException e) {
|
||||
fail();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void newSource_noSocketAddress() throws Exception {
|
||||
try {
|
||||
// just in case it's defined in the environment
|
||||
TestUtils.setEnvironmentVariable(Address.SOCKET_ENV_VARIABLE, "");
|
||||
DefaultJwtSource.newSource();
|
||||
fail();
|
||||
} catch (SocketEndpointAddressException e) {
|
||||
fail();
|
||||
} catch (IllegalStateException e) {
|
||||
assertEquals("Endpoint Socket Address Environment Variable is not set: SPIFFE_ENDPOINT_SOCKET", e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -24,7 +24,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull;
|
|||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
class JwtSourceTest {
|
||||
class DefaultJwtSourceTest {
|
||||
|
||||
private JwtSource jwtSource;
|
||||
private WorkloadApiClientStub workloadApiClient;
|
||||
|
|
@ -33,7 +33,7 @@ class JwtSourceTest {
|
|||
@BeforeEach
|
||||
void setUp() throws JwtSourceException, SocketEndpointAddressException {
|
||||
workloadApiClient = new WorkloadApiClientStub();
|
||||
DefaultJwtSource.JwtSourceOptions options = DefaultJwtSource.JwtSourceOptions.builder().workloadApiClient(workloadApiClient).build();
|
||||
JwtSourceOptions options = JwtSourceOptions.builder().workloadApiClient(workloadApiClient).build();
|
||||
System.setProperty(DefaultJwtSource.TIMEOUT_SYSTEM_PROPERTY, "PT1S");
|
||||
jwtSource = DefaultJwtSource.newSource(options);
|
||||
workloadApiClientErrorStub = new WorkloadApiClientErrorStub();
|
||||
|
|
@ -152,7 +152,7 @@ class JwtSourceTest {
|
|||
try {
|
||||
List<JwtSvid> svids = jwtSource.fetchJwtSvids("aud1", "aud2", "aud3");
|
||||
assertNotNull(svids);
|
||||
assertEquals(svids.size(), 2);
|
||||
assertEquals(2, svids.size());
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
|
||||
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
|
||||
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId());
|
||||
|
|
@ -193,7 +193,7 @@ class JwtSourceTest {
|
|||
|
||||
@Test
|
||||
void newSource_success() {
|
||||
val options = DefaultJwtSource.JwtSourceOptions
|
||||
val options = JwtSourceOptions
|
||||
.builder()
|
||||
.workloadApiClient(workloadApiClient)
|
||||
.initTimeout(Duration.ofSeconds(0))
|
||||
|
|
@ -220,7 +220,7 @@ class JwtSourceTest {
|
|||
|
||||
@Test
|
||||
void newSource_errorFetchingJwtBundles() {
|
||||
val options = DefaultJwtSource.JwtSourceOptions
|
||||
val options = JwtSourceOptions
|
||||
.builder()
|
||||
.workloadApiClient(workloadApiClientErrorStub)
|
||||
.spiffeSocketPath("unix:/tmp/test")
|
||||
|
|
@ -239,7 +239,7 @@ class JwtSourceTest {
|
|||
@Test
|
||||
void newSource_FailsBecauseOfTimeOut() throws Exception {
|
||||
try {
|
||||
val options = DefaultJwtSource.JwtSourceOptions
|
||||
val options = JwtSourceOptions
|
||||
.builder()
|
||||
.spiffeSocketPath("unix:/tmp/test")
|
||||
.build();
|
||||
|
|
@ -23,12 +23,15 @@ import java.nio.file.Files;
|
|||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.security.KeyPair;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.util.*;
|
||||
|
||||
import static io.spiffe.utils.TestUtils.toUri;
|
||||
|
||||
public class WorkloadApiClientStub implements WorkloadApiClient {
|
||||
|
||||
static final Duration JWT_TTL = Duration.ofSeconds(60);
|
||||
final String privateKey = "testdata/workloadapi/svid.key.der";
|
||||
final String svid = "testdata/workloadapi/svid.der";
|
||||
final String x509Bundle = "testdata/workloadapi/bundle.der";
|
||||
|
|
@ -36,8 +39,12 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
|
|||
final SpiffeId subject = SpiffeId.parse("spiffe://example.org/workload-server");
|
||||
final SpiffeId extraSubject = SpiffeId.parse("spiffe://example.org/extra-workload-server");
|
||||
|
||||
int fetchJwtSvidCallCount = 0;
|
||||
|
||||
boolean closed;
|
||||
|
||||
Clock clock = Clock.systemDefaultZone();
|
||||
|
||||
@Override
|
||||
public X509Context fetchX509Context() {
|
||||
return generateX509Context();
|
||||
|
|
@ -62,16 +69,19 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
|
|||
|
||||
@Override
|
||||
public JwtSvid fetchJwtSvid(@NonNull final String audience, final String... extraAudience) throws JwtSvidException {
|
||||
fetchJwtSvidCallCount++;
|
||||
return generateJwtSvid(subject, audience, extraAudience);
|
||||
}
|
||||
|
||||
@Override
|
||||
public JwtSvid fetchJwtSvid(@NonNull final SpiffeId subject, @NonNull final String audience, final String... extraAudience) throws JwtSvidException {
|
||||
fetchJwtSvidCallCount++;
|
||||
return generateJwtSvid(subject, audience, extraAudience);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<JwtSvid> fetchJwtSvids(@NonNull String audience, String... extraAudience) throws JwtSvidException {
|
||||
fetchJwtSvidCallCount++;
|
||||
List<JwtSvid> svids = new ArrayList<>();
|
||||
svids.add(generateJwtSvid(subject, audience, extraAudience));
|
||||
svids.add(generateJwtSvid(extraSubject, audience, extraAudience));
|
||||
|
|
@ -80,6 +90,7 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
|
|||
|
||||
@Override
|
||||
public List<JwtSvid> fetchJwtSvids(@NonNull SpiffeId subject, @NonNull String audience, String... extraAudience) throws JwtSvidException {
|
||||
fetchJwtSvidCallCount++;
|
||||
List<JwtSvid> svids = new ArrayList<>();
|
||||
svids.add(generateJwtSvid(subject, audience, extraAudience));
|
||||
return svids;
|
||||
|
|
@ -132,8 +143,9 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
|
|||
Map<String, Object> claims = new HashMap<>();
|
||||
claims.put("sub", subject.toString());
|
||||
claims.put("aud", new ArrayList<>(audParam));
|
||||
Date expiration = new Date(System.currentTimeMillis() + 3600000);
|
||||
claims.put("exp", expiration);
|
||||
|
||||
claims.put("iat", new Date(clock.millis()));
|
||||
claims.put("exp", new Date(clock.millis() + JWT_TTL.toMillis()));
|
||||
|
||||
KeyPair keyPair = TestUtils.generateECKeyPair(Curve.P_521);
|
||||
|
||||
|
|
@ -178,4 +190,20 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
|
|||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
void resetFetchJwtSvidCallCount() {
|
||||
fetchJwtSvidCallCount = 0;
|
||||
}
|
||||
|
||||
int getFetchJwtSvidCallCount() {
|
||||
return fetchJwtSvidCallCount;
|
||||
}
|
||||
|
||||
Clock getClock() {
|
||||
return clock;
|
||||
}
|
||||
|
||||
void setClock(Clock clock) {
|
||||
this.clock = clock;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -101,6 +101,7 @@ public class TestUtils {
|
|||
public static JWTClaimsSet buildJWTClaimSetFromClaimsMap(Map<String, Object> claims) {
|
||||
return new JWTClaimsSet.Builder()
|
||||
.subject((String) claims.get("sub"))
|
||||
.issueTime((Date) claims.get("iat"))
|
||||
.expirationTime((Date) claims.get("exp"))
|
||||
.audience((List<String>) claims.get("aud"))
|
||||
.build();
|
||||
|
|
|
|||
Loading…
Reference in New Issue