core,netty: split stream state on client-side; AbstractStream2

OkHttp will need to be migrated in a future commit.
This commit is contained in:
Eric Anderson 2016-07-17 23:15:05 -07:00
parent 93dd02ca9c
commit cad7124c27
17 changed files with 1193 additions and 313 deletions

View File

@ -0,0 +1,303 @@
/*
* Copyright 2014, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.internal;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.grpc.Metadata;
import io.grpc.Status;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
/**
* The abstract base class for {@link ClientStream} implementations. Extending classes only need to
* implement {@link #transportState()} and {@link #abstractClientStreamSink()}. Must only be called
* from the sending application thread.
*/
public abstract class AbstractClientStream2 extends AbstractStream2
implements ClientStream, MessageFramer.Sink {
private static final Logger log = Logger.getLogger(AbstractClientStream2.class.getName());
/**
* A sink for outbound operations, separated from the stream simply to avoid name
* collisions/confusion. Only called from application thread.
*/
protected interface Sink {
/**
* Sends an outbound frame to the remote end point.
*
* @param frame a buffer containing the chunk of data to be sent, or {@code null} if {@code
* endOfStream} with no data to send
* @param endOfStream {@code true} if this is the last frame; {@code flush} is guaranteed to be
* {@code true} if this is {@code true}
* @param flush {@code true} if more data may not be arriving soon
*/
void writeFrame(@Nullable WritableBuffer frame, boolean endOfStream, boolean flush);
/**
* Requests up to the given number of messages from the call to be delivered to the client. This
* should end up triggering {@link TransportState#requestMessagesFromDeframer(int)} on the
* transport thread.
*/
void request(int numMessages);
/**
* Tears down the stream, typically in the event of a timeout. This method may be called
* multiple times and from any thread.
*
* <p>This is a clone of {@link ClientStream#cancel()}; {@link AbstractClientStream2#cancel}
* delegates to this method.
*/
void cancel(Status status);
}
private final MessageFramer framer;
private boolean outboundClosed;
/**
* Whether cancel() has been called. This is not strictly necessary, but removes the delay between
* cancel() being called and isReady() beginning to return false, since cancel is commonly
* processed asynchronously.
*/
private volatile boolean cancelled;
protected AbstractClientStream2(WritableBufferAllocator bufferAllocator) {
framer = new MessageFramer(this, bufferAllocator);
}
/** {@inheritDoc} */
@Override
protected abstract TransportState transportState();
@Override
public void start(ClientStreamListener listener) {
transportState().setListener(listener);
}
/**
* Sink for transport to be called to perform outbound operations. Each stream must have its own
* unique sink.
*/
protected abstract Sink abstractClientStreamSink();
@Override
protected final MessageFramer framer() {
return framer;
}
@Override
public final void request(int numMessages) {
abstractClientStreamSink().request(numMessages);
}
@Override
public final void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
Preconditions.checkArgument(frame != null || endOfStream, "null frame before EOS");
abstractClientStreamSink().writeFrame(frame, endOfStream, flush);
}
@Override
public final void halfClose() {
if (!outboundClosed) {
outboundClosed = true;
endOfMessages();
}
}
@Override
public final void cancel(Status reason) {
Preconditions.checkArgument(!reason.isOk(), "Should not cancel with OK status");
cancelled = true;
abstractClientStreamSink().cancel(reason);
}
@Override
public final boolean isReady() {
return super.isReady() && !cancelled;
}
/** This should only called from the transport thread. */
protected abstract static class TransportState extends AbstractStream2.TransportState {
/** Whether listener.closed() has been called. */
private boolean listenerClosed;
private ClientStreamListener listener;
private Runnable deliveryStalledTask;
private boolean headersReceived;
/**
* Whether the stream is closed from the transport's perspective. This can differ from {@link
* #listenerClosed} because there may still be messages buffered to deliver to the application.
*/
private boolean statusReported;
protected TransportState(int maxMessageSize) {
super(maxMessageSize);
}
@VisibleForTesting
public final void setListener(ClientStreamListener listener) {
Preconditions.checkState(this.listener == null, "Already called setListener");
this.listener = Preconditions.checkNotNull(listener, "listener");
}
@Override
public final void deliveryStalled() {
if (deliveryStalledTask != null) {
deliveryStalledTask.run();
deliveryStalledTask = null;
}
}
@Override
public final void endOfStream() {
deliveryStalled();
}
@Override
protected final ClientStreamListener listener() {
return listener;
}
/**
* Called by transport implementations when they receive headers.
*
* @param headers the parsed headers
*/
protected void inboundHeadersReceived(Metadata headers) {
Preconditions.checkState(!statusReported, "Received headers on closed stream");
headersReceived = true;
listener().headersRead(headers);
}
/**
* Processes the contents of a received data frame from the server.
*
* @param frame the received data frame. Its ownership is transferred to this method.
*/
protected void inboundDataReceived(ReadableBuffer frame) {
Preconditions.checkNotNull(frame, "frame");
boolean needToCloseFrame = true;
try {
if (statusReported) {
log.log(Level.INFO, "Received data on closed stream");
return;
}
if (!headersReceived) {
transportReportStatus(
Status.INTERNAL.withDescription("headers not received before payload"),
false, new Metadata());
return;
}
needToCloseFrame = false;
deframe(frame, false);
} finally {
if (needToCloseFrame) {
frame.close();
}
}
}
/**
* Processes the trailers and status from the server.
*
* @param trailers the received trailers
* @param status the status extracted from the trailers
*/
protected void inboundTrailersReceived(Metadata trailers, Status status) {
Preconditions.checkNotNull(status, "status");
Preconditions.checkNotNull(trailers, "trailers");
if (statusReported) {
log.log(Level.INFO, "Received trailers on closed stream:\n {1}\n {2}",
new Object[]{status, trailers});
return;
}
transportReportStatus(status, false, trailers);
}
/**
* Report stream closure with status to the application layer if not already reported. This
* method must be called from the transport thread.
*
* @param status the new status to set
* @param stopDelivery if {@code true}, interrupts any further delivery of inbound messages that
* may already be queued up in the deframer. If {@code false}, the listener will be
* notified immediately after all currently completed messages in the deframer have been
* delivered to the application.
* @param trailers new instance of {@code Trailers}, either empty or those returned by the
* server
*/
public final void transportReportStatus(final Status status, boolean stopDelivery,
final Metadata trailers) {
Preconditions.checkNotNull(status, "status");
Preconditions.checkNotNull(trailers, "trailers");
// If stopDelivery, we continue in case previous invocation is waiting for stall
if (statusReported && !stopDelivery) {
return;
}
statusReported = true;
onStreamDeallocated();
// If not stopping delivery, then we must wait until the deframer is stalled (i.e., it has no
// complete messages to deliver).
if (stopDelivery || isDeframerStalled()) {
deliveryStalledTask = null;
closeListener(status, trailers);
} else {
deliveryStalledTask = new Runnable() {
@Override
public void run() {
closeListener(status, trailers);
}
};
}
}
/**
* Closes the listener if not previously closed.
*
* @throws IllegalStateException if the call has not yet been started.
*/
private void closeListener(Status status, Metadata trailers) {
if (!listenerClosed) {
listenerClosed = true;
closeDeframer();
listener().closed(status, trailers);
}
}
}
}

View File

@ -43,7 +43,6 @@ import javax.annotation.Nullable;
* Abstract base class for {@link ServerStream} implementations. Extending classes only need to
* implement {@link #transportState()} and {@link #abstractServerStreamSink()}. Must only be called
* from the sending application thread.
*/
public abstract class AbstractServerStream extends AbstractStream2
implements ServerStream, MessageFramer.Sink {
@ -158,6 +157,11 @@ public abstract class AbstractServerStream extends AbstractStream2
abstractServerStreamSink().cancel(status);
}
@Override
public final boolean isReady() {
return super.isReady();
}
@Override public Attributes attributes() {
return Attributes.EMPTY;
}
@ -241,6 +245,7 @@ public abstract class AbstractServerStream extends AbstractStream2
private void closeListener(Status newStatus) {
if (!listenerClosed) {
listenerClosed = true;
onStreamDeallocated();
closeDeframer();
listener().closed(newStatus);
}

View File

@ -97,7 +97,7 @@ public abstract class AbstractStream2 implements Stream {
}
@Override
public final boolean isReady() {
public boolean isReady() {
if (framer().isClosed()) {
return false;
}
@ -139,6 +139,12 @@ public abstract class AbstractStream2 implements Stream {
*/
@GuardedBy("onReadyLock")
private boolean allocated;
/**
* Indicates that the stream no longer exists for the transport. Implies that the application
* should be discouraged from sending, because doing so would have no effect.
*/
@GuardedBy("onReadyLock")
private boolean deallocated;
protected TransportState(int maxMessageSize) {
deframer = new MessageDeframer(this, Codec.Identity.NONE, maxMessageSize);
@ -174,6 +180,13 @@ public abstract class AbstractStream2 implements Stream {
deframer.close();
}
/**
* Indicates whether delivery is currently stalled, pending receipt of more data.
*/
protected final boolean isDeframerStalled() {
return deframer.isStalled();
}
/**
* Called to parse a received frame and attempt delivery of any completed
* messages. Must be called from the transport thread.
@ -214,7 +227,7 @@ public abstract class AbstractStream2 implements Stream {
private boolean isReady() {
synchronized (onReadyLock) {
return allocated && numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD;
return allocated && numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD && !deallocated;
}
}
@ -233,6 +246,19 @@ public abstract class AbstractStream2 implements Stream {
notifyIfReady();
}
/**
* Notify that the stream does not exist in a usable state any longer. This causes {@link
* AbstractStream2#isReady()} to return {@code false} from this point forward.
*
* <p>This does not generally need to be called explicitly by the transport, as it is handled
* implicitly by {@link AbstractClientStream2} and {@link AbstractServerStream}.
*/
protected final void onStreamDeallocated() {
synchronized (onReadyLock) {
deallocated = true;
}
}
/**
* Event handler to be called by the subclass when a number of bytes are being queued for
* sending to the remote endpoint.
@ -256,6 +282,8 @@ public abstract class AbstractStream2 implements Stream {
public final void onSentBytes(int numBytes) {
boolean doNotify;
synchronized (onReadyLock) {
checkState(allocated,
"onStreamAllocated was not called, but it seems the stream is active");
boolean belowThresholdBefore = numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD;
numSentBytesQueued -= numBytes;
boolean belowThresholdAfter = numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD;

View File

@ -0,0 +1,251 @@
/*
* Copyright 2014, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.internal;
import com.google.common.base.Charsets;
import com.google.common.base.Preconditions;
import io.grpc.InternalMetadata;
import io.grpc.Metadata;
import io.grpc.Status;
import java.nio.charset.Charset;
import javax.annotation.Nullable;
/**
* Base implementation for client streams using HTTP2 as the transport.
*/
public abstract class Http2ClientStreamTransportState extends AbstractClientStream2.TransportState {
/**
* Metadata marshaller for HTTP status lines.
*/
private static final InternalMetadata.TrustedAsciiMarshaller<Integer> HTTP_STATUS_MARSHALLER =
new InternalMetadata.TrustedAsciiMarshaller<Integer>() {
@Override
public byte[] toAsciiString(Integer value) {
throw new UnsupportedOperationException();
}
/**
* RFC 7231 says status codes are 3 digits long.
*
* @see: <a href="https://tools.ietf.org/html/rfc7231#section-6">RFC 7231</a>
*/
@Override
public Integer parseAsciiString(byte[] serialized) {
if (serialized.length >= 3) {
return (serialized[0] - '0') * 100 + (serialized[1] - '0') * 10 + (serialized[2] - '0');
}
throw new NumberFormatException(
"Malformed status code " + new String(serialized, InternalMetadata.US_ASCII));
}
};
private static final Metadata.Key<Integer> HTTP2_STATUS = InternalMetadata.keyOf(":status",
HTTP_STATUS_MARSHALLER);
/** When non-{@code null}, {@link #transportErrorMetadata} must also be non-{@code null}. */
private Status transportError;
private Metadata transportErrorMetadata;
private Charset errorCharset = Charsets.UTF_8;
private boolean contentTypeChecked;
protected Http2ClientStreamTransportState(int maxMessageSize) {
super(maxMessageSize);
}
/**
* Called to process a failure in HTTP/2 processing. It should notify the transport to cancel the
* stream and call {@code transportReportStatus()}.
*/
protected abstract void http2ProcessingFailed(Status status, Metadata trailers);
/**
* Called by subclasses whenever {@code Headers} are received from the transport.
*
* @param headers the received headers
*/
protected void transportHeadersReceived(Metadata headers) {
Preconditions.checkNotNull(headers, "headers");
if (transportError != null) {
// Already received a transport error so just augment it.
transportError = transportError.augmentDescription(headers.toString());
return;
}
Status httpStatus = statusFromHttpStatus(headers);
if (httpStatus == null) {
transportError = Status.INTERNAL.withDescription(
"received non-terminal headers with no :status");
} else if (!httpStatus.isOk()) {
transportError = httpStatus;
} else {
transportError = checkContentType(headers);
}
if (transportError != null) {
// Note we don't immediately report the transport error, instead we wait for more data on the
// stream so we can accumulate more detail into the error before reporting it.
transportError = transportError.augmentDescription("\n" + headers);
transportErrorMetadata = headers;
errorCharset = extractCharset(headers);
} else {
stripTransportDetails(headers);
inboundHeadersReceived(headers);
}
}
/**
* Called by subclasses whenever a data frame is received from the transport.
*
* @param frame the received data frame
* @param endOfStream {@code true} if there will be no more data received for this stream
*/
protected void transportDataReceived(ReadableBuffer frame, boolean endOfStream) {
if (transportError != null) {
// We've already detected a transport error and now we're just accumulating more detail
// for it.
transportError = transportError.augmentDescription("DATA-----------------------------\n"
+ ReadableBuffers.readAsString(frame, errorCharset));
frame.close();
if (transportError.getDescription().length() > 1000 || endOfStream) {
http2ProcessingFailed(transportError, transportErrorMetadata);
}
} else {
inboundDataReceived(frame);
if (endOfStream) {
// This is a protocol violation as we expect to receive trailers.
transportError =
Status.INTERNAL.withDescription("Received unexpected EOS on DATA frame from server.");
transportErrorMetadata = new Metadata();
transportReportStatus(transportError, false, transportErrorMetadata);
}
}
}
/**
* Called by subclasses for the terminal trailer metadata on a stream.
*
* @param trailers the received terminal trailer metadata
*/
protected void transportTrailersReceived(Metadata trailers) {
Preconditions.checkNotNull(trailers, "trailers");
if (transportError != null) {
// Already received a transport error so just augment it.
transportError = transportError.augmentDescription(trailers.toString());
} else {
transportError = checkContentType(trailers);
transportErrorMetadata = trailers;
}
if (transportError != null) {
http2ProcessingFailed(transportError, transportErrorMetadata);
} else {
Status status = statusFromTrailers(trailers);
stripTransportDetails(trailers);
inboundTrailersReceived(trailers, status);
}
}
private static Status statusFromHttpStatus(Metadata metadata) {
Integer httpStatus = metadata.get(HTTP2_STATUS);
if (httpStatus != null) {
Status status = GrpcUtil.httpStatusToGrpcStatus(httpStatus);
return status.isOk() ? status
: status.augmentDescription("extracted status from HTTP :status " + httpStatus);
}
return null;
}
/**
* Extract the response status from trailers.
*/
private static Status statusFromTrailers(Metadata trailers) {
Status status = trailers.get(Status.CODE_KEY);
if (status == null) {
status = statusFromHttpStatus(trailers);
if (status == null || status.isOk()) {
status = Status.UNKNOWN.withDescription("missing GRPC status in response");
} else {
status = status.withDescription(
"missing GRPC status, inferred error from HTTP status code");
}
}
String message = trailers.get(Status.MESSAGE_KEY);
if (message != null) {
status = status.augmentDescription(message);
}
return status;
}
/**
* Inspect the content type field from received headers or trailers and return an error Status if
* content type is invalid or not present. Returns null if no error was found.
*/
@Nullable
private Status checkContentType(Metadata headers) {
if (contentTypeChecked) {
return null;
}
contentTypeChecked = true;
String contentType = headers.get(GrpcUtil.CONTENT_TYPE_KEY);
if (!GrpcUtil.isGrpcContentType(contentType)) {
return Status.INTERNAL.withDescription("Invalid content-type: " + contentType);
}
return null;
}
/**
* Inspect the raw metadata and figure out what charset is being used.
*/
private static Charset extractCharset(Metadata headers) {
String contentType = headers.get(GrpcUtil.CONTENT_TYPE_KEY);
if (contentType != null) {
String[] split = contentType.split("charset=");
try {
return Charset.forName(split[split.length - 1].trim());
} catch (Exception t) {
// Ignore and assume UTF-8
}
}
return Charsets.UTF_8;
}
/**
* Strip HTTP transport implementation details so they don't leak via metadata into
* the application layer.
*/
private static void stripTransportDetails(Metadata metadata) {
metadata.discardAll(HTTP2_STATUS);
metadata.discardAll(Status.CODE_KEY);
metadata.discardAll(Status.MESSAGE_KEY);
}
}

View File

@ -0,0 +1,279 @@
/*
* Copyright 2015, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.internal;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.verify;
import io.grpc.Codec;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.internal.MessageFramerTest.ByteWritableBuffer;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/**
* Test for {@link AbstractClientStream2}. This class tries to test functionality in
* AbstractClientStream2, but not in any super classes.
*/
@RunWith(JUnit4.class)
public class AbstractClientStream2Test {
@Rule public final ExpectedException thrown = ExpectedException.none();
@Mock private ClientStreamListener mockListener;
@Captor private ArgumentCaptor<Status> statusCaptor;
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
}
private final WritableBufferAllocator allocator = new WritableBufferAllocator() {
@Override
public WritableBuffer allocate(int capacityHint) {
return new ByteWritableBuffer(capacityHint);
}
};
@Test
public void cancel_doNotAcceptOk() {
for (Code code : Code.values()) {
ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(listener);
if (code != Code.OK) {
stream.cancel(Status.fromCodeValue(code.value()));
} else {
try {
stream.cancel(Status.fromCodeValue(code.value()));
fail();
} catch (IllegalArgumentException e) {
// ignore
}
}
}
}
@Test
public void cancel_failsOnNull() {
ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(listener);
thrown.expect(NullPointerException.class);
stream.cancel(null);
}
@Test
public void cancel_notifiesOnlyOnce() {
final BaseTransportState state = new BaseTransportState();
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, state, new BaseSink() {
@Override
public void cancel(Status errorStatus) {
// Cancel should eventually result in a transportReportStatus on the transport thread
state.transportReportStatus(errorStatus, true/*stop delivery*/, new Metadata());
}
});
stream.start(mockListener);
stream.cancel(Status.DEADLINE_EXCEEDED);
stream.cancel(Status.DEADLINE_EXCEEDED);
verify(mockListener).closed(any(Status.class), any(Metadata.class));
}
@Test
public void startFailsOnNullListener() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
thrown.expect(NullPointerException.class);
stream.start(null);
}
@Test
public void cantCallStartTwice() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
thrown.expect(IllegalStateException.class);
stream.start(mockListener);
}
@Test
public void inboundDataReceived_failsOnNullFrame() {
ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(listener);
thrown.expect(NullPointerException.class);
stream.transportState().inboundDataReceived(null);
}
@Test
public void inboundDataReceived_failsOnNoHeaders() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
stream.transportState().inboundDataReceived(ReadableBuffers.empty());
verify(mockListener).closed(statusCaptor.capture(), any(Metadata.class));
assertEquals(Code.INTERNAL, statusCaptor.getValue().getCode());
}
@Test
public void inboundHeadersReceived_notifiesListener() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
Metadata headers = new Metadata();
stream.transportState().inboundHeadersReceived(headers);
verify(mockListener).headersRead(headers);
}
@Test
public void inboundHeadersReceived_failsIfStatusReported() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
stream.transportState().transportReportStatus(Status.CANCELLED, false, new Metadata());
thrown.expect(IllegalStateException.class);
stream.transportState().inboundHeadersReceived(new Metadata());
}
@Test
public void inboundHeadersReceived_acceptsGzipEncoding() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
Metadata headers = new Metadata();
headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, new Codec.Gzip().getMessageEncoding());
stream.transportState().inboundHeadersReceived(headers);
verify(mockListener).headersRead(headers);
}
@Test
public void inboundHeadersReceived_acceptsIdentityEncoding() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
Metadata headers = new Metadata();
headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, Codec.Identity.NONE.getMessageEncoding());
stream.transportState().inboundHeadersReceived(headers);
verify(mockListener).headersRead(headers);
}
@Test
public void rstStreamClosesStream() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
// The application will call request when waiting for a message, which will in turn call this
// on the transport thread.
stream.transportState().requestMessagesFromDeframer(1);
// Send first byte of 2 byte message
stream.transportState().deframe(ReadableBuffers.wrap(new byte[] {0, 0, 0, 0, 2, 1}), false);
Status status = Status.INTERNAL;
// Simulate getting a reset
stream.transportState().transportReportStatus(status, false /*stop delivery*/, new Metadata());
verify(mockListener).closed(any(Status.class), any(Metadata.class));
}
/**
* No-op base class for testing.
*/
private static class BaseAbstractClientStream extends AbstractClientStream2 {
private final TransportState state;
private final Sink sink;
public BaseAbstractClientStream(WritableBufferAllocator allocator) {
this(allocator, new BaseTransportState(), new BaseSink());
}
public BaseAbstractClientStream(WritableBufferAllocator allocator, TransportState state,
Sink sink) {
super(allocator);
this.state = state;
this.sink = sink;
}
@Override
protected Sink abstractClientStreamSink() {
return sink;
}
@Override
public TransportState transportState() {
return state;
}
@Override
public void setAuthority(String authority) {}
}
private static class BaseSink implements AbstractClientStream2.Sink {
@Override
public void request(int numMessages) {}
@Override
public void writeFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {}
@Override
public void cancel(Status reason) {}
}
private static class BaseTransportState extends AbstractClientStream2.TransportState {
public BaseTransportState() {
super(DEFAULT_MAX_MESSAGE_SIZE);
}
@Override
protected void deframeFailed(Throwable cause) {}
@Override
public void bytesRead(int processedBytes) {}
}
}

View File

@ -39,17 +39,17 @@ import io.grpc.Status;
* Command sent from a Netty client stream to the handler to cancel the stream.
*/
class CancelClientStreamCommand extends WriteQueue.AbstractQueuedCommand {
private final NettyClientStream stream;
private final NettyClientStream.TransportState stream;
private final Status reason;
CancelClientStreamCommand(NettyClientStream stream, Status reason) {
CancelClientStreamCommand(NettyClientStream.TransportState stream, Status reason) {
this.stream = Preconditions.checkNotNull(stream, "stream");
Preconditions.checkNotNull(reason, "reason");
Preconditions.checkArgument(!reason.isOk(), "Should not cancel with OK status");
this.reason = reason;
}
NettyClientStream stream() {
NettyClientStream.TransportState stream() {
return stream;
}

View File

@ -41,15 +41,15 @@ import io.netty.handler.codec.http2.Http2Headers;
*/
class CreateStreamCommand extends WriteQueue.AbstractQueuedCommand {
private final Http2Headers headers;
private final NettyClientStream stream;
private final NettyClientStream.TransportState stream;
CreateStreamCommand(Http2Headers headers,
NettyClientStream stream) {
NettyClientStream.TransportState stream) {
this.stream = Preconditions.checkNotNull(stream, "stream");
this.headers = Preconditions.checkNotNull(headers, "headers");
}
NettyClientStream stream() {
NettyClientStream.TransportState stream() {
return stream;
}

View File

@ -246,17 +246,16 @@ class NettyClientHandler extends AbstractNettyHandler {
}
private void onHeadersRead(int streamId, Http2Headers headers, boolean endStream) {
NettyClientStream stream = clientStream(requireHttp2Stream(streamId));
NettyClientStream.TransportState stream = clientStream(requireHttp2Stream(streamId));
stream.transportHeadersReceived(headers, endStream);
}
/**
* Handler for an inbound HTTP/2 DATA frame.
*/
private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfStream) {
flowControlPing().onDataRead(data.readableBytes(), padding);
NettyClientStream stream = clientStream(requireHttp2Stream(streamId));
NettyClientStream.TransportState stream = clientStream(requireHttp2Stream(streamId));
stream.transportDataReceived(data, endOfStream);
}
@ -265,7 +264,7 @@ class NettyClientHandler extends AbstractNettyHandler {
* Handler for an inbound HTTP/2 RST_STREAM frame, terminating a stream.
*/
private void onRstStreamRead(int streamId, long errorCode) {
NettyClientStream stream = clientStream(connection().stream(streamId));
NettyClientStream.TransportState stream = clientStream(connection().stream(streamId));
if (stream != null) {
Status status = GrpcUtil.Http2Error.statusForCode((int) errorCode)
.augmentDescription("Received Rst Stream");
@ -295,7 +294,7 @@ class NettyClientHandler extends AbstractNettyHandler {
connection().forEachActiveStream(new Http2StreamVisitor() {
@Override
public boolean visit(Http2Stream stream) throws Http2Exception {
NettyClientStream clientStream = clientStream(stream);
NettyClientStream.TransportState clientStream = clientStream(stream);
if (clientStream != null) {
clientStream.transportReportStatus(
lifecycleManager.getShutdownStatus(), false, new Metadata());
@ -322,7 +321,7 @@ class NettyClientHandler extends AbstractNettyHandler {
protected void onStreamError(ChannelHandlerContext ctx, Throwable cause,
Http2Exception.StreamException http2Ex) {
// Close the stream with a status that contains the cause.
NettyClientStream stream = clientStream(connection().stream(http2Ex.streamId()));
NettyClientStream.TransportState stream = clientStream(connection().stream(http2Ex.streamId()));
if (stream != null) {
stream.transportReportStatus(Utils.statusFromThrowable(cause), false, new Metadata());
} else {
@ -370,9 +369,9 @@ class NettyClientHandler extends AbstractNettyHandler {
return;
}
final NettyClientStream stream = command.stream();
final NettyClientStream.TransportState stream = command.stream();
final Http2Headers headers = command.headers();
stream.id(streamId);
stream.setId(streamId);
// Create an intermediate promise so that we can intercept the failure reported back to the
// application.
@ -418,7 +417,7 @@ class NettyClientHandler extends AbstractNettyHandler {
*/
private void cancelStream(ChannelHandlerContext ctx, CancelClientStreamCommand cmd,
ChannelPromise promise) {
NettyClientStream stream = cmd.stream();
NettyClientStream.TransportState stream = cmd.stream();
stream.transportReportStatus(cmd.reason(), true, new Metadata());
encoder().writeRstStream(ctx, stream.id(), Http2Error.CANCEL.code(), promise);
}
@ -507,7 +506,7 @@ class NettyClientHandler extends AbstractNettyHandler {
connection().forEachActiveStream(new Http2StreamVisitor() {
@Override
public boolean visit(Http2Stream stream) throws Http2Exception {
NettyClientStream clientStream = clientStream(stream);
NettyClientStream.TransportState clientStream = clientStream(stream);
if (clientStream != null) {
clientStream.transportReportStatus(msg.getStatus(), true, new Metadata());
resetStream(ctx, stream.id(), Http2Error.CANCEL.code(), ctx.newPromise());
@ -531,7 +530,7 @@ class NettyClientHandler extends AbstractNettyHandler {
@Override
public boolean visit(Http2Stream stream) throws Http2Exception {
if (stream.id() > lastKnownStream) {
NettyClientStream clientStream = clientStream(stream);
NettyClientStream.TransportState clientStream = clientStream(stream);
if (clientStream != null) {
clientStream.transportReportStatus(goAwayStatus, false, new Metadata());
}
@ -566,8 +565,8 @@ class NettyClientHandler extends AbstractNettyHandler {
/**
* Gets the client stream associated to the given HTTP/2 stream object.
*/
private NettyClientStream clientStream(Http2Stream stream) {
return stream == null ? null : (NettyClientStream) stream.getProperty(streamKey);
private NettyClientStream.TransportState clientStream(Http2Stream stream) {
return stream == null ? null : (NettyClientStream.TransportState) stream.getProperty(streamKey);
}
private int incrementAndGetNextStreamId() throws StatusException {

View File

@ -41,9 +41,10 @@ import io.grpc.InternalMethodDescriptor;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.internal.AbstractClientStream2;
import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.Http2ClientStream;
import io.grpc.internal.Http2ClientStreamTransportState;
import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
@ -56,43 +57,51 @@ import io.netty.util.AsciiString;
import javax.annotation.Nullable;
/**
* Client stream for a Netty transport.
* Client stream for a Netty transport. Must only be called from the sending application
* thread.
*/
abstract class NettyClientStream extends Http2ClientStream implements StreamIdHolder {
class NettyClientStream extends AbstractClientStream2 {
private static final InternalMethodDescriptor methodDescriptorAccessor =
new InternalMethodDescriptor(InternalKnownTransport.NETTY);
private final Sink sink = new Sink();
private final TransportState state;
private final WriteQueue writeQueue;
private final MethodDescriptor<?, ?> method;
/** {@code null} after start. */
private Metadata headers;
private final Channel channel;
private final NettyClientHandler handler;
private AsciiString authority;
private final AsciiString scheme;
private final AsciiString userAgent;
private AsciiString authority;
private Http2Stream http2Stream;
private int id;
private WriteQueue writeQueue;
NettyClientStream(MethodDescriptor<?, ?> method, Metadata headers, Channel channel,
NettyClientHandler handler, int maxMessageSize, AsciiString authority, AsciiString scheme,
NettyClientStream(TransportState state, MethodDescriptor<?, ?> method, Metadata headers,
Channel channel, AsciiString authority, AsciiString scheme,
AsciiString userAgent) {
super(new NettyWritableBufferAllocator(channel.alloc()), maxMessageSize);
super(new NettyWritableBufferAllocator(channel.alloc()));
this.state = checkNotNull(state, "transportState");
this.writeQueue = state.handler.getWriteQueue();
this.method = checkNotNull(method, "method");
this.headers = checkNotNull(headers, "headers");
this.writeQueue = handler.getWriteQueue();
this.channel = checkNotNull(channel, "channel");
this.handler = checkNotNull(handler, "handler");
this.authority = checkNotNull(authority, "authority");
this.scheme = checkNotNull(scheme, "scheme");
this.userAgent = userAgent;
}
@Override
protected TransportState transportState() {
return state;
}
@Override
protected Sink abstractClientStreamSink() {
return sink;
}
@Override
public void setAuthority(String authority) {
checkState(listener() == null, "must be call before start");
checkState(headers != null, "must be call before start");
this.authority = AsciiString.of(checkNotNull(authority, "authority"));
}
@ -116,114 +125,134 @@ abstract class NettyClientStream extends Http2ClientStream implements StreamIdHo
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
// Stream creation failed. Close the stream if not already closed.
Status s = statusFromFailedFuture(future);
transportReportStatus(s, true, new Metadata());
Status s = transportState().statusFromFailedFuture(future);
transportState().transportReportStatus(s, true, new Metadata());
}
}
};
// Write the command requesting the creation of the stream.
writeQueue.enqueue(new CreateStreamCommand(http2Headers, this),
writeQueue.enqueue(new CreateStreamCommand(http2Headers, transportState()),
!method.getType().clientSendsOneMessage()).addListener(failureListener);
}
@Override
public void transportReportStatus(Status newStatus, boolean stopDelivery, Metadata trailers) {
super.transportReportStatus(newStatus, stopDelivery, trailers);
}
private class Sink implements AbstractClientStream2.Sink {
@Override
public void writeFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
ByteBuf bytebuf = frame == null ? EMPTY_BUFFER : ((NettyWritableBuffer) frame).bytebuf();
final int numBytes = bytebuf.readableBytes();
if (numBytes > 0) {
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
writeQueue.enqueue(
new SendGrpcFrameCommand(transportState(), bytebuf, endOfStream),
channel.newPromise().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
transportState().onSentBytes(numBytes);
}
}), flush);
} else {
// The frame is empty and will not impact outbound flow control. Just send it.
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, endOfStream), flush);
}
}
/**
* Intended to be overriden by NettyClientTransport, which has more information about failures.
* May only be called from event loop.
*/
protected abstract Status statusFromFailedFuture(ChannelFuture f);
@Override
public void request(int numMessages) {
if (channel.eventLoop().inEventLoop()) {
// Processing data read in the event loop so can call into the deframer immediately
transportState().requestMessagesFromDeframer(numMessages);
} else {
writeQueue.enqueue(new RequestMessagesCommand(transportState(), numMessages), true);
}
}
@Override
public void request(int numMessages) {
if (channel.eventLoop().inEventLoop()) {
// Processing data read in the event loop so can call into the deframer immediately
requestMessagesFromDeframer(numMessages);
} else {
writeQueue.enqueue(new RequestMessagesCommand(this, numMessages), true);
@Override
public void cancel(Status status) {
writeQueue.enqueue(new CancelClientStreamCommand(transportState(), status), true);
}
}
@Override
public int id() {
return id;
}
/** This should only called from the transport thread. */
public abstract static class TransportState extends Http2ClientStreamTransportState
implements StreamIdHolder {
private final NettyClientHandler handler;
private int id;
private Http2Stream http2Stream;
public void id(int id) {
checkArgument(id != ABSENT_ID, "Can't use absent id");
this.id = id;
}
/**
* Sets the underlying Netty {@link Http2Stream} for this stream. This must be called in the
* context of the transport thread.
*/
public void setHttp2Stream(Http2Stream http2Stream) {
checkNotNull(http2Stream, "http2Stream");
checkState(this.http2Stream == null, "Can only set http2Stream once");
this.http2Stream = http2Stream;
// Now that the stream has actually been initialized, call the listener's onReady callback if
// appropriate.
onStreamAllocated();
}
/**
* Gets the underlying Netty {@link Http2Stream} for this stream.
*/
@Nullable
public Http2Stream http2Stream() {
return http2Stream;
}
void transportHeadersReceived(Http2Headers headers, boolean endOfStream) {
if (endOfStream) {
transportTrailersReceived(Utils.convertTrailers(headers));
} else {
transportHeadersReceived(Utils.convertHeaders(headers));
public TransportState(NettyClientHandler handler, int maxMessageSize) {
super(maxMessageSize);
this.handler = checkNotNull(handler, "handler");
}
}
void transportDataReceived(ByteBuf frame, boolean endOfStream) {
transportDataReceived(new NettyReadableBuffer(frame.retain()), endOfStream);
}
@Override
protected void sendCancel(Status reason) {
// Send the cancel command to the handler.
writeQueue.enqueue(new CancelClientStreamCommand(this, reason), true);
}
@Override
protected void sendFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
ByteBuf bytebuf = frame == null ? EMPTY_BUFFER : ((NettyWritableBuffer) frame).bytebuf();
final int numBytes = bytebuf.readableBytes();
if (numBytes > 0) {
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
writeQueue.enqueue(
new SendGrpcFrameCommand(this, bytebuf, endOfStream),
channel.newPromise().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
onSentBytes(numBytes);
}
}), flush);
} else {
// The frame is empty and will not impact outbound flow control. Just send it.
writeQueue.enqueue(new SendGrpcFrameCommand(this, bytebuf, endOfStream), flush);
@Override
public int id() {
return id;
}
}
@Override
protected void returnProcessedBytes(int processedBytes) {
handler.returnProcessedBytes(http2Stream, processedBytes);
writeQueue.scheduleFlush();
public void setId(int id) {
checkArgument(id > 0, "id must be positive");
this.id = id;
}
/**
* Sets the underlying Netty {@link Http2Stream} for this stream. This must be called in the
* context of the transport thread.
*/
public void setHttp2Stream(Http2Stream http2Stream) {
checkNotNull(http2Stream, "http2Stream");
checkState(this.http2Stream == null, "Can only set http2Stream once");
this.http2Stream = http2Stream;
// Now that the stream has actually been initialized, call the listener's onReady callback if
// appropriate.
onStreamAllocated();
}
/**
* Gets the underlying Netty {@link Http2Stream} for this stream.
*/
@Nullable
public Http2Stream http2Stream() {
return http2Stream;
}
/**
* Intended to be overriden by NettyClientTransport, which has more information about failures.
* May only be called from event loop.
*/
protected abstract Status statusFromFailedFuture(ChannelFuture f);
@Override
protected void http2ProcessingFailed(Status status, Metadata trailers) {
transportReportStatus(status, false, trailers);
handler.getWriteQueue().enqueue(new CancelClientStreamCommand(this, status), true);
}
@Override
public void bytesRead(int processedBytes) {
handler.returnProcessedBytes(http2Stream, processedBytes);
handler.getWriteQueue().scheduleFlush();
}
@Override
protected void deframeFailed(Throwable cause) {
http2ProcessingFailed(Status.fromThrowable(cause), new Metadata());
}
void transportHeadersReceived(Http2Headers headers, boolean endOfStream) {
if (endOfStream) {
transportTrailersReceived(Utils.convertTrailers(headers));
} else {
transportHeadersReceived(Utils.convertHeaders(headers));
}
}
void transportDataReceived(ByteBuf frame, boolean endOfStream) {
transportDataReceived(new NettyReadableBuffer(frame.retain()), endOfStream);
}
}
}

View File

@ -120,13 +120,14 @@ class NettyClientTransport implements ConnectionClientTransport {
callOptions) {
Preconditions.checkNotNull(method, "method");
Preconditions.checkNotNull(headers, "headers");
return new NettyClientStream(method, headers, channel, handler, maxMessageSize, authority,
negotiationHandler.scheme(), userAgent) {
@Override
protected Status statusFromFailedFuture(ChannelFuture f) {
return NettyClientTransport.this.statusFromFailedFuture(f);
}
};
return new NettyClientStream(
new NettyClientStream.TransportState(handler, maxMessageSize) {
@Override
protected Status statusFromFailedFuture(ChannelFuture f) {
return NettyClientTransport.this.statusFromFailedFuture(f);
}
},
method, headers, channel, authority, negotiationHandler.scheme(), userAgent);
}
@Override

View File

@ -32,7 +32,6 @@
package io.grpc.netty;
import io.grpc.internal.AbstractStream2;
import io.grpc.internal.Stream;
/**
* Command which requests messages from the deframer.
@ -40,26 +39,14 @@ import io.grpc.internal.Stream;
class RequestMessagesCommand extends WriteQueue.AbstractQueuedCommand {
private final int numMessages;
private final Stream stream;
private final AbstractStream2.TransportState state;
public RequestMessagesCommand(Stream stream, int numMessages) {
this.state = null;
this.numMessages = numMessages;
this.stream = stream;
}
public RequestMessagesCommand(AbstractStream2.TransportState state, int numMessages) {
this.state = state;
this.numMessages = numMessages;
this.stream = null;
}
void requestMessages() {
if (stream != null) {
stream.request(numMessages);
} else {
state.requestMessagesFromDeframer(numMessages);
}
state.requestMessagesFromDeframer(numMessages);
}
}

View File

@ -32,6 +32,7 @@
package io.grpc.netty;
import static com.google.common.base.Charsets.UTF_8;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
import static io.grpc.netty.Utils.HTTPS;
@ -41,6 +42,7 @@ import static io.grpc.netty.Utils.TE_HEADER;
import static io.grpc.netty.Utils.TE_TRAILERS;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
@ -49,17 +51,16 @@ import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.notNull;
import static org.mockito.Mockito.calls;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.base.Ticker;
import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.ClientTransport;
import io.grpc.internal.ClientTransport.PingCallback;
import io.grpc.internal.GrpcUtil;
@ -87,19 +88,17 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.io.InputStream;
/**
* Tests for {@link NettyClientHandler}.
*/
@RunWith(JUnit4.class)
public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHandler> {
// TODO(zhangkun83): mocking concrete classes is not safe. Consider making NettyClientStream an
// interface.
@Mock
private NettyClientStream stream;
private NettyClientStream.TransportState streamTransportState;
private Http2Headers grpcHeaders;
private long nanoTime; // backs a ticker, for testing ping round-trip time measurement
private int flowControlWindow = DEFAULT_WINDOW_SIZE;
@ -108,6 +107,8 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
@Mock
private NettyClientTransport.Listener listener;
@Mock
private ClientStreamListener streamListener;
/**
* Set up for test.
@ -118,6 +119,8 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
lifecycleManager = new ClientTransportLifecycleManager(listener);
initChannel(new GrpcHttp2ClientHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE));
streamTransportState = new TransportStateImpl(handler(), DEFAULT_MAX_MESSAGE_SIZE);
streamTransportState.setListener(streamListener);
grpcHeaders = new DefaultHttp2Headers()
.scheme(HTTPS)
@ -138,15 +141,14 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
// Force the stream to be buffered.
receiveMaxConcurrentStreams(0);
// Create a new stream with id 3.
ChannelFuture createFuture = enqueue(new CreateStreamCommand(grpcHeaders, stream));
verify(stream).id(eq(3));
when(stream.id()).thenReturn(3);
ChannelFuture createFuture = enqueue(
new CreateStreamCommand(grpcHeaders, streamTransportState));
assertEquals(3, streamTransportState.id());
// Cancel the stream.
cancelStream(Status.CANCELLED);
assertTrue(createFuture.isSuccess());
verify(stream).transportReportStatus(eq(Status.CANCELLED), eq(true),
any(Metadata.class));
verify(streamListener).closed(eq(Status.CANCELLED), any(Metadata.class));
}
@Test
@ -224,7 +226,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
createStream();
// Send a frame and verify that it was written.
ChannelFuture future = enqueue(new SendGrpcFrameCommand(stream, content(), true));
ChannelFuture future = enqueue(new SendGrpcFrameCommand(streamTransportState, content(), true));
assertTrue(future.isSuccess());
verifyWrite().writeData(eq(ctx()), eq(3), eq(content()), eq(0), eq(true),
@ -233,43 +235,43 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
@Test
public void sendForUnknownStreamShouldFail() throws Exception {
when(stream.id()).thenReturn(3);
ChannelFuture future = enqueue(new SendGrpcFrameCommand(stream, content(), true));
ChannelFuture future = enqueue(new SendGrpcFrameCommand(streamTransportState, content(), true));
assertTrue(future.isDone());
assertFalse(future.isSuccess());
}
@Test
public void inboundHeadersShouldForwardToStream() throws Exception {
public void inboundShouldForwardToStream() throws Exception {
createStream();
// Read a headers frame first.
Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK)
.set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC);
.set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC)
.set(as("magic"), as("value"));
ByteBuf headersFrame = headersFrame(3, headers);
channelRead(headersFrame);
verify(stream).transportHeadersReceived(headers, false);
}
ArgumentCaptor<Metadata> captor = ArgumentCaptor.forClass(Metadata.class);
verify(streamListener).headersRead(captor.capture());
assertEquals("value",
captor.getValue().get(Metadata.Key.of("magic", Metadata.ASCII_STRING_MARSHALLER)));
@Test
public void inboundDataShouldForwardToStream() throws Exception {
ByteBuf data = content().copy();
createStream();
streamTransportState.requestMessagesFromDeframer(1);
// Create a data frame and then trigger the handler to read it.
// Need to retain to simulate what is done by the stream.
ByteBuf frame = dataFrame(3, false).retain();
ByteBuf frame = grpcDataFrame(3, false, contentAsArray());
channelRead(frame);
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
verify(stream).transportDataReceived(captor.capture(), eq(false));
assertTrue(ByteBufUtil.equals(data, captor.getValue()));
ArgumentCaptor<InputStream> isCaptor = ArgumentCaptor.forClass(InputStream.class);
verify(streamListener).messageRead(isCaptor.capture());
assertArrayEquals(ByteBufUtil.getBytes(content()),
ByteStreams.toByteArray(isCaptor.getValue()));
isCaptor.getValue().close();
}
@Test
public void receivedGoAwayShouldCancelBufferedStream() throws Exception {
// Force the stream to be buffered.
receiveMaxConcurrentStreams(0);
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, stream));
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
channelRead(goAwayFrame(0));
assertTrue(future.isDone());
assertFalse(future.isSuccess());
@ -280,13 +282,12 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
@Test
public void receivedGoAwayShouldFailUnknownStreams() throws Exception {
enqueue(new CreateStreamCommand(grpcHeaders, stream));
enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
// Read a GOAWAY that indicates our stream was never processed by the server.
channelRead(goAwayFrame(0, 8 /* Cancel */, Unpooled.copiedBuffer("this is a test", UTF_8)));
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(stream).transportReportStatus(captor.capture(), eq(false),
notNull(Metadata.class));
verify(streamListener).closed(captor.capture(), notNull(Metadata.class));
assertEquals(Status.CANCELLED.getCode(), captor.getValue().getCode());
assertEquals("HTTP/2 error code: CANCEL\nReceived Goaway\nthis is a test",
captor.getValue().getDescription());
@ -296,7 +297,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
public void receivedGoAwayShouldFailUnknownBufferedStreams() throws Exception {
receiveMaxConcurrentStreams(0);
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, stream));
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
// Read a GOAWAY that indicates our stream was never processed by the server.
channelRead(goAwayFrame(0, 8 /* Cancel */, Unpooled.copiedBuffer("this is a test", UTF_8)));
@ -314,7 +315,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
channelRead(goAwayFrame(0, 8 /* Cancel */, Unpooled.copiedBuffer("this is a test", UTF_8)));
// Now try to create a stream.
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, stream));
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
assertTrue(future.isDone());
assertFalse(future.isSuccess());
Status status = Status.fromThrowable(future.cause());
@ -326,19 +327,17 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
@Test
public void cancelStreamShouldCreateAndThenFailBufferedStream() throws Exception {
receiveMaxConcurrentStreams(0);
enqueue(new CreateStreamCommand(grpcHeaders, stream));
verify(stream).id(3);
when(stream.id()).thenReturn(3);
enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
assertEquals(3, streamTransportState.id());
cancelStream(Status.CANCELLED);
verify(stream).transportReportStatus(eq(Status.CANCELLED), eq(true),
any(Metadata.class));
verify(streamListener).closed(eq(Status.CANCELLED), any(Metadata.class));
}
@Test
public void channelShutdownShouldCancelBufferedStreams() throws Exception {
// Force a stream to get added to the pending queue.
receiveMaxConcurrentStreams(0);
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, stream));
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
handler().channelInactive(ctx());
assertTrue(future.isDone());
@ -351,9 +350,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
handler().channelInactive(ctx());
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
InOrder inOrder = inOrder(stream);
inOrder.verify(stream, calls(1)).transportReportStatus(captor.capture(), eq(false),
notNull(Metadata.class));
verify(streamListener).closed(captor.capture(), notNull(Metadata.class));
assertEquals(Status.UNAVAILABLE.getCode(), captor.getValue().getCode());
}
@ -375,12 +372,18 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
@Test
public void createIncrementsIdsForActualAndBufferdStreams() throws Exception {
receiveMaxConcurrentStreams(2);
enqueue(new CreateStreamCommand(grpcHeaders, stream));
verify(stream).id(eq(3));
enqueue(new CreateStreamCommand(grpcHeaders, stream));
verify(stream).id(eq(5));
enqueue(new CreateStreamCommand(grpcHeaders, stream));
verify(stream).id(eq(7));
enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
assertEquals(3, streamTransportState.id());
streamTransportState = new TransportStateImpl(handler(), DEFAULT_MAX_MESSAGE_SIZE);
streamTransportState.setListener(streamListener);
enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
assertEquals(5, streamTransportState.id());
streamTransportState = new TransportStateImpl(handler(), DEFAULT_MAX_MESSAGE_SIZE);
streamTransportState.setListener(streamListener);
enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
assertEquals(7, streamTransportState.id());
}
@Test
@ -467,7 +470,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
verifyWrite().writePing(eq(ctx()), eq(false), captor.capture(), any(ChannelPromise.class));
ByteBuf payload = captor.getValue();
channelRead(dataFrame(3, false));
channelRead(grpcDataFrame(3, false, contentAsArray()));
long pingData = handler().flowControlPing().payload();
ByteBuf buffer = handler().ctx().alloc().buffer(8);
buffer.writeLong(pingData);
@ -507,18 +510,13 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
channelRead(serializedSettings);
}
private ByteBuf dataFrame(int streamId, boolean endStream) {
return dataFrame(streamId, endStream, content());
}
private ChannelFuture createStream() throws Exception {
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, stream));
when(stream.id()).thenReturn(3);
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
return future;
}
private ChannelFuture cancelStream(Status status) throws Exception {
return enqueue(new CancelClientStreamCommand(stream, status));
return enqueue(new CancelClientStreamCommand(streamTransportState, status));
}
@Override
@ -567,4 +565,15 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
this.failureCause = cause;
}
}
class TransportStateImpl extends NettyClientStream.TransportState {
public TransportStateImpl(NettyClientHandler handler, int maxMessageSize) {
super(handler, maxMessageSize);
}
@Override
protected Status statusFromFailedFuture(ChannelFuture f) {
return Utils.statusFromThrowable(f.cause());
}
}
}

View File

@ -51,6 +51,7 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableListMultimap;
@ -63,7 +64,6 @@ import io.grpc.internal.GrpcUtil;
import io.grpc.netty.WriteQueue.QueuedCommand;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
@ -109,16 +109,15 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test
public void closeShouldSucceed() {
// Force stream creation.
stream().id(STREAM_ID);
stream().transportState().setId(STREAM_ID);
stream().halfClose();
assertTrue(stream().canReceive());
assertFalse(stream().canSend());
verifyNoMoreInteractions(listener);
}
@Test
public void cancelShouldSendCommand() {
// Set stream id to indicate it has been created
stream().id(STREAM_ID);
stream().transportState().setId(STREAM_ID);
stream().cancel(Status.CANCELLED);
ArgumentCaptor<CancelClientStreamCommand> commandCaptor =
ArgumentCaptor.forClass(CancelClientStreamCommand.class);
@ -129,7 +128,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test
public void deadlineExceededCancelShouldSendCommand() {
// Set stream id to indicate it has been created
stream().id(STREAM_ID);
stream().transportState().setId(STREAM_ID);
stream().cancel(Status.DEADLINE_EXCEEDED);
ArgumentCaptor<CancelClientStreamCommand> commandCaptor =
ArgumentCaptor.forClass(CancelClientStreamCommand.class);
@ -146,12 +145,12 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test
public void writeMessageShouldSendRequest() throws Exception {
// Force stream creation.
stream().id(STREAM_ID);
stream().transportState().setId(STREAM_ID);
byte[] msg = smallMessage();
stream.writeMessage(new ByteArrayInputStream(msg));
stream.flush();
verify(writeQueue).enqueue(
eq(new SendGrpcFrameCommand(stream, messageFrame(MESSAGE), false)),
eq(new SendGrpcFrameCommand(stream.transportState(), messageFrame(MESSAGE), false)),
any(ChannelPromise.class),
eq(true));
}
@ -159,112 +158,109 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test
public void writeMessageShouldSendRequestUnknownLength() throws Exception {
// Force stream creation.
stream().id(STREAM_ID);
stream().transportState().setId(STREAM_ID);
byte[] msg = smallMessage();
stream.writeMessage(new BufferedInputStream(new ByteArrayInputStream(msg)));
stream.flush();
// Two writes occur, one for the GRPC frame header and the second with the payload
verify(writeQueue).enqueue(
eq(new SendGrpcFrameCommand(stream, messageFrame(MESSAGE).slice(0, 5), false)),
eq(new SendGrpcFrameCommand(
stream.transportState(), messageFrame(MESSAGE).slice(0, 5), false)),
any(ChannelPromise.class),
eq(false));
verify(writeQueue).enqueue(
eq(new SendGrpcFrameCommand(stream, messageFrame(MESSAGE).slice(5, 11), false)),
eq(new SendGrpcFrameCommand(
stream.transportState(), messageFrame(MESSAGE).slice(5, 11), false)),
any(ChannelPromise.class),
eq(true));
}
@Test
public void setStatusWithOkShouldCloseStream() {
stream().id(STREAM_ID);
stream().transportReportStatus(Status.OK, true, new Metadata());
stream().transportState().setId(STREAM_ID);
stream().transportState().transportReportStatus(Status.OK, true, new Metadata());
verify(listener).closed(same(Status.OK), any(Metadata.class));
assertTrue(stream.isClosed());
}
@Test
public void setStatusWithErrorShouldCloseStream() {
Status errorStatus = Status.INTERNAL;
stream().transportReportStatus(errorStatus, true, new Metadata());
stream().transportState().transportReportStatus(errorStatus, true, new Metadata());
verify(listener).closed(eq(errorStatus), any(Metadata.class));
assertTrue(stream.isClosed());
}
@Test
public void setStatusWithOkShouldNotOverrideError() {
Status errorStatus = Status.INTERNAL;
stream().transportReportStatus(errorStatus, true, new Metadata());
stream().transportReportStatus(Status.OK, true, new Metadata());
stream().transportState().transportReportStatus(errorStatus, true, new Metadata());
stream().transportState().transportReportStatus(Status.OK, true, new Metadata());
verify(listener).closed(any(Status.class), any(Metadata.class));
assertTrue(stream.isClosed());
}
@Test
public void setStatusWithErrorShouldNotOverridePreviousError() {
Status errorStatus = Status.INTERNAL;
stream().transportReportStatus(errorStatus, true, new Metadata());
stream().transportReportStatus(Status.fromThrowable(new RuntimeException("fake")), true,
new Metadata());
stream().transportState().transportReportStatus(errorStatus, true, new Metadata());
stream().transportState().transportReportStatus(
Status.fromThrowable(new RuntimeException("fake")), true, new Metadata());
verify(listener).closed(any(Status.class), any(Metadata.class));
assertTrue(stream.isClosed());
}
@Override
@Test
public void inboundMessageShouldCallListener() throws Exception {
// Receive headers first so that it's a valid GRPC response.
stream().id(STREAM_ID);
stream().transportHeadersReceived(grpcResponseHeaders(), false);
stream().transportState().setId(STREAM_ID);
stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
super.inboundMessageShouldCallListener();
}
@Test
public void inboundHeadersShouldCallListenerHeadersRead() throws Exception {
stream().id(STREAM_ID);
stream().transportState().setId(STREAM_ID);
Http2Headers headers = grpcResponseHeaders();
stream().transportHeadersReceived(headers, false);
stream().transportState().transportHeadersReceived(headers, false);
verify(listener).headersRead(any(Metadata.class));
}
@Test
public void inboundTrailersClosesCall() throws Exception {
stream().id(STREAM_ID);
stream().transportHeadersReceived(grpcResponseHeaders(), false);
stream().transportState().setId(STREAM_ID);
stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
super.inboundMessageShouldCallListener();
stream().transportHeadersReceived(grpcResponseTrailers(Status.OK), true);
stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.OK), true);
}
@Test
public void inboundStatusShouldSetStatus() throws Exception {
stream().id(STREAM_ID);
stream().transportState().setId(STREAM_ID);
// Receive headers first so that it's a valid GRPC response.
stream().transportHeadersReceived(grpcResponseHeaders(), false);
stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
stream().transportHeadersReceived(grpcResponseTrailers(Status.INTERNAL), true);
stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.INTERNAL), true);
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(listener).closed(captor.capture(), any(Metadata.class));
assertEquals(Status.INTERNAL.getCode(), captor.getValue().getCode());
assertTrue(stream.isClosed());
}
@Test
public void invalidInboundHeadersCancelStream() throws Exception {
stream().id(STREAM_ID);
stream().transportState().setId(STREAM_ID);
Http2Headers headers = grpcResponseHeaders();
headers.set("random", "4");
headers.remove(CONTENT_TYPE_HEADER);
// Remove once b/16290036 is fixed.
headers.status(new AsciiString("500"));
stream().transportHeadersReceived(headers, false);
stream().transportState().transportHeadersReceived(headers, false);
verify(listener, never()).closed(any(Status.class), any(Metadata.class));
// We are now waiting for 100 bytes of error context on the stream, cancel has not yet been
// sent
verify(channel, never()).writeAndFlush(any(CancelClientStreamCommand.class));
stream().transportDataReceived(Unpooled.buffer(100).writeZero(100), false);
stream().transportState().transportDataReceived(Unpooled.buffer(100).writeZero(100), false);
verify(channel, never()).writeAndFlush(any(CancelClientStreamCommand.class));
stream().transportDataReceived(Unpooled.buffer(1000).writeZero(1000), false);
stream().transportState().transportDataReceived(Unpooled.buffer(1000).writeZero(1000), false);
// Now verify that cancel is sent and an error is reported to the listener
verify(writeQueue).enqueue(any(CancelClientStreamCommand.class), eq(true));
@ -274,20 +270,19 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
assertEquals(Status.UNKNOWN.getCode(), captor.getValue().getCode());
assertEquals("4", metadataCaptor.getValue()
.get(Metadata.Key.of("random", Metadata.ASCII_STRING_MARSHALLER)));
assertTrue(stream.isClosed());
}
@Test
public void invalidInboundContentTypeShouldCancelStream() {
// Set stream id to indicate it has been created
stream().id(STREAM_ID);
stream().transportState().setId(STREAM_ID);
Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK).set(CONTENT_TYPE_HEADER,
new AsciiString("application/bad", UTF_8));
stream().transportHeadersReceived(headers, false);
stream().transportState().transportHeadersReceived(headers, false);
Http2Headers trailers = new DefaultHttp2Headers()
.set(new AsciiString("grpc-status", UTF_8), new AsciiString("0", UTF_8));
stream().transportHeadersReceived(trailers, true);
stream().transportState().transportHeadersReceived(trailers, true);
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
verify(listener).closed(captor.capture(), metadataCaptor.capture());
@ -300,7 +295,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test
public void nonGrpcResponseShouldSetStatus() throws Exception {
stream().transportDataReceived(Unpooled.copiedBuffer(MESSAGE, UTF_8), true);
stream().transportState().transportDataReceived(Unpooled.copiedBuffer(MESSAGE, UTF_8), true);
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(listener).closed(captor.capture(), any(Metadata.class));
assertEquals(Status.Code.INTERNAL, captor.getValue().getCode());
@ -308,13 +303,13 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test
public void deframedDataAfterCancelShouldBeIgnored() throws Exception {
stream().id(STREAM_ID);
stream().transportState().setId(STREAM_ID);
// Receive headers first so that it's a valid GRPC response.
stream().transportHeadersReceived(grpcResponseHeaders(), false);
stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
// Receive 2 consecutive empty frames. Only one is delivered at a time to the listener.
stream().transportDataReceived(simpleGrpcFrame(), false);
stream().transportDataReceived(simpleGrpcFrame(), false);
stream().transportState().transportDataReceived(simpleGrpcFrame(), false);
stream().transportState().transportDataReceived(simpleGrpcFrame(), false);
// Only allow the first to be delivered.
stream().request(1);
@ -323,14 +318,14 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
// data frames have been processed. Since cancellation will interrupt message delivery,
// this status will never be processed and the listener will instead only see the
// cancellation.
stream().transportHeadersReceived(grpcResponseTrailers(Status.INTERNAL), true);
stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.INTERNAL), true);
// Verify that the first was delivered.
verify(listener).messageRead(any(InputStream.class));
// Now set the error status.
Metadata trailers = Utils.convertTrailers(grpcResponseTrailers(Status.CANCELLED));
stream().transportReportStatus(Status.CANCELLED, true, trailers);
stream().transportState().transportReportStatus(Status.CANCELLED, true, trailers);
// Now allow the delivery of the second.
stream().request(1);
@ -342,14 +337,14 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test
public void dataFrameWithEosShouldDeframeAndThenFail() {
stream().id(STREAM_ID);
stream().transportState().setId(STREAM_ID);
stream().request(1);
// Receive headers first so that it's a valid GRPC response.
stream().transportHeadersReceived(grpcResponseHeaders(), false);
stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
// Receive a DATA frame with EOS set.
stream().transportDataReceived(simpleGrpcFrame(), true);
stream().transportState().transportDataReceived(simpleGrpcFrame(), true);
// Verify that the message was delivered.
verify(listener).messageRead(any(InputStream.class));
@ -363,14 +358,14 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
public void setHttp2StreamShouldNotifyReady() {
listener = mock(ClientStreamListener.class);
stream = new NettyClientStreamImpl(methodDescriptor, new Metadata(), channel, handler,
DEFAULT_MAX_MESSAGE_SIZE, AsciiString.of("localhost"), AsciiString.of("http"),
AsciiString.of("agent"));
stream = new NettyClientStream(new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE),
methodDescriptor, new Metadata(), channel, AsciiString.of("localhost"),
AsciiString.of("http"), AsciiString.of("agent"));
stream.start(listener);
stream().id(STREAM_ID);
stream().transportState().setId(STREAM_ID);
verify(listener, never()).onReady();
assertFalse(stream.isReady());
stream().setHttp2Stream(http2Stream);
stream().transportState().setHttp2Stream(http2Stream);
verify(listener).onReady();
assertTrue(stream.isReady());
}
@ -383,9 +378,9 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
Mockito.reset(writeQueue);
when(writeQueue.enqueue(any(QueuedCommand.class), any(boolean.class))).thenReturn(future);
stream = new NettyClientStreamImpl(methodDescriptor, new Metadata(), channel, handler,
DEFAULT_MAX_MESSAGE_SIZE, AsciiString.of("localhost"), AsciiString.of("http"),
AsciiString.of("good agent"));
stream = new NettyClientStream(new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE),
methodDescriptor, new Metadata(), channel, AsciiString.of("localhost"),
AsciiString.of("http"), AsciiString.of("good agent"));
stream.start(listener);
ArgumentCaptor<CreateStreamCommand> cmdCap = ArgumentCaptor.forClass(CreateStreamCommand.class);
@ -407,14 +402,12 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
}
}).when(writeQueue).enqueue(any(QueuedCommand.class), any(ChannelPromise.class), anyBoolean());
when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(future);
NettyClientStream stream = new NettyClientStreamImpl(methodDescriptor, new Metadata(), channel,
handler, DEFAULT_MAX_MESSAGE_SIZE, AsciiString.of("localhost"), AsciiString.of("http"),
AsciiString.of("agent"));
NettyClientStream stream = new NettyClientStream(
new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE), methodDescriptor, new Metadata(),
channel, AsciiString.of("localhost"), AsciiString.of("http"), AsciiString.of("agent"));
stream.start(listener);
assertTrue(stream.canSend());
assertTrue(stream.canReceive());
stream.id(STREAM_ID);
stream.setHttp2Stream(http2Stream);
stream.transportState().setId(STREAM_ID);
stream.transportState().setHttp2Stream(http2Stream);
reset(listener);
return stream;
}
@ -447,11 +440,9 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
return Utils.convertTrailers(trailers, true);
}
class NettyClientStreamImpl extends NettyClientStream {
NettyClientStreamImpl(MethodDescriptor<?, ?> method, Metadata headers, Channel channel,
NettyClientHandler handler, int maxMessageSize, AsciiString authority, AsciiString scheme,
AsciiString userAgent) {
super(method, headers, channel, handler, maxMessageSize, authority, scheme, userAgent);
class TransportStateImpl extends NettyClientStream.TransportState {
public TransportStateImpl(NettyClientHandler handler, int maxMessageSize) {
super(handler, maxMessageSize);
}
@Override

View File

@ -177,8 +177,8 @@ public class NettyClientTransportTest {
} catch (ExecutionException e) {
Status status = Status.fromThrowable(e);
assertEquals(INTERNAL, status.getCode());
System.err.println(status.getDescription());
assertTrue(status.getDescription().contains("deframing"));
assertTrue("Missing exceeds maximum from: " + status.getDescription(),
status.getDescription().contains("exceeds maximum"));
}
}

View File

@ -39,7 +39,10 @@ import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.internal.MessageFramer;
import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
@ -68,6 +71,8 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.verification.VerificationMode;
import java.io.ByteArrayInputStream;
/**
* Base class for Netty handler unit tests.
*/
@ -144,6 +149,25 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
handler().channelRead(ctx, obj);
}
protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) {
final ByteBuf compressionFrame = Unpooled.buffer(content.length);
MessageFramer framer = new MessageFramer(new MessageFramer.Sink() {
@Override
public void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
if (frame != null) {
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf();
compressionFrame.writeBytes(bytebuf);
}
}
}, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT));
framer.writePayload(new ByteArrayInputStream(content));
framer.flush();
ChannelHandlerContext ctx = newMockContext();
new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream,
newPromise());
return captureWrite(ctx);
}
protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) {
// Need to retain the content since the frameWriter releases it.
content.retain();

View File

@ -59,20 +59,16 @@ import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.MessageFramer;
import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.WritableBuffer;
import io.grpc.netty.GrpcHttp2HeadersDecoder.GrpcHttp2ServerHeadersDecoder;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.Http2CodecUtil;
import io.netty.handler.codec.http2.Http2Error;
@ -91,7 +87,6 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
/**
@ -159,11 +154,12 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
stream.request(1);
// Create a data frame and then trigger the handler to read it.
ByteBuf frame = dataFrame(STREAM_ID, endStream);
ByteBuf frame = grpcDataFrame(STREAM_ID, endStream, contentAsArray());
channelRead(frame);
ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
verify(streamListener).messageRead(captor.capture());
assertArrayEquals(ByteBufUtil.getBytes(content()), ByteStreams.toByteArray(captor.getValue()));
captor.getValue().close();
if (endStream) {
verify(streamListener).halfClosed();
@ -333,33 +329,10 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
stream = streamCaptor.getValue();
}
private ByteBuf dataFrame(int streamId, boolean endStream) {
byte[] contentAsArray = contentAsArray();
final ByteBuf compressionFrame = Unpooled.buffer(contentAsArray.length);
MessageFramer framer = new MessageFramer(new MessageFramer.Sink() {
@Override
public void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
if (frame != null) {
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf();
compressionFrame.writeBytes(bytebuf);
}
}
}, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT));
framer.writePayload(new ByteArrayInputStream(contentAsArray));
framer.flush();
if (endStream) {
framer.close();
}
ChannelHandlerContext ctx = newMockContext();
new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream,
newPromise());
return captureWrite(ctx);
}
private ByteBuf emptyGrpcFrame(int streamId, boolean endStream) throws Exception {
ByteBuf buf = NettyTestUtil.messageFrame("");
try {
return super.dataFrame(streamId, endStream, buf);
return dataFrame(streamId, endStream, buf);
} finally {
buf.release();
}

View File

@ -138,7 +138,8 @@ public abstract class NettyStreamTestBase<T extends Stream> {
((NettyServerStream) stream).transportState()
.inboundDataReceived(messageFrame(MESSAGE), false);
} else {
((NettyClientStream) stream).transportDataReceived(messageFrame(MESSAGE), false);
((NettyClientStream) stream).transportState()
.transportDataReceived(messageFrame(MESSAGE), false);
}
ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
verify(listener()).messageRead(captor.capture());