From 6bf0936f8ef0e68f2bacfec2e0f58fec8ad7c53f Mon Sep 17 00:00:00 2001 From: Jihun Cho Date: Fri, 28 Dec 2018 17:20:03 -0800 Subject: [PATCH] okhttp: move async mechanism from FrameWriter to sink (AsyncSink) (#4916) Optimize OkHttp transport's memory use by getting rid of queuing writes in AsyncFrameWriter. If any write is pending due to connection issue or by flow control, AsyncFrameWriter can use at least 8K per each task (task includes buffer) even if the actual payload is very small. To merge pending writes, Async mechanism is moved from AsyncFrameWriter to AsyncSink (AsyncSink is used by okio's FrameWriter). AsyncSink is still relying on okio's buffer to decide merging writes or not. Resolves #4860 --- .../java/io/grpc/okhttp/AsyncFrameWriter.java | 276 ----------------- .../main/java/io/grpc/okhttp/AsyncSink.java | 162 ++++++++++ .../okhttp/ExceptionHandlingFrameWriter.java | 214 +++++++++++++ .../io/grpc/okhttp/OkHttpClientStream.java | 6 +- .../io/grpc/okhttp/OkHttpClientTransport.java | 117 ++++--- .../grpc/okhttp/OutboundFlowController.java | 7 +- .../io/grpc/okhttp/AsyncFrameWriterTest.java | 149 --------- .../java/io/grpc/okhttp/AsyncSinkTest.java | 289 ++++++++++++++++++ .../ExceptionHandlingFrameWriterTest.java | 77 +++++ .../grpc/okhttp/OkHttpClientStreamTest.java | 25 +- .../okhttp/OkHttpClientTransportTest.java | 116 +++++-- 11 files changed, 927 insertions(+), 511 deletions(-) delete mode 100644 okhttp/src/main/java/io/grpc/okhttp/AsyncFrameWriter.java create mode 100644 okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java create mode 100644 okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java delete mode 100644 okhttp/src/test/java/io/grpc/okhttp/AsyncFrameWriterTest.java create mode 100644 okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java create mode 100644 okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java diff --git a/okhttp/src/main/java/io/grpc/okhttp/AsyncFrameWriter.java b/okhttp/src/main/java/io/grpc/okhttp/AsyncFrameWriter.java deleted file mode 100644 index 210aaa1f08..0000000000 --- a/okhttp/src/main/java/io/grpc/okhttp/AsyncFrameWriter.java +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Copyright 2014 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.okhttp; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; -import io.grpc.internal.SerializingExecutor; -import io.grpc.okhttp.internal.framed.ErrorCode; -import io.grpc.okhttp.internal.framed.FrameWriter; -import io.grpc.okhttp.internal.framed.Header; -import io.grpc.okhttp.internal.framed.Settings; -import java.io.IOException; -import java.net.Socket; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.concurrent.atomic.AtomicLong; -import java.util.logging.Level; -import java.util.logging.Logger; -import okio.Buffer; - -class AsyncFrameWriter implements FrameWriter { - private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName()); - private FrameWriter frameWriter; - private Socket socket; - // Although writes are thread-safe, we serialize them to prevent consuming many Threads that are - // just waiting on each other. - private final SerializingExecutor executor; - private final TransportExceptionHandler transportExceptionHandler; - private final AtomicLong flushCounter = new AtomicLong(); - // Some exceptions are not very useful and add too much noise to the log - private static final Set QUIET_ERRORS = - Collections.unmodifiableSet(new HashSet<>(Arrays.asList("Socket closed"))); - - public AsyncFrameWriter( - TransportExceptionHandler transportExceptionHandler, SerializingExecutor executor) { - this.transportExceptionHandler = transportExceptionHandler; - this.executor = executor; - } - - /** - * Set the real frameWriter and the corresponding underlying socket, the socket is needed for - * closing. - * - *

should only be called by thread of executor. - */ - void becomeConnected(FrameWriter frameWriter, Socket socket) { - Preconditions.checkState(this.frameWriter == null, - "AsyncFrameWriter's setFrameWriter() should only be called once."); - this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter"); - this.socket = Preconditions.checkNotNull(socket, "socket"); - } - - @Override - public void connectionPreface() { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.connectionPreface(); - } - }); - } - - @Override - public void ackSettings(final Settings peerSettings) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.ackSettings(peerSettings); - } - }); - } - - @Override - public void pushPromise(final int streamId, final int promisedStreamId, - final List

requestHeaders) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.pushPromise(streamId, promisedStreamId, requestHeaders); - } - }); - } - - @Override - public void flush() { - // keep track of version of flushes to skip flush if another flush task is queued. - final long flushCount = flushCounter.incrementAndGet(); - - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - // There can be a flush starvation if there are continuous flood of flush is queued, this - // is not an issue with OkHttp since it flushes if the buffer is full. - if (flushCounter.get() == flushCount) { - frameWriter.flush(); - } - } - }); - } - - @Override - public void synStream(final boolean outFinished, final boolean inFinished, final int streamId, - final int associatedStreamId, final List
headerBlock) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.synStream(outFinished, inFinished, streamId, associatedStreamId, headerBlock); - } - }); - } - - @Override - public void synReply(final boolean outFinished, final int streamId, - final List
headerBlock) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.synReply(outFinished, streamId, headerBlock); - } - }); - } - - @Override - public void headers(final int streamId, final List
headerBlock) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.headers(streamId, headerBlock); - } - }); - } - - @Override - public void rstStream(final int streamId, final ErrorCode errorCode) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.rstStream(streamId, errorCode); - } - }); - } - - @Override - public void data(final boolean outFinished, final int streamId, final Buffer source, - final int byteCount) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.data(outFinished, streamId, source, byteCount); - } - }); - } - - @Override - public void settings(final Settings okHttpSettings) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.settings(okHttpSettings); - } - }); - } - - @Override - public void ping(final boolean ack, final int payload1, final int payload2) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.ping(ack, payload1, payload2); - } - }); - } - - @Override - public void goAway(final int lastGoodStreamId, final ErrorCode errorCode, - final byte[] debugData) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.goAway(lastGoodStreamId, errorCode, debugData); - // Flush it since after goAway, we are likely to close this writer. - frameWriter.flush(); - } - }); - } - - @Override - public void windowUpdate(final int streamId, final long windowSizeIncrement) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.windowUpdate(streamId, windowSizeIncrement); - } - }); - } - - @Override - public void close() { - executor.execute(new Runnable() { - @Override - public void run() { - if (frameWriter != null) { - try { - frameWriter.close(); - socket.close(); - } catch (IOException e) { - log.log(getLogLevel(e), "Failed closing connection", e); - } - } - } - }); - } - - /** - * Accepts a throwable and returns the appropriate logging level. Uninteresting exceptions - * should not clutter the log. - */ - @VisibleForTesting - static Level getLogLevel(Throwable t) { - if (t instanceof IOException - && t.getMessage() != null - && QUIET_ERRORS.contains(t.getMessage())) { - return Level.FINE; - - } - return Level.INFO; - } - - private abstract class WriteRunnable implements Runnable { - @Override - public final void run() { - try { - if (frameWriter == null) { - throw new IOException("Unable to perform write due to unavailable frameWriter."); - } - doRun(); - } catch (RuntimeException e) { - transportExceptionHandler.onException(e); - } catch (Exception e) { - transportExceptionHandler.onException(e); - } - } - - public abstract void doRun() throws IOException; - } - - @Override - public int maxDataLength() { - return frameWriter == null ? 0x4000 /* 16384, the minimum required by the HTTP/2 spec */ - : frameWriter.maxDataLength(); - } - - /** A class that handles transport exception. */ - interface TransportExceptionHandler { - - /** Handles exception. */ - void onException(Throwable throwable); - } -} diff --git a/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java new file mode 100644 index 0000000000..bd1d4762a2 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java @@ -0,0 +1,162 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import io.grpc.internal.SerializingExecutor; +import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; +import java.io.IOException; +import java.net.Socket; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import okio.Buffer; +import okio.Sink; +import okio.Timeout; + +/** + * A sink that asynchronously write / flushes a buffer internally. AsyncSink provides flush + * coalescing to minimize network packing transmit. + */ +final class AsyncSink implements Sink { + + private final Object lock = new Object(); + private final Buffer buffer = new Buffer(); + private final SerializingExecutor serializingExecutor; + private final TransportExceptionHandler transportExceptionHandler; + + @GuardedBy("lock") + private boolean writeEnqueued = false; + @GuardedBy("lock") + private boolean flushEnqueued = false; + private boolean closed = false; + @Nullable + private Sink sink; + @Nullable + private Socket socket; + + private AsyncSink(SerializingExecutor executor, TransportExceptionHandler exceptionHandler) { + this.serializingExecutor = checkNotNull(executor, "executor"); + this.transportExceptionHandler = checkNotNull(exceptionHandler, "exceptionHandler"); + } + + static AsyncSink sink( + SerializingExecutor executor, TransportExceptionHandler exceptionHandler) { + return new AsyncSink(executor, exceptionHandler); + } + + /** + * Sets the actual sink. It is allowed to call write / flush operations on the sink iff calling + * this method is scheduled in the executor. The socket is needed for closing. + * + *

should only be called once by thread of executor. + */ + void becomeConnected(Sink sink, Socket socket) { + checkState(this.sink == null, "AsyncSink's becomeConnected should only be called once."); + this.sink = checkNotNull(sink, "sink"); + this.socket = checkNotNull(socket, "socket"); + } + + @Override + public void write(Buffer source, long byteCount) throws IOException { + checkNotNull(source, "source"); + if (closed) { + throw new IOException("closed"); + } + synchronized (lock) { + buffer.write(source, byteCount); + if (writeEnqueued || flushEnqueued || buffer.completeSegmentByteCount() <= 0) { + return; + } + writeEnqueued = true; + } + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + Buffer buf = new Buffer(); + synchronized (lock) { + buf.write(buffer, buffer.completeSegmentByteCount()); + writeEnqueued = false; + } + try { + sink.write(buf, buf.size()); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + }); + } + + @Override + public void flush() throws IOException { + if (closed) { + throw new IOException("closed"); + } + synchronized (lock) { + if (flushEnqueued) { + return; + } + flushEnqueued = true; + } + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + Buffer buf = new Buffer(); + synchronized (lock) { + buf.write(buffer, buffer.size()); + flushEnqueued = false; + } + try { + sink.write(buf, buf.size()); + sink.flush(); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + }); + } + + @Override + public Timeout timeout() { + return Timeout.NONE; + } + + @Override + public void close() { + if (closed) { + return; + } + closed = true; + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + buffer.close(); + try { + sink.close(); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + try { + socket.close(); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + }); + } +} \ No newline at end of file diff --git a/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java b/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java new file mode 100644 index 0000000000..0ee4c42367 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java @@ -0,0 +1,214 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.grpc.okhttp.internal.framed.ErrorCode; +import io.grpc.okhttp.internal.framed.FrameWriter; +import io.grpc.okhttp.internal.framed.Header; +import io.grpc.okhttp.internal.framed.Settings; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; +import okio.Buffer; + +final class ExceptionHandlingFrameWriter implements FrameWriter { + + private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName()); + // Some exceptions are not very useful and add too much noise to the log + private static final Set QUIET_ERRORS = + Collections.unmodifiableSet(new HashSet<>(Arrays.asList("Socket closed"))); + + private final TransportExceptionHandler transportExceptionHandler; + + private final FrameWriter frameWriter; + + ExceptionHandlingFrameWriter( + TransportExceptionHandler transportExceptionHandler, FrameWriter frameWriter) { + this.transportExceptionHandler = + checkNotNull(transportExceptionHandler, "transportExceptionHandler"); + this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter"); + } + + @Override + public void connectionPreface() { + try { + frameWriter.connectionPreface(); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void ackSettings(Settings peerSettings) { + try { + frameWriter.ackSettings(peerSettings); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void pushPromise(int streamId, int promisedStreamId, List

requestHeaders) { + try { + frameWriter.pushPromise(streamId, promisedStreamId, requestHeaders); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void flush() { + try { + frameWriter.flush(); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void synStream( + boolean outFinished, + boolean inFinished, + int streamId, + int associatedStreamId, + List
headerBlock) { + try { + frameWriter.synStream(outFinished, inFinished, streamId, associatedStreamId, headerBlock); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void synReply(boolean outFinished, int streamId, + List
headerBlock) { + try { + frameWriter.synReply(outFinished, streamId, headerBlock); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void headers(int streamId, List
headerBlock) { + try { + frameWriter.headers(streamId, headerBlock); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void rstStream(int streamId, ErrorCode errorCode) { + try { + frameWriter.rstStream(streamId, errorCode); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public int maxDataLength() { + return frameWriter.maxDataLength(); + } + + @Override + public void data(boolean outFinished, int streamId, Buffer source, int byteCount) { + try { + frameWriter.data(outFinished, streamId, source, byteCount); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void settings(Settings okHttpSettings) { + try { + frameWriter.settings(okHttpSettings); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void ping(boolean ack, int payload1, int payload2) { + try { + frameWriter.ping(ack, payload1, payload2); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void goAway(int lastGoodStreamId, ErrorCode errorCode, + byte[] debugData) { + try { + frameWriter.goAway(lastGoodStreamId, errorCode, debugData); + // Flush it since after goAway, we are likely to close this writer. + frameWriter.flush(); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void windowUpdate(int streamId, long windowSizeIncrement) { + try { + frameWriter.windowUpdate(streamId, windowSizeIncrement); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void close() { + try { + frameWriter.close(); + } catch (IOException e) { + log.log(getLogLevel(e), "Failed closing connection", e); + } + } + + /** + * Accepts a throwable and returns the appropriate logging level. Uninteresting exceptions + * should not clutter the log. + */ + @VisibleForTesting + static Level getLogLevel(Throwable t) { + if (t instanceof IOException + && t.getMessage() != null + && QUIET_ERRORS.contains(t.getMessage())) { + return Level.FINE; + } + return Level.INFO; + } + + /** A class that handles transport exception. */ + interface TransportExceptionHandler { + /** Handles exception. */ + void onException(Throwable throwable); + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java index 782cbbc1db..e62d0cf3ce 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java @@ -62,7 +62,7 @@ class OkHttpClientStream extends AbstractClientStream { OkHttpClientStream( MethodDescriptor method, Metadata headers, - AsyncFrameWriter frameWriter, + ExceptionHandlingFrameWriter frameWriter, OkHttpClientTransport transport, OutboundFlowController outboundFlow, Object lock, @@ -203,7 +203,7 @@ class OkHttpClientStream extends AbstractClientStream { @GuardedBy("lock") private int processedWindow; @GuardedBy("lock") - private final AsyncFrameWriter frameWriter; + private final ExceptionHandlingFrameWriter frameWriter; @GuardedBy("lock") private final OutboundFlowController outboundFlow; @GuardedBy("lock") @@ -216,7 +216,7 @@ class OkHttpClientStream extends AbstractClientStream { int maxMessageSize, StatsTraceContext statsTraceCtx, Object lock, - AsyncFrameWriter frameWriter, + ExceptionHandlingFrameWriter frameWriter, OutboundFlowController outboundFlow, OkHttpClientTransport transport, int initialWindowSize) { diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index f1ce827018..81ef3a0cc5 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -57,7 +57,7 @@ import io.grpc.internal.SerializingExecutor; import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; -import io.grpc.okhttp.AsyncFrameWriter.TransportExceptionHandler; +import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; import io.grpc.okhttp.internal.ConnectionSpec; import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.FrameReader; @@ -107,7 +107,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep private static final OkHttpClientStream[] EMPTY_STREAM_ARRAY = new OkHttpClientStream[0]; private static Map buildErrorCodeToStatusMap() { - Map errorToStatus = new EnumMap(ErrorCode.class); + Map errorToStatus = new EnumMap<>(ErrorCode.class); errorToStatus.put(ErrorCode.NO_ERROR, Status.INTERNAL.withDescription("No error: A GRPC status of OK should have been sent")); errorToStatus.put(ErrorCode.PROTOCOL_ERROR, @@ -144,15 +144,15 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep private final int initialWindowSize; private Listener listener; private FrameReader testFrameReader; - private AsyncFrameWriter frameWriter; + @GuardedBy("lock") + private ExceptionHandlingFrameWriter frameWriter; private OutboundFlowController outboundFlow; private final Object lock = new Object(); private final InternalLogId logId; @GuardedBy("lock") private int nextStreamId; @GuardedBy("lock") - private final Map streams = - new HashMap(); + private final Map streams = new HashMap<>(); private final Executor executor; // Wrap on executor, to guarantee some operations be executed serially. private final SerializingExecutor serializingExecutor; @@ -182,7 +182,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep private int maxConcurrentStreams = 0; @SuppressWarnings("JdkObsolete") // Usage is bursty; want low memory usage when empty @GuardedBy("lock") - private LinkedList pendingStreams = new LinkedList(); + private LinkedList pendingStreams = new LinkedList<>(); private final ConnectionSpec connectionSpec; private FrameWriter testFrameWriter; private ScheduledExecutorService scheduler; @@ -219,7 +219,6 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep Runnable connectingCallback; SettableFuture connectedFuture; - OkHttpClientTransport(InetSocketAddress address, String authority, @Nullable String userAgent, Executor executor, @Nullable SSLSocketFactory sslSocketFactory, @Nullable HostnameVerifier hostnameVerifier, ConnectionSpec connectionSpec, @@ -322,11 +321,11 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep @Override public void ping(final PingCallback callback, Executor executor) { - checkState(frameWriter != null); long data = 0; Http2Ping p; boolean writePing; synchronized (lock) { + checkState(frameWriter != null); if (stopped) { Http2Ping.notifyFailed(callback, executor, getPingFailure()); return; @@ -345,9 +344,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep writePing = true; transportTracer.reportKeepAliveSent(); } - } - if (writePing) { - frameWriter.ping(false, (int) (data >>> 32), (int) data); + if (writePing) { + frameWriter.ping(false, (int) (data >>> 32), (int) data); + } } // If transport concurrently failed/stopped since we released the lock above, this could // immediately invoke callback (which we shouldn't do while holding a lock) @@ -449,15 +448,15 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep keepAliveWithoutCalls); keepAliveManager.onTransportStarted(); } - - frameWriter = new AsyncFrameWriter(this, serializingExecutor); - outboundFlow = new OutboundFlowController(this, frameWriter, initialWindowSize); - // Connecting in the serializingExecutor, so that some stream operations like synStream - // will be executed after connected. - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - if (isForTest()) { + if (isForTest()) { + synchronized (lock) { + frameWriter = new ExceptionHandlingFrameWriter(OkHttpClientTransport.this, testFrameWriter); + outboundFlow = + new OutboundFlowController(OkHttpClientTransport.this, frameWriter, initialWindowSize); + } + serializingExecutor.execute(new Runnable() { + @Override + public void run() { if (connectingCallback != null) { connectingCallback.run(); } @@ -467,11 +466,25 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep maxConcurrentStreams = Integer.MAX_VALUE; startPendingStreams(); } - frameWriter.becomeConnected(testFrameWriter, socket); connectedFuture.set(null); - return; } + }); + return null; + } + final AsyncSink asyncSink = AsyncSink.sink(serializingExecutor, this); + final Variant variant = new Http2(); + FrameWriter rawFrameWriter = variant.newWriter(Okio.buffer(asyncSink), true); + + synchronized (lock) { + frameWriter = new ExceptionHandlingFrameWriter(this, rawFrameWriter); + outboundFlow = new OutboundFlowController(this, frameWriter, initialWindowSize); + } + // Connecting in the serializingExecutor, so that some stream operations like synStream + // will be executed after connected. + serializingExecutor.execute(new Runnable() { + @Override + public void run() { // Use closed source on failure so that the reader immediately shuts down. BufferedSource source = Okio.buffer(new Source() { @Override @@ -485,10 +498,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep } @Override - public void close() {} + public void close() { + } }); - Variant variant = new Http2(); - BufferedSink sink; Socket sock; SSLSession sslSession = null; try { @@ -508,7 +520,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep } sock.setTcpNoDelay(true); source = Okio.buffer(Okio.source(sock)); - sink = Okio.buffer(Okio.sink(sock)); + asyncSink.becomeConnected(Okio.sink(sock), sock); + // The return value of OkHttpTlsUpgrader.upgrade is an SSLSocket that has this info attributes = Attributes .newBuilder() @@ -526,31 +539,29 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep return; } finally { clientFrameHandler = new ClientFrameHandler(variant.newReader(source, true)); - executor.execute(clientFrameHandler); } - - FrameWriter rawFrameWriter; synchronized (lock) { socket = Preconditions.checkNotNull(sock, "socket"); - maxConcurrentStreams = Integer.MAX_VALUE; - startPendingStreams(); if (sslSession != null) { securityInfo = new InternalChannelz.Security(new InternalChannelz.Tls(sslSession)); } } - - rawFrameWriter = variant.newWriter(sink, true); - frameWriter.becomeConnected(rawFrameWriter, socket); - - try { - // Do these with the raw FrameWriter, so that they will be done in this thread, - // and before any possible pending stream operations. - rawFrameWriter.connectionPreface(); - Settings settings = new Settings(); - rawFrameWriter.settings(settings); - } catch (Exception e) { - onException(e); - return; + } + }); + synchronized (lock) { + frameWriter.connectionPreface(); + Settings settings = new Settings(); + frameWriter.settings(settings); + } + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + // ClientFrameHandler need to be started after connectionPreface / settings, otherwise it + // may send goAway immediately. + executor.execute(clientFrameHandler); + synchronized (lock) { + maxConcurrentStreams = Integer.MAX_VALUE; + startPendingStreams(); } } }); @@ -558,7 +569,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep } private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddress proxyAddress, - String proxyUsername, String proxyPassword) throws IOException, StatusException { + String proxyUsername, String proxyPassword) throws StatusException { try { Socket sock; // The proxy address may not be resolved @@ -1039,7 +1050,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep OkHttpClientStream stream = getStream(streamId); if (stream == null) { if (mayHaveCreatedStream(streamId)) { - frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM); + synchronized (lock) { + frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM); + } in.skip(length); } else { onError(ErrorCode.PROTOCOL_ERROR, "Received data for unknown stream: " + streamId); @@ -1059,7 +1072,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep // connection window update connectionUnacknowledgedBytesRead += length; if (connectionUnacknowledgedBytesRead >= initialWindowSize * DEFAULT_WINDOW_UPDATE_RATIO) { - frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead); + synchronized (lock) { + frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead); + } connectionUnacknowledgedBytesRead = 0; } } @@ -1170,7 +1185,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep @Override public void ping(boolean ack, int payload1, int payload2) { if (!ack) { - frameWriter.ping(true, payload1, payload2); + synchronized (lock) { + frameWriter.ping(true, payload1, payload2); + } } else { Http2Ping p = null; long ackPayload = (((long) payload1) << 32) | (payload2 & 0xffffffffL); @@ -1222,7 +1239,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep public void pushPromise(int streamId, int promisedStreamId, List
requestHeaders) throws IOException { // We don't accept server initiated stream. - frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR); + synchronized (lock) { + frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR); + } } @Override diff --git a/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java b/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java index e3cad95313..441bb21151 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java @@ -323,12 +323,7 @@ class OutboundFlowController { try { // endOfStream is set for the last chunk of data marked as endOfStream boolean isEndOfStream = buffer.size() == frameBytes && endOfStream; - // AsyncFrameWriter drains buffer in executor. To avoid race, copy to temp. - // TODO(jihuncho) remove temp buff when async logic is moved to AsyncSink. - Buffer temp = new Buffer(); - temp.write(buffer, frameBytes); - - frameWriter.data(isEndOfStream, streamId, temp, frameBytes); + frameWriter.data(isEndOfStream, streamId, buffer, frameBytes); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/okhttp/src/test/java/io/grpc/okhttp/AsyncFrameWriterTest.java b/okhttp/src/test/java/io/grpc/okhttp/AsyncFrameWriterTest.java deleted file mode 100644 index 45e80707c5..0000000000 --- a/okhttp/src/test/java/io/grpc/okhttp/AsyncFrameWriterTest.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright 2018 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.okhttp; - -import static com.google.common.truth.Truth.assertThat; -import static io.grpc.okhttp.AsyncFrameWriter.getLogLevel; -import static org.mockito.Matchers.anyBoolean; -import static org.mockito.Matchers.anyInt; -import static org.mockito.Mockito.inOrder; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; - -import io.grpc.internal.SerializingExecutor; -import io.grpc.okhttp.AsyncFrameWriter.TransportExceptionHandler; -import io.grpc.okhttp.internal.framed.FrameWriter; -import java.io.IOException; -import java.net.Socket; -import java.util.Queue; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.Executor; -import java.util.logging.Level; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.InOrder; -import org.mockito.Mock; -import org.mockito.runners.MockitoJUnitRunner; - -@RunWith(MockitoJUnitRunner.class) -public class AsyncFrameWriterTest { - - @Mock private Socket socket; - @Mock private FrameWriter frameWriter; - - private QueueingExecutor queueingExecutor = new QueueingExecutor(); - private TransportExceptionHandler transportExceptionHandler = - new EscalatingTransportErrorHandler(); - private AsyncFrameWriter asyncFrameWriter = - new AsyncFrameWriter(transportExceptionHandler, new SerializingExecutor(queueingExecutor)); - - @Before - public void setUp() throws Exception { - asyncFrameWriter.becomeConnected(frameWriter, socket); - } - - @Test - public void noCoalesceRequired() throws IOException { - asyncFrameWriter.ping(true, 0, 1); - asyncFrameWriter.flush(); - queueingExecutor.runAll(); - - verify(frameWriter, times(1)).ping(anyBoolean(), anyInt(), anyInt()); - verify(frameWriter, times(1)).flush(); - } - - @Test - public void flushCoalescing_shouldNotMergeTwoDistinctFlushes() throws IOException { - asyncFrameWriter.ping(true, 0, 1); - asyncFrameWriter.flush(); - queueingExecutor.runAll(); - - asyncFrameWriter.ping(true, 0, 2); - asyncFrameWriter.flush(); - queueingExecutor.runAll(); - - verify(frameWriter, times(2)).ping(anyBoolean(), anyInt(), anyInt()); - verify(frameWriter, times(2)).flush(); - } - - @Test - public void flushCoalescing_shouldMergeTwoQueuedFlushes() throws IOException { - asyncFrameWriter.ping(true, 0, 1); - asyncFrameWriter.flush(); - asyncFrameWriter.ping(true, 0, 2); - asyncFrameWriter.flush(); - - queueingExecutor.runAll(); - - InOrder inOrder = inOrder(frameWriter); - inOrder.verify(frameWriter, times(2)).ping(anyBoolean(), anyInt(), anyInt()); - inOrder.verify(frameWriter).flush(); - } - - @Test - public void unknownException() { - assertThat(getLogLevel(new Exception())).isEqualTo(Level.INFO); - } - - @Test - public void quiet() { - assertThat(getLogLevel(new IOException("Socket closed"))).isEqualTo(Level.FINE); - } - - @Test - public void nonquiet() { - assertThat(getLogLevel(new IOException("foo"))).isEqualTo(Level.INFO); - } - - @Test - public void nullMessage() { - IOException e = new IOException(); - assertThat(e.getMessage()).isNull(); - assertThat(getLogLevel(e)).isEqualTo(Level.INFO); - } - - /** - * Executor queues incoming runnables instead of running it. Runnables can be invoked via {@link - * QueueingExecutor#runAll} in serial order. - */ - private static class QueueingExecutor implements Executor { - - private final Queue runnables = new ConcurrentLinkedQueue(); - - @Override - public void execute(Runnable command) { - runnables.add(command); - } - - public void runAll() { - Runnable r; - while ((r = runnables.poll()) != null) { - r.run(); - } - } - } - - /** Rethrows as Assertion error. */ - private static class EscalatingTransportErrorHandler implements TransportExceptionHandler { - - @Override - public void onException(Throwable throwable) { - throw new AssertionError(throwable); - } - } -} diff --git a/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java b/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java new file mode 100644 index 0000000000..084a00f72c --- /dev/null +++ b/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java @@ -0,0 +1,289 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.CALLS_REAL_METHODS; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; + +import com.google.common.base.Charsets; +import io.grpc.internal.SerializingExecutor; +import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; +import java.io.IOException; +import java.net.Socket; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executor; +import okio.Buffer; +import okio.Sink; +import okio.Timeout; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.InOrder; + +/** Tests for {@link AsyncSink}. */ +@RunWith(JUnit4.class) +public class AsyncSinkTest { + + private Socket socket = mock(Socket.class); + private Sink mockedSink = mock(VoidSink.class, CALLS_REAL_METHODS); + private QueueingExecutor queueingExecutor = new QueueingExecutor(); + private TransportExceptionHandler exceptionHandler = mock(TransportExceptionHandler.class); + private AsyncSink sink = + AsyncSink.sink(new SerializingExecutor(queueingExecutor), exceptionHandler); + + @Before + public void setUp() throws Exception { + sink.becomeConnected(mockedSink, socket); + } + + @Test + public void noCoalesceRequired() throws IOException { + Buffer buffer = new Buffer(); + sink.write(buffer.writeUtf8("hello"), buffer.size()); + sink.flush(); + queueingExecutor.runAll(); + + InOrder inOrder = inOrder(mockedSink); + inOrder.verify(mockedSink).write(any(Buffer.class), anyInt()); + inOrder.verify(mockedSink).flush(); + } + + @Test + public void flushCoalescing_shouldNotMergeTwoDistinctFlushes() throws IOException { + byte[] firstData = "a string".getBytes(Charsets.UTF_8); + byte[] secondData = "a longer string".getBytes(Charsets.UTF_8); + + Buffer buffer = new Buffer(); + sink.write(buffer.write(firstData), buffer.size()); + sink.flush(); + queueingExecutor.runAll(); + + sink.write(buffer.write(secondData), buffer.size()); + sink.flush(); + queueingExecutor.runAll(); + + InOrder inOrder = inOrder(mockedSink); + inOrder.verify(mockedSink).write(any(Buffer.class), anyInt()); + inOrder.verify(mockedSink).flush(); + inOrder.verify(mockedSink).write(any(Buffer.class), anyInt()); + inOrder.verify(mockedSink).flush(); + } + + @Test + public void flushCoalescing_shouldMergeTwoQueuedFlushesAndWrites() throws IOException { + byte[] firstData = "a string".getBytes(Charsets.UTF_8); + byte[] secondData = "a longer string".getBytes(Charsets.UTF_8); + Buffer buffer = new Buffer().write(firstData); + sink.write(buffer, buffer.size()); + sink.flush(); + buffer = new Buffer().write(secondData); + sink.write(buffer, buffer.size()); + sink.flush(); + + queueingExecutor.runAll(); + + InOrder inOrder = inOrder(mockedSink); + inOrder.verify(mockedSink) + .write(any(Buffer.class), eq((long) firstData.length + secondData.length)); + inOrder.verify(mockedSink).flush(); + } + + @Test + public void flushCoalescing_shouldMergeWrites() throws IOException { + byte[] firstData = "a string".getBytes(Charsets.UTF_8); + byte[] secondData = "a longer string".getBytes(Charsets.UTF_8); + Buffer buffer = new Buffer(); + sink.write(buffer.write(firstData), buffer.size()); + sink.write(buffer.write(secondData), buffer.size()); + sink.flush(); + queueingExecutor.runAll(); + + InOrder inOrder = inOrder(mockedSink); + inOrder.verify(mockedSink) + .write(any(Buffer.class), eq((long) firstData.length + secondData.length)); + inOrder.verify(mockedSink).flush(); + } + + @Test + public void write_shouldCachePreviousException() throws IOException { + Exception ioException = new IOException("some exception"); + doThrow(ioException) + .when(mockedSink).write(any(Buffer.class), anyLong()); + Buffer buffer = new Buffer(); + buffer.writeUtf8("any message"); + sink.write(buffer, buffer.size()); + sink.flush(); + queueingExecutor.runAll(); + sink.write(buffer, buffer.size()); + queueingExecutor.runAll(); + + verify(exceptionHandler, timeout(1000)).onException(ioException); + } + + @Test + public void close_writeShouldThrowException() { + sink.close(); + queueingExecutor.runAll(); + try { + sink.write(new Buffer(), 0); + fail("should throw ioException"); + } catch (IOException e) { + assertThat(e).hasMessageThat().contains("closed"); + } + } + + @Test + public void write_shouldThrowIfAlreadyClosed() throws IOException { + Exception ioException = new IOException("some exception"); + doThrow(ioException) + .when(mockedSink).write(any(Buffer.class), anyLong()); + Buffer buffer = new Buffer(); + buffer.writeUtf8("any message"); + sink.write(buffer, buffer.size()); + sink.close(); + queueingExecutor.runAll(); + try { + sink.write(buffer, buffer.size()); + queueingExecutor.runAll(); + fail("should throw ioException"); + } catch (IOException e) { + assertThat(e).hasMessageThat().contains("closed"); + } + } + + @Test + public void close_flushShouldThrowException() throws IOException { + sink.close(); + queueingExecutor.runAll(); + try { + sink.flush(); + queueingExecutor.runAll(); + fail("should fail"); + } catch (IOException e) { + assertThat(e).hasMessageThat().contains("closed"); + } + } + + @Test + public void flush_shouldThrowIfAlreadyClosed() throws IOException { + Buffer buffer = new Buffer(); + buffer.writeUtf8("any message"); + sink.write(buffer, buffer.size()); + sink.close(); + queueingExecutor.runAll(); + try { + sink.flush(); + queueingExecutor.runAll(); + fail("should fail"); + } catch (IOException e) { + assertThat(e).hasMessageThat().contains("closed"); + } + } + + @Test + public void write_callSinkIfBufferIsLargerThanSegmentSize() throws IOException { + Buffer buffer = new Buffer(); + // OkHttp is using 8192 as segment size. + int payloadSize = 8192 * 2 - 1; + int padding = 10; + buffer.write(new byte[payloadSize]); + + int completeSegmentBytes = (int) buffer.completeSegmentByteCount(); + assertThat(completeSegmentBytes).isLessThan(payloadSize); + + // first trying to send of all complete segments, but not the padding + sink.write(buffer, completeSegmentBytes + padding); + queueingExecutor.runAll(); + verify(mockedSink).write(any(Buffer.class), eq((long) completeSegmentBytes)); + verify(mockedSink, never()).flush(); + assertThat(buffer.size()).isEqualTo((long) (payloadSize - completeSegmentBytes - padding)); + + // writing smaller than completed segment, shouldn't trigger write to Sink. + reset(mockedSink); + sink.write(buffer, buffer.size()); + queueingExecutor.runAll(); + verify(mockedSink, never()).write(any(Buffer.class), anyLong()); + verify(mockedSink, never()).flush(); + assertThat(buffer.exhausted()).isTrue(); + + // flush should write everything. + sink.flush(); + queueingExecutor.runAll(); + verify(mockedSink).write(any(Buffer.class),eq((long) payloadSize - completeSegmentBytes)); + verify(mockedSink).flush(); + } + + /** + * Executor queues incoming runnables instead of running it. Runnables can be invoked via {@link + * QueueingExecutor#runAll} in serial order. + */ + private static class QueueingExecutor implements Executor { + + private final Queue runnables = new ConcurrentLinkedQueue<>(); + + @Override + public void execute(Runnable command) { + runnables.add(command); + } + + public void runAll() { + Runnable r; + while ((r = runnables.poll()) != null) { + r.run(); + } + } + } + + /** Test sink to mimic real Sink behavior since write has a side effect. */ + private static class VoidSink implements Sink { + + @Override + public void write(Buffer source, long byteCount) throws IOException { + // removes byteCount bytes from source. + source.read(new byte[(int) byteCount], 0, (int) byteCount); + } + + @Override + public void flush() throws IOException { + // do nothing + } + + @Override + public Timeout timeout() { + return Timeout.NONE; + } + + @Override + public void close() throws IOException { + // do nothing + } + } +} \ No newline at end of file diff --git a/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java b/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java new file mode 100644 index 0000000000..13cc338d35 --- /dev/null +++ b/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.okhttp.ExceptionHandlingFrameWriter.getLogLevel; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; +import io.grpc.okhttp.internal.framed.FrameWriter; +import io.grpc.okhttp.internal.framed.Header; +import java.io.IOException; +import java.util.ArrayList; +import java.util.logging.Level; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ExceptionHandlingFrameWriterTest { + + private FrameWriter mockedFrameWriter = mock(FrameWriter.class); + private TransportExceptionHandler transportExceptionHandler = + mock(TransportExceptionHandler.class); + private ExceptionHandlingFrameWriter exceptionHandlingFrameWriter = + new ExceptionHandlingFrameWriter(transportExceptionHandler, mockedFrameWriter); + + @Test + public void exception() throws IOException { + IOException exception = new IOException("some exception"); + doThrow(exception).when(mockedFrameWriter) + .synReply(false, 100, new ArrayList
()); + + exceptionHandlingFrameWriter.synReply(false, 100, new ArrayList
()); + + verify(transportExceptionHandler).onException(exception); + verify(mockedFrameWriter).synReply(false, 100, new ArrayList
()); + } + + @Test + public void unknownException() { + assertThat(getLogLevel(new Exception())).isEqualTo(Level.INFO); + } + + @Test + public void quiet() { + assertThat(getLogLevel(new IOException("Socket closed"))).isEqualTo(Level.FINE); + } + + @Test + public void nonquiet() { + assertThat(getLogLevel(new IOException("foo"))).isEqualTo(Level.INFO); + } + + @Test + public void nullMessage() { + IOException e = new IOException(); + assertThat(e.getMessage()).isNull(); + assertThat(getLogLevel(e)).isEqualTo(Level.INFO); + } +} \ No newline at end of file diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java index 9768d0f353..1b7d38fc1a 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java @@ -37,8 +37,10 @@ import io.grpc.internal.NoopClientStreamListener; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; import io.grpc.okhttp.internal.framed.ErrorCode; +import io.grpc.okhttp.internal.framed.FrameWriter; import io.grpc.okhttp.internal.framed.Header; import java.io.ByteArrayInputStream; +import java.io.IOException; import java.nio.charset.Charset; import java.util.List; import java.util.concurrent.atomic.AtomicReference; @@ -60,7 +62,8 @@ public class OkHttpClientStreamTest { private static final int INITIAL_WINDOW_SIZE = 65535; @Mock private MethodDescriptor.Marshaller marshaller; - @Mock private AsyncFrameWriter frameWriter; + @Mock private FrameWriter mockedFrameWriter; + private ExceptionHandlingFrameWriter frameWriter; @Mock private OkHttpClientTransport transport; @Mock private OutboundFlowController flowController; @Captor private ArgumentCaptor> headersCaptor; @@ -81,6 +84,8 @@ public class OkHttpClientStreamTest { .setResponseMarshaller(marshaller) .build(); + frameWriter = + new ExceptionHandlingFrameWriter(transport, mockedFrameWriter); stream = new OkHttpClientStream( methodDescriptor, new Metadata(), @@ -144,11 +149,11 @@ public class OkHttpClientStreamTest { stream.transportState().start(1234); - verifyNoMoreInteractions(frameWriter); + verifyNoMoreInteractions(mockedFrameWriter); } @Test - public void start_userAgentRemoved() { + public void start_userAgentRemoved() throws IOException { Metadata metaData = new Metadata(); metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application"); stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport, @@ -157,13 +162,14 @@ public class OkHttpClientStreamTest { stream.start(new BaseClientStreamListener()); stream.transportState().start(3); - verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); + verify(mockedFrameWriter) + .synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); assertThat(headersCaptor.getValue()) .contains(new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application")); } @Test - public void start_headerFieldOrder() { + public void start_headerFieldOrder() throws IOException { Metadata metaData = new Metadata(); metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application"); stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport, @@ -172,7 +178,8 @@ public class OkHttpClientStreamTest { stream.start(new BaseClientStreamListener()); stream.transportState().start(3); - verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); + verify(mockedFrameWriter) + .synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); assertThat(headersCaptor.getValue()).containsExactly( Headers.SCHEME_HEADER, Headers.METHOD_HEADER, @@ -185,7 +192,7 @@ public class OkHttpClientStreamTest { } @Test - public void getUnaryRequest() { + public void getUnaryRequest() throws IOException { MethodDescriptor getMethod = MethodDescriptor.newBuilder() .setType(MethodDescriptor.MethodType.UNARY) .setFullMethodName("service/method") @@ -200,7 +207,7 @@ public class OkHttpClientStreamTest { stream.start(new BaseClientStreamListener()); // GET streams send headers after halfClose is called. - verify(frameWriter, times(0)).synStream( + verify(mockedFrameWriter, times(0)).synStream( eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class)); @@ -210,7 +217,7 @@ public class OkHttpClientStreamTest { verify(transport).streamReadyToStart(eq(stream)); stream.transportState().start(3); - verify(frameWriter) + verify(mockedFrameWriter) .synStream(eq(true), eq(false), eq(3), eq(0), headersCaptor.capture()); assertThat(headersCaptor.getValue()).contains(Headers.METHOD_GET_HEADER); assertThat(headersCaptor.getValue()).contains( diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index cb3f7e5fe2..367d8e18d7 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -46,7 +46,6 @@ import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; @@ -91,9 +90,11 @@ import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketAddress; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -113,6 +114,7 @@ import org.junit.Test; import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.AdditionalAnswers; import org.mockito.ArgumentCaptor; import org.mockito.InOrder; import org.mockito.Matchers; @@ -139,7 +141,6 @@ public class OkHttpClientTransportTest { @Rule public final Timeout globalTimeout = Timeout.seconds(10); - @Mock private FrameWriter frameWriter; private MethodDescriptor method = TestMethodDescriptors.voidMethod(); @@ -150,8 +151,10 @@ public class OkHttpClientTransportTest { private final SSLSocketFactory sslSocketFactory = null; private final HostnameVerifier hostnameVerifier = null; private final TransportTracer transportTracer = new TransportTracer(); + private final Queue capturedBuffer = new ArrayDeque<>(); private OkHttpClientTransport clientTransport; private MockFrameReader frameReader; + private Socket socket; private ExecutorService executor = Executors.newCachedThreadPool(); private long nanoTime; // backs a ticker, for testing ping round-trip time measurement private SettableFuture connectedFuture; @@ -166,8 +169,10 @@ public class OkHttpClientTransportTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); - when(frameWriter.maxDataLength()).thenReturn(Integer.MAX_VALUE); frameReader = new MockFrameReader(); + socket = new MockSocket(frameReader); + MockFrameWriter mockFrameWriter = new MockFrameWriter(socket, capturedBuffer); + frameWriter = mock(FrameWriter.class, AdditionalAnswers.delegatesTo(mockFrameWriter)); } @After @@ -217,7 +222,7 @@ public class OkHttpClientTransportTest { frameReader, frameWriter, startId, - new MockSocket(frameReader), + socket, stopwatchSupplier, connectingCallback, connectedFuture, @@ -556,10 +561,9 @@ public class OkHttpClientTransportTest { assertEquals(12, input.available()); stream.writeMessage(input); stream.flush(); - ArgumentCaptor captor = ArgumentCaptor.forClass(Buffer.class); verify(frameWriter, timeout(TIME_OUT_MS)) - .data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH)); - Buffer sentFrame = captor.getValue(); + .data(eq(false), eq(3), any(Buffer.class), eq(12 + HEADER_LENGTH)); + Buffer sentFrame = capturedBuffer.poll(); assertEquals(createMessageFrame(message), sentFrame); stream.cancel(Status.CANCELLED); shutdownAndVerify(); @@ -888,11 +892,9 @@ public class OkHttpClientTransportTest { assertEquals(22, input.available()); stream.writeMessage(input); stream.flush(); - ArgumentCaptor captor = - ArgumentCaptor.forClass(Buffer.class); verify(frameWriter, timeout(TIME_OUT_MS)) - .data(eq(false), eq(3), captor.capture(), eq(22 + HEADER_LENGTH)); - Buffer sentFrame = captor.getValue(); + .data(eq(false), eq(3), any(Buffer.class), eq(22 + HEADER_LENGTH)); + Buffer sentFrame = capturedBuffer.poll(); assertEquals(createMessageFrame(sentMessage), sentFrame); // And read. @@ -976,10 +978,9 @@ public class OkHttpClientTransportTest { // The second stream should be active now, and the pending data should be sent out. assertEquals(1, activeStreamCount()); assertEquals(0, clientTransport.getPendingStreamSize()); - ArgumentCaptor captor = ArgumentCaptor.forClass(Buffer.class); verify(frameWriter, timeout(TIME_OUT_MS)) - .data(eq(true), eq(5), captor.capture(), eq(5 + HEADER_LENGTH)); - Buffer sentFrame = captor.getValue(); + .data(eq(true), eq(5), any(Buffer.class), eq(5 + HEADER_LENGTH)); + Buffer sentFrame = capturedBuffer.poll(); assertEquals(createMessageFrame(sentMessage), sentFrame); stream2.cancel(Status.CANCELLED); shutdownAndVerify(); @@ -1456,10 +1457,9 @@ public class OkHttpClientTransportTest { allowTransportConnected(); // The queued message should be sent out. - ArgumentCaptor captor = ArgumentCaptor.forClass(Buffer.class); verify(frameWriter, timeout(TIME_OUT_MS)) - .data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH)); - Buffer sentFrame = captor.getValue(); + .data(eq(false), eq(3), any(Buffer.class), eq(12 + HEADER_LENGTH)); + Buffer sentFrame = capturedBuffer.poll(); assertEquals(createMessageFrame(message), sentFrame); stream.cancel(Status.CANCELLED); shutdownAndVerify(); @@ -2091,7 +2091,6 @@ public class OkHttpClientTransportTest { verify(frameWriter, timeout(TIME_OUT_MS)).close(); } catch (IOException e) { throw new RuntimeException(e); - } frameReader.assertClosed(); } @@ -2113,4 +2112,83 @@ public class OkHttpClientTransportTest { throws ExecutionException, InterruptedException { return obj.getStats().get().data; } -} + + /** A FrameWriter to mock with CALL_REAL_METHODS option. */ + private static class MockFrameWriter implements FrameWriter { + + private Socket socket; + private Queue capturedBuffer; + + public MockFrameWriter(Socket socket, Queue capturedBuffer) { + // Sets a socket to close. Some tests assumes that FrameWriter will close underlying sink + // which will eventually close the socket. + this.socket = socket; + this.capturedBuffer = capturedBuffer; + } + + void setSocket(Socket socket) { + this.socket = socket; + } + + @Override + public void close() throws IOException { + socket.close(); + } + + @Override + public int maxDataLength() { + return Integer.MAX_VALUE; + } + + @Override + public void data(boolean outFinished, int streamId, Buffer source, int byteCount) + throws IOException { + // simulate the side effect, and captures to internal queue. + Buffer capture = new Buffer(); + capture.write(source, byteCount); + capturedBuffer.add(capture); + } + + // rest of methods are unimplemented + + @Override + public void connectionPreface() throws IOException {} + + @Override + public void ackSettings(Settings peerSettings) throws IOException {} + + @Override + public void pushPromise(int streamId, int promisedStreamId, List
requestHeaders) + throws IOException {} + + @Override + public void flush() throws IOException {} + + @Override + public void synStream(boolean outFinished, boolean inFinished, int streamId, + int associatedStreamId, List
headerBlock) throws IOException {} + + @Override + public void synReply(boolean outFinished, int streamId, List
headerBlock) + throws IOException {} + + @Override + public void headers(int streamId, List
headerBlock) throws IOException {} + + @Override + public void rstStream(int streamId, ErrorCode errorCode) throws IOException {} + + @Override + public void settings(Settings okHttpSettings) throws IOException {} + + @Override + public void ping(boolean ack, int payload1, int payload2) throws IOException {} + + @Override + public void goAway(int lastGoodStreamId, ErrorCode errorCode, byte[] debugData) + throws IOException {} + + @Override + public void windowUpdate(int streamId, long windowSizeIncrement) throws IOException {} + } +} \ No newline at end of file