Add FetchJWTSVIDs function for workloadapi and jwtSource (#90)

Signed-off-by: Yuhan Li <liyuhan.loveyana@bytedance.com>
This commit is contained in:
M1a0 2022-04-28 05:21:24 +08:00 committed by GitHub
parent 5a9fa55fe6
commit 6cdc17eb9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 376 additions and 14 deletions

View File

@ -2,6 +2,9 @@ package io.spiffe.svid.jwtsvid;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.spiffeid.SpiffeId;
import lombok.NonNull;
import java.util.List;
/**
* Represents a source of SPIFFE JWT-SVIDs.
@ -28,4 +31,25 @@ public interface JwtSvidSource {
* @throws JwtSvidException when there is an error fetching the JWT SVID
*/
JwtSvid fetchJwtSvid(SpiffeId subject, String audience, String... extraAudiences) throws JwtSvidException;
/**
* Fetches all SPIFFE JWT-SVIDs on one-shot blocking call.
*
* @param audience the audience of the JWT-SVID
* @param extraAudience the extra audience for the JWT_SVID
* @return all of {@link JwtSvid} object
* @throws JwtSvidException if there is an error fetching or processing the JWT from the Workload API
*/
List<JwtSvid> fetchJwtSvids(@NonNull String audience, String... extraAudience) throws JwtSvidException;
/**
* Fetches all SPIFFE JWT-SVIDs on one-shot blocking call.
*
* @param subject a SPIFFE ID
* @param audience the audience of the JWT-SVID
* @param extraAudience the extra audience for the JWT_SVID
* @return all of {@link JwtSvid} object
* @throws JwtSvidException if there is an error fetching or processing the JWT from the Workload API
*/
List<JwtSvid> fetchJwtSvids(@NonNull SpiffeId subject, @NonNull String audience, String... extraAudience) throws JwtSvidException;
}

View File

@ -26,6 +26,7 @@ import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.logging.Level;
import java.util.List;
import static io.spiffe.workloadapi.internal.ThreadUtils.await;
@ -130,6 +131,30 @@ public class DefaultJwtSource implements JwtSource {
return workloadApiClient.fetchJwtSvid(subject, audience, extraAudiences);
}
@Override
public List<JwtSvid> fetchJwtSvids(String audience, String... extraAudiences) throws JwtSvidException {
if (isClosed()) {
throw new IllegalStateException("JWT SVID source is closed");
}
return workloadApiClient.fetchJwtSvids(audience, extraAudiences);
}
/**
* Fetches all new JWT SVIDs from the Workload API for the given subject SPIFFE ID and audiences.
*
* @return all {@link JwtSvid}s
* @throws IllegalStateException if the source is closed
*/
@Override
public List<JwtSvid> fetchJwtSvids(final SpiffeId subject, final String audience, final String... extraAudiences)
throws JwtSvidException {
if (isClosed()) {
throw new IllegalStateException("JWT SVID source is closed");
}
return workloadApiClient.fetchJwtSvids(subject, audience, extraAudiences);
}
/**
* Returns the JWT bundle for a given trust domain.
*

View File

@ -241,6 +241,39 @@ public final class DefaultWorkloadApiClient implements WorkloadApiClient {
}
}
/**
* {@inheritDoc}
*/
@Override
public List<JwtSvid> fetchJwtSvids(@NonNull String audience, String... extraAudience) throws JwtSvidException {
final Set<String> audParam = createAudienceSet(audience, extraAudience);
try (val cancellableContext = Context.current().withCancellation()) {
return cancellableContext.call(() -> callFetchJwtSvids(audParam));
} catch (Exception e) {
throw new JwtSvidException("Error fetching JWT SVID", e);
}
}
/**
* {@inheritDoc}
* @return
*/
@Override
public List<JwtSvid> fetchJwtSvids(@NonNull final SpiffeId subject,
@NonNull final String audience,
final String... extraAudience)
throws JwtSvidException {
final Set<String> audParam = createAudienceSet(audience, extraAudience);
try (val cancellableContext = Context.current().withCancellation()) {
return cancellableContext.call(() -> callFetchJwtSvids(subject, audParam));
} catch (Exception e) {
throw new JwtSvidException("Error fetching JWT SVID", e);
}
}
/**
* {@inheritDoc}
*/
@ -330,7 +363,7 @@ public final class DefaultWorkloadApiClient implements WorkloadApiClient {
.addAllAudience(audience)
.build();
val response = workloadApiBlockingStub.fetchJWTSVID(jwtSvidRequest);
return processJwtSvidResponse(response, audience);
return processJwtSvidResponse(response, audience, true).get(0);
}
private JwtSvid callFetchJwtSvid(final Set<String> audience) throws JwtSvidException {
@ -338,14 +371,40 @@ public final class DefaultWorkloadApiClient implements WorkloadApiClient {
.addAllAudience(audience)
.build();
val response = workloadApiBlockingStub.fetchJWTSVID(jwtSvidRequest);
return processJwtSvidResponse(response, audience);
return processJwtSvidResponse(response, audience, true).get(0);
}
private JwtSvid processJwtSvidResponse(Workload.JWTSVIDResponse response, Set<String> audience) throws JwtSvidException {
private List<JwtSvid> callFetchJwtSvids(final SpiffeId subject, final Set<String> audience) throws JwtSvidException {
val jwtSvidRequest = Workload.JWTSVIDRequest.newBuilder()
.setSpiffeId(subject.toString())
.addAllAudience(audience)
.build();
val response = workloadApiBlockingStub.fetchJWTSVID(jwtSvidRequest);
return processJwtSvidResponse(response, audience, false);
}
private List<JwtSvid> callFetchJwtSvids(final Set<String> audience) throws JwtSvidException {
val jwtSvidRequest = Workload.JWTSVIDRequest.newBuilder()
.addAllAudience(audience)
.build();
val response = workloadApiBlockingStub.fetchJWTSVID(jwtSvidRequest);
return processJwtSvidResponse(response, audience, false);
}
private List<JwtSvid> processJwtSvidResponse(Workload.JWTSVIDResponse response, Set<String> audience, boolean firstOnly) throws JwtSvidException {
if (response.getSvidsList() == null || response.getSvidsList().isEmpty()) {
throw new JwtSvidException("JWT SVID response from the Workload API is empty");
}
return JwtSvid.parseInsecure(response.getSvids(0).getSvid(), audience);
int n = response.getSvidsCount();
if (firstOnly) {
n = 1;
}
ArrayList<JwtSvid> svids = new ArrayList<>(n);
for (int i = 0; i < n; i++) {
val svid = JwtSvid.parseInsecure(response.getSvids(i).getSvid(), audience);
svids.add(svid);
}
return svids;
}
private JwtBundleSet callFetchBundles() throws JwtBundleException {

View File

@ -11,6 +11,7 @@ import io.spiffe.svid.jwtsvid.JwtSvid;
import lombok.NonNull;
import java.io.Closeable;
import java.util.List;
/**
* Represents a client to interact with the Workload API.
@ -78,6 +79,27 @@ public interface WorkloadApiClient extends Closeable {
*/
JwtSvid fetchJwtSvid(@NonNull SpiffeId subject, @NonNull String audience, String... extraAudience) throws JwtSvidException;
/**
* Fetches all SPIFFE JWT-SVIDs on one-shot blocking call.
*
* @param audience the audience of the JWT-SVID
* @param extraAudience the extra audience for the JWT_SVID
* @return all of {@link JwtSvid} object
* @throws JwtSvidException if there is an error fetching or processing the JWT from the Workload API
*/
List<JwtSvid> fetchJwtSvids(@NonNull String audience, String... extraAudience) throws JwtSvidException;
/**
* Fetches a SPIFFE JWT-SVID on one-shot blocking call.
*
* @param subject a SPIFFE ID
* @param audience the audience of the JWT-SVID
* @param extraAudience the extra audience for the JWT_SVID
* @return all of {@link JwtSvid} objectÏ
* @throws JwtSvidException if there is an error fetching or processing the JWT from the Workload API
*/
List<JwtSvid> fetchJwtSvids(@NonNull SpiffeId subject, @NonNull String audience, String... extraAudience) throws JwtSvidException;
/**
* Fetches the JWT bundles for JWT-SVID validation, keyed by trust domain.
*

View File

@ -116,6 +116,28 @@ class DefaultWorkloadApiClientEmptyResponseTest {
}
}
@Test
void testFetchJwtSvids_throwsJwtSvidException() {
try {
workloadApiClient.fetchJwtSvids("aud1", "aud2");
fail();
} catch (JwtSvidException e) {
assertEquals("Error fetching JWT SVID", e.getMessage());
assertEquals("JWT SVID response from the Workload API is empty", e.getCause().getMessage());
}
}
@Test
void testFetchJwtSvidsPassingSpiffeId_throwsJwtSvidException() {
try {
workloadApiClient.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/test"), "aud1", "aud2");
fail();
} catch (JwtSvidException e) {
assertEquals("Error fetching JWT SVID", e.getMessage());
assertEquals("JWT SVID response from the Workload API is empty", e.getCause().getMessage());
}
}
@Test
void testValidateJwtSvid_throwsJwtSvidException() {
try {

View File

@ -115,6 +115,26 @@ class DefaultWorkloadApiClientInvalidArgumentTest {
}
}
@Test
void testFetchJwtSvids_throwsJwtSvidException() {
try {
workloadApiClient.fetchJwtSvids("aud1", "aud2");
fail();
} catch (JwtSvidException e) {
assertEquals("Error fetching JWT SVID", e.getMessage());
}
}
@Test
void testFetchJwtSvidsPassingSpiffeId_throwsJwtSvidException() {
try {
workloadApiClient.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/test"), "aud1", "aud2");
fail();
} catch (JwtSvidException e) {
assertEquals("Error fetching JWT SVID", e.getMessage());
}
}
@Test
void testValidateJwtSvid_throwsJwtSvidException() {
try {

View File

@ -279,6 +279,74 @@ class DefaultWorkloadApiClientTest {
}
}
@Test
void testFetchJwtSvids() {
try {
List<JwtSvid> jwtSvids = workloadApiClient.fetchJwtSvids("aud1", "aud2", "aud3");
System.out.println(jwtSvids.toString());
assertNotNull(jwtSvids);
assertEquals(jwtSvids.size(), 2);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), jwtSvids.get(0).getSpiffeId());
assertTrue(jwtSvids.get(0).getAudience().contains("aud1"));
assertEquals(3, jwtSvids.get(0).getAudience().size());
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), jwtSvids.get(1).getSpiffeId());
assertTrue(jwtSvids.get(1).getAudience().contains("aud1"));
assertEquals(3, jwtSvids.get(1).getAudience().size());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidsPassingSpiffeId() {
try {
List<JwtSvid> jwtSvids = workloadApiClient.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/test"), "aud1", "aud2", "aud3");
assertNotNull(jwtSvids);
assertEquals(jwtSvids.size(), 1);
assertEquals(SpiffeId.parse("spiffe://example.org/test"), jwtSvids.get(0).getSpiffeId());
assertTrue(jwtSvids.get(0).getAudience().contains("aud1"));
assertEquals(3, jwtSvids.get(0).getAudience().size());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvids_nullAudience() {
try {
workloadApiClient.fetchJwtSvid(null, new String[]{"aud2", "aud3"});
fail();
} catch (NullPointerException e) {
assertEquals("audience is marked non-null but is null", e.getMessage());
} catch (JwtSvidException e) {
fail();
}
}
@Test
void testFetchJwtSvids_withSpiffeIdAndNullAudience() {
try {
workloadApiClient.fetchJwtSvid(SpiffeId.parse("spiffe://example.org/text"), null, "aud2", "aud3");
fail();
} catch (NullPointerException e) {
assertEquals("audience is marked non-null but is null", e.getMessage());
} catch (JwtSvidException e) {
fail();
}
}
@Test
void testFetchJwtSvids_nullSpiffeId() {
try {
workloadApiClient.fetchJwtSvid(null, "aud1", new String[]{"aud2", "aud3"});
fail();
} catch (NullPointerException e) {
assertEquals("subject is marked non-null but is null", e.getMessage());
} catch (JwtSvidException e) {
fail();
}
}
@Test
void testValidateJwtSvid() {
String token = generateToken("spiffe://example.org/workload-server", Collections.singletonList("aud1"));

View File

@ -108,25 +108,49 @@ class FakeWorkloadApi extends SpiffeWorkloadAPIImplBase {
@Override
public void fetchJWTSVID(Workload.JWTSVIDRequest request, StreamObserver<Workload.JWTSVIDResponse> responseObserver) {
String spiffeId = request.getSpiffeId();
String extraSpiffeId = "spiffe://example.org/extra-workload-server";
boolean firstOnly = true;
if (StringUtils.isBlank(spiffeId)) {
firstOnly = false;
spiffeId = "spiffe://example.org/workload-server";
}
Date expiration = new Date(System.currentTimeMillis() + 3600000);
Map<String, Object> claims = new HashMap<>();
claims.put("sub", spiffeId);
claims.put("aud", getAudienceList(request.getAudienceList()));
Date expiration = new Date(System.currentTimeMillis() + 3600000);
claims.put("exp", expiration);
Map<String, Object> extraClaims = new HashMap<>();
extraClaims.put("sub", extraSpiffeId);
extraClaims.put("aud", getAudienceList(request.getAudienceList()));
extraClaims.put("exp", expiration);
KeyPair keyPair = TestUtils.generateECKeyPair(Curve.P_521);
String token = TestUtils.generateToken(claims, keyPair, "authority1");
String extraToken = TestUtils.generateToken(extraClaims, keyPair, "authority1");
Workload.JWTSVID jwtsvid = Workload.JWTSVID
.newBuilder()
.setSpiffeId(spiffeId)
.setSvid(token)
.build();
Workload.JWTSVIDResponse response = Workload.JWTSVIDResponse.newBuilder().addSvids(jwtsvid).build();
Workload.JWTSVID extraJwtsvid = Workload.JWTSVID
.newBuilder()
.setSpiffeId(extraSpiffeId)
.setSvid(extraToken)
.build();
Workload.JWTSVIDResponse.Builder builder = Workload.JWTSVIDResponse.newBuilder();
builder.addSvids(jwtsvid);
if (!firstOnly) {
builder.addSvids(extraJwtsvid);
}
Workload.JWTSVIDResponse response = builder.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
}

View File

@ -17,6 +17,7 @@ import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@ -132,6 +133,64 @@ class JwtSourceTest {
}
}
@Test
void testFetchJwtSvidsWithSubject() {
try {
List<JwtSvid> svids = jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
assertNotNull(svids);
assertEquals(svids.size(), 1);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidsWithoutSubject() {
try {
List<JwtSvid> svids = jwtSource.fetchJwtSvids("aud1", "aud2", "aud3");
assertNotNull(svids);
assertEquals(svids.size(), 2);
assertEquals(SpiffeId.parse("spiffe://example.org/workload-server"), svids.get(0).getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(0).getAudience());
assertEquals(SpiffeId.parse("spiffe://example.org/extra-workload-server"), svids.get(1).getSpiffeId());
assertEquals(Sets.newHashSet("aud1", "aud2", "aud3"), svids.get(1).getAudience());
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvids_SourceIsClosed_ThrowsIllegalStateException() throws IOException {
jwtSource.close();
try {
jwtSource.fetchJwtSvids("aud1", "aud2", "aud3");
fail("expected exception");
} catch (IllegalStateException e) {
assertEquals("JWT SVID source is closed", e.getMessage());
assertTrue(workloadApiClient.closed);
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void testFetchJwtSvidsWithSubject_SourceIsClosed_ThrowsIllegalStateException() throws IOException {
jwtSource.close();
try {
jwtSource.fetchJwtSvids(SpiffeId.parse("spiffe://example.org/workload-server"), "aud1", "aud2", "aud3");
fail("expected exception");
} catch (IllegalStateException e) {
assertEquals("JWT SVID source is closed", e.getMessage());
assertTrue(workloadApiClient.closed);
} catch (JwtSvidException e) {
fail(e);
}
}
@Test
void newSource_success() {
val options = DefaultJwtSource.JwtSourceOptions

View File

@ -11,6 +11,7 @@ import io.spiffe.svid.jwtsvid.JwtSvid;
import lombok.NonNull;
import java.io.IOException;
import java.util.List;
public class WorkloadApiClientErrorStub implements WorkloadApiClient {
@ -44,6 +45,16 @@ public class WorkloadApiClientErrorStub implements WorkloadApiClient {
throw new JwtSvidException("Testing exception");
}
@Override
public List<JwtSvid> fetchJwtSvids(@NonNull String audience, String... extraAudience) throws JwtSvidException {
throw new JwtSvidException("Testing exception");
}
@Override
public List<JwtSvid> fetchJwtSvids(@NonNull SpiffeId subject, @NonNull String audience, String... extraAudience) throws JwtSvidException {
throw new JwtSvidException("Testing exception");
}
@Override
public JwtBundleSet fetchJwtBundles() throws JwtBundleException {
throw new JwtBundleException("Testing exception");

View File

@ -23,14 +23,7 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.KeyPair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.*;
import static io.spiffe.utils.TestUtils.toUri;
@ -41,6 +34,7 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
final String x509Bundle = "testdata/workloadapi/bundle.der";
final String jwtBundle = "testdata/workloadapi/bundle.json";
final SpiffeId subject = SpiffeId.parse("spiffe://example.org/workload-server");
final SpiffeId extraSubject = SpiffeId.parse("spiffe://example.org/extra-workload-server");
boolean closed;
@ -76,6 +70,21 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
return generateJwtSvid(subject, audience, extraAudience);
}
@Override
public List<JwtSvid> fetchJwtSvids(@NonNull String audience, String... extraAudience) throws JwtSvidException {
List<JwtSvid> svids = new ArrayList<>();
svids.add(generateJwtSvid(subject, audience, extraAudience));
svids.add(generateJwtSvid(extraSubject, audience, extraAudience));
return svids;
}
@Override
public List<JwtSvid> fetchJwtSvids(@NonNull SpiffeId subject, @NonNull String audience, String... extraAudience) throws JwtSvidException {
List<JwtSvid> svids = new ArrayList<>();
svids.add(generateJwtSvid(subject, audience, extraAudience));
return svids;
}
@Override
public JwtBundleSet fetchJwtBundles() throws JwtBundleException {
return generateJwtBundleSet();

View File

@ -13,6 +13,7 @@ import io.spiffe.workloadapi.WorkloadApiClient;
import io.spiffe.workloadapi.X509Context;
import lombok.NonNull;
import java.util.List;
import java.io.IOException;
public class WorkloadApiClientErrorStub implements WorkloadApiClient {
@ -46,7 +47,15 @@ public class WorkloadApiClientErrorStub implements WorkloadApiClient {
public JwtSvid fetchJwtSvid(@NonNull final SpiffeId subject, @NonNull final String audience, final String... extraAudience) throws JwtSvidException {
throw new JwtSvidException("Testing exception");
}
@Override
public List<JwtSvid> fetchJwtSvids(@NonNull final String audience, final String... extraAudience) throws JwtSvidException {
throw new JwtSvidException("Testing exception");
}
@Override
public List<JwtSvid> fetchJwtSvids(@NonNull final SpiffeId subject, @NonNull final String audience, final String... extraAudience) throws JwtSvidException {
throw new JwtSvidException("Testing exception");
}
@Override
public JwtBundleSet fetchJwtBundles() throws JwtBundleException {
throw new JwtBundleException("Testing exception");

View File

@ -24,6 +24,7 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Collections;
public class WorkloadApiClientStub implements WorkloadApiClient {
@ -64,6 +65,15 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
return null;
}
@Override
public List<JwtSvid> fetchJwtSvids(@NonNull String audience, String... extraAudience) throws JwtSvidException {
return null;
}
@Override
public List<JwtSvid> fetchJwtSvids(@NonNull final SpiffeId subject, @NonNull final String audience, final String... extraAudience) throws JwtSvidException {
return null;
}
@Override
public JwtBundleSet fetchJwtBundles() throws JwtBundleException {
return null;