diff --git a/xds/build.gradle b/xds/build.gradle index 7b7cc101c5..9839d1a938 100644 --- a/xds/build.gradle +++ b/xds/build.gradle @@ -24,7 +24,8 @@ dependencies { project(':grpc-core'), project(':grpc-netty'), project(':grpc-services'), - project(':grpc-alts') + project(':grpc-alts'), + libraries.netty_epoll compile (libraries.pgv) { // PGV depends on com.google.protobuf:protobuf-java 3.6.1 conflicting with :grpc-protobuf diff --git a/xds/src/main/java/io/grpc/xds/sds/SdsClient.java b/xds/src/main/java/io/grpc/xds/sds/SdsClient.java new file mode 100644 index 0000000000..f43486a4f1 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/sds/SdsClient.java @@ -0,0 +1,386 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.sds; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.rpc.Code; +import io.envoyproxy.envoy.api.v2.DiscoveryRequest; +import io.envoyproxy.envoy.api.v2.DiscoveryResponse; +import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig; +import io.envoyproxy.envoy.api.v2.auth.Secret; +import io.envoyproxy.envoy.api.v2.core.ApiConfigSource; +import io.envoyproxy.envoy.api.v2.core.ApiConfigSource.ApiType; +import io.envoyproxy.envoy.api.v2.core.ConfigSource; +import io.envoyproxy.envoy.api.v2.core.GrpcService; +import io.envoyproxy.envoy.api.v2.core.GrpcService.GoogleGrpc; +import io.envoyproxy.envoy.api.v2.core.Node; +import io.envoyproxy.envoy.service.discovery.v2.SecretDiscoveryServiceGrpc; +import io.envoyproxy.envoy.service.discovery.v2.SecretDiscoveryServiceGrpc.SecretDiscoveryServiceStub; +import io.grpc.Internal; +import io.grpc.ManagedChannel; +import io.grpc.Status; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.internal.SharedResourceHolder; +import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.stub.StreamObserver; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.epoll.Epoll; +import io.netty.channel.epoll.EpollDomainSocketChannel; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.unix.DomainSocketAddress; +import io.netty.util.concurrent.DefaultThreadFactory; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.concurrent.NotThreadSafe; + +/** + * SDS client used by an {@link SslContextProvider} to get SDS updates from the SDS server. This is + * most likely a temporary implementation until merged with the XdsClient. + */ +// TODO(sanjaypujare): once XdsClientImpl is ready, merge with it and add retry logic +@Internal +@NotThreadSafe +final class SdsClient { + private static final Logger logger = Logger.getLogger(SdsClient.class.getName()); + private static final String SECRET_TYPE_URL = "type.googleapis.com/envoy.api.v2.auth.Secret"; + private static final EventLoopGroupResource eventLoopGroupResource = + Epoll.isAvailable() ? new EventLoopGroupResource("SdsClient") : null; + + private SecretWatcher watcher; + private final SdsSecretConfig sdsSecretConfig; + private final Node clientNode; + private final Executor watcherExecutor; + private EventLoopGroup eventLoopGroup; + private ManagedChannel channel; + private SecretDiscoveryServiceStub secretDiscoveryServiceStub; + private ResponseObserver responseObserver; + private StreamObserver requestObserver; + private DiscoveryResponse lastResponse; + + /** Factory for creating SdsClient based on input params and for unit tests. */ + static class Factory { + + /** Creates an SdsClient with {@link InProcessChannelBuilder}. */ + @VisibleForTesting + static SdsClient createWithInProcChannel( + SdsSecretConfig sdsSecretConfig, Node node, Executor watcherExecutor, String name) { + ManagedChannel channel = InProcessChannelBuilder.forName(name).directExecutor().build(); + return new SdsClient(sdsSecretConfig, node, watcherExecutor, channel, null); + } + + /** Creates an SdsClient with {@link NettyChannelBuilder} for UDS or IP-based sockets. */ + static SdsClient createWithNettyChannel( + SdsSecretConfig sdsSecretConfig, + Node node, + Executor watcherExecutor, + Executor channelExecutor) { + String targetUri = extractUdsTarget(sdsSecretConfig.getSdsConfig()); + NettyChannelBuilder builder; + EventLoopGroup eventLoopGroup = null; + if (targetUri.startsWith("unix:")) { + checkState(Epoll.isAvailable(), "Epoll is not available"); + eventLoopGroup = SharedResourceHolder.get(eventLoopGroupResource); + builder = + NettyChannelBuilder.forAddress(new DomainSocketAddress(targetUri.substring(5))) + .eventLoopGroup(eventLoopGroup) + .channelType(EpollDomainSocketChannel.class); + } else { + builder = NettyChannelBuilder.forTarget(targetUri); + } + builder = builder.usePlaintext(); + if (channelExecutor != null) { + builder = builder.executor(channelExecutor); + } + ManagedChannel channel = builder.build(); + return new SdsClient(sdsSecretConfig, node, watcherExecutor, channel, eventLoopGroup); + } + + @VisibleForTesting + static String extractUdsTarget(ConfigSource configSource) { + checkNotNull(configSource, "configSource"); + checkArgument( + configSource.hasApiConfigSource(), "only configSource with ApiConfigSource supported"); + ApiConfigSource apiConfigSource = configSource.getApiConfigSource(); + checkArgument( + ApiType.GRPC.equals(apiConfigSource.getApiType()), + "only GRPC ApiConfigSource type supported"); + checkArgument( + apiConfigSource.getGrpcServicesCount() == 1, + "expecting exactly 1 GrpcService in ApiConfigSource"); + GrpcService grpcService = apiConfigSource.getGrpcServices(0); + checkArgument( + grpcService.hasGoogleGrpc() && !grpcService.hasEnvoyGrpc(), + "only GoogleGrpc expected in GrpcService"); + GoogleGrpc googleGrpc = grpcService.getGoogleGrpc(); + // for now don't support any credentials + checkArgument( + !googleGrpc.hasChannelCredentials() + && googleGrpc.getCallCredentialsCount() == 0 + && Strings.isNullOrEmpty(googleGrpc.getCredentialsFactoryName()), + "No credentials supported in GoogleGrpc"); + String targetUri = googleGrpc.getTargetUri(); + checkArgument(!Strings.isNullOrEmpty(targetUri), "targetUri in GoogleGrpc is empty!"); + return targetUri; + } + } + + private SdsClient( + SdsSecretConfig sdsSecretConfig, + Node node, + Executor watcherExecutor, + ManagedChannel channel, + EventLoopGroup eventLoopGroup) { + checkNotNull(sdsSecretConfig, "sdsSecretConfig"); + checkNotNull(node, "node"); + this.sdsSecretConfig = sdsSecretConfig; + this.clientNode = node; + this.watcherExecutor = watcherExecutor; + this.eventLoopGroup = eventLoopGroup; + checkNotNull(channel, "channel"); + this.channel = channel; + } + + /** + * Starts resource discovery with SDS protocol. This method should be the first one to be called + * and should be called only once. + */ + void start() { + if (requestObserver == null) { + secretDiscoveryServiceStub = SecretDiscoveryServiceGrpc.newStub(channel); + responseObserver = new ResponseObserver(); + requestObserver = secretDiscoveryServiceStub.streamSecrets(responseObserver); + } + } + + /** Stops resource discovery. No method in this class should be called after this point. */ + void shutdown() { + if (requestObserver != null) { + requestObserver.onCompleted(); + requestObserver = null; + channel.shutdownNow(); + if (eventLoopGroup != null) { + eventLoopGroup = SharedResourceHolder.release(eventLoopGroupResource, eventLoopGroup); + } + } + } + + /** Response observer for SdsClient. */ + private final class ResponseObserver implements StreamObserver { + ResponseObserver() {} + + @Override + public void onNext(DiscoveryResponse discoveryResponse) { + processDiscoveryResponse(discoveryResponse); + } + + @Override + public void onError(Throwable t) { + sendErrorToWatcher(t); + } + + @Override + public void onCompleted() { + // TODO(sanjaypujare): add retry logic once client implementation is final + } + } + + private void processDiscoveryResponse(final DiscoveryResponse response) { + watcherExecutor.execute( + new Runnable() { + @Override + public void run() { + if (!processSecretsFromDiscoveryResponse(response)) { + sendNack(Code.INTERNAL_VALUE, "Secret not updated"); + return; + } + lastResponse = response; + // send discovery request as ACK + sendDiscoveryRequestOnStream(); + } + }); + } + + private void sendNack(int errorCode, String errorMessage) { + String nonce = ""; + String versionInfo = ""; + + if (lastResponse != null) { + nonce = lastResponse.getNonce(); + versionInfo = lastResponse.getVersionInfo(); + } + DiscoveryRequest.Builder builder = + DiscoveryRequest.newBuilder() + .setTypeUrl(SECRET_TYPE_URL) + .setResponseNonce(nonce) + .setVersionInfo(versionInfo) + .addResourceNames(sdsSecretConfig.getName()) + .setErrorDetail( + com.google.rpc.Status.newBuilder() + .setCode(errorCode) + .setMessage(errorMessage) + .build()) + .setNode(clientNode); + + DiscoveryRequest req = builder.build(); + requestObserver.onNext(req); + } + + private void sendErrorToWatcher(final Throwable t) { + final SecretWatcher localCopy = watcher; + if (localCopy != null) { + watcherExecutor.execute( + new Runnable() { + @Override + public void run() { + try { + localCopy.onError(Status.fromThrowable(t)); + } catch (Throwable throwable) { + logger.log(Level.SEVERE, "exception from onError", throwable); + } + } + }); + } + } + + private boolean processSecretsFromDiscoveryResponse(DiscoveryResponse response) { + List resources = response.getResourcesList(); + checkState(resources.size() == 1, "exactly one resource expected"); + boolean noException = true; + for (Any any : resources) { + final String typeUrl = any.getTypeUrl(); + checkState(SECRET_TYPE_URL.equals(typeUrl), "wrong value for typeUrl %s", typeUrl); + Secret secret = null; + try { + secret = Secret.parseFrom(any.getValue()); + if (!processSecret(secret)) { + noException = false; + } + } catch (InvalidProtocolBufferException e) { + logger.log(Level.SEVERE, "exception from parseFrom", e); + } + } + return noException; + } + + private boolean processSecret(Secret secret) { + checkState( + sdsSecretConfig.getName().equals(secret.getName()), + "expected secret name %s", + sdsSecretConfig.getName()); + boolean noException = true; + final SecretWatcher localCopy = watcher; + if (localCopy != null) { + try { + localCopy.onSecretChanged(secret); + } catch (Throwable throwable) { + noException = false; + logger.log(Level.SEVERE, "exception from onSecretChanged", throwable); + } + } + return noException; + } + + /** Registers a secret watcher for this client's SdsSecretConfig. */ + void watchSecret(SecretWatcher secretWatcher) throws InvalidProtocolBufferException { + checkNotNull(secretWatcher, "secretWatcher"); + checkState(watcher == null, "watcher already set"); + watcher = secretWatcher; + if (lastResponse == null) { + sendDiscoveryRequestOnStream(); + } else { + watcherExecutor.execute( + new Runnable() { + @Override + public void run() { + processSecretsFromDiscoveryResponse(lastResponse); + } + }); + } + } + + /** Unregisters the given endpoints watcher. */ + void cancelSecretWatch(SecretWatcher secretWatcher) { + checkNotNull(secretWatcher, "secretWatcher"); + checkArgument(secretWatcher == watcher, "Incorrect secretWatcher to cancel"); + watcher = null; + } + + /** Secret watcher interface. */ + interface SecretWatcher { + void onSecretChanged(Secret secretUpdate); + + void onError(Status error); + } + + private static final class EventLoopGroupResource implements Resource { + private final String name; + + EventLoopGroupResource(String name) { + this.name = name; + } + + @Override + public EventLoopGroup create() { + // Use Netty's DefaultThreadFactory in order to get the benefit of FastThreadLocal. + ThreadFactory threadFactory = new DefaultThreadFactory(name, /* daemon= */ true); + return new EpollEventLoopGroup(1, threadFactory); + } + + @SuppressWarnings("FutureReturnValueIgnored") + @Override + public void close(EventLoopGroup instance) { + try { + instance.shutdownGracefully(0, 0, TimeUnit.SECONDS).sync(); + } catch (InterruptedException e) { + logger.log(Level.SEVERE, "from EventLoopGroup.shutdownGracefully", e); + Thread.currentThread().interrupt(); // to not "swallow" the exception... + } + } + } + + private void sendDiscoveryRequestOnStream() { + String nonce = ""; + String versionInfo = ""; + + if (lastResponse != null) { + nonce = lastResponse.getNonce(); + versionInfo = lastResponse.getVersionInfo(); + } + DiscoveryRequest.Builder builder = + DiscoveryRequest.newBuilder() + .setTypeUrl(SECRET_TYPE_URL) + .setResponseNonce(nonce) + .setVersionInfo(versionInfo) + .addResourceNames(sdsSecretConfig.getName()) + .setNode(clientNode); + + DiscoveryRequest req = builder.build(); + requestObserver.onNext(req); + } +} diff --git a/xds/src/test/java/io/grpc/xds/sds/SdsClientTest.java b/xds/src/test/java/io/grpc/xds/sds/SdsClientTest.java new file mode 100644 index 0000000000..38664a3942 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/sds/SdsClientTest.java @@ -0,0 +1,335 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.sds; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.util.concurrent.MoreExecutors; +import com.google.protobuf.ByteString; +import io.envoyproxy.envoy.api.v2.DiscoveryRequest; +import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; +import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig; +import io.envoyproxy.envoy.api.v2.auth.Secret; +import io.envoyproxy.envoy.api.v2.auth.TlsCertificate; +import io.envoyproxy.envoy.api.v2.core.ApiConfigSource; +import io.envoyproxy.envoy.api.v2.core.ConfigSource; +import io.envoyproxy.envoy.api.v2.core.DataSource; +import io.envoyproxy.envoy.api.v2.core.GrpcService; +import io.envoyproxy.envoy.api.v2.core.Node; +import io.grpc.Status; +import io.grpc.internal.testing.TestUtils; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +/** Unit tests for {@link SdsClient}. */ +@RunWith(JUnit4.class) +public class SdsClientTest { + + private static final String SERVER_0_PEM_FILE = "server0.pem"; + private static final String SERVER_0_KEY_FILE = "server0.key"; + private static final String SERVER_1_PEM_FILE = "server1.pem"; + private static final String SERVER_1_KEY_FILE = "server1.key"; + private static final String CA_PEM_FILE = "ca.pem"; + + private TestSdsServer.ServerMock serverMock; + private TestSdsServer server; + private SdsClient sdsClient; + private Node node; + private SdsSecretConfig sdsSecretConfig; + + private static ConfigSource buildConfigSource(String targetUri) { + return ConfigSource.newBuilder() + .setApiConfigSource( + ApiConfigSource.newBuilder() + .setApiType(ApiConfigSource.ApiType.GRPC) + .addGrpcServices( + GrpcService.newBuilder() + .setGoogleGrpc( + GrpcService.GoogleGrpc.newBuilder().setTargetUri(targetUri).build()) + .build()) + .build()) + .build(); + } + + private static String getResourcesFileContent(String resFile) throws IOException { + String tempFile = TestUtils.loadCert(resFile).getAbsolutePath(); + return new String(Files.readAllBytes(Paths.get(tempFile)), StandardCharsets.UTF_8); + } + + @Before + public void setUp() throws IOException { + serverMock = mock(TestSdsServer.ServerMock.class); + server = new TestSdsServer(serverMock); + server.startServer("inproc", false); + ConfigSource configSource = buildConfigSource("inproc"); + sdsSecretConfig = + SdsSecretConfig.newBuilder().setSdsConfig(configSource).setName("name1").build(); + node = Node.newBuilder().setId("sds-client-temp-test1").build(); + sdsClient = + SdsClient.Factory.createWithInProcChannel( + sdsSecretConfig, node, MoreExecutors.directExecutor(), "inproc"); + sdsClient.start(); + } + + @After + public void teardown() throws InterruptedException { + sdsClient.shutdown(); + server.shutdown(); + } + + @Test + public void configSourceUdsTarget() { + ConfigSource configSource = buildConfigSource("unix:/tmp/uds_path"); + String targetUri = SdsClient.Factory.extractUdsTarget(configSource); + assertThat(targetUri).isEqualTo("unix:/tmp/uds_path"); + } + + @Test + public void testSecretWatcher_tlsCertificate() throws IOException { + SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class); + + when(serverMock.getSecretFor("name1")) + .thenReturn(getOneTlsCertSecret("name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE)); + + sdsClient.watchSecret(mockWatcher); + verifyDiscoveryRequest(server.lastGoodRequest, "", "", node, "name1"); + verifyDiscoveryRequest( + server.lastRequestOnlyForAck, + server.lastResponse.getVersionInfo(), + server.lastResponse.getNonce(), + node, + "name1"); + verifySecretWatcher(mockWatcher, "name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE); + + reset(mockWatcher); + when(serverMock.getSecretFor("name1")) + .thenReturn(getOneTlsCertSecret("name1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE)); + server.generateAsyncResponse("name1"); + verifySecretWatcher(mockWatcher, "name1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE); + + reset(mockWatcher); + sdsClient.cancelSecretWatch(mockWatcher); + server.generateAsyncResponse("name1"); + verify(mockWatcher, never()).onSecretChanged(ArgumentMatchers.any(Secret.class)); + } + + @Test + public void testSecretWatcher_certificateValidationContext() throws IOException { + SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class); + + when(serverMock.getSecretFor("name1")) + .thenReturn(getOneCertificateValidationContextSecret("name1", CA_PEM_FILE)); + + sdsClient.watchSecret(mockWatcher); + verifyDiscoveryRequest(server.lastGoodRequest, "", "", node, "name1"); + verifyDiscoveryRequest( + server.lastRequestOnlyForAck, + server.lastResponse.getVersionInfo(), + server.lastResponse.getNonce(), + node, + "name1"); + verifySecretWatcher(mockWatcher, "name1", CA_PEM_FILE); + + reset(mockWatcher); + when(serverMock.getSecretFor("name1")) + .thenReturn(getOneCertificateValidationContextSecret("name1", SERVER_1_PEM_FILE)); + server.generateAsyncResponse("name1"); + verifySecretWatcher(mockWatcher, "name1", SERVER_1_PEM_FILE); + + reset(mockWatcher); + sdsClient.cancelSecretWatch(mockWatcher); + server.generateAsyncResponse("name1"); + verify(mockWatcher, never()).onSecretChanged(ArgumentMatchers.any(Secret.class)); + } + + @Test + public void testSecretWatcher_multipleWatchers_expectException() throws IOException { + SdsClient.SecretWatcher mockWatcher1 = mock(SdsClient.SecretWatcher.class); + SdsClient.SecretWatcher mockWatcher2 = mock(SdsClient.SecretWatcher.class); + + when(serverMock.getSecretFor("name1")) + .thenReturn(getOneTlsCertSecret("name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE)); + + sdsClient.watchSecret(mockWatcher1); + verifyDiscoveryRequest(server.lastGoodRequest, "", "", node, "name1"); + verifyDiscoveryRequest( + server.lastRequestOnlyForAck, + server.lastResponse.getVersionInfo(), + server.lastResponse.getNonce(), + node, + "name1"); + verifySecretWatcher(mockWatcher1, "name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE); + + // add mockWatcher2 + try { + sdsClient.watchSecret(mockWatcher2); + fail("exception expected"); + } catch (IllegalStateException expected) { + assertThat(expected).hasMessageThat().isEqualTo("watcher already set"); + } + } + + @Test + public void testSecretWatcher_onError_expectOnError() throws IOException { + SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class); + final ArrayList requestArrayList = new ArrayList<>(); + + when(serverMock.onNext(any(DiscoveryRequest.class))) + .thenAnswer( + new Answer() { + @Override + public Boolean answer(InvocationOnMock invocation) throws Throwable { + Object[] args = invocation.getArguments(); + DiscoveryRequest req = (DiscoveryRequest) args[0]; + requestArrayList.add(req); + server.discoveryService.inboundStreamObserver.responseObserver.onError( + Status.NOT_FOUND.asException()); + return true; + } + }); + sdsClient.watchSecret(mockWatcher); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(mockWatcher, times(1)).onError(statusCaptor.capture()); + Status status = statusCaptor.getValue(); + assertThat(status).isEqualTo(Status.NOT_FOUND); + assertThat(requestArrayList.size()).isEqualTo(1); + } + + @Test + public void testSecretWatcher_onSecretChangedException_expectNack() throws IOException { + SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class); + + when(serverMock.getSecretFor("name1")) + .thenReturn(getOneTlsCertSecret("name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE)); + doThrow(new RuntimeException("test exception-abc")) + .when(mockWatcher) + .onSecretChanged(any(Secret.class)); + + sdsClient.watchSecret(mockWatcher); + verifyDiscoveryRequest(server.lastGoodRequest, "", "", node, "name1"); + assertThat(server.lastRequestOnlyForAck).isNull(); + assertThat(server.lastNack).isNotNull(); + assertThat(server.lastNack.getVersionInfo()).isEmpty(); + assertThat(server.lastNack.getResponseNonce()).isEmpty(); + com.google.rpc.Status errorDetail = server.lastNack.getErrorDetail(); + assertThat(errorDetail.getCode()).isEqualTo(Status.Code.INTERNAL.value()); + assertThat(errorDetail.getMessage()).isEqualTo("Secret not updated"); + } + + static void verifyDiscoveryRequest( + DiscoveryRequest request, + String versionInfo, + String responseNonce, + Node node, + String... resourceNames) { + assertThat(request).isNotNull(); + assertThat(request.getNode()).isEqualTo(node); + assertThat(request.getResourceNamesList()).isEqualTo(Arrays.asList(resourceNames)); + assertThat(request.getTypeUrl()).isEqualTo("type.googleapis.com/envoy.api.v2.auth.Secret"); + if (versionInfo != null) { + assertThat(request.getVersionInfo()).isEqualTo(versionInfo); + } + if (responseNonce != null) { + assertThat(request.getResponseNonce()).isEqualTo(responseNonce); + } + } + + static void verifySecretWatcher( + SdsClient.SecretWatcher mockWatcher, + String secretName, + String keyFileName, + String certFileName) + throws IOException { + ArgumentCaptor secretCaptor = ArgumentCaptor.forClass(Secret.class); + verify(mockWatcher, times(1)).onSecretChanged(secretCaptor.capture()); + Secret secret = secretCaptor.getValue(); + assertThat(secret.getName()).isEqualTo(secretName); + assertThat(secret.hasTlsCertificate()).isTrue(); + TlsCertificate tlsCertificate = secret.getTlsCertificate(); + assertThat(tlsCertificate.getPrivateKey().getInlineBytes().toStringUtf8()) + .isEqualTo(getResourcesFileContent(keyFileName)); + assertThat(tlsCertificate.getCertificateChain().getInlineBytes().toStringUtf8()) + .isEqualTo(getResourcesFileContent(certFileName)); + } + + private void verifySecretWatcher( + SdsClient.SecretWatcher mockWatcher, String secretName, String caFileName) + throws IOException { + ArgumentCaptor secretCaptor = ArgumentCaptor.forClass(Secret.class); + verify(mockWatcher, times(1)).onSecretChanged(secretCaptor.capture()); + Secret secret = secretCaptor.getValue(); + assertThat(secret.getName()).isEqualTo(secretName); + assertThat(secret.hasValidationContext()).isTrue(); + CertificateValidationContext certificateValidationContext = secret.getValidationContext(); + assertThat(certificateValidationContext.getTrustedCa().getInlineBytes().toStringUtf8()) + .isEqualTo(getResourcesFileContent(caFileName)); + } + + static Secret getOneTlsCertSecret(String name, String keyFileName, String certFileName) + throws IOException { + TlsCertificate tlsCertificate = + TlsCertificate.newBuilder() + .setPrivateKey( + DataSource.newBuilder() + .setInlineBytes(ByteString.copyFromUtf8(getResourcesFileContent(keyFileName))) + .build()) + .setCertificateChain( + DataSource.newBuilder() + .setInlineBytes(ByteString.copyFromUtf8(getResourcesFileContent(certFileName))) + .build()) + .build(); + return Secret.newBuilder().setName(name).setTlsCertificate(tlsCertificate).build(); + } + + private Secret getOneCertificateValidationContextSecret(String name, String trustFileName) + throws IOException { + CertificateValidationContext certificateValidationContext = + CertificateValidationContext.newBuilder() + .setTrustedCa( + DataSource.newBuilder() + .setInlineBytes(ByteString.copyFromUtf8(getResourcesFileContent(trustFileName))) + .build()) + .build(); + + return Secret.newBuilder() + .setName(name) + .setValidationContext(certificateValidationContext) + .build(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/sds/SdsClientUdsTest.java b/xds/src/test/java/io/grpc/xds/sds/SdsClientUdsTest.java new file mode 100644 index 0000000000..e6b7d3372f --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/sds/SdsClientUdsTest.java @@ -0,0 +1,136 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.sds; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.util.concurrent.MoreExecutors; +import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig; +import io.envoyproxy.envoy.api.v2.auth.Secret; +import io.envoyproxy.envoy.api.v2.core.ApiConfigSource; +import io.envoyproxy.envoy.api.v2.core.ConfigSource; +import io.envoyproxy.envoy.api.v2.core.GrpcService; +import io.envoyproxy.envoy.api.v2.core.Node; +import io.netty.channel.epoll.Epoll; +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import org.junit.After; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentMatchers; + +/** Unit tests for {@link SdsClient} using UDS transport. */ +@RunWith(JUnit4.class) +public class SdsClientUdsTest { + + private static final String SERVER_0_PEM_FILE = "server0.pem"; + private static final String SERVER_0_KEY_FILE = "server0.key"; + private static final String SERVER_1_PEM_FILE = "server1.pem"; + private static final String SERVER_1_KEY_FILE = "server1.key"; + private static final String SDSCLIENT_TEST_SOCKET = "/tmp/sdsclient-test.socket"; + + private TestSdsServer.ServerMock serverMock; + private TestSdsServer server; + private SdsClient sdsClient; + private Node node; + private SdsSecretConfig sdsSecretConfig; + + private static ConfigSource buildConfigSource(String targetUri) { + return ConfigSource.newBuilder() + .setApiConfigSource( + ApiConfigSource.newBuilder() + .setApiType(ApiConfigSource.ApiType.GRPC) + .addGrpcServices( + GrpcService.newBuilder() + .setGoogleGrpc( + GrpcService.GoogleGrpc.newBuilder().setTargetUri(targetUri).build()) + .build()) + .build()) + .build(); + } + + @Before + public void setUp() throws IOException { + Assume.assumeTrue(Epoll.isAvailable()); + serverMock = mock(TestSdsServer.ServerMock.class); + server = new TestSdsServer(serverMock); + server.startServer(SDSCLIENT_TEST_SOCKET, true); + ConfigSource configSource = buildConfigSource("unix:" + SDSCLIENT_TEST_SOCKET); + sdsSecretConfig = + SdsSecretConfig.newBuilder().setSdsConfig(configSource).setName("name1").build(); + node = Node.newBuilder().setId("sds-client-temp-test2").build(); + sdsClient = + SdsClient.Factory.createWithNettyChannel( + sdsSecretConfig, node, MoreExecutors.directExecutor(), MoreExecutors.directExecutor()); + sdsClient.start(); + } + + @After + public void teardown() throws InterruptedException { + if (sdsClient != null) { + sdsClient.shutdown(); + } + if (server != null) { + server.shutdown(); + } + } + + @Test + public void testSecretWatcher_tlsCertificate() throws IOException, InterruptedException { + final SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class); + + when(serverMock.getSecretFor("name1")) + .thenReturn( + SdsClientTest.getOneTlsCertSecret("name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE)); + + sdsClient.watchSecret(mockWatcher); + // wait until our server received the requests + assertThat(server.requestsCounter.tryAcquire(2, 1000, TimeUnit.MILLISECONDS)).isTrue(); + SdsClientTest.verifyDiscoveryRequest(server.lastGoodRequest, "", "", node, "name1"); + SdsClientTest.verifySecretWatcher(mockWatcher, "name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE); + SdsClientTest.verifyDiscoveryRequest( + server.lastRequestOnlyForAck, + server.lastResponse.getVersionInfo(), + server.lastResponse.getNonce(), + node, + "name1"); + + reset(mockWatcher); + when(serverMock.getSecretFor("name1")) + .thenReturn( + SdsClientTest.getOneTlsCertSecret("name1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE)); + server.generateAsyncResponse("name1"); + // wait until our server received the request + assertThat(server.requestsCounter.tryAcquire(1, 1000, TimeUnit.MILLISECONDS)).isTrue(); + SdsClientTest.verifySecretWatcher(mockWatcher, "name1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE); + + reset(mockWatcher); + sdsClient.cancelSecretWatch(mockWatcher); + server.generateAsyncResponse("name1"); + // wait until our server received the request + assertThat(server.requestsCounter.tryAcquire(1, 1000, TimeUnit.MILLISECONDS)).isTrue(); + verify(mockWatcher, never()).onSecretChanged(ArgumentMatchers.any(Secret.class)); + } +} diff --git a/xds/src/test/java/io/grpc/xds/sds/TestSdsServer.java b/xds/src/test/java/io/grpc/xds/sds/TestSdsServer.java new file mode 100644 index 0000000000..f62f1deea7 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/sds/TestSdsServer.java @@ -0,0 +1,287 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.sds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.ProtocolStringList; +import io.envoyproxy.envoy.api.v2.DiscoveryRequest; +import io.envoyproxy.envoy.api.v2.DiscoveryResponse; +import io.envoyproxy.envoy.api.v2.auth.Secret; +import io.envoyproxy.envoy.service.discovery.v2.SecretDiscoveryServiceGrpc; +import io.grpc.Server; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.stub.StreamObserver; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerDomainSocketChannel; +import io.netty.channel.unix.DomainSocketAddress; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.Semaphore; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** Server used only in testing {@link SdsClient} so does not contain actual server logic. */ +final class TestSdsServer { + private static final Logger logger = Logger.getLogger(TestSdsServer.class.getName()); + + // SecretTypeURL defines the type URL for Envoy secret proto. + private static final String SECRET_TYPE_URL = "type.googleapis.com/envoy.api.v2.auth.Secret"; + + private String currentVersion; + private String lastRespondedNonce; + private List lastResourceNames; + @VisibleForTesting SecretDiscoveryServiceImpl discoveryService; + + /** Used for signalling test clients that request received and processed. */ + @VisibleForTesting final Semaphore requestsCounter = new Semaphore(0); + + private EventLoopGroup elg; + private EventLoopGroup boss; + private Server server; + private final ServerMock serverMock; + + /** last "good" discovery request we processed and sent a response to. */ + @VisibleForTesting DiscoveryRequest lastGoodRequest; + + /** last discovery request that was only used as Ack since it contained no new resources. */ + @VisibleForTesting DiscoveryRequest lastRequestOnlyForAck; + + /** last Nack. */ + @VisibleForTesting DiscoveryRequest lastNack; + + /** last response we sent. */ + @VisibleForTesting DiscoveryResponse lastResponse; + + TestSdsServer(ServerMock serverMock) { + checkNotNull(serverMock, "serverMock"); + this.serverMock = serverMock; + } + + /** + * Starts the server with given transport params. + * + * @param name UDS pathname or server name for {@link InProcessServerBuilder} + * @param useUds creates a UDS based server if true. + */ + void startServer(String name, boolean useUds) throws IOException { + checkNotNull(name, "name"); + discoveryService = new SecretDiscoveryServiceImpl(); + if (useUds) { + elg = new EpollEventLoopGroup(); + boss = new EpollEventLoopGroup(1); + server = + NettyServerBuilder.forAddress(new DomainSocketAddress(name)) + .bossEventLoopGroup(boss) + .workerEventLoopGroup(elg) + .channelType(EpollServerDomainSocketChannel.class) + .addService(discoveryService) + .directExecutor() + .build() + .start(); + } else { + server = + InProcessServerBuilder.forName(name) + .addService(discoveryService) + .directExecutor() + .build() + .start(); + } + } + + @SuppressWarnings("FutureReturnValueIgnored") + void shutdown() throws InterruptedException { + server.shutdown(); + if (boss != null) { + boss.shutdownGracefully().sync(); + } + if (elg != null) { + elg.shutdownGracefully().sync(); + } + } + + /** Interface that allows mocking server behavior. */ + interface ServerMock { + Secret getSecretFor(String name); + + boolean onNext(DiscoveryRequest discoveryRequest); + } + + /** Callers can call this to return an "async" response for given resource names. */ + void generateAsyncResponse(String... names) { + discoveryService.inboundStreamObserver.generateAsyncResponse(Arrays.asList(names)); + } + + /** Main streaming service implementation. */ + final class SecretDiscoveryServiceImpl + extends SecretDiscoveryServiceGrpc.SecretDiscoveryServiceImplBase { + + // we use startTime for generating version string. + final long startTime = System.nanoTime(); + SdsInboundStreamObserver inboundStreamObserver; + + @Override + public StreamObserver streamSecrets( + StreamObserver responseObserver) { + checkNotNull(responseObserver, "responseObserver"); + inboundStreamObserver = new SdsInboundStreamObserver(responseObserver); + return inboundStreamObserver; + } + + @Override + public void fetchSecrets( + DiscoveryRequest discoveryRequest, StreamObserver responseObserver) { + throw new UnsupportedOperationException("unary fetchSecrets not implemented!"); + } + + private DiscoveryResponse buildResponse(DiscoveryRequest discoveryRequest) { + checkNotNull(discoveryRequest, "discoveryRequest"); + String requestVersion = discoveryRequest.getVersionInfo(); + String requestNonce = discoveryRequest.getResponseNonce(); + ProtocolStringList resourceNames = discoveryRequest.getResourceNamesList(); + return buildResponse(requestVersion, requestNonce, resourceNames, false, discoveryRequest); + } + + private DiscoveryResponse buildResponse( + String requestVersion, + String requestNonce, + List resourceNames, + boolean forcedAsync, + DiscoveryRequest discoveryRequest) { + checkNotNull(resourceNames, "resourceNames"); + if (discoveryRequest != null && discoveryRequest.hasErrorDetail()) { + lastNack = discoveryRequest; + return null; + } + // for stale version or nonce don't send a response + if (!Strings.isNullOrEmpty(requestVersion) && !requestVersion.equals(currentVersion)) { + return null; + } + if (!Strings.isNullOrEmpty(requestNonce) && !requestNonce.equals(lastRespondedNonce)) { + return null; + } + // check if any new resources are being requested... + if (!forcedAsync && isSubset(resourceNames, lastResourceNames)) { + if (discoveryRequest != null) { + lastRequestOnlyForAck = discoveryRequest; + } + return null; + } + + final String version = generateVersionFromCurrentTime(); + DiscoveryResponse.Builder responseBuilder = + DiscoveryResponse.newBuilder() + .setVersionInfo(version) + .setNonce(generateAndSaveNonce()) + .setTypeUrl(SECRET_TYPE_URL); + + for (String resourceName : resourceNames) { + buildAndAddResource(responseBuilder, resourceName); + } + DiscoveryResponse response = responseBuilder.build(); + currentVersion = version; + lastResponse = response; + lastResourceNames = resourceNames; + return response; + } + + private String generateVersionFromCurrentTime() { + return "" + ((System.nanoTime() - startTime) / 1000000L); + } + + private void buildAndAddResource( + DiscoveryResponse.Builder responseBuilder, String resourceName) { + Secret secret = serverMock.getSecretFor(resourceName); + ByteString data = secret.toByteString(); + Any anyValue = Any.newBuilder().setTypeUrl(SECRET_TYPE_URL).setValue(data).build(); + responseBuilder.addResources(anyValue); + } + + /** Inbound {@link StreamObserver} to process incoming requests. */ + final class SdsInboundStreamObserver implements StreamObserver { + final StreamObserver responseObserver; + ScheduledExecutorService periodicScheduler; + ScheduledFuture future; + + SdsInboundStreamObserver(StreamObserver responseObserver) { + this.responseObserver = responseObserver; + } + + private void generateAsyncResponse(List nameList) { + checkNotNull(nameList, "nameList"); + if (!nameList.isEmpty()) { + responseObserver.onNext( + buildResponse( + currentVersion, + lastRespondedNonce, + nameList, + /* forcedAsync= */ true, + /* discoveryRequest= */ null)); + } + } + + @Override + public void onNext(DiscoveryRequest discoveryRequest) { + if (!serverMock.onNext(discoveryRequest)) { + DiscoveryResponse discoveryResponse = buildResponse(discoveryRequest); + if (discoveryResponse != null) { + lastGoodRequest = discoveryRequest; + responseObserver.onNext(discoveryResponse); + } + } + requestsCounter.release(); + } + + @Override + public void onError(Throwable t) { + logger.log(Level.SEVERE, "onError", t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + } + } + + /** Checks if resourceNames is a "subset" of lastResourceNames. */ + private static boolean isSubset( + @Nullable List resourceNames, @Nullable List lastResourceNames) { + if (lastResourceNames == null) { + return resourceNames == null || resourceNames.isEmpty(); + } + if (resourceNames == null) { + return true; // since lastResourceNames is NOT null + } + return lastResourceNames.containsAll(resourceNames); + } + + private String generateAndSaveNonce() { + lastRespondedNonce = Long.toHexString(System.currentTimeMillis()); + return lastRespondedNonce; + } +}