mirror of https://github.com/grpc/grpc-java.git
xds: add null reference checks in SslContextProviderSupplier (#8169)
This commit is contained in:
parent
e08b9db208
commit
e59604b7ce
|
|
@ -41,8 +41,8 @@ public final class SslContextProviderSupplier implements Closeable {
|
||||||
|
|
||||||
public SslContextProviderSupplier(
|
public SslContextProviderSupplier(
|
||||||
BaseTlsContext tlsContext, TlsContextManager tlsContextManager) {
|
BaseTlsContext tlsContext, TlsContextManager tlsContextManager) {
|
||||||
this.tlsContext = tlsContext;
|
this.tlsContext = checkNotNull(tlsContext, "tlsContext");
|
||||||
this.tlsContextManager = tlsContextManager;
|
this.tlsContextManager = checkNotNull(tlsContextManager, "tlsContextManager");
|
||||||
}
|
}
|
||||||
|
|
||||||
public BaseTlsContext getTlsContext() {
|
public BaseTlsContext getTlsContext() {
|
||||||
|
|
@ -92,11 +92,13 @@ public final class SslContextProviderSupplier implements Closeable {
|
||||||
/** Called by consumer when tlsContext changes. */
|
/** Called by consumer when tlsContext changes. */
|
||||||
@Override
|
@Override
|
||||||
public synchronized void close() {
|
public synchronized void close() {
|
||||||
|
if (sslContextProvider != null) {
|
||||||
if (tlsContext instanceof UpstreamTlsContext) {
|
if (tlsContext instanceof UpstreamTlsContext) {
|
||||||
tlsContextManager.releaseClientSslContextProvider(sslContextProvider);
|
tlsContextManager.releaseClientSslContextProvider(sslContextProvider);
|
||||||
} else {
|
} else {
|
||||||
tlsContextManager.releaseServerSslContextProvider(sslContextProvider);
|
tlsContextManager.releaseServerSslContextProvider(sslContextProvider);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
shutdown = true;
|
shutdown = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,9 @@ import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
import static org.mockito.Mockito.any;
|
import static org.mockito.Mockito.any;
|
||||||
import static org.mockito.Mockito.doReturn;
|
import static org.mockito.Mockito.doReturn;
|
||||||
|
import static org.mockito.Mockito.doThrow;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.never;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
|
||||||
|
|
@ -65,16 +67,20 @@ public class SslContextProviderSupplierTest {
|
||||||
doReturn(mockSslContextProvider)
|
doReturn(mockSslContextProvider)
|
||||||
.when(mockTlsContextManager)
|
.when(mockTlsContextManager)
|
||||||
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
|
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
|
||||||
|
supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void callUpdateSslContext() {
|
||||||
mockCallback = mock(SslContextProvider.Callback.class);
|
mockCallback = mock(SslContextProvider.Callback.class);
|
||||||
Executor mockExecutor = mock(Executor.class);
|
Executor mockExecutor = mock(Executor.class);
|
||||||
doReturn(mockExecutor).when(mockCallback).getExecutor();
|
doReturn(mockExecutor).when(mockCallback).getExecutor();
|
||||||
supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager);
|
|
||||||
supplier.updateSslContext(mockCallback);
|
supplier.updateSslContext(mockCallback);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void get_updateSecret() {
|
public void get_updateSecret() {
|
||||||
prepareSupplier();
|
prepareSupplier();
|
||||||
|
callUpdateSslContext();
|
||||||
verify(mockTlsContextManager, times(2))
|
verify(mockTlsContextManager, times(2))
|
||||||
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
|
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
|
||||||
verify(mockTlsContextManager, times(0))
|
verify(mockTlsContextManager, times(0))
|
||||||
|
|
@ -97,6 +103,7 @@ public class SslContextProviderSupplierTest {
|
||||||
@Test
|
@Test
|
||||||
public void get_onException() {
|
public void get_onException() {
|
||||||
prepareSupplier();
|
prepareSupplier();
|
||||||
|
callUpdateSslContext();
|
||||||
ArgumentCaptor<SslContextProvider.Callback> callbackCaptor = ArgumentCaptor.forClass(null);
|
ArgumentCaptor<SslContextProvider.Callback> callbackCaptor = ArgumentCaptor.forClass(null);
|
||||||
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
|
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
|
||||||
SslContextProvider.Callback capturedCallback = callbackCaptor.getValue();
|
SslContextProvider.Callback capturedCallback = callbackCaptor.getValue();
|
||||||
|
|
@ -109,6 +116,7 @@ public class SslContextProviderSupplierTest {
|
||||||
@Test
|
@Test
|
||||||
public void testClose() {
|
public void testClose() {
|
||||||
prepareSupplier();
|
prepareSupplier();
|
||||||
|
callUpdateSslContext();
|
||||||
supplier.close();
|
supplier.close();
|
||||||
verify(mockTlsContextManager, times(1))
|
verify(mockTlsContextManager, times(1))
|
||||||
.releaseClientSslContextProvider(eq(mockSslContextProvider));
|
.releaseClientSslContextProvider(eq(mockSslContextProvider));
|
||||||
|
|
@ -120,4 +128,21 @@ public class SslContextProviderSupplierTest {
|
||||||
assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!");
|
assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testClose_nullSslContextProvider() {
|
||||||
|
prepareSupplier();
|
||||||
|
doThrow(new NullPointerException()).when(mockTlsContextManager)
|
||||||
|
.releaseClientSslContextProvider(null);
|
||||||
|
supplier.close();
|
||||||
|
verify(mockTlsContextManager, never())
|
||||||
|
.releaseClientSslContextProvider(eq(mockSslContextProvider));
|
||||||
|
SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class);
|
||||||
|
try {
|
||||||
|
supplier.updateSslContext(mockCallback);
|
||||||
|
Assert.fail("no exception thrown");
|
||||||
|
} catch (IllegalStateException expected) {
|
||||||
|
assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue