Merge pull request #60 from maxlambrecht/validate-jwt-alg

Validate 'alg' header when parsing JWT-SVIDs from tokens
This commit is contained in:
Ryan Turner 2021-02-05 10:11:02 -08:00 committed by GitHub
commit e33417b10b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 151 additions and 116 deletions

View File

@ -5,7 +5,7 @@ import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.RSAKey;
import io.spiffe.Algorithm;
import io.spiffe.internal.JwtSignatureAlgorithm;
import io.spiffe.bundle.BundleSource;
import io.spiffe.exception.AuthorityNotFoundException;
import io.spiffe.exception.BundleNotFoundException;
@ -72,7 +72,7 @@ public class JwtBundle implements BundleSource<JwtBundle> {
try {
val jwkSet = JWKSet.load(bundlePath.toFile());
return toJwtBundle(trustDomain, jwkSet);
} catch (IOException | ParseException | JOSEException e) {
} catch (IllegalArgumentException | IOException | ParseException | JOSEException e) {
val error = "Could not load bundle from file: %s";
throw new JwtBundleException(String.format(error, bundlePath.toString()), e);
}
@ -189,7 +189,7 @@ public class JwtBundle implements BundleSource<JwtBundle> {
}
private static PublicKey getPublicKey(final JWK jwk) throws JOSEException, ParseException, KeyException {
val family = Algorithm.Family.parse(jwk.getKeyType().getValue());
val family = JwtSignatureAlgorithm.Family.parse(jwk.getKeyType().getValue());
final PublicKey publicKey;
switch (family) {

View File

@ -1,13 +1,15 @@
package io.spiffe;
package io.spiffe.internal;
import lombok.NonNull;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
/**
* Represents JWT Algorithms.
* Represents JWT Signature Supported Algorithms.
*/
public enum Algorithm {
public enum JwtSignatureAlgorithm {
/**
* ECDSA algorithm using SHA-256 hash algorithm.
@ -52,16 +54,11 @@ public enum Algorithm {
/**
* RSASSA-PSS using SHA-512 and MGF1 padding with SHA-512.
*/
PS512("PS512"),
/**
* Non-Supported algorithm.
*/
OTHER("OTHER");
PS512("PS512");
private final String name;
Algorithm(final String name) {
JwtSignatureAlgorithm(final String name) {
this.name = name;
}
@ -74,13 +71,12 @@ public enum Algorithm {
*/
public enum Family {
RSA("RSA", RS256, RS384, RS512, PS256, PS384, PS512),
EC("EC", ES256, ES384, ES512),
OTHER("UNKNOWN");
EC("EC", ES256, ES384, ES512);
private final String name;
private final Set<Algorithm> algorithms;
private final Set<JwtSignatureAlgorithm> algorithms;
Family(final String name, final Algorithm... algs) {
Family(final String name, final JwtSignatureAlgorithm... algs) {
this.name = name;
algorithms = new HashSet<>();
Collections.addAll(algorithms, algs);
@ -90,7 +86,7 @@ public enum Algorithm {
return name;
}
public boolean contains(final Algorithm a) {
public boolean contains(final JwtSignatureAlgorithm a) {
return algorithms.contains(a);
}
@ -101,14 +97,14 @@ public enum Algorithm {
} else if (s.equals(EC.getName())) {
family = EC;
} else {
family = OTHER;
throw new IllegalArgumentException("Unsupported JWT family algorithm: " + s);
}
return family;
}
}
public static Algorithm parse(final String s) {
final Algorithm algorithm;
public static JwtSignatureAlgorithm parse(@NonNull final String s) {
final JwtSignatureAlgorithm algorithm;
if (s.equals(RS256.getName())) {
algorithm = RS256;
} else if (s.equals(RS384.getName())) {
@ -128,7 +124,7 @@ public enum Algorithm {
} else if (s.equals(PS512.getName())) {
algorithm = PS512;
} else {
algorithm = OTHER;
throw new IllegalArgumentException("Unsupported JWT algorithm: " + s);
}
return algorithm;
}

View File

@ -1,13 +1,14 @@
package io.spiffe.svid.jwtsvid;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.ECDSAVerifier;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import io.spiffe.Algorithm;
import io.spiffe.internal.JwtSignatureAlgorithm;
import io.spiffe.bundle.BundleSource;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.exception.AuthorityNotFoundException;
@ -85,7 +86,8 @@ public class JwtSvid {
* @return an instance of a {@link JwtSvid} with a SPIFFE ID parsed from the 'sub', audience from 'aud', and expiry
* from 'exp' claim.
* @throws JwtSvidException when the token expired or the expiration claim is missing,
* when the algorithm is not supported, when the header 'kid' is missing,
* when the algorithm is not supported (See {@link JwtSignatureAlgorithm}),
* when the header 'kid' is missing,
* when the signature cannot be verified, or
* when the 'aud' claim has an audience that is not in the audience list
* provided as parameter
@ -106,8 +108,9 @@ public class JwtSvid {
}
val signedJwt = getSignedJWT(token);
val claimsSet = getJwtClaimsSet(signedJwt);
JwtSignatureAlgorithm algorithm = parseAlgorithm(signedJwt.getHeader().getAlgorithm());
val claimsSet = getJwtClaimsSet(signedJwt);
validateAudience(claimsSet.getAudience(), audience);
val expirationTime = claimsSet.getExpirationTime();
@ -119,7 +122,6 @@ public class JwtSvid {
val keyId = getKeyId(signedJwt.getHeader());
val jwtAuthority = jwtBundle.findJwtAuthority(keyId);
val algorithm = signedJwt.getHeader().getAlgorithm().getName();
verifySignature(signedJwt, jwtAuthority, algorithm, keyId);
val claimAudience = new HashSet<>(claimsSet.getAudience());
@ -136,7 +138,8 @@ public class JwtSvid {
* @return an instance of a {@link JwtSvid} with a SPIFFE ID parsed from the 'sub', audience from 'aud', and expiry
* from 'exp' claim.
* @throws JwtSvidException when the token expired or the expiration claim is missing, or when
* the 'aud' has an audience that is not in the audience provided as parameter
* the 'aud' has an audience that is not in the audience provided as parameter,
* or when the 'alg' is not supported (See {@link JwtSignatureAlgorithm}).
* @throws IllegalArgumentException when the token cannot be parsed
*/
public static JwtSvid parseInsecure(@NonNull final String token, @NonNull final Set<String> audience) throws JwtSvidException {
@ -145,8 +148,9 @@ public class JwtSvid {
}
val signedJwt = getSignedJWT(token);
val claimsSet = getJwtClaimsSet(signedJwt);
parseAlgorithm(signedJwt.getHeader().getAlgorithm());
val claimsSet = getJwtClaimsSet(signedJwt);
validateAudience(claimsSet.getAudience(), audience);
val expirationTime = claimsSet.getExpirationTime();
@ -216,7 +220,7 @@ public class JwtSvid {
return signedJwt;
}
private static void verifySignature(final SignedJWT signedJwt, final PublicKey jwtAuthority, final String algorithm, final String keyId) throws JwtSvidException {
private static void verifySignature(final SignedJWT signedJwt, final PublicKey jwtAuthority, final JwtSignatureAlgorithm algorithm, final String keyId) throws JwtSvidException {
boolean verify;
try {
val verifier = getJwsVerifier(jwtAuthority, algorithm);
@ -230,12 +234,11 @@ public class JwtSvid {
}
}
private static JWSVerifier getJwsVerifier(final PublicKey jwtAuthority, final String algorithm) throws JOSEException, JwtSvidException {
private static JWSVerifier getJwsVerifier(final PublicKey jwtAuthority, final JwtSignatureAlgorithm algorithm) throws JOSEException, JwtSvidException {
JWSVerifier verifier;
val alg = Algorithm.parse(algorithm);
if (Algorithm.Family.EC.contains(alg)) {
if (JwtSignatureAlgorithm.Family.EC.contains(algorithm)) {
verifier = new ECDSAVerifier((ECPublicKey) jwtAuthority);
} else if (Algorithm.Family.RSA.contains(alg)) {
} else if (JwtSignatureAlgorithm.Family.RSA.contains(algorithm)) {
verifier = new RSASSAVerifier((RSAPublicKey) jwtAuthority);
} else {
throw new JwtSvidException(String.format("Unsupported token signature algorithm %s", algorithm));
@ -284,4 +287,16 @@ public class JwtSvid {
throw new JwtSvidException(String.format("expected audience in %s (audience=%s)", expectedAudiences, audClaim));
}
}
private static JwtSignatureAlgorithm parseAlgorithm(JWSAlgorithm algorithm) throws JwtSvidException {
if (algorithm == null) {
throw new JwtSvidException("jwt header 'alg' is required");
}
try {
return JwtSignatureAlgorithm.parse(algorithm.getName());
} catch (IllegalArgumentException e) {
throw new JwtSvidException(e.getMessage(), e);
}
}
}

View File

@ -157,15 +157,15 @@ class JwtBundleTest {
}
@Test
void testLoadFile_InvalidKeyType_ThrowsKeyException() throws URISyntaxException, JwtBundleException {
void testLoadFile_InvalidKeyType_ThrowsKeyException() throws URISyntaxException, KeyException {
Path path = Paths.get(toUri("testdata/jwtbundle/jwks_invalid_keytype.json"));
TrustDomain trustDomain = TrustDomain.of("domain.test");
try {
JwtBundle.load(trustDomain, path);
fail("should have thrown exception");
} catch (KeyException e) {
assertEquals("Key Type not supported: OKP", e.getMessage());
} catch (JwtBundleException e) {
assertEquals("Unsupported JWT family algorithm: OKP", e.getCause().getMessage());
}
}

View File

@ -1,79 +0,0 @@
package io.spiffe.svid.jwtsvid;
import io.spiffe.Algorithm;
import lombok.Builder;
import lombok.Value;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
class AlgorithmTest {
@ParameterizedTest
@MethodSource("provideTestCases")
void parse(TestCase testCase) {
Algorithm signatureAlgorithm = Algorithm.parse(testCase.name);
assertEquals(testCase.expectedAlgorithm, signatureAlgorithm);
assertEquals(testCase.name, signatureAlgorithm.getName());
}
static Stream<Arguments> provideTestCases() {
return Stream.of(
Arguments.of(TestCase.builder().name("RS256").expectedAlgorithm(Algorithm.RS256).build()),
Arguments.of(TestCase.builder().name("RS384").expectedAlgorithm(Algorithm.RS384).build()),
Arguments.of(TestCase.builder().name("RS512").expectedAlgorithm(Algorithm.RS512).build()),
Arguments.of(TestCase.builder().name("ES256").expectedAlgorithm(Algorithm.ES256).build()),
Arguments.of(TestCase.builder().name("ES384").expectedAlgorithm(Algorithm.ES384).build()),
Arguments.of(TestCase.builder().name("ES512").expectedAlgorithm(Algorithm.ES512).build()),
Arguments.of(TestCase.builder().name("PS256").expectedAlgorithm(Algorithm.PS256).build()),
Arguments.of(TestCase.builder().name("PS384").expectedAlgorithm(Algorithm.PS384).build()),
Arguments.of(TestCase.builder().name("PS512").expectedAlgorithm(Algorithm.PS512).build()),
Arguments.of(TestCase.builder().name("OTHER").expectedAlgorithm(Algorithm.OTHER).build())
);
}
@Value
static class TestCase {
String name;
Algorithm expectedAlgorithm;
@Builder
public TestCase(String name, Algorithm expectedAlgorithm) {
this.name = name;
this.expectedAlgorithm = expectedAlgorithm;
}
}
@Test
void testParseFamilyRSA() {
Algorithm.Family rsa = Algorithm.Family.parse("RSA");
assertEquals(Algorithm.Family.RSA, rsa);
assertTrue(rsa.contains(Algorithm.RS256));
assertTrue(rsa.contains(Algorithm.RS384));
assertTrue(rsa.contains(Algorithm.RS512));
assertTrue(rsa.contains(Algorithm.PS256));
assertTrue(rsa.contains(Algorithm.PS384));
assertTrue(rsa.contains(Algorithm.PS512));
}
@Test
void testParseFamilyEC() {
Algorithm.Family ec = Algorithm.Family.parse("EC");
assertEquals(Algorithm.Family.EC, ec);
assertTrue(ec.contains(Algorithm.ES256));
assertTrue(ec.contains(Algorithm.ES384));
assertTrue(ec.contains(Algorithm.ES512));
}
@Test
void testParseFamilyOTHER() {
Algorithm.Family other = Algorithm.Family.parse("unknown family");
assertEquals(Algorithm.Family.OTHER, other);
}
}

View File

@ -0,0 +1,93 @@
package io.spiffe.svid.jwtsvid;
import io.spiffe.internal.JwtSignatureAlgorithm;
import lombok.Builder;
import lombok.Value;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
class JwtSignatureAlgorithmTest {
@ParameterizedTest
@MethodSource("provideTestCases")
void parse(TestCase testCase) {
JwtSignatureAlgorithm signatureAlgorithm = JwtSignatureAlgorithm.parse(testCase.name);
assertEquals(testCase.expectedAlgorithm, signatureAlgorithm);
assertEquals(testCase.name, signatureAlgorithm.getName());
}
static Stream<Arguments> provideTestCases() {
return Stream.of(
Arguments.of(TestCase.builder().name("RS256").expectedAlgorithm(JwtSignatureAlgorithm.RS256).build()),
Arguments.of(TestCase.builder().name("RS384").expectedAlgorithm(JwtSignatureAlgorithm.RS384).build()),
Arguments.of(TestCase.builder().name("RS512").expectedAlgorithm(JwtSignatureAlgorithm.RS512).build()),
Arguments.of(TestCase.builder().name("ES256").expectedAlgorithm(JwtSignatureAlgorithm.ES256).build()),
Arguments.of(TestCase.builder().name("ES384").expectedAlgorithm(JwtSignatureAlgorithm.ES384).build()),
Arguments.of(TestCase.builder().name("ES512").expectedAlgorithm(JwtSignatureAlgorithm.ES512).build()),
Arguments.of(TestCase.builder().name("PS256").expectedAlgorithm(JwtSignatureAlgorithm.PS256).build()),
Arguments.of(TestCase.builder().name("PS384").expectedAlgorithm(JwtSignatureAlgorithm.PS384).build()),
Arguments.of(TestCase.builder().name("PS512").expectedAlgorithm(JwtSignatureAlgorithm.PS512).build())
);
}
@Value
static class TestCase {
String name;
JwtSignatureAlgorithm expectedAlgorithm;
@Builder
public TestCase(String name, JwtSignatureAlgorithm expectedAlgorithm) {
this.name = name;
this.expectedAlgorithm = expectedAlgorithm;
}
}
@Test
void testParseFamilyRSA() {
JwtSignatureAlgorithm.Family rsa = JwtSignatureAlgorithm.Family.parse("RSA");
assertEquals(JwtSignatureAlgorithm.Family.RSA, rsa);
assertTrue(rsa.contains(JwtSignatureAlgorithm.RS256));
assertTrue(rsa.contains(JwtSignatureAlgorithm.RS384));
assertTrue(rsa.contains(JwtSignatureAlgorithm.RS512));
assertTrue(rsa.contains(JwtSignatureAlgorithm.PS256));
assertTrue(rsa.contains(JwtSignatureAlgorithm.PS384));
assertTrue(rsa.contains(JwtSignatureAlgorithm.PS512));
}
@Test
void testParseFamilyEC() {
JwtSignatureAlgorithm.Family ec = JwtSignatureAlgorithm.Family.parse("EC");
assertEquals(JwtSignatureAlgorithm.Family.EC, ec);
assertTrue(ec.contains(JwtSignatureAlgorithm.ES256));
assertTrue(ec.contains(JwtSignatureAlgorithm.ES384));
assertTrue(ec.contains(JwtSignatureAlgorithm.ES512));
}
@Test
void testParseUnknownFamily() {
try {
JwtSignatureAlgorithm.Family.parse("unknown family");
fail();
} catch (IllegalArgumentException e) {
assertEquals("Unsupported JWT family algorithm: unknown family", e.getMessage());
}
}
@Test
void testParseUnsupportedAlgorithm() {
try {
JwtSignatureAlgorithm.parse("HS256");
fail();
} catch (IllegalArgumentException e) {
assertEquals("Unsupported JWT algorithm: HS256", e.getMessage());
}
}
}

View File

@ -154,7 +154,7 @@ class JwtSvidParseAndValidateTest {
.jwtBundle(jwtBundle)
.expectedAudience(Collections.singleton("audience"))
.generateToken(() -> HS256TOKEN)
.expectedException(new JwtSvidException("Unsupported token signature algorithm HS256"))
.expectedException(new JwtSvidException("Unsupported JWT algorithm: HS256"))
.build()),
Arguments.of(TestCase.builder()
.name("5. missing subject")

View File

@ -28,6 +28,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
class JwtSvidParseInsecureTest {
private static final String HS256TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImF1dGhvcml0eTEifQ." +
"eyJzdWIiOiJzcGlmZmU6Ly90ZXN0LmRvbWFpbi9ob3N0IiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxMjM0MzQzNTM0NTUsImlh" +
"dCI6MTUxNjIzOTAyMiwiYXVkIjoiYXVkaWVuY2UifQ.wNm5pQGSLCw5N9ddgSF2hkgmQpGnG9le_gpiFmyBhao";
@ParameterizedTest
@MethodSource("provideJwtScenarios")
void parseJwt(TestCase testCase) {
@ -147,6 +151,12 @@ class JwtSvidParseInsecureTest {
.expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(TestUtils.buildJWTClaimSet(audience, "non-spiffe-subject", expiration), key1, "authority1"))
.expectedException(new JwtSvidException("Subject non-spiffe-subject cannot be parsed as a SPIFFE ID"))
.build()),
Arguments.of(TestCase.builder()
.name("unsupported algorithm")
.expectedAudience(Collections.singleton("audience"))
.generateToken(() -> HS256TOKEN)
.expectedException(new JwtSvidException("Unsupported JWT algorithm: HS256"))
.build())
);
}