diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java index 13e201600d..356473078d 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java @@ -199,7 +199,7 @@ class AsyncFrameWriter implements FrameWriter { try { doRun(); } catch (IOException ex) { - transport.abort(Status.fromThrowable(ex)); + transport.abort(ex); throw new RuntimeException(ex); } } diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpBuffer.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpBuffer.java new file mode 100644 index 0000000000..25dfd48772 --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpBuffer.java @@ -0,0 +1,67 @@ +package com.google.net.stubby.newtransport.okhttp; + +import com.google.net.stubby.newtransport.AbstractBuffer; +import com.google.net.stubby.newtransport.Buffer; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +/** + * A {@link Buffer} implementation that is backed by an {@link okio.Buffer}. + */ +class OkHttpBuffer extends AbstractBuffer { + private final okio.Buffer buffer; + + OkHttpBuffer(okio.Buffer buffer) { + this.buffer = buffer; + } + + @Override + public int readableBytes() { + return (int) buffer.size(); + } + + @Override + public int readUnsignedByte() { + return buffer.readByte() & 0x000000FF; + } + + @Override + public void skipBytes(int length) { + try { + buffer.skip(length); + } catch (EOFException e) { + throw new IndexOutOfBoundsException(e.getMessage()); + } + } + + @Override + public void readBytes(byte[] dest, int destOffset, int length) { + buffer.read(dest, destOffset, length); + } + + @Override + public void readBytes(ByteBuffer dest) { + // We are not using it. + throw new UnsupportedOperationException(); + } + + @Override + public void readBytes(OutputStream dest, int length) throws IOException { + buffer.writeTo(dest, length); + } + + @Override + public Buffer readBytes(int length) { + okio.Buffer buf = new okio.Buffer(); + buf.write(buffer, length); + return new OkHttpBuffer(buf); + } + + @Override + public void close() { + buffer.clear(); + } +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java index 8baafc09fc..de4ffcdb9f 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java @@ -2,17 +2,18 @@ package com.google.net.stubby.newtransport.okhttp; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; -import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; import com.google.net.stubby.Metadata; import com.google.net.stubby.MethodDescriptor; import com.google.net.stubby.Status; import com.google.net.stubby.newtransport.AbstractClientStream; import com.google.net.stubby.newtransport.AbstractClientTransport; +import com.google.net.stubby.newtransport.Buffers; import com.google.net.stubby.newtransport.ClientStream; import com.google.net.stubby.newtransport.ClientStreamListener; import com.google.net.stubby.newtransport.ClientTransport; -import com.google.net.stubby.newtransport.InputStreamDeframer; +import com.google.net.stubby.newtransport.MessageDeframer2; import com.google.net.stubby.newtransport.StreamState; import com.squareup.okhttp.internal.spdy.ErrorCode; @@ -41,6 +42,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.Executor; +import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; /** @@ -93,7 +95,7 @@ public class OkHttpClientTransport extends AbstractClientTransport { private final Map streams = Collections.synchronizedMap(new HashMap()); private final Executor executor; - private int unacknowledgedBytesRead; + private int connectionUnacknowledgedBytesRead; private ClientFrameHandler clientFrameHandler; // The status used to finish all active streams when the transport is closed. @GuardedBy("lock") @@ -170,7 +172,7 @@ public class OkHttpClientTransport extends AbstractClientTransport { // The GOAWAY is part of graceful shutdown. frameWriter.goAway(0, ErrorCode.NO_ERROR, new byte[0]); - abort(Status.INTERNAL.withDescription("Transport stopped")); + onGoAway(0, Status.INTERNAL.withDescription("Transport stopped"), null); } stopIfNecessary(); } @@ -186,33 +188,36 @@ public class OkHttpClientTransport extends AbstractClientTransport { } /** - * Finish all active streams with given status, then close the transport. + * Finish all active streams due to a failure, then close the transport. */ - void abort(Status status) { - onGoAway(-1, status); + void abort(Throwable failureCause) { + onGoAway(0, Status.fromThrowable(failureCause), failureCause); } - private void onGoAway(int lastKnownStreamId, Status status) { + private void onGoAway(int lastKnownStreamId, Status status, @Nullable Throwable failureCause) { ArrayList goAwayStreams = new ArrayList(); synchronized (lock) { goAway = true; goAwayStatus = status; - Iterator> it = streams.entrySet().iterator(); - while (it.hasNext()) { - Map.Entry entry = it.next(); - if (entry.getKey() > lastKnownStreamId) { - goAwayStreams.add(entry.getValue()); - it.remove(); + synchronized (streams) { + Iterator> it = streams.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + if (entry.getKey() > lastKnownStreamId) { + goAwayStreams.add(entry.getValue()); + it.remove(); + } } } } // Starting stop, go into STOPPING state so that Channel know this Transport should not be used - // further, will become STOPPED once all streams are complete. + // further, will become STOPPED once all streams are complete, or become FAILED immediately if + // the transport is aborted by some error. State state = state(); if (state == State.RUNNING || state == State.NEW) { - if (status.getCode() == Status.Code.INTERNAL && status.getCause() != null) { - notifyFailed(status.asRuntimeException()); + if (failureCause != null) { + notifyFailed(failureCause); } else { stopAsync(); } @@ -290,7 +295,7 @@ public class OkHttpClientTransport extends AbstractClientTransport { while (frameReader.nextFrame(this)) { } } catch (IOException ioe) { - abort(Status.fromThrowable(ioe)); + abort(ioe); } finally { // Restore the original thread name. Thread.currentThread().setName(threadName); @@ -309,26 +314,19 @@ public class OkHttpClientTransport extends AbstractClientTransport { frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM); return; } - InputStreamDeframer deframer = stream.getDeframer(); // Wait until the frame is complete. in.require(length); - deframer.deliverFrame(ByteStreams.limit(in.inputStream(), length), inFinished); - unacknowledgedBytesRead += length; - stream.unacknowledgedBytesRead += length; - if (unacknowledgedBytesRead >= DEFAULT_INITIAL_WINDOW_SIZE / 2) { - frameWriter.windowUpdate(0, unacknowledgedBytesRead); - unacknowledgedBytesRead = 0; - } - if (stream.unacknowledgedBytesRead >= DEFAULT_INITIAL_WINDOW_SIZE / 2) { - frameWriter.windowUpdate(streamId, stream.unacknowledgedBytesRead); - stream.unacknowledgedBytesRead = 0; - } - if (inFinished) { - if (finishStream(streamId, Status.OK)) { - stopIfNecessary(); - } + Buffer buf = new Buffer(); + buf.write(in.buffer(), length); + stream.deliverData(buf, inFinished, length); + + // connection window update + connectionUnacknowledgedBytesRead += length; + if (connectionUnacknowledgedBytesRead >= DEFAULT_INITIAL_WINDOW_SIZE / 2) { + frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead); + connectionUnacknowledgedBytesRead = 0; } } @@ -343,6 +341,15 @@ public class OkHttpClientTransport extends AbstractClientTransport { List
headerBlock, HeadersMode headersMode) { // TODO(user): handle received headers. + if (inFinished) { + final OkHttpClientStream stream; + stream = streams.get(streamId); + if (stream == null) { + frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM); + return; + } + stream.deliverHeaders(inFinished); + } } @Override @@ -358,7 +365,7 @@ public class OkHttpClientTransport extends AbstractClientTransport { try { frameWriter.ackSettings(settings); } catch (IOException e) { - abort(Status.fromThrowable(e)); + abort(e); } } @@ -376,7 +383,7 @@ public class OkHttpClientTransport extends AbstractClientTransport { @Override public void goAway(int lastGoodStreamId, ErrorCode errorCode, ByteString debugData) { - onGoAway(lastGoodStreamId, Status.UNAVAILABLE.withDescription("Go away")); + onGoAway(lastGoodStreamId, Status.UNAVAILABLE.withDescription("Go away"), null); } @Override @@ -388,7 +395,7 @@ public class OkHttpClientTransport extends AbstractClientTransport { @Override public void windowUpdate(int arg0, long arg1) { - // TODO(user): flow control. + // TODO(user): outbound flow control. } @Override @@ -410,7 +417,7 @@ public class OkHttpClientTransport extends AbstractClientTransport { stream.streamId = nextStreamId; streams.put(stream.streamId, stream); if (nextStreamId >= Integer.MAX_VALUE - 2) { - onGoAway(Integer.MAX_VALUE, Status.INTERNAL.withDescription("Stream id exhaust")); + onGoAway(Integer.MAX_VALUE, Status.INTERNAL.withDescription("Stream id exhaust"), null); } else { nextStreamId += 2; } @@ -422,13 +429,33 @@ public class OkHttpClientTransport extends AbstractClientTransport { @VisibleForTesting class OkHttpClientStream extends AbstractClientStream { int streamId; - final InputStreamDeframer deframer; + final MessageDeframer2 deframer; + @GuardedBy("this") int unacknowledgedBytesRead; + @GuardedBy("this") + boolean windowUpdateDisabled; OkHttpClientStream(MethodDescriptor method, Metadata.Headers headers, ClientStreamListener listener) { super(listener); - deframer = new InputStreamDeframer(inboundMessageHandler()); + if (!GRPC_V2_PROTOCOL) { + throw new RuntimeException("okhttp transport can only work with V2 protocol!"); + } + deframer = new MessageDeframer2(inboundMessageHandler(), new Executor() { + // An executor that synchronized on this stream before executing a task, so that flow + // control processing is properly synchronized. + @Override + public void execute(final Runnable command) { + executor.execute(new Runnable() { + @Override + public void run() { + synchronized (OkHttpClientStream.this) { + command.run(); + } + } + }); + } + }); synchronized (lock) { if (goAway) { setStatus(goAwayStatus, new Metadata.Trailers()); @@ -441,8 +468,24 @@ public class OkHttpClientTransport extends AbstractClientTransport { Headers.createRequestHeaders(headers, defaultPath, defaultAuthority)); } - InputStreamDeframer getDeframer() { - return deframer; + /** + * We synchronized on "this" for delivering frames and updating window size, so that the future + * listeners (executed by synchronizedExecutor) will not be executed in the same time. + */ + synchronized void deliverData(Buffer data, boolean endOfStream, int length) { + deframer.deframe(new OkHttpBuffer(data), endOfStream); + unacknowledgedBytesRead += length; + if (windowUpdateDisabled) { + return; + } + if (unacknowledgedBytesRead >= DEFAULT_INITIAL_WINDOW_SIZE / 2) { + frameWriter.windowUpdate(streamId, unacknowledgedBytesRead); + unacknowledgedBytesRead = 0; + } + } + + synchronized void deliverHeaders(boolean endOfStream) { + deframer.deframe(Buffers.empty(), endOfStream); } @Override @@ -461,8 +504,23 @@ public class OkHttpClientTransport extends AbstractClientTransport { } @Override - protected void disableWindowUpdate(ListenableFuture processingFuture) { - // TODO(user): implement inbound flow control. + synchronized protected void disableWindowUpdate(ListenableFuture processingFuture) { + if (processingFuture == null || processingFuture.isDone()) { + return; + } + windowUpdateDisabled = true; + processingFuture.addListener(new Runnable() { + @Override + public void run() { + synchronized (OkHttpClientStream.this) { + windowUpdateDisabled = false; + if (unacknowledgedBytesRead >= DEFAULT_INITIAL_WINDOW_SIZE / 2) { + frameWriter.windowUpdate(streamId, unacknowledgedBytesRead); + unacknowledgedBytesRead = 0; + } + } + } + }, MoreExecutors.directExecutor()); } @Override @@ -477,5 +535,13 @@ public class OkHttpClientTransport extends AbstractClientTransport { stopIfNecessary(); } } + + @Override + public void remoteEndClosed() { + super.remoteEndClosed(); + if (finishStream(streamId, Status.OK)) { + stopIfNecessary(); + } + } } } diff --git a/core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java b/core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java index 5b8044bc18..7e46f7d00d 100644 --- a/core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java +++ b/core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java @@ -8,14 +8,19 @@ import static org.mockito.Mockito.any; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.Service; +import com.google.common.util.concurrent.Service.State; +import com.google.common.util.concurrent.SettableFuture; import com.google.net.stubby.Metadata; import com.google.net.stubby.MethodDescriptor; import com.google.net.stubby.Status; +import com.google.net.stubby.newtransport.AbstractStream; import com.google.net.stubby.newtransport.ClientStreamListener; import com.google.net.stubby.newtransport.okhttp.OkHttpClientTransport.ClientFrameHandler; import com.google.net.stubby.newtransport.okhttp.OkHttpClientTransport.OkHttpClientStream; @@ -24,11 +29,10 @@ import com.squareup.okhttp.internal.spdy.ErrorCode; import com.squareup.okhttp.internal.spdy.FrameReader; import okio.Buffer; -import okio.BufferedSource; import org.junit.After; -import org.junit.Assume; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -38,8 +42,6 @@ import org.mockito.MockitoAnnotations; import java.io.BufferedReader; import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; @@ -59,14 +61,8 @@ import java.util.concurrent.TimeUnit; public class OkHttpClientTransportTest { private static final int TIME_OUT_MS = 5000000; private static final String NETWORK_ISSUE_MESSAGE = "network issue"; - // The extra bytes that would be added to a message before passing to okhttp transport, they are: - // 4 bytes (1 byte flag, 3 byte length) compression frame header + 5 bytes (1 byte flag, - // 4 bytes length) message frame header. - private static final int EXTRA_BYTES = 9; - - // Flags - private static final byte PAYLOAD_FRAME = 0x0; - public static final byte STATUS_FRAME = 0x3; + // The gRPC header length, which includes 1 byte compression flag and 4 bytes message length. + private static final int HEADER_LENGTH = 5; @Mock private AsyncFrameWriter frameWriter; @@ -80,6 +76,7 @@ public class OkHttpClientTransportTest { @Before public void setup() { + AbstractStream.GRPC_V2_PROTOCOL = true; MockitoAnnotations.initMocks(this); streams = new HashMap(); frameReader = new MockFrameReader(); @@ -94,9 +91,12 @@ public class OkHttpClientTransportTest { @After public void tearDown() { - clientTransport.stopAsync(); - assertTrue(frameReader.closed); - verify(frameWriter).close(); + State state = clientTransport.state(); + if (state == State.NEW || state == State.RUNNING) { + clientTransport.stopAsync(); + assertTrue(frameReader.closed); + verify(frameWriter).close(); + } executor.shutdown(); } @@ -105,7 +105,6 @@ public class OkHttpClientTransportTest { */ @Test public void nextFrameThrowIOException() throws Exception { - Assume.assumeTrue(false); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener1); @@ -121,23 +120,19 @@ public class OkHttpClientTransportTest { assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage()); assertEquals(Status.INTERNAL.getCode(), listener1.status.getCode()); assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage()); - assertTrue("Service state: " + clientTransport.state(), - Service.State.TERMINATED == clientTransport.state()); + assertEquals(Service.State.FAILED, clientTransport.state()); } @Test public void readMessages() throws Exception { - Assume.assumeTrue(false); final int numMessages = 10; final String message = "Hello Client"; MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method, new Metadata.Headers(), listener); assertTrue(streams.containsKey(3)); for (int i = 0; i < numMessages; i++) { - BufferedSource source = mock(BufferedSource.class); - InputStream inputStream = createMessageFrame(message + i); - when(source.inputStream()).thenReturn(inputStream); - frameHandler.data(i == numMessages - 1 ? true : false, 3, source, inputStream.available()); + Buffer buffer = createMessageFrame(message + i); + frameHandler.data(i == numMessages - 1 ? true : false, 3, buffer, (int) buffer.size()); } listener.waitUntilStreamClosed(); assertEquals(Status.OK, listener.status); @@ -147,23 +142,16 @@ public class OkHttpClientTransportTest { } } - @Test + @Ignore + /** + * TODO (simonma): Re-implement this test, since status is carried by header instead of data frame + * in V2 protocol. + */ public void readStatus() throws Exception { - Assume.assumeTrue(false); - MockStreamListener listener = new MockStreamListener(); - clientTransport.newStream(method,new Metadata.Headers(), listener); - assertTrue(streams.containsKey(3)); - BufferedSource source = mock(BufferedSource.class); - InputStream inputStream = createStatusFrame((short) Status.UNAVAILABLE.getCode().value()); - when(source.inputStream()).thenReturn(inputStream); - frameHandler.data(true, 3, source, inputStream.available()); - listener.waitUntilStreamClosed(); - assertEquals(Status.UNAVAILABLE.getCode(), listener.status.getCode()); } @Test public void receiveReset() throws Exception { - Assume.assumeTrue(false); MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method,new Metadata.Headers(), listener); assertTrue(streams.containsKey(3)); @@ -174,7 +162,6 @@ public class OkHttpClientTransportTest { @Test public void cancelStream() throws Exception { - Assume.assumeTrue(false); MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method,new Metadata.Headers(), listener); OkHttpClientStream stream = streams.get(3); @@ -187,7 +174,6 @@ public class OkHttpClientTransportTest { @Test public void writeMessage() throws Exception { - Assume.assumeTrue(false); final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); clientTransport.newStream(method,new Metadata.Headers(), listener); @@ -196,16 +182,14 @@ public class OkHttpClientTransportTest { assertEquals(12, input.available()); stream.writeMessage(input, input.available(), null); stream.flush(); - ArgumentCaptor captor = - ArgumentCaptor.forClass(Buffer.class); - verify(frameWriter).data(eq(false), eq(3), captor.capture(), eq(12 + EXTRA_BYTES)); + ArgumentCaptor captor = ArgumentCaptor.forClass(Buffer.class); + verify(frameWriter).data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH)); Buffer sentFrame = captor.getValue(); - checkSameInputStream(createMessageFrame(message), sentFrame.inputStream()); + assertEquals(createMessageFrame(message), sentFrame); } @Test public void windowUpdate() throws Exception { - Assume.assumeTrue(false); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); clientTransport.newStream(method,new Metadata.Headers(), listener1); @@ -216,33 +200,28 @@ public class OkHttpClientTransportTest { int messageLength = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE / 4; byte[] fakeMessage = new byte[messageLength]; - BufferedSource source = mock(BufferedSource.class); // Stream 1 receives a message - InputStream messageFrame = createMessageFrame(fakeMessage); - int messageFrameLength = messageFrame.available(); - when(source.inputStream()).thenReturn(messageFrame); - frameHandler.data(false, 3, source, messageFrame.available()); + Buffer buffer = createMessageFrame(fakeMessage); + int messageFrameLength = (int) buffer.size(); + frameHandler.data(false, 3, buffer, messageFrameLength); // Stream 2 receives a message - messageFrame = createMessageFrame(fakeMessage); - when(source.inputStream()).thenReturn(messageFrame); - frameHandler.data(false, 5, source, messageFrame.available()); + buffer = createMessageFrame(fakeMessage); + frameHandler.data(false, 5, buffer, messageFrameLength); verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * messageFrameLength)); reset(frameWriter); // Stream 1 receives another message - messageFrame = createMessageFrame(fakeMessage); - when(source.inputStream()).thenReturn(messageFrame); - frameHandler.data(false, 3, source, messageFrame.available()); + buffer = createMessageFrame(fakeMessage); + frameHandler.data(false, 3, buffer, messageFrameLength); verify(frameWriter).windowUpdate(eq(3), eq((long) 2 * messageFrameLength)); // Stream 2 receives another message - messageFrame = createMessageFrame(fakeMessage); - when(source.inputStream()).thenReturn(messageFrame); - frameHandler.data(false, 5, source, messageFrame.available()); + buffer = createMessageFrame(fakeMessage); + frameHandler.data(false, 5, buffer, messageFrameLength); verify(frameWriter).windowUpdate(eq(5), eq((long) 2 * messageFrameLength)); verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * messageFrameLength)); @@ -253,14 +232,38 @@ public class OkHttpClientTransportTest { assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener1.status); stream2.cancel(); - verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); + verify(frameWriter).rstStream(eq(5), eq(ErrorCode.CANCEL)); listener2.waitUntilStreamClosed(); assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener2.status); } + @Test + public void windowUpdateWithInboundFlowControl() throws Exception { + SettableFuture future = SettableFuture.create(); + MockStreamListener listener = new MockStreamListener(future); + clientTransport.newStream(method, new Metadata.Headers(), listener); + OkHttpClientStream stream = streams.get(3); + + int messageLength = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE / 2 + 1; + byte[] fakeMessage = new byte[messageLength]; + + Buffer buffer = createMessageFrame(fakeMessage); + long messageFrameLength = buffer.size(); + frameHandler.data(false, 3, buffer, (int) messageFrameLength); + verify(frameWriter).windowUpdate(eq(0), eq(messageFrameLength)); + verify(frameWriter, times(0)).windowUpdate(eq(3), eq(messageFrameLength)); + + future.set(null); + verify(frameWriter).windowUpdate(eq(3), eq(messageFrameLength)); + + stream.cancel(); + verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL)); + listener.waitUntilStreamClosed(); + assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener.status); + } + @Test public void stopNormally() throws Exception { - Assume.assumeTrue(false); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); clientTransport.newStream(method,new Metadata.Headers(), listener1); @@ -278,7 +281,6 @@ public class OkHttpClientTransportTest { @Test public void receiveGoAway() throws Exception { - Assume.assumeTrue(false); // start 2 streams. MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); @@ -316,16 +318,14 @@ public class OkHttpClientTransportTest { stream.flush(); ArgumentCaptor captor = ArgumentCaptor.forClass(Buffer.class); - verify(frameWriter).data(eq(false), eq(3), captor.capture(), eq(22 + EXTRA_BYTES)); + verify(frameWriter).data(eq(false), eq(3), captor.capture(), eq(22 + HEADER_LENGTH)); Buffer sentFrame = captor.getValue(); - checkSameInputStream(createMessageFrame(sentMessage), sentFrame.inputStream()); + assertEquals(createMessageFrame(sentMessage), sentFrame); // And read. final String receivedMessage = "No, you are fine."; - BufferedSource source = mock(BufferedSource.class); - InputStream inputStream = createMessageFrame(receivedMessage); - when(source.inputStream()).thenReturn(inputStream); - frameHandler.data(true, 3, source, inputStream.available()); + Buffer buffer = createMessageFrame(receivedMessage); + frameHandler.data(true, 3, buffer, (int) buffer.size()); listener1.waitUntilStreamClosed(); assertEquals(1, listener1.messages.size()); assertEquals(receivedMessage, listener1.messages.get(0)); @@ -337,7 +337,6 @@ public class OkHttpClientTransportTest { @Test public void streamIdExhaust() throws Exception { - Assume.assumeTrue(false); int startId = Integer.MAX_VALUE - 2; AsyncFrameWriter writer = mock(AsyncFrameWriter.class); OkHttpClientTransport transport = @@ -361,57 +360,16 @@ public class OkHttpClientTransportTest { assertEquals(Service.State.TERMINATED, transport.state()); } - private static void checkSameInputStream(InputStream in1, InputStream in2) throws IOException { - assertEquals(in1.available(), in2.available()); - byte[] b1 = new byte[in1.available()]; - in1.read(b1); - byte[] b2 = new byte[in2.available()]; - in2.read(b2); - for (int i = 0; i < b1.length; i++) { - if (b1[i] != b2[i]) { - fail("Different InputStream."); - } - } - } - - private static InputStream createMessageFrame(String message) throws IOException { + private static Buffer createMessageFrame(String message) { return createMessageFrame(message.getBytes(StandardCharsets.UTF_8)); } - private static InputStream createMessageFrame(byte[] message) throws IOException { - ByteArrayOutputStream os = new ByteArrayOutputStream(); - DataOutputStream dos = new DataOutputStream(os); - dos.write(PAYLOAD_FRAME); - dos.writeInt(message.length); - dos.write(message); - dos.close(); - byte[] messageFrame = os.toByteArray(); - - // Write the compression header followed by the message frame. - return addCompressionHeader(messageFrame); - } - - private static InputStream createStatusFrame(short code) throws IOException { - ByteArrayOutputStream os = new ByteArrayOutputStream(); - DataOutputStream dos = new DataOutputStream(os); - dos.write(STATUS_FRAME); - int length = 2; - dos.writeInt(length); - dos.writeShort(code); - dos.close(); - byte[] statusFrame = os.toByteArray(); - - // Write the compression header followed by the status frame. - return addCompressionHeader(statusFrame); - } - - private static InputStream addCompressionHeader(byte[] raw) throws IOException { - ByteArrayOutputStream os = new ByteArrayOutputStream(); - DataOutputStream dos = new DataOutputStream(os); - dos.writeInt(raw.length); - dos.write(raw); - dos.close(); - return new ByteArrayInputStream(os.toByteArray()); + private static Buffer createMessageFrame(byte[] message) { + Buffer buffer = new Buffer(); + buffer.writeByte(0 /* UNCOMPRESSED */); + buffer.writeInt(message.length); + buffer.write(message); + return buffer; } private static class MockFrameReader implements FrameReader { @@ -456,6 +414,15 @@ public class OkHttpClientTransportTest { Status status; CountDownLatch closed = new CountDownLatch(1); ArrayList messages = new ArrayList(); + final ListenableFuture messageFuture; + + MockStreamListener() { + messageFuture = null; + } + + MockStreamListener(ListenableFuture future) { + messageFuture = future; + } @Override public ListenableFuture headersRead(Metadata.Headers headers) { @@ -468,7 +435,7 @@ public class OkHttpClientTransportTest { if (msg != null) { messages.add(msg); } - return null; + return messageFuture; } @Override