diff --git a/okhttp/src/main/java/io/grpc/okhttp/AsyncFrameWriter.java b/okhttp/src/main/java/io/grpc/okhttp/AsyncFrameWriter.java deleted file mode 100644 index 210aaa1f08..0000000000 --- a/okhttp/src/main/java/io/grpc/okhttp/AsyncFrameWriter.java +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Copyright 2014 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.okhttp; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; -import io.grpc.internal.SerializingExecutor; -import io.grpc.okhttp.internal.framed.ErrorCode; -import io.grpc.okhttp.internal.framed.FrameWriter; -import io.grpc.okhttp.internal.framed.Header; -import io.grpc.okhttp.internal.framed.Settings; -import java.io.IOException; -import java.net.Socket; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.concurrent.atomic.AtomicLong; -import java.util.logging.Level; -import java.util.logging.Logger; -import okio.Buffer; - -class AsyncFrameWriter implements FrameWriter { - private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName()); - private FrameWriter frameWriter; - private Socket socket; - // Although writes are thread-safe, we serialize them to prevent consuming many Threads that are - // just waiting on each other. - private final SerializingExecutor executor; - private final TransportExceptionHandler transportExceptionHandler; - private final AtomicLong flushCounter = new AtomicLong(); - // Some exceptions are not very useful and add too much noise to the log - private static final Set QUIET_ERRORS = - Collections.unmodifiableSet(new HashSet<>(Arrays.asList("Socket closed"))); - - public AsyncFrameWriter( - TransportExceptionHandler transportExceptionHandler, SerializingExecutor executor) { - this.transportExceptionHandler = transportExceptionHandler; - this.executor = executor; - } - - /** - * Set the real frameWriter and the corresponding underlying socket, the socket is needed for - * closing. - * - *

should only be called by thread of executor. - */ - void becomeConnected(FrameWriter frameWriter, Socket socket) { - Preconditions.checkState(this.frameWriter == null, - "AsyncFrameWriter's setFrameWriter() should only be called once."); - this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter"); - this.socket = Preconditions.checkNotNull(socket, "socket"); - } - - @Override - public void connectionPreface() { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.connectionPreface(); - } - }); - } - - @Override - public void ackSettings(final Settings peerSettings) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.ackSettings(peerSettings); - } - }); - } - - @Override - public void pushPromise(final int streamId, final int promisedStreamId, - final List

requestHeaders) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.pushPromise(streamId, promisedStreamId, requestHeaders); - } - }); - } - - @Override - public void flush() { - // keep track of version of flushes to skip flush if another flush task is queued. - final long flushCount = flushCounter.incrementAndGet(); - - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - // There can be a flush starvation if there are continuous flood of flush is queued, this - // is not an issue with OkHttp since it flushes if the buffer is full. - if (flushCounter.get() == flushCount) { - frameWriter.flush(); - } - } - }); - } - - @Override - public void synStream(final boolean outFinished, final boolean inFinished, final int streamId, - final int associatedStreamId, final List
headerBlock) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.synStream(outFinished, inFinished, streamId, associatedStreamId, headerBlock); - } - }); - } - - @Override - public void synReply(final boolean outFinished, final int streamId, - final List
headerBlock) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.synReply(outFinished, streamId, headerBlock); - } - }); - } - - @Override - public void headers(final int streamId, final List
headerBlock) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.headers(streamId, headerBlock); - } - }); - } - - @Override - public void rstStream(final int streamId, final ErrorCode errorCode) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.rstStream(streamId, errorCode); - } - }); - } - - @Override - public void data(final boolean outFinished, final int streamId, final Buffer source, - final int byteCount) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.data(outFinished, streamId, source, byteCount); - } - }); - } - - @Override - public void settings(final Settings okHttpSettings) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.settings(okHttpSettings); - } - }); - } - - @Override - public void ping(final boolean ack, final int payload1, final int payload2) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.ping(ack, payload1, payload2); - } - }); - } - - @Override - public void goAway(final int lastGoodStreamId, final ErrorCode errorCode, - final byte[] debugData) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.goAway(lastGoodStreamId, errorCode, debugData); - // Flush it since after goAway, we are likely to close this writer. - frameWriter.flush(); - } - }); - } - - @Override - public void windowUpdate(final int streamId, final long windowSizeIncrement) { - executor.execute(new WriteRunnable() { - @Override - public void doRun() throws IOException { - frameWriter.windowUpdate(streamId, windowSizeIncrement); - } - }); - } - - @Override - public void close() { - executor.execute(new Runnable() { - @Override - public void run() { - if (frameWriter != null) { - try { - frameWriter.close(); - socket.close(); - } catch (IOException e) { - log.log(getLogLevel(e), "Failed closing connection", e); - } - } - } - }); - } - - /** - * Accepts a throwable and returns the appropriate logging level. Uninteresting exceptions - * should not clutter the log. - */ - @VisibleForTesting - static Level getLogLevel(Throwable t) { - if (t instanceof IOException - && t.getMessage() != null - && QUIET_ERRORS.contains(t.getMessage())) { - return Level.FINE; - - } - return Level.INFO; - } - - private abstract class WriteRunnable implements Runnable { - @Override - public final void run() { - try { - if (frameWriter == null) { - throw new IOException("Unable to perform write due to unavailable frameWriter."); - } - doRun(); - } catch (RuntimeException e) { - transportExceptionHandler.onException(e); - } catch (Exception e) { - transportExceptionHandler.onException(e); - } - } - - public abstract void doRun() throws IOException; - } - - @Override - public int maxDataLength() { - return frameWriter == null ? 0x4000 /* 16384, the minimum required by the HTTP/2 spec */ - : frameWriter.maxDataLength(); - } - - /** A class that handles transport exception. */ - interface TransportExceptionHandler { - - /** Handles exception. */ - void onException(Throwable throwable); - } -} diff --git a/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java new file mode 100644 index 0000000000..bd1d4762a2 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java @@ -0,0 +1,162 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import io.grpc.internal.SerializingExecutor; +import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; +import java.io.IOException; +import java.net.Socket; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import okio.Buffer; +import okio.Sink; +import okio.Timeout; + +/** + * A sink that asynchronously write / flushes a buffer internally. AsyncSink provides flush + * coalescing to minimize network packing transmit. + */ +final class AsyncSink implements Sink { + + private final Object lock = new Object(); + private final Buffer buffer = new Buffer(); + private final SerializingExecutor serializingExecutor; + private final TransportExceptionHandler transportExceptionHandler; + + @GuardedBy("lock") + private boolean writeEnqueued = false; + @GuardedBy("lock") + private boolean flushEnqueued = false; + private boolean closed = false; + @Nullable + private Sink sink; + @Nullable + private Socket socket; + + private AsyncSink(SerializingExecutor executor, TransportExceptionHandler exceptionHandler) { + this.serializingExecutor = checkNotNull(executor, "executor"); + this.transportExceptionHandler = checkNotNull(exceptionHandler, "exceptionHandler"); + } + + static AsyncSink sink( + SerializingExecutor executor, TransportExceptionHandler exceptionHandler) { + return new AsyncSink(executor, exceptionHandler); + } + + /** + * Sets the actual sink. It is allowed to call write / flush operations on the sink iff calling + * this method is scheduled in the executor. The socket is needed for closing. + * + *

should only be called once by thread of executor. + */ + void becomeConnected(Sink sink, Socket socket) { + checkState(this.sink == null, "AsyncSink's becomeConnected should only be called once."); + this.sink = checkNotNull(sink, "sink"); + this.socket = checkNotNull(socket, "socket"); + } + + @Override + public void write(Buffer source, long byteCount) throws IOException { + checkNotNull(source, "source"); + if (closed) { + throw new IOException("closed"); + } + synchronized (lock) { + buffer.write(source, byteCount); + if (writeEnqueued || flushEnqueued || buffer.completeSegmentByteCount() <= 0) { + return; + } + writeEnqueued = true; + } + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + Buffer buf = new Buffer(); + synchronized (lock) { + buf.write(buffer, buffer.completeSegmentByteCount()); + writeEnqueued = false; + } + try { + sink.write(buf, buf.size()); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + }); + } + + @Override + public void flush() throws IOException { + if (closed) { + throw new IOException("closed"); + } + synchronized (lock) { + if (flushEnqueued) { + return; + } + flushEnqueued = true; + } + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + Buffer buf = new Buffer(); + synchronized (lock) { + buf.write(buffer, buffer.size()); + flushEnqueued = false; + } + try { + sink.write(buf, buf.size()); + sink.flush(); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + }); + } + + @Override + public Timeout timeout() { + return Timeout.NONE; + } + + @Override + public void close() { + if (closed) { + return; + } + closed = true; + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + buffer.close(); + try { + sink.close(); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + try { + socket.close(); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + }); + } +} \ No newline at end of file diff --git a/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java b/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java new file mode 100644 index 0000000000..0ee4c42367 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java @@ -0,0 +1,214 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.grpc.okhttp.internal.framed.ErrorCode; +import io.grpc.okhttp.internal.framed.FrameWriter; +import io.grpc.okhttp.internal.framed.Header; +import io.grpc.okhttp.internal.framed.Settings; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; +import okio.Buffer; + +final class ExceptionHandlingFrameWriter implements FrameWriter { + + private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName()); + // Some exceptions are not very useful and add too much noise to the log + private static final Set QUIET_ERRORS = + Collections.unmodifiableSet(new HashSet<>(Arrays.asList("Socket closed"))); + + private final TransportExceptionHandler transportExceptionHandler; + + private final FrameWriter frameWriter; + + ExceptionHandlingFrameWriter( + TransportExceptionHandler transportExceptionHandler, FrameWriter frameWriter) { + this.transportExceptionHandler = + checkNotNull(transportExceptionHandler, "transportExceptionHandler"); + this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter"); + } + + @Override + public void connectionPreface() { + try { + frameWriter.connectionPreface(); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void ackSettings(Settings peerSettings) { + try { + frameWriter.ackSettings(peerSettings); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void pushPromise(int streamId, int promisedStreamId, List

requestHeaders) { + try { + frameWriter.pushPromise(streamId, promisedStreamId, requestHeaders); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void flush() { + try { + frameWriter.flush(); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void synStream( + boolean outFinished, + boolean inFinished, + int streamId, + int associatedStreamId, + List
headerBlock) { + try { + frameWriter.synStream(outFinished, inFinished, streamId, associatedStreamId, headerBlock); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void synReply(boolean outFinished, int streamId, + List
headerBlock) { + try { + frameWriter.synReply(outFinished, streamId, headerBlock); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void headers(int streamId, List
headerBlock) { + try { + frameWriter.headers(streamId, headerBlock); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void rstStream(int streamId, ErrorCode errorCode) { + try { + frameWriter.rstStream(streamId, errorCode); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public int maxDataLength() { + return frameWriter.maxDataLength(); + } + + @Override + public void data(boolean outFinished, int streamId, Buffer source, int byteCount) { + try { + frameWriter.data(outFinished, streamId, source, byteCount); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void settings(Settings okHttpSettings) { + try { + frameWriter.settings(okHttpSettings); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void ping(boolean ack, int payload1, int payload2) { + try { + frameWriter.ping(ack, payload1, payload2); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void goAway(int lastGoodStreamId, ErrorCode errorCode, + byte[] debugData) { + try { + frameWriter.goAway(lastGoodStreamId, errorCode, debugData); + // Flush it since after goAway, we are likely to close this writer. + frameWriter.flush(); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void windowUpdate(int streamId, long windowSizeIncrement) { + try { + frameWriter.windowUpdate(streamId, windowSizeIncrement); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + } + + @Override + public void close() { + try { + frameWriter.close(); + } catch (IOException e) { + log.log(getLogLevel(e), "Failed closing connection", e); + } + } + + /** + * Accepts a throwable and returns the appropriate logging level. Uninteresting exceptions + * should not clutter the log. + */ + @VisibleForTesting + static Level getLogLevel(Throwable t) { + if (t instanceof IOException + && t.getMessage() != null + && QUIET_ERRORS.contains(t.getMessage())) { + return Level.FINE; + } + return Level.INFO; + } + + /** A class that handles transport exception. */ + interface TransportExceptionHandler { + /** Handles exception. */ + void onException(Throwable throwable); + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java index 782cbbc1db..e62d0cf3ce 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java @@ -62,7 +62,7 @@ class OkHttpClientStream extends AbstractClientStream { OkHttpClientStream( MethodDescriptor method, Metadata headers, - AsyncFrameWriter frameWriter, + ExceptionHandlingFrameWriter frameWriter, OkHttpClientTransport transport, OutboundFlowController outboundFlow, Object lock, @@ -203,7 +203,7 @@ class OkHttpClientStream extends AbstractClientStream { @GuardedBy("lock") private int processedWindow; @GuardedBy("lock") - private final AsyncFrameWriter frameWriter; + private final ExceptionHandlingFrameWriter frameWriter; @GuardedBy("lock") private final OutboundFlowController outboundFlow; @GuardedBy("lock") @@ -216,7 +216,7 @@ class OkHttpClientStream extends AbstractClientStream { int maxMessageSize, StatsTraceContext statsTraceCtx, Object lock, - AsyncFrameWriter frameWriter, + ExceptionHandlingFrameWriter frameWriter, OutboundFlowController outboundFlow, OkHttpClientTransport transport, int initialWindowSize) { diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index f1ce827018..81ef3a0cc5 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -57,7 +57,7 @@ import io.grpc.internal.SerializingExecutor; import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; -import io.grpc.okhttp.AsyncFrameWriter.TransportExceptionHandler; +import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; import io.grpc.okhttp.internal.ConnectionSpec; import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.FrameReader; @@ -107,7 +107,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep private static final OkHttpClientStream[] EMPTY_STREAM_ARRAY = new OkHttpClientStream[0]; private static Map buildErrorCodeToStatusMap() { - Map errorToStatus = new EnumMap(ErrorCode.class); + Map errorToStatus = new EnumMap<>(ErrorCode.class); errorToStatus.put(ErrorCode.NO_ERROR, Status.INTERNAL.withDescription("No error: A GRPC status of OK should have been sent")); errorToStatus.put(ErrorCode.PROTOCOL_ERROR, @@ -144,15 +144,15 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep private final int initialWindowSize; private Listener listener; private FrameReader testFrameReader; - private AsyncFrameWriter frameWriter; + @GuardedBy("lock") + private ExceptionHandlingFrameWriter frameWriter; private OutboundFlowController outboundFlow; private final Object lock = new Object(); private final InternalLogId logId; @GuardedBy("lock") private int nextStreamId; @GuardedBy("lock") - private final Map streams = - new HashMap(); + private final Map streams = new HashMap<>(); private final Executor executor; // Wrap on executor, to guarantee some operations be executed serially. private final SerializingExecutor serializingExecutor; @@ -182,7 +182,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep private int maxConcurrentStreams = 0; @SuppressWarnings("JdkObsolete") // Usage is bursty; want low memory usage when empty @GuardedBy("lock") - private LinkedList pendingStreams = new LinkedList(); + private LinkedList pendingStreams = new LinkedList<>(); private final ConnectionSpec connectionSpec; private FrameWriter testFrameWriter; private ScheduledExecutorService scheduler; @@ -219,7 +219,6 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep Runnable connectingCallback; SettableFuture connectedFuture; - OkHttpClientTransport(InetSocketAddress address, String authority, @Nullable String userAgent, Executor executor, @Nullable SSLSocketFactory sslSocketFactory, @Nullable HostnameVerifier hostnameVerifier, ConnectionSpec connectionSpec, @@ -322,11 +321,11 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep @Override public void ping(final PingCallback callback, Executor executor) { - checkState(frameWriter != null); long data = 0; Http2Ping p; boolean writePing; synchronized (lock) { + checkState(frameWriter != null); if (stopped) { Http2Ping.notifyFailed(callback, executor, getPingFailure()); return; @@ -345,9 +344,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep writePing = true; transportTracer.reportKeepAliveSent(); } - } - if (writePing) { - frameWriter.ping(false, (int) (data >>> 32), (int) data); + if (writePing) { + frameWriter.ping(false, (int) (data >>> 32), (int) data); + } } // If transport concurrently failed/stopped since we released the lock above, this could // immediately invoke callback (which we shouldn't do while holding a lock) @@ -449,15 +448,15 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep keepAliveWithoutCalls); keepAliveManager.onTransportStarted(); } - - frameWriter = new AsyncFrameWriter(this, serializingExecutor); - outboundFlow = new OutboundFlowController(this, frameWriter, initialWindowSize); - // Connecting in the serializingExecutor, so that some stream operations like synStream - // will be executed after connected. - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - if (isForTest()) { + if (isForTest()) { + synchronized (lock) { + frameWriter = new ExceptionHandlingFrameWriter(OkHttpClientTransport.this, testFrameWriter); + outboundFlow = + new OutboundFlowController(OkHttpClientTransport.this, frameWriter, initialWindowSize); + } + serializingExecutor.execute(new Runnable() { + @Override + public void run() { if (connectingCallback != null) { connectingCallback.run(); } @@ -467,11 +466,25 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep maxConcurrentStreams = Integer.MAX_VALUE; startPendingStreams(); } - frameWriter.becomeConnected(testFrameWriter, socket); connectedFuture.set(null); - return; } + }); + return null; + } + final AsyncSink asyncSink = AsyncSink.sink(serializingExecutor, this); + final Variant variant = new Http2(); + FrameWriter rawFrameWriter = variant.newWriter(Okio.buffer(asyncSink), true); + + synchronized (lock) { + frameWriter = new ExceptionHandlingFrameWriter(this, rawFrameWriter); + outboundFlow = new OutboundFlowController(this, frameWriter, initialWindowSize); + } + // Connecting in the serializingExecutor, so that some stream operations like synStream + // will be executed after connected. + serializingExecutor.execute(new Runnable() { + @Override + public void run() { // Use closed source on failure so that the reader immediately shuts down. BufferedSource source = Okio.buffer(new Source() { @Override @@ -485,10 +498,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep } @Override - public void close() {} + public void close() { + } }); - Variant variant = new Http2(); - BufferedSink sink; Socket sock; SSLSession sslSession = null; try { @@ -508,7 +520,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep } sock.setTcpNoDelay(true); source = Okio.buffer(Okio.source(sock)); - sink = Okio.buffer(Okio.sink(sock)); + asyncSink.becomeConnected(Okio.sink(sock), sock); + // The return value of OkHttpTlsUpgrader.upgrade is an SSLSocket that has this info attributes = Attributes .newBuilder() @@ -526,31 +539,29 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep return; } finally { clientFrameHandler = new ClientFrameHandler(variant.newReader(source, true)); - executor.execute(clientFrameHandler); } - - FrameWriter rawFrameWriter; synchronized (lock) { socket = Preconditions.checkNotNull(sock, "socket"); - maxConcurrentStreams = Integer.MAX_VALUE; - startPendingStreams(); if (sslSession != null) { securityInfo = new InternalChannelz.Security(new InternalChannelz.Tls(sslSession)); } } - - rawFrameWriter = variant.newWriter(sink, true); - frameWriter.becomeConnected(rawFrameWriter, socket); - - try { - // Do these with the raw FrameWriter, so that they will be done in this thread, - // and before any possible pending stream operations. - rawFrameWriter.connectionPreface(); - Settings settings = new Settings(); - rawFrameWriter.settings(settings); - } catch (Exception e) { - onException(e); - return; + } + }); + synchronized (lock) { + frameWriter.connectionPreface(); + Settings settings = new Settings(); + frameWriter.settings(settings); + } + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + // ClientFrameHandler need to be started after connectionPreface / settings, otherwise it + // may send goAway immediately. + executor.execute(clientFrameHandler); + synchronized (lock) { + maxConcurrentStreams = Integer.MAX_VALUE; + startPendingStreams(); } } }); @@ -558,7 +569,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep } private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddress proxyAddress, - String proxyUsername, String proxyPassword) throws IOException, StatusException { + String proxyUsername, String proxyPassword) throws StatusException { try { Socket sock; // The proxy address may not be resolved @@ -1039,7 +1050,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep OkHttpClientStream stream = getStream(streamId); if (stream == null) { if (mayHaveCreatedStream(streamId)) { - frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM); + synchronized (lock) { + frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM); + } in.skip(length); } else { onError(ErrorCode.PROTOCOL_ERROR, "Received data for unknown stream: " + streamId); @@ -1059,7 +1072,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep // connection window update connectionUnacknowledgedBytesRead += length; if (connectionUnacknowledgedBytesRead >= initialWindowSize * DEFAULT_WINDOW_UPDATE_RATIO) { - frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead); + synchronized (lock) { + frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead); + } connectionUnacknowledgedBytesRead = 0; } } @@ -1170,7 +1185,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep @Override public void ping(boolean ack, int payload1, int payload2) { if (!ack) { - frameWriter.ping(true, payload1, payload2); + synchronized (lock) { + frameWriter.ping(true, payload1, payload2); + } } else { Http2Ping p = null; long ackPayload = (((long) payload1) << 32) | (payload2 & 0xffffffffL); @@ -1222,7 +1239,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep public void pushPromise(int streamId, int promisedStreamId, List
requestHeaders) throws IOException { // We don't accept server initiated stream. - frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR); + synchronized (lock) { + frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR); + } } @Override diff --git a/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java b/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java index e3cad95313..441bb21151 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java @@ -323,12 +323,7 @@ class OutboundFlowController { try { // endOfStream is set for the last chunk of data marked as endOfStream boolean isEndOfStream = buffer.size() == frameBytes && endOfStream; - // AsyncFrameWriter drains buffer in executor. To avoid race, copy to temp. - // TODO(jihuncho) remove temp buff when async logic is moved to AsyncSink. - Buffer temp = new Buffer(); - temp.write(buffer, frameBytes); - - frameWriter.data(isEndOfStream, streamId, temp, frameBytes); + frameWriter.data(isEndOfStream, streamId, buffer, frameBytes); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/okhttp/src/test/java/io/grpc/okhttp/AsyncFrameWriterTest.java b/okhttp/src/test/java/io/grpc/okhttp/AsyncFrameWriterTest.java deleted file mode 100644 index 45e80707c5..0000000000 --- a/okhttp/src/test/java/io/grpc/okhttp/AsyncFrameWriterTest.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright 2018 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.okhttp; - -import static com.google.common.truth.Truth.assertThat; -import static io.grpc.okhttp.AsyncFrameWriter.getLogLevel; -import static org.mockito.Matchers.anyBoolean; -import static org.mockito.Matchers.anyInt; -import static org.mockito.Mockito.inOrder; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; - -import io.grpc.internal.SerializingExecutor; -import io.grpc.okhttp.AsyncFrameWriter.TransportExceptionHandler; -import io.grpc.okhttp.internal.framed.FrameWriter; -import java.io.IOException; -import java.net.Socket; -import java.util.Queue; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.Executor; -import java.util.logging.Level; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.InOrder; -import org.mockito.Mock; -import org.mockito.runners.MockitoJUnitRunner; - -@RunWith(MockitoJUnitRunner.class) -public class AsyncFrameWriterTest { - - @Mock private Socket socket; - @Mock private FrameWriter frameWriter; - - private QueueingExecutor queueingExecutor = new QueueingExecutor(); - private TransportExceptionHandler transportExceptionHandler = - new EscalatingTransportErrorHandler(); - private AsyncFrameWriter asyncFrameWriter = - new AsyncFrameWriter(transportExceptionHandler, new SerializingExecutor(queueingExecutor)); - - @Before - public void setUp() throws Exception { - asyncFrameWriter.becomeConnected(frameWriter, socket); - } - - @Test - public void noCoalesceRequired() throws IOException { - asyncFrameWriter.ping(true, 0, 1); - asyncFrameWriter.flush(); - queueingExecutor.runAll(); - - verify(frameWriter, times(1)).ping(anyBoolean(), anyInt(), anyInt()); - verify(frameWriter, times(1)).flush(); - } - - @Test - public void flushCoalescing_shouldNotMergeTwoDistinctFlushes() throws IOException { - asyncFrameWriter.ping(true, 0, 1); - asyncFrameWriter.flush(); - queueingExecutor.runAll(); - - asyncFrameWriter.ping(true, 0, 2); - asyncFrameWriter.flush(); - queueingExecutor.runAll(); - - verify(frameWriter, times(2)).ping(anyBoolean(), anyInt(), anyInt()); - verify(frameWriter, times(2)).flush(); - } - - @Test - public void flushCoalescing_shouldMergeTwoQueuedFlushes() throws IOException { - asyncFrameWriter.ping(true, 0, 1); - asyncFrameWriter.flush(); - asyncFrameWriter.ping(true, 0, 2); - asyncFrameWriter.flush(); - - queueingExecutor.runAll(); - - InOrder inOrder = inOrder(frameWriter); - inOrder.verify(frameWriter, times(2)).ping(anyBoolean(), anyInt(), anyInt()); - inOrder.verify(frameWriter).flush(); - } - - @Test - public void unknownException() { - assertThat(getLogLevel(new Exception())).isEqualTo(Level.INFO); - } - - @Test - public void quiet() { - assertThat(getLogLevel(new IOException("Socket closed"))).isEqualTo(Level.FINE); - } - - @Test - public void nonquiet() { - assertThat(getLogLevel(new IOException("foo"))).isEqualTo(Level.INFO); - } - - @Test - public void nullMessage() { - IOException e = new IOException(); - assertThat(e.getMessage()).isNull(); - assertThat(getLogLevel(e)).isEqualTo(Level.INFO); - } - - /** - * Executor queues incoming runnables instead of running it. Runnables can be invoked via {@link - * QueueingExecutor#runAll} in serial order. - */ - private static class QueueingExecutor implements Executor { - - private final Queue runnables = new ConcurrentLinkedQueue(); - - @Override - public void execute(Runnable command) { - runnables.add(command); - } - - public void runAll() { - Runnable r; - while ((r = runnables.poll()) != null) { - r.run(); - } - } - } - - /** Rethrows as Assertion error. */ - private static class EscalatingTransportErrorHandler implements TransportExceptionHandler { - - @Override - public void onException(Throwable throwable) { - throw new AssertionError(throwable); - } - } -} diff --git a/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java b/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java new file mode 100644 index 0000000000..084a00f72c --- /dev/null +++ b/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java @@ -0,0 +1,289 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.CALLS_REAL_METHODS; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; + +import com.google.common.base.Charsets; +import io.grpc.internal.SerializingExecutor; +import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; +import java.io.IOException; +import java.net.Socket; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executor; +import okio.Buffer; +import okio.Sink; +import okio.Timeout; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.InOrder; + +/** Tests for {@link AsyncSink}. */ +@RunWith(JUnit4.class) +public class AsyncSinkTest { + + private Socket socket = mock(Socket.class); + private Sink mockedSink = mock(VoidSink.class, CALLS_REAL_METHODS); + private QueueingExecutor queueingExecutor = new QueueingExecutor(); + private TransportExceptionHandler exceptionHandler = mock(TransportExceptionHandler.class); + private AsyncSink sink = + AsyncSink.sink(new SerializingExecutor(queueingExecutor), exceptionHandler); + + @Before + public void setUp() throws Exception { + sink.becomeConnected(mockedSink, socket); + } + + @Test + public void noCoalesceRequired() throws IOException { + Buffer buffer = new Buffer(); + sink.write(buffer.writeUtf8("hello"), buffer.size()); + sink.flush(); + queueingExecutor.runAll(); + + InOrder inOrder = inOrder(mockedSink); + inOrder.verify(mockedSink).write(any(Buffer.class), anyInt()); + inOrder.verify(mockedSink).flush(); + } + + @Test + public void flushCoalescing_shouldNotMergeTwoDistinctFlushes() throws IOException { + byte[] firstData = "a string".getBytes(Charsets.UTF_8); + byte[] secondData = "a longer string".getBytes(Charsets.UTF_8); + + Buffer buffer = new Buffer(); + sink.write(buffer.write(firstData), buffer.size()); + sink.flush(); + queueingExecutor.runAll(); + + sink.write(buffer.write(secondData), buffer.size()); + sink.flush(); + queueingExecutor.runAll(); + + InOrder inOrder = inOrder(mockedSink); + inOrder.verify(mockedSink).write(any(Buffer.class), anyInt()); + inOrder.verify(mockedSink).flush(); + inOrder.verify(mockedSink).write(any(Buffer.class), anyInt()); + inOrder.verify(mockedSink).flush(); + } + + @Test + public void flushCoalescing_shouldMergeTwoQueuedFlushesAndWrites() throws IOException { + byte[] firstData = "a string".getBytes(Charsets.UTF_8); + byte[] secondData = "a longer string".getBytes(Charsets.UTF_8); + Buffer buffer = new Buffer().write(firstData); + sink.write(buffer, buffer.size()); + sink.flush(); + buffer = new Buffer().write(secondData); + sink.write(buffer, buffer.size()); + sink.flush(); + + queueingExecutor.runAll(); + + InOrder inOrder = inOrder(mockedSink); + inOrder.verify(mockedSink) + .write(any(Buffer.class), eq((long) firstData.length + secondData.length)); + inOrder.verify(mockedSink).flush(); + } + + @Test + public void flushCoalescing_shouldMergeWrites() throws IOException { + byte[] firstData = "a string".getBytes(Charsets.UTF_8); + byte[] secondData = "a longer string".getBytes(Charsets.UTF_8); + Buffer buffer = new Buffer(); + sink.write(buffer.write(firstData), buffer.size()); + sink.write(buffer.write(secondData), buffer.size()); + sink.flush(); + queueingExecutor.runAll(); + + InOrder inOrder = inOrder(mockedSink); + inOrder.verify(mockedSink) + .write(any(Buffer.class), eq((long) firstData.length + secondData.length)); + inOrder.verify(mockedSink).flush(); + } + + @Test + public void write_shouldCachePreviousException() throws IOException { + Exception ioException = new IOException("some exception"); + doThrow(ioException) + .when(mockedSink).write(any(Buffer.class), anyLong()); + Buffer buffer = new Buffer(); + buffer.writeUtf8("any message"); + sink.write(buffer, buffer.size()); + sink.flush(); + queueingExecutor.runAll(); + sink.write(buffer, buffer.size()); + queueingExecutor.runAll(); + + verify(exceptionHandler, timeout(1000)).onException(ioException); + } + + @Test + public void close_writeShouldThrowException() { + sink.close(); + queueingExecutor.runAll(); + try { + sink.write(new Buffer(), 0); + fail("should throw ioException"); + } catch (IOException e) { + assertThat(e).hasMessageThat().contains("closed"); + } + } + + @Test + public void write_shouldThrowIfAlreadyClosed() throws IOException { + Exception ioException = new IOException("some exception"); + doThrow(ioException) + .when(mockedSink).write(any(Buffer.class), anyLong()); + Buffer buffer = new Buffer(); + buffer.writeUtf8("any message"); + sink.write(buffer, buffer.size()); + sink.close(); + queueingExecutor.runAll(); + try { + sink.write(buffer, buffer.size()); + queueingExecutor.runAll(); + fail("should throw ioException"); + } catch (IOException e) { + assertThat(e).hasMessageThat().contains("closed"); + } + } + + @Test + public void close_flushShouldThrowException() throws IOException { + sink.close(); + queueingExecutor.runAll(); + try { + sink.flush(); + queueingExecutor.runAll(); + fail("should fail"); + } catch (IOException e) { + assertThat(e).hasMessageThat().contains("closed"); + } + } + + @Test + public void flush_shouldThrowIfAlreadyClosed() throws IOException { + Buffer buffer = new Buffer(); + buffer.writeUtf8("any message"); + sink.write(buffer, buffer.size()); + sink.close(); + queueingExecutor.runAll(); + try { + sink.flush(); + queueingExecutor.runAll(); + fail("should fail"); + } catch (IOException e) { + assertThat(e).hasMessageThat().contains("closed"); + } + } + + @Test + public void write_callSinkIfBufferIsLargerThanSegmentSize() throws IOException { + Buffer buffer = new Buffer(); + // OkHttp is using 8192 as segment size. + int payloadSize = 8192 * 2 - 1; + int padding = 10; + buffer.write(new byte[payloadSize]); + + int completeSegmentBytes = (int) buffer.completeSegmentByteCount(); + assertThat(completeSegmentBytes).isLessThan(payloadSize); + + // first trying to send of all complete segments, but not the padding + sink.write(buffer, completeSegmentBytes + padding); + queueingExecutor.runAll(); + verify(mockedSink).write(any(Buffer.class), eq((long) completeSegmentBytes)); + verify(mockedSink, never()).flush(); + assertThat(buffer.size()).isEqualTo((long) (payloadSize - completeSegmentBytes - padding)); + + // writing smaller than completed segment, shouldn't trigger write to Sink. + reset(mockedSink); + sink.write(buffer, buffer.size()); + queueingExecutor.runAll(); + verify(mockedSink, never()).write(any(Buffer.class), anyLong()); + verify(mockedSink, never()).flush(); + assertThat(buffer.exhausted()).isTrue(); + + // flush should write everything. + sink.flush(); + queueingExecutor.runAll(); + verify(mockedSink).write(any(Buffer.class),eq((long) payloadSize - completeSegmentBytes)); + verify(mockedSink).flush(); + } + + /** + * Executor queues incoming runnables instead of running it. Runnables can be invoked via {@link + * QueueingExecutor#runAll} in serial order. + */ + private static class QueueingExecutor implements Executor { + + private final Queue runnables = new ConcurrentLinkedQueue<>(); + + @Override + public void execute(Runnable command) { + runnables.add(command); + } + + public void runAll() { + Runnable r; + while ((r = runnables.poll()) != null) { + r.run(); + } + } + } + + /** Test sink to mimic real Sink behavior since write has a side effect. */ + private static class VoidSink implements Sink { + + @Override + public void write(Buffer source, long byteCount) throws IOException { + // removes byteCount bytes from source. + source.read(new byte[(int) byteCount], 0, (int) byteCount); + } + + @Override + public void flush() throws IOException { + // do nothing + } + + @Override + public Timeout timeout() { + return Timeout.NONE; + } + + @Override + public void close() throws IOException { + // do nothing + } + } +} \ No newline at end of file diff --git a/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java b/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java new file mode 100644 index 0000000000..13cc338d35 --- /dev/null +++ b/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.okhttp.ExceptionHandlingFrameWriter.getLogLevel; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; +import io.grpc.okhttp.internal.framed.FrameWriter; +import io.grpc.okhttp.internal.framed.Header; +import java.io.IOException; +import java.util.ArrayList; +import java.util.logging.Level; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ExceptionHandlingFrameWriterTest { + + private FrameWriter mockedFrameWriter = mock(FrameWriter.class); + private TransportExceptionHandler transportExceptionHandler = + mock(TransportExceptionHandler.class); + private ExceptionHandlingFrameWriter exceptionHandlingFrameWriter = + new ExceptionHandlingFrameWriter(transportExceptionHandler, mockedFrameWriter); + + @Test + public void exception() throws IOException { + IOException exception = new IOException("some exception"); + doThrow(exception).when(mockedFrameWriter) + .synReply(false, 100, new ArrayList
()); + + exceptionHandlingFrameWriter.synReply(false, 100, new ArrayList
()); + + verify(transportExceptionHandler).onException(exception); + verify(mockedFrameWriter).synReply(false, 100, new ArrayList
()); + } + + @Test + public void unknownException() { + assertThat(getLogLevel(new Exception())).isEqualTo(Level.INFO); + } + + @Test + public void quiet() { + assertThat(getLogLevel(new IOException("Socket closed"))).isEqualTo(Level.FINE); + } + + @Test + public void nonquiet() { + assertThat(getLogLevel(new IOException("foo"))).isEqualTo(Level.INFO); + } + + @Test + public void nullMessage() { + IOException e = new IOException(); + assertThat(e.getMessage()).isNull(); + assertThat(getLogLevel(e)).isEqualTo(Level.INFO); + } +} \ No newline at end of file diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java index 9768d0f353..1b7d38fc1a 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java @@ -37,8 +37,10 @@ import io.grpc.internal.NoopClientStreamListener; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; import io.grpc.okhttp.internal.framed.ErrorCode; +import io.grpc.okhttp.internal.framed.FrameWriter; import io.grpc.okhttp.internal.framed.Header; import java.io.ByteArrayInputStream; +import java.io.IOException; import java.nio.charset.Charset; import java.util.List; import java.util.concurrent.atomic.AtomicReference; @@ -60,7 +62,8 @@ public class OkHttpClientStreamTest { private static final int INITIAL_WINDOW_SIZE = 65535; @Mock private MethodDescriptor.Marshaller marshaller; - @Mock private AsyncFrameWriter frameWriter; + @Mock private FrameWriter mockedFrameWriter; + private ExceptionHandlingFrameWriter frameWriter; @Mock private OkHttpClientTransport transport; @Mock private OutboundFlowController flowController; @Captor private ArgumentCaptor> headersCaptor; @@ -81,6 +84,8 @@ public class OkHttpClientStreamTest { .setResponseMarshaller(marshaller) .build(); + frameWriter = + new ExceptionHandlingFrameWriter(transport, mockedFrameWriter); stream = new OkHttpClientStream( methodDescriptor, new Metadata(), @@ -144,11 +149,11 @@ public class OkHttpClientStreamTest { stream.transportState().start(1234); - verifyNoMoreInteractions(frameWriter); + verifyNoMoreInteractions(mockedFrameWriter); } @Test - public void start_userAgentRemoved() { + public void start_userAgentRemoved() throws IOException { Metadata metaData = new Metadata(); metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application"); stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport, @@ -157,13 +162,14 @@ public class OkHttpClientStreamTest { stream.start(new BaseClientStreamListener()); stream.transportState().start(3); - verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); + verify(mockedFrameWriter) + .synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); assertThat(headersCaptor.getValue()) .contains(new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application")); } @Test - public void start_headerFieldOrder() { + public void start_headerFieldOrder() throws IOException { Metadata metaData = new Metadata(); metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application"); stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport, @@ -172,7 +178,8 @@ public class OkHttpClientStreamTest { stream.start(new BaseClientStreamListener()); stream.transportState().start(3); - verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); + verify(mockedFrameWriter) + .synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); assertThat(headersCaptor.getValue()).containsExactly( Headers.SCHEME_HEADER, Headers.METHOD_HEADER, @@ -185,7 +192,7 @@ public class OkHttpClientStreamTest { } @Test - public void getUnaryRequest() { + public void getUnaryRequest() throws IOException { MethodDescriptor getMethod = MethodDescriptor.newBuilder() .setType(MethodDescriptor.MethodType.UNARY) .setFullMethodName("service/method") @@ -200,7 +207,7 @@ public class OkHttpClientStreamTest { stream.start(new BaseClientStreamListener()); // GET streams send headers after halfClose is called. - verify(frameWriter, times(0)).synStream( + verify(mockedFrameWriter, times(0)).synStream( eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class)); @@ -210,7 +217,7 @@ public class OkHttpClientStreamTest { verify(transport).streamReadyToStart(eq(stream)); stream.transportState().start(3); - verify(frameWriter) + verify(mockedFrameWriter) .synStream(eq(true), eq(false), eq(3), eq(0), headersCaptor.capture()); assertThat(headersCaptor.getValue()).contains(Headers.METHOD_GET_HEADER); assertThat(headersCaptor.getValue()).contains( diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index cb3f7e5fe2..367d8e18d7 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -46,7 +46,6 @@ import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; @@ -91,9 +90,11 @@ import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketAddress; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -113,6 +114,7 @@ import org.junit.Test; import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.AdditionalAnswers; import org.mockito.ArgumentCaptor; import org.mockito.InOrder; import org.mockito.Matchers; @@ -139,7 +141,6 @@ public class OkHttpClientTransportTest { @Rule public final Timeout globalTimeout = Timeout.seconds(10); - @Mock private FrameWriter frameWriter; private MethodDescriptor method = TestMethodDescriptors.voidMethod(); @@ -150,8 +151,10 @@ public class OkHttpClientTransportTest { private final SSLSocketFactory sslSocketFactory = null; private final HostnameVerifier hostnameVerifier = null; private final TransportTracer transportTracer = new TransportTracer(); + private final Queue capturedBuffer = new ArrayDeque<>(); private OkHttpClientTransport clientTransport; private MockFrameReader frameReader; + private Socket socket; private ExecutorService executor = Executors.newCachedThreadPool(); private long nanoTime; // backs a ticker, for testing ping round-trip time measurement private SettableFuture connectedFuture; @@ -166,8 +169,10 @@ public class OkHttpClientTransportTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); - when(frameWriter.maxDataLength()).thenReturn(Integer.MAX_VALUE); frameReader = new MockFrameReader(); + socket = new MockSocket(frameReader); + MockFrameWriter mockFrameWriter = new MockFrameWriter(socket, capturedBuffer); + frameWriter = mock(FrameWriter.class, AdditionalAnswers.delegatesTo(mockFrameWriter)); } @After @@ -217,7 +222,7 @@ public class OkHttpClientTransportTest { frameReader, frameWriter, startId, - new MockSocket(frameReader), + socket, stopwatchSupplier, connectingCallback, connectedFuture, @@ -556,10 +561,9 @@ public class OkHttpClientTransportTest { assertEquals(12, input.available()); stream.writeMessage(input); stream.flush(); - ArgumentCaptor captor = ArgumentCaptor.forClass(Buffer.class); verify(frameWriter, timeout(TIME_OUT_MS)) - .data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH)); - Buffer sentFrame = captor.getValue(); + .data(eq(false), eq(3), any(Buffer.class), eq(12 + HEADER_LENGTH)); + Buffer sentFrame = capturedBuffer.poll(); assertEquals(createMessageFrame(message), sentFrame); stream.cancel(Status.CANCELLED); shutdownAndVerify(); @@ -888,11 +892,9 @@ public class OkHttpClientTransportTest { assertEquals(22, input.available()); stream.writeMessage(input); stream.flush(); - ArgumentCaptor captor = - ArgumentCaptor.forClass(Buffer.class); verify(frameWriter, timeout(TIME_OUT_MS)) - .data(eq(false), eq(3), captor.capture(), eq(22 + HEADER_LENGTH)); - Buffer sentFrame = captor.getValue(); + .data(eq(false), eq(3), any(Buffer.class), eq(22 + HEADER_LENGTH)); + Buffer sentFrame = capturedBuffer.poll(); assertEquals(createMessageFrame(sentMessage), sentFrame); // And read. @@ -976,10 +978,9 @@ public class OkHttpClientTransportTest { // The second stream should be active now, and the pending data should be sent out. assertEquals(1, activeStreamCount()); assertEquals(0, clientTransport.getPendingStreamSize()); - ArgumentCaptor captor = ArgumentCaptor.forClass(Buffer.class); verify(frameWriter, timeout(TIME_OUT_MS)) - .data(eq(true), eq(5), captor.capture(), eq(5 + HEADER_LENGTH)); - Buffer sentFrame = captor.getValue(); + .data(eq(true), eq(5), any(Buffer.class), eq(5 + HEADER_LENGTH)); + Buffer sentFrame = capturedBuffer.poll(); assertEquals(createMessageFrame(sentMessage), sentFrame); stream2.cancel(Status.CANCELLED); shutdownAndVerify(); @@ -1456,10 +1457,9 @@ public class OkHttpClientTransportTest { allowTransportConnected(); // The queued message should be sent out. - ArgumentCaptor captor = ArgumentCaptor.forClass(Buffer.class); verify(frameWriter, timeout(TIME_OUT_MS)) - .data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH)); - Buffer sentFrame = captor.getValue(); + .data(eq(false), eq(3), any(Buffer.class), eq(12 + HEADER_LENGTH)); + Buffer sentFrame = capturedBuffer.poll(); assertEquals(createMessageFrame(message), sentFrame); stream.cancel(Status.CANCELLED); shutdownAndVerify(); @@ -2091,7 +2091,6 @@ public class OkHttpClientTransportTest { verify(frameWriter, timeout(TIME_OUT_MS)).close(); } catch (IOException e) { throw new RuntimeException(e); - } frameReader.assertClosed(); } @@ -2113,4 +2112,83 @@ public class OkHttpClientTransportTest { throws ExecutionException, InterruptedException { return obj.getStats().get().data; } -} + + /** A FrameWriter to mock with CALL_REAL_METHODS option. */ + private static class MockFrameWriter implements FrameWriter { + + private Socket socket; + private Queue capturedBuffer; + + public MockFrameWriter(Socket socket, Queue capturedBuffer) { + // Sets a socket to close. Some tests assumes that FrameWriter will close underlying sink + // which will eventually close the socket. + this.socket = socket; + this.capturedBuffer = capturedBuffer; + } + + void setSocket(Socket socket) { + this.socket = socket; + } + + @Override + public void close() throws IOException { + socket.close(); + } + + @Override + public int maxDataLength() { + return Integer.MAX_VALUE; + } + + @Override + public void data(boolean outFinished, int streamId, Buffer source, int byteCount) + throws IOException { + // simulate the side effect, and captures to internal queue. + Buffer capture = new Buffer(); + capture.write(source, byteCount); + capturedBuffer.add(capture); + } + + // rest of methods are unimplemented + + @Override + public void connectionPreface() throws IOException {} + + @Override + public void ackSettings(Settings peerSettings) throws IOException {} + + @Override + public void pushPromise(int streamId, int promisedStreamId, List
requestHeaders) + throws IOException {} + + @Override + public void flush() throws IOException {} + + @Override + public void synStream(boolean outFinished, boolean inFinished, int streamId, + int associatedStreamId, List
headerBlock) throws IOException {} + + @Override + public void synReply(boolean outFinished, int streamId, List
headerBlock) + throws IOException {} + + @Override + public void headers(int streamId, List
headerBlock) throws IOException {} + + @Override + public void rstStream(int streamId, ErrorCode errorCode) throws IOException {} + + @Override + public void settings(Settings okHttpSettings) throws IOException {} + + @Override + public void ping(boolean ack, int payload1, int payload2) throws IOException {} + + @Override + public void goAway(int lastGoodStreamId, ErrorCode errorCode, byte[] debugData) + throws IOException {} + + @Override + public void windowUpdate(int streamId, long windowSizeIncrement) throws IOException {} + } +} \ No newline at end of file