diff --git a/java-spiffe-core/src/main/java/io/spiffe/svid/jwtsvid/JwtSvid.java b/java-spiffe-core/src/main/java/io/spiffe/svid/jwtsvid/JwtSvid.java index f83182b..1ec9a04 100644 --- a/java-spiffe-core/src/main/java/io/spiffe/svid/jwtsvid/JwtSvid.java +++ b/java-spiffe-core/src/main/java/io/spiffe/svid/jwtsvid/JwtSvid.java @@ -64,11 +64,17 @@ public class JwtSvid { */ String token; + /** + * Issued at time of JWT-SVID as present in 'iat' claim. + */ + Date issuedAt; + public static final String HEADER_TYP_JWT = "JWT"; public static final String HEADER_TYP_JOSE = "JOSE"; private JwtSvid(final SpiffeId spiffeId, final Set audience, + final Date issuedAt, final Date expiry, final Map claims, final String token) { @@ -77,6 +83,7 @@ public class JwtSvid { this.expiry = expiry; this.claims = claims; this.token = token; + this.issuedAt = issuedAt; } /** @@ -120,6 +127,8 @@ public class JwtSvid { val claimsSet = getJwtClaimsSet(signedJwt); validateAudience(claimsSet.getAudience(), audience); + val issuedAt = claimsSet.getIssueTime(); + val expirationTime = claimsSet.getExpirationTime(); validateExpiration(expirationTime); @@ -132,7 +141,8 @@ public class JwtSvid { verifySignature(signedJwt, jwtAuthority, algorithm, keyId); val claimAudience = new HashSet<>(claimsSet.getAudience()); - return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token); + + return new JwtSvid(spiffeId, claimAudience, issuedAt, expirationTime, claimsSet.getClaims(), token); } /** @@ -163,13 +173,16 @@ public class JwtSvid { val claimsSet = getJwtClaimsSet(signedJwt); validateAudience(claimsSet.getAudience(), audience); + val issuedAt = claimsSet.getIssueTime(); + val expirationTime = claimsSet.getExpirationTime(); validateExpiration(expirationTime); val spiffeId = getSpiffeIdOfSubject(claimsSet); val claimAudience = new HashSet<>(claimsSet.getAudience()); - return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token); + + return new JwtSvid(spiffeId, claimAudience, issuedAt, expirationTime, claimsSet.getClaims(), token); } /** diff --git a/java-spiffe-core/src/main/java/io/spiffe/workloadapi/CachedJwtSource.java b/java-spiffe-core/src/main/java/io/spiffe/workloadapi/CachedJwtSource.java new file mode 100644 index 0000000..ede0d5a --- /dev/null +++ b/java-spiffe-core/src/main/java/io/spiffe/workloadapi/CachedJwtSource.java @@ -0,0 +1,338 @@ +package io.spiffe.workloadapi; + + +import io.spiffe.bundle.jwtbundle.JwtBundle; +import io.spiffe.bundle.jwtbundle.JwtBundleSet; +import io.spiffe.bundle.x509bundle.X509Bundle; +import io.spiffe.exception.*; +import io.spiffe.spiffeid.SpiffeId; +import io.spiffe.spiffeid.TrustDomain; +import io.spiffe.svid.jwtsvid.JwtSvid; +import lombok.NonNull; +import lombok.SneakyThrows; +import lombok.extern.java.Log; +import lombok.val; +import org.apache.commons.lang3.tuple.ImmutablePair; + +import java.io.Closeable; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.Date; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.logging.Level; + +import static io.spiffe.workloadapi.internal.ThreadUtils.await; + +/** + * Represents a source of SPIFFE JWT SVIDs and JWT bundles maintained via the Workload API. + * The JWT SVIDs are cached and fetchJwtSvid methods return from cache + * checking that the JWT SVID has still at least half of its lifetime. + */ +@Log +public class CachedJwtSource implements JwtSource { + static final String TIMEOUT_SYSTEM_PROPERTY = "spiffe.newJwtSource.timeout"; + + static final Duration DEFAULT_TIMEOUT = + Duration.parse(System.getProperty(TIMEOUT_SYSTEM_PROPERTY, "PT0S")); + + // Synchronized map of JWT SVIDs, keyed by a pair of SPIFFE ID and a Set of audiences strings. + // This map is used to cache the JWT SVIDs and avoid fetching them from the Workload API. + private final + Map>, List> jwtSvids = new ConcurrentHashMap<>(); + + private JwtBundleSet bundles; + + private final WorkloadApiClient workloadApiClient; + private volatile boolean closed; + private Clock clock; + + // private constructor + private CachedJwtSource(final WorkloadApiClient workloadApiClient) { + this.clock = Clock.systemDefaultZone(); + this.workloadApiClient = workloadApiClient; + } + + /** + * Creates a new Cached JWT source. It blocks until the initial update with the JWT bundles + * has been received from the Workload API or until the timeout configured + * through the system property `spiffe.newJwtSource.timeout` expires. + * If no timeout is configured, it blocks until it gets a JWT update from the Workload API. + *

+ * It uses the default address socket endpoint from the environment variable to get the Workload API address. + * + * @return an instance of {@link DefaultJwtSource}, with the JWT bundles initialized + * @throws SocketEndpointAddressException if the address to the Workload API is not valid + * @throws JwtSourceException if the source could not be initialized + */ + public static JwtSource newSource() throws JwtSourceException, SocketEndpointAddressException { + JwtSourceOptions options = JwtSourceOptions.builder().initTimeout(DEFAULT_TIMEOUT).build(); + return newSource(options); + } + + /** + * Creates a new JWT source. It blocks until the initial update with the JWT bundles + * has been received from the Workload API, doing retries with an exponential backoff policy, + * or until the initTimeout has expired. + *

+ * If the timeout is not provided in the options, the default timeout is read from the + * system property `spiffe.newJwtSource.timeout`. If none is configured, this method will + * block until the JWT bundles can be retrieved from the Workload API. + *

+ * The {@link WorkloadApiClient} can be provided in the options, if it is not, + * a new client is created. + * + * @param options {@link JwtSourceOptions} + * @return an instance of {@link CachedJwtSource}, with the JWT bundles initialized + * @throws SocketEndpointAddressException if the address to the Workload API is not valid + * @throws JwtSourceException if the source could not be initialized + */ + public static JwtSource newSource(@NonNull final JwtSourceOptions options) + throws SocketEndpointAddressException, JwtSourceException { + if (options.getWorkloadApiClient() == null) { + options.setWorkloadApiClient(createClient(options)); + } + + if (options.getInitTimeout() == null) { + options.setInitTimeout(DEFAULT_TIMEOUT); + } + + CachedJwtSource jwtSource = new CachedJwtSource(options.getWorkloadApiClient()); + + try { + jwtSource.init(options.getInitTimeout()); + } catch (Exception e) { + jwtSource.close(); + throw new JwtSourceException("Error creating JWT source", e); + } + + return jwtSource; + } + + /** + * Fetches a JWT SVID for the given audiences. The JWT SVID is cached and + * returned from the cache if it still has at least half of its lifetime. + * + * @param audience the audience + * @param extraAudiences a list of extra audiences as an array of String + * @return a {@link JwtSvid} + * @throws JwtSvidException + */ + @Override + public JwtSvid fetchJwtSvid(final String audience, final String... extraAudiences) throws JwtSvidException { + if (isClosed()) { + throw new IllegalStateException("JWT SVID source is closed"); + } + + return getJwtSvids(audience, extraAudiences).get(0); + } + + /** + * Fetches a JWT SVID for the given subject and audience. The JWT SVID is cached and + * returned from cache if it has still at least half of its lifetime. + * + * @return a {@link JwtSvid} + * @throws IllegalStateException if the source is closed + */ + @Override + public JwtSvid fetchJwtSvid(final SpiffeId subject, final String audience, final String... extraAudiences) + throws JwtSvidException { + if (isClosed()) { + throw new IllegalStateException("JWT SVID source is closed"); + } + + return getJwtSvids(subject, audience, extraAudiences).get(0); + } + + /** + * Fetches a list of JWT SVIDs for the given audience. The JWT SVIDs are cached and + * returned from cache if they have still at least half of their lifetime. + * + * @return a list of {@link JwtSvid}s + * @throws IllegalStateException if the source is closed + */ + @Override + public List fetchJwtSvids(final String audience, final String... extraAudiences) throws JwtSvidException { + if (isClosed()) { + throw new IllegalStateException("JWT SVID source is closed"); + } + + return getJwtSvids(audience, extraAudiences); + } + + /** + * Fetches a list of JWT SVIDs for the given subject and audience. The JWT SVIDs are cached and + * returned from cache if they have still at least half of their lifetime. + * + * @return a list of {@link JwtSvid}s + * @throws IllegalStateException if the source is closed + */ + @Override + public List fetchJwtSvids(final SpiffeId subject, final String audience, final String... extraAudiences) + throws JwtSvidException { + if (isClosed()) { + throw new IllegalStateException("JWT SVID source is closed"); + } + + return getJwtSvids(subject, audience, extraAudiences); + } + + /** + * Returns the JWT bundle for a given trust domain. + * + * @return an instance of a {@link X509Bundle} + * @throws BundleNotFoundException is there is no bundle for the trust domain provided + * @throws IllegalStateException if the source is closed + */ + @Override + public JwtBundle getBundleForTrustDomain(@NonNull final TrustDomain trustDomain) throws BundleNotFoundException { + if (isClosed()) { + throw new IllegalStateException("JWT bundle source is closed"); + } + return bundles.getBundleForTrustDomain(trustDomain); + } + + /** + * Closes this source, dropping the connection to the Workload API. + * Other source methods will return an error after close has been called. + *

+ * It is marked with {@link SneakyThrows} because it is not expected to throw + * the checked exception defined on the {@link Closeable} interface. + */ + @SneakyThrows + @Override + public void close() { + if (!closed) { + synchronized (this) { + if (!closed) { + workloadApiClient.close(); + closed = true; + } + } + } + } + + // Check if the jwtSvids map contains the cacheKey, returns it if it does and the JWT SVID has not passed its half lifetime. + // If the cache does not contain the key or the JWT SVID has passed its half lifetime, make a new FetchJWTSVID call to the Workload API, + // adds the JWT SVIDs to the cache map and returns them. + // Only one thread can fetch new JWT SVIDs and update the cache at a time. + private List getJwtSvids(SpiffeId subject, String audience, String... extraAudiences) throws JwtSvidException { + Set audiencesSet = getAudienceSet(audience, extraAudiences); + ImmutablePair> cacheKey = new ImmutablePair<>(subject, audiencesSet); + + List svidList = jwtSvids.get(cacheKey); + if (svidList != null && !isTokenPastHalfLifetime(svidList.get(0))) { + return svidList; + } + + // even using ConcurrentHashMap, there is a possibility of multiple threads trying to fetch new JWT SVIDs at the same time. + synchronized (this) { + // Check again if the jwtSvids map contains the cacheKey, and return the entry if it exists and the JWT SVID has not passed its half lifetime. + // If it does not exist or the JWT-SVID has passed half its lifetime, call the Workload API to fetch new JWT-SVIDs, + // add them to the cache map, and return the list of JWT-SVIDs. + svidList = jwtSvids.get(cacheKey); + if (svidList != null && !isTokenPastHalfLifetime(svidList.get(0))) { + return svidList; + } + + if (cacheKey.left == null) { + svidList = workloadApiClient.fetchJwtSvids(audience, extraAudiences); + } else { + svidList = workloadApiClient.fetchJwtSvids(cacheKey.left, audience, extraAudiences); + } + jwtSvids.put(cacheKey, svidList); + return svidList; + } + } + + private List getJwtSvids(String audience, String... extraAudiences) throws JwtSvidException { + return getJwtSvids(null, audience, extraAudiences); + } + + private static Set getAudienceSet(String audience, String[] extraAudiences) { + Set audiencesString; + if (extraAudiences != null && extraAudiences.length > 0) { + audiencesString = new HashSet<>(Arrays.asList(extraAudiences)); + audiencesString.add(audience); + } else { + audiencesString = Collections.singleton(audience); + } + return audiencesString; + } + + private boolean isTokenPastHalfLifetime(JwtSvid jwtSvid) { + Instant now = clock.instant(); + val halfLife = new Date(jwtSvid.getExpiry().getTime() - (jwtSvid.getExpiry().getTime() - jwtSvid.getIssuedAt().getTime()) / 2); + val halfLifeInstant = Instant.ofEpochMilli(halfLife.getTime()); + return now.isAfter(halfLifeInstant); + } + + + private void init(final Duration timeout) throws TimeoutException { + CountDownLatch done = new CountDownLatch(1); + setJwtBundlesWatcher(done); + + boolean success; + if (timeout.isZero()) { + await(done); + success = true; + } else { + success = await(done, timeout.getSeconds(), TimeUnit.SECONDS); + } + if (!success) { + throw new TimeoutException("Timeout waiting for JWT bundles update"); + } + } + + private void setJwtBundlesWatcher(final CountDownLatch done) { + workloadApiClient.watchJwtBundles(new Watcher() { + @Override + public void onUpdate(final JwtBundleSet update) { + log.log(Level.INFO, "Received JwtBundleSet update"); + setJwtBundleSet(update); + done.countDown(); + } + + @Override + public void onError(final Throwable error) { + log.log(Level.SEVERE, "Error in JwtBundleSet watcher", error); + done.countDown(); + throw new WatcherException("Error fetching JwtBundleSet", error); + } + }); + } + + private void setJwtBundleSet(final JwtBundleSet update) { + synchronized (this) { + this.bundles = update; + } + } + + private boolean isClosed() { + synchronized (this) { + return closed; + } + } + + private static WorkloadApiClient createClient(final JwtSourceOptions options) + throws SocketEndpointAddressException { + val clientOptions = DefaultWorkloadApiClient.ClientOptions + .builder() + .spiffeSocketPath(options.getSpiffeSocketPath()) + .build(); + return DefaultWorkloadApiClient.newClient(clientOptions); + } + + void setClock(Clock clock) { + this.clock = clock; + } +} diff --git a/java-spiffe-core/src/main/java/io/spiffe/workloadapi/DefaultJwtSource.java b/java-spiffe-core/src/main/java/io/spiffe/workloadapi/DefaultJwtSource.java index b4721df..d92ee7c 100644 --- a/java-spiffe-core/src/main/java/io/spiffe/workloadapi/DefaultJwtSource.java +++ b/java-spiffe-core/src/main/java/io/spiffe/workloadapi/DefaultJwtSource.java @@ -3,30 +3,22 @@ package io.spiffe.workloadapi; import io.spiffe.bundle.jwtbundle.JwtBundle; import io.spiffe.bundle.jwtbundle.JwtBundleSet; import io.spiffe.bundle.x509bundle.X509Bundle; -import io.spiffe.exception.BundleNotFoundException; -import io.spiffe.exception.JwtSourceException; -import io.spiffe.exception.JwtSvidException; -import io.spiffe.exception.SocketEndpointAddressException; -import io.spiffe.exception.WatcherException; +import io.spiffe.exception.*; import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.TrustDomain; import io.spiffe.svid.jwtsvid.JwtSvid; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Data; import lombok.NonNull; -import lombok.Setter; import lombok.SneakyThrows; import lombok.extern.java.Log; import lombok.val; import java.io.Closeable; import java.time.Duration; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.logging.Level; -import java.util.List; import static io.spiffe.workloadapi.internal.ThreadUtils.await; @@ -87,18 +79,18 @@ public class DefaultJwtSource implements JwtSource { */ public static JwtSource newSource(@NonNull final JwtSourceOptions options) throws SocketEndpointAddressException, JwtSourceException { - if (options.workloadApiClient == null) { - options.workloadApiClient = createClient(options); + if (options.getWorkloadApiClient()== null) { + options.setWorkloadApiClient(createClient(options)); } - if (options.initTimeout == null) { - options.initTimeout = DEFAULT_TIMEOUT; + if (options.getInitTimeout()== null) { + options.setInitTimeout(DEFAULT_TIMEOUT); } - DefaultJwtSource jwtSource = new DefaultJwtSource(options.workloadApiClient); + DefaultJwtSource jwtSource = new DefaultJwtSource(options.getWorkloadApiClient()); try { - jwtSource.init(options.initTimeout); + jwtSource.init(options.getInitTimeout()); } catch (Exception e) { jwtSource.close(); throw new JwtSourceException("Error creating JWT source", e); @@ -241,42 +233,8 @@ public class DefaultJwtSource implements JwtSource { throws SocketEndpointAddressException { val clientOptions = DefaultWorkloadApiClient.ClientOptions .builder() - .spiffeSocketPath(options.spiffeSocketPath) + .spiffeSocketPath(options.getSpiffeSocketPath()) .build(); return DefaultWorkloadApiClient.newClient(clientOptions); } - - /** - * Options to configure a {@link DefaultJwtSource}. - *

- * spiffeSocketPath Address to the Workload API, if it is not set, the default address will be used. - *

- * initTimeout Timeout for initializing the instance. If it is not defined, the timeout is read - * from the System property `spiffe.newJwtSource.timeout'. If this is also not defined, no default timeout is applied. - *

- * workloadApiClient A custom instance of a {@link WorkloadApiClient}, if it is not set, - * a new client will be created. - */ - @Data - public static class JwtSourceOptions { - - @Setter(AccessLevel.NONE) - private String spiffeSocketPath; - - @Setter(AccessLevel.NONE) - private Duration initTimeout; - - @Setter(AccessLevel.NONE) - private WorkloadApiClient workloadApiClient; - - @Builder - public JwtSourceOptions( - final String spiffeSocketPath, - final WorkloadApiClient workloadApiClient, - final Duration initTimeout) { - this.spiffeSocketPath = spiffeSocketPath; - this.workloadApiClient = workloadApiClient; - this.initTimeout = initTimeout; - } - } } diff --git a/java-spiffe-core/src/main/java/io/spiffe/workloadapi/JwtSourceOptions.java b/java-spiffe-core/src/main/java/io/spiffe/workloadapi/JwtSourceOptions.java new file mode 100644 index 0000000..e7817a9 --- /dev/null +++ b/java-spiffe-core/src/main/java/io/spiffe/workloadapi/JwtSourceOptions.java @@ -0,0 +1,43 @@ +package io.spiffe.workloadapi; + + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Data; +import lombok.Setter; + +import java.time.Duration; + +/** + * Options to configure a {@link JwtSource}. + *

+ * spiffeSocketPath Address to the Workload API, if it is not set, the default address will be used. + *

+ * initTimeout Timeout for initializing the instance. If it is not defined, the timeout is read + * from the System property `spiffe.newJwtSource.timeout'. If this is also not defined, no default timeout is applied. + *

+ * workloadApiClient A custom instance of a {@link WorkloadApiClient}, if it is not set, + * a new client will be created. + */ +@Data +public class JwtSourceOptions { + + @Setter(AccessLevel.PUBLIC) + private String spiffeSocketPath; + + @Setter(AccessLevel.PUBLIC) + private Duration initTimeout; + + @Setter(AccessLevel.PUBLIC) + private WorkloadApiClient workloadApiClient; + + @Builder + public JwtSourceOptions( + final String spiffeSocketPath, + final WorkloadApiClient workloadApiClient, + final Duration initTimeout) { + this.spiffeSocketPath = spiffeSocketPath; + this.workloadApiClient = workloadApiClient; + this.initTimeout = initTimeout; + } +} diff --git a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseAndValidateTest.java b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseAndValidateTest.java index e6e249c..8db35ad 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseAndValidateTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseAndValidateTest.java @@ -124,6 +124,7 @@ class JwtSvidParseAndValidateTest { jwtBundle.putJwtAuthority("authority3", key3.getPublic()); SpiffeId spiffeId = trustDomain.newSpiffeId("host"); + Date issuedAt = new Date(); Date expiration = new Date(System.currentTimeMillis() + (60 * 60 * 1000)); Set audience = new HashSet() {{add("audience1"); add("audience2");}}; @@ -139,6 +140,7 @@ class JwtSvidParseAndValidateTest { .expectedJwtSvid(newJwtSvidInstance( trustDomain.newSpiffeId("host"), audience, + issuedAt, expiration, claims.getClaims(), TestUtils.generateToken(claims, key1, "authority1", JwtSvid.HEADER_TYP_JOSE) )) .build()), @@ -151,6 +153,7 @@ class JwtSvidParseAndValidateTest { .expectedJwtSvid(newJwtSvidInstance( trustDomain.newSpiffeId("host"), audience, + issuedAt, expiration, claims.getClaims(), TestUtils.generateToken(claims, key3, "authority3", JwtSvid.HEADER_TYP_JWT))) .build()), @@ -163,6 +166,7 @@ class JwtSvidParseAndValidateTest { .expectedJwtSvid(newJwtSvidInstance( trustDomain.newSpiffeId("host"), audience, + issuedAt, expiration, claims.getClaims(), TestUtils.generateToken(claims, key3, "authority3"))) .build()) diff --git a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseInsecureTest.java b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseInsecureTest.java index 77af9da..78e7b96 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseInsecureTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseInsecureTest.java @@ -112,6 +112,7 @@ class JwtSvidParseInsecureTest { SpiffeId spiffeId = trustDomain.newSpiffeId("host"); Date expiration = new Date(System.currentTimeMillis() + 3600000); + Date issuedAt = new Date(); Set audience = Collections.singleton("audience"); JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration); @@ -125,6 +126,7 @@ class JwtSvidParseInsecureTest { .expectedJwtSvid(newJwtSvidInstance( trustDomain.newSpiffeId("host"), audience, + issuedAt, expiration, claims.getClaims(), TestUtils.generateToken(claims, key1, "authority1", JwtSvid.HEADER_TYP_JWT))) .build()), @@ -136,6 +138,7 @@ class JwtSvidParseInsecureTest { .expectedJwtSvid(newJwtSvidInstance( trustDomain.newSpiffeId("host"), audience, + issuedAt, expiration, claims.getClaims(), TestUtils.generateToken(claims, key1, "authority1", JwtSvid.HEADER_TYP_JWT))) .build()), @@ -147,6 +150,7 @@ class JwtSvidParseInsecureTest { .expectedJwtSvid(newJwtSvidInstance( trustDomain.newSpiffeId("host"), audience, + issuedAt, expiration, claims.getClaims(), TestUtils.generateToken(claims, key1, "authority1", ""))) .build())); @@ -234,13 +238,14 @@ class JwtSvidParseInsecureTest { static JwtSvid newJwtSvidInstance(final SpiffeId spiffeId, final Set audience, + final Date issuedAt, final Date expiry, final Map claims, final String token) { val constructor = JwtSvid.class.getDeclaredConstructors()[0]; constructor.setAccessible(true); try { - return (JwtSvid) constructor.newInstance(spiffeId, audience, expiry, claims, token); + return (JwtSvid) constructor.newInstance(spiffeId, audience, issuedAt, expiry, claims, token); } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { throw new RuntimeException(e); } diff --git a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/CachedJwtSourceTest.java b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/CachedJwtSourceTest.java new file mode 100644 index 0000000..9cc9132 --- /dev/null +++ b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/CachedJwtSourceTest.java @@ -0,0 +1,522 @@ +package io.spiffe.workloadapi; + +import com.google.common.collect.Sets; +import io.spiffe.bundle.jwtbundle.JwtBundle; +import io.spiffe.exception.BundleNotFoundException; +import io.spiffe.exception.JwtSourceException; +import io.spiffe.exception.JwtSvidException; +import io.spiffe.exception.SocketEndpointAddressException; +import io.spiffe.spiffeid.SpiffeId; +import io.spiffe.spiffeid.TrustDomain; +import io.spiffe.svid.jwtsvid.JwtSvid; +import io.spiffe.utils.TestUtils; +import lombok.val; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import static io.spiffe.workloadapi.WorkloadApiClientStub.JWT_TTL; +import static org.junit.jupiter.api.Assertions.*; + +class CachedJwtSourceTest { + private CachedJwtSource jwtSource; + private WorkloadApiClientStub workloadApiClient; + private WorkloadApiClientErrorStub workloadApiClientErrorStub; + private Clock clock; + + @BeforeEach + void setUp() throws JwtSourceException, SocketEndpointAddressException { + workloadApiClient = new WorkloadApiClientStub(); + JwtSourceOptions options = JwtSourceOptions.builder().workloadApiClient(workloadApiClient).build(); + System.setProperty(DefaultJwtSource.TIMEOUT_SYSTEM_PROPERTY, "PT1S"); + jwtSource = (CachedJwtSource) CachedJwtSource.newSource(options); + workloadApiClientErrorStub = new WorkloadApiClientErrorStub(); + + clock = Clock.fixed(Instant.now(), ZoneId.systemDefault()); + workloadApiClient.setClock(clock); + jwtSource.setClock(clock); + } + + @AfterEach + void tearDown() throws IOException { + jwtSource.close(); + } + + @Test + void testGetBundleForTrustDomain() { + try { + JwtBundle bundle = jwtSource.getBundleForTrustDomain(TrustDomain.parse("example.org")); + assertNotNull(bundle); + assertEquals(TrustDomain.parse("example.org"), bundle.getTrustDomain()); + } catch (BundleNotFoundException e) { + fail(e); + } + } + + @Test + void testGetBundleForTrustDomain_nullParam() { + try { + jwtSource.getBundleForTrustDomain(null); + fail(); + } catch (NullPointerException e) { + assertEquals("trustDomain is marked non-null but is null", e.getMessage()); + } catch (BundleNotFoundException e) { + fail(); + } + } + + @Test + void testGetBundleForTrustDomain_SourceIsClosed_ThrowsIllegalStateException() throws IOException { + jwtSource.close(); + try { + jwtSource.getBundleForTrustDomain(TrustDomain.parse("example.org")); + fail("expected exception"); + } catch (IllegalStateException e) { + assertEquals("JWT bundle source is closed", e.getMessage()); + assertTrue(workloadApiClient.closed); + } catch (BundleNotFoundException e) { + fail("not expected exception", e); + } + } + + @Test + void testFetchJwtSvidWithSubject() { + try { + JwtSvid svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + } catch (JwtSvidException e) { + fail(e); + } + } + + @Test + void testFetchJwtSvidWithSubject_ReturnFromCache() { + try { + JwtSvid svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud3", "aud2", "aud1"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount()); + + // call again to get from cache changing the order of the audiences + svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount()); + + // call again using different subject + svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/extra-workload-server"), "aud2", "aud3", "aud1"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount()); + + // call again using the same audiences + svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/extra-workload-server"), "aud1", "aud2", "aud3"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount()); + } catch (JwtSvidException e) { + fail(e); + } + } + + @Test + void testFetchJwtSvidWithSubject_JwtSvidExpiredInCache() { + try { + JwtSvid svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount()); + + // set clock forwards but not enough to expire the JWT SVID in the cache + jwtSource.setClock(clock.offset(clock, JWT_TTL.dividedBy(2).minus(Duration.ofSeconds(1)))); + + // call again to get from cache, fetchJwtSvid call count should not change + svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount()); + + // set clock to expire the JWT SVID in the cache + jwtSource.setClock(clock.offset(clock, JWT_TTL.dividedBy(2).plus(Duration.ofSeconds(1)))); + + // call again, fetchJwtSvid call count should increase + svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount()); + + } catch (JwtSvidException e) { + fail(e); + } + } + + @Test + void testFetchJwtSvidWithSubject_JwtSvidExpiredInCache_MultipleThreads() { + // test fetchJwtSvid with several threads trying to read and write the cache + // at the same time, the cache should be updated only once + try { + + jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); + assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount()); + + // set clock to expire the JWT SVID in the cache + Clock offset = Clock.offset(clock, JWT_TTL.dividedBy(2).plus(Duration.ofSeconds(1))); + jwtSource.setClock(offset); + workloadApiClient.setClock(offset); + + // create a thread pool with 10 threads + ExecutorService executorService = Executors.newFixedThreadPool(10); + + List> futures = new ArrayList<>(); + + // create 10 tasks to fetch a JWT SVID + for (int i = 0; i < 10; i++) { + futures.add(executorService.submit(() -> jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"))); + } + + // wait for all tasks to finish + for (Future future : futures) { + future.get(); + } + + // verify that the cache was updated only once after the JWT SVID expired + assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount()); + + } catch (InterruptedException | ExecutionException | JwtSvidException e) { + fail(e); + } + } + + @Test + void testFetchJwtSvidWithoutSubject() { + try { + JwtSvid svid = jwtSource.fetchJwtSvid("aud1", "aud2", "aud3"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + } catch (JwtSvidException e) { + fail(e); + } + } + + @Test + void testFetchJwtSvidWithoutSubject_ReturnFromCache() { + try { + JwtSvid svid = jwtSource.fetchJwtSvid("aud1", "aud2", "aud3"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount()); + + // call again to get from cache changing the order of the audiences, the call count should not change + svid = jwtSource.fetchJwtSvid("aud3", "aud2", "aud1"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount()); + + // call again using different audience, the call count should increase + svid = jwtSource.fetchJwtSvid("other-audience"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("other-audience"), svid.getAudience()); + assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount()); + } catch (JwtSvidException e) { + fail(e); + } + } + + @Test + void testFetchJwtSvidWithoutSubject_JwtSvidExpiredInCache() { + try { + JwtSvid svid = jwtSource.fetchJwtSvid("aud1", "aud2", "aud3"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount()); + + // set clock forwards but not enough to expire the JWT SVID in the cache + jwtSource.setClock(clock.offset(clock, JWT_TTL.dividedBy(2).minus(Duration.ofSeconds(1)))); + + // call again to get from cache, fetchJwtSvid call count should not change + svid = jwtSource.fetchJwtSvid("aud3", "aud2", "aud1"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount()); + + // set clock forwards to expire the JWT SVID in the cache + jwtSource.setClock(clock.offset(clock, JWT_TTL.dividedBy(2).plus(Duration.ofSeconds(1)))); + + // call again, fetchJwtSvid call count should increase + svid = jwtSource.fetchJwtSvid("aud1", "aud2", "aud3"); + assertNotNull(svid); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); + assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount()); + } catch (JwtSvidException e) { + fail(e); + } + } + + @Test + void testFetchJwtSvid_SourceIsClosed_ThrowsIllegalStateException() throws IOException { + jwtSource.close(); + try { + jwtSource.fetchJwtSvid("aud1", "aud2", "aud3"); + fail("expected exception"); + } catch (IllegalStateException e) { + assertEquals("JWT SVID source is closed", e.getMessage()); + assertTrue(workloadApiClient.closed); + } catch (JwtSvidException e) { + fail(e); + } + } + + @Test + void testFetchJwtSvidWithSubject_SourceIsClosed_ThrowsIllegalStateException() throws IOException { + jwtSource.close(); + try { + jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); + fail("expected exception"); + } catch (IllegalStateException e) { + assertEquals("JWT SVID source is closed", e.getMessage()); + assertTrue(workloadApiClient.closed); + } catch (JwtSvidException e) { + fail(e); + } + } + + + @Test + void testFetchJwtSvidsWithSubject() { + try { + List svids = jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); + assertNotNull(svids); + assertEquals(1, svids.size()); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience()); + } catch (JwtSvidException e) { + fail(e); + } + } + + @Test + void testFetchJwtSvidsWithSubject_ReturnFromCache() { + try { + List svids = jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); + assertNotNull(svids); + assertEquals(1, svids.size()); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience()); + assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount()); + + // call again to get from cache changing the order of the audiences + svids = jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); + assertNotNull(svids); + assertEquals(1, svids.size()); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId()); + assertEquals(Sets.newHashSet("aud3", "aud2", "aud1"), svids.get(0).getAudience()); + assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount()); + + // call again using different audience + svids = jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "other-audience"); + assertNotNull(svids); + assertEquals(1, svids.size()); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId()); + assertEquals(Sets.newHashSet("other-audience"), svids.get(0).getAudience()); + assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount()); + } catch (JwtSvidException e) { + fail(e); + } + } + + @Test + void testFetchJwtSvidsWithoutSubject() { + try { + List svids = jwtSource.fetchJwtSvids("aud1", "aud2", "aud3"); + assertNotNull(svids); + assertEquals(svids.size(), 2); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience()); + assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(1).getAudience()); + } catch (JwtSvidException e) { + fail(e); + } + } + + @Test + void testFetchJwtSvidsWithoutSubject_ReturnFromCache() { + try { + List svids = jwtSource.fetchJwtSvids("aud1", "aud2", "aud3"); + assertNotNull(svids); + assertEquals(svids.size(), 2); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience()); + assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(1).getAudience()); + assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount()); + + // call again to get from cache changing the order of the audiences + svids = jwtSource.fetchJwtSvids("aud2", "aud3", "aud1"); + assertNotNull(svids); + assertEquals(svids.size(), 2); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience()); + assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(1).getAudience()); + assertEquals(1, workloadApiClient.getFetchJwtSvidCallCount()); + + // call again using different audience + svids = jwtSource.fetchJwtSvids("other-audience"); + assertNotNull(svids); + assertEquals(svids.size(), 2); + assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId()); + assertEquals(Sets.newHashSet("other-audience"), svids.get(0).getAudience()); + assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId()); + assertEquals(Sets.newHashSet("other-audience"), svids.get(1).getAudience()); + assertEquals(2, workloadApiClient.getFetchJwtSvidCallCount()); + } catch (JwtSvidException e) { + fail(e); + } + } + + @Test + void testFetchJwtSvids_SourceIsClosed_ThrowsIllegalStateException() throws IOException { + jwtSource.close(); + try { + jwtSource.fetchJwtSvids("aud1", "aud2", "aud3"); + fail("expected exception"); + } catch (IllegalStateException e) { + assertEquals("JWT SVID source is closed", e.getMessage()); + assertTrue(workloadApiClient.closed); + } catch (JwtSvidException e) { + fail(e); + } + } + + @Test + void testFetchJwtSvidsWithSubject_SourceIsClosed_ThrowsIllegalStateException() throws IOException { + jwtSource.close(); + try { + jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); + fail("expected exception"); + } catch (IllegalStateException e) { + assertEquals("JWT SVID source is closed", e.getMessage()); + assertTrue(workloadApiClient.closed); + } catch (JwtSvidException e) { + fail(e); + } + } + + @Test + void newSource_success() { + val options = JwtSourceOptions + .builder() + .workloadApiClient(workloadApiClient) + .initTimeout(Duration.ofSeconds(0)) + .build(); + try { + JwtSource jwtSource = DefaultJwtSource.newSource(options); + assertNotNull(jwtSource); + } catch (SocketEndpointAddressException | JwtSourceException e) { + fail(e); + } + } + + @Test + void newSource_nullParam() { + try { + DefaultJwtSource.newSource(null); + fail(); + } catch (NullPointerException e) { + assertEquals("options is marked non-null but is null", e.getMessage()); + } catch (SocketEndpointAddressException | JwtSourceException e) { + fail(); + } + } + + @Test + void newSource_errorFetchingJwtBundles() { + val options = JwtSourceOptions + .builder() + .workloadApiClient(workloadApiClientErrorStub) + .spiffeSocketPath("unix:/tmp/test") + .build(); + try { + DefaultJwtSource.newSource(options); + fail(); + } catch (JwtSourceException e) { + assertEquals("Error creating JWT source", e.getMessage()); + assertEquals("Error fetching JwtBundleSet", e.getCause().getMessage()); + } catch (Exception e) { + fail(); + } + } + + @Test + void newSource_FailsBecauseOfTimeOut() throws Exception { + try { + val options = JwtSourceOptions + .builder() + .spiffeSocketPath("unix:/tmp/test") + .build(); + DefaultJwtSource.newSource(options); + fail(); + } catch (JwtSourceException e) { + assertEquals("Error creating JWT source", e.getMessage()); + assertEquals("Timeout waiting for JWT bundles update", e.getCause().getMessage()); + } catch (SocketEndpointAddressException e) { + fail(); + } + } + + @Test + void newSource_DefaultSocketAddress() throws Exception { + try { + TestUtils.setEnvironmentVariable(Address.SOCKET_ENV_VARIABLE, "unix:/tmp/test"); + DefaultJwtSource.newSource(); + fail(); + } catch (JwtSourceException e) { + assertEquals("Error creating JWT source", e.getMessage()); + } catch (SocketEndpointAddressException e) { + fail(); + } + } + + @Test + void newSource_noSocketAddress() throws Exception { + try { + // just in case it's defined in the environment + TestUtils.setEnvironmentVariable(Address.SOCKET_ENV_VARIABLE, ""); + DefaultJwtSource.newSource(); + fail(); + } catch (SocketEndpointAddressException e) { + fail(); + } catch (IllegalStateException e) { + assertEquals("Endpoint Socket Address Environment Variable is not set: SPIFFE_ENDPOINT_SOCKET", e.getMessage()); + } + } +} diff --git a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/JwtSourceTest.java b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/DefaultJwtSourceTest.java similarity index 96% rename from java-spiffe-core/src/test/java/io/spiffe/workloadapi/JwtSourceTest.java rename to java-spiffe-core/src/test/java/io/spiffe/workloadapi/DefaultJwtSourceTest.java index 4496b94..1771357 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/JwtSourceTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/DefaultJwtSourceTest.java @@ -24,7 +24,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; -class JwtSourceTest { +class DefaultJwtSourceTest { private JwtSource jwtSource; private WorkloadApiClientStub workloadApiClient; @@ -33,7 +33,7 @@ class JwtSourceTest { @BeforeEach void setUp() throws JwtSourceException, SocketEndpointAddressException { workloadApiClient = new WorkloadApiClientStub(); - DefaultJwtSource.JwtSourceOptions options = DefaultJwtSource.JwtSourceOptions.builder().workloadApiClient(workloadApiClient).build(); + JwtSourceOptions options = JwtSourceOptions.builder().workloadApiClient(workloadApiClient).build(); System.setProperty(DefaultJwtSource.TIMEOUT_SYSTEM_PROPERTY, "PT1S"); jwtSource = DefaultJwtSource.newSource(options); workloadApiClientErrorStub = new WorkloadApiClientErrorStub(); @@ -152,7 +152,7 @@ class JwtSourceTest { try { List svids = jwtSource.fetchJwtSvids("aud1", "aud2", "aud3"); assertNotNull(svids); - assertEquals(svids.size(), 2); + assertEquals(2, svids.size()); assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId()); assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience()); assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId()); @@ -193,7 +193,7 @@ class JwtSourceTest { @Test void newSource_success() { - val options = DefaultJwtSource.JwtSourceOptions + val options = JwtSourceOptions .builder() .workloadApiClient(workloadApiClient) .initTimeout(Duration.ofSeconds(0)) @@ -220,7 +220,7 @@ class JwtSourceTest { @Test void newSource_errorFetchingJwtBundles() { - val options = DefaultJwtSource.JwtSourceOptions + val options = JwtSourceOptions .builder() .workloadApiClient(workloadApiClientErrorStub) .spiffeSocketPath("unix:/tmp/test") @@ -239,7 +239,7 @@ class JwtSourceTest { @Test void newSource_FailsBecauseOfTimeOut() throws Exception { try { - val options = DefaultJwtSource.JwtSourceOptions + val options = JwtSourceOptions .builder() .spiffeSocketPath("unix:/tmp/test") .build(); diff --git a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/WorkloadApiClientStub.java b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/WorkloadApiClientStub.java index 9543ac5..7e3044f 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/WorkloadApiClientStub.java +++ b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/WorkloadApiClientStub.java @@ -23,12 +23,15 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.security.KeyPair; +import java.time.Clock; +import java.time.Duration; import java.util.*; import static io.spiffe.utils.TestUtils.toUri; public class WorkloadApiClientStub implements WorkloadApiClient { + static final Duration JWT_TTL = Duration.ofSeconds(60); final String privateKey = "testdata/workloadapi/svid.key.der"; final String svid = "testdata/workloadapi/svid.der"; final String x509Bundle = "testdata/workloadapi/bundle.der"; @@ -36,8 +39,12 @@ public class WorkloadApiClientStub implements WorkloadApiClient { final SpiffeId subject = SpiffeId.parse("spiffe://example.org/workload-server"); final SpiffeId extraSubject = SpiffeId.parse("spiffe://example.org/extra-workload-server"); + int fetchJwtSvidCallCount = 0; + boolean closed; + Clock clock = Clock.systemDefaultZone(); + @Override public X509Context fetchX509Context() { return generateX509Context(); @@ -62,16 +69,19 @@ public class WorkloadApiClientStub implements WorkloadApiClient { @Override public JwtSvid fetchJwtSvid(@NonNull final String audience, final String... extraAudience) throws JwtSvidException { + fetchJwtSvidCallCount++; return generateJwtSvid(subject, audience, extraAudience); } @Override public JwtSvid fetchJwtSvid(@NonNull final SpiffeId subject, @NonNull final String audience, final String... extraAudience) throws JwtSvidException { + fetchJwtSvidCallCount++; return generateJwtSvid(subject, audience, extraAudience); } @Override public List fetchJwtSvids(@NonNull String audience, String... extraAudience) throws JwtSvidException { + fetchJwtSvidCallCount++; List svids = new ArrayList<>(); svids.add(generateJwtSvid(subject, audience, extraAudience)); svids.add(generateJwtSvid(extraSubject, audience, extraAudience)); @@ -80,6 +90,7 @@ public class WorkloadApiClientStub implements WorkloadApiClient { @Override public List fetchJwtSvids(@NonNull SpiffeId subject, @NonNull String audience, String... extraAudience) throws JwtSvidException { + fetchJwtSvidCallCount++; List svids = new ArrayList<>(); svids.add(generateJwtSvid(subject, audience, extraAudience)); return svids; @@ -132,8 +143,9 @@ public class WorkloadApiClientStub implements WorkloadApiClient { Map claims = new HashMap<>(); claims.put("sub", subject.toString()); claims.put("aud", new ArrayList<>(audParam)); - Date expiration = new Date(System.currentTimeMillis() + 3600000); - claims.put("exp", expiration); + + claims.put("iat", new Date(clock.millis())); + claims.put("exp", new Date(clock.millis() + JWT_TTL.toMillis())); KeyPair keyPair = TestUtils.generateECKeyPair(Curve.P_521); @@ -178,4 +190,20 @@ public class WorkloadApiClientStub implements WorkloadApiClient { throw new RuntimeException(e); } } + + void resetFetchJwtSvidCallCount() { + fetchJwtSvidCallCount = 0; + } + + int getFetchJwtSvidCallCount() { + return fetchJwtSvidCallCount; + } + + Clock getClock() { + return clock; + } + + void setClock(Clock clock) { + this.clock = clock; + } } diff --git a/java-spiffe-core/src/testFixtures/java/io/spiffe/utils/TestUtils.java b/java-spiffe-core/src/testFixtures/java/io/spiffe/utils/TestUtils.java index 9fc3e57..1b912bd 100644 --- a/java-spiffe-core/src/testFixtures/java/io/spiffe/utils/TestUtils.java +++ b/java-spiffe-core/src/testFixtures/java/io/spiffe/utils/TestUtils.java @@ -101,6 +101,7 @@ public class TestUtils { public static JWTClaimsSet buildJWTClaimSetFromClaimsMap(Map claims) { return new JWTClaimsSet.Builder() .subject((String) claims.get("sub")) + .issueTime((Date) claims.get("iat")) .expirationTime((Date) claims.get("exp")) .audience((List) claims.get("aud")) .build();