From aada0780b83079cc5857406647997f6dee0c7cd1 Mon Sep 17 00:00:00 2001 From: Eric Gribkoff Date: Wed, 14 Dec 2016 16:14:22 -0800 Subject: [PATCH] services: Flow control for proto reflection service --- .../service/ProtoReflectionService.java | 102 +++++++--- .../service/ProtoReflectionServiceTest.java | 184 ++++++++++++++++-- 2 files changed, 239 insertions(+), 47 deletions(-) diff --git a/services/src/main/java/io/grpc/protobuf/service/ProtoReflectionService.java b/services/src/main/java/io/grpc/protobuf/service/ProtoReflectionService.java index e208b2b082..52d0450d99 100644 --- a/services/src/main/java/io/grpc/protobuf/service/ProtoReflectionService.java +++ b/services/src/main/java/io/grpc/protobuf/service/ProtoReflectionService.java @@ -53,6 +53,7 @@ import io.grpc.reflection.v1alpha.ServerReflectionGrpc; import io.grpc.reflection.v1alpha.ServerReflectionRequest; import io.grpc.reflection.v1alpha.ServerReflectionResponse; import io.grpc.reflection.v1alpha.ServiceResponse; +import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import java.util.HashMap; @@ -88,18 +89,30 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef @Override public StreamObserver serverReflectionInfo( final StreamObserver responseObserver) { - return new ProtoReflectionStreamObserver(responseObserver); + final ServerCallStreamObserver serverCallStreamObserver = + (ServerCallStreamObserver) responseObserver; + ProtoReflectionStreamObserver requestObserver = + new ProtoReflectionStreamObserver(serverCallStreamObserver); + serverCallStreamObserver.setOnReadyHandler(requestObserver); + serverCallStreamObserver.disableAutoInboundFlowControl(); + serverCallStreamObserver.request(1); + return requestObserver; } - private class ProtoReflectionStreamObserver implements StreamObserver { - private final StreamObserver responseObserver; + private class ProtoReflectionStreamObserver implements Runnable, + StreamObserver { + private final ServerCallStreamObserver serverCallStreamObserver; private Set serviceNames; private Map fileDescriptorsByName; private Map fileDescriptorsBySymbol; private Map> fileDescriptorsByExtensionAndNumber; - ProtoReflectionStreamObserver(StreamObserver responseObserver) { - this.responseObserver = responseObserver; + private boolean closeAfterSend = false; + private ServerReflectionRequest request; + + ProtoReflectionStreamObserver( + ServerCallStreamObserver serverCallStreamObserver) { + this.serverCallStreamObserver = serverCallStreamObserver; } private void processExtension(FieldDescriptor extension, FileDescriptor fd) { @@ -200,45 +213,70 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef } } + @Override + public void run() { + if (request != null) { + handleReflectionRequest(); + } + } + @Override public void onNext(ServerReflectionRequest request) { - initFileDescriptorMaps(); - switch (request.getMessageRequestCase()) { - case FILE_BY_FILENAME: - getFileByName(request); - return; - case FILE_CONTAINING_SYMBOL: - getFileContainingSymbol(request); - return; - case FILE_CONTAINING_EXTENSION: - getFileByExtension(request); - return; - case ALL_EXTENSION_NUMBERS_OF_TYPE: - getAllExtensions(request); - return; - case LIST_SERVICES: - listServices(request); - return; - default: - sendErrorResponse(request, Status.UNIMPLEMENTED, ""); + Preconditions.checkState(this.request == null); + this.request = request; + handleReflectionRequest(); + } + + private void handleReflectionRequest() { + if (serverCallStreamObserver.isReady()) { + initFileDescriptorMaps(); + switch (request.getMessageRequestCase()) { + case FILE_BY_FILENAME: + getFileByName(request); + break; + case FILE_CONTAINING_SYMBOL: + getFileContainingSymbol(request); + break; + case FILE_CONTAINING_EXTENSION: + getFileByExtension(request); + break; + case ALL_EXTENSION_NUMBERS_OF_TYPE: + getAllExtensions(request); + break; + case LIST_SERVICES: + listServices(request); + break; + default: + sendErrorResponse(request, Status.UNIMPLEMENTED, ""); + } + request = null; + if (closeAfterSend) { + serverCallStreamObserver.onCompleted(); + } else { + serverCallStreamObserver.request(1); + } } } @Override public void onCompleted() { - responseObserver.onCompleted(); + if (request != null) { + closeAfterSend = true; + } else { + serverCallStreamObserver.onCompleted(); + } } @Override public void onError(Throwable cause) { - responseObserver.onError(cause); + serverCallStreamObserver.onError(cause); } private void getFileByName(ServerReflectionRequest request) { String name = request.getFileByFilename(); FileDescriptor fd = fileDescriptorsByName.get(name); if (fd != null) { - responseObserver.onNext(createServerReflectionResponse(request, fd)); + serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); } else { sendErrorResponse(request, Status.NOT_FOUND, "File not found."); } @@ -248,7 +286,7 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef String symbol = request.getFileContainingSymbol(); if (fileDescriptorsBySymbol.containsKey(symbol)) { FileDescriptor fd = fileDescriptorsBySymbol.get(symbol); - responseObserver.onNext(createServerReflectionResponse(request, fd)); + serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); return; } sendErrorResponse(request, Status.NOT_FOUND, "Symbol not found."); @@ -262,7 +300,7 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef && fileDescriptorsByExtensionAndNumber .get(containingType) .containsKey(extensionNumber)) { - responseObserver.onNext( + serverCallStreamObserver.onNext( createServerReflectionResponse( request, fileDescriptorsByExtensionAndNumber.get(containingType).get(extensionNumber))); @@ -279,7 +317,7 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef for (int extensionNumber : fileDescriptorsByExtensionAndNumber.get(type).keySet()) { builder.addExtensionNumber(extensionNumber); } - responseObserver.onNext( + serverCallStreamObserver.onNext( ServerReflectionResponse.newBuilder() .setValidHost(request.getHost()) .setOriginalRequest(request) @@ -295,7 +333,7 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef for (String serviceName : serviceNames) { builder.addService(ServiceResponse.newBuilder().setName(serviceName)); } - responseObserver.onNext( + serverCallStreamObserver.onNext( ServerReflectionResponse.newBuilder() .setValidHost(request.getHost()) .setOriginalRequest(request) @@ -314,7 +352,7 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef .setErrorCode(status.getCode().value()) .setErrorMessage(message)) .build(); - responseObserver.onNext(response); + serverCallStreamObserver.onNext(response); } private ServerReflectionResponse createServerReflectionResponse( diff --git a/services/src/test/java/io/grpc/protobuf/service/ProtoReflectionServiceTest.java b/services/src/test/java/io/grpc/protobuf/service/ProtoReflectionServiceTest.java index 1f3cf48089..2705e7fb88 100644 --- a/services/src/test/java/io/grpc/protobuf/service/ProtoReflectionServiceTest.java +++ b/services/src/test/java/io/grpc/protobuf/service/ProtoReflectionServiceTest.java @@ -32,12 +32,17 @@ package io.grpc.protobuf.service; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import com.google.protobuf.ByteString; +import io.grpc.ManagedChannel; +import io.grpc.Server; import io.grpc.ServerServiceDefinition; +import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; -import io.grpc.internal.ServerImpl; import io.grpc.reflection.testing.DynamicServiceGrpc; import io.grpc.reflection.testing.ReflectableServiceGrpc; import io.grpc.reflection.testing.ReflectionTestDepthThreeProto; @@ -46,46 +51,70 @@ import io.grpc.reflection.testing.ReflectionTestDepthTwoProto; import io.grpc.reflection.testing.ReflectionTestProto; import io.grpc.reflection.v1alpha.ExtensionRequest; import io.grpc.reflection.v1alpha.FileDescriptorResponse; +import io.grpc.reflection.v1alpha.ServerReflectionGrpc; import io.grpc.reflection.v1alpha.ServerReflectionRequest; import io.grpc.reflection.v1alpha.ServerReflectionResponse; import io.grpc.reflection.v1alpha.ServiceResponse; +import io.grpc.stub.ClientCallStreamObserver; +import io.grpc.stub.ClientResponseObserver; import io.grpc.stub.StreamObserver; import io.grpc.testing.StreamRecorder; import io.grpc.util.MutableHandlerRegistry; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; /** * Tests for {@link ProtoReflectionService}. */ @RunWith(JUnit4.class) public class ProtoReflectionServiceTest { - private static final String TEST_HOST = "localhost"; - private MutableHandlerRegistry handlerRegistry = new MutableHandlerRegistry(); - private ProtoReflectionService reflectionService; - - private ServerImpl server; + private Server server; + private ManagedChannel channel; + private ServerReflectionGrpc.ServerReflectionStub stub; @Before - public void setUp() throws IOException { + public void setUp() throws Exception { reflectionService = new ProtoReflectionService(); server = InProcessServerBuilder.forName("proto-reflection-test") + .directExecutor() .addService(reflectionService) .addService(new ReflectableServiceGrpc.ReflectableServiceImplBase() {}) .fallbackHandlerRegistry(handlerRegistry) + .build() + .start(); + channel = InProcessChannelBuilder.forName("proto-reflection-test") + .directExecutor() .build(); + stub = ServerReflectionGrpc.newStub(channel); + + // TODO(ericgribkoff) Remove after fix for https://github.com/grpc/grpc-java/issues/2444 is + // merged. + doNoOpCall(); + } + + @After + public void tearDown() { + if (server != null) { + server.shutdownNow(); + } + if (channel != null) { + channel.shutdownNow(); + } } @Test @@ -142,9 +171,10 @@ public class ProtoReflectionServiceTest { StreamRecorder responseObserver = StreamRecorder.create(); StreamObserver requestObserver = - reflectionService.serverReflectionInfo(responseObserver); + stub.serverReflectionInfo(responseObserver); requestObserver.onNext(request); requestObserver.onCompleted(); + assertEquals(goldenResponse, responseObserver.firstValue().get()); } @@ -166,7 +196,7 @@ public class ProtoReflectionServiceTest { StreamRecorder responseObserver = StreamRecorder.create(); StreamObserver requestObserver = - reflectionService.serverReflectionInfo(responseObserver); + stub.serverReflectionInfo(responseObserver); requestObserver.onNext(request); requestObserver.onCompleted(); @@ -198,7 +228,7 @@ public class ProtoReflectionServiceTest { StreamRecorder responseObserver = StreamRecorder.create(); StreamObserver requestObserver = - reflectionService.serverReflectionInfo(responseObserver); + stub.serverReflectionInfo(responseObserver); requestObserver.onNext(request); requestObserver.onCompleted(); assertEquals(goldenResponse, responseObserver.firstValue().get()); @@ -226,7 +256,7 @@ public class ProtoReflectionServiceTest { StreamRecorder responseObserver = StreamRecorder.create(); StreamObserver requestObserver = - reflectionService.serverReflectionInfo(responseObserver); + stub.serverReflectionInfo(responseObserver); requestObserver.onNext(request); requestObserver.onCompleted(); @@ -264,7 +294,7 @@ public class ProtoReflectionServiceTest { StreamRecorder responseObserver = StreamRecorder.create(); StreamObserver requestObserver = - reflectionService.serverReflectionInfo(responseObserver); + stub.serverReflectionInfo(responseObserver); requestObserver.onNext(request); requestObserver.onCompleted(); assertEquals(goldenResponse, responseObserver.firstValue().get()); @@ -282,7 +312,7 @@ public class ProtoReflectionServiceTest { StreamRecorder responseObserver = StreamRecorder.create(); StreamObserver requestObserver = - reflectionService.serverReflectionInfo(responseObserver); + stub.serverReflectionInfo(responseObserver); requestObserver.onNext(request); requestObserver.onCompleted(); Set extensionNumberResponseSet = @@ -295,13 +325,110 @@ public class ProtoReflectionServiceTest { assertEquals(goldenResponse, extensionNumberResponseSet); } + @Test + public void flowControl() throws Exception { + FlowControlClientResponseObserver clientResponseObserver = + new FlowControlClientResponseObserver(); + ClientCallStreamObserver requestObserver = + (ClientCallStreamObserver) + stub.serverReflectionInfo(clientResponseObserver); + + // ClientCalls.startCall() calls request(1) initially, so we should get an immediate response. + requestObserver.onNext(flowControlRequest); + assertEquals(1, clientResponseObserver.getResponses().size()); + assertEquals(flowControlGoldenResponse, clientResponseObserver.getResponses().get(0)); + + // Verify we don't receive an additional response until we request it. + requestObserver.onNext(flowControlRequest); + assertEquals(1, clientResponseObserver.getResponses().size()); + + requestObserver.request(1); + assertEquals(2, clientResponseObserver.getResponses().size()); + assertEquals(flowControlGoldenResponse, clientResponseObserver.getResponses().get(1)); + + requestObserver.onCompleted(); + assertTrue(clientResponseObserver.onCompleteCalled()); + } + + @Test + public void flowControlOnCompleteWithPendingRequest() throws Exception { + FlowControlClientResponseObserver clientResponseObserver = + new FlowControlClientResponseObserver(); + ClientCallStreamObserver requestObserver = + (ClientCallStreamObserver) + stub.serverReflectionInfo(clientResponseObserver); + + // ClientCalls.startCall() calls request(1) initially, so make additional request. + requestObserver.onNext(flowControlRequest); + requestObserver.onNext(flowControlRequest); + requestObserver.onCompleted(); + assertEquals(1, clientResponseObserver.getResponses().size()); + assertFalse(clientResponseObserver.onCompleteCalled()); + + requestObserver.request(1); + assertTrue(clientResponseObserver.onCompleteCalled()); + assertEquals(2, clientResponseObserver.getResponses().size()); + assertEquals(flowControlGoldenResponse, clientResponseObserver.getResponses().get(1)); + } + + private final ServerReflectionRequest flowControlRequest = + ServerReflectionRequest.newBuilder() + .setHost(TEST_HOST) + .setFileByFilename("io/grpc/reflection/testing/reflection_test_depth_three.proto") + .build(); + private final ServerReflectionResponse flowControlGoldenResponse = + ServerReflectionResponse.newBuilder() + .setValidHost(TEST_HOST) + .setOriginalRequest(flowControlRequest) + .setFileDescriptorResponse( + FileDescriptorResponse.newBuilder() + .addFileDescriptorProto( + ReflectionTestDepthThreeProto.getDescriptor().toProto().toByteString()) + .build()) + .build(); + + private static class FlowControlClientResponseObserver implements + ClientResponseObserver { + private final List responses = + new ArrayList(); + private boolean onCompleteCalled = false; + + @Override + public void beforeStart( + final ClientCallStreamObserver requestStream) { + requestStream.disableAutoInboundFlowControl(); + } + + @Override + public void onNext(ServerReflectionResponse value) { + responses.add(value); + } + + @Override + public void onError(Throwable t) { + fail("onError called"); + } + + @Override + public void onCompleted() { + onCompleteCalled = true; + } + + public List getResponses() { + return responses; + } + + public boolean onCompleteCalled() { + return onCompleteCalled; + } + } private void assertServiceResponseEquals(Set goldenResponse) throws Exception { ServerReflectionRequest request = ServerReflectionRequest.newBuilder().setHost(TEST_HOST).setListServices("services").build(); StreamRecorder responseObserver = StreamRecorder.create(); StreamObserver requestObserver = - reflectionService.serverReflectionInfo(responseObserver); + stub.serverReflectionInfo(responseObserver); requestObserver.onNext(request); requestObserver.onCompleted(); List response = @@ -309,4 +436,31 @@ public class ProtoReflectionServiceTest { assertEquals(goldenResponse.size(), response.size()); assertEquals(goldenResponse, new HashSet(response)); } + + // TODO(ericgribkoff) Remove after fix for https://github.com/grpc/grpc-java/issues/2444 is + // merged. + private void doNoOpCall() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + ServerReflectionRequest request = + ServerReflectionRequest.newBuilder().setHost(TEST_HOST).setListServices("services").build(); + StreamObserver requestObserver = + stub.serverReflectionInfo(new StreamObserver() { + @Override + public void onNext(ServerReflectionResponse value) { + } + + @Override + public void onError(Throwable t) { + latch.countDown(); + } + + @Override + public void onCompleted() { + latch.countDown(); + } + }); + requestObserver.onNext(request); + requestObserver.onCompleted(); + assertTrue(latch.await(5, TimeUnit.SECONDS)); + } }