mirror of https://github.com/grpc/grpc-java.git
Improve okhttp client transport, handles go away and add unit test.
------------- Created by MOE: http://code.google.com/p/moe-java MOE_MIGRATED_REVID=72155172
This commit is contained in:
parent
5f334f7c52
commit
7bf17dc4d6
|
|
@ -1,7 +1,8 @@
|
|||
package com.google.net.stubby.newtransport.okhttp;
|
||||
|
||||
import com.google.common.util.concurrent.SerializingExecutor;
|
||||
import com.google.common.util.concurrent.Service;
|
||||
import com.google.net.stubby.Status;
|
||||
import com.google.net.stubby.transport.Transport.Code;
|
||||
|
||||
import com.squareup.okhttp.internal.spdy.ErrorCode;
|
||||
import com.squareup.okhttp.internal.spdy.FrameWriter;
|
||||
|
|
@ -17,9 +18,10 @@ import java.util.concurrent.Executor;
|
|||
class AsyncFrameWriter implements FrameWriter {
|
||||
private final FrameWriter frameWriter;
|
||||
private final Executor executor;
|
||||
private final Service transport;
|
||||
private final OkHttpClientTransport transport;
|
||||
|
||||
public AsyncFrameWriter(FrameWriter frameWriter, Service transport, Executor executor) {
|
||||
public AsyncFrameWriter(FrameWriter frameWriter, OkHttpClientTransport transport,
|
||||
Executor executor) {
|
||||
this.frameWriter = frameWriter;
|
||||
this.transport = transport;
|
||||
// Although writes are thread-safe, we serialize them to prevent consuming many Threads that are
|
||||
|
|
@ -158,6 +160,8 @@ class AsyncFrameWriter implements FrameWriter {
|
|||
@Override
|
||||
public void doRun() throws IOException {
|
||||
frameWriter.goAway(lastGoodStreamId, errorCode, debugData);
|
||||
// Flush it since after goAway, we are likely to close this writer.
|
||||
frameWriter.flush();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -188,7 +192,7 @@ class AsyncFrameWriter implements FrameWriter {
|
|||
try {
|
||||
doRun();
|
||||
} catch (IOException ex) {
|
||||
transport.stopAsync();
|
||||
transport.abort(Status.fromThrowable(ex));
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
package com.google.net.stubby.newtransport.okhttp;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.io.ByteBuffers;
|
||||
|
|
@ -33,9 +34,10 @@ import okio.Buffer;
|
|||
import java.io.IOException;
|
||||
import java.net.Socket;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Collection;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
|
|
@ -48,6 +50,7 @@ import javax.annotation.concurrent.GuardedBy;
|
|||
*/
|
||||
public class OkHttpClientTransport extends AbstractClientTransport {
|
||||
/** The default initial window size in HTTP/2 is 64 KiB for the stream and connection. */
|
||||
@VisibleForTesting
|
||||
static final int DEFAULT_INITIAL_WINDOW_SIZE = 64 * 1024;
|
||||
|
||||
private static final ImmutableMap<ErrorCode, Status> ERROR_CODE_TO_STATUS = ImmutableMap
|
||||
|
|
@ -75,21 +78,40 @@ public class OkHttpClientTransport extends AbstractClientTransport {
|
|||
private final int port;
|
||||
private FrameReader frameReader;
|
||||
private AsyncFrameWriter frameWriter;
|
||||
@GuardedBy("this")
|
||||
private Object lock = new Object();
|
||||
@GuardedBy("lock")
|
||||
private int nextStreamId;
|
||||
private final Map<Integer, OkHttpClientStream> streams =
|
||||
Collections.synchronizedMap(new HashMap<Integer, OkHttpClientStream>());
|
||||
private final ExecutorService executor = Executors.newCachedThreadPool();
|
||||
private int unacknowledgedBytesRead;
|
||||
private ClientFrameHandler clientFrameHandler;
|
||||
// The status used to finish all active streams when the transport is closed.
|
||||
@GuardedBy("lock")
|
||||
private boolean goAway;
|
||||
@GuardedBy("lock")
|
||||
private Status goAwayStatus;
|
||||
|
||||
public OkHttpClientTransport(String host, int port) {
|
||||
this.host = host;
|
||||
this.host = Preconditions.checkNotNull(host);
|
||||
this.port = port;
|
||||
// Client initiated streams are odd, server initiated ones are even. Server should not need to
|
||||
// use it. We start clients at 3 to avoid conflicting with HTTP negotiation.
|
||||
nextStreamId = 3;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a transport connected to a fake peer for test.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
OkHttpClientTransport(FrameReader frameReader, AsyncFrameWriter frameWriter, int nextStreamId) {
|
||||
host = null;
|
||||
port = -1;
|
||||
this.nextStreamId = nextStreamId;
|
||||
this.frameReader = frameReader;
|
||||
this.frameWriter = frameWriter;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClientStream newStreamInternal(MethodDescriptor<?, ?> method, StreamListener listener) {
|
||||
return new OkHttpClientStream(method, listener);
|
||||
|
|
@ -97,53 +119,85 @@ public class OkHttpClientTransport extends AbstractClientTransport {
|
|||
|
||||
@Override
|
||||
protected void doStart() {
|
||||
BufferedSource source;
|
||||
BufferedSink sink;
|
||||
try {
|
||||
Socket socket = new Socket(host, port);
|
||||
// TODO(user): use SpdyConnection.
|
||||
source = Okio.buffer(Okio.source(socket));
|
||||
sink = Okio.buffer(Okio.sink(socket));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
// We set host to null for test.
|
||||
if (host != null) {
|
||||
BufferedSource source;
|
||||
BufferedSink sink;
|
||||
try {
|
||||
Socket socket = new Socket(host, port);
|
||||
source = Okio.buffer(Okio.source(socket));
|
||||
sink = Okio.buffer(Okio.sink(socket));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
Variant variant = new Http20Draft12();
|
||||
frameReader = variant.newReader(source, true);
|
||||
frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
|
||||
}
|
||||
Variant variant = new Http20Draft12();
|
||||
frameReader = variant.newReader(source, true);
|
||||
frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
|
||||
|
||||
executor.execute(new ClientFrameHandler());
|
||||
notifyStarted();
|
||||
clientFrameHandler = new ClientFrameHandler();
|
||||
executor.execute(clientFrameHandler);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doStop() {
|
||||
closeAllStreams(new Status(Code.INTERNAL, "Transport stopped"));
|
||||
frameWriter.close();
|
||||
try {
|
||||
frameReader.close();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
boolean normalClose;
|
||||
synchronized (lock) {
|
||||
normalClose = !goAway;
|
||||
}
|
||||
executor.shutdown();
|
||||
notifyStopped();
|
||||
if (normalClose) {
|
||||
abort(new Status(Code.INTERNAL, "Transport stopped"));
|
||||
// Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated streams.
|
||||
// The GOAWAY is part of graceful shutdown.
|
||||
frameWriter.goAway(0, ErrorCode.NO_ERROR, null);
|
||||
}
|
||||
stopIfNecessary();
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
ClientFrameHandler getHandler() {
|
||||
return clientFrameHandler;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
Map<Integer, OkHttpClientStream> getStreams() {
|
||||
return streams;
|
||||
}
|
||||
|
||||
/**
|
||||
* Close and remove all streams.
|
||||
* Finish all active streams with given status, then close the transport.
|
||||
*/
|
||||
private void closeAllStreams(Status status) {
|
||||
Collection<OkHttpClientStream> streamsCopy;
|
||||
synchronized (streams) {
|
||||
streamsCopy = streams.values();
|
||||
streams.clear();
|
||||
void abort(Status status) {
|
||||
onGoAway(-1, status);
|
||||
}
|
||||
|
||||
private void onGoAway(int lastKnownStreamId, Status status) {
|
||||
ArrayList<OkHttpClientStream> goAwayStreams = new ArrayList<OkHttpClientStream>();
|
||||
synchronized (lock) {
|
||||
goAway = true;
|
||||
goAwayStatus = status;
|
||||
Iterator<Map.Entry<Integer, OkHttpClientStream>> it = streams.entrySet().iterator();
|
||||
while (it.hasNext()) {
|
||||
Map.Entry<Integer, OkHttpClientStream> entry = it.next();
|
||||
if (entry.getKey() > lastKnownStreamId) {
|
||||
goAwayStreams.add(entry.getValue());
|
||||
it.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
for (OkHttpClientStream stream : streamsCopy) {
|
||||
|
||||
// Starting stop, go into STOPPING state so that Channel know this Transport should not be used
|
||||
// further, will become STOPPED once all streams are complete.
|
||||
stopAsync();
|
||||
|
||||
for (OkHttpClientStream stream : goAwayStreams) {
|
||||
stream.setStatus(status);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Called when a HTTP2 stream is closed.
|
||||
* Called when a stream is closed.
|
||||
*
|
||||
* <p> Return false if the stream has already finished.
|
||||
*/
|
||||
|
|
@ -158,11 +212,40 @@ public class OkHttpClientTransport extends AbstractClientTransport {
|
|||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* When the transport is in goAway states, we should stop it once all active streams finish.
|
||||
*/
|
||||
private void stopIfNecessary() {
|
||||
boolean shouldStop;
|
||||
synchronized (lock) {
|
||||
shouldStop = (goAway && streams.size() == 0);
|
||||
}
|
||||
if (shouldStop) {
|
||||
frameWriter.close();
|
||||
try {
|
||||
frameReader.close();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
executor.shutdown();
|
||||
notifyStopped();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a Grpc status corresponding to the given ErrorCode.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
static Status toGrpcStatus(ErrorCode code) {
|
||||
return ERROR_CODE_TO_STATUS.get(code);
|
||||
}
|
||||
|
||||
/**
|
||||
* Runnable which reads frames and dispatches them to in flight calls
|
||||
*/
|
||||
private class ClientFrameHandler implements FrameReader.Handler, Runnable {
|
||||
private ClientFrameHandler() {}
|
||||
@VisibleForTesting
|
||||
class ClientFrameHandler implements FrameReader.Handler, Runnable {
|
||||
ClientFrameHandler() {}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
|
|
@ -173,8 +256,7 @@ public class OkHttpClientTransport extends AbstractClientTransport {
|
|||
while (frameReader.nextFrame(this)) {
|
||||
}
|
||||
} catch (IOException ioe) {
|
||||
ioe.printStackTrace();
|
||||
closeAllStreams(new Status(Code.INTERNAL, ioe.getMessage()));
|
||||
abort(Status.fromThrowable(ioe));
|
||||
} finally {
|
||||
// Restore the original thread name.
|
||||
Thread.currentThread().setName(threadName);
|
||||
|
|
@ -210,7 +292,9 @@ public class OkHttpClientTransport extends AbstractClientTransport {
|
|||
stream.unacknowledgedBytesRead = 0;
|
||||
}
|
||||
if (inFinished) {
|
||||
finishStream(streamId, Status.OK);
|
||||
if (finishStream(streamId, Status.OK)) {
|
||||
stopIfNecessary();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -229,7 +313,9 @@ public class OkHttpClientTransport extends AbstractClientTransport {
|
|||
|
||||
@Override
|
||||
public void rstStream(int streamId, ErrorCode errorCode) {
|
||||
finishStream(streamId, ERROR_CODE_TO_STATUS.get(errorCode));
|
||||
if (finishStream(streamId, toGrpcStatus(errorCode))) {
|
||||
stopIfNecessary();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -252,18 +338,14 @@ public class OkHttpClientTransport extends AbstractClientTransport {
|
|||
|
||||
@Override
|
||||
public void goAway(int lastGoodStreamId, ErrorCode errorCode, ByteString debugData) {
|
||||
// TODO(user): Log here and implement the real Go away behavior: streams have
|
||||
// id <= lastGoodStreamId should not be closed.
|
||||
closeAllStreams(new Status(Code.UNAVAILABLE, "Go away"));
|
||||
stopAsync();
|
||||
onGoAway(lastGoodStreamId, new Status(Code.UNAVAILABLE, "Go away"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void pushPromise(int streamId, int promisedStreamId, List<Header> requestHeaders)
|
||||
throws IOException {
|
||||
// TODO(user): should send SETTINGS_ENABLE_PUSH=0, then here we should reset it with
|
||||
// PROTOCOL_ERROR.
|
||||
frameWriter.rstStream(streamId, ErrorCode.REFUSED_STREAM);
|
||||
// We don't accept server initiated stream.
|
||||
frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -284,28 +366,42 @@ public class OkHttpClientTransport extends AbstractClientTransport {
|
|||
}
|
||||
}
|
||||
|
||||
@GuardedBy("lock")
|
||||
private void assignStreamId(OkHttpClientStream stream) {
|
||||
Preconditions.checkState(stream.streamId == 0, "StreamId already assigned");
|
||||
stream.streamId = nextStreamId;
|
||||
streams.put(stream.streamId, stream);
|
||||
if (nextStreamId >= Integer.MAX_VALUE - 2) {
|
||||
onGoAway(Integer.MAX_VALUE, new Status(Code.INTERNAL, "Stream id exhaust"));
|
||||
} else {
|
||||
nextStreamId += 2;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Client stream for the okhttp transport.
|
||||
*/
|
||||
private class OkHttpClientStream extends AbstractStream implements ClientStream {
|
||||
@VisibleForTesting
|
||||
class OkHttpClientStream extends AbstractStream implements ClientStream {
|
||||
int streamId;
|
||||
final InputStreamDeframer deframer;
|
||||
int unacknowledgedBytesRead;
|
||||
|
||||
public OkHttpClientStream(MethodDescriptor<?, ?> method, StreamListener listener) {
|
||||
OkHttpClientStream(MethodDescriptor<?, ?> method, StreamListener listener) {
|
||||
super(listener);
|
||||
Preconditions.checkState(streamId == 0, "StreamId should be 0");
|
||||
synchronized (OkHttpClientTransport.this) {
|
||||
streamId = nextStreamId;
|
||||
nextStreamId += 2;
|
||||
streams.put(streamId, this);
|
||||
frameWriter.synStream(false, false, streamId, 0,
|
||||
Headers.createRequestHeaders(method.getName()));
|
||||
}
|
||||
deframer = new InputStreamDeframer(inboundMessageHandler());
|
||||
synchronized (lock) {
|
||||
if (goAway) {
|
||||
setStatus(goAwayStatus);
|
||||
return;
|
||||
}
|
||||
assignStreamId(this);
|
||||
}
|
||||
frameWriter.synStream(false, false, streamId, 0,
|
||||
Headers.createRequestHeaders(method.getName()));
|
||||
}
|
||||
|
||||
public InputStreamDeframer getDeframer() {
|
||||
InputStreamDeframer getDeframer() {
|
||||
return deframer;
|
||||
}
|
||||
|
||||
|
|
@ -330,8 +426,9 @@ public class OkHttpClientTransport extends AbstractClientTransport {
|
|||
public void cancel() {
|
||||
Preconditions.checkState(streamId != 0, "streamId should be set");
|
||||
outboundPhase = Phase.STATUS;
|
||||
if (finishStream(streamId, ERROR_CODE_TO_STATUS.get(ErrorCode.CANCEL))) {
|
||||
if (finishStream(streamId, toGrpcStatus(ErrorCode.CANCEL))) {
|
||||
frameWriter.rstStream(streamId, ErrorCode.CANCEL);
|
||||
stopIfNecessary();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,555 @@
|
|||
package com.google.net.stubby.newtransport.okhttp;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
import static org.mockito.Mockito.any;
|
||||
import static org.mockito.Mockito.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.common.util.concurrent.ListenableFuture;
|
||||
import com.google.common.util.concurrent.Service;
|
||||
import com.google.net.stubby.MethodDescriptor;
|
||||
import com.google.net.stubby.Status;
|
||||
import com.google.net.stubby.newtransport.StreamListener;
|
||||
import com.google.net.stubby.newtransport.okhttp.OkHttpClientTransport.ClientFrameHandler;
|
||||
import com.google.net.stubby.newtransport.okhttp.OkHttpClientTransport.OkHttpClientStream;
|
||||
import com.google.net.stubby.transport.Transport;
|
||||
import com.google.net.stubby.transport.Transport.Code;
|
||||
import com.google.net.stubby.transport.Transport.ContextValue;
|
||||
import com.google.protobuf.ByteString;
|
||||
|
||||
import com.squareup.okhttp.internal.spdy.ErrorCode;
|
||||
import com.squareup.okhttp.internal.spdy.FrameReader;
|
||||
|
||||
import okio.Buffer;
|
||||
import okio.BufferedSource;
|
||||
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.MockitoAnnotations;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Tests for {@link OkHttpClientTransport}.
|
||||
*/
|
||||
@RunWith(JUnit4.class)
|
||||
public class OkHttpClientTransportTest {
|
||||
private static final int TIME_OUT_MS = 5000000;
|
||||
private static final String NETWORK_ISSUE_MESSAGE = "network issue";
|
||||
|
||||
// Flags
|
||||
private static final byte PAYLOAD_FRAME = 0x0;
|
||||
public static final byte CONTEXT_VALUE_FRAME = 0x1;
|
||||
public static final byte STATUS_FRAME = 0x3;
|
||||
|
||||
@Mock
|
||||
private AsyncFrameWriter frameWriter;
|
||||
@Mock
|
||||
MethodDescriptor<?, ?> method;
|
||||
private OkHttpClientTransport clientTransport;
|
||||
private MockFrameReader frameReader;
|
||||
private Map<Integer, OkHttpClientStream> streams;
|
||||
private ClientFrameHandler frameHandler;
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
MockitoAnnotations.initMocks(this);
|
||||
streams = new HashMap<Integer, OkHttpClientStream>();
|
||||
frameReader = new MockFrameReader();
|
||||
clientTransport = new OkHttpClientTransport(frameReader, frameWriter, 3);
|
||||
clientTransport.startAsync();
|
||||
frameHandler = clientTransport.getHandler();
|
||||
streams = clientTransport.getStreams();
|
||||
when(method.getName()).thenReturn("fakemethod");
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
clientTransport.stopAsync();
|
||||
assertTrue(frameReader.closed);
|
||||
verify(frameWriter).close();
|
||||
}
|
||||
|
||||
/**
|
||||
* When nextFrame throws IOException, the transport should be aborted.
|
||||
*/
|
||||
@Test
|
||||
public void nextFrameThrowIOException() throws Exception {
|
||||
MockStreamListener listener1 = new MockStreamListener();
|
||||
MockStreamListener listener2 = new MockStreamListener();
|
||||
clientTransport.newStream(method, listener1);
|
||||
clientTransport.newStream(method, listener2);
|
||||
assertEquals(2, streams.size());
|
||||
assertTrue(streams.containsKey(3));
|
||||
assertTrue(streams.containsKey(5));
|
||||
frameReader.throwIOExceptionForNextFrame();
|
||||
listener1.waitUntilStreamClosed();
|
||||
listener2.waitUntilStreamClosed();
|
||||
assertEquals(0, streams.size());
|
||||
assertEquals(Code.INTERNAL, listener1.status.getCode());
|
||||
assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage());
|
||||
assertEquals(Code.INTERNAL, listener1.status.getCode());
|
||||
assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage());
|
||||
assertTrue("Service state: " + clientTransport.state(),
|
||||
Service.State.TERMINATED == clientTransport.state());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void readMessages() throws Exception {
|
||||
final int numMessages = 10;
|
||||
final String message = "Hello Client";
|
||||
MockStreamListener listener = new MockStreamListener();
|
||||
clientTransport.newStream(method, listener);
|
||||
assertTrue(streams.containsKey(3));
|
||||
for (int i = 0; i < numMessages; i++) {
|
||||
BufferedSource source = mock(BufferedSource.class);
|
||||
InputStream inputStream = createMessageFrame(message + i);
|
||||
when(source.inputStream()).thenReturn(inputStream);
|
||||
frameHandler.data(i == numMessages - 1 ? true : false, 3, source, inputStream.available());
|
||||
}
|
||||
listener.waitUntilStreamClosed();
|
||||
assertEquals(Status.OK, listener.status);
|
||||
assertEquals(numMessages, listener.messages.size());
|
||||
for (int i = 0; i < numMessages; i++) {
|
||||
assertEquals(message + i, listener.messages.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void readContexts() throws Exception {
|
||||
final int numContexts = 10;
|
||||
final String key = "KEY";
|
||||
final String value = "value";
|
||||
MockStreamListener listener = new MockStreamListener();
|
||||
clientTransport.newStream(method, listener);
|
||||
assertTrue(streams.containsKey(3));
|
||||
for (int i = 0; i < numContexts; i++) {
|
||||
BufferedSource source = mock(BufferedSource.class);
|
||||
InputStream inputStream = createContextFrame(key + i, value + i);
|
||||
when(source.inputStream()).thenReturn(inputStream);
|
||||
frameHandler.data(i == numContexts - 1 ? true : false, 3, source, inputStream.available());
|
||||
}
|
||||
listener.waitUntilStreamClosed();
|
||||
assertEquals(Status.OK, listener.status);
|
||||
assertEquals(numContexts, listener.contexts.size());
|
||||
for (int i = 0; i < numContexts; i++) {
|
||||
String val = listener.contexts.get(key + i);
|
||||
assertNotNull(val);
|
||||
assertEquals(value + i, val);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void readStatus() throws Exception {
|
||||
MockStreamListener listener = new MockStreamListener();
|
||||
clientTransport.newStream(method, listener);
|
||||
assertTrue(streams.containsKey(3));
|
||||
BufferedSource source = mock(BufferedSource.class);
|
||||
InputStream inputStream = createStatusFrame((short) Transport.Code.UNAVAILABLE.getNumber());
|
||||
when(source.inputStream()).thenReturn(inputStream);
|
||||
frameHandler.data(true, 3, source, inputStream.available());
|
||||
listener.waitUntilStreamClosed();
|
||||
assertEquals(Transport.Code.UNAVAILABLE, listener.status.getCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void receiveReset() throws Exception {
|
||||
MockStreamListener listener = new MockStreamListener();
|
||||
clientTransport.newStream(method, listener);
|
||||
assertTrue(streams.containsKey(3));
|
||||
frameHandler.rstStream(3, ErrorCode.PROTOCOL_ERROR);
|
||||
listener.waitUntilStreamClosed();
|
||||
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.PROTOCOL_ERROR), listener.status);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void cancelStream() throws Exception {
|
||||
MockStreamListener listener = new MockStreamListener();
|
||||
clientTransport.newStream(method, listener);
|
||||
OkHttpClientStream stream = streams.get(3);
|
||||
assertNotNull(stream);
|
||||
stream.cancel();
|
||||
verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
|
||||
listener.waitUntilStreamClosed();
|
||||
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener.status);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void writeMessage() throws Exception {
|
||||
final String message = "Hello Server";
|
||||
MockStreamListener listener = new MockStreamListener();
|
||||
clientTransport.newStream(method, listener);
|
||||
OkHttpClientStream stream = streams.get(3);
|
||||
InputStream input = new ByteArrayInputStream(message.getBytes(StandardCharsets.UTF_8));
|
||||
stream.writeMessage(input, input.available(), null);
|
||||
stream.flush();
|
||||
ArgumentCaptor<Buffer> captor =
|
||||
ArgumentCaptor.forClass(Buffer.class);
|
||||
verify(frameWriter).data(eq(false), eq(3), captor.capture());
|
||||
Buffer sentFrame = captor.getValue();
|
||||
checkSameInputStream(createMessageFrame(message), sentFrame.inputStream());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void writeContext() throws Exception {
|
||||
final String key = "KEY";
|
||||
final String value = "VALUE";
|
||||
MockStreamListener listener = new MockStreamListener();
|
||||
clientTransport.newStream(method, listener);
|
||||
OkHttpClientStream stream = streams.get(3);
|
||||
InputStream input = new ByteArrayInputStream(value.getBytes(StandardCharsets.UTF_8));
|
||||
stream.writeContext(key, input, input.available(), null);
|
||||
stream.flush();
|
||||
ArgumentCaptor<Buffer> captor =
|
||||
ArgumentCaptor.forClass(Buffer.class);
|
||||
verify(frameWriter).data(eq(false), eq(3), captor.capture());
|
||||
stream.cancel();
|
||||
verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
|
||||
listener.waitUntilStreamClosed();
|
||||
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener.status);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void windowUpdate() throws Exception {
|
||||
MockStreamListener listener1 = new MockStreamListener();
|
||||
MockStreamListener listener2 = new MockStreamListener();
|
||||
clientTransport.newStream(method, listener1);
|
||||
clientTransport.newStream(method, listener2);
|
||||
assertEquals(2, streams.size());
|
||||
OkHttpClientStream stream1 = streams.get(3);
|
||||
OkHttpClientStream stream2 = streams.get(5);
|
||||
|
||||
int messageLength = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE / 4;
|
||||
byte[] fakeMessage = new byte[messageLength];
|
||||
byte[] contextBody = ContextValue
|
||||
.newBuilder()
|
||||
.setKey("KEY")
|
||||
.setValue(ByteString.copyFrom(fakeMessage))
|
||||
.build()
|
||||
.toByteArray();
|
||||
|
||||
// Stream 1 receives context
|
||||
InputStream contextFrame = createContextFrame(contextBody);
|
||||
int contextFrameLength = contextFrame.available();
|
||||
BufferedSource source = mock(BufferedSource.class);
|
||||
when(source.inputStream()).thenReturn(contextFrame);
|
||||
frameHandler.data(false, 3, source, contextFrame.available());
|
||||
|
||||
// Stream 2 receives context
|
||||
contextFrame = createContextFrame(contextBody);
|
||||
when(source.inputStream()).thenReturn(contextFrame);
|
||||
frameHandler.data(false, 5, source, contextFrame.available());
|
||||
|
||||
verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * contextFrameLength));
|
||||
|
||||
// Stream 1 receives a message
|
||||
InputStream messageFrame = createMessageFrame(fakeMessage);
|
||||
int messageFrameLength = messageFrame.available();
|
||||
when(source.inputStream()).thenReturn(messageFrame);
|
||||
frameHandler.data(false, 3, source, messageFrame.available());
|
||||
|
||||
verify(frameWriter).windowUpdate(eq(3), eq((long) contextFrameLength + messageFrameLength));
|
||||
|
||||
// Stream 2 receives a message
|
||||
messageFrame = createMessageFrame(fakeMessage);
|
||||
when(source.inputStream()).thenReturn(messageFrame);
|
||||
frameHandler.data(false, 5, source, messageFrame.available());
|
||||
|
||||
verify(frameWriter).windowUpdate(eq(5), eq((long) contextFrameLength + messageFrameLength));
|
||||
verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * messageFrameLength));
|
||||
|
||||
stream1.cancel();
|
||||
verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
|
||||
listener1.waitUntilStreamClosed();
|
||||
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener1.status);
|
||||
|
||||
stream2.cancel();
|
||||
verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
|
||||
listener2.waitUntilStreamClosed();
|
||||
assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener2.status);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void stopNormally() throws Exception {
|
||||
MockStreamListener listener1 = new MockStreamListener();
|
||||
MockStreamListener listener2 = new MockStreamListener();
|
||||
clientTransport.newStream(method, listener1);
|
||||
clientTransport.newStream(method, listener2);
|
||||
assertEquals(2, streams.size());
|
||||
clientTransport.stopAsync();
|
||||
listener1.waitUntilStreamClosed();
|
||||
listener2.waitUntilStreamClosed();
|
||||
verify(frameWriter).goAway(eq(0), eq(ErrorCode.NO_ERROR), (byte[]) any());
|
||||
assertEquals(0, streams.size());
|
||||
assertEquals(Code.INTERNAL, listener1.status.getCode());
|
||||
assertEquals(Code.INTERNAL, listener2.status.getCode());
|
||||
assertEquals(Service.State.TERMINATED, clientTransport.state());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void receiveGoAway() throws Exception {
|
||||
// start 2 streams.
|
||||
MockStreamListener listener1 = new MockStreamListener();
|
||||
MockStreamListener listener2 = new MockStreamListener();
|
||||
clientTransport.newStream(method, listener1);
|
||||
clientTransport.newStream(method, listener2);
|
||||
assertEquals(2, streams.size());
|
||||
|
||||
// Receive goAway, max good id is 3.
|
||||
frameHandler.goAway(3, ErrorCode.CANCEL, null);
|
||||
|
||||
// Transport should be in STOPPING state.
|
||||
assertEquals(Service.State.STOPPING, clientTransport.state());
|
||||
|
||||
// Stream 2 should be closed.
|
||||
listener2.waitUntilStreamClosed();
|
||||
assertEquals(1, streams.size());
|
||||
assertEquals(Code.UNAVAILABLE, listener2.status.getCode());
|
||||
|
||||
// New stream should be failed.
|
||||
MockStreamListener listener3 = new MockStreamListener();
|
||||
try {
|
||||
clientTransport.newStream(method, listener3);
|
||||
fail("new stream should no be accepted by a go-away transport.");
|
||||
} catch (IllegalStateException ex) {
|
||||
// expected.
|
||||
}
|
||||
|
||||
// But stream 1 should be able to send.
|
||||
final String sentMessage = "Should I also go away?";
|
||||
OkHttpClientStream stream = streams.get(3);
|
||||
InputStream input =
|
||||
new ByteArrayInputStream(sentMessage.getBytes(StandardCharsets.UTF_8));
|
||||
stream.writeMessage(input, input.available(), null);
|
||||
stream.flush();
|
||||
ArgumentCaptor<Buffer> captor =
|
||||
ArgumentCaptor.forClass(Buffer.class);
|
||||
verify(frameWriter).data(eq(false), eq(3), captor.capture());
|
||||
Buffer sentFrame = captor.getValue();
|
||||
checkSameInputStream(createMessageFrame(sentMessage), sentFrame.inputStream());
|
||||
|
||||
// And read.
|
||||
final String receivedMessage = "No, you are fine.";
|
||||
BufferedSource source = mock(BufferedSource.class);
|
||||
InputStream inputStream = createMessageFrame(receivedMessage);
|
||||
when(source.inputStream()).thenReturn(inputStream);
|
||||
frameHandler.data(true, 3, source, inputStream.available());
|
||||
listener1.waitUntilStreamClosed();
|
||||
assertEquals(1, listener1.messages.size());
|
||||
assertEquals(receivedMessage, listener1.messages.get(0));
|
||||
|
||||
// The transport should be stopped after all active streams finished.
|
||||
assertTrue("Service state: " + clientTransport.state(),
|
||||
Service.State.TERMINATED == clientTransport.state());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void streamIdExhaust() throws Exception {
|
||||
int startId = Integer.MAX_VALUE - 2;
|
||||
AsyncFrameWriter writer = mock(AsyncFrameWriter.class);
|
||||
OkHttpClientTransport transport =
|
||||
new OkHttpClientTransport(frameReader, writer, startId);
|
||||
transport.startAsync();
|
||||
streams = transport.getStreams();
|
||||
|
||||
MockStreamListener listener1 = new MockStreamListener();
|
||||
transport.newStream(method, listener1);
|
||||
|
||||
try {
|
||||
transport.newStream(method, new MockStreamListener());
|
||||
fail("new stream should not be accepted by a go-away transport.");
|
||||
} catch (IllegalStateException ex) {
|
||||
// expected.
|
||||
}
|
||||
|
||||
streams.get(startId).cancel();
|
||||
listener1.waitUntilStreamClosed();
|
||||
verify(writer).rstStream(eq(startId), eq(ErrorCode.CANCEL));
|
||||
assertEquals(Service.State.TERMINATED, transport.state());
|
||||
}
|
||||
|
||||
private static void checkSameInputStream(InputStream in1, InputStream in2) throws IOException {
|
||||
assertEquals(in1.available(), in2.available());
|
||||
byte[] b1 = new byte[in1.available()];
|
||||
in1.read(b1);
|
||||
byte[] b2 = new byte[in2.available()];
|
||||
in2.read(b2);
|
||||
for (int i = 0; i < b1.length; i++) {
|
||||
if (b1[i] != b2[i]) {
|
||||
fail("Different InputStream.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static InputStream createMessageFrame(String message) throws IOException {
|
||||
return createMessageFrame(message.getBytes(StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
private static InputStream createMessageFrame(byte[] message) throws IOException {
|
||||
ByteArrayOutputStream os = new ByteArrayOutputStream();
|
||||
DataOutputStream dos = new DataOutputStream(os);
|
||||
dos.write(PAYLOAD_FRAME);
|
||||
dos.writeInt(message.length);
|
||||
dos.write(message);
|
||||
dos.close();
|
||||
byte[] messageFrame = os.toByteArray();
|
||||
|
||||
// Write the compression header followed by the message frame.
|
||||
return addCompressionHeader(messageFrame);
|
||||
}
|
||||
|
||||
private static InputStream createContextFrame(String key, String value) throws IOException {
|
||||
byte[] body = ContextValue
|
||||
.newBuilder()
|
||||
.setKey(key)
|
||||
.setValue(ByteString.copyFromUtf8(value))
|
||||
.build()
|
||||
.toByteArray();
|
||||
return createContextFrame(body);
|
||||
}
|
||||
|
||||
private static InputStream createContextFrame(byte[] body) throws IOException {
|
||||
ByteArrayOutputStream os = new ByteArrayOutputStream();
|
||||
DataOutputStream dos = new DataOutputStream(os);
|
||||
dos.write(CONTEXT_VALUE_FRAME);
|
||||
dos.writeInt(body.length);
|
||||
dos.write(body);
|
||||
dos.close();
|
||||
byte[] contextFrame = os.toByteArray();
|
||||
|
||||
// Write the compression header followed by the context frame.
|
||||
return addCompressionHeader(contextFrame);
|
||||
}
|
||||
|
||||
private static InputStream createStatusFrame(short code) throws IOException {
|
||||
ByteArrayOutputStream os = new ByteArrayOutputStream();
|
||||
DataOutputStream dos = new DataOutputStream(os);
|
||||
dos.write(STATUS_FRAME);
|
||||
int length = 2;
|
||||
dos.writeInt(length);
|
||||
dos.writeShort(code);
|
||||
dos.close();
|
||||
byte[] statusFrame = os.toByteArray();
|
||||
|
||||
// Write the compression header followed by the status frame.
|
||||
return addCompressionHeader(statusFrame);
|
||||
}
|
||||
|
||||
private static InputStream addCompressionHeader(byte[] raw) throws IOException {
|
||||
ByteArrayOutputStream os = new ByteArrayOutputStream();
|
||||
DataOutputStream dos = new DataOutputStream(os);
|
||||
dos.writeInt(raw.length);
|
||||
dos.write(raw);
|
||||
dos.close();
|
||||
return new ByteArrayInputStream(os.toByteArray());
|
||||
}
|
||||
|
||||
private static class MockFrameReader implements FrameReader {
|
||||
boolean closed;
|
||||
boolean throwExceptionForNextFrame;
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
closed = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean nextFrame(Handler handler) throws IOException {
|
||||
if (throwExceptionForNextFrame) {
|
||||
throw new IOException(NETWORK_ISSUE_MESSAGE);
|
||||
}
|
||||
synchronized (this) {
|
||||
try {
|
||||
wait();
|
||||
} catch (InterruptedException e) {
|
||||
throw new IOException(e);
|
||||
}
|
||||
}
|
||||
if (throwExceptionForNextFrame) {
|
||||
throw new IOException(NETWORK_ISSUE_MESSAGE);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
synchronized void throwIOExceptionForNextFrame() {
|
||||
throwExceptionForNextFrame = true;
|
||||
notifyAll();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readConnectionPreface() throws IOException {
|
||||
// not used.
|
||||
}
|
||||
}
|
||||
|
||||
private static class MockStreamListener implements StreamListener {
|
||||
Status status;
|
||||
CountDownLatch closed = new CountDownLatch(1);
|
||||
ArrayList<String> messages = new ArrayList<String>();
|
||||
Map<String, String> contexts = new HashMap<String, String>();
|
||||
|
||||
@Override
|
||||
public ListenableFuture<Void> contextRead(String name, InputStream value, int length) {
|
||||
String valueStr = getContent(value);
|
||||
if (valueStr != null) {
|
||||
// We assume only one context for each name.
|
||||
contexts.put(name, valueStr);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenableFuture<Void> messageRead(InputStream message, int length) {
|
||||
String msg = getContent(message);
|
||||
if (msg != null) {
|
||||
messages.add(msg);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void closed(Status status) {
|
||||
this.status = status;
|
||||
closed.countDown();
|
||||
}
|
||||
|
||||
void waitUntilStreamClosed() throws InterruptedException {
|
||||
if (!closed.await(TIME_OUT_MS, TimeUnit.MILLISECONDS)) {
|
||||
fail("Failed waiting stream to be closed.");
|
||||
}
|
||||
}
|
||||
|
||||
static String getContent(InputStream message) {
|
||||
BufferedReader br =
|
||||
new BufferedReader(new InputStreamReader(message, StandardCharsets.UTF_8));
|
||||
try {
|
||||
// Only one line message is used in this test.
|
||||
return br.readLine();
|
||||
} catch (IOException e) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue