core: delay CallCredentialsApplyingTransport shutdown until metadataApplier finalized (#7813)

Improve the CallCredentialsApplyingTransport shutdown lifecycle management. Right now CallCredentialsApplyingTransport shutdown the delegated real transport too early. It should be waiting for the metadataAppliers to finish because they may execute asynchronously. In addition, there is no shutdown check on CallCredentialsApplyingTransport for newStream(). The degraded lifecycle implementation may cause RejectionExecutionException, or accepting new RPCs after the underlying transport is already closed during channel shutdown.

We added listener on metadataApplier to notify completion, a magic counter to track the pending metadataApplier for delaying shutdown, also added shutdown check for newStream().
This commit is contained in:
yifeizhuang 2021-01-26 12:01:16 -08:00 committed by GitHub
parent dbd903c018
commit ac2ead70b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 229 additions and 3 deletions

View File

@ -29,9 +29,12 @@ import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.SecurityLevel;
import io.grpc.Status;
import io.grpc.internal.MetadataApplierImpl.MetadataApplierListener;
import java.net.SocketAddress;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.concurrent.GuardedBy;
final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory {
private final ClientTransportFactory delegate;
@ -66,6 +69,21 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa
private class CallCredentialsApplyingTransport extends ForwardingConnectionClientTransport {
private final ConnectionClientTransport delegate;
private final String authority;
// Negative value means transport active, non-negative value indicates shutdown invoked.
private final AtomicInteger pendingApplier = new AtomicInteger(Integer.MIN_VALUE + 1);
private volatile Status shutdownStatus;
@GuardedBy("this")
private Status savedShutdownStatus;
@GuardedBy("this")
private Status savedShutdownNowStatus;
private final MetadataApplierListener applierListener = new MetadataApplierListener() {
@Override
public void onComplete() {
if (pendingApplier.decrementAndGet() == 0) {
maybeShutdown();
}
}
};
CallCredentialsApplyingTransport(ConnectionClientTransport delegate, String authority) {
this.delegate = checkNotNull(delegate, "delegate");
@ -89,7 +107,11 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa
}
if (creds != null) {
MetadataApplierImpl applier = new MetadataApplierImpl(
delegate, method, headers, callOptions);
delegate, method, headers, callOptions, applierListener);
if (pendingApplier.incrementAndGet() > 0) {
applierListener.onComplete();
return new FailingClientStream(shutdownStatus);
}
RequestInfo requestInfo = new RequestInfo() {
@Override
public MethodDescriptor<?, ?> getMethodDescriptor() {
@ -123,8 +145,69 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa
}
return applier.returnStream();
} else {
if (pendingApplier.get() >= 0) {
return new FailingClientStream(shutdownStatus);
}
return delegate.newStream(method, headers, callOptions);
}
}
@Override
public void shutdown(Status status) {
checkNotNull(status, "status");
synchronized (this) {
if (pendingApplier.get() < 0) {
shutdownStatus = status;
pendingApplier.addAndGet(Integer.MAX_VALUE);
} else {
return;
}
if (pendingApplier.get() != 0) {
savedShutdownStatus = status;
return;
}
}
super.shutdown(status);
}
// TODO(zivy): cancel pending applier here.
@Override
public void shutdownNow(Status status) {
checkNotNull(status, "status");
synchronized (this) {
if (pendingApplier.get() < 0) {
shutdownStatus = status;
pendingApplier.addAndGet(Integer.MAX_VALUE);
} else if (savedShutdownNowStatus != null) {
return;
}
if (pendingApplier.get() != 0) {
savedShutdownNowStatus = status;
// TODO(zivy): propagate shutdownNow to the delegate immediately.
return;
}
}
super.shutdownNow(status);
}
private void maybeShutdown() {
Status maybeShutdown;
Status maybeShutdownNow;
synchronized (this) {
if (pendingApplier.get() != 0) {
return;
}
maybeShutdown = savedShutdownStatus;
maybeShutdownNow = savedShutdownNowStatus;
savedShutdownStatus = null;
savedShutdownNowStatus = null;
}
if (maybeShutdown != null) {
super.shutdown(maybeShutdown);
}
if (maybeShutdownNow != null) {
super.shutdownNow(maybeShutdownNow);
}
}
}
}

View File

@ -35,6 +35,7 @@ final class MetadataApplierImpl extends MetadataApplier {
private final Metadata origHeaders;
private final CallOptions callOptions;
private final Context ctx;
private final MetadataApplierListener listener;
private final Object lock = new Object();
@ -51,12 +52,13 @@ final class MetadataApplierImpl extends MetadataApplier {
MetadataApplierImpl(
ClientTransport transport, MethodDescriptor<?, ?> method, Metadata origHeaders,
CallOptions callOptions) {
CallOptions callOptions, MetadataApplierListener listener) {
this.transport = transport;
this.method = method;
this.origHeaders = origHeaders;
this.callOptions = callOptions;
this.ctx = Context.current();
this.listener = listener;
}
@Override
@ -84,14 +86,19 @@ final class MetadataApplierImpl extends MetadataApplier {
private void finalizeWith(ClientStream stream) {
checkState(!finalized, "already finalized");
finalized = true;
boolean directStream = false;
synchronized (lock) {
if (returnedStream == null) {
// Fast path: returnStream() hasn't been called, the call will use the
// real stream directly.
returnedStream = stream;
return;
directStream = true;
}
}
if (directStream) {
listener.onComplete();
return;
}
// returnStream() has been called before me, thus delayedStream must have been
// created.
checkState(delayedStream != null, "delayedStream is null");
@ -100,6 +107,7 @@ final class MetadataApplierImpl extends MetadataApplier {
// TODO(ejona): run this on a separate thread
slow.run();
}
listener.onComplete();
}
/**
@ -116,4 +124,11 @@ final class MetadataApplierImpl extends MetadataApplier {
}
}
}
public interface MetadataApplierListener {
/**
* Notify that the metadata has been successfully applied, or failed.
* */
void onComplete();
}
}

View File

@ -19,6 +19,7 @@ package io.grpc.internal;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
@ -203,6 +204,10 @@ public class CallCredentials2ApplyingTest {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode());
assertSame(ex, stream.getError().getCause());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
@Test
@ -227,6 +232,10 @@ public class CallCredentials2ApplyingTest {
assertSame(mockStream, stream);
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
@Test
@ -249,6 +258,10 @@ public class CallCredentials2ApplyingTest {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertSame(error, stream.getError());
transport.shutdownNow(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdownNow(Status.UNAVAILABLE);
}
@Test
@ -263,6 +276,9 @@ public class CallCredentials2ApplyingTest {
any(RequestInfo.class), same(mockExecutor), applierCaptor.capture());
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);
Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
applierCaptor.getValue().apply(headers);
@ -271,6 +287,9 @@ public class CallCredentials2ApplyingTest {
assertSame(mockStream, stream.getRealStream());
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
@Test
@ -290,6 +309,10 @@ public class CallCredentials2ApplyingTest {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
FailingClientStream failingStream = (FailingClientStream) stream.getRealStream();
assertSame(error, failingStream.getError());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
@Test
@ -301,5 +324,9 @@ public class CallCredentials2ApplyingTest {
assertSame(mockStream, stream);
assertNull(origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
}

View File

@ -19,12 +19,14 @@ package io.grpc.internal;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ -179,6 +181,11 @@ public class CallCredentialsApplyingTest {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode());
assertSame(ex, stream.getError().getCause());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
@Test
@ -192,6 +199,10 @@ public class CallCredentialsApplyingTest {
assertSame(mockStream, stream);
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
@Test
@ -214,6 +225,10 @@ public class CallCredentialsApplyingTest {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertSame(error, stream.getError());
transport.shutdownNow(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdownNow(Status.UNAVAILABLE);
}
@Test
@ -228,6 +243,11 @@ public class CallCredentialsApplyingTest {
same(mockExecutor), applierCaptor.capture());
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
applierCaptor.getValue().apply(headers);
@ -236,6 +256,79 @@ public class CallCredentialsApplyingTest {
assertSame(mockStream, stream.getRealStream());
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
@Test
public void delayedShutdown_shutdownShutdownNowThenApply() {
transport.newStream(method, origHeaders, callOptions);
ArgumentCaptor<CallCredentials.MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null);
verify(mockCreds).applyRequestMetadata(any(RequestInfo.class),
same(mockExecutor), applierCaptor.capture());
transport.shutdown(Status.UNAVAILABLE);
transport.shutdownNow(Status.ABORTED);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport, never()).shutdown(any(Status.class));
verify(mockTransport, never()).shutdownNow(any(Status.class));
Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
applierCaptor.getValue().apply(headers);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdownNow(Status.ABORTED);
}
@Test
public void delayedShutdown_shutdownThenApplyThenShutdownNow() {
transport.newStream(method, origHeaders, callOptions);
ArgumentCaptor<CallCredentials.MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null);
verify(mockCreds).applyRequestMetadata(any(RequestInfo.class),
same(mockExecutor), applierCaptor.capture());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport, never()).shutdown(any(Status.class));
Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
applierCaptor.getValue().apply(headers);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
transport.shutdownNow(Status.ABORTED);
verify(mockTransport).shutdownNow(Status.ABORTED);
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
transport.shutdownNow(Status.ABORTED);
verify(mockTransport, times(2)).shutdownNow(Status.ABORTED);
}
@Test
public void delayedShutdown_shutdownMulti() {
Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
transport.newStream(method, origHeaders, callOptions);
transport.newStream(method, origHeaders, callOptions);
transport.newStream(method, origHeaders, callOptions);
ArgumentCaptor<CallCredentials.MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null);
verify(mockCreds, times(3)).applyRequestMetadata(any(RequestInfo.class),
same(mockExecutor), applierCaptor.capture());
applierCaptor.getAllValues().get(1).apply(headers);
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);
applierCaptor.getAllValues().get(0).apply(headers);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);
applierCaptor.getAllValues().get(2).apply(headers);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
@Test
@ -255,6 +348,10 @@ public class CallCredentialsApplyingTest {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
FailingClientStream failingStream = (FailingClientStream) stream.getRealStream();
assertSame(error, failingStream.getError());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
@Test
@ -266,6 +363,10 @@ public class CallCredentialsApplyingTest {
assertSame(mockStream, stream);
assertNull(origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
@Test