diff --git a/okhttp/src/main/java/io/grpc/transport/okhttp/AsyncFrameWriter.java b/okhttp/src/main/java/io/grpc/transport/okhttp/AsyncFrameWriter.java
index bbc6e79cad..8615e489b4 100644
--- a/okhttp/src/main/java/io/grpc/transport/okhttp/AsyncFrameWriter.java
+++ b/okhttp/src/main/java/io/grpc/transport/okhttp/AsyncFrameWriter.java
@@ -32,7 +32,6 @@
package io.grpc.transport.okhttp;
import com.google.common.base.Preconditions;
-import com.google.common.util.concurrent.SettableFuture;
import com.squareup.okhttp.internal.spdy.ErrorCode;
import com.squareup.okhttp.internal.spdy.FrameWriter;
@@ -44,11 +43,15 @@ import io.grpc.SerializingExecutor;
import okio.Buffer;
import java.io.IOException;
+import java.net.Socket;
import java.util.List;
-import java.util.concurrent.ExecutionException;
+import java.util.logging.Level;
+import java.util.logging.Logger;
class AsyncFrameWriter implements FrameWriter {
+ private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName());
private FrameWriter frameWriter;
+ private Socket socket;
// Although writes are thread-safe, we serialize them to prevent consuming many Threads that are
// just waiting on each other.
private final SerializingExecutor executor;
@@ -60,12 +63,16 @@ class AsyncFrameWriter implements FrameWriter {
}
/**
- * Set the real frameWriter, should only be called by thread of executor.
+ * Set the real frameWriter and the corresponding underlying socket, the socket is needed for
+ * closing.
+ *
+ *
should only be called by thread of executor.
*/
- void setFrameWriter(FrameWriter frameWriter) {
+ void becomeConnected(FrameWriter frameWriter, Socket socket) {
Preconditions.checkState(this.frameWriter == null,
"AsyncFrameWriter's setFrameWriter() should only be called once.");
- this.frameWriter = frameWriter;
+ this.frameWriter = Preconditions.checkNotNull(frameWriter);
+ this.socket = Preconditions.checkNotNull(socket);
}
@Override
@@ -207,30 +214,19 @@ class AsyncFrameWriter implements FrameWriter {
@Override
public void close() {
- // Wait for the frameWriter to close.
- final SettableFuture> closeFuture = SettableFuture.create();
executor.execute(new Runnable() {
@Override
public void run() {
- try {
- if (frameWriter != null) {
+ if (frameWriter != null) {
+ try {
frameWriter.close();
+ socket.close();
+ } catch (IOException e) {
+ log.log(Level.WARNING, "Failed closing connection", e);
}
- } catch (IOException e) {
- closeFuture.setException(e);
- } finally {
- closeFuture.set(null);
}
}
});
- try {
- closeFuture.get();
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- throw new RuntimeException(e);
- } catch (ExecutionException e) {
- throw new RuntimeException(e);
- }
}
private abstract class WriteRunnable implements Runnable {
diff --git a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientStream.java
index 0f92bad3c2..fa042944dd 100644
--- a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientStream.java
+++ b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientStream.java
@@ -37,7 +37,6 @@ import static com.google.common.base.Preconditions.checkState;
import com.squareup.okhttp.internal.spdy.ErrorCode;
import com.squareup.okhttp.internal.spdy.Header;
-import io.grpc.Metadata;
import io.grpc.MethodDescriptor.MethodType;
import io.grpc.Status;
import io.grpc.transport.ClientStreamListener;
@@ -69,8 +68,8 @@ class OkHttpClientStream extends Http2ClientStream {
AsyncFrameWriter frameWriter,
OkHttpClientTransport transport,
OutboundFlowController outboundFlow,
- MethodType type) {
- return new OkHttpClientStream(listener, frameWriter, transport, outboundFlow, type);
+ MethodType type, Object lock) {
+ return new OkHttpClientStream(listener, frameWriter, transport, outboundFlow, type, lock);
}
@GuardedBy("lock")
@@ -80,7 +79,7 @@ class OkHttpClientStream extends Http2ClientStream {
private final AsyncFrameWriter frameWriter;
private final OutboundFlowController outboundFlow;
private final OkHttpClientTransport transport;
- private final Object lock = new Object();
+ private final Object lock;
private Object outboundFlowState;
private volatile Integer id;
@@ -88,12 +87,14 @@ class OkHttpClientStream extends Http2ClientStream {
AsyncFrameWriter frameWriter,
OkHttpClientTransport transport,
OutboundFlowController outboundFlow,
- MethodType type) {
+ MethodType type,
+ Object lock) {
super(new OkHttpWritableBufferAllocator(), listener);
this.frameWriter = frameWriter;
this.transport = transport;
this.outboundFlow = outboundFlow;
this.type = type;
+ this.lock = lock;
}
/**
@@ -139,33 +140,30 @@ class OkHttpClientStream extends Http2ClientStream {
onSentBytes(numBytes);
}
+ /**
+ * Must be called with holding the transport lock.
+ */
public void transportHeadersReceived(List headers, boolean endOfStream) {
- synchronized (lock) {
- if (endOfStream) {
- transportTrailersReceived(Utils.convertTrailers(headers));
- } else {
- transportHeadersReceived(Utils.convertHeaders(headers));
- }
+ if (endOfStream) {
+ transportTrailersReceived(Utils.convertTrailers(headers));
+ } else {
+ transportHeadersReceived(Utils.convertHeaders(headers));
}
}
/**
- * We synchronized on "lock" for delivering frames and updating window size, because
- * the {@link #request(int)} call can be called in other thread for delivering frames.
+ * Must be called with holding the transport lock.
*/
public void transportDataReceived(okio.Buffer frame, boolean endOfStream) {
- synchronized (lock) {
- long length = frame.size();
- window -= length;
- if (window < 0) {
- frameWriter.rstStream(id(), ErrorCode.FLOW_CONTROL_ERROR);
- Status status = Status.INTERNAL.withDescription(
- "Received data size exceeded our receiving window size");
- transport.finishStream(id(), status, null);
- return;
- }
- super.transportDataReceived(new OkHttpReadableBuffer(frame), endOfStream);
+ long length = frame.size();
+ window -= length;
+ if (window < 0) {
+ frameWriter.rstStream(id(), ErrorCode.FLOW_CONTROL_ERROR);
+ transport.finishStream(id(), Status.INTERNAL.withDescription(
+ "Received data size exceeded our receiving window size"), null);
+ return;
}
+ super.transportDataReceived(new OkHttpReadableBuffer(frame), endOfStream);
}
@Override
@@ -199,14 +197,6 @@ class OkHttpClientStream extends Http2ClientStream {
}
}
- @Override
- public void transportReportStatus(Status newStatus, boolean stopDelivery,
- Metadata.Trailers trailers) {
- synchronized (lock) {
- super.transportReportStatus(newStatus, stopDelivery, trailers);
- }
- }
-
@Override
protected void sendCancel(Status reason) {
transport.finishStream(id(), reason, ErrorCode.CANCEL);
diff --git a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java
index c629271564..c4ed67932a 100644
--- a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java
+++ b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java
@@ -70,7 +70,6 @@ import okio.Okio;
import java.io.IOException;
import java.net.Socket;
-import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
@@ -93,6 +92,7 @@ import javax.net.ssl.SSLSocketFactory;
class OkHttpClientTransport implements ClientTransport {
private static final Map ERROR_CODE_TO_STATUS;
private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName());
+ private static final OkHttpClientStream[] EMPTY_STREAM_ARRAY = new OkHttpClientStream[0];
static {
Map errorToStatus = new HashMap();
@@ -138,8 +138,9 @@ class OkHttpClientTransport implements ClientTransport {
private final Object lock = new Object();
@GuardedBy("lock")
private int nextStreamId;
+ @GuardedBy("lock")
private final Map streams =
- Collections.synchronizedMap(new HashMap());
+ new HashMap();
private final Executor executor;
// Wrap on executor, to guarantee some operations be executed serially.
private final SerializingExecutor serializingExecutor;
@@ -245,8 +246,8 @@ class OkHttpClientTransport implements ClientTransport {
Preconditions.checkNotNull(headers, "headers");
Preconditions.checkNotNull(listener, "listener");
- OkHttpClientStream clientStream =
- OkHttpClientStream.newStream(listener, frameWriter, this, outboundFlow, method.getType());
+ OkHttpClientStream clientStream = OkHttpClientStream.newStream(
+ listener, frameWriter, this, outboundFlow, method.getType(), lock);
String defaultPath = "/" + method.getFullMethodName();
List requestHeaders =
@@ -332,7 +333,7 @@ class OkHttpClientTransport implements ClientTransport {
clientFrameHandler = new ClientFrameHandler(testFrameReader);
executor.execute(clientFrameHandler);
connectedCallback.run();
- frameWriter.setFrameWriter(testFrameWriter);
+ frameWriter.becomeConnected(testFrameWriter, socket);
return;
}
BufferedSource source;
@@ -369,7 +370,7 @@ class OkHttpClientTransport implements ClientTransport {
Variant variant = new Http2();
rawFrameWriter = variant.newWriter(sink, true);
- frameWriter.setFrameWriter(rawFrameWriter);
+ frameWriter.becomeConnected(rawFrameWriter, socket);
try {
// Do these with the raw FrameWriter, so that they will be done in this thread,
@@ -390,25 +391,35 @@ class OkHttpClientTransport implements ClientTransport {
@Override
public void shutdown() {
- boolean normalClose;
synchronized (lock) {
- normalClose = !goAway;
+ if (goAway) {
+ return;
+ }
}
- if (normalClose) {
- // Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated streams.
- // The GOAWAY is part of graceful shutdown.
- frameWriter.goAway(0, ErrorCode.NO_ERROR, new byte[0]);
- onGoAway(Integer.MAX_VALUE, Status.UNAVAILABLE.withDescription("Transport stopped"));
- }
- stopIfNecessary();
+ // Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated streams.
+ // The GOAWAY is part of graceful shutdown.
+ frameWriter.goAway(0, ErrorCode.NO_ERROR, new byte[0]);
+
+ onGoAway(Integer.MAX_VALUE, Status.UNAVAILABLE.withDescription("Transport stopped"));
}
+ /**
+ * Gets all active streams as an array.
+ */
+ OkHttpClientStream[] getActiveStreams() {
+ synchronized (lock) {
+ return streams.values().toArray(EMPTY_STREAM_ARRAY);
+ }
+ }
+
+
@VisibleForTesting
ClientFrameHandler getHandler() {
return clientFrameHandler;
}
+ @VisibleForTesting
Map getStreams() {
return streams;
}
@@ -438,37 +449,32 @@ class OkHttpClientTransport implements ClientTransport {
private void onGoAway(int lastKnownStreamId, Status status) {
boolean notifyShutdown;
- ArrayList goAwayStreams = new ArrayList();
- List pendingStreamsCopy;
synchronized (lock) {
notifyShutdown = !goAway;
goAway = true;
goAwayStatus = status;
- synchronized (streams) {
- Iterator> it = streams.entrySet().iterator();
- while (it.hasNext()) {
- Map.Entry entry = it.next();
- if (entry.getKey() > lastKnownStreamId) {
- goAwayStreams.add(entry.getValue());
- it.remove();
- }
+ Iterator> it = streams.entrySet().iterator();
+ while (it.hasNext()) {
+ Map.Entry entry = it.next();
+ if (entry.getKey() > lastKnownStreamId) {
+ it.remove();
+ entry.getValue().transportReportStatus(status, false, new Metadata.Trailers());
}
}
- pendingStreamsCopy = pendingStreams;
- pendingStreams = new LinkedList();
+
+ for (PendingStream stream : pendingStreams) {
+ stream.clientStream.transportReportStatus(status, true, new Metadata.Trailers());
+ stream.createdFuture.set(null);
+ }
+ pendingStreams.clear();
}
if (notifyShutdown) {
+ // TODO(madongfly): Another thread may called stopIfNecessary() and closed the socket, so that
+ // the reading thread calls listener.transportTerminated() and race with this call.
listener.transportShutdown();
}
- for (OkHttpClientStream stream : goAwayStreams) {
- stream.transportReportStatus(status, false, new Metadata.Trailers());
- }
- for (PendingStream stream : pendingStreamsCopy) {
- stream.clientStream.transportReportStatus(
- status, true, new Metadata.Trailers());
- stream.createdFuture.set(null);
- }
+
stopIfNecessary();
}
@@ -486,19 +492,20 @@ class OkHttpClientTransport implements ClientTransport {
* @param errorCode reset the stream with this ErrorCode if not null.
*/
void finishStream(int streamId, @Nullable Status status, @Nullable ErrorCode errorCode) {
- OkHttpClientStream stream;
- stream = streams.remove(streamId);
- if (stream != null) {
- if (errorCode != null) {
- frameWriter.rstStream(streamId, ErrorCode.CANCEL);
- }
- if (status != null) {
- boolean isCancelled = (status.getCode() == Code.CANCELLED
- || status.getCode() == Code.DEADLINE_EXCEEDED);
- stream.transportReportStatus(status, isCancelled, new Metadata.Trailers());
- }
- if (!startPendingStreams()) {
- stopIfNecessary();
+ synchronized (lock) {
+ OkHttpClientStream stream = streams.remove(streamId);
+ if (stream != null) {
+ if (errorCode != null) {
+ frameWriter.rstStream(streamId, ErrorCode.CANCEL);
+ }
+ if (status != null) {
+ boolean isCancelled = (status.getCode() == Code.CANCELLED
+ || status.getCode() == Code.DEADLINE_EXCEEDED);
+ stream.transportReportStatus(status, isCancelled, new Metadata.Trailers());
+ }
+ if (!startPendingStreams()) {
+ stopIfNecessary();
+ }
}
}
}
@@ -507,38 +514,20 @@ class OkHttpClientTransport implements ClientTransport {
* When the transport is in goAway states, we should stop it once all active streams finish.
*/
void stopIfNecessary() {
- boolean shouldStop;
- Http2Ping outstandingPing = null;
- boolean socketConnected;
synchronized (lock) {
- shouldStop = (goAway && streams.size() == 0);
- if (shouldStop) {
- if (stopped) {
- // We've already stopped, don't stop again.
- shouldStop = false;
- } else {
+ if (goAway && streams.size() == 0) {
+ if (!stopped) {
stopped = true;
- outstandingPing = ping;
- ping = null;
+ // We will close the underlying socket in the writing thread to break out the reader
+ // thread, which will close the frameReader and notify the listener.
+ frameWriter.close();
+
+ if (ping != null) {
+ ping.failed(getPingFailure());
+ ping = null;
+ }
}
}
- socketConnected = socket != null;
- }
- if (shouldStop) {
- // Wait for the frame writer to close.
- frameWriter.close();
- if (socketConnected) {
- // Close the socket to break out the reader thread, which will close the
- // frameReader and notify the listener.
- try {
- socket.close();
- } catch (IOException e) {
- log.log(Level.WARNING, "Failed closing socket", e);
- }
- }
- }
- if (outstandingPing != null) {
- outstandingPing.failed(getPingFailure());
}
}
@@ -558,6 +547,12 @@ class OkHttpClientTransport implements ClientTransport {
}
}
+ OkHttpClientStream getStream(int streamId) {
+ synchronized (lock) {
+ return streams.get(streamId);
+ }
+ }
+
/**
* Returns a Grpc status corresponding to the given ErrorCode.
*/
@@ -607,8 +602,7 @@ class OkHttpClientTransport implements ClientTransport {
@Override
public void data(boolean inFinished, int streamId, BufferedSource in, int length)
throws IOException {
- final OkHttpClientStream stream;
- stream = streams.get(streamId);
+ OkHttpClientStream stream = getStream(streamId);
if (stream == null) {
if (mayHaveCreatedStream(streamId)) {
frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM);
@@ -622,7 +616,9 @@ class OkHttpClientTransport implements ClientTransport {
Buffer buf = new Buffer();
buf.write(in.buffer(), length);
- stream.transportDataReceived(buf, inFinished);
+ synchronized (lock) {
+ stream.transportDataReceived(buf, inFinished);
+ }
}
// connection window update
@@ -643,18 +639,23 @@ class OkHttpClientTransport implements ClientTransport {
int associatedStreamId,
List headerBlock,
HeadersMode headersMode) {
- OkHttpClientStream stream;
- stream = streams.get(streamId);
- if (stream == null) {
- if (mayHaveCreatedStream(streamId)) {
- frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM);
+ boolean unknownStream = false;
+ synchronized (lock) {
+ OkHttpClientStream stream = streams.get(streamId);
+ if (stream == null) {
+ if (mayHaveCreatedStream(streamId)) {
+ frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM);
+ } else {
+ unknownStream = true;
+ }
} else {
- // We don't expect any server-initiated streams.
- onError(ErrorCode.PROTOCOL_ERROR, "Received header for unknown stream: " + streamId);
+ stream.transportHeadersReceived(headerBlock, inFinished);
}
- return;
}
- stream.transportHeadersReceived(headerBlock, inFinished);
+ if (unknownStream) {
+ // We don't expect any server-initiated streams.
+ onError(ErrorCode.PROTOCOL_ERROR, "Received header for unknown stream: " + streamId);
+ }
}
@Override
@@ -748,7 +749,7 @@ class OkHttpClientTransport implements ClientTransport {
return;
}
- OkHttpClientStream stream = streams.get(streamId);
+ OkHttpClientStream stream = getStream(streamId);
if (stream != null) {
outboundFlow.windowUpdate(stream, (int) delta);
} else if (!mayHaveCreatedStream(streamId)) {
diff --git a/okhttp/src/main/java/io/grpc/transport/okhttp/OutboundFlowController.java b/okhttp/src/main/java/io/grpc/transport/okhttp/OutboundFlowController.java
index ebe062be0b..3112c5635b 100644
--- a/okhttp/src/main/java/io/grpc/transport/okhttp/OutboundFlowController.java
+++ b/okhttp/src/main/java/io/grpc/transport/okhttp/OutboundFlowController.java
@@ -54,7 +54,6 @@ import javax.annotation.Nullable;
* streams.
*/
class OutboundFlowController {
- private static final OkHttpClientStream[] EMPTY_STREAM_ARRAY = new OkHttpClientStream[0];
private final OkHttpClientTransport transport;
private final FrameWriter frameWriter;
private int initialWindowSize = DEFAULT_WINDOW_SIZE;
@@ -72,7 +71,7 @@ class OutboundFlowController {
int delta = newWindowSize - initialWindowSize;
initialWindowSize = newWindowSize;
- for (OkHttpClientStream stream : getActiveStreams()) {
+ for (OkHttpClientStream stream : transport.getActiveStreams()) {
OutboundFlowState state = (OutboundFlowState) stream.getOutboundFlowState();
if (state == null) {
// Create the OutboundFlowState with the new window size.
@@ -116,7 +115,7 @@ class OutboundFlowController {
throw new IllegalArgumentException("Invalid streamId: " + streamId);
}
- OkHttpClientStream stream = transport.getStreams().get(streamId);
+ OkHttpClientStream stream = transport.getStream(streamId);
if (stream == null) {
// This is possible for a stream that has received end-of-stream from server (but hasn't sent
// end-of-stream), and was removed from the transport stream map.
@@ -173,18 +172,11 @@ class OutboundFlowController {
return state;
}
- /**
- * Gets all active streams as an array.
- */
- private OkHttpClientStream[] getActiveStreams() {
- return transport.getStreams().values().toArray(EMPTY_STREAM_ARRAY);
- }
-
/**
* Writes as much data for all the streams as possible given the current flow control windows.
*/
private void writeStreams() {
- OkHttpClientStream[] streams = getActiveStreams();
+ OkHttpClientStream[] streams = transport.getActiveStreams();
int connectionWindow = connectionState.window();
for (int numStreams = streams.length; numStreams > 0 && connectionWindow > 0;) {
int nextNumStreams = 0;
@@ -210,7 +202,7 @@ class OutboundFlowController {
// Now take one last pass through all of the streams and write any allocated bytes.
WriteStatus writeStatus = new WriteStatus();
- for (OkHttpClientStream stream : getActiveStreams()) {
+ for (OkHttpClientStream stream : transport.getActiveStreams()) {
OutboundFlowState state = state(stream);
state.writeBytes(state.allocatedBytes(), writeStatus);
state.clearAllocatedBytes();
diff --git a/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java
index cec61987d4..58a8b41d30 100644
--- a/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java
+++ b/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java
@@ -195,7 +195,7 @@ public class OkHttpClientTransportTest {
assertEquals("Protocol error\n" + NETWORK_ISSUE_MESSAGE, listener1.status.getDescription());
assertEquals(Status.INTERNAL.getCode(), listener2.status.getCode());
assertEquals("Protocol error\n" + NETWORK_ISSUE_MESSAGE, listener2.status.getDescription());
- verify(transportListener).transportShutdown();
+ verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown();
verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated();
}