mirror of https://github.com/grpc/grpc-java.git
xds: fix the race condition in SslContextProviderSupplier's updateSslContext and close (#8294)
This commit is contained in:
parent
3965315039
commit
629748da61
|
|
@ -17,7 +17,6 @@
|
|||
package io.grpc.xds.internal.sds;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkNotNull;
|
||||
import static com.google.common.base.Preconditions.checkState;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.MoreObjects;
|
||||
|
|
@ -56,13 +55,14 @@ public final class SslContextProviderSupplier implements Closeable {
|
|||
public synchronized void updateSslContext(final SslContextProvider.Callback callback) {
|
||||
checkNotNull(callback, "callback");
|
||||
try {
|
||||
checkState(!shutdown, "Supplier is shutdown!");
|
||||
if (!shutdown) {
|
||||
if (sslContextProvider == null) {
|
||||
sslContextProvider = getSslContextProvider();
|
||||
}
|
||||
}
|
||||
// we want to increment the ref-count so call findOrCreate again...
|
||||
final SslContextProvider toRelease = getSslContextProvider();
|
||||
sslContextProvider.addCallback(
|
||||
toRelease.addCallback(
|
||||
new SslContextProvider.Callback(callback.getExecutor()) {
|
||||
|
||||
@Override
|
||||
|
|
@ -115,6 +115,7 @@ public final class SslContextProviderSupplier implements Closeable {
|
|||
tlsContextManager.releaseServerSslContextProvider(sslContextProvider);
|
||||
}
|
||||
}
|
||||
sslContextProvider = null;
|
||||
shutdown = true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,16 +23,13 @@ import static org.mockito.Mockito.doReturn;
|
|||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.spy;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
import com.google.common.util.concurrent.MoreExecutors;
|
||||
import io.grpc.xds.EnvoyServerProtoData;
|
||||
import io.grpc.xds.TlsContextManager;
|
||||
import io.netty.handler.ssl.SslContext;
|
||||
import java.util.concurrent.Executor;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
|
@ -106,7 +103,9 @@ public class SslContextProviderSupplierTest {
|
|||
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
|
||||
SslContextProvider.Callback capturedCallback = callbackCaptor.getValue();
|
||||
assertThat(capturedCallback).isNotNull();
|
||||
capturedCallback.onException(new Exception("test"));
|
||||
Exception exception = new Exception("test");
|
||||
capturedCallback.onException(exception);
|
||||
verify(mockCallback, times(1)).onException(eq(exception));
|
||||
verify(mockTlsContextManager, times(1))
|
||||
.releaseClientSslContextProvider(eq(mockSslContextProvider));
|
||||
}
|
||||
|
|
@ -118,20 +117,11 @@ public class SslContextProviderSupplierTest {
|
|||
supplier.close();
|
||||
verify(mockTlsContextManager, times(1))
|
||||
.releaseClientSslContextProvider(eq(mockSslContextProvider));
|
||||
SslContextProvider.Callback mockCallback = spy(
|
||||
new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
|
||||
@Override
|
||||
public void updateSecret(SslContext sslContext) {
|
||||
Assert.fail("unexpected call");
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onException(Throwable argument) {
|
||||
assertThat(argument).isInstanceOf(IllegalStateException.class);
|
||||
assertThat(argument).hasMessageThat().contains("Supplier is shutdown!");
|
||||
}
|
||||
});
|
||||
supplier.updateSslContext(mockCallback);
|
||||
verify(mockTlsContextManager, times(3))
|
||||
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
|
||||
verify(mockTlsContextManager, times(1))
|
||||
.releaseClientSslContextProvider(any(SslContextProvider.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
@ -142,19 +132,8 @@ public class SslContextProviderSupplierTest {
|
|||
supplier.close();
|
||||
verify(mockTlsContextManager, never())
|
||||
.releaseClientSslContextProvider(eq(mockSslContextProvider));
|
||||
SslContextProvider.Callback mockCallback = spy(
|
||||
new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
|
||||
@Override
|
||||
public void updateSecret(SslContext sslContext) {
|
||||
Assert.fail("unexpected call");
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onException(Throwable argument) {
|
||||
assertThat(argument).isInstanceOf(IllegalStateException.class);
|
||||
assertThat(argument).hasMessageThat().contains("Supplier is shutdown!");
|
||||
}
|
||||
});
|
||||
supplier.updateSslContext(mockCallback);
|
||||
callUpdateSslContext();
|
||||
verify(mockTlsContextManager, times(1))
|
||||
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue