Improve okhttp client transport, handles go away and add unit test.

-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=72155172
This commit is contained in:
simonma 2014-07-29 09:52:20 -07:00 committed by Eric Anderson
parent 5f334f7c52
commit 7bf17dc4d6
3 changed files with 717 additions and 61 deletions

View File

@ -1,7 +1,8 @@
package com.google.net.stubby.newtransport.okhttp;
import com.google.common.util.concurrent.SerializingExecutor;
import com.google.common.util.concurrent.Service;
import com.google.net.stubby.Status;
import com.google.net.stubby.transport.Transport.Code;
import com.squareup.okhttp.internal.spdy.ErrorCode;
import com.squareup.okhttp.internal.spdy.FrameWriter;
@ -17,9 +18,10 @@ import java.util.concurrent.Executor;
class AsyncFrameWriter implements FrameWriter {
private final FrameWriter frameWriter;
private final Executor executor;
private final Service transport;
private final OkHttpClientTransport transport;
public AsyncFrameWriter(FrameWriter frameWriter, Service transport, Executor executor) {
public AsyncFrameWriter(FrameWriter frameWriter, OkHttpClientTransport transport,
Executor executor) {
this.frameWriter = frameWriter;
this.transport = transport;
// Although writes are thread-safe, we serialize them to prevent consuming many Threads that are
@ -158,6 +160,8 @@ class AsyncFrameWriter implements FrameWriter {
@Override
public void doRun() throws IOException {
frameWriter.goAway(lastGoodStreamId, errorCode, debugData);
// Flush it since after goAway, we are likely to close this writer.
frameWriter.flush();
}
});
}
@ -188,7 +192,7 @@ class AsyncFrameWriter implements FrameWriter {
try {
doRun();
} catch (IOException ex) {
transport.stopAsync();
transport.abort(Status.fromThrowable(ex));
throw new RuntimeException(ex);
}
}

View File

@ -1,5 +1,6 @@
package com.google.net.stubby.newtransport.okhttp;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.ByteBuffers;
@ -33,9 +34,10 @@ import okio.Buffer;
import java.io.IOException;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
@ -48,6 +50,7 @@ import javax.annotation.concurrent.GuardedBy;
*/
public class OkHttpClientTransport extends AbstractClientTransport {
/** The default initial window size in HTTP/2 is 64 KiB for the stream and connection. */
@VisibleForTesting
static final int DEFAULT_INITIAL_WINDOW_SIZE = 64 * 1024;
private static final ImmutableMap<ErrorCode, Status> ERROR_CODE_TO_STATUS = ImmutableMap
@ -75,21 +78,40 @@ public class OkHttpClientTransport extends AbstractClientTransport {
private final int port;
private FrameReader frameReader;
private AsyncFrameWriter frameWriter;
@GuardedBy("this")
private Object lock = new Object();
@GuardedBy("lock")
private int nextStreamId;
private final Map<Integer, OkHttpClientStream> streams =
Collections.synchronizedMap(new HashMap<Integer, OkHttpClientStream>());
private final ExecutorService executor = Executors.newCachedThreadPool();
private int unacknowledgedBytesRead;
private ClientFrameHandler clientFrameHandler;
// The status used to finish all active streams when the transport is closed.
@GuardedBy("lock")
private boolean goAway;
@GuardedBy("lock")
private Status goAwayStatus;
public OkHttpClientTransport(String host, int port) {
this.host = host;
this.host = Preconditions.checkNotNull(host);
this.port = port;
// Client initiated streams are odd, server initiated ones are even. Server should not need to
// use it. We start clients at 3 to avoid conflicting with HTTP negotiation.
nextStreamId = 3;
}
/**
* Create a transport connected to a fake peer for test.
*/
@VisibleForTesting
OkHttpClientTransport(FrameReader frameReader, AsyncFrameWriter frameWriter, int nextStreamId) {
host = null;
port = -1;
this.nextStreamId = nextStreamId;
this.frameReader = frameReader;
this.frameWriter = frameWriter;
}
@Override
protected ClientStream newStreamInternal(MethodDescriptor<?, ?> method, StreamListener listener) {
return new OkHttpClientStream(method, listener);
@ -97,53 +119,85 @@ public class OkHttpClientTransport extends AbstractClientTransport {
@Override
protected void doStart() {
BufferedSource source;
BufferedSink sink;
try {
Socket socket = new Socket(host, port);
// TODO(user): use SpdyConnection.
source = Okio.buffer(Okio.source(socket));
sink = Okio.buffer(Okio.sink(socket));
} catch (IOException e) {
throw new RuntimeException(e);
// We set host to null for test.
if (host != null) {
BufferedSource source;
BufferedSink sink;
try {
Socket socket = new Socket(host, port);
source = Okio.buffer(Okio.source(socket));
sink = Okio.buffer(Okio.sink(socket));
} catch (IOException e) {
throw new RuntimeException(e);
}
Variant variant = new Http20Draft12();
frameReader = variant.newReader(source, true);
frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
}
Variant variant = new Http20Draft12();
frameReader = variant.newReader(source, true);
frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
executor.execute(new ClientFrameHandler());
notifyStarted();
clientFrameHandler = new ClientFrameHandler();
executor.execute(clientFrameHandler);
}
@Override
protected void doStop() {
closeAllStreams(new Status(Code.INTERNAL, "Transport stopped"));
frameWriter.close();
try {
frameReader.close();
} catch (IOException e) {
throw new RuntimeException(e);
boolean normalClose;
synchronized (lock) {
normalClose = !goAway;
}
executor.shutdown();
notifyStopped();
if (normalClose) {
abort(new Status(Code.INTERNAL, "Transport stopped"));
// Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated streams.
// The GOAWAY is part of graceful shutdown.
frameWriter.goAway(0, ErrorCode.NO_ERROR, null);
}
stopIfNecessary();
}
@VisibleForTesting
ClientFrameHandler getHandler() {
return clientFrameHandler;
}
@VisibleForTesting
Map<Integer, OkHttpClientStream> getStreams() {
return streams;
}
/**
* Close and remove all streams.
* Finish all active streams with given status, then close the transport.
*/
private void closeAllStreams(Status status) {
Collection<OkHttpClientStream> streamsCopy;
synchronized (streams) {
streamsCopy = streams.values();
streams.clear();
void abort(Status status) {
onGoAway(-1, status);
}
private void onGoAway(int lastKnownStreamId, Status status) {
ArrayList<OkHttpClientStream> goAwayStreams = new ArrayList<OkHttpClientStream>();
synchronized (lock) {
goAway = true;
goAwayStatus = status;
Iterator<Map.Entry<Integer, OkHttpClientStream>> it = streams.entrySet().iterator();
while (it.hasNext()) {
Map.Entry<Integer, OkHttpClientStream> entry = it.next();
if (entry.getKey() > lastKnownStreamId) {
goAwayStreams.add(entry.getValue());
it.remove();
}
}
}
for (OkHttpClientStream stream : streamsCopy) {
// Starting stop, go into STOPPING state so that Channel know this Transport should not be used
// further, will become STOPPED once all streams are complete.
stopAsync();
for (OkHttpClientStream stream : goAwayStreams) {
stream.setStatus(status);
}
}
/**
* Called when a HTTP2 stream is closed.
* Called when a stream is closed.
*
* <p> Return false if the stream has already finished.
*/
@ -158,11 +212,40 @@ public class OkHttpClientTransport extends AbstractClientTransport {
return false;
}
/**
* When the transport is in goAway states, we should stop it once all active streams finish.
*/
private void stopIfNecessary() {
boolean shouldStop;
synchronized (lock) {
shouldStop = (goAway && streams.size() == 0);
}
if (shouldStop) {
frameWriter.close();
try {
frameReader.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
executor.shutdown();
notifyStopped();
}
}
/**
* Returns a Grpc status corresponding to the given ErrorCode.
*/
@VisibleForTesting
static Status toGrpcStatus(ErrorCode code) {
return ERROR_CODE_TO_STATUS.get(code);
}
/**
* Runnable which reads frames and dispatches them to in flight calls
*/
private class ClientFrameHandler implements FrameReader.Handler, Runnable {
private ClientFrameHandler() {}
@VisibleForTesting
class ClientFrameHandler implements FrameReader.Handler, Runnable {
ClientFrameHandler() {}
@Override
public void run() {
@ -173,8 +256,7 @@ public class OkHttpClientTransport extends AbstractClientTransport {
while (frameReader.nextFrame(this)) {
}
} catch (IOException ioe) {
ioe.printStackTrace();
closeAllStreams(new Status(Code.INTERNAL, ioe.getMessage()));
abort(Status.fromThrowable(ioe));
} finally {
// Restore the original thread name.
Thread.currentThread().setName(threadName);
@ -210,7 +292,9 @@ public class OkHttpClientTransport extends AbstractClientTransport {
stream.unacknowledgedBytesRead = 0;
}
if (inFinished) {
finishStream(streamId, Status.OK);
if (finishStream(streamId, Status.OK)) {
stopIfNecessary();
}
}
}
@ -229,7 +313,9 @@ public class OkHttpClientTransport extends AbstractClientTransport {
@Override
public void rstStream(int streamId, ErrorCode errorCode) {
finishStream(streamId, ERROR_CODE_TO_STATUS.get(errorCode));
if (finishStream(streamId, toGrpcStatus(errorCode))) {
stopIfNecessary();
}
}
@Override
@ -252,18 +338,14 @@ public class OkHttpClientTransport extends AbstractClientTransport {
@Override
public void goAway(int lastGoodStreamId, ErrorCode errorCode, ByteString debugData) {
// TODO(user): Log here and implement the real Go away behavior: streams have
// id <= lastGoodStreamId should not be closed.
closeAllStreams(new Status(Code.UNAVAILABLE, "Go away"));
stopAsync();
onGoAway(lastGoodStreamId, new Status(Code.UNAVAILABLE, "Go away"));
}
@Override
public void pushPromise(int streamId, int promisedStreamId, List<Header> requestHeaders)
throws IOException {
// TODO(user): should send SETTINGS_ENABLE_PUSH=0, then here we should reset it with
// PROTOCOL_ERROR.
frameWriter.rstStream(streamId, ErrorCode.REFUSED_STREAM);
// We don't accept server initiated stream.
frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR);
}
@Override
@ -284,28 +366,42 @@ public class OkHttpClientTransport extends AbstractClientTransport {
}
}
@GuardedBy("lock")
private void assignStreamId(OkHttpClientStream stream) {
Preconditions.checkState(stream.streamId == 0, "StreamId already assigned");
stream.streamId = nextStreamId;
streams.put(stream.streamId, stream);
if (nextStreamId >= Integer.MAX_VALUE - 2) {
onGoAway(Integer.MAX_VALUE, new Status(Code.INTERNAL, "Stream id exhaust"));
} else {
nextStreamId += 2;
}
}
/**
* Client stream for the okhttp transport.
*/
private class OkHttpClientStream extends AbstractStream implements ClientStream {
@VisibleForTesting
class OkHttpClientStream extends AbstractStream implements ClientStream {
int streamId;
final InputStreamDeframer deframer;
int unacknowledgedBytesRead;
public OkHttpClientStream(MethodDescriptor<?, ?> method, StreamListener listener) {
OkHttpClientStream(MethodDescriptor<?, ?> method, StreamListener listener) {
super(listener);
Preconditions.checkState(streamId == 0, "StreamId should be 0");
synchronized (OkHttpClientTransport.this) {
streamId = nextStreamId;
nextStreamId += 2;
streams.put(streamId, this);
frameWriter.synStream(false, false, streamId, 0,
Headers.createRequestHeaders(method.getName()));
}
deframer = new InputStreamDeframer(inboundMessageHandler());
synchronized (lock) {
if (goAway) {
setStatus(goAwayStatus);
return;
}
assignStreamId(this);
}
frameWriter.synStream(false, false, streamId, 0,
Headers.createRequestHeaders(method.getName()));
}
public InputStreamDeframer getDeframer() {
InputStreamDeframer getDeframer() {
return deframer;
}
@ -330,8 +426,9 @@ public class OkHttpClientTransport extends AbstractClientTransport {
public void cancel() {
Preconditions.checkState(streamId != 0, "streamId should be set");
outboundPhase = Phase.STATUS;
if (finishStream(streamId, ERROR_CODE_TO_STATUS.get(ErrorCode.CANCEL))) {
if (finishStream(streamId, toGrpcStatus(ErrorCode.CANCEL))) {
frameWriter.rstStream(streamId, ErrorCode.CANCEL);
stopIfNecessary();
}
}
}

View File

@ -0,0 +1,555 @@
package com.google.net.stubby.newtransport.okhttp;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.Service;
import com.google.net.stubby.MethodDescriptor;
import com.google.net.stubby.Status;
import com.google.net.stubby.newtransport.StreamListener;
import com.google.net.stubby.newtransport.okhttp.OkHttpClientTransport.ClientFrameHandler;
import com.google.net.stubby.newtransport.okhttp.OkHttpClientTransport.OkHttpClientStream;
import com.google.net.stubby.transport.Transport;
import com.google.net.stubby.transport.Transport.Code;
import com.google.net.stubby.transport.Transport.ContextValue;
import com.google.protobuf.ByteString;
import com.squareup.okhttp.internal.spdy.ErrorCode;
import com.squareup.okhttp.internal.spdy.FrameReader;
import okio.Buffer;
import okio.BufferedSource;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
/**
* Tests for {@link OkHttpClientTransport}.
*/
@RunWith(JUnit4.class)
public class OkHttpClientTransportTest {
private static final int TIME_OUT_MS = 5000000;
private static final String NETWORK_ISSUE_MESSAGE = "network issue";
// Flags
private static final byte PAYLOAD_FRAME = 0x0;
public static final byte CONTEXT_VALUE_FRAME = 0x1;
public static final byte STATUS_FRAME = 0x3;
@Mock
private AsyncFrameWriter frameWriter;
@Mock
MethodDescriptor<?, ?> method;
private OkHttpClientTransport clientTransport;
private MockFrameReader frameReader;
private Map<Integer, OkHttpClientStream> streams;
private ClientFrameHandler frameHandler;
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
streams = new HashMap<Integer, OkHttpClientStream>();
frameReader = new MockFrameReader();
clientTransport = new OkHttpClientTransport(frameReader, frameWriter, 3);
clientTransport.startAsync();
frameHandler = clientTransport.getHandler();
streams = clientTransport.getStreams();
when(method.getName()).thenReturn("fakemethod");
}
@After
public void tearDown() {
clientTransport.stopAsync();
assertTrue(frameReader.closed);
verify(frameWriter).close();
}
/**
* When nextFrame throws IOException, the transport should be aborted.
*/
@Test
public void nextFrameThrowIOException() throws Exception {
MockStreamListener listener1 = new MockStreamListener();
MockStreamListener listener2 = new MockStreamListener();
clientTransport.newStream(method, listener1);
clientTransport.newStream(method, listener2);
assertEquals(2, streams.size());
assertTrue(streams.containsKey(3));
assertTrue(streams.containsKey(5));
frameReader.throwIOExceptionForNextFrame();
listener1.waitUntilStreamClosed();
listener2.waitUntilStreamClosed();
assertEquals(0, streams.size());
assertEquals(Code.INTERNAL, listener1.status.getCode());
assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage());
assertEquals(Code.INTERNAL, listener1.status.getCode());
assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage());
assertTrue("Service state: " + clientTransport.state(),
Service.State.TERMINATED == clientTransport.state());
}
@Test
public void readMessages() throws Exception {
final int numMessages = 10;
final String message = "Hello Client";
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method, listener);
assertTrue(streams.containsKey(3));
for (int i = 0; i < numMessages; i++) {
BufferedSource source = mock(BufferedSource.class);
InputStream inputStream = createMessageFrame(message + i);
when(source.inputStream()).thenReturn(inputStream);
frameHandler.data(i == numMessages - 1 ? true : false, 3, source, inputStream.available());
}
listener.waitUntilStreamClosed();
assertEquals(Status.OK, listener.status);
assertEquals(numMessages, listener.messages.size());
for (int i = 0; i < numMessages; i++) {
assertEquals(message + i, listener.messages.get(i));
}
}
@Test
public void readContexts() throws Exception {
final int numContexts = 10;
final String key = "KEY";
final String value = "value";
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method, listener);
assertTrue(streams.containsKey(3));
for (int i = 0; i < numContexts; i++) {
BufferedSource source = mock(BufferedSource.class);
InputStream inputStream = createContextFrame(key + i, value + i);
when(source.inputStream()).thenReturn(inputStream);
frameHandler.data(i == numContexts - 1 ? true : false, 3, source, inputStream.available());
}
listener.waitUntilStreamClosed();
assertEquals(Status.OK, listener.status);
assertEquals(numContexts, listener.contexts.size());
for (int i = 0; i < numContexts; i++) {
String val = listener.contexts.get(key + i);
assertNotNull(val);
assertEquals(value + i, val);
}
}
@Test
public void readStatus() throws Exception {
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method, listener);
assertTrue(streams.containsKey(3));
BufferedSource source = mock(BufferedSource.class);
InputStream inputStream = createStatusFrame((short) Transport.Code.UNAVAILABLE.getNumber());
when(source.inputStream()).thenReturn(inputStream);
frameHandler.data(true, 3, source, inputStream.available());
listener.waitUntilStreamClosed();
assertEquals(Transport.Code.UNAVAILABLE, listener.status.getCode());
}
@Test
public void receiveReset() throws Exception {
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method, listener);
assertTrue(streams.containsKey(3));
frameHandler.rstStream(3, ErrorCode.PROTOCOL_ERROR);
listener.waitUntilStreamClosed();
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.PROTOCOL_ERROR), listener.status);
}
@Test
public void cancelStream() throws Exception {
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method, listener);
OkHttpClientStream stream = streams.get(3);
assertNotNull(stream);
stream.cancel();
verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
listener.waitUntilStreamClosed();
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener.status);
}
@Test
public void writeMessage() throws Exception {
final String message = "Hello Server";
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method, listener);
OkHttpClientStream stream = streams.get(3);
InputStream input = new ByteArrayInputStream(message.getBytes(StandardCharsets.UTF_8));
stream.writeMessage(input, input.available(), null);
stream.flush();
ArgumentCaptor<Buffer> captor =
ArgumentCaptor.forClass(Buffer.class);
verify(frameWriter).data(eq(false), eq(3), captor.capture());
Buffer sentFrame = captor.getValue();
checkSameInputStream(createMessageFrame(message), sentFrame.inputStream());
}
@Test
public void writeContext() throws Exception {
final String key = "KEY";
final String value = "VALUE";
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method, listener);
OkHttpClientStream stream = streams.get(3);
InputStream input = new ByteArrayInputStream(value.getBytes(StandardCharsets.UTF_8));
stream.writeContext(key, input, input.available(), null);
stream.flush();
ArgumentCaptor<Buffer> captor =
ArgumentCaptor.forClass(Buffer.class);
verify(frameWriter).data(eq(false), eq(3), captor.capture());
stream.cancel();
verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
listener.waitUntilStreamClosed();
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener.status);
}
@Test
public void windowUpdate() throws Exception {
MockStreamListener listener1 = new MockStreamListener();
MockStreamListener listener2 = new MockStreamListener();
clientTransport.newStream(method, listener1);
clientTransport.newStream(method, listener2);
assertEquals(2, streams.size());
OkHttpClientStream stream1 = streams.get(3);
OkHttpClientStream stream2 = streams.get(5);
int messageLength = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE / 4;
byte[] fakeMessage = new byte[messageLength];
byte[] contextBody = ContextValue
.newBuilder()
.setKey("KEY")
.setValue(ByteString.copyFrom(fakeMessage))
.build()
.toByteArray();
// Stream 1 receives context
InputStream contextFrame = createContextFrame(contextBody);
int contextFrameLength = contextFrame.available();
BufferedSource source = mock(BufferedSource.class);
when(source.inputStream()).thenReturn(contextFrame);
frameHandler.data(false, 3, source, contextFrame.available());
// Stream 2 receives context
contextFrame = createContextFrame(contextBody);
when(source.inputStream()).thenReturn(contextFrame);
frameHandler.data(false, 5, source, contextFrame.available());
verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * contextFrameLength));
// Stream 1 receives a message
InputStream messageFrame = createMessageFrame(fakeMessage);
int messageFrameLength = messageFrame.available();
when(source.inputStream()).thenReturn(messageFrame);
frameHandler.data(false, 3, source, messageFrame.available());
verify(frameWriter).windowUpdate(eq(3), eq((long) contextFrameLength + messageFrameLength));
// Stream 2 receives a message
messageFrame = createMessageFrame(fakeMessage);
when(source.inputStream()).thenReturn(messageFrame);
frameHandler.data(false, 5, source, messageFrame.available());
verify(frameWriter).windowUpdate(eq(5), eq((long) contextFrameLength + messageFrameLength));
verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * messageFrameLength));
stream1.cancel();
verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
listener1.waitUntilStreamClosed();
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener1.status);
stream2.cancel();
verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
listener2.waitUntilStreamClosed();
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener2.status);
}
@Test
public void stopNormally() throws Exception {
MockStreamListener listener1 = new MockStreamListener();
MockStreamListener listener2 = new MockStreamListener();
clientTransport.newStream(method, listener1);
clientTransport.newStream(method, listener2);
assertEquals(2, streams.size());
clientTransport.stopAsync();
listener1.waitUntilStreamClosed();
listener2.waitUntilStreamClosed();
verify(frameWriter).goAway(eq(0), eq(ErrorCode.NO_ERROR), (byte[]) any());
assertEquals(0, streams.size());
assertEquals(Code.INTERNAL, listener1.status.getCode());
assertEquals(Code.INTERNAL, listener2.status.getCode());
assertEquals(Service.State.TERMINATED, clientTransport.state());
}
@Test
public void receiveGoAway() throws Exception {
// start 2 streams.
MockStreamListener listener1 = new MockStreamListener();
MockStreamListener listener2 = new MockStreamListener();
clientTransport.newStream(method, listener1);
clientTransport.newStream(method, listener2);
assertEquals(2, streams.size());
// Receive goAway, max good id is 3.
frameHandler.goAway(3, ErrorCode.CANCEL, null);
// Transport should be in STOPPING state.
assertEquals(Service.State.STOPPING, clientTransport.state());
// Stream 2 should be closed.
listener2.waitUntilStreamClosed();
assertEquals(1, streams.size());
assertEquals(Code.UNAVAILABLE, listener2.status.getCode());
// New stream should be failed.
MockStreamListener listener3 = new MockStreamListener();
try {
clientTransport.newStream(method, listener3);
fail("new stream should no be accepted by a go-away transport.");
} catch (IllegalStateException ex) {
// expected.
}
// But stream 1 should be able to send.
final String sentMessage = "Should I also go away?";
OkHttpClientStream stream = streams.get(3);
InputStream input =
new ByteArrayInputStream(sentMessage.getBytes(StandardCharsets.UTF_8));
stream.writeMessage(input, input.available(), null);
stream.flush();
ArgumentCaptor<Buffer> captor =
ArgumentCaptor.forClass(Buffer.class);
verify(frameWriter).data(eq(false), eq(3), captor.capture());
Buffer sentFrame = captor.getValue();
checkSameInputStream(createMessageFrame(sentMessage), sentFrame.inputStream());
// And read.
final String receivedMessage = "No, you are fine.";
BufferedSource source = mock(BufferedSource.class);
InputStream inputStream = createMessageFrame(receivedMessage);
when(source.inputStream()).thenReturn(inputStream);
frameHandler.data(true, 3, source, inputStream.available());
listener1.waitUntilStreamClosed();
assertEquals(1, listener1.messages.size());
assertEquals(receivedMessage, listener1.messages.get(0));
// The transport should be stopped after all active streams finished.
assertTrue("Service state: " + clientTransport.state(),
Service.State.TERMINATED == clientTransport.state());
}
@Test
public void streamIdExhaust() throws Exception {
int startId = Integer.MAX_VALUE - 2;
AsyncFrameWriter writer = mock(AsyncFrameWriter.class);
OkHttpClientTransport transport =
new OkHttpClientTransport(frameReader, writer, startId);
transport.startAsync();
streams = transport.getStreams();
MockStreamListener listener1 = new MockStreamListener();
transport.newStream(method, listener1);
try {
transport.newStream(method, new MockStreamListener());
fail("new stream should not be accepted by a go-away transport.");
} catch (IllegalStateException ex) {
// expected.
}
streams.get(startId).cancel();
listener1.waitUntilStreamClosed();
verify(writer).rstStream(eq(startId), eq(ErrorCode.CANCEL));
assertEquals(Service.State.TERMINATED, transport.state());
}
private static void checkSameInputStream(InputStream in1, InputStream in2) throws IOException {
assertEquals(in1.available(), in2.available());
byte[] b1 = new byte[in1.available()];
in1.read(b1);
byte[] b2 = new byte[in2.available()];
in2.read(b2);
for (int i = 0; i < b1.length; i++) {
if (b1[i] != b2[i]) {
fail("Different InputStream.");
}
}
}
private static InputStream createMessageFrame(String message) throws IOException {
return createMessageFrame(message.getBytes(StandardCharsets.UTF_8));
}
private static InputStream createMessageFrame(byte[] message) throws IOException {
ByteArrayOutputStream os = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(os);
dos.write(PAYLOAD_FRAME);
dos.writeInt(message.length);
dos.write(message);
dos.close();
byte[] messageFrame = os.toByteArray();
// Write the compression header followed by the message frame.
return addCompressionHeader(messageFrame);
}
private static InputStream createContextFrame(String key, String value) throws IOException {
byte[] body = ContextValue
.newBuilder()
.setKey(key)
.setValue(ByteString.copyFromUtf8(value))
.build()
.toByteArray();
return createContextFrame(body);
}
private static InputStream createContextFrame(byte[] body) throws IOException {
ByteArrayOutputStream os = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(os);
dos.write(CONTEXT_VALUE_FRAME);
dos.writeInt(body.length);
dos.write(body);
dos.close();
byte[] contextFrame = os.toByteArray();
// Write the compression header followed by the context frame.
return addCompressionHeader(contextFrame);
}
private static InputStream createStatusFrame(short code) throws IOException {
ByteArrayOutputStream os = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(os);
dos.write(STATUS_FRAME);
int length = 2;
dos.writeInt(length);
dos.writeShort(code);
dos.close();
byte[] statusFrame = os.toByteArray();
// Write the compression header followed by the status frame.
return addCompressionHeader(statusFrame);
}
private static InputStream addCompressionHeader(byte[] raw) throws IOException {
ByteArrayOutputStream os = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(os);
dos.writeInt(raw.length);
dos.write(raw);
dos.close();
return new ByteArrayInputStream(os.toByteArray());
}
private static class MockFrameReader implements FrameReader {
boolean closed;
boolean throwExceptionForNextFrame;
@Override
public void close() throws IOException {
closed = true;
}
@Override
public boolean nextFrame(Handler handler) throws IOException {
if (throwExceptionForNextFrame) {
throw new IOException(NETWORK_ISSUE_MESSAGE);
}
synchronized (this) {
try {
wait();
} catch (InterruptedException e) {
throw new IOException(e);
}
}
if (throwExceptionForNextFrame) {
throw new IOException(NETWORK_ISSUE_MESSAGE);
}
return true;
}
synchronized void throwIOExceptionForNextFrame() {
throwExceptionForNextFrame = true;
notifyAll();
}
@Override
public void readConnectionPreface() throws IOException {
// not used.
}
}
private static class MockStreamListener implements StreamListener {
Status status;
CountDownLatch closed = new CountDownLatch(1);
ArrayList<String> messages = new ArrayList<String>();
Map<String, String> contexts = new HashMap<String, String>();
@Override
public ListenableFuture<Void> contextRead(String name, InputStream value, int length) {
String valueStr = getContent(value);
if (valueStr != null) {
// We assume only one context for each name.
contexts.put(name, valueStr);
}
return null;
}
@Override
public ListenableFuture<Void> messageRead(InputStream message, int length) {
String msg = getContent(message);
if (msg != null) {
messages.add(msg);
}
return null;
}
@Override
public void closed(Status status) {
this.status = status;
closed.countDown();
}
void waitUntilStreamClosed() throws InterruptedException {
if (!closed.await(TIME_OUT_MS, TimeUnit.MILLISECONDS)) {
fail("Failed waiting stream to be closed.");
}
}
static String getContent(InputStream message) {
BufferedReader br =
new BufferedReader(new InputStreamReader(message, StandardCharsets.UTF_8));
try {
// Only one line message is used in this test.
return br.readLine();
} catch (IOException e) {
return null;
}
}
}
}