xds: extend SslContextProviderSupplier to DowmstreamTlsContext for server side (#8146)

This commit is contained in:
sanjaypujare 2021-05-04 22:19:15 -07:00 committed by GitHub
parent 27b1641653
commit c9e327d42f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 20 deletions

View File

@ -264,7 +264,7 @@ final class ClusterImplLoadBalancer extends LoadBalancer {
private void updateSslContextProviderSupplier(@Nullable UpstreamTlsContext tlsContext) { private void updateSslContextProviderSupplier(@Nullable UpstreamTlsContext tlsContext) {
UpstreamTlsContext currentTlsContext = UpstreamTlsContext currentTlsContext =
sslContextProviderSupplier != null sslContextProviderSupplier != null
? sslContextProviderSupplier.getUpstreamTlsContext() ? (UpstreamTlsContext)sslContextProviderSupplier.getTlsContext()
: null; : null;
if (Objects.equals(currentTlsContext, tlsContext)) { if (Objects.equals(currentTlsContext, tlsContext)) {
return; return;

View File

@ -19,31 +19,33 @@ package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
/** /**
* Enables the CDS policy to initialize this object with the received {@link UpstreamTlsContext} & * Enables Client or server side to initialize this object with the received {@link BaseTlsContext}
* communicate it to the consumer i.e. {@link SdsProtocolNegotiators.ClientSdsProtocolNegotiator} * and communicate it to the consumer i.e. {@link SdsProtocolNegotiators}
* to lazily evaluate the {@link SslContextProvider}. The supplier prevents credentials leakage in * to lazily evaluate the {@link SslContextProvider}. The supplier prevents credentials leakage in
* cases where the user is not using xDS credentials but the CDS policy contains a non-default * cases where the user is not using xDS credentials but the client/server contains a non-default
* {@link UpstreamTlsContext}. * {@link BaseTlsContext}.
*/ */
public final class SslContextProviderSupplier implements Closeable { public final class SslContextProviderSupplier implements Closeable {
private final UpstreamTlsContext upstreamTlsContext; private final BaseTlsContext tlsContext;
private final TlsContextManager tlsContextManager; private final TlsContextManager tlsContextManager;
private SslContextProvider sslContextProvider; private SslContextProvider sslContextProvider;
private boolean shutdown; private boolean shutdown;
public SslContextProviderSupplier( public SslContextProviderSupplier(
UpstreamTlsContext upstreamTlsContext, TlsContextManager tlsContextManager) { BaseTlsContext tlsContext, TlsContextManager tlsContextManager) {
this.upstreamTlsContext = upstreamTlsContext; this.tlsContext = tlsContext;
this.tlsContextManager = tlsContextManager; this.tlsContextManager = tlsContextManager;
} }
public UpstreamTlsContext getUpstreamTlsContext() { public BaseTlsContext getTlsContext() {
return upstreamTlsContext; return tlsContext;
} }
/** Updates SslContext via the passed callback. */ /** Updates SslContext via the passed callback. */
@ -51,34 +53,48 @@ public final class SslContextProviderSupplier implements Closeable {
checkNotNull(callback, "callback"); checkNotNull(callback, "callback");
checkState(!shutdown, "Supplier is shutdown!"); checkState(!shutdown, "Supplier is shutdown!");
if (sslContextProvider == null) { if (sslContextProvider == null) {
sslContextProvider = sslContextProvider = getSslContextProvider();
tlsContextManager.findOrCreateClientSslContextProvider(upstreamTlsContext);
} }
// we want to increment the ref-count so call findOrCreate again... // we want to increment the ref-count so call findOrCreate again...
final SslContextProvider toRelease = final SslContextProvider toRelease = getSslContextProvider();
tlsContextManager.findOrCreateClientSslContextProvider(upstreamTlsContext);
sslContextProvider.addCallback( sslContextProvider.addCallback(
new SslContextProvider.Callback(callback.getExecutor()) { new SslContextProvider.Callback(callback.getExecutor()) {
@Override @Override
public void updateSecret(SslContext sslContext) { public void updateSecret(SslContext sslContext) {
callback.updateSecret(sslContext); callback.updateSecret(sslContext);
tlsContextManager.releaseClientSslContextProvider(toRelease); releaseSslContextProvider(toRelease);
} }
@Override @Override
public void onException(Throwable throwable) { public void onException(Throwable throwable) {
callback.onException(throwable); callback.onException(throwable);
tlsContextManager.releaseClientSslContextProvider(toRelease); releaseSslContextProvider(toRelease);
} }
}); });
} }
/** Called by {@link io.grpc.xds.CdsLoadBalancer} when upstreamTlsContext changes. */ private void releaseSslContextProvider(SslContextProvider toRelease) {
if (tlsContext instanceof UpstreamTlsContext) {
tlsContextManager.releaseClientSslContextProvider(toRelease);
} else {
tlsContextManager.releaseServerSslContextProvider(toRelease);
}
}
private SslContextProvider getSslContextProvider() {
return tlsContext instanceof UpstreamTlsContext
? tlsContextManager.findOrCreateClientSslContextProvider((UpstreamTlsContext) tlsContext)
: tlsContextManager.findOrCreateServerSslContextProvider((DownstreamTlsContext) tlsContext);
}
/** Called by consumer when tlsContext changes. */
@Override @Override
public synchronized void close() { public synchronized void close() {
if (sslContextProvider != null) { if (tlsContext instanceof UpstreamTlsContext) {
tlsContextManager.releaseClientSslContextProvider(sslContextProvider); tlsContextManager.releaseClientSslContextProvider(sslContextProvider);
} else {
tlsContextManager.releaseServerSslContextProvider(sslContextProvider);
} }
shutdown = true; shutdown = true;
} }

View File

@ -520,7 +520,7 @@ public class ClusterImplLoadBalancerTest {
SslContextProviderSupplier supplier = SslContextProviderSupplier supplier =
eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER);
if (enableSecurity) { if (enableSecurity) {
assertThat(supplier.getUpstreamTlsContext()).isEqualTo(upstreamTlsContext); assertThat(supplier.getTlsContext()).isEqualTo(upstreamTlsContext);
} else { } else {
assertThat(supplier).isNull(); assertThat(supplier).isNull();
} }
@ -554,7 +554,7 @@ public class ClusterImplLoadBalancerTest {
SslContextProviderSupplier supplier = SslContextProviderSupplier supplier =
eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER);
if (enableSecurity) { if (enableSecurity) {
assertThat(supplier.getUpstreamTlsContext()).isEqualTo(upstreamTlsContext); assertThat(supplier.getTlsContext()).isEqualTo(upstreamTlsContext);
} else { } else {
assertThat(supplier).isNull(); assertThat(supplier).isNull();
} }