diff --git a/integration-testing/src/main/java/com/google/net/stubby/testing/integration/AbstractTransportTest.java b/integration-testing/src/main/java/com/google/net/stubby/testing/integration/AbstractTransportTest.java index 6239ca21b1..effdaf943a 100644 --- a/integration-testing/src/main/java/com/google/net/stubby/testing/integration/AbstractTransportTest.java +++ b/integration-testing/src/main/java/com/google/net/stubby/testing/integration/AbstractTransportTest.java @@ -3,7 +3,6 @@ package com.google.net.stubby.testing.integration; import static com.google.net.stubby.testing.integration.Messages.PayloadType.COMPRESSABLE; import static com.google.net.stubby.testing.integration.Util.assertEquals; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; import com.google.common.base.Throwables; @@ -41,8 +40,8 @@ import org.junit.Assume; import org.junit.Before; import org.junit.Test; -import java.util.Arrays; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.LinkedList; import java.util.List; diff --git a/integration-testing/src/test/java/com/google/net/stubby/testing/integration/Http2OkHttpTest.java b/integration-testing/src/test/java/com/google/net/stubby/testing/integration/Http2OkHttpTest.java index de8978d2ed..c0230d9324 100644 --- a/integration-testing/src/test/java/com/google/net/stubby/testing/integration/Http2OkHttpTest.java +++ b/integration-testing/src/test/java/com/google/net/stubby/testing/integration/Http2OkHttpTest.java @@ -1,7 +1,5 @@ package com.google.net.stubby.testing.integration; -import static org.junit.Assume.assumeTrue; - import com.google.net.stubby.ChannelImpl; import com.google.net.stubby.transport.AbstractStream; import com.google.net.stubby.transport.netty.NettyServerBuilder; @@ -35,10 +33,4 @@ public class Http2OkHttpTest extends AbstractTransportTest { protected ChannelImpl createChannel() { return OkHttpChannelBuilder.forAddress("127.0.0.1", serverPort).build(); } - - @Override - public void clientStreaming() { - // TODO(user): Broken. We assume due to flow control bugs. - assumeTrue(false); - } } diff --git a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientStream.java index 29b373c8aa..3ca82ff3bc 100644 --- a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientStream.java @@ -28,7 +28,8 @@ class OkHttpClientStream extends Http2ClientStream { */ static OkHttpClientStream newStream(final Executor executor, ClientStreamListener listener, AsyncFrameWriter frameWriter, - OkHttpClientTransport transport) { + OkHttpClientTransport transport, + OutboundFlowController outboundFlow) { // Create a lock object that can be used by both the executor and methods in the stream // to ensure consistent locking behavior. final Object executorLock = new Object(); @@ -46,7 +47,7 @@ class OkHttpClientStream extends Http2ClientStream { } }; return new OkHttpClientStream(synchronizingExecutor, listener, frameWriter, transport, - executorLock); + executorLock, outboundFlow); } @GuardedBy("executorLock") @@ -54,15 +55,18 @@ class OkHttpClientStream extends Http2ClientStream { @GuardedBy("executorLock") private int processedWindow = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE; private final AsyncFrameWriter frameWriter; + private final OutboundFlowController outboundFlow; private final OkHttpClientTransport transport; // Lock used to synchronize with work done on the executor. private final Object executorLock; + private Object outboundFlowState; private OkHttpClientStream(final Executor executor, final ClientStreamListener listener, AsyncFrameWriter frameWriter, OkHttpClientTransport transport, - Object executorLock) { + Object executorLock, + OutboundFlowController outboundFlow) { super(listener, null, executor); if (!GRPC_V2_PROTOCOL) { throw new RuntimeException("okhttp transport can only work with V2 protocol!"); @@ -70,6 +74,7 @@ class OkHttpClientStream extends Http2ClientStream { this.frameWriter = frameWriter; this.transport = transport; this.executorLock = executorLock; + this.outboundFlow = outboundFlow; } public void transportHeadersReceived(List
headers, boolean endOfStream) { @@ -105,8 +110,7 @@ class OkHttpClientStream extends Http2ClientStream { // Per http2 SPEC, the max data length should be larger than 64K, while our frame size is // only 4K. Preconditions.checkState(buffer.size() < frameWriter.maxDataLength()); - frameWriter.data(endOfStream, id(), buffer, (int) buffer.size()); - frameWriter.flush(); + outboundFlow.data(endOfStream, id(), buffer); } @Override @@ -144,4 +148,12 @@ class OkHttpClientStream extends Http2ClientStream { transport.stopIfNecessary(); } } + + void setOutboundFlowState(Object outboundFlowState) { + this.outboundFlowState = outboundFlowState; + } + + Object getOutboundFlowState() { + return outboundFlowState; + } } diff --git a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientTransport.java index be7bb2793c..d0a90500c3 100644 --- a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientTransport.java @@ -82,6 +82,7 @@ public class OkHttpClientTransport extends AbstractClientTransport { private final String defaultAuthority; private FrameReader frameReader; private AsyncFrameWriter frameWriter; + private OutboundFlowController outboundFlow; private final Object lock = new Object(); @GuardedBy("lock") private int nextStreamId; @@ -118,6 +119,7 @@ public class OkHttpClientTransport extends AbstractClientTransport { this.executor = Preconditions.checkNotNull(executor); this.frameReader = Preconditions.checkNotNull(frameReader); this.frameWriter = Preconditions.checkNotNull(frameWriter); + this.outboundFlow = new OutboundFlowController(this, frameWriter); this.nextStreamId = nextStreamId; } @@ -126,7 +128,7 @@ public class OkHttpClientTransport extends AbstractClientTransport { Metadata.Headers headers, ClientStreamListener listener) { OkHttpClientStream clientStream = OkHttpClientStream.newStream(executor, listener, - frameWriter, this); + frameWriter, this, outboundFlow); if (goAway) { clientStream.setStatus(goAwayStatus, new Metadata.Trailers()); } else { @@ -154,6 +156,7 @@ public class OkHttpClientTransport extends AbstractClientTransport { Variant variant = new Http20Draft14(); frameReader = variant.newReader(source, true); frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor); + outboundFlow = new OutboundFlowController(this, frameWriter); frameWriter.connectionPreface(); Settings settings = new Settings(); frameWriter.settings(settings); @@ -185,7 +188,6 @@ public class OkHttpClientTransport extends AbstractClientTransport { return clientFrameHandler; } - @VisibleForTesting Map getStreams() { return streams; } @@ -395,8 +397,8 @@ public class OkHttpClientTransport extends AbstractClientTransport { } @Override - public void windowUpdate(int arg0, long arg1) { - // TODO(user): outbound flow control. + public void windowUpdate(int streamId, long delta) { + outboundFlow.windowUpdate(streamId, (int) delta); } @Override diff --git a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OutboundFlowController.java b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OutboundFlowController.java new file mode 100644 index 0000000000..a1881d0c11 --- /dev/null +++ b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OutboundFlowController.java @@ -0,0 +1,399 @@ +package com.google.net.stubby.transport.okhttp; + +import static com.google.net.stubby.transport.okhttp.Utils.CONNECTION_STREAM_ID; +import static com.google.net.stubby.transport.okhttp.Utils.DEFAULT_WINDOW_SIZE; +import static com.google.net.stubby.transport.okhttp.Utils.MAX_FRAME_SIZE; +import static java.lang.Math.ceil; +import static java.lang.Math.max; +import static java.lang.Math.min; + +import com.google.common.base.Preconditions; + +import com.squareup.okhttp.internal.spdy.FrameWriter; + +import okio.Buffer; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Queue; + +/** + * Simple outbound flow controller that evenly splits the connection window across all existing + * streams. + */ +class OutboundFlowController { + private static final OkHttpClientStream[] EMPTY_STREAM_ARRAY = new OkHttpClientStream[0]; + private final OkHttpClientTransport transport; + private final FrameWriter frameWriter; + private int initialWindowSize = DEFAULT_WINDOW_SIZE; + private final OutboundFlowState connectionState = new OutboundFlowState(CONNECTION_STREAM_ID); + + OutboundFlowController(OkHttpClientTransport transport, FrameWriter frameWriter) { + this.transport = Preconditions.checkNotNull(transport, "transport"); + this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter"); + } + + synchronized void initialOutboundWindowSize(int newWindowSize) { + if (newWindowSize < 0) { + throw new IllegalArgumentException("Invalid initial window size: " + newWindowSize); + } + + int delta = newWindowSize - initialWindowSize; + initialWindowSize = newWindowSize; + for (OkHttpClientStream stream : getActiveStreams()) { + // Verify that the maximum value is not exceeded by this change. + OutboundFlowState state = state(stream); + state.incrementStreamWindow(delta); + } + + if (delta > 0) { + // The window size increased, send any pending frames for all streams. + writeStreams(); + } + } + + synchronized void windowUpdate(int streamId, int delta) { + if (streamId == CONNECTION_STREAM_ID) { + // Update the connection window and write any pending frames for all streams. + connectionState.incrementStreamWindow(delta); + writeStreams(); + } else { + // Update the stream window and write any pending frames for the stream. + OutboundFlowState state = stateOrFail(streamId); + state.incrementStreamWindow(delta); + + WriteStatus writeStatus = new WriteStatus(); + state.writeBytes(state.writableWindow(), writeStatus); + if (writeStatus.hasWritten()) { + flush(); + } + } + } + + synchronized void data(boolean outFinished, int streamId, Buffer source) { + Preconditions.checkNotNull(source, "source"); + if (streamId <= 0) { + throw new IllegalArgumentException("streamId must be > 0"); + } + + OutboundFlowState state = stateOrFail(streamId); + int window = state.writableWindow(); + boolean framesAlreadyQueued = state.hasFrame(); + + OutboundFlowState.Frame frame = state.newFrame(source, outFinished); + if (!framesAlreadyQueued && window >= frame.size()) { + // Window size is large enough to send entire data frame + frame.write(); + flush(); + return; + } + + // Enqueue the frame to be written when the window size permits. + frame.enqueue(); + + if (framesAlreadyQueued || window <= 0) { + // Stream already has frames pending or is stalled, don't send anything now. + return; + } + + // Create and send a partial frame up to the window size. + frame.split(window).write(); + flush(); + } + + private void flush() { + try { + frameWriter.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private OutboundFlowState state(OkHttpClientStream stream) { + OutboundFlowState state = (OutboundFlowState) stream.getOutboundFlowState(); + if (state == null) { + state = new OutboundFlowState(stream.id()); + stream.setOutboundFlowState(state); + } + return state; + } + + private OutboundFlowState state(int streamId) { + OkHttpClientStream stream = transport.getStreams().get(streamId); + return stream != null ? state(stream) : null; + } + + private OutboundFlowState stateOrFail(int streamId) { + OutboundFlowState state = state(streamId); + if (state == null) { + throw new RuntimeException("Missing flow control window for stream: " + streamId); + } + return state; + } + + /** + * Gets all active streams as an array. + */ + private OkHttpClientStream[] getActiveStreams() { + return transport.getStreams().values().toArray(EMPTY_STREAM_ARRAY); + } + + /** + * Writes as much data for all the streams as possible given the current flow control windows. + */ + private void writeStreams() { + OkHttpClientStream[] streams = getActiveStreams(); + int connectionWindow = connectionState.window(); + for (int numStreams = streams.length; numStreams > 0 && connectionWindow > 0;) { + int nextNumStreams = 0; + int windowSlice = (int) ceil(connectionWindow / (float) numStreams); + for (int index = 0; index < numStreams && connectionWindow > 0; ++index) { + OkHttpClientStream stream = streams[index]; + OutboundFlowState state = state(stream); + + int bytesForStream = min(connectionWindow, min(state.unallocatedBytes(), windowSlice)); + if (bytesForStream > 0) { + state.allocateBytes(bytesForStream); + connectionWindow -= bytesForStream; + } + + if (state.unallocatedBytes() > 0) { + // There is more data to process for this stream. Add it to the next + // pass. + streams[nextNumStreams++] = stream; + } + } + numStreams = nextNumStreams; + } + + // Now take one last pass through all of the streams and write any allocated bytes. + WriteStatus writeStatus = new WriteStatus(); + for (OkHttpClientStream stream : getActiveStreams()) { + OutboundFlowState state = state(stream); + state.writeBytes(state.allocatedBytes(), writeStatus); + state.clearAllocatedBytes(); + } + + if (writeStatus.hasWritten()) { + flush(); + } + } + + /** + * Simple status that keeps track of the number of writes performed. + */ + private final class WriteStatus { + int numWrites; + + void incrementNumWrites() { + numWrites++; + } + + boolean hasWritten() { + return numWrites > 0; + } + } + + /** + * The outbound flow control state for a single stream. + */ + private final class OutboundFlowState { + final Queue pendingWriteQueue; + final int streamId; + int queuedBytes; + int window = initialWindowSize; + int allocatedBytes; + + OutboundFlowState(int streamId) { + this.streamId = streamId; + pendingWriteQueue = new ArrayDeque(2); + } + + int window() { + return window; + } + + void allocateBytes(int bytes) { + allocatedBytes += bytes; + } + + int allocatedBytes() { + return allocatedBytes; + } + + int unallocatedBytes() { + return streamableBytes() - allocatedBytes; + } + + void clearAllocatedBytes() { + allocatedBytes = 0; + } + + /** + * Increments the flow control window for this stream by the given delta and returns the new + * value. + */ + int incrementStreamWindow(int delta) { + if (delta > 0 && Integer.MAX_VALUE - delta < window) { + throw new IllegalArgumentException("Window size overflow for stream: " + streamId); + } + window += delta; + + return window; + } + + /** + * Returns the maximum writable window (minimum of the stream and connection windows). + */ + int writableWindow() { + return min(window, connectionState.window()); + } + + int streamableBytes() { + return max(0, min(window, queuedBytes)); + } + + /** + * Creates a new frame with the given values but does not add it to the pending queue. + */ + Frame newFrame(Buffer data, boolean endStream) { + return new Frame(data, endStream); + } + + /** + * Indicates whether or not there are frames in the pending queue. + */ + boolean hasFrame() { + return !pendingWriteQueue.isEmpty(); + } + + /** + * Returns the the head of the pending queue, or {@code null} if empty. + */ + private Frame peek() { + return pendingWriteQueue.peek(); + } + + /** + * Writes up to the number of bytes from the pending queue. + */ + int writeBytes(int bytes, WriteStatus writeStatus) { + int bytesAttempted = 0; + int maxBytes = min(bytes, writableWindow()); + while (hasFrame()) { + Frame pendingWrite = peek(); + if (maxBytes >= pendingWrite.size()) { + // Window size is large enough to send entire data frame + writeStatus.incrementNumWrites(); + bytesAttempted += pendingWrite.size(); + pendingWrite.write(); + } else if (maxBytes <= 0) { + // No data from the current frame can be written - we're done. + // We purposely check this after first testing the size of the + // pending frame to properly handle zero-length frame. + break; + } else { + // We can send a partial frame + Frame partialFrame = pendingWrite.split(maxBytes); + writeStatus.incrementNumWrites(); + bytesAttempted += partialFrame.size(); + partialFrame.write(); + } + + // Update the threshold. + maxBytes = min(bytes - bytesAttempted, writableWindow()); + } + return bytesAttempted; + } + + /** + * A wrapper class around the content of a data frame. + */ + private final class Frame { + final Buffer data; + final boolean endStream; + boolean enqueued; + + Frame(Buffer data, boolean endStream) { + this.data = data; + this.endStream = endStream; + } + + /** + * Gets the total size (in bytes) of this frame including the data and padding. + */ + int size() { + return (int) data.size(); + } + + void enqueue() { + if (!enqueued) { + enqueued = true; + pendingWriteQueue.offer(this); + + // Increment the number of pending bytes for this stream. + queuedBytes += size(); + } + } + + /** + * Writes the frame and decrements the stream and connection window sizes. If the frame is in + * the pending queue, the written bytes are removed from this branch of the priority tree. + */ + void write() { + // Using a do/while loop because if the buffer is empty we still need to call + // the writer once to send the empty frame. + do { + int bytesToWrite = size(); + int frameBytes = min(bytesToWrite, MAX_FRAME_SIZE); + if (frameBytes == bytesToWrite) { + // All the bytes fit into a single HTTP/2 frame, just send it all. + connectionState.incrementStreamWindow(-bytesToWrite); + incrementStreamWindow(-bytesToWrite); + try { + frameWriter.data(endStream, streamId, data, bytesToWrite); + } catch (IOException e) { + throw new RuntimeException(e); + } + if (enqueued) { + // It's enqueued - remove it from the head of the pending write queue. + queuedBytes -= bytesToWrite; + pendingWriteQueue.remove(this); + } + return; + } + + // Split a chunk that will fit into a single HTTP/2 frame and write it. + Frame frame = split(frameBytes); + frame.write(); + } while (size() > 0); + } + + /** + * Creates a new frame that is a view of this frame's data. The {@code maxBytes} are first + * split from the data buffer. If not all the requested bytes are available, the remaining + * bytes are then split from the padding (if available). + * + * @param maxBytes the maximum number of bytes that is allowed in the created frame. + * @return the partial frame. + */ + Frame split(int maxBytes) { + // The requested maxBytes should always be less than the size of this frame. + assert maxBytes < size() : "Attempting to split a frame for the full size."; + + // Get the portion of the data buffer to be split. Limit to the readable bytes. + int dataSplit = min(maxBytes, (int) data.size()); + + Buffer splitSlice = new Buffer(); + splitSlice.write(data, dataSplit); + + Frame frame = new Frame(splitSlice, false); + + if (enqueued) { + queuedBytes -= dataSplit; + } + return frame; + } + } + } +} diff --git a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/Utils.java b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/Utils.java index 7293b97082..6a42970b8d 100644 --- a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/Utils.java +++ b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/Utils.java @@ -10,6 +10,10 @@ import java.util.List; * Common utility methods for OkHttp transport. */ class Utils { + static final int DEFAULT_WINDOW_SIZE = 65535; + static final int CONNECTION_STREAM_ID = 0; + static final int MAX_FRAME_SIZE = 16384; + public static Metadata.Headers convertHeaders(List
http2Headers) { return new Metadata.Headers(convertHeadersToArray(http2Headers)); }