diff --git a/core/src/main/java/io/grpc/transport/MessageFramer.java b/core/src/main/java/io/grpc/transport/MessageFramer.java index e17eee3162..10c5a726b1 100644 --- a/core/src/main/java/io/grpc/transport/MessageFramer.java +++ b/core/src/main/java/io/grpc/transport/MessageFramer.java @@ -219,6 +219,7 @@ public class MessageFramer { */ public void close() { if (!isClosed()) { + closed = true; // With the current code we don't expect readableBytes > 0 to be possible here, added // defensively to prevent buffer leak issues if the framer code changes later. if (buffer != null && buffer.readableBytes() == 0) { @@ -226,7 +227,6 @@ public class MessageFramer { buffer = null; } commitToSink(true, true); - closed = true; } } diff --git a/core/src/test/java/io/grpc/transport/MessageFramerTest.java b/core/src/test/java/io/grpc/transport/MessageFramerTest.java index c83737eb0e..d12a843679 100644 --- a/core/src/test/java/io/grpc/transport/MessageFramerTest.java +++ b/core/src/test/java/io/grpc/transport/MessageFramerTest.java @@ -34,6 +34,7 @@ package io.grpc.transport; import static io.grpc.transport.MessageFramer.Compression; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -49,6 +50,7 @@ import org.mockito.MockitoAnnotations; import java.io.ByteArrayInputStream; import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; /** * Tests for {@link MessageFramer}. @@ -76,7 +78,7 @@ public class MessageFramerTest { writePayload(framer, new byte[] {3, 14}); verifyNoMoreInteractions(sink); framer.flush(); - verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 2, 3, 14}), false, true); + verify(sink).deliverFrame(toWriteBuffer(new byte[]{0, 0, 0, 0, 2, 3, 14}), false, true); assertEquals(1, allocator.allocCount); verifyNoMoreInteractions(sink); } @@ -87,11 +89,11 @@ public class MessageFramerTest { framer = new MessageFramer(sink, allocator); writePayload(framer, new byte[] {3}); verifyNoMoreInteractions(sink); - writePayload(framer, new byte[] {14}); + writePayload(framer, new byte[]{14}); verifyNoMoreInteractions(sink); framer.flush(); verify(sink).deliverFrame( - toWriteBuffer(new byte[] {0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 1, 14}), false, true); + toWriteBuffer(new byte[]{0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 1, 14}), false, true); verifyNoMoreInteractions(sink); assertEquals(1, allocator.allocCount); } @@ -102,7 +104,7 @@ public class MessageFramerTest { verifyNoMoreInteractions(sink); framer.close(); verify(sink).deliverFrame( - toWriteBuffer(new byte[] {0, 0, 0, 0, 7, 3, 14, 1, 5, 9, 2, 6}), true, true); + toWriteBuffer(new byte[]{0, 0, 0, 0, 7, 3, 14, 1, 5, 9, 2, 6}), true, true); verifyNoMoreInteractions(sink); assertEquals(1, allocator.allocCount); } @@ -123,7 +125,7 @@ public class MessageFramerTest { verifyNoMoreInteractions(sink); framer.flush(); - verify(sink).deliverFrame(toWriteBuffer(new byte[] {5}), false, true); + verify(sink).deliverFrame(toWriteBuffer(new byte[]{5}), false, true); verifyNoMoreInteractions(sink); assertEquals(2, allocator.allocCount); } @@ -139,7 +141,7 @@ public class MessageFramerTest { verifyNoMoreInteractions(sink); framer.flush(); - verify(sink).deliverFrame(toWriteBufferWithMinSize(new byte[] {1, 3}, 12), false, true); + verify(sink).deliverFrame(toWriteBufferWithMinSize(new byte[]{1, 3}, 12), false, true); verifyNoMoreInteractions(sink); assertEquals(2, allocator.allocCount); } @@ -157,7 +159,7 @@ public class MessageFramerTest { writePayload(framer, new byte[] {3, 14}); framer.flush(); framer.flush(); - verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 2, 3, 14}), false, true); + verify(sink).deliverFrame(toWriteBuffer(new byte[]{0, 0, 0, 0, 2, 3, 14}), false, true); verifyNoMoreInteractions(sink); assertEquals(1, allocator.allocCount); } @@ -201,6 +203,26 @@ public class MessageFramerTest { assertEquals(1, allocator.allocCount); } + @Test + public void closeIsRentrantSafe() throws Exception { + MessageFramer.Sink reentrant = new MessageFramer.Sink() { + int count = 0; + @Override + public void deliverFrame(WritableBuffer frame, boolean endOfStream, boolean flush) { + if (count == 0) { + framer.close(); + count++; + } else { + fail("received event from reentrant call to close"); + } + } + }; + framer = new MessageFramer(reentrant, allocator, Compression.NONE); + writePayload(framer, new byte[]{3, 14}); + framer.close(); + } + + private static WritableBuffer toWriteBuffer(byte[] data) { return toWriteBufferWithMinSize(data, 0); }