xds: add null reference checks in SslContextProviderSupplier (#8169)

This commit is contained in:
sanjaypujare 2021-05-12 10:27:44 -07:00 committed by GitHub
parent e08b9db208
commit e59604b7ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 7 deletions

View File

@ -41,8 +41,8 @@ public final class SslContextProviderSupplier implements Closeable {
public SslContextProviderSupplier(
BaseTlsContext tlsContext, TlsContextManager tlsContextManager) {
this.tlsContext = tlsContext;
this.tlsContextManager = tlsContextManager;
this.tlsContext = checkNotNull(tlsContext, "tlsContext");
this.tlsContextManager = checkNotNull(tlsContextManager, "tlsContextManager");
}
public BaseTlsContext getTlsContext() {
@ -92,11 +92,13 @@ public final class SslContextProviderSupplier implements Closeable {
/** Called by consumer when tlsContext changes. */
@Override
public synchronized void close() {
if (sslContextProvider != null) {
if (tlsContext instanceof UpstreamTlsContext) {
tlsContextManager.releaseClientSslContextProvider(sslContextProvider);
} else {
tlsContextManager.releaseServerSslContextProvider(sslContextProvider);
}
}
shutdown = true;
}
}

View File

@ -23,7 +23,9 @@ import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -65,16 +67,20 @@ public class SslContextProviderSupplierTest {
doReturn(mockSslContextProvider)
.when(mockTlsContextManager)
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager);
}
private void callUpdateSslContext() {
mockCallback = mock(SslContextProvider.Callback.class);
Executor mockExecutor = mock(Executor.class);
doReturn(mockExecutor).when(mockCallback).getExecutor();
supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager);
supplier.updateSslContext(mockCallback);
}
@Test
public void get_updateSecret() {
prepareSupplier();
callUpdateSslContext();
verify(mockTlsContextManager, times(2))
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
verify(mockTlsContextManager, times(0))
@ -97,6 +103,7 @@ public class SslContextProviderSupplierTest {
@Test
public void get_onException() {
prepareSupplier();
callUpdateSslContext();
ArgumentCaptor<SslContextProvider.Callback> callbackCaptor = ArgumentCaptor.forClass(null);
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
SslContextProvider.Callback capturedCallback = callbackCaptor.getValue();
@ -109,6 +116,7 @@ public class SslContextProviderSupplierTest {
@Test
public void testClose() {
prepareSupplier();
callUpdateSslContext();
supplier.close();
verify(mockTlsContextManager, times(1))
.releaseClientSslContextProvider(eq(mockSslContextProvider));
@ -120,4 +128,21 @@ public class SslContextProviderSupplierTest {
assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!");
}
}
@Test
public void testClose_nullSslContextProvider() {
prepareSupplier();
doThrow(new NullPointerException()).when(mockTlsContextManager)
.releaseClientSslContextProvider(null);
supplier.close();
verify(mockTlsContextManager, never())
.releaseClientSslContextProvider(eq(mockSslContextProvider));
SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class);
try {
supplier.updateSslContext(mockCallback);
Assert.fail("no exception thrown");
} catch (IllegalStateException expected) {
assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!");
}
}
}