Merge pull request #63 from maxlambrecht/fetch-x509-bundles

Implement FetchX509Bundles method on WorkloadAPI client
This commit is contained in:
Ryan Turner 2021-03-15 12:49:32 -07:00 committed by GitHub
commit 1177178a1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 590 additions and 58 deletions

View File

@ -2,9 +2,11 @@ package io.spiffe.workloadapi;
import io.grpc.Context;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.bundle.x509bundle.X509BundleSet;
import io.spiffe.exception.JwtBundleException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.SocketEndpointAddressException;
import io.spiffe.exception.X509BundleException;
import io.spiffe.exception.X509ContextException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.svid.jwtsvid.JwtSvid;
@ -37,6 +39,7 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.logging.Level;
import static io.spiffe.workloadapi.StreamObservers.getJwtBundleStreamObserver;
import static io.spiffe.workloadapi.StreamObservers.getX509BundlesStreamObserver;
import static io.spiffe.workloadapi.StreamObservers.getX509ContextStreamObserver;
/**
@ -180,6 +183,33 @@ public final class DefaultWorkloadApiClient implements WorkloadApiClient {
this.cancellableContexts.add(cancellableContext);
}
/**
* {@inheritDoc}
*/
@Override
public X509BundleSet fetchX509Bundles() throws X509BundleException {
try (val cancellableContext = Context.current().withCancellation()) {
return cancellableContext.call(this::callFetchX509Bundles);
} catch (Exception e) {
throw new X509BundleException("Error fetching X.509 bundles", e);
}
}
/**
* {@inheritDoc}
*/
@Override
public void watchX509Bundles(@NonNull final Watcher<X509BundleSet> watcher) {
val retryHandler = new RetryHandler(exponentialBackoffPolicy, retryExecutor);
val cancellableContext = Context.current().withCancellation();
val streamObserver =
getX509BundlesStreamObserver(watcher, retryHandler, cancellableContext, workloadApiAsyncStub);
cancellableContext.run(() -> workloadApiAsyncStub.fetchX509Bundles(newX509BundlesRequest(), streamObserver));
this.cancellableContexts.add(cancellableContext);
}
/**
* {@inheritDoc}
*/
@ -289,6 +319,11 @@ public final class DefaultWorkloadApiClient implements WorkloadApiClient {
return GrpcConversionUtils.toX509Context(x509SvidResponse);
}
private X509BundleSet callFetchX509Bundles() throws X509BundleException {
val x509BundlesResponse = workloadApiBlockingStub.fetchX509Bundles(newX509BundlesRequest());
return GrpcConversionUtils.toX509BundleSet(x509BundlesResponse);
}
private JwtSvid callFetchJwtSvid(final SpiffeId subject, final Set<String> audience) throws JwtSvidException {
val jwtSvidRequest = Workload.JWTSVIDRequest.newBuilder()
.setSpiffeId(subject.toString())
@ -307,7 +342,7 @@ public final class DefaultWorkloadApiClient implements WorkloadApiClient {
}
private JwtSvid processJwtSvidResponse(Workload.JWTSVIDResponse response, Set<String> audience) throws JwtSvidException {
if (response.getSvidsList() == null || response.getSvidsList().size() == 0) {
if (response.getSvidsList() == null || response.getSvidsList().isEmpty()) {
throw new JwtSvidException("JWT SVID response from the Workload API is empty");
}
return JwtSvid.parseInsecure(response.getSvids(0).getSvid(), audience);
@ -316,7 +351,7 @@ public final class DefaultWorkloadApiClient implements WorkloadApiClient {
private JwtBundleSet callFetchBundles() throws JwtBundleException {
val request = Workload.JWTBundlesRequest.newBuilder().build();
val bundlesResponse = workloadApiBlockingStub.fetchJWTBundles(request);
return GrpcConversionUtils.toBundleSet(bundlesResponse);
return GrpcConversionUtils.toJwtBundleSet(bundlesResponse);
}
private Set<String> createAudienceSet(final String audience, final String[] extraAudience) {
@ -330,6 +365,10 @@ public final class DefaultWorkloadApiClient implements WorkloadApiClient {
return Workload.X509SVIDRequest.newBuilder().build();
}
private Workload.X509BundlesRequest newX509BundlesRequest() {
return Workload.X509BundlesRequest.newBuilder().build();
}
private Workload.JWTBundlesRequest newJwtBundlesRequest() {
return Workload.JWTBundlesRequest.newBuilder().build();
}

View File

@ -39,7 +39,7 @@ final class GrpcConversionUtils {
}
static X509Context toX509Context(final Workload.X509SVIDResponse x509SvidResponse) throws X509ContextException {
if (x509SvidResponse.getSvidsList() == null || x509SvidResponse.getSvidsList().size() == 0) {
if (x509SvidResponse.getSvidsList() == null || x509SvidResponse.getSvidsList().isEmpty()) {
throw new X509ContextException("X.509 Context response from the Workload API is empty");
}
@ -49,16 +49,39 @@ final class GrpcConversionUtils {
return X509Context.of(x509SvidList, bundleSet);
}
static JwtBundleSet toBundleSet(final Iterator<Workload.JWTBundlesResponse> bundlesResponseIterator) throws JwtBundleException {
public static X509BundleSet toX509BundleSet(Iterator<Workload.X509BundlesResponse> bundlesResponseIterator) throws X509BundleException {
if (!bundlesResponseIterator.hasNext()) {
throw new X509BundleException("X.509 Bundle response from the Workload API is empty");
}
val bundlesResponse = bundlesResponseIterator.next();
return toX509BundleSet(bundlesResponse);
}
static X509BundleSet toX509BundleSet(final Workload.X509BundlesResponse bundlesResponse) throws X509BundleException {
val bundlesCount = bundlesResponse.getBundlesCount();
if (bundlesCount == 0) {
throw new X509BundleException("X.509 Bundle response from the Workload API is empty");
}
final List<X509Bundle> x509Bundles = new ArrayList<>(bundlesCount);
for (Map.Entry<String, ByteString> entry : bundlesResponse.getBundlesMap().entrySet()) {
X509Bundle x509Bundle = createX509Bundle(entry);
x509Bundles.add(x509Bundle);
}
return X509BundleSet.of(x509Bundles);
}
static JwtBundleSet toJwtBundleSet(final Iterator<Workload.JWTBundlesResponse> bundlesResponseIterator) throws JwtBundleException {
if (!bundlesResponseIterator.hasNext()) {
throw new JwtBundleException("JWT Bundle response from the Workload API is empty");
}
val bundlesResponse = bundlesResponseIterator.next();
return toBundleSet(bundlesResponse);
return toJwtBundleSet(bundlesResponse);
}
static JwtBundleSet toBundleSet(final Workload.JWTBundlesResponse bundlesResponse) throws JwtBundleException {
static JwtBundleSet toJwtBundleSet(final Workload.JWTBundlesResponse bundlesResponse) throws JwtBundleException {
if (bundlesResponse.getBundlesMap().size() == 0) {
throw new JwtBundleException("JWT Bundle response from the Workload API is empty");
}
@ -146,4 +169,10 @@ final class GrpcConversionUtils {
byte[] bundleBytes = entry.getValue().toByteArray();
return JwtBundle.parse(trustDomain, bundleBytes);
}
private static X509Bundle createX509Bundle(Map.Entry<String, ByteString> bundleEntry) throws X509BundleException {
TrustDomain trustDomain = TrustDomain.of(bundleEntry.getKey());
byte[] bundleBytes = bundleEntry.getValue().toByteArray();
return X509Bundle.parse(trustDomain, bundleBytes);
}
}

View File

@ -4,7 +4,9 @@ import io.grpc.Context;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.bundle.x509bundle.X509BundleSet;
import io.spiffe.exception.JwtBundleException;
import io.spiffe.exception.X509BundleException;
import io.spiffe.exception.X509ContextException;
import io.spiffe.workloadapi.grpc.SpiffeWorkloadAPIGrpc;
import io.spiffe.workloadapi.grpc.Workload;
@ -18,6 +20,7 @@ import java.util.logging.Level;
final class StreamObservers {
private static final String INVALID_ARGUMENT = "INVALID_ARGUMENT";
private static final String STREAM_IS_COMPLETED = "Workload API stream is completed";
private StreamObservers() {
}
@ -48,7 +51,7 @@ final class StreamObservers {
private void handleWatchX509ContextError(final Throwable t) {
if (isErrorNotRetryable(t)) {
watcher.onError(new X509ContextException("Canceling X.509 Context watch", t));
watcher.onError(new X509ContextException("Cancelling X.509 Context watch", t));
} else {
handleX509ContextRetry(t);
}
@ -62,14 +65,66 @@ final class StreamObservers {
() -> workloadApiAsyncStub.fetchX509SVID(newX509SvidRequest(),
this)));
} else {
watcher.onError(new X509ContextException("Canceling X.509 Context watch", t));
watcher.onError(new X509ContextException("Cancelling X.509 Context watch", t));
}
}
@Override
public void onCompleted() {
cancellableContext.close();
log.info("Workload API stream is completed");
log.info(STREAM_IS_COMPLETED);
}
};
}
static StreamObserver<Workload.X509BundlesResponse> getX509BundlesStreamObserver(
final Watcher<X509BundleSet> watcher,
final RetryHandler retryHandler,
final Context.CancellableContext cancellableContext,
final SpiffeWorkloadAPIGrpc.SpiffeWorkloadAPIStub workloadApiAsyncStub) {
return new StreamObserver<Workload.X509BundlesResponse>() {
@Override
public void onNext(final Workload.X509BundlesResponse value) {
try {
val x509Context = GrpcConversionUtils.toX509BundleSet(value);
watcher.onUpdate(x509Context);
retryHandler.reset();
} catch (X509BundleException e) {
watcher.onError(new X509ContextException("Error processing X.509 bundles update", e));
}
}
@Override
public void onError(final Throwable t) {
log.log(Level.SEVERE, "X.509 bundles observer error", t);
handleWatchX509BundlesError(t);
}
private void handleWatchX509BundlesError(final Throwable t) {
if (isErrorNotRetryable(t)) {
watcher.onError(new X509ContextException("Cancelling X.509 bundles watch", t));
} else {
handleX509BundlesRetry(t);
}
}
private void handleX509BundlesRetry(Throwable t) {
if (retryHandler.shouldRetry()) {
log.log(Level.INFO, "Retrying connecting to Workload API to register X.509 bundles watcher");
retryHandler.scheduleRetry(() ->
cancellableContext.run(
() -> workloadApiAsyncStub.fetchX509Bundles(newX509BundlesRequest(),
this)));
} else {
watcher.onError(new X509BundleException("Cancelling X.509 bundles watch", t));
}
}
@Override
public void onCompleted() {
cancellableContext.close();
log.info(STREAM_IS_COMPLETED);
}
};
}
@ -84,7 +139,7 @@ final class StreamObservers {
@Override
public void onNext(final Workload.JWTBundlesResponse value) {
try {
val jwtBundleSet = GrpcConversionUtils.toBundleSet(value);
val jwtBundleSet = GrpcConversionUtils.toJwtBundleSet(value);
watcher.onUpdate(jwtBundleSet);
retryHandler.reset();
} catch (JwtBundleException e) {
@ -100,7 +155,7 @@ final class StreamObservers {
private void handleWatchJwtBundleError(final Throwable t) {
if (isErrorNotRetryable(t)) {
watcher.onError(new JwtBundleException("Canceling JWT Bundles watch", t));
watcher.onError(new JwtBundleException("Cancelling JWT Bundles watch", t));
} else {
handleJwtBundleRetry(t);
}
@ -113,14 +168,14 @@ final class StreamObservers {
cancellableContext.run(() -> workloadApiAsyncStub.fetchJWTBundles(newJwtBundlesRequest(),
this)));
} else {
watcher.onError(new JwtBundleException("Canceling JWT Bundles watch", t));
watcher.onError(new JwtBundleException("Cancelling JWT Bundles watch", t));
}
}
@Override
public void onCompleted() {
cancellableContext.close();
log.info("Workload API stream is completed");
log.info(STREAM_IS_COMPLETED);
}
};
}
@ -133,6 +188,10 @@ final class StreamObservers {
return Workload.X509SVIDRequest.newBuilder().build();
}
private static Workload.X509BundlesRequest newX509BundlesRequest() {
return Workload.X509BundlesRequest.newBuilder().build();
}
private static Workload.JWTBundlesRequest newJwtBundlesRequest() {
return Workload.JWTBundlesRequest.newBuilder().build();
}

View File

@ -1,8 +1,10 @@
package io.spiffe.workloadapi;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.bundle.x509bundle.X509BundleSet;
import io.spiffe.exception.JwtBundleException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.X509BundleException;
import io.spiffe.exception.X509ContextException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.svid.jwtsvid.JwtSvid;
@ -36,6 +38,25 @@ public interface WorkloadApiClient extends Closeable {
*/
void watchX509Context(@NonNull Watcher<X509Context> watcher);
/**
* Fetches the X.509 bundles on a one-shot blocking call.
*
* @return an instance of a {@link X509BundleSet} containing the X.509 bundles keyed by TrustDomain
* @throws X509BundleException if there is an error fetching or processing the X.509 bundles
*/
X509BundleSet fetchX509Bundles() throws X509BundleException;
/**
* Watches for X.509 bundles updates.
* <p>
* A new Stream to the Workload API is opened for each call to this method, so that the client starts getting
* updates immediately after the Stream is ready and doesn't have to wait until the Workload API dispatches
* the next update.
*
* @param watcher an instance that implements a {@link Watcher} for {@link X509BundleSet}.
*/
void watchX509Bundles(@NonNull Watcher<X509BundleSet> watcher);
/**
* Fetches a SPIFFE JWT-SVID on one-shot blocking call.
*

View File

@ -42,6 +42,13 @@ message X509SVID {
bytes bundle = 4;
}
message X509BundlesRequest {}
message X509BundlesResponse {
// x509 certificates, keyed by trust domain URI
map<string, bytes> bundles = 1;
}
message JWTSVID {
string spiffe_id = 1;
@ -91,5 +98,6 @@ service SpiffeWorkloadAPI {
// well as related information like trust bundles and CRLs. As
// this information changes, subsequent messages will be sent.
rpc FetchX509SVID(X509SVIDRequest) returns (stream X509SVIDResponse);
rpc FetchX509Bundles(X509BundlesRequest) returns (stream X509BundlesResponse);
}

View File

@ -2,8 +2,10 @@ package io.spiffe.workloadapi;
import io.grpc.testing.GrpcCleanupRule;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.bundle.x509bundle.X509BundleSet;
import io.spiffe.exception.JwtBundleException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.X509BundleException;
import io.spiffe.exception.X509ContextException;
import io.spiffe.spiffeid.SpiffeId;
import org.junit.Rule;
@ -34,7 +36,7 @@ class DefaultWorkloadApiClientCorruptedResponsesTest {
}
@Test
public void testFetchX509Context_throwsX509ContextException() throws Exception {
void testFetchX509Context_throwsX509ContextException() throws Exception {
try {
workloadApiClient.fetchX509Context();
fail();
@ -44,9 +46,8 @@ class DefaultWorkloadApiClientCorruptedResponsesTest {
}
@Test
public void testWatchX509Context_onErrorIsCalledOnWatcher() throws Exception {
void testWatchX509Context_onErrorIsCalledOnWatcher() throws Exception {
CountDownLatch done = new CountDownLatch(1);
final String[] error = new String[1];
Watcher<X509Context> contextWatcher = new Watcher<X509Context>() {
@Override
public void onUpdate(X509Context update) {
@ -55,13 +56,41 @@ class DefaultWorkloadApiClientCorruptedResponsesTest {
@Override
public void onError(Throwable e) {
error[0] = e.getMessage();
assertEquals("Error processing X.509 Context update", e.getMessage());
done.countDown();
}
};
workloadApiClient.watchX509Context(contextWatcher);
done.await();
assertEquals("Error processing X.509 Context update", error[0]);
}
@Test
void testFetchX509Bundles_throwsX509BundleException() {
try {
workloadApiClient.fetchX509Bundles();
fail();
} catch (X509BundleException e) {
assertEquals("Error fetching X.509 bundles", e.getMessage());
}
}
@Test
void testWatchX509Bundles_onErrorIsCalledOnWatched() throws InterruptedException {
CountDownLatch done = new CountDownLatch(1);
Watcher<X509BundleSet> contextWatcher = new Watcher<X509BundleSet>() {
@Override
public void onUpdate(X509BundleSet update) {
fail();
}
@Override
public void onError(Throwable e) {
assertEquals("Error processing X.509 bundles update", e.getMessage());
done.countDown();
}
};
workloadApiClient.watchX509Bundles(contextWatcher);
done.await();
}
@Test
@ -97,7 +126,6 @@ class DefaultWorkloadApiClientCorruptedResponsesTest {
@Test
void testWatchJwtBundles_onErrorIsCalledOnWatched() throws InterruptedException {
CountDownLatch done = new CountDownLatch(1);
final String[] error = new String[1];
Watcher<JwtBundleSet> contextWatcher = new Watcher<JwtBundleSet>() {
@Override
public void onUpdate(JwtBundleSet update) {
@ -106,12 +134,11 @@ class DefaultWorkloadApiClientCorruptedResponsesTest {
@Override
public void onError(Throwable e) {
error[0] = e.getMessage();
assertEquals("Error processing JWT bundles update", e.getMessage());
done.countDown();
}
};
workloadApiClient.watchJwtBundles(contextWatcher);
done.await();
assertEquals("Error processing JWT bundles update", error[0]);
}
}

View File

@ -2,8 +2,10 @@ package io.spiffe.workloadapi;
import io.grpc.testing.GrpcCleanupRule;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.bundle.x509bundle.X509BundleSet;
import io.spiffe.exception.JwtBundleException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.X509BundleException;
import io.spiffe.exception.X509ContextException;
import io.spiffe.spiffeid.SpiffeId;
import org.junit.Rule;
@ -35,7 +37,7 @@ class DefaultWorkloadApiClientEmptyResponseTest {
@Test
public void testFetchX509Context_throwsX509ContextException() throws Exception {
void testFetchX509Context_throwsX509ContextException() throws Exception {
try {
workloadApiClient.fetchX509Context();
fail();
@ -45,9 +47,8 @@ class DefaultWorkloadApiClientEmptyResponseTest {
}
@Test
public void testWatchX509Context_onErrorIsCalledOnWatcher() throws Exception {
void testWatchX509Context_onErrorIsCalledOnWatcher() throws Exception {
CountDownLatch done = new CountDownLatch(1);
final String[] error = new String[1];
Watcher<X509Context> contextWatcher = new Watcher<X509Context>() {
@Override
public void onUpdate(X509Context update) {
@ -56,13 +57,41 @@ class DefaultWorkloadApiClientEmptyResponseTest {
@Override
public void onError(Throwable e) {
error[0] = e.getMessage();
assertEquals("Error processing X.509 Context update", e.getMessage());
done.countDown();
}
};
workloadApiClient.watchX509Context(contextWatcher);
done.await();
assertEquals("Error processing X.509 Context update", error[0]);
}
@Test
void testFetchX509Bundles_throwsX509BundleException() {
try {
workloadApiClient.fetchX509Bundles();
fail();
} catch (X509BundleException e) {
assertEquals("Error fetching X.509 bundles", e.getMessage());
}
}
@Test
void testWatchX509Bundles_onErrorIsCalledOnWatched() throws InterruptedException {
CountDownLatch done = new CountDownLatch(1);
Watcher<X509BundleSet> contextWatcher = new Watcher<X509BundleSet>() {
@Override
public void onUpdate(X509BundleSet update) {
fail();
}
@Override
public void onError(Throwable e) {
assertEquals("Error processing X.509 bundles update", e.getMessage());
done.countDown();
}
};
workloadApiClient.watchX509Bundles(contextWatcher);
done.await();
}
@Test
@ -110,7 +139,6 @@ class DefaultWorkloadApiClientEmptyResponseTest {
@Test
void testWatchJwtBundles_onErrorIsCalledOnWatched() throws InterruptedException {
CountDownLatch done = new CountDownLatch(1);
final String[] error = new String[1];
Watcher<JwtBundleSet> contextWatcher = new Watcher<JwtBundleSet>() {
@Override
public void onUpdate(JwtBundleSet update) {
@ -119,12 +147,11 @@ class DefaultWorkloadApiClientEmptyResponseTest {
@Override
public void onError(Throwable e) {
error[0] = e.getMessage();
assertEquals("Error processing JWT bundles update", e.getMessage());
done.countDown();
}
};
workloadApiClient.watchJwtBundles(contextWatcher);
done.await();
assertEquals("Error processing JWT bundles update", error[0]);
}
}

View File

@ -3,8 +3,10 @@ package io.spiffe.workloadapi;
import io.grpc.Status;
import io.grpc.testing.GrpcCleanupRule;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.bundle.x509bundle.X509BundleSet;
import io.spiffe.exception.JwtBundleException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.X509BundleException;
import io.spiffe.exception.X509ContextException;
import io.spiffe.spiffeid.SpiffeId;
import org.junit.Rule;
@ -18,7 +20,7 @@ import java.util.concurrent.CountDownLatch;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
class DefaultWorkloadApiClientInvalidaArgumentTest {
class DefaultWorkloadApiClientInvalidArgumentTest {
@Rule
public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
@ -36,7 +38,7 @@ class DefaultWorkloadApiClientInvalidaArgumentTest {
@Test
public void testFetchX509Context_throwsX509ContextException() throws Exception {
void testFetchX509Context_throwsX509ContextException() throws Exception {
try {
workloadApiClient.fetchX509Context();
fail();
@ -46,9 +48,8 @@ class DefaultWorkloadApiClientInvalidaArgumentTest {
}
@Test
public void testWatchX509Context_onErrorIsCalledOnWatcher() throws Exception {
void testWatchX509Context_onErrorIsCalledOnWatcher() throws Exception {
CountDownLatch done = new CountDownLatch(1);
final String[] error = new String[1];
Watcher<X509Context> contextWatcher = new Watcher<X509Context>() {
@Override
public void onUpdate(X509Context update) {
@ -57,13 +58,41 @@ class DefaultWorkloadApiClientInvalidaArgumentTest {
@Override
public void onError(Throwable e) {
error[0] = e.getMessage();
assertEquals("Cancelling X.509 Context watch", e.getMessage());
done.countDown();
}
};
workloadApiClient.watchX509Context(contextWatcher);
done.await();
assertEquals("Canceling X.509 Context watch", error[0]);
}
@Test
void testFetchX509Bundles_throwsX509BundleException() {
try {
workloadApiClient.fetchX509Bundles();
fail();
} catch (X509BundleException e) {
assertEquals("Error fetching X.509 bundles", e.getMessage());
}
}
@Test
void testWatchX509Bundles_onErrorIsCalledOnWatched() throws InterruptedException {
CountDownLatch done = new CountDownLatch(1);
Watcher<X509BundleSet> contextWatcher = new Watcher<X509BundleSet>() {
@Override
public void onUpdate(X509BundleSet update) {
fail();
}
@Override
public void onError(Throwable e) {
assertEquals("Cancelling X.509 bundles watch", e.getMessage());
done.countDown();
}
};
workloadApiClient.watchX509Bundles(contextWatcher);
done.await();
}
@Test
@ -109,7 +138,6 @@ class DefaultWorkloadApiClientInvalidaArgumentTest {
@Test
void testWatchJwtBundles_onErrorIsCalledOnWatched() throws InterruptedException {
CountDownLatch done = new CountDownLatch(1);
final String[] error = new String[1];
Watcher<JwtBundleSet> contextWatcher = new Watcher<JwtBundleSet>() {
@Override
public void onUpdate(JwtBundleSet update) {
@ -118,12 +146,11 @@ class DefaultWorkloadApiClientInvalidaArgumentTest {
@Override
public void onError(Throwable e) {
error[0] = e.getMessage();
assertEquals("Cancelling JWT Bundles watch", e.getMessage());
done.countDown();
}
};
workloadApiClient.watchJwtBundles(contextWatcher);
done.await();
assertEquals("Canceling JWT Bundles watch", error[0]);
}
}

View File

@ -3,6 +3,7 @@ package io.spiffe.workloadapi;
import io.grpc.Status;
import io.grpc.testing.GrpcCleanupRule;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.bundle.x509bundle.X509BundleSet;
import io.spiffe.exception.X509ContextException;
import org.junit.Rule;
import org.junit.jupiter.api.AfterEach;
@ -34,7 +35,7 @@ class DefaultWorkloadApiClientRetryableErrorTest {
@Test
public void testFetchX509Context_throwsX509ContextException() throws Exception {
void testFetchX509Context_throwsX509ContextException() throws Exception {
try {
workloadApiClient.fetchX509Context();
fail();
@ -44,9 +45,8 @@ class DefaultWorkloadApiClientRetryableErrorTest {
}
@Test
public void testWatchX509Context_onErrorIsCalledOnWatcher() throws Exception {
void testWatchX509Context_onErrorIsCalledOnWatcher() throws Exception {
CountDownLatch done = new CountDownLatch(1);
final String[] error = new String[1];
Watcher<X509Context> contextWatcher = new Watcher<X509Context>() {
@Override
public void onUpdate(X509Context update) {
@ -55,19 +55,36 @@ class DefaultWorkloadApiClientRetryableErrorTest {
@Override
public void onError(Throwable e) {
error[0] = e.getMessage();
assertEquals("Cancelling X.509 Context watch", e.getMessage());
done.countDown();
}
};
workloadApiClient.watchX509Context(contextWatcher);
done.await(5, TimeUnit.SECONDS);
assertEquals("Canceling X.509 Context watch", error[0]);
done.await();
}
@Test
void testWatchX509Bundles_onErrorIsCalledOnWatched() throws InterruptedException {
CountDownLatch done = new CountDownLatch(1);
Watcher<X509BundleSet> contextWatcher = new Watcher<X509BundleSet>() {
@Override
public void onUpdate(X509BundleSet update) {
fail();
}
@Override
public void onError(Throwable e) {
assertEquals("Cancelling X.509 bundles watch", e.getMessage());
done.countDown();
}
};
workloadApiClient.watchX509Bundles(contextWatcher);
done.await();
}
@Test
void testWatchJwtBundles_onErrorIsCalledOnWatched() throws InterruptedException {
CountDownLatch done = new CountDownLatch(1);
final String[] error = new String[1];
Watcher<JwtBundleSet> contextWatcher = new Watcher<JwtBundleSet>() {
@Override
public void onUpdate(JwtBundleSet update) {
@ -76,12 +93,11 @@ class DefaultWorkloadApiClientRetryableErrorTest {
@Override
public void onError(Throwable e) {
error[0] = e.getMessage();
assertEquals("Cancelling JWT Bundles watch", e.getMessage());
done.countDown();
}
};
workloadApiClient.watchJwtBundles(contextWatcher);
done.await(5, TimeUnit.SECONDS);
assertEquals("Canceling JWT Bundles watch", error[0]);
done.await();
}
}

View File

@ -5,10 +5,12 @@ import io.grpc.testing.GrpcCleanupRule;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.bundle.x509bundle.X509BundleSet;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.exception.JwtBundleException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.SocketEndpointAddressException;
import io.spiffe.exception.X509BundleException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.svid.jwtsvid.JwtSvid;
@ -28,7 +30,6 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@ -92,7 +93,7 @@ class DefaultWorkloadApiClientTest {
}
@Test
public void testFetchX509Context() throws Exception {
void testFetchX509Context() throws Exception {
X509Context x509Context = workloadApiClient.fetchX509Context();
@ -126,7 +127,7 @@ class DefaultWorkloadApiClientTest {
};
workloadApiClient.watchX509Context(contextWatcher);
done.await(1, TimeUnit.SECONDS);
done.await();
X509Context update = x509Context[0];
assertNotNull(update);
@ -151,6 +152,70 @@ class DefaultWorkloadApiClientTest {
}
}
@Test
void testFetchX509Bundles() {
X509BundleSet x509BundleSet = null;
try {
x509BundleSet = workloadApiClient.fetchX509Bundles();
} catch (X509BundleException e) {
fail(e);
}
assertNotNull(x509BundleSet);
try {
X509Bundle bundle = x509BundleSet.getBundleForTrustDomain(TrustDomain.of("example.org"));
assertNotNull(bundle);
X509Bundle otherBundle = x509BundleSet.getBundleForTrustDomain(TrustDomain.of("domain.test"));
assertNotNull(otherBundle);
} catch (BundleNotFoundException e) {
fail(e);
}
}
@Test
void testWatchX509Bundles() throws InterruptedException {
CountDownLatch done = new CountDownLatch(1);
final X509BundleSet[] x509BundleSet = new X509BundleSet[1];
Watcher<X509BundleSet> x509BundleSetWatcher = new Watcher<X509BundleSet>() {
@Override
public void onUpdate(X509BundleSet update) {
x509BundleSet[0] = update;
done.countDown();
}
@Override
public void onError(Throwable e) {
}
};
workloadApiClient.watchX509Bundles(x509BundleSetWatcher);
done.await();
X509BundleSet update = x509BundleSet[0];
assertNotNull(update);
try {
X509Bundle bundle1 = update.getBundleForTrustDomain(TrustDomain.of("example.org"));
assertNotNull(bundle1);
X509Bundle bundle2 = update.getBundleForTrustDomain(TrustDomain.of("domain.test"));
assertNotNull(bundle2);
} catch (BundleNotFoundException e) {
fail(e);
}
}
@Test
void testWatchX509BundlesNullWatcher_throwsNullPointerException() {
try {
workloadApiClient.watchX509Bundles(null);
} catch (NullPointerException e) {
assertEquals("watcher is marked non-null but is null", e.getMessage());
}
}
@Test
void testFetchJwtSvid() {
@ -289,7 +354,7 @@ class DefaultWorkloadApiClientTest {
};
workloadApiClient.watchJwtBundles(jwtBundleSetWatcher);
done.await(1, TimeUnit.SECONDS);
done.await();
JwtBundleSet update = jwtBundleSet[0];
assertNotNull(update);
@ -303,7 +368,7 @@ class DefaultWorkloadApiClientTest {
}
@Test
void testWatchSvidBundlesNullWatcher_throwsNullPointerException() {
void testWatchJwtBundlesNullWatcher_throwsNullPointerException() {
try {
workloadApiClient.watchJwtBundles(null);
} catch (NullPointerException e) {

View File

@ -141,7 +141,7 @@ class DefaultX509SourceTest {
}
@Test
void newSource_errorFetchingJwtBundles() {
void newSource_errorFetchingX509Context() {
val options = DefaultX509Source.X509SourceOptions
.builder()
.workloadApiClient(workloadApiClientErrorStub)

View File

@ -81,6 +81,29 @@ class FakeWorkloadApi extends SpiffeWorkloadAPIImplBase {
}
}
@Override
public void fetchX509Bundles(Workload.X509BundlesRequest request, StreamObserver<Workload.X509BundlesResponse> responseObserver) {
try {
Path pathBundle = Paths.get(toUri(x509Bundle));
byte[] bundleBytes = Files.readAllBytes(pathBundle);
ByteString bundleByteString = ByteString.copyFrom(bundleBytes);
Path pathFederateBundle = Paths.get(toUri(federatedBundle));
byte[] federatedBundleBytes = Files.readAllBytes(pathFederateBundle);
ByteString federatedByteString = ByteString.copyFrom(federatedBundleBytes);
Workload.X509BundlesResponse response = Workload.X509BundlesResponse
.newBuilder()
.putBundles(TrustDomain.of("example.org").getName(), bundleByteString)
.putBundles(TrustDomain.of("domain.test").getName(), federatedByteString)
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
} catch (URISyntaxException | IOException e) {
throw new Error("Failed FakeSpiffeWorkloadApiService.fetchX509Bundles", e);
}
}
@Override
public void fetchJWTSVID(Workload.JWTSVIDRequest request, StreamObserver<Workload.JWTSVIDResponse> responseObserver) {

View File

@ -48,6 +48,25 @@ class FakeWorkloadApiCorruptedResponses extends SpiffeWorkloadAPIImplBase {
}
}
@Override
public void fetchX509Bundles(Workload.X509BundlesRequest request, StreamObserver<Workload.X509BundlesResponse> responseObserver) {
Path pathBundle = null;
try {
pathBundle = Paths.get(toUri(corrupted));
byte[] bundleBytes = Files.readAllBytes(pathBundle);
ByteString corruptedByteString = ByteString.copyFrom(bundleBytes);
Workload.X509BundlesResponse response = Workload.X509BundlesResponse
.newBuilder()
.putBundles("example.org", corruptedByteString)
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
} catch (URISyntaxException | IOException e) {
throw new Error("Failed FakeSpiffeWorkloadApiService.fetchX509Bundles", e);
}
}
@Override
public void fetchJWTSVID(Workload.JWTSVIDRequest request, StreamObserver<Workload.JWTSVIDResponse> responseObserver) {

View File

@ -12,6 +12,12 @@ class FakeWorkloadApiEmptyResponse extends SpiffeWorkloadAPIImplBase {
responseObserver.onCompleted();
}
@Override
public void fetchX509Bundles(Workload.X509BundlesRequest request, StreamObserver<Workload.X509BundlesResponse> responseObserver) {
responseObserver.onNext(Workload.X509BundlesResponse.newBuilder().build());
responseObserver.onCompleted();
}
@Override
public void fetchJWTSVID(Workload.JWTSVIDRequest request, StreamObserver<Workload.JWTSVIDResponse> responseObserver) {
responseObserver.onNext(Workload.JWTSVIDResponse.newBuilder().build());

View File

@ -20,6 +20,11 @@ class FakeWorkloadApiExceptions extends SpiffeWorkloadAPIImplBase {
responseObserver.onError(exception);
}
@Override
public void fetchX509Bundles(Workload.X509BundlesRequest request, StreamObserver<Workload.X509BundlesResponse> responseObserver) {
responseObserver.onError(exception);
}
@Override
public void fetchJWTSVID(Workload.JWTSVIDRequest request, StreamObserver<Workload.JWTSVIDResponse> responseObserver) {
responseObserver.onError(exception);

View File

@ -1,20 +1,38 @@
package io.spiffe.workloadapi;
import com.google.protobuf.ByteString;
import io.grpc.stub.StreamObserver;
import io.spiffe.bundle.x509bundle.X509Bundle;
import io.spiffe.bundle.x509bundle.X509BundleSet;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.exception.JwtBundleException;
import io.spiffe.exception.X509BundleException;
import io.spiffe.exception.X509ContextException;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.workloadapi.grpc.Workload;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.Iterator;
import java.util.Set;
import static io.spiffe.utils.TestUtils.toUri;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.fail;
class GrpcConversionUtilsTest {
final String x509Bundle = "testdata/workloadapi/bundle.der";
final String federatedBundle = "testdata/workloadapi/federated-bundle.pem";
@Test
void toX509Context_emptyResponse() {
void test_toX509Context_emptyResponse() {
Iterator<Workload.X509SVIDResponse> iterator = Collections.emptyIterator();
try {
GrpcConversionUtils.toX509Context(iterator);
@ -24,21 +42,92 @@ class GrpcConversionUtilsTest {
}
@Test
void toBundleSet() {
void test_toJwtBundleSet_emtpyResponse() {
Iterator<Workload.JWTBundlesResponse> iterator = Collections.emptyIterator();
try {
GrpcConversionUtils.toBundleSet(iterator);
GrpcConversionUtils.toJwtBundleSet(iterator);
} catch (JwtBundleException e) {
assertEquals("JWT Bundle response from the Workload API is empty", e.getMessage());
}
}
@Test
void parseX509Bundle_corruptedBytes() {
void test_parseX509Bundle_corruptedBytes() {
try {
GrpcConversionUtils.parseX509Bundle(TrustDomain.of("example.org"), "corrupted".getBytes());
} catch (X509ContextException e) {
assertEquals("X.509 Bundles could not be processed", e.getMessage());
}
}
@Test
void test_toX509BundleSet_from_X509BundlesResponse() throws URISyntaxException, IOException {
Workload.X509BundlesResponse response = createX509BundlesResponse();
try {
X509BundleSet x509BundleSet = GrpcConversionUtils.toX509BundleSet(response);
X509Bundle bundle1 = x509BundleSet.getBundleForTrustDomain(TrustDomain.of("example.org"));
X509Bundle bundle2 = x509BundleSet.getBundleForTrustDomain(TrustDomain.of("domain.test"));
assertEquals(1, bundle1.getX509Authorities().size());
assertEquals(1, bundle2.getX509Authorities().size());
} catch (X509BundleException | BundleNotFoundException e) {
fail();
}
}
@Test
void test_toX509BundleSet_from_X509BundlesResponseIterator() throws URISyntaxException, IOException {
Workload.X509BundlesResponse response = createX509BundlesResponse();
final Iterator<Workload.X509BundlesResponse> iterator = Collections.singleton(response).iterator();
try {
X509BundleSet x509BundleSet = GrpcConversionUtils.toX509BundleSet(iterator);
X509Bundle bundle1 = x509BundleSet.getBundleForTrustDomain(TrustDomain.of("example.org"));
X509Bundle bundle2 = x509BundleSet.getBundleForTrustDomain(TrustDomain.of("domain.test"));
assertEquals(1, bundle1.getX509Authorities().size());
assertEquals(1, bundle2.getX509Authorities().size());
} catch (X509BundleException | BundleNotFoundException e) {
fail();
}
}
@Test
void test_toX509BundleSet_fromEmptyResponse() {
Workload.X509BundlesResponse response = Workload.X509BundlesResponse.newBuilder().build();
try {
GrpcConversionUtils.toX509BundleSet(response);
fail();
} catch (X509BundleException e) {
assertEquals("X.509 Bundle response from the Workload API is empty", e.getMessage());
}
}
@Test
void test_toX509BundleSet_fromEmptyIterator() {
final Iterator<Workload.X509BundlesResponse> iterator = Collections.emptyListIterator();
try {
GrpcConversionUtils.toX509BundleSet(iterator);
fail();
} catch (X509BundleException e) {
assertEquals("X.509 Bundle response from the Workload API is empty", e.getMessage());
}
}
private Workload.X509BundlesResponse createX509BundlesResponse() throws URISyntaxException, IOException {
Path pathBundle = Paths.get(toUri(x509Bundle));
byte[] bundleBytes = Files.readAllBytes(pathBundle);
ByteString bundleByteString = ByteString.copyFrom(bundleBytes);
Path pathFederateBundle = Paths.get(toUri(federatedBundle));
byte[] federatedBundleBytes = Files.readAllBytes(pathFederateBundle);
ByteString federatedByteString = ByteString.copyFrom(federatedBundleBytes);
return Workload.X509BundlesResponse
.newBuilder()
.putBundles(TrustDomain.of("example.org").getName(), bundleByteString)
.putBundles(TrustDomain.of("domain.test").getName(), federatedByteString)
.build();
}
}

View File

@ -1,8 +1,10 @@
package io.spiffe.workloadapi;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.bundle.x509bundle.X509BundleSet;
import io.spiffe.exception.JwtBundleException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.X509BundleException;
import io.spiffe.exception.X509ContextException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.svid.jwtsvid.JwtSvid;
@ -22,6 +24,16 @@ public class WorkloadApiClientErrorStub implements WorkloadApiClient {
watcher.onError(new X509ContextException("Testing exception"));
}
@Override
public X509BundleSet fetchX509Bundles() throws X509BundleException {
throw new X509BundleException("Testing exception");
}
@Override
public void watchX509Bundles(@NonNull Watcher<X509BundleSet> watcher) {
watcher.onError(new X509BundleException("Testing exception"));
}
@Override
public JwtSvid fetchJwtSvid(@NonNull final String audience, final String... extraAudience) throws JwtSvidException {
throw new JwtSvidException("Testing exception");

View File

@ -24,6 +24,7 @@ import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.KeyPair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
@ -54,6 +55,17 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
watcher.onUpdate(update);
}
@Override
public X509BundleSet fetchX509Bundles() {
return generateX509BundleSet();
}
@Override
public void watchX509Bundles(@NonNull Watcher<X509BundleSet> watcher) {
val x509BundleSet = generateX509BundleSet();
watcher.onUpdate(x509BundleSet);
}
@Override
public JwtSvid fetchJwtSvid(@NonNull final String audience, final String... extraAudience) throws JwtSvidException {
return generateJwtSvid(subject, audience, extraAudience);
@ -91,6 +103,18 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
}
}
private X509BundleSet generateX509BundleSet() {
try {
val pathBundle = Paths.get(toUri(x509Bundle));
byte[] bundleBytes = Files.readAllBytes(pathBundle);
val x509Bundle1 = X509Bundle.parse(TrustDomain.of("example.org"), bundleBytes);
val x509Bundle2 = X509Bundle.parse(TrustDomain.of("domain.test"), bundleBytes);
return X509BundleSet.of(Arrays.asList(x509Bundle1, x509Bundle2));
} catch (IOException | X509BundleException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
private JwtSvid generateJwtSvid(final @NonNull SpiffeId subject, final @NonNull String audience, final String[] extraAudience) throws JwtSvidException {
final Set<String> audParam = new HashSet<>();
audParam.add(audience);

View File

@ -1,8 +1,10 @@
package io.spiffe.helper.keystore;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.bundle.x509bundle.X509BundleSet;
import io.spiffe.exception.JwtBundleException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.X509BundleException;
import io.spiffe.exception.X509ContextException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.svid.jwtsvid.JwtSvid;
@ -25,6 +27,16 @@ public class WorkloadApiClientErrorStub implements WorkloadApiClient {
watcher.onError(new X509ContextException("Testing exception"));
}
@Override
public X509BundleSet fetchX509Bundles() throws X509BundleException {
throw new X509BundleException("Testing exception");
}
@Override
public void watchX509Bundles(@NonNull Watcher<X509BundleSet> watcher) {
watcher.onError(new X509BundleException("Testing exception"));
}
@Override
public JwtSvid fetchJwtSvid(@NonNull final String audience, final String... extraAudience) throws JwtSvidException {
throw new JwtSvidException("Testing exception");

View File

@ -23,6 +23,7 @@ import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Collections;
public class WorkloadApiClientStub implements WorkloadApiClient {
@ -42,6 +43,17 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
watcher.onUpdate(update);
}
@Override
public X509BundleSet fetchX509Bundles() throws X509BundleException {
return getX509BundleSet();
}
@Override
public void watchX509Bundles(@NonNull Watcher<X509BundleSet> watcher) {
val update = getX509BundleSet();
watcher.onUpdate(update);
}
@Override
public JwtSvid fetchJwtSvid(@NonNull String audience, String... extraAudience) throws JwtSvidException {
return null;
@ -88,6 +100,18 @@ public class WorkloadApiClientStub implements WorkloadApiClient {
}
}
private X509BundleSet getX509BundleSet() {
try {
Path pathBundle = Paths.get(toUri(x509Bundle));
byte[] bundleBytes = Files.readAllBytes(pathBundle);
val bundle1 = X509Bundle.parse(TrustDomain.of("example.org"), bundleBytes);
val bundle2 = X509Bundle.parse(TrustDomain.of("domain.test"), bundleBytes);
return X509BundleSet.of(Arrays.asList(bundle1, bundle2));
} catch (IOException | X509BundleException e) {
throw new RuntimeException(e);
}
}
private X509Svid getX509Svid() {
try {
Path pathCert = Paths.get(toUri(svid));