xds: Client and server proto negotiators and handlers added to SdsProtocolNegotiators (#6319)

This commit is contained in:
sanjaypujare 2019-10-24 15:27:53 -07:00 committed by GitHub
parent 30f8f26f7a
commit 48b41dce9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 774 additions and 16 deletions

View File

@ -67,6 +67,82 @@ public final class InternalProtocolNegotiators {
return new TlsNegotiator();
}
/**
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will be
* negotiated, the server TLS {@code handler} is added and writes to the {@link
* io.netty.channel.Channel} may happen immediately, even before the TLS Handshake is complete.
*/
public static InternalProtocolNegotiator.ProtocolNegotiator serverTls(SslContext sslContext) {
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.serverTls(sslContext);
final class ServerTlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
@Override
public AsciiString scheme() {
return negotiator.scheme();
}
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
return negotiator.newHandler(grpcHandler);
}
@Override
public void close() {
negotiator.close();
}
}
return new ServerTlsNegotiator();
}
/** Returns a {@link ProtocolNegotiator} for plaintext client channel. */
public static InternalProtocolNegotiator.ProtocolNegotiator plaintext() {
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.plaintext();
final class PlaintextNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
@Override
public AsciiString scheme() {
return negotiator.scheme();
}
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
return negotiator.newHandler(grpcHandler);
}
@Override
public void close() {
negotiator.close();
}
}
return new PlaintextNegotiator();
}
/** Returns a {@link ProtocolNegotiator} for plaintext server channel. */
public static InternalProtocolNegotiator.ProtocolNegotiator serverPlaintext() {
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.serverPlaintext();
final class ServerPlaintextNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
@Override
public AsciiString scheme() {
return negotiator.scheme();
}
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
return negotiator.newHandler(grpcHandler);
}
@Override
public void close() {
negotiator.close();
}
}
return new ServerPlaintextNegotiator();
}
/**
* Internal version of {@link WaitUntilActiveHandler}.
*/

View File

@ -16,14 +16,17 @@
package io.grpc.xds.sds;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.grpc.ExperimentalApi;
import io.grpc.ForwardingChannelBuilder;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.xds.sds.internal.SdsProtocolNegotiators;
import java.net.SocketAddress;
import javax.annotation.CheckReturnValue;
import javax.annotation.Nullable;
/**
* A version of {@link ManagedChannelBuilder} to create xDS managed channels that will use SDS to
@ -34,9 +37,11 @@ public final class XdsChannelBuilder extends ForwardingChannelBuilder<XdsChannel
private final NettyChannelBuilder delegate;
// TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS
@Nullable private UpstreamTlsContext upstreamTlsContext;
private XdsChannelBuilder(NettyChannelBuilder delegate) {
this.delegate = delegate;
SdsProtocolNegotiators.setProtocolNegotiatorFactory(delegate);
}
/**
@ -66,6 +71,15 @@ public final class XdsChannelBuilder extends ForwardingChannelBuilder<XdsChannel
return new XdsChannelBuilder(NettyChannelBuilder.forTarget(target));
}
/**
* Set the UpstreamTlsContext for this channel. This is a temporary workaround until CDS is
* implemented in the XDS client. Passing {@code null} will fall back to plaintext.
*/
public XdsChannelBuilder tlsContext(@Nullable UpstreamTlsContext upstreamTlsContext) {
this.upstreamTlsContext = upstreamTlsContext;
return this;
}
@Override
protected ManagedChannelBuilder<?> delegate() {
return delegate;
@ -73,6 +87,8 @@ public final class XdsChannelBuilder extends ForwardingChannelBuilder<XdsChannel
@Override
public ManagedChannel build() {
InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
delegate, SdsProtocolNegotiators.clientProtocolNegotiatorFactory(upstreamTlsContext));
return delegate.build();
}
}

View File

@ -16,6 +16,7 @@
package io.grpc.xds.sds;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.grpc.BindableService;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
@ -44,6 +45,9 @@ public final class XdsServerBuilder extends ServerBuilder<XdsServerBuilder> {
private final NettyServerBuilder delegate;
// TODO (sanjaypujare) integrate with xDS client to get downstreamTlsContext from LDS
@Nullable private DownstreamTlsContext downstreamTlsContext;
private XdsServerBuilder(NettyServerBuilder nettyDelegate) {
this.delegate = nettyDelegate;
}
@ -119,6 +123,15 @@ public final class XdsServerBuilder extends ServerBuilder<XdsServerBuilder> {
return this;
}
/**
* Set the DownstreamTlsContext for the server. This is a temporary workaround until integration
* with xDS client is implemented to get LDS. Passing {@code null} will fall back to plaintext.
*/
public XdsServerBuilder tlsContext(@Nullable DownstreamTlsContext downstreamTlsContext) {
this.downstreamTlsContext = downstreamTlsContext;
return this;
}
/** Creates a gRPC server builder for the given port. */
public static XdsServerBuilder forPort(int port) {
NettyServerBuilder nettyDelegate = NettyServerBuilder.forAddress(new InetSocketAddress(port));
@ -128,7 +141,8 @@ public final class XdsServerBuilder extends ServerBuilder<XdsServerBuilder> {
@Override
public Server build() {
// note: doing it in build() will overwrite any previously set ProtocolNegotiator
delegate.protocolNegotiator(SdsProtocolNegotiators.serverProtocolNegotiator());
delegate.protocolNegotiator(
SdsProtocolNegotiators.serverProtocolNegotiator(this.downstreamTlsContext));
return delegate.build();
}
}

View File

@ -16,15 +16,30 @@
package io.grpc.xds.sds.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.grpc.Internal;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory;
import io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiators;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.xds.sds.SecretProvider;
import io.grpc.xds.sds.TlsContextManager;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.ssl.SslContext;
import io.netty.util.AsciiString;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;
/**
* Provides client and server side gRPC {@link ProtocolNegotiator}s that use SDS to provide the SSL
@ -35,12 +50,40 @@ public final class SdsProtocolNegotiators {
private static final AsciiString SCHEME = AsciiString.of("https");
/**
* Returns a {@link ProtocolNegotiatorFactory} to be used on {@link NettyChannelBuilder}. Passing
* {@code null} for upstreamTlsContext will fall back to plaintext.
*/
// TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS
public static ProtocolNegotiatorFactory clientProtocolNegotiatorFactory(
@Nullable UpstreamTlsContext upstreamTlsContext) {
return new ClientSdsProtocolNegotiatorFactory(upstreamTlsContext);
}
/**
* Creates an SDS based {@link ProtocolNegotiator} for a {@link io.grpc.netty.NettyServerBuilder}.
* Passing {@code null} for downstreamTlsContext will fall back to plaintext.
*/
// TODO (sanjaypujare) integrate with xDS client to get LDS
public static ProtocolNegotiator serverProtocolNegotiator(
@Nullable DownstreamTlsContext downstreamTlsContext) {
return new ServerSdsProtocolNegotiator(downstreamTlsContext);
}
private static final class ClientSdsProtocolNegotiatorFactory
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
// TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS
private final UpstreamTlsContext upstreamTlsContext;
ClientSdsProtocolNegotiatorFactory(UpstreamTlsContext upstreamTlsContext) {
this.upstreamTlsContext = upstreamTlsContext;
}
@Override
public InternalProtocolNegotiator.ProtocolNegotiator buildProtocolNegotiator() {
final ClientSdsProtocolNegotiator negotiator = new ClientSdsProtocolNegotiator();
final ClientSdsProtocolNegotiator negotiator =
new ClientSdsProtocolNegotiator(upstreamTlsContext);
final class LocalSdsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
@Override
@ -63,7 +106,15 @@ public final class SdsProtocolNegotiators {
}
}
private static final class ClientSdsProtocolNegotiator implements ProtocolNegotiator {
@VisibleForTesting
static final class ClientSdsProtocolNegotiator implements ProtocolNegotiator {
// TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS
UpstreamTlsContext upstreamTlsContext;
ClientSdsProtocolNegotiator(UpstreamTlsContext upstreamTlsContext) {
this.upstreamTlsContext = upstreamTlsContext;
}
@Override
public AsciiString scheme() {
@ -72,16 +123,111 @@ public final class SdsProtocolNegotiators {
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
// TODO(sanjaypujare): once implemented return ClientSdsHandler
throw new UnsupportedOperationException("Not implemented yet");
// once CDS is implemented we will retrieve upstreamTlsContext as follows:
// grpcHandler.getEagAttributes().get(XdsAttributes.ATTR_UPSTREAM_TLS_CONTEXT);
if (isTlsContextEmpty(upstreamTlsContext)) {
return InternalProtocolNegotiators.plaintext().newHandler(grpcHandler);
}
return new ClientSdsHandler(grpcHandler, upstreamTlsContext);
}
private static boolean isTlsContextEmpty(UpstreamTlsContext upstreamTlsContext) {
return upstreamTlsContext == null || !upstreamTlsContext.hasCommonTlsContext();
}
@Override
public void close() {}
}
private static class BufferReadsHandler extends ChannelInboundHandlerAdapter {
private final List<Object> reads = new ArrayList<>();
private boolean readComplete;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
reads.add(msg);
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) {
readComplete = true;
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
for (Object msg : reads) {
super.channelRead(ctx, msg);
}
if (readComplete) {
super.channelReadComplete(ctx);
}
}
}
@VisibleForTesting
static final class ClientSdsHandler
extends InternalProtocolNegotiators.ProtocolNegotiationHandler {
private final GrpcHttp2ConnectionHandler grpcHandler;
private final UpstreamTlsContext upstreamTlsContext;
ClientSdsHandler(
GrpcHttp2ConnectionHandler grpcHandler, UpstreamTlsContext upstreamTlsContext) {
super(
// superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next'
// handler but we don't have a next handler _yet_. So we "disable" superclass's behavior
// here and then manually add 'next' when we call fireProtocolNegotiationEvent()
new ChannelHandlerAdapter() {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
ctx.pipeline().remove(this);
}
});
checkNotNull(grpcHandler, "grpcHandler");
this.grpcHandler = grpcHandler;
this.upstreamTlsContext = upstreamTlsContext;
}
@Override
protected void handlerAdded0(final ChannelHandlerContext ctx) {
final BufferReadsHandler bufferReads = new BufferReadsHandler();
ctx.pipeline().addBefore(ctx.name(), null, bufferReads);
SecretProvider<SslContext> sslContextProvider =
TlsContextManager.getInstance().findOrCreateClientSslContextProvider(upstreamTlsContext);
sslContextProvider.addCallback(
new SecretProvider.Callback<SslContext>() {
@Override
public void updateSecret(SslContext sslContext) {
ChannelHandler handler =
InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler);
// Delegate rest of handshake to TLS handler
ctx.pipeline().addAfter(ctx.name(), null, handler);
fireProtocolNegotiationEvent(ctx);
ctx.pipeline().remove(bufferReads);
}
@Override
public void onException(Throwable throwable) {
ctx.fireExceptionCaught(throwable);
}
},
ctx.executor());
}
}
private static final class ServerSdsProtocolNegotiator implements ProtocolNegotiator {
// TODO (sanjaypujare) integrate with xDS client to get LDS. LDS watcher will
// inject/update the downstreamTlsContext from LDS
private DownstreamTlsContext downstreamTlsContext;
ServerSdsProtocolNegotiator(DownstreamTlsContext downstreamTlsContext) {
this.downstreamTlsContext = downstreamTlsContext;
}
@Override
public AsciiString scheme() {
return SCHEME;
@ -89,22 +235,72 @@ public final class SdsProtocolNegotiators {
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
// TODO(sanjaypujare): once implemented return ServerSdsHandler
throw new UnsupportedOperationException("Not implemented yet");
if (isTlsContextEmpty(downstreamTlsContext)) {
return InternalProtocolNegotiators.serverPlaintext().newHandler(grpcHandler);
}
return new ServerSdsHandler(grpcHandler, downstreamTlsContext);
}
private static boolean isTlsContextEmpty(DownstreamTlsContext downstreamTlsContext) {
return downstreamTlsContext == null || !downstreamTlsContext.hasCommonTlsContext();
}
@Override
public void close() {}
}
/** Sets the {@link ProtocolNegotiatorFactory} on a NettyChannelBuilder. */
public static void setProtocolNegotiatorFactory(NettyChannelBuilder builder) {
InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
builder, new ClientSdsProtocolNegotiatorFactory());
}
@VisibleForTesting
static final class ServerSdsHandler
extends InternalProtocolNegotiators.ProtocolNegotiationHandler {
private final GrpcHttp2ConnectionHandler grpcHandler;
private final DownstreamTlsContext downstreamTlsContext;
/** Creates an SDS based {@link ProtocolNegotiator} for a server. */
public static ProtocolNegotiator serverProtocolNegotiator() {
return new ServerSdsProtocolNegotiator();
ServerSdsHandler(
GrpcHttp2ConnectionHandler grpcHandler, DownstreamTlsContext downstreamTlsContext) {
super(
// superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next'
// handler but we don't have a next handler _yet_. So we "disable" superclass's behavior
// here and then manually add 'next' when we call fireProtocolNegotiationEvent()
new ChannelHandlerAdapter() {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
ctx.pipeline().remove(this);
}
});
checkNotNull(grpcHandler, "grpcHandler");
this.grpcHandler = grpcHandler;
this.downstreamTlsContext = downstreamTlsContext;
}
@Override
protected void handlerAdded0(final ChannelHandlerContext ctx) {
final BufferReadsHandler bufferReads = new BufferReadsHandler();
ctx.pipeline().addBefore(ctx.name(), null, bufferReads);
SecretProvider<SslContext> sslContextProvider =
TlsContextManager.getInstance()
.findOrCreateServerSslContextProvider(downstreamTlsContext);
sslContextProvider.addCallback(
new SecretProvider.Callback<SslContext>() {
@Override
public void updateSecret(SslContext sslContext) {
ChannelHandler handler =
InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHandler);
// Delegate rest of handshake to TLS handler
ctx.pipeline().addAfter(ctx.name(), null, handler);
fireProtocolNegotiationEvent(ctx);
ctx.pipeline().remove(bufferReads);
}
@Override
public void onException(Throwable throwable) {
ctx.fireExceptionCaught(throwable);
}
},
ctx.executor());
}
}
}

View File

@ -0,0 +1,193 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.sds;
import static com.google.common.truth.Truth.assertThat;
import com.google.protobuf.BoolValue;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.envoyproxy.envoy.api.v2.core.DataSource;
import io.grpc.Server;
import io.grpc.internal.testing.TestUtils;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.protobuf.SimpleRequest;
import io.grpc.testing.protobuf.SimpleResponse;
import io.grpc.testing.protobuf.SimpleServiceGrpc;
import java.io.IOException;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/**
* Unit tests for {@link XdsChannelBuilder} and {@link XdsServerBuilder} for plaintext/TLS/mTLS
* modes.
*/
@RunWith(JUnit4.class)
public class XdsSdsClientServerTest {
@Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
@Test
public void plaintextClientServer() throws IOException {
Server server = getXdsServer(/* downstreamTlsContext= */ null);
buildClientAndTest(
/* upstreamTlsContext= */ null, /* overrideAuthority= */ null, "buddy", server.getPort());
}
/** TLS channel - no mTLS. */
@Test
public void tlsClientServer_noClientAuthentication() throws IOException {
String server1Pem = TestUtils.loadCert("server1.pem").getAbsolutePath();
String server1Key = TestUtils.loadCert("server1.key").getAbsolutePath();
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setPrivateKey(DataSource.newBuilder().setFilename(server1Key).build())
.setCertificateChain(DataSource.newBuilder().setFilename(server1Pem).build())
.build();
CommonTlsContext commonTlsContext =
CommonTlsContext.newBuilder().addTlsCertificates(tlsCert).build();
DownstreamTlsContext downstreamTlsContext =
DownstreamTlsContext.newBuilder()
.setCommonTlsContext(commonTlsContext)
.setRequireClientCertificate(BoolValue.of(false))
.build();
Server server = getXdsServer(downstreamTlsContext);
// for TLS client doesn't need cert but needs trustCa
String trustCa = TestUtils.loadCert("ca.pem").getAbsolutePath();
CertificateValidationContext certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename(trustCa).build())
.build();
CommonTlsContext commonTlsContext1 =
CommonTlsContext.newBuilder().setValidationContext(certContext).build();
UpstreamTlsContext upstreamTlsContext =
UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext1).build();
buildClientAndTest(upstreamTlsContext, "foo.test.google.fr", "buddy", server.getPort());
}
/** mTLS - client auth enabled. */
@Test
public void mtlsClientServer_withClientAuthentication() throws IOException, InterruptedException {
String server1Pem = TestUtils.loadCert("server1.pem").getAbsolutePath();
String server1Key = TestUtils.loadCert("server1.key").getAbsolutePath();
String trustCa = TestUtils.loadCert("ca.pem").getAbsolutePath();
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setPrivateKey(DataSource.newBuilder().setFilename(server1Key).build())
.setCertificateChain(DataSource.newBuilder().setFilename(server1Pem).build())
.build();
CertificateValidationContext certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename(trustCa).build())
.build();
CommonTlsContext commonTlsContext =
CommonTlsContext.newBuilder()
.addTlsCertificates(tlsCert)
.setValidationContext(certContext)
.build();
DownstreamTlsContext downstreamTlsContext =
DownstreamTlsContext.newBuilder()
.setCommonTlsContext(commonTlsContext)
.setRequireClientCertificate(BoolValue.of(false))
.build();
Server server = getXdsServer(downstreamTlsContext);
String clientPem = TestUtils.loadCert("client.pem").getAbsolutePath();
String clientKey = TestUtils.loadCert("client.key").getAbsolutePath();
TlsCertificate tlsCert1 =
TlsCertificate.newBuilder()
.setPrivateKey(DataSource.newBuilder().setFilename(clientKey).build())
.setCertificateChain(DataSource.newBuilder().setFilename(clientPem).build())
.build();
CommonTlsContext commonTlsContext1 =
CommonTlsContext.newBuilder()
.addTlsCertificates(tlsCert1)
.setValidationContext(certContext)
.build();
UpstreamTlsContext upstreamTlsContext =
UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext1).build();
buildClientAndTest(upstreamTlsContext, "foo.test.google.fr", "buddy", server.getPort());
}
private Server getXdsServer(DownstreamTlsContext downstreamTlsContext) throws IOException {
XdsServerBuilder serverBuilder =
XdsServerBuilder.forPort(0) // get unused port
.addService(new SimpleServiceImpl())
.tlsContext(downstreamTlsContext);
return cleanupRule.register(serverBuilder.build()).start();
}
private void buildClientAndTest(
UpstreamTlsContext upstreamTlsContext,
String overrideAuthority,
String requestMessage,
int serverPort) {
XdsChannelBuilder builder =
XdsChannelBuilder.forTarget("localhost:" + serverPort).tlsContext(upstreamTlsContext);
if (overrideAuthority != null) {
builder = builder.overrideAuthority(overrideAuthority);
}
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
SimpleServiceGrpc.newBlockingStub(cleanupRule.register(builder.build()));
String resp = unaryRpc(requestMessage, blockingStub);
assertThat(resp).isEqualTo("Hello " + requestMessage);
}
/** Say hello to server. */
private static String unaryRpc(
String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) {
SimpleRequest request = SimpleRequest.newBuilder().setRequestMessage(requestMessage).build();
SimpleResponse response = blockingStub.unaryRpc(request);
return response.getResponseMessage();
}
private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase {
@Override
public void unaryRpc(SimpleRequest req, StreamObserver<SimpleResponse> responseObserver) {
SimpleResponse response =
SimpleResponse.newBuilder()
.setResponseMessage("Hello " + req.getRequestMessage())
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
}
}
}

View File

@ -0,0 +1,263 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.sds.internal;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import com.google.common.base.Strings;
import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.envoyproxy.envoy.api.v2.core.DataSource;
import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalProtocolNegotiationEvent;
import io.grpc.xds.sds.internal.SdsProtocolNegotiators.ClientSdsHandler;
import io.grpc.xds.sds.internal.SdsProtocolNegotiators.ClientSdsProtocolNegotiator;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http2.DefaultHttp2Connection;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
import io.netty.handler.codec.http2.Http2ConnectionDecoder;
import io.netty.handler.codec.http2.Http2ConnectionEncoder;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import java.io.IOException;
import java.util.Iterator;
import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link SdsProtocolNegotiators}. */
@RunWith(JUnit4.class)
public class SdsProtocolNegotiatorsTest {
private static final String SERVER_1_PEM_FILE = "server1.pem";
private static final String SERVER_1_KEY_FILE = "server1.key";
private static final String CLIENT_PEM_FILE = "client.pem";
private static final String CLIENT_KEY_FILE = "client.key";
private static final String CA_PEM_FILE = "ca.pem";
private final GrpcHttp2ConnectionHandler grpcHandler =
FakeGrpcHttp2ConnectionHandler.newHandler();
private EmbeddedChannel channel = new EmbeddedChannel();
private ChannelPipeline pipeline = channel.pipeline();
private ChannelHandlerContext channelHandlerCtx;
private static String getTempFileNameForResourcesFile(String resFile) throws IOException {
return Strings.isNullOrEmpty(resFile) ? null : TestUtils.loadCert(resFile).getAbsolutePath();
}
/** Builds DownstreamTlsContext from file-names. */
private static DownstreamTlsContext buildDownstreamTlsContextFromFilenames(
String privateKey, String certChain, String trustCa) throws IOException {
return buildDownstreamTlsContext(
buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
}
/** Builds UpstreamTlsContext from file-names. */
private static UpstreamTlsContext buildUpstreamTlsContextFromFilenames(
String privateKey, String certChain, String trustCa) throws IOException {
return buildUpstreamTlsContext(
buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
}
/** Builds UpstreamTlsContext from commonTlsContext. */
private static UpstreamTlsContext buildUpstreamTlsContext(CommonTlsContext commonTlsContext) {
UpstreamTlsContext upstreamTlsContext =
UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build();
return upstreamTlsContext;
}
/** Builds DownstreamTlsContext from commonTlsContext. */
private static DownstreamTlsContext buildDownstreamTlsContext(CommonTlsContext commonTlsContext) {
DownstreamTlsContext downstreamTlsContext =
DownstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build();
return downstreamTlsContext;
}
private static CommonTlsContext buildCommonTlsContextFromFilenames(
String privateKey, String certChain, String trustCa) throws IOException {
TlsCertificate tlsCert = null;
privateKey = getTempFileNameForResourcesFile(privateKey);
certChain = getTempFileNameForResourcesFile(certChain);
trustCa = getTempFileNameForResourcesFile(trustCa);
if (!Strings.isNullOrEmpty(privateKey) && !Strings.isNullOrEmpty(certChain)) {
tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setFilename(certChain))
.setPrivateKey(DataSource.newBuilder().setFilename(privateKey))
.build();
}
CertificateValidationContext certContext = null;
if (!Strings.isNullOrEmpty(trustCa)) {
certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename(trustCa))
.build();
}
return getCommonTlsContext(tlsCert, certContext);
}
private static CommonTlsContext getCommonTlsContext(
TlsCertificate tlsCertificate, CertificateValidationContext certContext) {
CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
if (tlsCertificate != null) {
builder = builder.addTlsCertificates(tlsCertificate);
}
if (certContext != null) {
builder = builder.setValidationContext(certContext);
}
return builder.build();
}
@Test
public void clientSdsProtocolNegotiatorNewHandler_nullTlsContext() {
ClientSdsProtocolNegotiator pn =
new ClientSdsProtocolNegotiator(/* upstreamTlsContext= */ null);
ChannelHandler newHandler = pn.newHandler(grpcHandler);
assertThat(newHandler).isNotNull();
// ProtocolNegotiators.WaitUntilActiveHandler not accessible, get canonical name
assertThat(newHandler.getClass().getCanonicalName())
.contains("io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler");
}
@Test
public void clientSdsProtocolNegotiatorNewHandler_nonNullTlsContext() {
UpstreamTlsContext upstreamTlsContext =
buildUpstreamTlsContext(getCommonTlsContext(null, null));
ClientSdsProtocolNegotiator pn = new ClientSdsProtocolNegotiator(upstreamTlsContext);
ChannelHandler newHandler = pn.newHandler(grpcHandler);
assertThat(newHandler).isNotNull();
assertThat(newHandler).isInstanceOf(ClientSdsHandler.class);
}
@Test
public void clientSdsHandler_addLast() throws IOException {
UpstreamTlsContext upstreamTlsContext =
buildUpstreamTlsContextFromFilenames(CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler =
new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, upstreamTlsContext);
pipeline.addLast(clientSdsHandler);
channelHandlerCtx = pipeline.context(clientSdsHandler);
assertNotNull(channelHandlerCtx); // clientSdsHandler ctx is non-null since we just added it
// kick off protocol negotiation.
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
channel.runPendingTasks(); // need this for tasks to execute on eventLoop
channelHandlerCtx = pipeline.context(clientSdsHandler);
assertThat(channelHandlerCtx).isNull();
// pipeline should have SslHandler and ClientTlsHandler
Iterator<Map.Entry<String, ChannelHandler>> iterator = pipeline.iterator();
assertThat(iterator.next().getValue()).isInstanceOf(SslHandler.class);
// ProtocolNegotiators.ClientTlsHandler.class not accessible, get canonical name
assertThat(iterator.next().getValue().getClass().getCanonicalName())
.contains("ProtocolNegotiators.ClientTlsHandler");
}
@Test
public void serverSdsHandler_addLast() throws IOException {
DownstreamTlsContext downstreamTlsContext =
buildDownstreamTlsContextFromFilenames(SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
SdsProtocolNegotiators.ServerSdsHandler serverSdsHandler =
new SdsProtocolNegotiators.ServerSdsHandler(grpcHandler, downstreamTlsContext);
pipeline.addLast(serverSdsHandler);
channelHandlerCtx = pipeline.context(serverSdsHandler);
assertNotNull(channelHandlerCtx); // serverSdsHandler ctx is non-null since we just added it
// kick off protocol negotiation
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
channel.runPendingTasks(); // need this for tasks to execute on eventLoop
channelHandlerCtx = pipeline.context(serverSdsHandler);
assertThat(channelHandlerCtx).isNull();
// pipeline should have SslHandler and ServerTlsHandler
Iterator<Map.Entry<String, ChannelHandler>> iterator = pipeline.iterator();
assertThat(iterator.next().getValue()).isInstanceOf(SslHandler.class);
// ProtocolNegotiators.ServerTlsHandler.class is not accessible, get canonical name
assertThat(iterator.next().getValue().getClass().getCanonicalName())
.contains("ProtocolNegotiators.ServerTlsHandler");
}
@Test
public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent()
throws IOException, InterruptedException {
UpstreamTlsContext upstreamTlsContext =
buildUpstreamTlsContextFromFilenames(CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler =
new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, upstreamTlsContext);
pipeline.addLast(clientSdsHandler);
channelHandlerCtx = pipeline.context(clientSdsHandler);
assertNotNull(channelHandlerCtx); // non-null since we just added it
// kick off protocol negotiation.
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
channel.runPendingTasks(); // need this for tasks to execute on eventLoop
channelHandlerCtx = pipeline.context(clientSdsHandler);
assertThat(channelHandlerCtx).isNull();
Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
pipeline.fireUserEventTriggered(sslEvent);
channel.runPendingTasks(); // need this for tasks to execute on eventLoop
assertTrue(channel.isOpen());
}
private static final class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
FakeGrpcHttp2ConnectionHandler(
ChannelPromise channelUnused,
Http2ConnectionDecoder decoder,
Http2ConnectionEncoder encoder,
Http2Settings initialSettings) {
super(channelUnused, decoder, encoder, initialSettings);
}
static FakeGrpcHttp2ConnectionHandler newHandler() {
DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false);
DefaultHttp2ConnectionEncoder encoder =
new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter());
DefaultHttp2ConnectionDecoder decoder =
new DefaultHttp2ConnectionDecoder(conn, encoder, new DefaultHttp2FrameReader());
Http2Settings settings = new Http2Settings();
return new FakeGrpcHttp2ConnectionHandler(
/*channelUnused=*/ null, decoder, encoder, settings);
}
@Override
public String getAuthority() {
return "authority";
}
}
}