diff --git a/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java b/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java index 2efe4b3951..5cd294b4fb 100644 --- a/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java +++ b/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java @@ -27,6 +27,7 @@ import io.grpc.health.v1.HealthCheckRequest; import io.grpc.health.v1.HealthCheckResponse; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; import io.grpc.health.v1.HealthGrpc; +import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import java.util.HashMap; import java.util.IdentityHashMap; @@ -83,6 +84,11 @@ final class HealthServiceImpl extends HealthGrpc.HealthImplBase { final StreamObserver responseObserver) { final String service = request.getService(); synchronized (watchLock) { + if (responseObserver instanceof ServerCallStreamObserver) { + ((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> { + removeWatcher(service, responseObserver); + }); + } ServingStatus status = statusMap.get(service); responseObserver.onNext(getResponseForWatch(status)); IdentityHashMap, Boolean> serviceWatchers = @@ -98,21 +104,25 @@ final class HealthServiceImpl extends HealthGrpc.HealthImplBase { @Override // Called when the client has closed the stream public void cancelled(Context context) { - synchronized (watchLock) { - IdentityHashMap, Boolean> serviceWatchers = - watchers.get(service); - if (serviceWatchers != null) { - serviceWatchers.remove(responseObserver); - if (serviceWatchers.isEmpty()) { - watchers.remove(service); - } - } - } + removeWatcher(service, responseObserver); } }, MoreExecutors.directExecutor()); } + void removeWatcher(String service, StreamObserver responseObserver) { + synchronized (watchLock) { + IdentityHashMap, Boolean> serviceWatchers = + watchers.get(service); + if (serviceWatchers != null) { + serviceWatchers.remove(responseObserver); + if (serviceWatchers.isEmpty()) { + watchers.remove(service); + } + } + } + } + void setStatus(String service, ServingStatus status) { synchronized (watchLock) { if (terminal) { diff --git a/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java b/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java index 87d4ac29be..b2652e9277 100644 --- a/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java +++ b/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java @@ -18,6 +18,11 @@ package io.grpc.protobuf.services; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import io.grpc.BindableService; import io.grpc.Context; @@ -28,6 +33,7 @@ import io.grpc.health.v1.HealthCheckRequest; import io.grpc.health.v1.HealthCheckResponse; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; import io.grpc.health.v1.HealthGrpc; +import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcServerRule; import java.util.ArrayDeque; @@ -109,6 +115,18 @@ public class HealthStatusManagerTest { assertThat(obs.responses).isEmpty(); } + @Test + @SuppressWarnings("unchecked") + public void serverCallStreamObserver_watch() throws Exception { + manager.setStatus(SERVICE1, ServingStatus.SERVING); + ServerCallStreamObserver observer = mock(ServerCallStreamObserver.class); + service.watch(HealthCheckRequest.newBuilder().setService(SERVICE1).build(), observer); + + verify(observer, times(1)) + .onNext(eq(HealthCheckResponse.newBuilder().setStatus(ServingStatus.SERVING).build())); + verify(observer, times(1)).setOnCancelHandler(any(Runnable.class)); + } + @Test public void enterTerminalState_ignoreClear() throws Exception { manager.setStatus(SERVICE1, ServingStatus.SERVING);