From ac2ead70b4e6b77268a0affd2656503e130e3d37 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Tue, 26 Jan 2021 12:01:16 -0800 Subject: [PATCH] core: delay CallCredentialsApplyingTransport shutdown until metadataApplier finalized (#7813) Improve the CallCredentialsApplyingTransport shutdown lifecycle management. Right now CallCredentialsApplyingTransport shutdown the delegated real transport too early. It should be waiting for the metadataAppliers to finish because they may execute asynchronously. In addition, there is no shutdown check on CallCredentialsApplyingTransport for newStream(). The degraded lifecycle implementation may cause RejectionExecutionException, or accepting new RPCs after the underlying transport is already closed during channel shutdown. We added listener on metadataApplier to notify completion, a magic counter to track the pending metadataApplier for delaying shutdown, also added shutdown check for newStream(). --- ...llCredentialsApplyingTransportFactory.java | 85 ++++++++++++++- .../io/grpc/internal/MetadataApplierImpl.java | 19 +++- .../CallCredentials2ApplyingTest.java | 27 +++++ .../internal/CallCredentialsApplyingTest.java | 101 ++++++++++++++++++ 4 files changed, 229 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java index ab2a36d949..de96a3306b 100644 --- a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java @@ -29,9 +29,12 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.SecurityLevel; import io.grpc.Status; +import io.grpc.internal.MetadataApplierImpl.MetadataApplierListener; import java.net.SocketAddress; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.concurrent.GuardedBy; final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory { private final ClientTransportFactory delegate; @@ -66,6 +69,21 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa private class CallCredentialsApplyingTransport extends ForwardingConnectionClientTransport { private final ConnectionClientTransport delegate; private final String authority; + // Negative value means transport active, non-negative value indicates shutdown invoked. + private final AtomicInteger pendingApplier = new AtomicInteger(Integer.MIN_VALUE + 1); + private volatile Status shutdownStatus; + @GuardedBy("this") + private Status savedShutdownStatus; + @GuardedBy("this") + private Status savedShutdownNowStatus; + private final MetadataApplierListener applierListener = new MetadataApplierListener() { + @Override + public void onComplete() { + if (pendingApplier.decrementAndGet() == 0) { + maybeShutdown(); + } + } + }; CallCredentialsApplyingTransport(ConnectionClientTransport delegate, String authority) { this.delegate = checkNotNull(delegate, "delegate"); @@ -89,7 +107,11 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa } if (creds != null) { MetadataApplierImpl applier = new MetadataApplierImpl( - delegate, method, headers, callOptions); + delegate, method, headers, callOptions, applierListener); + if (pendingApplier.incrementAndGet() > 0) { + applierListener.onComplete(); + return new FailingClientStream(shutdownStatus); + } RequestInfo requestInfo = new RequestInfo() { @Override public MethodDescriptor getMethodDescriptor() { @@ -123,8 +145,69 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa } return applier.returnStream(); } else { + if (pendingApplier.get() >= 0) { + return new FailingClientStream(shutdownStatus); + } return delegate.newStream(method, headers, callOptions); } } + + @Override + public void shutdown(Status status) { + checkNotNull(status, "status"); + synchronized (this) { + if (pendingApplier.get() < 0) { + shutdownStatus = status; + pendingApplier.addAndGet(Integer.MAX_VALUE); + } else { + return; + } + if (pendingApplier.get() != 0) { + savedShutdownStatus = status; + return; + } + } + super.shutdown(status); + } + + // TODO(zivy): cancel pending applier here. + @Override + public void shutdownNow(Status status) { + checkNotNull(status, "status"); + synchronized (this) { + if (pendingApplier.get() < 0) { + shutdownStatus = status; + pendingApplier.addAndGet(Integer.MAX_VALUE); + } else if (savedShutdownNowStatus != null) { + return; + } + if (pendingApplier.get() != 0) { + savedShutdownNowStatus = status; + // TODO(zivy): propagate shutdownNow to the delegate immediately. + return; + } + } + super.shutdownNow(status); + } + + private void maybeShutdown() { + Status maybeShutdown; + Status maybeShutdownNow; + synchronized (this) { + if (pendingApplier.get() != 0) { + return; + } + maybeShutdown = savedShutdownStatus; + maybeShutdownNow = savedShutdownNowStatus; + savedShutdownStatus = null; + savedShutdownNowStatus = null; + } + if (maybeShutdown != null) { + super.shutdown(maybeShutdown); + } + if (maybeShutdownNow != null) { + super.shutdownNow(maybeShutdownNow); + } + } } } diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index 4c49a14a06..76d280b2d0 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -35,6 +35,7 @@ final class MetadataApplierImpl extends MetadataApplier { private final Metadata origHeaders; private final CallOptions callOptions; private final Context ctx; + private final MetadataApplierListener listener; private final Object lock = new Object(); @@ -51,12 +52,13 @@ final class MetadataApplierImpl extends MetadataApplier { MetadataApplierImpl( ClientTransport transport, MethodDescriptor method, Metadata origHeaders, - CallOptions callOptions) { + CallOptions callOptions, MetadataApplierListener listener) { this.transport = transport; this.method = method; this.origHeaders = origHeaders; this.callOptions = callOptions; this.ctx = Context.current(); + this.listener = listener; } @Override @@ -84,14 +86,19 @@ final class MetadataApplierImpl extends MetadataApplier { private void finalizeWith(ClientStream stream) { checkState(!finalized, "already finalized"); finalized = true; + boolean directStream = false; synchronized (lock) { if (returnedStream == null) { // Fast path: returnStream() hasn't been called, the call will use the // real stream directly. returnedStream = stream; - return; + directStream = true; } } + if (directStream) { + listener.onComplete(); + return; + } // returnStream() has been called before me, thus delayedStream must have been // created. checkState(delayedStream != null, "delayedStream is null"); @@ -100,6 +107,7 @@ final class MetadataApplierImpl extends MetadataApplier { // TODO(ejona): run this on a separate thread slow.run(); } + listener.onComplete(); } /** @@ -116,4 +124,11 @@ final class MetadataApplierImpl extends MetadataApplier { } } } + + public interface MetadataApplierListener { + /** + * Notify that the metadata has been successfully applied, or failed. + * */ + void onComplete(); + } } diff --git a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java index c26944c16b..7725c46726 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java @@ -19,6 +19,7 @@ package io.grpc.internal; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doAnswer; @@ -203,6 +204,10 @@ public class CallCredentials2ApplyingTest { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); assertSame(ex, stream.getError().getCause()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -227,6 +232,10 @@ public class CallCredentials2ApplyingTest { assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -249,6 +258,10 @@ public class CallCredentials2ApplyingTest { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); assertSame(error, stream.getError()); + transport.shutdownNow(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdownNow(Status.UNAVAILABLE); } @Test @@ -263,6 +276,9 @@ public class CallCredentials2ApplyingTest { any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + transport.shutdown(Status.UNAVAILABLE); + verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); + Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); @@ -271,6 +287,9 @@ public class CallCredentials2ApplyingTest { assertSame(mockStream, stream.getRealStream()); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -290,6 +309,10 @@ public class CallCredentials2ApplyingTest { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); assertSame(error, failingStream.getError()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -301,5 +324,9 @@ public class CallCredentials2ApplyingTest { assertSame(mockStream, stream); assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } } diff --git a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java index 6949ab7c31..61a221f73d 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java @@ -19,12 +19,14 @@ package io.grpc.internal; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -179,6 +181,11 @@ public class CallCredentialsApplyingTest { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); assertSame(ex, stream.getError().getCause()); + + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -192,6 +199,10 @@ public class CallCredentialsApplyingTest { assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -214,6 +225,10 @@ public class CallCredentialsApplyingTest { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); assertSame(error, stream.getError()); + transport.shutdownNow(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdownNow(Status.UNAVAILABLE); } @Test @@ -228,6 +243,11 @@ public class CallCredentialsApplyingTest { same(mockExecutor), applierCaptor.capture()); verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + transport.shutdown(Status.UNAVAILABLE); + verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); @@ -236,6 +256,79 @@ public class CallCredentialsApplyingTest { assertSame(mockStream, stream.getRealStream()); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + } + + @Test + public void delayedShutdown_shutdownShutdownNowThenApply() { + transport.newStream(method, origHeaders, callOptions); + ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), + same(mockExecutor), applierCaptor.capture()); + transport.shutdown(Status.UNAVAILABLE); + transport.shutdownNow(Status.ABORTED); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport, never()).shutdown(any(Status.class)); + verify(mockTransport, never()).shutdownNow(any(Status.class)); + Metadata headers = new Metadata(); + headers.put(CREDS_KEY, CREDS_VALUE); + applierCaptor.getValue().apply(headers); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + verify(mockTransport).shutdownNow(Status.ABORTED); + } + + @Test + public void delayedShutdown_shutdownThenApplyThenShutdownNow() { + transport.newStream(method, origHeaders, callOptions); + ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), + same(mockExecutor), applierCaptor.capture()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport, never()).shutdown(any(Status.class)); + Metadata headers = new Metadata(); + headers.put(CREDS_KEY, CREDS_VALUE); + applierCaptor.getValue().apply(headers); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + transport.shutdownNow(Status.ABORTED); + verify(mockTransport).shutdownNow(Status.ABORTED); + + transport.shutdown(Status.UNAVAILABLE); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + transport.shutdownNow(Status.ABORTED); + verify(mockTransport, times(2)).shutdownNow(Status.ABORTED); + } + + @Test + public void delayedShutdown_shutdownMulti() { + Metadata headers = new Metadata(); + headers.put(CREDS_KEY, CREDS_VALUE); + + transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions); + ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds, times(3)).applyRequestMetadata(any(RequestInfo.class), + same(mockExecutor), applierCaptor.capture()); + applierCaptor.getAllValues().get(1).apply(headers); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); + + applierCaptor.getAllValues().get(0).apply(headers); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); + + applierCaptor.getAllValues().get(2).apply(headers); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -255,6 +348,10 @@ public class CallCredentialsApplyingTest { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); assertSame(error, failingStream.getError()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -266,6 +363,10 @@ public class CallCredentialsApplyingTest { assertSame(mockStream, stream); assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test