diff --git a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java index 708268333f..087127d5cf 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java +++ b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java @@ -70,6 +70,7 @@ public final class XdsClientWrapperForServerSds { @Nullable private final XdsClient xdsClient; private final int port; private final ScheduledExecutorService timeService; + private final XdsClient.ListenerWatcher listenerWatcher; /** * Factory method for creating a {@link XdsClientWrapperForServerSds}. @@ -106,15 +107,14 @@ public final class XdsClientWrapperForServerSds { this.port = port; this.xdsClient = xdsClient; this.timeService = timeService; - xdsClient.watchListenerData( - port, + this.listenerWatcher = new XdsClient.ListenerWatcher() { @Override public void onListenerChanged(XdsClient.ListenerUpdate update) { logger.log( Level.INFO, - "Setting myListener from ConfigUpdate listener :{0}", - update.getListener().toString()); + "Setting myListener from ConfigUpdate listener: {0}", + update.getListener()); curListener = update.getListener(); } @@ -126,9 +126,10 @@ public final class XdsClientWrapperForServerSds { curListener = null; } // TODO(sanjaypujare): Implement logic for other cases based on final design. - logger.log(Level.SEVERE, "ListenerWatcher in XdsClientWrapperForServerSds:{0}", error); + logger.log(Level.SEVERE, "ListenerWatcher in XdsClientWrapperForServerSds: {0}", error); } - }); + }; + xdsClient.watchListenerData(port, listenerWatcher); } /** @@ -157,6 +158,11 @@ public final class XdsClientWrapperForServerSds { return null; } + @VisibleForTesting + XdsClient.ListenerWatcher getListenerWatcher() { + return listenerWatcher; + } + private static final class FilterChainComparator implements Comparator { private final InetSocketAddress localAddress; diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java index 8270d2dad1..c50fec650a 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java @@ -16,7 +16,6 @@ package io.grpc.xds.internal.sds; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; @@ -74,22 +73,19 @@ public final class SdsProtocolNegotiators { /** * Creates an SDS based {@link ProtocolNegotiator} for a {@link io.grpc.netty.NettyServerBuilder}. - * Passing {@code null} for downstreamTlsContext will fall back to plaintext. * If xDS returns no DownstreamTlsContext, it will fall back to plaintext. * - * @param downstreamTlsContext passed in {@link XdsServerBuilder#tlsContext}. * @param port the listening port passed to {@link XdsServerBuilder#forPort(int)}. */ public static ProtocolNegotiator serverProtocolNegotiator( - @Nullable DownstreamTlsContext downstreamTlsContext, int port, - SynchronizationContext syncContext) { + int port, SynchronizationContext syncContext) { XdsClientWrapperForServerSds xdsClientWrapperForServerSds = ServerSdsProtocolNegotiator.getXdsClientWrapperForServerSds(port, syncContext); - if (xdsClientWrapperForServerSds == null && downstreamTlsContext == null) { + if (xdsClientWrapperForServerSds == null) { logger.log(Level.INFO, "Fallback to plaintext for server at port {0}", port); return InternalProtocolNegotiators.serverPlaintext(); } else { - return new ServerSdsProtocolNegotiator(downstreamTlsContext, xdsClientWrapperForServerSds); + return new ServerSdsProtocolNegotiator(xdsClientWrapperForServerSds); } } @@ -267,18 +263,13 @@ public final class SdsProtocolNegotiators { @VisibleForTesting public static final class ServerSdsProtocolNegotiator implements ProtocolNegotiator { - @Nullable private final DownstreamTlsContext downstreamTlsContext; - @Nullable private final XdsClientWrapperForServerSds xdsClientWrapperForServerSds; + private final XdsClientWrapperForServerSds xdsClientWrapperForServerSds; /** Constructor. */ @VisibleForTesting - public ServerSdsProtocolNegotiator( - @Nullable DownstreamTlsContext downstreamTlsContext, - @Nullable XdsClientWrapperForServerSds xdsClientWrapperForServerSds) { - checkArgument(downstreamTlsContext != null || xdsClientWrapperForServerSds != null, - "both downstreamTlsContext and xdsClientWrapperForServerSds cannot be null"); - this.downstreamTlsContext = downstreamTlsContext; - this.xdsClientWrapperForServerSds = xdsClientWrapperForServerSds; + public ServerSdsProtocolNegotiator(XdsClientWrapperForServerSds xdsClientWrapperForServerSds) { + this.xdsClientWrapperForServerSds = + checkNotNull(xdsClientWrapperForServerSds, "xdsClientWrapperForServerSds"); } private static XdsClientWrapperForServerSds getXdsClientWrapperForServerSds( @@ -299,8 +290,7 @@ public final class SdsProtocolNegotiators { @Override public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - return new HandlerPickerHandler(grpcHandler, downstreamTlsContext, - xdsClientWrapperForServerSds); + return new HandlerPickerHandler(grpcHandler, xdsClientWrapperForServerSds); } @Override @@ -315,16 +305,13 @@ public final class SdsProtocolNegotiators { static final class HandlerPickerHandler extends ChannelInboundHandlerAdapter { private final GrpcHttp2ConnectionHandler grpcHandler; - private final DownstreamTlsContext downstreamTlsContextFromBuilder; private final XdsClientWrapperForServerSds xdsClientWrapperForServerSds; HandlerPickerHandler( GrpcHttp2ConnectionHandler grpcHandler, - @Nullable DownstreamTlsContext downstreamTlsContext, @Nullable XdsClientWrapperForServerSds xdsClientWrapperForServerSds) { checkNotNull(grpcHandler, "grpcHandler"); this.grpcHandler = grpcHandler; - this.downstreamTlsContextFromBuilder = downstreamTlsContext; this.xdsClientWrapperForServerSds = xdsClientWrapperForServerSds; } @@ -339,9 +326,6 @@ public final class SdsProtocolNegotiators { xdsClientWrapperForServerSds == null ? null : xdsClientWrapperForServerSds.getDownstreamTlsContext(ctx.channel()); - if (isTlsContextEmpty(downstreamTlsContext)) { - downstreamTlsContext = downstreamTlsContextFromBuilder; - } if (isTlsContextEmpty(downstreamTlsContext)) { logger.log(Level.INFO, "Fallback to plaintext for {0}", ctx.channel().localAddress()); ctx.pipeline() diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/XdsServerBuilder.java b/xds/src/main/java/io/grpc/xds/internal/sds/XdsServerBuilder.java index 6950f1f98c..25d802c92f 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/XdsServerBuilder.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/XdsServerBuilder.java @@ -17,7 +17,6 @@ package io.grpc.xds.internal.sds; import com.google.common.annotations.VisibleForTesting; -import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; import io.grpc.BindableService; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; @@ -51,9 +50,6 @@ public final class XdsServerBuilder extends ServerBuilder { private final NettyServerBuilder delegate; private final int port; - // TODO (sanjaypujare) integrate with xDS client to get downstreamTlsContext from LDS - @Nullable private DownstreamTlsContext downstreamTlsContext; - private XdsServerBuilder(NettyServerBuilder nettyDelegate, int port) { this.delegate = nettyDelegate; this.port = port; @@ -130,15 +126,6 @@ public final class XdsServerBuilder extends ServerBuilder { return this; } - /** - * Set the DownstreamTlsContext for the server. This is a temporary workaround until integration - * with xDS client is implemented to get LDS. Passing {@code null} will fall back to plaintext. - */ - public XdsServerBuilder tlsContext(@Nullable DownstreamTlsContext downstreamTlsContext) { - this.downstreamTlsContext = downstreamTlsContext; - return this; - } - /** Creates a gRPC server builder for the given port. */ public static XdsServerBuilder forPort(int port) { NettyServerBuilder nettyDelegate = NettyServerBuilder.forAddress(new InetSocketAddress(port)); @@ -173,8 +160,7 @@ public final class XdsServerBuilder extends ServerBuilder { } }); InternalProtocolNegotiator.ProtocolNegotiator serverProtocolNegotiator = - SdsProtocolNegotiators.serverProtocolNegotiator( - this.downstreamTlsContext, port, syncContext); + SdsProtocolNegotiators.serverProtocolNegotiator(port, syncContext); return buildServer(serverProtocolNegotiator); } diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java index c6c2e7e342..7faa0154e0 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java @@ -20,6 +20,11 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.xds.XdsLbPolicies.EDS_POLICY_NAME; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; @@ -56,7 +61,7 @@ import io.grpc.xds.XdsClient.EndpointWatcher; import io.grpc.xds.XdsClient.RefCountedXdsClientObjectPool; import io.grpc.xds.XdsClient.XdsClientFactory; import io.grpc.xds.XdsLoadBalancerProvider.XdsConfig; -import io.grpc.xds.internal.sds.SecretVolumeSslContextProviderTest; +import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.SslContextProvider; import io.grpc.xds.internal.sds.TlsContextManager; import java.net.InetSocketAddress; @@ -79,11 +84,6 @@ import org.mockito.MockitoAnnotations; */ @RunWith(JUnit4.class) public class CdsLoadBalancerTest { - private static final String CLIENT_PEM_FILE = "client.pem"; - private static final String CLIENT_KEY_FILE = "client.key"; - private static final String BADCLIENT_PEM_FILE = "badclient.pem"; - private static final String BADCLIENT_KEY_FILE = "badclient.key"; - private static final String CA_PEM_FILE = "ca.pem"; private final RefCountedXdsClientObjectPool xdsClientPool = new RefCountedXdsClientObjectPool( new XdsClientFactory() { @@ -356,7 +356,7 @@ public class CdsLoadBalancerTest { verify(xdsClient).watchClusterData(eq("foo.googleapis.com"), clusterWatcherCaptor1.capture()); UpstreamTlsContext upstreamTlsContext = - SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames( + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); SslContextProvider mockSslContextProvider = @@ -415,8 +415,8 @@ public class CdsLoadBalancerTest { reset(mockTlsContextManager); reset(helper); UpstreamTlsContext upstreamTlsContext1 = - SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames( - BADCLIENT_KEY_FILE, BADCLIENT_PEM_FILE, CA_PEM_FILE); + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( + BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE); SslContextProvider mockSslContextProvider1 = (SslContextProvider) mock(SslContextProvider.class); doReturn(upstreamTlsContext1).when(mockSslContextProvider1).getSource(); diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTest.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTest.java index f6d6f8304e..e0882bf94c 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTest.java @@ -18,6 +18,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -139,6 +140,26 @@ public class XdsClientWrapperForServerSdsTest { private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; private final DownstreamTlsContext[] tlsContexts = new DownstreamTlsContext[3]; + /** Creates XdsClientWrapperForServerSds: also used by other classes. */ + public static XdsClientWrapperForServerSds createXdsClientWrapperForServerSds( + int port, DownstreamTlsContext downstreamTlsContext) { + XdsClient mockXdsClient = mock(XdsClient.class); + XdsClientWrapperForServerSds xdsClientWrapperForServerSds = + new XdsClientWrapperForServerSds(port, mockXdsClient, null); + generateListenerUpdateToWatcher( + port, downstreamTlsContext, xdsClientWrapperForServerSds.getListenerWatcher()); + return xdsClientWrapperForServerSds; + } + + static void generateListenerUpdateToWatcher( + int port, DownstreamTlsContext tlsContext, XdsClient.ListenerWatcher registeredWatcher) { + EnvoyServerProtoData.Listener listener = + XdsSdsClientServerTest.buildListener("listener1", "0.0.0.0", port, tlsContext); + XdsClient.ListenerUpdate listenerUpdate = + XdsClient.ListenerUpdate.newBuilder().setListener(listener).build(); + registeredWatcher.onListenerChanged(listenerUpdate); + } + @Before public void setUp() throws IOException { MockitoAnnotations.initMocks(this); diff --git a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java new file mode 100644 index 0000000000..8d1e09ff46 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java @@ -0,0 +1,213 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsClientWrapperForServerSdsTest.buildFilterChainMatch; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static org.junit.Assert.fail; + +import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; +import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; +import io.grpc.Server; +import io.grpc.StatusRuntimeException; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.sds.SdsProtocolNegotiators; +import io.grpc.xds.internal.sds.XdsChannelBuilder; +import io.grpc.xds.internal.sds.XdsServerBuilder; +import java.io.IOException; +import java.net.ServerSocket; +import java.util.Arrays; +import javax.net.ssl.SSLHandshakeException; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link XdsChannelBuilder} and {@link XdsServerBuilder} for plaintext/TLS/mTLS + * modes. + */ +@RunWith(JUnit4.class) +public class XdsSdsClientServerTest { + + @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + private int port; + + @Before + public void setUp() throws IOException { + port = findFreePort(); + } + + @Test + public void plaintextClientServer() throws IOException { + Server unused = buildServerWithTlsContext(/* downstreamTlsContext= */ null); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(/* upstreamTlsContext= */ null, /* overrideAuthority= */ null); + assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); + } + + /** TLS channel - no mTLS. */ + @Test + public void tlsClientServer_noClientAuthentication() throws IOException { + DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( + SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, null); + Server unused = buildServerWithTlsContext(downstreamTlsContext); + + // for TLS, client only needs trustCa + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( + /* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } + + /** mTLS - client auth enabled. */ + @Test + public void mtlsClientServer_withClientAuthentication() throws IOException { + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( + CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); + XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext); + } + + /** mTLS - client auth enabled then update server certs to untrusted. */ + @Test + public void mtlsClientServer_changeServerContext_expectException() throws IOException { + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( + CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); + XdsClient.ListenerWatcher listenerWatcher = + performMtlsTestAndGetListenerWatcher(upstreamTlsContext); + DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( + BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE); + XdsClientWrapperForServerSdsTest.generateListenerUpdateToWatcher( + port, downstreamTlsContext, listenerWatcher); + try { + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, "foo.test.google.fr"); + assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); + fail("exception expected"); + } catch (StatusRuntimeException sre) { + assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class); + assertThat(sre).hasCauseThat().hasMessageThat().isEqualTo("General OpenSslEngine problem"); + } + } + + private XdsClient.ListenerWatcher performMtlsTestAndGetListenerWatcher( + UpstreamTlsContext upstreamTlsContext) throws IOException { + DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( + SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); + + final XdsClientWrapperForServerSds xdsClientWrapperForServerSds = + XdsClientWrapperForServerSdsTest.createXdsClientWrapperForServerSds( + port, /* downstreamTlsContext= */ downstreamTlsContext); + SdsProtocolNegotiators.ServerSdsProtocolNegotiator serverSdsProtocolNegotiator = + new SdsProtocolNegotiators.ServerSdsProtocolNegotiator(xdsClientWrapperForServerSds); + Server unused = getServer(port, serverSdsProtocolNegotiator); + + XdsClient.ListenerWatcher listenerWatcher = xdsClientWrapperForServerSds.getListenerWatcher(); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, "foo.test.google.fr"); + assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); + return listenerWatcher; + } + + private Server buildServerWithTlsContext(DownstreamTlsContext downstreamTlsContext) + throws IOException { + final XdsClientWrapperForServerSds xdsClientWrapperForServerSds = + XdsClientWrapperForServerSdsTest.createXdsClientWrapperForServerSds( + port, /* downstreamTlsContext= */ downstreamTlsContext); + SdsProtocolNegotiators.ServerSdsProtocolNegotiator serverSdsProtocolNegotiator = + new SdsProtocolNegotiators.ServerSdsProtocolNegotiator(xdsClientWrapperForServerSds); + return getServer(port, serverSdsProtocolNegotiator); + } + + private Server getServer( + int port, SdsProtocolNegotiators.ServerSdsProtocolNegotiator serverSdsProtocolNegotiator) + throws IOException { + XdsServerBuilder builder = XdsServerBuilder.forPort(port).addService(new SimpleServiceImpl()); + return cleanupRule.register(builder.buildServer(serverSdsProtocolNegotiator)).start(); + } + + private static int findFreePort() throws IOException { + try (ServerSocket socket = new ServerSocket(0)) { + socket.setReuseAddress(true); + return socket.getLocalPort(); + } + } + + static EnvoyServerProtoData.Listener buildListener( + String name, String address, int port, DownstreamTlsContext tlsContext) { + EnvoyServerProtoData.FilterChainMatch filterChainMatch = buildFilterChainMatch(port, address); + EnvoyServerProtoData.FilterChain filterChain1 = + new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext); + EnvoyServerProtoData.Listener listener = + new EnvoyServerProtoData.Listener(name, address, Arrays.asList(filterChain1)); + return listener; + } + + private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( + UpstreamTlsContext upstreamTlsContext, String overrideAuthority) { + XdsChannelBuilder builder = + XdsChannelBuilder.forTarget("localhost:" + port).tlsContext(upstreamTlsContext); + if (overrideAuthority != null) { + builder = builder.overrideAuthority(overrideAuthority); + } + return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(builder.build())); + } + + /** Say hello to server. */ + private static String unaryRpc( + String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) { + SimpleRequest request = SimpleRequest.newBuilder().setRequestMessage(requestMessage).build(); + SimpleResponse response = blockingStub.unaryRpc(request); + return response.getResponseMessage(); + } + + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + + @Override + public void unaryRpc(SimpleRequest req, StreamObserver responseObserver) { + SimpleResponse response = + SimpleResponse.newBuilder() + .setResponseMessage("Hello " + req.getRequestMessage()) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java index 5562347a97..d2546b1d23 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java @@ -44,7 +44,7 @@ public class XdsServerBuilderTest { XdsClientWrapperForServerSds xdsClientWrapperForServerSds = new XdsClientWrapperForServerSds(port, mockXdsClient, null); ServerSdsProtocolNegotiator serverSdsProtocolNegotiator = - new ServerSdsProtocolNegotiator(null, xdsClientWrapperForServerSds); + new ServerSdsProtocolNegotiator(xdsClientWrapperForServerSds); Server xdsServer = builder.buildServer(serverSdsProtocolNegotiator); xdsServer.start(); xdsServer.shutdown(); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java index a72b70931c..be86950560 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java @@ -17,6 +17,9 @@ package io.grpc.xds.internal.sds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; @@ -29,17 +32,13 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class ClientSslContextProviderFactoryTest { - private static final String CLIENT_PEM_FILE = "client.pem"; - private static final String CLIENT_KEY_FILE = "client.key"; - private static final String CA_PEM_FILE = "ca.pem"; - ClientSslContextProviderFactory clientSslContextProviderFactory = new ClientSslContextProviderFactory(); @Test public void createSslContextProvider_allFilenames() { UpstreamTlsContext upstreamTlsContext = - SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames( + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); SslContextProvider sslContextProvider = diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java index a7dc265298..0f1bb235d3 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java @@ -23,12 +23,29 @@ import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValid import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig; import io.envoyproxy.envoy.api.v2.auth.TlsCertificate; +import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; import io.envoyproxy.envoy.api.v2.core.DataSource; +import io.grpc.internal.testing.TestUtils; +import java.io.IOException; import java.util.Arrays; +import javax.annotation.Nullable; /** Utility class for client and server ssl provider tests. */ public class CommonTlsContextTestsUtil { + public static final String SERVER_0_PEM_FILE = "server0.pem"; + public static final String SERVER_0_KEY_FILE = "server0.key"; + public static final String SERVER_1_PEM_FILE = "server1.pem"; + public static final String SERVER_1_KEY_FILE = "server1.key"; + public static final String CLIENT_PEM_FILE = "client.pem"; + public static final String CLIENT_KEY_FILE = "client.key"; + public static final String CA_PEM_FILE = "ca.pem"; + /** Bad/untrusted server certs. */ + public static final String BAD_SERVER_PEM_FILE = "badserver.pem"; + public static final String BAD_SERVER_KEY_FILE = "badserver.key"; + public static final String BAD_CLIENT_PEM_FILE = "badclient.pem"; + public static final String BAD_CLIENT_KEY_FILE = "badclient.key"; + static SdsSecretConfig buildSdsSecretConfig(String name, String targetUri, String channelType) { SdsSecretConfig sdsSecretConfig = null; if (!Strings.isNullOrEmpty(name) && !Strings.isNullOrEmpty(targetUri)) { @@ -144,4 +161,85 @@ public class CommonTlsContextTestsUtil { Arrays.asList("managed-tls"), null)); } + + static String getTempFileNameForResourcesFile(String resFile) throws IOException { + return TestUtils.loadCert(resFile).getAbsolutePath(); + } + + /** + * Helper method to build DownstreamTlsContext for above tests. Called from other classes as well. + */ + public static DownstreamTlsContext buildDownstreamTlsContextFromFilenames( + @Nullable String privateKey, @Nullable String certChain, @Nullable String trustCa) { + // get temp file for each file + try { + if (certChain != null) { + certChain = getTempFileNameForResourcesFile(certChain); + } + if (privateKey != null) { + privateKey = getTempFileNameForResourcesFile(privateKey); + } + if (trustCa != null) { + trustCa = getTempFileNameForResourcesFile(trustCa); + } + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + return buildDownstreamTlsContext( + buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa)); + } + + /** + * Helper method to build UpstreamTlsContext for above tests. Called from other classes as well. + */ + public static UpstreamTlsContext buildUpstreamTlsContextFromFilenames( + @Nullable String privateKey, @Nullable String certChain, @Nullable String trustCa) { + try { + if (certChain != null) { + certChain = getTempFileNameForResourcesFile(certChain); + } + if (privateKey != null) { + privateKey = getTempFileNameForResourcesFile(privateKey); + } + if (trustCa != null) { + trustCa = getTempFileNameForResourcesFile(trustCa); + } + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + return SecretVolumeSslContextProviderTest.buildUpstreamTlsContext( + buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa)); + } + + private static CommonTlsContext buildCommonTlsContextFromFilenames( + String privateKey, String certChain, String trustCa) { + TlsCertificate tlsCert = null; + if (!Strings.isNullOrEmpty(privateKey) && !Strings.isNullOrEmpty(certChain)) { + tlsCert = + TlsCertificate.newBuilder() + .setCertificateChain(DataSource.newBuilder().setFilename(certChain)) + .setPrivateKey(DataSource.newBuilder().setFilename(privateKey)) + .build(); + } + CertificateValidationContext certContext = null; + if (!Strings.isNullOrEmpty(trustCa)) { + certContext = + CertificateValidationContext.newBuilder() + .setTrustedCa(DataSource.newBuilder().setFilename(trustCa)) + .build(); + } + return getCommonTlsContext(tlsCert, certContext); + } + + static CommonTlsContext getCommonTlsContext( + TlsCertificate tlsCertificate, CertificateValidationContext certContext) { + CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); + if (tlsCertificate != null) { + builder = builder.addTlsCertificates(tlsCertificate); + } + if (certContext != null) { + builder = builder.setValidationContext(certContext); + } + return builder.build(); + } } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientFileBasedMetadataTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientFileBasedMetadataTest.java index fda58feaaa..6d84b9fca9 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientFileBasedMetadataTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientFileBasedMetadataTest.java @@ -17,6 +17,8 @@ package io.grpc.xds.internal.sds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; import static org.junit.Assert.fail; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; @@ -217,9 +219,7 @@ public class SdsClientFileBasedMetadataTest { public void testSecretWatcher_tlsCertificate() throws IOException { SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class); - doReturn( - SdsClientTest.getOneTlsCertSecret( - "name1", SdsClientTest.SERVER_0_KEY_FILE, SdsClientTest.SERVER_0_PEM_FILE)) + doReturn(SdsClientTest.getOneTlsCertSecret("name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE)) .when(serverMock) .getSecretFor("name1"); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientTest.java index 0927e7e286..da5c0c1bc3 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientTest.java @@ -17,6 +17,11 @@ package io.grpc.xds.internal.sds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import static org.junit.Assert.fail; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doThrow; @@ -65,12 +70,6 @@ import org.mockito.stubbing.Answer; @RunWith(JUnit4.class) public class SdsClientTest { - static final String SERVER_0_PEM_FILE = "server0.pem"; - static final String SERVER_0_KEY_FILE = "server0.key"; - static final String SERVER_1_PEM_FILE = "server1.pem"; - static final String SERVER_1_KEY_FILE = "server1.key"; - static final String CA_PEM_FILE = "ca.pem"; - private TestSdsServer.ServerMock serverMock; private TestSdsServer server; private SdsClient sdsClient; diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientUdsFileBasedMetadataTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientUdsFileBasedMetadataTest.java index 02fde83dfc..25fe4cf6fd 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientUdsFileBasedMetadataTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientUdsFileBasedMetadataTest.java @@ -17,6 +17,8 @@ package io.grpc.xds.internal.sds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; @@ -94,9 +96,7 @@ public class SdsClientUdsFileBasedMetadataTest { public void testSecretWatcher_tlsCertificate() throws IOException, InterruptedException { final SdsClient.SecretWatcher mockWatcher = mock(SdsClient.SecretWatcher.class); - doReturn( - SdsClientTest.getOneTlsCertSecret( - "name1", SdsClientTest.SERVER_0_KEY_FILE, SdsClientTest.SERVER_0_PEM_FILE)) + doReturn(SdsClientTest.getOneTlsCertSecret("name1", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE)) .when(serverMock) .getSecretFor("name1"); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientUdsTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientUdsTest.java index 80f9ac18f6..9f5431f6a0 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientUdsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsClientUdsTest.java @@ -17,6 +17,10 @@ package io.grpc.xds.internal.sds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; @@ -45,10 +49,6 @@ import org.mockito.ArgumentMatchers; @RunWith(JUnit4.class) public class SdsClientUdsTest { - private static final String SERVER_0_PEM_FILE = "server0.pem"; - private static final String SERVER_0_KEY_FILE = "server0.key"; - private static final String SERVER_1_PEM_FILE = "server1.pem"; - private static final String SERVER_1_KEY_FILE = "server1.key"; private static final String SDSCLIENT_TEST_SOCKET = "/tmp/sdsclient-test.socket"; private TestSdsServer.ServerMock serverMock; diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java index ce05235749..b7a9336ccf 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java @@ -17,6 +17,11 @@ package io.grpc.xds.internal.sds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @@ -31,6 +36,8 @@ import io.grpc.internal.testing.TestUtils; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.xds.XdsClientWrapperForServerSds; +import io.grpc.xds.XdsClientWrapperForServerSdsTest; import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ClientSdsHandler; import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ClientSdsProtocolNegotiator; import io.netty.channel.ChannelHandler; @@ -49,6 +56,8 @@ import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandshakeCompletionEvent; import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.util.Iterator; import java.util.Map; import org.junit.Test; @@ -59,12 +68,6 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class SdsProtocolNegotiatorsTest { - private static final String SERVER_1_PEM_FILE = "server1.pem"; - private static final String SERVER_1_KEY_FILE = "server1.key"; - private static final String CLIENT_PEM_FILE = "client.pem"; - private static final String CLIENT_KEY_FILE = "client.key"; - private static final String CA_PEM_FILE = "ca.pem"; - private final GrpcHttp2ConnectionHandler grpcHandler = FakeGrpcHttp2ConnectionHandler.newHandler(); @@ -153,7 +156,8 @@ public class SdsProtocolNegotiatorsTest { @Test public void clientSdsProtocolNegotiatorNewHandler_nonNullTlsContext() { UpstreamTlsContext upstreamTlsContext = - buildUpstreamTlsContext(getCommonTlsContext(null, null)); + buildUpstreamTlsContext( + getCommonTlsContext(/* tlsCertificate= */ null, /* certContext= */ null)); ClientSdsProtocolNegotiator pn = new ClientSdsProtocolNegotiator(upstreamTlsContext); ChannelHandler newHandler = pn.newHandler(grpcHandler); assertThat(newHandler).isNotNull(); @@ -187,11 +191,23 @@ public class SdsProtocolNegotiatorsTest { @Test public void serverSdsHandler_addLast() throws IOException { + // we need InetSocketAddress instead of EmbeddedSocketAddress as localAddress for this test + channel = + new EmbeddedChannel() { + @Override + public SocketAddress localAddress() { + return new InetSocketAddress("172.168.1.1", 80); + } + }; + pipeline = channel.pipeline(); DownstreamTlsContext downstreamTlsContext = buildDownstreamTlsContextFromFilenames(SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); + XdsClientWrapperForServerSds xdsClientWrapperForServerSds = + XdsClientWrapperForServerSdsTest.createXdsClientWrapperForServerSds( + 80, downstreamTlsContext); SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = - new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, downstreamTlsContext, null); + new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, xdsClientWrapperForServerSds); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler @@ -217,7 +233,8 @@ public class SdsProtocolNegotiatorsTest { @Test public void serverSdsHandler_nullTlsContext_expectPlaintext() throws IOException { SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = - new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, null, null); + new SdsProtocolNegotiators.HandlerPickerHandler( + grpcHandler, /* xdsClientWrapperForServerSds= */ null); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler @@ -259,10 +276,9 @@ public class SdsProtocolNegotiatorsTest { } @Test - public void serverSdsProtocolNegotiator_passNulls_expectPlaintext() { + public void serverSdsProtocolNegotiator_nullSyncContext_expectPlaintext() { InternalProtocolNegotiator.ProtocolNegotiator protocolNegotiator = - SdsProtocolNegotiators.serverProtocolNegotiator(null, 7000, - null); + SdsProtocolNegotiators.serverProtocolNegotiator(/* port= */ 7000, /* syncContext= */ null); assertThat(protocolNegotiator.scheme().toString()).isEqualTo("http"); } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java index 1a31b6c6b7..db8e13d525 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsSslContextProviderTest.java @@ -17,6 +17,11 @@ package io.grpc.xds.internal.sds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import static io.grpc.xds.internal.sds.SdsClientTest.getOneCertificateValidationContextSecret; import static io.grpc.xds.internal.sds.SdsClientTest.getOneTlsCertSecret; import static io.grpc.xds.internal.sds.SecretVolumeSslContextProviderTest.doChecksOnSslContext; @@ -39,12 +44,6 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class SdsSslContextProviderTest { - private static final String SERVER_1_PEM_FILE = "server1.pem"; - private static final String SERVER_1_KEY_FILE = "server1.key"; - private static final String CLIENT_PEM_FILE = "client.pem"; - private static final String CLIENT_KEY_FILE = "client.key"; - private static final String CA_PEM_FILE = "ca.pem"; - private TestSdsServer.ServerMock serverMock; private TestSdsServer server; private Node node; @@ -182,8 +181,7 @@ public class SdsSslContextProviderTest { when(serverMock.getSecretFor(/* name= */ "cert1")) .thenReturn(getOneTlsCertSecret(/* name= */ "cert1", CLIENT_KEY_FILE, CLIENT_PEM_FILE)); when(serverMock.getSecretFor("valid1")) - .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", - CA_PEM_FILE)); + .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE)); SdsSslContextProvider provider = getSdsSslContextProvider( diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java index 6477188621..6d5e5bdc83 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SecretVolumeSslContextProviderTest.java @@ -17,16 +17,18 @@ package io.grpc.xds.internal.sds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; -import com.google.common.base.Strings; import com.google.common.util.concurrent.MoreExecutors; import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; -import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; import io.envoyproxy.envoy.api.v2.auth.TlsCertificate; import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; import io.envoyproxy.envoy.api.v2.core.DataSource; -import io.grpc.internal.testing.TestUtils; import io.netty.handler.ssl.SslContext; import java.io.IOException; import java.security.cert.CertStoreException; @@ -43,12 +45,6 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class SecretVolumeSslContextProviderTest { - private static final String SERVER_1_PEM_FILE = "server1.pem"; - private static final String SERVER_1_KEY_FILE = "server1.key"; - private static final String CLIENT_PEM_FILE = "client.pem"; - private static final String CLIENT_KEY_FILE = "client.key"; - private static final String CA_PEM_FILE = "ca.pem"; - @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); @Test @@ -277,8 +273,8 @@ public class SecretVolumeSslContextProviderTest { TlsCertificate tlsCert = TlsCertificate.getDefaultInstance(); try { SecretVolumeSslContextProvider.getProviderForServer( - CommonTlsContextTestsUtil - .buildDownstreamTlsContext(getCommonTlsContext(tlsCert, /* certContext= */ null))); + CommonTlsContextTestsUtil.buildDownstreamTlsContext( + CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, /* certContext= */ null))); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().isEqualTo("filename expected"); @@ -298,8 +294,8 @@ public class SecretVolumeSslContextProviderTest { .build(); try { SecretVolumeSslContextProvider.getProviderForServer( - CommonTlsContextTestsUtil - .buildDownstreamTlsContext(getCommonTlsContext(tlsCert, certContext))); + CommonTlsContextTestsUtil.buildDownstreamTlsContext( + CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext))); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected.getMessage()).isEqualTo("filename expected"); @@ -311,7 +307,9 @@ public class SecretVolumeSslContextProviderTest { CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance(); try { SecretVolumeSslContextProvider.getProviderForClient( - buildUpstreamTlsContext(getCommonTlsContext(/* tlsCertificate= */ null, certContext))); + buildUpstreamTlsContext( + CommonTlsContextTestsUtil.getCommonTlsContext( + /* tlsCertificate= */ null, certContext))); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().isEqualTo("certContext is required"); @@ -331,7 +329,8 @@ public class SecretVolumeSslContextProviderTest { .build(); try { SecretVolumeSslContextProvider.getProviderForClient( - buildUpstreamTlsContext(getCommonTlsContext(tlsCert, certContext))); + buildUpstreamTlsContext( + CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext))); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().isEqualTo("filename expected"); @@ -351,38 +350,27 @@ public class SecretVolumeSslContextProviderTest { .build(); try { SecretVolumeSslContextProvider.getProviderForClient( - buildUpstreamTlsContext(getCommonTlsContext(tlsCert, certContext))); + buildUpstreamTlsContext( + CommonTlsContextTestsUtil.getCommonTlsContext(tlsCert, certContext))); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().isEqualTo("filename expected"); } } - private static String getTempFileNameForResourcesFile(String resFile) throws IOException { - return TestUtils.loadCert(resFile).getAbsolutePath(); - } - /** Helper method to build SecretVolumeSslContextProvider from given files. */ private static SecretVolumeSslContextProvider getSslContextSecretVolumeSecretProvider( - boolean server, String certChainFilename, String privateKeyFilename, String trustedCaFilename) - throws IOException { + boolean server, + String certChainFilename, + String privateKeyFilename, + String trustedCaFilename) { - // get temp file for each file - if (certChainFilename != null) { - certChainFilename = getTempFileNameForResourcesFile(certChainFilename); - } - if (privateKeyFilename != null) { - privateKeyFilename = getTempFileNameForResourcesFile(privateKeyFilename); - } - if (trustedCaFilename != null) { - trustedCaFilename = getTempFileNameForResourcesFile(trustedCaFilename); - } return server ? SecretVolumeSslContextProvider.getProviderForServer( - buildDownstreamTlsContextFromFilenames( - privateKeyFilename, certChainFilename, trustedCaFilename)) + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( + privateKeyFilename, certChainFilename, trustedCaFilename)) : SecretVolumeSslContextProvider.getProviderForClient( - buildUpstreamTlsContextFromFilenames( + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( privateKeyFilename, certChainFilename, trustedCaFilename)); } @@ -416,56 +404,6 @@ public class SecretVolumeSslContextProviderTest { } } - /** - * Helper method to build DownstreamTlsContext for above tests. Called from other classes as well. - */ - static DownstreamTlsContext buildDownstreamTlsContextFromFilenames( - String privateKey, String certChain, String trustCa) { - return CommonTlsContextTestsUtil.buildDownstreamTlsContext( - buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa)); - } - - /** - * Helper method to build UpstreamTlsContext for above tests. Called from other classes as well. - */ - public static UpstreamTlsContext buildUpstreamTlsContextFromFilenames( - String privateKey, String certChain, String trustCa) { - return buildUpstreamTlsContext( - buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa)); - } - - private static CommonTlsContext buildCommonTlsContextFromFilenames( - String privateKey, String certChain, String trustCa) { - TlsCertificate tlsCert = null; - if (!Strings.isNullOrEmpty(privateKey) && !Strings.isNullOrEmpty(certChain)) { - tlsCert = - TlsCertificate.newBuilder() - .setCertificateChain(DataSource.newBuilder().setFilename(certChain)) - .setPrivateKey(DataSource.newBuilder().setFilename(privateKey)) - .build(); - } - CertificateValidationContext certContext = null; - if (!Strings.isNullOrEmpty(trustCa)) { - certContext = - CertificateValidationContext.newBuilder() - .setTrustedCa(DataSource.newBuilder().setFilename(trustCa)) - .build(); - } - return getCommonTlsContext(tlsCert, certContext); - } - - private static CommonTlsContext getCommonTlsContext( - TlsCertificate tlsCertificate, CertificateValidationContext certContext) { - CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); - if (tlsCertificate != null) { - builder = builder.addTlsCertificates(tlsCertificate); - } - if (certContext != null) { - builder = builder.setValidationContext(certContext); - } - return builder.build(); - } - /** * Helper method to build UpstreamTlsContext for above tests. Called from other classes as well. */ diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java index b61c629e43..e21bc55f2c 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java @@ -17,6 +17,9 @@ package io.grpc.xds.internal.sds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; @@ -29,18 +32,14 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class ServerSslContextProviderFactoryTest { - private static final String SERVER_PEM_FILE = "server1.pem"; - private static final String SERVER_KEY_FILE = "server1.key"; - private static final String CA_PEM_FILE = "ca.pem"; - ServerSslContextProviderFactory serverSslContextProviderFactory = new ServerSslContextProviderFactory(); @Test public void createSslContextProvider_allFilenames() { DownstreamTlsContext downstreamTlsContext = - SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames( - SERVER_KEY_FILE, SERVER_PEM_FILE, CA_PEM_FILE); + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( + SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); SslContextProvider sslContextProvider = serverSslContextProviderFactory.createSslContextProvider(downstreamTlsContext); @@ -70,7 +69,7 @@ public class ServerSslContextProviderFactoryTest { public void createSslContextProvider_sdsConfigForCertValidationContext_expectException() { CommonTlsContext commonTlsContext = CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForValidationContext( - "name", "unix:/tmp/sds/path", SERVER_KEY_FILE, SERVER_PEM_FILE); + "name", "unix:/tmp/sds/path", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE); DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContext(commonTlsContext); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java index 253005ed8d..0c3617f11e 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java @@ -17,6 +17,13 @@ package io.grpc.xds.internal.sds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -40,14 +47,6 @@ import org.mockito.junit.MockitoRule; @RunWith(JUnit4.class) public class TlsContextManagerTest { - private static final String SERVER_0_PEM_FILE = "server0.pem"; - private static final String SERVER_0_KEY_FILE = "server0.key"; - private static final String SERVER_1_PEM_FILE = "server1.pem"; - private static final String SERVER_1_KEY_FILE = "server1.key"; - private static final String CLIENT_PEM_FILE = "client.pem"; - private static final String CLIENT_KEY_FILE = "client.key"; - private static final String CA_PEM_FILE = "ca.pem"; - @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); @Mock @@ -66,7 +65,7 @@ public class TlsContextManagerTest { @Test public void createServerSslContextProvider() { DownstreamTlsContext downstreamTlsContext = - SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames( + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); @@ -82,7 +81,7 @@ public class TlsContextManagerTest { @Test public void createClientSslContextProvider() { UpstreamTlsContext upstreamTlsContext = - SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames( + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( /* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); @@ -98,7 +97,7 @@ public class TlsContextManagerTest { @Test public void createServerSslContextProvider_differentInstance() { DownstreamTlsContext downstreamTlsContext = - SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames( + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); @@ -107,7 +106,7 @@ public class TlsContextManagerTest { assertThat(serverSecretProvider).isNotNull(); DownstreamTlsContext downstreamTlsContext1 = - SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames( + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( SERVER_0_KEY_FILE, SERVER_0_PEM_FILE, CA_PEM_FILE); SslContextProvider serverSecretProvider1 = tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext1); @@ -118,7 +117,7 @@ public class TlsContextManagerTest { @Test public void createClientSslContextProvider_differentInstance() { UpstreamTlsContext upstreamTlsContext = - SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames( + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( /* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); @@ -127,7 +126,7 @@ public class TlsContextManagerTest { assertThat(clientSecretProvider).isNotNull(); UpstreamTlsContext upstreamTlsContext1 = - SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames( + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); SslContextProvider clientSecretProvider1 = @@ -138,7 +137,7 @@ public class TlsContextManagerTest { @Test public void createServerSslContextProvider_releaseInstance() { DownstreamTlsContext downstreamTlsContext = - SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames( + CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); TlsContextManagerImpl tlsContextManagerImpl = @@ -158,7 +157,7 @@ public class TlsContextManagerTest { @Test public void createClientSslContextProvider_releaseInstance() { UpstreamTlsContext upstreamTlsContext = - SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames( + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); TlsContextManagerImpl tlsContextManagerImpl = diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/XdsSdsClientServerTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/XdsSdsClientServerTest.java deleted file mode 100644 index f3063c4436..0000000000 --- a/xds/src/test/java/io/grpc/xds/internal/sds/XdsSdsClientServerTest.java +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Copyright 2019 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.sds; - -import static com.google.common.truth.Truth.assertThat; - -import com.google.protobuf.BoolValue; -import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; -import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; -import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; -import io.envoyproxy.envoy.api.v2.auth.TlsCertificate; -import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; -import io.envoyproxy.envoy.api.v2.core.DataSource; -import io.grpc.Server; -import io.grpc.internal.testing.TestUtils; -import io.grpc.stub.StreamObserver; -import io.grpc.testing.GrpcCleanupRule; -import io.grpc.testing.protobuf.SimpleRequest; -import io.grpc.testing.protobuf.SimpleResponse; -import io.grpc.testing.protobuf.SimpleServiceGrpc; -import java.io.IOException; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Unit tests for {@link XdsChannelBuilder} and {@link XdsServerBuilder} for plaintext/TLS/mTLS - * modes. - */ -@RunWith(JUnit4.class) -public class XdsSdsClientServerTest { - - @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); - - @Test - public void plaintextClientServer() throws IOException { - Server server = getXdsServer(/* downstreamTlsContext= */ null); - buildClientAndTest( - /* upstreamTlsContext= */ null, /* overrideAuthority= */ null, "buddy", server.getPort()); - } - - /** TLS channel - no mTLS. */ - @Test - public void tlsClientServer_noClientAuthentication() throws IOException { - String server1Pem = TestUtils.loadCert("server1.pem").getAbsolutePath(); - String server1Key = TestUtils.loadCert("server1.key").getAbsolutePath(); - - TlsCertificate tlsCert = - TlsCertificate.newBuilder() - .setPrivateKey(DataSource.newBuilder().setFilename(server1Key).build()) - .setCertificateChain(DataSource.newBuilder().setFilename(server1Pem).build()) - .build(); - - CommonTlsContext commonTlsContext = - CommonTlsContext.newBuilder().addTlsCertificates(tlsCert).build(); - - DownstreamTlsContext downstreamTlsContext = - DownstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setRequireClientCertificate(BoolValue.of(false)) - .build(); - - Server server = getXdsServer(downstreamTlsContext); - - // for TLS client doesn't need cert but needs trustCa - String trustCa = TestUtils.loadCert("ca.pem").getAbsolutePath(); - CertificateValidationContext certContext = - CertificateValidationContext.newBuilder() - .setTrustedCa(DataSource.newBuilder().setFilename(trustCa).build()) - .build(); - - CommonTlsContext commonTlsContext1 = - CommonTlsContext.newBuilder().setValidationContext(certContext).build(); - - UpstreamTlsContext upstreamTlsContext = - UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext1).build(); - buildClientAndTest(upstreamTlsContext, "foo.test.google.fr", "buddy", server.getPort()); - } - - /** mTLS - client auth enabled. */ - @Test - public void mtlsClientServer_withClientAuthentication() throws IOException, InterruptedException { - String server1Pem = TestUtils.loadCert("server1.pem").getAbsolutePath(); - String server1Key = TestUtils.loadCert("server1.key").getAbsolutePath(); - String trustCa = TestUtils.loadCert("ca.pem").getAbsolutePath(); - - TlsCertificate tlsCert = - TlsCertificate.newBuilder() - .setPrivateKey(DataSource.newBuilder().setFilename(server1Key).build()) - .setCertificateChain(DataSource.newBuilder().setFilename(server1Pem).build()) - .build(); - - CertificateValidationContext certContext = - CertificateValidationContext.newBuilder() - .setTrustedCa(DataSource.newBuilder().setFilename(trustCa).build()) - .build(); - - CommonTlsContext commonTlsContext = - CommonTlsContext.newBuilder() - .addTlsCertificates(tlsCert) - .setValidationContext(certContext) - .build(); - - DownstreamTlsContext downstreamTlsContext = - DownstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setRequireClientCertificate(BoolValue.of(false)) - .build(); - - Server server = getXdsServer(downstreamTlsContext); - - String clientPem = TestUtils.loadCert("client.pem").getAbsolutePath(); - String clientKey = TestUtils.loadCert("client.key").getAbsolutePath(); - - TlsCertificate tlsCert1 = - TlsCertificate.newBuilder() - .setPrivateKey(DataSource.newBuilder().setFilename(clientKey).build()) - .setCertificateChain(DataSource.newBuilder().setFilename(clientPem).build()) - .build(); - - CommonTlsContext commonTlsContext1 = - CommonTlsContext.newBuilder() - .addTlsCertificates(tlsCert1) - .setValidationContext(certContext) - .build(); - - UpstreamTlsContext upstreamTlsContext = - UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext1).build(); - - buildClientAndTest(upstreamTlsContext, "foo.test.google.fr", "buddy", server.getPort()); - } - - private Server getXdsServer(DownstreamTlsContext downstreamTlsContext) throws IOException { - XdsServerBuilder serverBuilder = - XdsServerBuilder.forPort(0) // get unused port - .addService(new SimpleServiceImpl()) - .tlsContext(downstreamTlsContext); - return cleanupRule.register(serverBuilder.build()).start(); - } - - private void buildClientAndTest( - UpstreamTlsContext upstreamTlsContext, - String overrideAuthority, - String requestMessage, - int serverPort) { - - XdsChannelBuilder builder = - XdsChannelBuilder.forTarget("localhost:" + serverPort).tlsContext(upstreamTlsContext); - if (overrideAuthority != null) { - builder = builder.overrideAuthority(overrideAuthority); - } - SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - SimpleServiceGrpc.newBlockingStub(cleanupRule.register(builder.build())); - String resp = unaryRpc(requestMessage, blockingStub); - assertThat(resp).isEqualTo("Hello " + requestMessage); - } - - /** Say hello to server. */ - private static String unaryRpc( - String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) { - SimpleRequest request = SimpleRequest.newBuilder().setRequestMessage(requestMessage).build(); - SimpleResponse response = blockingStub.unaryRpc(request); - return response.getResponseMessage(); - } - - private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { - - @Override - public void unaryRpc(SimpleRequest req, StreamObserver responseObserver) { - SimpleResponse response = - SimpleResponse.newBuilder() - .setResponseMessage("Hello " + req.getRequestMessage()) - .build(); - responseObserver.onNext(response); - responseObserver.onCompleted(); - } - } -} diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactoryTest.java index 466c90c322..72baa2df9e 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactoryTest.java @@ -17,6 +17,11 @@ package io.grpc.xds.internal.sds.trust; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import com.google.protobuf.ByteString; import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext; @@ -36,21 +41,6 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class SdsTrustManagerFactoryTest { - /** Trust store cert. */ - private static final String CA_PEM_FILE = "ca.pem"; - - /** server cert. */ - private static final String SERVER_1_PEM_FILE = "server1.pem"; - - /** client cert. */ - private static final String CLIENT_PEM_FILE = "client.pem"; - - /** bad server cert. */ - private static final String BAD_SERVER_PEM_FILE = "badserver.pem"; - - /** bad client cert. */ - private static final String BAD_CLIENT_PEM_FILE = "badclient.pem"; - @Test public void constructor_fromFile() throws CertificateException, IOException, CertStoreException { SdsTrustManagerFactory factory = diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManagerTest.java index 78cadd4453..66ad6f8ad6 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManagerTest.java @@ -17,6 +17,10 @@ package io.grpc.xds.internal.sds.trust; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import static org.junit.Assert.fail; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.doReturn; @@ -52,18 +56,6 @@ import sun.security.validator.ValidatorException; @RunWith(JUnit4.class) public class SdsX509TrustManagerTest { - /** Trust store cert. */ - private static final String CA_PEM_FILE = "ca.pem"; - - /** server1 has 4 SANs. */ - private static final String SERVER_1_PEM_FILE = "server1.pem"; - - /** client has no SANs. */ - private static final String CLIENT_PEM_FILE = "client.pem"; - - /** Untrusted server. */ - private static final String BAD_SERVER_PEM_FILE = "badserver.pem"; - @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule();