Addressing PR comments:

- rename enum to improve clarity
- add missing validations to socket address parsing
- add test scenarios for address parsing
- improve Address javadoc to use the language of the SPIFFE spec
- some minor refactors

Signed-off-by: Max Lambrecht <maxlambrecht@gmail.com>
This commit is contained in:
Max Lambrecht 2020-06-24 11:43:53 -03:00
parent 96d660ad3a
commit 0005bd5a1c
7 changed files with 96 additions and 91 deletions

View File

@ -1,13 +1,13 @@
package io.spiffe.internal;
public enum PrivateKeyAlgorithm {
public enum AsymmetricKeyAlgorithm {
RSA("RSA"),
EC("EC");
private final String value;
PrivateKeyAlgorithm(final String value) {
AsymmetricKeyAlgorithm(final String value) {
this.value = value;
}
@ -15,7 +15,7 @@ public enum PrivateKeyAlgorithm {
return value;
}
public static PrivateKeyAlgorithm parse(String a) {
public static AsymmetricKeyAlgorithm parse(String a) {
if ("RSA".equalsIgnoreCase(a)) {
return RSA;
} else if ("EC".equalsIgnoreCase(a)) {

View File

@ -36,8 +36,8 @@ import java.util.stream.Collectors;
import static io.spiffe.internal.KeyUsage.CRL_SIGN;
import static io.spiffe.internal.KeyUsage.DIGITAL_SIGNATURE;
import static io.spiffe.internal.KeyUsage.KEY_CERT_SIGN;
import static io.spiffe.internal.PrivateKeyAlgorithm.EC;
import static io.spiffe.internal.PrivateKeyAlgorithm.RSA;
import static io.spiffe.internal.AsymmetricKeyAlgorithm.EC;
import static io.spiffe.internal.AsymmetricKeyAlgorithm.RSA;
import static org.apache.commons.lang3.StringUtils.startsWith;
/**
@ -90,7 +90,7 @@ public class CertificateUtils {
* @throws InvalidKeySpecException
* @throws NoSuchAlgorithmException
*/
public static PrivateKey generatePrivateKey(final byte[] privateKeyBytes, PrivateKeyAlgorithm algorithm, KeyFileFormat keyFileFormat) throws InvalidKeySpecException, NoSuchAlgorithmException, InvalidKeyException {
public static PrivateKey generatePrivateKey(final byte[] privateKeyBytes, AsymmetricKeyAlgorithm algorithm, KeyFileFormat keyFileFormat) throws InvalidKeySpecException, NoSuchAlgorithmException, InvalidKeyException {
EncodedKeySpec kspec = getEncodedKeySpec(privateKeyBytes, keyFileFormat);
return generatePrivateKeyWithSpec(kspec, algorithm);
}
@ -159,7 +159,7 @@ public class CertificateUtils {
* @throws InvalidKeyException if the keys don't match
*/
public static void validatePrivateKey(final PrivateKey privateKey, final X509Certificate x509Certificate) throws InvalidKeyException {
PrivateKeyAlgorithm algorithm = PrivateKeyAlgorithm.parse(privateKey.getAlgorithm());
AsymmetricKeyAlgorithm algorithm = AsymmetricKeyAlgorithm.parse(privateKey.getAlgorithm());
switch (algorithm) {
case RSA:
@ -231,7 +231,7 @@ public class CertificateUtils {
.collect(Collectors.toList());
}
private static PrivateKey generatePrivateKeyWithSpec(final EncodedKeySpec keySpec, PrivateKeyAlgorithm algorithm) throws NoSuchAlgorithmException, InvalidKeySpecException {
private static PrivateKey generatePrivateKeyWithSpec(final EncodedKeySpec keySpec, AsymmetricKeyAlgorithm algorithm) throws NoSuchAlgorithmException, InvalidKeySpecException {
PrivateKey privateKey;
switch (algorithm) {
case EC:

View File

@ -26,6 +26,7 @@ import java.text.ParseException;
import java.util.Collections;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -111,8 +112,7 @@ public class JwtSvid {
val signedJwt = getSignedJWT(token);
val claimsSet = getJwtClaimsSet(signedJwt);
val claimAudience = new HashSet<>(claimsSet.getAudience());
validateAudience(claimAudience, audience);
validateAudience(claimsSet.getAudience(), audience);
val expirationTime = claimsSet.getExpirationTime();
validateExpiration(expirationTime);
@ -126,6 +126,7 @@ public class JwtSvid {
val algorithm = signedJwt.getHeader().getAlgorithm().getName();
verifySignature(signedJwt, jwtAuthority, algorithm, keyId);
val claimAudience = new HashSet<>(claimsSet.getAudience());
return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token);
}
@ -150,14 +151,14 @@ public class JwtSvid {
val signedJwt = getSignedJWT(token);
val claimsSet = getJwtClaimsSet(signedJwt);
val claimAudience = new HashSet<>(claimsSet.getAudience());
validateAudience(claimAudience, audience);
validateAudience(claimsSet.getAudience(), audience);
val expirationTime = claimsSet.getExpirationTime();
validateExpiration(expirationTime);
val spiffeId = getSpiffeIdOfSubject(claimsSet);
val claimAudience = new HashSet<>(claimsSet.getAudience());
return new JwtSvid(spiffeId, claimAudience, expirationTime, claimsSet.getClaims(), token);
}
@ -275,7 +276,7 @@ public class JwtSvid {
}
private static void validateAudience(final Set<String> audClaim, final Set<String> expectedAudience) throws JwtSvidException {
private static void validateAudience(final List<String> audClaim, final Set<String> expectedAudience) throws JwtSvidException {
for (String aud : audClaim) {
if (!expectedAudience.contains(aud)) {
throw new JwtSvidException(String.format("expected audience in %s (audience=%s)", expectedAudience, audClaim));

View File

@ -3,7 +3,7 @@ package io.spiffe.svid.x509svid;
import io.spiffe.exception.X509SvidException;
import io.spiffe.internal.CertificateUtils;
import io.spiffe.internal.KeyFileFormat;
import io.spiffe.internal.PrivateKeyAlgorithm;
import io.spiffe.internal.AsymmetricKeyAlgorithm;
import io.spiffe.spiffeid.SpiffeId;
import lombok.NonNull;
import lombok.Value;
@ -154,7 +154,7 @@ public class X509Svid {
}
private static PrivateKey generatePrivateKey(final byte[] privateKeyBytes, final KeyFileFormat keyFileFormat, final List<X509Certificate> x509Certificates) throws X509SvidException {
PrivateKeyAlgorithm algorithm = PrivateKeyAlgorithm.parse(x509Certificates.get(0).getPublicKey().getAlgorithm());
AsymmetricKeyAlgorithm algorithm = AsymmetricKeyAlgorithm.parse(x509Certificates.get(0).getPublicKey().getAlgorithm());
PrivateKey privateKey;
try {
privateKey = CertificateUtils.generatePrivateKey(privateKeyBytes, algorithm, keyFileFormat);

View File

@ -1,15 +1,14 @@
package io.spiffe.workloadapi;
import com.google.common.collect.Sets;
import io.spiffe.exception.SocketEndpointAddressException;
import lombok.val;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.validator.routines.InetAddressValidator;
import java.net.InetAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
/**
* Utility class to get the default Workload API address and parse string addresses.
@ -23,7 +22,7 @@ public class Address {
private static final String UNIX_SCHEME = "unix";
private static final String TCP_SCHEME = "tcp";
private static final List<String> VALID_SCHEMES = Arrays.asList(UNIX_SCHEME, TCP_SCHEME);
private static final Set<String> VALID_SCHEMES = Sets.newHashSet(UNIX_SCHEME, TCP_SCHEME);
/**
* @return the default Workload API address hold by the system environment variable
@ -36,116 +35,115 @@ public class Address {
/**
* Parses and validates a Workload API socket address.
* <p>
* The given address should either have a tcp, or a unix scheme.
* The value of the address is structured as an RFC 3986 URI. The scheme MUST be set to either unix or tcp,
* which indicates that the endpoint is served over a Unix Domain Socket or a TCP listen socket, respectively.
* <p>
* The given address should contain a path.
* If the scheme is set to unix, then the authority component MUST NOT be set, and the path component MUST be set
* to the absolute path of the SPIFFE Workload Endpoint Unix Domain Socket (e.g. unix:///path/to/endpoint.sock).
* The scheme and path components are mandatory, and no other component may be set.
* <p>
* The given address cannot be opaque, cannot have fragments, query values or user info.
* <p>
* If the given address is tcp, it should contain an IP and a port.
* If the scheme is set to tcp, then the host component of the authority MUST be set to an IP address,
* and the port component of the authority MUST be set to the TCP port number of the SPIFFE Workload Endpoint TCP listen socket.
* The scheme, host, and port components are mandatory, and no other component may be set.
* As an example, tcp://127.0.0.1:8000 is valid, and tcp://127.0.0.1:8000/foo is not.
*
* @param address the Workload API socket address as a string
* @return an instance of a {@link URI}
* @throws SocketEndpointAddressException if the address could not be parsed or if it is not valid
* @throws SocketEndpointAddressException if the address could not be parsed or if it doesn't complain to the rules
* defined in the SPIFFE Standard.
* @see <a href="https://github.com/spiffe/spiffe/blob/master/standards/SPIFFE_Workload_Endpoint.md#4-locating-the-endpoint">SPIFFE Workload Endpoint</a>
*/
public static URI parseAddress(final String address) throws SocketEndpointAddressException {
URI parsedAddress;
try {
parsedAddress = new URI(address);
} catch (URISyntaxException e) {
throw new SocketEndpointAddressException(String.format("Workload endpoint socket is not a valid URI: %s", address), e);
}
val parsedAddress = parseUri(address);
val scheme = parsedAddress.getScheme();
if (!isValid(scheme)) {
if (isSchemeNotValid(scheme)) {
throw new SocketEndpointAddressException(String.format("Workload endpoint socket URI must have a tcp:// or unix:// scheme: %s", address));
}
String error = null;
if (UNIX_SCHEME.equals(scheme)) {
error = validateUnixAddress(parsedAddress);
}
if (TCP_SCHEME.equals(scheme)) {
error = validateTcpAddress(parsedAddress);
}
if (StringUtils.isNotBlank(error)) {
throw new SocketEndpointAddressException(String.format(error, address));
validateUnixAddress(parsedAddress);
} else {
validateTcpAddress(parsedAddress);
}
return parsedAddress;
}
private static String validateUnixAddress(final URI parsedAddress) {
if (parsedAddress.isOpaque() && parsedAddress.isAbsolute()) {
return "Workload endpoint unix socket URI must not be opaque: %s";
private static URI parseUri(final String address) throws SocketEndpointAddressException {
final URI parsedAddress;
try {
parsedAddress = new URI(address);
} catch (URISyntaxException e) {
throw new SocketEndpointAddressException(String.format("Workload endpoint socket is not a valid URI: %s", address), e);
}
return parsedAddress;
}
private static void validateUnixAddress(final URI parsedAddress) throws SocketEndpointAddressException {
if (parsedAddress.isOpaque()) {
throw new SocketEndpointAddressException(String.format("Workload endpoint unix socket URI must not be opaque: %s", parsedAddress));
}
if (StringUtils.isNotBlank(parsedAddress.getUserInfo())) {
return "Workload endpoint unix socket URI must not include user info: %s";
if (StringUtils.isNotBlank(parsedAddress.getRawAuthority())) {
throw new SocketEndpointAddressException(String.format("Workload endpoint unix socket URI must not include authority component: %s", parsedAddress));
}
if (hasEmptyPath(parsedAddress.getPath())) {
throw new SocketEndpointAddressException(String.format("Workload endpoint unix socket path cannot be blank: %s", parsedAddress));
}
if (StringUtils.isNotBlank(parsedAddress.getRawQuery())) {
return "Workload endpoint unix socket URI must not include query values: %s";
throw new SocketEndpointAddressException(String.format("Workload endpoint unix socket URI must not include query values: %s", parsedAddress));
}
if (StringUtils.isNotBlank(parsedAddress.getFragment())) {
return "Workload endpoint unix socket URI must not include a fragment: %s";
throw new SocketEndpointAddressException(String.format("Workload endpoint unix socket URI must not include a fragment: %s", parsedAddress));
}
return "";
}
private static String validateTcpAddress(final URI parsedAddress) {
if (parsedAddress.isOpaque() && parsedAddress.isAbsolute()) {
return "Workload endpoint tcp socket URI must not be opaque: %s";
private static void validateTcpAddress(final URI parsedAddress) throws SocketEndpointAddressException {
if (parsedAddress.isOpaque()) {
throw new SocketEndpointAddressException(String.format("Workload endpoint tcp socket URI must not be opaque: %s", parsedAddress));
}
if (StringUtils.isNotBlank(parsedAddress.getUserInfo())) {
return "Workload endpoint tcp socket URI must not include user info: %s";
throw new SocketEndpointAddressException(String.format("Workload endpoint tcp socket URI must not include user info: %s", parsedAddress));
}
if (StringUtils.isBlank(parsedAddress.getHost())) {
return "Workload endpoint tcp socket URI must include a host: %s";
throw new SocketEndpointAddressException(String.format("Workload endpoint tcp socket URI must include a host: %s", parsedAddress));
}
if (StringUtils.isNotBlank(parsedAddress.getPath())) {
return "Workload endpoint tcp socket URI must not include a path: %s";
throw new SocketEndpointAddressException(String.format("Workload endpoint tcp socket URI must not include a path: %s", parsedAddress));
}
if (StringUtils.isNotBlank(parsedAddress.getRawQuery())) {
return "Workload endpoint tcp socket URI must not include query values: %s";
throw new SocketEndpointAddressException(String.format("Workload endpoint tcp socket URI must not include query values: %s", parsedAddress));
}
if (StringUtils.isNotBlank(parsedAddress.getFragment())) {
return "Workload endpoint tcp socket URI must not include a fragment: %s";
throw new SocketEndpointAddressException(String.format("Workload endpoint tcp socket URI must not include a fragment: %s", parsedAddress));
}
String ip = parseIp(parsedAddress.getHost());
if (StringUtils.isBlank(ip)) {
return "Workload endpoint tcp socket URI host component must be an IP:port: %s";
val ipValid = InetAddressValidator.getInstance().isValid(parsedAddress.getHost());
if (!ipValid) {
throw new SocketEndpointAddressException(String.format("Workload endpoint tcp socket URI host component must be an IP:port: %s", parsedAddress));
}
int port = parsedAddress.getPort();
if (port == -1) {
return "Workload endpoint tcp socket URI host component must include a port: %s";
throw new SocketEndpointAddressException(String.format("Workload endpoint tcp socket URI host component must include a port: %s", parsedAddress));
}
return "";
}
private static boolean isValid(final String scheme) {
return StringUtils.isNotBlank(scheme) && VALID_SCHEMES.contains(scheme);
private static boolean hasEmptyPath(final String path) {
return StringUtils.isBlank(path) || path.equals("/");
}
private static String parseIp(final String host) {
try {
InetAddress ip = InetAddress.getByName(host);
return ip.getHostAddress();
} catch (UnknownHostException e) {
return null;
}
private static boolean isSchemeNotValid(final String scheme) {
return !VALID_SCHEMES.contains(scheme);
}
private Address() {

View File

@ -23,7 +23,7 @@ import java.security.spec.InvalidKeySpecException;
import java.util.Arrays;
import java.util.List;
import static io.spiffe.internal.PrivateKeyAlgorithm.RSA;
import static io.spiffe.internal.AsymmetricKeyAlgorithm.RSA;
import static io.spiffe.utils.X509CertificateTestUtils.createCertificate;
import static io.spiffe.utils.X509CertificateTestUtils.createRootCA;
import static org.junit.jupiter.api.Assertions.assertEquals;
@ -88,7 +88,7 @@ public class CertificateUtilsTest {
byte[] keyBytes = ecKeyPair.getPrivate().getEncoded();
try {
PrivateKey privateKey = CertificateUtils.generatePrivateKey(keyBytes, PrivateKeyAlgorithm.EC, KeyFileFormat.DER);
PrivateKey privateKey = CertificateUtils.generatePrivateKey(keyBytes, AsymmetricKeyAlgorithm.EC, KeyFileFormat.DER);
assertNotNull(privateKey);
assertEquals("EC", privateKey.getAlgorithm());
} catch (InvalidKeySpecException | InvalidKeyException | NoSuchAlgorithmException e) {

View File

@ -26,27 +26,33 @@ public class AddressTest {
static Stream<Arguments> provideTestAddress() {
return Stream.of(
Arguments.of("unix://foo", URI.create("unix://foo")),
Arguments.of("\\t", "Workload endpoint socket is not a valid URI: \\t"),
Arguments.of("blah", "Workload endpoint socket URI must have a tcp:// or unix:// scheme: blah"),
Arguments.of("unix:opaque", "Workload endpoint unix socket URI must not be opaque: unix:opaque"),
Arguments.of("unix://", "Workload endpoint socket is not a valid URI: unix://"),
Arguments.of("unix://foo?whatever", "Workload endpoint unix socket URI must not include query values: unix://foo?whatever"),
Arguments.of("unix://foo#whatever", "Workload endpoint unix socket URI must not include a fragment: unix://foo#whatever"),
Arguments.of("unix://john:doe@foo/path", "Workload endpoint unix socket URI must not include user info: unix://john:doe@foo/path"),
Arguments.of("unix:///foo", URI.create("unix:///foo")),
Arguments.of("unix:/path/to/endpoint.sock", URI.create("unix:/path/to/endpoint.sock")),
Arguments.of("unix:///path/to/endpoint.sock", URI.create("unix:///path/to/endpoint.sock")),
Arguments.of("tcp://127.0.0.1:8000", URI.create("tcp://127.0.0.1:8000")),
Arguments.of("tcp://1.2.3.4:5", URI.create("tcp://1.2.3.4:5")),
Arguments.of("\\t", "Workload endpoint socket is not a valid URI: \\t"),
Arguments.of("///foo", "Workload endpoint socket URI must have a tcp:// or unix:// scheme: ///foo"),
Arguments.of("blah", "Workload endpoint socket URI must have a tcp:// or unix:// scheme: blah"),
Arguments.of("blah:///foo", "Workload endpoint socket URI must have a tcp:// or unix:// scheme: blah:///foo"),
Arguments.of("unix:opaque", "Workload endpoint unix socket URI must not be opaque: unix:opaque"),
Arguments.of("unix:/", "Workload endpoint unix socket path cannot be blank: unix:/"),
Arguments.of("unix://", "Workload endpoint socket is not a valid URI: unix://"),
Arguments.of("unix:///", "Workload endpoint unix socket path cannot be blank: unix:///"),
Arguments.of("unix://foo", "Workload endpoint unix socket URI must not include authority component: unix://foo"),
Arguments.of("unix:///foo?whatever", "Workload endpoint unix socket URI must not include query values: unix:///foo?whatever"),
Arguments.of("unix:///foo#whatever", "Workload endpoint unix socket URI must not include a fragment: unix:///foo#whatever"),
Arguments.of("tcp://127.0.0.1:8000/foo", "Workload endpoint tcp socket URI must not include a path: tcp://127.0.0.1:8000/foo"),
Arguments.of("tcp:opaque", "Workload endpoint tcp socket URI must not be opaque: tcp:opaque"),
Arguments.of("tcp://", "Workload endpoint socket is not a valid URI: tcp://"),
Arguments.of("tcp:///test", "Workload endpoint tcp socket URI must include a host: tcp:///test"),
Arguments.of("tcp://1.2.3.4:5?whatever", "Workload endpoint tcp socket URI must not include query values: tcp://1.2.3.4:5?whatever"),
Arguments.of("tcp://1.2.3.4:5#whatever", "Workload endpoint tcp socket URI must not include a fragment: tcp://1.2.3.4:5#whatever"),
Arguments.of("tcp://john:doe@1.2.3.4:5/path", "Workload endpoint tcp socket URI must not include user info: tcp://john:doe@1.2.3.4:5/path"),
Arguments.of("tcp://1.2.3.4:5/path", "Workload endpoint tcp socket URI must not include a path: tcp://1.2.3.4:5/path"),
Arguments.of("tcp://foo", "Workload endpoint tcp socket URI host component must be an IP:port: tcp://foo"),
Arguments.of("tcp://1.2.3.4", "Workload endpoint tcp socket URI host component must include a port: tcp://1.2.3.4"),
Arguments.of("blah://foo", "Workload endpoint socket URI must have a tcp:// or unix:// scheme: blah://foo")
Arguments.of("tcp://foo:9000", "Workload endpoint tcp socket URI host component must be an IP:port: tcp://foo:9000"),
Arguments.of("tcp://1.2.3.4", "Workload endpoint tcp socket URI host component must include a port: tcp://1.2.3.4")
);
}
}