xds: remove remaining occurences of SDS in the security code (#10219)

This commit is contained in:
sanjaypujare 2023-05-25 10:28:37 -07:00 committed by GitHub
parent e172ea7efc
commit e875d1b01c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 61 deletions

View File

@ -44,7 +44,7 @@ import java.util.logging.Logger;
import javax.annotation.Nullable; import javax.annotation.Nullable;
/** /**
* Provides client and server side gRPC {@link ProtocolNegotiator}s that use SDS to provide the SSL * Provides client and server side gRPC {@link ProtocolNegotiator}s to provide the SSL
* context. * context.
*/ */
@VisibleForTesting @VisibleForTesting
@ -61,7 +61,7 @@ public final class SecurityProtocolNegotiators {
public static final Attributes.Key<SslContextProviderSupplier> public static final Attributes.Key<SslContextProviderSupplier>
ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER = ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER =
Attributes.Key.create("io.grpc.xds.internal.sds.server.sslContextProviderSupplier"); Attributes.Key.create("io.grpc.xds.internal.security.server.sslContextProviderSupplier");
/** /**
* Returns a {@link InternalProtocolNegotiator.ClientFactory}. * Returns a {@link InternalProtocolNegotiator.ClientFactory}.
@ -88,7 +88,7 @@ public final class SecurityProtocolNegotiators {
@Override @Override
public ProtocolNegotiator newNegotiator(ObjectPool<? extends Executor> offloadExecutorPool) { public ProtocolNegotiator newNegotiator(ObjectPool<? extends Executor> offloadExecutorPool) {
return new ServerSdsProtocolNegotiator( return new ServerSecurityProtocolNegotiator(
fallbackProtocolNegotiator.newNegotiator(offloadExecutorPool)); fallbackProtocolNegotiator.newNegotiator(offloadExecutorPool));
} }
} }
@ -103,7 +103,7 @@ public final class SecurityProtocolNegotiators {
@Override @Override
public ProtocolNegotiator newNegotiator() { public ProtocolNegotiator newNegotiator() {
return new ClientSdsProtocolNegotiator(fallbackProtocolNegotiator.newNegotiator()); return new ClientSecurityProtocolNegotiator(fallbackProtocolNegotiator.newNegotiator());
} }
@Override @Override
@ -113,11 +113,11 @@ public final class SecurityProtocolNegotiators {
} }
@VisibleForTesting @VisibleForTesting
static final class ClientSdsProtocolNegotiator implements ProtocolNegotiator { static final class ClientSecurityProtocolNegotiator implements ProtocolNegotiator {
@Nullable private final ProtocolNegotiator fallbackProtocolNegotiator; @Nullable private final ProtocolNegotiator fallbackProtocolNegotiator;
ClientSdsProtocolNegotiator(@Nullable ProtocolNegotiator fallbackProtocolNegotiator) { ClientSecurityProtocolNegotiator(@Nullable ProtocolNegotiator fallbackProtocolNegotiator) {
this.fallbackProtocolNegotiator = fallbackProtocolNegotiator; this.fallbackProtocolNegotiator = fallbackProtocolNegotiator;
} }
@ -137,7 +137,7 @@ public final class SecurityProtocolNegotiators {
fallbackProtocolNegotiator, "No TLS config and no fallbackProtocolNegotiator!"); fallbackProtocolNegotiator, "No TLS config and no fallbackProtocolNegotiator!");
return fallbackProtocolNegotiator.newHandler(grpcHandler); return fallbackProtocolNegotiator.newHandler(grpcHandler);
} }
return new ClientSdsHandler(grpcHandler, localSslContextProviderSupplier); return new ClientSecurityHandler(grpcHandler, localSslContextProviderSupplier);
} }
@Override @Override
@ -176,12 +176,12 @@ public final class SecurityProtocolNegotiators {
} }
@VisibleForTesting @VisibleForTesting
static final class ClientSdsHandler static final class ClientSecurityHandler
extends InternalProtocolNegotiators.ProtocolNegotiationHandler { extends InternalProtocolNegotiators.ProtocolNegotiationHandler {
private final GrpcHttp2ConnectionHandler grpcHandler; private final GrpcHttp2ConnectionHandler grpcHandler;
private final SslContextProviderSupplier sslContextProviderSupplier; private final SslContextProviderSupplier sslContextProviderSupplier;
ClientSdsHandler( ClientSecurityHandler(
GrpcHttp2ConnectionHandler grpcHandler, GrpcHttp2ConnectionHandler grpcHandler,
SslContextProviderSupplier sslContextProviderSupplier) { SslContextProviderSupplier sslContextProviderSupplier) {
super( super(
@ -214,7 +214,7 @@ public final class SecurityProtocolNegotiators {
} }
logger.log( logger.log(
Level.FINEST, Level.FINEST,
"ClientSdsHandler.updateSslContext authority={0}, ctx.name={1}", "ClientSecurityHandler.updateSslContext authority={0}, ctx.name={1}",
new Object[]{grpcHandler.getAuthority(), ctx.name()}); new Object[]{grpcHandler.getAuthority(), ctx.name()});
ChannelHandler handler = ChannelHandler handler =
InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler);
@ -241,13 +241,14 @@ public final class SecurityProtocolNegotiators {
} }
} }
private static final class ServerSdsProtocolNegotiator implements ProtocolNegotiator { private static final class ServerSecurityProtocolNegotiator implements ProtocolNegotiator {
@Nullable private final ProtocolNegotiator fallbackProtocolNegotiator; @Nullable private final ProtocolNegotiator fallbackProtocolNegotiator;
/** Constructor. */ /** Constructor. */
@VisibleForTesting @VisibleForTesting
public ServerSdsProtocolNegotiator(@Nullable ProtocolNegotiator fallbackProtocolNegotiator) { public ServerSecurityProtocolNegotiator(
@Nullable ProtocolNegotiator fallbackProtocolNegotiator) {
this.fallbackProtocolNegotiator = fallbackProtocolNegotiator; this.fallbackProtocolNegotiator = fallbackProtocolNegotiator;
} }
@ -306,7 +307,7 @@ public final class SecurityProtocolNegotiators {
.replace( .replace(
this, this,
null, null,
new ServerSdsHandler( new ServerSecurityHandler(
grpcHandler, sslContextProviderSupplier)); grpcHandler, sslContextProviderSupplier));
ctx.fireUserEventTriggered(pne); ctx.fireUserEventTriggered(pne);
return; return;
@ -318,12 +319,12 @@ public final class SecurityProtocolNegotiators {
} }
@VisibleForTesting @VisibleForTesting
static final class ServerSdsHandler static final class ServerSecurityHandler
extends InternalProtocolNegotiators.ProtocolNegotiationHandler { extends InternalProtocolNegotiators.ProtocolNegotiationHandler {
private final GrpcHttp2ConnectionHandler grpcHandler; private final GrpcHttp2ConnectionHandler grpcHandler;
private final SslContextProviderSupplier sslContextProviderSupplier; private final SslContextProviderSupplier sslContextProviderSupplier;
ServerSdsHandler( ServerSecurityHandler(
GrpcHttp2ConnectionHandler grpcHandler, GrpcHttp2ConnectionHandler grpcHandler,
SslContextProviderSupplier sslContextProviderSupplier) { SslContextProviderSupplier sslContextProviderSupplier) {
super( super(

View File

@ -73,7 +73,7 @@ public final class XdsTrustManagerFactory extends SimpleTrustManagerFactory {
certificateValidationContext == null || !certificateValidationContext.hasTrustedCa(), certificateValidationContext == null || !certificateValidationContext.hasTrustedCa(),
"only static certificateValidationContext expected"); "only static certificateValidationContext expected");
} }
xdsX509TrustManager = createSdsX509TrustManager(certs, certificateValidationContext); xdsX509TrustManager = createX509TrustManager(certs, certificateValidationContext);
} }
private static X509Certificate[] getTrustedCaFromCertContext( private static X509Certificate[] getTrustedCaFromCertContext(
@ -98,7 +98,7 @@ public final class XdsTrustManagerFactory extends SimpleTrustManagerFactory {
} }
@VisibleForTesting @VisibleForTesting
static XdsX509TrustManager createSdsX509TrustManager( static XdsX509TrustManager createX509TrustManager(
X509Certificate[] certs, CertificateValidationContext certContext) throws CertStoreException { X509Certificate[] certs, CertificateValidationContext certContext) throws CertStoreException {
TrustManagerFactory tmf = null; TrustManagerFactory tmf = null;
try { try {
@ -115,7 +115,7 @@ public final class XdsTrustManagerFactory extends SimpleTrustManagerFactory {
} }
tmf.init(ks); tmf.init(ks);
} catch (NoSuchAlgorithmException | KeyStoreException | IOException | CertificateException e) { } catch (NoSuchAlgorithmException | KeyStoreException | IOException | CertificateException e) {
logger.log(Level.SEVERE, "createSdsX509TrustManager", e); logger.log(Level.SEVERE, "createX509TrustManager", e);
throw new CertStoreException(e); throw new CertStoreException(e);
} }
TrustManager[] tms = tmf.getTrustManagers(); TrustManager[] tms = tmf.getTrustManagers();

View File

@ -88,7 +88,7 @@ import org.junit.runners.JUnit4;
* modes. * modes.
*/ */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class XdsSdsClientServerTest { public class XdsSecurityClientServerTest {
@Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
private int port; private int port;
@ -356,7 +356,7 @@ public class XdsSdsClientServerTest {
xdsClient.deliverLdsUpdate(listenerUpdate); xdsClient.deliverLdsUpdate(listenerUpdate);
startFuture.get(10, TimeUnit.SECONDS); startFuture.get(10, TimeUnit.SECONDS);
port = xdsServer.getPort(); port = xdsServer.getPort();
URI expectedUri = new URI("sdstest://localhost:" + port); URI expectedUri = new URI("sectest://localhost:" + port);
fakeNameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri).build(); fakeNameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri).build();
NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverFactory); NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverFactory);
} }
@ -400,7 +400,7 @@ public class XdsSdsClientServerTest {
throws URISyntaxException { throws URISyntaxException {
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder<?> channelBuilder =
Grpc.newChannelBuilder( Grpc.newChannelBuilder(
"sdstest://localhost:" + port, "sectest://localhost:" + port,
XdsChannelCredentials.create(InsecureChannelCredentials.create())); XdsChannelCredentials.create(InsecureChannelCredentials.create()));
if (overrideAuthority != null) { if (overrideAuthority != null) {
@ -486,7 +486,7 @@ public class XdsSdsClientServerTest {
@Override @Override
public String getDefaultScheme() { public String getDefaultScheme() {
return "sdstest"; return "sectest";
} }
@Override @Override

View File

@ -54,7 +54,7 @@ import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
// TODO (zivy@): move certain tests down to XdsServerWrapperTest, or up to XdsSdsClientServerTest. // TODO (zivy@): move certain tests down to XdsServerWrapperTest or to XdsSecurityClientServerTest
/** /**
* Unit tests for {@link XdsServerBuilder}. * Unit tests for {@link XdsServerBuilder}.
*/ */

View File

@ -51,8 +51,8 @@ import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.InternalXdsAttributes; import io.grpc.xds.InternalXdsAttributes;
import io.grpc.xds.TlsContextManager; import io.grpc.xds.TlsContextManager;
import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSdsHandler; import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSecurityHandler;
import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSdsProtocolNegotiator; import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSecurityProtocolNegotiator;
import io.grpc.xds.internal.security.certprovider.CommonCertProviderTestUtils; import io.grpc.xds.internal.security.certprovider.CommonCertProviderTestUtils;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
@ -95,20 +95,21 @@ public class SecurityProtocolNegotiatorsTest {
private ChannelHandlerContext channelHandlerCtx; private ChannelHandlerContext channelHandlerCtx;
@Test @Test
public void clientSdsProtocolNegotiatorNewHandler_noTlsContextAttribute() { public void clientSecurityProtocolNegotiatorNewHandler_noTlsContextAttribute() {
ChannelHandler mockChannelHandler = mock(ChannelHandler.class); ChannelHandler mockChannelHandler = mock(ChannelHandler.class);
ProtocolNegotiator mockProtocolNegotiator = mock(ProtocolNegotiator.class); ProtocolNegotiator mockProtocolNegotiator = mock(ProtocolNegotiator.class);
when(mockProtocolNegotiator.newHandler(grpcHandler)).thenReturn(mockChannelHandler); when(mockProtocolNegotiator.newHandler(grpcHandler)).thenReturn(mockChannelHandler);
ClientSdsProtocolNegotiator pn = new ClientSdsProtocolNegotiator(mockProtocolNegotiator); ClientSecurityProtocolNegotiator pn
= new ClientSecurityProtocolNegotiator(mockProtocolNegotiator);
ChannelHandler newHandler = pn.newHandler(grpcHandler); ChannelHandler newHandler = pn.newHandler(grpcHandler);
assertThat(newHandler).isNotNull(); assertThat(newHandler).isNotNull();
assertThat(newHandler).isSameInstanceAs(mockChannelHandler); assertThat(newHandler).isSameInstanceAs(mockChannelHandler);
} }
@Test @Test
public void clientSdsProtocolNegotiatorNewHandler_noFallback_expectException() { public void clientSecurityProtocolNegotiatorNewHandler_noFallback_expectException() {
ClientSdsProtocolNegotiator pn = ClientSecurityProtocolNegotiator pn =
new ClientSdsProtocolNegotiator(/* fallbackProtocolNegotiator= */ null); new ClientSecurityProtocolNegotiator(/* fallbackProtocolNegotiator= */ null);
try { try {
pn.newHandler(grpcHandler); pn.newHandler(grpcHandler);
fail("exception expected!"); fail("exception expected!");
@ -120,11 +121,11 @@ public class SecurityProtocolNegotiatorsTest {
} }
@Test @Test
public void clientSdsProtocolNegotiatorNewHandler_withTlsContextAttribute() { public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() {
UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build()); CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build());
ClientSdsProtocolNegotiator pn = ClientSecurityProtocolNegotiator pn =
new ClientSdsProtocolNegotiator(InternalProtocolNegotiators.plaintext()); new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext());
GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class);
ChannelLogger logger = mock(ChannelLogger.class); ChannelLogger logger = mock(ChannelLogger.class);
doNothing().when(logger).log(any(ChannelLogLevel.class), anyString()); doNothing().when(logger).log(any(ChannelLogLevel.class), anyString());
@ -138,11 +139,11 @@ public class SecurityProtocolNegotiatorsTest {
.build()); .build());
ChannelHandler newHandler = pn.newHandler(mockHandler); ChannelHandler newHandler = pn.newHandler(mockHandler);
assertThat(newHandler).isNotNull(); assertThat(newHandler).isNotNull();
assertThat(newHandler).isInstanceOf(ClientSdsHandler.class); assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class);
} }
@Test @Test
public void clientSdsHandler_addLast() public void clientSecurityHandler_addLast()
throws InterruptedException, TimeoutException, ExecutionException { throws InterruptedException, TimeoutException, ExecutionException {
FakeClock executor = new FakeClock(); FakeClock executor = new FakeClock();
CommonCertProviderTestUtils.register(executor); CommonCertProviderTestUtils.register(executor);
@ -156,11 +157,11 @@ public class SecurityProtocolNegotiatorsTest {
SslContextProviderSupplier sslContextProviderSupplier = SslContextProviderSupplier sslContextProviderSupplier =
new SslContextProviderSupplier(upstreamTlsContext, new SslContextProviderSupplier(upstreamTlsContext,
new TlsContextManagerImpl(bootstrapInfoForClient)); new TlsContextManagerImpl(bootstrapInfoForClient));
SecurityProtocolNegotiators.ClientSdsHandler clientSdsHandler = ClientSecurityHandler clientSecurityHandler =
new SecurityProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier); new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier);
pipeline.addLast(clientSdsHandler); pipeline.addLast(clientSecurityHandler);
channelHandlerCtx = pipeline.context(clientSdsHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler);
assertNotNull(channelHandlerCtx); // clientSdsHandler ctx is non-null since we just added it assertNotNull(channelHandlerCtx);
// kick off protocol negotiation. // kick off protocol negotiation.
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
@ -182,7 +183,7 @@ public class SecurityProtocolNegotiatorsTest {
Object fromFuture = future.get(2, TimeUnit.SECONDS); Object fromFuture = future.get(2, TimeUnit.SECONDS);
assertThat(fromFuture).isInstanceOf(SslContext.class); assertThat(fromFuture).isInstanceOf(SslContext.class);
channel.runPendingTasks(); channel.runPendingTasks();
channelHandlerCtx = pipeline.context(clientSdsHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler);
assertThat(channelHandlerCtx).isNull(); assertThat(channelHandlerCtx).isNull();
// pipeline should have SslHandler and ClientTlsHandler // pipeline should have SslHandler and ClientTlsHandler
@ -195,7 +196,7 @@ public class SecurityProtocolNegotiatorsTest {
} }
@Test @Test
public void serverSdsHandler_addLast() public void serverSecurityHandler_addLast()
throws InterruptedException, TimeoutException, ExecutionException { throws InterruptedException, TimeoutException, ExecutionException {
FakeClock executor = new FakeClock(); FakeClock executor = new FakeClock();
CommonCertProviderTestUtils.register(executor); CommonCertProviderTestUtils.register(executor);
@ -228,7 +229,7 @@ public class SecurityProtocolNegotiatorsTest {
channelHandlerCtx = pipeline.context(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler);
assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler
// kick off protocol negotiation: should replace HandlerPickerHandler with ServerSdsHandler // kick off protocol negotiation: should replace HandlerPickerHandler with ServerSecurityHandler
ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault(); ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault();
Attributes attr = InternalProtocolNegotiationEvent.getAttributes(event) Attributes attr = InternalProtocolNegotiationEvent.getAttributes(event)
.toBuilder().set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER, .toBuilder().set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER,
@ -236,7 +237,7 @@ public class SecurityProtocolNegotiatorsTest {
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.withAttributes(event, attr)); pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.withAttributes(event, attr));
channelHandlerCtx = pipeline.context(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler);
assertThat(channelHandlerCtx).isNull(); assertThat(channelHandlerCtx).isNull();
channelHandlerCtx = pipeline.context(SecurityProtocolNegotiators.ServerSdsHandler.class); channelHandlerCtx = pipeline.context(SecurityProtocolNegotiators.ServerSecurityHandler.class);
assertThat(channelHandlerCtx).isNotNull(); assertThat(channelHandlerCtx).isNotNull();
SslContextProviderSupplier sslContextProviderSupplier = SslContextProviderSupplier sslContextProviderSupplier =
@ -259,7 +260,7 @@ public class SecurityProtocolNegotiatorsTest {
Object fromFuture = future.get(2, TimeUnit.SECONDS); Object fromFuture = future.get(2, TimeUnit.SECONDS);
assertThat(fromFuture).isInstanceOf(SslContext.class); assertThat(fromFuture).isInstanceOf(SslContext.class);
channel.runPendingTasks(); channel.runPendingTasks();
channelHandlerCtx = pipeline.context(SecurityProtocolNegotiators.ServerSdsHandler.class); channelHandlerCtx = pipeline.context(SecurityProtocolNegotiators.ServerSecurityHandler.class);
assertThat(channelHandlerCtx).isNull(); assertThat(channelHandlerCtx).isNull();
// pipeline should only have SslHandler and ServerTlsHandler // pipeline should only have SslHandler and ServerTlsHandler
@ -272,7 +273,7 @@ public class SecurityProtocolNegotiatorsTest {
} }
@Test @Test
public void serverSdsHandler_defaultDownstreamTlsContext_expectFallbackProtocolNegotiator() public void serverSecurityHandler_defaultDownstreamTlsContext_expectFallbackProtocolNegotiator()
throws IOException { throws IOException {
ChannelHandler mockChannelHandler = mock(ChannelHandler.class); ChannelHandler mockChannelHandler = mock(ChannelHandler.class);
ProtocolNegotiator mockProtocolNegotiator = mock(ProtocolNegotiator.class); ProtocolNegotiator mockProtocolNegotiator = mock(ProtocolNegotiator.class);
@ -294,7 +295,7 @@ public class SecurityProtocolNegotiatorsTest {
channelHandlerCtx = pipeline.context(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler);
assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler
// kick off protocol negotiation: should replace HandlerPickerHandler with ServerSdsHandler // kick off protocol negotiation: should replace HandlerPickerHandler with ServerSecurityHandler
ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault(); ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault();
Attributes attr = InternalProtocolNegotiationEvent.getAttributes(event) Attributes attr = InternalProtocolNegotiationEvent.getAttributes(event)
.toBuilder().set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER, null).build(); .toBuilder().set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER, null).build();
@ -309,7 +310,7 @@ public class SecurityProtocolNegotiatorsTest {
} }
@Test @Test
public void serverSdsHandler_nullTlsContext_expectFallbackProtocolNegotiator() { public void serverSecurityHandler_nullTlsContext_expectFallbackProtocolNegotiator() {
ChannelHandler mockChannelHandler = mock(ChannelHandler.class); ChannelHandler mockChannelHandler = mock(ChannelHandler.class);
ProtocolNegotiator mockProtocolNegotiator = mock(ProtocolNegotiator.class); ProtocolNegotiator mockProtocolNegotiator = mock(ProtocolNegotiator.class);
when(mockProtocolNegotiator.newHandler(grpcHandler)).thenReturn(mockChannelHandler); when(mockProtocolNegotiator.newHandler(grpcHandler)).thenReturn(mockChannelHandler);
@ -354,7 +355,7 @@ public class SecurityProtocolNegotiatorsTest {
} }
@Test @Test
public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent()
throws InterruptedException, TimeoutException, ExecutionException { throws InterruptedException, TimeoutException, ExecutionException {
FakeClock executor = new FakeClock(); FakeClock executor = new FakeClock();
CommonCertProviderTestUtils.register(executor); CommonCertProviderTestUtils.register(executor);
@ -368,11 +369,11 @@ public class SecurityProtocolNegotiatorsTest {
SslContextProviderSupplier sslContextProviderSupplier = SslContextProviderSupplier sslContextProviderSupplier =
new SslContextProviderSupplier(upstreamTlsContext, new SslContextProviderSupplier(upstreamTlsContext,
new TlsContextManagerImpl(bootstrapInfoForClient)); new TlsContextManagerImpl(bootstrapInfoForClient));
SecurityProtocolNegotiators.ClientSdsHandler clientSdsHandler = ClientSecurityHandler clientSecurityHandler =
new SecurityProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier); new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier);
pipeline.addLast(clientSdsHandler); pipeline.addLast(clientSecurityHandler);
channelHandlerCtx = pipeline.context(clientSdsHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler);
assertNotNull(channelHandlerCtx); // non-null since we just added it assertNotNull(channelHandlerCtx); // non-null since we just added it
// kick off protocol negotiation. // kick off protocol negotiation.
@ -395,7 +396,7 @@ public class SecurityProtocolNegotiatorsTest {
Object fromFuture = future.get(5, TimeUnit.SECONDS); Object fromFuture = future.get(5, TimeUnit.SECONDS);
assertThat(fromFuture).isInstanceOf(SslContext.class); assertThat(fromFuture).isInstanceOf(SslContext.class);
channel.runPendingTasks(); channel.runPendingTasks();
channelHandlerCtx = pipeline.context(clientSdsHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler);
assertThat(channelHandlerCtx).isNull(); assertThat(channelHandlerCtx).isNull();
Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
@ -406,7 +407,7 @@ public class SecurityProtocolNegotiatorsTest {
} }
@Test @Test
public void clientSdsProtocolNegotiatorNewHandler_handleHandlerRemoved() { public void clientSecurityProtocolNegotiatorNewHandler_handleHandlerRemoved() {
FakeClock executor = new FakeClock(); FakeClock executor = new FakeClock();
CommonCertProviderTestUtils.register(executor); CommonCertProviderTestUtils.register(executor);
Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils
@ -419,17 +420,17 @@ public class SecurityProtocolNegotiatorsTest {
SslContextProviderSupplier sslContextProviderSupplier = SslContextProviderSupplier sslContextProviderSupplier =
new SslContextProviderSupplier(upstreamTlsContext, new SslContextProviderSupplier(upstreamTlsContext,
new TlsContextManagerImpl(bootstrapInfoForClient)); new TlsContextManagerImpl(bootstrapInfoForClient));
SecurityProtocolNegotiators.ClientSdsHandler clientSdsHandler = ClientSecurityHandler clientSecurityHandler =
new SecurityProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier); new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier);
pipeline.addLast(clientSdsHandler); pipeline.addLast(clientSecurityHandler);
channelHandlerCtx = pipeline.context(clientSdsHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler);
// kick off protocol negotiation. // kick off protocol negotiation.
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
executor.runDueTasks(); executor.runDueTasks();
pipeline.remove(clientSdsHandler); pipeline.remove(clientSecurityHandler);
channel.runPendingTasks(); channel.runPendingTasks();
channel.checkException(); channel.checkException();
CommonCertProviderTestUtils.register0(); CommonCertProviderTestUtils.register0();

View File

@ -629,7 +629,7 @@ public class XdsX509TrustManagerTest {
throws CertificateException, IOException, CertStoreException { throws CertificateException, IOException, CertStoreException {
X509Certificate[] caCerts = X509Certificate[] caCerts =
CertificateUtils.toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE)); CertificateUtils.toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE));
trustManager = XdsTrustManagerFactory.createSdsX509TrustManager(caCerts, trustManager = XdsTrustManagerFactory.createX509TrustManager(caCerts,
null); null);
when(mockSession.getProtocol()).thenReturn("TLSv1.2"); when(mockSession.getProtocol()).thenReturn("TLSv1.2");
when(mockSession.getPeerHost()).thenReturn("peer-host-from-mock"); when(mockSession.getPeerHost()).thenReturn("peer-host-from-mock");