internal: Split-state AbstractStream; sending and receiving

This introduces an AbstractStream2 that is intended to replace the
current AbstractStream. Only server-side is implemented in this commit
which is why AbstractStream remains. This is mostly a reorganization of
AbstractStream and children, but minor internal behavioral changes were
required which makes it appear more like a reimplementation.

A strong focus was on splitting state that is maintained on the
application's thread (with Stream) and state that is maintained by the
transport (and used for StreamListener). By splitting the state it makes
it much easier to verify thread-safety and to reason about interactions.

I consider this a stepping stone for making even more changes to
simplify the Stream implementations and do not think some of the changes
are yet at their logical conclusion. Some of the changes may also
immediately be replaced with something better. The focus was to improve
readability and comprehesibility to more easily make more interesting
changes.

The only thing really removed is some state checking during sending
which is already occurring in ServerCallImpl.
This commit is contained in:
Eric Anderson 2016-04-09 00:30:26 -07:00
parent 8090effa2d
commit 6382015f9d
15 changed files with 732 additions and 538 deletions

View File

@ -31,42 +31,145 @@
package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.base.Preconditions;
import io.grpc.Attributes;
import io.grpc.Metadata;
import io.grpc.Status;
import java.io.InputStream;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
/**
* Abstract base class for {@link ServerStream} implementations.
*
* @param <IdT> the type of the stream identifier
*/
public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
implements ServerStream {
private static final Logger log = Logger.getLogger(AbstractServerStream.class.getName());
* Abstract base class for {@link ServerStream} implementations. Extending classes only need to
* implement {@link #transportState()} and {@link #abstractServerStreamSink()}. Must only be called
* from the sending application thread.
*/
public abstract class AbstractServerStream extends AbstractStream2
implements ServerStream, MessageFramer.Sink {
/**
* A sink for outbound operations, separated from the stream simply to avoid name
* collisions/confusion. Only called from application thread.
*/
protected interface Sink {
/**
* Sends response headers to the remote end point.
*
* @param headers the headers to be sent to client.
*/
void writeHeaders(Metadata headers);
/**
* Sends an outbound frame to the remote end point.
*
* @param frame a buffer containing the chunk of data to be sent.
* @param flush {@code true} if more data may not be arriving soon
*/
void writeFrame(@Nullable WritableBuffer frame, boolean flush);
/**
* Sends trailers to the remote end point. This call implies end of stream.
*
* @param trailers metadata to be sent to the end point
* @param headersSent {@code true} if response headers have already been sent.
*/
void writeTrailers(Metadata trailers, boolean headersSent);
/**
* Requests up to the given number of messages from the call to be delivered. This should end up
* triggering {@link TransportState#requestMessagesFromDeframer(int)} on the transport thread.
*/
void request(int numMessages);
/**
* Tears down the stream, typically in the event of a timeout. This method may be called
* multiple times and from any thread.
*
* <p>This is a clone of {@link ServerStream#cancel()}.
*/
void cancel(Status status);
}
private final MessageFramer framer;
private boolean outboundClosed;
private boolean headersSent;
protected AbstractServerStream(WritableBufferAllocator bufferAllocator) {
framer = new MessageFramer(this, bufferAllocator);
}
@Override
protected abstract TransportState transportState();
/**
* Sink for transport to be called to perform outbound operations. Each stream must have its own
* unique sink.
*/
protected abstract Sink abstractServerStreamSink();
@Override
protected final MessageFramer framer() {
return framer;
}
@Override
public final void request(int numMessages) {
abstractServerStreamSink().request(numMessages);
}
@Override
public final void writeHeaders(Metadata headers) {
Preconditions.checkNotNull(headers, "headers");
headersSent = true;
abstractServerStreamSink().writeHeaders(headers);
}
@Override
public final void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
// Since endOfStream is triggered by the sending of trailers, avoid flush here and just flush
// after the trailers.
abstractServerStreamSink().writeFrame(frame, endOfStream ? false : flush);
}
@Override
public final void close(Status status, Metadata trailers) {
Preconditions.checkNotNull(status, "status");
Preconditions.checkNotNull(trailers, "trailers");
if (!outboundClosed) {
outboundClosed = true;
endOfMessages();
addStatusToTrailers(trailers, status);
abstractServerStreamSink().writeTrailers(trailers, headersSent);
}
}
private void addStatusToTrailers(Metadata trailers, Status status) {
trailers.removeAll(Status.CODE_KEY);
trailers.removeAll(Status.MESSAGE_KEY);
trailers.put(Status.CODE_KEY, status);
if (status.getDescription() != null) {
trailers.put(Status.MESSAGE_KEY, status.getDescription());
}
}
@Override
public final void cancel(Status status) {
abstractServerStreamSink().cancel(status);
}
@Override public Attributes attributes() {
return Attributes.EMPTY;
}
/** This should only called from the transport thread. */
protected abstract static class TransportState extends AbstractStream2.TransportState {
/** Whether listener.closed() has been called. */
private boolean listenerClosed;
private ServerStreamListener listener;
private boolean headersSent = false;
/**
* Whether the stream was closed gracefully by the application (vs. a transport-level failure).
*/
private boolean gracefulClose;
/** Saved trailers from close() that need to be sent once the framer has sent all messages. */
private Metadata stashedTrailers;
protected AbstractServerStream(WritableBufferAllocator bufferAllocator,
int maxMessageSize) {
super(bufferAllocator, maxMessageSize);
protected TransportState(int maxMessageSize) {
super(maxMessageSize);
}
/**
@ -74,199 +177,62 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
* thread.
*/
public final void setListener(ServerStreamListener listener) {
this.listener = checkNotNull(listener);
this.listener = Preconditions.checkNotNull(listener);
// Now that the stream has actually been initialized, call the listener's onReady callback if
// appropriate.
onStreamAllocated();
}
@Override
public void deliveryStalled() {}
@Override
public void endOfStream() {
closeDeframer();
listener().halfClosed();
}
@Override
protected ServerStreamListener listener() {
return listener;
}
@Override
protected void receiveMessage(InputStream is) {
inboundPhase(Phase.MESSAGE);
listener().messageRead(is);
}
@Override
public final void writeHeaders(Metadata headers) {
Preconditions.checkNotNull(headers, "headers");
outboundPhase(Phase.HEADERS);
headersSent = true;
internalSendHeaders(headers);
outboundPhase(Phase.MESSAGE);
}
@Override
public final void writeMessage(InputStream message) {
if (outboundPhase() != Phase.MESSAGE) {
throw new IllegalStateException("Messages are only permitted after headers and before close");
}
super.writeMessage(message);
}
@Override
public final void close(Status status, Metadata trailers) {
Preconditions.checkNotNull(status, "status");
Preconditions.checkNotNull(trailers, "trailers");
if (outboundPhase(Phase.STATUS) != Phase.STATUS) {
gracefulClose = true;
stashedTrailers = trailers;
writeStatusToTrailers(status);
closeFramer();
}
}
private void writeStatusToTrailers(Status status) {
stashedTrailers.removeAll(Status.CODE_KEY);
stashedTrailers.removeAll(Status.MESSAGE_KEY);
stashedTrailers.put(Status.CODE_KEY, status);
if (status.getDescription() != null) {
stashedTrailers.put(Status.MESSAGE_KEY, status.getDescription());
}
}
/**
* Called by transport implementations when they receive headers.
*
* @param headers the parsed headers
*/
protected void inboundHeadersReceived(Metadata headers) {
inboundPhase(Phase.MESSAGE);
}
/**
* Called in the network thread to process the content of an inbound DATA frame from the client.
* Called in the transport thread to process the content of an inbound DATA frame from the
* client.
*
* @param frame the inbound HTTP/2 DATA frame. If this buffer is not used immediately, it must
* be retained.
* @param endOfStream {@code true} if no more data will be received on the stream.
*/
public void inboundDataReceived(ReadableBuffer frame, boolean endOfStream) {
if (inboundPhase() == Phase.STATUS) {
frame.close();
return;
}
// Deframe the message. If a failure occurs, deframeFailed will be called.
deframe(frame, endOfStream);
}
@Override
protected final void deframeFailed(Throwable cause) {
log.log(Level.WARNING, "Exception processing message", cause);
abortStream(Status.fromThrowable(cause), true);
}
@Override
protected final void internalSendFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
if (frame != null) {
sendFrame(frame, false, endOfStream ? false : flush);
}
if (endOfStream) {
sendTrailers(stashedTrailers, headersSent);
headersSent = true;
stashedTrailers = null;
}
}
/**
* Sends response headers to the remote end points.
*
* @param headers the headers to be sent to client.
*/
protected abstract void internalSendHeaders(Metadata headers);
/**
* Sends an outbound frame to the remote end point.
*
* @param frame a buffer containing the chunk of data to be sent.
* @param endOfStream if {@code true} indicates that no more data will be sent on the stream by
* this endpoint.
* @param flush {@code true} if more data may not be arriving soon
*/
protected abstract void sendFrame(WritableBuffer frame, boolean endOfStream, boolean flush);
/**
* Sends trailers to the remote end point. This call implies end of stream.
*
* @param trailers metadata to be sent to end point
* @param headersSent {@code true} if response headers have already been sent.
*/
protected abstract void sendTrailers(Metadata trailers, boolean headersSent);
/**
* Indicates the stream is considered completely closed and there is no further opportunity for
* error. It calls the listener's {@code closed()} if it was not already done by {@link
* #abortStream}. Note that it is expected that either {@code closed()} or {@code abortStream()}
* was previously called, since {@code closed()} is required for a normal stream closure and
* {@code abortStream()} for abnormal.
*/
public void complete() {
if (!gracefulClose) {
closeListener(Status.INTERNAL.withDescription("successful complete() without close()"));
throw new IllegalStateException("successful complete() without close()");
}
closeListener(Status.OK);
}
/**
* Called when the remote end half-closes the stream.
*/
@Override
protected final void remoteEndClosed() {
halfCloseListener();
}
/**
* Aborts the stream with an error status, cleans up resources and notifies the listener if
* necessary.
* Notifies failure to the listener of the stream. The transport is responsible for notifying
* the client of the failure independent of this method.
*
* <p>Unlike {@link #close(Status, Metadata)}, this method is only called from the
* transport. The transport should use this method instead of {@code close(Status)} for internal
* errors to prevent exposing unexpected states and exceptions to the application.
*
* @param status the error status. Must not be {@link Status#OK}.
* @param notifyClient {@code true} if the stream is still writable and you want to notify the
* client about stream closure and send the status
*/
public final void abortStream(Status status, boolean notifyClient) {
// TODO(louiscryan): Investigate whether we can remove the notification to the client
// and rely on a transport layer stream reset instead.
public final void transportReportStatus(Status status) {
Preconditions.checkArgument(!status.isOk(), "status must not be OK");
closeListener(status);
if (notifyClient) {
// TODO(louiscryan): Remove
if (stashedTrailers == null) {
stashedTrailers = new Metadata();
}
writeStatusToTrailers(status);
sendStreamAbortToClient(status, stashedTrailers);
}
}
@Override
public boolean isClosed() {
return super.isClosed() || listenerClosed;
}
/**
* Notifies the remote client that this stream has aborted.
* Indicates the stream is considered completely closed and there is no further opportunity for
* error. It calls the listener's {@code closed()} if it was not already done by {@link
* #transportReportStatus}.
*/
protected abstract void sendStreamAbortToClient(Status status, Metadata trailers);
/**
* Fires a half-closed event to the listener and frees inbound resources.
*/
private void halfCloseListener() {
if (inboundPhase(Phase.STATUS) != Phase.STATUS && !listenerClosed) {
closeDeframer();
listener().halfClosed();
}
public void complete() {
closeListener(Status.OK);
}
/**
@ -279,8 +245,5 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
listener().closed(newStatus);
}
}
@Override public Attributes attributes() {
return Attributes.EMPTY;
}
}

View File

@ -0,0 +1,279 @@
/*
* Copyright 2014, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Codec;
import io.grpc.Compressor;
import io.grpc.Decompressor;
import java.io.InputStream;
import javax.annotation.concurrent.GuardedBy;
/**
* The stream and stream state as used by the application. Must only be called from the sending
* application thread.
*/
public abstract class AbstractStream2 implements Stream {
/** The framer to use for sending messages. */
protected abstract MessageFramer framer();
/**
* Obtain the transport state corresponding to this stream. Each stream must have its own unique
* transport state.
*/
protected abstract TransportState transportState();
@Override
public final void setMessageCompression(boolean enable) {
framer().setMessageCompression(enable);
}
@Override
public final void writeMessage(InputStream message) {
checkNotNull(message);
if (!framer().isClosed()) {
framer().writePayload(message);
}
}
@Override
public final void flush() {
if (!framer().isClosed()) {
framer().flush();
}
}
/**
* Closes the underlying framer. Should be called when the outgoing stream is gracefully closed
* (half closure on client; closure on server).
*/
protected final void endOfMessages() {
framer().close();
}
@Override
public final void setCompressor(Compressor compressor) {
framer().setCompressor(checkNotNull(compressor, "compressor"));
}
@Override
public final void setDecompressor(Decompressor decompressor) {
transportState().setDecompressor(checkNotNull(decompressor, "decompressor"));
}
@Override
public final boolean isReady() {
if (framer().isClosed()) {
return false;
}
return transportState().isReady();
}
/**
* Event handler to be called by the subclass when a number of bytes are being queued for sending
* to the remote endpoint.
*
* @param numBytes the number of bytes being sent.
*/
protected final void onSendingBytes(int numBytes) {
transportState().onSendingBytes(numBytes);
}
/**
* Stream state as used by the transport. This should only called from the transport thread
* (except for private interactions with {@code AbstractStream2}).
*/
public abstract static class TransportState implements MessageDeframer.Listener {
/**
* The default number of queued bytes for a given stream, below which
* {@link StreamListener#onReady()} will be called.
*/
private static final int DEFAULT_ONREADY_THRESHOLD = 32 * 1024;
private final MessageDeframer deframer;
private final Object onReadyLock = new Object();
/**
* The number of bytes currently queued, waiting to be sent. When this falls below
* DEFAULT_ONREADY_THRESHOLD, {@link StreamListener#onReady()} will be called.
*/
@GuardedBy("onReadyLock")
private int numSentBytesQueued;
/**
* Indicates the stream has been created on the connection. This implies that the stream is no
* longer limited by MAX_CONCURRENT_STREAMS.
*/
@GuardedBy("onReadyLock")
private boolean allocated;
protected TransportState(int maxMessageSize) {
deframer = new MessageDeframer(this, Codec.Identity.NONE, maxMessageSize);
}
@VisibleForTesting
TransportState(MessageDeframer deframer) {
this.deframer = deframer;
}
/**
* Override this method to provide a stream listener.
*/
protected abstract StreamListener listener();
@Override
public void messageRead(InputStream is) {
listener().messageRead(is);
}
/**
* Called when a {@link #deframe(ReadableBuffer, boolean)} operation failed.
*
* @param cause the actual failure
*/
protected abstract void deframeFailed(Throwable cause);
/**
* Closes this deframer and frees any resources. After this method is called, additional calls
* will have no effect.
*/
protected final void closeDeframer() {
deframer.close();
}
/**
* Called to parse a received frame and attempt delivery of any completed
* messages. Must be called from the transport thread.
*/
protected final void deframe(ReadableBuffer frame, boolean endOfStream) {
if (deframer.isClosed()) {
frame.close();
return;
}
try {
deframer.deframe(frame, endOfStream);
} catch (Throwable t) {
deframeFailed(t);
}
}
/**
* Called to request the given number of messages from the deframer. Must be called
* from the transport thread.
*/
public final void requestMessagesFromDeframer(int numMessages) {
if (deframer.isClosed()) {
return;
}
try {
deframer.request(numMessages);
} catch (Throwable t) {
deframeFailed(t);
}
}
private void setDecompressor(Decompressor decompressor) {
if (deframer.isClosed()) {
return;
}
deframer.setDecompressor(decompressor);
}
private boolean isReady() {
synchronized (onReadyLock) {
return allocated && numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD;
}
}
/**
* Event handler to be called by the subclass when the stream's headers have passed any
* connection flow control (i.e., MAX_CONCURRENT_STREAMS). It may call the listener's {@link
* StreamListener#onReady()} handler if appropriate. This must be called from the transport
* thread, since the listener may be called back directly.
*/
protected final void onStreamAllocated() {
checkState(listener() != null);
synchronized (onReadyLock) {
checkState(!allocated, "Already allocated");
allocated = true;
}
notifyIfReady();
}
/**
* Event handler to be called by the subclass when a number of bytes are being queued for
* sending to the remote endpoint.
*
* @param numBytes the number of bytes being sent.
*/
private void onSendingBytes(int numBytes) {
synchronized (onReadyLock) {
numSentBytesQueued += numBytes;
}
}
/**
* Event handler to be called by the subclass when a number of bytes has been sent to the remote
* endpoint. May call back the listener's {@link StreamListener#onReady()} handler if
* appropriate. This must be called from the transport thread, since the listener may be called
* back directly.
*
* @param numBytes the number of bytes that were sent.
*/
public final void onSentBytes(int numBytes) {
boolean doNotify;
synchronized (onReadyLock) {
boolean belowThresholdBefore = numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD;
numSentBytesQueued -= numBytes;
boolean belowThresholdAfter = numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD;
doNotify = !belowThresholdBefore && belowThresholdAfter;
}
if (doNotify) {
notifyIfReady();
}
}
private void notifyIfReady() {
boolean doNotify;
synchronized (onReadyLock) {
doNotify = isReady();
}
if (doNotify) {
listener().onReady();
}
}
}
}

View File

@ -51,6 +51,8 @@ import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;
/**
* Encodes gRPC messages to be delivered via the transport layer which implements {@link
* MessageFramer.Sink}.
@ -69,7 +71,7 @@ public class MessageFramer {
* @param endOfStream whether the frame is the last one for the GRPC stream
* @param flush {@code true} if more data may not be arriving soon
*/
void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flush);
void deliverFrame(@Nullable WritableBuffer frame, boolean endOfStream, boolean flush);
}
private static final int HEADER_LENGTH = 5;

View File

@ -32,17 +32,17 @@
package io.grpc.internal;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.AbstractStream.Phase;
import io.grpc.internal.MessageFramerTest.ByteWritableBuffer;
import org.junit.Rule;
@ -50,13 +50,10 @@ import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
/**
* Tests for {@link AbstractServerStream}.
@ -74,22 +71,24 @@ public class AbstractServerStreamTest {
}
};
private final AbstractServerStreamBase defaultStream =
new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE);
private AbstractServerStream.Sink sink = mock(AbstractServerStream.Sink.class);
private AbstractServerStreamBase stream = new AbstractServerStreamBase(
allocator, sink, new AbstractServerStreamBase.TransportState(MAX_MESSAGE_SIZE));
private final ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
@Test
public void setListener_setOnlyOnce() {
defaultStream.setListener(new ServerStreamListenerBase());
stream.transportState().setListener(new ServerStreamListenerBase());
thrown.expect(IllegalStateException.class);
defaultStream.setListener(new ServerStreamListenerBase());
stream.transportState().setListener(new ServerStreamListenerBase());
}
@Test
public void setListener_readyCalled() {
ServerStreamListener streamListener = mock(ServerStreamListener.class);
defaultStream.setListener(streamListener);
stream.transportState().setListener(streamListener);
verify(streamListener).onReady();
}
@ -98,179 +97,87 @@ public class AbstractServerStreamTest {
public void setListener_failsOnNull() {
thrown.expect(NullPointerException.class);
defaultStream.setListener(null);
stream.transportState().setListener(null);
}
@Test
public void receiveMessage_listenerCalled() {
public void messageRead_listenerCalled() {
final ServerStreamListener streamListener = mock(ServerStreamListener.class);
AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) {
@Override
protected ServerStreamListener listener() {
return streamListener;
}
};
stream.transportState().setListener(streamListener);
// Normally called by a deframe event.
stream.receiveMessage(new ByteArrayInputStream(new byte[]{}));
stream.transportState().messageRead(new ByteArrayInputStream(new byte[]{}));
verify(streamListener).messageRead(isA(InputStream.class));
}
@Test
public void receiveMessage_failsIfHalfClosed() {
// Simulate being closed, without invoking the listener
defaultStream.inboundPhase(Phase.STATUS);
thrown.expect(IllegalStateException.class);
// Normally called by a deframe event.
defaultStream.receiveMessage(new ByteArrayInputStream(new byte[]{}));
}
@Test
public void writeHeaders_failsOnNullHeaders() {
AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE);
thrown.expect(NullPointerException.class);
stream.writeHeaders(null);
}
@Test
public void writeHeaders_failsIfAlreadySent() {
defaultStream.writeHeaders(new Metadata());
thrown.expect(IllegalStateException.class);
defaultStream.writeHeaders(new Metadata());
}
@Test
public void writeHeaders() {
final AtomicReference<Metadata> capturedHeaders = new AtomicReference<Metadata>(null);
Metadata headers = new Metadata();
AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) {
@Override
protected void internalSendHeaders(Metadata captured) {
capturedHeaders.set(captured);
}
};
stream.writeHeaders(headers);
assertEquals(headers, capturedHeaders.get());
assertEquals(Phase.MESSAGE, stream.outboundPhase());
}
@Test
public void writeMessage_writeHeadersIfNeeded() {
final AtomicReference<Metadata> capturedHeaders = new AtomicReference<Metadata>(null);
AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) {
@Override
protected void internalSendHeaders(Metadata captured) {
capturedHeaders.set(captured);
}
};
stream.writeHeaders(new Metadata());
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
assertNotNull(capturedHeaders.get());
verify(sink).writeHeaders(same(headers));
}
@Test
public void writeMessage_dontWriteDuplicateHeaders() {
final AtomicReference<Metadata> capturedHeaders = new AtomicReference<Metadata>(null);
Metadata headers = new Metadata();
AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) {
@Override
protected void internalSendHeaders(Metadata captured) {
capturedHeaders.set(captured);
}
};
stream.writeHeaders(headers);
stream.writeHeaders(new Metadata());
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
// Make sure it wasn't called twice, by checking that the exact headers sent are the ones
// returned.
assertSame(headers, capturedHeaders.get());
// Make sure it wasn't called twice
verify(sink).writeHeaders(any(Metadata.class));
}
@Test
public void writeMessage_ignoreIfFramerClosed() {
final AtomicBoolean sendCalled = new AtomicBoolean();
AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) {
@Override
protected void sendFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
sendCalled.set(true);
}
};
stream.writeHeaders(new Metadata());
stream.closeFramer();
stream.endOfMessages();
reset(sink);
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
assertFalse(sendCalled.get());
verify(sink, never()).writeFrame(any(WritableBuffer.class), any(Boolean.class));
}
@Test
public void writeMessage() {
final AtomicBoolean sendCalled = new AtomicBoolean();
AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) {
@Override
protected void sendFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
sendCalled.set(true);
}
};
stream.writeHeaders(new Metadata());
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
// Force the message to be flushed
stream.closeFramer();
stream.flush();
assertTrue(sendCalled.get());
assertEquals(Phase.MESSAGE, stream.outboundPhase());
verify(sink).writeFrame(any(WritableBuffer.class), eq(true));
}
@Test
public void close_failsOnNullStatus() {
thrown.expect(NullPointerException.class);
defaultStream.close(null, new Metadata());
stream.close(null, new Metadata());
}
@Test
public void close_failsOnNullMetadata() {
thrown.expect(NullPointerException.class);
defaultStream.close(Status.INTERNAL, null);
stream.close(Status.INTERNAL, null);
}
@Test
public void close_sendsTrailers() {
final AtomicReference<Metadata> capturedTrailers = new AtomicReference<Metadata>(null);
AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) {
@Override
protected void sendTrailers(Metadata trailers, boolean headersSent) {
capturedTrailers.set(trailers);
}
};
Metadata trailers = new Metadata();
stream.close(Status.INTERNAL, trailers);
assertSame(trailers, capturedTrailers.get());
verify(sink).writeTrailers(any(Metadata.class), eq(false));
}
@Test
public void close_sendTrailersClearsReservedFields() {
final AtomicReference<Metadata> capturedTrailers = new AtomicReference<Metadata>(null);
AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) {
@Override
protected void sendTrailers(Metadata trailers, boolean headersSent) {
capturedTrailers.set(trailers);
}
};
// stream actually mutates trailers, so we can't check that the fields here are the same as
// the captured ones.
Metadata trailers = new Metadata();
@ -279,8 +186,9 @@ public class AbstractServerStreamTest {
stream.close(Status.INTERNAL.withDescription("bad"), trailers);
assertEquals(Status.Code.INTERNAL, capturedTrailers.get().get(Status.CODE_KEY).getCode());
assertEquals("bad", capturedTrailers.get().get(Status.MESSAGE_KEY));
verify(sink).writeTrailers(metadataCaptor.capture(), eq(false));
assertEquals(Status.Code.INTERNAL, metadataCaptor.getValue().get(Status.CODE_KEY).getCode());
assertEquals("bad", metadataCaptor.getValue().get(Status.MESSAGE_KEY));
}
private static class ServerStreamListenerBase implements ServerStreamListener {
@ -297,41 +205,38 @@ public class AbstractServerStreamTest {
public void closed(Status status) {}
}
private static class AbstractServerStreamBase extends AbstractServerStream<Void> {
protected AbstractServerStreamBase(
WritableBufferAllocator bufferAllocator, int maxMessageSize) {
super(bufferAllocator, maxMessageSize);
private static class AbstractServerStreamBase extends AbstractServerStream {
private final Sink sink;
private final AbstractServerStream.TransportState state;
protected AbstractServerStreamBase(WritableBufferAllocator bufferAllocator, Sink sink,
AbstractServerStream.TransportState state) {
super(bufferAllocator);
this.sink = sink;
this.state = state;
}
@Override
public void cancel(Status status) {}
@Override
public void request(int numMessages) {}
@Override
protected void internalSendHeaders(Metadata headers) {}
@Override
protected void sendFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {}
@Override
protected void sendTrailers(Metadata trailers, boolean headersSent) {}
@Override
@Nullable
public Void id() {
return null;
protected Sink abstractServerStreamSink() {
return sink;
}
@Override
protected void inboundDeliveryPaused() {}
protected AbstractServerStream.TransportState transportState() {
return state;
}
static class TransportState extends AbstractServerStream.TransportState {
protected TransportState(int maxMessageSize) {
super(maxMessageSize);
}
@Override
protected void returnProcessedBytes(int processedBytes) {}
protected void deframeFailed(Throwable cause) {}
@Override
protected void sendStreamAbortToClient(Status status, Metadata trailers) {}
public void bytesRead(int processedBytes) {}
}
}
}

View File

@ -41,15 +41,15 @@ import io.grpc.Status;
* Command sent from a Netty server stream to the handler to cancel the stream.
*/
class CancelServerStreamCommand {
private final NettyServerStream stream;
private final NettyServerStream.TransportState stream;
private final Status reason;
CancelServerStreamCommand(NettyServerStream stream, Status reason) {
CancelServerStreamCommand(NettyServerStream.TransportState stream, Status reason) {
this.stream = Preconditions.checkNotNull(stream);
this.reason = Preconditions.checkNotNull(reason);
}
NettyServerStream stream() {
NettyServerStream.TransportState stream() {
return stream;
}

View File

@ -54,7 +54,7 @@ import javax.annotation.Nullable;
/**
* Client stream for a Netty transport.
*/
abstract class NettyClientStream extends Http2ClientStream {
abstract class NettyClientStream extends Http2ClientStream implements StreamIdHolder {
private final MethodDescriptor<?, ?> method;
/** {@code null} after start. */
private Metadata headers;

View File

@ -188,16 +188,19 @@ class NettyServerHandler extends AbstractNettyHandler {
// method.
Http2Stream http2Stream = requireHttp2Stream(streamId);
NettyServerStream stream = new NettyServerStream(ctx.channel(), http2Stream, this,
maxMessageSize);
NettyServerStream.TransportState state =
new NettyServerStream.TransportState(this, http2Stream, maxMessageSize);
NettyServerStream stream = new NettyServerStream(ctx.channel(), state);
Metadata metadata = Utils.convertHeaders(headers);
stream.inboundHeadersReceived(metadata);
ServerStreamListener listener =
transportListener.streamCreated(stream, method, metadata);
stream.setListener(listener);
http2Stream.setProperty(streamKey, stream);
// TODO(ejona): this could be racy since stream could have been used before getting here. All
// cases appear to be fine, but some are almost only by happenstance and it is difficult to
// audit. It would be good to improve the API to be less prone to races.
state.setListener(listener);
http2Stream.setProperty(streamKey, state);
} catch (Http2Exception e) {
throw e;
@ -210,7 +213,7 @@ class NettyServerHandler extends AbstractNettyHandler {
private void onDataRead(int streamId, ByteBuf data, boolean endOfStream) throws Http2Exception {
try {
NettyServerStream stream = serverStream(requireHttp2Stream(streamId));
NettyServerStream.TransportState stream = serverStream(requireHttp2Stream(streamId));
stream.inboundDataReceived(data, endOfStream);
} catch (Throwable e) {
logger.log(Level.WARNING, "Exception in onDataRead()", e);
@ -221,9 +224,9 @@ class NettyServerHandler extends AbstractNettyHandler {
private void onRstStreamRead(int streamId) throws Http2Exception {
try {
NettyServerStream stream = serverStream(connection().stream(streamId));
NettyServerStream.TransportState stream = serverStream(connection().stream(streamId));
if (stream != null) {
stream.abortStream(Status.CANCELLED, false);
stream.transportReportStatus(Status.CANCELLED);
}
} catch (Throwable e) {
logger.log(Level.WARNING, "Exception in onRstStreamRead()", e);
@ -244,16 +247,15 @@ class NettyServerHandler extends AbstractNettyHandler {
protected void onStreamError(ChannelHandlerContext ctx, Throwable cause,
StreamException http2Ex) {
logger.log(Level.WARNING, "Stream Error", cause);
NettyServerStream serverStream = serverStream(
NettyServerStream.TransportState serverStream = serverStream(
connection().stream(Http2Exception.streamId(http2Ex)));
if (serverStream != null) {
// Abort the stream with a status to help the client with debugging.
serverStream.abortStream(Utils.statusFromThrowable(cause), true);
} else {
serverStream.transportReportStatus(Utils.statusFromThrowable(cause));
}
// TODO(ejona): Abort the stream by sending headers to help the client with debugging.
// Delegate to the base class to send a RST_STREAM.
super.onStreamError(ctx, cause, http2Ex);
}
}
/**
* Handler for the Channel shutting down.
@ -267,9 +269,9 @@ class NettyServerHandler extends AbstractNettyHandler {
connection().forEachActiveStream(new Http2StreamVisitor() {
@Override
public boolean visit(Http2Stream stream) throws Http2Exception {
NettyServerStream serverStream = serverStream(stream);
NettyServerStream.TransportState serverStream = serverStream(stream);
if (serverStream != null) {
serverStream.abortStream(status, false);
serverStream.transportReportStatus(status);
}
return true;
}
@ -318,7 +320,7 @@ class NettyServerHandler extends AbstractNettyHandler {
}
private void closeStreamWhenDone(ChannelPromise promise, int streamId) throws Http2Exception {
final NettyServerStream stream = serverStream(requireHttp2Stream(streamId));
final NettyServerStream.TransportState stream = serverStream(requireHttp2Stream(streamId));
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
@ -345,15 +347,15 @@ class NettyServerHandler extends AbstractNettyHandler {
private void sendResponseHeaders(ChannelHandlerContext ctx, SendResponseHeadersCommand cmd,
ChannelPromise promise) throws Http2Exception {
if (cmd.endOfStream()) {
closeStreamWhenDone(promise, cmd.streamId());
closeStreamWhenDone(promise, cmd.stream().id());
}
encoder().writeHeaders(ctx, cmd.streamId(), cmd.headers(), 0, cmd.endOfStream(), promise);
encoder().writeHeaders(ctx, cmd.stream().id(), cmd.headers(), 0, cmd.endOfStream(), promise);
}
private void cancelStream(ChannelHandlerContext ctx, CancelServerStreamCommand cmd,
ChannelPromise promise) {
// Notify the listener if we haven't already.
cmd.stream().abortStream(cmd.reason(), false);
cmd.stream().transportReportStatus(cmd.reason());
// Terminate the stream.
encoder().writeRstStream(ctx, cmd.stream().id(), Http2Error.CANCEL.code(), promise);
}
@ -408,8 +410,8 @@ class NettyServerHandler extends AbstractNettyHandler {
/**
* Returns the server stream associated to the given HTTP/2 stream object.
*/
private NettyServerStream serverStream(Http2Stream stream) {
return stream == null ? null : (NettyServerStream) stream.getProperty(streamKey);
private NettyServerStream.TransportState serverStream(Http2Stream stream) {
return stream == null ? null : (NettyServerStream.TransportState) stream.getProperty(streamKey);
}
private Http2Exception newStreamException(int streamId, Throwable cause) {

View File

@ -46,105 +46,40 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.handler.codec.http2.Http2Stream;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.SSLSession;
/**
* Server stream for a Netty HTTP2 transport.
* Server stream for a Netty HTTP2 transport. Must only be called from the sending application
* thread.
*/
class NettyServerStream extends AbstractServerStream<Integer> {
class NettyServerStream extends AbstractServerStream {
private static final Logger log = Logger.getLogger(NettyServerStream.class.getName());
private final Sink sink = new Sink();
private final TransportState state;
private final Channel channel;
private final NettyServerHandler handler;
private final Http2Stream http2Stream;
private final WriteQueue writeQueue;
private final Attributes attributes;
NettyServerStream(Channel channel, Http2Stream http2Stream, NettyServerHandler handler,
int maxMessageSize) {
super(new NettyWritableBufferAllocator(channel.alloc()), maxMessageSize);
this.writeQueue = handler.getWriteQueue();
public NettyServerStream(Channel channel, TransportState state) {
super(new NettyWritableBufferAllocator(channel.alloc()));
this.state = checkNotNull(state, "transportState");
this.channel = checkNotNull(channel, "channel");
this.http2Stream = checkNotNull(http2Stream, "http2Stream");
this.handler = checkNotNull(handler, "handler");
this.writeQueue = state.handler.getWriteQueue();
this.attributes = buildAttributes(channel);
}
@Override
public Integer id() {
return http2Stream.id();
protected TransportState transportState() {
return state;
}
@Override
protected void inboundHeadersReceived(Metadata headers) {
super.inboundHeadersReceived(headers);
}
void inboundDataReceived(ByteBuf frame, boolean endOfStream) {
super.inboundDataReceived(new NettyReadableBuffer(frame.retain()), endOfStream);
}
@Override
public void request(final int numMessages) {
if (channel.eventLoop().inEventLoop()) {
// Processing data read in the event loop so can call into the deframer immediately
requestMessagesFromDeframer(numMessages);
} else {
writeQueue.enqueue(new RequestMessagesCommand(this, numMessages), true);
}
}
@Override
protected void inboundDeliveryPaused() {
// Do nothing.
}
@Override
protected void internalSendHeaders(Metadata headers) {
writeQueue.enqueue(new SendResponseHeadersCommand(id(),
Utils.convertServerHeaders(headers), false),
true);
}
@Override
protected void sendFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf();
final int numBytes = bytebuf.readableBytes();
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
writeQueue.enqueue(
new SendGrpcFrameCommand(this, bytebuf, endOfStream),
channel.newPromise().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
onSentBytes(numBytes);
}
}), flush);
}
@Override
protected void sendTrailers(Metadata trailers, boolean headersSent) {
Http2Headers http2Trailers = Utils.convertTrailers(trailers, headersSent);
writeQueue.enqueue(new SendResponseHeadersCommand(id(), http2Trailers, true), true);
}
@Override
protected void returnProcessedBytes(int processedBytes) {
handler.returnProcessedBytes(http2Stream, processedBytes);
writeQueue.scheduleFlush();
}
@Override
protected void sendStreamAbortToClient(Status status, Metadata trailers) {
// Cancel the stream.
// TODO(nmittler): Consider sending trailers.
cancel(status);
}
@Override
public void cancel(Status status) {
writeQueue.enqueue(new CancelServerStreamCommand(this, status), true);
protected Sink abstractServerStreamSink() {
return sink;
}
@Override public Attributes attributes() {
@ -163,4 +98,92 @@ class NettyServerStream extends AbstractServerStream<Integer> {
.set(ServerCall.SSL_SESSION_KEY, sslSession)
.build();
}
private class Sink implements AbstractServerStream.Sink {
@Override
public void request(final int numMessages) {
if (channel.eventLoop().inEventLoop()) {
// Processing data read in the event loop so can call into the deframer immediately
transportState().requestMessagesFromDeframer(numMessages);
} else {
writeQueue.enqueue(new RequestMessagesCommand(transportState(), numMessages), true);
}
}
@Override
public void writeHeaders(Metadata headers) {
writeQueue.enqueue(new SendResponseHeadersCommand(transportState(),
Utils.convertServerHeaders(headers), false),
true);
}
@Override
public void writeFrame(WritableBuffer frame, boolean flush) {
if (frame == null) {
writeQueue.scheduleFlush();
return;
}
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf();
final int numBytes = bytebuf.readableBytes();
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
writeQueue.enqueue(
new SendGrpcFrameCommand(transportState(), bytebuf, false),
channel.newPromise().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
transportState().onSentBytes(numBytes);
}
}), flush);
}
@Override
public void writeTrailers(Metadata trailers, boolean headersSent) {
Http2Headers http2Trailers = Utils.convertTrailers(trailers, headersSent);
writeQueue.enqueue(
new SendResponseHeadersCommand(transportState(), http2Trailers, true), true);
}
@Override
public void cancel(Status status) {
writeQueue.enqueue(new CancelServerStreamCommand(transportState(), status), true);
}
}
/** This should only called from the transport thread. */
public static class TransportState extends AbstractServerStream.TransportState
implements StreamIdHolder {
private final Http2Stream http2Stream;
private final NettyServerHandler handler;
public TransportState(NettyServerHandler handler, Http2Stream http2Stream, int maxMessageSize) {
super(maxMessageSize);
this.http2Stream = checkNotNull(http2Stream, "http2Stream");
this.handler = checkNotNull(handler, "handler");
}
@Override
public void bytesRead(int processedBytes) {
handler.returnProcessedBytes(http2Stream, processedBytes);
handler.getWriteQueue().scheduleFlush();
}
@Override
protected void deframeFailed(Throwable cause) {
log.log(Level.WARNING, "Exception processing message", cause);
Status status = Status.fromThrowable(cause);
transportReportStatus(status);
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);
}
void inboundDataReceived(ByteBuf frame, boolean endOfStream) {
super.inboundDataReceived(new NettyReadableBuffer(frame.retain()), endOfStream);
}
public Integer id() {
return http2Stream.id();
}
}
}

View File

@ -31,7 +31,8 @@
package io.grpc.netty;
import io.grpc.internal.AbstractStream;
import io.grpc.internal.AbstractStream2;
import io.grpc.internal.Stream;
/**
* Command which requests messages from the deframer.
@ -39,14 +40,26 @@ import io.grpc.internal.AbstractStream;
class RequestMessagesCommand {
private final int numMessages;
private final AbstractStream<Integer> stream;
private final Stream stream;
private final AbstractStream2.TransportState state;
public RequestMessagesCommand(AbstractStream<Integer> stream, int numMessages) {
public RequestMessagesCommand(Stream stream, int numMessages) {
this.state = null;
this.numMessages = numMessages;
this.stream = stream;
}
public RequestMessagesCommand(AbstractStream2.TransportState state, int numMessages) {
this.state = state;
this.numMessages = numMessages;
this.stream = null;
}
void requestMessages() {
if (stream != null) {
stream.request(numMessages);
} else {
state.requestMessagesFromDeframer(numMessages);
}
}
}

View File

@ -31,7 +31,6 @@
package io.grpc.netty;
import io.grpc.internal.AbstractStream;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufHolder;
import io.netty.buffer.DefaultByteBufHolder;
@ -40,11 +39,10 @@ import io.netty.buffer.DefaultByteBufHolder;
* Command sent from the transport to the Netty channel to send a GRPC frame to the remote endpoint.
*/
class SendGrpcFrameCommand extends DefaultByteBufHolder {
private final AbstractStream<Integer> stream;
private final StreamIdHolder stream;
private final boolean endStream;
SendGrpcFrameCommand(AbstractStream<Integer> stream, ByteBuf content,
boolean endStream) {
SendGrpcFrameCommand(StreamIdHolder stream, ByteBuf content, boolean endStream) {
super(content);
this.stream = stream;
this.endStream = endStream;

View File

@ -39,18 +39,18 @@ import io.netty.handler.codec.http2.Http2Headers;
* Command sent from the transport to the Netty channel to send response headers to the client.
*/
class SendResponseHeadersCommand {
private final int streamId;
private final StreamIdHolder stream;
private final Http2Headers headers;
private final boolean endOfStream;
SendResponseHeadersCommand(int streamId, Http2Headers headers, boolean endOfStream) {
this.streamId = streamId;
SendResponseHeadersCommand(StreamIdHolder stream, Http2Headers headers, boolean endOfStream) {
this.stream = Preconditions.checkNotNull(stream);
this.headers = Preconditions.checkNotNull(headers);
this.endOfStream = endOfStream;
}
int streamId() {
return streamId;
StreamIdHolder stream() {
return stream;
}
Http2Headers headers() {
@ -67,19 +67,19 @@ class SendResponseHeadersCommand {
return false;
}
SendResponseHeadersCommand thatCmd = (SendResponseHeadersCommand) that;
return thatCmd.streamId == streamId
return thatCmd.stream.equals(stream)
&& thatCmd.headers.equals(headers)
&& thatCmd.endOfStream == endOfStream;
}
@Override
public String toString() {
return getClass().getSimpleName() + "(streamId=" + streamId + ", headers=" + headers
return getClass().getSimpleName() + "(stream=" + stream.id() + ", headers=" + headers
+ ", endOfStream=" + endOfStream + ")";
}
@Override
public int hashCode() {
return streamId;
return stream.hashCode();
}
}

View File

@ -0,0 +1,37 @@
/*
* Copyright 2016, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.netty;
/** Container for stream ids. */
interface StreamIdHolder {
public Integer id();
}

View File

@ -133,7 +133,8 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
createStream();
// Send a frame and verify that it was written.
ChannelFuture future = enqueue(new SendGrpcFrameCommand(stream, content(), false));
ChannelFuture future = enqueue(
new SendGrpcFrameCommand(stream.transportState(), content(), false));
assertTrue(future.isSuccess());
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(false),
any(ChannelPromise.class));
@ -275,9 +276,9 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
@Test
public void cancelShouldSendRstStream() throws Exception {
createStream();
enqueue(new CancelServerStreamCommand(stream, Status.DEADLINE_EXCEEDED));
verifyWrite().writeRstStream(eq(ctx()), eq(stream.id()), eq(Http2Error.CANCEL.code()),
any(ChannelPromise.class));
enqueue(new CancelServerStreamCommand(stream.transportState(), Status.DEADLINE_EXCEEDED));
verifyWrite().writeRstStream(eq(ctx()), eq(stream.transportState().id()),
eq(Http2Error.CANCEL.code()), any(ChannelPromise.class));
}
@Test

View File

@ -33,8 +33,6 @@ package io.grpc.netty;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.eq;
@ -95,11 +93,13 @@ public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream
Http2Headers headers = new DefaultHttp2Headers()
.status(Utils.STATUS_OK)
.set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_GRPC);
verify(writeQueue).enqueue(new SendResponseHeadersCommand(STREAM_ID, headers, false), true);
verify(writeQueue).enqueue(
new SendResponseHeadersCommand(stream.transportState(), headers, false), true);
byte[] msg = smallMessage();
stream.writeMessage(new ByteArrayInputStream(msg));
stream.flush();
verify(writeQueue).enqueue(eq(new SendGrpcFrameCommand(stream, messageFrame(MESSAGE), false)),
verify(writeQueue).enqueue(
eq(new SendGrpcFrameCommand(stream.transportState(), messageFrame(MESSAGE), false)),
any(ChannelPromise.class),
eq(true));
}
@ -108,38 +108,23 @@ public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream
public void writeHeadersShouldSendHeaders() throws Exception {
Metadata headers = new Metadata();
stream().writeHeaders(headers);
verify(writeQueue).enqueue(new SendResponseHeadersCommand(STREAM_ID,
verify(writeQueue).enqueue(new SendResponseHeadersCommand(stream.transportState(),
Utils.convertServerHeaders(headers), false), true);
}
@Test
public void duplicateWriteHeadersShouldFail() throws Exception {
Metadata headers = new Metadata();
stream().writeHeaders(headers);
verify(writeQueue).enqueue(new SendResponseHeadersCommand(STREAM_ID,
Utils.convertServerHeaders(headers), false), true);
try {
stream().writeHeaders(headers);
fail("Can only write response headers once");
} catch (IllegalStateException ise) {
// Success
}
}
@Test
public void closeBeforeClientHalfCloseShouldSucceed() throws Exception {
stream().close(Status.OK, new Metadata());
verify(writeQueue).enqueue(
new SendResponseHeadersCommand(STREAM_ID, new DefaultHttp2Headers()
new SendResponseHeadersCommand(stream.transportState(), new DefaultHttp2Headers()
.status(new AsciiString("200"))
.set(new AsciiString("content-type"), new AsciiString("application/grpc"))
.set(new AsciiString("grpc-status"), new AsciiString("0")), true),
true);
verifyZeroInteractions(serverListener);
// Sending complete. Listener gets closed()
stream().complete();
stream().transportState().complete();
verify(serverListener).closed(Status.OK);
assertTrue(stream().isClosed());
verifyZeroInteractions(serverListener);
}
@ -148,56 +133,43 @@ public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream
// Error is sent on wire and ends the stream
stream().close(Status.CANCELLED, trailers);
verify(writeQueue).enqueue(
new SendResponseHeadersCommand(STREAM_ID, new DefaultHttp2Headers()
new SendResponseHeadersCommand(stream.transportState(), new DefaultHttp2Headers()
.status(new AsciiString("200"))
.set(new AsciiString("content-type"), new AsciiString("application/grpc"))
.set(new AsciiString("grpc-status"), new AsciiString("1")), true),
true);
verifyZeroInteractions(serverListener);
// Sending complete. Listener gets closed()
stream().complete();
stream().transportState().complete();
verify(serverListener).closed(Status.OK);
assertTrue(stream().isClosed());
verifyZeroInteractions(serverListener);
}
@Test
public void closeAfterClientHalfCloseShouldSucceed() throws Exception {
// Client half-closes. Listener gets halfClosed()
stream().inboundDataReceived(new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT), true);
assertTrue(stream().canSend());
stream().transportState()
.inboundDataReceived(new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT), true);
verify(serverListener).halfClosed();
// Server closes. Status sent
stream().close(Status.OK, trailers);
assertTrue(stream().isClosed());
verifyNoMoreInteractions(serverListener);
verify(writeQueue).enqueue(
new SendResponseHeadersCommand(STREAM_ID, new DefaultHttp2Headers()
new SendResponseHeadersCommand(stream.transportState(), new DefaultHttp2Headers()
.status(new AsciiString("200"))
.set(new AsciiString("content-type"), new AsciiString("application/grpc"))
.set(new AsciiString("grpc-status"), new AsciiString("0")), true),
true);
// Sending and receiving complete. Listener gets closed()
stream().complete();
stream().transportState().complete();
verify(serverListener).closed(Status.OK);
verifyNoMoreInteractions(serverListener);
}
@Test
public void abortStreamAndSendStatus() throws Exception {
Status status = Status.INTERNAL.withCause(new Throwable());
stream().abortStream(status, true);
assertTrue(stream().isClosed());
verify(serverListener).closed(same(status));
verify(writeQueue).enqueue(new CancelServerStreamCommand(stream(), status), true);
verifyNoMoreInteractions(serverListener);
}
@Test
public void abortStreamAndNotSendStatus() throws Exception {
Status status = Status.INTERNAL.withCause(new Throwable());
stream().abortStream(status, false);
assertTrue(stream().isClosed());
stream().transportState().transportReportStatus(status);
verify(serverListener).closed(same(status));
verify(channel, never()).writeAndFlush(any(SendResponseHeadersCommand.class));
verify(channel, never()).writeAndFlush(any(SendGrpcFrameCommand.class));
@ -208,21 +180,20 @@ public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream
public void abortStreamAfterClientHalfCloseShouldCallClose() {
Status status = Status.INTERNAL.withCause(new Throwable());
// Client half-closes. Listener gets halfClosed()
stream().inboundDataReceived(new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT), true);
assertTrue(stream().canSend());
stream().transportState().inboundDataReceived(
new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT), true);
verify(serverListener).halfClosed();
// Abort from the transport layer
stream().abortStream(status, true);
stream().transportState().transportReportStatus(status);
verify(serverListener).closed(same(status));
verifyNoMoreInteractions(serverListener);
assertTrue(stream().isClosed());
}
@Test
public void emptyFramerShouldSendNoPayload() throws Exception {
stream().close(Status.OK, new Metadata());
verify(writeQueue).enqueue(
new SendResponseHeadersCommand(STREAM_ID, new DefaultHttp2Headers()
new SendResponseHeadersCommand(stream.transportState(), new DefaultHttp2Headers()
.status(new AsciiString("200"))
.set(new AsciiString("content-type"), new AsciiString("application/grpc"))
.set(new AsciiString("grpc-status"), new AsciiString("0")), true),
@ -233,7 +204,7 @@ public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream
public void cancelStreamShouldSucceed() {
stream().cancel(Status.DEADLINE_EXCEEDED);
verify(writeQueue).enqueue(
new CancelServerStreamCommand(stream(), Status.DEADLINE_EXCEEDED),
new CancelServerStreamCommand(stream().transportState(), Status.DEADLINE_EXCEEDED),
true);
}
@ -250,11 +221,10 @@ public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream
}
}).when(writeQueue).enqueue(any(), any(ChannelPromise.class), anyBoolean());
when(writeQueue.enqueue(any(), anyBoolean())).thenReturn(future);
NettyServerStream stream = new NettyServerStream(channel, http2Stream, handler,
DEFAULT_MAX_MESSAGE_SIZE);
stream.setListener(serverListener);
assertTrue(stream.canReceive());
assertTrue(stream.canSend());
NettyServerStream.TransportState state =
new NettyServerStream.TransportState(handler, http2Stream, DEFAULT_MAX_MESSAGE_SIZE);
NettyServerStream stream = new NettyServerStream(channel, state);
stream.transportState().setListener(serverListener);
verify(serverListener, atLeastOnce()).onReady();
verifyNoMoreInteractions(serverListener);
return stream;

View File

@ -43,7 +43,7 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.internal.AbstractStream;
import io.grpc.internal.Stream;
import io.grpc.internal.StreamListener;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
@ -71,7 +71,7 @@ import java.util.concurrent.TimeUnit;
/**
* Base class for Netty stream unit tests.
*/
public abstract class NettyStreamTestBase<T extends AbstractStream<Integer>> {
public abstract class NettyStreamTestBase<T extends Stream> {
protected static final String MESSAGE = "hello world";
protected static final int STREAM_ID = 1;
@ -135,7 +135,8 @@ public abstract class NettyStreamTestBase<T extends AbstractStream<Integer>> {
stream.request(1);
if (stream instanceof NettyServerStream) {
((NettyServerStream) stream).inboundDataReceived(messageFrame(MESSAGE), false);
((NettyServerStream) stream).transportState()
.inboundDataReceived(messageFrame(MESSAGE), false);
} else {
((NettyClientStream) stream).transportDataReceived(messageFrame(MESSAGE), false);
}