Testing that buffered streams clean up properly upon disconnect.

This commit is contained in:
nmittler 2015-06-11 15:05:48 -07:00
parent 15104cdc69
commit cb486e461d
2 changed files with 118 additions and 46 deletions

View File

@ -108,6 +108,13 @@ class NettyServerTransport implements ServerTransport {
} }
} }
/**
* For testing purposes only.
*/
Channel channel() {
return channel;
}
private void notifyTerminated(Throwable t) { private void notifyTerminated(Throwable t) {
if (t != null) { if (t != null) {
log.log(Level.SEVERE, "Transport failed", t); log.log(Level.SEVERE, "Transport failed", t);

View File

@ -33,6 +33,7 @@ package io.grpc.transport.netty;
import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Charsets.UTF_8;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE;
import static org.junit.Assert.fail;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.SettableFuture;
@ -42,6 +43,7 @@ import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.MethodType; import io.grpc.MethodType;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.testing.TestUtils; import io.grpc.testing.TestUtils;
import io.grpc.transport.ClientStream; import io.grpc.transport.ClientStream;
import io.grpc.transport.ClientStreamListener; import io.grpc.transport.ClientStreamListener;
@ -72,14 +74,15 @@ import java.io.InputStream;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
/** /**
* Tests for {@link NettyClientTransport}. * Tests for {@link NettyClientTransport}.
*/ */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class NettyClientTransportTest { public class NettyClientTransportTest {
private static final String MESSAGE = "hello";
@Mock @Mock
private ClientTransport.Listener clientTransportListener; private ClientTransport.Listener clientTransportListener;
@ -88,21 +91,14 @@ public class NettyClientTransportTest {
private NioEventLoopGroup group; private NioEventLoopGroup group;
private InetSocketAddress address; private InetSocketAddress address;
private NettyServer server; private NettyServer server;
private TestServerListener serverListener = new TestServerListener();
@Before @Before
public void setup() throws Exception { public void setup() throws Exception {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
group = new NioEventLoopGroup(1); group = new NioEventLoopGroup(1);
// Start the server.
address = TestUtils.testServerAddress(TestUtils.pickUnusedPort()); address = TestUtils.testServerAddress(TestUtils.pickUnusedPort());
File serverCert = TestUtils.loadCert("server1.pem");
File key = TestUtils.loadCert("server1.key");
SslContext serverContext = GrpcSslContexts.forServer(serverCert, key).build();
server = new NettyServer(address, NioServerSocketChannel.class,
group, group, serverContext, 100, DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE);
server.start(new TestServerListener());
} }
@After @After
@ -123,50 +119,113 @@ public class NettyClientTransportTest {
*/ */
@Test @Test
public void creatingMultipleTlsTransportsShouldSucceed() throws Exception { public void creatingMultipleTlsTransportsShouldSucceed() throws Exception {
startServer(Integer.MAX_VALUE);
// Create a couple client transports.
for (int index = 0; index < 2; ++index) {
NettyClientTransport transport = newTransport();
transport.start(clientTransportListener);
}
// Send a single RPC on each transport.
final List<Rpc> rpcs = new ArrayList<Rpc>(transports.size());
for (NettyClientTransport transport : transports) {
rpcs.add(new Rpc(transport).halfClose());
}
// Wait for the RPCs to complete.
for (Rpc rpc : rpcs) {
rpc.waitForResponse();
}
}
@Test
public void bufferedStreamsShouldBeClosedWhenConnectionTerminates() throws Exception {
// Only allow a single stream active at a time.
startServer(1);
NettyClientTransport transport = newTransport();
transport.start(clientTransportListener);
// Send a dummy RPC in order to ensure that the updated SETTINGS_MAX_CONCURRENT_STREAMS
// has been received by the remote endpoint.
new Rpc(transport).halfClose().waitForResponse();
// Create 3 streams, but don't half-close. The transport will buffer the second and third.
Rpc[] rpcs = new Rpc[] { new Rpc(transport), new Rpc(transport), new Rpc(transport) };
// Wait for the response for the stream that was actually created.
rpcs[0].waitForResponse();
// Now forcibly terminate the connection from the server side.
serverListener.transports.get(0).channel().pipeline().firstContext().close();
// Now wait for both listeners to be closed.
for (Rpc rpc : rpcs) {
try {
rpc.waitForClose();
fail("Expected the RPC to fail");
} catch (ExecutionException e) {
// Expected.
}
}
}
private NettyClientTransport newTransport() throws IOException {
// Create the protocol negotiator. // Create the protocol negotiator.
File clientCert = TestUtils.loadCert("ca.pem"); File clientCert = TestUtils.loadCert("ca.pem");
SslContext clientContext = GrpcSslContexts.forClient().trustManager(clientCert).build(); SslContext clientContext = GrpcSslContexts.forClient().trustManager(clientCert).build();
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, address); ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, address);
// Create a couple client transports.
for (int index = 0; index < 2; ++index) {
NettyClientTransport transport = new NettyClientTransport(address, NioSocketChannel.class, NettyClientTransport transport = new NettyClientTransport(address, NioSocketChannel.class,
group, negotiator, DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE); group, negotiator, DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE);
transports.add(transport); transports.add(transport);
transport.start(clientTransportListener); return transport;
} }
// Send a single RPC on each transport. private void startServer(int maxStreamsPerConnection) throws IOException {
final List<SettableFuture> rpcFutures = new ArrayList<SettableFuture>(transports.size()); File serverCert = TestUtils.loadCert("server1.pem");
MethodDescriptor<String, String> method = MethodDescriptor.create(MethodType.UNARY, File key = TestUtils.loadCert("server1.key");
"/testService/test", 10, TimeUnit.SECONDS, StringMarshaller.INSTANCE, SslContext serverContext = GrpcSslContexts.forServer(serverCert, key).build();
server = new NettyServer(address, NioServerSocketChannel.class,
group, group, serverContext, maxStreamsPerConnection,
DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE);
server.start(serverListener);
}
private static class Rpc {
static final String MESSAGE = "hello";
static final MethodDescriptor<String, String> METHOD = MethodDescriptor.create(
MethodType.UNARY, "/testService/test", 10, TimeUnit.SECONDS, StringMarshaller.INSTANCE,
StringMarshaller.INSTANCE); StringMarshaller.INSTANCE);
for (NettyClientTransport transport : transports) {
SettableFuture rpcFuture = SettableFuture.create(); final ClientStream stream;
rpcFutures.add(rpcFuture); final TestClientStreamListener listener = new TestClientStreamListener();
ClientStream stream = transport.newStream(method, new Metadata.Headers(),
new TestClientStreamListener(rpcFuture)); Rpc(NettyClientTransport transport) {
stream = transport.newStream(METHOD, new Metadata.Headers(), listener);
stream.request(1); stream.request(1);
stream.writeMessage(messageStream()); stream.writeMessage(new ByteArrayInputStream(MESSAGE.getBytes()));
stream.flush();
}
Rpc halfClose() {
stream.halfClose(); stream.halfClose();
return this;
} }
// Wait for the RPCs to complete. void waitForResponse() throws InterruptedException, ExecutionException, TimeoutException {
for (SettableFuture rpcFuture : rpcFutures) { listener.responseFuture.get(10, TimeUnit.SECONDS);
rpcFuture.get(10, TimeUnit.SECONDS);
}
} }
private static InputStream messageStream() { void waitForClose() throws InterruptedException, ExecutionException, TimeoutException {
return new ByteArrayInputStream(MESSAGE.getBytes()); listener.closedFuture.get(10, TimeUnit.SECONDS);
}
} }
private static class TestClientStreamListener implements ClientStreamListener { private static class TestClientStreamListener implements ClientStreamListener {
private final SettableFuture<?> future; private final SettableFuture<Void> closedFuture = SettableFuture.create();
private final SettableFuture<Void> responseFuture = SettableFuture.create();
TestClientStreamListener(SettableFuture<?> future) {
this.future = future;
}
@Override @Override
public void headersRead(Metadata.Headers headers) { public void headersRead(Metadata.Headers headers) {
@ -175,14 +234,17 @@ public class NettyClientTransportTest {
@Override @Override
public void closed(Status status, Metadata.Trailers trailers) { public void closed(Status status, Metadata.Trailers trailers) {
if (status.isOk()) { if (status.isOk()) {
future.set(null); closedFuture.set(null);
} else { } else {
future.setException(status.asException()); StatusException e = status.asException();
closedFuture.setException(e);
responseFuture.setException(e);
} }
} }
@Override @Override
public void messageRead(InputStream message) { public void messageRead(InputStream message) {
responseFuture.set(null);
} }
@Override @Override
@ -191,9 +253,11 @@ public class NettyClientTransportTest {
} }
private static class TestServerListener implements ServerListener { private static class TestServerListener implements ServerListener {
final List<NettyServerTransport> transports = new ArrayList<NettyServerTransport>();
@Override @Override
public ServerTransportListener transportCreated(final ServerTransport transport) { public ServerTransportListener transportCreated(final ServerTransport transport) {
transports.add((NettyServerTransport) transport);
return new ServerTransportListener() { return new ServerTransportListener() {
@Override @Override
@ -205,7 +269,8 @@ public class NettyClientTransportTest {
@Override @Override
public void messageRead(InputStream message) { public void messageRead(InputStream message) {
// Just echo back the message. // Just echo back the message.
stream.writeMessage(messageStream()); stream.writeMessage(message);
stream.flush();
} }
@Override @Override