diff --git a/okhttp/src/main/java/io/grpc/transport/okhttp/AsyncFrameWriter.java b/okhttp/src/main/java/io/grpc/transport/okhttp/AsyncFrameWriter.java index b5594ea997..bbc6e79cad 100644 --- a/okhttp/src/main/java/io/grpc/transport/okhttp/AsyncFrameWriter.java +++ b/okhttp/src/main/java/io/grpc/transport/okhttp/AsyncFrameWriter.java @@ -31,6 +31,7 @@ package io.grpc.transport.okhttp; +import com.google.common.base.Preconditions; import com.google.common.util.concurrent.SettableFuture; import com.squareup.okhttp.internal.spdy.ErrorCode; @@ -45,20 +46,26 @@ import okio.Buffer; import java.io.IOException; import java.util.List; import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; class AsyncFrameWriter implements FrameWriter { - private final FrameWriter frameWriter; - private final Executor executor; + private FrameWriter frameWriter; + // Although writes are thread-safe, we serialize them to prevent consuming many Threads that are + // just waiting on each other. + private final SerializingExecutor executor; private final OkHttpClientTransport transport; - public AsyncFrameWriter(FrameWriter frameWriter, OkHttpClientTransport transport, - Executor executor) { - this.frameWriter = frameWriter; + public AsyncFrameWriter(OkHttpClientTransport transport, SerializingExecutor executor) { this.transport = transport; - // Although writes are thread-safe, we serialize them to prevent consuming many Threads that are - // just waiting on each other. - this.executor = new SerializingExecutor(executor); + this.executor = executor; + } + + /** + * Set the real frameWriter, should only be called by thread of executor. + */ + void setFrameWriter(FrameWriter frameWriter) { + Preconditions.checkState(this.frameWriter == null, + "AsyncFrameWriter's setFrameWriter() should only be called once."); + this.frameWriter = frameWriter; } @Override @@ -206,7 +213,9 @@ class AsyncFrameWriter implements FrameWriter { @Override public void run() { try { - frameWriter.close(); + if (frameWriter != null) { + frameWriter.close(); + } } catch (IOException e) { closeFuture.setException(e); } finally { @@ -228,6 +237,9 @@ class AsyncFrameWriter implements FrameWriter { @Override public final void run() { try { + if (frameWriter == null) { + throw new IOException("Unable to perform write due to unavailable frameWriter."); + } doRun(); } catch (IOException ex) { transport.onIoException(ex); @@ -240,6 +252,7 @@ class AsyncFrameWriter implements FrameWriter { @Override public int maxDataLength() { - return frameWriter.maxDataLength(); + return frameWriter == null ? 0x4000 /* 16384, the minimum required by the HTTP/2 spec */ + : frameWriter.maxDataLength(); } } diff --git a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java index d7eb7776a2..c629271564 100644 --- a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java @@ -43,6 +43,7 @@ import com.squareup.okhttp.ConnectionSpec; import com.squareup.okhttp.OkHttpTlsUpgrader; import com.squareup.okhttp.internal.spdy.ErrorCode; import com.squareup.okhttp.internal.spdy.FrameReader; +import com.squareup.okhttp.internal.spdy.FrameWriter; import com.squareup.okhttp.internal.spdy.Header; import com.squareup.okhttp.internal.spdy.HeadersMode; import com.squareup.okhttp.internal.spdy.Http2; @@ -53,6 +54,7 @@ import com.squareup.okhttp.internal.spdy.Variant; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.MethodType; +import io.grpc.SerializingExecutor; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.transport.ClientStreamListener; @@ -130,7 +132,7 @@ class OkHttpClientTransport implements ClientTransport { private final Random random = new Random(); private final Ticker ticker; private Listener listener; - private FrameReader frameReader; + private FrameReader testFrameReader; private AsyncFrameWriter frameWriter; private OutboundFlowController outboundFlow; private final Object lock = new Object(); @@ -139,6 +141,8 @@ class OkHttpClientTransport implements ClientTransport { private final Map streams = Collections.synchronizedMap(new HashMap()); private final Executor executor; + // Wrap on executor, to guarantee some operations be executed serially. + private final SerializingExecutor serializingExecutor; private int connectionUnacknowledgedBytesRead; private ClientFrameHandler clientFrameHandler; // The status used to finish all active streams when the transport is closed. @@ -157,6 +161,9 @@ class OkHttpClientTransport implements ClientTransport { @GuardedBy("lock") private LinkedList pendingStreams = new LinkedList(); private final ConnectionSpec connectionSpec; + private FrameWriter testFrameWriter; + // Used by test only. + Runnable connectedCallback; OkHttpClientTransport(String host, int port, String authorityHost, Executor executor, @Nullable SSLSocketFactory sslSocketFactory, ConnectionSpec connectionSpec) { @@ -165,6 +172,7 @@ class OkHttpClientTransport implements ClientTransport { this.authorityHost = authorityHost; defaultAuthority = authorityHost + ":" + port; this.executor = Preconditions.checkNotNull(executor); + serializingExecutor = new SerializingExecutor(executor); // Client initiated streams are odd, server initiated ones are even. Server should not need to // use it. We start clients at 3 to avoid conflicting with HTTP negotiation. nextStreamId = 3; @@ -177,29 +185,25 @@ class OkHttpClientTransport implements ClientTransport { * Create a transport connected to a fake peer for test. */ @VisibleForTesting - OkHttpClientTransport(Executor executor, FrameReader frameReader, AsyncFrameWriter frameWriter, - int nextStreamId, Socket socket) { - this(executor, frameReader, frameWriter, nextStreamId, socket, Ticker.systemTicker()); - } - - /** - * Create a transport connected to a fake peer for test, with a custom ticker. - */ - @VisibleForTesting - OkHttpClientTransport(Executor executor, FrameReader frameReader, AsyncFrameWriter frameWriter, - int nextStreamId, Socket socket, Ticker ticker) { + OkHttpClientTransport(Executor executor, FrameReader frameReader, FrameWriter testFrameWriter, + int nextStreamId, Socket socket, Ticker ticker, Runnable connectedCallback) { host = null; port = 0; authorityHost = null; defaultAuthority = "notarealauthority:80"; this.executor = Preconditions.checkNotNull(executor); - this.frameReader = Preconditions.checkNotNull(frameReader); - this.frameWriter = Preconditions.checkNotNull(frameWriter); + serializingExecutor = new SerializingExecutor(executor); + this.testFrameReader = Preconditions.checkNotNull(frameReader); + this.testFrameWriter = Preconditions.checkNotNull(testFrameWriter); this.socket = Preconditions.checkNotNull(socket); - this.outboundFlow = new OutboundFlowController(this, frameWriter); this.nextStreamId = nextStreamId; this.ticker = ticker; this.connectionSpec = null; + this.connectedCallback = Preconditions.checkNotNull(connectedCallback); + } + + private boolean isForTest() { + return host == null; } @Override @@ -315,36 +319,73 @@ class OkHttpClientTransport implements ClientTransport { @Override public void start(Listener listener) { this.listener = Preconditions.checkNotNull(listener, "listener"); - // We set host to null for test. - if (host != null) { - BufferedSource source; - BufferedSink sink; - try { - socket = new Socket(host, port); - if (sslSocketFactory != null) { - socket = OkHttpTlsUpgrader.upgrade( - sslSocketFactory, socket, authorityHost, port, connectionSpec); - } - socket.setTcpNoDelay(true); - source = Okio.buffer(Okio.source(socket)); - sink = Okio.buffer(Okio.sink(socket)); - } catch (IOException e) { - // TODO(jhump): should we instead notify the listener of shutdown+terminated? - // (and probably do all of this work asynchronously instead of in calling thread) - throw Status.UNAVAILABLE.withDescription("Failed connecting").withCause(e) - .asRuntimeException(); - } - Variant variant = new Http2(); - frameReader = variant.newReader(source, true); - frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor); - outboundFlow = new OutboundFlowController(this, frameWriter); - frameWriter.connectionPreface(); - Settings settings = new Settings(); - frameWriter.settings(settings); - } - clientFrameHandler = new ClientFrameHandler(); - executor.execute(clientFrameHandler); + frameWriter = new AsyncFrameWriter(this, serializingExecutor); + outboundFlow = new OutboundFlowController(this, frameWriter); + + // Connecting in the serializingExecutor, so that some stream operations like synStream + // will be executed after connected. + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + if (isForTest()) { + clientFrameHandler = new ClientFrameHandler(testFrameReader); + executor.execute(clientFrameHandler); + connectedCallback.run(); + frameWriter.setFrameWriter(testFrameWriter); + return; + } + BufferedSource source; + BufferedSink sink; + Socket sock; + try { + sock = new Socket(host, port); + if (sslSocketFactory != null) { + sock = OkHttpTlsUpgrader.upgrade( + sslSocketFactory, sock, authorityHost, port, connectionSpec); + } + sock.setTcpNoDelay(true); + source = Okio.buffer(Okio.source(sock)); + sink = Okio.buffer(Okio.sink(sock)); + } catch (IOException e) { + onIoException(e); + // (and probably do all of this work asynchronously instead of in calling thread) + throw new RuntimeException(e); + } + + FrameWriter rawFrameWriter; + synchronized (lock) { + if (stopped) { + // In case user called shutdown() during the connecting. + try { + sock.close(); + } catch (IOException e) { + log.log(Level.WARNING, "Failed closing socket", e); + } + return; + } + socket = sock; + } + + Variant variant = new Http2(); + rawFrameWriter = variant.newWriter(sink, true); + frameWriter.setFrameWriter(rawFrameWriter); + + try { + // Do these with the raw FrameWriter, so that they will be done in this thread, + // and before any possible pending stream operations. + rawFrameWriter.connectionPreface(); + Settings settings = new Settings(); + rawFrameWriter.settings(settings); + } catch (IOException e) { + onIoException(e); + throw new RuntimeException(e); + } + + clientFrameHandler = new ClientFrameHandler(variant.newReader(source, true)); + executor.execute(clientFrameHandler); + } + }); } @Override @@ -356,9 +397,8 @@ class OkHttpClientTransport implements ClientTransport { if (normalClose) { // Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated streams. // The GOAWAY is part of graceful shutdown. - if (frameWriter != null) { - frameWriter.goAway(0, ErrorCode.NO_ERROR, new byte[0]); - } + frameWriter.goAway(0, ErrorCode.NO_ERROR, new byte[0]); + onGoAway(Integer.MAX_VALUE, Status.UNAVAILABLE.withDescription("Transport stopped")); } stopIfNecessary(); @@ -469,6 +509,7 @@ class OkHttpClientTransport implements ClientTransport { void stopIfNecessary() { boolean shouldStop; Http2Ping outstandingPing = null; + boolean socketConnected; synchronized (lock) { shouldStop = (goAway && streams.size() == 0); if (shouldStop) { @@ -481,11 +522,12 @@ class OkHttpClientTransport implements ClientTransport { ping = null; } } + socketConnected = socket != null; } if (shouldStop) { // Wait for the frame writer to close. - if (frameWriter != null) { - frameWriter.close(); + frameWriter.close(); + if (socketConnected) { // Close the socket to break out the reader thread, which will close the // frameReader and notify the listener. try { @@ -529,7 +571,11 @@ class OkHttpClientTransport implements ClientTransport { */ @VisibleForTesting class ClientFrameHandler implements FrameReader.Handler, Runnable { - ClientFrameHandler() {} + FrameReader frameReader; + + ClientFrameHandler(FrameReader frameReader) { + this.frameReader = frameReader; + } @Override public void run() { diff --git a/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java index 42d226e7db..09526c222e 100644 --- a/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java @@ -46,20 +46,22 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.timeout; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import com.google.common.base.Ticker; import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.SettableFuture; import com.squareup.okhttp.internal.spdy.ErrorCode; import com.squareup.okhttp.internal.spdy.FrameReader; +import com.squareup.okhttp.internal.spdy.FrameWriter; import com.squareup.okhttp.internal.spdy.Header; import com.squareup.okhttp.internal.spdy.HeadersMode; import com.squareup.okhttp.internal.spdy.OkHttpSettingsUtil; @@ -79,7 +81,9 @@ import okio.Buffer; import org.junit.After; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -95,7 +99,6 @@ import java.io.InputStreamReader; import java.net.Socket; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; @@ -113,8 +116,11 @@ public class OkHttpClientTransportTest { // The gRPC header length, which includes 1 byte compression flag and 4 bytes message length. private static final int HEADER_LENGTH = 5; + @Rule + public Timeout globalTimeout = new Timeout(10 * 1000); + @Mock - private AsyncFrameWriter frameWriter; + private FrameWriter frameWriter; @Mock MethodDescriptor method; @Mock @@ -125,15 +131,25 @@ public class OkHttpClientTransportTest { private ClientFrameHandler frameHandler; private ExecutorService executor; private long nanoTime; // backs a ticker, for testing ping round-trip time measurement + private ConnectedCallback connectedCallback; /** Set up for test. */ @Before public void setUp() { MockitoAnnotations.initMocks(this); - streams = new HashMap(); - frameReader = new MockFrameReader(); - MockSocket socket = new MockSocket(frameReader); executor = Executors.newCachedThreadPool(); + when(method.getFullMethodName()).thenReturn("fakemethod"); + when(method.getType()).thenReturn(MethodType.UNARY); + when(frameWriter.maxDataLength()).thenReturn(Integer.MAX_VALUE); + frameReader = new MockFrameReader(); + } + + private void initTransport() { + initTransport(3, new ConnectedCallback(false)); + } + + private void initTransport(int startId, ConnectedCallback connectedCallback) { + this.connectedCallback = connectedCallback; Ticker ticker = new Ticker() { @Override public long read() { @@ -141,21 +157,18 @@ public class OkHttpClientTransportTest { } }; clientTransport = new OkHttpClientTransport( - executor, frameReader, frameWriter, 3, socket, ticker); + executor, frameReader, frameWriter, startId, + new MockSocket(frameReader), ticker, connectedCallback); clientTransport.start(transportListener); - frameHandler = clientTransport.getHandler(); streams = clientTransport.getStreams(); - when(method.getFullMethodName()).thenReturn("fakemethod"); - when(method.getType()).thenReturn(MethodType.UNARY); - when(frameWriter.maxDataLength()).thenReturn(Integer.MAX_VALUE); } /** Final test checks and clean up. */ @After - public void tearDown() { + public void tearDown() throws Exception { clientTransport.shutdown(); assertEquals(0, streams.size()); - verify(frameWriter).close(); + verify(frameWriter, timeout(TIME_OUT_MS)).close(); frameReader.assertClosed(); executor.shutdown(); } @@ -165,6 +178,7 @@ public class OkHttpClientTransportTest { */ @Test public void nextFrameThrowIoException() throws Exception { + initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener1).request(1); @@ -186,18 +200,19 @@ public class OkHttpClientTransportTest { @Test public void readMessages() throws Exception { + initTransport(); final int numMessages = 10; final String message = "Hello Client"; MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener).request(numMessages); assertTrue(streams.containsKey(3)); - frameHandler.headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); + frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); assertNotNull(listener.headers); for (int i = 0; i < numMessages; i++) { Buffer buffer = createMessageFrame(message + i); - frameHandler.data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size()); } - frameHandler.headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS); + frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS); listener.waitUntilStreamClosed(); assertEquals(Status.OK, listener.status); assertNotNull(listener.trailers); @@ -209,34 +224,39 @@ public class OkHttpClientTransportTest { @Test public void receivedHeadersForInvalidStreamShouldKillConnection() throws Exception { + initTransport(); // Empty headers block without correct content type or status - frameHandler.headers(false, false, 3, 0, new ArrayList
(), + frameHandler().headers(false, false, 3, 0, new ArrayList
(), HeadersMode.HTTP_20_HEADERS); - verify(frameWriter).goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); + verify(frameWriter, timeout(TIME_OUT_MS)) + .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); verify(transportListener).transportShutdown(); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } @Test public void receivedDataForInvalidStreamShouldKillConnection() throws Exception { - frameHandler.data(false, 3, createMessageFrame(new String(new char[1000])), 1000); - verify(frameWriter).goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); + initTransport(); + frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000); + verify(frameWriter, timeout(TIME_OUT_MS)) + .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); verify(transportListener).transportShutdown(); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } @Test public void invalidInboundHeadersCancelStream() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener).request(1); assertTrue(streams.containsKey(3)); // Empty headers block without correct content type or status - frameHandler.headers(false, false, 3, 0, new ArrayList
(), + frameHandler().headers(false, false, 3, 0, new ArrayList
(), HeadersMode.HTTP_20_HEADERS); // Now wait to receive 1000 bytes of data so we can have a better error message before // cancelling the streaam. - frameHandler.data(false, 3, createMessageFrame(new String(new char[1000])), 1000); - verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); + frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000); + verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); assertNull(listener.headers); assertEquals(Status.INTERNAL.getCode(), listener.status.getCode()); assertNotNull(listener.trailers); @@ -244,32 +264,35 @@ public class OkHttpClientTransportTest { @Test public void readStatus() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener); assertTrue(streams.containsKey(3)); - frameHandler.headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS); + frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS); listener.waitUntilStreamClosed(); assertEquals(Status.Code.OK, listener.status.getCode()); } @Test public void receiveReset() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener); assertTrue(streams.containsKey(3)); - frameHandler.rstStream(3, ErrorCode.PROTOCOL_ERROR); + frameHandler().rstStream(3, ErrorCode.PROTOCOL_ERROR); listener.waitUntilStreamClosed(); assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.PROTOCOL_ERROR), listener.status); } @Test public void cancelStream() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener); OkHttpClientStream stream = streams.get(3); assertNotNull(stream); stream.cancel(Status.CANCELLED); - verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); + verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); listener.waitUntilStreamClosed(); assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(), listener.status.getCode()); @@ -277,6 +300,7 @@ public class OkHttpClientTransportTest { @Test public void headersShouldAddDefaultUserAgent() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener); Header userAgentHeader = new Header(HttpUtil.USER_AGENT_KEY.name(), @@ -285,40 +309,45 @@ public class OkHttpClientTransportTest { new Header(Header.TARGET_AUTHORITY, "notarealauthority:80"), new Header(Header.TARGET_PATH, "/fakemethod"), userAgentHeader, CONTENT_TYPE_HEADER, TE_HEADER); - verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders)); + verify(frameWriter, timeout(TIME_OUT_MS)) + .synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders)); streams.get(3).cancel(Status.CANCELLED); } @Test public void headersShouldOverrideDefaultUserAgent() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); String userAgent = "fakeUserAgent"; Metadata.Headers metadata = new Metadata.Headers(); metadata.put(HttpUtil.USER_AGENT_KEY, userAgent); clientTransport.newStream(method, metadata, listener); List
expectedHeaders = Arrays.asList(SCHEME_HEADER, METHOD_HEADER, - new Header(Header.TARGET_AUTHORITY, "notarealauthority:80"), - new Header(Header.TARGET_PATH, "/fakemethod"), - new Header(HttpUtil.USER_AGENT_KEY.name(), - HttpUtil.getGrpcUserAgent("okhttp", userAgent)), - CONTENT_TYPE_HEADER, TE_HEADER); - verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders)); + new Header(Header.TARGET_AUTHORITY, "notarealauthority:80"), + new Header(Header.TARGET_PATH, "/fakemethod"), + new Header(HttpUtil.USER_AGENT_KEY.name(), + HttpUtil.getGrpcUserAgent("okhttp", userAgent)), + CONTENT_TYPE_HEADER, TE_HEADER); + verify(frameWriter, timeout(TIME_OUT_MS)) + .synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders)); streams.get(3).cancel(Status.CANCELLED); } @Test public void cancelStreamForDeadlineExceeded() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener); OkHttpClientStream stream = streams.get(3); assertNotNull(stream); stream.cancel(Status.DEADLINE_EXCEEDED); - verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); + verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); listener.waitUntilStreamClosed(); } @Test public void writeMessage() throws Exception { + initTransport(); final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener); @@ -328,7 +357,8 @@ public class OkHttpClientTransportTest { stream.writeMessage(input); stream.flush(); ArgumentCaptor captor = ArgumentCaptor.forClass(Buffer.class); - verify(frameWriter).data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH)); + verify(frameWriter, timeout(TIME_OUT_MS)) + .data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH)); Buffer sentFrame = captor.getValue(); assertEquals(createMessageFrame(message), sentFrame); stream.cancel(Status.CANCELLED); @@ -336,6 +366,7 @@ public class OkHttpClientTransportTest { @Test public void windowUpdate() throws Exception { + initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); clientTransport.newStream(method,new Metadata.Headers(), listener1).request(2); @@ -344,8 +375,8 @@ public class OkHttpClientTransportTest { OkHttpClientStream stream1 = streams.get(3); OkHttpClientStream stream2 = streams.get(5); - frameHandler.headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); - frameHandler.headers(false, false, 5, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); + frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); + frameHandler().headers(false, false, 5, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); int messageLength = Utils.DEFAULT_WINDOW_SIZE / 4; byte[] fakeMessage = new byte[messageLength]; @@ -353,36 +384,40 @@ public class OkHttpClientTransportTest { // Stream 1 receives a message Buffer buffer = createMessageFrame(fakeMessage); int messageFrameLength = (int) buffer.size(); - frameHandler.data(false, 3, buffer, messageFrameLength); + frameHandler().data(false, 3, buffer, messageFrameLength); // Stream 2 receives a message buffer = createMessageFrame(fakeMessage); - frameHandler.data(false, 5, buffer, messageFrameLength); + frameHandler().data(false, 5, buffer, messageFrameLength); - verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * messageFrameLength)); + verify(frameWriter, timeout(TIME_OUT_MS)) + .windowUpdate(eq(0), eq((long) 2 * messageFrameLength)); reset(frameWriter); // Stream 1 receives another message buffer = createMessageFrame(fakeMessage); - frameHandler.data(false, 3, buffer, messageFrameLength); + frameHandler().data(false, 3, buffer, messageFrameLength); - verify(frameWriter).windowUpdate(eq(3), eq((long) 2 * messageFrameLength)); + verify(frameWriter, timeout(TIME_OUT_MS)) + .windowUpdate(eq(3), eq((long) 2 * messageFrameLength)); // Stream 2 receives another message buffer = createMessageFrame(fakeMessage); - frameHandler.data(false, 5, buffer, messageFrameLength); + frameHandler().data(false, 5, buffer, messageFrameLength); - verify(frameWriter).windowUpdate(eq(5), eq((long) 2 * messageFrameLength)); - verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * messageFrameLength)); + verify(frameWriter, timeout(TIME_OUT_MS)) + .windowUpdate(eq(5), eq((long) 2 * messageFrameLength)); + verify(frameWriter, timeout(TIME_OUT_MS)) + .windowUpdate(eq(0), eq((long) 2 * messageFrameLength)); stream1.cancel(Status.CANCELLED); - verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); + verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); listener1.waitUntilStreamClosed(); assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(), listener1.status.getCode()); stream2.cancel(Status.CANCELLED); - verify(frameWriter).rstStream(eq(5), eq(ErrorCode.CANCEL)); + verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(5), eq(ErrorCode.CANCEL)); listener2.waitUntilStreamClosed(); assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(), listener2.status.getCode()); @@ -390,6 +425,7 @@ public class OkHttpClientTransportTest { @Test public void windowUpdateWithInboundFlowControl() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener).request(1); OkHttpClientStream stream = streams.get(3); @@ -397,16 +433,16 @@ public class OkHttpClientTransportTest { int messageLength = Utils.DEFAULT_WINDOW_SIZE / 2 + 1; byte[] fakeMessage = new byte[messageLength]; - frameHandler.headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); + frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); Buffer buffer = createMessageFrame(fakeMessage); long messageFrameLength = buffer.size(); - frameHandler.data(false, 3, buffer, (int) messageFrameLength); - verify(frameWriter).windowUpdate(eq(0), eq(messageFrameLength)); + frameHandler().data(false, 3, buffer, (int) messageFrameLength); + verify(frameWriter, timeout(TIME_OUT_MS)).windowUpdate(eq(0), eq(messageFrameLength)); // We return the bytes for the stream window as we read the message. - verify(frameWriter).windowUpdate(eq(3), eq(messageFrameLength)); + verify(frameWriter, timeout(TIME_OUT_MS)).windowUpdate(eq(3), eq(messageFrameLength)); stream.cancel(Status.CANCELLED); - verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); + verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); listener.waitUntilStreamClosed(); assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL).getCode(), listener.status.getCode()); @@ -414,6 +450,7 @@ public class OkHttpClientTransportTest { @Test public void outboundFlowControl() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = clientTransport.newStream(method, new Metadata.Headers(), listener); @@ -422,7 +459,7 @@ public class OkHttpClientTransportTest { InputStream input = new ByteArrayInputStream(new byte[messageLength]); stream.writeMessage(input); stream.flush(); - verify(frameWriter).data( + verify(frameWriter, timeout(TIME_OUT_MS)).data( eq(false), eq(3), any(Buffer.class), eq(messageLength + HEADER_LENGTH)); @@ -432,12 +469,13 @@ public class OkHttpClientTransportTest { stream.flush(); int partiallySentSize = Utils.DEFAULT_WINDOW_SIZE - messageLength - HEADER_LENGTH; - verify(frameWriter).data(eq(false), eq(3), any(Buffer.class), eq(partiallySentSize)); + verify(frameWriter, timeout(TIME_OUT_MS)) + .data(eq(false), eq(3), any(Buffer.class), eq(partiallySentSize)); // Get more credit, the rest data should be sent out. - frameHandler.windowUpdate(3, Utils.DEFAULT_WINDOW_SIZE); - frameHandler.windowUpdate(0, Utils.DEFAULT_WINDOW_SIZE); - verify(frameWriter).data( + frameHandler().windowUpdate(3, Utils.DEFAULT_WINDOW_SIZE); + frameHandler().windowUpdate(0, Utils.DEFAULT_WINDOW_SIZE); + verify(frameWriter, timeout(TIME_OUT_MS)).data( eq(false), eq(3), any(Buffer.class), eq(messageLength + HEADER_LENGTH - partiallySentSize)); stream.cancel(Status.CANCELLED); @@ -446,6 +484,7 @@ public class OkHttpClientTransportTest { @Test public void outboundFlowControlWithInitialWindowSizeChange() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = clientTransport.newStream(method, new Metadata.Headers(), listener); int messageLength = 20; @@ -454,30 +493,33 @@ public class OkHttpClientTransportTest { stream.writeMessage(input); stream.flush(); // part of the message can be sent. - verify(frameWriter).data(eq(false), eq(3), any(Buffer.class), eq(HEADER_LENGTH + 10)); + verify(frameWriter, timeout(TIME_OUT_MS)) + .data(eq(false), eq(3), any(Buffer.class), eq(HEADER_LENGTH + 10)); // Avoid connection flow control. - frameHandler.windowUpdate(0, HEADER_LENGTH + 10); + frameHandler().windowUpdate(0, HEADER_LENGTH + 10); // Increase initial window size setInitialWindowSize(HEADER_LENGTH + 20); // The rest data should be sent. - verify(frameWriter).data(eq(false), eq(3), any(Buffer.class), eq(10)); - frameHandler.windowUpdate(0, 10); + verify(frameWriter, timeout(TIME_OUT_MS)).data(eq(false), eq(3), any(Buffer.class), eq(10)); + frameHandler().windowUpdate(0, 10); // Decrease initial window size to HEADER_LENGTH, since we've already sent // out HEADER_LENGTH + 20 bytes data, the window size should be -20 now. setInitialWindowSize(HEADER_LENGTH); // Get 20 tokens back, still can't send any data. - frameHandler.windowUpdate(3, 20); + frameHandler().windowUpdate(3, 20); input = new ByteArrayInputStream(new byte[messageLength]); stream.writeMessage(input); stream.flush(); // Only the previous two write operations happened. - verify(frameWriter, times(2)).data(anyBoolean(), anyInt(), any(Buffer.class), anyInt()); + verify(frameWriter, timeout(TIME_OUT_MS).times(2)) + .data(anyBoolean(), anyInt(), any(Buffer.class), anyInt()); // Get enough tokens to send the pending message. - frameHandler.windowUpdate(3, HEADER_LENGTH + 20); - verify(frameWriter).data(eq(false), eq(3), any(Buffer.class), eq(HEADER_LENGTH + 20)); + frameHandler().windowUpdate(3, HEADER_LENGTH + 20); + verify(frameWriter, timeout(TIME_OUT_MS)) + .data(eq(false), eq(3), any(Buffer.class), eq(HEADER_LENGTH + 20)); stream.cancel(Status.CANCELLED); listener.waitUntilStreamClosed(); @@ -485,6 +527,7 @@ public class OkHttpClientTransportTest { @Test public void stopNormally() throws Exception { + initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 @@ -493,7 +536,7 @@ public class OkHttpClientTransportTest { = clientTransport.newStream(method, new Metadata.Headers(), listener2); assertEquals(2, streams.size()); clientTransport.shutdown(); - verify(frameWriter).goAway(eq(0), eq(ErrorCode.NO_ERROR), (byte[]) any()); + verify(frameWriter, timeout(TIME_OUT_MS)).goAway(eq(0), eq(ErrorCode.NO_ERROR), (byte[]) any()); assertEquals(2, streams.size()); verify(transportListener).transportShutdown(); @@ -509,6 +552,7 @@ public class OkHttpClientTransportTest { @Test public void receiveGoAway() throws Exception { + initTransport(); // start 2 streams. MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); @@ -517,7 +561,7 @@ public class OkHttpClientTransportTest { assertEquals(2, streams.size()); // Receive goAway, max good id is 3. - frameHandler.goAway(3, ErrorCode.CANCEL, null); + frameHandler().goAway(3, ErrorCode.CANCEL, null); // Transport should be in STOPPING state. verify(transportListener).transportShutdown(); @@ -529,7 +573,7 @@ public class OkHttpClientTransportTest { assertEquals(Status.CANCELLED.getCode(), listener2.status.getCode()); // New stream should be failed. - assertNewStreamFail(clientTransport); + assertNewStreamFail(); // But stream 1 should be able to send. final String sentMessage = "Should I also go away?"; @@ -541,16 +585,17 @@ public class OkHttpClientTransportTest { stream.flush(); ArgumentCaptor captor = ArgumentCaptor.forClass(Buffer.class); - verify(frameWriter).data(eq(false), eq(3), captor.capture(), eq(22 + HEADER_LENGTH)); + verify(frameWriter, timeout(TIME_OUT_MS)) + .data(eq(false), eq(3), captor.capture(), eq(22 + HEADER_LENGTH)); Buffer sentFrame = captor.getValue(); assertEquals(createMessageFrame(sentMessage), sentFrame); // And read. - frameHandler.headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); + frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); final String receivedMessage = "No, you are fine."; Buffer buffer = createMessageFrame(receivedMessage); - frameHandler.data(false, 3, buffer, (int) buffer.size()); - frameHandler.headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS); + frameHandler().data(false, 3, buffer, (int) buffer.size()); + frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS); listener1.waitUntilStreamClosed(); assertEquals(1, listener1.messages.size()); assertEquals(receivedMessage, listener1.messages.get(0)); @@ -562,27 +607,23 @@ public class OkHttpClientTransportTest { @Test public void streamIdExhausted() throws Exception { int startId = Integer.MAX_VALUE - 2; - AsyncFrameWriter writer = mock(AsyncFrameWriter.class); - MockFrameReader frameReader = new MockFrameReader(); - OkHttpClientTransport transport = new OkHttpClientTransport( - executor, frameReader, writer, startId, new MockSocket(frameReader)); - transport.start(transportListener); - streams = transport.getStreams(); + initTransport(startId, new ConnectedCallback(false)); MockStreamListener listener1 = new MockStreamListener(); - transport.newStream(method, new Metadata.Headers(), listener1); + clientTransport.newStream(method, new Metadata.Headers(), listener1); - assertNewStreamFail(transport); + assertNewStreamFail(); streams.get(startId).cancel(Status.CANCELLED); listener1.waitUntilStreamClosed(); - verify(writer).rstStream(eq(startId), eq(ErrorCode.CANCEL)); + verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(startId), eq(ErrorCode.CANCEL)); verify(transportListener).transportShutdown(); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } @Test public void pendingStreamSucceed() throws Exception { + initTransport(); setMaxConcurrentStreams(1); final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); @@ -615,6 +656,7 @@ public class OkHttpClientTransportTest { @Test public void pendingStreamFailedByGoAway() throws Exception { + initTransport(); setMaxConcurrentStreams(0); final MockStreamListener listener = new MockStreamListener(); final CountDownLatch newStreamReturn = new CountDownLatch(1); @@ -628,7 +670,7 @@ public class OkHttpClientTransportTest { }).start(); waitForStreamPending(1); - frameHandler.goAway(0, ErrorCode.CANCEL, null); + frameHandler().goAway(0, ErrorCode.CANCEL, null); assertTrue("newStream() call is still blocking", newStreamReturn.await(TIME_OUT_MS, TimeUnit.MILLISECONDS)); @@ -639,6 +681,7 @@ public class OkHttpClientTransportTest { @Test public void pendingStreamFailedByShutdown() throws Exception { + initTransport(); setMaxConcurrentStreams(0); final MockStreamListener listener = new MockStreamListener(); final CountDownLatch newStreamReturn = new CountDownLatch(1); @@ -664,11 +707,7 @@ public class OkHttpClientTransportTest { @Test public void pendingStreamFailedByIdExhausted() throws Exception { int startId = Integer.MAX_VALUE - 4; - clientTransport = new OkHttpClientTransport( - executor, frameReader, frameWriter, startId, new MockSocket(frameReader)); - clientTransport.start(transportListener); - frameHandler = clientTransport.getHandler(); - streams = clientTransport.getStreams(); + initTransport(startId, new ConnectedCallback(false)); setMaxConcurrentStreams(1); final MockStreamListener listener1 = new MockStreamListener(); @@ -721,22 +760,23 @@ public class OkHttpClientTransportTest { @Test public void receivingWindowExceeded() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener).request(1); - frameHandler.headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); + frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); int messageLength = Utils.DEFAULT_WINDOW_SIZE + 1; byte[] fakeMessage = new byte[messageLength]; Buffer buffer = createMessageFrame(fakeMessage); int messageFrameLength = (int) buffer.size(); - frameHandler.data(false, 3, buffer, messageFrameLength); + frameHandler().data(false, 3, buffer, messageFrameLength); listener.waitUntilStreamClosed(); assertEquals(Status.INTERNAL.getCode(), listener.status.getCode()); assertEquals("Received data size exceeded our receiving window size", listener.status.getDescription()); - verify(frameWriter).rstStream(eq(3), eq(ErrorCode.FLOW_CONTROL_ERROR)); + verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.FLOW_CONTROL_ERROR)); } @Test @@ -764,27 +804,29 @@ public class OkHttpClientTransportTest { } private void shouldHeadersBeFlushed(boolean shouldBeFlushed) throws Exception { + initTransport(); OkHttpClientStream stream = clientTransport.newStream( method, new Metadata.Headers(), new MockStreamListener()); - verify(frameWriter).synStream( + verify(frameWriter, timeout(TIME_OUT_MS)).synStream( eq(false), eq(false), eq(3), eq(0), Matchers.anyListOf(Header.class)); if (shouldBeFlushed) { - verify(frameWriter).flush(); + verify(frameWriter, timeout(TIME_OUT_MS)).flush(); } else { - verify(frameWriter, times(0)).flush(); + verify(frameWriter, timeout(TIME_OUT_MS).times(0)).flush(); } stream.cancel(Status.CANCELLED); } @Test public void receiveDataWithoutHeader() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method,new Metadata.Headers(), listener).request(1); Buffer buffer = createMessageFrame(new byte[1]); - frameHandler.data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size()); // Trigger the failure by a trailer. - frameHandler.headers( + frameHandler().headers( true, true, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); listener.waitUntilStreamClosed(); @@ -795,14 +837,15 @@ public class OkHttpClientTransportTest { @Test public void receiveDataWithoutHeaderAndTrailer() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener).request(1); Buffer buffer = createMessageFrame(new byte[1]); - frameHandler.data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size()); // Trigger the failure by a data frame. buffer = createMessageFrame(new byte[1]); - frameHandler.data(true, 3, buffer, (int) buffer.size()); + frameHandler().data(true, 3, buffer, (int) buffer.size()); listener.waitUntilStreamClosed(); assertEquals(Status.INTERNAL.getCode(), listener.status.getCode()); @@ -812,13 +855,14 @@ public class OkHttpClientTransportTest { @Test public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); - clientTransport.newStream(method,new Metadata.Headers(), listener).request(1); + clientTransport.newStream(method, new Metadata.Headers(), listener).request(1); Buffer buffer = createMessageFrame(new byte[1000]); - frameHandler.data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size()); // Once we receive enough detail, we cancel the stream. so we should have sent cancel. - verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); + verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); listener.waitUntilStreamClosed(); assertEquals(Status.INTERNAL.getCode(), listener.status.getCode()); @@ -828,43 +872,48 @@ public class OkHttpClientTransportTest { @Test public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = clientTransport.newStream(method, new Metadata.Headers(), listener); stream.cancel(Status.CANCELLED); Buffer buffer = createMessageFrame( new byte[Utils.DEFAULT_WINDOW_SIZE / 2 + 1]); - frameHandler.data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size()); // Should still update the connection window even stream 3 is gone. - verify(frameWriter).windowUpdate(0, + verify(frameWriter, timeout(TIME_OUT_MS)).windowUpdate(0, HEADER_LENGTH + Utils.DEFAULT_WINDOW_SIZE / 2 + 1); buffer = createMessageFrame( new byte[Utils.DEFAULT_WINDOW_SIZE / 2 + 1]); // This should kill the connection, since we never created stream 5. - frameHandler.data(false, 5, buffer, (int) buffer.size()); - verify(frameWriter).goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); + frameHandler().data(false, 5, buffer, (int) buffer.size()); + verify(frameWriter, timeout(TIME_OUT_MS)) + .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); verify(transportListener).transportShutdown(); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } @Test public void receiveWindowUpdateForUnknownStream() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = clientTransport.newStream(method, new Metadata.Headers(), listener); stream.cancel(Status.CANCELLED); // This should be ignored. - frameHandler.windowUpdate(3, 73); + frameHandler().windowUpdate(3, 73); listener.waitUntilStreamClosed(); // This should kill the connection, since we never created stream 5. - frameHandler.windowUpdate(5, 73); - verify(frameWriter).goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); + frameHandler().windowUpdate(5, 73); + verify(frameWriter, timeout(TIME_OUT_MS)) + .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); verify(transportListener).transportShutdown(); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } @Test public void shouldBeInitiallyReady() throws Exception { + initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = clientTransport.newStream( method,new Metadata.Headers(), listener); @@ -876,6 +925,7 @@ public class OkHttpClientTransportTest { @Test public void notifyOnReady() throws Exception { + initTransport(); final int messageLength = 15; setInitialWindowSize(0); MockStreamListener listener = new MockStreamListener(); @@ -903,14 +953,14 @@ public class OkHttpClientTransportTest { assertFalse(stream.isReady()); // Let the first message out. - frameHandler.windowUpdate(0, HEADER_LENGTH + messageLength); - frameHandler.windowUpdate(3, HEADER_LENGTH + messageLength); + frameHandler().windowUpdate(0, HEADER_LENGTH + messageLength); + frameHandler().windowUpdate(3, HEADER_LENGTH + messageLength); assertFalse(stream.isReady()); assertFalse(listener.isOnReadyCalled()); // Let the second message out. - frameHandler.windowUpdate(0, HEADER_LENGTH + messageLength); - frameHandler.windowUpdate(3, HEADER_LENGTH + messageLength); + frameHandler().windowUpdate(0, HEADER_LENGTH + messageLength); + frameHandler().windowUpdate(3, HEADER_LENGTH + messageLength); assertTrue(stream.isReady()); assertTrue(listener.isOnReadyCalled()); @@ -931,6 +981,7 @@ public class OkHttpClientTransportTest { @Test public void ping() throws Exception { + initTransport(); PingCallbackImpl callback1 = new PingCallbackImpl(); clientTransport.ping(callback1, MoreExecutors.directExecutor()); // add'l ping will be added as listener to outstanding operation @@ -939,7 +990,7 @@ public class OkHttpClientTransportTest { ArgumentCaptor captor1 = ArgumentCaptor.forClass(int.class); ArgumentCaptor captor2 = ArgumentCaptor.forClass(int.class); - verify(frameWriter).ping(eq(false), captor1.capture(), captor2.capture()); + verify(frameWriter, timeout(TIME_OUT_MS)).ping(eq(false), captor1.capture(), captor2.capture()); // callback not invoked until we see acknowledgement assertEquals(0, callback1.invocationCount); assertEquals(0, callback2.invocationCount); @@ -948,7 +999,7 @@ public class OkHttpClientTransportTest { int payload2 = captor2.getValue(); // getting a bad ack won't complete the future // to make the ack "bad", we modify the payload so it doesn't match - frameHandler.ping(true, payload1, payload2 - 1); + frameHandler().ping(true, payload1, payload2 - 1); // operation not complete because ack was wrong assertEquals(0, callback1.invocationCount); assertEquals(0, callback2.invocationCount); @@ -956,7 +1007,7 @@ public class OkHttpClientTransportTest { nanoTime += TimeUnit.MICROSECONDS.toNanos(10101); // reading the proper response should complete the future - frameHandler.ping(true, payload1, payload2); + frameHandler().ping(true, payload1, payload2); assertEquals(1, callback1.invocationCount); assertEquals(10101, callback1.roundTripTime); assertNull(callback1.failureCause); @@ -973,6 +1024,7 @@ public class OkHttpClientTransportTest { @Test public void ping_failsWhenTransportShutdown() throws Exception { + initTransport(); PingCallbackImpl callback = new PingCallbackImpl(); clientTransport.ping(callback, MoreExecutors.directExecutor()); assertEquals(0, callback.invocationCount); @@ -995,6 +1047,7 @@ public class OkHttpClientTransportTest { @Test public void ping_failsIfTransportFails() throws Exception { + initTransport(); PingCallbackImpl callback = new PingCallbackImpl(); clientTransport.ping(callback, MoreExecutors.directExecutor()); assertEquals(0, callback.invocationCount); @@ -1015,6 +1068,87 @@ public class OkHttpClientTransportTest { ((StatusException) callback.failureCause).getStatus().getCode()); } + @Test + public void writeBeforeConnected() throws Exception { + initTransport(3, new ConnectedCallback(true)); + final String message = "Hello Server"; + MockStreamListener listener = new MockStreamListener(); + OkHttpClientStream stream = clientTransport.newStream(method, new Metadata.Headers(), listener); + InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); + stream.writeMessage(input); + stream.flush(); + // The message should be queued. + verifyNoMoreInteractions(frameWriter); + + connectedCallback.allowConnected(); + + // The queued message should be sent out. + ArgumentCaptor captor = ArgumentCaptor.forClass(Buffer.class); + verify(frameWriter, timeout(TIME_OUT_MS)) + .data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH)); + Buffer sentFrame = captor.getValue(); + assertEquals(createMessageFrame(message), sentFrame); + stream.cancel(Status.CANCELLED); + } + + @Test + public void cancelBeforeConnected() throws Exception { + initTransport(3, new ConnectedCallback(true)); + final String message = "Hello Server"; + MockStreamListener listener = new MockStreamListener(); + OkHttpClientStream stream = clientTransport.newStream(method, new Metadata.Headers(), listener); + InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); + stream.writeMessage(input); + stream.flush(); + stream.cancel(Status.CANCELLED); + verifyNoMoreInteractions(frameWriter); + + connectedCallback.allowConnected(); + + // There should be 4 pending operations + verify(frameWriter, timeout(TIME_OUT_MS)).synStream( + eq(false), eq(false), eq(3), eq(0), Matchers.>any()); + verify(frameWriter, timeout(TIME_OUT_MS)).flush(); + verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); + + // TODO(madongfly): Is this really what we want, we may just throw away the messages of + // a cancelled stream. + verify(frameWriter, timeout(TIME_OUT_MS)) + .data(eq(false), eq(3), any(Buffer.class), eq(12 + HEADER_LENGTH)); + } + + @Test + public void shutdownDuringConnecting() throws Exception { + initTransport(3, new ConnectedCallback(true)); + final String message = "Hello Server"; + MockStreamListener listener = new MockStreamListener(); + OkHttpClientStream stream = clientTransport.newStream(method, new Metadata.Headers(), listener); + + clientTransport.shutdown(); + connectedCallback.allowConnected(); + + // The new stream should be failed, but the started stream should not be affected. + assertNewStreamFail(); + InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); + stream.writeMessage(input); + stream.flush(); + ArgumentCaptor captor = ArgumentCaptor.forClass(Buffer.class); + verify(frameWriter, timeout(TIME_OUT_MS)) + .data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH)); + Buffer sentFrame = captor.getValue(); + assertEquals(createMessageFrame(message), sentFrame); + stream.cancel(Status.CANCELLED); + } + + + private ClientFrameHandler frameHandler() throws Exception { + if (frameHandler == null) { + connectedCallback.waitUntilConnected(); + frameHandler = clientTransport.getHandler(); + } + return frameHandler; + } + private void waitForStreamPending(int expected) throws Exception { int duration = TIME_OUT_MS / 10; for (int i = 0; i < 10; i++) { @@ -1026,23 +1160,23 @@ public class OkHttpClientTransportTest { assertEquals(expected, clientTransport.getPendingStreamSize()); } - private void assertNewStreamFail(OkHttpClientTransport transport) throws Exception { + private void assertNewStreamFail() throws Exception { MockStreamListener listener = new MockStreamListener(); - transport.newStream(method, new Metadata.Headers(), listener); + clientTransport.newStream(method, new Metadata.Headers(), listener); listener.waitUntilStreamClosed(); assertFalse(listener.status.isOk()); } - private void setMaxConcurrentStreams(int num) { + private void setMaxConcurrentStreams(int num) throws Exception { Settings settings = new Settings(); OkHttpSettingsUtil.set(settings, OkHttpSettingsUtil.MAX_CONCURRENT_STREAMS, num); - frameHandler.settings(false, settings); + frameHandler().settings(false, settings); } - private void setInitialWindowSize(int size) { + private void setInitialWindowSize(int size) throws Exception { Settings settings = new Settings(); OkHttpSettingsUtil.set(settings, OkHttpSettingsUtil.INITIAL_WINDOW_SIZE, size); - frameHandler.settings(false, settings); + frameHandler().settings(false, settings); } private static Buffer createMessageFrame(String message) { @@ -1220,4 +1354,32 @@ public class OkHttpClientTransportTest { this.failureCause = cause; } } + + private class ConnectedCallback implements Runnable { + SettableFuture connected; + SettableFuture delayed; + + private ConnectedCallback(boolean delayConnection) { + connected = SettableFuture.create(); + if (delayConnection) { + delayed = SettableFuture.create(); + } + } + + @Override + public void run() { + if (delayed != null) { + Futures.getUnchecked(delayed); + } + connected.set(null); + } + + void allowConnected() { + delayed.set(null); + } + + void waitUntilConnected() throws Exception { + connected.get(TIME_OUT_MS, TimeUnit.MILLISECONDS); + } + } }