okhttp: move async mechanism from FrameWriter to sink (AsyncSink) (#4916)

Optimize OkHttp transport's memory use by getting rid of queuing writes in
AsyncFrameWriter. If any write is pending due to connection issue or by flow
control, AsyncFrameWriter can use at least 8K per each task (task includes
buffer) even if the actual payload is very small. To merge pending writes,
Async mechanism is moved from AsyncFrameWriter to AsyncSink (AsyncSink 
is used by okio's FrameWriter). AsyncSink is still relying on okio's buffer to
decide merging writes or not.

Resolves #4860
This commit is contained in:
Jihun Cho 2018-12-28 17:20:03 -08:00 committed by GitHub
parent a4859c1e93
commit 6bf0936f8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 927 additions and 511 deletions

View File

@ -1,276 +0,0 @@
/*
* Copyright 2014 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.grpc.internal.SerializingExecutor;
import io.grpc.okhttp.internal.framed.ErrorCode;
import io.grpc.okhttp.internal.framed.FrameWriter;
import io.grpc.okhttp.internal.framed.Header;
import io.grpc.okhttp.internal.framed.Settings;
import java.io.IOException;
import java.net.Socket;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Level;
import java.util.logging.Logger;
import okio.Buffer;
class AsyncFrameWriter implements FrameWriter {
private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName());
private FrameWriter frameWriter;
private Socket socket;
// Although writes are thread-safe, we serialize them to prevent consuming many Threads that are
// just waiting on each other.
private final SerializingExecutor executor;
private final TransportExceptionHandler transportExceptionHandler;
private final AtomicLong flushCounter = new AtomicLong();
// Some exceptions are not very useful and add too much noise to the log
private static final Set<String> QUIET_ERRORS =
Collections.unmodifiableSet(new HashSet<>(Arrays.asList("Socket closed")));
public AsyncFrameWriter(
TransportExceptionHandler transportExceptionHandler, SerializingExecutor executor) {
this.transportExceptionHandler = transportExceptionHandler;
this.executor = executor;
}
/**
* Set the real frameWriter and the corresponding underlying socket, the socket is needed for
* closing.
*
* <p>should only be called by thread of executor.
*/
void becomeConnected(FrameWriter frameWriter, Socket socket) {
Preconditions.checkState(this.frameWriter == null,
"AsyncFrameWriter's setFrameWriter() should only be called once.");
this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter");
this.socket = Preconditions.checkNotNull(socket, "socket");
}
@Override
public void connectionPreface() {
executor.execute(new WriteRunnable() {
@Override
public void doRun() throws IOException {
frameWriter.connectionPreface();
}
});
}
@Override
public void ackSettings(final Settings peerSettings) {
executor.execute(new WriteRunnable() {
@Override
public void doRun() throws IOException {
frameWriter.ackSettings(peerSettings);
}
});
}
@Override
public void pushPromise(final int streamId, final int promisedStreamId,
final List<Header> requestHeaders) {
executor.execute(new WriteRunnable() {
@Override
public void doRun() throws IOException {
frameWriter.pushPromise(streamId, promisedStreamId, requestHeaders);
}
});
}
@Override
public void flush() {
// keep track of version of flushes to skip flush if another flush task is queued.
final long flushCount = flushCounter.incrementAndGet();
executor.execute(new WriteRunnable() {
@Override
public void doRun() throws IOException {
// There can be a flush starvation if there are continuous flood of flush is queued, this
// is not an issue with OkHttp since it flushes if the buffer is full.
if (flushCounter.get() == flushCount) {
frameWriter.flush();
}
}
});
}
@Override
public void synStream(final boolean outFinished, final boolean inFinished, final int streamId,
final int associatedStreamId, final List<Header> headerBlock) {
executor.execute(new WriteRunnable() {
@Override
public void doRun() throws IOException {
frameWriter.synStream(outFinished, inFinished, streamId, associatedStreamId, headerBlock);
}
});
}
@Override
public void synReply(final boolean outFinished, final int streamId,
final List<Header> headerBlock) {
executor.execute(new WriteRunnable() {
@Override
public void doRun() throws IOException {
frameWriter.synReply(outFinished, streamId, headerBlock);
}
});
}
@Override
public void headers(final int streamId, final List<Header> headerBlock) {
executor.execute(new WriteRunnable() {
@Override
public void doRun() throws IOException {
frameWriter.headers(streamId, headerBlock);
}
});
}
@Override
public void rstStream(final int streamId, final ErrorCode errorCode) {
executor.execute(new WriteRunnable() {
@Override
public void doRun() throws IOException {
frameWriter.rstStream(streamId, errorCode);
}
});
}
@Override
public void data(final boolean outFinished, final int streamId, final Buffer source,
final int byteCount) {
executor.execute(new WriteRunnable() {
@Override
public void doRun() throws IOException {
frameWriter.data(outFinished, streamId, source, byteCount);
}
});
}
@Override
public void settings(final Settings okHttpSettings) {
executor.execute(new WriteRunnable() {
@Override
public void doRun() throws IOException {
frameWriter.settings(okHttpSettings);
}
});
}
@Override
public void ping(final boolean ack, final int payload1, final int payload2) {
executor.execute(new WriteRunnable() {
@Override
public void doRun() throws IOException {
frameWriter.ping(ack, payload1, payload2);
}
});
}
@Override
public void goAway(final int lastGoodStreamId, final ErrorCode errorCode,
final byte[] debugData) {
executor.execute(new WriteRunnable() {
@Override
public void doRun() throws IOException {
frameWriter.goAway(lastGoodStreamId, errorCode, debugData);
// Flush it since after goAway, we are likely to close this writer.
frameWriter.flush();
}
});
}
@Override
public void windowUpdate(final int streamId, final long windowSizeIncrement) {
executor.execute(new WriteRunnable() {
@Override
public void doRun() throws IOException {
frameWriter.windowUpdate(streamId, windowSizeIncrement);
}
});
}
@Override
public void close() {
executor.execute(new Runnable() {
@Override
public void run() {
if (frameWriter != null) {
try {
frameWriter.close();
socket.close();
} catch (IOException e) {
log.log(getLogLevel(e), "Failed closing connection", e);
}
}
}
});
}
/**
* Accepts a throwable and returns the appropriate logging level. Uninteresting exceptions
* should not clutter the log.
*/
@VisibleForTesting
static Level getLogLevel(Throwable t) {
if (t instanceof IOException
&& t.getMessage() != null
&& QUIET_ERRORS.contains(t.getMessage())) {
return Level.FINE;
}
return Level.INFO;
}
private abstract class WriteRunnable implements Runnable {
@Override
public final void run() {
try {
if (frameWriter == null) {
throw new IOException("Unable to perform write due to unavailable frameWriter.");
}
doRun();
} catch (RuntimeException e) {
transportExceptionHandler.onException(e);
} catch (Exception e) {
transportExceptionHandler.onException(e);
}
}
public abstract void doRun() throws IOException;
}
@Override
public int maxDataLength() {
return frameWriter == null ? 0x4000 /* 16384, the minimum required by the HTTP/2 spec */
: frameWriter.maxDataLength();
}
/** A class that handles transport exception. */
interface TransportExceptionHandler {
/** Handles exception. */
void onException(Throwable throwable);
}
}

View File

@ -0,0 +1,162 @@
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import io.grpc.internal.SerializingExecutor;
import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler;
import java.io.IOException;
import java.net.Socket;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import okio.Buffer;
import okio.Sink;
import okio.Timeout;
/**
* A sink that asynchronously write / flushes a buffer internally. AsyncSink provides flush
* coalescing to minimize network packing transmit.
*/
final class AsyncSink implements Sink {
private final Object lock = new Object();
private final Buffer buffer = new Buffer();
private final SerializingExecutor serializingExecutor;
private final TransportExceptionHandler transportExceptionHandler;
@GuardedBy("lock")
private boolean writeEnqueued = false;
@GuardedBy("lock")
private boolean flushEnqueued = false;
private boolean closed = false;
@Nullable
private Sink sink;
@Nullable
private Socket socket;
private AsyncSink(SerializingExecutor executor, TransportExceptionHandler exceptionHandler) {
this.serializingExecutor = checkNotNull(executor, "executor");
this.transportExceptionHandler = checkNotNull(exceptionHandler, "exceptionHandler");
}
static AsyncSink sink(
SerializingExecutor executor, TransportExceptionHandler exceptionHandler) {
return new AsyncSink(executor, exceptionHandler);
}
/**
* Sets the actual sink. It is allowed to call write / flush operations on the sink iff calling
* this method is scheduled in the executor. The socket is needed for closing.
*
* <p>should only be called once by thread of executor.
*/
void becomeConnected(Sink sink, Socket socket) {
checkState(this.sink == null, "AsyncSink's becomeConnected should only be called once.");
this.sink = checkNotNull(sink, "sink");
this.socket = checkNotNull(socket, "socket");
}
@Override
public void write(Buffer source, long byteCount) throws IOException {
checkNotNull(source, "source");
if (closed) {
throw new IOException("closed");
}
synchronized (lock) {
buffer.write(source, byteCount);
if (writeEnqueued || flushEnqueued || buffer.completeSegmentByteCount() <= 0) {
return;
}
writeEnqueued = true;
}
serializingExecutor.execute(new Runnable() {
@Override
public void run() {
Buffer buf = new Buffer();
synchronized (lock) {
buf.write(buffer, buffer.completeSegmentByteCount());
writeEnqueued = false;
}
try {
sink.write(buf, buf.size());
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
});
}
@Override
public void flush() throws IOException {
if (closed) {
throw new IOException("closed");
}
synchronized (lock) {
if (flushEnqueued) {
return;
}
flushEnqueued = true;
}
serializingExecutor.execute(new Runnable() {
@Override
public void run() {
Buffer buf = new Buffer();
synchronized (lock) {
buf.write(buffer, buffer.size());
flushEnqueued = false;
}
try {
sink.write(buf, buf.size());
sink.flush();
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
});
}
@Override
public Timeout timeout() {
return Timeout.NONE;
}
@Override
public void close() {
if (closed) {
return;
}
closed = true;
serializingExecutor.execute(new Runnable() {
@Override
public void run() {
buffer.close();
try {
sink.close();
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
try {
socket.close();
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
});
}
}

View File

@ -0,0 +1,214 @@
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.grpc.okhttp.internal.framed.ErrorCode;
import io.grpc.okhttp.internal.framed.FrameWriter;
import io.grpc.okhttp.internal.framed.Header;
import io.grpc.okhttp.internal.framed.Settings;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import okio.Buffer;
final class ExceptionHandlingFrameWriter implements FrameWriter {
private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName());
// Some exceptions are not very useful and add too much noise to the log
private static final Set<String> QUIET_ERRORS =
Collections.unmodifiableSet(new HashSet<>(Arrays.asList("Socket closed")));
private final TransportExceptionHandler transportExceptionHandler;
private final FrameWriter frameWriter;
ExceptionHandlingFrameWriter(
TransportExceptionHandler transportExceptionHandler, FrameWriter frameWriter) {
this.transportExceptionHandler =
checkNotNull(transportExceptionHandler, "transportExceptionHandler");
this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter");
}
@Override
public void connectionPreface() {
try {
frameWriter.connectionPreface();
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
@Override
public void ackSettings(Settings peerSettings) {
try {
frameWriter.ackSettings(peerSettings);
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
@Override
public void pushPromise(int streamId, int promisedStreamId, List<Header> requestHeaders) {
try {
frameWriter.pushPromise(streamId, promisedStreamId, requestHeaders);
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
@Override
public void flush() {
try {
frameWriter.flush();
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
@Override
public void synStream(
boolean outFinished,
boolean inFinished,
int streamId,
int associatedStreamId,
List<Header> headerBlock) {
try {
frameWriter.synStream(outFinished, inFinished, streamId, associatedStreamId, headerBlock);
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
@Override
public void synReply(boolean outFinished, int streamId,
List<Header> headerBlock) {
try {
frameWriter.synReply(outFinished, streamId, headerBlock);
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
@Override
public void headers(int streamId, List<Header> headerBlock) {
try {
frameWriter.headers(streamId, headerBlock);
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
@Override
public void rstStream(int streamId, ErrorCode errorCode) {
try {
frameWriter.rstStream(streamId, errorCode);
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
@Override
public int maxDataLength() {
return frameWriter.maxDataLength();
}
@Override
public void data(boolean outFinished, int streamId, Buffer source, int byteCount) {
try {
frameWriter.data(outFinished, streamId, source, byteCount);
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
@Override
public void settings(Settings okHttpSettings) {
try {
frameWriter.settings(okHttpSettings);
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
@Override
public void ping(boolean ack, int payload1, int payload2) {
try {
frameWriter.ping(ack, payload1, payload2);
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
@Override
public void goAway(int lastGoodStreamId, ErrorCode errorCode,
byte[] debugData) {
try {
frameWriter.goAway(lastGoodStreamId, errorCode, debugData);
// Flush it since after goAway, we are likely to close this writer.
frameWriter.flush();
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
@Override
public void windowUpdate(int streamId, long windowSizeIncrement) {
try {
frameWriter.windowUpdate(streamId, windowSizeIncrement);
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
}
@Override
public void close() {
try {
frameWriter.close();
} catch (IOException e) {
log.log(getLogLevel(e), "Failed closing connection", e);
}
}
/**
* Accepts a throwable and returns the appropriate logging level. Uninteresting exceptions
* should not clutter the log.
*/
@VisibleForTesting
static Level getLogLevel(Throwable t) {
if (t instanceof IOException
&& t.getMessage() != null
&& QUIET_ERRORS.contains(t.getMessage())) {
return Level.FINE;
}
return Level.INFO;
}
/** A class that handles transport exception. */
interface TransportExceptionHandler {
/** Handles exception. */
void onException(Throwable throwable);
}
}

View File

@ -62,7 +62,7 @@ class OkHttpClientStream extends AbstractClientStream {
OkHttpClientStream(
MethodDescriptor<?, ?> method,
Metadata headers,
AsyncFrameWriter frameWriter,
ExceptionHandlingFrameWriter frameWriter,
OkHttpClientTransport transport,
OutboundFlowController outboundFlow,
Object lock,
@ -203,7 +203,7 @@ class OkHttpClientStream extends AbstractClientStream {
@GuardedBy("lock")
private int processedWindow;
@GuardedBy("lock")
private final AsyncFrameWriter frameWriter;
private final ExceptionHandlingFrameWriter frameWriter;
@GuardedBy("lock")
private final OutboundFlowController outboundFlow;
@GuardedBy("lock")
@ -216,7 +216,7 @@ class OkHttpClientStream extends AbstractClientStream {
int maxMessageSize,
StatsTraceContext statsTraceCtx,
Object lock,
AsyncFrameWriter frameWriter,
ExceptionHandlingFrameWriter frameWriter,
OutboundFlowController outboundFlow,
OkHttpClientTransport transport,
int initialWindowSize) {

View File

@ -57,7 +57,7 @@ import io.grpc.internal.SerializingExecutor;
import io.grpc.internal.SharedResourceHolder;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.TransportTracer;
import io.grpc.okhttp.AsyncFrameWriter.TransportExceptionHandler;
import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler;
import io.grpc.okhttp.internal.ConnectionSpec;
import io.grpc.okhttp.internal.framed.ErrorCode;
import io.grpc.okhttp.internal.framed.FrameReader;
@ -107,7 +107,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
private static final OkHttpClientStream[] EMPTY_STREAM_ARRAY = new OkHttpClientStream[0];
private static Map<ErrorCode, Status> buildErrorCodeToStatusMap() {
Map<ErrorCode, Status> errorToStatus = new EnumMap<ErrorCode, Status>(ErrorCode.class);
Map<ErrorCode, Status> errorToStatus = new EnumMap<>(ErrorCode.class);
errorToStatus.put(ErrorCode.NO_ERROR,
Status.INTERNAL.withDescription("No error: A GRPC status of OK should have been sent"));
errorToStatus.put(ErrorCode.PROTOCOL_ERROR,
@ -144,15 +144,15 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
private final int initialWindowSize;
private Listener listener;
private FrameReader testFrameReader;
private AsyncFrameWriter frameWriter;
@GuardedBy("lock")
private ExceptionHandlingFrameWriter frameWriter;
private OutboundFlowController outboundFlow;
private final Object lock = new Object();
private final InternalLogId logId;
@GuardedBy("lock")
private int nextStreamId;
@GuardedBy("lock")
private final Map<Integer, OkHttpClientStream> streams =
new HashMap<Integer, OkHttpClientStream>();
private final Map<Integer, OkHttpClientStream> streams = new HashMap<>();
private final Executor executor;
// Wrap on executor, to guarantee some operations be executed serially.
private final SerializingExecutor serializingExecutor;
@ -182,7 +182,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
private int maxConcurrentStreams = 0;
@SuppressWarnings("JdkObsolete") // Usage is bursty; want low memory usage when empty
@GuardedBy("lock")
private LinkedList<OkHttpClientStream> pendingStreams = new LinkedList<OkHttpClientStream>();
private LinkedList<OkHttpClientStream> pendingStreams = new LinkedList<>();
private final ConnectionSpec connectionSpec;
private FrameWriter testFrameWriter;
private ScheduledExecutorService scheduler;
@ -219,7 +219,6 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
Runnable connectingCallback;
SettableFuture<Void> connectedFuture;
OkHttpClientTransport(InetSocketAddress address, String authority, @Nullable String userAgent,
Executor executor, @Nullable SSLSocketFactory sslSocketFactory,
@Nullable HostnameVerifier hostnameVerifier, ConnectionSpec connectionSpec,
@ -322,11 +321,11 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
@Override
public void ping(final PingCallback callback, Executor executor) {
checkState(frameWriter != null);
long data = 0;
Http2Ping p;
boolean writePing;
synchronized (lock) {
checkState(frameWriter != null);
if (stopped) {
Http2Ping.notifyFailed(callback, executor, getPingFailure());
return;
@ -345,9 +344,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
writePing = true;
transportTracer.reportKeepAliveSent();
}
}
if (writePing) {
frameWriter.ping(false, (int) (data >>> 32), (int) data);
if (writePing) {
frameWriter.ping(false, (int) (data >>> 32), (int) data);
}
}
// If transport concurrently failed/stopped since we released the lock above, this could
// immediately invoke callback (which we shouldn't do while holding a lock)
@ -449,15 +448,15 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
keepAliveWithoutCalls);
keepAliveManager.onTransportStarted();
}
frameWriter = new AsyncFrameWriter(this, serializingExecutor);
outboundFlow = new OutboundFlowController(this, frameWriter, initialWindowSize);
// Connecting in the serializingExecutor, so that some stream operations like synStream
// will be executed after connected.
serializingExecutor.execute(new Runnable() {
@Override
public void run() {
if (isForTest()) {
if (isForTest()) {
synchronized (lock) {
frameWriter = new ExceptionHandlingFrameWriter(OkHttpClientTransport.this, testFrameWriter);
outboundFlow =
new OutboundFlowController(OkHttpClientTransport.this, frameWriter, initialWindowSize);
}
serializingExecutor.execute(new Runnable() {
@Override
public void run() {
if (connectingCallback != null) {
connectingCallback.run();
}
@ -467,11 +466,25 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
maxConcurrentStreams = Integer.MAX_VALUE;
startPendingStreams();
}
frameWriter.becomeConnected(testFrameWriter, socket);
connectedFuture.set(null);
return;
}
});
return null;
}
final AsyncSink asyncSink = AsyncSink.sink(serializingExecutor, this);
final Variant variant = new Http2();
FrameWriter rawFrameWriter = variant.newWriter(Okio.buffer(asyncSink), true);
synchronized (lock) {
frameWriter = new ExceptionHandlingFrameWriter(this, rawFrameWriter);
outboundFlow = new OutboundFlowController(this, frameWriter, initialWindowSize);
}
// Connecting in the serializingExecutor, so that some stream operations like synStream
// will be executed after connected.
serializingExecutor.execute(new Runnable() {
@Override
public void run() {
// Use closed source on failure so that the reader immediately shuts down.
BufferedSource source = Okio.buffer(new Source() {
@Override
@ -485,10 +498,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
}
@Override
public void close() {}
public void close() {
}
});
Variant variant = new Http2();
BufferedSink sink;
Socket sock;
SSLSession sslSession = null;
try {
@ -508,7 +520,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
}
sock.setTcpNoDelay(true);
source = Okio.buffer(Okio.source(sock));
sink = Okio.buffer(Okio.sink(sock));
asyncSink.becomeConnected(Okio.sink(sock), sock);
// The return value of OkHttpTlsUpgrader.upgrade is an SSLSocket that has this info
attributes = Attributes
.newBuilder()
@ -526,31 +539,29 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
return;
} finally {
clientFrameHandler = new ClientFrameHandler(variant.newReader(source, true));
executor.execute(clientFrameHandler);
}
FrameWriter rawFrameWriter;
synchronized (lock) {
socket = Preconditions.checkNotNull(sock, "socket");
maxConcurrentStreams = Integer.MAX_VALUE;
startPendingStreams();
if (sslSession != null) {
securityInfo = new InternalChannelz.Security(new InternalChannelz.Tls(sslSession));
}
}
rawFrameWriter = variant.newWriter(sink, true);
frameWriter.becomeConnected(rawFrameWriter, socket);
try {
// Do these with the raw FrameWriter, so that they will be done in this thread,
// and before any possible pending stream operations.
rawFrameWriter.connectionPreface();
Settings settings = new Settings();
rawFrameWriter.settings(settings);
} catch (Exception e) {
onException(e);
return;
}
});
synchronized (lock) {
frameWriter.connectionPreface();
Settings settings = new Settings();
frameWriter.settings(settings);
}
serializingExecutor.execute(new Runnable() {
@Override
public void run() {
// ClientFrameHandler need to be started after connectionPreface / settings, otherwise it
// may send goAway immediately.
executor.execute(clientFrameHandler);
synchronized (lock) {
maxConcurrentStreams = Integer.MAX_VALUE;
startPendingStreams();
}
}
});
@ -558,7 +569,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
}
private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddress proxyAddress,
String proxyUsername, String proxyPassword) throws IOException, StatusException {
String proxyUsername, String proxyPassword) throws StatusException {
try {
Socket sock;
// The proxy address may not be resolved
@ -1039,7 +1050,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
OkHttpClientStream stream = getStream(streamId);
if (stream == null) {
if (mayHaveCreatedStream(streamId)) {
frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM);
synchronized (lock) {
frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM);
}
in.skip(length);
} else {
onError(ErrorCode.PROTOCOL_ERROR, "Received data for unknown stream: " + streamId);
@ -1059,7 +1072,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
// connection window update
connectionUnacknowledgedBytesRead += length;
if (connectionUnacknowledgedBytesRead >= initialWindowSize * DEFAULT_WINDOW_UPDATE_RATIO) {
frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead);
synchronized (lock) {
frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead);
}
connectionUnacknowledgedBytesRead = 0;
}
}
@ -1170,7 +1185,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
@Override
public void ping(boolean ack, int payload1, int payload2) {
if (!ack) {
frameWriter.ping(true, payload1, payload2);
synchronized (lock) {
frameWriter.ping(true, payload1, payload2);
}
} else {
Http2Ping p = null;
long ackPayload = (((long) payload1) << 32) | (payload2 & 0xffffffffL);
@ -1222,7 +1239,9 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
public void pushPromise(int streamId, int promisedStreamId, List<Header> requestHeaders)
throws IOException {
// We don't accept server initiated stream.
frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR);
synchronized (lock) {
frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR);
}
}
@Override

View File

@ -323,12 +323,7 @@ class OutboundFlowController {
try {
// endOfStream is set for the last chunk of data marked as endOfStream
boolean isEndOfStream = buffer.size() == frameBytes && endOfStream;
// AsyncFrameWriter drains buffer in executor. To avoid race, copy to temp.
// TODO(jihuncho) remove temp buff when async logic is moved to AsyncSink.
Buffer temp = new Buffer();
temp.write(buffer, frameBytes);
frameWriter.data(isEndOfStream, streamId, temp, frameBytes);
frameWriter.data(isEndOfStream, streamId, buffer, frameBytes);
} catch (IOException e) {
throw new RuntimeException(e);
}

View File

@ -1,149 +0,0 @@
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.okhttp.AsyncFrameWriter.getLogLevel;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import io.grpc.internal.SerializingExecutor;
import io.grpc.okhttp.AsyncFrameWriter.TransportExceptionHandler;
import io.grpc.okhttp.internal.framed.FrameWriter;
import java.io.IOException;
import java.net.Socket;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
import java.util.logging.Level;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
@RunWith(MockitoJUnitRunner.class)
public class AsyncFrameWriterTest {
@Mock private Socket socket;
@Mock private FrameWriter frameWriter;
private QueueingExecutor queueingExecutor = new QueueingExecutor();
private TransportExceptionHandler transportExceptionHandler =
new EscalatingTransportErrorHandler();
private AsyncFrameWriter asyncFrameWriter =
new AsyncFrameWriter(transportExceptionHandler, new SerializingExecutor(queueingExecutor));
@Before
public void setUp() throws Exception {
asyncFrameWriter.becomeConnected(frameWriter, socket);
}
@Test
public void noCoalesceRequired() throws IOException {
asyncFrameWriter.ping(true, 0, 1);
asyncFrameWriter.flush();
queueingExecutor.runAll();
verify(frameWriter, times(1)).ping(anyBoolean(), anyInt(), anyInt());
verify(frameWriter, times(1)).flush();
}
@Test
public void flushCoalescing_shouldNotMergeTwoDistinctFlushes() throws IOException {
asyncFrameWriter.ping(true, 0, 1);
asyncFrameWriter.flush();
queueingExecutor.runAll();
asyncFrameWriter.ping(true, 0, 2);
asyncFrameWriter.flush();
queueingExecutor.runAll();
verify(frameWriter, times(2)).ping(anyBoolean(), anyInt(), anyInt());
verify(frameWriter, times(2)).flush();
}
@Test
public void flushCoalescing_shouldMergeTwoQueuedFlushes() throws IOException {
asyncFrameWriter.ping(true, 0, 1);
asyncFrameWriter.flush();
asyncFrameWriter.ping(true, 0, 2);
asyncFrameWriter.flush();
queueingExecutor.runAll();
InOrder inOrder = inOrder(frameWriter);
inOrder.verify(frameWriter, times(2)).ping(anyBoolean(), anyInt(), anyInt());
inOrder.verify(frameWriter).flush();
}
@Test
public void unknownException() {
assertThat(getLogLevel(new Exception())).isEqualTo(Level.INFO);
}
@Test
public void quiet() {
assertThat(getLogLevel(new IOException("Socket closed"))).isEqualTo(Level.FINE);
}
@Test
public void nonquiet() {
assertThat(getLogLevel(new IOException("foo"))).isEqualTo(Level.INFO);
}
@Test
public void nullMessage() {
IOException e = new IOException();
assertThat(e.getMessage()).isNull();
assertThat(getLogLevel(e)).isEqualTo(Level.INFO);
}
/**
* Executor queues incoming runnables instead of running it. Runnables can be invoked via {@link
* QueueingExecutor#runAll} in serial order.
*/
private static class QueueingExecutor implements Executor {
private final Queue<Runnable> runnables = new ConcurrentLinkedQueue<Runnable>();
@Override
public void execute(Runnable command) {
runnables.add(command);
}
public void runAll() {
Runnable r;
while ((r = runnables.poll()) != null) {
r.run();
}
}
}
/** Rethrows as Assertion error. */
private static class EscalatingTransportErrorHandler implements TransportExceptionHandler {
@Override
public void onException(Throwable throwable) {
throw new AssertionError(throwable);
}
}
}

View File

@ -0,0 +1,289 @@
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.CALLS_REAL_METHODS;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import com.google.common.base.Charsets;
import io.grpc.internal.SerializingExecutor;
import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler;
import java.io.IOException;
import java.net.Socket;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
import okio.Buffer;
import okio.Sink;
import okio.Timeout;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.InOrder;
/** Tests for {@link AsyncSink}. */
@RunWith(JUnit4.class)
public class AsyncSinkTest {
private Socket socket = mock(Socket.class);
private Sink mockedSink = mock(VoidSink.class, CALLS_REAL_METHODS);
private QueueingExecutor queueingExecutor = new QueueingExecutor();
private TransportExceptionHandler exceptionHandler = mock(TransportExceptionHandler.class);
private AsyncSink sink =
AsyncSink.sink(new SerializingExecutor(queueingExecutor), exceptionHandler);
@Before
public void setUp() throws Exception {
sink.becomeConnected(mockedSink, socket);
}
@Test
public void noCoalesceRequired() throws IOException {
Buffer buffer = new Buffer();
sink.write(buffer.writeUtf8("hello"), buffer.size());
sink.flush();
queueingExecutor.runAll();
InOrder inOrder = inOrder(mockedSink);
inOrder.verify(mockedSink).write(any(Buffer.class), anyInt());
inOrder.verify(mockedSink).flush();
}
@Test
public void flushCoalescing_shouldNotMergeTwoDistinctFlushes() throws IOException {
byte[] firstData = "a string".getBytes(Charsets.UTF_8);
byte[] secondData = "a longer string".getBytes(Charsets.UTF_8);
Buffer buffer = new Buffer();
sink.write(buffer.write(firstData), buffer.size());
sink.flush();
queueingExecutor.runAll();
sink.write(buffer.write(secondData), buffer.size());
sink.flush();
queueingExecutor.runAll();
InOrder inOrder = inOrder(mockedSink);
inOrder.verify(mockedSink).write(any(Buffer.class), anyInt());
inOrder.verify(mockedSink).flush();
inOrder.verify(mockedSink).write(any(Buffer.class), anyInt());
inOrder.verify(mockedSink).flush();
}
@Test
public void flushCoalescing_shouldMergeTwoQueuedFlushesAndWrites() throws IOException {
byte[] firstData = "a string".getBytes(Charsets.UTF_8);
byte[] secondData = "a longer string".getBytes(Charsets.UTF_8);
Buffer buffer = new Buffer().write(firstData);
sink.write(buffer, buffer.size());
sink.flush();
buffer = new Buffer().write(secondData);
sink.write(buffer, buffer.size());
sink.flush();
queueingExecutor.runAll();
InOrder inOrder = inOrder(mockedSink);
inOrder.verify(mockedSink)
.write(any(Buffer.class), eq((long) firstData.length + secondData.length));
inOrder.verify(mockedSink).flush();
}
@Test
public void flushCoalescing_shouldMergeWrites() throws IOException {
byte[] firstData = "a string".getBytes(Charsets.UTF_8);
byte[] secondData = "a longer string".getBytes(Charsets.UTF_8);
Buffer buffer = new Buffer();
sink.write(buffer.write(firstData), buffer.size());
sink.write(buffer.write(secondData), buffer.size());
sink.flush();
queueingExecutor.runAll();
InOrder inOrder = inOrder(mockedSink);
inOrder.verify(mockedSink)
.write(any(Buffer.class), eq((long) firstData.length + secondData.length));
inOrder.verify(mockedSink).flush();
}
@Test
public void write_shouldCachePreviousException() throws IOException {
Exception ioException = new IOException("some exception");
doThrow(ioException)
.when(mockedSink).write(any(Buffer.class), anyLong());
Buffer buffer = new Buffer();
buffer.writeUtf8("any message");
sink.write(buffer, buffer.size());
sink.flush();
queueingExecutor.runAll();
sink.write(buffer, buffer.size());
queueingExecutor.runAll();
verify(exceptionHandler, timeout(1000)).onException(ioException);
}
@Test
public void close_writeShouldThrowException() {
sink.close();
queueingExecutor.runAll();
try {
sink.write(new Buffer(), 0);
fail("should throw ioException");
} catch (IOException e) {
assertThat(e).hasMessageThat().contains("closed");
}
}
@Test
public void write_shouldThrowIfAlreadyClosed() throws IOException {
Exception ioException = new IOException("some exception");
doThrow(ioException)
.when(mockedSink).write(any(Buffer.class), anyLong());
Buffer buffer = new Buffer();
buffer.writeUtf8("any message");
sink.write(buffer, buffer.size());
sink.close();
queueingExecutor.runAll();
try {
sink.write(buffer, buffer.size());
queueingExecutor.runAll();
fail("should throw ioException");
} catch (IOException e) {
assertThat(e).hasMessageThat().contains("closed");
}
}
@Test
public void close_flushShouldThrowException() throws IOException {
sink.close();
queueingExecutor.runAll();
try {
sink.flush();
queueingExecutor.runAll();
fail("should fail");
} catch (IOException e) {
assertThat(e).hasMessageThat().contains("closed");
}
}
@Test
public void flush_shouldThrowIfAlreadyClosed() throws IOException {
Buffer buffer = new Buffer();
buffer.writeUtf8("any message");
sink.write(buffer, buffer.size());
sink.close();
queueingExecutor.runAll();
try {
sink.flush();
queueingExecutor.runAll();
fail("should fail");
} catch (IOException e) {
assertThat(e).hasMessageThat().contains("closed");
}
}
@Test
public void write_callSinkIfBufferIsLargerThanSegmentSize() throws IOException {
Buffer buffer = new Buffer();
// OkHttp is using 8192 as segment size.
int payloadSize = 8192 * 2 - 1;
int padding = 10;
buffer.write(new byte[payloadSize]);
int completeSegmentBytes = (int) buffer.completeSegmentByteCount();
assertThat(completeSegmentBytes).isLessThan(payloadSize);
// first trying to send of all complete segments, but not the padding
sink.write(buffer, completeSegmentBytes + padding);
queueingExecutor.runAll();
verify(mockedSink).write(any(Buffer.class), eq((long) completeSegmentBytes));
verify(mockedSink, never()).flush();
assertThat(buffer.size()).isEqualTo((long) (payloadSize - completeSegmentBytes - padding));
// writing smaller than completed segment, shouldn't trigger write to Sink.
reset(mockedSink);
sink.write(buffer, buffer.size());
queueingExecutor.runAll();
verify(mockedSink, never()).write(any(Buffer.class), anyLong());
verify(mockedSink, never()).flush();
assertThat(buffer.exhausted()).isTrue();
// flush should write everything.
sink.flush();
queueingExecutor.runAll();
verify(mockedSink).write(any(Buffer.class),eq((long) payloadSize - completeSegmentBytes));
verify(mockedSink).flush();
}
/**
* Executor queues incoming runnables instead of running it. Runnables can be invoked via {@link
* QueueingExecutor#runAll} in serial order.
*/
private static class QueueingExecutor implements Executor {
private final Queue<Runnable> runnables = new ConcurrentLinkedQueue<>();
@Override
public void execute(Runnable command) {
runnables.add(command);
}
public void runAll() {
Runnable r;
while ((r = runnables.poll()) != null) {
r.run();
}
}
}
/** Test sink to mimic real Sink behavior since write has a side effect. */
private static class VoidSink implements Sink {
@Override
public void write(Buffer source, long byteCount) throws IOException {
// removes byteCount bytes from source.
source.read(new byte[(int) byteCount], 0, (int) byteCount);
}
@Override
public void flush() throws IOException {
// do nothing
}
@Override
public Timeout timeout() {
return Timeout.NONE;
}
@Override
public void close() throws IOException {
// do nothing
}
}
}

View File

@ -0,0 +1,77 @@
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.okhttp.ExceptionHandlingFrameWriter.getLogLevel;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler;
import io.grpc.okhttp.internal.framed.FrameWriter;
import io.grpc.okhttp.internal.framed.Header;
import java.io.IOException;
import java.util.ArrayList;
import java.util.logging.Level;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class ExceptionHandlingFrameWriterTest {
private FrameWriter mockedFrameWriter = mock(FrameWriter.class);
private TransportExceptionHandler transportExceptionHandler =
mock(TransportExceptionHandler.class);
private ExceptionHandlingFrameWriter exceptionHandlingFrameWriter =
new ExceptionHandlingFrameWriter(transportExceptionHandler, mockedFrameWriter);
@Test
public void exception() throws IOException {
IOException exception = new IOException("some exception");
doThrow(exception).when(mockedFrameWriter)
.synReply(false, 100, new ArrayList<Header>());
exceptionHandlingFrameWriter.synReply(false, 100, new ArrayList<Header>());
verify(transportExceptionHandler).onException(exception);
verify(mockedFrameWriter).synReply(false, 100, new ArrayList<Header>());
}
@Test
public void unknownException() {
assertThat(getLogLevel(new Exception())).isEqualTo(Level.INFO);
}
@Test
public void quiet() {
assertThat(getLogLevel(new IOException("Socket closed"))).isEqualTo(Level.FINE);
}
@Test
public void nonquiet() {
assertThat(getLogLevel(new IOException("foo"))).isEqualTo(Level.INFO);
}
@Test
public void nullMessage() {
IOException e = new IOException();
assertThat(e.getMessage()).isNull();
assertThat(getLogLevel(e)).isEqualTo(Level.INFO);
}
}

View File

@ -37,8 +37,10 @@ import io.grpc.internal.NoopClientStreamListener;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.TransportTracer;
import io.grpc.okhttp.internal.framed.ErrorCode;
import io.grpc.okhttp.internal.framed.FrameWriter;
import io.grpc.okhttp.internal.framed.Header;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
@ -60,7 +62,8 @@ public class OkHttpClientStreamTest {
private static final int INITIAL_WINDOW_SIZE = 65535;
@Mock private MethodDescriptor.Marshaller<Void> marshaller;
@Mock private AsyncFrameWriter frameWriter;
@Mock private FrameWriter mockedFrameWriter;
private ExceptionHandlingFrameWriter frameWriter;
@Mock private OkHttpClientTransport transport;
@Mock private OutboundFlowController flowController;
@Captor private ArgumentCaptor<List<Header>> headersCaptor;
@ -81,6 +84,8 @@ public class OkHttpClientStreamTest {
.setResponseMarshaller(marshaller)
.build();
frameWriter =
new ExceptionHandlingFrameWriter(transport, mockedFrameWriter);
stream = new OkHttpClientStream(
methodDescriptor,
new Metadata(),
@ -144,11 +149,11 @@ public class OkHttpClientStreamTest {
stream.transportState().start(1234);
verifyNoMoreInteractions(frameWriter);
verifyNoMoreInteractions(mockedFrameWriter);
}
@Test
public void start_userAgentRemoved() {
public void start_userAgentRemoved() throws IOException {
Metadata metaData = new Metadata();
metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application");
stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport,
@ -157,13 +162,14 @@ public class OkHttpClientStreamTest {
stream.start(new BaseClientStreamListener());
stream.transportState().start(3);
verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
verify(mockedFrameWriter)
.synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
assertThat(headersCaptor.getValue())
.contains(new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application"));
}
@Test
public void start_headerFieldOrder() {
public void start_headerFieldOrder() throws IOException {
Metadata metaData = new Metadata();
metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application");
stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport,
@ -172,7 +178,8 @@ public class OkHttpClientStreamTest {
stream.start(new BaseClientStreamListener());
stream.transportState().start(3);
verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
verify(mockedFrameWriter)
.synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
assertThat(headersCaptor.getValue()).containsExactly(
Headers.SCHEME_HEADER,
Headers.METHOD_HEADER,
@ -185,7 +192,7 @@ public class OkHttpClientStreamTest {
}
@Test
public void getUnaryRequest() {
public void getUnaryRequest() throws IOException {
MethodDescriptor<?, ?> getMethod = MethodDescriptor.<Void, Void>newBuilder()
.setType(MethodDescriptor.MethodType.UNARY)
.setFullMethodName("service/method")
@ -200,7 +207,7 @@ public class OkHttpClientStreamTest {
stream.start(new BaseClientStreamListener());
// GET streams send headers after halfClose is called.
verify(frameWriter, times(0)).synStream(
verify(mockedFrameWriter, times(0)).synStream(
eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class));
@ -210,7 +217,7 @@ public class OkHttpClientStreamTest {
verify(transport).streamReadyToStart(eq(stream));
stream.transportState().start(3);
verify(frameWriter)
verify(mockedFrameWriter)
.synStream(eq(true), eq(false), eq(3), eq(0), headersCaptor.capture());
assertThat(headersCaptor.getValue()).contains(Headers.METHOD_GET_HEADER);
assertThat(headersCaptor.getValue()).contains(

View File

@ -46,7 +46,6 @@ import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import com.google.common.base.Stopwatch;
import com.google.common.base.Supplier;
@ -91,9 +90,11 @@ import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
@ -113,6 +114,7 @@ import org.junit.Test;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.AdditionalAnswers;
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.Matchers;
@ -139,7 +141,6 @@ public class OkHttpClientTransportTest {
@Rule public final Timeout globalTimeout = Timeout.seconds(10);
@Mock
private FrameWriter frameWriter;
private MethodDescriptor<Void, Void> method = TestMethodDescriptors.voidMethod();
@ -150,8 +151,10 @@ public class OkHttpClientTransportTest {
private final SSLSocketFactory sslSocketFactory = null;
private final HostnameVerifier hostnameVerifier = null;
private final TransportTracer transportTracer = new TransportTracer();
private final Queue<Buffer> capturedBuffer = new ArrayDeque<>();
private OkHttpClientTransport clientTransport;
private MockFrameReader frameReader;
private Socket socket;
private ExecutorService executor = Executors.newCachedThreadPool();
private long nanoTime; // backs a ticker, for testing ping round-trip time measurement
private SettableFuture<Void> connectedFuture;
@ -166,8 +169,10 @@ public class OkHttpClientTransportTest {
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
when(frameWriter.maxDataLength()).thenReturn(Integer.MAX_VALUE);
frameReader = new MockFrameReader();
socket = new MockSocket(frameReader);
MockFrameWriter mockFrameWriter = new MockFrameWriter(socket, capturedBuffer);
frameWriter = mock(FrameWriter.class, AdditionalAnswers.delegatesTo(mockFrameWriter));
}
@After
@ -217,7 +222,7 @@ public class OkHttpClientTransportTest {
frameReader,
frameWriter,
startId,
new MockSocket(frameReader),
socket,
stopwatchSupplier,
connectingCallback,
connectedFuture,
@ -556,10 +561,9 @@ public class OkHttpClientTransportTest {
assertEquals(12, input.available());
stream.writeMessage(input);
stream.flush();
ArgumentCaptor<Buffer> captor = ArgumentCaptor.forClass(Buffer.class);
verify(frameWriter, timeout(TIME_OUT_MS))
.data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH));
Buffer sentFrame = captor.getValue();
.data(eq(false), eq(3), any(Buffer.class), eq(12 + HEADER_LENGTH));
Buffer sentFrame = capturedBuffer.poll();
assertEquals(createMessageFrame(message), sentFrame);
stream.cancel(Status.CANCELLED);
shutdownAndVerify();
@ -888,11 +892,9 @@ public class OkHttpClientTransportTest {
assertEquals(22, input.available());
stream.writeMessage(input);
stream.flush();
ArgumentCaptor<Buffer> captor =
ArgumentCaptor.forClass(Buffer.class);
verify(frameWriter, timeout(TIME_OUT_MS))
.data(eq(false), eq(3), captor.capture(), eq(22 + HEADER_LENGTH));
Buffer sentFrame = captor.getValue();
.data(eq(false), eq(3), any(Buffer.class), eq(22 + HEADER_LENGTH));
Buffer sentFrame = capturedBuffer.poll();
assertEquals(createMessageFrame(sentMessage), sentFrame);
// And read.
@ -976,10 +978,9 @@ public class OkHttpClientTransportTest {
// The second stream should be active now, and the pending data should be sent out.
assertEquals(1, activeStreamCount());
assertEquals(0, clientTransport.getPendingStreamSize());
ArgumentCaptor<Buffer> captor = ArgumentCaptor.forClass(Buffer.class);
verify(frameWriter, timeout(TIME_OUT_MS))
.data(eq(true), eq(5), captor.capture(), eq(5 + HEADER_LENGTH));
Buffer sentFrame = captor.getValue();
.data(eq(true), eq(5), any(Buffer.class), eq(5 + HEADER_LENGTH));
Buffer sentFrame = capturedBuffer.poll();
assertEquals(createMessageFrame(sentMessage), sentFrame);
stream2.cancel(Status.CANCELLED);
shutdownAndVerify();
@ -1456,10 +1457,9 @@ public class OkHttpClientTransportTest {
allowTransportConnected();
// The queued message should be sent out.
ArgumentCaptor<Buffer> captor = ArgumentCaptor.forClass(Buffer.class);
verify(frameWriter, timeout(TIME_OUT_MS))
.data(eq(false), eq(3), captor.capture(), eq(12 + HEADER_LENGTH));
Buffer sentFrame = captor.getValue();
.data(eq(false), eq(3), any(Buffer.class), eq(12 + HEADER_LENGTH));
Buffer sentFrame = capturedBuffer.poll();
assertEquals(createMessageFrame(message), sentFrame);
stream.cancel(Status.CANCELLED);
shutdownAndVerify();
@ -2091,7 +2091,6 @@ public class OkHttpClientTransportTest {
verify(frameWriter, timeout(TIME_OUT_MS)).close();
} catch (IOException e) {
throw new RuntimeException(e);
}
frameReader.assertClosed();
}
@ -2113,4 +2112,83 @@ public class OkHttpClientTransportTest {
throws ExecutionException, InterruptedException {
return obj.getStats().get().data;
}
}
/** A FrameWriter to mock with CALL_REAL_METHODS option. */
private static class MockFrameWriter implements FrameWriter {
private Socket socket;
private Queue<Buffer> capturedBuffer;
public MockFrameWriter(Socket socket, Queue<Buffer> capturedBuffer) {
// Sets a socket to close. Some tests assumes that FrameWriter will close underlying sink
// which will eventually close the socket.
this.socket = socket;
this.capturedBuffer = capturedBuffer;
}
void setSocket(Socket socket) {
this.socket = socket;
}
@Override
public void close() throws IOException {
socket.close();
}
@Override
public int maxDataLength() {
return Integer.MAX_VALUE;
}
@Override
public void data(boolean outFinished, int streamId, Buffer source, int byteCount)
throws IOException {
// simulate the side effect, and captures to internal queue.
Buffer capture = new Buffer();
capture.write(source, byteCount);
capturedBuffer.add(capture);
}
// rest of methods are unimplemented
@Override
public void connectionPreface() throws IOException {}
@Override
public void ackSettings(Settings peerSettings) throws IOException {}
@Override
public void pushPromise(int streamId, int promisedStreamId, List<Header> requestHeaders)
throws IOException {}
@Override
public void flush() throws IOException {}
@Override
public void synStream(boolean outFinished, boolean inFinished, int streamId,
int associatedStreamId, List<Header> headerBlock) throws IOException {}
@Override
public void synReply(boolean outFinished, int streamId, List<Header> headerBlock)
throws IOException {}
@Override
public void headers(int streamId, List<Header> headerBlock) throws IOException {}
@Override
public void rstStream(int streamId, ErrorCode errorCode) throws IOException {}
@Override
public void settings(Settings okHttpSettings) throws IOException {}
@Override
public void ping(boolean ack, int payload1, int payload2) throws IOException {}
@Override
public void goAway(int lastGoodStreamId, ErrorCode errorCode, byte[] debugData)
throws IOException {}
@Override
public void windowUpdate(int streamId, long windowSizeIncrement) throws IOException {}
}
}