diff --git a/okhttp/src/main/java/io/grpc/transport/okhttp/AsyncFrameWriter.java b/okhttp/src/main/java/io/grpc/transport/okhttp/AsyncFrameWriter.java index bbc6e79cad..8615e489b4 100644 --- a/okhttp/src/main/java/io/grpc/transport/okhttp/AsyncFrameWriter.java +++ b/okhttp/src/main/java/io/grpc/transport/okhttp/AsyncFrameWriter.java @@ -32,7 +32,6 @@ package io.grpc.transport.okhttp; import com.google.common.base.Preconditions; -import com.google.common.util.concurrent.SettableFuture; import com.squareup.okhttp.internal.spdy.ErrorCode; import com.squareup.okhttp.internal.spdy.FrameWriter; @@ -44,11 +43,15 @@ import io.grpc.SerializingExecutor; import okio.Buffer; import java.io.IOException; +import java.net.Socket; import java.util.List; -import java.util.concurrent.ExecutionException; +import java.util.logging.Level; +import java.util.logging.Logger; class AsyncFrameWriter implements FrameWriter { + private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName()); private FrameWriter frameWriter; + private Socket socket; // Although writes are thread-safe, we serialize them to prevent consuming many Threads that are // just waiting on each other. private final SerializingExecutor executor; @@ -60,12 +63,16 @@ class AsyncFrameWriter implements FrameWriter { } /** - * Set the real frameWriter, should only be called by thread of executor. + * Set the real frameWriter and the corresponding underlying socket, the socket is needed for + * closing. + * + *

should only be called by thread of executor. */ - void setFrameWriter(FrameWriter frameWriter) { + void becomeConnected(FrameWriter frameWriter, Socket socket) { Preconditions.checkState(this.frameWriter == null, "AsyncFrameWriter's setFrameWriter() should only be called once."); - this.frameWriter = frameWriter; + this.frameWriter = Preconditions.checkNotNull(frameWriter); + this.socket = Preconditions.checkNotNull(socket); } @Override @@ -207,30 +214,19 @@ class AsyncFrameWriter implements FrameWriter { @Override public void close() { - // Wait for the frameWriter to close. - final SettableFuture closeFuture = SettableFuture.create(); executor.execute(new Runnable() { @Override public void run() { - try { - if (frameWriter != null) { + if (frameWriter != null) { + try { frameWriter.close(); + socket.close(); + } catch (IOException e) { + log.log(Level.WARNING, "Failed closing connection", e); } - } catch (IOException e) { - closeFuture.setException(e); - } finally { - closeFuture.set(null); } } }); - try { - closeFuture.get(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } catch (ExecutionException e) { - throw new RuntimeException(e); - } } private abstract class WriteRunnable implements Runnable { diff --git a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientStream.java index 0f92bad3c2..fa042944dd 100644 --- a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientStream.java @@ -37,7 +37,6 @@ import static com.google.common.base.Preconditions.checkState; import com.squareup.okhttp.internal.spdy.ErrorCode; import com.squareup.okhttp.internal.spdy.Header; -import io.grpc.Metadata; import io.grpc.MethodDescriptor.MethodType; import io.grpc.Status; import io.grpc.transport.ClientStreamListener; @@ -69,8 +68,8 @@ class OkHttpClientStream extends Http2ClientStream { AsyncFrameWriter frameWriter, OkHttpClientTransport transport, OutboundFlowController outboundFlow, - MethodType type) { - return new OkHttpClientStream(listener, frameWriter, transport, outboundFlow, type); + MethodType type, Object lock) { + return new OkHttpClientStream(listener, frameWriter, transport, outboundFlow, type, lock); } @GuardedBy("lock") @@ -80,7 +79,7 @@ class OkHttpClientStream extends Http2ClientStream { private final AsyncFrameWriter frameWriter; private final OutboundFlowController outboundFlow; private final OkHttpClientTransport transport; - private final Object lock = new Object(); + private final Object lock; private Object outboundFlowState; private volatile Integer id; @@ -88,12 +87,14 @@ class OkHttpClientStream extends Http2ClientStream { AsyncFrameWriter frameWriter, OkHttpClientTransport transport, OutboundFlowController outboundFlow, - MethodType type) { + MethodType type, + Object lock) { super(new OkHttpWritableBufferAllocator(), listener); this.frameWriter = frameWriter; this.transport = transport; this.outboundFlow = outboundFlow; this.type = type; + this.lock = lock; } /** @@ -139,33 +140,30 @@ class OkHttpClientStream extends Http2ClientStream { onSentBytes(numBytes); } + /** + * Must be called with holding the transport lock. + */ public void transportHeadersReceived(List

headers, boolean endOfStream) { - synchronized (lock) { - if (endOfStream) { - transportTrailersReceived(Utils.convertTrailers(headers)); - } else { - transportHeadersReceived(Utils.convertHeaders(headers)); - } + if (endOfStream) { + transportTrailersReceived(Utils.convertTrailers(headers)); + } else { + transportHeadersReceived(Utils.convertHeaders(headers)); } } /** - * We synchronized on "lock" for delivering frames and updating window size, because - * the {@link #request(int)} call can be called in other thread for delivering frames. + * Must be called with holding the transport lock. */ public void transportDataReceived(okio.Buffer frame, boolean endOfStream) { - synchronized (lock) { - long length = frame.size(); - window -= length; - if (window < 0) { - frameWriter.rstStream(id(), ErrorCode.FLOW_CONTROL_ERROR); - Status status = Status.INTERNAL.withDescription( - "Received data size exceeded our receiving window size"); - transport.finishStream(id(), status, null); - return; - } - super.transportDataReceived(new OkHttpReadableBuffer(frame), endOfStream); + long length = frame.size(); + window -= length; + if (window < 0) { + frameWriter.rstStream(id(), ErrorCode.FLOW_CONTROL_ERROR); + transport.finishStream(id(), Status.INTERNAL.withDescription( + "Received data size exceeded our receiving window size"), null); + return; } + super.transportDataReceived(new OkHttpReadableBuffer(frame), endOfStream); } @Override @@ -199,14 +197,6 @@ class OkHttpClientStream extends Http2ClientStream { } } - @Override - public void transportReportStatus(Status newStatus, boolean stopDelivery, - Metadata.Trailers trailers) { - synchronized (lock) { - super.transportReportStatus(newStatus, stopDelivery, trailers); - } - } - @Override protected void sendCancel(Status reason) { transport.finishStream(id(), reason, ErrorCode.CANCEL); diff --git a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java index c629271564..c4ed67932a 100644 --- a/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/transport/okhttp/OkHttpClientTransport.java @@ -70,7 +70,6 @@ import okio.Okio; import java.io.IOException; import java.net.Socket; -import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; @@ -93,6 +92,7 @@ import javax.net.ssl.SSLSocketFactory; class OkHttpClientTransport implements ClientTransport { private static final Map ERROR_CODE_TO_STATUS; private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName()); + private static final OkHttpClientStream[] EMPTY_STREAM_ARRAY = new OkHttpClientStream[0]; static { Map errorToStatus = new HashMap(); @@ -138,8 +138,9 @@ class OkHttpClientTransport implements ClientTransport { private final Object lock = new Object(); @GuardedBy("lock") private int nextStreamId; + @GuardedBy("lock") private final Map streams = - Collections.synchronizedMap(new HashMap()); + new HashMap(); private final Executor executor; // Wrap on executor, to guarantee some operations be executed serially. private final SerializingExecutor serializingExecutor; @@ -245,8 +246,8 @@ class OkHttpClientTransport implements ClientTransport { Preconditions.checkNotNull(headers, "headers"); Preconditions.checkNotNull(listener, "listener"); - OkHttpClientStream clientStream = - OkHttpClientStream.newStream(listener, frameWriter, this, outboundFlow, method.getType()); + OkHttpClientStream clientStream = OkHttpClientStream.newStream( + listener, frameWriter, this, outboundFlow, method.getType(), lock); String defaultPath = "/" + method.getFullMethodName(); List
requestHeaders = @@ -332,7 +333,7 @@ class OkHttpClientTransport implements ClientTransport { clientFrameHandler = new ClientFrameHandler(testFrameReader); executor.execute(clientFrameHandler); connectedCallback.run(); - frameWriter.setFrameWriter(testFrameWriter); + frameWriter.becomeConnected(testFrameWriter, socket); return; } BufferedSource source; @@ -369,7 +370,7 @@ class OkHttpClientTransport implements ClientTransport { Variant variant = new Http2(); rawFrameWriter = variant.newWriter(sink, true); - frameWriter.setFrameWriter(rawFrameWriter); + frameWriter.becomeConnected(rawFrameWriter, socket); try { // Do these with the raw FrameWriter, so that they will be done in this thread, @@ -390,25 +391,35 @@ class OkHttpClientTransport implements ClientTransport { @Override public void shutdown() { - boolean normalClose; synchronized (lock) { - normalClose = !goAway; + if (goAway) { + return; + } } - if (normalClose) { - // Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated streams. - // The GOAWAY is part of graceful shutdown. - frameWriter.goAway(0, ErrorCode.NO_ERROR, new byte[0]); - onGoAway(Integer.MAX_VALUE, Status.UNAVAILABLE.withDescription("Transport stopped")); - } - stopIfNecessary(); + // Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated streams. + // The GOAWAY is part of graceful shutdown. + frameWriter.goAway(0, ErrorCode.NO_ERROR, new byte[0]); + + onGoAway(Integer.MAX_VALUE, Status.UNAVAILABLE.withDescription("Transport stopped")); } + /** + * Gets all active streams as an array. + */ + OkHttpClientStream[] getActiveStreams() { + synchronized (lock) { + return streams.values().toArray(EMPTY_STREAM_ARRAY); + } + } + + @VisibleForTesting ClientFrameHandler getHandler() { return clientFrameHandler; } + @VisibleForTesting Map getStreams() { return streams; } @@ -438,37 +449,32 @@ class OkHttpClientTransport implements ClientTransport { private void onGoAway(int lastKnownStreamId, Status status) { boolean notifyShutdown; - ArrayList goAwayStreams = new ArrayList(); - List pendingStreamsCopy; synchronized (lock) { notifyShutdown = !goAway; goAway = true; goAwayStatus = status; - synchronized (streams) { - Iterator> it = streams.entrySet().iterator(); - while (it.hasNext()) { - Map.Entry entry = it.next(); - if (entry.getKey() > lastKnownStreamId) { - goAwayStreams.add(entry.getValue()); - it.remove(); - } + Iterator> it = streams.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + if (entry.getKey() > lastKnownStreamId) { + it.remove(); + entry.getValue().transportReportStatus(status, false, new Metadata.Trailers()); } } - pendingStreamsCopy = pendingStreams; - pendingStreams = new LinkedList(); + + for (PendingStream stream : pendingStreams) { + stream.clientStream.transportReportStatus(status, true, new Metadata.Trailers()); + stream.createdFuture.set(null); + } + pendingStreams.clear(); } if (notifyShutdown) { + // TODO(madongfly): Another thread may called stopIfNecessary() and closed the socket, so that + // the reading thread calls listener.transportTerminated() and race with this call. listener.transportShutdown(); } - for (OkHttpClientStream stream : goAwayStreams) { - stream.transportReportStatus(status, false, new Metadata.Trailers()); - } - for (PendingStream stream : pendingStreamsCopy) { - stream.clientStream.transportReportStatus( - status, true, new Metadata.Trailers()); - stream.createdFuture.set(null); - } + stopIfNecessary(); } @@ -486,19 +492,20 @@ class OkHttpClientTransport implements ClientTransport { * @param errorCode reset the stream with this ErrorCode if not null. */ void finishStream(int streamId, @Nullable Status status, @Nullable ErrorCode errorCode) { - OkHttpClientStream stream; - stream = streams.remove(streamId); - if (stream != null) { - if (errorCode != null) { - frameWriter.rstStream(streamId, ErrorCode.CANCEL); - } - if (status != null) { - boolean isCancelled = (status.getCode() == Code.CANCELLED - || status.getCode() == Code.DEADLINE_EXCEEDED); - stream.transportReportStatus(status, isCancelled, new Metadata.Trailers()); - } - if (!startPendingStreams()) { - stopIfNecessary(); + synchronized (lock) { + OkHttpClientStream stream = streams.remove(streamId); + if (stream != null) { + if (errorCode != null) { + frameWriter.rstStream(streamId, ErrorCode.CANCEL); + } + if (status != null) { + boolean isCancelled = (status.getCode() == Code.CANCELLED + || status.getCode() == Code.DEADLINE_EXCEEDED); + stream.transportReportStatus(status, isCancelled, new Metadata.Trailers()); + } + if (!startPendingStreams()) { + stopIfNecessary(); + } } } } @@ -507,38 +514,20 @@ class OkHttpClientTransport implements ClientTransport { * When the transport is in goAway states, we should stop it once all active streams finish. */ void stopIfNecessary() { - boolean shouldStop; - Http2Ping outstandingPing = null; - boolean socketConnected; synchronized (lock) { - shouldStop = (goAway && streams.size() == 0); - if (shouldStop) { - if (stopped) { - // We've already stopped, don't stop again. - shouldStop = false; - } else { + if (goAway && streams.size() == 0) { + if (!stopped) { stopped = true; - outstandingPing = ping; - ping = null; + // We will close the underlying socket in the writing thread to break out the reader + // thread, which will close the frameReader and notify the listener. + frameWriter.close(); + + if (ping != null) { + ping.failed(getPingFailure()); + ping = null; + } } } - socketConnected = socket != null; - } - if (shouldStop) { - // Wait for the frame writer to close. - frameWriter.close(); - if (socketConnected) { - // Close the socket to break out the reader thread, which will close the - // frameReader and notify the listener. - try { - socket.close(); - } catch (IOException e) { - log.log(Level.WARNING, "Failed closing socket", e); - } - } - } - if (outstandingPing != null) { - outstandingPing.failed(getPingFailure()); } } @@ -558,6 +547,12 @@ class OkHttpClientTransport implements ClientTransport { } } + OkHttpClientStream getStream(int streamId) { + synchronized (lock) { + return streams.get(streamId); + } + } + /** * Returns a Grpc status corresponding to the given ErrorCode. */ @@ -607,8 +602,7 @@ class OkHttpClientTransport implements ClientTransport { @Override public void data(boolean inFinished, int streamId, BufferedSource in, int length) throws IOException { - final OkHttpClientStream stream; - stream = streams.get(streamId); + OkHttpClientStream stream = getStream(streamId); if (stream == null) { if (mayHaveCreatedStream(streamId)) { frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM); @@ -622,7 +616,9 @@ class OkHttpClientTransport implements ClientTransport { Buffer buf = new Buffer(); buf.write(in.buffer(), length); - stream.transportDataReceived(buf, inFinished); + synchronized (lock) { + stream.transportDataReceived(buf, inFinished); + } } // connection window update @@ -643,18 +639,23 @@ class OkHttpClientTransport implements ClientTransport { int associatedStreamId, List
headerBlock, HeadersMode headersMode) { - OkHttpClientStream stream; - stream = streams.get(streamId); - if (stream == null) { - if (mayHaveCreatedStream(streamId)) { - frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM); + boolean unknownStream = false; + synchronized (lock) { + OkHttpClientStream stream = streams.get(streamId); + if (stream == null) { + if (mayHaveCreatedStream(streamId)) { + frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM); + } else { + unknownStream = true; + } } else { - // We don't expect any server-initiated streams. - onError(ErrorCode.PROTOCOL_ERROR, "Received header for unknown stream: " + streamId); + stream.transportHeadersReceived(headerBlock, inFinished); } - return; } - stream.transportHeadersReceived(headerBlock, inFinished); + if (unknownStream) { + // We don't expect any server-initiated streams. + onError(ErrorCode.PROTOCOL_ERROR, "Received header for unknown stream: " + streamId); + } } @Override @@ -748,7 +749,7 @@ class OkHttpClientTransport implements ClientTransport { return; } - OkHttpClientStream stream = streams.get(streamId); + OkHttpClientStream stream = getStream(streamId); if (stream != null) { outboundFlow.windowUpdate(stream, (int) delta); } else if (!mayHaveCreatedStream(streamId)) { diff --git a/okhttp/src/main/java/io/grpc/transport/okhttp/OutboundFlowController.java b/okhttp/src/main/java/io/grpc/transport/okhttp/OutboundFlowController.java index ebe062be0b..3112c5635b 100644 --- a/okhttp/src/main/java/io/grpc/transport/okhttp/OutboundFlowController.java +++ b/okhttp/src/main/java/io/grpc/transport/okhttp/OutboundFlowController.java @@ -54,7 +54,6 @@ import javax.annotation.Nullable; * streams. */ class OutboundFlowController { - private static final OkHttpClientStream[] EMPTY_STREAM_ARRAY = new OkHttpClientStream[0]; private final OkHttpClientTransport transport; private final FrameWriter frameWriter; private int initialWindowSize = DEFAULT_WINDOW_SIZE; @@ -72,7 +71,7 @@ class OutboundFlowController { int delta = newWindowSize - initialWindowSize; initialWindowSize = newWindowSize; - for (OkHttpClientStream stream : getActiveStreams()) { + for (OkHttpClientStream stream : transport.getActiveStreams()) { OutboundFlowState state = (OutboundFlowState) stream.getOutboundFlowState(); if (state == null) { // Create the OutboundFlowState with the new window size. @@ -116,7 +115,7 @@ class OutboundFlowController { throw new IllegalArgumentException("Invalid streamId: " + streamId); } - OkHttpClientStream stream = transport.getStreams().get(streamId); + OkHttpClientStream stream = transport.getStream(streamId); if (stream == null) { // This is possible for a stream that has received end-of-stream from server (but hasn't sent // end-of-stream), and was removed from the transport stream map. @@ -173,18 +172,11 @@ class OutboundFlowController { return state; } - /** - * Gets all active streams as an array. - */ - private OkHttpClientStream[] getActiveStreams() { - return transport.getStreams().values().toArray(EMPTY_STREAM_ARRAY); - } - /** * Writes as much data for all the streams as possible given the current flow control windows. */ private void writeStreams() { - OkHttpClientStream[] streams = getActiveStreams(); + OkHttpClientStream[] streams = transport.getActiveStreams(); int connectionWindow = connectionState.window(); for (int numStreams = streams.length; numStreams > 0 && connectionWindow > 0;) { int nextNumStreams = 0; @@ -210,7 +202,7 @@ class OutboundFlowController { // Now take one last pass through all of the streams and write any allocated bytes. WriteStatus writeStatus = new WriteStatus(); - for (OkHttpClientStream stream : getActiveStreams()) { + for (OkHttpClientStream stream : transport.getActiveStreams()) { OutboundFlowState state = state(stream); state.writeBytes(state.allocatedBytes(), writeStatus); state.clearAllocatedBytes(); diff --git a/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java index cec61987d4..58a8b41d30 100644 --- a/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java @@ -195,7 +195,7 @@ public class OkHttpClientTransportTest { assertEquals("Protocol error\n" + NETWORK_ISSUE_MESSAGE, listener1.status.getDescription()); assertEquals(Status.INTERNAL.getCode(), listener2.status.getCode()); assertEquals("Protocol error\n" + NETWORK_ISSUE_MESSAGE, listener2.status.getDescription()); - verify(transportListener).transportShutdown(); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); }