diff --git a/binder/src/main/java/io/grpc/binder/internal/Inbound.java b/binder/src/main/java/io/grpc/binder/internal/Inbound.java index da1a296154..5ab96085a4 100644 --- a/binder/src/main/java/io/grpc/binder/internal/Inbound.java +++ b/binder/src/main/java/io/grpc/binder/internal/Inbound.java @@ -468,7 +468,7 @@ abstract class Inbound implements StreamListener.Messa if (firstMessage != null) { stream = firstMessage; firstMessage = null; - } else if (messageAvailable()) { + } else if (numRequestedMessages > 0 && messageAvailable()) { stream = assembleNextMessage(); } if (stream != null) { diff --git a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java index 06dfef5daa..d70d8bc6b8 100644 --- a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java @@ -1503,6 +1503,36 @@ public abstract class AbstractTransportTest { return count; } + @Test + public void messageProducerOnlyProducesRequestedMessages() throws Exception { + server.start(serverListener); + client = newClientTransport(server); + startTransport(client, mockClientTransportListener); + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); + serverTransport = serverTransportListener.transport; + + // Start an RPC. + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); + StreamCreation serverStreamCreation = + serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertEquals(methodDescriptor.getFullMethodName(), serverStreamCreation.method); + + // Have the client send two messages. + clientStream.writeMessage(methodDescriptor.streamRequest("MESSAGE")); + clientStream.writeMessage(methodDescriptor.streamRequest("MESSAGE")); + clientStream.flush(); + + doPingPong(serverListener); + + // Verify server only receives one message if that's all it requests. + serverStreamCreation.stream.request(1); + verifyMessageCountAndClose(serverStreamCreation.listener.messageQueue, 1); + } + @Test public void interactionsAfterServerStreamCloseAreNoops() throws Exception { server.start(serverListener);