diff --git a/core/src/main/java/io/grpc/internal/AbstractServerStream.java b/core/src/main/java/io/grpc/internal/AbstractServerStream.java index 49e97ccd55..a71fd58d87 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerStream.java @@ -31,256 +31,219 @@ package io.grpc.internal; -import static com.google.common.base.Preconditions.checkNotNull; - import com.google.common.base.Preconditions; import io.grpc.Attributes; import io.grpc.Metadata; import io.grpc.Status; -import java.io.InputStream; -import java.util.logging.Level; -import java.util.logging.Logger; +import javax.annotation.Nullable; /** - * Abstract base class for {@link ServerStream} implementations. - * - * @param the type of the stream identifier + * Abstract base class for {@link ServerStream} implementations. Extending classes only need to + * implement {@link #transportState()} and {@link #abstractServerStreamSink()}. Must only be called + * from the sending application thread. + */ -public abstract class AbstractServerStream extends AbstractStream - implements ServerStream { - private static final Logger log = Logger.getLogger(AbstractServerStream.class.getName()); - - /** Whether listener.closed() has been called. */ - private boolean listenerClosed; - private ServerStreamListener listener; - - private boolean headersSent = false; +public abstract class AbstractServerStream extends AbstractStream2 + implements ServerStream, MessageFramer.Sink { /** - * Whether the stream was closed gracefully by the application (vs. a transport-level failure). + * A sink for outbound operations, separated from the stream simply to avoid name + * collisions/confusion. Only called from application thread. */ - private boolean gracefulClose; - /** Saved trailers from close() that need to be sent once the framer has sent all messages. */ - private Metadata stashedTrailers; + protected interface Sink { + /** + * Sends response headers to the remote end point. + * + * @param headers the headers to be sent to client. + */ + void writeHeaders(Metadata headers); - protected AbstractServerStream(WritableBufferAllocator bufferAllocator, - int maxMessageSize) { - super(bufferAllocator, maxMessageSize); + /** + * Sends an outbound frame to the remote end point. + * + * @param frame a buffer containing the chunk of data to be sent. + * @param flush {@code true} if more data may not be arriving soon + */ + void writeFrame(@Nullable WritableBuffer frame, boolean flush); + + /** + * Sends trailers to the remote end point. This call implies end of stream. + * + * @param trailers metadata to be sent to the end point + * @param headersSent {@code true} if response headers have already been sent. + */ + void writeTrailers(Metadata trailers, boolean headersSent); + + /** + * Requests up to the given number of messages from the call to be delivered. This should end up + * triggering {@link TransportState#requestMessagesFromDeframer(int)} on the transport thread. + */ + void request(int numMessages); + + /** + * Tears down the stream, typically in the event of a timeout. This method may be called + * multiple times and from any thread. + * + *

This is a clone of {@link ServerStream#cancel()}. + */ + void cancel(Status status); } - /** - * Sets the listener to receive notifications. Must be called in the context of the transport - * thread. - */ - public final void setListener(ServerStreamListener listener) { - this.listener = checkNotNull(listener); + private final MessageFramer framer; + private boolean outboundClosed; + private boolean headersSent; - // Now that the stream has actually been initialized, call the listener's onReady callback if - // appropriate. - onStreamAllocated(); + protected AbstractServerStream(WritableBufferAllocator bufferAllocator) { + framer = new MessageFramer(this, bufferAllocator); } @Override - protected ServerStreamListener listener() { - return listener; + protected abstract TransportState transportState(); + + /** + * Sink for transport to be called to perform outbound operations. Each stream must have its own + * unique sink. + */ + protected abstract Sink abstractServerStreamSink(); + + @Override + protected final MessageFramer framer() { + return framer; } @Override - protected void receiveMessage(InputStream is) { - inboundPhase(Phase.MESSAGE); - listener().messageRead(is); + public final void request(int numMessages) { + abstractServerStreamSink().request(numMessages); } @Override public final void writeHeaders(Metadata headers) { Preconditions.checkNotNull(headers, "headers"); - outboundPhase(Phase.HEADERS); headersSent = true; - internalSendHeaders(headers); - outboundPhase(Phase.MESSAGE); + abstractServerStreamSink().writeHeaders(headers); } @Override - public final void writeMessage(InputStream message) { - if (outboundPhase() != Phase.MESSAGE) { - throw new IllegalStateException("Messages are only permitted after headers and before close"); - } - super.writeMessage(message); + public final void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flush) { + // Since endOfStream is triggered by the sending of trailers, avoid flush here and just flush + // after the trailers. + abstractServerStreamSink().writeFrame(frame, endOfStream ? false : flush); } @Override public final void close(Status status, Metadata trailers) { Preconditions.checkNotNull(status, "status"); Preconditions.checkNotNull(trailers, "trailers"); - if (outboundPhase(Phase.STATUS) != Phase.STATUS) { - gracefulClose = true; - stashedTrailers = trailers; - writeStatusToTrailers(status); - closeFramer(); + if (!outboundClosed) { + outboundClosed = true; + endOfMessages(); + addStatusToTrailers(trailers, status); + abstractServerStreamSink().writeTrailers(trailers, headersSent); } } - private void writeStatusToTrailers(Status status) { - stashedTrailers.removeAll(Status.CODE_KEY); - stashedTrailers.removeAll(Status.MESSAGE_KEY); - stashedTrailers.put(Status.CODE_KEY, status); + private void addStatusToTrailers(Metadata trailers, Status status) { + trailers.removeAll(Status.CODE_KEY); + trailers.removeAll(Status.MESSAGE_KEY); + trailers.put(Status.CODE_KEY, status); if (status.getDescription() != null) { - stashedTrailers.put(Status.MESSAGE_KEY, status.getDescription()); - } - } - - /** - * Called by transport implementations when they receive headers. - * - * @param headers the parsed headers - */ - protected void inboundHeadersReceived(Metadata headers) { - inboundPhase(Phase.MESSAGE); - } - - /** - * Called in the network thread to process the content of an inbound DATA frame from the client. - * - * @param frame the inbound HTTP/2 DATA frame. If this buffer is not used immediately, it must - * be retained. - * @param endOfStream {@code true} if no more data will be received on the stream. - */ - public void inboundDataReceived(ReadableBuffer frame, boolean endOfStream) { - if (inboundPhase() == Phase.STATUS) { - frame.close(); - return; - } - // Deframe the message. If a failure occurs, deframeFailed will be called. - deframe(frame, endOfStream); - } - - @Override - protected final void deframeFailed(Throwable cause) { - log.log(Level.WARNING, "Exception processing message", cause); - abortStream(Status.fromThrowable(cause), true); - } - - @Override - protected final void internalSendFrame(WritableBuffer frame, boolean endOfStream, boolean flush) { - if (frame != null) { - sendFrame(frame, false, endOfStream ? false : flush); - } - if (endOfStream) { - sendTrailers(stashedTrailers, headersSent); - headersSent = true; - stashedTrailers = null; - } - } - - /** - * Sends response headers to the remote end points. - * - * @param headers the headers to be sent to client. - */ - protected abstract void internalSendHeaders(Metadata headers); - - /** - * Sends an outbound frame to the remote end point. - * - * @param frame a buffer containing the chunk of data to be sent. - * @param endOfStream if {@code true} indicates that no more data will be sent on the stream by - * this endpoint. - * @param flush {@code true} if more data may not be arriving soon - */ - protected abstract void sendFrame(WritableBuffer frame, boolean endOfStream, boolean flush); - - /** - * Sends trailers to the remote end point. This call implies end of stream. - * - * @param trailers metadata to be sent to end point - * @param headersSent {@code true} if response headers have already been sent. - */ - protected abstract void sendTrailers(Metadata trailers, boolean headersSent); - - /** - * Indicates the stream is considered completely closed and there is no further opportunity for - * error. It calls the listener's {@code closed()} if it was not already done by {@link - * #abortStream}. Note that it is expected that either {@code closed()} or {@code abortStream()} - * was previously called, since {@code closed()} is required for a normal stream closure and - * {@code abortStream()} for abnormal. - */ - public void complete() { - if (!gracefulClose) { - closeListener(Status.INTERNAL.withDescription("successful complete() without close()")); - throw new IllegalStateException("successful complete() without close()"); - } - closeListener(Status.OK); - } - - /** - * Called when the remote end half-closes the stream. - */ - @Override - protected final void remoteEndClosed() { - halfCloseListener(); - } - - /** - * Aborts the stream with an error status, cleans up resources and notifies the listener if - * necessary. - * - *

Unlike {@link #close(Status, Metadata)}, this method is only called from the - * transport. The transport should use this method instead of {@code close(Status)} for internal - * errors to prevent exposing unexpected states and exceptions to the application. - * - * @param status the error status. Must not be {@link Status#OK}. - * @param notifyClient {@code true} if the stream is still writable and you want to notify the - * client about stream closure and send the status - */ - public final void abortStream(Status status, boolean notifyClient) { - // TODO(louiscryan): Investigate whether we can remove the notification to the client - // and rely on a transport layer stream reset instead. - Preconditions.checkArgument(!status.isOk(), "status must not be OK"); - closeListener(status); - if (notifyClient) { - // TODO(louiscryan): Remove - if (stashedTrailers == null) { - stashedTrailers = new Metadata(); - } - writeStatusToTrailers(status); - sendStreamAbortToClient(status, stashedTrailers); + trailers.put(Status.MESSAGE_KEY, status.getDescription()); } } @Override - public boolean isClosed() { - return super.isClosed() || listenerClosed; - } - - /** - * Notifies the remote client that this stream has aborted. - */ - protected abstract void sendStreamAbortToClient(Status status, Metadata trailers); - - /** - * Fires a half-closed event to the listener and frees inbound resources. - */ - private void halfCloseListener() { - if (inboundPhase(Phase.STATUS) != Phase.STATUS && !listenerClosed) { - closeDeframer(); - listener().halfClosed(); - } - } - - /** - * Closes the listener if not previously closed and frees resources. - */ - private void closeListener(Status newStatus) { - if (!listenerClosed) { - listenerClosed = true; - closeDeframer(); - listener().closed(newStatus); - } + public final void cancel(Status status) { + abstractServerStreamSink().cancel(status); } @Override public Attributes attributes() { return Attributes.EMPTY; } + + /** This should only called from the transport thread. */ + protected abstract static class TransportState extends AbstractStream2.TransportState { + /** Whether listener.closed() has been called. */ + private boolean listenerClosed; + private ServerStreamListener listener; + + protected TransportState(int maxMessageSize) { + super(maxMessageSize); + } + + /** + * Sets the listener to receive notifications. Must be called in the context of the transport + * thread. + */ + public final void setListener(ServerStreamListener listener) { + this.listener = Preconditions.checkNotNull(listener); + + // Now that the stream has actually been initialized, call the listener's onReady callback if + // appropriate. + onStreamAllocated(); + } + + @Override + public void deliveryStalled() {} + + @Override + public void endOfStream() { + closeDeframer(); + listener().halfClosed(); + } + + @Override + protected ServerStreamListener listener() { + return listener; + } + + /** + * Called in the transport thread to process the content of an inbound DATA frame from the + * client. + * + * @param frame the inbound HTTP/2 DATA frame. If this buffer is not used immediately, it must + * be retained. + * @param endOfStream {@code true} if no more data will be received on the stream. + */ + public void inboundDataReceived(ReadableBuffer frame, boolean endOfStream) { + // Deframe the message. If a failure occurs, deframeFailed will be called. + deframe(frame, endOfStream); + } + + /** + * Notifies failure to the listener of the stream. The transport is responsible for notifying + * the client of the failure independent of this method. + * + *

Unlike {@link #close(Status, Metadata)}, this method is only called from the + * transport. The transport should use this method instead of {@code close(Status)} for internal + * errors to prevent exposing unexpected states and exceptions to the application. + * + * @param status the error status. Must not be {@link Status#OK}. + */ + public final void transportReportStatus(Status status) { + Preconditions.checkArgument(!status.isOk(), "status must not be OK"); + closeListener(status); + } + + /** + * Indicates the stream is considered completely closed and there is no further opportunity for + * error. It calls the listener's {@code closed()} if it was not already done by {@link + * #transportReportStatus}. + */ + public void complete() { + closeListener(Status.OK); + } + + /** + * Closes the listener if not previously closed and frees resources. + */ + private void closeListener(Status newStatus) { + if (!listenerClosed) { + listenerClosed = true; + closeDeframer(); + listener().closed(newStatus); + } + } + } } diff --git a/core/src/main/java/io/grpc/internal/AbstractStream2.java b/core/src/main/java/io/grpc/internal/AbstractStream2.java new file mode 100644 index 0000000000..c346f99845 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/AbstractStream2.java @@ -0,0 +1,279 @@ +/* + * Copyright 2014, Google Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package io.grpc.internal; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.annotations.VisibleForTesting; + +import io.grpc.Codec; +import io.grpc.Compressor; +import io.grpc.Decompressor; + +import java.io.InputStream; + +import javax.annotation.concurrent.GuardedBy; + +/** + * The stream and stream state as used by the application. Must only be called from the sending + * application thread. + */ +public abstract class AbstractStream2 implements Stream { + /** The framer to use for sending messages. */ + protected abstract MessageFramer framer(); + + /** + * Obtain the transport state corresponding to this stream. Each stream must have its own unique + * transport state. + */ + protected abstract TransportState transportState(); + + @Override + public final void setMessageCompression(boolean enable) { + framer().setMessageCompression(enable); + } + + @Override + public final void writeMessage(InputStream message) { + checkNotNull(message); + if (!framer().isClosed()) { + framer().writePayload(message); + } + } + + @Override + public final void flush() { + if (!framer().isClosed()) { + framer().flush(); + } + } + + /** + * Closes the underlying framer. Should be called when the outgoing stream is gracefully closed + * (half closure on client; closure on server). + */ + protected final void endOfMessages() { + framer().close(); + } + + @Override + public final void setCompressor(Compressor compressor) { + framer().setCompressor(checkNotNull(compressor, "compressor")); + } + + @Override + public final void setDecompressor(Decompressor decompressor) { + transportState().setDecompressor(checkNotNull(decompressor, "decompressor")); + } + + @Override + public final boolean isReady() { + if (framer().isClosed()) { + return false; + } + return transportState().isReady(); + } + + /** + * Event handler to be called by the subclass when a number of bytes are being queued for sending + * to the remote endpoint. + * + * @param numBytes the number of bytes being sent. + */ + protected final void onSendingBytes(int numBytes) { + transportState().onSendingBytes(numBytes); + } + + /** + * Stream state as used by the transport. This should only called from the transport thread + * (except for private interactions with {@code AbstractStream2}). + */ + public abstract static class TransportState implements MessageDeframer.Listener { + /** + * The default number of queued bytes for a given stream, below which + * {@link StreamListener#onReady()} will be called. + */ + private static final int DEFAULT_ONREADY_THRESHOLD = 32 * 1024; + + private final MessageDeframer deframer; + private final Object onReadyLock = new Object(); + /** + * The number of bytes currently queued, waiting to be sent. When this falls below + * DEFAULT_ONREADY_THRESHOLD, {@link StreamListener#onReady()} will be called. + */ + @GuardedBy("onReadyLock") + private int numSentBytesQueued; + /** + * Indicates the stream has been created on the connection. This implies that the stream is no + * longer limited by MAX_CONCURRENT_STREAMS. + */ + @GuardedBy("onReadyLock") + private boolean allocated; + + protected TransportState(int maxMessageSize) { + deframer = new MessageDeframer(this, Codec.Identity.NONE, maxMessageSize); + } + + @VisibleForTesting + TransportState(MessageDeframer deframer) { + this.deframer = deframer; + } + + /** + * Override this method to provide a stream listener. + */ + protected abstract StreamListener listener(); + + @Override + public void messageRead(InputStream is) { + listener().messageRead(is); + } + + /** + * Called when a {@link #deframe(ReadableBuffer, boolean)} operation failed. + * + * @param cause the actual failure + */ + protected abstract void deframeFailed(Throwable cause); + + /** + * Closes this deframer and frees any resources. After this method is called, additional calls + * will have no effect. + */ + protected final void closeDeframer() { + deframer.close(); + } + + /** + * Called to parse a received frame and attempt delivery of any completed + * messages. Must be called from the transport thread. + */ + protected final void deframe(ReadableBuffer frame, boolean endOfStream) { + if (deframer.isClosed()) { + frame.close(); + return; + } + try { + deframer.deframe(frame, endOfStream); + } catch (Throwable t) { + deframeFailed(t); + } + } + + /** + * Called to request the given number of messages from the deframer. Must be called + * from the transport thread. + */ + public final void requestMessagesFromDeframer(int numMessages) { + if (deframer.isClosed()) { + return; + } + try { + deframer.request(numMessages); + } catch (Throwable t) { + deframeFailed(t); + } + } + + private void setDecompressor(Decompressor decompressor) { + if (deframer.isClosed()) { + return; + } + deframer.setDecompressor(decompressor); + } + + private boolean isReady() { + synchronized (onReadyLock) { + return allocated && numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD; + } + } + + /** + * Event handler to be called by the subclass when the stream's headers have passed any + * connection flow control (i.e., MAX_CONCURRENT_STREAMS). It may call the listener's {@link + * StreamListener#onReady()} handler if appropriate. This must be called from the transport + * thread, since the listener may be called back directly. + */ + protected final void onStreamAllocated() { + checkState(listener() != null); + synchronized (onReadyLock) { + checkState(!allocated, "Already allocated"); + allocated = true; + } + notifyIfReady(); + } + + /** + * Event handler to be called by the subclass when a number of bytes are being queued for + * sending to the remote endpoint. + * + * @param numBytes the number of bytes being sent. + */ + private void onSendingBytes(int numBytes) { + synchronized (onReadyLock) { + numSentBytesQueued += numBytes; + } + } + + /** + * Event handler to be called by the subclass when a number of bytes has been sent to the remote + * endpoint. May call back the listener's {@link StreamListener#onReady()} handler if + * appropriate. This must be called from the transport thread, since the listener may be called + * back directly. + * + * @param numBytes the number of bytes that were sent. + */ + public final void onSentBytes(int numBytes) { + boolean doNotify; + synchronized (onReadyLock) { + boolean belowThresholdBefore = numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD; + numSentBytesQueued -= numBytes; + boolean belowThresholdAfter = numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD; + doNotify = !belowThresholdBefore && belowThresholdAfter; + } + if (doNotify) { + notifyIfReady(); + } + } + + private void notifyIfReady() { + boolean doNotify; + synchronized (onReadyLock) { + doNotify = isReady(); + } + if (doNotify) { + listener().onReady(); + } + } + } +} diff --git a/core/src/main/java/io/grpc/internal/MessageFramer.java b/core/src/main/java/io/grpc/internal/MessageFramer.java index f7c206e8d3..8a073e9b1c 100644 --- a/core/src/main/java/io/grpc/internal/MessageFramer.java +++ b/core/src/main/java/io/grpc/internal/MessageFramer.java @@ -51,6 +51,8 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; +import javax.annotation.Nullable; + /** * Encodes gRPC messages to be delivered via the transport layer which implements {@link * MessageFramer.Sink}. @@ -69,7 +71,7 @@ public class MessageFramer { * @param endOfStream whether the frame is the last one for the GRPC stream * @param flush {@code true} if more data may not be arriving soon */ - void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flush); + void deliverFrame(@Nullable WritableBuffer frame, boolean endOfStream, boolean flush); } private static final int HEADER_LENGTH = 5; diff --git a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java index 15611c6a87..16e308a46b 100644 --- a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java @@ -32,17 +32,17 @@ package io.grpc.internal; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; import static org.mockito.Matchers.isA; +import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import io.grpc.Metadata; import io.grpc.Status; -import io.grpc.internal.AbstractStream.Phase; import io.grpc.internal.MessageFramerTest.ByteWritableBuffer; import org.junit.Rule; @@ -50,13 +50,10 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import java.io.ByteArrayInputStream; import java.io.InputStream; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; - -import javax.annotation.Nullable; /** * Tests for {@link AbstractServerStream}. @@ -74,22 +71,24 @@ public class AbstractServerStreamTest { } }; - private final AbstractServerStreamBase defaultStream = - new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE); + private AbstractServerStream.Sink sink = mock(AbstractServerStream.Sink.class); + private AbstractServerStreamBase stream = new AbstractServerStreamBase( + allocator, sink, new AbstractServerStreamBase.TransportState(MAX_MESSAGE_SIZE)); + private final ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); @Test public void setListener_setOnlyOnce() { - defaultStream.setListener(new ServerStreamListenerBase()); + stream.transportState().setListener(new ServerStreamListenerBase()); thrown.expect(IllegalStateException.class); - defaultStream.setListener(new ServerStreamListenerBase()); + stream.transportState().setListener(new ServerStreamListenerBase()); } @Test public void setListener_readyCalled() { ServerStreamListener streamListener = mock(ServerStreamListener.class); - defaultStream.setListener(streamListener); + stream.transportState().setListener(streamListener); verify(streamListener).onReady(); } @@ -98,179 +97,87 @@ public class AbstractServerStreamTest { public void setListener_failsOnNull() { thrown.expect(NullPointerException.class); - defaultStream.setListener(null); + stream.transportState().setListener(null); } @Test - public void receiveMessage_listenerCalled() { + public void messageRead_listenerCalled() { final ServerStreamListener streamListener = mock(ServerStreamListener.class); - AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) { - @Override - protected ServerStreamListener listener() { - return streamListener; - } - }; + stream.transportState().setListener(streamListener); // Normally called by a deframe event. - stream.receiveMessage(new ByteArrayInputStream(new byte[]{})); + stream.transportState().messageRead(new ByteArrayInputStream(new byte[]{})); verify(streamListener).messageRead(isA(InputStream.class)); } - @Test - public void receiveMessage_failsIfHalfClosed() { - // Simulate being closed, without invoking the listener - defaultStream.inboundPhase(Phase.STATUS); - - thrown.expect(IllegalStateException.class); - - // Normally called by a deframe event. - defaultStream.receiveMessage(new ByteArrayInputStream(new byte[]{})); - } - @Test public void writeHeaders_failsOnNullHeaders() { - AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE); thrown.expect(NullPointerException.class); stream.writeHeaders(null); } - @Test - public void writeHeaders_failsIfAlreadySent() { - defaultStream.writeHeaders(new Metadata()); - thrown.expect(IllegalStateException.class); - - defaultStream.writeHeaders(new Metadata()); - } - @Test public void writeHeaders() { - final AtomicReference capturedHeaders = new AtomicReference(null); Metadata headers = new Metadata(); - AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) { - @Override - protected void internalSendHeaders(Metadata captured) { - capturedHeaders.set(captured); - } - }; - stream.writeHeaders(headers); - - assertEquals(headers, capturedHeaders.get()); - assertEquals(Phase.MESSAGE, stream.outboundPhase()); - } - - @Test - public void writeMessage_writeHeadersIfNeeded() { - final AtomicReference capturedHeaders = new AtomicReference(null); - AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) { - @Override - protected void internalSendHeaders(Metadata captured) { - capturedHeaders.set(captured); - } - }; - stream.writeHeaders(new Metadata()); - - stream.writeMessage(new ByteArrayInputStream(new byte[]{})); - - assertNotNull(capturedHeaders.get()); + verify(sink).writeHeaders(same(headers)); } @Test public void writeMessage_dontWriteDuplicateHeaders() { - final AtomicReference capturedHeaders = new AtomicReference(null); - Metadata headers = new Metadata(); - AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) { - @Override - protected void internalSendHeaders(Metadata captured) { - capturedHeaders.set(captured); - } - }; - stream.writeHeaders(headers); - + stream.writeHeaders(new Metadata()); stream.writeMessage(new ByteArrayInputStream(new byte[]{})); - // Make sure it wasn't called twice, by checking that the exact headers sent are the ones - // returned. - assertSame(headers, capturedHeaders.get()); + // Make sure it wasn't called twice + verify(sink).writeHeaders(any(Metadata.class)); } @Test public void writeMessage_ignoreIfFramerClosed() { - final AtomicBoolean sendCalled = new AtomicBoolean(); - AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) { - @Override - protected void sendFrame(WritableBuffer frame, boolean endOfStream, boolean flush) { - sendCalled.set(true); - } - }; stream.writeHeaders(new Metadata()); - stream.closeFramer(); + stream.endOfMessages(); + reset(sink); stream.writeMessage(new ByteArrayInputStream(new byte[]{})); - assertFalse(sendCalled.get()); + verify(sink, never()).writeFrame(any(WritableBuffer.class), any(Boolean.class)); } @Test public void writeMessage() { - final AtomicBoolean sendCalled = new AtomicBoolean(); - AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) { - @Override - protected void sendFrame(WritableBuffer frame, boolean endOfStream, boolean flush) { - sendCalled.set(true); - } - }; stream.writeHeaders(new Metadata()); stream.writeMessage(new ByteArrayInputStream(new byte[]{})); - // Force the message to be flushed - stream.closeFramer(); + stream.flush(); - assertTrue(sendCalled.get()); - assertEquals(Phase.MESSAGE, stream.outboundPhase()); + verify(sink).writeFrame(any(WritableBuffer.class), eq(true)); } @Test public void close_failsOnNullStatus() { thrown.expect(NullPointerException.class); - defaultStream.close(null, new Metadata()); + stream.close(null, new Metadata()); } @Test public void close_failsOnNullMetadata() { thrown.expect(NullPointerException.class); - defaultStream.close(Status.INTERNAL, null); + stream.close(Status.INTERNAL, null); } @Test public void close_sendsTrailers() { - final AtomicReference capturedTrailers = new AtomicReference(null); - AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) { - @Override - protected void sendTrailers(Metadata trailers, boolean headersSent) { - capturedTrailers.set(trailers); - } - }; Metadata trailers = new Metadata(); - stream.close(Status.INTERNAL, trailers); - - assertSame(trailers, capturedTrailers.get()); + verify(sink).writeTrailers(any(Metadata.class), eq(false)); } @Test public void close_sendTrailersClearsReservedFields() { - final AtomicReference capturedTrailers = new AtomicReference(null); - AbstractServerStreamBase stream = new AbstractServerStreamBase(allocator, MAX_MESSAGE_SIZE) { - @Override - protected void sendTrailers(Metadata trailers, boolean headersSent) { - capturedTrailers.set(trailers); - } - }; // stream actually mutates trailers, so we can't check that the fields here are the same as // the captured ones. Metadata trailers = new Metadata(); @@ -279,8 +186,9 @@ public class AbstractServerStreamTest { stream.close(Status.INTERNAL.withDescription("bad"), trailers); - assertEquals(Status.Code.INTERNAL, capturedTrailers.get().get(Status.CODE_KEY).getCode()); - assertEquals("bad", capturedTrailers.get().get(Status.MESSAGE_KEY)); + verify(sink).writeTrailers(metadataCaptor.capture(), eq(false)); + assertEquals(Status.Code.INTERNAL, metadataCaptor.getValue().get(Status.CODE_KEY).getCode()); + assertEquals("bad", metadataCaptor.getValue().get(Status.MESSAGE_KEY)); } private static class ServerStreamListenerBase implements ServerStreamListener { @@ -297,41 +205,38 @@ public class AbstractServerStreamTest { public void closed(Status status) {} } - private static class AbstractServerStreamBase extends AbstractServerStream { - protected AbstractServerStreamBase( - WritableBufferAllocator bufferAllocator, int maxMessageSize) { - super(bufferAllocator, maxMessageSize); + private static class AbstractServerStreamBase extends AbstractServerStream { + private final Sink sink; + private final AbstractServerStream.TransportState state; + + protected AbstractServerStreamBase(WritableBufferAllocator bufferAllocator, Sink sink, + AbstractServerStream.TransportState state) { + super(bufferAllocator); + this.sink = sink; + this.state = state; } @Override - public void cancel(Status status) {} - - @Override - public void request(int numMessages) {} - - @Override - protected void internalSendHeaders(Metadata headers) {} - - @Override - protected void sendFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {} - - @Override - protected void sendTrailers(Metadata trailers, boolean headersSent) {} - - @Override - @Nullable - public Void id() { - return null; + protected Sink abstractServerStreamSink() { + return sink; } @Override - protected void inboundDeliveryPaused() {} + protected AbstractServerStream.TransportState transportState() { + return state; + } - @Override - protected void returnProcessedBytes(int processedBytes) {} + static class TransportState extends AbstractServerStream.TransportState { + protected TransportState(int maxMessageSize) { + super(maxMessageSize); + } - @Override - protected void sendStreamAbortToClient(Status status, Metadata trailers) {} + @Override + protected void deframeFailed(Throwable cause) {} + + @Override + public void bytesRead(int processedBytes) {} + } } } diff --git a/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java b/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java index 6efffbbb0d..5d42b7727c 100644 --- a/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java +++ b/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java @@ -41,15 +41,15 @@ import io.grpc.Status; * Command sent from a Netty server stream to the handler to cancel the stream. */ class CancelServerStreamCommand { - private final NettyServerStream stream; + private final NettyServerStream.TransportState stream; private final Status reason; - CancelServerStreamCommand(NettyServerStream stream, Status reason) { + CancelServerStreamCommand(NettyServerStream.TransportState stream, Status reason) { this.stream = Preconditions.checkNotNull(stream); this.reason = Preconditions.checkNotNull(reason); } - NettyServerStream stream() { + NettyServerStream.TransportState stream() { return stream; } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientStream.java b/netty/src/main/java/io/grpc/netty/NettyClientStream.java index 32ca863cc4..ff78d85499 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientStream.java @@ -54,7 +54,7 @@ import javax.annotation.Nullable; /** * Client stream for a Netty transport. */ -abstract class NettyClientStream extends Http2ClientStream { +abstract class NettyClientStream extends Http2ClientStream implements StreamIdHolder { private final MethodDescriptor method; /** {@code null} after start. */ private Metadata headers; diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index ded0a6fb55..fa602c8d3b 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -188,16 +188,19 @@ class NettyServerHandler extends AbstractNettyHandler { // method. Http2Stream http2Stream = requireHttp2Stream(streamId); - NettyServerStream stream = new NettyServerStream(ctx.channel(), http2Stream, this, - maxMessageSize); + NettyServerStream.TransportState state = + new NettyServerStream.TransportState(this, http2Stream, maxMessageSize); + NettyServerStream stream = new NettyServerStream(ctx.channel(), state); Metadata metadata = Utils.convertHeaders(headers); - stream.inboundHeadersReceived(metadata); ServerStreamListener listener = transportListener.streamCreated(stream, method, metadata); - stream.setListener(listener); - http2Stream.setProperty(streamKey, stream); + // TODO(ejona): this could be racy since stream could have been used before getting here. All + // cases appear to be fine, but some are almost only by happenstance and it is difficult to + // audit. It would be good to improve the API to be less prone to races. + state.setListener(listener); + http2Stream.setProperty(streamKey, state); } catch (Http2Exception e) { throw e; @@ -210,7 +213,7 @@ class NettyServerHandler extends AbstractNettyHandler { private void onDataRead(int streamId, ByteBuf data, boolean endOfStream) throws Http2Exception { try { - NettyServerStream stream = serverStream(requireHttp2Stream(streamId)); + NettyServerStream.TransportState stream = serverStream(requireHttp2Stream(streamId)); stream.inboundDataReceived(data, endOfStream); } catch (Throwable e) { logger.log(Level.WARNING, "Exception in onDataRead()", e); @@ -221,9 +224,9 @@ class NettyServerHandler extends AbstractNettyHandler { private void onRstStreamRead(int streamId) throws Http2Exception { try { - NettyServerStream stream = serverStream(connection().stream(streamId)); + NettyServerStream.TransportState stream = serverStream(connection().stream(streamId)); if (stream != null) { - stream.abortStream(Status.CANCELLED, false); + stream.transportReportStatus(Status.CANCELLED); } } catch (Throwable e) { logger.log(Level.WARNING, "Exception in onRstStreamRead()", e); @@ -244,15 +247,14 @@ class NettyServerHandler extends AbstractNettyHandler { protected void onStreamError(ChannelHandlerContext ctx, Throwable cause, StreamException http2Ex) { logger.log(Level.WARNING, "Stream Error", cause); - NettyServerStream serverStream = serverStream( + NettyServerStream.TransportState serverStream = serverStream( connection().stream(Http2Exception.streamId(http2Ex))); if (serverStream != null) { - // Abort the stream with a status to help the client with debugging. - serverStream.abortStream(Utils.statusFromThrowable(cause), true); - } else { - // Delegate to the base class to send a RST_STREAM. - super.onStreamError(ctx, cause, http2Ex); + serverStream.transportReportStatus(Utils.statusFromThrowable(cause)); } + // TODO(ejona): Abort the stream by sending headers to help the client with debugging. + // Delegate to the base class to send a RST_STREAM. + super.onStreamError(ctx, cause, http2Ex); } /** @@ -267,9 +269,9 @@ class NettyServerHandler extends AbstractNettyHandler { connection().forEachActiveStream(new Http2StreamVisitor() { @Override public boolean visit(Http2Stream stream) throws Http2Exception { - NettyServerStream serverStream = serverStream(stream); + NettyServerStream.TransportState serverStream = serverStream(stream); if (serverStream != null) { - serverStream.abortStream(status, false); + serverStream.transportReportStatus(status); } return true; } @@ -318,7 +320,7 @@ class NettyServerHandler extends AbstractNettyHandler { } private void closeStreamWhenDone(ChannelPromise promise, int streamId) throws Http2Exception { - final NettyServerStream stream = serverStream(requireHttp2Stream(streamId)); + final NettyServerStream.TransportState stream = serverStream(requireHttp2Stream(streamId)); promise.addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) { @@ -345,15 +347,15 @@ class NettyServerHandler extends AbstractNettyHandler { private void sendResponseHeaders(ChannelHandlerContext ctx, SendResponseHeadersCommand cmd, ChannelPromise promise) throws Http2Exception { if (cmd.endOfStream()) { - closeStreamWhenDone(promise, cmd.streamId()); + closeStreamWhenDone(promise, cmd.stream().id()); } - encoder().writeHeaders(ctx, cmd.streamId(), cmd.headers(), 0, cmd.endOfStream(), promise); + encoder().writeHeaders(ctx, cmd.stream().id(), cmd.headers(), 0, cmd.endOfStream(), promise); } private void cancelStream(ChannelHandlerContext ctx, CancelServerStreamCommand cmd, ChannelPromise promise) { // Notify the listener if we haven't already. - cmd.stream().abortStream(cmd.reason(), false); + cmd.stream().transportReportStatus(cmd.reason()); // Terminate the stream. encoder().writeRstStream(ctx, cmd.stream().id(), Http2Error.CANCEL.code(), promise); } @@ -408,8 +410,8 @@ class NettyServerHandler extends AbstractNettyHandler { /** * Returns the server stream associated to the given HTTP/2 stream object. */ - private NettyServerStream serverStream(Http2Stream stream) { - return stream == null ? null : (NettyServerStream) stream.getProperty(streamKey); + private NettyServerStream.TransportState serverStream(Http2Stream stream) { + return stream == null ? null : (NettyServerStream.TransportState) stream.getProperty(streamKey); } private Http2Exception newStreamException(int streamId, Throwable cause) { diff --git a/netty/src/main/java/io/grpc/netty/NettyServerStream.java b/netty/src/main/java/io/grpc/netty/NettyServerStream.java index ad00c63ba5..901b5b0521 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerStream.java @@ -46,105 +46,40 @@ import io.netty.channel.ChannelFutureListener; import io.netty.handler.codec.http2.Http2Headers; import io.netty.handler.codec.http2.Http2Stream; +import java.util.logging.Level; +import java.util.logging.Logger; + import javax.net.ssl.SSLSession; /** - * Server stream for a Netty HTTP2 transport. + * Server stream for a Netty HTTP2 transport. Must only be called from the sending application + * thread. */ -class NettyServerStream extends AbstractServerStream { +class NettyServerStream extends AbstractServerStream { + private static final Logger log = Logger.getLogger(NettyServerStream.class.getName()); + private final Sink sink = new Sink(); + private final TransportState state; private final Channel channel; - private final NettyServerHandler handler; - private final Http2Stream http2Stream; private final WriteQueue writeQueue; private final Attributes attributes; - NettyServerStream(Channel channel, Http2Stream http2Stream, NettyServerHandler handler, - int maxMessageSize) { - super(new NettyWritableBufferAllocator(channel.alloc()), maxMessageSize); - this.writeQueue = handler.getWriteQueue(); + public NettyServerStream(Channel channel, TransportState state) { + super(new NettyWritableBufferAllocator(channel.alloc())); + this.state = checkNotNull(state, "transportState"); this.channel = checkNotNull(channel, "channel"); - this.http2Stream = checkNotNull(http2Stream, "http2Stream"); - this.handler = checkNotNull(handler, "handler"); + this.writeQueue = state.handler.getWriteQueue(); this.attributes = buildAttributes(channel); } @Override - public Integer id() { - return http2Stream.id(); + protected TransportState transportState() { + return state; } @Override - protected void inboundHeadersReceived(Metadata headers) { - super.inboundHeadersReceived(headers); - } - - void inboundDataReceived(ByteBuf frame, boolean endOfStream) { - super.inboundDataReceived(new NettyReadableBuffer(frame.retain()), endOfStream); - } - - @Override - public void request(final int numMessages) { - if (channel.eventLoop().inEventLoop()) { - // Processing data read in the event loop so can call into the deframer immediately - requestMessagesFromDeframer(numMessages); - } else { - writeQueue.enqueue(new RequestMessagesCommand(this, numMessages), true); - } - } - - @Override - protected void inboundDeliveryPaused() { - // Do nothing. - } - - @Override - protected void internalSendHeaders(Metadata headers) { - writeQueue.enqueue(new SendResponseHeadersCommand(id(), - Utils.convertServerHeaders(headers), false), - true); - } - - @Override - protected void sendFrame(WritableBuffer frame, boolean endOfStream, boolean flush) { - ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf(); - final int numBytes = bytebuf.readableBytes(); - // Add the bytes to outbound flow control. - onSendingBytes(numBytes); - writeQueue.enqueue( - new SendGrpcFrameCommand(this, bytebuf, endOfStream), - channel.newPromise().addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - // Remove the bytes from outbound flow control, optionally notifying - // the client that they can send more bytes. - onSentBytes(numBytes); - } - }), flush); - } - - @Override - protected void sendTrailers(Metadata trailers, boolean headersSent) { - Http2Headers http2Trailers = Utils.convertTrailers(trailers, headersSent); - writeQueue.enqueue(new SendResponseHeadersCommand(id(), http2Trailers, true), true); - } - - @Override - protected void returnProcessedBytes(int processedBytes) { - handler.returnProcessedBytes(http2Stream, processedBytes); - writeQueue.scheduleFlush(); - } - - @Override - protected void sendStreamAbortToClient(Status status, Metadata trailers) { - // Cancel the stream. - // TODO(nmittler): Consider sending trailers. - cancel(status); - } - - @Override - public void cancel(Status status) { - writeQueue.enqueue(new CancelServerStreamCommand(this, status), true); + protected Sink abstractServerStreamSink() { + return sink; } @Override public Attributes attributes() { @@ -163,4 +98,92 @@ class NettyServerStream extends AbstractServerStream { .set(ServerCall.SSL_SESSION_KEY, sslSession) .build(); } + + private class Sink implements AbstractServerStream.Sink { + @Override + public void request(final int numMessages) { + if (channel.eventLoop().inEventLoop()) { + // Processing data read in the event loop so can call into the deframer immediately + transportState().requestMessagesFromDeframer(numMessages); + } else { + writeQueue.enqueue(new RequestMessagesCommand(transportState(), numMessages), true); + } + } + + @Override + public void writeHeaders(Metadata headers) { + writeQueue.enqueue(new SendResponseHeadersCommand(transportState(), + Utils.convertServerHeaders(headers), false), + true); + } + + @Override + public void writeFrame(WritableBuffer frame, boolean flush) { + if (frame == null) { + writeQueue.scheduleFlush(); + return; + } + ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf(); + final int numBytes = bytebuf.readableBytes(); + // Add the bytes to outbound flow control. + onSendingBytes(numBytes); + writeQueue.enqueue( + new SendGrpcFrameCommand(transportState(), bytebuf, false), + channel.newPromise().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + // Remove the bytes from outbound flow control, optionally notifying + // the client that they can send more bytes. + transportState().onSentBytes(numBytes); + } + }), flush); + } + + @Override + public void writeTrailers(Metadata trailers, boolean headersSent) { + Http2Headers http2Trailers = Utils.convertTrailers(trailers, headersSent); + writeQueue.enqueue( + new SendResponseHeadersCommand(transportState(), http2Trailers, true), true); + } + + @Override + public void cancel(Status status) { + writeQueue.enqueue(new CancelServerStreamCommand(transportState(), status), true); + } + } + + /** This should only called from the transport thread. */ + public static class TransportState extends AbstractServerStream.TransportState + implements StreamIdHolder { + private final Http2Stream http2Stream; + private final NettyServerHandler handler; + + public TransportState(NettyServerHandler handler, Http2Stream http2Stream, int maxMessageSize) { + super(maxMessageSize); + this.http2Stream = checkNotNull(http2Stream, "http2Stream"); + this.handler = checkNotNull(handler, "handler"); + } + + @Override + public void bytesRead(int processedBytes) { + handler.returnProcessedBytes(http2Stream, processedBytes); + handler.getWriteQueue().scheduleFlush(); + } + + @Override + protected void deframeFailed(Throwable cause) { + log.log(Level.WARNING, "Exception processing message", cause); + Status status = Status.fromThrowable(cause); + transportReportStatus(status); + handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true); + } + + void inboundDataReceived(ByteBuf frame, boolean endOfStream) { + super.inboundDataReceived(new NettyReadableBuffer(frame.retain()), endOfStream); + } + + public Integer id() { + return http2Stream.id(); + } + } } diff --git a/netty/src/main/java/io/grpc/netty/RequestMessagesCommand.java b/netty/src/main/java/io/grpc/netty/RequestMessagesCommand.java index b707490842..404139923d 100644 --- a/netty/src/main/java/io/grpc/netty/RequestMessagesCommand.java +++ b/netty/src/main/java/io/grpc/netty/RequestMessagesCommand.java @@ -31,7 +31,8 @@ package io.grpc.netty; -import io.grpc.internal.AbstractStream; +import io.grpc.internal.AbstractStream2; +import io.grpc.internal.Stream; /** * Command which requests messages from the deframer. @@ -39,14 +40,26 @@ import io.grpc.internal.AbstractStream; class RequestMessagesCommand { private final int numMessages; - private final AbstractStream stream; + private final Stream stream; + private final AbstractStream2.TransportState state; - public RequestMessagesCommand(AbstractStream stream, int numMessages) { + public RequestMessagesCommand(Stream stream, int numMessages) { + this.state = null; this.numMessages = numMessages; this.stream = stream; } + public RequestMessagesCommand(AbstractStream2.TransportState state, int numMessages) { + this.state = state; + this.numMessages = numMessages; + this.stream = null; + } + void requestMessages() { - stream.request(numMessages); + if (stream != null) { + stream.request(numMessages); + } else { + state.requestMessagesFromDeframer(numMessages); + } } } diff --git a/netty/src/main/java/io/grpc/netty/SendGrpcFrameCommand.java b/netty/src/main/java/io/grpc/netty/SendGrpcFrameCommand.java index c06150aba7..b9258a29e3 100644 --- a/netty/src/main/java/io/grpc/netty/SendGrpcFrameCommand.java +++ b/netty/src/main/java/io/grpc/netty/SendGrpcFrameCommand.java @@ -31,7 +31,6 @@ package io.grpc.netty; -import io.grpc.internal.AbstractStream; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufHolder; import io.netty.buffer.DefaultByteBufHolder; @@ -40,11 +39,10 @@ import io.netty.buffer.DefaultByteBufHolder; * Command sent from the transport to the Netty channel to send a GRPC frame to the remote endpoint. */ class SendGrpcFrameCommand extends DefaultByteBufHolder { - private final AbstractStream stream; + private final StreamIdHolder stream; private final boolean endStream; - SendGrpcFrameCommand(AbstractStream stream, ByteBuf content, - boolean endStream) { + SendGrpcFrameCommand(StreamIdHolder stream, ByteBuf content, boolean endStream) { super(content); this.stream = stream; this.endStream = endStream; diff --git a/netty/src/main/java/io/grpc/netty/SendResponseHeadersCommand.java b/netty/src/main/java/io/grpc/netty/SendResponseHeadersCommand.java index 46bb0e59e7..f07aa79e59 100644 --- a/netty/src/main/java/io/grpc/netty/SendResponseHeadersCommand.java +++ b/netty/src/main/java/io/grpc/netty/SendResponseHeadersCommand.java @@ -39,18 +39,18 @@ import io.netty.handler.codec.http2.Http2Headers; * Command sent from the transport to the Netty channel to send response headers to the client. */ class SendResponseHeadersCommand { - private final int streamId; + private final StreamIdHolder stream; private final Http2Headers headers; private final boolean endOfStream; - SendResponseHeadersCommand(int streamId, Http2Headers headers, boolean endOfStream) { - this.streamId = streamId; + SendResponseHeadersCommand(StreamIdHolder stream, Http2Headers headers, boolean endOfStream) { + this.stream = Preconditions.checkNotNull(stream); this.headers = Preconditions.checkNotNull(headers); this.endOfStream = endOfStream; } - int streamId() { - return streamId; + StreamIdHolder stream() { + return stream; } Http2Headers headers() { @@ -67,19 +67,19 @@ class SendResponseHeadersCommand { return false; } SendResponseHeadersCommand thatCmd = (SendResponseHeadersCommand) that; - return thatCmd.streamId == streamId + return thatCmd.stream.equals(stream) && thatCmd.headers.equals(headers) && thatCmd.endOfStream == endOfStream; } @Override public String toString() { - return getClass().getSimpleName() + "(streamId=" + streamId + ", headers=" + headers + return getClass().getSimpleName() + "(stream=" + stream.id() + ", headers=" + headers + ", endOfStream=" + endOfStream + ")"; } @Override public int hashCode() { - return streamId; + return stream.hashCode(); } } diff --git a/netty/src/main/java/io/grpc/netty/StreamIdHolder.java b/netty/src/main/java/io/grpc/netty/StreamIdHolder.java new file mode 100644 index 0000000000..b3f677f6a4 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/StreamIdHolder.java @@ -0,0 +1,37 @@ +/* + * Copyright 2016, Google Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package io.grpc.netty; + +/** Container for stream ids. */ +interface StreamIdHolder { + public Integer id(); +} diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 52ec3fd402..92e2bed36c 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -133,7 +133,8 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase> { +public abstract class NettyStreamTestBase { protected static final String MESSAGE = "hello world"; protected static final int STREAM_ID = 1; @@ -135,7 +135,8 @@ public abstract class NettyStreamTestBase> { stream.request(1); if (stream instanceof NettyServerStream) { - ((NettyServerStream) stream).inboundDataReceived(messageFrame(MESSAGE), false); + ((NettyServerStream) stream).transportState() + .inboundDataReceived(messageFrame(MESSAGE), false); } else { ((NettyClientStream) stream).transportDataReceived(messageFrame(MESSAGE), false); }