diff --git a/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java index bd1d4762a2..23c5a76c74 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java +++ b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java @@ -85,19 +85,15 @@ final class AsyncSink implements Sink { } writeEnqueued = true; } - serializingExecutor.execute(new Runnable() { + serializingExecutor.execute(new WriteRunnable() { @Override - public void run() { + public void doRun() throws IOException { Buffer buf = new Buffer(); synchronized (lock) { buf.write(buffer, buffer.completeSegmentByteCount()); writeEnqueued = false; } - try { - sink.write(buf, buf.size()); - } catch (IOException e) { - transportExceptionHandler.onException(e); - } + sink.write(buf, buf.size()); } }); } @@ -113,20 +109,16 @@ final class AsyncSink implements Sink { } flushEnqueued = true; } - serializingExecutor.execute(new Runnable() { + serializingExecutor.execute(new WriteRunnable() { @Override - public void run() { + public void doRun() throws IOException { Buffer buf = new Buffer(); synchronized (lock) { buf.write(buffer, buffer.size()); flushEnqueued = false; } - try { - sink.write(buf, buf.size()); - sink.flush(); - } catch (IOException e) { - transportExceptionHandler.onException(e); - } + sink.write(buf, buf.size()); + sink.flush(); } }); } @@ -147,16 +139,36 @@ final class AsyncSink implements Sink { public void run() { buffer.close(); try { - sink.close(); + if (sink != null) { + sink.close(); + } } catch (IOException e) { transportExceptionHandler.onException(e); } try { - socket.close(); + if (socket != null) { + socket.close(); + } } catch (IOException e) { transportExceptionHandler.onException(e); } } }); } + + private abstract class WriteRunnable implements Runnable { + @Override + public final void run() { + try { + if (sink == null) { + throw new IOException("Unable to perform write due to unavailable sink."); + } + doRun(); + } catch (Exception e) { + transportExceptionHandler.onException(e); + } + } + + public abstract void doRun() throws IOException; + } } \ No newline at end of file diff --git a/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java b/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java index 5a0e848518..b7cb774579 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java @@ -41,10 +41,10 @@ import java.util.concurrent.Executor; import okio.Buffer; import okio.Sink; import okio.Timeout; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.InOrder; /** Tests for {@link AsyncSink}. */ @@ -58,14 +58,10 @@ public class AsyncSinkTest { private final AsyncSink sink = AsyncSink.sink(new SerializingExecutor(queueingExecutor), exceptionHandler); - @Before - public void setUp() throws Exception { - sink.becomeConnected(mockedSink, socket); - } - @Test public void noCoalesceRequired() throws IOException { Buffer buffer = new Buffer(); + sink.becomeConnected(mockedSink, socket); sink.write(buffer.writeUtf8("hello"), buffer.size()); sink.flush(); queueingExecutor.runAll(); @@ -80,6 +76,7 @@ public class AsyncSinkTest { byte[] firstData = "a string".getBytes(Charsets.UTF_8); byte[] secondData = "a longer string".getBytes(Charsets.UTF_8); + sink.becomeConnected(mockedSink, socket); Buffer buffer = new Buffer(); sink.write(buffer.write(firstData), buffer.size()); sink.flush(); @@ -101,6 +98,7 @@ public class AsyncSinkTest { byte[] firstData = "a string".getBytes(Charsets.UTF_8); byte[] secondData = "a longer string".getBytes(Charsets.UTF_8); Buffer buffer = new Buffer().write(firstData); + sink.becomeConnected(mockedSink, socket); sink.write(buffer, buffer.size()); sink.flush(); buffer = new Buffer().write(secondData); @@ -120,6 +118,7 @@ public class AsyncSinkTest { byte[] firstData = "a string".getBytes(Charsets.UTF_8); byte[] secondData = "a longer string".getBytes(Charsets.UTF_8); Buffer buffer = new Buffer(); + sink.becomeConnected(mockedSink, socket); sink.write(buffer.write(firstData), buffer.size()); sink.write(buffer.write(secondData), buffer.size()); sink.flush(); @@ -138,6 +137,7 @@ public class AsyncSinkTest { .when(mockedSink).write(any(Buffer.class), anyLong()); Buffer buffer = new Buffer(); buffer.writeUtf8("any message"); + sink.becomeConnected(mockedSink, socket); sink.write(buffer, buffer.size()); sink.flush(); queueingExecutor.runAll(); @@ -166,6 +166,7 @@ public class AsyncSinkTest { .when(mockedSink).write(any(Buffer.class), anyLong()); Buffer buffer = new Buffer(); buffer.writeUtf8("any message"); + sink.becomeConnected(mockedSink, socket); sink.write(buffer, buffer.size()); sink.close(); queueingExecutor.runAll(); @@ -180,6 +181,7 @@ public class AsyncSinkTest { @Test public void close_flushShouldThrowException() throws IOException { + sink.becomeConnected(mockedSink, socket); sink.close(); queueingExecutor.runAll(); try { @@ -195,6 +197,7 @@ public class AsyncSinkTest { public void flush_shouldThrowIfAlreadyClosed() throws IOException { Buffer buffer = new Buffer(); buffer.writeUtf8("any message"); + sink.becomeConnected(mockedSink, socket); sink.write(buffer, buffer.size()); sink.close(); queueingExecutor.runAll(); @@ -210,6 +213,7 @@ public class AsyncSinkTest { @Test public void write_callSinkIfBufferIsLargerThanSegmentSize() throws IOException { Buffer buffer = new Buffer(); + sink.becomeConnected(mockedSink, socket); // OkHttp is using 8192 as segment size. int payloadSize = 8192 * 2 - 1; int padding = 10; @@ -240,6 +244,40 @@ public class AsyncSinkTest { verify(mockedSink).flush(); } + @Test + public void writeAndFlush_beforeConnected() throws IOException { + Buffer buffer = new Buffer(); + sink.write(buffer.writeUtf8("hello"), buffer.size()); + sink.flush(); + queueingExecutor.runAll(); + + verify(mockedSink, never()).write(any(Buffer.class), anyLong()); + verify(mockedSink, never()).flush(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Throwable.class); + + verify(exceptionHandler).onException(captor.capture()); + + Throwable t = captor.getValue(); + assertThat(t).isInstanceOf(IOException.class); + assertThat(t).hasMessageThat().contains("unavailable sink"); + } + + @Test + public void close_multipleCloseShouldNotThrow() throws IOException { + sink.becomeConnected(mockedSink, socket); + + sink.close(); + queueingExecutor.runAll(); + + verify(exceptionHandler, never()).onException(any(Throwable.class)); + + sink.close(); + queueingExecutor.runAll(); + + verify(exceptionHandler, never()).onException(any(Throwable.class)); + } + /** * Executor queues incoming runnables instead of running it. Runnables can be invoked via {@link * QueueingExecutor#runAll} in serial order.