Moving decompression to the channel thread.

This commit is contained in:
nmittler 2015-01-27 12:53:34 -08:00
parent 2049e0d618
commit 89b8d7ff47
14 changed files with 146 additions and 93 deletions

View File

@ -349,7 +349,7 @@ public final class ChannelImpl implements Channel {
}
@Override
public void messageRead(final InputStream message, final int length) {
public void messageRead(final InputStream message) {
callExecutor.execute(new Runnable() {
@Override
public void run() {

View File

@ -325,7 +325,7 @@ public class ServerImpl implements Server {
private static class NoopListener implements ServerStreamListener {
@Override
public void messageRead(InputStream value, int length) {
public void messageRead(InputStream value) {
try {
value.close();
} catch (IOException e) {
@ -378,12 +378,12 @@ public class ServerImpl implements Server {
}
@Override
public void messageRead(final InputStream message, final int length) {
public void messageRead(final InputStream message) {
callExecutor.execute(new Runnable() {
@Override
public void run() {
try {
getListener().messageRead(message, length);
getListener().messageRead(message);
} catch (Throwable t) {
internalClose(Status.fromThrowable(t), new Metadata.Trailers());
throw Throwables.propagate(t);
@ -476,7 +476,7 @@ public class ServerImpl implements Server {
}
@Override
public void messageRead(final InputStream message, int length) {
public void messageRead(final InputStream message) {
if (cancelled) {
return;
}

View File

@ -67,9 +67,9 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
}
@Override
protected void receiveMessage(InputStream is, int length) {
protected void receiveMessage(InputStream is) {
if (!listenerClosed) {
listener.messageRead(is, length);
listener.messageRead(is);
}
}
@ -203,12 +203,11 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
closeListenerTask = null;
// Determine if the deframer is stalled (i.e. currently has no complete messages to deliver).
boolean deliveryStalled = deframer.isStalled();
boolean deliveryStalled = isDeframerStalled();
if (stopDelivery || deliveryStalled) {
// Close the listener immediately.
listenerClosed = true;
listener.closed(newStatus, trailers);
closeListener(newStatus, trailers);
} else {
// Delay close until inboundDeliveryStalled()
closeListenerTask = newCloseListenerTask(newStatus, trailers);
@ -222,15 +221,22 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
return new Runnable() {
@Override
public void run() {
if (!listenerClosed) {
// Status has not been reported to the application layer
listenerClosed = true;
listener.closed(status, trailers);
}
closeListener(status, trailers);
}
};
}
/**
* Closes the listener if not previously closed.
*/
private void closeListener(Status newStatus, Metadata.Trailers trailers) {
if (!listenerClosed) {
listenerClosed = true;
closeDeframer();
listener.closed(newStatus, trailers);
}
}
/**
* Executes the pending listener close task, if one exists.
*/

View File

@ -71,9 +71,9 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
}
@Override
protected void receiveMessage(InputStream is, int length) {
protected void receiveMessage(InputStream is) {
inboundPhase(Phase.MESSAGE);
listener.messageRead(is, length);
listener.messageRead(is);
}
@Override
@ -180,12 +180,11 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
* abortStream()} for abnormal.
*/
public void complete() {
listenerClosed = true;
if (!gracefulClose) {
listener.closed(Status.INTERNAL.withDescription("successful complete() without close()"));
closeListener(Status.INTERNAL.withDescription("successful complete() without close()"));
throw new IllegalStateException("successful complete() without close()");
}
listener.closed(Status.OK);
closeListener(Status.OK);
}
/**
@ -193,9 +192,7 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
*/
@Override
protected final void remoteEndClosed() {
if (inboundPhase(Phase.STATUS) != Phase.STATUS) {
listener.halfClosed();
}
halfCloseListener();
}
/**
@ -214,10 +211,7 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
// TODO(lryan): Investigate whether we can remove the notification to the client
// and rely on a transport layer stream reset instead.
Preconditions.checkArgument(!status.isOk(), "status must not be OK");
if (!listenerClosed) {
listenerClosed = true;
listener.closed(status);
}
closeListener(status);
if (notifyClient) {
// TODO(lryan): Remove
if (stashedTrailers == null) {
@ -234,4 +228,25 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
public boolean isClosed() {
return super.isClosed() || listenerClosed;
}
/**
* Fires a half-closed event to the listener and frees inbound resources.
*/
private void halfCloseListener() {
if (inboundPhase(Phase.STATUS) != Phase.STATUS && !listenerClosed) {
closeDeframer();
listener.halfClosed();
}
}
/**
* Closes the listener if not previously closed and frees resources.
*/
private void closeListener(Status newStatus) {
if (!listenerClosed) {
listenerClosed = true;
closeDeframer();
listener.closed(newStatus);
}
}
}

View File

@ -53,8 +53,7 @@ public abstract class AbstractStream<IdT> implements Stream {
private volatile IdT id;
private final MessageFramer framer;
final MessageDeframer deframer;
private final MessageDeframer deframer;
/**
* Inbound phase is exclusively written to by the transport thread.
@ -74,8 +73,8 @@ public abstract class AbstractStream<IdT> implements Stream {
}
@Override
public void messageRead(InputStream input, final int length) {
receiveMessage(input, length);
public void messageRead(InputStream input) {
receiveMessage(input);
}
@Override
@ -173,7 +172,7 @@ public abstract class AbstractStream<IdT> implements Stream {
protected abstract void internalSendFrame(ByteBuffer frame, boolean endOfStream);
/** A message was deframed. */
protected abstract void receiveMessage(InputStream is, int length);
protected abstract void receiveMessage(InputStream is);
/** Deframer has no pending deliveries. */
protected abstract void inboundDeliveryPaused();
@ -192,6 +191,14 @@ public abstract class AbstractStream<IdT> implements Stream {
*/
protected abstract void deframeFailed(Throwable cause);
/**
* Closes this deframer and frees any resources. After this method is called, additional calls
* will have no effect.
*/
protected final void closeDeframer() {
deframer.close();
}
/**
* Called to parse a received frame and attempt delivery of any completed
* messages. Must be called from the transport thread.
@ -204,6 +211,13 @@ public abstract class AbstractStream<IdT> implements Stream {
}
}
/**
* Indicates whether delivery is currently stalled, pending receipt of more data.
*/
protected final boolean isDeframerStalled() {
return deframer.isStalled();
}
/**
* Called to request the given number of messages from the deframer. Must be called
* from the transport thread.

View File

@ -32,11 +32,9 @@
package io.grpc.transport;
import com.google.common.base.Preconditions;
import com.google.common.io.ByteStreams;
import io.grpc.Status;
import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
@ -76,9 +74,8 @@ public class MessageDeframer implements Closeable {
* Called to deliver the next complete message.
*
* @param is stream containing the message.
* @param length the length in bytes of the message.
*/
void messageRead(InputStream is, int length);
void messageRead(InputStream is);
/**
* Called when end-of-stream has not yet been reached but there are no complete messages
@ -135,6 +132,7 @@ public class MessageDeframer implements Closeable {
* @param numMessages the requested number of messages to be delivered to the listener.
*/
public void request(int numMessages) {
checkNotClosed();
Preconditions.checkArgument(numMessages > 0, "numMessages must be > 0");
pendingDeliveries += numMessages;
deliver();
@ -144,6 +142,7 @@ public class MessageDeframer implements Closeable {
* Adds the given data to this deframer and attempts delivery to the sink.
*/
public void deframe(Buffer data, boolean endOfStream) {
checkNotClosed();
Preconditions.checkNotNull(data, "data");
Preconditions.checkState(!this.endOfStream, "Past end of stream");
unprocessed.addBuffer(data);
@ -160,14 +159,32 @@ public class MessageDeframer implements Closeable {
return deliveryStalled;
}
/**
* Closes this deframer and frees any resources. After this method is called, additional
* calls will have no effect.
*/
@Override
public void close() {
unprocessed.close();
if (nextFrame != null) {
nextFrame.close();
try {
if (unprocessed != null) {
unprocessed.close();
}
if (nextFrame != null) {
nextFrame.close();
}
} finally {
unprocessed = null;
nextFrame = null;
}
}
/**
* Throws if this deframer has already been closed.
*/
private void checkNotClosed() {
Preconditions.checkState(unprocessed != null, "MessageDeframer is already closed");
}
/**
* Reads and delivers as many messages to the sink as possible.
*/
@ -256,8 +273,7 @@ public class MessageDeframer implements Closeable {
private void processHeader() {
int type = nextFrame.readUnsignedByte();
if ((type & RESERVED_MASK) != 0) {
throw Status.INTERNAL
.withDescription("Frame header malformed: reserved bits not zero")
throw Status.INTERNAL.withDescription("Frame header malformed: reserved bits not zero")
.asRuntimeException();
}
compressedFlag = (type & COMPRESSED_FLAG_MASK) != 0;
@ -274,31 +290,33 @@ public class MessageDeframer implements Closeable {
* several GRPC messages within it.
*/
private void processBody() {
if (compressedFlag) {
if (compression == Compression.NONE) {
throw Status.INTERNAL.withDescription(
"Can't decode compressed frame as compression not configured.").asRuntimeException();
} else if (compression == Compression.GZIP) {
// Fully drain frame.
byte[] bytes;
try {
bytes =
ByteStreams.toByteArray(new GZIPInputStream(Buffers.openStream(nextFrame, false)));
} catch (IOException ex) {
throw new RuntimeException(ex);
}
listener.messageRead(new ByteArrayInputStream(bytes), bytes.length);
} else {
throw new AssertionError("Unknown compression type");
}
} else {
// Don't close the frame, since the sink is now responsible for the life-cycle.
listener.messageRead(Buffers.openStream(nextFrame, true), nextFrame.readableBytes());
nextFrame = null;
}
InputStream stream = compressedFlag ? getCompressedBody() : getUncompressedBody();
nextFrame = null;
listener.messageRead(stream);
// Done with this frame, begin processing the next header.
state = State.HEADER;
requiredLength = HEADER_LENGTH;
}
private InputStream getUncompressedBody() {
return Buffers.openStream(nextFrame, true);
}
private InputStream getCompressedBody() {
if (compression == Compression.NONE) {
throw Status.INTERNAL.withDescription(
"Can't decode compressed frame as compression not configured.").asRuntimeException();
}
if (compression != Compression.GZIP) {
throw new AssertionError("Unknown compression type");
}
try {
return new GZIPInputStream(Buffers.openStream(nextFrame, true));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -47,7 +47,6 @@ public interface StreamListener {
* <p>This method should return quickly, as the same thread may be used to process other streams.
*
* @param message the bytes of the message.
* @param length the length of the message {@link InputStream}.
*/
void messageRead(InputStream message, int length);
void messageRead(InputStream message);
}

View File

@ -221,7 +221,7 @@ public class ServerImplTest {
assertNotNull(call);
String order = "Lots of pizza, please";
streamListener.messageRead(STRING_MARSHALLER.stream(order), 1);
streamListener.messageRead(STRING_MARSHALLER.stream(order));
verify(callListener, timeout(2000)).onPayload(order);
call.sendPayload(314);

View File

@ -34,9 +34,9 @@ package io.grpc.transport;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -69,7 +69,7 @@ public class MessageDeframerTest {
public void simplePayload() {
deframer.request(1);
deframer.deframe(buffer(new byte[]{0, 0, 0, 0, 2, 3, 14}), false);
verify(listener).messageRead(messages.capture(), eq(2));
verify(listener).messageRead(messages.capture());
assertEquals(Bytes.asList(new byte[]{3, 14}), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
@ -79,11 +79,12 @@ public class MessageDeframerTest {
public void smallCombinedPayloads() {
deframer.request(2);
deframer.deframe(buffer(new byte[]{0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 2, 14, 15}), false);
verify(listener).messageRead(messages.capture(), eq(1));
assertEquals(Bytes.asList(new byte[] {3}), bytes(messages));
verify(listener).messageRead(messages.capture(), eq(2));
verify(listener, times(2)).messageRead(messages.capture());
List<InputStream> streams = messages.getAllValues();
assertEquals(2, streams.size());
assertEquals(Bytes.asList(new byte[] {3}), bytes(streams.get(0)));
verify(listener, atLeastOnce()).bytesRead(anyInt());
assertEquals(Bytes.asList(new byte[] {14, 15}), bytes(messages));
assertEquals(Bytes.asList(new byte[] {14, 15}), bytes(streams.get(1)));
verifyNoMoreInteractions(listener);
}
@ -91,7 +92,7 @@ public class MessageDeframerTest {
public void endOfStreamWithPayloadShouldNotifyEndOfStream() {
deframer.request(1);
deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 1, 3}), true);
verify(listener).messageRead(messages.capture(), eq(1));
verify(listener).messageRead(messages.capture());
assertEquals(Bytes.asList(new byte[] {3}), bytes(messages));
verify(listener).endOfStream();
verify(listener, atLeastOnce()).bytesRead(anyInt());
@ -112,7 +113,7 @@ public class MessageDeframerTest {
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
deframer.deframe(buffer(new byte[] {2, 6}), false);
verify(listener).messageRead(messages.capture(), eq(7));
verify(listener).messageRead(messages.capture());
assertEquals(Bytes.asList(new byte[] {3, 14, 1, 5, 9, 2, 6}), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
@ -126,7 +127,7 @@ public class MessageDeframerTest {
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
deframer.deframe(buffer(new byte[] {0, 0, 1, 3}), false);
verify(listener).messageRead(messages.capture(), eq(1));
verify(listener).messageRead(messages.capture());
assertEquals(Bytes.asList(new byte[] {3}), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
@ -136,7 +137,7 @@ public class MessageDeframerTest {
public void emptyPayload() {
deframer.request(1);
deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 0}), false);
verify(listener).messageRead(messages.capture(), eq(0));
verify(listener).messageRead(messages.capture());
assertEquals(Bytes.asList(), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
@ -147,7 +148,7 @@ public class MessageDeframerTest {
deframer.request(1);
deframer.deframe(
Buffers.wrap(Bytes.concat(new byte[] {0, 0, 0, 3, (byte) 232}, new byte[1000])), false);
verify(listener).messageRead(messages.capture(), eq(1000));
verify(listener).messageRead(messages.capture());
assertEquals(Bytes.asList(new byte[1000]), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
@ -159,7 +160,7 @@ public class MessageDeframerTest {
verifyNoMoreInteractions(listener);
deframer.request(1);
verify(listener).messageRead(messages.capture(), eq(1));
verify(listener).messageRead(messages.capture());
assertEquals(Bytes.asList(new byte[] {3}), bytes(messages));
verify(listener).endOfStream();
verify(listener, atLeastOnce()).bytesRead(anyInt());
@ -175,15 +176,19 @@ public class MessageDeframerTest {
assertTrue(payload.length < 100);
byte[] header = new byte[] {1, 0, 0, 0, (byte) payload.length};
deframer.deframe(buffer(Bytes.concat(header, payload)), false);
verify(listener).messageRead(messages.capture(), eq(1000));
verify(listener).messageRead(messages.capture());
assertEquals(Bytes.asList(new byte[1000]), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
}
private static List<Byte> bytes(ArgumentCaptor<InputStream> captor) {
return bytes(captor.getValue());
}
private static List<Byte> bytes(InputStream in) {
try {
return Bytes.asList(ByteStreams.toByteArray(captor.getValue()));
return Bytes.asList(ByteStreams.toByteArray(in));
} catch (IOException ex) {
throw new AssertionError(ex);
}

View File

@ -191,7 +191,7 @@ class NettyServerHandler extends Http2ConnectionHandler {
protected void onStreamError(ChannelHandlerContext ctx, Throwable cause,
StreamException http2Ex) {
logger.log(Level.WARNING, "Stream Error", cause);
Http2Stream stream = connection().stream(http2Ex.streamId(http2Ex));
Http2Stream stream = connection().stream(Http2Exception.streamId(http2Ex));
if (stream != null) {
// Abort the stream with a status to help the client with debugging.
// Don't need to send a RST_STREAM since the end-of-stream flag will

View File

@ -40,7 +40,6 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.never;
@ -240,7 +239,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
stream().transportHeadersReceived(grpcResponseTrailers(Status.INTERNAL), true);
// Verify that the first was delivered.
verify(listener).messageRead(any(InputStream.class), anyInt());
verify(listener).messageRead(any(InputStream.class));
// Now set the error status.
Metadata.Trailers trailers = Utils.convertTrailers(grpcResponseTrailers(Status.CANCELLED));
@ -250,7 +249,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
stream().request(1);
// Verify that the listener was only notified of the first message, not the second.
verify(listener).messageRead(any(InputStream.class), anyInt());
verify(listener).messageRead(any(InputStream.class));
verify(listener).closed(eq(Status.CANCELLED), eq(trailers));
}

View File

@ -42,7 +42,6 @@ import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never;
@ -168,7 +167,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
ByteBuf frame = dataFrame(STREAM_ID, endStream);
handler.channelRead(ctx, frame);
ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
verify(streamListener).messageRead(captor.capture(), eq(CONTENT.length));
verify(streamListener).messageRead(captor.capture());
assertArrayEquals(CONTENT, ByteStreams.toByteArray(captor.getValue()));
if (endStream) {
@ -184,7 +183,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
handler.channelRead(ctx, emptyGrpcFrame(STREAM_ID, true));
ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
verify(streamListener).messageRead(captor.capture(), anyInt());
verify(streamListener).messageRead(captor.capture());
assertArrayEquals(new byte[0], ByteStreams.toByteArray(captor.getValue()));
verify(streamListener).halfClosed();
verifyNoMoreInteractions(streamListener);
@ -195,7 +194,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
createStream();
handler.channelRead(ctx, rstStreamFrame(STREAM_ID, (int) Http2Error.CANCEL.code()));
verify(streamListener, never()).messageRead(any(InputStream.class), anyInt());
verify(streamListener, never()).messageRead(any(InputStream.class));
verify(streamListener).closed(Status.CANCELLED);
verifyNoMoreInteractions(streamListener);
}
@ -208,7 +207,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
// When a DATA frame is read, throw an exception. It will be converted into an
// Http2StreamException.
RuntimeException e = new RuntimeException("Fake Exception");
doThrow(e).when(streamListener).messageRead(any(InputStream.class), anyInt());
doThrow(e).when(streamListener).messageRead(any(InputStream.class));
// Read a DATA frame to trigger the exception.
handler.channelRead(ctx, emptyGrpcFrame(STREAM_ID, true));
@ -266,7 +265,6 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
ArgumentCaptor<NettyServerStream> streamCaptor =
ArgumentCaptor.forClass(NettyServerStream.class);
@SuppressWarnings("rawtypes")
ArgumentCaptor<String> methodCaptor = ArgumentCaptor.forClass(String.class);
verify(transportListener).streamCreated(streamCaptor.capture(), methodCaptor.capture(),
any(Metadata.Headers.class));

View File

@ -36,7 +36,6 @@ import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ -132,7 +131,7 @@ public abstract class NettyStreamTestBase {
((NettyClientStream) stream).transportDataReceived(messageFrame(MESSAGE), false);
}
ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
verify(listener()).messageRead(captor.capture(), eq(MESSAGE.length()));
verify(listener()).messageRead(captor.capture());
// Verify that inbound flow control window update has been disabled for the stream.
assertEquals(MESSAGE, NettyTestUtil.toString(captor.getValue()));

View File

@ -498,7 +498,7 @@ public class OkHttpClientTransportTest {
}
@Override
public void messageRead(InputStream message, int length) {
public void messageRead(InputStream message) {
String msg = getContent(message);
if (msg != null) {
messages.add(msg);