Add `CachedJwtSource` (#116)

Add CachedJwtSource

Signed-off-by: Max Lambrecht <max.lambrecht@hpe.com>
This commit is contained in:
Max Lambrecht 2023-04-04 13:49:32 -05:00 committed by GitHub
parent 5e16f7a632
commit 3ca77c1de2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 974 additions and 62 deletions

View File

@ -64,11 +64,17 @@ public class JwtSvid {
*/
String token;
/**
* Issued at time of JWT-SVID as present in 'iat' claim.
*/
Date issuedAt;
public static final String HEADER_TYP_JWT = "JWT";
public static final String HEADER_TYP_JOSE = "JOSE";
private JwtSvid(final SpiffeId spiffeId,
final Set<String> audience,
final Date issuedAt,
final Date expiry,
final Map<String, Object> claims,
final String token) {
@ -77,6 +83,7 @@ public class JwtSvid {
this.expiry = expiry;
this.claims = claims;
this.token = token;
this.issuedAt = issuedAt;
}
/**
@ -120,6 +127,8 @@ public class JwtSvid {
val claimsSet = getJwtClaimsSet(signedJwt);
validateAudience(claimsSet.getAudience(), audience);
val issuedAt = claimsSet.getIssueTime();
val expirationTime = claimsSet.getExpirationTime();
validateExpiration(expirationTime);
@ -132,7 +141,8 @@ public class JwtSvid {
verifySignature(signedJwt, jwtAuthority, algorithm, keyId);
val claimAudience = new HashSet<>(claimsSet.getAudience());
return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token);
return new JwtSvid(spiffeId, claimAudience, issuedAt, expirationTime, claimsSet.getClaims(), token);
}
/**
@ -163,13 +173,16 @@ public class JwtSvid {
val claimsSet = getJwtClaimsSet(signedJwt);
validateAudience(claimsSet.getAudience(), audience);
val issuedAt = claimsSet.getIssueTime();
val expirationTime = claimsSet.getExpirationTime();
validateExpiration(expirationTime);
val spiffeId = getSpiffeIdOfSubject(claimsSet);
val claimAudience = new HashSet<>(claimsSet.getAudience());
return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token);
return new JwtSvid(spiffeId, claimAudience, issuedAt, expirationTime, claimsSet.getClaims(), token);
}
/**

View File

@ -0,0 +1,338 @@
package io.spiffe.workloadapi;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.exception.*;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.svid.jwtsvid.JwtSvid;
import lombok.NonNull;
import lombok.SneakyThrows;
import lombok.extern.java.Log;
import lombok.val;
import org.apache.commons.lang3.tuple.ImmutablePair;
import java.io.Closeable;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.logging.Level;
import static io.spiffe.workloadapi.internal.ThreadUtils.await;
/**
* Represents a source of SPIFFE JWT SVIDs and JWT bundles maintained via the Workload API.
* The JWT SVIDs are cached and fetchJwtSvid methods return from cache
* checking that the JWT SVID has still at least half of its lifetime.
*/
@Log
public class CachedJwtSource implements JwtSource {
static final String TIMEOUT_SYSTEM_PROPERTY = "spiffe.newJwtSource.timeout";
static final Duration DEFAULT_TIMEOUT =
Duration.parse(System.getProperty(TIMEOUT_SYSTEM_PROPERTY, "PT0S"));
// Synchronized map of JWT SVIDs, keyed by a pair of SPIFFE ID and a Set of audiences strings.
// This map is used to cache the JWT SVIDs and avoid fetching them from the Workload API.
private final
Map<ImmutablePair<SpiffeId, Set<String>>, List<JwtSvid>> jwtSvids = new ConcurrentHashMap<>();
private JwtBundleSet bundles;
private final WorkloadApiClient workloadApiClient;
private volatile boolean closed;
private Clock clock;
// private constructor
private CachedJwtSource(final WorkloadApiClient workloadApiClient) {
this.clock = Clock.systemDefaultZone();
this.workloadApiClient = workloadApiClient;
}
/**
* Creates a new Cached JWT source. It blocks until the initial update with the JWT bundles
* has been received from the Workload API or until the timeout configured
* through the system property `spiffe.newJwtSource.timeout` expires.
* If no timeout is configured, it blocks until it gets a JWT update from the Workload API.
* <p>
* It uses the default address socket endpoint from the environment variable to get the Workload API address.
*
* @return an instance of {@link DefaultJwtSource}, with the JWT bundles initialized
* @throws SocketEndpointAddressException if the address to the Workload API is not valid
* @throws JwtSourceException if the source could not be initialized
*/
public static JwtSource newSource() throws JwtSourceException, SocketEndpointAddressException {
JwtSourceOptions options = JwtSourceOptions.builder().initTimeout(DEFAULT_TIMEOUT).build();
return newSource(options);
}
/**
* Creates a new JWT source. It blocks until the initial update with the JWT bundles
* has been received from the Workload API, doing retries with an exponential backoff policy,
* or until the initTimeout has expired.
* <p>
* If the timeout is not provided in the options, the default timeout is read from the
* system property `spiffe.newJwtSource.timeout`. If none is configured, this method will
* block until the JWT bundles can be retrieved from the Workload API.
* <p>
* The {@link WorkloadApiClient} can be provided in the options, if it is not,
* a new client is created.
*
* @param options {@link JwtSourceOptions}
* @return an instance of {@link CachedJwtSource}, with the JWT bundles initialized
* @throws SocketEndpointAddressException if the address to the Workload API is not valid
* @throws JwtSourceException if the source could not be initialized
*/
public static JwtSource newSource(@NonNull final JwtSourceOptions options)
throws SocketEndpointAddressException, JwtSourceException {
if (options.getWorkloadApiClient() == null) {
options.setWorkloadApiClient(createClient(options));
}
if (options.getInitTimeout() == null) {
options.setInitTimeout(DEFAULT_TIMEOUT);
}
CachedJwtSource jwtSource = new CachedJwtSource(options.getWorkloadApiClient());
try {
jwtSource.init(options.getInitTimeout());
} catch (Exception e) {
jwtSource.close();
throw new JwtSourceException("Error creating JWT source", e);
}
return jwtSource;
}
/**
* Fetches a JWT SVID for the given audiences. The JWT SVID is cached and
* returned from the cache if it still has at least half of its lifetime.
*
* @param audience the audience
* @param extraAudiences a list of extra audiences as an array of String
* @return a {@link JwtSvid}
* @throws JwtSvidException
*/
@Override
public JwtSvid fetchJwtSvid(final String audience, final String... extraAudiences) throws JwtSvidException {
if (isClosed()) {
throw new IllegalStateException("JWT SVID source is closed");
}
return getJwtSvids(audience, extraAudiences).get(0);
}
/**
* Fetches a JWT SVID for the given subject and audience. The JWT SVID is cached and
* returned from cache if it has still at least half of its lifetime.
*
* @return a {@link JwtSvid}
* @throws IllegalStateException if the source is closed
*/
@Override
public JwtSvid fetchJwtSvid(final SpiffeId subject, final String audience, final String... extraAudiences)
throws JwtSvidException {
if (isClosed()) {
throw new IllegalStateException("JWT SVID source is closed");
}
return getJwtSvids(subject, audience, extraAudiences).get(0);
}
/**
* Fetches a list of JWT SVIDs for the given audience. The JWT SVIDs are cached and
* returned from cache if they have still at least half of their lifetime.
*
* @return a list of {@link JwtSvid}s
* @throws IllegalStateException if the source is closed
*/
@Override
public List<JwtSvid> fetchJwtSvids(final String audience, final String... extraAudiences) throws JwtSvidException {
if (isClosed()) {
throw new IllegalStateException("JWT SVID source is closed");
}
return getJwtSvids(audience, extraAudiences);
}
/**
* Fetches a list of JWT SVIDs for the given subject and audience. The JWT SVIDs are cached and
* returned from cache if they have still at least half of their lifetime.
*
* @return a list of {@link JwtSvid}s
* @throws IllegalStateException if the source is closed
*/
@Override
public List<JwtSvid> fetchJwtSvids(final SpiffeId subject, final String audience, final String... extraAudiences)
throws JwtSvidException {
if (isClosed()) {
throw new IllegalStateException("JWT SVID source is closed");
}
return getJwtSvids(subject, audience, extraAudiences);
}
/**
* Returns the JWT bundle for a given trust domain.
*
* @return an instance of a {@link X509Bundle}
* @throws BundleNotFoundException is there is no bundle for the trust domain provided
* @throws IllegalStateException if the source is closed
*/
@Override
public JwtBundle getBundleForTrustDomain(@NonNull final TrustDomain trustDomain) throws BundleNotFoundException {
if (isClosed()) {
throw new IllegalStateException("JWT bundle source is closed");
}
return bundles.getBundleForTrustDomain(trustDomain);
}
/**
* Closes this source, dropping the connection to the Workload API.
* Other source methods will return an error after close has been called.
* <p>
* It is marked with {@link SneakyThrows} because it is not expected to throw
* the checked exception defined on the {@link Closeable} interface.
*/
@SneakyThrows
@Override
public void close() {
if (!closed) {
synchronized (this) {
if (!closed) {
workloadApiClient.close();
closed = true;
}
}
}
}
// Check if the jwtSvids map contains the cacheKey, returns it if it does and the JWT SVID has not passed its half lifetime.
// If the cache does not contain the key or the JWT SVID has passed its half lifetime, make a new FetchJWTSVID call to the Workload API,
// adds the JWT SVIDs to the cache map and returns them.
// Only one thread can fetch new JWT SVIDs and update the cache at a time.
private List<JwtSvid> getJwtSvids(SpiffeId subject, String audience, String... extraAudiences) throws JwtSvidException {
Set<String> audiencesSet = getAudienceSet(audience, extraAudiences);
ImmutablePair<SpiffeId, Set<String>> cacheKey = new ImmutablePair<>(subject, audiencesSet);
List<JwtSvid> svidList = jwtSvids.get(cacheKey);
if (svidList != null && !isTokenPastHalfLifetime(svidList.get(0))) {
return svidList;
}
// even using ConcurrentHashMap, there is a possibility of multiple threads trying to fetch new JWT SVIDs at the same time.
synchronized (this) {
// Check again if the jwtSvids map contains the cacheKey, and return the entry if it exists and the JWT SVID has not passed its half lifetime.
// If it does not exist or the JWT-SVID has passed half its lifetime, call the Workload API to fetch new JWT-SVIDs,
// add them to the cache map, and return the list of JWT-SVIDs.
svidList = jwtSvids.get(cacheKey);
if (svidList != null && !isTokenPastHalfLifetime(svidList.get(0))) {
return svidList;
}
if (cacheKey.left == null) {
svidList = workloadApiClient.fetchJwtSvids(audience, extraAudiences);
} else {
svidList = workloadApiClient.fetchJwtSvids(cacheKey.left, audience, extraAudiences);
}
jwtSvids.put(cacheKey, svidList);
return svidList;
}
}
private List<JwtSvid> getJwtSvids(String audience, String... extraAudiences) throws JwtSvidException {
return getJwtSvids(null, audience, extraAudiences);
}
private static Set<String> getAudienceSet(String audience, String[] extraAudiences) {
Set<String> audiencesString;
if (extraAudiences != null && extraAudiences.length > 0) {
audiencesString = new HashSet<>(Arrays.asList(extraAudiences));
audiencesString.add(audience);
} else {
audiencesString = Collections.singleton(audience);
}
return audiencesString;
}
private boolean isTokenPastHalfLifetime(JwtSvid jwtSvid) {
Instant now = clock.instant();
val halfLife = new Date(jwtSvid.getExpiry().getTime() - (jwtSvid.getExpiry().getTime() - jwtSvid.getIssuedAt().getTime()) / 2);
val halfLifeInstant = Instant.ofEpochMilli(halfLife.getTime());
return now.isAfter(halfLifeInstant);
}
private void init(final Duration timeout) throws TimeoutException {
CountDownLatch done = new CountDownLatch(1);
setJwtBundlesWatcher(done);
boolean success;
if (timeout.isZero()) {
await(done);
success = true;
} else {
success = await(done, timeout.getSeconds(), TimeUnit.SECONDS);
}
if (!success) {
throw new TimeoutException("Timeout waiting for JWT bundles update");
}
}
private void setJwtBundlesWatcher(final CountDownLatch done) {
workloadApiClient.watchJwtBundles(new Watcher<JwtBundleSet>() {
@Override
public void onUpdate(final JwtBundleSet update) {
log.log(Level.INFO, "Received JwtBundleSet update");
setJwtBundleSet(update);
done.countDown();
}
@Override
public void onError(final Throwable error) {
log.log(Level.SEVERE, "Error in JwtBundleSet watcher", error);
done.countDown();
throw new WatcherException("Error fetching JwtBundleSet", error);
}
});
}
private void setJwtBundleSet(final JwtBundleSet update) {
synchronized (this) {
this.bundles = update;
}
}
private boolean isClosed() {
synchronized (this) {
return closed;
}
}
private static WorkloadApiClient createClient(final JwtSourceOptions options)
throws SocketEndpointAddressException {
val clientOptions = DefaultWorkloadApiClient.ClientOptions
.builder()
.spiffeSocketPath(options.getSpiffeSocketPath())
.build();
return DefaultWorkloadApiClient.newClient(clientOptions);
}
void setClock(Clock clock) {
this.clock = clock;
}
}

View File

@ -3,30 +3,22 @@ package io.spiffe.workloadapi;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.exception.JwtSourceException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.SocketEndpointAddressException;
import io.spiffe.exception.WatcherException;
import io.spiffe.exception.*;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.svid.jwtsvid.JwtSvid;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.Data;
import lombok.NonNull;
import lombok.Setter;
import lombok.SneakyThrows;
import lombok.extern.java.Log;
import lombok.val;
import java.io.Closeable;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.logging.Level;
import java.util.List;
import static io.spiffe.workloadapi.internal.ThreadUtils.await;
@ -87,18 +79,18 @@ public class DefaultJwtSource implements JwtSource {
*/
public static JwtSource newSource(@NonNull final JwtSourceOptions options)
throws SocketEndpointAddressException, JwtSourceException {
if (options.workloadApiClient == null) {
options.workloadApiClient = createClient(options);
if (options.getWorkloadApiClient()== null) {
options.setWorkloadApiClient(createClient(options));
}
if (options.initTimeout == null) {
options.initTimeout = DEFAULT_TIMEOUT;
if (options.getInitTimeout()== null) {
options.setInitTimeout(DEFAULT_TIMEOUT);
}
DefaultJwtSource jwtSource = new DefaultJwtSource(options.workloadApiClient);
DefaultJwtSource jwtSource = new DefaultJwtSource(options.getWorkloadApiClient());
try {
jwtSource.init(options.initTimeout);
jwtSource.init(options.getInitTimeout());
} catch (Exception e) {
jwtSource.close();
throw new JwtSourceException("Error creating JWT source", e);
@ -241,42 +233,8 @@ public class DefaultJwtSource implements JwtSource {
throws SocketEndpointAddressException {
val clientOptions = DefaultWorkloadApiClient.ClientOptions
.builder()
.spiffeSocketPath(options.spiffeSocketPath)
.spiffeSocketPath(options.getSpiffeSocketPath())
.build();
return DefaultWorkloadApiClient.newClient(clientOptions);
}
/**
* Options to configure a {@link DefaultJwtSource}.
* <p>
* <code>spiffeSocketPath</code> Address to the Workload API, if it is not set, the default address will be used.
* <p>
* <code>initTimeout</code> Timeout for initializing the instance. If it is not defined, the timeout is read
* from the System property `spiffe.newJwtSource.timeout'. If this is also not defined, no default timeout is applied.
* <p>
* <code>workloadApiClient</code> A custom instance of a {@link WorkloadApiClient}, if it is not set,
* a new client will be created.
*/
@Data
public static class JwtSourceOptions {
@Setter(AccessLevel.NONE)
private String spiffeSocketPath;
@Setter(AccessLevel.NONE)
private Duration initTimeout;
@Setter(AccessLevel.NONE)
private WorkloadApiClient workloadApiClient;
@Builder
public JwtSourceOptions(
final String spiffeSocketPath,
final WorkloadApiClient workloadApiClient,
final Duration initTimeout) {
this.spiffeSocketPath = spiffeSocketPath;
this.workloadApiClient = workloadApiClient;
this.initTimeout = initTimeout;
}
}
}

View File

@ -0,0 +1,43 @@
package io.spiffe.workloadapi;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.Data;
import lombok.Setter;
import java.time.Duration;
/**
* Options to configure a {@link JwtSource}.
* <p>
* <code>spiffeSocketPath</code> Address to the Workload API, if it is not set, the default address will be used.
* <p>
* <code>initTimeout</code> Timeout for initializing the instance. If it is not defined, the timeout is read
* from the System property `spiffe.newJwtSource.timeout'. If this is also not defined, no default timeout is applied.
* <p>
* <code>workloadApiClient</code> A custom instance of a {@link WorkloadApiClient}, if it is not set,
* a new client will be created.
*/
@Data
public class JwtSourceOptions {
@Setter(AccessLevel.PUBLIC)
private String spiffeSocketPath;
@Setter(AccessLevel.PUBLIC)
private Duration initTimeout;
@Setter(AccessLevel.PUBLIC)
private WorkloadApiClient workloadApiClient;
@Builder
public JwtSourceOptions(
final String spiffeSocketPath,
final WorkloadApiClient workloadApiClient,
final Duration initTimeout) {
this.spiffeSocketPath = spiffeSocketPath;
this.workloadApiClient = workloadApiClient;
this.initTimeout = initTimeout;
}
}

View File

@ -124,6 +124,7 @@ class JwtSvidParseAndValidateTest {
jwtBundle.putJwtAuthority("authority3", key3.getPublic());
SpiffeId spiffeId = trustDomain.newSpiffeId("host");
Date issuedAt = new Date();
Date expiration = new Date(System.currentTimeMillis() + (60 * 60 * 1000));
Set<String> audience = new HashSet<String>() {{add("audience1"); add("audience2");}};
@ -139,6 +140,7 @@ class JwtSvidParseAndValidateTest {
.expectedJwtSvid(newJwtSvidInstance(
trustDomain.newSpiffeId("host"),
audience,
issuedAt,
expiration,
claims.getClaims(), TestUtils.generateToken(claims, key1, "authority1", JwtSvid.HEADER_TYP_JOSE) ))
.build()),
@ -151,6 +153,7 @@ class JwtSvidParseAndValidateTest {
.expectedJwtSvid(newJwtSvidInstance(
trustDomain.newSpiffeId("host"),
audience,
issuedAt,
expiration,
claims.getClaims(), TestUtils.generateToken(claims, key3, "authority3", JwtSvid.HEADER_TYP_JWT)))
.build()),
@ -163,6 +166,7 @@ class JwtSvidParseAndValidateTest {
.expectedJwtSvid(newJwtSvidInstance(
trustDomain.newSpiffeId("host"),
audience,
issuedAt,
expiration,
claims.getClaims(), TestUtils.generateToken(claims, key3, "authority3")))
.build())

View File

@ -112,6 +112,7 @@ class JwtSvidParseInsecureTest {
SpiffeId spiffeId = trustDomain.newSpiffeId("host");
Date expiration = new Date(System.currentTimeMillis() + 3600000);
Date issuedAt = new Date();
Set<String> audience = Collections.singleton("audience");
JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration);
@ -125,6 +126,7 @@ class JwtSvidParseInsecureTest {
.expectedJwtSvid(newJwtSvidInstance(
trustDomain.newSpiffeId("host"),
audience,
issuedAt,
expiration,
claims.getClaims(), TestUtils.generateToken(claims, key1, "authority1", JwtSvid.HEADER_TYP_JWT)))
.build()),
@ -136,6 +138,7 @@ class JwtSvidParseInsecureTest {
.expectedJwtSvid(newJwtSvidInstance(
trustDomain.newSpiffeId("host"),
audience,
issuedAt,
expiration,
claims.getClaims(), TestUtils.generateToken(claims, key1, "authority1", JwtSvid.HEADER_TYP_JWT)))
.build()),
@ -147,6 +150,7 @@ class JwtSvidParseInsecureTest {
.expectedJwtSvid(newJwtSvidInstance(
trustDomain.newSpiffeId("host"),
audience,
issuedAt,
expiration,
claims.getClaims(), TestUtils.generateToken(claims, key1, "authority1", "")))
.build()));
@ -234,13 +238,14 @@ class JwtSvidParseInsecureTest {
static JwtSvid newJwtSvidInstance(final SpiffeId spiffeId,
final Set<String> audience,
final Date issuedAt,
final Date expiry,
final Map<String, Object> claims,
final String token) {
val constructor = JwtSvid.class.getDeclaredConstructors()[0];
constructor.setAccessible(true);
try {
return (JwtSvid) constructor.newInstance(spiffeId, audience, expiry, claims, token);
return (JwtSvid) constructor.newInstance(spiffeId, audience, issuedAt, expiry, claims, token);
} catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}

View File

@ -0,0 +1,522 @@
package io.spiffe.workloadapi;
import com.google.common.collect.Sets;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.exception.JwtSourceException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.SocketEndpointAddressException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.svid.jwtsvid.JwtSvid;
import io.spiffe.utils.TestUtils;
import lombok.val;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import static io.spiffe.workloadapi.WorkloadApiClientStub.JWT_TTL;
import static org.junit.jupiter.api.Assertions.*;
class CachedJwtSourceTest {
private CachedJwtSource jwtSource;
private WorkloadApiClientStub workloadApiClient;
private WorkloadApiClientErrorStub workloadApiClientErrorStub;
private Clock clock;
@BeforeEach
void setUp() throws JwtSourceException, SocketEndpointAddressException {
workloadApiClient = new WorkloadApiClientStub();
JwtSourceOptions options = JwtSourceOptions.builder().workloadApiClient(workloadApiClient).build();
System.setProperty(DefaultJwtSource.TIMEOUT_SYSTEM_PROPERTY, "PT1S");
jwtSource = (CachedJwtSource) CachedJwtSource.newSource(options);
workloadApiClientErrorStub = new WorkloadApiClientErrorStub();
clock = Clock.fixed(Instant.now(), ZoneId.systemDefault());
workloadApiClient.setClock(clock);
jwtSource.setClock(clock);
}
@AfterEach
void tearDown() throws IOException {
jwtSource.close();
}
@Test
void testGetBundleForTrustDomain() {
try {
JwtBundle bundle = jwtSource.getBundleForTrustDomain(TrustDomain.parse("example.org"));
assertNotNull(bundle);
assertEquals(TrustDomain.parse("example.org"), bundle.getTrustDomain());
} catch (BundleNotFoundException e) {
fail(e);
}
}
@Test
void testGetBundleForTrustDomain_nullParam() {
try {
jwtSource.getBundleForTrustDomain(null);
fail();
} catch (NullPointerException e) {
assertEquals("trustDomain is marked non-null but is null", e.getMessage());
} catch (BundleNotFoundException e) {
fail();
}
}
@Test
void testGetBundleForTrustDomain_SourceIsClosed_ThrowsIllegalStateException() throws IOException {
jwtSource.close();
try {
jwtSource.getBundleForTrustDomain(TrustDomain.parse("example.org"));
fail("expected exception");
} catch (IllegalStateException e) {
assertEquals("JWT bundle source is closed", e.getMessage());
assertTrue(workloadApiClient.closed);
} catch (BundleNotFoundException e) {
fail("not expected exception", e);
}
}
@Test
void testFetchJwtSvidWithSubject() {
try {
JwtSvid svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidWithSubject_ReturnFromCache() {
try {
JwtSvid svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud3", "aud2", "aud1");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
// call again to get from cache changing the order of the audiences
svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
// call again using different subject
svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/extra-workload-server"), "aud2", "aud3", "aud1");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
// call again using the same audiences
svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/extra-workload-server"), "aud1", "aud2", "aud3");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidWithSubject_JwtSvidExpiredInCache() {
try {
JwtSvid svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
// set clock forwards but not enough to expire the JWT SVID in the cache
jwtSource.setClock(clock.offset(clock, JWT_TTL.dividedBy(2).minus(Duration.ofSeconds(1))));
// call again to get from cache, fetchJwtSvid call count should not change
svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
// set clock to expire the JWT SVID in the cache
jwtSource.setClock(clock.offset(clock, JWT_TTL.dividedBy(2).plus(Duration.ofSeconds(1))));
// call again, fetchJwtSvid call count should increase
svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidWithSubject_JwtSvidExpiredInCache_MultipleThreads() {
// test fetchJwtSvid with several threads trying to read and write the cache
// at the same time, the cache should be updated only once
try {
jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
// set clock to expire the JWT SVID in the cache
Clock offset = Clock.offset(clock, JWT_TTL.dividedBy(2).plus(Duration.ofSeconds(1)));
jwtSource.setClock(offset);
workloadApiClient.setClock(offset);
// create a thread pool with 10 threads
ExecutorService executorService = Executors.newFixedThreadPool(10);
List<Future<JwtSvid>> futures = new ArrayList<>();
// create 10 tasks to fetch a JWT SVID
for (int i = 0; i < 10; i++) {
futures.add(executorService.submit(() -> jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3")));
}
// wait for all tasks to finish
for (Future<JwtSvid> future : futures) {
future.get();
}
// verify that the cache was updated only once after the JWT SVID expired
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
} catch (InterruptedException | ExecutionException | JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidWithoutSubject() {
try {
JwtSvid svid = jwtSource.fetchJwtSvid("aud1", "aud2", "aud3");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidWithoutSubject_ReturnFromCache() {
try {
JwtSvid svid = jwtSource.fetchJwtSvid("aud1", "aud2", "aud3");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
// call again to get from cache changing the order of the audiences, the call count should not change
svid = jwtSource.fetchJwtSvid("aud3", "aud2", "aud1");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
// call again using different audience, the call count should increase
svid = jwtSource.fetchJwtSvid("other-audience");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("other-audience"), svid.getAudience());
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidWithoutSubject_JwtSvidExpiredInCache() {
try {
JwtSvid svid = jwtSource.fetchJwtSvid("aud1", "aud2", "aud3");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
// set clock forwards but not enough to expire the JWT SVID in the cache
jwtSource.setClock(clock.offset(clock, JWT_TTL.dividedBy(2).minus(Duration.ofSeconds(1))));
// call again to get from cache, fetchJwtSvid call count should not change
svid = jwtSource.fetchJwtSvid("aud3", "aud2", "aud1");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
// set clock forwards to expire the JWT SVID in the cache
jwtSource.setClock(clock.offset(clock, JWT_TTL.dividedBy(2).plus(Duration.ofSeconds(1))));
// call again, fetchJwtSvid call count should increase
svid = jwtSource.fetchJwtSvid("aud1", "aud2", "aud3");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvid_SourceIsClosed_ThrowsIllegalStateException() throws IOException {
jwtSource.close();
try {
jwtSource.fetchJwtSvid("aud1", "aud2", "aud3");
fail("expected exception");
} catch (IllegalStateException e) {
assertEquals("JWT SVID source is closed", e.getMessage());
assertTrue(workloadApiClient.closed);
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidWithSubject_SourceIsClosed_ThrowsIllegalStateException() throws IOException {
jwtSource.close();
try {
jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
fail("expected exception");
} catch (IllegalStateException e) {
assertEquals("JWT SVID source is closed", e.getMessage());
assertTrue(workloadApiClient.closed);
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidsWithSubject() {
try {
List<JwtSvid> svids = jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertNotNull(svids);
assertEquals(1, svids.size());
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidsWithSubject_ReturnFromCache() {
try {
List<JwtSvid> svids = jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertNotNull(svids);
assertEquals(1, svids.size());
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
// call again to get from cache changing the order of the audiences
svids = jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertNotNull(svids);
assertEquals(1, svids.size());
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
assertEquals(Sets.newHashSet("aud3", "aud2", "aud1"), svids.get(0).getAudience());
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
// call again using different audience
svids = jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "other-audience");
assertNotNull(svids);
assertEquals(1, svids.size());
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
assertEquals(Sets.newHashSet("other-audience"), svids.get(0).getAudience());
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidsWithoutSubject() {
try {
List<JwtSvid> svids = jwtSource.fetchJwtSvids("aud1", "aud2", "aud3");
assertNotNull(svids);
assertEquals(svids.size(), 2);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(1).getAudience());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidsWithoutSubject_ReturnFromCache() {
try {
List<JwtSvid> svids = jwtSource.fetchJwtSvids("aud1", "aud2", "aud3");
assertNotNull(svids);
assertEquals(svids.size(), 2);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(1).getAudience());
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
// call again to get from cache changing the order of the audiences
svids = jwtSource.fetchJwtSvids("aud2", "aud3", "aud1");
assertNotNull(svids);
assertEquals(svids.size(), 2);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(1).getAudience());
assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount());
// call again using different audience
svids = jwtSource.fetchJwtSvids("other-audience");
assertNotNull(svids);
assertEquals(svids.size(), 2);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
assertEquals(Sets.newHashSet("other-audience"), svids.get(0).getAudience());
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId());
assertEquals(Sets.newHashSet("other-audience"), svids.get(1).getAudience());
assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvids_SourceIsClosed_ThrowsIllegalStateException() throws IOException {
jwtSource.close();
try {
jwtSource.fetchJwtSvids("aud1", "aud2", "aud3");
fail("expected exception");
} catch (IllegalStateException e) {
assertEquals("JWT SVID source is closed", e.getMessage());
assertTrue(workloadApiClient.closed);
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidsWithSubject_SourceIsClosed_ThrowsIllegalStateException() throws IOException {
jwtSource.close();
try {
jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
fail("expected exception");
} catch (IllegalStateException e) {
assertEquals("JWT SVID source is closed", e.getMessage());
assertTrue(workloadApiClient.closed);
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void newSource_success() {
val options = JwtSourceOptions
.builder()
.workloadApiClient(workloadApiClient)
.initTimeout(Duration.ofSeconds(0))
.build();
try {
JwtSource jwtSource = DefaultJwtSource.newSource(options);
assertNotNull(jwtSource);
} catch (SocketEndpointAddressException | JwtSourceException e) {
fail(e);
}
}
@Test
void newSource_nullParam() {
try {
DefaultJwtSource.newSource(null);
fail();
} catch (NullPointerException e) {
assertEquals("options is marked non-null but is null", e.getMessage());
} catch (SocketEndpointAddressException | JwtSourceException e) {
fail();
}
}
@Test
void newSource_errorFetchingJwtBundles() {
val options = JwtSourceOptions
.builder()
.workloadApiClient(workloadApiClientErrorStub)
.spiffeSocketPath("unix:/tmp/test")
.build();
try {
DefaultJwtSource.newSource(options);
fail();
} catch (JwtSourceException e) {
assertEquals("Error creating JWT source", e.getMessage());
assertEquals("Error fetching JwtBundleSet", e.getCause().getMessage());
} catch (Exception e) {
fail();
}
}
@Test
void newSource_FailsBecauseOfTimeOut() throws Exception {
try {
val options = JwtSourceOptions
.builder()
.spiffeSocketPath("unix:/tmp/test")
.build();
DefaultJwtSource.newSource(options);
fail();
} catch (JwtSourceException e) {
assertEquals("Error creating JWT source", e.getMessage());
assertEquals("Timeout waiting for JWT bundles update", e.getCause().getMessage());
} catch (SocketEndpointAddressException e) {
fail();
}
}
@Test
void newSource_DefaultSocketAddress() throws Exception {
try {
TestUtils.setEnvironmentVariable(Address.SOCKET_ENV_VARIABLE, "unix:/tmp/test");
DefaultJwtSource.newSource();
fail();
} catch (JwtSourceException e) {
assertEquals("Error creating JWT source", e.getMessage());
} catch (SocketEndpointAddressException e) {
fail();
}
}
@Test
void newSource_noSocketAddress() throws Exception {
try {
// just in case it's defined in the environment
TestUtils.setEnvironmentVariable(Address.SOCKET_ENV_VARIABLE, "");
DefaultJwtSource.newSource();
fail();
} catch (SocketEndpointAddressException e) {
fail();
} catch (IllegalStateException e) {
assertEquals("Endpoint Socket Address Environment Variable is not set: SPIFFE_ENDPOINT_SOCKET", e.getMessage());
}
}
}

View File

@ -24,7 +24,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
class JwtSourceTest {
class DefaultJwtSourceTest {
private JwtSource jwtSource;
private WorkloadApiClientStub workloadApiClient;
@ -33,7 +33,7 @@ class JwtSourceTest {
@BeforeEach
void setUp() throws JwtSourceException, SocketEndpointAddressException {
workloadApiClient = new WorkloadApiClientStub();
DefaultJwtSource.JwtSourceOptions options = DefaultJwtSource.JwtSourceOptions.builder().workloadApiClient(workloadApiClient).build();
JwtSourceOptions options = JwtSourceOptions.builder().workloadApiClient(workloadApiClient).build();
System.setProperty(DefaultJwtSource.TIMEOUT_SYSTEM_PROPERTY, "PT1S");
jwtSource = DefaultJwtSource.newSource(options);
workloadApiClientErrorStub = new WorkloadApiClientErrorStub();
@ -152,7 +152,7 @@ class JwtSourceTest {
try {
List<JwtSvid> svids = jwtSource.fetchJwtSvids("aud1", "aud2", "aud3");
assertNotNull(svids);
assertEquals(svids.size(), 2);
assertEquals(2, svids.size());
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId());
@ -193,7 +193,7 @@ class JwtSourceTest {
@Test
void newSource_success() {
val options = DefaultJwtSource.JwtSourceOptions
val options = JwtSourceOptions
.builder()
.workloadApiClient(workloadApiClient)
.initTimeout(Duration.ofSeconds(0))
@ -220,7 +220,7 @@ class JwtSourceTest {
@Test
void newSource_errorFetchingJwtBundles() {
val options = DefaultJwtSource.JwtSourceOptions
val options = JwtSourceOptions
.builder()
.workloadApiClient(workloadApiClientErrorStub)
.spiffeSocketPath("unix:/tmp/test")
@ -239,7 +239,7 @@ class JwtSourceTest {
@Test
void newSource_FailsBecauseOfTimeOut() throws Exception {
try {
val options = DefaultJwtSource.JwtSourceOptions
val options = JwtSourceOptions
.builder()
.spiffeSocketPath("unix:/tmp/test")
.build();

View File

@ -23,12 +23,15 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.KeyPair;
import java.time.Clock;
import java.time.Duration;
import java.util.*;
import static io.spiffe.utils.TestUtils.toUri;
public class WorkloadApiClientStub implements WorkloadApiClient {
static final Duration JWT_TTL = Duration.ofSeconds(60);
final String privateKey = "testdata/workloadapi/svid.key.der";
final String svid = "testdata/workloadapi/svid.der";
final String x509Bundle = "testdata/workloadapi/bundle.der";
@ -36,8 +39,12 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
final SpiffeId subject = SpiffeId.parse("spiffe://example.org/workload-server");
final SpiffeId extraSubject = SpiffeId.parse("spiffe://example.org/extra-workload-server");
int fetchJwtSvidCallCount = 0;
boolean closed;
Clock clock = Clock.systemDefaultZone();
@Override
public X509Context fetchX509Context() {
return generateX509Context();
@ -62,16 +69,19 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
@Override
public JwtSvid fetchJwtSvid(@NonNull final String audience, final String... extraAudience) throws JwtSvidException {
fetchJwtSvidCallCount++;
return generateJwtSvid(subject, audience, extraAudience);
}
@Override
public JwtSvid fetchJwtSvid(@NonNull final SpiffeId subject, @NonNull final String audience, final String... extraAudience) throws JwtSvidException {
fetchJwtSvidCallCount++;
return generateJwtSvid(subject, audience, extraAudience);
}
@Override
public List<JwtSvid> fetchJwtSvids(@NonNull String audience, String... extraAudience) throws JwtSvidException {
fetchJwtSvidCallCount++;
List<JwtSvid> svids = new ArrayList<>();
svids.add(generateJwtSvid(subject, audience, extraAudience));
svids.add(generateJwtSvid(extraSubject, audience, extraAudience));
@ -80,6 +90,7 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
@Override
public List<JwtSvid> fetchJwtSvids(@NonNull SpiffeId subject, @NonNull String audience, String... extraAudience) throws JwtSvidException {
fetchJwtSvidCallCount++;
List<JwtSvid> svids = new ArrayList<>();
svids.add(generateJwtSvid(subject, audience, extraAudience));
return svids;
@ -132,8 +143,9 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
Map<String, Object> claims = new HashMap<>();
claims.put("sub", subject.toString());
claims.put("aud", new ArrayList<>(audParam));
Date expiration = new Date(System.currentTimeMillis() + 3600000);
claims.put("exp", expiration);
claims.put("iat", new Date(clock.millis()));
claims.put("exp", new Date(clock.millis() + JWT_TTL.toMillis()));
KeyPair keyPair = TestUtils.generateECKeyPair(Curve.P_521);
@ -178,4 +190,20 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
throw new RuntimeException(e);
}
}
void resetFetchJwtSvidCallCount() {
fetchJwtSvidCallCount = 0;
}
int getFetchJwtSvidCallCount() {
return fetchJwtSvidCallCount;
}
Clock getClock() {
return clock;
}
void setClock(Clock clock) {
this.clock = clock;
}
}

View File

@ -101,6 +101,7 @@ public class TestUtils {
public static JWTClaimsSet buildJWTClaimSetFromClaimsMap(Map<String, Object> claims) {
return new JWTClaimsSet.Builder()
.subject((String) claims.get("sub"))
.issueTime((Date) claims.get("iat"))
.expirationTime((Date) claims.get("exp"))
.audience((List<String>) claims.get("aud"))
.build();