diff --git a/java-spiffe-core/src/main/java/io/spiffe/spiffeid/SpiffeIdUtils.java b/java-spiffe-core/src/main/java/io/spiffe/spiffeid/SpiffeIdUtils.java index 3858a99..002d4d6 100644 --- a/java-spiffe-core/src/main/java/io/spiffe/spiffeid/SpiffeIdUtils.java +++ b/java-spiffe-core/src/main/java/io/spiffe/spiffeid/SpiffeIdUtils.java @@ -7,7 +7,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; -import java.util.List; +import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -30,11 +30,11 @@ public class SpiffeIdUtils { * @throws IOException if the given spiffeIdsFile cannot be read * @throws IllegalArgumentException if any of the SPIFFE IDs in the file cannot be parsed */ - public static List getSpiffeIdListFromFile(final Path spiffeIdsFile) throws IOException { + public static Set getSpiffeIdSetFromFile(final Path spiffeIdsFile) throws IOException { try (Stream lines = Files.lines(spiffeIdsFile)) { return lines .map(SpiffeId::parse) - .collect(Collectors.toList()); + .collect(Collectors.toSet()); } } @@ -47,15 +47,15 @@ public class SpiffeIdUtils { * @return a list of {@link SpiffeId} instances. * @throws IllegalArgumentException is the string provided is blank */ - public static List toListOfSpiffeIds(final String spiffeIds, final char separator) { + public static Set toSetOfSpiffeIds(final String spiffeIds, final char separator) { if (isBlank(spiffeIds)) { - return Collections.emptyList(); + return Collections.emptySet(); } val array = spiffeIds.split(String.valueOf(separator)); return Arrays.stream(array) .map(SpiffeId::parse) - .collect(Collectors.toList()); + .collect(Collectors.toSet()); } /** @@ -65,8 +65,8 @@ public class SpiffeIdUtils { * @return a list of {@link SpiffeId} instances * @throws IllegalArgumentException is the string provided is blank */ - public static List toListOfSpiffeIds(final String spiffeIds) { - return toListOfSpiffeIds(spiffeIds, DEFAULT_CHAR_SEPARATOR); + public static Set toSetOfSpiffeIds(final String spiffeIds) { + return toSetOfSpiffeIds(spiffeIds, DEFAULT_CHAR_SEPARATOR); } private SpiffeIdUtils() { diff --git a/java-spiffe-core/src/main/java/io/spiffe/svid/jwtsvid/JwtSvid.java b/java-spiffe-core/src/main/java/io/spiffe/svid/jwtsvid/JwtSvid.java index 5be170a..8ae3528 100644 --- a/java-spiffe-core/src/main/java/io/spiffe/svid/jwtsvid/JwtSvid.java +++ b/java-spiffe-core/src/main/java/io/spiffe/svid/jwtsvid/JwtSvid.java @@ -24,8 +24,9 @@ import java.security.interfaces.ECPublicKey; import java.security.interfaces.RSAPublicKey; import java.text.ParseException; import java.util.Date; -import java.util.List; +import java.util.HashSet; import java.util.Map; +import java.util.Set; /** * Represents a SPIFFE JWT-SVID. @@ -41,7 +42,7 @@ public class JwtSvid { /** * Audience is the intended recipients of JWT-SVID as present in the 'aud' claim. */ - List audience; + Set audience; /** * Expiration time of JWT-SVID as present in 'exp' claim. @@ -68,7 +69,7 @@ public class JwtSvid { * @param token the token as string */ JwtSvid(@NonNull final SpiffeId spiffeId, - @NonNull final List audience, + @NonNull final Set audience, @NonNull final Date expiry, @NonNull final Map claims, @NonNull final String token) { @@ -99,36 +100,38 @@ public class JwtSvid { */ public static JwtSvid parseAndValidate(@NonNull final String token, @NonNull final BundleSource jwtBundleSource, - @NonNull final List audience) + @NonNull final Set audience) throws JwtSvidException, BundleNotFoundException, AuthorityNotFoundException { if (StringUtils.isBlank(token)) { throw new IllegalArgumentException("Token cannot be blank"); } + final SignedJWT signedJwt; + final JWTClaimsSet claimsSet; try { - val signedJwt = SignedJWT.parse(token); - val claimsSet = signedJwt.getJWTClaimsSet(); - - List claimAudience = claimsSet.getAudience(); - validateAudience(claimAudience, audience); - - val expirationTime = claimsSet.getExpirationTime(); - validateExpiration(expirationTime); - - val spiffeId = getSpiffeId(claimsSet); - val jwtBundle = jwtBundleSource.getBundleForTrustDomain(spiffeId.getTrustDomain()); - - val keyId = getKeyId(signedJwt.getHeader()); - val jwtAuthority = jwtBundle.findJwtAuthority(keyId); - - val algorithm = signedJwt.getHeader().getAlgorithm().getName(); - verifySignature(signedJwt, jwtAuthority, algorithm, keyId); - - return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token); + signedJwt = SignedJWT.parse(token); + claimsSet = signedJwt.getJWTClaimsSet(); } catch (ParseException e) { throw new IllegalArgumentException("Unable to parse JWT token", e); } + + Set claimAudience = new HashSet<>(claimsSet.getAudience()); + validateAudience(claimAudience, audience); + + val expirationTime = claimsSet.getExpirationTime(); + validateExpiration(expirationTime); + + val spiffeId = getSpiffeIdOfSubject(claimsSet); + val jwtBundle = jwtBundleSource.getBundleForTrustDomain(spiffeId.getTrustDomain()); + + val keyId = getKeyId(signedJwt.getHeader()); + val jwtAuthority = jwtBundle.findJwtAuthority(keyId); + + val algorithm = signedJwt.getHeader().getAlgorithm().getName(); + verifySignature(signedJwt, jwtAuthority, algorithm, keyId); + + return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token); } /** @@ -144,27 +147,29 @@ public class JwtSvid { * the 'aud' has an audience that is not in the audience provided as parameter * @throws IllegalArgumentException when the token cannot be parsed */ - public static JwtSvid parseInsecure(@NonNull final String token, @NonNull final List audience) throws JwtSvidException { + public static JwtSvid parseInsecure(@NonNull final String token, @NonNull final Set audience) throws JwtSvidException { if (StringUtils.isBlank(token)) { throw new IllegalArgumentException("Token cannot be blank"); } + final SignedJWT signedJwt; + final JWTClaimsSet claimsSet; try { - val signedJwt = SignedJWT.parse(token); - val claimsSet = signedJwt.getJWTClaimsSet(); - - List claimAudience = claimsSet.getAudience(); - validateAudience(claimAudience, audience); - - val expirationTime = claimsSet.getExpirationTime(); - validateExpiration(expirationTime); - - val spiffeId = getSpiffeId(claimsSet); - - return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token); + signedJwt = SignedJWT.parse(token); + claimsSet = signedJwt.getJWTClaimsSet(); } catch (ParseException e) { throw new IllegalArgumentException("Unable to parse JWT token", e); } + + Set claimAudience = new HashSet<>(claimsSet.getAudience()); + validateAudience(claimAudience, audience); + + val expirationTime = claimsSet.getExpirationTime(); + validateExpiration(expirationTime); + + val spiffeId = getSpiffeIdOfSubject(claimsSet); + + return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token); } /** @@ -173,7 +178,7 @@ public class JwtSvid { * * @return the token as String */ - public String marshall() { + public String marshal() { return token; } @@ -202,9 +207,10 @@ public class JwtSvid { private static JWSVerifier getJwsVerifier(final PublicKey jwtAuthority, final String algorithm) throws JOSEException, JwtSvidException { JWSVerifier verifier; - if (Algorithm.Family.EC.contains(Algorithm.parse(algorithm))) { + final Algorithm alg = Algorithm.parse(algorithm); + if (Algorithm.Family.EC.contains(alg)) { verifier = new ECDSAVerifier((ECPublicKey) jwtAuthority); - } else if (Algorithm.Family.RSA.contains(Algorithm.parse(algorithm))) { + } else if (Algorithm.Family.RSA.contains(alg)) { verifier = new RSASSAVerifier((RSAPublicKey) jwtAuthority); } else { throw new JwtSvidException(String.format("Unsupported token signature algorithm %s", algorithm)); @@ -214,9 +220,12 @@ public class JwtSvid { private static String getKeyId(final JWSHeader header) throws JwtSvidException { val keyId = header.getKeyID(); - if (StringUtils.isBlank(keyId)) { + if (keyId == null) { throw new JwtSvidException("Token header missing key id"); } + if (StringUtils.isBlank(keyId)) { + throw new JwtSvidException("Token header key id contains an empty value"); + } return keyId; } @@ -230,7 +239,7 @@ public class JwtSvid { } } - private static SpiffeId getSpiffeId(final JWTClaimsSet claimsSet) throws JwtSvidException { + private static SpiffeId getSpiffeIdOfSubject(final JWTClaimsSet claimsSet) throws JwtSvidException { val subject = claimsSet.getSubject(); if (StringUtils.isBlank(subject)) { throw new JwtSvidException("Token missing subject claim"); @@ -244,7 +253,7 @@ public class JwtSvid { } - private static void validateAudience(final List audClaim, final List audience) throws JwtSvidException { + private static void validateAudience(final Set audClaim, final Set audience) throws JwtSvidException { for (String aud : audClaim) { if (!audience.contains(aud)) { throw new JwtSvidException(String.format("expected audience in %s (audience=%s)", audience, audClaim)); diff --git a/java-spiffe-core/src/main/java/io/spiffe/svid/x509svid/X509Svid.java b/java-spiffe-core/src/main/java/io/spiffe/svid/x509svid/X509Svid.java index fd1d7b8..10ee1bc 100644 --- a/java-spiffe-core/src/main/java/io/spiffe/svid/x509svid/X509Svid.java +++ b/java-spiffe-core/src/main/java/io/spiffe/svid/x509svid/X509Svid.java @@ -26,7 +26,7 @@ import java.util.List; * Contains a SPIFFE ID, a private key and a chain of X.509 certificates. */ @Value -public class X509Svid implements X509SvidSource { +public class X509Svid { SpiffeId spiffeId; @@ -48,8 +48,17 @@ public class X509Svid implements X509SvidSource { this.privateKey = privateKey; } + /** + * @return the Leaf Certificate of the chain + */ + public X509Certificate getLeaf() { + return chain.get(0); + } + /** * Loads the X.509 SVID from PEM encoded files on disk. + *

+ * It is assumed that the leaf certificate is always the first certificate in the parsed chain. * * @param certsFilePath path to X.509 certificate chain file * @param privateKeyFilePath path to private key file @@ -76,6 +85,8 @@ public class X509Svid implements X509SvidSource { /** * Parses the X.509 SVID from PEM or DER blocks containing certificate chain and key * bytes. The key must be a PEM or DER block with PKCS#8. + *

+ * It is assumed that the leaf certificate is always the first certificate in the parsed chain. * * @param certsBytes chain of certificates as a byte array * @param privateKeyBytes private key as byte array @@ -89,6 +100,8 @@ public class X509Svid implements X509SvidSource { /** * Parses the X509-SVID from certificate and key bytes. The certificate must be ASN.1 DER (concatenated with * no intermediate padding if there are more than one certificate). The key must be a PKCS#8 ASN.1 DER. + *

+ * It is assumed that the leaf certificate is always the first certificate in the parsed chain. * * @param certsBytes chain of certificates as a byte array * @param privateKeyBytes private key as byte array @@ -106,14 +119,6 @@ public class X509Svid implements X509SvidSource { return chain.toArray(new X509Certificate[0]); } - /** - * @return this instance, implementing a X509Svid interface. - */ - @Override - public X509Svid getX509Svid() { - return this; - } - private static X509Svid createX509Svid(final byte[] certsBytes, final byte[] privateKeyBytes, KeyFileFormat keyFileFormat) throws X509SvidException { List x509Certificates; @@ -142,20 +147,24 @@ public class X509Svid implements X509SvidSource { validateLeafCertificate(x509Certificates.get(0)); if (x509Certificates.size() > 1) { - validateSigningCertificates(x509Certificates.subList(1, x509Certificates.size())); + validateSigningCertificates(x509Certificates); } return new X509Svid(spiffeId, x509Certificates, privateKey); } private static void validateSigningCertificates(final List certificates) throws X509SvidException { - for (X509Certificate cert : certificates) { - if (!CertificateUtils.isCA(cert)) { - throw new X509SvidException("Signing certificate must have CA flag set to true"); - } - if (!CertificateUtils.hasKeyUsageCertSign(cert)) { - throw new X509SvidException("Signing certificate must have 'keyCertSign' as key usage"); - } + for (int i = 1; i < certificates.size(); i++) { + verifyCaCert(certificates.get(i)); + } + } + + private static void verifyCaCert(final X509Certificate cert) throws X509SvidException { + if (!CertificateUtils.isCA(cert)) { + throw new X509SvidException("Signing certificate must have CA flag set to true"); + } + if (!CertificateUtils.hasKeyUsageCertSign(cert)) { + throw new X509SvidException("Signing certificate must have 'keyCertSign' as key usage"); } } @@ -163,10 +172,10 @@ public class X509Svid implements X509SvidSource { if (CertificateUtils.isCA(leaf)) { throw new X509SvidException("Leaf certificate must not have CA flag set to true"); } - validateKeyUsage(leaf); + validateKeyUsageOfLeafCertificate(leaf); } - private static void validateKeyUsage(final X509Certificate leaf) throws X509SvidException { + private static void validateKeyUsageOfLeafCertificate(final X509Certificate leaf) throws X509SvidException { if (!CertificateUtils.hasKeyUsageDigitalSignature(leaf)) { throw new X509SvidException("Leaf certificate must have 'digitalSignature' as key usage"); } diff --git a/java-spiffe-core/src/main/java/io/spiffe/svid/x509svid/X509SvidValidator.java b/java-spiffe-core/src/main/java/io/spiffe/svid/x509svid/X509SvidValidator.java index 2304710..452b723 100644 --- a/java-spiffe-core/src/main/java/io/spiffe/svid/x509svid/X509SvidValidator.java +++ b/java-spiffe-core/src/main/java/io/spiffe/svid/x509svid/X509SvidValidator.java @@ -1,18 +1,19 @@ package io.spiffe.svid.x509svid; import io.spiffe.bundle.BundleSource; +import io.spiffe.bundle.x509bundle.X509Bundle; import io.spiffe.exception.BundleNotFoundException; import io.spiffe.internal.CertificateUtils; import io.spiffe.spiffeid.SpiffeId; import lombok.NonNull; import lombok.val; -import io.spiffe.bundle.x509bundle.X509Bundle; import java.security.cert.CertPathValidatorException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.List; +import java.util.Set; import java.util.function.Supplier; /** @@ -48,17 +49,17 @@ public class X509SvidValidator { * Checks that the X.509 SVID provided has a SPIFFE ID that is in the list of accepted SPIFFE IDs supplied. * * @param x509Certificate a {@link X509Svid} with a SPIFFE ID to be verified - * @param acceptedSpiffedIdsSupplier a {@link Supplier} of a list os SPIFFE IDs that are accepted - * @throws CertificateException is the SPIFFE ID in x509Certificate is not in the list supplied by acceptedSpiffedIdsSupplier, + * @param acceptedSpiffeIdsSupplier a {@link Supplier} of a Set of SPIFFE IDs that are accepted + * @throws CertificateException if the SPIFFE ID in x509Certificate is not in the Set supplied by acceptedSpiffeIdsSupplier, * or if the SPIFFE ID cannot be parsed from the x509Certificate - * @throws NullPointerException if the given x509Certificate or acceptedSpiffedIdsSupplier are null + * @throws NullPointerException if the given x509Certificate or acceptedSpiffeIdsSupplier are null */ public static void verifySpiffeId(@NonNull final X509Certificate x509Certificate, - @NonNull final Supplier> acceptedSpiffedIdsSupplier) + @NonNull final Supplier> acceptedSpiffeIdsSupplier) throws CertificateException { - val spiffeIdList = acceptedSpiffedIdsSupplier.get(); + val spiffeIdSet = acceptedSpiffeIdsSupplier.get(); val spiffeId = CertificateUtils.getSpiffeId(x509Certificate); - if (!spiffeIdList.contains(spiffeId)) { + if (!spiffeIdSet.contains(spiffeId)) { throw new CertificateException(String.format("SPIFFE ID %s in X.509 certificate is not accepted", spiffeId)); } } diff --git a/java-spiffe-core/src/main/java/io/spiffe/workloadapi/WorkloadApiClient.java b/java-spiffe-core/src/main/java/io/spiffe/workloadapi/WorkloadApiClient.java index df2a3a9..5241f04 100644 --- a/java-spiffe-core/src/main/java/io/spiffe/workloadapi/WorkloadApiClient.java +++ b/java-spiffe-core/src/main/java/io/spiffe/workloadapi/WorkloadApiClient.java @@ -32,8 +32,10 @@ import java.security.KeyException; import java.security.cert.CertificateException; import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; import java.util.Iterator; import java.util.List; +import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -192,7 +194,7 @@ public class WorkloadApiClient implements Closeable { * @throws JwtSvidException if there is an error fetching or processing the JWT from the Workload API */ public JwtSvid fetchJwtSvid(@NonNull final SpiffeId subject, @NonNull final String audience, final String... extraAudience) throws JwtSvidException { - List audParam = new ArrayList<>(); + Set audParam = new HashSet<>(); audParam.add(audience); Collections.addAll(audParam, extraAudience); @@ -238,7 +240,7 @@ public class WorkloadApiClient implements Closeable { throw new JwtSvidException("Error validating JWT SVID", e); } - return JwtSvid.parseInsecure(token, Collections.singletonList(audience)); + return JwtSvid.parseInsecure(token, Collections.singleton(audience)); } /** @@ -388,7 +390,7 @@ public class WorkloadApiClient implements Closeable { throw new X509ContextException("Error processing X509Context: x509SVIDResponse is empty"); } - private JwtSvid callFetchJwtSvid(SpiffeId subject, List audience) throws JwtSvidException { + private JwtSvid callFetchJwtSvid(SpiffeId subject, Set audience) throws JwtSvidException { Workload.JWTSVIDRequest jwtsvidRequest = Workload.JWTSVIDRequest .newBuilder() .setSpiffeId(subject.toString()) diff --git a/java-spiffe-core/src/test/java/io/spiffe/spiffeid/SpiffeIdUtilsTest.java b/java-spiffe-core/src/test/java/io/spiffe/spiffeid/SpiffeIdUtilsTest.java index ae2b33c..5503ad5 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/spiffeid/SpiffeIdUtilsTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/spiffeid/SpiffeIdUtilsTest.java @@ -1,5 +1,6 @@ package io.spiffe.spiffeid; +import lombok.val; import org.junit.jupiter.api.Test; import java.io.IOException; @@ -8,36 +9,37 @@ import java.net.URISyntaxException; import java.nio.file.NoSuchFileException; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.List; +import java.util.Set; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; class SpiffeIdUtilsTest { @Test - void getSpiffeIdListFromFile() throws URISyntaxException { + void getSpiffeIdSetFromFile() throws URISyntaxException { Path path = Paths.get(toUri("testdata/spiffeid/spiffeIds.txt")); try { - List spiffeIdList = SpiffeIdUtils.getSpiffeIdListFromFile(path); - assertNotNull(spiffeIdList); - assertEquals(3, spiffeIdList.size()); - assertEquals(SpiffeId.parse("spiffe://example.org/workload1"), spiffeIdList.get(0)); - assertEquals(SpiffeId.parse("spiffe://example.org/workload2"), spiffeIdList.get(1)); - assertEquals(SpiffeId.parse("spiffe://example2.org/workload1"), spiffeIdList.get(2)); + Set spiffeIdSet = SpiffeIdUtils.getSpiffeIdSetFromFile(path); + assertNotNull(spiffeIdSet); + assertEquals(3, spiffeIdSet.size()); + assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example.org/workload1"))); + assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example.org/workload2"))); + assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example2.org/workload1"))); } catch (IOException e) { fail(e); } } @Test - void getSpiffeIdListFromNonExistenFile_throwsException() throws IOException { + void getSpiffeIdSetFromNonExistenFile_throwsException() throws IOException { Path path = Paths.get("testdata/spiffeid/non-existent-file"); try { - SpiffeIdUtils.getSpiffeIdListFromFile(path); + SpiffeIdUtils.getSpiffeIdSetFromFile(path); fail("should have thrown exception"); } catch (NoSuchFileException e) { assertEquals("testdata/spiffeid/non-existent-file", e.getMessage()); @@ -45,15 +47,14 @@ class SpiffeIdUtilsTest { } @Test - void toListOfSpiffeIds() { - String spiffeIdsAsString = " spiffe://example.org/workload1, spiffe://example.org/workload2 "; + void toSetOfSpiffeIds() { + val spiffeIdsAsString = " spiffe://example.org/workload1, spiffe://example.org/workload2 "; + val spiffeIdSet = SpiffeIdUtils.toSetOfSpiffeIds(spiffeIdsAsString, ','); - List spiffeIdList = SpiffeIdUtils.toListOfSpiffeIds(spiffeIdsAsString, ','); - - assertNotNull(spiffeIdList); - assertEquals(2, spiffeIdList.size()); - assertEquals(SpiffeId.parse("spiffe://example.org/workload1"), spiffeIdList.get(0)); - assertEquals(SpiffeId.parse("spiffe://example.org/workload2"), spiffeIdList.get(1)); + assertNotNull(spiffeIdSet); + assertEquals(2, spiffeIdSet.size()); + assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example.org/workload1"))); + assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example.org/workload2"))); } private URI toUri(String path) throws URISyntaxException { diff --git a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseAndValidateTest.java b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseAndValidateTest.java index d17be69..9ca0a5c 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseAndValidateTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseAndValidateTest.java @@ -2,24 +2,24 @@ package io.spiffe.svid.jwtsvid; import com.nimbusds.jose.jwk.Curve; import com.nimbusds.jwt.JWTClaimsSet; +import io.spiffe.bundle.jwtbundle.JwtBundle; import io.spiffe.exception.AuthorityNotFoundException; import io.spiffe.exception.BundleNotFoundException; import io.spiffe.exception.JwtSvidException; +import io.spiffe.spiffeid.SpiffeId; +import io.spiffe.spiffeid.TrustDomain; +import io.spiffe.utils.TestUtils; import lombok.Builder; import lombok.Value; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import io.spiffe.bundle.jwtbundle.JwtBundle; -import io.spiffe.spiffeid.SpiffeId; -import io.spiffe.spiffeid.TrustDomain; -import io.spiffe.utils.TestUtils; import java.security.KeyPair; import java.util.Collections; import java.util.Date; -import java.util.List; +import java.util.Set; import java.util.function.Supplier; import java.util.stream.Stream; @@ -43,7 +43,7 @@ class JwtSvidParseAndValidateTest { assertEquals(testCase.expectedJwtSvid.getAudience(), jwtSvid.getAudience()); assertEquals(testCase.expectedJwtSvid.getExpiry().toInstant().getEpochSecond(), jwtSvid.getExpiry().toInstant().getEpochSecond()); assertEquals(token, jwtSvid.getToken()); - assertEquals(token, jwtSvid.marshall()); + assertEquals(token, jwtSvid.marshal()); } catch (Exception e) { assertEquals(testCase.expectedException.getClass(), e.getClass()); assertEquals(testCase.expectedException.getMessage(), e.getMessage()); @@ -54,7 +54,7 @@ class JwtSvidParseAndValidateTest { void testParseAndValidate_nullToken_throwsNullPointerException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException { TrustDomain trustDomain = TrustDomain.of("test.domain"); JwtBundle jwtBundle = new JwtBundle(trustDomain); - List audience = Collections.singletonList("audience"); + Set audience = Collections.singleton("audience"); try { JwtSvid.parseAndValidate(null, jwtBundle, audience); @@ -67,7 +67,7 @@ class JwtSvidParseAndValidateTest { void testParseAndValidate_emptyToken_throwsIllegalArgumentException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException { TrustDomain trustDomain = TrustDomain.of("test.domain"); JwtBundle jwtBundle = new JwtBundle(trustDomain); - List audience = Collections.singletonList("audience"); + Set audience = Collections.singleton("audience"); try { JwtSvid.parseAndValidate("", jwtBundle, audience); @@ -78,7 +78,7 @@ class JwtSvidParseAndValidateTest { @Test void testParseAndValidate_nullBundle_throwsNullPointerException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException { - List audience = Collections.singletonList("audience"); + Set audience = Collections.singleton("audience"); try { JwtSvid.parseAndValidate("token", null, audience); } catch (NullPointerException e) { @@ -90,7 +90,6 @@ class JwtSvidParseAndValidateTest { void testParseAndValidate_nullAudience_throwsNullPointerException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException { TrustDomain trustDomain = TrustDomain.of("test.domain"); JwtBundle jwtBundle = new JwtBundle(trustDomain); - List audience = Collections.singletonList("audience"); try { JwtSvid.parseAndValidate("token", jwtBundle, null); @@ -112,7 +111,7 @@ class JwtSvidParseAndValidateTest { SpiffeId spiffeId = trustDomain.newSpiffeId("host"); Date expiration = new Date(System.currentTimeMillis() + 3600000); - List audience = Collections.singletonList("audience"); + Set audience = Collections.singleton("audience"); JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration); @@ -179,7 +178,7 @@ class JwtSvidParseAndValidateTest { Arguments.of(TestCase.builder() .name("8. unexpected audience") .jwtBundle(jwtBundle) - .expectedAudience(Collections.singletonList("another")) + .expectedAudience(Collections.singleton("another")) .generateToken(() -> TestUtils.generateToken(claims, key1, "authority1")) .expectedException(new JwtSvidException("expected audience in [another] (audience=[audience])")) .build()), @@ -194,32 +193,39 @@ class JwtSvidParseAndValidateTest { .name("10. missing key id") .jwtBundle(jwtBundle) .expectedAudience(audience) - .generateToken(() -> TestUtils.generateToken(claims, key1, "")) + .generateToken(() -> TestUtils.generateToken(claims, key1, null)) .expectedException(new JwtSvidException("Token header missing key id")) .build()), Arguments.of(TestCase.builder() - .name("11. no bundle for trust domain") + .name("11. key id contains an empty value") + .jwtBundle(jwtBundle) + .expectedAudience(audience) + .generateToken(() -> TestUtils.generateToken(claims, key1, " ")) + .expectedException(new JwtSvidException("Token header key id contains an empty value")) + .build()), + Arguments.of(TestCase.builder() + .name("12. no bundle for trust domain") .jwtBundle(new JwtBundle(TrustDomain.of("other.domain"))) .expectedAudience(audience) .generateToken(() -> TestUtils.generateToken(claims, key1, "authority1")) .expectedException(new BundleNotFoundException("No JWT bundle found for trust domain test.domain")) .build()), Arguments.of(TestCase.builder() - .name("12. no authority found for key id") + .name("13. no authority found for key id") .jwtBundle(new JwtBundle(TrustDomain.of("test.domain"))) .expectedAudience(audience) .generateToken(() -> TestUtils.generateToken(claims, key1, "authority1")) .expectedException(new AuthorityNotFoundException("No authority found for the trust domain test.domain and key id authority1")) .build()), Arguments.of(TestCase.builder() - .name("13. signature cannot be verified with authority") + .name("14. signature cannot be verified with authority") .jwtBundle(jwtBundle) .expectedAudience(audience) .generateToken(() -> TestUtils.generateToken(claims, key2, "authority1")) .expectedException(new JwtSvidException("Signature invalid: cannot be verified with the authority with keyId=authority1")) .build()), Arguments.of(TestCase.builder() - .name("14. authority algorithm mismatch") + .name("15. authority algorithm mismatch") .jwtBundle(jwtBundle) .expectedAudience(audience) .generateToken(() -> TestUtils.generateToken(claims, key3, "authority1")) @@ -232,13 +238,13 @@ class JwtSvidParseAndValidateTest { static class TestCase { String name; JwtBundle jwtBundle; - List audience; + Set audience; Supplier generateToken; Exception expectedException; JwtSvid expectedJwtSvid; @Builder - public TestCase(String name, JwtBundle jwtBundle, List expectedAudience, Supplier generateToken, + public TestCase(String name, JwtBundle jwtBundle, Set expectedAudience, Supplier generateToken, Exception expectedException, JwtSvid expectedJwtSvid) { this.name = name; this.jwtBundle = jwtBundle; diff --git a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseInsecureTest.java b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseInsecureTest.java index 77179b5..8fe2a47 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseInsecureTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseInsecureTest.java @@ -2,22 +2,22 @@ package io.spiffe.svid.jwtsvid; import com.nimbusds.jose.jwk.Curve; import com.nimbusds.jwt.JWTClaimsSet; +import io.spiffe.bundle.jwtbundle.JwtBundle; import io.spiffe.exception.JwtSvidException; +import io.spiffe.spiffeid.SpiffeId; +import io.spiffe.spiffeid.TrustDomain; +import io.spiffe.utils.TestUtils; import lombok.Builder; import lombok.Value; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import io.spiffe.bundle.jwtbundle.JwtBundle; -import io.spiffe.spiffeid.SpiffeId; -import io.spiffe.spiffeid.TrustDomain; -import io.spiffe.utils.TestUtils; import java.security.KeyPair; import java.util.Collections; import java.util.Date; -import java.util.List; +import java.util.Set; import java.util.function.Supplier; import java.util.stream.Stream; @@ -46,7 +46,7 @@ class JwtSvidParseInsecureTest { @Test void testParseInsecure_nullToken_throwsNullPointerException() throws JwtSvidException { - List audience = Collections.singletonList("audience"); + Set audience = Collections.singleton("audience"); try { JwtSvid.parseInsecure(null, audience); @@ -57,7 +57,7 @@ class JwtSvidParseInsecureTest { @Test void testParseAndValidate_emptyToken_throwsIllegalArgumentException() throws JwtSvidException { - List audience = Collections.singletonList("audience"); + Set audience = Collections.singleton("audience"); try { JwtSvid.parseInsecure("", audience); } catch (IllegalArgumentException e) { @@ -71,7 +71,7 @@ class JwtSvidParseInsecureTest { KeyPair key1 = TestUtils.generateECKeyPair(Curve.P_521); TrustDomain trustDomain = TrustDomain.of("test.domain"); SpiffeId spiffeId = trustDomain.newSpiffeId("host"); - List audience = Collections.singletonList("audience"); + Set audience = Collections.singleton("audience"); Date expiration = new Date(System.currentTimeMillis() + 3600000); JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration); @@ -93,7 +93,7 @@ class JwtSvidParseInsecureTest { SpiffeId spiffeId = trustDomain.newSpiffeId("host"); Date expiration = new Date(System.currentTimeMillis() + 3600000); - List audience = Collections.singletonList("audience"); + Set audience = Collections.singleton("audience"); JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration); @@ -135,7 +135,7 @@ class JwtSvidParseInsecureTest { .build()), Arguments.of(TestCase.builder() .name("unexpected audience") - .expectedAudience(Collections.singletonList("another")) + .expectedAudience(Collections.singleton("another")) .generateToken(() -> TestUtils.generateToken(claims, key1, "authority1")) .expectedException(new JwtSvidException("expected audience in [another] (audience=[audience])")) .build()), @@ -151,13 +151,13 @@ class JwtSvidParseInsecureTest { @Value static class TestCase { String name; - List audience; + Set audience; Supplier generateToken; Exception expectedException; JwtSvid expectedJwtSvid; @Builder - public TestCase(String name, List expectedAudience, Supplier generateToken, + public TestCase(String name, Set expectedAudience, Supplier generateToken, Exception expectedException, JwtSvid expectedJwtSvid) { this.name = name; this.audience = expectedAudience; diff --git a/java-spiffe-core/src/test/java/io/spiffe/svid/x509svid/X509SvidTest.java b/java-spiffe-core/src/test/java/io/spiffe/svid/x509svid/X509SvidTest.java index c0e1a8e..8292f76 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/svid/x509svid/X509SvidTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/svid/x509svid/X509SvidTest.java @@ -264,11 +264,11 @@ public class X509SvidTest { } @Test - void testGetX509Svid() throws URISyntaxException, X509SvidException { + void testGetLeaf() throws URISyntaxException, X509SvidException { Path certPath = Paths.get(toUri(certSingle)); Path keyPath = Paths.get(toUri(keyRSA)); X509Svid x509Svid = X509Svid.load(certPath, keyPath); - assertEquals(x509Svid, x509Svid.getX509Svid()); + assertEquals(x509Svid.getChain().get(0), x509Svid.getLeaf()); } @Test diff --git a/java-spiffe-core/src/test/java/io/spiffe/svid/x509svid/X509SvidValidatorTest.java b/java-spiffe-core/src/test/java/io/spiffe/svid/x509svid/X509SvidValidatorTest.java index 448855e..d70dce6 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/svid/x509svid/X509SvidValidatorTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/svid/x509svid/X509SvidValidatorTest.java @@ -1,27 +1,28 @@ package io.spiffe.svid.x509svid; -import io.spiffe.exception.BundleNotFoundException; -import lombok.val; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import com.google.common.collect.Sets; import io.spiffe.bundle.x509bundle.X509Bundle; +import io.spiffe.exception.BundleNotFoundException; import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.TrustDomain; import io.spiffe.utils.X509CertificateTestUtils.CertAndKeyPair; +import lombok.val; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import java.io.IOException; import java.net.URISyntaxException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; -import static java.util.Collections.EMPTY_LIST; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; import static io.spiffe.utils.X509CertificateTestUtils.createCertificate; import static io.spiffe.utils.X509CertificateTestUtils.createRootCA; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; public class X509SvidValidatorTest { @@ -84,19 +85,19 @@ public class X509SvidValidatorTest { val spiffeId1 = SpiffeId.parse("spiffe://example.org/test"); val spiffeId2 = SpiffeId.parse("spiffe://example.org/test2"); - val spiffeIdList = Arrays.asList(spiffeId1, spiffeId2); + val spiffeIdSet = Sets.newHashSet(spiffeId1, spiffeId2); - X509SvidValidator.verifySpiffeId(leaf.getCertificate(), () -> spiffeIdList); + X509SvidValidator.verifySpiffeId(leaf.getCertificate(), () -> spiffeIdSet); } @Test void checkSpiffeId_givenASpiffeIdNotInTheListOfAcceptedIds_throwsCertificateException() throws IOException, CertificateException, URISyntaxException { val spiffeId1 = SpiffeId.parse("spiffe://example.org/other1"); val spiffeId2 = SpiffeId.parse("spiffe://example.org/other2"); - List spiffeIdList = Arrays.asList(spiffeId1, spiffeId2); + val spiffeIdSet = Sets.newHashSet(spiffeId1, spiffeId2); try { - X509SvidValidator.verifySpiffeId(leaf.getCertificate(), () -> spiffeIdList); + X509SvidValidator.verifySpiffeId(leaf.getCertificate(), () -> spiffeIdSet); fail("Should have thrown CertificateException"); } catch (CertificateException e) { assertEquals("SPIFFE ID spiffe://example.org/test in X.509 certificate is not accepted", e.getMessage()); @@ -106,7 +107,7 @@ public class X509SvidValidatorTest { @Test void checkSpiffeId_nullX509Certificate_throwsNullPointerException() throws CertificateException { try { - X509SvidValidator.verifySpiffeId(null, () -> EMPTY_LIST); + X509SvidValidator.verifySpiffeId(null, Collections::emptySet); fail("should have thrown an exception"); } catch (NullPointerException e) { assertEquals("x509Certificate is marked non-null but is null", e.getMessage()); @@ -119,7 +120,7 @@ public class X509SvidValidatorTest { X509SvidValidator.verifySpiffeId(leaf.getCertificate(), null); fail("should have thrown an exception"); } catch (NullPointerException e) { - assertEquals("acceptedSpiffedIdsSupplier is marked non-null but is null", e.getMessage()); + assertEquals("acceptedSpiffeIdsSupplier is marked non-null but is null", e.getMessage()); } } diff --git a/java-spiffe-core/src/test/java/io/spiffe/utils/TestUtils.java b/java-spiffe-core/src/test/java/io/spiffe/utils/TestUtils.java index 53634df..22d41e2 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/utils/TestUtils.java +++ b/java-spiffe-core/src/test/java/io/spiffe/utils/TestUtils.java @@ -17,9 +17,11 @@ import java.security.KeyPairGenerator; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.security.spec.ECGenParameterSpec; +import java.util.ArrayList; import java.util.Date; import java.util.List; import java.util.Map; +import java.util.Set; import static java.util.stream.Collectors.joining; import static java.util.stream.Stream.generate; @@ -87,11 +89,11 @@ public class TestUtils { } } - public static JWTClaimsSet buildJWTClaimSet(List audience, String spiffeId, Date expiration) { + public static JWTClaimsSet buildJWTClaimSet(Set audience, String spiffeId, Date expiration) { return new JWTClaimsSet.Builder() .subject(spiffeId) .expirationTime(expiration) - .audience(audience) + .audience(new ArrayList<>(audience)) .build(); } diff --git a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/FakeWorkloadApi.java b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/FakeWorkloadApi.java index 88e3775..5b512f8 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/FakeWorkloadApi.java +++ b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/FakeWorkloadApi.java @@ -129,7 +129,7 @@ class FakeWorkloadApi extends SpiffeWorkloadAPIImplBase { JwtSvid jwtSvid = null; try { - jwtSvid = JwtSvid.parseInsecure(token, Collections.singletonList(audience)); + jwtSvid = JwtSvid.parseInsecure(token, Collections.singleton(audience)); } catch (JwtSvidException e) { responseObserver.onError(new StatusRuntimeException(Status.INVALID_ARGUMENT.withDescription(e.getMessage()))); } diff --git a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/JwtSourceTest.java b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/JwtSourceTest.java index bcabe86..9535759 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/JwtSourceTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/JwtSourceTest.java @@ -1,28 +1,28 @@ package io.spiffe.workloadapi; +import com.google.common.collect.Sets; import io.grpc.ManagedChannel; import io.grpc.Server; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.testing.GrpcCleanupRule; +import io.spiffe.bundle.jwtbundle.JwtBundle; import io.spiffe.exception.BundleNotFoundException; import io.spiffe.exception.JwtSourceException; import io.spiffe.exception.JwtSvidException; import io.spiffe.exception.SocketEndpointAddressException; import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.TrustDomain; -import org.junit.Rule; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import io.spiffe.bundle.jwtbundle.JwtBundle; import io.spiffe.svid.jwtsvid.JwtSvid; import io.spiffe.workloadapi.grpc.SpiffeWorkloadAPIGrpc; import io.spiffe.workloadapi.internal.ManagedChannelWrapper; import io.spiffe.workloadapi.internal.SecurityHeaderInterceptor; +import org.junit.Rule; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import java.io.IOException; -import java.util.Arrays; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -98,7 +98,7 @@ class JwtSourceTest { JwtSvid svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); assertNotNull(svid); assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); - assertEquals(Arrays.asList("aud1", "aud2", "aud3"), svid.getAudience()); + assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience()); } catch (JwtSvidException e) { fail(e); } diff --git a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/WorkloadApiClientTest.java b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/WorkloadApiClientTest.java index ef64705..6a6cc7b 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/WorkloadApiClientTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/WorkloadApiClientTest.java @@ -38,6 +38,7 @@ import java.util.concurrent.Executors; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; class WorkloadApiClientTest { @@ -163,7 +164,7 @@ class WorkloadApiClientTest { JwtSvid jwtSvid = workloadApiClient.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); assertNotNull(jwtSvid); assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), jwtSvid.getSpiffeId()); - assertEquals("aud1", jwtSvid.getAudience().get(0)); + assertTrue(jwtSvid.getAudience().contains("aud1")); assertEquals(3, jwtSvid.getAudience().size()); } catch (JwtSvidException e) { fail(e); @@ -177,7 +178,7 @@ class WorkloadApiClientTest { JwtSvid jwtSvid = workloadApiClient.validateJwtSvid(token, "aud1"); assertNotNull(jwtSvid); assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), jwtSvid.getSpiffeId()); - assertEquals("aud1", jwtSvid.getAudience().get(0)); + assertTrue(jwtSvid.getAudience().contains("aud1")); assertEquals(1, jwtSvid.getAudience().size()); } catch (JwtSvidException e) { fail(e); diff --git a/java-spiffe-core/src/test/resources/testdata/spiffeid/spiffeIds.txt b/java-spiffe-core/src/test/resources/testdata/spiffeid/spiffeIds.txt index f293245..bed223d 100644 --- a/java-spiffe-core/src/test/resources/testdata/spiffeid/spiffeIds.txt +++ b/java-spiffe-core/src/test/resources/testdata/spiffeid/spiffeIds.txt @@ -1,3 +1,4 @@ spiffe://example.org/workload1 spiffe://example.org/workload2 spiffe://example2.org/workload1 +spiffe://example.org/workload1 diff --git a/java-spiffe-provider/README.md b/java-spiffe-provider/README.md index c3dfce5..e5362ef 100644 --- a/java-spiffe-provider/README.md +++ b/java-spiffe-provider/README.md @@ -33,11 +33,11 @@ Alternatively, a different Workload API address can be used by passing it to the X509Source x509Source = X509Source.newSource(sourceOptions); - Supplier> spiffeIdListSupplier = () -> Collections.singletonList(SpiffeId.parse("spiffe://example.org/test")); + Supplier> spiffeIdSetSupplier = () -> Collections.singleton(SpiffeId.parse("spiffe://example.org/test")); SslContextOptions sslContextOptions = SslContextOptions .builder() - .acceptedSpiffeIdsSupplier(spiffeIdListSupplier) + .acceptedSpiffeIdsSupplier(spiffeIdSetSupplier) .x509Source(x509Source) .build(); diff --git a/java-spiffe-provider/src/main/java/io/spiffe/provider/SpiffeSslContextFactory.java b/java-spiffe-provider/src/main/java/io/spiffe/provider/SpiffeSslContextFactory.java index 76bb578..0a83343 100644 --- a/java-spiffe-provider/src/main/java/io/spiffe/provider/SpiffeSslContextFactory.java +++ b/java-spiffe-provider/src/main/java/io/spiffe/provider/SpiffeSslContextFactory.java @@ -11,7 +11,7 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManager; import java.security.KeyManagementException; import java.security.NoSuchAlgorithmException; -import java.util.List; +import java.util.Set; import java.util.function.Supplier; /** @@ -72,14 +72,14 @@ public final class SpiffeSslContextFactory { public static class SslContextOptions { String sslProtocol; X509Source x509Source; - Supplier> acceptedSpiffeIdsSupplier; + Supplier> acceptedSpiffeIdsSupplier; boolean acceptAnySpiffeId; @Builder public SslContextOptions( final String sslProtocol, final X509Source x509Source, - final Supplier> acceptedSpiffeIdsSupplier, + final Supplier> acceptedSpiffeIdsSupplier, final boolean acceptAnySpiffeId) { this.x509Source = x509Source; this.acceptedSpiffeIdsSupplier = acceptedSpiffeIdsSupplier; diff --git a/java-spiffe-provider/src/main/java/io/spiffe/provider/SpiffeTrustManager.java b/java-spiffe-provider/src/main/java/io/spiffe/provider/SpiffeTrustManager.java index c7b42ce..3b6c420 100644 --- a/java-spiffe-provider/src/main/java/io/spiffe/provider/SpiffeTrustManager.java +++ b/java-spiffe-provider/src/main/java/io/spiffe/provider/SpiffeTrustManager.java @@ -1,9 +1,9 @@ package io.spiffe.provider; import io.spiffe.bundle.BundleSource; +import io.spiffe.bundle.x509bundle.X509Bundle; import io.spiffe.exception.BundleNotFoundException; import io.spiffe.spiffeid.SpiffeId; -import io.spiffe.bundle.x509bundle.X509Bundle; import io.spiffe.svid.x509svid.X509SvidValidator; import lombok.NonNull; @@ -12,9 +12,9 @@ import javax.net.ssl.X509ExtendedTrustManager; import java.net.Socket; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; +import java.util.Collections; +import java.util.Set; import java.util.function.Supplier; /** @@ -26,7 +26,7 @@ import java.util.function.Supplier; public final class SpiffeTrustManager extends X509ExtendedTrustManager { private final BundleSource x509BundleSource; - private final Supplier> acceptedSpiffeIdsSupplier; + private final Supplier> acceptedSpiffeIdsSupplier; private final boolean acceptAnySpiffeId; /** @@ -36,10 +36,10 @@ public final class SpiffeTrustManager extends X509ExtendedTrustManager { * and a {@link Supplier} of a List of accepted {@link SpiffeId} to be used during peer SVID validation. * * @param x509BundleSource an implementation of a {@link BundleSource} - * @param acceptedSpiffeIdsSupplier a {@link Supplier} of a list of accepted SPIFFE IDs. + * @param acceptedSpiffeIdsSupplier a {@link Supplier} of a Set of accepted SPIFFE IDs. */ public SpiffeTrustManager(@NonNull final BundleSource x509BundleSource, - @NonNull final Supplier> acceptedSpiffeIdsSupplier) { + @NonNull final Supplier> acceptedSpiffeIdsSupplier) { this.x509BundleSource = x509BundleSource; this.acceptedSpiffeIdsSupplier = acceptedSpiffeIdsSupplier; this.acceptAnySpiffeId = false; @@ -57,7 +57,7 @@ public final class SpiffeTrustManager extends X509ExtendedTrustManager { public SpiffeTrustManager(@NonNull final BundleSource x509BundleSource, final boolean acceptAnySpiffeId) { this.x509BundleSource = x509BundleSource; - this.acceptedSpiffeIdsSupplier = ArrayList::new; + this.acceptedSpiffeIdsSupplier = Collections::emptySet; this.acceptAnySpiffeId = acceptAnySpiffeId; } diff --git a/java-spiffe-provider/src/main/java/io/spiffe/provider/SpiffeTrustManagerFactory.java b/java-spiffe-provider/src/main/java/io/spiffe/provider/SpiffeTrustManagerFactory.java index d265263..80e7e8c 100644 --- a/java-spiffe-provider/src/main/java/io/spiffe/provider/SpiffeTrustManagerFactory.java +++ b/java-spiffe-provider/src/main/java/io/spiffe/provider/SpiffeTrustManagerFactory.java @@ -1,19 +1,19 @@ package io.spiffe.provider; import io.spiffe.bundle.BundleSource; +import io.spiffe.bundle.x509bundle.X509Bundle; import io.spiffe.exception.SocketEndpointAddressException; import io.spiffe.exception.X509SourceException; import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.SpiffeIdUtils; import io.spiffe.workloadapi.X509Source; import lombok.NonNull; -import io.spiffe.bundle.x509bundle.X509Bundle; import javax.net.ssl.ManagerFactoryParameters; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactorySpi; import java.security.KeyStore; -import java.util.List; +import java.util.Set; import java.util.function.Supplier; import static io.spiffe.provider.SpiffeProviderConstants.SSL_SPIFFE_ACCEPT_ALL_PROPERTY; @@ -35,11 +35,11 @@ import static io.spiffe.provider.SpiffeProviderConstants.SSL_SPIFFE_ACCEPT_PROPE public class SpiffeTrustManagerFactory extends TrustManagerFactorySpi { private static final boolean ACCEPT_ANY_SPIFFE_ID; - private static final Supplier> DEFAULT_SPIFFE_ID_LIST_SUPPLIER; + private static final Supplier> DEFAULT_SPIFFE_ID_LIST_SUPPLIER; static { ACCEPT_ANY_SPIFFE_ID = Boolean.parseBoolean(EnvironmentUtils.getProperty(SSL_SPIFFE_ACCEPT_ALL_PROPERTY, "false")); - DEFAULT_SPIFFE_ID_LIST_SUPPLIER = () -> SpiffeIdUtils.toListOfSpiffeIds(EnvironmentUtils.getProperty(SSL_SPIFFE_ACCEPT_PROPERTY)); + DEFAULT_SPIFFE_ID_LIST_SUPPLIER = () -> SpiffeIdUtils.toSetOfSpiffeIds(EnvironmentUtils.getProperty(SSL_SPIFFE_ACCEPT_PROPERTY)); } /** @@ -113,12 +113,12 @@ public class SpiffeTrustManagerFactory extends TrustManagerFactorySpi { * and a supplier of accepted SPIFFE IDs. * * @param x509BundleSource a {@link BundleSource} to provide the X.509-Bundles - * @param acceptedSpiffeIdsSupplier a Supplier to provide a List of SPIFFE IDs that are accepted + * @param acceptedSpiffeIdsSupplier a Supplier to provide a Set of SPIFFE IDs that are accepted * @return an instance of a {@link TrustManager} wrapped in an array. The actual type returned is {@link SpiffeTrustManager} */ public TrustManager[] engineGetTrustManagers( @NonNull final BundleSource x509BundleSource, - @NonNull final Supplier> acceptedSpiffeIdsSupplier) { + @NonNull final Supplier> acceptedSpiffeIdsSupplier) { SpiffeTrustManager spiffeTrustManager = new SpiffeTrustManager(x509BundleSource, acceptedSpiffeIdsSupplier); return new TrustManager[]{spiffeTrustManager}; diff --git a/java-spiffe-provider/src/test/java/io/spiffe/provider/SpiffeTrustManagerTest.java b/java-spiffe-provider/src/test/java/io/spiffe/provider/SpiffeTrustManagerTest.java index e5fb87c..ec4497a 100644 --- a/java-spiffe-provider/src/test/java/io/spiffe/provider/SpiffeTrustManagerTest.java +++ b/java-spiffe-provider/src/test/java/io/spiffe/provider/SpiffeTrustManagerTest.java @@ -1,18 +1,18 @@ package io.spiffe.provider; import io.spiffe.bundle.BundleSource; +import io.spiffe.bundle.x509bundle.X509Bundle; import io.spiffe.exception.BundleNotFoundException; import io.spiffe.exception.X509SvidException; import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.TrustDomain; +import io.spiffe.svid.x509svid.X509Svid; import lombok.val; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import io.spiffe.bundle.x509bundle.X509Bundle; -import io.spiffe.svid.x509svid.X509Svid; import javax.net.ssl.X509TrustManager; import java.io.IOException; @@ -22,7 +22,7 @@ import java.nio.file.Paths; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Collections; -import java.util.List; +import java.util.Set; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; @@ -36,7 +36,7 @@ public class SpiffeTrustManagerTest { static X509Bundle x509Bundle; static X509Svid x509Svid; static X509Svid otherX509Svid; - List acceptedSpiffeIds; + Set acceptedSpiffeIds; X509TrustManager trustManager; @BeforeAll @@ -59,19 +59,12 @@ public class SpiffeTrustManagerTest { void setupMocks() { MockitoAnnotations.initMocks(this); trustManager = (X509TrustManager) - new SpiffeTrustManagerFactory() - .engineGetTrustManagers( - bundleSource, - () -> acceptedSpiffeIds)[0]; + new SpiffeTrustManagerFactory().engineGetTrustManagers(bundleSource, () -> acceptedSpiffeIds)[0]; } @Test void checkClientTrusted_passAExpiredCertificate_throwsException() throws BundleNotFoundException { - acceptedSpiffeIds = - Collections - .singletonList( - SpiffeId.parse("spiffe://example.org/test") - ); + acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/test")); val chain = x509Svid.getChainArray(); @@ -87,11 +80,7 @@ public class SpiffeTrustManagerTest { @Test void checkClientTrusted_noBundleForTrustDomain_ThrowCertificateException() throws BundleNotFoundException { - acceptedSpiffeIds = - Collections - .singletonList( - SpiffeId.parse("spiffe://example.org/test") - ); + acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/test")); val chain = x509Svid.getChainArray(); @@ -107,11 +96,7 @@ public class SpiffeTrustManagerTest { @Test void checkClientTrusted_passCertificateWithNonAcceptedSpiffeId_ThrowCertificateException() throws BundleNotFoundException { - acceptedSpiffeIds = - Collections - .singletonList( - SpiffeId.parse("spiffe://example.org/other") - ); + acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/other")); X509Certificate[] chain = x509Svid.getChainArray(); @@ -128,11 +113,7 @@ public class SpiffeTrustManagerTest { @Test void checkClientTrusted_passCertificateThatDoesntChainToBundle_ThrowCertificateException() throws BundleNotFoundException { - acceptedSpiffeIds = - Collections - .singletonList( - SpiffeId.parse("spiffe://other.org/test") - ); + acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://other.org/test")); val chain = otherX509Svid.getChainArray(); @@ -148,11 +129,7 @@ public class SpiffeTrustManagerTest { @Test void checkServerTrusted_passAnExpiredCertificate_ThrowsException() throws BundleNotFoundException { - acceptedSpiffeIds = - Collections - .singletonList( - SpiffeId.parse("spiffe://example.org/test") - ); + acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/test")); val chain = x509Svid.getChainArray(); @@ -168,11 +145,7 @@ public class SpiffeTrustManagerTest { @Test void checkServerTrusted_passCertificateWithNonAcceptedSpiffeId_ThrowCertificateException() throws BundleNotFoundException { - acceptedSpiffeIds = - Collections - .singletonList( - SpiffeId.parse("spiffe://example.org/other") - ); + acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/other")); val chain = x509Svid.getChainArray(); @@ -188,11 +161,7 @@ public class SpiffeTrustManagerTest { @Test void checkServerTrusted_passCertificateThatDoesntChainToBundle_ThrowCertificateException() throws BundleNotFoundException { - acceptedSpiffeIds = - Collections - .singletonList( - SpiffeId.parse("spiffe://other.org/test") - ); + acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://other.org/test")); val chain = otherX509Svid.getChainArray(); diff --git a/java-spiffe-provider/src/test/java/io/spiffe/provider/examples/mtls/HttpsClient.java b/java-spiffe-provider/src/test/java/io/spiffe/provider/examples/mtls/HttpsClient.java index 8452a1a..8a5b92d 100644 --- a/java-spiffe-provider/src/test/java/io/spiffe/provider/examples/mtls/HttpsClient.java +++ b/java-spiffe-provider/src/test/java/io/spiffe/provider/examples/mtls/HttpsClient.java @@ -3,13 +3,13 @@ package io.spiffe.provider.examples.mtls; import io.spiffe.exception.SocketEndpointAddressException; import io.spiffe.exception.X509SourceException; import io.spiffe.provider.SpiffeKeyManager; +import io.spiffe.provider.SpiffeSslContextFactory; +import io.spiffe.provider.SpiffeSslContextFactory.SslContextOptions; import io.spiffe.provider.SpiffeTrustManager; import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.SpiffeIdUtils; import io.spiffe.workloadapi.X509Source; import lombok.val; -import io.spiffe.provider.SpiffeSslContextFactory; -import io.spiffe.provider.SpiffeSslContextFactory.SslContextOptions; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocket; @@ -20,7 +20,7 @@ import java.net.URISyntaxException; import java.nio.file.Paths; import java.security.KeyManagementException; import java.security.NoSuchAlgorithmException; -import java.util.List; +import java.util.Set; import java.util.function.Supplier; /** @@ -35,12 +35,12 @@ import java.util.function.Supplier; public class HttpsClient { String spiffeSocket; - Supplier> acceptedSpiffeIdsListSupplier; + Supplier> acceptedSpiffeIdsSetSupplier; int serverPort; public static void main(String[] args) { String spiffeSocket = "unix:/tmp/agent.sock"; - HttpsClient httpsClient = new HttpsClient(4000, spiffeSocket, () -> new AcceptedSpiffeIds().getList()); + HttpsClient httpsClient = new HttpsClient(4000, spiffeSocket, () -> new AcceptedSpiffeIds().getSet()); try { httpsClient.run(); } catch (KeyManagementException | NoSuchAlgorithmException | IOException | SocketEndpointAddressException | X509SourceException e) { @@ -48,10 +48,10 @@ public class HttpsClient { } } - HttpsClient(int serverPort, String spiffeSocket, Supplier> acceptedSpiffeIdsListSupplier) { + HttpsClient(int serverPort, String spiffeSocket, Supplier> acceptedSpiffeIdsSetSupplier) { this.serverPort = serverPort; this.spiffeSocket = spiffeSocket; - this.acceptedSpiffeIdsListSupplier = acceptedSpiffeIdsListSupplier; + this.acceptedSpiffeIdsSetSupplier = acceptedSpiffeIdsSetSupplier; } void run() throws IOException, SocketEndpointAddressException, KeyManagementException, NoSuchAlgorithmException, X509SourceException { @@ -64,7 +64,7 @@ public class HttpsClient { SslContextOptions sslContextOptions = SslContextOptions .builder() - .acceptedSpiffeIdsSupplier(acceptedSpiffeIdsListSupplier) + .acceptedSpiffeIdsSupplier(acceptedSpiffeIdsSetSupplier) .x509Source(x509Source) .build(); SSLContext sslContext = SpiffeSslContextFactory.getSslContext(sslContextOptions); @@ -76,9 +76,9 @@ public class HttpsClient { } private static class AcceptedSpiffeIds { - List getList() { + Set getSet() { try { - return SpiffeIdUtils.getSpiffeIdListFromFile(Paths.get(toUri("testdata/spiffeIds.txt"))); + return SpiffeIdUtils.getSpiffeIdSetFromFile(Paths.get(toUri("testdata/spiffeIds.txt"))); } catch (IOException | URISyntaxException e) { throw new RuntimeException("Error getting list of spiffeIds", e); }