diff --git a/java-spiffe-core/src/main/java/io/spiffe/bundle/jwtbundle/JwtBundle.java b/java-spiffe-core/src/main/java/io/spiffe/bundle/jwtbundle/JwtBundle.java index 047de11..0b83ef0 100644 --- a/java-spiffe-core/src/main/java/io/spiffe/bundle/jwtbundle/JwtBundle.java +++ b/java-spiffe-core/src/main/java/io/spiffe/bundle/jwtbundle/JwtBundle.java @@ -5,7 +5,7 @@ import com.nimbusds.jose.jwk.ECKey; import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.RSAKey; -import io.spiffe.Algorithm; +import io.spiffe.internal.JwtSignatureAlgorithm; import io.spiffe.bundle.BundleSource; import io.spiffe.exception.AuthorityNotFoundException; import io.spiffe.exception.BundleNotFoundException; @@ -72,7 +72,7 @@ public class JwtBundle implements BundleSource { try { val jwkSet = JWKSet.load(bundlePath.toFile()); return toJwtBundle(trustDomain, jwkSet); - } catch (IOException | ParseException | JOSEException e) { + } catch (IllegalArgumentException | IOException | ParseException | JOSEException e) { val error = "Could not load bundle from file: %s"; throw new JwtBundleException(String.format(error, bundlePath.toString()), e); } @@ -189,7 +189,7 @@ public class JwtBundle implements BundleSource { } private static PublicKey getPublicKey(final JWK jwk) throws JOSEException, ParseException, KeyException { - val family = Algorithm.Family.parse(jwk.getKeyType().getValue()); + val family = JwtSignatureAlgorithm.Family.parse(jwk.getKeyType().getValue()); final PublicKey publicKey; switch (family) { diff --git a/java-spiffe-core/src/main/java/io/spiffe/Algorithm.java b/java-spiffe-core/src/main/java/io/spiffe/internal/JwtSignatureAlgorithm.java similarity index 77% rename from java-spiffe-core/src/main/java/io/spiffe/Algorithm.java rename to java-spiffe-core/src/main/java/io/spiffe/internal/JwtSignatureAlgorithm.java index ea293c2..6f05209 100644 --- a/java-spiffe-core/src/main/java/io/spiffe/Algorithm.java +++ b/java-spiffe-core/src/main/java/io/spiffe/internal/JwtSignatureAlgorithm.java @@ -1,13 +1,15 @@ -package io.spiffe; +package io.spiffe.internal; + +import lombok.NonNull; import java.util.Collections; import java.util.HashSet; import java.util.Set; /** - * Represents JWT Algorithms. + * Represents JWT Signature Supported Algorithms. */ -public enum Algorithm { +public enum JwtSignatureAlgorithm { /** * ECDSA algorithm using SHA-256 hash algorithm. @@ -52,16 +54,11 @@ public enum Algorithm { /** * RSASSA-PSS using SHA-512 and MGF1 padding with SHA-512. */ - PS512("PS512"), - - /** - * Non-Supported algorithm. - */ - OTHER("OTHER"); + PS512("PS512"); private final String name; - Algorithm(final String name) { + JwtSignatureAlgorithm(final String name) { this.name = name; } @@ -74,13 +71,12 @@ public enum Algorithm { */ public enum Family { RSA("RSA", RS256, RS384, RS512, PS256, PS384, PS512), - EC("EC", ES256, ES384, ES512), - OTHER("UNKNOWN"); + EC("EC", ES256, ES384, ES512); private final String name; - private final Set algorithms; + private final Set algorithms; - Family(final String name, final Algorithm... algs) { + Family(final String name, final JwtSignatureAlgorithm... algs) { this.name = name; algorithms = new HashSet<>(); Collections.addAll(algorithms, algs); @@ -90,7 +86,7 @@ public enum Algorithm { return name; } - public boolean contains(final Algorithm a) { + public boolean contains(final JwtSignatureAlgorithm a) { return algorithms.contains(a); } @@ -101,14 +97,14 @@ public enum Algorithm { } else if (s.equals(EC.getName())) { family = EC; } else { - family = OTHER; + throw new IllegalArgumentException("Unsupported JWT family algorithm: " + s); } return family; } } - public static Algorithm parse(final String s) { - final Algorithm algorithm; + public static JwtSignatureAlgorithm parse(@NonNull final String s) { + final JwtSignatureAlgorithm algorithm; if (s.equals(RS256.getName())) { algorithm = RS256; } else if (s.equals(RS384.getName())) { @@ -128,7 +124,7 @@ public enum Algorithm { } else if (s.equals(PS512.getName())) { algorithm = PS512; } else { - algorithm = OTHER; + throw new IllegalArgumentException("Unsupported JWT algorithm: " + s); } return algorithm; } diff --git a/java-spiffe-core/src/main/java/io/spiffe/svid/jwtsvid/JwtSvid.java b/java-spiffe-core/src/main/java/io/spiffe/svid/jwtsvid/JwtSvid.java index 89e0339..8594102 100644 --- a/java-spiffe-core/src/main/java/io/spiffe/svid/jwtsvid/JwtSvid.java +++ b/java-spiffe-core/src/main/java/io/spiffe/svid/jwtsvid/JwtSvid.java @@ -1,13 +1,14 @@ package io.spiffe.svid.jwtsvid; import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSHeader; import com.nimbusds.jose.JWSVerifier; import com.nimbusds.jose.crypto.ECDSAVerifier; import com.nimbusds.jose.crypto.RSASSAVerifier; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; -import io.spiffe.Algorithm; +import io.spiffe.internal.JwtSignatureAlgorithm; import io.spiffe.bundle.BundleSource; import io.spiffe.bundle.jwtbundle.JwtBundle; import io.spiffe.exception.AuthorityNotFoundException; @@ -85,7 +86,8 @@ public class JwtSvid { * @return an instance of a {@link JwtSvid} with a SPIFFE ID parsed from the 'sub', audience from 'aud', and expiry * from 'exp' claim. * @throws JwtSvidException when the token expired or the expiration claim is missing, - * when the algorithm is not supported, when the header 'kid' is missing, + * when the algorithm is not supported (See {@link JwtSignatureAlgorithm}), + * when the header 'kid' is missing, * when the signature cannot be verified, or * when the 'aud' claim has an audience that is not in the audience list * provided as parameter @@ -106,8 +108,9 @@ public class JwtSvid { } val signedJwt = getSignedJWT(token); - val claimsSet = getJwtClaimsSet(signedJwt); + JwtSignatureAlgorithm algorithm = parseAlgorithm(signedJwt.getHeader().getAlgorithm()); + val claimsSet = getJwtClaimsSet(signedJwt); validateAudience(claimsSet.getAudience(), audience); val expirationTime = claimsSet.getExpirationTime(); @@ -119,7 +122,6 @@ public class JwtSvid { val keyId = getKeyId(signedJwt.getHeader()); val jwtAuthority = jwtBundle.findJwtAuthority(keyId); - val algorithm = signedJwt.getHeader().getAlgorithm().getName(); verifySignature(signedJwt, jwtAuthority, algorithm, keyId); val claimAudience = new HashSet<>(claimsSet.getAudience()); @@ -136,7 +138,8 @@ public class JwtSvid { * @return an instance of a {@link JwtSvid} with a SPIFFE ID parsed from the 'sub', audience from 'aud', and expiry * from 'exp' claim. * @throws JwtSvidException when the token expired or the expiration claim is missing, or when - * the 'aud' has an audience that is not in the audience provided as parameter + * the 'aud' has an audience that is not in the audience provided as parameter, + * or when the 'alg' is not supported (See {@link JwtSignatureAlgorithm}). * @throws IllegalArgumentException when the token cannot be parsed */ public static JwtSvid parseInsecure(@NonNull final String token, @NonNull final Set audience) throws JwtSvidException { @@ -145,8 +148,9 @@ public class JwtSvid { } val signedJwt = getSignedJWT(token); - val claimsSet = getJwtClaimsSet(signedJwt); + parseAlgorithm(signedJwt.getHeader().getAlgorithm()); + val claimsSet = getJwtClaimsSet(signedJwt); validateAudience(claimsSet.getAudience(), audience); val expirationTime = claimsSet.getExpirationTime(); @@ -216,7 +220,7 @@ public class JwtSvid { return signedJwt; } - private static void verifySignature(final SignedJWT signedJwt, final PublicKey jwtAuthority, final String algorithm, final String keyId) throws JwtSvidException { + private static void verifySignature(final SignedJWT signedJwt, final PublicKey jwtAuthority, final JwtSignatureAlgorithm algorithm, final String keyId) throws JwtSvidException { boolean verify; try { val verifier = getJwsVerifier(jwtAuthority, algorithm); @@ -230,12 +234,11 @@ public class JwtSvid { } } - private static JWSVerifier getJwsVerifier(final PublicKey jwtAuthority, final String algorithm) throws JOSEException, JwtSvidException { + private static JWSVerifier getJwsVerifier(final PublicKey jwtAuthority, final JwtSignatureAlgorithm algorithm) throws JOSEException, JwtSvidException { JWSVerifier verifier; - val alg = Algorithm.parse(algorithm); - if (Algorithm.Family.EC.contains(alg)) { + if (JwtSignatureAlgorithm.Family.EC.contains(algorithm)) { verifier = new ECDSAVerifier((ECPublicKey) jwtAuthority); - } else if (Algorithm.Family.RSA.contains(alg)) { + } else if (JwtSignatureAlgorithm.Family.RSA.contains(algorithm)) { verifier = new RSASSAVerifier((RSAPublicKey) jwtAuthority); } else { throw new JwtSvidException(String.format("Unsupported token signature algorithm %s", algorithm)); @@ -284,4 +287,16 @@ public class JwtSvid { throw new JwtSvidException(String.format("expected audience in %s (audience=%s)", expectedAudiences, audClaim)); } } + + private static JwtSignatureAlgorithm parseAlgorithm(JWSAlgorithm algorithm) throws JwtSvidException { + if (algorithm == null) { + throw new JwtSvidException("jwt header 'alg' is required"); + } + + try { + return JwtSignatureAlgorithm.parse(algorithm.getName()); + } catch (IllegalArgumentException e) { + throw new JwtSvidException(e.getMessage(), e); + } + } } diff --git a/java-spiffe-core/src/test/java/io/spiffe/bundle/jwtbundle/JwtBundleTest.java b/java-spiffe-core/src/test/java/io/spiffe/bundle/jwtbundle/JwtBundleTest.java index 423c76f..8ff9760 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/bundle/jwtbundle/JwtBundleTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/bundle/jwtbundle/JwtBundleTest.java @@ -157,15 +157,15 @@ class JwtBundleTest { } @Test - void testLoadFile_InvalidKeyType_ThrowsKeyException() throws URISyntaxException, JwtBundleException { + void testLoadFile_InvalidKeyType_ThrowsKeyException() throws URISyntaxException, KeyException { Path path = Paths.get(toUri("testdata/jwtbundle/jwks_invalid_keytype.json")); TrustDomain trustDomain = TrustDomain.of("domain.test"); try { JwtBundle.load(trustDomain, path); fail("should have thrown exception"); - } catch (KeyException e) { - assertEquals("Key Type not supported: OKP", e.getMessage()); + } catch (JwtBundleException e) { + assertEquals("Unsupported JWT family algorithm: OKP", e.getCause().getMessage()); } } diff --git a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/AlgorithmTest.java b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/AlgorithmTest.java deleted file mode 100644 index 34b42fd..0000000 --- a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/AlgorithmTest.java +++ /dev/null @@ -1,79 +0,0 @@ -package io.spiffe.svid.jwtsvid; - -import io.spiffe.Algorithm; -import lombok.Builder; -import lombok.Value; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; - -import java.util.stream.Stream; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -class AlgorithmTest { - - @ParameterizedTest - @MethodSource("provideTestCases") - void parse(TestCase testCase) { - Algorithm signatureAlgorithm = Algorithm.parse(testCase.name); - assertEquals(testCase.expectedAlgorithm, signatureAlgorithm); - assertEquals(testCase.name, signatureAlgorithm.getName()); - } - - static Stream provideTestCases() { - return Stream.of( - Arguments.of(TestCase.builder().name("RS256").expectedAlgorithm(Algorithm.RS256).build()), - Arguments.of(TestCase.builder().name("RS384").expectedAlgorithm(Algorithm.RS384).build()), - Arguments.of(TestCase.builder().name("RS512").expectedAlgorithm(Algorithm.RS512).build()), - Arguments.of(TestCase.builder().name("ES256").expectedAlgorithm(Algorithm.ES256).build()), - Arguments.of(TestCase.builder().name("ES384").expectedAlgorithm(Algorithm.ES384).build()), - Arguments.of(TestCase.builder().name("ES512").expectedAlgorithm(Algorithm.ES512).build()), - Arguments.of(TestCase.builder().name("PS256").expectedAlgorithm(Algorithm.PS256).build()), - Arguments.of(TestCase.builder().name("PS384").expectedAlgorithm(Algorithm.PS384).build()), - Arguments.of(TestCase.builder().name("PS512").expectedAlgorithm(Algorithm.PS512).build()), - Arguments.of(TestCase.builder().name("OTHER").expectedAlgorithm(Algorithm.OTHER).build()) - ); - } - - @Value - static class TestCase { - String name; - Algorithm expectedAlgorithm; - - @Builder - public TestCase(String name, Algorithm expectedAlgorithm) { - this.name = name; - this.expectedAlgorithm = expectedAlgorithm; - } - } - - @Test - void testParseFamilyRSA() { - Algorithm.Family rsa = Algorithm.Family.parse("RSA"); - assertEquals(Algorithm.Family.RSA, rsa); - assertTrue(rsa.contains(Algorithm.RS256)); - assertTrue(rsa.contains(Algorithm.RS384)); - assertTrue(rsa.contains(Algorithm.RS512)); - assertTrue(rsa.contains(Algorithm.PS256)); - assertTrue(rsa.contains(Algorithm.PS384)); - assertTrue(rsa.contains(Algorithm.PS512)); - } - - @Test - void testParseFamilyEC() { - Algorithm.Family ec = Algorithm.Family.parse("EC"); - assertEquals(Algorithm.Family.EC, ec); - assertTrue(ec.contains(Algorithm.ES256)); - assertTrue(ec.contains(Algorithm.ES384)); - assertTrue(ec.contains(Algorithm.ES512)); - } - - @Test - void testParseFamilyOTHER() { - Algorithm.Family other = Algorithm.Family.parse("unknown family"); - assertEquals(Algorithm.Family.OTHER, other); - } -} \ No newline at end of file diff --git a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSignatureAlgorithmTest.java b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSignatureAlgorithmTest.java new file mode 100644 index 0000000..bd6c087 --- /dev/null +++ b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSignatureAlgorithmTest.java @@ -0,0 +1,93 @@ +package io.spiffe.svid.jwtsvid; + +import io.spiffe.internal.JwtSignatureAlgorithm; +import lombok.Builder; +import lombok.Value; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +class JwtSignatureAlgorithmTest { + + @ParameterizedTest + @MethodSource("provideTestCases") + void parse(TestCase testCase) { + JwtSignatureAlgorithm signatureAlgorithm = JwtSignatureAlgorithm.parse(testCase.name); + assertEquals(testCase.expectedAlgorithm, signatureAlgorithm); + assertEquals(testCase.name, signatureAlgorithm.getName()); + } + + static Stream provideTestCases() { + return Stream.of( + Arguments.of(TestCase.builder().name("RS256").expectedAlgorithm(JwtSignatureAlgorithm.RS256).build()), + Arguments.of(TestCase.builder().name("RS384").expectedAlgorithm(JwtSignatureAlgorithm.RS384).build()), + Arguments.of(TestCase.builder().name("RS512").expectedAlgorithm(JwtSignatureAlgorithm.RS512).build()), + Arguments.of(TestCase.builder().name("ES256").expectedAlgorithm(JwtSignatureAlgorithm.ES256).build()), + Arguments.of(TestCase.builder().name("ES384").expectedAlgorithm(JwtSignatureAlgorithm.ES384).build()), + Arguments.of(TestCase.builder().name("ES512").expectedAlgorithm(JwtSignatureAlgorithm.ES512).build()), + Arguments.of(TestCase.builder().name("PS256").expectedAlgorithm(JwtSignatureAlgorithm.PS256).build()), + Arguments.of(TestCase.builder().name("PS384").expectedAlgorithm(JwtSignatureAlgorithm.PS384).build()), + Arguments.of(TestCase.builder().name("PS512").expectedAlgorithm(JwtSignatureAlgorithm.PS512).build()) + ); + } + + @Value + static class TestCase { + String name; + JwtSignatureAlgorithm expectedAlgorithm; + + @Builder + public TestCase(String name, JwtSignatureAlgorithm expectedAlgorithm) { + this.name = name; + this.expectedAlgorithm = expectedAlgorithm; + } + } + + @Test + void testParseFamilyRSA() { + JwtSignatureAlgorithm.Family rsa = JwtSignatureAlgorithm.Family.parse("RSA"); + assertEquals(JwtSignatureAlgorithm.Family.RSA, rsa); + assertTrue(rsa.contains(JwtSignatureAlgorithm.RS256)); + assertTrue(rsa.contains(JwtSignatureAlgorithm.RS384)); + assertTrue(rsa.contains(JwtSignatureAlgorithm.RS512)); + assertTrue(rsa.contains(JwtSignatureAlgorithm.PS256)); + assertTrue(rsa.contains(JwtSignatureAlgorithm.PS384)); + assertTrue(rsa.contains(JwtSignatureAlgorithm.PS512)); + } + + @Test + void testParseFamilyEC() { + JwtSignatureAlgorithm.Family ec = JwtSignatureAlgorithm.Family.parse("EC"); + assertEquals(JwtSignatureAlgorithm.Family.EC, ec); + assertTrue(ec.contains(JwtSignatureAlgorithm.ES256)); + assertTrue(ec.contains(JwtSignatureAlgorithm.ES384)); + assertTrue(ec.contains(JwtSignatureAlgorithm.ES512)); + } + + @Test + void testParseUnknownFamily() { + try { + JwtSignatureAlgorithm.Family.parse("unknown family"); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Unsupported JWT family algorithm: unknown family", e.getMessage()); + } + } + + @Test + void testParseUnsupportedAlgorithm() { + try { + JwtSignatureAlgorithm.parse("HS256"); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Unsupported JWT algorithm: HS256", e.getMessage()); + } + } +} \ No newline at end of file diff --git a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseAndValidateTest.java b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseAndValidateTest.java index 8a74075..44429c4 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseAndValidateTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseAndValidateTest.java @@ -154,7 +154,7 @@ class JwtSvidParseAndValidateTest { .jwtBundle(jwtBundle) .expectedAudience(Collections.singleton("audience")) .generateToken(() -> HS256TOKEN) - .expectedException(new JwtSvidException("Unsupported token signature algorithm HS256")) + .expectedException(new JwtSvidException("Unsupported JWT algorithm: HS256")) .build()), Arguments.of(TestCase.builder() .name("5. missing subject") diff --git a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseInsecureTest.java b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseInsecureTest.java index 4913098..5b6539f 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseInsecureTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseInsecureTest.java @@ -28,6 +28,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals; class JwtSvidParseInsecureTest { + private static final String HS256TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImF1dGhvcml0eTEifQ." + + "eyJzdWIiOiJzcGlmZmU6Ly90ZXN0LmRvbWFpbi9ob3N0IiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxMjM0MzQzNTM0NTUsImlh" + + "dCI6MTUxNjIzOTAyMiwiYXVkIjoiYXVkaWVuY2UifQ.wNm5pQGSLCw5N9ddgSF2hkgmQpGnG9le_gpiFmyBhao"; + @ParameterizedTest @MethodSource("provideJwtScenarios") void parseJwt(TestCase testCase) { @@ -147,6 +151,12 @@ class JwtSvidParseInsecureTest { .expectedAudience(audience) .generateToken(() -> TestUtils.generateToken(TestUtils.buildJWTClaimSet(audience, "non-spiffe-subject", expiration), key1, "authority1")) .expectedException(new JwtSvidException("Subject non-spiffe-subject cannot be parsed as a SPIFFE ID")) + .build()), + Arguments.of(TestCase.builder() + .name("unsupported algorithm") + .expectedAudience(Collections.singleton("audience")) + .generateToken(() -> HS256TOKEN) + .expectedException(new JwtSvidException("Unsupported JWT algorithm: HS256")) .build()) ); }