From 7bf17dc4d6c6766b2dffc5f6cfc316ba7d65536d Mon Sep 17 00:00:00 2001 From: simonma Date: Tue, 29 Jul 2014 09:52:20 -0700 Subject: [PATCH] Improve okhttp client transport, handles go away and add unit test. ------------- Created by MOE: http://code.google.com/p/moe-java MOE_MIGRATED_REVID=72155172 --- .../newtransport/okhttp/AsyncFrameWriter.java | 12 +- .../okhttp/OkHttpClientTransport.java | 211 +++++-- .../okhttp/OkHttpClientTransportTest.java | 555 ++++++++++++++++++ 3 files changed, 717 insertions(+), 61 deletions(-) create mode 100644 core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java index d5ac86c907..6affe1f075 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java @@ -1,7 +1,8 @@ package com.google.net.stubby.newtransport.okhttp; import com.google.common.util.concurrent.SerializingExecutor; -import com.google.common.util.concurrent.Service; +import com.google.net.stubby.Status; +import com.google.net.stubby.transport.Transport.Code; import com.squareup.okhttp.internal.spdy.ErrorCode; import com.squareup.okhttp.internal.spdy.FrameWriter; @@ -17,9 +18,10 @@ import java.util.concurrent.Executor; class AsyncFrameWriter implements FrameWriter { private final FrameWriter frameWriter; private final Executor executor; - private final Service transport; + private final OkHttpClientTransport transport; - public AsyncFrameWriter(FrameWriter frameWriter, Service transport, Executor executor) { + public AsyncFrameWriter(FrameWriter frameWriter, OkHttpClientTransport transport, + Executor executor) { this.frameWriter = frameWriter; this.transport = transport; // Although writes are thread-safe, we serialize them to prevent consuming many Threads that are @@ -158,6 +160,8 @@ class AsyncFrameWriter implements FrameWriter { @Override public void doRun() throws IOException { frameWriter.goAway(lastGoodStreamId, errorCode, debugData); + // Flush it since after goAway, we are likely to close this writer. + frameWriter.flush(); } }); } @@ -188,7 +192,7 @@ class AsyncFrameWriter implements FrameWriter { try { doRun(); } catch (IOException ex) { - transport.stopAsync(); + transport.abort(Status.fromThrowable(ex)); throw new RuntimeException(ex); } } diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java index fc974e01b7..b4e9f68442 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java @@ -1,5 +1,6 @@ package com.google.net.stubby.newtransport.okhttp; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.common.io.ByteBuffers; @@ -33,9 +34,10 @@ import okio.Buffer; import java.io.IOException; import java.net.Socket; import java.nio.ByteBuffer; -import java.util.Collection; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; @@ -48,6 +50,7 @@ import javax.annotation.concurrent.GuardedBy; */ public class OkHttpClientTransport extends AbstractClientTransport { /** The default initial window size in HTTP/2 is 64 KiB for the stream and connection. */ + @VisibleForTesting static final int DEFAULT_INITIAL_WINDOW_SIZE = 64 * 1024; private static final ImmutableMap ERROR_CODE_TO_STATUS = ImmutableMap @@ -75,21 +78,40 @@ public class OkHttpClientTransport extends AbstractClientTransport { private final int port; private FrameReader frameReader; private AsyncFrameWriter frameWriter; - @GuardedBy("this") + private Object lock = new Object(); + @GuardedBy("lock") private int nextStreamId; private final Map streams = Collections.synchronizedMap(new HashMap()); private final ExecutorService executor = Executors.newCachedThreadPool(); private int unacknowledgedBytesRead; + private ClientFrameHandler clientFrameHandler; + // The status used to finish all active streams when the transport is closed. + @GuardedBy("lock") + private boolean goAway; + @GuardedBy("lock") + private Status goAwayStatus; public OkHttpClientTransport(String host, int port) { - this.host = host; + this.host = Preconditions.checkNotNull(host); this.port = port; // Client initiated streams are odd, server initiated ones are even. Server should not need to // use it. We start clients at 3 to avoid conflicting with HTTP negotiation. nextStreamId = 3; } + /** + * Create a transport connected to a fake peer for test. + */ + @VisibleForTesting + OkHttpClientTransport(FrameReader frameReader, AsyncFrameWriter frameWriter, int nextStreamId) { + host = null; + port = -1; + this.nextStreamId = nextStreamId; + this.frameReader = frameReader; + this.frameWriter = frameWriter; + } + @Override protected ClientStream newStreamInternal(MethodDescriptor method, StreamListener listener) { return new OkHttpClientStream(method, listener); @@ -97,53 +119,85 @@ public class OkHttpClientTransport extends AbstractClientTransport { @Override protected void doStart() { - BufferedSource source; - BufferedSink sink; - try { - Socket socket = new Socket(host, port); - // TODO(user): use SpdyConnection. - source = Okio.buffer(Okio.source(socket)); - sink = Okio.buffer(Okio.sink(socket)); - } catch (IOException e) { - throw new RuntimeException(e); + // We set host to null for test. + if (host != null) { + BufferedSource source; + BufferedSink sink; + try { + Socket socket = new Socket(host, port); + source = Okio.buffer(Okio.source(socket)); + sink = Okio.buffer(Okio.sink(socket)); + } catch (IOException e) { + throw new RuntimeException(e); + } + Variant variant = new Http20Draft12(); + frameReader = variant.newReader(source, true); + frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor); } - Variant variant = new Http20Draft12(); - frameReader = variant.newReader(source, true); - frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor); - executor.execute(new ClientFrameHandler()); notifyStarted(); + clientFrameHandler = new ClientFrameHandler(); + executor.execute(clientFrameHandler); } @Override protected void doStop() { - closeAllStreams(new Status(Code.INTERNAL, "Transport stopped")); - frameWriter.close(); - try { - frameReader.close(); - } catch (IOException e) { - throw new RuntimeException(e); + boolean normalClose; + synchronized (lock) { + normalClose = !goAway; } - executor.shutdown(); - notifyStopped(); + if (normalClose) { + abort(new Status(Code.INTERNAL, "Transport stopped")); + // Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated streams. + // The GOAWAY is part of graceful shutdown. + frameWriter.goAway(0, ErrorCode.NO_ERROR, null); + } + stopIfNecessary(); + } + + @VisibleForTesting + ClientFrameHandler getHandler() { + return clientFrameHandler; + } + + @VisibleForTesting + Map getStreams() { + return streams; } /** - * Close and remove all streams. + * Finish all active streams with given status, then close the transport. */ - private void closeAllStreams(Status status) { - Collection streamsCopy; - synchronized (streams) { - streamsCopy = streams.values(); - streams.clear(); + void abort(Status status) { + onGoAway(-1, status); + } + + private void onGoAway(int lastKnownStreamId, Status status) { + ArrayList goAwayStreams = new ArrayList(); + synchronized (lock) { + goAway = true; + goAwayStatus = status; + Iterator> it = streams.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + if (entry.getKey() > lastKnownStreamId) { + goAwayStreams.add(entry.getValue()); + it.remove(); + } + } } - for (OkHttpClientStream stream : streamsCopy) { + + // Starting stop, go into STOPPING state so that Channel know this Transport should not be used + // further, will become STOPPED once all streams are complete. + stopAsync(); + + for (OkHttpClientStream stream : goAwayStreams) { stream.setStatus(status); } } /** - * Called when a HTTP2 stream is closed. + * Called when a stream is closed. * *

Return false if the stream has already finished. */ @@ -158,11 +212,40 @@ public class OkHttpClientTransport extends AbstractClientTransport { return false; } + /** + * When the transport is in goAway states, we should stop it once all active streams finish. + */ + private void stopIfNecessary() { + boolean shouldStop; + synchronized (lock) { + shouldStop = (goAway && streams.size() == 0); + } + if (shouldStop) { + frameWriter.close(); + try { + frameReader.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + executor.shutdown(); + notifyStopped(); + } + } + + /** + * Returns a Grpc status corresponding to the given ErrorCode. + */ + @VisibleForTesting + static Status toGrpcStatus(ErrorCode code) { + return ERROR_CODE_TO_STATUS.get(code); + } + /** * Runnable which reads frames and dispatches them to in flight calls */ - private class ClientFrameHandler implements FrameReader.Handler, Runnable { - private ClientFrameHandler() {} + @VisibleForTesting + class ClientFrameHandler implements FrameReader.Handler, Runnable { + ClientFrameHandler() {} @Override public void run() { @@ -173,8 +256,7 @@ public class OkHttpClientTransport extends AbstractClientTransport { while (frameReader.nextFrame(this)) { } } catch (IOException ioe) { - ioe.printStackTrace(); - closeAllStreams(new Status(Code.INTERNAL, ioe.getMessage())); + abort(Status.fromThrowable(ioe)); } finally { // Restore the original thread name. Thread.currentThread().setName(threadName); @@ -210,7 +292,9 @@ public class OkHttpClientTransport extends AbstractClientTransport { stream.unacknowledgedBytesRead = 0; } if (inFinished) { - finishStream(streamId, Status.OK); + if (finishStream(streamId, Status.OK)) { + stopIfNecessary(); + } } } @@ -229,7 +313,9 @@ public class OkHttpClientTransport extends AbstractClientTransport { @Override public void rstStream(int streamId, ErrorCode errorCode) { - finishStream(streamId, ERROR_CODE_TO_STATUS.get(errorCode)); + if (finishStream(streamId, toGrpcStatus(errorCode))) { + stopIfNecessary(); + } } @Override @@ -252,18 +338,14 @@ public class OkHttpClientTransport extends AbstractClientTransport { @Override public void goAway(int lastGoodStreamId, ErrorCode errorCode, ByteString debugData) { - // TODO(user): Log here and implement the real Go away behavior: streams have - // id <= lastGoodStreamId should not be closed. - closeAllStreams(new Status(Code.UNAVAILABLE, "Go away")); - stopAsync(); + onGoAway(lastGoodStreamId, new Status(Code.UNAVAILABLE, "Go away")); } @Override public void pushPromise(int streamId, int promisedStreamId, List

requestHeaders) throws IOException { - // TODO(user): should send SETTINGS_ENABLE_PUSH=0, then here we should reset it with - // PROTOCOL_ERROR. - frameWriter.rstStream(streamId, ErrorCode.REFUSED_STREAM); + // We don't accept server initiated stream. + frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR); } @Override @@ -284,28 +366,42 @@ public class OkHttpClientTransport extends AbstractClientTransport { } } + @GuardedBy("lock") + private void assignStreamId(OkHttpClientStream stream) { + Preconditions.checkState(stream.streamId == 0, "StreamId already assigned"); + stream.streamId = nextStreamId; + streams.put(stream.streamId, stream); + if (nextStreamId >= Integer.MAX_VALUE - 2) { + onGoAway(Integer.MAX_VALUE, new Status(Code.INTERNAL, "Stream id exhaust")); + } else { + nextStreamId += 2; + } + } + /** * Client stream for the okhttp transport. */ - private class OkHttpClientStream extends AbstractStream implements ClientStream { + @VisibleForTesting + class OkHttpClientStream extends AbstractStream implements ClientStream { int streamId; final InputStreamDeframer deframer; int unacknowledgedBytesRead; - public OkHttpClientStream(MethodDescriptor method, StreamListener listener) { + OkHttpClientStream(MethodDescriptor method, StreamListener listener) { super(listener); - Preconditions.checkState(streamId == 0, "StreamId should be 0"); - synchronized (OkHttpClientTransport.this) { - streamId = nextStreamId; - nextStreamId += 2; - streams.put(streamId, this); - frameWriter.synStream(false, false, streamId, 0, - Headers.createRequestHeaders(method.getName())); - } deframer = new InputStreamDeframer(inboundMessageHandler()); + synchronized (lock) { + if (goAway) { + setStatus(goAwayStatus); + return; + } + assignStreamId(this); + } + frameWriter.synStream(false, false, streamId, 0, + Headers.createRequestHeaders(method.getName())); } - public InputStreamDeframer getDeframer() { + InputStreamDeframer getDeframer() { return deframer; } @@ -330,8 +426,9 @@ public class OkHttpClientTransport extends AbstractClientTransport { public void cancel() { Preconditions.checkState(streamId != 0, "streamId should be set"); outboundPhase = Phase.STATUS; - if (finishStream(streamId, ERROR_CODE_TO_STATUS.get(ErrorCode.CANCEL))) { + if (finishStream(streamId, toGrpcStatus(ErrorCode.CANCEL))) { frameWriter.rstStream(streamId, ErrorCode.CANCEL); + stopIfNecessary(); } } } diff --git a/core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java b/core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java new file mode 100644 index 0000000000..7c3d7fd895 --- /dev/null +++ b/core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java @@ -0,0 +1,555 @@ +package com.google.net.stubby.newtransport.okhttp; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.Service; +import com.google.net.stubby.MethodDescriptor; +import com.google.net.stubby.Status; +import com.google.net.stubby.newtransport.StreamListener; +import com.google.net.stubby.newtransport.okhttp.OkHttpClientTransport.ClientFrameHandler; +import com.google.net.stubby.newtransport.okhttp.OkHttpClientTransport.OkHttpClientStream; +import com.google.net.stubby.transport.Transport; +import com.google.net.stubby.transport.Transport.Code; +import com.google.net.stubby.transport.Transport.ContextValue; +import com.google.protobuf.ByteString; + +import com.squareup.okhttp.internal.spdy.ErrorCode; +import com.squareup.okhttp.internal.spdy.FrameReader; + +import okio.Buffer; +import okio.BufferedSource; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +/** + * Tests for {@link OkHttpClientTransport}. + */ +@RunWith(JUnit4.class) +public class OkHttpClientTransportTest { + private static final int TIME_OUT_MS = 5000000; + private static final String NETWORK_ISSUE_MESSAGE = "network issue"; + + // Flags + private static final byte PAYLOAD_FRAME = 0x0; + public static final byte CONTEXT_VALUE_FRAME = 0x1; + public static final byte STATUS_FRAME = 0x3; + + @Mock + private AsyncFrameWriter frameWriter; + @Mock + MethodDescriptor method; + private OkHttpClientTransport clientTransport; + private MockFrameReader frameReader; + private Map streams; + private ClientFrameHandler frameHandler; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + streams = new HashMap(); + frameReader = new MockFrameReader(); + clientTransport = new OkHttpClientTransport(frameReader, frameWriter, 3); + clientTransport.startAsync(); + frameHandler = clientTransport.getHandler(); + streams = clientTransport.getStreams(); + when(method.getName()).thenReturn("fakemethod"); + } + + @After + public void tearDown() { + clientTransport.stopAsync(); + assertTrue(frameReader.closed); + verify(frameWriter).close(); + } + + /** + * When nextFrame throws IOException, the transport should be aborted. + */ + @Test + public void nextFrameThrowIOException() throws Exception { + MockStreamListener listener1 = new MockStreamListener(); + MockStreamListener listener2 = new MockStreamListener(); + clientTransport.newStream(method, listener1); + clientTransport.newStream(method, listener2); + assertEquals(2, streams.size()); + assertTrue(streams.containsKey(3)); + assertTrue(streams.containsKey(5)); + frameReader.throwIOExceptionForNextFrame(); + listener1.waitUntilStreamClosed(); + listener2.waitUntilStreamClosed(); + assertEquals(0, streams.size()); + assertEquals(Code.INTERNAL, listener1.status.getCode()); + assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage()); + assertEquals(Code.INTERNAL, listener1.status.getCode()); + assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage()); + assertTrue("Service state: " + clientTransport.state(), + Service.State.TERMINATED == clientTransport.state()); + } + + @Test + public void readMessages() throws Exception { + final int numMessages = 10; + final String message = "Hello Client"; + MockStreamListener listener = new MockStreamListener(); + clientTransport.newStream(method, listener); + assertTrue(streams.containsKey(3)); + for (int i = 0; i < numMessages; i++) { + BufferedSource source = mock(BufferedSource.class); + InputStream inputStream = createMessageFrame(message + i); + when(source.inputStream()).thenReturn(inputStream); + frameHandler.data(i == numMessages - 1 ? true : false, 3, source, inputStream.available()); + } + listener.waitUntilStreamClosed(); + assertEquals(Status.OK, listener.status); + assertEquals(numMessages, listener.messages.size()); + for (int i = 0; i < numMessages; i++) { + assertEquals(message + i, listener.messages.get(i)); + } + } + + @Test + public void readContexts() throws Exception { + final int numContexts = 10; + final String key = "KEY"; + final String value = "value"; + MockStreamListener listener = new MockStreamListener(); + clientTransport.newStream(method, listener); + assertTrue(streams.containsKey(3)); + for (int i = 0; i < numContexts; i++) { + BufferedSource source = mock(BufferedSource.class); + InputStream inputStream = createContextFrame(key + i, value + i); + when(source.inputStream()).thenReturn(inputStream); + frameHandler.data(i == numContexts - 1 ? true : false, 3, source, inputStream.available()); + } + listener.waitUntilStreamClosed(); + assertEquals(Status.OK, listener.status); + assertEquals(numContexts, listener.contexts.size()); + for (int i = 0; i < numContexts; i++) { + String val = listener.contexts.get(key + i); + assertNotNull(val); + assertEquals(value + i, val); + } + } + + @Test + public void readStatus() throws Exception { + MockStreamListener listener = new MockStreamListener(); + clientTransport.newStream(method, listener); + assertTrue(streams.containsKey(3)); + BufferedSource source = mock(BufferedSource.class); + InputStream inputStream = createStatusFrame((short) Transport.Code.UNAVAILABLE.getNumber()); + when(source.inputStream()).thenReturn(inputStream); + frameHandler.data(true, 3, source, inputStream.available()); + listener.waitUntilStreamClosed(); + assertEquals(Transport.Code.UNAVAILABLE, listener.status.getCode()); + } + + @Test + public void receiveReset() throws Exception { + MockStreamListener listener = new MockStreamListener(); + clientTransport.newStream(method, listener); + assertTrue(streams.containsKey(3)); + frameHandler.rstStream(3, ErrorCode.PROTOCOL_ERROR); + listener.waitUntilStreamClosed(); + assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.PROTOCOL_ERROR), listener.status); + } + + @Test + public void cancelStream() throws Exception { + MockStreamListener listener = new MockStreamListener(); + clientTransport.newStream(method, listener); + OkHttpClientStream stream = streams.get(3); + assertNotNull(stream); + stream.cancel(); + verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); + listener.waitUntilStreamClosed(); + assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener.status); + } + + @Test + public void writeMessage() throws Exception { + final String message = "Hello Server"; + MockStreamListener listener = new MockStreamListener(); + clientTransport.newStream(method, listener); + OkHttpClientStream stream = streams.get(3); + InputStream input = new ByteArrayInputStream(message.getBytes(StandardCharsets.UTF_8)); + stream.writeMessage(input, input.available(), null); + stream.flush(); + ArgumentCaptor captor = + ArgumentCaptor.forClass(Buffer.class); + verify(frameWriter).data(eq(false), eq(3), captor.capture()); + Buffer sentFrame = captor.getValue(); + checkSameInputStream(createMessageFrame(message), sentFrame.inputStream()); + } + + @Test + public void writeContext() throws Exception { + final String key = "KEY"; + final String value = "VALUE"; + MockStreamListener listener = new MockStreamListener(); + clientTransport.newStream(method, listener); + OkHttpClientStream stream = streams.get(3); + InputStream input = new ByteArrayInputStream(value.getBytes(StandardCharsets.UTF_8)); + stream.writeContext(key, input, input.available(), null); + stream.flush(); + ArgumentCaptor captor = + ArgumentCaptor.forClass(Buffer.class); + verify(frameWriter).data(eq(false), eq(3), captor.capture()); + stream.cancel(); + verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); + listener.waitUntilStreamClosed(); + assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener.status); + } + + @Test + public void windowUpdate() throws Exception { + MockStreamListener listener1 = new MockStreamListener(); + MockStreamListener listener2 = new MockStreamListener(); + clientTransport.newStream(method, listener1); + clientTransport.newStream(method, listener2); + assertEquals(2, streams.size()); + OkHttpClientStream stream1 = streams.get(3); + OkHttpClientStream stream2 = streams.get(5); + + int messageLength = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE / 4; + byte[] fakeMessage = new byte[messageLength]; + byte[] contextBody = ContextValue + .newBuilder() + .setKey("KEY") + .setValue(ByteString.copyFrom(fakeMessage)) + .build() + .toByteArray(); + + // Stream 1 receives context + InputStream contextFrame = createContextFrame(contextBody); + int contextFrameLength = contextFrame.available(); + BufferedSource source = mock(BufferedSource.class); + when(source.inputStream()).thenReturn(contextFrame); + frameHandler.data(false, 3, source, contextFrame.available()); + + // Stream 2 receives context + contextFrame = createContextFrame(contextBody); + when(source.inputStream()).thenReturn(contextFrame); + frameHandler.data(false, 5, source, contextFrame.available()); + + verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * contextFrameLength)); + + // Stream 1 receives a message + InputStream messageFrame = createMessageFrame(fakeMessage); + int messageFrameLength = messageFrame.available(); + when(source.inputStream()).thenReturn(messageFrame); + frameHandler.data(false, 3, source, messageFrame.available()); + + verify(frameWriter).windowUpdate(eq(3), eq((long) contextFrameLength + messageFrameLength)); + + // Stream 2 receives a message + messageFrame = createMessageFrame(fakeMessage); + when(source.inputStream()).thenReturn(messageFrame); + frameHandler.data(false, 5, source, messageFrame.available()); + + verify(frameWriter).windowUpdate(eq(5), eq((long) contextFrameLength + messageFrameLength)); + verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * messageFrameLength)); + + stream1.cancel(); + verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); + listener1.waitUntilStreamClosed(); + assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener1.status); + + stream2.cancel(); + verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); + listener2.waitUntilStreamClosed(); + assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener2.status); + } + + @Test + public void stopNormally() throws Exception { + MockStreamListener listener1 = new MockStreamListener(); + MockStreamListener listener2 = new MockStreamListener(); + clientTransport.newStream(method, listener1); + clientTransport.newStream(method, listener2); + assertEquals(2, streams.size()); + clientTransport.stopAsync(); + listener1.waitUntilStreamClosed(); + listener2.waitUntilStreamClosed(); + verify(frameWriter).goAway(eq(0), eq(ErrorCode.NO_ERROR), (byte[]) any()); + assertEquals(0, streams.size()); + assertEquals(Code.INTERNAL, listener1.status.getCode()); + assertEquals(Code.INTERNAL, listener2.status.getCode()); + assertEquals(Service.State.TERMINATED, clientTransport.state()); + } + + @Test + public void receiveGoAway() throws Exception { + // start 2 streams. + MockStreamListener listener1 = new MockStreamListener(); + MockStreamListener listener2 = new MockStreamListener(); + clientTransport.newStream(method, listener1); + clientTransport.newStream(method, listener2); + assertEquals(2, streams.size()); + + // Receive goAway, max good id is 3. + frameHandler.goAway(3, ErrorCode.CANCEL, null); + + // Transport should be in STOPPING state. + assertEquals(Service.State.STOPPING, clientTransport.state()); + + // Stream 2 should be closed. + listener2.waitUntilStreamClosed(); + assertEquals(1, streams.size()); + assertEquals(Code.UNAVAILABLE, listener2.status.getCode()); + + // New stream should be failed. + MockStreamListener listener3 = new MockStreamListener(); + try { + clientTransport.newStream(method, listener3); + fail("new stream should no be accepted by a go-away transport."); + } catch (IllegalStateException ex) { + // expected. + } + + // But stream 1 should be able to send. + final String sentMessage = "Should I also go away?"; + OkHttpClientStream stream = streams.get(3); + InputStream input = + new ByteArrayInputStream(sentMessage.getBytes(StandardCharsets.UTF_8)); + stream.writeMessage(input, input.available(), null); + stream.flush(); + ArgumentCaptor captor = + ArgumentCaptor.forClass(Buffer.class); + verify(frameWriter).data(eq(false), eq(3), captor.capture()); + Buffer sentFrame = captor.getValue(); + checkSameInputStream(createMessageFrame(sentMessage), sentFrame.inputStream()); + + // And read. + final String receivedMessage = "No, you are fine."; + BufferedSource source = mock(BufferedSource.class); + InputStream inputStream = createMessageFrame(receivedMessage); + when(source.inputStream()).thenReturn(inputStream); + frameHandler.data(true, 3, source, inputStream.available()); + listener1.waitUntilStreamClosed(); + assertEquals(1, listener1.messages.size()); + assertEquals(receivedMessage, listener1.messages.get(0)); + + // The transport should be stopped after all active streams finished. + assertTrue("Service state: " + clientTransport.state(), + Service.State.TERMINATED == clientTransport.state()); + } + + @Test + public void streamIdExhaust() throws Exception { + int startId = Integer.MAX_VALUE - 2; + AsyncFrameWriter writer = mock(AsyncFrameWriter.class); + OkHttpClientTransport transport = + new OkHttpClientTransport(frameReader, writer, startId); + transport.startAsync(); + streams = transport.getStreams(); + + MockStreamListener listener1 = new MockStreamListener(); + transport.newStream(method, listener1); + + try { + transport.newStream(method, new MockStreamListener()); + fail("new stream should not be accepted by a go-away transport."); + } catch (IllegalStateException ex) { + // expected. + } + + streams.get(startId).cancel(); + listener1.waitUntilStreamClosed(); + verify(writer).rstStream(eq(startId), eq(ErrorCode.CANCEL)); + assertEquals(Service.State.TERMINATED, transport.state()); + } + + private static void checkSameInputStream(InputStream in1, InputStream in2) throws IOException { + assertEquals(in1.available(), in2.available()); + byte[] b1 = new byte[in1.available()]; + in1.read(b1); + byte[] b2 = new byte[in2.available()]; + in2.read(b2); + for (int i = 0; i < b1.length; i++) { + if (b1[i] != b2[i]) { + fail("Different InputStream."); + } + } + } + + private static InputStream createMessageFrame(String message) throws IOException { + return createMessageFrame(message.getBytes(StandardCharsets.UTF_8)); + } + + private static InputStream createMessageFrame(byte[] message) throws IOException { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(os); + dos.write(PAYLOAD_FRAME); + dos.writeInt(message.length); + dos.write(message); + dos.close(); + byte[] messageFrame = os.toByteArray(); + + // Write the compression header followed by the message frame. + return addCompressionHeader(messageFrame); + } + + private static InputStream createContextFrame(String key, String value) throws IOException { + byte[] body = ContextValue + .newBuilder() + .setKey(key) + .setValue(ByteString.copyFromUtf8(value)) + .build() + .toByteArray(); + return createContextFrame(body); + } + + private static InputStream createContextFrame(byte[] body) throws IOException { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(os); + dos.write(CONTEXT_VALUE_FRAME); + dos.writeInt(body.length); + dos.write(body); + dos.close(); + byte[] contextFrame = os.toByteArray(); + + // Write the compression header followed by the context frame. + return addCompressionHeader(contextFrame); + } + + private static InputStream createStatusFrame(short code) throws IOException { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(os); + dos.write(STATUS_FRAME); + int length = 2; + dos.writeInt(length); + dos.writeShort(code); + dos.close(); + byte[] statusFrame = os.toByteArray(); + + // Write the compression header followed by the status frame. + return addCompressionHeader(statusFrame); + } + + private static InputStream addCompressionHeader(byte[] raw) throws IOException { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(os); + dos.writeInt(raw.length); + dos.write(raw); + dos.close(); + return new ByteArrayInputStream(os.toByteArray()); + } + + private static class MockFrameReader implements FrameReader { + boolean closed; + boolean throwExceptionForNextFrame; + + @Override + public void close() throws IOException { + closed = true; + } + + @Override + public boolean nextFrame(Handler handler) throws IOException { + if (throwExceptionForNextFrame) { + throw new IOException(NETWORK_ISSUE_MESSAGE); + } + synchronized (this) { + try { + wait(); + } catch (InterruptedException e) { + throw new IOException(e); + } + } + if (throwExceptionForNextFrame) { + throw new IOException(NETWORK_ISSUE_MESSAGE); + } + return true; + } + + synchronized void throwIOExceptionForNextFrame() { + throwExceptionForNextFrame = true; + notifyAll(); + } + + @Override + public void readConnectionPreface() throws IOException { + // not used. + } + } + + private static class MockStreamListener implements StreamListener { + Status status; + CountDownLatch closed = new CountDownLatch(1); + ArrayList messages = new ArrayList(); + Map contexts = new HashMap(); + + @Override + public ListenableFuture contextRead(String name, InputStream value, int length) { + String valueStr = getContent(value); + if (valueStr != null) { + // We assume only one context for each name. + contexts.put(name, valueStr); + } + return null; + } + + @Override + public ListenableFuture messageRead(InputStream message, int length) { + String msg = getContent(message); + if (msg != null) { + messages.add(msg); + } + return null; + } + + @Override + public void closed(Status status) { + this.status = status; + closed.countDown(); + } + + void waitUntilStreamClosed() throws InterruptedException { + if (!closed.await(TIME_OUT_MS, TimeUnit.MILLISECONDS)) { + fail("Failed waiting stream to be closed."); + } + } + + static String getContent(InputStream message) { + BufferedReader br = + new BufferedReader(new InputStreamReader(message, StandardCharsets.UTF_8)); + try { + // Only one line message is used in this test. + return br.readLine(); + } catch (IOException e) { + return null; + } + } + } +}