Addressing PR comments:

- refactor acceptedSpiffeIds from List to Set
- refactor tests
- renaming methods to improve clarity
- amendments in javadocs

Signed-off-by: Max Lambrecht <maxlambrecht@gmail.com>
This commit is contained in:
Max Lambrecht 2020-06-23 11:26:00 -03:00
parent dbfb09f0f8
commit ca5511eb91
21 changed files with 228 additions and 226 deletions

View File

@ -7,7 +7,7 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -30,11 +30,11 @@ public class SpiffeIdUtils {
* @throws IOException if the given spiffeIdsFile cannot be read
* @throws IllegalArgumentException if any of the SPIFFE IDs in the file cannot be parsed
*/
public static List<SpiffeId> getSpiffeIdListFromFile(final Path spiffeIdsFile) throws IOException {
public static Set<SpiffeId> getSpiffeIdSetFromFile(final Path spiffeIdsFile) throws IOException {
try (Stream<String> lines = Files.lines(spiffeIdsFile)) {
return lines
.map(SpiffeId::parse)
.collect(Collectors.toList());
.collect(Collectors.toSet());
}
}
@ -47,15 +47,15 @@ public class SpiffeIdUtils {
* @return a list of {@link SpiffeId} instances.
* @throws IllegalArgumentException is the string provided is blank
*/
public static List<SpiffeId> toListOfSpiffeIds(final String spiffeIds, final char separator) {
public static Set<SpiffeId> toSetOfSpiffeIds(final String spiffeIds, final char separator) {
if (isBlank(spiffeIds)) {
return Collections.emptyList();
return Collections.emptySet();
}
val array = spiffeIds.split(String.valueOf(separator));
return Arrays.stream(array)
.map(SpiffeId::parse)
.collect(Collectors.toList());
.collect(Collectors.toSet());
}
/**
@ -65,8 +65,8 @@ public class SpiffeIdUtils {
* @return a list of {@link SpiffeId} instances
* @throws IllegalArgumentException is the string provided is blank
*/
public static List<SpiffeId> toListOfSpiffeIds(final String spiffeIds) {
return toListOfSpiffeIds(spiffeIds, DEFAULT_CHAR_SEPARATOR);
public static Set<SpiffeId> toSetOfSpiffeIds(final String spiffeIds) {
return toSetOfSpiffeIds(spiffeIds, DEFAULT_CHAR_SEPARATOR);
}
private SpiffeIdUtils() {

View File

@ -24,8 +24,9 @@ import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.Date;
import java.util.List;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/**
* Represents a SPIFFE JWT-SVID.
@ -41,7 +42,7 @@ public class JwtSvid {
/**
* Audience is the intended recipients of JWT-SVID as present in the 'aud' claim.
*/
List<String> audience;
Set<String> audience;
/**
* Expiration time of JWT-SVID as present in 'exp' claim.
@ -68,7 +69,7 @@ public class JwtSvid {
* @param token the token as string
*/
JwtSvid(@NonNull final SpiffeId spiffeId,
@NonNull final List<String> audience,
@NonNull final Set<String> audience,
@NonNull final Date expiry,
@NonNull final Map<String, Object> claims,
@NonNull final String token) {
@ -99,24 +100,29 @@ public class JwtSvid {
*/
public static JwtSvid parseAndValidate(@NonNull final String token,
@NonNull final BundleSource<JwtBundle> jwtBundleSource,
@NonNull final List<String> audience)
@NonNull final Set<String> audience)
throws JwtSvidException, BundleNotFoundException, AuthorityNotFoundException {
if (StringUtils.isBlank(token)) {
throw new IllegalArgumentException("Token cannot be blank");
}
final SignedJWT signedJwt;
final JWTClaimsSet claimsSet;
try {
val signedJwt = SignedJWT.parse(token);
val claimsSet = signedJwt.getJWTClaimsSet();
signedJwt = SignedJWT.parse(token);
claimsSet = signedJwt.getJWTClaimsSet();
} catch (ParseException e) {
throw new IllegalArgumentException("Unable to parse JWT token", e);
}
List<String> claimAudience = claimsSet.getAudience();
Set<String> claimAudience = new HashSet<>(claimsSet.getAudience());
validateAudience(claimAudience, audience);
val expirationTime = claimsSet.getExpirationTime();
validateExpiration(expirationTime);
val spiffeId = getSpiffeId(claimsSet);
val spiffeId = getSpiffeIdOfSubject(claimsSet);
val jwtBundle = jwtBundleSource.getBundleForTrustDomain(spiffeId.getTrustDomain());
val keyId = getKeyId(signedJwt.getHeader());
@ -126,9 +132,6 @@ public class JwtSvid {
verifySignature(signedJwt, jwtAuthority, algorithm, keyId);
return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token);
} catch (ParseException e) {
throw new IllegalArgumentException("Unable to parse JWT token", e);
}
}
/**
@ -144,27 +147,29 @@ public class JwtSvid {
* the 'aud' has an audience that is not in the audience provided as parameter
* @throws IllegalArgumentException when the token cannot be parsed
*/
public static JwtSvid parseInsecure(@NonNull final String token, @NonNull final List<String> audience) throws JwtSvidException {
public static JwtSvid parseInsecure(@NonNull final String token, @NonNull final Set<String> audience) throws JwtSvidException {
if (StringUtils.isBlank(token)) {
throw new IllegalArgumentException("Token cannot be blank");
}
final SignedJWT signedJwt;
final JWTClaimsSet claimsSet;
try {
val signedJwt = SignedJWT.parse(token);
val claimsSet = signedJwt.getJWTClaimsSet();
signedJwt = SignedJWT.parse(token);
claimsSet = signedJwt.getJWTClaimsSet();
} catch (ParseException e) {
throw new IllegalArgumentException("Unable to parse JWT token", e);
}
List<String> claimAudience = claimsSet.getAudience();
Set<String> claimAudience = new HashSet<>(claimsSet.getAudience());
validateAudience(claimAudience, audience);
val expirationTime = claimsSet.getExpirationTime();
validateExpiration(expirationTime);
val spiffeId = getSpiffeId(claimsSet);
val spiffeId = getSpiffeIdOfSubject(claimsSet);
return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token);
} catch (ParseException e) {
throw new IllegalArgumentException("Unable to parse JWT token", e);
}
}
/**
@ -173,7 +178,7 @@ public class JwtSvid {
*
* @return the token as String
*/
public String marshall() {
public String marshal() {
return token;
}
@ -202,9 +207,10 @@ public class JwtSvid {
private static JWSVerifier getJwsVerifier(final PublicKey jwtAuthority, final String algorithm) throws JOSEException, JwtSvidException {
JWSVerifier verifier;
if (Algorithm.Family.EC.contains(Algorithm.parse(algorithm))) {
final Algorithm alg = Algorithm.parse(algorithm);
if (Algorithm.Family.EC.contains(alg)) {
verifier = new ECDSAVerifier((ECPublicKey) jwtAuthority);
} else if (Algorithm.Family.RSA.contains(Algorithm.parse(algorithm))) {
} else if (Algorithm.Family.RSA.contains(alg)) {
verifier = new RSASSAVerifier((RSAPublicKey) jwtAuthority);
} else {
throw new JwtSvidException(String.format("Unsupported token signature algorithm %s", algorithm));
@ -214,9 +220,12 @@ public class JwtSvid {
private static String getKeyId(final JWSHeader header) throws JwtSvidException {
val keyId = header.getKeyID();
if (StringUtils.isBlank(keyId)) {
if (keyId == null) {
throw new JwtSvidException("Token header missing key id");
}
if (StringUtils.isBlank(keyId)) {
throw new JwtSvidException("Token header key id contains an empty value");
}
return keyId;
}
@ -230,7 +239,7 @@ public class JwtSvid {
}
}
private static SpiffeId getSpiffeId(final JWTClaimsSet claimsSet) throws JwtSvidException {
private static SpiffeId getSpiffeIdOfSubject(final JWTClaimsSet claimsSet) throws JwtSvidException {
val subject = claimsSet.getSubject();
if (StringUtils.isBlank(subject)) {
throw new JwtSvidException("Token missing subject claim");
@ -244,7 +253,7 @@ public class JwtSvid {
}
private static void validateAudience(final List<String> audClaim, final List<String> audience) throws JwtSvidException {
private static void validateAudience(final Set<String> audClaim, final Set<String> audience) throws JwtSvidException {
for (String aud : audClaim) {
if (!audience.contains(aud)) {
throw new JwtSvidException(String.format("expected audience in %s (audience=%s)", audience, audClaim));

View File

@ -26,7 +26,7 @@ import java.util.List;
* Contains a SPIFFE ID, a private key and a chain of X.509 certificates.
*/
@Value
public class X509Svid implements X509SvidSource {
public class X509Svid {
SpiffeId spiffeId;
@ -48,8 +48,17 @@ public class X509Svid implements X509SvidSource {
this.privateKey = privateKey;
}
/**
* @return the Leaf Certificate of the chain
*/
public X509Certificate getLeaf() {
return chain.get(0);
}
/**
* Loads the X.509 SVID from PEM encoded files on disk.
* <p>
* It is assumed that the leaf certificate is always the first certificate in the parsed chain.
*
* @param certsFilePath path to X.509 certificate chain file
* @param privateKeyFilePath path to private key file
@ -76,6 +85,8 @@ public class X509Svid implements X509SvidSource {
/**
* Parses the X.509 SVID from PEM or DER blocks containing certificate chain and key
* bytes. The key must be a PEM or DER block with PKCS#8.
* <p>
* It is assumed that the leaf certificate is always the first certificate in the parsed chain.
*
* @param certsBytes chain of certificates as a byte array
* @param privateKeyBytes private key as byte array
@ -89,6 +100,8 @@ public class X509Svid implements X509SvidSource {
/**
* Parses the X509-SVID from certificate and key bytes. The certificate must be ASN.1 DER (concatenated with
* no intermediate padding if there are more than one certificate). The key must be a PKCS#8 ASN.1 DER.
* <p>
* It is assumed that the leaf certificate is always the first certificate in the parsed chain.
*
* @param certsBytes chain of certificates as a byte array
* @param privateKeyBytes private key as byte array
@ -106,14 +119,6 @@ public class X509Svid implements X509SvidSource {
return chain.toArray(new X509Certificate[0]);
}
/**
* @return this instance, implementing a X509Svid interface.
*/
@Override
public X509Svid getX509Svid() {
return this;
}
private static X509Svid createX509Svid(final byte[] certsBytes, final byte[] privateKeyBytes, KeyFileFormat keyFileFormat) throws X509SvidException {
List<X509Certificate> x509Certificates;
@ -142,14 +147,19 @@ public class X509Svid implements X509SvidSource {
validateLeafCertificate(x509Certificates.get(0));
if (x509Certificates.size() > 1) {
validateSigningCertificates(x509Certificates.subList(1, x509Certificates.size()));
validateSigningCertificates(x509Certificates);
}
return new X509Svid(spiffeId, x509Certificates, privateKey);
}
private static void validateSigningCertificates(final List<X509Certificate> certificates) throws X509SvidException {
for (X509Certificate cert : certificates) {
for (int i = 1; i < certificates.size(); i++) {
verifyCaCert(certificates.get(i));
}
}
private static void verifyCaCert(final X509Certificate cert) throws X509SvidException {
if (!CertificateUtils.isCA(cert)) {
throw new X509SvidException("Signing certificate must have CA flag set to true");
}
@ -157,16 +167,15 @@ public class X509Svid implements X509SvidSource {
throw new X509SvidException("Signing certificate must have 'keyCertSign' as key usage");
}
}
}
private static void validateLeafCertificate(final X509Certificate leaf) throws X509SvidException {
if (CertificateUtils.isCA(leaf)) {
throw new X509SvidException("Leaf certificate must not have CA flag set to true");
}
validateKeyUsage(leaf);
validateKeyUsageOfLeafCertificate(leaf);
}
private static void validateKeyUsage(final X509Certificate leaf) throws X509SvidException {
private static void validateKeyUsageOfLeafCertificate(final X509Certificate leaf) throws X509SvidException {
if (!CertificateUtils.hasKeyUsageDigitalSignature(leaf)) {
throw new X509SvidException("Leaf certificate must have 'digitalSignature' as key usage");
}

View File

@ -1,18 +1,19 @@
package io.spiffe.svid.x509svid;
import io.spiffe.bundle.BundleSource;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.internal.CertificateUtils;
import io.spiffe.spiffeid.SpiffeId;
import lombok.NonNull;
import lombok.val;
import io.spiffe.bundle.x509bundle.X509Bundle;
import java.security.cert.CertPathValidatorException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;
/**
@ -48,17 +49,17 @@ public class X509SvidValidator {
* Checks that the X.509 SVID provided has a SPIFFE ID that is in the list of accepted SPIFFE IDs supplied.
*
* @param x509Certificate a {@link X509Svid} with a SPIFFE ID to be verified
* @param acceptedSpiffedIdsSupplier a {@link Supplier} of a list os SPIFFE IDs that are accepted
* @throws CertificateException is the SPIFFE ID in x509Certificate is not in the list supplied by acceptedSpiffedIdsSupplier,
* @param acceptedSpiffeIdsSupplier a {@link Supplier} of a Set of SPIFFE IDs that are accepted
* @throws CertificateException if the SPIFFE ID in x509Certificate is not in the Set supplied by acceptedSpiffeIdsSupplier,
* or if the SPIFFE ID cannot be parsed from the x509Certificate
* @throws NullPointerException if the given x509Certificate or acceptedSpiffedIdsSupplier are null
* @throws NullPointerException if the given x509Certificate or acceptedSpiffeIdsSupplier are null
*/
public static void verifySpiffeId(@NonNull final X509Certificate x509Certificate,
@NonNull final Supplier<List<SpiffeId>> acceptedSpiffedIdsSupplier)
@NonNull final Supplier<Set<SpiffeId>> acceptedSpiffeIdsSupplier)
throws CertificateException {
val spiffeIdList = acceptedSpiffedIdsSupplier.get();
val spiffeIdSet = acceptedSpiffeIdsSupplier.get();
val spiffeId = CertificateUtils.getSpiffeId(x509Certificate);
if (!spiffeIdList.contains(spiffeId)) {
if (!spiffeIdSet.contains(spiffeId)) {
throw new CertificateException(String.format("SPIFFE ID %s in X.509 certificate is not accepted", spiffeId));
}
}

View File

@ -32,8 +32,10 @@ import java.security.KeyException;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
@ -192,7 +194,7 @@ public class WorkloadApiClient implements Closeable {
* @throws JwtSvidException if there is an error fetching or processing the JWT from the Workload API
*/
public JwtSvid fetchJwtSvid(@NonNull final SpiffeId subject, @NonNull final String audience, final String... extraAudience) throws JwtSvidException {
List<String> audParam = new ArrayList<>();
Set<String> audParam = new HashSet<>();
audParam.add(audience);
Collections.addAll(audParam, extraAudience);
@ -238,7 +240,7 @@ public class WorkloadApiClient implements Closeable {
throw new JwtSvidException("Error validating JWT SVID", e);
}
return JwtSvid.parseInsecure(token, Collections.singletonList(audience));
return JwtSvid.parseInsecure(token, Collections.singleton(audience));
}
/**
@ -388,7 +390,7 @@ public class WorkloadApiClient implements Closeable {
throw new X509ContextException("Error processing X509Context: x509SVIDResponse is empty");
}
private JwtSvid callFetchJwtSvid(SpiffeId subject, List<String> audience) throws JwtSvidException {
private JwtSvid callFetchJwtSvid(SpiffeId subject, Set<String> audience) throws JwtSvidException {
Workload.JWTSVIDRequest jwtsvidRequest = Workload.JWTSVIDRequest
.newBuilder()
.setSpiffeId(subject.toString())

View File

@ -1,5 +1,6 @@
package io.spiffe.spiffeid;
import lombok.val;
import org.junit.jupiter.api.Test;
import java.io.IOException;
@ -8,36 +9,37 @@ import java.net.URISyntaxException;
import java.nio.file.NoSuchFileException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
class SpiffeIdUtilsTest {
@Test
void getSpiffeIdListFromFile() throws URISyntaxException {
void getSpiffeIdSetFromFile() throws URISyntaxException {
Path path = Paths.get(toUri("testdata/spiffeid/spiffeIds.txt"));
try {
List<SpiffeId> spiffeIdList = SpiffeIdUtils.getSpiffeIdListFromFile(path);
assertNotNull(spiffeIdList);
assertEquals(3, spiffeIdList.size());
assertEquals(SpiffeId.parse("spiffe://example.org/workload1"), spiffeIdList.get(0));
assertEquals(SpiffeId.parse("spiffe://example.org/workload2"), spiffeIdList.get(1));
assertEquals(SpiffeId.parse("spiffe://example2.org/workload1"), spiffeIdList.get(2));
Set<SpiffeId> spiffeIdSet = SpiffeIdUtils.getSpiffeIdSetFromFile(path);
assertNotNull(spiffeIdSet);
assertEquals(3, spiffeIdSet.size());
assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example.org/workload1")));
assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example.org/workload2")));
assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example2.org/workload1")));
} catch (IOException e) {
fail(e);
}
}
@Test
void getSpiffeIdListFromNonExistenFile_throwsException() throws IOException {
void getSpiffeIdSetFromNonExistenFile_throwsException() throws IOException {
Path path = Paths.get("testdata/spiffeid/non-existent-file");
try {
SpiffeIdUtils.getSpiffeIdListFromFile(path);
SpiffeIdUtils.getSpiffeIdSetFromFile(path);
fail("should have thrown exception");
} catch (NoSuchFileException e) {
assertEquals("testdata/spiffeid/non-existent-file", e.getMessage());
@ -45,15 +47,14 @@ class SpiffeIdUtilsTest {
}
@Test
void toListOfSpiffeIds() {
String spiffeIdsAsString = " spiffe://example.org/workload1, spiffe://example.org/workload2 ";
void toSetOfSpiffeIds() {
val spiffeIdsAsString = " spiffe://example.org/workload1, spiffe://example.org/workload2 ";
val spiffeIdSet = SpiffeIdUtils.toSetOfSpiffeIds(spiffeIdsAsString, ',');
List<SpiffeId> spiffeIdList = SpiffeIdUtils.toListOfSpiffeIds(spiffeIdsAsString, ',');
assertNotNull(spiffeIdList);
assertEquals(2, spiffeIdList.size());
assertEquals(SpiffeId.parse("spiffe://example.org/workload1"), spiffeIdList.get(0));
assertEquals(SpiffeId.parse("spiffe://example.org/workload2"), spiffeIdList.get(1));
assertNotNull(spiffeIdSet);
assertEquals(2, spiffeIdSet.size());
assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example.org/workload1")));
assertTrue(spiffeIdSet.contains(SpiffeId.parse("spiffe://example.org/workload2")));
}
private URI toUri(String path) throws URISyntaxException {

View File

@ -2,24 +2,24 @@ package io.spiffe.svid.jwtsvid;
import com.nimbusds.jose.jwk.Curve;
import com.nimbusds.jwt.JWTClaimsSet;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.exception.AuthorityNotFoundException;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.utils.TestUtils;
import lombok.Builder;
import lombok.Value;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.utils.TestUtils;
import java.security.KeyPair;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Stream;
@ -43,7 +43,7 @@ class JwtSvidParseAndValidateTest {
assertEquals(testCase.expectedJwtSvid.getAudience(), jwtSvid.getAudience());
assertEquals(testCase.expectedJwtSvid.getExpiry().toInstant().getEpochSecond(), jwtSvid.getExpiry().toInstant().getEpochSecond());
assertEquals(token, jwtSvid.getToken());
assertEquals(token, jwtSvid.marshall());
assertEquals(token, jwtSvid.marshal());
} catch (Exception e) {
assertEquals(testCase.expectedException.getClass(), e.getClass());
assertEquals(testCase.expectedException.getMessage(), e.getMessage());
@ -54,7 +54,7 @@ class JwtSvidParseAndValidateTest {
void testParseAndValidate_nullToken_throwsNullPointerException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException {
TrustDomain trustDomain = TrustDomain.of("test.domain");
JwtBundle jwtBundle = new JwtBundle(trustDomain);
List<String> audience = Collections.singletonList("audience");
Set<String> audience = Collections.singleton("audience");
try {
JwtSvid.parseAndValidate(null, jwtBundle, audience);
@ -67,7 +67,7 @@ class JwtSvidParseAndValidateTest {
void testParseAndValidate_emptyToken_throwsIllegalArgumentException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException {
TrustDomain trustDomain = TrustDomain.of("test.domain");
JwtBundle jwtBundle = new JwtBundle(trustDomain);
List<String> audience = Collections.singletonList("audience");
Set<String> audience = Collections.singleton("audience");
try {
JwtSvid.parseAndValidate("", jwtBundle, audience);
@ -78,7 +78,7 @@ class JwtSvidParseAndValidateTest {
@Test
void testParseAndValidate_nullBundle_throwsNullPointerException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException {
List<String> audience = Collections.singletonList("audience");
Set<String> audience = Collections.singleton("audience");
try {
JwtSvid.parseAndValidate("token", null, audience);
} catch (NullPointerException e) {
@ -90,7 +90,6 @@ class JwtSvidParseAndValidateTest {
void testParseAndValidate_nullAudience_throwsNullPointerException() throws JwtSvidException, AuthorityNotFoundException, BundleNotFoundException {
TrustDomain trustDomain = TrustDomain.of("test.domain");
JwtBundle jwtBundle = new JwtBundle(trustDomain);
List<String> audience = Collections.singletonList("audience");
try {
JwtSvid.parseAndValidate("token", jwtBundle, null);
@ -112,7 +111,7 @@ class JwtSvidParseAndValidateTest {
SpiffeId spiffeId = trustDomain.newSpiffeId("host");
Date expiration = new Date(System.currentTimeMillis() + 3600000);
List<String> audience = Collections.singletonList("audience");
Set<String> audience = Collections.singleton("audience");
JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration);
@ -179,7 +178,7 @@ class JwtSvidParseAndValidateTest {
Arguments.of(TestCase.builder()
.name("8. unexpected audience")
.jwtBundle(jwtBundle)
.expectedAudience(Collections.singletonList("another"))
.expectedAudience(Collections.singleton("another"))
.generateToken(() -> TestUtils.generateToken(claims, key1, "authority1"))
.expectedException(new JwtSvidException("expected audience in [another] (audience=[audience])"))
.build()),
@ -194,32 +193,39 @@ class JwtSvidParseAndValidateTest {
.name("10. missing key id")
.jwtBundle(jwtBundle)
.expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(claims, key1, ""))
.generateToken(() -> TestUtils.generateToken(claims, key1, null))
.expectedException(new JwtSvidException("Token header missing key id"))
.build()),
Arguments.of(TestCase.builder()
.name("11. no bundle for trust domain")
.name("11. key id contains an empty value")
.jwtBundle(jwtBundle)
.expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(claims, key1, " "))
.expectedException(new JwtSvidException("Token header key id contains an empty value"))
.build()),
Arguments.of(TestCase.builder()
.name("12. no bundle for trust domain")
.jwtBundle(new JwtBundle(TrustDomain.of("other.domain")))
.expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(claims, key1, "authority1"))
.expectedException(new BundleNotFoundException("No JWT bundle found for trust domain test.domain"))
.build()),
Arguments.of(TestCase.builder()
.name("12. no authority found for key id")
.name("13. no authority found for key id")
.jwtBundle(new JwtBundle(TrustDomain.of("test.domain")))
.expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(claims, key1, "authority1"))
.expectedException(new AuthorityNotFoundException("No authority found for the trust domain test.domain and key id authority1"))
.build()),
Arguments.of(TestCase.builder()
.name("13. signature cannot be verified with authority")
.name("14. signature cannot be verified with authority")
.jwtBundle(jwtBundle)
.expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(claims, key2, "authority1"))
.expectedException(new JwtSvidException("Signature invalid: cannot be verified with the authority with keyId=authority1"))
.build()),
Arguments.of(TestCase.builder()
.name("14. authority algorithm mismatch")
.name("15. authority algorithm mismatch")
.jwtBundle(jwtBundle)
.expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(claims, key3, "authority1"))
@ -232,13 +238,13 @@ class JwtSvidParseAndValidateTest {
static class TestCase {
String name;
JwtBundle jwtBundle;
List<String> audience;
Set<String> audience;
Supplier<String> generateToken;
Exception expectedException;
JwtSvid expectedJwtSvid;
@Builder
public TestCase(String name, JwtBundle jwtBundle, List<String> expectedAudience, Supplier<String> generateToken,
public TestCase(String name, JwtBundle jwtBundle, Set<String> expectedAudience, Supplier<String> generateToken,
Exception expectedException, JwtSvid expectedJwtSvid) {
this.name = name;
this.jwtBundle = jwtBundle;

View File

@ -2,22 +2,22 @@ package io.spiffe.svid.jwtsvid;
import com.nimbusds.jose.jwk.Curve;
import com.nimbusds.jwt.JWTClaimsSet;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.utils.TestUtils;
import lombok.Builder;
import lombok.Value;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.utils.TestUtils;
import java.security.KeyPair;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Stream;
@ -46,7 +46,7 @@ class JwtSvidParseInsecureTest {
@Test
void testParseInsecure_nullToken_throwsNullPointerException() throws JwtSvidException {
List<String> audience = Collections.singletonList("audience");
Set<String> audience = Collections.singleton("audience");
try {
JwtSvid.parseInsecure(null, audience);
@ -57,7 +57,7 @@ class JwtSvidParseInsecureTest {
@Test
void testParseAndValidate_emptyToken_throwsIllegalArgumentException() throws JwtSvidException {
List<String> audience = Collections.singletonList("audience");
Set<String> audience = Collections.singleton("audience");
try {
JwtSvid.parseInsecure("", audience);
} catch (IllegalArgumentException e) {
@ -71,7 +71,7 @@ class JwtSvidParseInsecureTest {
KeyPair key1 = TestUtils.generateECKeyPair(Curve.P_521);
TrustDomain trustDomain = TrustDomain.of("test.domain");
SpiffeId spiffeId = trustDomain.newSpiffeId("host");
List<String> audience = Collections.singletonList("audience");
Set<String> audience = Collections.singleton("audience");
Date expiration = new Date(System.currentTimeMillis() + 3600000);
JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration);
@ -93,7 +93,7 @@ class JwtSvidParseInsecureTest {
SpiffeId spiffeId = trustDomain.newSpiffeId("host");
Date expiration = new Date(System.currentTimeMillis() + 3600000);
List<String> audience = Collections.singletonList("audience");
Set<String> audience = Collections.singleton("audience");
JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, spiffeId.toString(), expiration);
@ -135,7 +135,7 @@ class JwtSvidParseInsecureTest {
.build()),
Arguments.of(TestCase.builder()
.name("unexpected audience")
.expectedAudience(Collections.singletonList("another"))
.expectedAudience(Collections.singleton("another"))
.generateToken(() -> TestUtils.generateToken(claims, key1, "authority1"))
.expectedException(new JwtSvidException("expected audience in [another] (audience=[audience])"))
.build()),
@ -151,13 +151,13 @@ class JwtSvidParseInsecureTest {
@Value
static class TestCase {
String name;
List<String> audience;
Set<String> audience;
Supplier<String> generateToken;
Exception expectedException;
JwtSvid expectedJwtSvid;
@Builder
public TestCase(String name, List<String> expectedAudience, Supplier<String> generateToken,
public TestCase(String name, Set<String> expectedAudience, Supplier<String> generateToken,
Exception expectedException, JwtSvid expectedJwtSvid) {
this.name = name;
this.audience = expectedAudience;

View File

@ -264,11 +264,11 @@ public class X509SvidTest {
}
@Test
void testGetX509Svid() throws URISyntaxException, X509SvidException {
void testGetLeaf() throws URISyntaxException, X509SvidException {
Path certPath = Paths.get(toUri(certSingle));
Path keyPath = Paths.get(toUri(keyRSA));
X509Svid x509Svid = X509Svid.load(certPath, keyPath);
assertEquals(x509Svid, x509Svid.getX509Svid());
assertEquals(x509Svid.getChain().get(0), x509Svid.getLeaf());
}
@Test

View File

@ -1,27 +1,28 @@
package io.spiffe.svid.x509svid;
import io.spiffe.exception.BundleNotFoundException;
import lombok.val;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import com.google.common.collect.Sets;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.utils.X509CertificateTestUtils.CertAndKeyPair;
import lombok.val;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.net.URISyntaxException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import static java.util.Collections.EMPTY_LIST;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import static io.spiffe.utils.X509CertificateTestUtils.createCertificate;
import static io.spiffe.utils.X509CertificateTestUtils.createRootCA;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
public class X509SvidValidatorTest {
@ -84,19 +85,19 @@ public class X509SvidValidatorTest {
val spiffeId1 = SpiffeId.parse("spiffe://example.org/test");
val spiffeId2 = SpiffeId.parse("spiffe://example.org/test2");
val spiffeIdList = Arrays.asList(spiffeId1, spiffeId2);
val spiffeIdSet = Sets.newHashSet(spiffeId1, spiffeId2);
X509SvidValidator.verifySpiffeId(leaf.getCertificate(), () -> spiffeIdList);
X509SvidValidator.verifySpiffeId(leaf.getCertificate(), () -> spiffeIdSet);
}
@Test
void checkSpiffeId_givenASpiffeIdNotInTheListOfAcceptedIds_throwsCertificateException() throws IOException, CertificateException, URISyntaxException {
val spiffeId1 = SpiffeId.parse("spiffe://example.org/other1");
val spiffeId2 = SpiffeId.parse("spiffe://example.org/other2");
List<SpiffeId> spiffeIdList = Arrays.asList(spiffeId1, spiffeId2);
val spiffeIdSet = Sets.newHashSet(spiffeId1, spiffeId2);
try {
X509SvidValidator.verifySpiffeId(leaf.getCertificate(), () -> spiffeIdList);
X509SvidValidator.verifySpiffeId(leaf.getCertificate(), () -> spiffeIdSet);
fail("Should have thrown CertificateException");
} catch (CertificateException e) {
assertEquals("SPIFFE ID spiffe://example.org/test in X.509 certificate is not accepted", e.getMessage());
@ -106,7 +107,7 @@ public class X509SvidValidatorTest {
@Test
void checkSpiffeId_nullX509Certificate_throwsNullPointerException() throws CertificateException {
try {
X509SvidValidator.verifySpiffeId(null, () -> EMPTY_LIST);
X509SvidValidator.verifySpiffeId(null, Collections::emptySet);
fail("should have thrown an exception");
} catch (NullPointerException e) {
assertEquals("x509Certificate is marked non-null but is null", e.getMessage());
@ -119,7 +120,7 @@ public class X509SvidValidatorTest {
X509SvidValidator.verifySpiffeId(leaf.getCertificate(), null);
fail("should have thrown an exception");
} catch (NullPointerException e) {
assertEquals("acceptedSpiffedIdsSupplier is marked non-null but is null", e.getMessage());
assertEquals("acceptedSpiffeIdsSupplier is marked non-null but is null", e.getMessage());
}
}

View File

@ -17,9 +17,11 @@ import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.spec.ECGenParameterSpec;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Stream.generate;
@ -87,11 +89,11 @@ public class TestUtils {
}
}
public static JWTClaimsSet buildJWTClaimSet(List<String> audience, String spiffeId, Date expiration) {
public static JWTClaimsSet buildJWTClaimSet(Set<String> audience, String spiffeId, Date expiration) {
return new JWTClaimsSet.Builder()
.subject(spiffeId)
.expirationTime(expiration)
.audience(audience)
.audience(new ArrayList<>(audience))
.build();
}

View File

@ -129,7 +129,7 @@ class FakeWorkloadApi extends SpiffeWorkloadAPIImplBase {
JwtSvid jwtSvid = null;
try {
jwtSvid = JwtSvid.parseInsecure(token, Collections.singletonList(audience));
jwtSvid = JwtSvid.parseInsecure(token, Collections.singleton(audience));
} catch (JwtSvidException e) {
responseObserver.onError(new StatusRuntimeException(Status.INVALID_ARGUMENT.withDescription(e.getMessage())));
}

View File

@ -1,28 +1,28 @@
package io.spiffe.workloadapi;
import com.google.common.collect.Sets;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.testing.GrpcCleanupRule;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.exception.JwtSourceException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.SocketEndpointAddressException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import org.junit.Rule;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.svid.jwtsvid.JwtSvid;
import io.spiffe.workloadapi.grpc.SpiffeWorkloadAPIGrpc;
import io.spiffe.workloadapi.internal.ManagedChannelWrapper;
import io.spiffe.workloadapi.internal.SecurityHeaderInterceptor;
import org.junit.Rule;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@ -98,7 +98,7 @@ class JwtSourceTest {
JwtSvid svid = jwtSource.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertNotNull(svid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svid.getSpiffeId());
assertEquals(Arrays.asList("aud1", "aud2", "aud3"), svid.getAudience());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svid.getAudience());
} catch (JwtSvidException e) {
fail(e);
}

View File

@ -38,6 +38,7 @@ import java.util.concurrent.Executors;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
class WorkloadApiClientTest {
@ -163,7 +164,7 @@ class WorkloadApiClientTest {
JwtSvid jwtSvid = workloadApiClient.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertNotNull(jwtSvid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), jwtSvid.getSpiffeId());
assertEquals("aud1", jwtSvid.getAudience().get(0));
assertTrue(jwtSvid.getAudience().contains("aud1"));
assertEquals(3, jwtSvid.getAudience().size());
} catch (JwtSvidException e) {
fail(e);
@ -177,7 +178,7 @@ class WorkloadApiClientTest {
JwtSvid jwtSvid = workloadApiClient.validateJwtSvid(token, "aud1");
assertNotNull(jwtSvid);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), jwtSvid.getSpiffeId());
assertEquals("aud1", jwtSvid.getAudience().get(0));
assertTrue(jwtSvid.getAudience().contains("aud1"));
assertEquals(1, jwtSvid.getAudience().size());
} catch (JwtSvidException e) {
fail(e);

View File

@ -1,3 +1,4 @@
spiffe://example.org/workload1
spiffe://example.org/workload2
spiffe://example2.org/workload1
spiffe://example.org/workload1

View File

@ -33,11 +33,11 @@ Alternatively, a different Workload API address can be used by passing it to the
X509Source x509Source = X509Source.newSource(sourceOptions);
Supplier<List<SpiffeId>> spiffeIdListSupplier = () -> Collections.singletonList(SpiffeId.parse("spiffe://example.org/test"));
Supplier<Set<SpiffeId>> spiffeIdSetSupplier = () -> Collections.singleton(SpiffeId.parse("spiffe://example.org/test"));
SslContextOptions sslContextOptions = SslContextOptions
.builder()
.acceptedSpiffeIdsSupplier(spiffeIdListSupplier)
.acceptedSpiffeIdsSupplier(spiffeIdSetSupplier)
.x509Source(x509Source)
.build();

View File

@ -11,7 +11,7 @@ import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;
/**
@ -72,14 +72,14 @@ public final class SpiffeSslContextFactory {
public static class SslContextOptions {
String sslProtocol;
X509Source x509Source;
Supplier<List<SpiffeId>> acceptedSpiffeIdsSupplier;
Supplier<Set<SpiffeId>> acceptedSpiffeIdsSupplier;
boolean acceptAnySpiffeId;
@Builder
public SslContextOptions(
final String sslProtocol,
final X509Source x509Source,
final Supplier<List<SpiffeId>> acceptedSpiffeIdsSupplier,
final Supplier<Set<SpiffeId>> acceptedSpiffeIdsSupplier,
final boolean acceptAnySpiffeId) {
this.x509Source = x509Source;
this.acceptedSpiffeIdsSupplier = acceptedSpiffeIdsSupplier;

View File

@ -1,9 +1,9 @@
package io.spiffe.provider;
import io.spiffe.bundle.BundleSource;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.svid.x509svid.X509SvidValidator;
import lombok.NonNull;
@ -12,9 +12,9 @@ import javax.net.ssl.X509ExtendedTrustManager;
import java.net.Socket;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Collections;
import java.util.Set;
import java.util.function.Supplier;
/**
@ -26,7 +26,7 @@ import java.util.function.Supplier;
public final class SpiffeTrustManager extends X509ExtendedTrustManager {
private final BundleSource<X509Bundle> x509BundleSource;
private final Supplier<List<SpiffeId>> acceptedSpiffeIdsSupplier;
private final Supplier<Set<SpiffeId>> acceptedSpiffeIdsSupplier;
private final boolean acceptAnySpiffeId;
/**
@ -36,10 +36,10 @@ public final class SpiffeTrustManager extends X509ExtendedTrustManager {
* and a {@link Supplier} of a List of accepted {@link SpiffeId} to be used during peer SVID validation.
*
* @param x509BundleSource an implementation of a {@link BundleSource}
* @param acceptedSpiffeIdsSupplier a {@link Supplier} of a list of accepted SPIFFE IDs.
* @param acceptedSpiffeIdsSupplier a {@link Supplier} of a Set of accepted SPIFFE IDs.
*/
public SpiffeTrustManager(@NonNull final BundleSource<X509Bundle> x509BundleSource,
@NonNull final Supplier<List<SpiffeId>> acceptedSpiffeIdsSupplier) {
@NonNull final Supplier<Set<SpiffeId>> acceptedSpiffeIdsSupplier) {
this.x509BundleSource = x509BundleSource;
this.acceptedSpiffeIdsSupplier = acceptedSpiffeIdsSupplier;
this.acceptAnySpiffeId = false;
@ -57,7 +57,7 @@ public final class SpiffeTrustManager extends X509ExtendedTrustManager {
public SpiffeTrustManager(@NonNull final BundleSource<X509Bundle> x509BundleSource,
final boolean acceptAnySpiffeId) {
this.x509BundleSource = x509BundleSource;
this.acceptedSpiffeIdsSupplier = ArrayList::new;
this.acceptedSpiffeIdsSupplier = Collections::emptySet;
this.acceptAnySpiffeId = acceptAnySpiffeId;
}

View File

@ -1,19 +1,19 @@
package io.spiffe.provider;
import io.spiffe.bundle.BundleSource;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.exception.SocketEndpointAddressException;
import io.spiffe.exception.X509SourceException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.SpiffeIdUtils;
import io.spiffe.workloadapi.X509Source;
import lombok.NonNull;
import io.spiffe.bundle.x509bundle.X509Bundle;
import javax.net.ssl.ManagerFactoryParameters;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactorySpi;
import java.security.KeyStore;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;
import static io.spiffe.provider.SpiffeProviderConstants.SSL_SPIFFE_ACCEPT_ALL_PROPERTY;
@ -35,11 +35,11 @@ import static io.spiffe.provider.SpiffeProviderConstants.SSL_SPIFFE_ACCEPT_PROPE
public class SpiffeTrustManagerFactory extends TrustManagerFactorySpi {
private static final boolean ACCEPT_ANY_SPIFFE_ID;
private static final Supplier<List<SpiffeId>> DEFAULT_SPIFFE_ID_LIST_SUPPLIER;
private static final Supplier<Set<SpiffeId>> DEFAULT_SPIFFE_ID_LIST_SUPPLIER;
static {
ACCEPT_ANY_SPIFFE_ID = Boolean.parseBoolean(EnvironmentUtils.getProperty(SSL_SPIFFE_ACCEPT_ALL_PROPERTY, "false"));
DEFAULT_SPIFFE_ID_LIST_SUPPLIER = () -> SpiffeIdUtils.toListOfSpiffeIds(EnvironmentUtils.getProperty(SSL_SPIFFE_ACCEPT_PROPERTY));
DEFAULT_SPIFFE_ID_LIST_SUPPLIER = () -> SpiffeIdUtils.toSetOfSpiffeIds(EnvironmentUtils.getProperty(SSL_SPIFFE_ACCEPT_PROPERTY));
}
/**
@ -113,12 +113,12 @@ public class SpiffeTrustManagerFactory extends TrustManagerFactorySpi {
* and a supplier of accepted SPIFFE IDs.
*
* @param x509BundleSource a {@link BundleSource} to provide the X.509-Bundles
* @param acceptedSpiffeIdsSupplier a Supplier to provide a List of SPIFFE IDs that are accepted
* @param acceptedSpiffeIdsSupplier a Supplier to provide a Set of SPIFFE IDs that are accepted
* @return an instance of a {@link TrustManager} wrapped in an array. The actual type returned is {@link SpiffeTrustManager}
*/
public TrustManager[] engineGetTrustManagers(
@NonNull final BundleSource<X509Bundle> x509BundleSource,
@NonNull final Supplier<List<SpiffeId>> acceptedSpiffeIdsSupplier) {
@NonNull final Supplier<Set<SpiffeId>> acceptedSpiffeIdsSupplier) {
SpiffeTrustManager spiffeTrustManager = new SpiffeTrustManager(x509BundleSource, acceptedSpiffeIdsSupplier);
return new TrustManager[]{spiffeTrustManager};

View File

@ -1,18 +1,18 @@
package io.spiffe.provider;
import io.spiffe.bundle.BundleSource;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.exception.X509SvidException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.svid.x509svid.X509Svid;
import lombok.val;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.svid.x509svid.X509Svid;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
@ -22,7 +22,7 @@ import java.nio.file.Paths;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
@ -36,7 +36,7 @@ public class SpiffeTrustManagerTest {
static X509Bundle x509Bundle;
static X509Svid x509Svid;
static X509Svid otherX509Svid;
List<SpiffeId> acceptedSpiffeIds;
Set<SpiffeId> acceptedSpiffeIds;
X509TrustManager trustManager;
@BeforeAll
@ -59,19 +59,12 @@ public class SpiffeTrustManagerTest {
void setupMocks() {
MockitoAnnotations.initMocks(this);
trustManager = (X509TrustManager)
new SpiffeTrustManagerFactory()
.engineGetTrustManagers(
bundleSource,
() -> acceptedSpiffeIds)[0];
new SpiffeTrustManagerFactory().engineGetTrustManagers(bundleSource, () -> acceptedSpiffeIds)[0];
}
@Test
void checkClientTrusted_passAExpiredCertificate_throwsException() throws BundleNotFoundException {
acceptedSpiffeIds =
Collections
.singletonList(
SpiffeId.parse("spiffe://example.org/test")
);
acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/test"));
val chain = x509Svid.getChainArray();
@ -87,11 +80,7 @@ public class SpiffeTrustManagerTest {
@Test
void checkClientTrusted_noBundleForTrustDomain_ThrowCertificateException() throws BundleNotFoundException {
acceptedSpiffeIds =
Collections
.singletonList(
SpiffeId.parse("spiffe://example.org/test")
);
acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/test"));
val chain = x509Svid.getChainArray();
@ -107,11 +96,7 @@ public class SpiffeTrustManagerTest {
@Test
void checkClientTrusted_passCertificateWithNonAcceptedSpiffeId_ThrowCertificateException() throws BundleNotFoundException {
acceptedSpiffeIds =
Collections
.singletonList(
SpiffeId.parse("spiffe://example.org/other")
);
acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/other"));
X509Certificate[] chain = x509Svid.getChainArray();
@ -128,11 +113,7 @@ public class SpiffeTrustManagerTest {
@Test
void checkClientTrusted_passCertificateThatDoesntChainToBundle_ThrowCertificateException() throws BundleNotFoundException {
acceptedSpiffeIds =
Collections
.singletonList(
SpiffeId.parse("spiffe://other.org/test")
);
acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://other.org/test"));
val chain = otherX509Svid.getChainArray();
@ -148,11 +129,7 @@ public class SpiffeTrustManagerTest {
@Test
void checkServerTrusted_passAnExpiredCertificate_ThrowsException() throws BundleNotFoundException {
acceptedSpiffeIds =
Collections
.singletonList(
SpiffeId.parse("spiffe://example.org/test")
);
acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/test"));
val chain = x509Svid.getChainArray();
@ -168,11 +145,7 @@ public class SpiffeTrustManagerTest {
@Test
void checkServerTrusted_passCertificateWithNonAcceptedSpiffeId_ThrowCertificateException() throws BundleNotFoundException {
acceptedSpiffeIds =
Collections
.singletonList(
SpiffeId.parse("spiffe://example.org/other")
);
acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://example.org/other"));
val chain = x509Svid.getChainArray();
@ -188,11 +161,7 @@ public class SpiffeTrustManagerTest {
@Test
void checkServerTrusted_passCertificateThatDoesntChainToBundle_ThrowCertificateException() throws BundleNotFoundException {
acceptedSpiffeIds =
Collections
.singletonList(
SpiffeId.parse("spiffe://other.org/test")
);
acceptedSpiffeIds = Collections.singleton(SpiffeId.parse("spiffe://other.org/test"));
val chain = otherX509Svid.getChainArray();

View File

@ -3,13 +3,13 @@ package io.spiffe.provider.examples.mtls;
import io.spiffe.exception.SocketEndpointAddressException;
import io.spiffe.exception.X509SourceException;
import io.spiffe.provider.SpiffeKeyManager;
import io.spiffe.provider.SpiffeSslContextFactory;
import io.spiffe.provider.SpiffeSslContextFactory.SslContextOptions;
import io.spiffe.provider.SpiffeTrustManager;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.SpiffeIdUtils;
import io.spiffe.workloadapi.X509Source;
import lombok.val;
import io.spiffe.provider.SpiffeSslContextFactory;
import io.spiffe.provider.SpiffeSslContextFactory.SslContextOptions;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
@ -20,7 +20,7 @@ import java.net.URISyntaxException;
import java.nio.file.Paths;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;
/**
@ -35,12 +35,12 @@ import java.util.function.Supplier;
public class HttpsClient {
String spiffeSocket;
Supplier<List<SpiffeId>> acceptedSpiffeIdsListSupplier;
Supplier<Set<SpiffeId>> acceptedSpiffeIdsSetSupplier;
int serverPort;
public static void main(String[] args) {
String spiffeSocket = "unix:/tmp/agent.sock";
HttpsClient httpsClient = new HttpsClient(4000, spiffeSocket, () -> new AcceptedSpiffeIds().getList());
HttpsClient httpsClient = new HttpsClient(4000, spiffeSocket, () -> new AcceptedSpiffeIds().getSet());
try {
httpsClient.run();
} catch (KeyManagementException | NoSuchAlgorithmException | IOException | SocketEndpointAddressException | X509SourceException e) {
@ -48,10 +48,10 @@ public class HttpsClient {
}
}
HttpsClient(int serverPort, String spiffeSocket, Supplier<List<SpiffeId>> acceptedSpiffeIdsListSupplier) {
HttpsClient(int serverPort, String spiffeSocket, Supplier<Set<SpiffeId>> acceptedSpiffeIdsSetSupplier) {
this.serverPort = serverPort;
this.spiffeSocket = spiffeSocket;
this.acceptedSpiffeIdsListSupplier = acceptedSpiffeIdsListSupplier;
this.acceptedSpiffeIdsSetSupplier = acceptedSpiffeIdsSetSupplier;
}
void run() throws IOException, SocketEndpointAddressException, KeyManagementException, NoSuchAlgorithmException, X509SourceException {
@ -64,7 +64,7 @@ public class HttpsClient {
SslContextOptions sslContextOptions = SslContextOptions
.builder()
.acceptedSpiffeIdsSupplier(acceptedSpiffeIdsListSupplier)
.acceptedSpiffeIdsSupplier(acceptedSpiffeIdsSetSupplier)
.x509Source(x509Source)
.build();
SSLContext sslContext = SpiffeSslContextFactory.getSslContext(sslContextOptions);
@ -76,9 +76,9 @@ public class HttpsClient {
}
private static class AcceptedSpiffeIds {
List<SpiffeId> getList() {
Set<SpiffeId> getSet() {
try {
return SpiffeIdUtils.getSpiffeIdListFromFile(Paths.get(toUri("testdata/spiffeIds.txt")));
return SpiffeIdUtils.getSpiffeIdSetFromFile(Paths.get(toUri("testdata/spiffeIds.txt")));
} catch (IOException | URISyntaxException e) {
throw new RuntimeException("Error getting list of spiffeIds", e);
}