xds: eliminate downstreamTlsContext from XdsServerBuilder (#6901)

* xds: eliminate downstreamTlsContext from XdsServerBuilder


Co-authored-by: Jihun Cho <jihuncho@google.com>
This commit is contained in:
sanjaypujare 2020-04-13 17:37:26 -07:00 committed by GitHub
parent 2cc46acc55
commit 2f07c83fed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 469 additions and 424 deletions

View File

@ -70,6 +70,7 @@ public final class XdsClientWrapperForServerSds {
@Nullable private final XdsClient xdsClient;
private final int port;
private final ScheduledExecutorService timeService;
private final XdsClient.ListenerWatcher listenerWatcher;
/**
* Factory method for creating a {@link XdsClientWrapperForServerSds}.
@ -106,15 +107,14 @@ public final class XdsClientWrapperForServerSds {
this.port = port;
this.xdsClient = xdsClient;
this.timeService = timeService;
xdsClient.watchListenerData(
port,
this.listenerWatcher =
new XdsClient.ListenerWatcher() {
@Override
public void onListenerChanged(XdsClient.ListenerUpdate update) {
logger.log(
Level.INFO,
"Setting myListener from ConfigUpdate listener :{0}",
update.getListener().toString());
"Setting myListener from ConfigUpdate listener: {0}",
update.getListener());
curListener = update.getListener();
}
@ -126,9 +126,10 @@ public final class XdsClientWrapperForServerSds {
curListener = null;
}
// TODO(sanjaypujare): Implement logic for other cases based on final design.
logger.log(Level.SEVERE, "ListenerWatcher in XdsClientWrapperForServerSds:{0}", error);
logger.log(Level.SEVERE, "ListenerWatcher in XdsClientWrapperForServerSds: {0}", error);
}
});
};
xdsClient.watchListenerData(port, listenerWatcher);
}
/**
@ -157,6 +158,11 @@ public final class XdsClientWrapperForServerSds {
return null;
}
@VisibleForTesting
XdsClient.ListenerWatcher getListenerWatcher() {
return listenerWatcher;
}
private static final class FilterChainComparator implements Comparator<FilterChain> {
private final InetSocketAddress localAddress;

View File

@ -16,7 +16,6 @@
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
@ -74,22 +73,19 @@ public final class SdsProtocolNegotiators {
/**
* Creates an SDS based {@link ProtocolNegotiator} for a {@link io.grpc.netty.NettyServerBuilder}.
* Passing {@code null} for downstreamTlsContext will fall back to plaintext.
* If xDS returns no DownstreamTlsContext, it will fall back to plaintext.
*
* @param downstreamTlsContext passed in {@link XdsServerBuilder#tlsContext}.
* @param port the listening port passed to {@link XdsServerBuilder#forPort(int)}.
*/
public static ProtocolNegotiator serverProtocolNegotiator(
@Nullable DownstreamTlsContext downstreamTlsContext, int port,
SynchronizationContext syncContext) {
int port, SynchronizationContext syncContext) {
XdsClientWrapperForServerSds xdsClientWrapperForServerSds =
ServerSdsProtocolNegotiator.getXdsClientWrapperForServerSds(port, syncContext);
if (xdsClientWrapperForServerSds == null && downstreamTlsContext == null) {
if (xdsClientWrapperForServerSds == null) {
logger.log(Level.INFO, "Fallback to plaintext for server at port {0}", port);
return InternalProtocolNegotiators.serverPlaintext();
} else {
return new ServerSdsProtocolNegotiator(downstreamTlsContext, xdsClientWrapperForServerSds);
return new ServerSdsProtocolNegotiator(xdsClientWrapperForServerSds);
}
}
@ -267,18 +263,13 @@ public final class SdsProtocolNegotiators {
@VisibleForTesting
public static final class ServerSdsProtocolNegotiator implements ProtocolNegotiator {
@Nullable private final DownstreamTlsContext downstreamTlsContext;
@Nullable private final XdsClientWrapperForServerSds xdsClientWrapperForServerSds;
private final XdsClientWrapperForServerSds xdsClientWrapperForServerSds;
/** Constructor. */
@VisibleForTesting
public ServerSdsProtocolNegotiator(
@Nullable DownstreamTlsContext downstreamTlsContext,
@Nullable XdsClientWrapperForServerSds xdsClientWrapperForServerSds) {
checkArgument(downstreamTlsContext != null || xdsClientWrapperForServerSds != null,
"both downstreamTlsContext and xdsClientWrapperForServerSds cannot be null");
this.downstreamTlsContext = downstreamTlsContext;
this.xdsClientWrapperForServerSds = xdsClientWrapperForServerSds;
public ServerSdsProtocolNegotiator(XdsClientWrapperForServerSds xdsClientWrapperForServerSds) {
this.xdsClientWrapperForServerSds =
checkNotNull(xdsClientWrapperForServerSds, "xdsClientWrapperForServerSds");
}
private static XdsClientWrapperForServerSds getXdsClientWrapperForServerSds(
@ -299,8 +290,7 @@ public final class SdsProtocolNegotiators {
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
return new HandlerPickerHandler(grpcHandler, downstreamTlsContext,
xdsClientWrapperForServerSds);
return new HandlerPickerHandler(grpcHandler, xdsClientWrapperForServerSds);
}
@Override
@ -315,16 +305,13 @@ public final class SdsProtocolNegotiators {
static final class HandlerPickerHandler
extends ChannelInboundHandlerAdapter {
private final GrpcHttp2ConnectionHandler grpcHandler;
private final DownstreamTlsContext downstreamTlsContextFromBuilder;
private final XdsClientWrapperForServerSds xdsClientWrapperForServerSds;
HandlerPickerHandler(
GrpcHttp2ConnectionHandler grpcHandler,
@Nullable DownstreamTlsContext downstreamTlsContext,
@Nullable XdsClientWrapperForServerSds xdsClientWrapperForServerSds) {
checkNotNull(grpcHandler, "grpcHandler");
this.grpcHandler = grpcHandler;
this.downstreamTlsContextFromBuilder = downstreamTlsContext;
this.xdsClientWrapperForServerSds = xdsClientWrapperForServerSds;
}
@ -339,9 +326,6 @@ public final class SdsProtocolNegotiators {
xdsClientWrapperForServerSds == null
? null
: xdsClientWrapperForServerSds.getDownstreamTlsContext(ctx.channel());
if (isTlsContextEmpty(downstreamTlsContext)) {
downstreamTlsContext = downstreamTlsContextFromBuilder;
}
if (isTlsContextEmpty(downstreamTlsContext)) {
logger.log(Level.INFO, "Fallback to plaintext for {0}", ctx.channel().localAddress());
ctx.pipeline()

View File

@ -17,7 +17,6 @@
package io.grpc.xds.internal.sds;
import com.google.common.annotations.VisibleForTesting;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.grpc.BindableService;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
@ -51,9 +50,6 @@ public final class XdsServerBuilder extends ServerBuilder<XdsServerBuilder> {
private final NettyServerBuilder delegate;
private final int port;
// TODO (sanjaypujare) integrate with xDS client to get downstreamTlsContext from LDS
@Nullable private DownstreamTlsContext downstreamTlsContext;
private XdsServerBuilder(NettyServerBuilder nettyDelegate, int port) {
this.delegate = nettyDelegate;
this.port = port;
@ -130,15 +126,6 @@ public final class XdsServerBuilder extends ServerBuilder<XdsServerBuilder> {
return this;
}
/**
* Set the DownstreamTlsContext for the server. This is a temporary workaround until integration
* with xDS client is implemented to get LDS. Passing {@code null} will fall back to plaintext.
*/
public XdsServerBuilder tlsContext(@Nullable DownstreamTlsContext downstreamTlsContext) {
this.downstreamTlsContext = downstreamTlsContext;
return this;
}
/** Creates a gRPC server builder for the given port. */
public static XdsServerBuilder forPort(int port) {
NettyServerBuilder nettyDelegate = NettyServerBuilder.forAddress(new InetSocketAddress(port));
@ -173,8 +160,7 @@ public final class XdsServerBuilder extends ServerBuilder<XdsServerBuilder> {
}
});
InternalProtocolNegotiator.ProtocolNegotiator serverProtocolNegotiator =
SdsProtocolNegotiators.serverProtocolNegotiator(
this.downstreamTlsContext, port, syncContext);
SdsProtocolNegotiators.serverProtocolNegotiator(port, syncContext);
return buildServer(serverProtocolNegotiator);
}

View File

@ -20,6 +20,11 @@ import static com.google.common.truth.Truth.assertThat;
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import static io.grpc.xds.XdsLbPolicies.EDS_POLICY_NAME;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.same;
@ -56,7 +61,7 @@ import io.grpc.xds.XdsClient.EndpointWatcher;
import io.grpc.xds.XdsClient.RefCountedXdsClientObjectPool;
import io.grpc.xds.XdsClient.XdsClientFactory;
import io.grpc.xds.XdsLoadBalancerProvider.XdsConfig;
import io.grpc.xds.internal.sds.SecretVolumeSslContextProviderTest;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
import io.grpc.xds.internal.sds.SslContextProvider;
import io.grpc.xds.internal.sds.TlsContextManager;
import java.net.InetSocketAddress;
@ -79,11 +84,6 @@ import org.mockito.MockitoAnnotations;
*/
@RunWith(JUnit4.class)
public class CdsLoadBalancerTest {
private static final String CLIENT_PEM_FILE = "client.pem";
private static final String CLIENT_KEY_FILE = "client.key";
private static final String BADCLIENT_PEM_FILE = "badclient.pem";
private static final String BADCLIENT_KEY_FILE = "badclient.key";
private static final String CA_PEM_FILE = "ca.pem";
private final RefCountedXdsClientObjectPool xdsClientPool = new RefCountedXdsClientObjectPool(
new XdsClientFactory() {
@ -356,7 +356,7 @@ public class CdsLoadBalancerTest {
verify(xdsClient).watchClusterData(eq("foo.googleapis.com"), clusterWatcherCaptor1.capture());
UpstreamTlsContext upstreamTlsContext =
SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider<UpstreamTlsContext> mockSslContextProvider =
@ -415,8 +415,8 @@ public class CdsLoadBalancerTest {
reset(mockTlsContextManager);
reset(helper);
UpstreamTlsContext upstreamTlsContext1 =
SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames(
BADCLIENT_KEY_FILE, BADCLIENT_PEM_FILE, CA_PEM_FILE);
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider<UpstreamTlsContext> mockSslContextProvider1 =
(SslContextProvider<UpstreamTlsContext>) mock(SslContextProvider.class);
doReturn(upstreamTlsContext1).when(mockSslContextProvider1).getSource();

View File

@ -18,6 +18,7 @@ package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ -139,6 +140,26 @@ public class XdsClientWrapperForServerSdsTest {
private XdsClientWrapperForServerSds xdsClientWrapperForServerSds;
private final DownstreamTlsContext[] tlsContexts = new DownstreamTlsContext[3];
/** Creates XdsClientWrapperForServerSds: also used by other classes. */
public static XdsClientWrapperForServerSds createXdsClientWrapperForServerSds(
int port, DownstreamTlsContext downstreamTlsContext) {
XdsClient mockXdsClient = mock(XdsClient.class);
XdsClientWrapperForServerSds xdsClientWrapperForServerSds =
new XdsClientWrapperForServerSds(port, mockXdsClient, null);
generateListenerUpdateToWatcher(
port, downstreamTlsContext, xdsClientWrapperForServerSds.getListenerWatcher());
return xdsClientWrapperForServerSds;
}
static void generateListenerUpdateToWatcher(
int port, DownstreamTlsContext tlsContext, XdsClient.ListenerWatcher registeredWatcher) {
EnvoyServerProtoData.Listener listener =
XdsSdsClientServerTest.buildListener("listener1", "0.0.0.0", port, tlsContext);
XdsClient.ListenerUpdate listenerUpdate =
XdsClient.ListenerUpdate.newBuilder().setListener(listener).build();
registeredWatcher.onListenerChanged(listenerUpdate);
}
@Before
public void setUp() throws IOException {
MockitoAnnotations.initMocks(this);

View File

@ -0,0 +1,213 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.XdsClientWrapperForServerSdsTest.buildFilterChainMatch;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static org.junit.Assert.fail;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.grpc.Server;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.protobuf.SimpleRequest;
import io.grpc.testing.protobuf.SimpleResponse;
import io.grpc.testing.protobuf.SimpleServiceGrpc;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
import io.grpc.xds.internal.sds.SdsProtocolNegotiators;
import io.grpc.xds.internal.sds.XdsChannelBuilder;
import io.grpc.xds.internal.sds.XdsServerBuilder;
import java.io.IOException;
import java.net.ServerSocket;
import java.util.Arrays;
import javax.net.ssl.SSLHandshakeException;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/**
* Unit tests for {@link XdsChannelBuilder} and {@link XdsServerBuilder} for plaintext/TLS/mTLS
* modes.
*/
@RunWith(JUnit4.class)
public class XdsSdsClientServerTest {
@Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
private int port;
@Before
public void setUp() throws IOException {
port = findFreePort();
}
@Test
public void plaintextClientServer() throws IOException {
Server unused = buildServerWithTlsContext(/* downstreamTlsContext= */ null);
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(/* upstreamTlsContext= */ null, /* overrideAuthority= */ null);
assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy");
}
/** TLS channel - no mTLS. */
@Test
public void tlsClientServer_noClientAuthentication() throws IOException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, null);
Server unused = buildServerWithTlsContext(downstreamTlsContext);
// for TLS, client only needs trustCa
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE);
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr");
assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy");
}
/** mTLS - client auth enabled. */
@Test
public void mtlsClientServer_withClientAuthentication() throws IOException {
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext);
}
/** mTLS - client auth enabled then update server certs to untrusted. */
@Test
public void mtlsClientServer_changeServerContext_expectException() throws IOException {
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
XdsClient.ListenerWatcher listenerWatcher =
performMtlsTestAndGetListenerWatcher(upstreamTlsContext);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE);
XdsClientWrapperForServerSdsTest.generateListenerUpdateToWatcher(
port, downstreamTlsContext, listenerWatcher);
try {
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(upstreamTlsContext, "foo.test.google.fr");
assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy");
fail("exception expected");
} catch (StatusRuntimeException sre) {
assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class);
assertThat(sre).hasCauseThat().hasMessageThat().isEqualTo("General OpenSslEngine problem");
}
}
private XdsClient.ListenerWatcher performMtlsTestAndGetListenerWatcher(
UpstreamTlsContext upstreamTlsContext) throws IOException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
final XdsClientWrapperForServerSds xdsClientWrapperForServerSds =
XdsClientWrapperForServerSdsTest.createXdsClientWrapperForServerSds(
port, /* downstreamTlsContext= */ downstreamTlsContext);
SdsProtocolNegotiators.ServerSdsProtocolNegotiator serverSdsProtocolNegotiator =
new SdsProtocolNegotiators.ServerSdsProtocolNegotiator(xdsClientWrapperForServerSds);
Server unused = getServer(port, serverSdsProtocolNegotiator);
XdsClient.ListenerWatcher listenerWatcher = xdsClientWrapperForServerSds.getListenerWatcher();
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStub(upstreamTlsContext, "foo.test.google.fr");
assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy");
return listenerWatcher;
}
private Server buildServerWithTlsContext(DownstreamTlsContext downstreamTlsContext)
throws IOException {
final XdsClientWrapperForServerSds xdsClientWrapperForServerSds =
XdsClientWrapperForServerSdsTest.createXdsClientWrapperForServerSds(
port, /* downstreamTlsContext= */ downstreamTlsContext);
SdsProtocolNegotiators.ServerSdsProtocolNegotiator serverSdsProtocolNegotiator =
new SdsProtocolNegotiators.ServerSdsProtocolNegotiator(xdsClientWrapperForServerSds);
return getServer(port, serverSdsProtocolNegotiator);
}
private Server getServer(
int port, SdsProtocolNegotiators.ServerSdsProtocolNegotiator serverSdsProtocolNegotiator)
throws IOException {
XdsServerBuilder builder = XdsServerBuilder.forPort(port).addService(new SimpleServiceImpl());
return cleanupRule.register(builder.buildServer(serverSdsProtocolNegotiator)).start();
}
private static int findFreePort() throws IOException {
try (ServerSocket socket = new ServerSocket(0)) {
socket.setReuseAddress(true);
return socket.getLocalPort();
}
}
static EnvoyServerProtoData.Listener buildListener(
String name, String address, int port, DownstreamTlsContext tlsContext) {
EnvoyServerProtoData.FilterChainMatch filterChainMatch = buildFilterChainMatch(port, address);
EnvoyServerProtoData.FilterChain filterChain1 =
new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext);
EnvoyServerProtoData.Listener listener =
new EnvoyServerProtoData.Listener(name, address, Arrays.asList(filterChain1));
return listener;
}
private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub(
UpstreamTlsContext upstreamTlsContext, String overrideAuthority) {
XdsChannelBuilder builder =
XdsChannelBuilder.forTarget("localhost:" + port).tlsContext(upstreamTlsContext);
if (overrideAuthority != null) {
builder = builder.overrideAuthority(overrideAuthority);
}
return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(builder.build()));
}
/** Say hello to server. */
private static String unaryRpc(
String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) {
SimpleRequest request = SimpleRequest.newBuilder().setRequestMessage(requestMessage).build();
SimpleResponse response = blockingStub.unaryRpc(request);
return response.getResponseMessage();
}
private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase {
@Override
public void unaryRpc(SimpleRequest req, StreamObserver<SimpleResponse> responseObserver) {
SimpleResponse response =
SimpleResponse.newBuilder()
.setResponseMessage("Hello " + req.getRequestMessage())
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
}
}
}

View File

@ -44,7 +44,7 @@ public class XdsServerBuilderTest {
XdsClientWrapperForServerSds xdsClientWrapperForServerSds =
new XdsClientWrapperForServerSds(port, mockXdsClient, null);
ServerSdsProtocolNegotiator serverSdsProtocolNegotiator =
new ServerSdsProtocolNegotiator(null, xdsClientWrapperForServerSds);
new ServerSdsProtocolNegotiator(xdsClientWrapperForServerSds);
Server xdsServer = builder.buildServer(serverSdsProtocolNegotiator);
xdsServer.start();
xdsServer.shutdown();

View File

@ -17,6 +17,9 @@
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
@ -29,17 +32,13 @@ import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class ClientSslContextProviderFactoryTest {
private static final String CLIENT_PEM_FILE = "client.pem";
private static final String CLIENT_KEY_FILE = "client.key";
private static final String CA_PEM_FILE = "ca.pem";
ClientSslContextProviderFactory clientSslContextProviderFactory =
new ClientSslContextProviderFactory();
@Test
public void createSslContextProvider_allFilenames() {
UpstreamTlsContext upstreamTlsContext =
SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider<UpstreamTlsContext> sslContextProvider =

View File

@ -23,12 +23,29 @@ import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValid
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.envoyproxy.envoy.api.v2.core.DataSource;
import io.grpc.internal.testing.TestUtils;
import java.io.IOException;
import java.util.Arrays;
import javax.annotation.Nullable;
/** Utility class for client and server ssl provider tests. */
public class CommonTlsContextTestsUtil {
public static final String SERVER_0_PEM_FILE = "server0.pem";
public static final String SERVER_0_KEY_FILE = "server0.key";
public static final String SERVER_1_PEM_FILE = "server1.pem";
public static final String SERVER_1_KEY_FILE = "server1.key";
public static final String CLIENT_PEM_FILE = "client.pem";
public static final String CLIENT_KEY_FILE = "client.key";
public static final String CA_PEM_FILE = "ca.pem";
/** Bad/untrusted server certs. */
public static final String BAD_SERVER_PEM_FILE = "badserver.pem";
public static final String BAD_SERVER_KEY_FILE = "badserver.key";
public static final String BAD_CLIENT_PEM_FILE = "badclient.pem";
public static final String BAD_CLIENT_KEY_FILE = "badclient.key";
static SdsSecretConfig buildSdsSecretConfig(String name, String targetUri, String channelType) {
SdsSecretConfig sdsSecretConfig = null;
if (!Strings.isNullOrEmpty(name) && !Strings.isNullOrEmpty(targetUri)) {
@ -144,4 +161,85 @@ public class CommonTlsContextTestsUtil {
Arrays.asList("managed-tls"),
null));
}
static String getTempFileNameForResourcesFile(String resFile) throws IOException {
return TestUtils.loadCert(resFile).getAbsolutePath();
}
/**
* Helper method to build DownstreamTlsContext for above tests. Called from other classes as well.
*/
public static DownstreamTlsContext buildDownstreamTlsContextFromFilenames(
@Nullable String privateKey, @Nullable String certChain, @Nullable String trustCa) {
// get temp file for each file
try {
if (certChain != null) {
certChain = getTempFileNameForResourcesFile(certChain);
}
if (privateKey != null) {
privateKey = getTempFileNameForResourcesFile(privateKey);
}
if (trustCa != null) {
trustCa = getTempFileNameForResourcesFile(trustCa);
}
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
return buildDownstreamTlsContext(
buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
}
/**
* Helper method to build UpstreamTlsContext for above tests. Called from other classes as well.
*/
public static UpstreamTlsContext buildUpstreamTlsContextFromFilenames(
@Nullable String privateKey, @Nullable String certChain, @Nullable String trustCa) {
try {
if (certChain != null) {
certChain = getTempFileNameForResourcesFile(certChain);
}
if (privateKey != null) {
privateKey = getTempFileNameForResourcesFile(privateKey);
}
if (trustCa != null) {
trustCa = getTempFileNameForResourcesFile(trustCa);
}
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
return SecretVolumeSslContextProviderTest.buildUpstreamTlsContext(
buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
}
private static CommonTlsContext buildCommonTlsContextFromFilenames(
String privateKey, String certChain, String trustCa) {
TlsCertificate tlsCert = null;
if (!Strings.isNullOrEmpty(privateKey) && !Strings.isNullOrEmpty(certChain)) {
tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setFilename(certChain))
.setPrivateKey(DataSource.newBuilder().setFilename(privateKey))
.build();
}
CertificateValidationContext certContext = null;
if (!Strings.isNullOrEmpty(trustCa)) {
certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename(trustCa))
.build();
}
return getCommonTlsContext(tlsCert, certContext);
}
static CommonTlsContext getCommonTlsContext(
TlsCertificate tlsCertificate, CertificateValidationContext certContext) {
CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
if (tlsCertificate != null) {
builder = builder.addTlsCertificates(tlsCertificate);
}
if (certContext != null) {
builder = builder.setValidationContext(certContext);
}
return builder.build();
}
}

View File

@ -17,6 +17,8 @@
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
@ -217,9 +219,7 @@ public class SdsClientFileBasedMetadataTest {
public void testSecretWatcher_tlsCertificate() throws IOException {
SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class);
doReturn(
SdsClientTest.getOneTlsCertSecret(
"name1", SdsClientTest.SERVER_0_KEY_FILE, SdsClientTest.SERVER_0_PEM_FILE))
doReturn(SdsClientTest.getOneTlsCertSecret("name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE))
.when(serverMock)
.getSecretFor("name1");

View File

@ -17,6 +17,11 @@
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doThrow;
@ -65,12 +70,6 @@ import org.mockito.stubbing.Answer;
@RunWith(JUnit4.class)
public class SdsClientTest {
static final String SERVER_0_PEM_FILE = "server0.pem";
static final String SERVER_0_KEY_FILE = "server0.key";
static final String SERVER_1_PEM_FILE = "server1.pem";
static final String SERVER_1_KEY_FILE = "server1.key";
static final String CA_PEM_FILE = "ca.pem";
private TestSdsServer.ServerMock serverMock;
private TestSdsServer server;
private SdsClient sdsClient;

View File

@ -17,6 +17,8 @@
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
@ -94,9 +96,7 @@ public class SdsClientUdsFileBasedMetadataTest {
public void testSecretWatcher_tlsCertificate() throws IOException, InterruptedException {
final SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class);
doReturn(
SdsClientTest.getOneTlsCertSecret(
"name1", SdsClientTest.SERVER_0_KEY_FILE, SdsClientTest.SERVER_0_PEM_FILE))
doReturn(SdsClientTest.getOneTlsCertSecret("name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE))
.when(serverMock)
.getSecretFor("name1");

View File

@ -17,6 +17,10 @@
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
@ -45,10 +49,6 @@ import org.mockito.ArgumentMatchers;
@RunWith(JUnit4.class)
public class SdsClientUdsTest {
private static final String SERVER_0_PEM_FILE = "server0.pem";
private static final String SERVER_0_KEY_FILE = "server0.key";
private static final String SERVER_1_PEM_FILE = "server1.pem";
private static final String SERVER_1_KEY_FILE = "server1.key";
private static final String SDSCLIENT_TEST_SOCKET = "/tmp/sdsclient-test.socket";
private TestSdsServer.ServerMock serverMock;

View File

@ -17,6 +17,11 @@
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
@ -31,6 +36,8 @@ import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalProtocolNegotiationEvent;
import io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.xds.XdsClientWrapperForServerSds;
import io.grpc.xds.XdsClientWrapperForServerSdsTest;
import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ClientSdsHandler;
import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ClientSdsProtocolNegotiator;
import io.netty.channel.ChannelHandler;
@ -49,6 +56,8 @@ import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Iterator;
import java.util.Map;
import org.junit.Test;
@ -59,12 +68,6 @@ import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class SdsProtocolNegotiatorsTest {
private static final String SERVER_1_PEM_FILE = "server1.pem";
private static final String SERVER_1_KEY_FILE = "server1.key";
private static final String CLIENT_PEM_FILE = "client.pem";
private static final String CLIENT_KEY_FILE = "client.key";
private static final String CA_PEM_FILE = "ca.pem";
private final GrpcHttp2ConnectionHandler grpcHandler =
FakeGrpcHttp2ConnectionHandler.newHandler();
@ -153,7 +156,8 @@ public class SdsProtocolNegotiatorsTest {
@Test
public void clientSdsProtocolNegotiatorNewHandler_nonNullTlsContext() {
UpstreamTlsContext upstreamTlsContext =
buildUpstreamTlsContext(getCommonTlsContext(null, null));
buildUpstreamTlsContext(
getCommonTlsContext(/* tlsCertificate= */ null, /* certContext= */ null));
ClientSdsProtocolNegotiator pn = new ClientSdsProtocolNegotiator(upstreamTlsContext);
ChannelHandler newHandler = pn.newHandler(grpcHandler);
assertThat(newHandler).isNotNull();
@ -187,11 +191,23 @@ public class SdsProtocolNegotiatorsTest {
@Test
public void serverSdsHandler_addLast() throws IOException {
// we need InetSocketAddress instead of EmbeddedSocketAddress as localAddress for this test
channel =
new EmbeddedChannel() {
@Override
public SocketAddress localAddress() {
return new InetSocketAddress("172.168.1.1", 80);
}
};
pipeline = channel.pipeline();
DownstreamTlsContext downstreamTlsContext =
buildDownstreamTlsContextFromFilenames(SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
XdsClientWrapperForServerSds xdsClientWrapperForServerSds =
XdsClientWrapperForServerSdsTest.createXdsClientWrapperForServerSds(
80, downstreamTlsContext);
SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler =
new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, downstreamTlsContext, null);
new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, xdsClientWrapperForServerSds);
pipeline.addLast(handlerPickerHandler);
channelHandlerCtx = pipeline.context(handlerPickerHandler);
assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler
@ -217,7 +233,8 @@ public class SdsProtocolNegotiatorsTest {
@Test
public void serverSdsHandler_nullTlsContext_expectPlaintext() throws IOException {
SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler =
new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, null, null);
new SdsProtocolNegotiators.HandlerPickerHandler(
grpcHandler, /* xdsClientWrapperForServerSds= */ null);
pipeline.addLast(handlerPickerHandler);
channelHandlerCtx = pipeline.context(handlerPickerHandler);
assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler
@ -259,10 +276,9 @@ public class SdsProtocolNegotiatorsTest {
}
@Test
public void serverSdsProtocolNegotiator_passNulls_expectPlaintext() {
public void serverSdsProtocolNegotiator_nullSyncContext_expectPlaintext() {
InternalProtocolNegotiator.ProtocolNegotiator protocolNegotiator =
SdsProtocolNegotiators.serverProtocolNegotiator(null, 7000,
null);
SdsProtocolNegotiators.serverProtocolNegotiator(/* port= */ 7000, /* syncContext= */ null);
assertThat(protocolNegotiator.scheme().toString()).isEqualTo("http");
}

View File

@ -17,6 +17,11 @@
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static io.grpc.xds.internal.sds.SdsClientTest.getOneCertificateValidationContextSecret;
import static io.grpc.xds.internal.sds.SdsClientTest.getOneTlsCertSecret;
import static io.grpc.xds.internal.sds.SecretVolumeSslContextProviderTest.doChecksOnSslContext;
@ -39,12 +44,6 @@ import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class SdsSslContextProviderTest {
private static final String SERVER_1_PEM_FILE = "server1.pem";
private static final String SERVER_1_KEY_FILE = "server1.key";
private static final String CLIENT_PEM_FILE = "client.pem";
private static final String CLIENT_KEY_FILE = "client.key";
private static final String CA_PEM_FILE = "ca.pem";
private TestSdsServer.ServerMock serverMock;
private TestSdsServer server;
private Node node;
@ -182,8 +181,7 @@ public class SdsSslContextProviderTest {
when(serverMock.getSecretFor(/* name= */ "cert1"))
.thenReturn(getOneTlsCertSecret(/* name= */ "cert1", CLIENT_KEY_FILE, CLIENT_PEM_FILE));
when(serverMock.getSecretFor("valid1"))
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1",
CA_PEM_FILE));
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider =
getSdsSslContextProvider(

View File

@ -17,16 +17,18 @@
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import com.google.common.base.Strings;
import com.google.common.util.concurrent.MoreExecutors;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.envoyproxy.envoy.api.v2.core.DataSource;
import io.grpc.internal.testing.TestUtils;
import io.netty.handler.ssl.SslContext;
import java.io.IOException;
import java.security.cert.CertStoreException;
@ -43,12 +45,6 @@ import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class SecretVolumeSslContextProviderTest {
private static final String SERVER_1_PEM_FILE = "server1.pem";
private static final String SERVER_1_KEY_FILE = "server1.key";
private static final String CLIENT_PEM_FILE = "client.pem";
private static final String CLIENT_KEY_FILE = "client.key";
private static final String CA_PEM_FILE = "ca.pem";
@Rule public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Test
@ -277,8 +273,8 @@ public class SecretVolumeSslContextProviderTest {
TlsCertificate tlsCert = TlsCertificate.getDefaultInstance();
try {
SecretVolumeSslContextProvider.getProviderForServer(
CommonTlsContextTestsUtil
.buildDownstreamTlsContext(getCommonTlsContext(tlsCert, /* certContext= */ null)));
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, /* certContext= */ null)));
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -298,8 +294,8 @@ public class SecretVolumeSslContextProviderTest {
.build();
try {
SecretVolumeSslContextProvider.getProviderForServer(
CommonTlsContextTestsUtil
.buildDownstreamTlsContext(getCommonTlsContext(tlsCert, certContext)));
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext)));
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected.getMessage()).isEqualTo("filename expected");
@ -311,7 +307,9 @@ public class SecretVolumeSslContextProviderTest {
CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance();
try {
SecretVolumeSslContextProvider.getProviderForClient(
buildUpstreamTlsContext(getCommonTlsContext(/* tlsCertificate= */ null, certContext)));
buildUpstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(
/* tlsCertificate= */ null, certContext)));
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("certContext is required");
@ -331,7 +329,8 @@ public class SecretVolumeSslContextProviderTest {
.build();
try {
SecretVolumeSslContextProvider.getProviderForClient(
buildUpstreamTlsContext(getCommonTlsContext(tlsCert, certContext)));
buildUpstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext)));
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
@ -351,38 +350,27 @@ public class SecretVolumeSslContextProviderTest {
.build();
try {
SecretVolumeSslContextProvider.getProviderForClient(
buildUpstreamTlsContext(getCommonTlsContext(tlsCert, certContext)));
buildUpstreamTlsContext(
CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext)));
Assert.fail("no exception thrown");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("filename expected");
}
}
private static String getTempFileNameForResourcesFile(String resFile) throws IOException {
return TestUtils.loadCert(resFile).getAbsolutePath();
}
/** Helper method to build SecretVolumeSslContextProvider from given files. */
private static SecretVolumeSslContextProvider<?> getSslContextSecretVolumeSecretProvider(
boolean server, String certChainFilename, String privateKeyFilename, String trustedCaFilename)
throws IOException {
boolean server,
String certChainFilename,
String privateKeyFilename,
String trustedCaFilename) {
// get temp file for each file
if (certChainFilename != null) {
certChainFilename = getTempFileNameForResourcesFile(certChainFilename);
}
if (privateKeyFilename != null) {
privateKeyFilename = getTempFileNameForResourcesFile(privateKeyFilename);
}
if (trustedCaFilename != null) {
trustedCaFilename = getTempFileNameForResourcesFile(trustedCaFilename);
}
return server
? SecretVolumeSslContextProvider.getProviderForServer(
buildDownstreamTlsContextFromFilenames(
privateKeyFilename, certChainFilename, trustedCaFilename))
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
privateKeyFilename, certChainFilename, trustedCaFilename))
: SecretVolumeSslContextProvider.getProviderForClient(
buildUpstreamTlsContextFromFilenames(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
privateKeyFilename, certChainFilename, trustedCaFilename));
}
@ -416,56 +404,6 @@ public class SecretVolumeSslContextProviderTest {
}
}
/**
* Helper method to build DownstreamTlsContext for above tests. Called from other classes as well.
*/
static DownstreamTlsContext buildDownstreamTlsContextFromFilenames(
String privateKey, String certChain, String trustCa) {
return CommonTlsContextTestsUtil.buildDownstreamTlsContext(
buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
}
/**
* Helper method to build UpstreamTlsContext for above tests. Called from other classes as well.
*/
public static UpstreamTlsContext buildUpstreamTlsContextFromFilenames(
String privateKey, String certChain, String trustCa) {
return buildUpstreamTlsContext(
buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
}
private static CommonTlsContext buildCommonTlsContextFromFilenames(
String privateKey, String certChain, String trustCa) {
TlsCertificate tlsCert = null;
if (!Strings.isNullOrEmpty(privateKey) && !Strings.isNullOrEmpty(certChain)) {
tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setFilename(certChain))
.setPrivateKey(DataSource.newBuilder().setFilename(privateKey))
.build();
}
CertificateValidationContext certContext = null;
if (!Strings.isNullOrEmpty(trustCa)) {
certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename(trustCa))
.build();
}
return getCommonTlsContext(tlsCert, certContext);
}
private static CommonTlsContext getCommonTlsContext(
TlsCertificate tlsCertificate, CertificateValidationContext certContext) {
CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
if (tlsCertificate != null) {
builder = builder.addTlsCertificates(tlsCertificate);
}
if (certContext != null) {
builder = builder.setValidationContext(certContext);
}
return builder.build();
}
/**
* Helper method to build UpstreamTlsContext for above tests. Called from other classes as well.
*/

View File

@ -17,6 +17,9 @@
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
@ -29,18 +32,14 @@ import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class ServerSslContextProviderFactoryTest {
private static final String SERVER_PEM_FILE = "server1.pem";
private static final String SERVER_KEY_FILE = "server1.key";
private static final String CA_PEM_FILE = "ca.pem";
ServerSslContextProviderFactory serverSslContextProviderFactory =
new ServerSslContextProviderFactory();
@Test
public void createSslContextProvider_allFilenames() {
DownstreamTlsContext downstreamTlsContext =
SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames(
SERVER_KEY_FILE, SERVER_PEM_FILE, CA_PEM_FILE);
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
SslContextProvider<DownstreamTlsContext> sslContextProvider =
serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext);
@ -70,7 +69,7 @@ public class ServerSslContextProviderFactoryTest {
public void createSslContextProvider_sdsConfigForCertValidationContext_expectException() {
CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForValidationContext(
"name", "unix:/tmp/sds/path", SERVER_KEY_FILE, SERVER_PEM_FILE);
"name", "unix:/tmp/sds/path", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContext(commonTlsContext);

View File

@ -17,6 +17,13 @@
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
@ -40,14 +47,6 @@ import org.mockito.junit.MockitoRule;
@RunWith(JUnit4.class)
public class TlsContextManagerTest {
private static final String SERVER_0_PEM_FILE = "server0.pem";
private static final String SERVER_0_KEY_FILE = "server0.key";
private static final String SERVER_1_PEM_FILE = "server1.pem";
private static final String SERVER_1_KEY_FILE = "server1.key";
private static final String CLIENT_PEM_FILE = "client.pem";
private static final String CLIENT_KEY_FILE = "client.key";
private static final String CA_PEM_FILE = "ca.pem";
@Rule public final MockitoRule mockitoRule = MockitoJUnit.rule();
@Mock
@ -66,7 +65,7 @@ public class TlsContextManagerTest {
@Test
public void createServerSslContextProvider() {
DownstreamTlsContext downstreamTlsContext =
SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames(
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance();
@ -82,7 +81,7 @@ public class TlsContextManagerTest {
@Test
public void createClientSslContextProvider() {
UpstreamTlsContext upstreamTlsContext =
SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance();
@ -98,7 +97,7 @@ public class TlsContextManagerTest {
@Test
public void createServerSslContextProvider_differentInstance() {
DownstreamTlsContext downstreamTlsContext =
SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames(
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance();
@ -107,7 +106,7 @@ public class TlsContextManagerTest {
assertThat(serverSecretProvider).isNotNull();
DownstreamTlsContext downstreamTlsContext1 =
SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames(
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_0_KEY_FILE, SERVER_0_PEM_FILE, CA_PEM_FILE);
SslContextProvider<DownstreamTlsContext> serverSecretProvider1 =
tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext1);
@ -118,7 +117,7 @@ public class TlsContextManagerTest {
@Test
public void createClientSslContextProvider_differentInstance() {
UpstreamTlsContext upstreamTlsContext =
SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
/* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE);
TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance();
@ -127,7 +126,7 @@ public class TlsContextManagerTest {
assertThat(clientSecretProvider).isNotNull();
UpstreamTlsContext upstreamTlsContext1 =
SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SslContextProvider<UpstreamTlsContext> clientSecretProvider1 =
@ -138,7 +137,7 @@ public class TlsContextManagerTest {
@Test
public void createServerSslContextProvider_releaseInstance() {
DownstreamTlsContext downstreamTlsContext =
SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames(
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null);
TlsContextManagerImpl tlsContextManagerImpl =
@ -158,7 +157,7 @@ public class TlsContextManagerTest {
@Test
public void createClientSslContextProvider_releaseInstance() {
UpstreamTlsContext upstreamTlsContext =
SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames(
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
TlsContextManagerImpl tlsContextManagerImpl =

View File

@ -1,193 +0,0 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import com.google.protobuf.BoolValue;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.envoyproxy.envoy.api.v2.core.DataSource;
import io.grpc.Server;
import io.grpc.internal.testing.TestUtils;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.protobuf.SimpleRequest;
import io.grpc.testing.protobuf.SimpleResponse;
import io.grpc.testing.protobuf.SimpleServiceGrpc;
import java.io.IOException;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/**
* Unit tests for {@link XdsChannelBuilder} and {@link XdsServerBuilder} for plaintext/TLS/mTLS
* modes.
*/
@RunWith(JUnit4.class)
public class XdsSdsClientServerTest {
@Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
@Test
public void plaintextClientServer() throws IOException {
Server server = getXdsServer(/* downstreamTlsContext= */ null);
buildClientAndTest(
/* upstreamTlsContext= */ null, /* overrideAuthority= */ null, "buddy", server.getPort());
}
/** TLS channel - no mTLS. */
@Test
public void tlsClientServer_noClientAuthentication() throws IOException {
String server1Pem = TestUtils.loadCert("server1.pem").getAbsolutePath();
String server1Key = TestUtils.loadCert("server1.key").getAbsolutePath();
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setPrivateKey(DataSource.newBuilder().setFilename(server1Key).build())
.setCertificateChain(DataSource.newBuilder().setFilename(server1Pem).build())
.build();
CommonTlsContext commonTlsContext =
CommonTlsContext.newBuilder().addTlsCertificates(tlsCert).build();
DownstreamTlsContext downstreamTlsContext =
DownstreamTlsContext.newBuilder()
.setCommonTlsContext(commonTlsContext)
.setRequireClientCertificate(BoolValue.of(false))
.build();
Server server = getXdsServer(downstreamTlsContext);
// for TLS client doesn't need cert but needs trustCa
String trustCa = TestUtils.loadCert("ca.pem").getAbsolutePath();
CertificateValidationContext certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename(trustCa).build())
.build();
CommonTlsContext commonTlsContext1 =
CommonTlsContext.newBuilder().setValidationContext(certContext).build();
UpstreamTlsContext upstreamTlsContext =
UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext1).build();
buildClientAndTest(upstreamTlsContext, "foo.test.google.fr", "buddy", server.getPort());
}
/** mTLS - client auth enabled. */
@Test
public void mtlsClientServer_withClientAuthentication() throws IOException, InterruptedException {
String server1Pem = TestUtils.loadCert("server1.pem").getAbsolutePath();
String server1Key = TestUtils.loadCert("server1.key").getAbsolutePath();
String trustCa = TestUtils.loadCert("ca.pem").getAbsolutePath();
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setPrivateKey(DataSource.newBuilder().setFilename(server1Key).build())
.setCertificateChain(DataSource.newBuilder().setFilename(server1Pem).build())
.build();
CertificateValidationContext certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename(trustCa).build())
.build();
CommonTlsContext commonTlsContext =
CommonTlsContext.newBuilder()
.addTlsCertificates(tlsCert)
.setValidationContext(certContext)
.build();
DownstreamTlsContext downstreamTlsContext =
DownstreamTlsContext.newBuilder()
.setCommonTlsContext(commonTlsContext)
.setRequireClientCertificate(BoolValue.of(false))
.build();
Server server = getXdsServer(downstreamTlsContext);
String clientPem = TestUtils.loadCert("client.pem").getAbsolutePath();
String clientKey = TestUtils.loadCert("client.key").getAbsolutePath();
TlsCertificate tlsCert1 =
TlsCertificate.newBuilder()
.setPrivateKey(DataSource.newBuilder().setFilename(clientKey).build())
.setCertificateChain(DataSource.newBuilder().setFilename(clientPem).build())
.build();
CommonTlsContext commonTlsContext1 =
CommonTlsContext.newBuilder()
.addTlsCertificates(tlsCert1)
.setValidationContext(certContext)
.build();
UpstreamTlsContext upstreamTlsContext =
UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext1).build();
buildClientAndTest(upstreamTlsContext, "foo.test.google.fr", "buddy", server.getPort());
}
private Server getXdsServer(DownstreamTlsContext downstreamTlsContext) throws IOException {
XdsServerBuilder serverBuilder =
XdsServerBuilder.forPort(0) // get unused port
.addService(new SimpleServiceImpl())
.tlsContext(downstreamTlsContext);
return cleanupRule.register(serverBuilder.build()).start();
}
private void buildClientAndTest(
UpstreamTlsContext upstreamTlsContext,
String overrideAuthority,
String requestMessage,
int serverPort) {
XdsChannelBuilder builder =
XdsChannelBuilder.forTarget("localhost:" + serverPort).tlsContext(upstreamTlsContext);
if (overrideAuthority != null) {
builder = builder.overrideAuthority(overrideAuthority);
}
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
SimpleServiceGrpc.newBlockingStub(cleanupRule.register(builder.build()));
String resp = unaryRpc(requestMessage, blockingStub);
assertThat(resp).isEqualTo("Hello " + requestMessage);
}
/** Say hello to server. */
private static String unaryRpc(
String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) {
SimpleRequest request = SimpleRequest.newBuilder().setRequestMessage(requestMessage).build();
SimpleResponse response = blockingStub.unaryRpc(request);
return response.getResponseMessage();
}
private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase {
@Override
public void unaryRpc(SimpleRequest req, StreamObserver<SimpleResponse> responseObserver) {
SimpleResponse response =
SimpleResponse.newBuilder()
.setResponseMessage("Hello " + req.getRequestMessage())
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
}
}
}

View File

@ -17,6 +17,11 @@
package io.grpc.xds.internal.sds.trust;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import com.google.protobuf.ByteString;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
@ -36,21 +41,6 @@ import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class SdsTrustManagerFactoryTest {
/** Trust store cert. */
private static final String CA_PEM_FILE = "ca.pem";
/** server cert. */
private static final String SERVER_1_PEM_FILE = "server1.pem";
/** client cert. */
private static final String CLIENT_PEM_FILE = "client.pem";
/** bad server cert. */
private static final String BAD_SERVER_PEM_FILE = "badserver.pem";
/** bad client cert. */
private static final String BAD_CLIENT_PEM_FILE = "badclient.pem";
@Test
public void constructor_fromFile() throws CertificateException, IOException, CertStoreException {
SdsTrustManagerFactory factory =

View File

@ -17,6 +17,10 @@
package io.grpc.xds.internal.sds.trust;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.CALLS_REAL_METHODS;
import static org.mockito.Mockito.doReturn;
@ -52,18 +56,6 @@ import sun.security.validator.ValidatorException;
@RunWith(JUnit4.class)
public class SdsX509TrustManagerTest {
/** Trust store cert. */
private static final String CA_PEM_FILE = "ca.pem";
/** server1 has 4 SANs. */
private static final String SERVER_1_PEM_FILE = "server1.pem";
/** client has no SANs. */
private static final String CLIENT_PEM_FILE = "client.pem";
/** Untrusted server. */
private static final String BAD_SERVER_PEM_FILE = "badserver.pem";
@Rule
public final MockitoRule mockitoRule = MockitoJUnit.rule();