Addressing PR comments:

- refactor acceptedSpiffeIds from List to Set
- refactor tests
- renaming methods to improve clarity
- amendments in javadocs

Signed-off-by: Max Lambrecht <maxlambrecht@gmail.com>
This commit is contained in:
Max Lambrecht 2020-06-23 11:26:00 -03:00
parent dbfb09f0f8
commit ca5511eb91
21 changed files with 228 additions and 226 deletions

View File

@ -7,7 +7,7 @@ import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -30,11 +30,11 @@ public class SpiffeIdUtils {
* @throws IOException if the given spiffeIdsFile cannot be read * @throws IOException if the given spiffeIdsFile cannot be read
* @throws IllegalArgumentException if any of the SPIFFE IDs in the file cannot be parsed * @throws IllegalArgumentException if any of the SPIFFE IDs in the file cannot be parsed
*/ */
public static List<SpiffeId> getSpiffeIdListFromFile(final Path spiffeIdsFile) throws IOException { public static Set<SpiffeId> getSpiffeIdSetFromFile(final Path spiffeIdsFile) throws IOException {
try (Stream<String> lines = Files.lines(spiffeIdsFile)) { try (Stream<String> lines = Files.lines(spiffeIdsFile)) {
return lines return lines
.map(SpiffeId::parse) .map(SpiffeId::parse)
.collect(Collectors.toList()); .collect(Collectors.toSet());
} }
} }
@ -47,15 +47,15 @@ public class SpiffeIdUtils {
* @return a list of {@link SpiffeId} instances. * @return a list of {@link SpiffeId} instances.
* @throws IllegalArgumentException is the string provided is blank * @throws IllegalArgumentException is the string provided is blank
*/ */
public static List<SpiffeId> toListOfSpiffeIds(final String spiffeIds, final char separator) { public static Set<SpiffeId> toSetOfSpiffeIds(final String spiffeIds, final char separator) {
if (isBlank(spiffeIds)) { if (isBlank(spiffeIds)) {
return Collections.emptyList(); return Collections.emptySet();
} }
val array = spiffeIds.split(String.valueOf(separator)); val array = spiffeIds.split(String.valueOf(separator));
return Arrays.stream(array) return Arrays.stream(array)
.map(SpiffeId::parse) .map(SpiffeId::parse)
.collect(Collectors.toList()); .collect(Collectors.toSet());
} }
/** /**
@ -65,8 +65,8 @@ public class SpiffeIdUtils {
* @return a list of {@link SpiffeId} instances * @return a list of {@link SpiffeId} instances
* @throws IllegalArgumentException is the string provided is blank * @throws IllegalArgumentException is the string provided is blank
*/ */
public static List<SpiffeId> toListOfSpiffeIds(final String spiffeIds) { public static Set<SpiffeId> toSetOfSpiffeIds(final String spiffeIds) {
return toListOfSpiffeIds(spiffeIds, DEFAULT_CHAR_SEPARATOR); return toSetOfSpiffeIds(spiffeIds, DEFAULT_CHAR_SEPARATOR);
} }
private SpiffeIdUtils() { private SpiffeIdUtils() {

View File

@ -24,8 +24,9 @@ import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPublicKey; import java.security.interfaces.RSAPublicKey;
import java.text.ParseException; import java.text.ParseException;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set;
/** /**
* Represents a SPIFFE JWT-SVID. * Represents a SPIFFE JWT-SVID.
@ -41,7 +42,7 @@ public class JwtSvid {
/** /**
* Audience is the intended recipients of JWT-SVID as present in the 'aud' claim. * Audience is the intended recipients of JWT-SVID as present in the 'aud' claim.
*/ */
List<String> audience; Set<String> audience;
/** /**
* Expiration time of JWT-SVID as present in 'exp' claim. * Expiration time of JWT-SVID as present in 'exp' claim.
@ -68,7 +69,7 @@ public class JwtSvid {
* @param token the token as string * @param token the token as string
*/ */
JwtSvid(@NonNull final SpiffeId spiffeId, JwtSvid(@NonNull final SpiffeId spiffeId,
@NonNull final List<String> audience, @NonNull final Set<String> audience,
@NonNull final Date expiry, @NonNull final Date expiry,
@NonNull final Map<String, Object> claims, @NonNull final Map<String, Object> claims,
@NonNull final String token) { @NonNull final String token) {
@ -99,24 +100,29 @@ public class JwtSvid {
*/ */
public static JwtSvid parseAndValidate(@NonNull final String token, public static JwtSvid parseAndValidate(@NonNull final String token,
@NonNull final BundleSource<JwtBundle> jwtBundleSource, @NonNull final BundleSource<JwtBundle> jwtBundleSource,
@NonNull final List<String> audience) @NonNull final Set<String> audience)
throws JwtSvidException, BundleNotFoundException, AuthorityNotFoundException { throws JwtSvidException, BundleNotFoundException, AuthorityNotFoundException {
if (StringUtils.isBlank(token)) { if (StringUtils.isBlank(token)) {
throw new IllegalArgumentException("Token cannot be blank"); throw new IllegalArgumentException("Token cannot be blank");
} }
final SignedJWT signedJwt;
final JWTClaimsSet claimsSet;
try { try {
val signedJwt = SignedJWT.parse(token); signedJwt = SignedJWT.parse(token);
val claimsSet = signedJwt.getJWTClaimsSet(); claimsSet = signedJwt.getJWTClaimsSet();
} catch (ParseException e) {
throw new IllegalArgumentException("Unable to parse JWT token", e);
}
List<String> claimAudience = claimsSet.getAudience(); Set<String> claimAudience = new HashSet<>(claimsSet.getAudience());
validateAudience(claimAudience, audience); validateAudience(claimAudience, audience);
val expirationTime = claimsSet.getExpirationTime(); val expirationTime = claimsSet.getExpirationTime();
validateExpiration(expirationTime); validateExpiration(expirationTime);
val spiffeId = getSpiffeId(claimsSet); val spiffeId = getSpiffeIdOfSubject(claimsSet);
val jwtBundle = jwtBundleSource.getBundleForTrustDomain(spiffeId.getTrustDomain()); val jwtBundle = jwtBundleSource.getBundleForTrustDomain(spiffeId.getTrustDomain());
val keyId = getKeyId(signedJwt.getHeader()); val keyId = getKeyId(signedJwt.getHeader());
@ -126,9 +132,6 @@ public class JwtSvid {
verifySignature(signedJwt, jwtAuthority, algorithm, keyId); verifySignature(signedJwt, jwtAuthority, algorithm, keyId);
return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token); return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token);
} catch (ParseException e) {
throw new IllegalArgumentException("Unable to parse JWT token", e);
}
} }
/** /**
@ -144,27 +147,29 @@ public class JwtSvid {
* the 'aud' has an audience that is not in the audience provided as parameter * the 'aud' has an audience that is not in the audience provided as parameter
* @throws IllegalArgumentException when the token cannot be parsed * @throws IllegalArgumentException when the token cannot be parsed
*/ */
public static JwtSvid parseInsecure(@NonNull final String token, @NonNull final List<String> audience) throws JwtSvidException { public static JwtSvid parseInsecure(@NonNull final String token, @NonNull final Set<String> audience) throws JwtSvidException {
if (StringUtils.isBlank(token)) { if (StringUtils.isBlank(token)) {
throw new IllegalArgumentException("Token cannot be blank"); throw new IllegalArgumentException("Token cannot be blank");
} }
final SignedJWT signedJwt;
final JWTClaimsSet claimsSet;
try { try {
val signedJwt = SignedJWT.parse(token); signedJwt = SignedJWT.parse(token);
val claimsSet = signedJwt.getJWTClaimsSet(); claimsSet = signedJwt.getJWTClaimsSet();
} catch (ParseException e) {
throw new IllegalArgumentException("Unable to parse JWT token", e);
}
List<String> claimAudience = claimsSet.getAudience(); Set<String> claimAudience = new HashSet<>(claimsSet.getAudience());
validateAudience(claimAudience, audience); validateAudience(claimAudience, audience);
val expirationTime = claimsSet.getExpirationTime(); val expirationTime = claimsSet.getExpirationTime();
validateExpiration(expirationTime); validateExpiration(expirationTime);
val spiffeId = getSpiffeId(claimsSet); val spiffeId = getSpiffeIdOfSubject(claimsSet);
return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token); return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token);
} catch (ParseException e) {
throw new IllegalArgumentException("Unable to parse JWT token", e);
}
} }
/** /**
@ -173,7 +178,7 @@ public class JwtSvid {
* *
* @return the token as String * @return the token as String
*/ */
public String marshall() { public String marshal() {
return token; return token;
} }
@ -202,9 +207,10 @@ public class JwtSvid {
private static JWSVerifier getJwsVerifier(final PublicKey jwtAuthority, final String algorithm) throws JOSEException, JwtSvidException { private static JWSVerifier getJwsVerifier(final PublicKey jwtAuthority, final String algorithm) throws JOSEException, JwtSvidException {
JWSVerifier verifier; JWSVerifier verifier;
if (Algorithm.Family.EC.contains(Algorithm.parse(algorithm))) { final Algorithm alg = Algorithm.parse(algorithm);
if (Algorithm.Family.EC.contains(alg)) {
verifier = new ECDSAVerifier((ECPublicKey) jwtAuthority); verifier = new ECDSAVerifier((ECPublicKey) jwtAuthority);
} else if (Algorithm.Family.RSA.contains(Algorithm.parse(algorithm))) { } else if (Algorithm.Family.RSA.contains(alg)) {
verifier = new RSASSAVerifier((RSAPublicKey) jwtAuthority); verifier = new RSASSAVerifier((RSAPublicKey) jwtAuthority);
} else { } else {
throw new JwtSvidException(String.format("Unsupported token signature algorithm %s", algorithm)); throw new JwtSvidException(String.format("Unsupported token signature algorithm %s", algorithm));
@ -214,9 +220,12 @@ public class JwtSvid {
private static String getKeyId(final JWSHeader header) throws JwtSvidException { private static String getKeyId(final JWSHeader header) throws JwtSvidException {
val keyId = header.getKeyID(); val keyId = header.getKeyID();
if (StringUtils.isBlank(keyId)) { if (keyId == null) {
throw new JwtSvidException("Token header missing key id"); throw new JwtSvidException("Token header missing key id");
} }
if (StringUtils.isBlank(keyId)) {
throw new JwtSvidException("Token header key id contains an empty value");
}
return keyId; return keyId;
} }
@ -230,7 +239,7 @@ public class JwtSvid {
} }
} }
private static SpiffeId getSpiffeId(final JWTClaimsSet claimsSet) throws JwtSvidException { private static SpiffeId getSpiffeIdOfSubject(final JWTClaimsSet claimsSet) throws JwtSvidException {
val subject = claimsSet.getSubject(); val subject = claimsSet.getSubject();
if (StringUtils.isBlank(subject)) { if (StringUtils.isBlank(subject)) {
throw new JwtSvidException("Token missing subject claim"); throw new JwtSvidException("Token missing subject claim");
@ -244,7 +253,7 @@ public class JwtSvid {
} }
private static void validateAudience(final List<String> audClaim, final List<String> audience) throws JwtSvidException { private static void validateAudience(final Set<String> audClaim, final Set<String> audience) throws JwtSvidException {
for (String aud : audClaim) { for (String aud : audClaim) {
if (!audience.contains(aud)) { if (!audience.contains(aud)) {
throw new JwtSvidException(String.format("expected audience in %s (audience=%s)", audience, audClaim)); throw new JwtSvidException(String.format("expected audience in %s (audience=%s)", audience, audClaim));

View File

@ -26,7 +26,7 @@ import java.util.List;
* Contains a SPIFFE ID, a private key and a chain of X.509 certificates. * Contains a SPIFFE ID, a private key and a chain of X.509 certificates.
*/ */
@Value @Value
public class X509Svid implements X509SvidSource { public class X509Svid {
SpiffeId spiffeId; SpiffeId spiffeId;
@ -48,8 +48,17 @@ public class X509Svid implements X509SvidSource {
this.privateKey = privateKey; this.privateKey = privateKey;
} }
/**
* @return the Leaf Certificate of the chain
*/
public X509Certificate getLeaf() {
return chain.get(0);
}
/** /**
* Loads the X.509 SVID from PEM encoded files on disk. * Loads the X.509 SVID from PEM encoded files on disk.
* <p>
* It is assumed that the leaf certificate is always the first certificate in the parsed chain.
* *
* @param certsFilePath path to X.509 certificate chain file * @param certsFilePath path to X.509 certificate chain file
* @param privateKeyFilePath path to private key file * @param privateKeyFilePath path to private key file
@ -76,6 +85,8 @@ public class X509Svid implements X509SvidSource {
/** /**
* Parses the X.509 SVID from PEM or DER blocks containing certificate chain and key * Parses the X.509 SVID from PEM or DER blocks containing certificate chain and key
* bytes. The key must be a PEM or DER block with PKCS#8. * bytes. The key must be a PEM or DER block with PKCS#8.
* <p>
* It is assumed that the leaf certificate is always the first certificate in the parsed chain.
* *
* @param certsBytes chain of certificates as a byte array * @param certsBytes chain of certificates as a byte array
* @param privateKeyBytes private key as byte array * @param privateKeyBytes private key as byte array
@ -89,6 +100,8 @@ public class X509Svid implements X509SvidSource {
/** /**
* Parses the X509-SVID from certificate and key bytes. The certificate must be ASN.1 DER (concatenated with * Parses the X509-SVID from certificate and key bytes. The certificate must be ASN.1 DER (concatenated with
* no intermediate padding if there are more than one certificate). The key must be a PKCS#8 ASN.1 DER. * no intermediate padding if there are more than one certificate). The key must be a PKCS#8 ASN.1 DER.
* <p>
* It is assumed that the leaf certificate is always the first certificate in the parsed chain.
* *
* @param certsBytes chain of certificates as a byte array * @param certsBytes chain of certificates as a byte array
* @param privateKeyBytes private key as byte array * @param privateKeyBytes private key as byte array
@ -106,14 +119,6 @@ public class X509Svid implements X509SvidSource {
return chain.toArray(new X509Certificate[0]); return chain.toArray(new X509Certificate[0]);
} }
/**
* @return this instance, implementing a X509Svid interface.
*/
@Override
public X509Svid getX509Svid() {
return this;
}
private static X509Svid createX509Svid(final byte[] certsBytes, final byte[] privateKeyBytes, KeyFileFormat keyFileFormat) throws X509SvidException { private static X509Svid createX509Svid(final byte[] certsBytes, final byte[] privateKeyBytes, KeyFileFormat keyFileFormat) throws X509SvidException {
List<X509Certificate> x509Certificates; List<X509Certificate> x509Certificates;
@ -142,14 +147,19 @@ public class X509Svid implements X509SvidSource {
validateLeafCertificate(x509Certificates.get(0)); validateLeafCertificate(x509Certificates.get(0));
if (x509Certificates.size() > 1) { if (x509Certificates.size() > 1) {
validateSigningCertificates(x509Certificates.subList(1, x509Certificates.size())); validateSigningCertificates(x509Certificates);
} }
return new X509Svid(spiffeId, x509Certificates, privateKey); return new X509Svid(spiffeId, x509Certificates, privateKey);
} }
private static void validateSigningCertificates(final List<X509Certificate> certificates) throws X509SvidException { private static void validateSigningCertificates(final List<X509Certificate> certificates) throws X509SvidException {
for (X509Certificate cert : certificates) { for (int i = 1; i < certificates.size(); i++) {
verifyCaCert(certificates.get(i));
}
}
private static void verifyCaCert(final X509Certificate cert) throws X509SvidException {
if (!CertificateUtils.isCA(cert)) { if (!CertificateUtils.isCA(cert)) {
throw new X509SvidException("Signing certificate must have CA flag set to true"); throw new X509SvidException("Signing certificate must have CA flag set to true");
} }
@ -157,16 +167,15 @@ public class X509Svid implements X509SvidSource {
throw new X509SvidException("Signing certificate must have 'keyCertSign' as key usage"); throw new X509SvidException("Signing certificate must have 'keyCertSign' as key usage");
} }
} }
}
private static void validateLeafCertificate(final X509Certificate leaf) throws X509SvidException { private static void validateLeafCertificate(final X509Certificate leaf) throws X509SvidException {
if (CertificateUtils.isCA(leaf)) { if (CertificateUtils.isCA(leaf)) {
throw new X509SvidException("Leaf certificate must not have CA flag set to true"); throw new X509SvidException("Leaf certificate must not have CA flag set to true");
} }
validateKeyUsage(leaf); validateKeyUsageOfLeafCertificate(leaf);
} }
private static void validateKeyUsage(final X509Certificate leaf) throws X509SvidException { private static void validateKeyUsageOfLeafCertificate(final X509Certificate leaf) throws X509SvidException {
if (!CertificateUtils.hasKeyUsageDigitalSignature(leaf)) { if (!CertificateUtils.hasKeyUsageDigitalSignature(leaf)) {
throw new X509SvidException("Leaf certificate must have 'digitalSignature' as key usage"); throw new X509SvidException("Leaf certificate must have 'digitalSignature' as key usage");
} }

View File

@ -1,18 +1,19 @@
package io.spiffe.svid.x509svid; package io.spiffe.svid.x509svid;
import io.spiffe.bundle.BundleSource; import io.spiffe.bundle.BundleSource;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.exception.BundleNotFoundException; import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.internal.CertificateUtils; import io.spiffe.internal.CertificateUtils;
import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.SpiffeId;
import lombok.NonNull; import lombok.NonNull;
import lombok.val; import lombok.val;
import io.spiffe.bundle.x509bundle.X509Bundle;
import java.security.cert.CertPathValidatorException; import java.security.cert.CertPathValidatorException;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.function.Supplier; import java.util.function.Supplier;
/** /**
@ -48,17 +49,17 @@ public class X509SvidValidator {
* Checks that the X.509 SVID provided has a SPIFFE ID that is in the list of accepted SPIFFE IDs supplied. * Checks that the X.509 SVID provided has a SPIFFE ID that is in the list of accepted SPIFFE IDs supplied.
* *
* @param x509Certificate a {@link X509Svid} with a SPIFFE ID to be verified * @param x509Certificate a {@link X509Svid} with a SPIFFE ID to be verified
* @param acceptedSpiffedIdsSupplier a {@link Supplier} of a list os SPIFFE IDs that are accepted * @param acceptedSpiffeIdsSupplier a {@link Supplier} of a Set of SPIFFE IDs that are accepted
* @throws CertificateException is the SPIFFE ID in x509Certificate is not in the list supplied by acceptedSpiffedIdsSupplier, * @throws CertificateException if the SPIFFE ID in x509Certificate is not in the Set supplied by acceptedSpiffeIdsSupplier,
* or if the SPIFFE ID cannot be parsed from the x509Certificate * or if the SPIFFE ID cannot be parsed from the x509Certificate
* @throws NullPointerException if the given x509Certificate or acceptedSpiffedIdsSupplier are null * @throws NullPointerException if the given x509Certificate or acceptedSpiffeIdsSupplier are null
*/ */
public static void verifySpiffeId(@NonNull final X509Certificate x509Certificate, public static void verifySpiffeId(@NonNull final X509Certificate x509Certificate,
@NonNull final Supplier<List<SpiffeId>> acceptedSpiffedIdsSupplier) @NonNull final Supplier<Set<SpiffeId>> acceptedSpiffeIdsSupplier)
throws CertificateException { throws CertificateException {
val spiffeIdList = acceptedSpiffedIdsSupplier.get(); val spiffeIdSet = acceptedSpiffeIdsSupplier.get();
val spiffeId = CertificateUtils.getSpiffeId(x509Certificate); val spiffeId = CertificateUtils.getSpiffeId(x509Certificate);
if (!spiffeIdList.contains(spiffeId)) { if (!spiffeIdSet.contains(spiffeId)) {
throw new CertificateException(String.format("SPIFFE ID %s in X.509 certificate is not accepted", spiffeId)); throw new CertificateException(String.format("SPIFFE ID %s in X.509 certificate is not accepted", spiffeId));
} }
} }

View File

@ -32,8 +32,10 @@ import java.security.KeyException;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
@ -192,7 +194,7 @@ public class WorkloadApiClient implements Closeable {
* @throws JwtSvidException if there is an error fetching or processing the JWT from the Workload API * @throws JwtSvidException if there is an error fetching or processing the JWT from the Workload API
*/ */
public JwtSvid fetchJwtSvid(@NonNull final SpiffeId subject, @NonNull final String audience, final String... extraAudience) throws JwtSvidException { public JwtSvid fetchJwtSvid(@NonNull final SpiffeId subject, @NonNull final String audience, final String... extraAudience) throws JwtSvidException {
List<String> audParam = new ArrayList<>(); Set<String> audParam = new HashSet<>();
audParam.add(audience); audParam.add(audience);
Collections.addAll(audParam, extraAudience); Collections.addAll(audParam, extraAudience);
@ -238,7 +240,7 @@ public class WorkloadApiClient implements Closeable {
throw new JwtSvidException("Error validating JWT SVID", e); throw new JwtSvidException("Error validating JWT SVID", e);
} }
return JwtSvid.parseInsecure(token, Collections.singletonList(audience)); return JwtSvid.parseInsecure(token, Collections.singleton(audience));
} }
/** /**
@ -388,7 +390,7 @@ public class WorkloadApiClient implements Closeable {
throw new X509ContextException("Error processing X509Context: x509SVIDResponse is empty"); throw new X509ContextException("Error processing X509Context: x509SVIDResponse is empty");
} }
private JwtSvid callFetchJwtSvid(SpiffeId subject, List<String> audience) throws JwtSvidException { private JwtSvid callFetchJwtSvid(SpiffeId subject, Set<String> audience) throws JwtSvidException {
Workload.JWTSVIDRequest jwtsvidRequest = Workload.JWTSVIDRequest Workload.JWTSVIDRequest jwtsvidRequest = Workload.JWTSVIDRequest
.newBuilder() .newBuilder()
.setSpiffeId(subject.toString()) .setSpiffeId(subject.toString())

View File

@ -1,5 +1,6 @@
package io.spiffe.spiffeid; package io.spiffe.spiffeid;
import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.io.IOException; import java.io.IOException;
@ -8,36 +9,37 @@ import java.net.URISyntaxException;
import java.nio.file.NoSuchFileException; import java.nio.file.NoSuchFileException;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.List; import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
class SpiffeIdUtilsTest { class SpiffeIdUtilsTest {
@Test @Test
void getSpiffeIdListFromFile() throws URISyntaxException { void getSpiffeIdSetFromFile() throws URISyntaxException {
Path path = Paths.get(toUri("testdata/spiffeid/spiffeIds.txt")); Path path = Paths.get(toUri("testdata/spiffeid/spiffeIds.txt"));
try { try {
List<SpiffeId> spiffeIdList = SpiffeIdUtils.getSpiffeIdListFromFile(path); Set<SpiffeId> spiffeIdSet = SpiffeIdUtils.getSpiffeIdSetFromFile(path);
assertNotNull(spiffeIdList); assertNotNull(spiffeIdSet);
assertEquals(3, spiffeIdList.size()); assertEquals(3, spiffeIdSet.size());
assertEquals(SpiffeId.parse("spiffe://example.org/workload1"), spiffeIdList.get(0)); assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example.org/workload1")));
assertEquals(SpiffeId.parse("spiffe://example.org/workload2"), spiffeIdList.get(1)); assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example.org/workload2")));
assertEquals(SpiffeId.parse("spiffe://example2.org/workload1"), spiffeIdList.get(2)); assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example2.org/workload1")));
} catch (IOException e) { } catch (IOException e) {
fail(e); fail(e);
} }
} }
@Test @Test
void getSpiffeIdListFromNonExistenFile_throwsException() throws IOException { void getSpiffeIdSetFromNonExistenFile_throwsException() throws IOException {
Path path = Paths.get("testdata/spiffeid/non-existent-file"); Path path = Paths.get("testdata/spiffeid/non-existent-file");
try { try {
SpiffeIdUtils.getSpiffeIdListFromFile(path); SpiffeIdUtils.getSpiffeIdSetFromFile(path);
fail("should have thrown exception"); fail("should have thrown exception");
} catch (NoSuchFileException e) { } catch (NoSuchFileException e) {
assertEquals("testdata/spiffeid/non-existent-file", e.getMessage()); assertEquals("testdata/spiffeid/non-existent-file", e.getMessage());
@ -45,15 +47,14 @@ class SpiffeIdUtilsTest {
} }
@Test @Test
void toListOfSpiffeIds() { void toSetOfSpiffeIds() {
String spiffeIdsAsString = " spiffe://example.org/workload1, spiffe://example.org/workload2 "; val spiffeIdsAsString = " spiffe://example.org/workload1, spiffe://example.org/workload2 ";
val spiffeIdSet = SpiffeIdUtils.toSetOfSpiffeIds(spiffeIdsAsString, ',');
List<SpiffeId> spiffeIdList = SpiffeIdUtils.toListOfSpiffeIds(spiffeIdsAsString, ','); assertNotNull(spiffeIdSet);
assertEquals(2, spiffeIdSet.size());
assertNotNull(spiffeIdList); assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example.org/workload1")));
assertEquals(2, spiffeIdList.size()); assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example.org/workload2")));
assertEquals(SpiffeId.parse("spiffe://example.org/workload1"), spiffeIdList.get(0));
assertEquals(SpiffeId.parse("spiffe://example.org/workload2"), spiffeIdList.get(1));
} }
private URI toUri(String path) throws URISyntaxException { private URI toUri(String path) throws URISyntaxException {

View File

@ -2,24 +2,24 @@ package io.spiffe.svid.jwtsvid;
import com.nimbusds.jose.jwk.Curve; import com.nimbusds.jose.jwk.Curve;
import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTClaimsSet;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.exception.AuthorityNotFoundException; import io.spiffe.exception.AuthorityNotFoundException;
import io.spiffe.exception.BundleNotFoundException; import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.exception.JwtSvidException; import io.spiffe.exception.JwtSvidException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.utils.TestUtils;
import lombok.Builder; import lombok.Builder;
import lombok.Value; import lombok.Value;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.utils.TestUtils;
import java.security.KeyPair; import java.security.KeyPair;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.Set;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -43,7 +43,7 @@ class JwtSvidParseAndValidateTest {
assertEquals(testCase.expectedJwtSvid.getAudience(), jwtSvid.getAudience()); assertEquals(testCase.expectedJwtSvid.getAudience(), jwtSvid.getAudience());
assertEquals(testCase.expectedJwtSvid.getExpiry().toInstant().getEpochSecond(), jwtSvid.getExpiry().toInstant().getEpochSecond()); assertEquals(testCase.expectedJwtSvid.getExpiry().toInstant().getEpochSecond(), jwtSvid.getExpiry().toInstant().getEpochSecond());
assertEquals(token, jwtSvid.getToken()); assertEquals(token, jwtSvid.getToken());
assertEquals(token, jwtSvid.marshall()); assertEquals(token, jwtSvid.marshal());
} catch (Exception e) { } catch (Exception e) {
assertEquals(testCase.expectedException.getClass(), e.getClass()); assertEquals(testCase.expectedException.getClass(), e.getClass());
assertEquals(testCase.expectedException.getMessage(), e.getMessage()); assertEquals(testCase.expectedException.getMessage(), e.getMessage());
@ -54,7 +54,7 @@ class JwtSvidParseAndValidateTest {
void testParseAndValidate_nullToken_throwsNullPointerException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException { void testParseAndValidate_nullToken_throwsNullPointerException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException {
TrustDomain trustDomain = TrustDomain.of("test.domain"); TrustDomain trustDomain = TrustDomain.of("test.domain");
JwtBundle jwtBundle = new JwtBundle(trustDomain); JwtBundle jwtBundle = new JwtBundle(trustDomain);
List<String> audience = Collections.singletonList("audience"); Set<String> audience = Collections.singleton("audience");
try { try {
JwtSvid.parseAndValidate(null, jwtBundle, audience); JwtSvid.parseAndValidate(null, jwtBundle, audience);
@ -67,7 +67,7 @@ class JwtSvidParseAndValidateTest {
void testParseAndValidate_emptyToken_throwsIllegalArgumentException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException { void testParseAndValidate_emptyToken_throwsIllegalArgumentException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException {
TrustDomain trustDomain = TrustDomain.of("test.domain"); TrustDomain trustDomain = TrustDomain.of("test.domain");
JwtBundle jwtBundle = new JwtBundle(trustDomain); JwtBundle jwtBundle = new JwtBundle(trustDomain);
List<String> audience = Collections.singletonList("audience"); Set<String> audience = Collections.singleton("audience");
try { try {
JwtSvid.parseAndValidate("", jwtBundle, audience); JwtSvid.parseAndValidate("", jwtBundle, audience);
@ -78,7 +78,7 @@ class JwtSvidParseAndValidateTest {
@Test @Test
void testParseAndValidate_nullBundle_throwsNullPointerException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException { void testParseAndValidate_nullBundle_throwsNullPointerException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException {
List<String> audience = Collections.singletonList("audience"); Set<String> audience = Collections.singleton("audience");
try { try {
JwtSvid.parseAndValidate("token", null, audience); JwtSvid.parseAndValidate("token", null, audience);
} catch (NullPointerException e) { } catch (NullPointerException e) {
@ -90,7 +90,6 @@ class JwtSvidParseAndValidateTest {
void testParseAndValidate_nullAudience_throwsNullPointerException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException { void testParseAndValidate_nullAudience_throwsNullPointerException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException {
TrustDomain trustDomain = TrustDomain.of("test.domain"); TrustDomain trustDomain = TrustDomain.of("test.domain");
JwtBundle jwtBundle = new JwtBundle(trustDomain); JwtBundle jwtBundle = new JwtBundle(trustDomain);
List<String> audience = Collections.singletonList("audience");
try { try {
JwtSvid.parseAndValidate("token", jwtBundle, null); JwtSvid.parseAndValidate("token", jwtBundle, null);
@ -112,7 +111,7 @@ class JwtSvidParseAndValidateTest {
SpiffeId spiffeId = trustDomain.newSpiffeId("host"); SpiffeId spiffeId = trustDomain.newSpiffeId("host");
Date expiration = new Date(System.currentTimeMillis() + 3600000); Date expiration = new Date(System.currentTimeMillis() + 3600000);
List<String> audience = Collections.singletonList("audience"); Set<String> audience = Collections.singleton("audience");
JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration); JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration);
@ -179,7 +178,7 @@ class JwtSvidParseAndValidateTest {
Arguments.of(TestCase.builder() Arguments.of(TestCase.builder()
.name("8. unexpected audience") .name("8. unexpected audience")
.jwtBundle(jwtBundle) .jwtBundle(jwtBundle)
.expectedAudience(Collections.singletonList("another")) .expectedAudience(Collections.singleton("another"))
.generateToken(() -> TestUtils.generateToken(claims, key1, "authority1")) .generateToken(() -> TestUtils.generateToken(claims, key1, "authority1"))
.expectedException(new JwtSvidException("expected audience in [another] (audience=[audience])")) .expectedException(new JwtSvidException("expected audience in [another] (audience=[audience])"))
.build()), .build()),
@ -194,32 +193,39 @@ class JwtSvidParseAndValidateTest {
.name("10. missing key id") .name("10. missing key id")
.jwtBundle(jwtBundle) .jwtBundle(jwtBundle)
.expectedAudience(audience) .expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(claims, key1, "")) .generateToken(() -> TestUtils.generateToken(claims, key1, null))
.expectedException(new JwtSvidException("Token header missing key id")) .expectedException(new JwtSvidException("Token header missing key id"))
.build()), .build()),
Arguments.of(TestCase.builder() Arguments.of(TestCase.builder()
.name("11. no bundle for trust domain") .name("11. key id contains an empty value")
.jwtBundle(jwtBundle)
.expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(claims, key1, " "))
.expectedException(new JwtSvidException("Token header key id contains an empty value"))
.build()),
Arguments.of(TestCase.builder()
.name("12. no bundle for trust domain")
.jwtBundle(new JwtBundle(TrustDomain.of("other.domain"))) .jwtBundle(new JwtBundle(TrustDomain.of("other.domain")))
.expectedAudience(audience) .expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(claims, key1, "authority1")) .generateToken(() -> TestUtils.generateToken(claims, key1, "authority1"))
.expectedException(new BundleNotFoundException("No JWT bundle found for trust domain test.domain")) .expectedException(new BundleNotFoundException("No JWT bundle found for trust domain test.domain"))
.build()), .build()),
Arguments.of(TestCase.builder() Arguments.of(TestCase.builder()
.name("12. no authority found for key id") .name("13. no authority found for key id")
.jwtBundle(new JwtBundle(TrustDomain.of("test.domain"))) .jwtBundle(new JwtBundle(TrustDomain.of("test.domain")))
.expectedAudience(audience) .expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(claims, key1, "authority1")) .generateToken(() -> TestUtils.generateToken(claims, key1, "authority1"))
.expectedException(new AuthorityNotFoundException("No authority found for the trust domain test.domain and key id authority1")) .expectedException(new AuthorityNotFoundException("No authority found for the trust domain test.domain and key id authority1"))
.build()), .build()),
Arguments.of(TestCase.builder() Arguments.of(TestCase.builder()
.name("13. signature cannot be verified with authority") .name("14. signature cannot be verified with authority")
.jwtBundle(jwtBundle) .jwtBundle(jwtBundle)
.expectedAudience(audience) .expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(claims, key2, "authority1")) .generateToken(() -> TestUtils.generateToken(claims, key2, "authority1"))
.expectedException(new JwtSvidException("Signature invalid: cannot be verified with the authority with keyId=authority1")) .expectedException(new JwtSvidException("Signature invalid: cannot be verified with the authority with keyId=authority1"))
.build()), .build()),
Arguments.of(TestCase.builder() Arguments.of(TestCase.builder()
.name("14. authority algorithm mismatch") .name("15. authority algorithm mismatch")
.jwtBundle(jwtBundle) .jwtBundle(jwtBundle)
.expectedAudience(audience) .expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(claims, key3, "authority1")) .generateToken(() -> TestUtils.generateToken(claims, key3, "authority1"))
@ -232,13 +238,13 @@ class JwtSvidParseAndValidateTest {
static class TestCase { static class TestCase {
String name; String name;
JwtBundle jwtBundle; JwtBundle jwtBundle;
List<String> audience; Set<String> audience;
Supplier<String> generateToken; Supplier<String> generateToken;
Exception expectedException; Exception expectedException;
JwtSvid expectedJwtSvid; JwtSvid expectedJwtSvid;
@Builder @Builder
public TestCase(String name, JwtBundle jwtBundle, List<String> expectedAudience, Supplier<String> generateToken, public TestCase(String name, JwtBundle jwtBundle, Set<String> expectedAudience, Supplier<String> generateToken,
Exception expectedException, JwtSvid expectedJwtSvid) { Exception expectedException, JwtSvid expectedJwtSvid) {
this.name = name; this.name = name;
this.jwtBundle = jwtBundle; this.jwtBundle = jwtBundle;

View File

@ -2,22 +2,22 @@ package io.spiffe.svid.jwtsvid;
import com.nimbusds.jose.jwk.Curve; import com.nimbusds.jose.jwk.Curve;
import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTClaimsSet;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.exception.JwtSvidException; import io.spiffe.exception.JwtSvidException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.utils.TestUtils;
import lombok.Builder; import lombok.Builder;
import lombok.Value; import lombok.Value;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.utils.TestUtils;
import java.security.KeyPair; import java.security.KeyPair;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.Set;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -46,7 +46,7 @@ class JwtSvidParseInsecureTest {
@Test @Test
void testParseInsecure_nullToken_throwsNullPointerException() throws JwtSvidException { void testParseInsecure_nullToken_throwsNullPointerException() throws JwtSvidException {
List<String> audience = Collections.singletonList("audience"); Set<String> audience = Collections.singleton("audience");
try { try {
JwtSvid.parseInsecure(null, audience); JwtSvid.parseInsecure(null, audience);
@ -57,7 +57,7 @@ class JwtSvidParseInsecureTest {
@Test @Test
void testParseAndValidate_emptyToken_throwsIllegalArgumentException() throws JwtSvidException { void testParseAndValidate_emptyToken_throwsIllegalArgumentException() throws JwtSvidException {
List<String> audience = Collections.singletonList("audience"); Set<String> audience = Collections.singleton("audience");
try { try {
JwtSvid.parseInsecure("", audience); JwtSvid.parseInsecure("", audience);
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
@ -71,7 +71,7 @@ class JwtSvidParseInsecureTest {
KeyPair key1 = TestUtils.generateECKeyPair(Curve.P_521); KeyPair key1 = TestUtils.generateECKeyPair(Curve.P_521);
TrustDomain trustDomain = TrustDomain.of("test.domain"); TrustDomain trustDomain = TrustDomain.of("test.domain");
SpiffeId spiffeId = trustDomain.newSpiffeId("host"); SpiffeId spiffeId = trustDomain.newSpiffeId("host");
List<String> audience = Collections.singletonList("audience"); Set<String> audience = Collections.singleton("audience");
Date expiration = new Date(System.currentTimeMillis() + 3600000); Date expiration = new Date(System.currentTimeMillis() + 3600000);
JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration); JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration);
@ -93,7 +93,7 @@ class JwtSvidParseInsecureTest {
SpiffeId spiffeId = trustDomain.newSpiffeId("host"); SpiffeId spiffeId = trustDomain.newSpiffeId("host");
Date expiration = new Date(System.currentTimeMillis() + 3600000); Date expiration = new Date(System.currentTimeMillis() + 3600000);
List<String> audience = Collections.singletonList("audience"); Set<String> audience = Collections.singleton("audience");
JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration); JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration);
@ -135,7 +135,7 @@ class JwtSvidParseInsecureTest {
.build()), .build()),
Arguments.of(TestCase.builder() Arguments.of(TestCase.builder()
.name("unexpected audience") .name("unexpected audience")
.expectedAudience(Collections.singletonList("another")) .expectedAudience(Collections.singleton("another"))
.generateToken(() -> TestUtils.generateToken(claims, key1, "authority1")) .generateToken(() -> TestUtils.generateToken(claims, key1, "authority1"))
.expectedException(new JwtSvidException("expected audience in [another] (audience=[audience])")) .expectedException(new JwtSvidException("expected audience in [another] (audience=[audience])"))
.build()), .build()),
@ -151,13 +151,13 @@ class JwtSvidParseInsecureTest {
@Value @Value
static class TestCase { static class TestCase {
String name; String name;
List<String> audience; Set<String> audience;
Supplier<String> generateToken; Supplier<String> generateToken;
Exception expectedException; Exception expectedException;
JwtSvid expectedJwtSvid; JwtSvid expectedJwtSvid;
@Builder @Builder
public TestCase(String name, List<String> expectedAudience, Supplier<String> generateToken, public TestCase(String name, Set<String> expectedAudience, Supplier<String> generateToken,
Exception expectedException, JwtSvid expectedJwtSvid) { Exception expectedException, JwtSvid expectedJwtSvid) {
this.name = name; this.name = name;
this.audience = expectedAudience; this.audience = expectedAudience;

View File

@ -264,11 +264,11 @@ public class X509SvidTest {
} }
@Test @Test
void testGetX509Svid() throws URISyntaxException, X509SvidException { void testGetLeaf() throws URISyntaxException, X509SvidException {
Path certPath = Paths.get(toUri(certSingle)); Path certPath = Paths.get(toUri(certSingle));
Path keyPath = Paths.get(toUri(keyRSA)); Path keyPath = Paths.get(toUri(keyRSA));
X509Svid x509Svid = X509Svid.load(certPath, keyPath); X509Svid x509Svid = X509Svid.load(certPath, keyPath);
assertEquals(x509Svid, x509Svid.getX509Svid()); assertEquals(x509Svid.getChain().get(0), x509Svid.getLeaf());
} }
@Test @Test

View File

@ -1,27 +1,28 @@
package io.spiffe.svid.x509svid; package io.spiffe.svid.x509svid;
import io.spiffe.exception.BundleNotFoundException; import com.google.common.collect.Sets;
import lombok.val;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import io.spiffe.bundle.x509bundle.X509Bundle; import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain; import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.utils.X509CertificateTestUtils.CertAndKeyPair; import io.spiffe.utils.X509CertificateTestUtils.CertAndKeyPair;
import lombok.val;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.io.IOException; import java.io.IOException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import static java.util.Collections.EMPTY_LIST;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import static io.spiffe.utils.X509CertificateTestUtils.createCertificate; import static io.spiffe.utils.X509CertificateTestUtils.createCertificate;
import static io.spiffe.utils.X509CertificateTestUtils.createRootCA; import static io.spiffe.utils.X509CertificateTestUtils.createRootCA;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
public class X509SvidValidatorTest { public class X509SvidValidatorTest {
@ -84,19 +85,19 @@ public class X509SvidValidatorTest {
val spiffeId1 = SpiffeId.parse("spiffe://example.org/test"); val spiffeId1 = SpiffeId.parse("spiffe://example.org/test");
val spiffeId2 = SpiffeId.parse("spiffe://example.org/test2"); val spiffeId2 = SpiffeId.parse("spiffe://example.org/test2");
val spiffeIdList = Arrays.asList(spiffeId1, spiffeId2); val spiffeIdSet = Sets.newHashSet(spiffeId1, spiffeId2);
X509SvidValidator.verifySpiffeId(leaf.getCertificate(), () -> spiffeIdList); X509SvidValidator.verifySpiffeId(leaf.getCertificate(), () -> spiffeIdSet);
} }
@Test @Test
void checkSpiffeId_givenASpiffeIdNotInTheListOfAcceptedIds_throwsCertificateException() throws IOException, CertificateException, URISyntaxException { void checkSpiffeId_givenASpiffeIdNotInTheListOfAcceptedIds_throwsCertificateException() throws IOException, CertificateException, URISyntaxException {
val spiffeId1 = SpiffeId.parse("spiffe://example.org/other1"); val spiffeId1 = SpiffeId.parse("spiffe://example.org/other1");
val spiffeId2 = SpiffeId.parse("spiffe://example.org/other2"); val spiffeId2 = SpiffeId.parse("spiffe://example.org/other2");
List<SpiffeId> spiffeIdList = Arrays.asList(spiffeId1, spiffeId2); val spiffeIdSet = Sets.newHashSet(spiffeId1, spiffeId2);
try { try {
X509SvidValidator.verifySpiffeId(leaf.getCertificate(), () -> spiffeIdList); X509SvidValidator.verifySpiffeId(leaf.getCertificate(), () -> spiffeIdSet);
fail("Should have thrown CertificateException"); fail("Should have thrown CertificateException");
} catch (CertificateException e) { } catch (CertificateException e) {
assertEquals("SPIFFE ID spiffe://example.org/test in X.509 certificate is not accepted", e.getMessage()); assertEquals("SPIFFE ID spiffe://example.org/test in X.509 certificate is not accepted", e.getMessage());
@ -106,7 +107,7 @@ public class X509SvidValidatorTest {
@Test @Test
void checkSpiffeId_nullX509Certificate_throwsNullPointerException() throws CertificateException { void checkSpiffeId_nullX509Certificate_throwsNullPointerException() throws CertificateException {
try { try {
X509SvidValidator.verifySpiffeId(null, () -> EMPTY_LIST); X509SvidValidator.verifySpiffeId(null, Collections::emptySet);
fail("should have thrown an exception"); fail("should have thrown an exception");
} catch (NullPointerException e) { } catch (NullPointerException e) {
assertEquals("x509Certificate is marked non-null but is null", e.getMessage()); assertEquals("x509Certificate is marked non-null but is null", e.getMessage());
@ -119,7 +120,7 @@ public class X509SvidValidatorTest {
X509SvidValidator.verifySpiffeId(leaf.getCertificate(), null); X509SvidValidator.verifySpiffeId(leaf.getCertificate(), null);
fail("should have thrown an exception"); fail("should have thrown an exception");
} catch (NullPointerException e) { } catch (NullPointerException e) {
assertEquals("acceptedSpiffedIdsSupplier is marked non-null but is null", e.getMessage()); assertEquals("acceptedSpiffeIdsSupplier is marked non-null but is null", e.getMessage());
} }
} }

View File

@ -17,9 +17,11 @@ import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.security.spec.ECGenParameterSpec; import java.security.spec.ECGenParameterSpec;
import java.util.ArrayList;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.joining;
import static java.util.stream.Stream.generate; import static java.util.stream.Stream.generate;
@ -87,11 +89,11 @@ public class TestUtils {
} }
} }
public static JWTClaimsSet buildJWTClaimSet(List<String> audience, String spiffeId, Date expiration) { public static JWTClaimsSet buildJWTClaimSet(Set<String> audience, String spiffeId, Date expiration) {
return new JWTClaimsSet.Builder() return new JWTClaimsSet.Builder()
.subject(spiffeId) .subject(spiffeId)
.expirationTime(expiration) .expirationTime(expiration)
.audience(audience) .audience(new ArrayList<>(audience))
.build(); .build();
} }

View File

@ -129,7 +129,7 @@ class FakeWorkloadApi extends SpiffeWorkloadAPIImplBase {
JwtSvid jwtSvid = null; JwtSvid jwtSvid = null;
try { try {
jwtSvid = JwtSvid.parseInsecure(token, Collections.singletonList(audience)); jwtSvid = JwtSvid.parseInsecure(token, Collections.singleton(audience));
} catch (JwtSvidException e) { } catch (JwtSvidException e) {
responseObserver.onError(new StatusRuntimeException(Status.INVALID_ARGUMENT.withDescription(e.getMessage()))); responseObserver.onError(new StatusRuntimeException(Status.INVALID_ARGUMENT.withDescription(e.getMessage())));
} }

View File

@ -1,28 +1,28 @@
package io.spiffe.workloadapi; package io.spiffe.workloadapi;
import com.google.common.collect.Sets;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.Server; import io.grpc.Server;
import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.GrpcCleanupRule;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.exception.BundleNotFoundException; import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.exception.JwtSourceException; import io.spiffe.exception.JwtSourceException;
import io.spiffe.exception.JwtSvidException; import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.SocketEndpointAddressException; import io.spiffe.exception.SocketEndpointAddressException;
import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain; import io.spiffe.spiffeid.TrustDomain;
import org.junit.Rule;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.svid.jwtsvid.JwtSvid; import io.spiffe.svid.jwtsvid.JwtSvid;
import io.spiffe.workloadapi.grpc.SpiffeWorkloadAPIGrpc; import io.spiffe.workloadapi.grpc.SpiffeWorkloadAPIGrpc;
import io.spiffe.workloadapi.internal.ManagedChannelWrapper; import io.spiffe.workloadapi.internal.ManagedChannelWrapper;
import io.spiffe.workloadapi.internal.SecurityHeaderInterceptor; import io.spiffe.workloadapi.internal.SecurityHeaderInterceptor;
import org.junit.Rule;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
@ -98,7 +98,7 @@ class JwtSourceTest {
JwtSvid svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); JwtSvid svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertNotNull(svid); assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId()); assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Arrays.asList("aud1", "aud2", "aud3"), svid.getAudience()); assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
} catch (JwtSvidException e) { } catch (JwtSvidException e) {
fail(e); fail(e);
} }

View File

@ -38,6 +38,7 @@ import java.util.concurrent.Executors;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
class WorkloadApiClientTest { class WorkloadApiClientTest {
@ -163,7 +164,7 @@ class WorkloadApiClientTest {
JwtSvid jwtSvid = workloadApiClient.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3"); JwtSvid jwtSvid = workloadApiClient.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertNotNull(jwtSvid); assertNotNull(jwtSvid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), jwtSvid.getSpiffeId()); assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), jwtSvid.getSpiffeId());
assertEquals("aud1", jwtSvid.getAudience().get(0)); assertTrue(jwtSvid.getAudience().contains("aud1"));
assertEquals(3, jwtSvid.getAudience().size()); assertEquals(3, jwtSvid.getAudience().size());
} catch (JwtSvidException e) { } catch (JwtSvidException e) {
fail(e); fail(e);
@ -177,7 +178,7 @@ class WorkloadApiClientTest {
JwtSvid jwtSvid = workloadApiClient.validateJwtSvid(token, "aud1"); JwtSvid jwtSvid = workloadApiClient.validateJwtSvid(token, "aud1");
assertNotNull(jwtSvid); assertNotNull(jwtSvid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), jwtSvid.getSpiffeId()); assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), jwtSvid.getSpiffeId());
assertEquals("aud1", jwtSvid.getAudience().get(0)); assertTrue(jwtSvid.getAudience().contains("aud1"));
assertEquals(1, jwtSvid.getAudience().size()); assertEquals(1, jwtSvid.getAudience().size());
} catch (JwtSvidException e) { } catch (JwtSvidException e) {
fail(e); fail(e);

View File

@ -1,3 +1,4 @@
spiffe://example.org/workload1 spiffe://example.org/workload1
spiffe://example.org/workload2 spiffe://example.org/workload2
spiffe://example2.org/workload1 spiffe://example2.org/workload1
spiffe://example.org/workload1

View File

@ -33,11 +33,11 @@ Alternatively, a different Workload API address can be used by passing it to the
X509Source x509Source = X509Source.newSource(sourceOptions); X509Source x509Source = X509Source.newSource(sourceOptions);
Supplier<List<SpiffeId>> spiffeIdListSupplier = () -> Collections.singletonList(SpiffeId.parse("spiffe://example.org/test")); Supplier<Set<SpiffeId>> spiffeIdSetSupplier = () -> Collections.singleton(SpiffeId.parse("spiffe://example.org/test"));
SslContextOptions sslContextOptions = SslContextOptions SslContextOptions sslContextOptions = SslContextOptions
.builder() .builder()
.acceptedSpiffeIdsSupplier(spiffeIdListSupplier) .acceptedSpiffeIdsSupplier(spiffeIdSetSupplier)
.x509Source(x509Source) .x509Source(x509Source)
.build(); .build();

View File

@ -11,7 +11,7 @@ import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManager;
import java.security.KeyManagementException; import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.util.List; import java.util.Set;
import java.util.function.Supplier; import java.util.function.Supplier;
/** /**
@ -72,14 +72,14 @@ public final class SpiffeSslContextFactory {
public static class SslContextOptions { public static class SslContextOptions {
String sslProtocol; String sslProtocol;
X509Source x509Source; X509Source x509Source;
Supplier<List<SpiffeId>> acceptedSpiffeIdsSupplier; Supplier<Set<SpiffeId>> acceptedSpiffeIdsSupplier;
boolean acceptAnySpiffeId; boolean acceptAnySpiffeId;
@Builder @Builder
public SslContextOptions( public SslContextOptions(
final String sslProtocol, final String sslProtocol,
final X509Source x509Source, final X509Source x509Source,
final Supplier<List<SpiffeId>> acceptedSpiffeIdsSupplier, final Supplier<Set<SpiffeId>> acceptedSpiffeIdsSupplier,
final boolean acceptAnySpiffeId) { final boolean acceptAnySpiffeId) {
this.x509Source = x509Source; this.x509Source = x509Source;
this.acceptedSpiffeIdsSupplier = acceptedSpiffeIdsSupplier; this.acceptedSpiffeIdsSupplier = acceptedSpiffeIdsSupplier;

View File

@ -1,9 +1,9 @@
package io.spiffe.provider; package io.spiffe.provider;
import io.spiffe.bundle.BundleSource; import io.spiffe.bundle.BundleSource;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.exception.BundleNotFoundException; import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.svid.x509svid.X509SvidValidator; import io.spiffe.svid.x509svid.X509SvidValidator;
import lombok.NonNull; import lombok.NonNull;
@ -12,9 +12,9 @@ import javax.net.ssl.X509ExtendedTrustManager;
import java.net.Socket; import java.net.Socket;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.Collections;
import java.util.Set;
import java.util.function.Supplier; import java.util.function.Supplier;
/** /**
@ -26,7 +26,7 @@ import java.util.function.Supplier;
public final class SpiffeTrustManager extends X509ExtendedTrustManager { public final class SpiffeTrustManager extends X509ExtendedTrustManager {
private final BundleSource<X509Bundle> x509BundleSource; private final BundleSource<X509Bundle> x509BundleSource;
private final Supplier<List<SpiffeId>> acceptedSpiffeIdsSupplier; private final Supplier<Set<SpiffeId>> acceptedSpiffeIdsSupplier;
private final boolean acceptAnySpiffeId; private final boolean acceptAnySpiffeId;
/** /**
@ -36,10 +36,10 @@ public final class SpiffeTrustManager extends X509ExtendedTrustManager {
* and a {@link Supplier} of a List of accepted {@link SpiffeId} to be used during peer SVID validation. * and a {@link Supplier} of a List of accepted {@link SpiffeId} to be used during peer SVID validation.
* *
* @param x509BundleSource an implementation of a {@link BundleSource} * @param x509BundleSource an implementation of a {@link BundleSource}
* @param acceptedSpiffeIdsSupplier a {@link Supplier} of a list of accepted SPIFFE IDs. * @param acceptedSpiffeIdsSupplier a {@link Supplier} of a Set of accepted SPIFFE IDs.
*/ */
public SpiffeTrustManager(@NonNull final BundleSource<X509Bundle> x509BundleSource, public SpiffeTrustManager(@NonNull final BundleSource<X509Bundle> x509BundleSource,
@NonNull final Supplier<List<SpiffeId>> acceptedSpiffeIdsSupplier) { @NonNull final Supplier<Set<SpiffeId>> acceptedSpiffeIdsSupplier) {
this.x509BundleSource = x509BundleSource; this.x509BundleSource = x509BundleSource;
this.acceptedSpiffeIdsSupplier = acceptedSpiffeIdsSupplier; this.acceptedSpiffeIdsSupplier = acceptedSpiffeIdsSupplier;
this.acceptAnySpiffeId = false; this.acceptAnySpiffeId = false;
@ -57,7 +57,7 @@ public final class SpiffeTrustManager extends X509ExtendedTrustManager {
public SpiffeTrustManager(@NonNull final BundleSource<X509Bundle> x509BundleSource, public SpiffeTrustManager(@NonNull final BundleSource<X509Bundle> x509BundleSource,
final boolean acceptAnySpiffeId) { final boolean acceptAnySpiffeId) {
this.x509BundleSource = x509BundleSource; this.x509BundleSource = x509BundleSource;
this.acceptedSpiffeIdsSupplier = ArrayList::new; this.acceptedSpiffeIdsSupplier = Collections::emptySet;
this.acceptAnySpiffeId = acceptAnySpiffeId; this.acceptAnySpiffeId = acceptAnySpiffeId;
} }

View File

@ -1,19 +1,19 @@
package io.spiffe.provider; package io.spiffe.provider;
import io.spiffe.bundle.BundleSource; import io.spiffe.bundle.BundleSource;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.exception.SocketEndpointAddressException; import io.spiffe.exception.SocketEndpointAddressException;
import io.spiffe.exception.X509SourceException; import io.spiffe.exception.X509SourceException;
import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.SpiffeIdUtils; import io.spiffe.spiffeid.SpiffeIdUtils;
import io.spiffe.workloadapi.X509Source; import io.spiffe.workloadapi.X509Source;
import lombok.NonNull; import lombok.NonNull;
import io.spiffe.bundle.x509bundle.X509Bundle;
import javax.net.ssl.ManagerFactoryParameters; import javax.net.ssl.ManagerFactoryParameters;
import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactorySpi; import javax.net.ssl.TrustManagerFactorySpi;
import java.security.KeyStore; import java.security.KeyStore;
import java.util.List; import java.util.Set;
import java.util.function.Supplier; import java.util.function.Supplier;
import static io.spiffe.provider.SpiffeProviderConstants.SSL_SPIFFE_ACCEPT_ALL_PROPERTY; import static io.spiffe.provider.SpiffeProviderConstants.SSL_SPIFFE_ACCEPT_ALL_PROPERTY;
@ -35,11 +35,11 @@ import static io.spiffe.provider.SpiffeProviderConstants.SSL_SPIFFE_ACCEPT_PROPE
public class SpiffeTrustManagerFactory extends TrustManagerFactorySpi { public class SpiffeTrustManagerFactory extends TrustManagerFactorySpi {
private static final boolean ACCEPT_ANY_SPIFFE_ID; private static final boolean ACCEPT_ANY_SPIFFE_ID;
private static final Supplier<List<SpiffeId>> DEFAULT_SPIFFE_ID_LIST_SUPPLIER; private static final Supplier<Set<SpiffeId>> DEFAULT_SPIFFE_ID_LIST_SUPPLIER;
static { static {
ACCEPT_ANY_SPIFFE_ID = Boolean.parseBoolean(EnvironmentUtils.getProperty(SSL_SPIFFE_ACCEPT_ALL_PROPERTY, "false")); ACCEPT_ANY_SPIFFE_ID = Boolean.parseBoolean(EnvironmentUtils.getProperty(SSL_SPIFFE_ACCEPT_ALL_PROPERTY, "false"));
DEFAULT_SPIFFE_ID_LIST_SUPPLIER = () -> SpiffeIdUtils.toListOfSpiffeIds(EnvironmentUtils.getProperty(SSL_SPIFFE_ACCEPT_PROPERTY)); DEFAULT_SPIFFE_ID_LIST_SUPPLIER = () -> SpiffeIdUtils.toSetOfSpiffeIds(EnvironmentUtils.getProperty(SSL_SPIFFE_ACCEPT_PROPERTY));
} }
/** /**
@ -113,12 +113,12 @@ public class SpiffeTrustManagerFactory extends TrustManagerFactorySpi {
* and a supplier of accepted SPIFFE IDs. * and a supplier of accepted SPIFFE IDs.
* *
* @param x509BundleSource a {@link BundleSource} to provide the X.509-Bundles * @param x509BundleSource a {@link BundleSource} to provide the X.509-Bundles
* @param acceptedSpiffeIdsSupplier a Supplier to provide a List of SPIFFE IDs that are accepted * @param acceptedSpiffeIdsSupplier a Supplier to provide a Set of SPIFFE IDs that are accepted
* @return an instance of a {@link TrustManager} wrapped in an array. The actual type returned is {@link SpiffeTrustManager} * @return an instance of a {@link TrustManager} wrapped in an array. The actual type returned is {@link SpiffeTrustManager}
*/ */
public TrustManager[] engineGetTrustManagers( public TrustManager[] engineGetTrustManagers(
@NonNull final BundleSource<X509Bundle> x509BundleSource, @NonNull final BundleSource<X509Bundle> x509BundleSource,
@NonNull final Supplier<List<SpiffeId>> acceptedSpiffeIdsSupplier) { @NonNull final Supplier<Set<SpiffeId>> acceptedSpiffeIdsSupplier) {
SpiffeTrustManager spiffeTrustManager = new SpiffeTrustManager(x509BundleSource, acceptedSpiffeIdsSupplier); SpiffeTrustManager spiffeTrustManager = new SpiffeTrustManager(x509BundleSource, acceptedSpiffeIdsSupplier);
return new TrustManager[]{spiffeTrustManager}; return new TrustManager[]{spiffeTrustManager};

View File

@ -1,18 +1,18 @@
package io.spiffe.provider; package io.spiffe.provider;
import io.spiffe.bundle.BundleSource; import io.spiffe.bundle.BundleSource;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.exception.BundleNotFoundException; import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.exception.X509SvidException; import io.spiffe.exception.X509SvidException;
import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain; import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.svid.x509svid.X509Svid;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.svid.x509svid.X509Svid;
import javax.net.ssl.X509TrustManager; import javax.net.ssl.X509TrustManager;
import java.io.IOException; import java.io.IOException;
@ -22,7 +22,7 @@ import java.nio.file.Paths;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
@ -36,7 +36,7 @@ public class SpiffeTrustManagerTest {
static X509Bundle x509Bundle; static X509Bundle x509Bundle;
static X509Svid x509Svid; static X509Svid x509Svid;
static X509Svid otherX509Svid; static X509Svid otherX509Svid;
List<SpiffeId> acceptedSpiffeIds; Set<SpiffeId> acceptedSpiffeIds;
X509TrustManager trustManager; X509TrustManager trustManager;
@BeforeAll @BeforeAll
@ -59,19 +59,12 @@ public class SpiffeTrustManagerTest {
void setupMocks() { void setupMocks() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
trustManager = (X509TrustManager) trustManager = (X509TrustManager)
new SpiffeTrustManagerFactory() new SpiffeTrustManagerFactory().engineGetTrustManagers(bundleSource, () -> acceptedSpiffeIds)[0];
.engineGetTrustManagers(
bundleSource,
() -> acceptedSpiffeIds)[0];
} }
@Test @Test
void checkClientTrusted_passAExpiredCertificate_throwsException() throws BundleNotFoundException { void checkClientTrusted_passAExpiredCertificate_throwsException() throws BundleNotFoundException {
acceptedSpiffeIds = acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/test"));
Collections
.singletonList(
SpiffeId.parse("spiffe://example.org/test")
);
val chain = x509Svid.getChainArray(); val chain = x509Svid.getChainArray();
@ -87,11 +80,7 @@ public class SpiffeTrustManagerTest {
@Test @Test
void checkClientTrusted_noBundleForTrustDomain_ThrowCertificateException() throws BundleNotFoundException { void checkClientTrusted_noBundleForTrustDomain_ThrowCertificateException() throws BundleNotFoundException {
acceptedSpiffeIds = acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/test"));
Collections
.singletonList(
SpiffeId.parse("spiffe://example.org/test")
);
val chain = x509Svid.getChainArray(); val chain = x509Svid.getChainArray();
@ -107,11 +96,7 @@ public class SpiffeTrustManagerTest {
@Test @Test
void checkClientTrusted_passCertificateWithNonAcceptedSpiffeId_ThrowCertificateException() throws BundleNotFoundException { void checkClientTrusted_passCertificateWithNonAcceptedSpiffeId_ThrowCertificateException() throws BundleNotFoundException {
acceptedSpiffeIds = acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/other"));
Collections
.singletonList(
SpiffeId.parse("spiffe://example.org/other")
);
X509Certificate[] chain = x509Svid.getChainArray(); X509Certificate[] chain = x509Svid.getChainArray();
@ -128,11 +113,7 @@ public class SpiffeTrustManagerTest {
@Test @Test
void checkClientTrusted_passCertificateThatDoesntChainToBundle_ThrowCertificateException() throws BundleNotFoundException { void checkClientTrusted_passCertificateThatDoesntChainToBundle_ThrowCertificateException() throws BundleNotFoundException {
acceptedSpiffeIds = acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://other.org/test"));
Collections
.singletonList(
SpiffeId.parse("spiffe://other.org/test")
);
val chain = otherX509Svid.getChainArray(); val chain = otherX509Svid.getChainArray();
@ -148,11 +129,7 @@ public class SpiffeTrustManagerTest {
@Test @Test
void checkServerTrusted_passAnExpiredCertificate_ThrowsException() throws BundleNotFoundException { void checkServerTrusted_passAnExpiredCertificate_ThrowsException() throws BundleNotFoundException {
acceptedSpiffeIds = acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/test"));
Collections
.singletonList(
SpiffeId.parse("spiffe://example.org/test")
);
val chain = x509Svid.getChainArray(); val chain = x509Svid.getChainArray();
@ -168,11 +145,7 @@ public class SpiffeTrustManagerTest {
@Test @Test
void checkServerTrusted_passCertificateWithNonAcceptedSpiffeId_ThrowCertificateException() throws BundleNotFoundException { void checkServerTrusted_passCertificateWithNonAcceptedSpiffeId_ThrowCertificateException() throws BundleNotFoundException {
acceptedSpiffeIds = acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/other"));
Collections
.singletonList(
SpiffeId.parse("spiffe://example.org/other")
);
val chain = x509Svid.getChainArray(); val chain = x509Svid.getChainArray();
@ -188,11 +161,7 @@ public class SpiffeTrustManagerTest {
@Test @Test
void checkServerTrusted_passCertificateThatDoesntChainToBundle_ThrowCertificateException() throws BundleNotFoundException { void checkServerTrusted_passCertificateThatDoesntChainToBundle_ThrowCertificateException() throws BundleNotFoundException {
acceptedSpiffeIds = acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://other.org/test"));
Collections
.singletonList(
SpiffeId.parse("spiffe://other.org/test")
);
val chain = otherX509Svid.getChainArray(); val chain = otherX509Svid.getChainArray();

View File

@ -3,13 +3,13 @@ package io.spiffe.provider.examples.mtls;
import io.spiffe.exception.SocketEndpointAddressException; import io.spiffe.exception.SocketEndpointAddressException;
import io.spiffe.exception.X509SourceException; import io.spiffe.exception.X509SourceException;
import io.spiffe.provider.SpiffeKeyManager; import io.spiffe.provider.SpiffeKeyManager;
import io.spiffe.provider.SpiffeSslContextFactory;
import io.spiffe.provider.SpiffeSslContextFactory.SslContextOptions;
import io.spiffe.provider.SpiffeTrustManager; import io.spiffe.provider.SpiffeTrustManager;
import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.SpiffeIdUtils; import io.spiffe.spiffeid.SpiffeIdUtils;
import io.spiffe.workloadapi.X509Source; import io.spiffe.workloadapi.X509Source;
import lombok.val; import lombok.val;
import io.spiffe.provider.SpiffeSslContextFactory;
import io.spiffe.provider.SpiffeSslContextFactory.SslContextOptions;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocket;
@ -20,7 +20,7 @@ import java.net.URISyntaxException;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.security.KeyManagementException; import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.util.List; import java.util.Set;
import java.util.function.Supplier; import java.util.function.Supplier;
/** /**
@ -35,12 +35,12 @@ import java.util.function.Supplier;
public class HttpsClient { public class HttpsClient {
String spiffeSocket; String spiffeSocket;
Supplier<List<SpiffeId>> acceptedSpiffeIdsListSupplier; Supplier<Set<SpiffeId>> acceptedSpiffeIdsSetSupplier;
int serverPort; int serverPort;
public static void main(String[] args) { public static void main(String[] args) {
String spiffeSocket = "unix:/tmp/agent.sock"; String spiffeSocket = "unix:/tmp/agent.sock";
HttpsClient httpsClient = new HttpsClient(4000, spiffeSocket, () -> new AcceptedSpiffeIds().getList()); HttpsClient httpsClient = new HttpsClient(4000, spiffeSocket, () -> new AcceptedSpiffeIds().getSet());
try { try {
httpsClient.run(); httpsClient.run();
} catch (KeyManagementException | NoSuchAlgorithmException | IOException | SocketEndpointAddressException | X509SourceException e) { } catch (KeyManagementException | NoSuchAlgorithmException | IOException | SocketEndpointAddressException | X509SourceException e) {
@ -48,10 +48,10 @@ public class HttpsClient {
} }
} }
HttpsClient(int serverPort, String spiffeSocket, Supplier<List<SpiffeId>> acceptedSpiffeIdsListSupplier) { HttpsClient(int serverPort, String spiffeSocket, Supplier<Set<SpiffeId>> acceptedSpiffeIdsSetSupplier) {
this.serverPort = serverPort; this.serverPort = serverPort;
this.spiffeSocket = spiffeSocket; this.spiffeSocket = spiffeSocket;
this.acceptedSpiffeIdsListSupplier = acceptedSpiffeIdsListSupplier; this.acceptedSpiffeIdsSetSupplier = acceptedSpiffeIdsSetSupplier;
} }
void run() throws IOException, SocketEndpointAddressException, KeyManagementException, NoSuchAlgorithmException, X509SourceException { void run() throws IOException, SocketEndpointAddressException, KeyManagementException, NoSuchAlgorithmException, X509SourceException {
@ -64,7 +64,7 @@ public class HttpsClient {
SslContextOptions sslContextOptions = SslContextOptions SslContextOptions sslContextOptions = SslContextOptions
.builder() .builder()
.acceptedSpiffeIdsSupplier(acceptedSpiffeIdsListSupplier) .acceptedSpiffeIdsSupplier(acceptedSpiffeIdsSetSupplier)
.x509Source(x509Source) .x509Source(x509Source)
.build(); .build();
SSLContext sslContext = SpiffeSslContextFactory.getSslContext(sslContextOptions); SSLContext sslContext = SpiffeSslContextFactory.getSslContext(sslContextOptions);
@ -76,9 +76,9 @@ public class HttpsClient {
} }
private static class AcceptedSpiffeIds { private static class AcceptedSpiffeIds {
List<SpiffeId> getList() { Set<SpiffeId> getSet() {
try { try {
return SpiffeIdUtils.getSpiffeIdListFromFile(Paths.get(toUri("testdata/spiffeIds.txt"))); return SpiffeIdUtils.getSpiffeIdSetFromFile(Paths.get(toUri("testdata/spiffeIds.txt")));
} catch (IOException | URISyntaxException e) { } catch (IOException | URISyntaxException e) {
throw new RuntimeException("Error getting list of spiffeIds", e); throw new RuntimeException("Error getting list of spiffeIds", e);
} }