mirror of https://github.com/grpc/grpc-java.git
s2a: Add S2AStub cleanup handler. (#11600)
* Add S2AStub cleanup handler. * Give TLS and Cleanup handlers name + update comment. * Don't add TLS handler twice. * Don't remove explicitly, since done by fireProtocolNegotiationEvent. * plumb S2AStub close to handshake end + add integration test. * close stub when TLS negotiation fails.
This commit is contained in:
parent
2129078dee
commit
d628396ec7
|
|
@ -24,6 +24,7 @@ import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
|
||||||
import io.netty.channel.ChannelHandler;
|
import io.netty.channel.ChannelHandler;
|
||||||
import io.netty.handler.ssl.SslContext;
|
import io.netty.handler.ssl.SslContext;
|
||||||
import io.netty.util.AsciiString;
|
import io.netty.util.AsciiString;
|
||||||
|
import java.util.Optional;
|
||||||
import java.util.concurrent.Executor;
|
import java.util.concurrent.Executor;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -40,9 +41,10 @@ public final class InternalProtocolNegotiators {
|
||||||
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
|
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
|
||||||
*/
|
*/
|
||||||
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
|
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
|
||||||
ObjectPool<? extends Executor> executorPool) {
|
ObjectPool<? extends Executor> executorPool,
|
||||||
|
Optional<Runnable> handshakeCompleteRunnable) {
|
||||||
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext,
|
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext,
|
||||||
executorPool);
|
executorPool, handshakeCompleteRunnable);
|
||||||
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
|
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
@ -70,7 +72,7 @@ public final class InternalProtocolNegotiators {
|
||||||
* may happen immediately, even before the TLS Handshake is complete.
|
* may happen immediately, even before the TLS Handshake is complete.
|
||||||
*/
|
*/
|
||||||
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
|
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
|
||||||
return tls(sslContext, null);
|
return tls(sslContext, null, Optional.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -167,7 +169,8 @@ public final class InternalProtocolNegotiators {
|
||||||
public static ChannelHandler clientTlsHandler(
|
public static ChannelHandler clientTlsHandler(
|
||||||
ChannelHandler next, SslContext sslContext, String authority,
|
ChannelHandler next, SslContext sslContext, String authority,
|
||||||
ChannelLogger negotiationLogger) {
|
ChannelLogger negotiationLogger) {
|
||||||
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger);
|
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger,
|
||||||
|
Optional.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class ProtocolNegotiationHandler
|
public static class ProtocolNegotiationHandler
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,7 @@ import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
import java.util.concurrent.Executor;
|
import java.util.concurrent.Executor;
|
||||||
import java.util.concurrent.ScheduledExecutorService;
|
import java.util.concurrent.ScheduledExecutorService;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
@ -604,7 +605,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh
|
||||||
case PLAINTEXT_UPGRADE:
|
case PLAINTEXT_UPGRADE:
|
||||||
return ProtocolNegotiators.plaintextUpgrade();
|
return ProtocolNegotiators.plaintextUpgrade();
|
||||||
case TLS:
|
case TLS:
|
||||||
return ProtocolNegotiators.tls(sslContext, executorPool);
|
return ProtocolNegotiators.tls(sslContext, executorPool, Optional.empty());
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType);
|
throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,7 @@ import java.net.URI;
|
||||||
import java.nio.channels.ClosedChannelException;
|
import java.nio.channels.ClosedChannelException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.EnumSet;
|
import java.util.EnumSet;
|
||||||
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.concurrent.Executor;
|
import java.util.concurrent.Executor;
|
||||||
import java.util.logging.Level;
|
import java.util.logging.Level;
|
||||||
|
|
@ -543,16 +544,18 @@ final class ProtocolNegotiators {
|
||||||
static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator {
|
static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator {
|
||||||
|
|
||||||
public ClientTlsProtocolNegotiator(SslContext sslContext,
|
public ClientTlsProtocolNegotiator(SslContext sslContext,
|
||||||
ObjectPool<? extends Executor> executorPool) {
|
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
|
||||||
this.sslContext = checkNotNull(sslContext, "sslContext");
|
this.sslContext = checkNotNull(sslContext, "sslContext");
|
||||||
this.executorPool = executorPool;
|
this.executorPool = executorPool;
|
||||||
if (this.executorPool != null) {
|
if (this.executorPool != null) {
|
||||||
this.executor = this.executorPool.getObject();
|
this.executor = this.executorPool.getObject();
|
||||||
}
|
}
|
||||||
|
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
|
||||||
}
|
}
|
||||||
|
|
||||||
private final SslContext sslContext;
|
private final SslContext sslContext;
|
||||||
private final ObjectPool<? extends Executor> executorPool;
|
private final ObjectPool<? extends Executor> executorPool;
|
||||||
|
private final Optional<Runnable> handshakeCompleteRunnable;
|
||||||
private Executor executor;
|
private Executor executor;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
@ -565,7 +568,7 @@ final class ProtocolNegotiators {
|
||||||
ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler);
|
ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler);
|
||||||
ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger();
|
ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger();
|
||||||
ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(),
|
ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(),
|
||||||
this.executor, negotiationLogger);
|
this.executor, negotiationLogger, handshakeCompleteRunnable);
|
||||||
return new WaitUntilActiveHandler(cth, negotiationLogger);
|
return new WaitUntilActiveHandler(cth, negotiationLogger);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -583,15 +586,18 @@ final class ProtocolNegotiators {
|
||||||
private final String host;
|
private final String host;
|
||||||
private final int port;
|
private final int port;
|
||||||
private Executor executor;
|
private Executor executor;
|
||||||
|
private final Optional<Runnable> handshakeCompleteRunnable;
|
||||||
|
|
||||||
ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority,
|
ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority,
|
||||||
Executor executor, ChannelLogger negotiationLogger) {
|
Executor executor, ChannelLogger negotiationLogger,
|
||||||
|
Optional<Runnable> handshakeCompleteRunnable) {
|
||||||
super(next, negotiationLogger);
|
super(next, negotiationLogger);
|
||||||
this.sslContext = checkNotNull(sslContext, "sslContext");
|
this.sslContext = checkNotNull(sslContext, "sslContext");
|
||||||
HostPort hostPort = parseAuthority(authority);
|
HostPort hostPort = parseAuthority(authority);
|
||||||
this.host = hostPort.host;
|
this.host = hostPort.host;
|
||||||
this.port = hostPort.port;
|
this.port = hostPort.port;
|
||||||
this.executor = executor;
|
this.executor = executor;
|
||||||
|
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
@ -620,6 +626,9 @@ final class ProtocolNegotiators {
|
||||||
Exception ex =
|
Exception ex =
|
||||||
unavailableException("Failed ALPN negotiation: Unable to find compatible protocol");
|
unavailableException("Failed ALPN negotiation: Unable to find compatible protocol");
|
||||||
logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed.", ex);
|
logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed.", ex);
|
||||||
|
if (handshakeCompleteRunnable.isPresent()) {
|
||||||
|
handshakeCompleteRunnable.get().run();
|
||||||
|
}
|
||||||
ctx.fireExceptionCaught(ex);
|
ctx.fireExceptionCaught(ex);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -634,6 +643,9 @@ final class ProtocolNegotiators {
|
||||||
.withCause(t)
|
.withCause(t)
|
||||||
.asRuntimeException();
|
.asRuntimeException();
|
||||||
}
|
}
|
||||||
|
if (handshakeCompleteRunnable.isPresent()) {
|
||||||
|
handshakeCompleteRunnable.get().run();
|
||||||
|
}
|
||||||
ctx.fireExceptionCaught(t);
|
ctx.fireExceptionCaught(t);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -649,6 +661,9 @@ final class ProtocolNegotiators {
|
||||||
.set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
|
.set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
|
||||||
.build();
|
.build();
|
||||||
replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security));
|
replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security));
|
||||||
|
if (handshakeCompleteRunnable.isPresent()) {
|
||||||
|
handshakeCompleteRunnable.get().run();
|
||||||
|
}
|
||||||
fireProtocolNegotiationEvent(ctx);
|
fireProtocolNegotiationEvent(ctx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -683,8 +698,8 @@ final class ProtocolNegotiators {
|
||||||
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
|
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
|
||||||
*/
|
*/
|
||||||
public static ProtocolNegotiator tls(SslContext sslContext,
|
public static ProtocolNegotiator tls(SslContext sslContext,
|
||||||
ObjectPool<? extends Executor> executorPool) {
|
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
|
||||||
return new ClientTlsProtocolNegotiator(sslContext, executorPool);
|
return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -693,7 +708,7 @@ final class ProtocolNegotiators {
|
||||||
* may happen immediately, even before the TLS Handshake is complete.
|
* may happen immediately, even before the TLS Handshake is complete.
|
||||||
*/
|
*/
|
||||||
public static ProtocolNegotiator tls(SslContext sslContext) {
|
public static ProtocolNegotiator tls(SslContext sslContext) {
|
||||||
return tls(sslContext, null);
|
return tls(sslContext, null, Optional.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) {
|
public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) {
|
||||||
|
|
|
||||||
|
|
@ -105,6 +105,7 @@ import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
import java.util.concurrent.LinkedBlockingQueue;
|
import java.util.concurrent.LinkedBlockingQueue;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
@ -766,7 +767,8 @@ public class NettyClientTransportTest {
|
||||||
.trustManager(caCert)
|
.trustManager(caCert)
|
||||||
.keyManager(clientCert, clientKey)
|
.keyManager(clientCert, clientKey)
|
||||||
.build();
|
.build();
|
||||||
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool);
|
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool,
|
||||||
|
Optional.empty());
|
||||||
// after starting the client, the Executor in the client pool should be used
|
// after starting the client, the Executor in the client pool should be used
|
||||||
assertEquals(true, clientExecutorPool.isInUse());
|
assertEquals(true, clientExecutorPool.isInUse());
|
||||||
final NettyClientTransport transport = newTransport(negotiator);
|
final NettyClientTransport transport = newTransport(negotiator);
|
||||||
|
|
|
||||||
|
|
@ -120,6 +120,7 @@ import java.util.ArrayDeque;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
import java.util.Queue;
|
import java.util.Queue;
|
||||||
import java.util.concurrent.CountDownLatch;
|
import java.util.concurrent.CountDownLatch;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
@ -876,7 +877,7 @@ public class ProtocolNegotiatorsTest {
|
||||||
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
|
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
|
||||||
|
|
||||||
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
|
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
|
||||||
"authority", elg, noopLogger);
|
"authority", elg, noopLogger, Optional.empty());
|
||||||
pipeline.addLast(handler);
|
pipeline.addLast(handler);
|
||||||
pipeline.replace(SslHandler.class, null, goodSslHandler);
|
pipeline.replace(SslHandler.class, null, goodSslHandler);
|
||||||
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
|
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
|
||||||
|
|
@ -914,7 +915,7 @@ public class ProtocolNegotiatorsTest {
|
||||||
.applicationProtocolConfig(apn).build();
|
.applicationProtocolConfig(apn).build();
|
||||||
|
|
||||||
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
|
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
|
||||||
"authority", elg, noopLogger);
|
"authority", elg, noopLogger, Optional.empty());
|
||||||
pipeline.addLast(handler);
|
pipeline.addLast(handler);
|
||||||
pipeline.replace(SslHandler.class, null, goodSslHandler);
|
pipeline.replace(SslHandler.class, null, goodSslHandler);
|
||||||
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
|
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
|
||||||
|
|
@ -938,7 +939,7 @@ public class ProtocolNegotiatorsTest {
|
||||||
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
|
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
|
||||||
|
|
||||||
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
|
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
|
||||||
"authority", elg, noopLogger);
|
"authority", elg, noopLogger, Optional.empty());
|
||||||
pipeline.addLast(handler);
|
pipeline.addLast(handler);
|
||||||
|
|
||||||
final AtomicReference<Throwable> error = new AtomicReference<>();
|
final AtomicReference<Throwable> error = new AtomicReference<>();
|
||||||
|
|
@ -966,7 +967,7 @@ public class ProtocolNegotiatorsTest {
|
||||||
@Test
|
@Test
|
||||||
public void clientTlsHandler_closeDuringNegotiation() throws Exception {
|
public void clientTlsHandler_closeDuringNegotiation() throws Exception {
|
||||||
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
|
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
|
||||||
"authority", null, noopLogger);
|
"authority", null, noopLogger, Optional.empty());
|
||||||
pipeline.addLast(new WriteBufferingAndExceptionHandler(handler));
|
pipeline.addLast(new WriteBufferingAndExceptionHandler(handler));
|
||||||
ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
|
ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
|
||||||
|
|
||||||
|
|
@ -1228,7 +1229,8 @@ public class ProtocolNegotiatorsTest {
|
||||||
serverSslContext = GrpcSslContexts.forServer(server1Chain, server1Key).build();
|
serverSslContext = GrpcSslContexts.forServer(server1Chain, server1Key).build();
|
||||||
}
|
}
|
||||||
FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
|
FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
|
||||||
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, null);
|
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext,
|
||||||
|
null, Optional.empty());
|
||||||
WriteBufferingAndExceptionHandler clientWbaeh =
|
WriteBufferingAndExceptionHandler clientWbaeh =
|
||||||
new WriteBufferingAndExceptionHandler(pn.newHandler(gh));
|
new WriteBufferingAndExceptionHandler(pn.newHandler(gh));
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ import io.grpc.netty.InternalProtocolNegotiator;
|
||||||
import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel;
|
import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel;
|
||||||
import io.grpc.s2a.internal.handshaker.S2AIdentity;
|
import io.grpc.s2a.internal.handshaker.S2AIdentity;
|
||||||
import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory;
|
import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory;
|
||||||
|
import io.grpc.s2a.internal.handshaker.S2AStub;
|
||||||
import javax.annotation.concurrent.NotThreadSafe;
|
import javax.annotation.concurrent.NotThreadSafe;
|
||||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||||
|
|
||||||
|
|
@ -59,6 +60,7 @@ public final class S2AChannelCredentials {
|
||||||
private final String s2aAddress;
|
private final String s2aAddress;
|
||||||
private final ChannelCredentials s2aChannelCredentials;
|
private final ChannelCredentials s2aChannelCredentials;
|
||||||
private @Nullable S2AIdentity localIdentity = null;
|
private @Nullable S2AIdentity localIdentity = null;
|
||||||
|
private @Nullable S2AStub stub = null;
|
||||||
|
|
||||||
Builder(String s2aAddress, ChannelCredentials s2aChannelCredentials) {
|
Builder(String s2aAddress, ChannelCredentials s2aChannelCredentials) {
|
||||||
this.s2aAddress = s2aAddress;
|
this.s2aAddress = s2aAddress;
|
||||||
|
|
@ -104,6 +106,16 @@ public final class S2AChannelCredentials {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the stub to use to communicate with S2A. This is only used for testing that the
|
||||||
|
* stream to S2A gets closed.
|
||||||
|
*/
|
||||||
|
public Builder setStub(S2AStub stub) {
|
||||||
|
checkNotNull(stub);
|
||||||
|
this.stub = stub;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public ChannelCredentials build() {
|
public ChannelCredentials build() {
|
||||||
return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory());
|
return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory());
|
||||||
}
|
}
|
||||||
|
|
@ -113,7 +125,7 @@ public final class S2AChannelCredentials {
|
||||||
SharedResourcePool.forResource(
|
SharedResourcePool.forResource(
|
||||||
S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials));
|
S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials));
|
||||||
checkNotNull(s2aChannelPool, "s2aChannelPool");
|
checkNotNull(s2aChannelPool, "s2aChannelPool");
|
||||||
return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool);
|
return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool, stub);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -63,28 +63,35 @@ public final class S2AProtocolNegotiatorFactory {
|
||||||
* @param localIdentity the identity of the client; if none is provided, the S2A will use the
|
* @param localIdentity the identity of the client; if none is provided, the S2A will use the
|
||||||
* client's default identity.
|
* client's default identity.
|
||||||
* @param s2aChannelPool a pool of shared channels that can be used to connect to the S2A.
|
* @param s2aChannelPool a pool of shared channels that can be used to connect to the S2A.
|
||||||
|
* @param stub the stub to use to communicate with S2A. If none is provided the channelPool
|
||||||
|
* will be used to create the stub. This is exposed for verifying the stream to S2A gets
|
||||||
|
* closed in tests.
|
||||||
* @return a factory for creating a client-side protocol negotiator.
|
* @return a factory for creating a client-side protocol negotiator.
|
||||||
*/
|
*/
|
||||||
public static InternalProtocolNegotiator.ClientFactory createClientFactory(
|
public static InternalProtocolNegotiator.ClientFactory createClientFactory(
|
||||||
@Nullable S2AIdentity localIdentity, ObjectPool<Channel> s2aChannelPool) {
|
@Nullable S2AIdentity localIdentity, ObjectPool<Channel> s2aChannelPool,
|
||||||
|
@Nullable S2AStub stub) {
|
||||||
checkNotNull(s2aChannelPool, "S2A channel pool should not be null.");
|
checkNotNull(s2aChannelPool, "S2A channel pool should not be null.");
|
||||||
return new S2AClientProtocolNegotiatorFactory(localIdentity, s2aChannelPool);
|
return new S2AClientProtocolNegotiatorFactory(localIdentity, s2aChannelPool, stub);
|
||||||
}
|
}
|
||||||
|
|
||||||
static final class S2AClientProtocolNegotiatorFactory
|
static final class S2AClientProtocolNegotiatorFactory
|
||||||
implements InternalProtocolNegotiator.ClientFactory {
|
implements InternalProtocolNegotiator.ClientFactory {
|
||||||
private final @Nullable S2AIdentity localIdentity;
|
private final @Nullable S2AIdentity localIdentity;
|
||||||
private final ObjectPool<Channel> channelPool;
|
private final ObjectPool<Channel> channelPool;
|
||||||
|
private final @Nullable S2AStub stub;
|
||||||
|
|
||||||
S2AClientProtocolNegotiatorFactory(
|
S2AClientProtocolNegotiatorFactory(
|
||||||
@Nullable S2AIdentity localIdentity, ObjectPool<Channel> channelPool) {
|
@Nullable S2AIdentity localIdentity, ObjectPool<Channel> channelPool,
|
||||||
|
@Nullable S2AStub stub) {
|
||||||
this.localIdentity = localIdentity;
|
this.localIdentity = localIdentity;
|
||||||
this.channelPool = channelPool;
|
this.channelPool = channelPool;
|
||||||
|
this.stub = stub;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ProtocolNegotiator newNegotiator() {
|
public ProtocolNegotiator newNegotiator() {
|
||||||
return S2AProtocolNegotiator.createForClient(channelPool, localIdentity);
|
return S2AProtocolNegotiator.createForClient(channelPool, localIdentity, stub);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
@ -98,18 +105,20 @@ public final class S2AProtocolNegotiatorFactory {
|
||||||
static final class S2AProtocolNegotiator implements ProtocolNegotiator {
|
static final class S2AProtocolNegotiator implements ProtocolNegotiator {
|
||||||
|
|
||||||
private final ObjectPool<Channel> channelPool;
|
private final ObjectPool<Channel> channelPool;
|
||||||
private final Channel channel;
|
private @Nullable Channel channel = null;
|
||||||
private final Optional<S2AIdentity> localIdentity;
|
private final Optional<S2AIdentity> localIdentity;
|
||||||
|
private final @Nullable S2AStub stub;
|
||||||
private final ListeningExecutorService service =
|
private final ListeningExecutorService service =
|
||||||
MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1));
|
MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1));
|
||||||
|
|
||||||
static S2AProtocolNegotiator createForClient(
|
static S2AProtocolNegotiator createForClient(
|
||||||
ObjectPool<Channel> channelPool, @Nullable S2AIdentity localIdentity) {
|
ObjectPool<Channel> channelPool, @Nullable S2AIdentity localIdentity,
|
||||||
|
@Nullable S2AStub stub) {
|
||||||
checkNotNull(channelPool, "Channel pool should not be null.");
|
checkNotNull(channelPool, "Channel pool should not be null.");
|
||||||
if (localIdentity == null) {
|
if (localIdentity == null) {
|
||||||
return new S2AProtocolNegotiator(channelPool, Optional.empty());
|
return new S2AProtocolNegotiator(channelPool, Optional.empty(), stub);
|
||||||
} else {
|
} else {
|
||||||
return new S2AProtocolNegotiator(channelPool, Optional.of(localIdentity));
|
return new S2AProtocolNegotiator(channelPool, Optional.of(localIdentity), stub);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -122,10 +131,13 @@ public final class S2AProtocolNegotiatorFactory {
|
||||||
}
|
}
|
||||||
|
|
||||||
private S2AProtocolNegotiator(ObjectPool<Channel> channelPool,
|
private S2AProtocolNegotiator(ObjectPool<Channel> channelPool,
|
||||||
Optional<S2AIdentity> localIdentity) {
|
Optional<S2AIdentity> localIdentity, @Nullable S2AStub stub) {
|
||||||
this.channelPool = channelPool;
|
this.channelPool = channelPool;
|
||||||
this.localIdentity = localIdentity;
|
this.localIdentity = localIdentity;
|
||||||
this.channel = channelPool.getObject();
|
this.stub = stub;
|
||||||
|
if (this.stub == null) {
|
||||||
|
this.channel = channelPool.getObject();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
@ -139,13 +151,15 @@ public final class S2AProtocolNegotiatorFactory {
|
||||||
String hostname = getHostNameFromAuthority(grpcHandler.getAuthority());
|
String hostname = getHostNameFromAuthority(grpcHandler.getAuthority());
|
||||||
checkArgument(!isNullOrEmpty(hostname), "hostname should not be null or empty.");
|
checkArgument(!isNullOrEmpty(hostname), "hostname should not be null or empty.");
|
||||||
return new S2AProtocolNegotiationHandler(
|
return new S2AProtocolNegotiationHandler(
|
||||||
grpcHandler, channel, localIdentity, hostname, service);
|
grpcHandler, channel, localIdentity, hostname, service, stub);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close() {
|
public void close() {
|
||||||
service.shutdown();
|
service.shutdown();
|
||||||
channelPool.returnObject(channel);
|
if (channel != null) {
|
||||||
|
channelPool.returnObject(channel);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -180,18 +194,20 @@ public final class S2AProtocolNegotiatorFactory {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler {
|
private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler {
|
||||||
private final Channel channel;
|
private final @Nullable Channel channel;
|
||||||
private final Optional<S2AIdentity> localIdentity;
|
private final Optional<S2AIdentity> localIdentity;
|
||||||
private final String hostname;
|
private final String hostname;
|
||||||
private final GrpcHttp2ConnectionHandler grpcHandler;
|
private final GrpcHttp2ConnectionHandler grpcHandler;
|
||||||
private final ListeningExecutorService service;
|
private final ListeningExecutorService service;
|
||||||
|
private final @Nullable S2AStub stub;
|
||||||
|
|
||||||
private S2AProtocolNegotiationHandler(
|
private S2AProtocolNegotiationHandler(
|
||||||
GrpcHttp2ConnectionHandler grpcHandler,
|
GrpcHttp2ConnectionHandler grpcHandler,
|
||||||
Channel channel,
|
Channel channel,
|
||||||
Optional<S2AIdentity> localIdentity,
|
Optional<S2AIdentity> localIdentity,
|
||||||
String hostname,
|
String hostname,
|
||||||
ListeningExecutorService service) {
|
ListeningExecutorService service,
|
||||||
|
@Nullable S2AStub stub) {
|
||||||
super(
|
super(
|
||||||
// superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next'
|
// superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next'
|
||||||
// handler but we don't have a next handler _yet_. So we "disable" superclass's behavior
|
// handler but we don't have a next handler _yet_. So we "disable" superclass's behavior
|
||||||
|
|
@ -209,6 +225,7 @@ public final class S2AProtocolNegotiatorFactory {
|
||||||
this.hostname = hostname;
|
this.hostname = hostname;
|
||||||
checkNotNull(service, "service should not be null.");
|
checkNotNull(service, "service should not be null.");
|
||||||
this.service = service;
|
this.service = service;
|
||||||
|
this.stub = stub;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
@ -217,8 +234,13 @@ public final class S2AProtocolNegotiatorFactory {
|
||||||
BufferReadsHandler bufferReads = new BufferReadsHandler();
|
BufferReadsHandler bufferReads = new BufferReadsHandler();
|
||||||
ctx.pipeline().addBefore(ctx.name(), /* name= */ null, bufferReads);
|
ctx.pipeline().addBefore(ctx.name(), /* name= */ null, bufferReads);
|
||||||
|
|
||||||
S2AServiceGrpc.S2AServiceStub stub = S2AServiceGrpc.newStub(channel);
|
S2AStub s2aStub;
|
||||||
S2AStub s2aStub = S2AStub.newInstance(stub);
|
if (this.stub == null) {
|
||||||
|
checkNotNull(channel, "Channel to S2A should not be null");
|
||||||
|
s2aStub = S2AStub.newInstance(S2AServiceGrpc.newStub(channel));
|
||||||
|
} else {
|
||||||
|
s2aStub = this.stub;
|
||||||
|
}
|
||||||
|
|
||||||
ListenableFuture<SslContext> sslContextFuture =
|
ListenableFuture<SslContext> sslContextFuture =
|
||||||
service.submit(() -> SslContextFactory.createForClient(s2aStub, hostname, localIdentity));
|
service.submit(() -> SslContextFactory.createForClient(s2aStub, hostname, localIdentity));
|
||||||
|
|
@ -230,11 +252,17 @@ public final class S2AProtocolNegotiatorFactory {
|
||||||
ChannelHandler handler =
|
ChannelHandler handler =
|
||||||
InternalProtocolNegotiators.tls(
|
InternalProtocolNegotiators.tls(
|
||||||
sslContext,
|
sslContext,
|
||||||
SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR))
|
SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR),
|
||||||
|
Optional.of(new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
s2aStub.close();
|
||||||
|
}
|
||||||
|
}))
|
||||||
.newHandler(grpcHandler);
|
.newHandler(grpcHandler);
|
||||||
|
|
||||||
// Remove the bufferReads handler and delegate the rest of the handshake to the TLS
|
// Delegate the rest of the handshake to the TLS handler. and remove the
|
||||||
// handler.
|
// bufferReads handler.
|
||||||
ctx.pipeline().addAfter(ctx.name(), /* name= */ null, handler);
|
ctx.pipeline().addAfter(ctx.name(), /* name= */ null, handler);
|
||||||
fireProtocolNegotiationEvent(ctx);
|
fireProtocolNegotiationEvent(ctx);
|
||||||
ctx.pipeline().remove(bufferReads);
|
ctx.pipeline().remove(bufferReads);
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ import javax.annotation.concurrent.NotThreadSafe;
|
||||||
|
|
||||||
/** Reads and writes messages to and from the S2A. */
|
/** Reads and writes messages to and from the S2A. */
|
||||||
@NotThreadSafe
|
@NotThreadSafe
|
||||||
class S2AStub implements AutoCloseable {
|
public class S2AStub implements AutoCloseable {
|
||||||
private static final Logger logger = Logger.getLogger(S2AStub.class.getName());
|
private static final Logger logger = Logger.getLogger(S2AStub.class.getName());
|
||||||
private static final long HANDSHAKE_RPC_DEADLINE_SECS = 20;
|
private static final long HANDSHAKE_RPC_DEADLINE_SECS = 20;
|
||||||
private final StreamObserver<SessionResp> reader = new Reader();
|
private final StreamObserver<SessionResp> reader = new Reader();
|
||||||
|
|
@ -42,6 +42,7 @@ class S2AStub implements AutoCloseable {
|
||||||
private StreamObserver<SessionReq> writer;
|
private StreamObserver<SessionReq> writer;
|
||||||
private boolean doneReading = false;
|
private boolean doneReading = false;
|
||||||
private boolean doneWriting = false;
|
private boolean doneWriting = false;
|
||||||
|
private boolean isClosed = false;
|
||||||
|
|
||||||
static S2AStub newInstance(S2AServiceGrpc.S2AServiceStub serviceStub) {
|
static S2AStub newInstance(S2AServiceGrpc.S2AServiceStub serviceStub) {
|
||||||
checkNotNull(serviceStub);
|
checkNotNull(serviceStub);
|
||||||
|
|
@ -136,6 +137,11 @@ class S2AStub implements AutoCloseable {
|
||||||
if (writer != null) {
|
if (writer != null) {
|
||||||
writer.onCompleted();
|
writer.onCompleted();
|
||||||
}
|
}
|
||||||
|
isClosed = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isClosed() {
|
||||||
|
return isClosed;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create a new writer if the writer is null. */
|
/** Create a new writer if the writer is null. */
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ package io.grpc.s2a.internal.handshaker;
|
||||||
import static com.google.common.truth.Truth.assertThat;
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||||
|
|
||||||
|
import io.grpc.Channel;
|
||||||
import io.grpc.ChannelCredentials;
|
import io.grpc.ChannelCredentials;
|
||||||
import io.grpc.Grpc;
|
import io.grpc.Grpc;
|
||||||
import io.grpc.InsecureChannelCredentials;
|
import io.grpc.InsecureChannelCredentials;
|
||||||
|
|
@ -29,9 +30,12 @@ import io.grpc.ServerCredentials;
|
||||||
import io.grpc.TlsChannelCredentials;
|
import io.grpc.TlsChannelCredentials;
|
||||||
import io.grpc.TlsServerCredentials;
|
import io.grpc.TlsServerCredentials;
|
||||||
import io.grpc.benchmarks.Utils;
|
import io.grpc.benchmarks.Utils;
|
||||||
|
import io.grpc.internal.ObjectPool;
|
||||||
|
import io.grpc.internal.SharedResourcePool;
|
||||||
import io.grpc.netty.GrpcSslContexts;
|
import io.grpc.netty.GrpcSslContexts;
|
||||||
import io.grpc.netty.NettyServerBuilder;
|
import io.grpc.netty.NettyServerBuilder;
|
||||||
import io.grpc.s2a.S2AChannelCredentials;
|
import io.grpc.s2a.S2AChannelCredentials;
|
||||||
|
import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel;
|
||||||
import io.grpc.s2a.internal.handshaker.FakeS2AServer;
|
import io.grpc.s2a.internal.handshaker.FakeS2AServer;
|
||||||
import io.grpc.stub.StreamObserver;
|
import io.grpc.stub.StreamObserver;
|
||||||
import io.grpc.testing.protobuf.SimpleRequest;
|
import io.grpc.testing.protobuf.SimpleRequest;
|
||||||
|
|
@ -141,6 +145,25 @@ public final class IntegrationTest {
|
||||||
assertThat(doUnaryRpc(channel)).isTrue();
|
assertThat(doUnaryRpc(channel)).isTrue();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void clientCommunicateUsingS2ACredentialsSucceeds_verifyStreamToS2AClosed()
|
||||||
|
throws Exception {
|
||||||
|
ObjectPool<Channel> s2aChannelPool =
|
||||||
|
SharedResourcePool.forResource(
|
||||||
|
S2AHandshakerServiceChannel.getChannelResource(s2aAddress,
|
||||||
|
InsecureChannelCredentials.create()));
|
||||||
|
Channel ch = s2aChannelPool.getObject();
|
||||||
|
S2AStub stub = S2AStub.newInstance(S2AServiceGrpc.newStub(ch));
|
||||||
|
ChannelCredentials credentials =
|
||||||
|
S2AChannelCredentials.newBuilder(s2aAddress, InsecureChannelCredentials.create())
|
||||||
|
.setLocalSpiffeId("test-spiffe-id").setStub(stub).build();
|
||||||
|
ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build();
|
||||||
|
|
||||||
|
s2aChannelPool.returnObject(ch);
|
||||||
|
assertThat(doUnaryRpc(channel)).isTrue();
|
||||||
|
assertThat(stub.isClosed()).isTrue();
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Exception {
|
public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Exception {
|
||||||
String privateKeyPath = "src/test/resources/client_key.pem";
|
String privateKeyPath = "src/test/resources/client_key.pem";
|
||||||
|
|
|
||||||
|
|
@ -122,7 +122,7 @@ public class S2AProtocolNegotiatorFactoryTest {
|
||||||
@Test
|
@Test
|
||||||
public void createProtocolNegotiatorFactory_getsDefaultPort_succeeds() throws Exception {
|
public void createProtocolNegotiatorFactory_getsDefaultPort_succeeds() throws Exception {
|
||||||
InternalProtocolNegotiator.ClientFactory clientFactory =
|
InternalProtocolNegotiator.ClientFactory clientFactory =
|
||||||
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool);
|
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null);
|
||||||
|
|
||||||
assertThat(clientFactory.getDefaultPort()).isEqualTo(S2AProtocolNegotiatorFactory.DEFAULT_PORT);
|
assertThat(clientFactory.getDefaultPort()).isEqualTo(S2AProtocolNegotiatorFactory.DEFAULT_PORT);
|
||||||
}
|
}
|
||||||
|
|
@ -146,7 +146,7 @@ public class S2AProtocolNegotiatorFactoryTest {
|
||||||
public void createProtocolNegotiatorFactory_buildsAnS2AProtocolNegotiatorOnClientSide_succeeds()
|
public void createProtocolNegotiatorFactory_buildsAnS2AProtocolNegotiatorOnClientSide_succeeds()
|
||||||
throws Exception {
|
throws Exception {
|
||||||
InternalProtocolNegotiator.ClientFactory clientFactory =
|
InternalProtocolNegotiator.ClientFactory clientFactory =
|
||||||
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool);
|
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null);
|
||||||
|
|
||||||
ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator();
|
ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator();
|
||||||
|
|
||||||
|
|
@ -158,7 +158,7 @@ public class S2AProtocolNegotiatorFactoryTest {
|
||||||
public void closeProtocolNegotiator_verifyProtocolNegotiatorIsClosedOnClientSide()
|
public void closeProtocolNegotiator_verifyProtocolNegotiatorIsClosedOnClientSide()
|
||||||
throws Exception {
|
throws Exception {
|
||||||
InternalProtocolNegotiator.ClientFactory clientFactory =
|
InternalProtocolNegotiator.ClientFactory clientFactory =
|
||||||
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool);
|
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null);
|
||||||
ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator();
|
ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator();
|
||||||
|
|
||||||
clientNegotiator.close();
|
clientNegotiator.close();
|
||||||
|
|
@ -170,7 +170,7 @@ public class S2AProtocolNegotiatorFactoryTest {
|
||||||
public void createChannelHandler_addHandlerToMockContext() throws Exception {
|
public void createChannelHandler_addHandlerToMockContext() throws Exception {
|
||||||
ProtocolNegotiator clientNegotiator =
|
ProtocolNegotiator clientNegotiator =
|
||||||
S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.createForClient(
|
S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.createForClient(
|
||||||
channelPool, LOCAL_IDENTITY);
|
channelPool, LOCAL_IDENTITY, null);
|
||||||
|
|
||||||
ChannelHandler channelHandler = clientNegotiator.newHandler(fakeConnectionHandler);
|
ChannelHandler channelHandler = clientNegotiator.newHandler(fakeConnectionHandler);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue