diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index f121f6a30e..8c8ef855bb 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -272,6 +272,9 @@ class OkHttpClientTransport implements ClientTransport { frameWriter.flush(); } if (nextStreamId >= Integer.MAX_VALUE - 2) { + // Make sure nextStreamId greater than all used id, so that mayHaveCreatedStream() performs + // correctly. + nextStreamId = Integer.MAX_VALUE; onGoAway(Integer.MAX_VALUE, Status.INTERNAL.withDescription("Stream ids exhausted")); } else { nextStreamId += 2; diff --git a/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java b/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java index b686667e06..259a5d75e7 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java @@ -119,9 +119,6 @@ class OutboundFlowController { */ void data(boolean outFinished, int streamId, Buffer source, boolean flush) { Preconditions.checkNotNull(source, "source"); - if (streamId <= 0 || !transport.mayHaveCreatedStream(streamId)) { - throw new IllegalArgumentException("Invalid streamId: " + streamId); - } OkHttpClientStream stream = transport.getStream(streamId); if (stream == null) { diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index 6c8159a710..2a86c28ed1 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -603,13 +603,28 @@ public class OkHttpClientTransportTest { int startId = Integer.MAX_VALUE - 2; initTransport(startId, new ConnectedCallback(false)); - MockStreamListener listener1 = new MockStreamListener(); - clientTransport.newStream(method, new Metadata.Headers(), listener1); + MockStreamListener listener = new MockStreamListener(); + clientTransport.newStream(method, new Metadata.Headers(), listener).request(1); + // New stream should be failed. assertNewStreamFail(); + // The alive stream should be functional, receives a message. + frameHandler().headers( + false, false, startId, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); + assertNotNull(listener.headers); + String message = "hello"; + Buffer buffer = createMessageFrame(message); + frameHandler().data(false, startId, buffer, (int) buffer.size()); + getStream(startId).cancel(Status.CANCELLED); - listener1.waitUntilStreamClosed(); + // Receives the second message after be cancelled. + buffer = createMessageFrame(message); + frameHandler().data(false, startId, buffer, (int) buffer.size()); + + listener.waitUntilStreamClosed(); + // Should only have the first message delivered. + assertEquals(message, listener.messages.get(0)); verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(startId), eq(ErrorCode.CANCEL)); verify(transportListener).transportShutdown(isA(Status.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); @@ -627,6 +642,12 @@ public class OkHttpClientTransportTest { // The second stream should be pending. OkHttpClientStream stream2 = clientTransport.newStream(method, new Metadata.Headers(), listener2); + String sentMessage = "hello"; + InputStream input = new ByteArrayInputStream(sentMessage.getBytes(UTF_8)); + assertEquals(5, input.available()); + stream2.writeMessage(input); + stream2.flush(); + stream2.halfClose(); waitForStreamPending(1); assertEquals(1, activeStreamCount()); @@ -635,10 +656,16 @@ public class OkHttpClientTransportTest { stream1.cancel(Status.CANCELLED); listener1.waitUntilStreamClosed(); - // The second stream should be active now. + // The second stream should be active now, and the pending data should be sent out. assertEquals(1, activeStreamCount()); assertEquals(0, clientTransport.getPendingStreamSize()); - stream2.cancel(Status.CANCELLED); + ArgumentCaptor captor = ArgumentCaptor.forClass(Buffer.class); + verify(frameWriter, timeout(TIME_OUT_MS)) + .data(eq(false), eq(5), captor.capture(), eq(5 + HEADER_LENGTH)); + Buffer sentFrame = captor.getValue(); + assertEquals(createMessageFrame(sentMessage), sentFrame); + verify(frameWriter, timeout(TIME_OUT_MS)).data(eq(true), eq(5), any(Buffer.class), eq(0)); + stream2.sendCancel(Status.CANCELLED); } @Test