core,netty: split stream state on client-side; AbstractStream2

OkHttp will need to be migrated in a future commit.
This commit is contained in:
Eric Anderson 2016-07-17 23:15:05 -07:00
parent 93dd02ca9c
commit cad7124c27
17 changed files with 1193 additions and 313 deletions

View File

@ -0,0 +1,303 @@
/*
* Copyright 2014, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.internal;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.grpc.Metadata;
import io.grpc.Status;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
/**
* The abstract base class for {@link ClientStream} implementations. Extending classes only need to
* implement {@link #transportState()} and {@link #abstractClientStreamSink()}. Must only be called
* from the sending application thread.
*/
public abstract class AbstractClientStream2 extends AbstractStream2
implements ClientStream, MessageFramer.Sink {
private static final Logger log = Logger.getLogger(AbstractClientStream2.class.getName());
/**
* A sink for outbound operations, separated from the stream simply to avoid name
* collisions/confusion. Only called from application thread.
*/
protected interface Sink {
/**
* Sends an outbound frame to the remote end point.
*
* @param frame a buffer containing the chunk of data to be sent, or {@code null} if {@code
* endOfStream} with no data to send
* @param endOfStream {@code true} if this is the last frame; {@code flush} is guaranteed to be
* {@code true} if this is {@code true}
* @param flush {@code true} if more data may not be arriving soon
*/
void writeFrame(@Nullable WritableBuffer frame, boolean endOfStream, boolean flush);
/**
* Requests up to the given number of messages from the call to be delivered to the client. This
* should end up triggering {@link TransportState#requestMessagesFromDeframer(int)} on the
* transport thread.
*/
void request(int numMessages);
/**
* Tears down the stream, typically in the event of a timeout. This method may be called
* multiple times and from any thread.
*
* <p>This is a clone of {@link ClientStream#cancel()}; {@link AbstractClientStream2#cancel}
* delegates to this method.
*/
void cancel(Status status);
}
private final MessageFramer framer;
private boolean outboundClosed;
/**
* Whether cancel() has been called. This is not strictly necessary, but removes the delay between
* cancel() being called and isReady() beginning to return false, since cancel is commonly
* processed asynchronously.
*/
private volatile boolean cancelled;
protected AbstractClientStream2(WritableBufferAllocator bufferAllocator) {
framer = new MessageFramer(this, bufferAllocator);
}
/** {@inheritDoc} */
@Override
protected abstract TransportState transportState();
@Override
public void start(ClientStreamListener listener) {
transportState().setListener(listener);
}
/**
* Sink for transport to be called to perform outbound operations. Each stream must have its own
* unique sink.
*/
protected abstract Sink abstractClientStreamSink();
@Override
protected final MessageFramer framer() {
return framer;
}
@Override
public final void request(int numMessages) {
abstractClientStreamSink().request(numMessages);
}
@Override
public final void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
Preconditions.checkArgument(frame != null || endOfStream, "null frame before EOS");
abstractClientStreamSink().writeFrame(frame, endOfStream, flush);
}
@Override
public final void halfClose() {
if (!outboundClosed) {
outboundClosed = true;
endOfMessages();
}
}
@Override
public final void cancel(Status reason) {
Preconditions.checkArgument(!reason.isOk(), "Should not cancel with OK status");
cancelled = true;
abstractClientStreamSink().cancel(reason);
}
@Override
public final boolean isReady() {
return super.isReady() && !cancelled;
}
/** This should only called from the transport thread. */
protected abstract static class TransportState extends AbstractStream2.TransportState {
/** Whether listener.closed() has been called. */
private boolean listenerClosed;
private ClientStreamListener listener;
private Runnable deliveryStalledTask;
private boolean headersReceived;
/**
* Whether the stream is closed from the transport's perspective. This can differ from {@link
* #listenerClosed} because there may still be messages buffered to deliver to the application.
*/
private boolean statusReported;
protected TransportState(int maxMessageSize) {
super(maxMessageSize);
}
@VisibleForTesting
public final void setListener(ClientStreamListener listener) {
Preconditions.checkState(this.listener == null, "Already called setListener");
this.listener = Preconditions.checkNotNull(listener, "listener");
}
@Override
public final void deliveryStalled() {
if (deliveryStalledTask != null) {
deliveryStalledTask.run();
deliveryStalledTask = null;
}
}
@Override
public final void endOfStream() {
deliveryStalled();
}
@Override
protected final ClientStreamListener listener() {
return listener;
}
/**
* Called by transport implementations when they receive headers.
*
* @param headers the parsed headers
*/
protected void inboundHeadersReceived(Metadata headers) {
Preconditions.checkState(!statusReported, "Received headers on closed stream");
headersReceived = true;
listener().headersRead(headers);
}
/**
* Processes the contents of a received data frame from the server.
*
* @param frame the received data frame. Its ownership is transferred to this method.
*/
protected void inboundDataReceived(ReadableBuffer frame) {
Preconditions.checkNotNull(frame, "frame");
boolean needToCloseFrame = true;
try {
if (statusReported) {
log.log(Level.INFO, "Received data on closed stream");
return;
}
if (!headersReceived) {
transportReportStatus(
Status.INTERNAL.withDescription("headers not received before payload"),
false, new Metadata());
return;
}
needToCloseFrame = false;
deframe(frame, false);
} finally {
if (needToCloseFrame) {
frame.close();
}
}
}
/**
* Processes the trailers and status from the server.
*
* @param trailers the received trailers
* @param status the status extracted from the trailers
*/
protected void inboundTrailersReceived(Metadata trailers, Status status) {
Preconditions.checkNotNull(status, "status");
Preconditions.checkNotNull(trailers, "trailers");
if (statusReported) {
log.log(Level.INFO, "Received trailers on closed stream:\n {1}\n {2}",
new Object[]{status, trailers});
return;
}
transportReportStatus(status, false, trailers);
}
/**
* Report stream closure with status to the application layer if not already reported. This
* method must be called from the transport thread.
*
* @param status the new status to set
* @param stopDelivery if {@code true}, interrupts any further delivery of inbound messages that
* may already be queued up in the deframer. If {@code false}, the listener will be
* notified immediately after all currently completed messages in the deframer have been
* delivered to the application.
* @param trailers new instance of {@code Trailers}, either empty or those returned by the
* server
*/
public final void transportReportStatus(final Status status, boolean stopDelivery,
final Metadata trailers) {
Preconditions.checkNotNull(status, "status");
Preconditions.checkNotNull(trailers, "trailers");
// If stopDelivery, we continue in case previous invocation is waiting for stall
if (statusReported && !stopDelivery) {
return;
}
statusReported = true;
onStreamDeallocated();
// If not stopping delivery, then we must wait until the deframer is stalled (i.e., it has no
// complete messages to deliver).
if (stopDelivery || isDeframerStalled()) {
deliveryStalledTask = null;
closeListener(status, trailers);
} else {
deliveryStalledTask = new Runnable() {
@Override
public void run() {
closeListener(status, trailers);
}
};
}
}
/**
* Closes the listener if not previously closed.
*
* @throws IllegalStateException if the call has not yet been started.
*/
private void closeListener(Status status, Metadata trailers) {
if (!listenerClosed) {
listenerClosed = true;
closeDeframer();
listener().closed(status, trailers);
}
}
}
}

View File

@ -43,7 +43,6 @@ import javax.annotation.Nullable;
* Abstract base class for {@link ServerStream} implementations. Extending classes only need to * Abstract base class for {@link ServerStream} implementations. Extending classes only need to
* implement {@link #transportState()} and {@link #abstractServerStreamSink()}. Must only be called * implement {@link #transportState()} and {@link #abstractServerStreamSink()}. Must only be called
* from the sending application thread. * from the sending application thread.
*/ */
public abstract class AbstractServerStream extends AbstractStream2 public abstract class AbstractServerStream extends AbstractStream2
implements ServerStream, MessageFramer.Sink { implements ServerStream, MessageFramer.Sink {
@ -158,6 +157,11 @@ public abstract class AbstractServerStream extends AbstractStream2
abstractServerStreamSink().cancel(status); abstractServerStreamSink().cancel(status);
} }
@Override
public final boolean isReady() {
return super.isReady();
}
@Override public Attributes attributes() { @Override public Attributes attributes() {
return Attributes.EMPTY; return Attributes.EMPTY;
} }
@ -241,6 +245,7 @@ public abstract class AbstractServerStream extends AbstractStream2
private void closeListener(Status newStatus) { private void closeListener(Status newStatus) {
if (!listenerClosed) { if (!listenerClosed) {
listenerClosed = true; listenerClosed = true;
onStreamDeallocated();
closeDeframer(); closeDeframer();
listener().closed(newStatus); listener().closed(newStatus);
} }

View File

@ -97,7 +97,7 @@ public abstract class AbstractStream2 implements Stream {
} }
@Override @Override
public final boolean isReady() { public boolean isReady() {
if (framer().isClosed()) { if (framer().isClosed()) {
return false; return false;
} }
@ -139,6 +139,12 @@ public abstract class AbstractStream2 implements Stream {
*/ */
@GuardedBy("onReadyLock") @GuardedBy("onReadyLock")
private boolean allocated; private boolean allocated;
/**
* Indicates that the stream no longer exists for the transport. Implies that the application
* should be discouraged from sending, because doing so would have no effect.
*/
@GuardedBy("onReadyLock")
private boolean deallocated;
protected TransportState(int maxMessageSize) { protected TransportState(int maxMessageSize) {
deframer = new MessageDeframer(this, Codec.Identity.NONE, maxMessageSize); deframer = new MessageDeframer(this, Codec.Identity.NONE, maxMessageSize);
@ -174,6 +180,13 @@ public abstract class AbstractStream2 implements Stream {
deframer.close(); deframer.close();
} }
/**
* Indicates whether delivery is currently stalled, pending receipt of more data.
*/
protected final boolean isDeframerStalled() {
return deframer.isStalled();
}
/** /**
* Called to parse a received frame and attempt delivery of any completed * Called to parse a received frame and attempt delivery of any completed
* messages. Must be called from the transport thread. * messages. Must be called from the transport thread.
@ -214,7 +227,7 @@ public abstract class AbstractStream2 implements Stream {
private boolean isReady() { private boolean isReady() {
synchronized (onReadyLock) { synchronized (onReadyLock) {
return allocated && numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD; return allocated && numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD && !deallocated;
} }
} }
@ -233,6 +246,19 @@ public abstract class AbstractStream2 implements Stream {
notifyIfReady(); notifyIfReady();
} }
/**
* Notify that the stream does not exist in a usable state any longer. This causes {@link
* AbstractStream2#isReady()} to return {@code false} from this point forward.
*
* <p>This does not generally need to be called explicitly by the transport, as it is handled
* implicitly by {@link AbstractClientStream2} and {@link AbstractServerStream}.
*/
protected final void onStreamDeallocated() {
synchronized (onReadyLock) {
deallocated = true;
}
}
/** /**
* Event handler to be called by the subclass when a number of bytes are being queued for * Event handler to be called by the subclass when a number of bytes are being queued for
* sending to the remote endpoint. * sending to the remote endpoint.
@ -256,6 +282,8 @@ public abstract class AbstractStream2 implements Stream {
public final void onSentBytes(int numBytes) { public final void onSentBytes(int numBytes) {
boolean doNotify; boolean doNotify;
synchronized (onReadyLock) { synchronized (onReadyLock) {
checkState(allocated,
"onStreamAllocated was not called, but it seems the stream is active");
boolean belowThresholdBefore = numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD; boolean belowThresholdBefore = numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD;
numSentBytesQueued -= numBytes; numSentBytesQueued -= numBytes;
boolean belowThresholdAfter = numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD; boolean belowThresholdAfter = numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD;

View File

@ -0,0 +1,251 @@
/*
* Copyright 2014, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.internal;
import com.google.common.base.Charsets;
import com.google.common.base.Preconditions;
import io.grpc.InternalMetadata;
import io.grpc.Metadata;
import io.grpc.Status;
import java.nio.charset.Charset;
import javax.annotation.Nullable;
/**
* Base implementation for client streams using HTTP2 as the transport.
*/
public abstract class Http2ClientStreamTransportState extends AbstractClientStream2.TransportState {
/**
* Metadata marshaller for HTTP status lines.
*/
private static final InternalMetadata.TrustedAsciiMarshaller<Integer> HTTP_STATUS_MARSHALLER =
new InternalMetadata.TrustedAsciiMarshaller<Integer>() {
@Override
public byte[] toAsciiString(Integer value) {
throw new UnsupportedOperationException();
}
/**
* RFC 7231 says status codes are 3 digits long.
*
* @see: <a href="https://tools.ietf.org/html/rfc7231#section-6">RFC 7231</a>
*/
@Override
public Integer parseAsciiString(byte[] serialized) {
if (serialized.length >= 3) {
return (serialized[0] - '0') * 100 + (serialized[1] - '0') * 10 + (serialized[2] - '0');
}
throw new NumberFormatException(
"Malformed status code " + new String(serialized, InternalMetadata.US_ASCII));
}
};
private static final Metadata.Key<Integer> HTTP2_STATUS = InternalMetadata.keyOf(":status",
HTTP_STATUS_MARSHALLER);
/** When non-{@code null}, {@link #transportErrorMetadata} must also be non-{@code null}. */
private Status transportError;
private Metadata transportErrorMetadata;
private Charset errorCharset = Charsets.UTF_8;
private boolean contentTypeChecked;
protected Http2ClientStreamTransportState(int maxMessageSize) {
super(maxMessageSize);
}
/**
* Called to process a failure in HTTP/2 processing. It should notify the transport to cancel the
* stream and call {@code transportReportStatus()}.
*/
protected abstract void http2ProcessingFailed(Status status, Metadata trailers);
/**
* Called by subclasses whenever {@code Headers} are received from the transport.
*
* @param headers the received headers
*/
protected void transportHeadersReceived(Metadata headers) {
Preconditions.checkNotNull(headers, "headers");
if (transportError != null) {
// Already received a transport error so just augment it.
transportError = transportError.augmentDescription(headers.toString());
return;
}
Status httpStatus = statusFromHttpStatus(headers);
if (httpStatus == null) {
transportError = Status.INTERNAL.withDescription(
"received non-terminal headers with no :status");
} else if (!httpStatus.isOk()) {
transportError = httpStatus;
} else {
transportError = checkContentType(headers);
}
if (transportError != null) {
// Note we don't immediately report the transport error, instead we wait for more data on the
// stream so we can accumulate more detail into the error before reporting it.
transportError = transportError.augmentDescription("\n" + headers);
transportErrorMetadata = headers;
errorCharset = extractCharset(headers);
} else {
stripTransportDetails(headers);
inboundHeadersReceived(headers);
}
}
/**
* Called by subclasses whenever a data frame is received from the transport.
*
* @param frame the received data frame
* @param endOfStream {@code true} if there will be no more data received for this stream
*/
protected void transportDataReceived(ReadableBuffer frame, boolean endOfStream) {
if (transportError != null) {
// We've already detected a transport error and now we're just accumulating more detail
// for it.
transportError = transportError.augmentDescription("DATA-----------------------------\n"
+ ReadableBuffers.readAsString(frame, errorCharset));
frame.close();
if (transportError.getDescription().length() > 1000 || endOfStream) {
http2ProcessingFailed(transportError, transportErrorMetadata);
}
} else {
inboundDataReceived(frame);
if (endOfStream) {
// This is a protocol violation as we expect to receive trailers.
transportError =
Status.INTERNAL.withDescription("Received unexpected EOS on DATA frame from server.");
transportErrorMetadata = new Metadata();
transportReportStatus(transportError, false, transportErrorMetadata);
}
}
}
/**
* Called by subclasses for the terminal trailer metadata on a stream.
*
* @param trailers the received terminal trailer metadata
*/
protected void transportTrailersReceived(Metadata trailers) {
Preconditions.checkNotNull(trailers, "trailers");
if (transportError != null) {
// Already received a transport error so just augment it.
transportError = transportError.augmentDescription(trailers.toString());
} else {
transportError = checkContentType(trailers);
transportErrorMetadata = trailers;
}
if (transportError != null) {
http2ProcessingFailed(transportError, transportErrorMetadata);
} else {
Status status = statusFromTrailers(trailers);
stripTransportDetails(trailers);
inboundTrailersReceived(trailers, status);
}
}
private static Status statusFromHttpStatus(Metadata metadata) {
Integer httpStatus = metadata.get(HTTP2_STATUS);
if (httpStatus != null) {
Status status = GrpcUtil.httpStatusToGrpcStatus(httpStatus);
return status.isOk() ? status
: status.augmentDescription("extracted status from HTTP :status " + httpStatus);
}
return null;
}
/**
* Extract the response status from trailers.
*/
private static Status statusFromTrailers(Metadata trailers) {
Status status = trailers.get(Status.CODE_KEY);
if (status == null) {
status = statusFromHttpStatus(trailers);
if (status == null || status.isOk()) {
status = Status.UNKNOWN.withDescription("missing GRPC status in response");
} else {
status = status.withDescription(
"missing GRPC status, inferred error from HTTP status code");
}
}
String message = trailers.get(Status.MESSAGE_KEY);
if (message != null) {
status = status.augmentDescription(message);
}
return status;
}
/**
* Inspect the content type field from received headers or trailers and return an error Status if
* content type is invalid or not present. Returns null if no error was found.
*/
@Nullable
private Status checkContentType(Metadata headers) {
if (contentTypeChecked) {
return null;
}
contentTypeChecked = true;
String contentType = headers.get(GrpcUtil.CONTENT_TYPE_KEY);
if (!GrpcUtil.isGrpcContentType(contentType)) {
return Status.INTERNAL.withDescription("Invalid content-type: " + contentType);
}
return null;
}
/**
* Inspect the raw metadata and figure out what charset is being used.
*/
private static Charset extractCharset(Metadata headers) {
String contentType = headers.get(GrpcUtil.CONTENT_TYPE_KEY);
if (contentType != null) {
String[] split = contentType.split("charset=");
try {
return Charset.forName(split[split.length - 1].trim());
} catch (Exception t) {
// Ignore and assume UTF-8
}
}
return Charsets.UTF_8;
}
/**
* Strip HTTP transport implementation details so they don't leak via metadata into
* the application layer.
*/
private static void stripTransportDetails(Metadata metadata) {
metadata.discardAll(HTTP2_STATUS);
metadata.discardAll(Status.CODE_KEY);
metadata.discardAll(Status.MESSAGE_KEY);
}
}

View File

@ -0,0 +1,279 @@
/*
* Copyright 2015, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.internal;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.verify;
import io.grpc.Codec;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.internal.MessageFramerTest.ByteWritableBuffer;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/**
* Test for {@link AbstractClientStream2}. This class tries to test functionality in
* AbstractClientStream2, but not in any super classes.
*/
@RunWith(JUnit4.class)
public class AbstractClientStream2Test {
@Rule public final ExpectedException thrown = ExpectedException.none();
@Mock private ClientStreamListener mockListener;
@Captor private ArgumentCaptor<Status> statusCaptor;
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
}
private final WritableBufferAllocator allocator = new WritableBufferAllocator() {
@Override
public WritableBuffer allocate(int capacityHint) {
return new ByteWritableBuffer(capacityHint);
}
};
@Test
public void cancel_doNotAcceptOk() {
for (Code code : Code.values()) {
ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(listener);
if (code != Code.OK) {
stream.cancel(Status.fromCodeValue(code.value()));
} else {
try {
stream.cancel(Status.fromCodeValue(code.value()));
fail();
} catch (IllegalArgumentException e) {
// ignore
}
}
}
}
@Test
public void cancel_failsOnNull() {
ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(listener);
thrown.expect(NullPointerException.class);
stream.cancel(null);
}
@Test
public void cancel_notifiesOnlyOnce() {
final BaseTransportState state = new BaseTransportState();
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, state, new BaseSink() {
@Override
public void cancel(Status errorStatus) {
// Cancel should eventually result in a transportReportStatus on the transport thread
state.transportReportStatus(errorStatus, true/*stop delivery*/, new Metadata());
}
});
stream.start(mockListener);
stream.cancel(Status.DEADLINE_EXCEEDED);
stream.cancel(Status.DEADLINE_EXCEEDED);
verify(mockListener).closed(any(Status.class), any(Metadata.class));
}
@Test
public void startFailsOnNullListener() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
thrown.expect(NullPointerException.class);
stream.start(null);
}
@Test
public void cantCallStartTwice() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
thrown.expect(IllegalStateException.class);
stream.start(mockListener);
}
@Test
public void inboundDataReceived_failsOnNullFrame() {
ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(listener);
thrown.expect(NullPointerException.class);
stream.transportState().inboundDataReceived(null);
}
@Test
public void inboundDataReceived_failsOnNoHeaders() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
stream.transportState().inboundDataReceived(ReadableBuffers.empty());
verify(mockListener).closed(statusCaptor.capture(), any(Metadata.class));
assertEquals(Code.INTERNAL, statusCaptor.getValue().getCode());
}
@Test
public void inboundHeadersReceived_notifiesListener() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
Metadata headers = new Metadata();
stream.transportState().inboundHeadersReceived(headers);
verify(mockListener).headersRead(headers);
}
@Test
public void inboundHeadersReceived_failsIfStatusReported() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
stream.transportState().transportReportStatus(Status.CANCELLED, false, new Metadata());
thrown.expect(IllegalStateException.class);
stream.transportState().inboundHeadersReceived(new Metadata());
}
@Test
public void inboundHeadersReceived_acceptsGzipEncoding() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
Metadata headers = new Metadata();
headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, new Codec.Gzip().getMessageEncoding());
stream.transportState().inboundHeadersReceived(headers);
verify(mockListener).headersRead(headers);
}
@Test
public void inboundHeadersReceived_acceptsIdentityEncoding() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
Metadata headers = new Metadata();
headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, Codec.Identity.NONE.getMessageEncoding());
stream.transportState().inboundHeadersReceived(headers);
verify(mockListener).headersRead(headers);
}
@Test
public void rstStreamClosesStream() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator);
stream.start(mockListener);
// The application will call request when waiting for a message, which will in turn call this
// on the transport thread.
stream.transportState().requestMessagesFromDeframer(1);
// Send first byte of 2 byte message
stream.transportState().deframe(ReadableBuffers.wrap(new byte[] {0, 0, 0, 0, 2, 1}), false);
Status status = Status.INTERNAL;
// Simulate getting a reset
stream.transportState().transportReportStatus(status, false /*stop delivery*/, new Metadata());
verify(mockListener).closed(any(Status.class), any(Metadata.class));
}
/**
* No-op base class for testing.
*/
private static class BaseAbstractClientStream extends AbstractClientStream2 {
private final TransportState state;
private final Sink sink;
public BaseAbstractClientStream(WritableBufferAllocator allocator) {
this(allocator, new BaseTransportState(), new BaseSink());
}
public BaseAbstractClientStream(WritableBufferAllocator allocator, TransportState state,
Sink sink) {
super(allocator);
this.state = state;
this.sink = sink;
}
@Override
protected Sink abstractClientStreamSink() {
return sink;
}
@Override
public TransportState transportState() {
return state;
}
@Override
public void setAuthority(String authority) {}
}
private static class BaseSink implements AbstractClientStream2.Sink {
@Override
public void request(int numMessages) {}
@Override
public void writeFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {}
@Override
public void cancel(Status reason) {}
}
private static class BaseTransportState extends AbstractClientStream2.TransportState {
public BaseTransportState() {
super(DEFAULT_MAX_MESSAGE_SIZE);
}
@Override
protected void deframeFailed(Throwable cause) {}
@Override
public void bytesRead(int processedBytes) {}
}
}

View File

@ -39,17 +39,17 @@ import io.grpc.Status;
* Command sent from a Netty client stream to the handler to cancel the stream. * Command sent from a Netty client stream to the handler to cancel the stream.
*/ */
class CancelClientStreamCommand extends WriteQueue.AbstractQueuedCommand { class CancelClientStreamCommand extends WriteQueue.AbstractQueuedCommand {
private final NettyClientStream stream; private final NettyClientStream.TransportState stream;
private final Status reason; private final Status reason;
CancelClientStreamCommand(NettyClientStream stream, Status reason) { CancelClientStreamCommand(NettyClientStream.TransportState stream, Status reason) {
this.stream = Preconditions.checkNotNull(stream, "stream"); this.stream = Preconditions.checkNotNull(stream, "stream");
Preconditions.checkNotNull(reason, "reason"); Preconditions.checkNotNull(reason, "reason");
Preconditions.checkArgument(!reason.isOk(), "Should not cancel with OK status"); Preconditions.checkArgument(!reason.isOk(), "Should not cancel with OK status");
this.reason = reason; this.reason = reason;
} }
NettyClientStream stream() { NettyClientStream.TransportState stream() {
return stream; return stream;
} }

View File

@ -41,15 +41,15 @@ import io.netty.handler.codec.http2.Http2Headers;
*/ */
class CreateStreamCommand extends WriteQueue.AbstractQueuedCommand { class CreateStreamCommand extends WriteQueue.AbstractQueuedCommand {
private final Http2Headers headers; private final Http2Headers headers;
private final NettyClientStream stream; private final NettyClientStream.TransportState stream;
CreateStreamCommand(Http2Headers headers, CreateStreamCommand(Http2Headers headers,
NettyClientStream stream) { NettyClientStream.TransportState stream) {
this.stream = Preconditions.checkNotNull(stream, "stream"); this.stream = Preconditions.checkNotNull(stream, "stream");
this.headers = Preconditions.checkNotNull(headers, "headers"); this.headers = Preconditions.checkNotNull(headers, "headers");
} }
NettyClientStream stream() { NettyClientStream.TransportState stream() {
return stream; return stream;
} }

View File

@ -246,17 +246,16 @@ class NettyClientHandler extends AbstractNettyHandler {
} }
private void onHeadersRead(int streamId, Http2Headers headers, boolean endStream) { private void onHeadersRead(int streamId, Http2Headers headers, boolean endStream) {
NettyClientStream stream = clientStream(requireHttp2Stream(streamId)); NettyClientStream.TransportState stream = clientStream(requireHttp2Stream(streamId));
stream.transportHeadersReceived(headers, endStream); stream.transportHeadersReceived(headers, endStream);
} }
/** /**
* Handler for an inbound HTTP/2 DATA frame. * Handler for an inbound HTTP/2 DATA frame.
*/ */
private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfStream) { private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfStream) {
flowControlPing().onDataRead(data.readableBytes(), padding); flowControlPing().onDataRead(data.readableBytes(), padding);
NettyClientStream stream = clientStream(requireHttp2Stream(streamId)); NettyClientStream.TransportState stream = clientStream(requireHttp2Stream(streamId));
stream.transportDataReceived(data, endOfStream); stream.transportDataReceived(data, endOfStream);
} }
@ -265,7 +264,7 @@ class NettyClientHandler extends AbstractNettyHandler {
* Handler for an inbound HTTP/2 RST_STREAM frame, terminating a stream. * Handler for an inbound HTTP/2 RST_STREAM frame, terminating a stream.
*/ */
private void onRstStreamRead(int streamId, long errorCode) { private void onRstStreamRead(int streamId, long errorCode) {
NettyClientStream stream = clientStream(connection().stream(streamId)); NettyClientStream.TransportState stream = clientStream(connection().stream(streamId));
if (stream != null) { if (stream != null) {
Status status = GrpcUtil.Http2Error.statusForCode((int) errorCode) Status status = GrpcUtil.Http2Error.statusForCode((int) errorCode)
.augmentDescription("Received Rst Stream"); .augmentDescription("Received Rst Stream");
@ -295,7 +294,7 @@ class NettyClientHandler extends AbstractNettyHandler {
connection().forEachActiveStream(new Http2StreamVisitor() { connection().forEachActiveStream(new Http2StreamVisitor() {
@Override @Override
public boolean visit(Http2Stream stream) throws Http2Exception { public boolean visit(Http2Stream stream) throws Http2Exception {
NettyClientStream clientStream = clientStream(stream); NettyClientStream.TransportState clientStream = clientStream(stream);
if (clientStream != null) { if (clientStream != null) {
clientStream.transportReportStatus( clientStream.transportReportStatus(
lifecycleManager.getShutdownStatus(), false, new Metadata()); lifecycleManager.getShutdownStatus(), false, new Metadata());
@ -322,7 +321,7 @@ class NettyClientHandler extends AbstractNettyHandler {
protected void onStreamError(ChannelHandlerContext ctx, Throwable cause, protected void onStreamError(ChannelHandlerContext ctx, Throwable cause,
Http2Exception.StreamException http2Ex) { Http2Exception.StreamException http2Ex) {
// Close the stream with a status that contains the cause. // Close the stream with a status that contains the cause.
NettyClientStream stream = clientStream(connection().stream(http2Ex.streamId())); NettyClientStream.TransportState stream = clientStream(connection().stream(http2Ex.streamId()));
if (stream != null) { if (stream != null) {
stream.transportReportStatus(Utils.statusFromThrowable(cause), false, new Metadata()); stream.transportReportStatus(Utils.statusFromThrowable(cause), false, new Metadata());
} else { } else {
@ -370,9 +369,9 @@ class NettyClientHandler extends AbstractNettyHandler {
return; return;
} }
final NettyClientStream stream = command.stream(); final NettyClientStream.TransportState stream = command.stream();
final Http2Headers headers = command.headers(); final Http2Headers headers = command.headers();
stream.id(streamId); stream.setId(streamId);
// Create an intermediate promise so that we can intercept the failure reported back to the // Create an intermediate promise so that we can intercept the failure reported back to the
// application. // application.
@ -418,7 +417,7 @@ class NettyClientHandler extends AbstractNettyHandler {
*/ */
private void cancelStream(ChannelHandlerContext ctx, CancelClientStreamCommand cmd, private void cancelStream(ChannelHandlerContext ctx, CancelClientStreamCommand cmd,
ChannelPromise promise) { ChannelPromise promise) {
NettyClientStream stream = cmd.stream(); NettyClientStream.TransportState stream = cmd.stream();
stream.transportReportStatus(cmd.reason(), true, new Metadata()); stream.transportReportStatus(cmd.reason(), true, new Metadata());
encoder().writeRstStream(ctx, stream.id(), Http2Error.CANCEL.code(), promise); encoder().writeRstStream(ctx, stream.id(), Http2Error.CANCEL.code(), promise);
} }
@ -507,7 +506,7 @@ class NettyClientHandler extends AbstractNettyHandler {
connection().forEachActiveStream(new Http2StreamVisitor() { connection().forEachActiveStream(new Http2StreamVisitor() {
@Override @Override
public boolean visit(Http2Stream stream) throws Http2Exception { public boolean visit(Http2Stream stream) throws Http2Exception {
NettyClientStream clientStream = clientStream(stream); NettyClientStream.TransportState clientStream = clientStream(stream);
if (clientStream != null) { if (clientStream != null) {
clientStream.transportReportStatus(msg.getStatus(), true, new Metadata()); clientStream.transportReportStatus(msg.getStatus(), true, new Metadata());
resetStream(ctx, stream.id(), Http2Error.CANCEL.code(), ctx.newPromise()); resetStream(ctx, stream.id(), Http2Error.CANCEL.code(), ctx.newPromise());
@ -531,7 +530,7 @@ class NettyClientHandler extends AbstractNettyHandler {
@Override @Override
public boolean visit(Http2Stream stream) throws Http2Exception { public boolean visit(Http2Stream stream) throws Http2Exception {
if (stream.id() > lastKnownStream) { if (stream.id() > lastKnownStream) {
NettyClientStream clientStream = clientStream(stream); NettyClientStream.TransportState clientStream = clientStream(stream);
if (clientStream != null) { if (clientStream != null) {
clientStream.transportReportStatus(goAwayStatus, false, new Metadata()); clientStream.transportReportStatus(goAwayStatus, false, new Metadata());
} }
@ -566,8 +565,8 @@ class NettyClientHandler extends AbstractNettyHandler {
/** /**
* Gets the client stream associated to the given HTTP/2 stream object. * Gets the client stream associated to the given HTTP/2 stream object.
*/ */
private NettyClientStream clientStream(Http2Stream stream) { private NettyClientStream.TransportState clientStream(Http2Stream stream) {
return stream == null ? null : (NettyClientStream) stream.getProperty(streamKey); return stream == null ? null : (NettyClientStream.TransportState) stream.getProperty(streamKey);
} }
private int incrementAndGetNextStreamId() throws StatusException { private int incrementAndGetNextStreamId() throws StatusException {

View File

@ -41,9 +41,10 @@ import io.grpc.InternalMethodDescriptor;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.AbstractClientStream2;
import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.Http2ClientStream; import io.grpc.internal.Http2ClientStreamTransportState;
import io.grpc.internal.WritableBuffer; import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel; import io.netty.channel.Channel;
@ -56,43 +57,51 @@ import io.netty.util.AsciiString;
import javax.annotation.Nullable; import javax.annotation.Nullable;
/** /**
* Client stream for a Netty transport. * Client stream for a Netty transport. Must only be called from the sending application
* thread.
*/ */
abstract class NettyClientStream extends Http2ClientStream implements StreamIdHolder { class NettyClientStream extends AbstractClientStream2 {
private static final InternalMethodDescriptor methodDescriptorAccessor = private static final InternalMethodDescriptor methodDescriptorAccessor =
new InternalMethodDescriptor(InternalKnownTransport.NETTY); new InternalMethodDescriptor(InternalKnownTransport.NETTY);
private final Sink sink = new Sink();
private final TransportState state;
private final WriteQueue writeQueue;
private final MethodDescriptor<?, ?> method; private final MethodDescriptor<?, ?> method;
/** {@code null} after start. */ /** {@code null} after start. */
private Metadata headers; private Metadata headers;
private final Channel channel; private final Channel channel;
private final NettyClientHandler handler; private AsciiString authority;
private final AsciiString scheme; private final AsciiString scheme;
private final AsciiString userAgent; private final AsciiString userAgent;
private AsciiString authority;
private Http2Stream http2Stream; NettyClientStream(TransportState state, MethodDescriptor<?, ?> method, Metadata headers,
private int id; Channel channel, AsciiString authority, AsciiString scheme,
private WriteQueue writeQueue;
NettyClientStream(MethodDescriptor<?, ?> method, Metadata headers, Channel channel,
NettyClientHandler handler, int maxMessageSize, AsciiString authority, AsciiString scheme,
AsciiString userAgent) { AsciiString userAgent) {
super(new NettyWritableBufferAllocator(channel.alloc()), maxMessageSize); super(new NettyWritableBufferAllocator(channel.alloc()));
this.state = checkNotNull(state, "transportState");
this.writeQueue = state.handler.getWriteQueue();
this.method = checkNotNull(method, "method"); this.method = checkNotNull(method, "method");
this.headers = checkNotNull(headers, "headers"); this.headers = checkNotNull(headers, "headers");
this.writeQueue = handler.getWriteQueue();
this.channel = checkNotNull(channel, "channel"); this.channel = checkNotNull(channel, "channel");
this.handler = checkNotNull(handler, "handler");
this.authority = checkNotNull(authority, "authority"); this.authority = checkNotNull(authority, "authority");
this.scheme = checkNotNull(scheme, "scheme"); this.scheme = checkNotNull(scheme, "scheme");
this.userAgent = userAgent; this.userAgent = userAgent;
} }
@Override
protected TransportState transportState() {
return state;
}
@Override
protected Sink abstractClientStreamSink() {
return sink;
}
@Override @Override
public void setAuthority(String authority) { public void setAuthority(String authority) {
checkState(listener() == null, "must be call before start"); checkState(headers != null, "must be call before start");
this.authority = AsciiString.of(checkNotNull(authority, "authority")); this.authority = AsciiString.of(checkNotNull(authority, "authority"));
} }
@ -116,114 +125,134 @@ abstract class NettyClientStream extends Http2ClientStream implements StreamIdHo
public void operationComplete(ChannelFuture future) throws Exception { public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) { if (!future.isSuccess()) {
// Stream creation failed. Close the stream if not already closed. // Stream creation failed. Close the stream if not already closed.
Status s = statusFromFailedFuture(future); Status s = transportState().statusFromFailedFuture(future);
transportReportStatus(s, true, new Metadata()); transportState().transportReportStatus(s, true, new Metadata());
} }
} }
}; };
// Write the command requesting the creation of the stream. // Write the command requesting the creation of the stream.
writeQueue.enqueue(new CreateStreamCommand(http2Headers, this), writeQueue.enqueue(new CreateStreamCommand(http2Headers, transportState()),
!method.getType().clientSendsOneMessage()).addListener(failureListener); !method.getType().clientSendsOneMessage()).addListener(failureListener);
} }
@Override private class Sink implements AbstractClientStream2.Sink {
public void transportReportStatus(Status newStatus, boolean stopDelivery, Metadata trailers) { @Override
super.transportReportStatus(newStatus, stopDelivery, trailers); public void writeFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
} ByteBuf bytebuf = frame == null ? EMPTY_BUFFER : ((NettyWritableBuffer) frame).bytebuf();
final int numBytes = bytebuf.readableBytes();
if (numBytes > 0) {
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
writeQueue.enqueue(
new SendGrpcFrameCommand(transportState(), bytebuf, endOfStream),
channel.newPromise().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
transportState().onSentBytes(numBytes);
}
}), flush);
} else {
// The frame is empty and will not impact outbound flow control. Just send it.
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, endOfStream), flush);
}
}
/** @Override
* Intended to be overriden by NettyClientTransport, which has more information about failures. public void request(int numMessages) {
* May only be called from event loop. if (channel.eventLoop().inEventLoop()) {
*/ // Processing data read in the event loop so can call into the deframer immediately
protected abstract Status statusFromFailedFuture(ChannelFuture f); transportState().requestMessagesFromDeframer(numMessages);
} else {
writeQueue.enqueue(new RequestMessagesCommand(transportState(), numMessages), true);
}
}
@Override @Override
public void request(int numMessages) { public void cancel(Status status) {
if (channel.eventLoop().inEventLoop()) { writeQueue.enqueue(new CancelClientStreamCommand(transportState(), status), true);
// Processing data read in the event loop so can call into the deframer immediately
requestMessagesFromDeframer(numMessages);
} else {
writeQueue.enqueue(new RequestMessagesCommand(this, numMessages), true);
} }
} }
@Override /** This should only called from the transport thread. */
public int id() { public abstract static class TransportState extends Http2ClientStreamTransportState
return id; implements StreamIdHolder {
} private final NettyClientHandler handler;
private int id;
private Http2Stream http2Stream;
public void id(int id) { public TransportState(NettyClientHandler handler, int maxMessageSize) {
checkArgument(id != ABSENT_ID, "Can't use absent id"); super(maxMessageSize);
this.id = id; this.handler = checkNotNull(handler, "handler");
}
/**
* Sets the underlying Netty {@link Http2Stream} for this stream. This must be called in the
* context of the transport thread.
*/
public void setHttp2Stream(Http2Stream http2Stream) {
checkNotNull(http2Stream, "http2Stream");
checkState(this.http2Stream == null, "Can only set http2Stream once");
this.http2Stream = http2Stream;
// Now that the stream has actually been initialized, call the listener's onReady callback if
// appropriate.
onStreamAllocated();
}
/**
* Gets the underlying Netty {@link Http2Stream} for this stream.
*/
@Nullable
public Http2Stream http2Stream() {
return http2Stream;
}
void transportHeadersReceived(Http2Headers headers, boolean endOfStream) {
if (endOfStream) {
transportTrailersReceived(Utils.convertTrailers(headers));
} else {
transportHeadersReceived(Utils.convertHeaders(headers));
} }
}
void transportDataReceived(ByteBuf frame, boolean endOfStream) { @Override
transportDataReceived(new NettyReadableBuffer(frame.retain()), endOfStream); public int id() {
} return id;
@Override
protected void sendCancel(Status reason) {
// Send the cancel command to the handler.
writeQueue.enqueue(new CancelClientStreamCommand(this, reason), true);
}
@Override
protected void sendFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
ByteBuf bytebuf = frame == null ? EMPTY_BUFFER : ((NettyWritableBuffer) frame).bytebuf();
final int numBytes = bytebuf.readableBytes();
if (numBytes > 0) {
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
writeQueue.enqueue(
new SendGrpcFrameCommand(this, bytebuf, endOfStream),
channel.newPromise().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
onSentBytes(numBytes);
}
}), flush);
} else {
// The frame is empty and will not impact outbound flow control. Just send it.
writeQueue.enqueue(new SendGrpcFrameCommand(this, bytebuf, endOfStream), flush);
} }
}
@Override public void setId(int id) {
protected void returnProcessedBytes(int processedBytes) { checkArgument(id > 0, "id must be positive");
handler.returnProcessedBytes(http2Stream, processedBytes); this.id = id;
writeQueue.scheduleFlush(); }
/**
* Sets the underlying Netty {@link Http2Stream} for this stream. This must be called in the
* context of the transport thread.
*/
public void setHttp2Stream(Http2Stream http2Stream) {
checkNotNull(http2Stream, "http2Stream");
checkState(this.http2Stream == null, "Can only set http2Stream once");
this.http2Stream = http2Stream;
// Now that the stream has actually been initialized, call the listener's onReady callback if
// appropriate.
onStreamAllocated();
}
/**
* Gets the underlying Netty {@link Http2Stream} for this stream.
*/
@Nullable
public Http2Stream http2Stream() {
return http2Stream;
}
/**
* Intended to be overriden by NettyClientTransport, which has more information about failures.
* May only be called from event loop.
*/
protected abstract Status statusFromFailedFuture(ChannelFuture f);
@Override
protected void http2ProcessingFailed(Status status, Metadata trailers) {
transportReportStatus(status, false, trailers);
handler.getWriteQueue().enqueue(new CancelClientStreamCommand(this, status), true);
}
@Override
public void bytesRead(int processedBytes) {
handler.returnProcessedBytes(http2Stream, processedBytes);
handler.getWriteQueue().scheduleFlush();
}
@Override
protected void deframeFailed(Throwable cause) {
http2ProcessingFailed(Status.fromThrowable(cause), new Metadata());
}
void transportHeadersReceived(Http2Headers headers, boolean endOfStream) {
if (endOfStream) {
transportTrailersReceived(Utils.convertTrailers(headers));
} else {
transportHeadersReceived(Utils.convertHeaders(headers));
}
}
void transportDataReceived(ByteBuf frame, boolean endOfStream) {
transportDataReceived(new NettyReadableBuffer(frame.retain()), endOfStream);
}
} }
} }

View File

@ -120,13 +120,14 @@ class NettyClientTransport implements ConnectionClientTransport {
callOptions) { callOptions) {
Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(method, "method");
Preconditions.checkNotNull(headers, "headers"); Preconditions.checkNotNull(headers, "headers");
return new NettyClientStream(method, headers, channel, handler, maxMessageSize, authority, return new NettyClientStream(
negotiationHandler.scheme(), userAgent) { new NettyClientStream.TransportState(handler, maxMessageSize) {
@Override @Override
protected Status statusFromFailedFuture(ChannelFuture f) { protected Status statusFromFailedFuture(ChannelFuture f) {
return NettyClientTransport.this.statusFromFailedFuture(f); return NettyClientTransport.this.statusFromFailedFuture(f);
} }
}; },
method, headers, channel, authority, negotiationHandler.scheme(), userAgent);
} }
@Override @Override

View File

@ -32,7 +32,6 @@
package io.grpc.netty; package io.grpc.netty;
import io.grpc.internal.AbstractStream2; import io.grpc.internal.AbstractStream2;
import io.grpc.internal.Stream;
/** /**
* Command which requests messages from the deframer. * Command which requests messages from the deframer.
@ -40,26 +39,14 @@ import io.grpc.internal.Stream;
class RequestMessagesCommand extends WriteQueue.AbstractQueuedCommand { class RequestMessagesCommand extends WriteQueue.AbstractQueuedCommand {
private final int numMessages; private final int numMessages;
private final Stream stream;
private final AbstractStream2.TransportState state; private final AbstractStream2.TransportState state;
public RequestMessagesCommand(Stream stream, int numMessages) {
this.state = null;
this.numMessages = numMessages;
this.stream = stream;
}
public RequestMessagesCommand(AbstractStream2.TransportState state, int numMessages) { public RequestMessagesCommand(AbstractStream2.TransportState state, int numMessages) {
this.state = state; this.state = state;
this.numMessages = numMessages; this.numMessages = numMessages;
this.stream = null;
} }
void requestMessages() { void requestMessages() {
if (stream != null) { state.requestMessagesFromDeframer(numMessages);
stream.request(numMessages);
} else {
state.requestMessagesFromDeframer(numMessages);
}
} }
} }

View File

@ -32,6 +32,7 @@
package io.grpc.netty; package io.grpc.netty;
import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Charsets.UTF_8;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC; import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER; import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
import static io.grpc.netty.Utils.HTTPS; import static io.grpc.netty.Utils.HTTPS;
@ -41,6 +42,7 @@ import static io.grpc.netty.Utils.TE_HEADER;
import static io.grpc.netty.Utils.TE_TRAILERS; import static io.grpc.netty.Utils.TE_TRAILERS;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
@ -49,17 +51,16 @@ import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.notNull; import static org.mockito.Matchers.notNull;
import static org.mockito.Mockito.calls;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.base.Ticker; import com.google.common.base.Ticker;
import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusException; import io.grpc.StatusException;
import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.ClientTransport; import io.grpc.internal.ClientTransport;
import io.grpc.internal.ClientTransport.PingCallback; import io.grpc.internal.ClientTransport.PingCallback;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
@ -87,19 +88,17 @@ import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import java.io.InputStream;
/** /**
* Tests for {@link NettyClientHandler}. * Tests for {@link NettyClientHandler}.
*/ */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHandler> { public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHandler> {
// TODO(zhangkun83): mocking concrete classes is not safe. Consider making NettyClientStream an private NettyClientStream.TransportState streamTransportState;
// interface.
@Mock
private NettyClientStream stream;
private Http2Headers grpcHeaders; private Http2Headers grpcHeaders;
private long nanoTime; // backs a ticker, for testing ping round-trip time measurement private long nanoTime; // backs a ticker, for testing ping round-trip time measurement
private int flowControlWindow = DEFAULT_WINDOW_SIZE; private int flowControlWindow = DEFAULT_WINDOW_SIZE;
@ -108,6 +107,8 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
@Mock @Mock
private NettyClientTransport.Listener listener; private NettyClientTransport.Listener listener;
@Mock
private ClientStreamListener streamListener;
/** /**
* Set up for test. * Set up for test.
@ -118,6 +119,8 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
lifecycleManager = new ClientTransportLifecycleManager(listener); lifecycleManager = new ClientTransportLifecycleManager(listener);
initChannel(new GrpcHttp2ClientHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE)); initChannel(new GrpcHttp2ClientHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE));
streamTransportState = new TransportStateImpl(handler(), DEFAULT_MAX_MESSAGE_SIZE);
streamTransportState.setListener(streamListener);
grpcHeaders = new DefaultHttp2Headers() grpcHeaders = new DefaultHttp2Headers()
.scheme(HTTPS) .scheme(HTTPS)
@ -138,15 +141,14 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
// Force the stream to be buffered. // Force the stream to be buffered.
receiveMaxConcurrentStreams(0); receiveMaxConcurrentStreams(0);
// Create a new stream with id 3. // Create a new stream with id 3.
ChannelFuture createFuture = enqueue(new CreateStreamCommand(grpcHeaders, stream)); ChannelFuture createFuture = enqueue(
verify(stream).id(eq(3)); new CreateStreamCommand(grpcHeaders, streamTransportState));
when(stream.id()).thenReturn(3); assertEquals(3, streamTransportState.id());
// Cancel the stream. // Cancel the stream.
cancelStream(Status.CANCELLED); cancelStream(Status.CANCELLED);
assertTrue(createFuture.isSuccess()); assertTrue(createFuture.isSuccess());
verify(stream).transportReportStatus(eq(Status.CANCELLED), eq(true), verify(streamListener).closed(eq(Status.CANCELLED), any(Metadata.class));
any(Metadata.class));
} }
@Test @Test
@ -224,7 +226,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
createStream(); createStream();
// Send a frame and verify that it was written. // Send a frame and verify that it was written.
ChannelFuture future = enqueue(new SendGrpcFrameCommand(stream, content(), true)); ChannelFuture future = enqueue(new SendGrpcFrameCommand(streamTransportState, content(), true));
assertTrue(future.isSuccess()); assertTrue(future.isSuccess());
verifyWrite().writeData(eq(ctx()), eq(3), eq(content()), eq(0), eq(true), verifyWrite().writeData(eq(ctx()), eq(3), eq(content()), eq(0), eq(true),
@ -233,43 +235,43 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
@Test @Test
public void sendForUnknownStreamShouldFail() throws Exception { public void sendForUnknownStreamShouldFail() throws Exception {
when(stream.id()).thenReturn(3); ChannelFuture future = enqueue(new SendGrpcFrameCommand(streamTransportState, content(), true));
ChannelFuture future = enqueue(new SendGrpcFrameCommand(stream, content(), true));
assertTrue(future.isDone()); assertTrue(future.isDone());
assertFalse(future.isSuccess()); assertFalse(future.isSuccess());
} }
@Test @Test
public void inboundHeadersShouldForwardToStream() throws Exception { public void inboundShouldForwardToStream() throws Exception {
createStream(); createStream();
// Read a headers frame first. // Read a headers frame first.
Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK) Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK)
.set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC); .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC)
.set(as("magic"), as("value"));
ByteBuf headersFrame = headersFrame(3, headers); ByteBuf headersFrame = headersFrame(3, headers);
channelRead(headersFrame); channelRead(headersFrame);
verify(stream).transportHeadersReceived(headers, false); ArgumentCaptor<Metadata> captor = ArgumentCaptor.forClass(Metadata.class);
} verify(streamListener).headersRead(captor.capture());
assertEquals("value",
captor.getValue().get(Metadata.Key.of("magic", Metadata.ASCII_STRING_MARSHALLER)));
@Test streamTransportState.requestMessagesFromDeframer(1);
public void inboundDataShouldForwardToStream() throws Exception {
ByteBuf data = content().copy();
createStream();
// Create a data frame and then trigger the handler to read it. // Create a data frame and then trigger the handler to read it.
// Need to retain to simulate what is done by the stream. ByteBuf frame = grpcDataFrame(3, false, contentAsArray());
ByteBuf frame = dataFrame(3, false).retain();
channelRead(frame); channelRead(frame);
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class); ArgumentCaptor<InputStream> isCaptor = ArgumentCaptor.forClass(InputStream.class);
verify(stream).transportDataReceived(captor.capture(), eq(false)); verify(streamListener).messageRead(isCaptor.capture());
assertTrue(ByteBufUtil.equals(data, captor.getValue())); assertArrayEquals(ByteBufUtil.getBytes(content()),
ByteStreams.toByteArray(isCaptor.getValue()));
isCaptor.getValue().close();
} }
@Test @Test
public void receivedGoAwayShouldCancelBufferedStream() throws Exception { public void receivedGoAwayShouldCancelBufferedStream() throws Exception {
// Force the stream to be buffered. // Force the stream to be buffered.
receiveMaxConcurrentStreams(0); receiveMaxConcurrentStreams(0);
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, stream)); ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
channelRead(goAwayFrame(0)); channelRead(goAwayFrame(0));
assertTrue(future.isDone()); assertTrue(future.isDone());
assertFalse(future.isSuccess()); assertFalse(future.isSuccess());
@ -280,13 +282,12 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
@Test @Test
public void receivedGoAwayShouldFailUnknownStreams() throws Exception { public void receivedGoAwayShouldFailUnknownStreams() throws Exception {
enqueue(new CreateStreamCommand(grpcHeaders, stream)); enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
// Read a GOAWAY that indicates our stream was never processed by the server. // Read a GOAWAY that indicates our stream was never processed by the server.
channelRead(goAwayFrame(0, 8 /* Cancel */, Unpooled.copiedBuffer("this is a test", UTF_8))); channelRead(goAwayFrame(0, 8 /* Cancel */, Unpooled.copiedBuffer("this is a test", UTF_8)));
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class); ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(stream).transportReportStatus(captor.capture(), eq(false), verify(streamListener).closed(captor.capture(), notNull(Metadata.class));
notNull(Metadata.class));
assertEquals(Status.CANCELLED.getCode(), captor.getValue().getCode()); assertEquals(Status.CANCELLED.getCode(), captor.getValue().getCode());
assertEquals("HTTP/2 error code: CANCEL\nReceived Goaway\nthis is a test", assertEquals("HTTP/2 error code: CANCEL\nReceived Goaway\nthis is a test",
captor.getValue().getDescription()); captor.getValue().getDescription());
@ -296,7 +297,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
public void receivedGoAwayShouldFailUnknownBufferedStreams() throws Exception { public void receivedGoAwayShouldFailUnknownBufferedStreams() throws Exception {
receiveMaxConcurrentStreams(0); receiveMaxConcurrentStreams(0);
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, stream)); ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
// Read a GOAWAY that indicates our stream was never processed by the server. // Read a GOAWAY that indicates our stream was never processed by the server.
channelRead(goAwayFrame(0, 8 /* Cancel */, Unpooled.copiedBuffer("this is a test", UTF_8))); channelRead(goAwayFrame(0, 8 /* Cancel */, Unpooled.copiedBuffer("this is a test", UTF_8)));
@ -314,7 +315,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
channelRead(goAwayFrame(0, 8 /* Cancel */, Unpooled.copiedBuffer("this is a test", UTF_8))); channelRead(goAwayFrame(0, 8 /* Cancel */, Unpooled.copiedBuffer("this is a test", UTF_8)));
// Now try to create a stream. // Now try to create a stream.
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, stream)); ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
assertTrue(future.isDone()); assertTrue(future.isDone());
assertFalse(future.isSuccess()); assertFalse(future.isSuccess());
Status status = Status.fromThrowable(future.cause()); Status status = Status.fromThrowable(future.cause());
@ -326,19 +327,17 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
@Test @Test
public void cancelStreamShouldCreateAndThenFailBufferedStream() throws Exception { public void cancelStreamShouldCreateAndThenFailBufferedStream() throws Exception {
receiveMaxConcurrentStreams(0); receiveMaxConcurrentStreams(0);
enqueue(new CreateStreamCommand(grpcHeaders, stream)); enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
verify(stream).id(3); assertEquals(3, streamTransportState.id());
when(stream.id()).thenReturn(3);
cancelStream(Status.CANCELLED); cancelStream(Status.CANCELLED);
verify(stream).transportReportStatus(eq(Status.CANCELLED), eq(true), verify(streamListener).closed(eq(Status.CANCELLED), any(Metadata.class));
any(Metadata.class));
} }
@Test @Test
public void channelShutdownShouldCancelBufferedStreams() throws Exception { public void channelShutdownShouldCancelBufferedStreams() throws Exception {
// Force a stream to get added to the pending queue. // Force a stream to get added to the pending queue.
receiveMaxConcurrentStreams(0); receiveMaxConcurrentStreams(0);
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, stream)); ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
handler().channelInactive(ctx()); handler().channelInactive(ctx());
assertTrue(future.isDone()); assertTrue(future.isDone());
@ -351,9 +350,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
handler().channelInactive(ctx()); handler().channelInactive(ctx());
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class); ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
InOrder inOrder = inOrder(stream); verify(streamListener).closed(captor.capture(), notNull(Metadata.class));
inOrder.verify(stream, calls(1)).transportReportStatus(captor.capture(), eq(false),
notNull(Metadata.class));
assertEquals(Status.UNAVAILABLE.getCode(), captor.getValue().getCode()); assertEquals(Status.UNAVAILABLE.getCode(), captor.getValue().getCode());
} }
@ -375,12 +372,18 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
@Test @Test
public void createIncrementsIdsForActualAndBufferdStreams() throws Exception { public void createIncrementsIdsForActualAndBufferdStreams() throws Exception {
receiveMaxConcurrentStreams(2); receiveMaxConcurrentStreams(2);
enqueue(new CreateStreamCommand(grpcHeaders, stream)); enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
verify(stream).id(eq(3)); assertEquals(3, streamTransportState.id());
enqueue(new CreateStreamCommand(grpcHeaders, stream));
verify(stream).id(eq(5)); streamTransportState = new TransportStateImpl(handler(), DEFAULT_MAX_MESSAGE_SIZE);
enqueue(new CreateStreamCommand(grpcHeaders, stream)); streamTransportState.setListener(streamListener);
verify(stream).id(eq(7)); enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
assertEquals(5, streamTransportState.id());
streamTransportState = new TransportStateImpl(handler(), DEFAULT_MAX_MESSAGE_SIZE);
streamTransportState.setListener(streamListener);
enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
assertEquals(7, streamTransportState.id());
} }
@Test @Test
@ -467,7 +470,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class); ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
verifyWrite().writePing(eq(ctx()), eq(false), captor.capture(), any(ChannelPromise.class)); verifyWrite().writePing(eq(ctx()), eq(false), captor.capture(), any(ChannelPromise.class));
ByteBuf payload = captor.getValue(); ByteBuf payload = captor.getValue();
channelRead(dataFrame(3, false)); channelRead(grpcDataFrame(3, false, contentAsArray()));
long pingData = handler().flowControlPing().payload(); long pingData = handler().flowControlPing().payload();
ByteBuf buffer = handler().ctx().alloc().buffer(8); ByteBuf buffer = handler().ctx().alloc().buffer(8);
buffer.writeLong(pingData); buffer.writeLong(pingData);
@ -507,18 +510,13 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
channelRead(serializedSettings); channelRead(serializedSettings);
} }
private ByteBuf dataFrame(int streamId, boolean endStream) {
return dataFrame(streamId, endStream, content());
}
private ChannelFuture createStream() throws Exception { private ChannelFuture createStream() throws Exception {
ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, stream)); ChannelFuture future = enqueue(new CreateStreamCommand(grpcHeaders, streamTransportState));
when(stream.id()).thenReturn(3);
return future; return future;
} }
private ChannelFuture cancelStream(Status status) throws Exception { private ChannelFuture cancelStream(Status status) throws Exception {
return enqueue(new CancelClientStreamCommand(stream, status)); return enqueue(new CancelClientStreamCommand(streamTransportState, status));
} }
@Override @Override
@ -567,4 +565,15 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
this.failureCause = cause; this.failureCause = cause;
} }
} }
class TransportStateImpl extends NettyClientStream.TransportState {
public TransportStateImpl(NettyClientHandler handler, int maxMessageSize) {
super(handler, maxMessageSize);
}
@Override
protected Status statusFromFailedFuture(ChannelFuture f) {
return Utils.statusFromThrowable(f.cause());
}
}
} }

View File

@ -51,6 +51,7 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableListMultimap;
@ -63,7 +64,6 @@ import io.grpc.internal.GrpcUtil;
import io.grpc.netty.WriteQueue.QueuedCommand; import io.grpc.netty.WriteQueue.QueuedCommand;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2Headers; import io.netty.handler.codec.http2.DefaultHttp2Headers;
@ -109,16 +109,15 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test @Test
public void closeShouldSucceed() { public void closeShouldSucceed() {
// Force stream creation. // Force stream creation.
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
stream().halfClose(); stream().halfClose();
assertTrue(stream().canReceive()); verifyNoMoreInteractions(listener);
assertFalse(stream().canSend());
} }
@Test @Test
public void cancelShouldSendCommand() { public void cancelShouldSendCommand() {
// Set stream id to indicate it has been created // Set stream id to indicate it has been created
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
stream().cancel(Status.CANCELLED); stream().cancel(Status.CANCELLED);
ArgumentCaptor<CancelClientStreamCommand> commandCaptor = ArgumentCaptor<CancelClientStreamCommand> commandCaptor =
ArgumentCaptor.forClass(CancelClientStreamCommand.class); ArgumentCaptor.forClass(CancelClientStreamCommand.class);
@ -129,7 +128,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test @Test
public void deadlineExceededCancelShouldSendCommand() { public void deadlineExceededCancelShouldSendCommand() {
// Set stream id to indicate it has been created // Set stream id to indicate it has been created
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
stream().cancel(Status.DEADLINE_EXCEEDED); stream().cancel(Status.DEADLINE_EXCEEDED);
ArgumentCaptor<CancelClientStreamCommand> commandCaptor = ArgumentCaptor<CancelClientStreamCommand> commandCaptor =
ArgumentCaptor.forClass(CancelClientStreamCommand.class); ArgumentCaptor.forClass(CancelClientStreamCommand.class);
@ -146,12 +145,12 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test @Test
public void writeMessageShouldSendRequest() throws Exception { public void writeMessageShouldSendRequest() throws Exception {
// Force stream creation. // Force stream creation.
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
byte[] msg = smallMessage(); byte[] msg = smallMessage();
stream.writeMessage(new ByteArrayInputStream(msg)); stream.writeMessage(new ByteArrayInputStream(msg));
stream.flush(); stream.flush();
verify(writeQueue).enqueue( verify(writeQueue).enqueue(
eq(new SendGrpcFrameCommand(stream, messageFrame(MESSAGE), false)), eq(new SendGrpcFrameCommand(stream.transportState(), messageFrame(MESSAGE), false)),
any(ChannelPromise.class), any(ChannelPromise.class),
eq(true)); eq(true));
} }
@ -159,112 +158,109 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test @Test
public void writeMessageShouldSendRequestUnknownLength() throws Exception { public void writeMessageShouldSendRequestUnknownLength() throws Exception {
// Force stream creation. // Force stream creation.
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
byte[] msg = smallMessage(); byte[] msg = smallMessage();
stream.writeMessage(new BufferedInputStream(new ByteArrayInputStream(msg))); stream.writeMessage(new BufferedInputStream(new ByteArrayInputStream(msg)));
stream.flush(); stream.flush();
// Two writes occur, one for the GRPC frame header and the second with the payload // Two writes occur, one for the GRPC frame header and the second with the payload
verify(writeQueue).enqueue( verify(writeQueue).enqueue(
eq(new SendGrpcFrameCommand(stream, messageFrame(MESSAGE).slice(0, 5), false)), eq(new SendGrpcFrameCommand(
stream.transportState(), messageFrame(MESSAGE).slice(0, 5), false)),
any(ChannelPromise.class), any(ChannelPromise.class),
eq(false)); eq(false));
verify(writeQueue).enqueue( verify(writeQueue).enqueue(
eq(new SendGrpcFrameCommand(stream, messageFrame(MESSAGE).slice(5, 11), false)), eq(new SendGrpcFrameCommand(
stream.transportState(), messageFrame(MESSAGE).slice(5, 11), false)),
any(ChannelPromise.class), any(ChannelPromise.class),
eq(true)); eq(true));
} }
@Test @Test
public void setStatusWithOkShouldCloseStream() { public void setStatusWithOkShouldCloseStream() {
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
stream().transportReportStatus(Status.OK, true, new Metadata()); stream().transportState().transportReportStatus(Status.OK, true, new Metadata());
verify(listener).closed(same(Status.OK), any(Metadata.class)); verify(listener).closed(same(Status.OK), any(Metadata.class));
assertTrue(stream.isClosed());
} }
@Test @Test
public void setStatusWithErrorShouldCloseStream() { public void setStatusWithErrorShouldCloseStream() {
Status errorStatus = Status.INTERNAL; Status errorStatus = Status.INTERNAL;
stream().transportReportStatus(errorStatus, true, new Metadata()); stream().transportState().transportReportStatus(errorStatus, true, new Metadata());
verify(listener).closed(eq(errorStatus), any(Metadata.class)); verify(listener).closed(eq(errorStatus), any(Metadata.class));
assertTrue(stream.isClosed());
} }
@Test @Test
public void setStatusWithOkShouldNotOverrideError() { public void setStatusWithOkShouldNotOverrideError() {
Status errorStatus = Status.INTERNAL; Status errorStatus = Status.INTERNAL;
stream().transportReportStatus(errorStatus, true, new Metadata()); stream().transportState().transportReportStatus(errorStatus, true, new Metadata());
stream().transportReportStatus(Status.OK, true, new Metadata()); stream().transportState().transportReportStatus(Status.OK, true, new Metadata());
verify(listener).closed(any(Status.class), any(Metadata.class)); verify(listener).closed(any(Status.class), any(Metadata.class));
assertTrue(stream.isClosed());
} }
@Test @Test
public void setStatusWithErrorShouldNotOverridePreviousError() { public void setStatusWithErrorShouldNotOverridePreviousError() {
Status errorStatus = Status.INTERNAL; Status errorStatus = Status.INTERNAL;
stream().transportReportStatus(errorStatus, true, new Metadata()); stream().transportState().transportReportStatus(errorStatus, true, new Metadata());
stream().transportReportStatus(Status.fromThrowable(new RuntimeException("fake")), true, stream().transportState().transportReportStatus(
new Metadata()); Status.fromThrowable(new RuntimeException("fake")), true, new Metadata());
verify(listener).closed(any(Status.class), any(Metadata.class)); verify(listener).closed(any(Status.class), any(Metadata.class));
assertTrue(stream.isClosed());
} }
@Override @Override
@Test @Test
public void inboundMessageShouldCallListener() throws Exception { public void inboundMessageShouldCallListener() throws Exception {
// Receive headers first so that it's a valid GRPC response. // Receive headers first so that it's a valid GRPC response.
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
stream().transportHeadersReceived(grpcResponseHeaders(), false); stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
super.inboundMessageShouldCallListener(); super.inboundMessageShouldCallListener();
} }
@Test @Test
public void inboundHeadersShouldCallListenerHeadersRead() throws Exception { public void inboundHeadersShouldCallListenerHeadersRead() throws Exception {
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
Http2Headers headers = grpcResponseHeaders(); Http2Headers headers = grpcResponseHeaders();
stream().transportHeadersReceived(headers, false); stream().transportState().transportHeadersReceived(headers, false);
verify(listener).headersRead(any(Metadata.class)); verify(listener).headersRead(any(Metadata.class));
} }
@Test @Test
public void inboundTrailersClosesCall() throws Exception { public void inboundTrailersClosesCall() throws Exception {
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
stream().transportHeadersReceived(grpcResponseHeaders(), false); stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
super.inboundMessageShouldCallListener(); super.inboundMessageShouldCallListener();
stream().transportHeadersReceived(grpcResponseTrailers(Status.OK), true); stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.OK), true);
} }
@Test @Test
public void inboundStatusShouldSetStatus() throws Exception { public void inboundStatusShouldSetStatus() throws Exception {
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
// Receive headers first so that it's a valid GRPC response. // Receive headers first so that it's a valid GRPC response.
stream().transportHeadersReceived(grpcResponseHeaders(), false); stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
stream().transportHeadersReceived(grpcResponseTrailers(Status.INTERNAL), true); stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.INTERNAL), true);
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class); ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(listener).closed(captor.capture(), any(Metadata.class)); verify(listener).closed(captor.capture(), any(Metadata.class));
assertEquals(Status.INTERNAL.getCode(), captor.getValue().getCode()); assertEquals(Status.INTERNAL.getCode(), captor.getValue().getCode());
assertTrue(stream.isClosed());
} }
@Test @Test
public void invalidInboundHeadersCancelStream() throws Exception { public void invalidInboundHeadersCancelStream() throws Exception {
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
Http2Headers headers = grpcResponseHeaders(); Http2Headers headers = grpcResponseHeaders();
headers.set("random", "4"); headers.set("random", "4");
headers.remove(CONTENT_TYPE_HEADER); headers.remove(CONTENT_TYPE_HEADER);
// Remove once b/16290036 is fixed. // Remove once b/16290036 is fixed.
headers.status(new AsciiString("500")); headers.status(new AsciiString("500"));
stream().transportHeadersReceived(headers, false); stream().transportState().transportHeadersReceived(headers, false);
verify(listener, never()).closed(any(Status.class), any(Metadata.class)); verify(listener, never()).closed(any(Status.class), any(Metadata.class));
// We are now waiting for 100 bytes of error context on the stream, cancel has not yet been // We are now waiting for 100 bytes of error context on the stream, cancel has not yet been
// sent // sent
verify(channel, never()).writeAndFlush(any(CancelClientStreamCommand.class)); verify(channel, never()).writeAndFlush(any(CancelClientStreamCommand.class));
stream().transportDataReceived(Unpooled.buffer(100).writeZero(100), false); stream().transportState().transportDataReceived(Unpooled.buffer(100).writeZero(100), false);
verify(channel, never()).writeAndFlush(any(CancelClientStreamCommand.class)); verify(channel, never()).writeAndFlush(any(CancelClientStreamCommand.class));
stream().transportDataReceived(Unpooled.buffer(1000).writeZero(1000), false); stream().transportState().transportDataReceived(Unpooled.buffer(1000).writeZero(1000), false);
// Now verify that cancel is sent and an error is reported to the listener // Now verify that cancel is sent and an error is reported to the listener
verify(writeQueue).enqueue(any(CancelClientStreamCommand.class), eq(true)); verify(writeQueue).enqueue(any(CancelClientStreamCommand.class), eq(true));
@ -274,20 +270,19 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
assertEquals(Status.UNKNOWN.getCode(), captor.getValue().getCode()); assertEquals(Status.UNKNOWN.getCode(), captor.getValue().getCode());
assertEquals("4", metadataCaptor.getValue() assertEquals("4", metadataCaptor.getValue()
.get(Metadata.Key.of("random", Metadata.ASCII_STRING_MARSHALLER))); .get(Metadata.Key.of("random", Metadata.ASCII_STRING_MARSHALLER)));
assertTrue(stream.isClosed());
} }
@Test @Test
public void invalidInboundContentTypeShouldCancelStream() { public void invalidInboundContentTypeShouldCancelStream() {
// Set stream id to indicate it has been created // Set stream id to indicate it has been created
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK).set(CONTENT_TYPE_HEADER, Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK).set(CONTENT_TYPE_HEADER,
new AsciiString("application/bad", UTF_8)); new AsciiString("application/bad", UTF_8));
stream().transportHeadersReceived(headers, false); stream().transportState().transportHeadersReceived(headers, false);
Http2Headers trailers = new DefaultHttp2Headers() Http2Headers trailers = new DefaultHttp2Headers()
.set(new AsciiString("grpc-status", UTF_8), new AsciiString("0", UTF_8)); .set(new AsciiString("grpc-status", UTF_8), new AsciiString("0", UTF_8));
stream().transportHeadersReceived(trailers, true); stream().transportState().transportHeadersReceived(trailers, true);
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class); ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class); ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
verify(listener).closed(captor.capture(), metadataCaptor.capture()); verify(listener).closed(captor.capture(), metadataCaptor.capture());
@ -300,7 +295,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test @Test
public void nonGrpcResponseShouldSetStatus() throws Exception { public void nonGrpcResponseShouldSetStatus() throws Exception {
stream().transportDataReceived(Unpooled.copiedBuffer(MESSAGE, UTF_8), true); stream().transportState().transportDataReceived(Unpooled.copiedBuffer(MESSAGE, UTF_8), true);
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class); ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(listener).closed(captor.capture(), any(Metadata.class)); verify(listener).closed(captor.capture(), any(Metadata.class));
assertEquals(Status.Code.INTERNAL, captor.getValue().getCode()); assertEquals(Status.Code.INTERNAL, captor.getValue().getCode());
@ -308,13 +303,13 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test @Test
public void deframedDataAfterCancelShouldBeIgnored() throws Exception { public void deframedDataAfterCancelShouldBeIgnored() throws Exception {
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
// Receive headers first so that it's a valid GRPC response. // Receive headers first so that it's a valid GRPC response.
stream().transportHeadersReceived(grpcResponseHeaders(), false); stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
// Receive 2 consecutive empty frames. Only one is delivered at a time to the listener. // Receive 2 consecutive empty frames. Only one is delivered at a time to the listener.
stream().transportDataReceived(simpleGrpcFrame(), false); stream().transportState().transportDataReceived(simpleGrpcFrame(), false);
stream().transportDataReceived(simpleGrpcFrame(), false); stream().transportState().transportDataReceived(simpleGrpcFrame(), false);
// Only allow the first to be delivered. // Only allow the first to be delivered.
stream().request(1); stream().request(1);
@ -323,14 +318,14 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
// data frames have been processed. Since cancellation will interrupt message delivery, // data frames have been processed. Since cancellation will interrupt message delivery,
// this status will never be processed and the listener will instead only see the // this status will never be processed and the listener will instead only see the
// cancellation. // cancellation.
stream().transportHeadersReceived(grpcResponseTrailers(Status.INTERNAL), true); stream().transportState().transportHeadersReceived(grpcResponseTrailers(Status.INTERNAL), true);
// Verify that the first was delivered. // Verify that the first was delivered.
verify(listener).messageRead(any(InputStream.class)); verify(listener).messageRead(any(InputStream.class));
// Now set the error status. // Now set the error status.
Metadata trailers = Utils.convertTrailers(grpcResponseTrailers(Status.CANCELLED)); Metadata trailers = Utils.convertTrailers(grpcResponseTrailers(Status.CANCELLED));
stream().transportReportStatus(Status.CANCELLED, true, trailers); stream().transportState().transportReportStatus(Status.CANCELLED, true, trailers);
// Now allow the delivery of the second. // Now allow the delivery of the second.
stream().request(1); stream().request(1);
@ -342,14 +337,14 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Test @Test
public void dataFrameWithEosShouldDeframeAndThenFail() { public void dataFrameWithEosShouldDeframeAndThenFail() {
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
stream().request(1); stream().request(1);
// Receive headers first so that it's a valid GRPC response. // Receive headers first so that it's a valid GRPC response.
stream().transportHeadersReceived(grpcResponseHeaders(), false); stream().transportState().transportHeadersReceived(grpcResponseHeaders(), false);
// Receive a DATA frame with EOS set. // Receive a DATA frame with EOS set.
stream().transportDataReceived(simpleGrpcFrame(), true); stream().transportState().transportDataReceived(simpleGrpcFrame(), true);
// Verify that the message was delivered. // Verify that the message was delivered.
verify(listener).messageRead(any(InputStream.class)); verify(listener).messageRead(any(InputStream.class));
@ -363,14 +358,14 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
public void setHttp2StreamShouldNotifyReady() { public void setHttp2StreamShouldNotifyReady() {
listener = mock(ClientStreamListener.class); listener = mock(ClientStreamListener.class);
stream = new NettyClientStreamImpl(methodDescriptor, new Metadata(), channel, handler, stream = new NettyClientStream(new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE),
DEFAULT_MAX_MESSAGE_SIZE, AsciiString.of("localhost"), AsciiString.of("http"), methodDescriptor, new Metadata(), channel, AsciiString.of("localhost"),
AsciiString.of("agent")); AsciiString.of("http"), AsciiString.of("agent"));
stream.start(listener); stream.start(listener);
stream().id(STREAM_ID); stream().transportState().setId(STREAM_ID);
verify(listener, never()).onReady(); verify(listener, never()).onReady();
assertFalse(stream.isReady()); assertFalse(stream.isReady());
stream().setHttp2Stream(http2Stream); stream().transportState().setHttp2Stream(http2Stream);
verify(listener).onReady(); verify(listener).onReady();
assertTrue(stream.isReady()); assertTrue(stream.isReady());
} }
@ -383,9 +378,9 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
Mockito.reset(writeQueue); Mockito.reset(writeQueue);
when(writeQueue.enqueue(any(QueuedCommand.class), any(boolean.class))).thenReturn(future); when(writeQueue.enqueue(any(QueuedCommand.class), any(boolean.class))).thenReturn(future);
stream = new NettyClientStreamImpl(methodDescriptor, new Metadata(), channel, handler, stream = new NettyClientStream(new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE),
DEFAULT_MAX_MESSAGE_SIZE, AsciiString.of("localhost"), AsciiString.of("http"), methodDescriptor, new Metadata(), channel, AsciiString.of("localhost"),
AsciiString.of("good agent")); AsciiString.of("http"), AsciiString.of("good agent"));
stream.start(listener); stream.start(listener);
ArgumentCaptor<CreateStreamCommand> cmdCap = ArgumentCaptor.forClass(CreateStreamCommand.class); ArgumentCaptor<CreateStreamCommand> cmdCap = ArgumentCaptor.forClass(CreateStreamCommand.class);
@ -407,14 +402,12 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
} }
}).when(writeQueue).enqueue(any(QueuedCommand.class), any(ChannelPromise.class), anyBoolean()); }).when(writeQueue).enqueue(any(QueuedCommand.class), any(ChannelPromise.class), anyBoolean());
when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(future); when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(future);
NettyClientStream stream = new NettyClientStreamImpl(methodDescriptor, new Metadata(), channel, NettyClientStream stream = new NettyClientStream(
handler, DEFAULT_MAX_MESSAGE_SIZE, AsciiString.of("localhost"), AsciiString.of("http"), new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE), methodDescriptor, new Metadata(),
AsciiString.of("agent")); channel, AsciiString.of("localhost"), AsciiString.of("http"), AsciiString.of("agent"));
stream.start(listener); stream.start(listener);
assertTrue(stream.canSend()); stream.transportState().setId(STREAM_ID);
assertTrue(stream.canReceive()); stream.transportState().setHttp2Stream(http2Stream);
stream.id(STREAM_ID);
stream.setHttp2Stream(http2Stream);
reset(listener); reset(listener);
return stream; return stream;
} }
@ -447,11 +440,9 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
return Utils.convertTrailers(trailers, true); return Utils.convertTrailers(trailers, true);
} }
class NettyClientStreamImpl extends NettyClientStream { class TransportStateImpl extends NettyClientStream.TransportState {
NettyClientStreamImpl(MethodDescriptor<?, ?> method, Metadata headers, Channel channel, public TransportStateImpl(NettyClientHandler handler, int maxMessageSize) {
NettyClientHandler handler, int maxMessageSize, AsciiString authority, AsciiString scheme, super(handler, maxMessageSize);
AsciiString userAgent) {
super(method, headers, channel, handler, maxMessageSize, authority, scheme, userAgent);
} }
@Override @Override

View File

@ -177,8 +177,8 @@ public class NettyClientTransportTest {
} catch (ExecutionException e) { } catch (ExecutionException e) {
Status status = Status.fromThrowable(e); Status status = Status.fromThrowable(e);
assertEquals(INTERNAL, status.getCode()); assertEquals(INTERNAL, status.getCode());
System.err.println(status.getDescription()); assertTrue("Missing exceeds maximum from: " + status.getDescription(),
assertTrue(status.getDescription().contains("deframing")); status.getDescription().contains("exceeds maximum"));
} }
} }

View File

@ -39,7 +39,10 @@ import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.grpc.internal.MessageFramer;
import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil; import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
@ -68,6 +71,8 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.mockito.verification.VerificationMode; import org.mockito.verification.VerificationMode;
import java.io.ByteArrayInputStream;
/** /**
* Base class for Netty handler unit tests. * Base class for Netty handler unit tests.
*/ */
@ -144,6 +149,25 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
handler().channelRead(ctx, obj); handler().channelRead(ctx, obj);
} }
protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) {
final ByteBuf compressionFrame = Unpooled.buffer(content.length);
MessageFramer framer = new MessageFramer(new MessageFramer.Sink() {
@Override
public void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
if (frame != null) {
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf();
compressionFrame.writeBytes(bytebuf);
}
}
}, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT));
framer.writePayload(new ByteArrayInputStream(content));
framer.flush();
ChannelHandlerContext ctx = newMockContext();
new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream,
newPromise());
return captureWrite(ctx);
}
protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) { protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) {
// Need to retain the content since the frameWriter releases it. // Need to retain the content since the frameWriter releases it.
content.retain(); content.retain();

View File

@ -59,20 +59,16 @@ import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.Status.Code; import io.grpc.Status.Code;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.MessageFramer;
import io.grpc.internal.ServerStream; import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransportListener; import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.WritableBuffer;
import io.grpc.netty.GrpcHttp2HeadersDecoder.GrpcHttp2ServerHeadersDecoder; import io.grpc.netty.GrpcHttp2HeadersDecoder.GrpcHttp2ServerHeadersDecoder;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil; import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
import io.netty.handler.codec.http2.DefaultHttp2Headers; import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.Http2CodecUtil; import io.netty.handler.codec.http2.Http2CodecUtil;
import io.netty.handler.codec.http2.Http2Error; import io.netty.handler.codec.http2.Http2Error;
@ -91,7 +87,6 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import java.io.ByteArrayInputStream;
import java.io.InputStream; import java.io.InputStream;
/** /**
@ -159,11 +154,12 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
stream.request(1); stream.request(1);
// Create a data frame and then trigger the handler to read it. // Create a data frame and then trigger the handler to read it.
ByteBuf frame = dataFrame(STREAM_ID, endStream); ByteBuf frame = grpcDataFrame(STREAM_ID, endStream, contentAsArray());
channelRead(frame); channelRead(frame);
ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class); ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
verify(streamListener).messageRead(captor.capture()); verify(streamListener).messageRead(captor.capture());
assertArrayEquals(ByteBufUtil.getBytes(content()), ByteStreams.toByteArray(captor.getValue())); assertArrayEquals(ByteBufUtil.getBytes(content()), ByteStreams.toByteArray(captor.getValue()));
captor.getValue().close();
if (endStream) { if (endStream) {
verify(streamListener).halfClosed(); verify(streamListener).halfClosed();
@ -333,33 +329,10 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
stream = streamCaptor.getValue(); stream = streamCaptor.getValue();
} }
private ByteBuf dataFrame(int streamId, boolean endStream) {
byte[] contentAsArray = contentAsArray();
final ByteBuf compressionFrame = Unpooled.buffer(contentAsArray.length);
MessageFramer framer = new MessageFramer(new MessageFramer.Sink() {
@Override
public void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
if (frame != null) {
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf();
compressionFrame.writeBytes(bytebuf);
}
}
}, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT));
framer.writePayload(new ByteArrayInputStream(contentAsArray));
framer.flush();
if (endStream) {
framer.close();
}
ChannelHandlerContext ctx = newMockContext();
new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream,
newPromise());
return captureWrite(ctx);
}
private ByteBuf emptyGrpcFrame(int streamId, boolean endStream) throws Exception { private ByteBuf emptyGrpcFrame(int streamId, boolean endStream) throws Exception {
ByteBuf buf = NettyTestUtil.messageFrame(""); ByteBuf buf = NettyTestUtil.messageFrame("");
try { try {
return super.dataFrame(streamId, endStream, buf); return dataFrame(streamId, endStream, buf);
} finally { } finally {
buf.release(); buf.release();
} }

View File

@ -138,7 +138,8 @@ public abstract class NettyStreamTestBase<T extends Stream> {
((NettyServerStream) stream).transportState() ((NettyServerStream) stream).transportState()
.inboundDataReceived(messageFrame(MESSAGE), false); .inboundDataReceived(messageFrame(MESSAGE), false);
} else { } else {
((NettyClientStream) stream).transportDataReceived(messageFrame(MESSAGE), false); ((NettyClientStream) stream).transportState()
.transportDataReceived(messageFrame(MESSAGE), false);
} }
ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class); ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
verify(listener()).messageRead(captor.capture()); verify(listener()).messageRead(captor.capture());