Add initial Protocol Negotiation tests

This commit is contained in:
Carl Mastrangelo 2015-10-14 15:51:34 -07:00
parent 2c2d7171ec
commit dd815bc968
2 changed files with 264 additions and 36 deletions

View File

@ -82,52 +82,61 @@ public final class ProtocolNegotiators {
/**
* Create a TLS handler for HTTP/2 capable of using ALPN/NPN.
*/
public static ChannelHandler serverTls(SSLEngine sslEngine, final ChannelHandler grpcHandler) {
public static ChannelHandler serverTls(SSLEngine sslEngine, ChannelHandler grpcHandler) {
Preconditions.checkNotNull(sslEngine, "sslEngine");
return new TlsChannelInboundHandlerAdapter(new SslHandler(sslEngine, false), grpcHandler);
}
final SslHandler sslHandler = new SslHandler(sslEngine, false);
return new ChannelInboundHandlerAdapter() {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
super.handlerAdded(ctx);
ctx.pipeline().addFirst(sslHandler);
}
@VisibleForTesting
static final class TlsChannelInboundHandlerAdapter extends ChannelInboundHandlerAdapter {
private final ChannelHandler grpcHandler;
private final SslHandler sslHandler;
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
fail(ctx, cause);
}
TlsChannelInboundHandlerAdapter(SslHandler sslHandler, ChannelHandler grpcHandler) {
this.sslHandler = sslHandler;
this.grpcHandler = grpcHandler;
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
if (handshakeEvent.isSuccess()) {
if (HTTP2_VERSIONS.contains(sslHandler(ctx).applicationProtocol())) {
// Successfully negotiated the protocol. Replace this handler with
// the GRPC handler.
ctx.pipeline().replace(this, null, grpcHandler);
} else {
fail(ctx, new Exception(
"Failed protocol negotiation: Unable to find compatible protocol."));
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
super.handlerAdded(ctx);
ctx.pipeline().addFirst(sslHandler);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
fail(ctx, cause);
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
if (handshakeEvent.isSuccess()) {
if (HTTP2_VERSIONS.contains(sslHandler(ctx).applicationProtocol())) {
// Successfully negotiated the protocol. Replace this handler with
// the GRPC handler.
ctx.pipeline().replace(this, null, grpcHandler);
} else {
fail(ctx, handshakeEvent.cause());
fail(ctx, new Exception(
"Failed protocol negotiation: Unable to find compatible protocol."));
}
} else {
fail(ctx, handshakeEvent.cause());
}
super.userEventTriggered(ctx, evt);
}
super.userEventTriggered(ctx, evt);
}
private void fail(ChannelHandlerContext ctx, Throwable exception) {
logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed for new client.", exception);
ctx.close();
}
private void fail(ChannelHandlerContext ctx, Throwable exception) {
logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed for new client.", exception);
ctx.close();
}
private SslHandler sslHandler(ChannelHandlerContext ctx) {
return ctx.pipeline().get(SslHandler.class);
}
};
private SslHandler sslHandler(ChannelHandlerContext ctx) {
return ctx.pipeline().get(SslHandler.class);
}
}
/**
@ -235,7 +244,8 @@ public final class ProtocolNegotiators {
return Status.UNAVAILABLE.withDescription(msg).asRuntimeException();
}
private static void logSslEngineDetails(Level level, ChannelHandlerContext ctx, String msg,
@VisibleForTesting
static void logSslEngineDetails(Level level, ChannelHandlerContext ctx, String msg,
@Nullable Throwable t) {
if (!log.isLoggable(level)) {
return;

View File

@ -0,0 +1,218 @@
/*
* Copyright 2015, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.netty;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import com.google.common.collect.Iterables;
import io.grpc.netty.ProtocolNegotiators.TlsChannelInboundHandlerAdapter;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
import java.util.logging.Filter;
import java.util.logging.Level;
import java.util.logging.LogRecord;
import java.util.logging.Logger;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
@RunWith(JUnit4.class)
public class ProtocolNegotiatorsTest {
@Rule public final ExpectedException thrown = ExpectedException.none();
@Rule public final MockitoRule mocks = MockitoJUnit.rule();
@Mock private ChannelHandler grpcHandler;
private EmbeddedChannel channel = new EmbeddedChannel();
private ChannelPipeline pipeline = channel.pipeline();
private SslHandler sslHandler;
private SSLEngine engine;
private ChannelHandlerContext channelHandlerCtx;
@Before
public void setUp() throws Exception {
engine = SSLContext.getDefault().createSSLEngine();
sslHandler = new SslHandler(engine, false) {
@Override
public String applicationProtocol() {
// Just get any of them.
return Iterables.getFirst(GrpcSslContexts.HTTP2_VERSIONS, "");
}
};
}
@Test
public void tlsHandler_failsOnNullEngine() throws Exception {
thrown.expect(NullPointerException.class);
thrown.expectMessage("ssl");
ProtocolNegotiators.serverTls(null, null);
}
@Test
public void tlsAdapter_exceptionClosesChannel() throws Exception {
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(sslHandler, grpcHandler);
// Use addFirst due to the funny error handling in EmbeddedChannel.
pipeline.addFirst(handler);
pipeline.fireExceptionCaught(new Exception("bad"));
assertFalse(channel.isOpen());
}
@Test
public void tlsHandler_handlerAddedAddsSslHandler() throws Exception {
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(sslHandler, grpcHandler);
pipeline.addLast(handler);
assertEquals(sslHandler, pipeline.first());
}
@Test
public void tlsHandler_userEventTriggeredNonSslEvent() throws Exception {
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(sslHandler, grpcHandler);
pipeline.addLast(handler);
channelHandlerCtx = pipeline.context(handler);
Object nonSslEvent = new Object();
pipeline.fireUserEventTriggered(nonSslEvent);
// A non ssl event should not cause the grpcHandler to be in the pipeline yet.
ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
assertNull(grpcHandlerCtx);
}
@Test
public void tlsHandler_userEventTriggeredSslEvent_unsupportedProtocol() throws Exception {
SslHandler badSslHandler = new SslHandler(engine, false) {
@Override
public String applicationProtocol() {
return "badprotocol";
}
};
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(badSslHandler, grpcHandler);
pipeline.addLast(handler);
channelHandlerCtx = pipeline.context(handler);
Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
pipeline.fireUserEventTriggered(sslEvent);
// No h2 protocol was specified, so this should be closed.
assertFalse(channel.isOpen());
ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
assertNull(grpcHandlerCtx);
}
@Test
public void tlsHandler_userEventTriggeredSslEvent_handshakeFailure() throws Exception {
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(sslHandler, grpcHandler);
pipeline.addLast(handler);
channelHandlerCtx = pipeline.context(handler);
Object sslEvent = new SslHandshakeCompletionEvent(new RuntimeException("bad"));
pipeline.fireUserEventTriggered(sslEvent);
// No h2 protocol was specified, so this should be closed.
assertFalse(channel.isOpen());
ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
assertNull(grpcHandlerCtx);
}
@Test
public void tlsHandler_userEventTriggeredSslEvent_supportedProtocol() throws Exception {
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(sslHandler, grpcHandler);
pipeline.addLast(handler);
channelHandlerCtx = pipeline.context(handler);
Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
pipeline.fireUserEventTriggered(sslEvent);
assertTrue(channel.isOpen());
ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
assertNotNull(grpcHandlerCtx);
}
@Test
public void engineLog() {
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(sslHandler, grpcHandler);
pipeline.addLast(handler);
channelHandlerCtx = pipeline.context(handler);
Logger logger = Logger.getLogger(ProtocolNegotiators.class.getName());
Filter oldFilter = logger.getFilter();
try {
logger.setFilter(new Filter() {
@Override
public boolean isLoggable(LogRecord record) {
// We still want to the log method to be exercised, just not printed to stderr.
return false;
}
});
ProtocolNegotiators.logSslEngineDetails(
Level.INFO, channelHandlerCtx, "message", new Exception("bad"));
} finally {
logger.setFilter(oldFilter);
}
}
}