From cb486e461db49b70dddb33dd2818c3742f7fc963 Mon Sep 17 00:00:00 2001 From: nmittler Date: Thu, 11 Jun 2015 15:05:48 -0700 Subject: [PATCH] Testing that buffered streams clean up properly upon disconnect. --- .../transport/netty/NettyServerTransport.java | 7 + .../netty/NettyClientTransportTest.java | 157 +++++++++++++----- 2 files changed, 118 insertions(+), 46 deletions(-) diff --git a/netty/src/main/java/io/grpc/transport/netty/NettyServerTransport.java b/netty/src/main/java/io/grpc/transport/netty/NettyServerTransport.java index d08b697c01..83dc0fbd3d 100644 --- a/netty/src/main/java/io/grpc/transport/netty/NettyServerTransport.java +++ b/netty/src/main/java/io/grpc/transport/netty/NettyServerTransport.java @@ -108,6 +108,13 @@ class NettyServerTransport implements ServerTransport { } } + /** + * For testing purposes only. + */ + Channel channel() { + return channel; + } + private void notifyTerminated(Throwable t) { if (t != null) { log.log(Level.SEVERE, "Transport failed", t); diff --git a/netty/src/test/java/io/grpc/transport/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/transport/netty/NettyClientTransportTest.java index f4e504da66..73f89f1f20 100644 --- a/netty/src/test/java/io/grpc/transport/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/transport/netty/NettyClientTransportTest.java @@ -33,6 +33,7 @@ package io.grpc.transport.netty; import static com.google.common.base.Charsets.UTF_8; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; +import static org.junit.Assert.fail; import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.SettableFuture; @@ -42,6 +43,7 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodType; import io.grpc.Status; +import io.grpc.StatusException; import io.grpc.testing.TestUtils; import io.grpc.transport.ClientStream; import io.grpc.transport.ClientStreamListener; @@ -72,14 +74,15 @@ import java.io.InputStream; import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; /** * Tests for {@link NettyClientTransport}. */ @RunWith(JUnit4.class) public class NettyClientTransportTest { - private static final String MESSAGE = "hello"; @Mock private ClientTransport.Listener clientTransportListener; @@ -88,21 +91,14 @@ public class NettyClientTransportTest { private NioEventLoopGroup group; private InetSocketAddress address; private NettyServer server; + private TestServerListener serverListener = new TestServerListener(); @Before public void setup() throws Exception { MockitoAnnotations.initMocks(this); group = new NioEventLoopGroup(1); - - // Start the server. address = TestUtils.testServerAddress(TestUtils.pickUnusedPort()); - File serverCert = TestUtils.loadCert("server1.pem"); - File key = TestUtils.loadCert("server1.key"); - SslContext serverContext = GrpcSslContexts.forServer(serverCert, key).build(); - server = new NettyServer(address, NioServerSocketChannel.class, - group, group, serverContext, 100, DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE); - server.start(new TestServerListener()); } @After @@ -123,50 +119,113 @@ public class NettyClientTransportTest { */ @Test public void creatingMultipleTlsTransportsShouldSucceed() throws Exception { + startServer(Integer.MAX_VALUE); + + // Create a couple client transports. + for (int index = 0; index < 2; ++index) { + NettyClientTransport transport = newTransport(); + transport.start(clientTransportListener); + } + + // Send a single RPC on each transport. + final List rpcs = new ArrayList(transports.size()); + for (NettyClientTransport transport : transports) { + rpcs.add(new Rpc(transport).halfClose()); + } + + // Wait for the RPCs to complete. + for (Rpc rpc : rpcs) { + rpc.waitForResponse(); + } + } + + @Test + public void bufferedStreamsShouldBeClosedWhenConnectionTerminates() throws Exception { + // Only allow a single stream active at a time. + startServer(1); + + NettyClientTransport transport = newTransport(); + transport.start(clientTransportListener); + + // Send a dummy RPC in order to ensure that the updated SETTINGS_MAX_CONCURRENT_STREAMS + // has been received by the remote endpoint. + new Rpc(transport).halfClose().waitForResponse(); + + // Create 3 streams, but don't half-close. The transport will buffer the second and third. + Rpc[] rpcs = new Rpc[] { new Rpc(transport), new Rpc(transport), new Rpc(transport) }; + + // Wait for the response for the stream that was actually created. + rpcs[0].waitForResponse(); + + // Now forcibly terminate the connection from the server side. + serverListener.transports.get(0).channel().pipeline().firstContext().close(); + + // Now wait for both listeners to be closed. + for (Rpc rpc : rpcs) { + try { + rpc.waitForClose(); + fail("Expected the RPC to fail"); + } catch (ExecutionException e) { + // Expected. + } + } + } + + private NettyClientTransport newTransport() throws IOException { // Create the protocol negotiator. File clientCert = TestUtils.loadCert("ca.pem"); SslContext clientContext = GrpcSslContexts.forClient().trustManager(clientCert).build(); ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, address); - // Create a couple client transports. - for (int index = 0; index < 2; ++index) { - NettyClientTransport transport = new NettyClientTransport(address, NioSocketChannel.class, - group, negotiator, DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE); - transports.add(transport); - transport.start(clientTransportListener); - } - - // Send a single RPC on each transport. - final List rpcFutures = new ArrayList(transports.size()); - MethodDescriptor method = MethodDescriptor.create(MethodType.UNARY, - "/testService/test", 10, TimeUnit.SECONDS, StringMarshaller.INSTANCE, - StringMarshaller.INSTANCE); - for (NettyClientTransport transport : transports) { - SettableFuture rpcFuture = SettableFuture.create(); - rpcFutures.add(rpcFuture); - ClientStream stream = transport.newStream(method, new Metadata.Headers(), - new TestClientStreamListener(rpcFuture)); - stream.request(1); - stream.writeMessage(messageStream()); - stream.halfClose(); - } - - // Wait for the RPCs to complete. - for (SettableFuture rpcFuture : rpcFutures) { - rpcFuture.get(10, TimeUnit.SECONDS); - } + NettyClientTransport transport = new NettyClientTransport(address, NioSocketChannel.class, + group, negotiator, DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE); + transports.add(transport); + return transport; } - private static InputStream messageStream() { - return new ByteArrayInputStream(MESSAGE.getBytes()); + private void startServer(int maxStreamsPerConnection) throws IOException { + File serverCert = TestUtils.loadCert("server1.pem"); + File key = TestUtils.loadCert("server1.key"); + SslContext serverContext = GrpcSslContexts.forServer(serverCert, key).build(); + server = new NettyServer(address, NioServerSocketChannel.class, + group, group, serverContext, maxStreamsPerConnection, + DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE); + server.start(serverListener); + } + + private static class Rpc { + static final String MESSAGE = "hello"; + static final MethodDescriptor METHOD = MethodDescriptor.create( + MethodType.UNARY, "/testService/test", 10, TimeUnit.SECONDS, StringMarshaller.INSTANCE, + StringMarshaller.INSTANCE); + + final ClientStream stream; + final TestClientStreamListener listener = new TestClientStreamListener(); + + Rpc(NettyClientTransport transport) { + stream = transport.newStream(METHOD, new Metadata.Headers(), listener); + stream.request(1); + stream.writeMessage(new ByteArrayInputStream(MESSAGE.getBytes())); + stream.flush(); + } + + Rpc halfClose() { + stream.halfClose(); + return this; + } + + void waitForResponse() throws InterruptedException, ExecutionException, TimeoutException { + listener.responseFuture.get(10, TimeUnit.SECONDS); + } + + void waitForClose() throws InterruptedException, ExecutionException, TimeoutException { + listener.closedFuture.get(10, TimeUnit.SECONDS); + } } private static class TestClientStreamListener implements ClientStreamListener { - private final SettableFuture future; - - TestClientStreamListener(SettableFuture future) { - this.future = future; - } + private final SettableFuture closedFuture = SettableFuture.create(); + private final SettableFuture responseFuture = SettableFuture.create(); @Override public void headersRead(Metadata.Headers headers) { @@ -175,14 +234,17 @@ public class NettyClientTransportTest { @Override public void closed(Status status, Metadata.Trailers trailers) { if (status.isOk()) { - future.set(null); + closedFuture.set(null); } else { - future.setException(status.asException()); + StatusException e = status.asException(); + closedFuture.setException(e); + responseFuture.setException(e); } } @Override public void messageRead(InputStream message) { + responseFuture.set(null); } @Override @@ -191,9 +253,11 @@ public class NettyClientTransportTest { } private static class TestServerListener implements ServerListener { + final List transports = new ArrayList(); @Override public ServerTransportListener transportCreated(final ServerTransport transport) { + transports.add((NettyServerTransport) transport); return new ServerTransportListener() { @Override @@ -205,7 +269,8 @@ public class NettyClientTransportTest { @Override public void messageRead(InputStream message) { // Just echo back the message. - stream.writeMessage(messageStream()); + stream.writeMessage(message); + stream.flush(); } @Override