Testing that buffered streams clean up properly upon disconnect.

This commit is contained in:
nmittler 2015-06-11 15:05:48 -07:00
parent 15104cdc69
commit cb486e461d
2 changed files with 118 additions and 46 deletions

View File

@ -108,6 +108,13 @@ class NettyServerTransport implements ServerTransport {
}
}
/**
* For testing purposes only.
*/
Channel channel() {
return channel;
}
private void notifyTerminated(Throwable t) {
if (t != null) {
log.log(Level.SEVERE, "Transport failed", t);

View File

@ -33,6 +33,7 @@ package io.grpc.transport.netty;
import static com.google.common.base.Charsets.UTF_8;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE;
import static org.junit.Assert.fail;
import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.SettableFuture;
@ -42,6 +43,7 @@ import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodType;
import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.testing.TestUtils;
import io.grpc.transport.ClientStream;
import io.grpc.transport.ClientStreamListener;
@ -72,14 +74,15 @@ import java.io.InputStream;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
/**
* Tests for {@link NettyClientTransport}.
*/
@RunWith(JUnit4.class)
public class NettyClientTransportTest {
private static final String MESSAGE = "hello";
@Mock
private ClientTransport.Listener clientTransportListener;
@ -88,21 +91,14 @@ public class NettyClientTransportTest {
private NioEventLoopGroup group;
private InetSocketAddress address;
private NettyServer server;
private TestServerListener serverListener = new TestServerListener();
@Before
public void setup() throws Exception {
MockitoAnnotations.initMocks(this);
group = new NioEventLoopGroup(1);
// Start the server.
address = TestUtils.testServerAddress(TestUtils.pickUnusedPort());
File serverCert = TestUtils.loadCert("server1.pem");
File key = TestUtils.loadCert("server1.key");
SslContext serverContext = GrpcSslContexts.forServer(serverCert, key).build();
server = new NettyServer(address, NioServerSocketChannel.class,
group, group, serverContext, 100, DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE);
server.start(new TestServerListener());
}
@After
@ -123,50 +119,113 @@ public class NettyClientTransportTest {
*/
@Test
public void creatingMultipleTlsTransportsShouldSucceed() throws Exception {
startServer(Integer.MAX_VALUE);
// Create a couple client transports.
for (int index = 0; index < 2; ++index) {
NettyClientTransport transport = newTransport();
transport.start(clientTransportListener);
}
// Send a single RPC on each transport.
final List<Rpc> rpcs = new ArrayList<Rpc>(transports.size());
for (NettyClientTransport transport : transports) {
rpcs.add(new Rpc(transport).halfClose());
}
// Wait for the RPCs to complete.
for (Rpc rpc : rpcs) {
rpc.waitForResponse();
}
}
@Test
public void bufferedStreamsShouldBeClosedWhenConnectionTerminates() throws Exception {
// Only allow a single stream active at a time.
startServer(1);
NettyClientTransport transport = newTransport();
transport.start(clientTransportListener);
// Send a dummy RPC in order to ensure that the updated SETTINGS_MAX_CONCURRENT_STREAMS
// has been received by the remote endpoint.
new Rpc(transport).halfClose().waitForResponse();
// Create 3 streams, but don't half-close. The transport will buffer the second and third.
Rpc[] rpcs = new Rpc[] { new Rpc(transport), new Rpc(transport), new Rpc(transport) };
// Wait for the response for the stream that was actually created.
rpcs[0].waitForResponse();
// Now forcibly terminate the connection from the server side.
serverListener.transports.get(0).channel().pipeline().firstContext().close();
// Now wait for both listeners to be closed.
for (Rpc rpc : rpcs) {
try {
rpc.waitForClose();
fail("Expected the RPC to fail");
} catch (ExecutionException e) {
// Expected.
}
}
}
private NettyClientTransport newTransport() throws IOException {
// Create the protocol negotiator.
File clientCert = TestUtils.loadCert("ca.pem");
SslContext clientContext = GrpcSslContexts.forClient().trustManager(clientCert).build();
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, address);
// Create a couple client transports.
for (int index = 0; index < 2; ++index) {
NettyClientTransport transport = new NettyClientTransport(address, NioSocketChannel.class,
group, negotiator, DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE);
transports.add(transport);
transport.start(clientTransportListener);
return transport;
}
// Send a single RPC on each transport.
final List<SettableFuture> rpcFutures = new ArrayList<SettableFuture>(transports.size());
MethodDescriptor<String, String> method = MethodDescriptor.create(MethodType.UNARY,
"/testService/test", 10, TimeUnit.SECONDS, StringMarshaller.INSTANCE,
private void startServer(int maxStreamsPerConnection) throws IOException {
File serverCert = TestUtils.loadCert("server1.pem");
File key = TestUtils.loadCert("server1.key");
SslContext serverContext = GrpcSslContexts.forServer(serverCert, key).build();
server = new NettyServer(address, NioServerSocketChannel.class,
group, group, serverContext, maxStreamsPerConnection,
DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE);
server.start(serverListener);
}
private static class Rpc {
static final String MESSAGE = "hello";
static final MethodDescriptor<String, String> METHOD = MethodDescriptor.create(
MethodType.UNARY, "/testService/test", 10, TimeUnit.SECONDS, StringMarshaller.INSTANCE,
StringMarshaller.INSTANCE);
for (NettyClientTransport transport : transports) {
SettableFuture rpcFuture = SettableFuture.create();
rpcFutures.add(rpcFuture);
ClientStream stream = transport.newStream(method, new Metadata.Headers(),
new TestClientStreamListener(rpcFuture));
final ClientStream stream;
final TestClientStreamListener listener = new TestClientStreamListener();
Rpc(NettyClientTransport transport) {
stream = transport.newStream(METHOD, new Metadata.Headers(), listener);
stream.request(1);
stream.writeMessage(messageStream());
stream.writeMessage(new ByteArrayInputStream(MESSAGE.getBytes()));
stream.flush();
}
Rpc halfClose() {
stream.halfClose();
return this;
}
// Wait for the RPCs to complete.
for (SettableFuture rpcFuture : rpcFutures) {
rpcFuture.get(10, TimeUnit.SECONDS);
}
void waitForResponse() throws InterruptedException, ExecutionException, TimeoutException {
listener.responseFuture.get(10, TimeUnit.SECONDS);
}
private static InputStream messageStream() {
return new ByteArrayInputStream(MESSAGE.getBytes());
void waitForClose() throws InterruptedException, ExecutionException, TimeoutException {
listener.closedFuture.get(10, TimeUnit.SECONDS);
}
}
private static class TestClientStreamListener implements ClientStreamListener {
private final SettableFuture<?> future;
TestClientStreamListener(SettableFuture<?> future) {
this.future = future;
}
private final SettableFuture<Void> closedFuture = SettableFuture.create();
private final SettableFuture<Void> responseFuture = SettableFuture.create();
@Override
public void headersRead(Metadata.Headers headers) {
@ -175,14 +234,17 @@ public class NettyClientTransportTest {
@Override
public void closed(Status status, Metadata.Trailers trailers) {
if (status.isOk()) {
future.set(null);
closedFuture.set(null);
} else {
future.setException(status.asException());
StatusException e = status.asException();
closedFuture.setException(e);
responseFuture.setException(e);
}
}
@Override
public void messageRead(InputStream message) {
responseFuture.set(null);
}
@Override
@ -191,9 +253,11 @@ public class NettyClientTransportTest {
}
private static class TestServerListener implements ServerListener {
final List<NettyServerTransport> transports = new ArrayList<NettyServerTransport>();
@Override
public ServerTransportListener transportCreated(final ServerTransport transport) {
transports.add((NettyServerTransport) transport);
return new ServerTransportListener() {
@Override
@ -205,7 +269,8 @@ public class NettyClientTransportTest {
@Override
public void messageRead(InputStream message) {
// Just echo back the message.
stream.writeMessage(messageStream());
stream.writeMessage(message);
stream.flush();
}
@Override