services: Flow control for proto reflection service

This commit is contained in:
Eric Gribkoff 2016-12-14 16:14:22 -08:00 committed by GitHub
parent 221fadcbdd
commit aada0780b8
2 changed files with 239 additions and 47 deletions

View File

@ -53,6 +53,7 @@ import io.grpc.reflection.v1alpha.ServerReflectionGrpc;
import io.grpc.reflection.v1alpha.ServerReflectionRequest;
import io.grpc.reflection.v1alpha.ServerReflectionResponse;
import io.grpc.reflection.v1alpha.ServiceResponse;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.util.HashMap;
@ -88,18 +89,30 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
@Override
public StreamObserver<ServerReflectionRequest> serverReflectionInfo(
final StreamObserver<ServerReflectionResponse> responseObserver) {
return new ProtoReflectionStreamObserver(responseObserver);
final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver =
(ServerCallStreamObserver<ServerReflectionResponse>) responseObserver;
ProtoReflectionStreamObserver requestObserver =
new ProtoReflectionStreamObserver(serverCallStreamObserver);
serverCallStreamObserver.setOnReadyHandler(requestObserver);
serverCallStreamObserver.disableAutoInboundFlowControl();
serverCallStreamObserver.request(1);
return requestObserver;
}
private class ProtoReflectionStreamObserver implements StreamObserver<ServerReflectionRequest> {
private final StreamObserver<ServerReflectionResponse> responseObserver;
private class ProtoReflectionStreamObserver implements Runnable,
StreamObserver<ServerReflectionRequest> {
private final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver;
private Set<String> serviceNames;
private Map<String, FileDescriptor> fileDescriptorsByName;
private Map<String, FileDescriptor> fileDescriptorsBySymbol;
private Map<String, Map<Integer, FileDescriptor>> fileDescriptorsByExtensionAndNumber;
ProtoReflectionStreamObserver(StreamObserver<ServerReflectionResponse> responseObserver) {
this.responseObserver = responseObserver;
private boolean closeAfterSend = false;
private ServerReflectionRequest request;
ProtoReflectionStreamObserver(
ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver) {
this.serverCallStreamObserver = serverCallStreamObserver;
}
private void processExtension(FieldDescriptor extension, FileDescriptor fd) {
@ -200,45 +213,70 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
}
}
@Override
public void run() {
if (request != null) {
handleReflectionRequest();
}
}
@Override
public void onNext(ServerReflectionRequest request) {
Preconditions.checkState(this.request == null);
this.request = request;
handleReflectionRequest();
}
private void handleReflectionRequest() {
if (serverCallStreamObserver.isReady()) {
initFileDescriptorMaps();
switch (request.getMessageRequestCase()) {
case FILE_BY_FILENAME:
getFileByName(request);
return;
break;
case FILE_CONTAINING_SYMBOL:
getFileContainingSymbol(request);
return;
break;
case FILE_CONTAINING_EXTENSION:
getFileByExtension(request);
return;
break;
case ALL_EXTENSION_NUMBERS_OF_TYPE:
getAllExtensions(request);
return;
break;
case LIST_SERVICES:
listServices(request);
return;
break;
default:
sendErrorResponse(request, Status.UNIMPLEMENTED, "");
}
request = null;
if (closeAfterSend) {
serverCallStreamObserver.onCompleted();
} else {
serverCallStreamObserver.request(1);
}
}
}
@Override
public void onCompleted() {
responseObserver.onCompleted();
if (request != null) {
closeAfterSend = true;
} else {
serverCallStreamObserver.onCompleted();
}
}
@Override
public void onError(Throwable cause) {
responseObserver.onError(cause);
serverCallStreamObserver.onError(cause);
}
private void getFileByName(ServerReflectionRequest request) {
String name = request.getFileByFilename();
FileDescriptor fd = fileDescriptorsByName.get(name);
if (fd != null) {
responseObserver.onNext(createServerReflectionResponse(request, fd));
serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd));
} else {
sendErrorResponse(request, Status.NOT_FOUND, "File not found.");
}
@ -248,7 +286,7 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
String symbol = request.getFileContainingSymbol();
if (fileDescriptorsBySymbol.containsKey(symbol)) {
FileDescriptor fd = fileDescriptorsBySymbol.get(symbol);
responseObserver.onNext(createServerReflectionResponse(request, fd));
serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd));
return;
}
sendErrorResponse(request, Status.NOT_FOUND, "Symbol not found.");
@ -262,7 +300,7 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
&& fileDescriptorsByExtensionAndNumber
.get(containingType)
.containsKey(extensionNumber)) {
responseObserver.onNext(
serverCallStreamObserver.onNext(
createServerReflectionResponse(
request,
fileDescriptorsByExtensionAndNumber.get(containingType).get(extensionNumber)));
@ -279,7 +317,7 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
for (int extensionNumber : fileDescriptorsByExtensionAndNumber.get(type).keySet()) {
builder.addExtensionNumber(extensionNumber);
}
responseObserver.onNext(
serverCallStreamObserver.onNext(
ServerReflectionResponse.newBuilder()
.setValidHost(request.getHost())
.setOriginalRequest(request)
@ -295,7 +333,7 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
for (String serviceName : serviceNames) {
builder.addService(ServiceResponse.newBuilder().setName(serviceName));
}
responseObserver.onNext(
serverCallStreamObserver.onNext(
ServerReflectionResponse.newBuilder()
.setValidHost(request.getHost())
.setOriginalRequest(request)
@ -314,7 +352,7 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
.setErrorCode(status.getCode().value())
.setErrorMessage(message))
.build();
responseObserver.onNext(response);
serverCallStreamObserver.onNext(response);
}
private ServerReflectionResponse createServerReflectionResponse(

View File

@ -32,12 +32,17 @@
package io.grpc.protobuf.service;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.ServerServiceDefinition;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.ServerImpl;
import io.grpc.reflection.testing.DynamicServiceGrpc;
import io.grpc.reflection.testing.ReflectableServiceGrpc;
import io.grpc.reflection.testing.ReflectionTestDepthThreeProto;
@ -46,46 +51,70 @@ import io.grpc.reflection.testing.ReflectionTestDepthTwoProto;
import io.grpc.reflection.testing.ReflectionTestProto;
import io.grpc.reflection.v1alpha.ExtensionRequest;
import io.grpc.reflection.v1alpha.FileDescriptorResponse;
import io.grpc.reflection.v1alpha.ServerReflectionGrpc;
import io.grpc.reflection.v1alpha.ServerReflectionRequest;
import io.grpc.reflection.v1alpha.ServerReflectionResponse;
import io.grpc.reflection.v1alpha.ServiceResponse;
import io.grpc.stub.ClientCallStreamObserver;
import io.grpc.stub.ClientResponseObserver;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.StreamRecorder;
import io.grpc.util.MutableHandlerRegistry;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
/**
* Tests for {@link ProtoReflectionService}.
*/
@RunWith(JUnit4.class)
public class ProtoReflectionServiceTest {
private static final String TEST_HOST = "localhost";
private MutableHandlerRegistry handlerRegistry = new MutableHandlerRegistry();
private ProtoReflectionService reflectionService;
private ServerImpl server;
private Server server;
private ManagedChannel channel;
private ServerReflectionGrpc.ServerReflectionStub stub;
@Before
public void setUp() throws IOException {
public void setUp() throws Exception {
reflectionService = new ProtoReflectionService();
server = InProcessServerBuilder.forName("proto-reflection-test")
.directExecutor()
.addService(reflectionService)
.addService(new ReflectableServiceGrpc.ReflectableServiceImplBase() {})
.fallbackHandlerRegistry(handlerRegistry)
.build()
.start();
channel = InProcessChannelBuilder.forName("proto-reflection-test")
.directExecutor()
.build();
stub = ServerReflectionGrpc.newStub(channel);
// TODO(ericgribkoff) Remove after fix for https://github.com/grpc/grpc-java/issues/2444 is
// merged.
doNoOpCall();
}
@After
public void tearDown() {
if (server != null) {
server.shutdownNow();
}
if (channel != null) {
channel.shutdownNow();
}
}
@Test
@ -142,9 +171,10 @@ public class ProtoReflectionServiceTest {
StreamRecorder<ServerReflectionResponse> responseObserver = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver =
reflectionService.serverReflectionInfo(responseObserver);
stub.serverReflectionInfo(responseObserver);
requestObserver.onNext(request);
requestObserver.onCompleted();
assertEquals(goldenResponse, responseObserver.firstValue().get());
}
@ -166,7 +196,7 @@ public class ProtoReflectionServiceTest {
StreamRecorder<ServerReflectionResponse> responseObserver = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver =
reflectionService.serverReflectionInfo(responseObserver);
stub.serverReflectionInfo(responseObserver);
requestObserver.onNext(request);
requestObserver.onCompleted();
@ -198,7 +228,7 @@ public class ProtoReflectionServiceTest {
StreamRecorder<ServerReflectionResponse> responseObserver = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver =
reflectionService.serverReflectionInfo(responseObserver);
stub.serverReflectionInfo(responseObserver);
requestObserver.onNext(request);
requestObserver.onCompleted();
assertEquals(goldenResponse, responseObserver.firstValue().get());
@ -226,7 +256,7 @@ public class ProtoReflectionServiceTest {
StreamRecorder<ServerReflectionResponse> responseObserver = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver =
reflectionService.serverReflectionInfo(responseObserver);
stub.serverReflectionInfo(responseObserver);
requestObserver.onNext(request);
requestObserver.onCompleted();
@ -264,7 +294,7 @@ public class ProtoReflectionServiceTest {
StreamRecorder<ServerReflectionResponse> responseObserver = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver =
reflectionService.serverReflectionInfo(responseObserver);
stub.serverReflectionInfo(responseObserver);
requestObserver.onNext(request);
requestObserver.onCompleted();
assertEquals(goldenResponse, responseObserver.firstValue().get());
@ -282,7 +312,7 @@ public class ProtoReflectionServiceTest {
StreamRecorder<ServerReflectionResponse> responseObserver = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver =
reflectionService.serverReflectionInfo(responseObserver);
stub.serverReflectionInfo(responseObserver);
requestObserver.onNext(request);
requestObserver.onCompleted();
Set<Integer> extensionNumberResponseSet =
@ -295,13 +325,110 @@ public class ProtoReflectionServiceTest {
assertEquals(goldenResponse, extensionNumberResponseSet);
}
@Test
public void flowControl() throws Exception {
FlowControlClientResponseObserver clientResponseObserver =
new FlowControlClientResponseObserver();
ClientCallStreamObserver<ServerReflectionRequest> requestObserver =
(ClientCallStreamObserver<ServerReflectionRequest>)
stub.serverReflectionInfo(clientResponseObserver);
// ClientCalls.startCall() calls request(1) initially, so we should get an immediate response.
requestObserver.onNext(flowControlRequest);
assertEquals(1, clientResponseObserver.getResponses().size());
assertEquals(flowControlGoldenResponse, clientResponseObserver.getResponses().get(0));
// Verify we don't receive an additional response until we request it.
requestObserver.onNext(flowControlRequest);
assertEquals(1, clientResponseObserver.getResponses().size());
requestObserver.request(1);
assertEquals(2, clientResponseObserver.getResponses().size());
assertEquals(flowControlGoldenResponse, clientResponseObserver.getResponses().get(1));
requestObserver.onCompleted();
assertTrue(clientResponseObserver.onCompleteCalled());
}
@Test
public void flowControlOnCompleteWithPendingRequest() throws Exception {
FlowControlClientResponseObserver clientResponseObserver =
new FlowControlClientResponseObserver();
ClientCallStreamObserver<ServerReflectionRequest> requestObserver =
(ClientCallStreamObserver<ServerReflectionRequest>)
stub.serverReflectionInfo(clientResponseObserver);
// ClientCalls.startCall() calls request(1) initially, so make additional request.
requestObserver.onNext(flowControlRequest);
requestObserver.onNext(flowControlRequest);
requestObserver.onCompleted();
assertEquals(1, clientResponseObserver.getResponses().size());
assertFalse(clientResponseObserver.onCompleteCalled());
requestObserver.request(1);
assertTrue(clientResponseObserver.onCompleteCalled());
assertEquals(2, clientResponseObserver.getResponses().size());
assertEquals(flowControlGoldenResponse, clientResponseObserver.getResponses().get(1));
}
private final ServerReflectionRequest flowControlRequest =
ServerReflectionRequest.newBuilder()
.setHost(TEST_HOST)
.setFileByFilename("io/grpc/reflection/testing/reflection_test_depth_three.proto")
.build();
private final ServerReflectionResponse flowControlGoldenResponse =
ServerReflectionResponse.newBuilder()
.setValidHost(TEST_HOST)
.setOriginalRequest(flowControlRequest)
.setFileDescriptorResponse(
FileDescriptorResponse.newBuilder()
.addFileDescriptorProto(
ReflectionTestDepthThreeProto.getDescriptor().toProto().toByteString())
.build())
.build();
private static class FlowControlClientResponseObserver implements
ClientResponseObserver<ServerReflectionRequest, ServerReflectionResponse> {
private final List<ServerReflectionResponse> responses =
new ArrayList<ServerReflectionResponse>();
private boolean onCompleteCalled = false;
@Override
public void beforeStart(
final ClientCallStreamObserver<ServerReflectionRequest> requestStream) {
requestStream.disableAutoInboundFlowControl();
}
@Override
public void onNext(ServerReflectionResponse value) {
responses.add(value);
}
@Override
public void onError(Throwable t) {
fail("onError called");
}
@Override
public void onCompleted() {
onCompleteCalled = true;
}
public List<ServerReflectionResponse> getResponses() {
return responses;
}
public boolean onCompleteCalled() {
return onCompleteCalled;
}
}
private void assertServiceResponseEquals(Set<ServiceResponse> goldenResponse) throws Exception {
ServerReflectionRequest request =
ServerReflectionRequest.newBuilder().setHost(TEST_HOST).setListServices("services").build();
StreamRecorder<ServerReflectionResponse> responseObserver = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver =
reflectionService.serverReflectionInfo(responseObserver);
stub.serverReflectionInfo(responseObserver);
requestObserver.onNext(request);
requestObserver.onCompleted();
List<ServiceResponse> response =
@ -309,4 +436,31 @@ public class ProtoReflectionServiceTest {
assertEquals(goldenResponse.size(), response.size());
assertEquals(goldenResponse, new HashSet<ServiceResponse>(response));
}
// TODO(ericgribkoff) Remove after fix for https://github.com/grpc/grpc-java/issues/2444 is
// merged.
private void doNoOpCall() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
ServerReflectionRequest request =
ServerReflectionRequest.newBuilder().setHost(TEST_HOST).setListServices("services").build();
StreamObserver<ServerReflectionRequest> requestObserver =
stub.serverReflectionInfo(new StreamObserver<ServerReflectionResponse>() {
@Override
public void onNext(ServerReflectionResponse value) {
}
@Override
public void onError(Throwable t) {
latch.countDown();
}
@Override
public void onCompleted() {
latch.countDown();
}
});
requestObserver.onNext(request);
requestObserver.onCompleted();
assertTrue(latch.await(5, TimeUnit.SECONDS));
}
}