s2a: Add S2AStub cleanup handler. (#11600)

* Add S2AStub cleanup handler.

* Give TLS and Cleanup handlers name + update comment.

* Don't add TLS handler twice.

* Don't remove explicitly, since done by fireProtocolNegotiationEvent.

* plumb S2AStub close to handshake end + add integration test.

* close stub when TLS negotiation fails.
This commit is contained in:
Riya Mehta 2024-10-10 16:31:18 -07:00 committed by GitHub
parent 2129078dee
commit d628396ec7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 134 additions and 42 deletions

View File

@ -24,6 +24,7 @@ import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
import io.netty.util.AsciiString; import io.netty.util.AsciiString;
import java.util.Optional;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
/** /**
@ -40,9 +41,10 @@ public final class InternalProtocolNegotiators {
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
*/ */
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
ObjectPool<? extends Executor> executorPool) { ObjectPool<? extends Executor> executorPool,
Optional<Runnable> handshakeCompleteRunnable) {
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext,
executorPool); executorPool, handshakeCompleteRunnable);
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
@Override @Override
@ -70,7 +72,7 @@ public final class InternalProtocolNegotiators {
* may happen immediately, even before the TLS Handshake is complete. * may happen immediately, even before the TLS Handshake is complete.
*/ */
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) { public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
return tls(sslContext, null); return tls(sslContext, null, Optional.empty());
} }
/** /**
@ -167,7 +169,8 @@ public final class InternalProtocolNegotiators {
public static ChannelHandler clientTlsHandler( public static ChannelHandler clientTlsHandler(
ChannelHandler next, SslContext sslContext, String authority, ChannelHandler next, SslContext sslContext, String authority,
ChannelLogger negotiationLogger) { ChannelLogger negotiationLogger) {
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger); return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger,
Optional.empty());
} }
public static class ProtocolNegotiationHandler public static class ProtocolNegotiationHandler

View File

@ -63,6 +63,7 @@ import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -604,7 +605,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh
case PLAINTEXT_UPGRADE: case PLAINTEXT_UPGRADE:
return ProtocolNegotiators.plaintextUpgrade(); return ProtocolNegotiators.plaintextUpgrade();
case TLS: case TLS:
return ProtocolNegotiators.tls(sslContext, executorPool); return ProtocolNegotiators.tls(sslContext, executorPool, Optional.empty());
default: default:
throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType); throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType);
} }

View File

@ -72,6 +72,7 @@ import java.net.URI;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
import java.util.Arrays; import java.util.Arrays;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.logging.Level; import java.util.logging.Level;
@ -543,16 +544,18 @@ final class ProtocolNegotiators {
static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator { static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator {
public ClientTlsProtocolNegotiator(SslContext sslContext, public ClientTlsProtocolNegotiator(SslContext sslContext,
ObjectPool<? extends Executor> executorPool) { ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
this.sslContext = checkNotNull(sslContext, "sslContext"); this.sslContext = checkNotNull(sslContext, "sslContext");
this.executorPool = executorPool; this.executorPool = executorPool;
if (this.executorPool != null) { if (this.executorPool != null) {
this.executor = this.executorPool.getObject(); this.executor = this.executorPool.getObject();
} }
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
} }
private final SslContext sslContext; private final SslContext sslContext;
private final ObjectPool<? extends Executor> executorPool; private final ObjectPool<? extends Executor> executorPool;
private final Optional<Runnable> handshakeCompleteRunnable;
private Executor executor; private Executor executor;
@Override @Override
@ -565,7 +568,7 @@ final class ProtocolNegotiators {
ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler); ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler);
ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger(); ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger();
ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(), ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(),
this.executor, negotiationLogger); this.executor, negotiationLogger, handshakeCompleteRunnable);
return new WaitUntilActiveHandler(cth, negotiationLogger); return new WaitUntilActiveHandler(cth, negotiationLogger);
} }
@ -583,15 +586,18 @@ final class ProtocolNegotiators {
private final String host; private final String host;
private final int port; private final int port;
private Executor executor; private Executor executor;
private final Optional<Runnable> handshakeCompleteRunnable;
ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority, ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority,
Executor executor, ChannelLogger negotiationLogger) { Executor executor, ChannelLogger negotiationLogger,
Optional<Runnable> handshakeCompleteRunnable) {
super(next, negotiationLogger); super(next, negotiationLogger);
this.sslContext = checkNotNull(sslContext, "sslContext"); this.sslContext = checkNotNull(sslContext, "sslContext");
HostPort hostPort = parseAuthority(authority); HostPort hostPort = parseAuthority(authority);
this.host = hostPort.host; this.host = hostPort.host;
this.port = hostPort.port; this.port = hostPort.port;
this.executor = executor; this.executor = executor;
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
} }
@Override @Override
@ -620,6 +626,9 @@ final class ProtocolNegotiators {
Exception ex = Exception ex =
unavailableException("Failed ALPN negotiation: Unable to find compatible protocol"); unavailableException("Failed ALPN negotiation: Unable to find compatible protocol");
logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed.", ex); logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed.", ex);
if (handshakeCompleteRunnable.isPresent()) {
handshakeCompleteRunnable.get().run();
}
ctx.fireExceptionCaught(ex); ctx.fireExceptionCaught(ex);
} }
} else { } else {
@ -634,6 +643,9 @@ final class ProtocolNegotiators {
.withCause(t) .withCause(t)
.asRuntimeException(); .asRuntimeException();
} }
if (handshakeCompleteRunnable.isPresent()) {
handshakeCompleteRunnable.get().run();
}
ctx.fireExceptionCaught(t); ctx.fireExceptionCaught(t);
} }
} else { } else {
@ -649,6 +661,9 @@ final class ProtocolNegotiators {
.set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session) .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
.build(); .build();
replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security)); replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security));
if (handshakeCompleteRunnable.isPresent()) {
handshakeCompleteRunnable.get().run();
}
fireProtocolNegotiationEvent(ctx); fireProtocolNegotiationEvent(ctx);
} }
} }
@ -683,8 +698,8 @@ final class ProtocolNegotiators {
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
*/ */
public static ProtocolNegotiator tls(SslContext sslContext, public static ProtocolNegotiator tls(SslContext sslContext,
ObjectPool<? extends Executor> executorPool) { ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
return new ClientTlsProtocolNegotiator(sslContext, executorPool); return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable);
} }
/** /**
@ -693,7 +708,7 @@ final class ProtocolNegotiators {
* may happen immediately, even before the TLS Handshake is complete. * may happen immediately, even before the TLS Handshake is complete.
*/ */
public static ProtocolNegotiator tls(SslContext sslContext) { public static ProtocolNegotiator tls(SslContext sslContext) {
return tls(sslContext, null); return tls(sslContext, null, Optional.empty());
} }
public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) { public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) {

View File

@ -105,6 +105,7 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -766,7 +767,8 @@ public class NettyClientTransportTest {
.trustManager(caCert) .trustManager(caCert)
.keyManager(clientCert, clientKey) .keyManager(clientCert, clientKey)
.build(); .build();
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool); ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool,
Optional.empty());
// after starting the client, the Executor in the client pool should be used // after starting the client, the Executor in the client pool should be used
assertEquals(true, clientExecutorPool.isInUse()); assertEquals(true, clientExecutorPool.isInUse());
final NettyClientTransport transport = newTransport(negotiator); final NettyClientTransport transport = newTransport(negotiator);

View File

@ -120,6 +120,7 @@ import java.util.ArrayDeque;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -876,7 +877,7 @@ public class ProtocolNegotiatorsTest {
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", elg, noopLogger); "authority", elg, noopLogger, Optional.empty());
pipeline.addLast(handler); pipeline.addLast(handler);
pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.replace(SslHandler.class, null, goodSslHandler);
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
@ -914,7 +915,7 @@ public class ProtocolNegotiatorsTest {
.applicationProtocolConfig(apn).build(); .applicationProtocolConfig(apn).build();
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", elg, noopLogger); "authority", elg, noopLogger, Optional.empty());
pipeline.addLast(handler); pipeline.addLast(handler);
pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.replace(SslHandler.class, null, goodSslHandler);
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
@ -938,7 +939,7 @@ public class ProtocolNegotiatorsTest {
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", elg, noopLogger); "authority", elg, noopLogger, Optional.empty());
pipeline.addLast(handler); pipeline.addLast(handler);
final AtomicReference<Throwable> error = new AtomicReference<>(); final AtomicReference<Throwable> error = new AtomicReference<>();
@ -966,7 +967,7 @@ public class ProtocolNegotiatorsTest {
@Test @Test
public void clientTlsHandler_closeDuringNegotiation() throws Exception { public void clientTlsHandler_closeDuringNegotiation() throws Exception {
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", null, noopLogger); "authority", null, noopLogger, Optional.empty());
pipeline.addLast(new WriteBufferingAndExceptionHandler(handler)); pipeline.addLast(new WriteBufferingAndExceptionHandler(handler));
ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE); ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
@ -1228,7 +1229,8 @@ public class ProtocolNegotiatorsTest {
serverSslContext = GrpcSslContexts.forServer(server1Chain, server1Key).build(); serverSslContext = GrpcSslContexts.forServer(server1Chain, server1Key).build();
} }
FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler(); FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, null); ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext,
null, Optional.empty());
WriteBufferingAndExceptionHandler clientWbaeh = WriteBufferingAndExceptionHandler clientWbaeh =
new WriteBufferingAndExceptionHandler(pn.newHandler(gh)); new WriteBufferingAndExceptionHandler(pn.newHandler(gh));

View File

@ -31,6 +31,7 @@ import io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel; import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel;
import io.grpc.s2a.internal.handshaker.S2AIdentity; import io.grpc.s2a.internal.handshaker.S2AIdentity;
import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory; import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory;
import io.grpc.s2a.internal.handshaker.S2AStub;
import javax.annotation.concurrent.NotThreadSafe; import javax.annotation.concurrent.NotThreadSafe;
import org.checkerframework.checker.nullness.qual.Nullable; import org.checkerframework.checker.nullness.qual.Nullable;
@ -59,6 +60,7 @@ public final class S2AChannelCredentials {
private final String s2aAddress; private final String s2aAddress;
private final ChannelCredentials s2aChannelCredentials; private final ChannelCredentials s2aChannelCredentials;
private @Nullable S2AIdentity localIdentity = null; private @Nullable S2AIdentity localIdentity = null;
private @Nullable S2AStub stub = null;
Builder(String s2aAddress, ChannelCredentials s2aChannelCredentials) { Builder(String s2aAddress, ChannelCredentials s2aChannelCredentials) {
this.s2aAddress = s2aAddress; this.s2aAddress = s2aAddress;
@ -104,6 +106,16 @@ public final class S2AChannelCredentials {
return this; return this;
} }
/**
* Sets the stub to use to communicate with S2A. This is only used for testing that the
* stream to S2A gets closed.
*/
public Builder setStub(S2AStub stub) {
checkNotNull(stub);
this.stub = stub;
return this;
}
public ChannelCredentials build() { public ChannelCredentials build() {
return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory()); return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory());
} }
@ -113,7 +125,7 @@ public final class S2AChannelCredentials {
SharedResourcePool.forResource( SharedResourcePool.forResource(
S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials)); S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials));
checkNotNull(s2aChannelPool, "s2aChannelPool"); checkNotNull(s2aChannelPool, "s2aChannelPool");
return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool); return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool, stub);
} }
} }

View File

@ -63,28 +63,35 @@ public final class S2AProtocolNegotiatorFactory {
* @param localIdentity the identity of the client; if none is provided, the S2A will use the * @param localIdentity the identity of the client; if none is provided, the S2A will use the
* client's default identity. * client's default identity.
* @param s2aChannelPool a pool of shared channels that can be used to connect to the S2A. * @param s2aChannelPool a pool of shared channels that can be used to connect to the S2A.
* @param stub the stub to use to communicate with S2A. If none is provided the channelPool
* will be used to create the stub. This is exposed for verifying the stream to S2A gets
* closed in tests.
* @return a factory for creating a client-side protocol negotiator. * @return a factory for creating a client-side protocol negotiator.
*/ */
public static InternalProtocolNegotiator.ClientFactory createClientFactory( public static InternalProtocolNegotiator.ClientFactory createClientFactory(
@Nullable S2AIdentity localIdentity, ObjectPool<Channel> s2aChannelPool) { @Nullable S2AIdentity localIdentity, ObjectPool<Channel> s2aChannelPool,
@Nullable S2AStub stub) {
checkNotNull(s2aChannelPool, "S2A channel pool should not be null."); checkNotNull(s2aChannelPool, "S2A channel pool should not be null.");
return new S2AClientProtocolNegotiatorFactory(localIdentity, s2aChannelPool); return new S2AClientProtocolNegotiatorFactory(localIdentity, s2aChannelPool, stub);
} }
static final class S2AClientProtocolNegotiatorFactory static final class S2AClientProtocolNegotiatorFactory
implements InternalProtocolNegotiator.ClientFactory { implements InternalProtocolNegotiator.ClientFactory {
private final @Nullable S2AIdentity localIdentity; private final @Nullable S2AIdentity localIdentity;
private final ObjectPool<Channel> channelPool; private final ObjectPool<Channel> channelPool;
private final @Nullable S2AStub stub;
S2AClientProtocolNegotiatorFactory( S2AClientProtocolNegotiatorFactory(
@Nullable S2AIdentity localIdentity, ObjectPool<Channel> channelPool) { @Nullable S2AIdentity localIdentity, ObjectPool<Channel> channelPool,
@Nullable S2AStub stub) {
this.localIdentity = localIdentity; this.localIdentity = localIdentity;
this.channelPool = channelPool; this.channelPool = channelPool;
this.stub = stub;
} }
@Override @Override
public ProtocolNegotiator newNegotiator() { public ProtocolNegotiator newNegotiator() {
return S2AProtocolNegotiator.createForClient(channelPool, localIdentity); return S2AProtocolNegotiator.createForClient(channelPool, localIdentity, stub);
} }
@Override @Override
@ -98,18 +105,20 @@ public final class S2AProtocolNegotiatorFactory {
static final class S2AProtocolNegotiator implements ProtocolNegotiator { static final class S2AProtocolNegotiator implements ProtocolNegotiator {
private final ObjectPool<Channel> channelPool; private final ObjectPool<Channel> channelPool;
private final Channel channel; private @Nullable Channel channel = null;
private final Optional<S2AIdentity> localIdentity; private final Optional<S2AIdentity> localIdentity;
private final @Nullable S2AStub stub;
private final ListeningExecutorService service = private final ListeningExecutorService service =
MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1)); MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1));
static S2AProtocolNegotiator createForClient( static S2AProtocolNegotiator createForClient(
ObjectPool<Channel> channelPool, @Nullable S2AIdentity localIdentity) { ObjectPool<Channel> channelPool, @Nullable S2AIdentity localIdentity,
@Nullable S2AStub stub) {
checkNotNull(channelPool, "Channel pool should not be null."); checkNotNull(channelPool, "Channel pool should not be null.");
if (localIdentity == null) { if (localIdentity == null) {
return new S2AProtocolNegotiator(channelPool, Optional.empty()); return new S2AProtocolNegotiator(channelPool, Optional.empty(), stub);
} else { } else {
return new S2AProtocolNegotiator(channelPool, Optional.of(localIdentity)); return new S2AProtocolNegotiator(channelPool, Optional.of(localIdentity), stub);
} }
} }
@ -122,10 +131,13 @@ public final class S2AProtocolNegotiatorFactory {
} }
private S2AProtocolNegotiator(ObjectPool<Channel> channelPool, private S2AProtocolNegotiator(ObjectPool<Channel> channelPool,
Optional<S2AIdentity> localIdentity) { Optional<S2AIdentity> localIdentity, @Nullable S2AStub stub) {
this.channelPool = channelPool; this.channelPool = channelPool;
this.localIdentity = localIdentity; this.localIdentity = localIdentity;
this.channel = channelPool.getObject(); this.stub = stub;
if (this.stub == null) {
this.channel = channelPool.getObject();
}
} }
@Override @Override
@ -139,13 +151,15 @@ public final class S2AProtocolNegotiatorFactory {
String hostname = getHostNameFromAuthority(grpcHandler.getAuthority()); String hostname = getHostNameFromAuthority(grpcHandler.getAuthority());
checkArgument(!isNullOrEmpty(hostname), "hostname should not be null or empty."); checkArgument(!isNullOrEmpty(hostname), "hostname should not be null or empty.");
return new S2AProtocolNegotiationHandler( return new S2AProtocolNegotiationHandler(
grpcHandler, channel, localIdentity, hostname, service); grpcHandler, channel, localIdentity, hostname, service, stub);
} }
@Override @Override
public void close() { public void close() {
service.shutdown(); service.shutdown();
channelPool.returnObject(channel); if (channel != null) {
channelPool.returnObject(channel);
}
} }
} }
@ -180,18 +194,20 @@ public final class S2AProtocolNegotiatorFactory {
} }
private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler { private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler {
private final Channel channel; private final @Nullable Channel channel;
private final Optional<S2AIdentity> localIdentity; private final Optional<S2AIdentity> localIdentity;
private final String hostname; private final String hostname;
private final GrpcHttp2ConnectionHandler grpcHandler; private final GrpcHttp2ConnectionHandler grpcHandler;
private final ListeningExecutorService service; private final ListeningExecutorService service;
private final @Nullable S2AStub stub;
private S2AProtocolNegotiationHandler( private S2AProtocolNegotiationHandler(
GrpcHttp2ConnectionHandler grpcHandler, GrpcHttp2ConnectionHandler grpcHandler,
Channel channel, Channel channel,
Optional<S2AIdentity> localIdentity, Optional<S2AIdentity> localIdentity,
String hostname, String hostname,
ListeningExecutorService service) { ListeningExecutorService service,
@Nullable S2AStub stub) {
super( super(
// superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next'
// handler but we don't have a next handler _yet_. So we "disable" superclass's behavior // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior
@ -209,6 +225,7 @@ public final class S2AProtocolNegotiatorFactory {
this.hostname = hostname; this.hostname = hostname;
checkNotNull(service, "service should not be null."); checkNotNull(service, "service should not be null.");
this.service = service; this.service = service;
this.stub = stub;
} }
@Override @Override
@ -217,8 +234,13 @@ public final class S2AProtocolNegotiatorFactory {
BufferReadsHandler bufferReads = new BufferReadsHandler(); BufferReadsHandler bufferReads = new BufferReadsHandler();
ctx.pipeline().addBefore(ctx.name(), /* name= */ null, bufferReads); ctx.pipeline().addBefore(ctx.name(), /* name= */ null, bufferReads);
S2AServiceGrpc.S2AServiceStub stub = S2AServiceGrpc.newStub(channel); S2AStub s2aStub;
S2AStub s2aStub = S2AStub.newInstance(stub); if (this.stub == null) {
checkNotNull(channel, "Channel to S2A should not be null");
s2aStub = S2AStub.newInstance(S2AServiceGrpc.newStub(channel));
} else {
s2aStub = this.stub;
}
ListenableFuture<SslContext> sslContextFuture = ListenableFuture<SslContext> sslContextFuture =
service.submit(() -> SslContextFactory.createForClient(s2aStub, hostname, localIdentity)); service.submit(() -> SslContextFactory.createForClient(s2aStub, hostname, localIdentity));
@ -230,11 +252,17 @@ public final class S2AProtocolNegotiatorFactory {
ChannelHandler handler = ChannelHandler handler =
InternalProtocolNegotiators.tls( InternalProtocolNegotiators.tls(
sslContext, sslContext,
SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR)) SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR),
Optional.of(new Runnable() {
@Override
public void run() {
s2aStub.close();
}
}))
.newHandler(grpcHandler); .newHandler(grpcHandler);
// Remove the bufferReads handler and delegate the rest of the handshake to the TLS // Delegate the rest of the handshake to the TLS handler. and remove the
// handler. // bufferReads handler.
ctx.pipeline().addAfter(ctx.name(), /* name= */ null, handler); ctx.pipeline().addAfter(ctx.name(), /* name= */ null, handler);
fireProtocolNegotiationEvent(ctx); fireProtocolNegotiationEvent(ctx);
ctx.pipeline().remove(bufferReads); ctx.pipeline().remove(bufferReads);

View File

@ -33,7 +33,7 @@ import javax.annotation.concurrent.NotThreadSafe;
/** Reads and writes messages to and from the S2A. */ /** Reads and writes messages to and from the S2A. */
@NotThreadSafe @NotThreadSafe
class S2AStub implements AutoCloseable { public class S2AStub implements AutoCloseable {
private static final Logger logger = Logger.getLogger(S2AStub.class.getName()); private static final Logger logger = Logger.getLogger(S2AStub.class.getName());
private static final long HANDSHAKE_RPC_DEADLINE_SECS = 20; private static final long HANDSHAKE_RPC_DEADLINE_SECS = 20;
private final StreamObserver<SessionResp> reader = new Reader(); private final StreamObserver<SessionResp> reader = new Reader();
@ -42,6 +42,7 @@ class S2AStub implements AutoCloseable {
private StreamObserver<SessionReq> writer; private StreamObserver<SessionReq> writer;
private boolean doneReading = false; private boolean doneReading = false;
private boolean doneWriting = false; private boolean doneWriting = false;
private boolean isClosed = false;
static S2AStub newInstance(S2AServiceGrpc.S2AServiceStub serviceStub) { static S2AStub newInstance(S2AServiceGrpc.S2AServiceStub serviceStub) {
checkNotNull(serviceStub); checkNotNull(serviceStub);
@ -136,6 +137,11 @@ class S2AStub implements AutoCloseable {
if (writer != null) { if (writer != null) {
writer.onCompleted(); writer.onCompleted();
} }
isClosed = true;
}
public boolean isClosed() {
return isClosed;
} }
/** Create a new writer if the writer is null. */ /** Create a new writer if the writer is null. */

View File

@ -19,6 +19,7 @@ package io.grpc.s2a.internal.handshaker;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static java.util.concurrent.TimeUnit.SECONDS; import static java.util.concurrent.TimeUnit.SECONDS;
import io.grpc.Channel;
import io.grpc.ChannelCredentials; import io.grpc.ChannelCredentials;
import io.grpc.Grpc; import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials; import io.grpc.InsecureChannelCredentials;
@ -29,9 +30,12 @@ import io.grpc.ServerCredentials;
import io.grpc.TlsChannelCredentials; import io.grpc.TlsChannelCredentials;
import io.grpc.TlsServerCredentials; import io.grpc.TlsServerCredentials;
import io.grpc.benchmarks.Utils; import io.grpc.benchmarks.Utils;
import io.grpc.internal.ObjectPool;
import io.grpc.internal.SharedResourcePool;
import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyServerBuilder; import io.grpc.netty.NettyServerBuilder;
import io.grpc.s2a.S2AChannelCredentials; import io.grpc.s2a.S2AChannelCredentials;
import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel;
import io.grpc.s2a.internal.handshaker.FakeS2AServer; import io.grpc.s2a.internal.handshaker.FakeS2AServer;
import io.grpc.stub.StreamObserver; import io.grpc.stub.StreamObserver;
import io.grpc.testing.protobuf.SimpleRequest; import io.grpc.testing.protobuf.SimpleRequest;
@ -141,6 +145,25 @@ public final class IntegrationTest {
assertThat(doUnaryRpc(channel)).isTrue(); assertThat(doUnaryRpc(channel)).isTrue();
} }
@Test
public void clientCommunicateUsingS2ACredentialsSucceeds_verifyStreamToS2AClosed()
throws Exception {
ObjectPool<Channel> s2aChannelPool =
SharedResourcePool.forResource(
S2AHandshakerServiceChannel.getChannelResource(s2aAddress,
InsecureChannelCredentials.create()));
Channel ch = s2aChannelPool.getObject();
S2AStub stub = S2AStub.newInstance(S2AServiceGrpc.newStub(ch));
ChannelCredentials credentials =
S2AChannelCredentials.newBuilder(s2aAddress, InsecureChannelCredentials.create())
.setLocalSpiffeId("test-spiffe-id").setStub(stub).build();
ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build();
s2aChannelPool.returnObject(ch);
assertThat(doUnaryRpc(channel)).isTrue();
assertThat(stub.isClosed()).isTrue();
}
@Test @Test
public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Exception { public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Exception {
String privateKeyPath = "src/test/resources/client_key.pem"; String privateKeyPath = "src/test/resources/client_key.pem";

View File

@ -122,7 +122,7 @@ public class S2AProtocolNegotiatorFactoryTest {
@Test @Test
public void createProtocolNegotiatorFactory_getsDefaultPort_succeeds() throws Exception { public void createProtocolNegotiatorFactory_getsDefaultPort_succeeds() throws Exception {
InternalProtocolNegotiator.ClientFactory clientFactory = InternalProtocolNegotiator.ClientFactory clientFactory =
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool); S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null);
assertThat(clientFactory.getDefaultPort()).isEqualTo(S2AProtocolNegotiatorFactory.DEFAULT_PORT); assertThat(clientFactory.getDefaultPort()).isEqualTo(S2AProtocolNegotiatorFactory.DEFAULT_PORT);
} }
@ -146,7 +146,7 @@ public class S2AProtocolNegotiatorFactoryTest {
public void createProtocolNegotiatorFactory_buildsAnS2AProtocolNegotiatorOnClientSide_succeeds() public void createProtocolNegotiatorFactory_buildsAnS2AProtocolNegotiatorOnClientSide_succeeds()
throws Exception { throws Exception {
InternalProtocolNegotiator.ClientFactory clientFactory = InternalProtocolNegotiator.ClientFactory clientFactory =
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool); S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null);
ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator();
@ -158,7 +158,7 @@ public class S2AProtocolNegotiatorFactoryTest {
public void closeProtocolNegotiator_verifyProtocolNegotiatorIsClosedOnClientSide() public void closeProtocolNegotiator_verifyProtocolNegotiatorIsClosedOnClientSide()
throws Exception { throws Exception {
InternalProtocolNegotiator.ClientFactory clientFactory = InternalProtocolNegotiator.ClientFactory clientFactory =
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool); S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null);
ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator();
clientNegotiator.close(); clientNegotiator.close();
@ -170,7 +170,7 @@ public class S2AProtocolNegotiatorFactoryTest {
public void createChannelHandler_addHandlerToMockContext() throws Exception { public void createChannelHandler_addHandlerToMockContext() throws Exception {
ProtocolNegotiator clientNegotiator = ProtocolNegotiator clientNegotiator =
S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.createForClient( S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.createForClient(
channelPool, LOCAL_IDENTITY); channelPool, LOCAL_IDENTITY, null);
ChannelHandler channelHandler = clientNegotiator.newHandler(fakeConnectionHandler); ChannelHandler channelHandler = clientNegotiator.newHandler(fakeConnectionHandler);