Adding crude outbound flow control to OkHttp transport.

-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=80390743
This commit is contained in:
nathanmittler 2014-11-20 09:35:16 -08:00 committed by Eric Anderson
parent d7f78773ea
commit 0d8477c85c
6 changed files with 427 additions and 19 deletions

View File

@ -3,7 +3,6 @@ package com.google.net.stubby.testing.integration;
import static com.google.net.stubby.testing.integration.Messages.PayloadType.COMPRESSABLE; import static com.google.net.stubby.testing.integration.Messages.PayloadType.COMPRESSABLE;
import static com.google.net.stubby.testing.integration.Util.assertEquals; import static com.google.net.stubby.testing.integration.Util.assertEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import com.google.common.base.Throwables; import com.google.common.base.Throwables;
@ -41,8 +40,8 @@ import org.junit.Assume;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import java.util.Arrays;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;

View File

@ -1,7 +1,5 @@
package com.google.net.stubby.testing.integration; package com.google.net.stubby.testing.integration;
import static org.junit.Assume.assumeTrue;
import com.google.net.stubby.ChannelImpl; import com.google.net.stubby.ChannelImpl;
import com.google.net.stubby.transport.AbstractStream; import com.google.net.stubby.transport.AbstractStream;
import com.google.net.stubby.transport.netty.NettyServerBuilder; import com.google.net.stubby.transport.netty.NettyServerBuilder;
@ -35,10 +33,4 @@ public class Http2OkHttpTest extends AbstractTransportTest {
protected ChannelImpl createChannel() { protected ChannelImpl createChannel() {
return OkHttpChannelBuilder.forAddress("127.0.0.1", serverPort).build(); return OkHttpChannelBuilder.forAddress("127.0.0.1", serverPort).build();
} }
@Override
public void clientStreaming() {
// TODO(user): Broken. We assume due to flow control bugs.
assumeTrue(false);
}
} }

View File

@ -28,7 +28,8 @@ class OkHttpClientStream extends Http2ClientStream {
*/ */
static OkHttpClientStream newStream(final Executor executor, ClientStreamListener listener, static OkHttpClientStream newStream(final Executor executor, ClientStreamListener listener,
AsyncFrameWriter frameWriter, AsyncFrameWriter frameWriter,
OkHttpClientTransport transport) { OkHttpClientTransport transport,
OutboundFlowController outboundFlow) {
// Create a lock object that can be used by both the executor and methods in the stream // Create a lock object that can be used by both the executor and methods in the stream
// to ensure consistent locking behavior. // to ensure consistent locking behavior.
final Object executorLock = new Object(); final Object executorLock = new Object();
@ -46,7 +47,7 @@ class OkHttpClientStream extends Http2ClientStream {
} }
}; };
return new OkHttpClientStream(synchronizingExecutor, listener, frameWriter, transport, return new OkHttpClientStream(synchronizingExecutor, listener, frameWriter, transport,
executorLock); executorLock, outboundFlow);
} }
@GuardedBy("executorLock") @GuardedBy("executorLock")
@ -54,15 +55,18 @@ class OkHttpClientStream extends Http2ClientStream {
@GuardedBy("executorLock") @GuardedBy("executorLock")
private int processedWindow = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE; private int processedWindow = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE;
private final AsyncFrameWriter frameWriter; private final AsyncFrameWriter frameWriter;
private final OutboundFlowController outboundFlow;
private final OkHttpClientTransport transport; private final OkHttpClientTransport transport;
// Lock used to synchronize with work done on the executor. // Lock used to synchronize with work done on the executor.
private final Object executorLock; private final Object executorLock;
private Object outboundFlowState;
private OkHttpClientStream(final Executor executor, private OkHttpClientStream(final Executor executor,
final ClientStreamListener listener, final ClientStreamListener listener,
AsyncFrameWriter frameWriter, AsyncFrameWriter frameWriter,
OkHttpClientTransport transport, OkHttpClientTransport transport,
Object executorLock) { Object executorLock,
OutboundFlowController outboundFlow) {
super(listener, null, executor); super(listener, null, executor);
if (!GRPC_V2_PROTOCOL) { if (!GRPC_V2_PROTOCOL) {
throw new RuntimeException("okhttp transport can only work with V2 protocol!"); throw new RuntimeException("okhttp transport can only work with V2 protocol!");
@ -70,6 +74,7 @@ class OkHttpClientStream extends Http2ClientStream {
this.frameWriter = frameWriter; this.frameWriter = frameWriter;
this.transport = transport; this.transport = transport;
this.executorLock = executorLock; this.executorLock = executorLock;
this.outboundFlow = outboundFlow;
} }
public void transportHeadersReceived(List<Header> headers, boolean endOfStream) { public void transportHeadersReceived(List<Header> headers, boolean endOfStream) {
@ -105,8 +110,7 @@ class OkHttpClientStream extends Http2ClientStream {
// Per http2 SPEC, the max data length should be larger than 64K, while our frame size is // Per http2 SPEC, the max data length should be larger than 64K, while our frame size is
// only 4K. // only 4K.
Preconditions.checkState(buffer.size() < frameWriter.maxDataLength()); Preconditions.checkState(buffer.size() < frameWriter.maxDataLength());
frameWriter.data(endOfStream, id(), buffer, (int) buffer.size()); outboundFlow.data(endOfStream, id(), buffer);
frameWriter.flush();
} }
@Override @Override
@ -144,4 +148,12 @@ class OkHttpClientStream extends Http2ClientStream {
transport.stopIfNecessary(); transport.stopIfNecessary();
} }
} }
void setOutboundFlowState(Object outboundFlowState) {
this.outboundFlowState = outboundFlowState;
}
Object getOutboundFlowState() {
return outboundFlowState;
}
} }

View File

@ -82,6 +82,7 @@ public class OkHttpClientTransport extends AbstractClientTransport {
private final String defaultAuthority; private final String defaultAuthority;
private FrameReader frameReader; private FrameReader frameReader;
private AsyncFrameWriter frameWriter; private AsyncFrameWriter frameWriter;
private OutboundFlowController outboundFlow;
private final Object lock = new Object(); private final Object lock = new Object();
@GuardedBy("lock") @GuardedBy("lock")
private int nextStreamId; private int nextStreamId;
@ -118,6 +119,7 @@ public class OkHttpClientTransport extends AbstractClientTransport {
this.executor = Preconditions.checkNotNull(executor); this.executor = Preconditions.checkNotNull(executor);
this.frameReader = Preconditions.checkNotNull(frameReader); this.frameReader = Preconditions.checkNotNull(frameReader);
this.frameWriter = Preconditions.checkNotNull(frameWriter); this.frameWriter = Preconditions.checkNotNull(frameWriter);
this.outboundFlow = new OutboundFlowController(this, frameWriter);
this.nextStreamId = nextStreamId; this.nextStreamId = nextStreamId;
} }
@ -126,7 +128,7 @@ public class OkHttpClientTransport extends AbstractClientTransport {
Metadata.Headers headers, Metadata.Headers headers,
ClientStreamListener listener) { ClientStreamListener listener) {
OkHttpClientStream clientStream = OkHttpClientStream.newStream(executor, listener, OkHttpClientStream clientStream = OkHttpClientStream.newStream(executor, listener,
frameWriter, this); frameWriter, this, outboundFlow);
if (goAway) { if (goAway) {
clientStream.setStatus(goAwayStatus, new Metadata.Trailers()); clientStream.setStatus(goAwayStatus, new Metadata.Trailers());
} else { } else {
@ -154,6 +156,7 @@ public class OkHttpClientTransport extends AbstractClientTransport {
Variant variant = new Http20Draft14(); Variant variant = new Http20Draft14();
frameReader = variant.newReader(source, true); frameReader = variant.newReader(source, true);
frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor); frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
outboundFlow = new OutboundFlowController(this, frameWriter);
frameWriter.connectionPreface(); frameWriter.connectionPreface();
Settings settings = new Settings(); Settings settings = new Settings();
frameWriter.settings(settings); frameWriter.settings(settings);
@ -185,7 +188,6 @@ public class OkHttpClientTransport extends AbstractClientTransport {
return clientFrameHandler; return clientFrameHandler;
} }
@VisibleForTesting
Map<Integer, OkHttpClientStream> getStreams() { Map<Integer, OkHttpClientStream> getStreams() {
return streams; return streams;
} }
@ -395,8 +397,8 @@ public class OkHttpClientTransport extends AbstractClientTransport {
} }
@Override @Override
public void windowUpdate(int arg0, long arg1) { public void windowUpdate(int streamId, long delta) {
// TODO(user): outbound flow control. outboundFlow.windowUpdate(streamId, (int) delta);
} }
@Override @Override

View File

@ -0,0 +1,399 @@
package com.google.net.stubby.transport.okhttp;
import static com.google.net.stubby.transport.okhttp.Utils.CONNECTION_STREAM_ID;
import static com.google.net.stubby.transport.okhttp.Utils.DEFAULT_WINDOW_SIZE;
import static com.google.net.stubby.transport.okhttp.Utils.MAX_FRAME_SIZE;
import static java.lang.Math.ceil;
import static java.lang.Math.max;
import static java.lang.Math.min;
import com.google.common.base.Preconditions;
import com.squareup.okhttp.internal.spdy.FrameWriter;
import okio.Buffer;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Queue;
/**
* Simple outbound flow controller that evenly splits the connection window across all existing
* streams.
*/
class OutboundFlowController {
private static final OkHttpClientStream[] EMPTY_STREAM_ARRAY = new OkHttpClientStream[0];
private final OkHttpClientTransport transport;
private final FrameWriter frameWriter;
private int initialWindowSize = DEFAULT_WINDOW_SIZE;
private final OutboundFlowState connectionState = new OutboundFlowState(CONNECTION_STREAM_ID);
OutboundFlowController(OkHttpClientTransport transport, FrameWriter frameWriter) {
this.transport = Preconditions.checkNotNull(transport, "transport");
this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter");
}
synchronized void initialOutboundWindowSize(int newWindowSize) {
if (newWindowSize < 0) {
throw new IllegalArgumentException("Invalid initial window size: " + newWindowSize);
}
int delta = newWindowSize - initialWindowSize;
initialWindowSize = newWindowSize;
for (OkHttpClientStream stream : getActiveStreams()) {
// Verify that the maximum value is not exceeded by this change.
OutboundFlowState state = state(stream);
state.incrementStreamWindow(delta);
}
if (delta > 0) {
// The window size increased, send any pending frames for all streams.
writeStreams();
}
}
synchronized void windowUpdate(int streamId, int delta) {
if (streamId == CONNECTION_STREAM_ID) {
// Update the connection window and write any pending frames for all streams.
connectionState.incrementStreamWindow(delta);
writeStreams();
} else {
// Update the stream window and write any pending frames for the stream.
OutboundFlowState state = stateOrFail(streamId);
state.incrementStreamWindow(delta);
WriteStatus writeStatus = new WriteStatus();
state.writeBytes(state.writableWindow(), writeStatus);
if (writeStatus.hasWritten()) {
flush();
}
}
}
synchronized void data(boolean outFinished, int streamId, Buffer source) {
Preconditions.checkNotNull(source, "source");
if (streamId <= 0) {
throw new IllegalArgumentException("streamId must be > 0");
}
OutboundFlowState state = stateOrFail(streamId);
int window = state.writableWindow();
boolean framesAlreadyQueued = state.hasFrame();
OutboundFlowState.Frame frame = state.newFrame(source, outFinished);
if (!framesAlreadyQueued && window >= frame.size()) {
// Window size is large enough to send entire data frame
frame.write();
flush();
return;
}
// Enqueue the frame to be written when the window size permits.
frame.enqueue();
if (framesAlreadyQueued || window <= 0) {
// Stream already has frames pending or is stalled, don't send anything now.
return;
}
// Create and send a partial frame up to the window size.
frame.split(window).write();
flush();
}
private void flush() {
try {
frameWriter.flush();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private OutboundFlowState state(OkHttpClientStream stream) {
OutboundFlowState state = (OutboundFlowState) stream.getOutboundFlowState();
if (state == null) {
state = new OutboundFlowState(stream.id());
stream.setOutboundFlowState(state);
}
return state;
}
private OutboundFlowState state(int streamId) {
OkHttpClientStream stream = transport.getStreams().get(streamId);
return stream != null ? state(stream) : null;
}
private OutboundFlowState stateOrFail(int streamId) {
OutboundFlowState state = state(streamId);
if (state == null) {
throw new RuntimeException("Missing flow control window for stream: " + streamId);
}
return state;
}
/**
* Gets all active streams as an array.
*/
private OkHttpClientStream[] getActiveStreams() {
return transport.getStreams().values().toArray(EMPTY_STREAM_ARRAY);
}
/**
* Writes as much data for all the streams as possible given the current flow control windows.
*/
private void writeStreams() {
OkHttpClientStream[] streams = getActiveStreams();
int connectionWindow = connectionState.window();
for (int numStreams = streams.length; numStreams > 0 && connectionWindow > 0;) {
int nextNumStreams = 0;
int windowSlice = (int) ceil(connectionWindow / (float) numStreams);
for (int index = 0; index < numStreams && connectionWindow > 0; ++index) {
OkHttpClientStream stream = streams[index];
OutboundFlowState state = state(stream);
int bytesForStream = min(connectionWindow, min(state.unallocatedBytes(), windowSlice));
if (bytesForStream > 0) {
state.allocateBytes(bytesForStream);
connectionWindow -= bytesForStream;
}
if (state.unallocatedBytes() > 0) {
// There is more data to process for this stream. Add it to the next
// pass.
streams[nextNumStreams++] = stream;
}
}
numStreams = nextNumStreams;
}
// Now take one last pass through all of the streams and write any allocated bytes.
WriteStatus writeStatus = new WriteStatus();
for (OkHttpClientStream stream : getActiveStreams()) {
OutboundFlowState state = state(stream);
state.writeBytes(state.allocatedBytes(), writeStatus);
state.clearAllocatedBytes();
}
if (writeStatus.hasWritten()) {
flush();
}
}
/**
* Simple status that keeps track of the number of writes performed.
*/
private final class WriteStatus {
int numWrites;
void incrementNumWrites() {
numWrites++;
}
boolean hasWritten() {
return numWrites > 0;
}
}
/**
* The outbound flow control state for a single stream.
*/
private final class OutboundFlowState {
final Queue<Frame> pendingWriteQueue;
final int streamId;
int queuedBytes;
int window = initialWindowSize;
int allocatedBytes;
OutboundFlowState(int streamId) {
this.streamId = streamId;
pendingWriteQueue = new ArrayDeque<Frame>(2);
}
int window() {
return window;
}
void allocateBytes(int bytes) {
allocatedBytes += bytes;
}
int allocatedBytes() {
return allocatedBytes;
}
int unallocatedBytes() {
return streamableBytes() - allocatedBytes;
}
void clearAllocatedBytes() {
allocatedBytes = 0;
}
/**
* Increments the flow control window for this stream by the given delta and returns the new
* value.
*/
int incrementStreamWindow(int delta) {
if (delta > 0 && Integer.MAX_VALUE - delta < window) {
throw new IllegalArgumentException("Window size overflow for stream: " + streamId);
}
window += delta;
return window;
}
/**
* Returns the maximum writable window (minimum of the stream and connection windows).
*/
int writableWindow() {
return min(window, connectionState.window());
}
int streamableBytes() {
return max(0, min(window, queuedBytes));
}
/**
* Creates a new frame with the given values but does not add it to the pending queue.
*/
Frame newFrame(Buffer data, boolean endStream) {
return new Frame(data, endStream);
}
/**
* Indicates whether or not there are frames in the pending queue.
*/
boolean hasFrame() {
return !pendingWriteQueue.isEmpty();
}
/**
* Returns the the head of the pending queue, or {@code null} if empty.
*/
private Frame peek() {
return pendingWriteQueue.peek();
}
/**
* Writes up to the number of bytes from the pending queue.
*/
int writeBytes(int bytes, WriteStatus writeStatus) {
int bytesAttempted = 0;
int maxBytes = min(bytes, writableWindow());
while (hasFrame()) {
Frame pendingWrite = peek();
if (maxBytes >= pendingWrite.size()) {
// Window size is large enough to send entire data frame
writeStatus.incrementNumWrites();
bytesAttempted += pendingWrite.size();
pendingWrite.write();
} else if (maxBytes <= 0) {
// No data from the current frame can be written - we're done.
// We purposely check this after first testing the size of the
// pending frame to properly handle zero-length frame.
break;
} else {
// We can send a partial frame
Frame partialFrame = pendingWrite.split(maxBytes);
writeStatus.incrementNumWrites();
bytesAttempted += partialFrame.size();
partialFrame.write();
}
// Update the threshold.
maxBytes = min(bytes - bytesAttempted, writableWindow());
}
return bytesAttempted;
}
/**
* A wrapper class around the content of a data frame.
*/
private final class Frame {
final Buffer data;
final boolean endStream;
boolean enqueued;
Frame(Buffer data, boolean endStream) {
this.data = data;
this.endStream = endStream;
}
/**
* Gets the total size (in bytes) of this frame including the data and padding.
*/
int size() {
return (int) data.size();
}
void enqueue() {
if (!enqueued) {
enqueued = true;
pendingWriteQueue.offer(this);
// Increment the number of pending bytes for this stream.
queuedBytes += size();
}
}
/**
* Writes the frame and decrements the stream and connection window sizes. If the frame is in
* the pending queue, the written bytes are removed from this branch of the priority tree.
*/
void write() {
// Using a do/while loop because if the buffer is empty we still need to call
// the writer once to send the empty frame.
do {
int bytesToWrite = size();
int frameBytes = min(bytesToWrite, MAX_FRAME_SIZE);
if (frameBytes == bytesToWrite) {
// All the bytes fit into a single HTTP/2 frame, just send it all.
connectionState.incrementStreamWindow(-bytesToWrite);
incrementStreamWindow(-bytesToWrite);
try {
frameWriter.data(endStream, streamId, data, bytesToWrite);
} catch (IOException e) {
throw new RuntimeException(e);
}
if (enqueued) {
// It's enqueued - remove it from the head of the pending write queue.
queuedBytes -= bytesToWrite;
pendingWriteQueue.remove(this);
}
return;
}
// Split a chunk that will fit into a single HTTP/2 frame and write it.
Frame frame = split(frameBytes);
frame.write();
} while (size() > 0);
}
/**
* Creates a new frame that is a view of this frame's data. The {@code maxBytes} are first
* split from the data buffer. If not all the requested bytes are available, the remaining
* bytes are then split from the padding (if available).
*
* @param maxBytes the maximum number of bytes that is allowed in the created frame.
* @return the partial frame.
*/
Frame split(int maxBytes) {
// The requested maxBytes should always be less than the size of this frame.
assert maxBytes < size() : "Attempting to split a frame for the full size.";
// Get the portion of the data buffer to be split. Limit to the readable bytes.
int dataSplit = min(maxBytes, (int) data.size());
Buffer splitSlice = new Buffer();
splitSlice.write(data, dataSplit);
Frame frame = new Frame(splitSlice, false);
if (enqueued) {
queuedBytes -= dataSplit;
}
return frame;
}
}
}
}

View File

@ -10,6 +10,10 @@ import java.util.List;
* Common utility methods for OkHttp transport. * Common utility methods for OkHttp transport.
*/ */
class Utils { class Utils {
static final int DEFAULT_WINDOW_SIZE = 65535;
static final int CONNECTION_STREAM_ID = 0;
static final int MAX_FRAME_SIZE = 16384;
public static Metadata.Headers convertHeaders(List<Header> http2Headers) { public static Metadata.Headers convertHeaders(List<Header> http2Headers) {
return new Metadata.Headers(convertHeadersToArray(http2Headers)); return new Metadata.Headers(convertHeadersToArray(http2Headers));
} }