xds: implementation of SdsClient to be used by SDS based SslContextProviders (#6400)

This commit is contained in:
sanjaypujare 2019-11-18 10:36:42 -08:00 committed by GitHub
parent add020fd19
commit b05ce13df2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1146 additions and 1 deletions

View File

@ -24,7 +24,8 @@ dependencies {
project(':grpc-core'),
project(':grpc-netty'),
project(':grpc-services'),
project(':grpc-alts')
project(':grpc-alts'),
libraries.netty_epoll
compile (libraries.pgv) {
// PGV depends on com.google.protobuf:protobuf-java 3.6.1 conflicting with :grpc-protobuf

View File

@ -0,0 +1,386 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.rpc.Code;
import io.envoyproxy.envoy.api.v2.DiscoveryRequest;
import io.envoyproxy.envoy.api.v2.DiscoveryResponse;
import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig;
import io.envoyproxy.envoy.api.v2.auth.Secret;
import io.envoyproxy.envoy.api.v2.core.ApiConfigSource;
import io.envoyproxy.envoy.api.v2.core.ApiConfigSource.ApiType;
import io.envoyproxy.envoy.api.v2.core.ConfigSource;
import io.envoyproxy.envoy.api.v2.core.GrpcService;
import io.envoyproxy.envoy.api.v2.core.GrpcService.GoogleGrpc;
import io.envoyproxy.envoy.api.v2.core.Node;
import io.envoyproxy.envoy.service.discovery.v2.SecretDiscoveryServiceGrpc;
import io.envoyproxy.envoy.service.discovery.v2.SecretDiscoveryServiceGrpc.SecretDiscoveryServiceStub;
import io.grpc.Internal;
import io.grpc.ManagedChannel;
import io.grpc.Status;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.internal.SharedResourceHolder;
import io.grpc.internal.SharedResourceHolder.Resource;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollDomainSocketChannel;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.unix.DomainSocketAddress;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.concurrent.NotThreadSafe;
/**
* SDS client used by an {@link SslContextProvider} to get SDS updates from the SDS server. This is
* most likely a temporary implementation until merged with the XdsClient.
*/
// TODO(sanjaypujare): once XdsClientImpl is ready, merge with it and add retry logic
@Internal
@NotThreadSafe
final class SdsClient {
private static final Logger logger = Logger.getLogger(SdsClient.class.getName());
private static final String SECRET_TYPE_URL = "type.googleapis.com/envoy.api.v2.auth.Secret";
private static final EventLoopGroupResource eventLoopGroupResource =
Epoll.isAvailable() ? new EventLoopGroupResource("SdsClient") : null;
private SecretWatcher watcher;
private final SdsSecretConfig sdsSecretConfig;
private final Node clientNode;
private final Executor watcherExecutor;
private EventLoopGroup eventLoopGroup;
private ManagedChannel channel;
private SecretDiscoveryServiceStub secretDiscoveryServiceStub;
private ResponseObserver responseObserver;
private StreamObserver<DiscoveryRequest> requestObserver;
private DiscoveryResponse lastResponse;
/** Factory for creating SdsClient based on input params and for unit tests. */
static class Factory {
/** Creates an SdsClient with {@link InProcessChannelBuilder}. */
@VisibleForTesting
static SdsClient createWithInProcChannel(
SdsSecretConfig sdsSecretConfig, Node node, Executor watcherExecutor, String name) {
ManagedChannel channel = InProcessChannelBuilder.forName(name).directExecutor().build();
return new SdsClient(sdsSecretConfig, node, watcherExecutor, channel, null);
}
/** Creates an SdsClient with {@link NettyChannelBuilder} for UDS or IP-based sockets. */
static SdsClient createWithNettyChannel(
SdsSecretConfig sdsSecretConfig,
Node node,
Executor watcherExecutor,
Executor channelExecutor) {
String targetUri = extractUdsTarget(sdsSecretConfig.getSdsConfig());
NettyChannelBuilder builder;
EventLoopGroup eventLoopGroup = null;
if (targetUri.startsWith("unix:")) {
checkState(Epoll.isAvailable(), "Epoll is not available");
eventLoopGroup = SharedResourceHolder.get(eventLoopGroupResource);
builder =
NettyChannelBuilder.forAddress(new DomainSocketAddress(targetUri.substring(5)))
.eventLoopGroup(eventLoopGroup)
.channelType(EpollDomainSocketChannel.class);
} else {
builder = NettyChannelBuilder.forTarget(targetUri);
}
builder = builder.usePlaintext();
if (channelExecutor != null) {
builder = builder.executor(channelExecutor);
}
ManagedChannel channel = builder.build();
return new SdsClient(sdsSecretConfig, node, watcherExecutor, channel, eventLoopGroup);
}
@VisibleForTesting
static String extractUdsTarget(ConfigSource configSource) {
checkNotNull(configSource, "configSource");
checkArgument(
configSource.hasApiConfigSource(), "only configSource with ApiConfigSource supported");
ApiConfigSource apiConfigSource = configSource.getApiConfigSource();
checkArgument(
ApiType.GRPC.equals(apiConfigSource.getApiType()),
"only GRPC ApiConfigSource type supported");
checkArgument(
apiConfigSource.getGrpcServicesCount() == 1,
"expecting exactly 1 GrpcService in ApiConfigSource");
GrpcService grpcService = apiConfigSource.getGrpcServices(0);
checkArgument(
grpcService.hasGoogleGrpc() && !grpcService.hasEnvoyGrpc(),
"only GoogleGrpc expected in GrpcService");
GoogleGrpc googleGrpc = grpcService.getGoogleGrpc();
// for now don't support any credentials
checkArgument(
!googleGrpc.hasChannelCredentials()
&& googleGrpc.getCallCredentialsCount() == 0
&& Strings.isNullOrEmpty(googleGrpc.getCredentialsFactoryName()),
"No credentials supported in GoogleGrpc");
String targetUri = googleGrpc.getTargetUri();
checkArgument(!Strings.isNullOrEmpty(targetUri), "targetUri in GoogleGrpc is empty!");
return targetUri;
}
}
private SdsClient(
SdsSecretConfig sdsSecretConfig,
Node node,
Executor watcherExecutor,
ManagedChannel channel,
EventLoopGroup eventLoopGroup) {
checkNotNull(sdsSecretConfig, "sdsSecretConfig");
checkNotNull(node, "node");
this.sdsSecretConfig = sdsSecretConfig;
this.clientNode = node;
this.watcherExecutor = watcherExecutor;
this.eventLoopGroup = eventLoopGroup;
checkNotNull(channel, "channel");
this.channel = channel;
}
/**
* Starts resource discovery with SDS protocol. This method should be the first one to be called
* and should be called only once.
*/
void start() {
if (requestObserver == null) {
secretDiscoveryServiceStub = SecretDiscoveryServiceGrpc.newStub(channel);
responseObserver = new ResponseObserver();
requestObserver = secretDiscoveryServiceStub.streamSecrets(responseObserver);
}
}
/** Stops resource discovery. No method in this class should be called after this point. */
void shutdown() {
if (requestObserver != null) {
requestObserver.onCompleted();
requestObserver = null;
channel.shutdownNow();
if (eventLoopGroup != null) {
eventLoopGroup = SharedResourceHolder.release(eventLoopGroupResource, eventLoopGroup);
}
}
}
/** Response observer for SdsClient. */
private final class ResponseObserver implements StreamObserver<DiscoveryResponse> {
ResponseObserver() {}
@Override
public void onNext(DiscoveryResponse discoveryResponse) {
processDiscoveryResponse(discoveryResponse);
}
@Override
public void onError(Throwable t) {
sendErrorToWatcher(t);
}
@Override
public void onCompleted() {
// TODO(sanjaypujare): add retry logic once client implementation is final
}
}
private void processDiscoveryResponse(final DiscoveryResponse response) {
watcherExecutor.execute(
new Runnable() {
@Override
public void run() {
if (!processSecretsFromDiscoveryResponse(response)) {
sendNack(Code.INTERNAL_VALUE, "Secret not updated");
return;
}
lastResponse = response;
// send discovery request as ACK
sendDiscoveryRequestOnStream();
}
});
}
private void sendNack(int errorCode, String errorMessage) {
String nonce = "";
String versionInfo = "";
if (lastResponse != null) {
nonce = lastResponse.getNonce();
versionInfo = lastResponse.getVersionInfo();
}
DiscoveryRequest.Builder builder =
DiscoveryRequest.newBuilder()
.setTypeUrl(SECRET_TYPE_URL)
.setResponseNonce(nonce)
.setVersionInfo(versionInfo)
.addResourceNames(sdsSecretConfig.getName())
.setErrorDetail(
com.google.rpc.Status.newBuilder()
.setCode(errorCode)
.setMessage(errorMessage)
.build())
.setNode(clientNode);
DiscoveryRequest req = builder.build();
requestObserver.onNext(req);
}
private void sendErrorToWatcher(final Throwable t) {
final SecretWatcher localCopy = watcher;
if (localCopy != null) {
watcherExecutor.execute(
new Runnable() {
@Override
public void run() {
try {
localCopy.onError(Status.fromThrowable(t));
} catch (Throwable throwable) {
logger.log(Level.SEVERE, "exception from onError", throwable);
}
}
});
}
}
private boolean processSecretsFromDiscoveryResponse(DiscoveryResponse response) {
List<Any> resources = response.getResourcesList();
checkState(resources.size() == 1, "exactly one resource expected");
boolean noException = true;
for (Any any : resources) {
final String typeUrl = any.getTypeUrl();
checkState(SECRET_TYPE_URL.equals(typeUrl), "wrong value for typeUrl %s", typeUrl);
Secret secret = null;
try {
secret = Secret.parseFrom(any.getValue());
if (!processSecret(secret)) {
noException = false;
}
} catch (InvalidProtocolBufferException e) {
logger.log(Level.SEVERE, "exception from parseFrom", e);
}
}
return noException;
}
private boolean processSecret(Secret secret) {
checkState(
sdsSecretConfig.getName().equals(secret.getName()),
"expected secret name %s",
sdsSecretConfig.getName());
boolean noException = true;
final SecretWatcher localCopy = watcher;
if (localCopy != null) {
try {
localCopy.onSecretChanged(secret);
} catch (Throwable throwable) {
noException = false;
logger.log(Level.SEVERE, "exception from onSecretChanged", throwable);
}
}
return noException;
}
/** Registers a secret watcher for this client's SdsSecretConfig. */
void watchSecret(SecretWatcher secretWatcher) throws InvalidProtocolBufferException {
checkNotNull(secretWatcher, "secretWatcher");
checkState(watcher == null, "watcher already set");
watcher = secretWatcher;
if (lastResponse == null) {
sendDiscoveryRequestOnStream();
} else {
watcherExecutor.execute(
new Runnable() {
@Override
public void run() {
processSecretsFromDiscoveryResponse(lastResponse);
}
});
}
}
/** Unregisters the given endpoints watcher. */
void cancelSecretWatch(SecretWatcher secretWatcher) {
checkNotNull(secretWatcher, "secretWatcher");
checkArgument(secretWatcher == watcher, "Incorrect secretWatcher to cancel");
watcher = null;
}
/** Secret watcher interface. */
interface SecretWatcher {
void onSecretChanged(Secret secretUpdate);
void onError(Status error);
}
private static final class EventLoopGroupResource implements Resource<EventLoopGroup> {
private final String name;
EventLoopGroupResource(String name) {
this.name = name;
}
@Override
public EventLoopGroup create() {
// Use Netty's DefaultThreadFactory in order to get the benefit of FastThreadLocal.
ThreadFactory threadFactory = new DefaultThreadFactory(name, /* daemon= */ true);
return new EpollEventLoopGroup(1, threadFactory);
}
@SuppressWarnings("FutureReturnValueIgnored")
@Override
public void close(EventLoopGroup instance) {
try {
instance.shutdownGracefully(0, 0, TimeUnit.SECONDS).sync();
} catch (InterruptedException e) {
logger.log(Level.SEVERE, "from EventLoopGroup.shutdownGracefully", e);
Thread.currentThread().interrupt(); // to not "swallow" the exception...
}
}
}
private void sendDiscoveryRequestOnStream() {
String nonce = "";
String versionInfo = "";
if (lastResponse != null) {
nonce = lastResponse.getNonce();
versionInfo = lastResponse.getVersionInfo();
}
DiscoveryRequest.Builder builder =
DiscoveryRequest.newBuilder()
.setTypeUrl(SECRET_TYPE_URL)
.setResponseNonce(nonce)
.setVersionInfo(versionInfo)
.addResourceNames(sdsSecretConfig.getName())
.setNode(clientNode);
DiscoveryRequest req = builder.build();
requestObserver.onNext(req);
}
}

View File

@ -0,0 +1,335 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.sds;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.ByteString;
import io.envoyproxy.envoy.api.v2.DiscoveryRequest;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig;
import io.envoyproxy.envoy.api.v2.auth.Secret;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.core.ApiConfigSource;
import io.envoyproxy.envoy.api.v2.core.ConfigSource;
import io.envoyproxy.envoy.api.v2.core.DataSource;
import io.envoyproxy.envoy.api.v2.core.GrpcService;
import io.envoyproxy.envoy.api.v2.core.Node;
import io.grpc.Status;
import io.grpc.internal.testing.TestUtils;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
/** Unit tests for {@link SdsClient}. */
@RunWith(JUnit4.class)
public class SdsClientTest {
private static final String SERVER_0_PEM_FILE = "server0.pem";
private static final String SERVER_0_KEY_FILE = "server0.key";
private static final String SERVER_1_PEM_FILE = "server1.pem";
private static final String SERVER_1_KEY_FILE = "server1.key";
private static final String CA_PEM_FILE = "ca.pem";
private TestSdsServer.ServerMock serverMock;
private TestSdsServer server;
private SdsClient sdsClient;
private Node node;
private SdsSecretConfig sdsSecretConfig;
private static ConfigSource buildConfigSource(String targetUri) {
return ConfigSource.newBuilder()
.setApiConfigSource(
ApiConfigSource.newBuilder()
.setApiType(ApiConfigSource.ApiType.GRPC)
.addGrpcServices(
GrpcService.newBuilder()
.setGoogleGrpc(
GrpcService.GoogleGrpc.newBuilder().setTargetUri(targetUri).build())
.build())
.build())
.build();
}
private static String getResourcesFileContent(String resFile) throws IOException {
String tempFile = TestUtils.loadCert(resFile).getAbsolutePath();
return new String(Files.readAllBytes(Paths.get(tempFile)), StandardCharsets.UTF_8);
}
@Before
public void setUp() throws IOException {
serverMock = mock(TestSdsServer.ServerMock.class);
server = new TestSdsServer(serverMock);
server.startServer("inproc", false);
ConfigSource configSource = buildConfigSource("inproc");
sdsSecretConfig =
SdsSecretConfig.newBuilder().setSdsConfig(configSource).setName("name1").build();
node = Node.newBuilder().setId("sds-client-temp-test1").build();
sdsClient =
SdsClient.Factory.createWithInProcChannel(
sdsSecretConfig, node, MoreExecutors.directExecutor(), "inproc");
sdsClient.start();
}
@After
public void teardown() throws InterruptedException {
sdsClient.shutdown();
server.shutdown();
}
@Test
public void configSourceUdsTarget() {
ConfigSource configSource = buildConfigSource("unix:/tmp/uds_path");
String targetUri = SdsClient.Factory.extractUdsTarget(configSource);
assertThat(targetUri).isEqualTo("unix:/tmp/uds_path");
}
@Test
public void testSecretWatcher_tlsCertificate() throws IOException {
SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class);
when(serverMock.getSecretFor("name1"))
.thenReturn(getOneTlsCertSecret("name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE));
sdsClient.watchSecret(mockWatcher);
verifyDiscoveryRequest(server.lastGoodRequest, "", "", node, "name1");
verifyDiscoveryRequest(
server.lastRequestOnlyForAck,
server.lastResponse.getVersionInfo(),
server.lastResponse.getNonce(),
node,
"name1");
verifySecretWatcher(mockWatcher, "name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE);
reset(mockWatcher);
when(serverMock.getSecretFor("name1"))
.thenReturn(getOneTlsCertSecret("name1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE));
server.generateAsyncResponse("name1");
verifySecretWatcher(mockWatcher, "name1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE);
reset(mockWatcher);
sdsClient.cancelSecretWatch(mockWatcher);
server.generateAsyncResponse("name1");
verify(mockWatcher, never()).onSecretChanged(ArgumentMatchers.any(Secret.class));
}
@Test
public void testSecretWatcher_certificateValidationContext() throws IOException {
SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class);
when(serverMock.getSecretFor("name1"))
.thenReturn(getOneCertificateValidationContextSecret("name1", CA_PEM_FILE));
sdsClient.watchSecret(mockWatcher);
verifyDiscoveryRequest(server.lastGoodRequest, "", "", node, "name1");
verifyDiscoveryRequest(
server.lastRequestOnlyForAck,
server.lastResponse.getVersionInfo(),
server.lastResponse.getNonce(),
node,
"name1");
verifySecretWatcher(mockWatcher, "name1", CA_PEM_FILE);
reset(mockWatcher);
when(serverMock.getSecretFor("name1"))
.thenReturn(getOneCertificateValidationContextSecret("name1", SERVER_1_PEM_FILE));
server.generateAsyncResponse("name1");
verifySecretWatcher(mockWatcher, "name1", SERVER_1_PEM_FILE);
reset(mockWatcher);
sdsClient.cancelSecretWatch(mockWatcher);
server.generateAsyncResponse("name1");
verify(mockWatcher, never()).onSecretChanged(ArgumentMatchers.any(Secret.class));
}
@Test
public void testSecretWatcher_multipleWatchers_expectException() throws IOException {
SdsClient.SecretWatcher mockWatcher1 = mock(SdsClient.SecretWatcher.class);
SdsClient.SecretWatcher mockWatcher2 = mock(SdsClient.SecretWatcher.class);
when(serverMock.getSecretFor("name1"))
.thenReturn(getOneTlsCertSecret("name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE));
sdsClient.watchSecret(mockWatcher1);
verifyDiscoveryRequest(server.lastGoodRequest, "", "", node, "name1");
verifyDiscoveryRequest(
server.lastRequestOnlyForAck,
server.lastResponse.getVersionInfo(),
server.lastResponse.getNonce(),
node,
"name1");
verifySecretWatcher(mockWatcher1, "name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE);
// add mockWatcher2
try {
sdsClient.watchSecret(mockWatcher2);
fail("exception expected");
} catch (IllegalStateException expected) {
assertThat(expected).hasMessageThat().isEqualTo("watcher already set");
}
}
@Test
public void testSecretWatcher_onError_expectOnError() throws IOException {
SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class);
final ArrayList<DiscoveryRequest> requestArrayList = new ArrayList<>();
when(serverMock.onNext(any(DiscoveryRequest.class)))
.thenAnswer(
new Answer<Boolean>() {
@Override
public Boolean answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
DiscoveryRequest req = (DiscoveryRequest) args[0];
requestArrayList.add(req);
server.discoveryService.inboundStreamObserver.responseObserver.onError(
Status.NOT_FOUND.asException());
return true;
}
});
sdsClient.watchSecret(mockWatcher);
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
verify(mockWatcher, times(1)).onError(statusCaptor.capture());
Status status = statusCaptor.getValue();
assertThat(status).isEqualTo(Status.NOT_FOUND);
assertThat(requestArrayList.size()).isEqualTo(1);
}
@Test
public void testSecretWatcher_onSecretChangedException_expectNack() throws IOException {
SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class);
when(serverMock.getSecretFor("name1"))
.thenReturn(getOneTlsCertSecret("name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE));
doThrow(new RuntimeException("test exception-abc"))
.when(mockWatcher)
.onSecretChanged(any(Secret.class));
sdsClient.watchSecret(mockWatcher);
verifyDiscoveryRequest(server.lastGoodRequest, "", "", node, "name1");
assertThat(server.lastRequestOnlyForAck).isNull();
assertThat(server.lastNack).isNotNull();
assertThat(server.lastNack.getVersionInfo()).isEmpty();
assertThat(server.lastNack.getResponseNonce()).isEmpty();
com.google.rpc.Status errorDetail = server.lastNack.getErrorDetail();
assertThat(errorDetail.getCode()).isEqualTo(Status.Code.INTERNAL.value());
assertThat(errorDetail.getMessage()).isEqualTo("Secret not updated");
}
static void verifyDiscoveryRequest(
DiscoveryRequest request,
String versionInfo,
String responseNonce,
Node node,
String... resourceNames) {
assertThat(request).isNotNull();
assertThat(request.getNode()).isEqualTo(node);
assertThat(request.getResourceNamesList()).isEqualTo(Arrays.asList(resourceNames));
assertThat(request.getTypeUrl()).isEqualTo("type.googleapis.com/envoy.api.v2.auth.Secret");
if (versionInfo != null) {
assertThat(request.getVersionInfo()).isEqualTo(versionInfo);
}
if (responseNonce != null) {
assertThat(request.getResponseNonce()).isEqualTo(responseNonce);
}
}
static void verifySecretWatcher(
SdsClient.SecretWatcher mockWatcher,
String secretName,
String keyFileName,
String certFileName)
throws IOException {
ArgumentCaptor<Secret> secretCaptor = ArgumentCaptor.forClass(Secret.class);
verify(mockWatcher, times(1)).onSecretChanged(secretCaptor.capture());
Secret secret = secretCaptor.getValue();
assertThat(secret.getName()).isEqualTo(secretName);
assertThat(secret.hasTlsCertificate()).isTrue();
TlsCertificate tlsCertificate = secret.getTlsCertificate();
assertThat(tlsCertificate.getPrivateKey().getInlineBytes().toStringUtf8())
.isEqualTo(getResourcesFileContent(keyFileName));
assertThat(tlsCertificate.getCertificateChain().getInlineBytes().toStringUtf8())
.isEqualTo(getResourcesFileContent(certFileName));
}
private void verifySecretWatcher(
SdsClient.SecretWatcher mockWatcher, String secretName, String caFileName)
throws IOException {
ArgumentCaptor<Secret> secretCaptor = ArgumentCaptor.forClass(Secret.class);
verify(mockWatcher, times(1)).onSecretChanged(secretCaptor.capture());
Secret secret = secretCaptor.getValue();
assertThat(secret.getName()).isEqualTo(secretName);
assertThat(secret.hasValidationContext()).isTrue();
CertificateValidationContext certificateValidationContext = secret.getValidationContext();
assertThat(certificateValidationContext.getTrustedCa().getInlineBytes().toStringUtf8())
.isEqualTo(getResourcesFileContent(caFileName));
}
static Secret getOneTlsCertSecret(String name, String keyFileName, String certFileName)
throws IOException {
TlsCertificate tlsCertificate =
TlsCertificate.newBuilder()
.setPrivateKey(
DataSource.newBuilder()
.setInlineBytes(ByteString.copyFromUtf8(getResourcesFileContent(keyFileName)))
.build())
.setCertificateChain(
DataSource.newBuilder()
.setInlineBytes(ByteString.copyFromUtf8(getResourcesFileContent(certFileName)))
.build())
.build();
return Secret.newBuilder().setName(name).setTlsCertificate(tlsCertificate).build();
}
private Secret getOneCertificateValidationContextSecret(String name, String trustFileName)
throws IOException {
CertificateValidationContext certificateValidationContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(
DataSource.newBuilder()
.setInlineBytes(ByteString.copyFromUtf8(getResourcesFileContent(trustFileName)))
.build())
.build();
return Secret.newBuilder()
.setName(name)
.setValidationContext(certificateValidationContext)
.build();
}
}

View File

@ -0,0 +1,136 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.sds;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.util.concurrent.MoreExecutors;
import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig;
import io.envoyproxy.envoy.api.v2.auth.Secret;
import io.envoyproxy.envoy.api.v2.core.ApiConfigSource;
import io.envoyproxy.envoy.api.v2.core.ConfigSource;
import io.envoyproxy.envoy.api.v2.core.GrpcService;
import io.envoyproxy.envoy.api.v2.core.Node;
import io.netty.channel.epoll.Epoll;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
import org.junit.After;
import org.junit.Assume;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentMatchers;
/** Unit tests for {@link SdsClient} using UDS transport. */
@RunWith(JUnit4.class)
public class SdsClientUdsTest {
private static final String SERVER_0_PEM_FILE = "server0.pem";
private static final String SERVER_0_KEY_FILE = "server0.key";
private static final String SERVER_1_PEM_FILE = "server1.pem";
private static final String SERVER_1_KEY_FILE = "server1.key";
private static final String SDSCLIENT_TEST_SOCKET = "/tmp/sdsclient-test.socket";
private TestSdsServer.ServerMock serverMock;
private TestSdsServer server;
private SdsClient sdsClient;
private Node node;
private SdsSecretConfig sdsSecretConfig;
private static ConfigSource buildConfigSource(String targetUri) {
return ConfigSource.newBuilder()
.setApiConfigSource(
ApiConfigSource.newBuilder()
.setApiType(ApiConfigSource.ApiType.GRPC)
.addGrpcServices(
GrpcService.newBuilder()
.setGoogleGrpc(
GrpcService.GoogleGrpc.newBuilder().setTargetUri(targetUri).build())
.build())
.build())
.build();
}
@Before
public void setUp() throws IOException {
Assume.assumeTrue(Epoll.isAvailable());
serverMock = mock(TestSdsServer.ServerMock.class);
server = new TestSdsServer(serverMock);
server.startServer(SDSCLIENT_TEST_SOCKET, true);
ConfigSource configSource = buildConfigSource("unix:" + SDSCLIENT_TEST_SOCKET);
sdsSecretConfig =
SdsSecretConfig.newBuilder().setSdsConfig(configSource).setName("name1").build();
node = Node.newBuilder().setId("sds-client-temp-test2").build();
sdsClient =
SdsClient.Factory.createWithNettyChannel(
sdsSecretConfig, node, MoreExecutors.directExecutor(), MoreExecutors.directExecutor());
sdsClient.start();
}
@After
public void teardown() throws InterruptedException {
if (sdsClient != null) {
sdsClient.shutdown();
}
if (server != null) {
server.shutdown();
}
}
@Test
public void testSecretWatcher_tlsCertificate() throws IOException, InterruptedException {
final SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class);
when(serverMock.getSecretFor("name1"))
.thenReturn(
SdsClientTest.getOneTlsCertSecret("name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE));
sdsClient.watchSecret(mockWatcher);
// wait until our server received the requests
assertThat(server.requestsCounter.tryAcquire(2, 1000, TimeUnit.MILLISECONDS)).isTrue();
SdsClientTest.verifyDiscoveryRequest(server.lastGoodRequest, "", "", node, "name1");
SdsClientTest.verifySecretWatcher(mockWatcher, "name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE);
SdsClientTest.verifyDiscoveryRequest(
server.lastRequestOnlyForAck,
server.lastResponse.getVersionInfo(),
server.lastResponse.getNonce(),
node,
"name1");
reset(mockWatcher);
when(serverMock.getSecretFor("name1"))
.thenReturn(
SdsClientTest.getOneTlsCertSecret("name1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE));
server.generateAsyncResponse("name1");
// wait until our server received the request
assertThat(server.requestsCounter.tryAcquire(1, 1000, TimeUnit.MILLISECONDS)).isTrue();
SdsClientTest.verifySecretWatcher(mockWatcher, "name1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE);
reset(mockWatcher);
sdsClient.cancelSecretWatch(mockWatcher);
server.generateAsyncResponse("name1");
// wait until our server received the request
assertThat(server.requestsCounter.tryAcquire(1, 1000, TimeUnit.MILLISECONDS)).isTrue();
verify(mockWatcher, never()).onSecretChanged(ArgumentMatchers.any(Secret.class));
}
}

View File

@ -0,0 +1,287 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.ProtocolStringList;
import io.envoyproxy.envoy.api.v2.DiscoveryRequest;
import io.envoyproxy.envoy.api.v2.DiscoveryResponse;
import io.envoyproxy.envoy.api.v2.auth.Secret;
import io.envoyproxy.envoy.service.discovery.v2.SecretDiscoveryServiceGrpc;
import io.grpc.Server;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.stub.StreamObserver;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerDomainSocketChannel;
import io.netty.channel.unix.DomainSocketAddress;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.Semaphore;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
/** Server used only in testing {@link SdsClient} so does not contain actual server logic. */
final class TestSdsServer {
private static final Logger logger = Logger.getLogger(TestSdsServer.class.getName());
// SecretTypeURL defines the type URL for Envoy secret proto.
private static final String SECRET_TYPE_URL = "type.googleapis.com/envoy.api.v2.auth.Secret";
private String currentVersion;
private String lastRespondedNonce;
private List<String> lastResourceNames;
@VisibleForTesting SecretDiscoveryServiceImpl discoveryService;
/** Used for signalling test clients that request received and processed. */
@VisibleForTesting final Semaphore requestsCounter = new Semaphore(0);
private EventLoopGroup elg;
private EventLoopGroup boss;
private Server server;
private final ServerMock serverMock;
/** last "good" discovery request we processed and sent a response to. */
@VisibleForTesting DiscoveryRequest lastGoodRequest;
/** last discovery request that was only used as Ack since it contained no new resources. */
@VisibleForTesting DiscoveryRequest lastRequestOnlyForAck;
/** last Nack. */
@VisibleForTesting DiscoveryRequest lastNack;
/** last response we sent. */
@VisibleForTesting DiscoveryResponse lastResponse;
TestSdsServer(ServerMock serverMock) {
checkNotNull(serverMock, "serverMock");
this.serverMock = serverMock;
}
/**
* Starts the server with given transport params.
*
* @param name UDS pathname or server name for {@link InProcessServerBuilder}
* @param useUds creates a UDS based server if true.
*/
void startServer(String name, boolean useUds) throws IOException {
checkNotNull(name, "name");
discoveryService = new SecretDiscoveryServiceImpl();
if (useUds) {
elg = new EpollEventLoopGroup();
boss = new EpollEventLoopGroup(1);
server =
NettyServerBuilder.forAddress(new DomainSocketAddress(name))
.bossEventLoopGroup(boss)
.workerEventLoopGroup(elg)
.channelType(EpollServerDomainSocketChannel.class)
.addService(discoveryService)
.directExecutor()
.build()
.start();
} else {
server =
InProcessServerBuilder.forName(name)
.addService(discoveryService)
.directExecutor()
.build()
.start();
}
}
@SuppressWarnings("FutureReturnValueIgnored")
void shutdown() throws InterruptedException {
server.shutdown();
if (boss != null) {
boss.shutdownGracefully().sync();
}
if (elg != null) {
elg.shutdownGracefully().sync();
}
}
/** Interface that allows mocking server behavior. */
interface ServerMock {
Secret getSecretFor(String name);
boolean onNext(DiscoveryRequest discoveryRequest);
}
/** Callers can call this to return an "async" response for given resource names. */
void generateAsyncResponse(String... names) {
discoveryService.inboundStreamObserver.generateAsyncResponse(Arrays.asList(names));
}
/** Main streaming service implementation. */
final class SecretDiscoveryServiceImpl
extends SecretDiscoveryServiceGrpc.SecretDiscoveryServiceImplBase {
// we use startTime for generating version string.
final long startTime = System.nanoTime();
SdsInboundStreamObserver inboundStreamObserver;
@Override
public StreamObserver<DiscoveryRequest> streamSecrets(
StreamObserver<DiscoveryResponse> responseObserver) {
checkNotNull(responseObserver, "responseObserver");
inboundStreamObserver = new SdsInboundStreamObserver(responseObserver);
return inboundStreamObserver;
}
@Override
public void fetchSecrets(
DiscoveryRequest discoveryRequest, StreamObserver<DiscoveryResponse> responseObserver) {
throw new UnsupportedOperationException("unary fetchSecrets not implemented!");
}
private DiscoveryResponse buildResponse(DiscoveryRequest discoveryRequest) {
checkNotNull(discoveryRequest, "discoveryRequest");
String requestVersion = discoveryRequest.getVersionInfo();
String requestNonce = discoveryRequest.getResponseNonce();
ProtocolStringList resourceNames = discoveryRequest.getResourceNamesList();
return buildResponse(requestVersion, requestNonce, resourceNames, false, discoveryRequest);
}
private DiscoveryResponse buildResponse(
String requestVersion,
String requestNonce,
List<String> resourceNames,
boolean forcedAsync,
DiscoveryRequest discoveryRequest) {
checkNotNull(resourceNames, "resourceNames");
if (discoveryRequest != null && discoveryRequest.hasErrorDetail()) {
lastNack = discoveryRequest;
return null;
}
// for stale version or nonce don't send a response
if (!Strings.isNullOrEmpty(requestVersion) && !requestVersion.equals(currentVersion)) {
return null;
}
if (!Strings.isNullOrEmpty(requestNonce) && !requestNonce.equals(lastRespondedNonce)) {
return null;
}
// check if any new resources are being requested...
if (!forcedAsync && isSubset(resourceNames, lastResourceNames)) {
if (discoveryRequest != null) {
lastRequestOnlyForAck = discoveryRequest;
}
return null;
}
final String version = generateVersionFromCurrentTime();
DiscoveryResponse.Builder responseBuilder =
DiscoveryResponse.newBuilder()
.setVersionInfo(version)
.setNonce(generateAndSaveNonce())
.setTypeUrl(SECRET_TYPE_URL);
for (String resourceName : resourceNames) {
buildAndAddResource(responseBuilder, resourceName);
}
DiscoveryResponse response = responseBuilder.build();
currentVersion = version;
lastResponse = response;
lastResourceNames = resourceNames;
return response;
}
private String generateVersionFromCurrentTime() {
return "" + ((System.nanoTime() - startTime) / 1000000L);
}
private void buildAndAddResource(
DiscoveryResponse.Builder responseBuilder, String resourceName) {
Secret secret = serverMock.getSecretFor(resourceName);
ByteString data = secret.toByteString();
Any anyValue = Any.newBuilder().setTypeUrl(SECRET_TYPE_URL).setValue(data).build();
responseBuilder.addResources(anyValue);
}
/** Inbound {@link StreamObserver} to process incoming requests. */
final class SdsInboundStreamObserver implements StreamObserver<DiscoveryRequest> {
final StreamObserver<DiscoveryResponse> responseObserver;
ScheduledExecutorService periodicScheduler;
ScheduledFuture<?> future;
SdsInboundStreamObserver(StreamObserver<DiscoveryResponse> responseObserver) {
this.responseObserver = responseObserver;
}
private void generateAsyncResponse(List<String> nameList) {
checkNotNull(nameList, "nameList");
if (!nameList.isEmpty()) {
responseObserver.onNext(
buildResponse(
currentVersion,
lastRespondedNonce,
nameList,
/* forcedAsync= */ true,
/* discoveryRequest= */ null));
}
}
@Override
public void onNext(DiscoveryRequest discoveryRequest) {
if (!serverMock.onNext(discoveryRequest)) {
DiscoveryResponse discoveryResponse = buildResponse(discoveryRequest);
if (discoveryResponse != null) {
lastGoodRequest = discoveryRequest;
responseObserver.onNext(discoveryResponse);
}
}
requestsCounter.release();
}
@Override
public void onError(Throwable t) {
logger.log(Level.SEVERE, "onError", t);
}
@Override
public void onCompleted() {
responseObserver.onCompleted();
}
}
}
/** Checks if resourceNames is a "subset" of lastResourceNames. */
private static boolean isSubset(
@Nullable List<String> resourceNames, @Nullable List<String> lastResourceNames) {
if (lastResourceNames == null) {
return resourceNames == null || resourceNames.isEmpty();
}
if (resourceNames == null) {
return true; // since lastResourceNames is NOT null
}
return lastResourceNames.containsAll(resourceNames);
}
private String generateAndSaveNonce() {
lastRespondedNonce = Long.toHexString(System.currentTimeMillis());
return lastRespondedNonce;
}
}