services: don't update reflection index mid-stream

This fix addresses https://github.com/grpc/grpc-java/issues/2689
This commit is contained in:
Eric Gribkoff 2017-02-28 20:05:26 -08:00 committed by GitHub
parent 437fafab1b
commit e345273268
2 changed files with 261 additions and 204 deletions

View File

@ -79,49 +79,104 @@ import javax.annotation.concurrent.GuardedBy;
public final class ProtoReflectionService extends ServerReflectionGrpc.ServerReflectionImplBase public final class ProtoReflectionService extends ServerReflectionGrpc.ServerReflectionImplBase
implements InternalNotifyOnServerBuild { implements InternalNotifyOnServerBuild {
private volatile ServerReflectionIndex serverReflectionIndex; private final Object lock = new Object();
@GuardedBy("lock")
private ServerReflectionIndex serverReflectionIndex;
private Server server;
private ProtoReflectionService() {} private ProtoReflectionService() {}
public static BindableService getInstance() { public static BindableService newInstance() {
return new ProtoReflectionService(); return new ProtoReflectionService();
} }
/** /**
* Receives a reference to the server at build time. * Do not use this method.
*
* @deprecated use {@link ProtoReflectionService#newInstance()} instead.
*/ */
@Deprecated
public static BindableService getInstance() {
return newInstance();
}
/** Receives a reference to the server at build time. */
@Override @Override
public void notifyOnBuild(Server server) { public void notifyOnBuild(Server server) {
checkState(serverReflectionIndex == null); this.server = checkNotNull(server);
serverReflectionIndex = new ServerReflectionIndex(checkNotNull(server, "server")); }
/**
* Checks for updates to the server's mutable services and updates the index if any changes are
* detected. A change is any addition or removal in the set of file descriptors attached to the
* mutable services or a change in the service names.
*
* @return The (potentially updated) index.
*/
private ServerReflectionIndex updateIndexIfNecessary() {
synchronized (lock) {
if (serverReflectionIndex == null) {
serverReflectionIndex =
new ServerReflectionIndex(server.getImmutableServices(), server.getMutableServices());
return serverReflectionIndex;
}
Set<FileDescriptor> serverFileDescriptors = new HashSet<FileDescriptor>();
Set<String> serverServiceNames = new HashSet<String>();
List<ServerServiceDefinition> serverMutableServices = server.getMutableServices();
for (ServerServiceDefinition mutableService : serverMutableServices) {
io.grpc.ServiceDescriptor serviceDescriptor = mutableService.getServiceDescriptor();
if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) {
String serviceName = serviceDescriptor.getName();
FileDescriptor fileDescriptor =
((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor())
.getFileDescriptor();
serverFileDescriptors.add(fileDescriptor);
serverServiceNames.add(serviceName);
}
}
// Replace the index if the underlying mutable services have changed. Check both the file
// descriptors and the service names, because one file descriptor can define multiple
// services.
FileDescriptorIndex mutableServicesIndex = serverReflectionIndex.getMutableServicesIndex();
if (!mutableServicesIndex.getServiceFileDescriptors().equals(serverFileDescriptors)
|| !mutableServicesIndex.getServiceNames().equals(serverServiceNames)) {
serverReflectionIndex =
new ServerReflectionIndex(server.getImmutableServices(), serverMutableServices);
}
return serverReflectionIndex;
}
} }
@Override @Override
public StreamObserver<ServerReflectionRequest> serverReflectionInfo( public StreamObserver<ServerReflectionRequest> serverReflectionInfo(
final StreamObserver<ServerReflectionResponse> responseObserver) { final StreamObserver<ServerReflectionResponse> responseObserver) {
checkState(serverReflectionIndex != null);
serverReflectionIndex.initializeImmutableServicesIndex();
final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver = final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver =
(ServerCallStreamObserver<ServerReflectionResponse>) responseObserver; (ServerCallStreamObserver<ServerReflectionResponse>) responseObserver;
ProtoReflectionStreamObserver requestObserver = ProtoReflectionStreamObserver requestObserver =
new ProtoReflectionStreamObserver(serverCallStreamObserver); new ProtoReflectionStreamObserver(updateIndexIfNecessary(), serverCallStreamObserver);
serverCallStreamObserver.setOnReadyHandler(requestObserver); serverCallStreamObserver.setOnReadyHandler(requestObserver);
serverCallStreamObserver.disableAutoInboundFlowControl(); serverCallStreamObserver.disableAutoInboundFlowControl();
serverCallStreamObserver.request(1); serverCallStreamObserver.request(1);
return requestObserver; return requestObserver;
} }
private class ProtoReflectionStreamObserver implements Runnable, private static class ProtoReflectionStreamObserver
StreamObserver<ServerReflectionRequest> { implements Runnable, StreamObserver<ServerReflectionRequest> {
private final ServerReflectionIndex serverReflectionIndex;
private final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver; private final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver;
private boolean closeAfterSend = false; private boolean closeAfterSend = false;
private ServerReflectionRequest request; private ServerReflectionRequest request;
ProtoReflectionStreamObserver( ProtoReflectionStreamObserver(
ServerReflectionIndex serverReflectionIndex,
ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver) { ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver) {
this.serverReflectionIndex = serverReflectionIndex;
this.serverCallStreamObserver = checkNotNull(serverCallStreamObserver, "observer"); this.serverCallStreamObserver = checkNotNull(serverCallStreamObserver, "observer");
} }
@ -141,8 +196,6 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
private void handleReflectionRequest() { private void handleReflectionRequest() {
if (serverCallStreamObserver.isReady()) { if (serverCallStreamObserver.isReady()) {
serverReflectionIndex.updateMutableIndexIfNecessary();
switch (request.getMessageRequestCase()) { switch (request.getMessageRequestCase()) {
case FILE_BY_FILENAME: case FILE_BY_FILENAME:
getFileByName(request); getFileByName(request);
@ -250,8 +303,7 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
.build()); .build());
} }
private void sendErrorResponse( private void sendErrorResponse(ServerReflectionRequest request, Status status, String message) {
ServerReflectionRequest request, Status status, String message) {
ServerReflectionResponse response = ServerReflectionResponse response =
ServerReflectionResponse.newBuilder() ServerReflectionResponse.newBuilder()
.setValidHost(request.getHost()) .setValidHost(request.getHost())
@ -299,71 +351,27 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
* in the immutable service index are the mutable services checked. * in the immutable service index are the mutable services checked.
*/ */
private static final class ServerReflectionIndex { private static final class ServerReflectionIndex {
private FileDescriptorIndex immutableServicesIndex; private final FileDescriptorIndex immutableServicesIndex;
private final Object lock = new Object(); private final FileDescriptorIndex mutableServicesIndex;
/**
* Tracks mutable services. Accesses must be synchronized.
*/
@GuardedBy("lock") private FileDescriptorIndex mutableServicesIndex
= new FileDescriptorIndex(Collections.<ServerServiceDefinition>emptyList());
private final Server server; public ServerReflectionIndex(
List<ServerServiceDefinition> immutableServices,
public ServerReflectionIndex(Server server) { List<ServerServiceDefinition> mutableServices) {
this.server = server; immutableServicesIndex = new FileDescriptorIndex(immutableServices);
mutableServicesIndex = new FileDescriptorIndex(mutableServices);
} }
/** private FileDescriptorIndex getMutableServicesIndex() {
* When first called, initializes the immutable services index. Subsequent calls have no effect. return mutableServicesIndex;
*
* <p>This must be called by the reflection service before returning a new
* {@link ProtoReflectionStreamObserver}.
*/
private synchronized void initializeImmutableServicesIndex() {
if (immutableServicesIndex == null) {
immutableServicesIndex = new FileDescriptorIndex(server.getImmutableServices());
}
}
/**
* Checks for updates to the server's mutable services and updates the index if any changes
* are detected. A change is any addition or removal in the set of file descriptors attached to
* the mutable services or a change in the service names.
*/
private void updateMutableIndexIfNecessary() {
Set<FileDescriptor> currentFileDescriptors = new HashSet<FileDescriptor>();
Set<String> currentServiceNames = new HashSet<String>();
synchronized (lock) {
List<ServerServiceDefinition> currentMutableServices = server.getMutableServices();
for (ServerServiceDefinition mutableService : currentMutableServices) {
io.grpc.ServiceDescriptor serviceDescriptor = mutableService.getServiceDescriptor();
if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) {
String serviceName = serviceDescriptor.getName();
FileDescriptor fileDescriptor =
((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor())
.getFileDescriptor();
currentFileDescriptors.add(fileDescriptor);
checkState(!currentServiceNames.contains(serviceName),
"Service already defined: %s", serviceName);
currentServiceNames.add(serviceName);
}
}
// Replace the mutable index if the underlying services have changed. Check both the file
// descriptors and the service names, because one file descriptor can define multiple
// services.
if (!mutableServicesIndex.getServiceFileDescriptors().equals(currentFileDescriptors)
|| !mutableServicesIndex.getServiceNames().equals(currentServiceNames)) {
mutableServicesIndex = new FileDescriptorIndex(currentMutableServices);
}
}
} }
private Set<String> getServiceNames() { private Set<String> getServiceNames() {
Set<String> serviceNames = new HashSet<String>(immutableServicesIndex.getServiceNames()); Set<String> immutableServiceNames = immutableServicesIndex.getServiceNames();
synchronized (lock) { Set<String> mutableServiceNames = mutableServicesIndex.getServiceNames();
serviceNames.addAll(mutableServicesIndex.getServiceNames()); Set<String> serviceNames =
} new HashSet<String>(immutableServiceNames.size() + mutableServiceNames.size());
serviceNames.addAll(immutableServiceNames);
serviceNames.addAll(mutableServiceNames);
return serviceNames; return serviceNames;
} }
@ -371,9 +379,7 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
private FileDescriptor getFileDescriptorByName(String name) { private FileDescriptor getFileDescriptorByName(String name) {
FileDescriptor fd = immutableServicesIndex.getFileDescriptorByName(name); FileDescriptor fd = immutableServicesIndex.getFileDescriptorByName(name);
if (fd == null) { if (fd == null) {
synchronized (lock) { fd = mutableServicesIndex.getFileDescriptorByName(name);
fd = mutableServicesIndex.getFileDescriptorByName(name);
}
} }
return fd; return fd;
} }
@ -382,21 +388,17 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
private FileDescriptor getFileDescriptorBySymbol(String symbol) { private FileDescriptor getFileDescriptorBySymbol(String symbol) {
FileDescriptor fd = immutableServicesIndex.getFileDescriptorBySymbol(symbol); FileDescriptor fd = immutableServicesIndex.getFileDescriptorBySymbol(symbol);
if (fd == null) { if (fd == null) {
synchronized (lock) { fd = mutableServicesIndex.getFileDescriptorBySymbol(symbol);
fd = mutableServicesIndex.getFileDescriptorBySymbol(symbol);
}
} }
return fd; return fd;
} }
@Nullable @Nullable
private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int extension) { private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int extension) {
FileDescriptor fd FileDescriptor fd =
= immutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension); immutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension);
if (fd == null) { if (fd == null) {
synchronized (lock) { fd = mutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension);
fd = mutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension);
}
} }
return fd; return fd;
} }
@ -405,28 +407,26 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
private Set<Integer> getExtensionNumbersOfType(String type) { private Set<Integer> getExtensionNumbersOfType(String type) {
Set<Integer> extensionNumbers = immutableServicesIndex.getExtensionNumbersOfType(type); Set<Integer> extensionNumbers = immutableServicesIndex.getExtensionNumbersOfType(type);
if (extensionNumbers == null) { if (extensionNumbers == null) {
synchronized (lock) { extensionNumbers = mutableServicesIndex.getExtensionNumbersOfType(type);
extensionNumbers = mutableServicesIndex.getExtensionNumbersOfType(type);
}
} }
return extensionNumbers; return extensionNumbers;
} }
} }
/** /**
* Provides a set of methods for answering reflection queries for the file descriptors * Provides a set of methods for answering reflection queries for the file descriptors underlying
* underlying a set of services. Used by {@link ServerReflectionIndex} to separately index * a set of services. Used by {@link ServerReflectionIndex} to separately index immutable and
* immutable and mutable services. * mutable services.
*/ */
private static final class FileDescriptorIndex { private static final class FileDescriptorIndex {
private final Set<String> serviceNames = new HashSet<String>(); private final Set<String> serviceNames = new HashSet<String>();
private final Set<FileDescriptor> serviceFileDescriptors = new HashSet<FileDescriptor>(); private final Set<FileDescriptor> serviceFileDescriptors = new HashSet<FileDescriptor>();
private final Map<String, FileDescriptor> fileDescriptorsByName private final Map<String, FileDescriptor> fileDescriptorsByName =
= new HashMap<String, FileDescriptor>(); new HashMap<String, FileDescriptor>();
private final Map<String, FileDescriptor> fileDescriptorsBySymbol private final Map<String, FileDescriptor> fileDescriptorsBySymbol =
= new HashMap<String, FileDescriptor>(); new HashMap<String, FileDescriptor>();
private final Map<String, Map<Integer, FileDescriptor>> fileDescriptorsByExtensionAndNumber private final Map<String, Map<Integer, FileDescriptor>> fileDescriptorsByExtensionAndNumber =
= new HashMap<String, Map<Integer, FileDescriptor>>(); new HashMap<String, Map<Integer, FileDescriptor>>();
FileDescriptorIndex(List<ServerServiceDefinition> services) { FileDescriptorIndex(List<ServerServiceDefinition> services) {
Queue<FileDescriptor> fileDescriptorsToProcess = new LinkedList<FileDescriptor>(); Queue<FileDescriptor> fileDescriptorsToProcess = new LinkedList<FileDescriptor>();
@ -438,8 +438,8 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor()) ((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor())
.getFileDescriptor(); .getFileDescriptor();
String serviceName = serviceDescriptor.getName(); String serviceName = serviceDescriptor.getName();
checkState(!serviceNames.contains(serviceName), checkState(
"Service already defined: %s", serviceName); !serviceNames.contains(serviceName), "Service already defined: %s", serviceName);
serviceFileDescriptors.add(fileDescriptor); serviceFileDescriptors.add(fileDescriptor);
serviceNames.add(serviceName); serviceNames.add(serviceName);
if (!seenFiles.contains(fileDescriptor.getName())) { if (!seenFiles.contains(fileDescriptor.getName())) {
@ -501,8 +501,7 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
private void processFileDescriptor(FileDescriptor fd) { private void processFileDescriptor(FileDescriptor fd) {
String fdName = fd.getName(); String fdName = fd.getName();
checkState(!fileDescriptorsByName.containsKey(fdName), checkState(!fileDescriptorsByName.containsKey(fdName), "File name already used: %s", fdName);
"File name already used: %s", fdName);
fileDescriptorsByName.put(fdName, fd); fileDescriptorsByName.put(fdName, fd);
for (ServiceDescriptor service : fd.getServices()) { for (ServiceDescriptor service : fd.getServices()) {
processService(service, fd); processService(service, fd);
@ -517,21 +516,25 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
private void processService(ServiceDescriptor service, FileDescriptor fd) { private void processService(ServiceDescriptor service, FileDescriptor fd) {
String serviceName = service.getFullName(); String serviceName = service.getFullName();
checkState(!fileDescriptorsBySymbol.containsKey(serviceName), checkState(
"Service already defined: %s", serviceName); !fileDescriptorsBySymbol.containsKey(serviceName),
"Service already defined: %s",
serviceName);
fileDescriptorsBySymbol.put(serviceName, fd); fileDescriptorsBySymbol.put(serviceName, fd);
for (MethodDescriptor method : service.getMethods()) { for (MethodDescriptor method : service.getMethods()) {
String methodName = method.getFullName(); String methodName = method.getFullName();
checkState(!fileDescriptorsBySymbol.containsKey(methodName), checkState(
"Method already defined: %s", methodName); !fileDescriptorsBySymbol.containsKey(methodName),
"Method already defined: %s",
methodName);
fileDescriptorsBySymbol.put(methodName, fd); fileDescriptorsBySymbol.put(methodName, fd);
} }
} }
private void processType(Descriptor type, FileDescriptor fd) { private void processType(Descriptor type, FileDescriptor fd) {
String typeName = type.getFullName(); String typeName = type.getFullName();
checkState(!fileDescriptorsBySymbol.containsKey(typeName), checkState(
"Type already defined: %s", typeName); !fileDescriptorsBySymbol.containsKey(typeName), "Type already defined: %s", typeName);
fileDescriptorsBySymbol.put(typeName, fd); fileDescriptorsBySymbol.put(typeName, fd);
for (FieldDescriptor extension : type.getExtensions()) { for (FieldDescriptor extension : type.getExtensions()) {
processExtension(extension, fd); processExtension(extension, fd);
@ -550,7 +553,9 @@ public final class ProtoReflectionService extends ServerReflectionGrpc.ServerRef
} }
checkState( checkState(
!fileDescriptorsByExtensionAndNumber.get(extensionName).containsKey(extensionNumber), !fileDescriptorsByExtensionAndNumber.get(extensionName).containsKey(extensionNumber),
"Extension name and number already defined: %s, %s", extensionName, extensionNumber); "Extension name and number already defined: %s, %s",
extensionName,
extensionNumber);
fileDescriptorsByExtensionAndNumber.get(extensionName).put(extensionNumber, fd); fileDescriptorsByExtensionAndNumber.get(extensionName).put(extensionNumber, fd);
} }
} }

View File

@ -74,9 +74,7 @@ import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
/** /** Tests for {@link ProtoReflectionService}. */
* Tests for {@link ProtoReflectionService}.
*/
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class ProtoReflectionServiceTest { public class ProtoReflectionServiceTest {
private static final String TEST_HOST = "localhost"; private static final String TEST_HOST = "localhost";
@ -92,17 +90,16 @@ public class ProtoReflectionServiceTest {
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
reflectionService = ProtoReflectionService.getInstance(); reflectionService = ProtoReflectionService.newInstance();
server = InProcessServerBuilder.forName("proto-reflection-test") server =
.directExecutor() InProcessServerBuilder.forName("proto-reflection-test")
.addService(reflectionService) .directExecutor()
.addService(new ReflectableServiceGrpc.ReflectableServiceImplBase() {}) .addService(reflectionService)
.fallbackHandlerRegistry(handlerRegistry) .addService(new ReflectableServiceGrpc.ReflectableServiceImplBase() {})
.build() .fallbackHandlerRegistry(handlerRegistry)
.start(); .build()
channel = InProcessChannelBuilder.forName("proto-reflection-test") .start();
.directExecutor() channel = InProcessChannelBuilder.forName("proto-reflection-test").directExecutor().build();
.build();
stub = ServerReflectionGrpc.newStub(channel); stub = ServerReflectionGrpc.newStub(channel);
} }
@ -118,61 +115,61 @@ public class ProtoReflectionServiceTest {
@Test @Test
public void listServices() throws Exception { public void listServices() throws Exception {
Set<ServiceResponse> originalServices = new HashSet<ServiceResponse>( Set<ServiceResponse> originalServices =
Arrays.asList( new HashSet<ServiceResponse>(
ServiceResponse.newBuilder() Arrays.asList(
.setName("grpc.reflection.v1alpha.ServerReflection") ServiceResponse.newBuilder()
.build(), .setName("grpc.reflection.v1alpha.ServerReflection")
ServiceResponse.newBuilder() .build(),
.setName("grpc.reflection.testing.ReflectableService") ServiceResponse.newBuilder()
.build()) .setName("grpc.reflection.testing.ReflectableService")
); .build()));
assertServiceResponseEquals(originalServices); assertServiceResponseEquals(originalServices);
handlerRegistry.addService(dynamicService); handlerRegistry.addService(dynamicService);
assertServiceResponseEquals(new HashSet<ServiceResponse>( assertServiceResponseEquals(
Arrays.asList( new HashSet<ServiceResponse>(
ServiceResponse.newBuilder() Arrays.asList(
.setName("grpc.reflection.v1alpha.ServerReflection") ServiceResponse.newBuilder()
.build(), .setName("grpc.reflection.v1alpha.ServerReflection")
ServiceResponse.newBuilder() .build(),
.setName("grpc.reflection.testing.ReflectableService") ServiceResponse.newBuilder()
.build(), .setName("grpc.reflection.testing.ReflectableService")
ServiceResponse.newBuilder() .build(),
.setName("grpc.reflection.testing.DynamicService") ServiceResponse.newBuilder()
.build()) .setName("grpc.reflection.testing.DynamicService")
)); .build())));
handlerRegistry.addService(anotherDynamicService); handlerRegistry.addService(anotherDynamicService);
assertServiceResponseEquals(new HashSet<ServiceResponse>( assertServiceResponseEquals(
Arrays.asList( new HashSet<ServiceResponse>(
ServiceResponse.newBuilder() Arrays.asList(
.setName("grpc.reflection.v1alpha.ServerReflection") ServiceResponse.newBuilder()
.build(), .setName("grpc.reflection.v1alpha.ServerReflection")
ServiceResponse.newBuilder() .build(),
.setName("grpc.reflection.testing.ReflectableService") ServiceResponse.newBuilder()
.build(), .setName("grpc.reflection.testing.ReflectableService")
ServiceResponse.newBuilder() .build(),
.setName("grpc.reflection.testing.DynamicService") ServiceResponse.newBuilder()
.build(), .setName("grpc.reflection.testing.DynamicService")
ServiceResponse.newBuilder() .build(),
.setName("grpc.reflection.testing.AnotherDynamicService") ServiceResponse.newBuilder()
.build()) .setName("grpc.reflection.testing.AnotherDynamicService")
)); .build())));
handlerRegistry.removeService(dynamicService); handlerRegistry.removeService(dynamicService);
assertServiceResponseEquals(new HashSet<ServiceResponse>( assertServiceResponseEquals(
Arrays.asList( new HashSet<ServiceResponse>(
ServiceResponse.newBuilder() Arrays.asList(
.setName("grpc.reflection.v1alpha.ServerReflection") ServiceResponse.newBuilder()
.build(), .setName("grpc.reflection.v1alpha.ServerReflection")
ServiceResponse.newBuilder() .build(),
.setName("grpc.reflection.testing.ReflectableService") ServiceResponse.newBuilder()
.build(), .setName("grpc.reflection.testing.ReflectableService")
ServiceResponse.newBuilder() .build(),
.setName("grpc.reflection.testing.AnotherDynamicService") ServiceResponse.newBuilder()
.build()) .setName("grpc.reflection.testing.AnotherDynamicService")
)); .build())));
handlerRegistry.removeService(anotherDynamicService); handlerRegistry.removeService(anotherDynamicService);
assertServiceResponseEquals(originalServices); assertServiceResponseEquals(originalServices);
@ -207,7 +204,7 @@ public class ProtoReflectionServiceTest {
} }
@Test @Test
public void fileByFilenameForMutableServices() throws Exception { public void fileByFilenameConsistentForMutableServices() throws Exception {
ServerReflectionRequest request = ServerReflectionRequest request =
ServerReflectionRequest.newBuilder() ServerReflectionRequest.newBuilder()
.setHost(TEST_HOST) .setHost(TEST_HOST)
@ -229,13 +226,26 @@ public class ProtoReflectionServiceTest {
stub.serverReflectionInfo(responseObserver); stub.serverReflectionInfo(responseObserver);
handlerRegistry.addService(dynamicService); handlerRegistry.addService(dynamicService);
requestObserver.onNext(request); requestObserver.onNext(request);
handlerRegistry.removeService(dynamicService);
requestObserver.onNext(request);
requestObserver.onCompleted(); requestObserver.onCompleted();
StreamRecorder<ServerReflectionResponse> responseObserver2 = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver2 =
stub.serverReflectionInfo(responseObserver2);
handlerRegistry.removeService(dynamicService);
requestObserver2.onNext(request);
requestObserver2.onCompleted();
StreamRecorder<ServerReflectionResponse> responseObserver3 = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver3 =
stub.serverReflectionInfo(responseObserver3);
requestObserver3.onNext(request);
requestObserver3.onCompleted();
assertEquals(goldenResponse, responseObserver.getValues().get(0)); assertEquals(
assertEquals(ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE, ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE,
responseObserver.getValues().get(1).getMessageResponseCase()); responseObserver.firstValue().get().getMessageResponseCase());
assertEquals(goldenResponse, responseObserver2.firstValue().get());
assertEquals(
ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE,
responseObserver3.firstValue().get().getMessageResponseCase());
} }
@Test @Test
@ -251,8 +261,7 @@ public class ProtoReflectionServiceTest {
ReflectionTestProto.getDescriptor().toProto().toByteString(), ReflectionTestProto.getDescriptor().toProto().toByteString(),
ReflectionTestDepthTwoProto.getDescriptor().toProto().toByteString(), ReflectionTestDepthTwoProto.getDescriptor().toProto().toByteString(),
ReflectionTestDepthTwoAlternateProto.getDescriptor().toProto().toByteString(), ReflectionTestDepthTwoAlternateProto.getDescriptor().toProto().toByteString(),
ReflectionTestDepthThreeProto.getDescriptor().toProto().toByteString() ReflectionTestDepthThreeProto.getDescriptor().toProto().toByteString());
);
StreamRecorder<ServerReflectionResponse> responseObserver = StreamRecorder.create(); StreamRecorder<ServerReflectionResponse> responseObserver = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver = StreamObserver<ServerReflectionRequest> requestObserver =
@ -261,8 +270,11 @@ public class ProtoReflectionServiceTest {
requestObserver.onCompleted(); requestObserver.onCompleted();
List<ByteString> response = List<ByteString> response =
responseObserver.firstValue().get() responseObserver
.getFileDescriptorResponse().getFileDescriptorProtoList(); .firstValue()
.get()
.getFileDescriptorResponse()
.getFileDescriptorProtoList();
assertEquals(goldenResponse.size(), response.size()); assertEquals(goldenResponse.size(), response.size());
assertEquals(new HashSet<ByteString>(goldenResponse), new HashSet<ByteString>(response)); assertEquals(new HashSet<ByteString>(goldenResponse), new HashSet<ByteString>(response));
} }
@ -317,13 +329,26 @@ public class ProtoReflectionServiceTest {
stub.serverReflectionInfo(responseObserver); stub.serverReflectionInfo(responseObserver);
handlerRegistry.addService(dynamicService); handlerRegistry.addService(dynamicService);
requestObserver.onNext(request); requestObserver.onNext(request);
handlerRegistry.removeService(dynamicService);
requestObserver.onNext(request);
requestObserver.onCompleted(); requestObserver.onCompleted();
StreamRecorder<ServerReflectionResponse> responseObserver2 = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver2 =
stub.serverReflectionInfo(responseObserver2);
handlerRegistry.removeService(dynamicService);
requestObserver2.onNext(request);
requestObserver2.onCompleted();
StreamRecorder<ServerReflectionResponse> responseObserver3 = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver3 =
stub.serverReflectionInfo(responseObserver3);
requestObserver3.onNext(request);
requestObserver3.onCompleted();
assertEquals(goldenResponse, responseObserver.getValues().get(0)); assertEquals(
assertEquals(ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE, ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE,
responseObserver.getValues().get(1).getMessageResponseCase()); responseObserver.firstValue().get().getMessageResponseCase());
assertEquals(goldenResponse, responseObserver2.firstValue().get());
assertEquals(
ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE,
responseObserver3.firstValue().get().getMessageResponseCase());
} }
@Test @Test
@ -343,8 +368,7 @@ public class ProtoReflectionServiceTest {
ReflectionTestProto.getDescriptor().toProto().toByteString(), ReflectionTestProto.getDescriptor().toProto().toByteString(),
ReflectionTestDepthTwoProto.getDescriptor().toProto().toByteString(), ReflectionTestDepthTwoProto.getDescriptor().toProto().toByteString(),
ReflectionTestDepthTwoAlternateProto.getDescriptor().toProto().toByteString(), ReflectionTestDepthTwoAlternateProto.getDescriptor().toProto().toByteString(),
ReflectionTestDepthThreeProto.getDescriptor().toProto().toByteString() ReflectionTestDepthThreeProto.getDescriptor().toProto().toByteString());
);
StreamRecorder<ServerReflectionResponse> responseObserver = StreamRecorder.create(); StreamRecorder<ServerReflectionResponse> responseObserver = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver = StreamObserver<ServerReflectionRequest> requestObserver =
@ -353,8 +377,11 @@ public class ProtoReflectionServiceTest {
requestObserver.onCompleted(); requestObserver.onCompleted();
List<ByteString> response = List<ByteString> response =
responseObserver.firstValue().get() responseObserver
.getFileDescriptorResponse().getFileDescriptorProtoList(); .firstValue()
.get()
.getFileDescriptorResponse()
.getFileDescriptorProtoList();
assertEquals(goldenResponse.size(), response.size()); assertEquals(goldenResponse.size(), response.size());
assertEquals(new HashSet<ByteString>(goldenResponse), new HashSet<ByteString>(response)); assertEquals(new HashSet<ByteString>(goldenResponse), new HashSet<ByteString>(response));
} }
@ -419,13 +446,26 @@ public class ProtoReflectionServiceTest {
stub.serverReflectionInfo(responseObserver); stub.serverReflectionInfo(responseObserver);
handlerRegistry.addService(dynamicService); handlerRegistry.addService(dynamicService);
requestObserver.onNext(request); requestObserver.onNext(request);
handlerRegistry.removeService(dynamicService);
requestObserver.onNext(request);
requestObserver.onCompleted(); requestObserver.onCompleted();
StreamRecorder<ServerReflectionResponse> responseObserver2 = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver2 =
stub.serverReflectionInfo(responseObserver2);
handlerRegistry.removeService(dynamicService);
requestObserver2.onNext(request);
requestObserver2.onCompleted();
StreamRecorder<ServerReflectionResponse> responseObserver3 = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver3 =
stub.serverReflectionInfo(responseObserver3);
requestObserver3.onNext(request);
requestObserver3.onCompleted();
assertEquals(goldenResponse, responseObserver.getValues().get(0)); assertEquals(
assertEquals(ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE, ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE,
responseObserver.getValues().get(1).getMessageResponseCase()); responseObserver.firstValue().get().getMessageResponseCase());
assertEquals(goldenResponse, responseObserver2.firstValue().get());
assertEquals(
ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE,
responseObserver3.firstValue().get().getMessageResponseCase());
} }
@Test @Test
@ -477,13 +517,26 @@ public class ProtoReflectionServiceTest {
stub.serverReflectionInfo(responseObserver); stub.serverReflectionInfo(responseObserver);
handlerRegistry.addService(dynamicService); handlerRegistry.addService(dynamicService);
requestObserver.onNext(request); requestObserver.onNext(request);
handlerRegistry.removeService(dynamicService);
requestObserver.onNext(request);
requestObserver.onCompleted(); requestObserver.onCompleted();
StreamRecorder<ServerReflectionResponse> responseObserver2 = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver2 =
stub.serverReflectionInfo(responseObserver2);
handlerRegistry.removeService(dynamicService);
requestObserver2.onNext(request);
requestObserver2.onCompleted();
StreamRecorder<ServerReflectionResponse> responseObserver3 = StreamRecorder.create();
StreamObserver<ServerReflectionRequest> requestObserver3 =
stub.serverReflectionInfo(responseObserver3);
requestObserver3.onNext(request);
requestObserver3.onCompleted();
assertEquals(goldenResponse, responseObserver.getValues().get(0)); assertEquals(
assertEquals(ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE, ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE,
responseObserver.getValues().get(1).getMessageResponseCase()); responseObserver.firstValue().get().getMessageResponseCase());
assertEquals(goldenResponse, responseObserver2.firstValue().get());
assertEquals(
ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE,
responseObserver3.firstValue().get().getMessageResponseCase());
} }
@Test @Test
@ -548,15 +601,14 @@ public class ProtoReflectionServiceTest {
.build()) .build())
.build(); .build();
private static class FlowControlClientResponseObserver implements private static class FlowControlClientResponseObserver
ClientResponseObserver<ServerReflectionRequest, ServerReflectionResponse> { implements ClientResponseObserver<ServerReflectionRequest, ServerReflectionResponse> {
private final List<ServerReflectionResponse> responses = private final List<ServerReflectionResponse> responses =
new ArrayList<ServerReflectionResponse>(); new ArrayList<ServerReflectionResponse>();
private boolean onCompleteCalled = false; private boolean onCompleteCalled = false;
@Override @Override
public void beforeStart( public void beforeStart(final ClientCallStreamObserver<ServerReflectionRequest> requestStream) {
final ClientCallStreamObserver<ServerReflectionRequest> requestStream) {
requestStream.disableAutoInboundFlowControl(); requestStream.disableAutoInboundFlowControl();
} }