mirror of https://github.com/grpc/grpc-java.git
s2a: Add S2AStub cleanup handler. (#11600)
* Add S2AStub cleanup handler. * Give TLS and Cleanup handlers name + update comment. * Don't add TLS handler twice. * Don't remove explicitly, since done by fireProtocolNegotiationEvent. * plumb S2AStub close to handshake end + add integration test. * close stub when TLS negotiation fails.
This commit is contained in:
parent
2129078dee
commit
d628396ec7
|
|
@ -24,6 +24,7 @@ import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
|
|||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.handler.ssl.SslContext;
|
||||
import io.netty.util.AsciiString;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.Executor;
|
||||
|
||||
/**
|
||||
|
|
@ -40,9 +41,10 @@ public final class InternalProtocolNegotiators {
|
|||
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
|
||||
*/
|
||||
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
|
||||
ObjectPool<? extends Executor> executorPool) {
|
||||
ObjectPool<? extends Executor> executorPool,
|
||||
Optional<Runnable> handshakeCompleteRunnable) {
|
||||
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext,
|
||||
executorPool);
|
||||
executorPool, handshakeCompleteRunnable);
|
||||
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
|
||||
|
||||
@Override
|
||||
|
|
@ -70,7 +72,7 @@ public final class InternalProtocolNegotiators {
|
|||
* may happen immediately, even before the TLS Handshake is complete.
|
||||
*/
|
||||
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
|
||||
return tls(sslContext, null);
|
||||
return tls(sslContext, null, Optional.empty());
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -167,7 +169,8 @@ public final class InternalProtocolNegotiators {
|
|||
public static ChannelHandler clientTlsHandler(
|
||||
ChannelHandler next, SslContext sslContext, String authority,
|
||||
ChannelLogger negotiationLogger) {
|
||||
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger);
|
||||
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger,
|
||||
Optional.empty());
|
||||
}
|
||||
|
||||
public static class ProtocolNegotiationHandler
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ import java.util.Collection;
|
|||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
|
@ -604,7 +605,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh
|
|||
case PLAINTEXT_UPGRADE:
|
||||
return ProtocolNegotiators.plaintextUpgrade();
|
||||
case TLS:
|
||||
return ProtocolNegotiators.tls(sslContext, executorPool);
|
||||
return ProtocolNegotiators.tls(sslContext, executorPool, Optional.empty());
|
||||
default:
|
||||
throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ import java.net.URI;
|
|||
import java.nio.channels.ClosedChannelException;
|
||||
import java.util.Arrays;
|
||||
import java.util.EnumSet;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.logging.Level;
|
||||
|
|
@ -543,16 +544,18 @@ final class ProtocolNegotiators {
|
|||
static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator {
|
||||
|
||||
public ClientTlsProtocolNegotiator(SslContext sslContext,
|
||||
ObjectPool<? extends Executor> executorPool) {
|
||||
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
|
||||
this.sslContext = checkNotNull(sslContext, "sslContext");
|
||||
this.executorPool = executorPool;
|
||||
if (this.executorPool != null) {
|
||||
this.executor = this.executorPool.getObject();
|
||||
}
|
||||
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
|
||||
}
|
||||
|
||||
private final SslContext sslContext;
|
||||
private final ObjectPool<? extends Executor> executorPool;
|
||||
private final Optional<Runnable> handshakeCompleteRunnable;
|
||||
private Executor executor;
|
||||
|
||||
@Override
|
||||
|
|
@ -565,7 +568,7 @@ final class ProtocolNegotiators {
|
|||
ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler);
|
||||
ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger();
|
||||
ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(),
|
||||
this.executor, negotiationLogger);
|
||||
this.executor, negotiationLogger, handshakeCompleteRunnable);
|
||||
return new WaitUntilActiveHandler(cth, negotiationLogger);
|
||||
}
|
||||
|
||||
|
|
@ -583,15 +586,18 @@ final class ProtocolNegotiators {
|
|||
private final String host;
|
||||
private final int port;
|
||||
private Executor executor;
|
||||
private final Optional<Runnable> handshakeCompleteRunnable;
|
||||
|
||||
ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority,
|
||||
Executor executor, ChannelLogger negotiationLogger) {
|
||||
Executor executor, ChannelLogger negotiationLogger,
|
||||
Optional<Runnable> handshakeCompleteRunnable) {
|
||||
super(next, negotiationLogger);
|
||||
this.sslContext = checkNotNull(sslContext, "sslContext");
|
||||
HostPort hostPort = parseAuthority(authority);
|
||||
this.host = hostPort.host;
|
||||
this.port = hostPort.port;
|
||||
this.executor = executor;
|
||||
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -620,6 +626,9 @@ final class ProtocolNegotiators {
|
|||
Exception ex =
|
||||
unavailableException("Failed ALPN negotiation: Unable to find compatible protocol");
|
||||
logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed.", ex);
|
||||
if (handshakeCompleteRunnable.isPresent()) {
|
||||
handshakeCompleteRunnable.get().run();
|
||||
}
|
||||
ctx.fireExceptionCaught(ex);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -634,6 +643,9 @@ final class ProtocolNegotiators {
|
|||
.withCause(t)
|
||||
.asRuntimeException();
|
||||
}
|
||||
if (handshakeCompleteRunnable.isPresent()) {
|
||||
handshakeCompleteRunnable.get().run();
|
||||
}
|
||||
ctx.fireExceptionCaught(t);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -649,6 +661,9 @@ final class ProtocolNegotiators {
|
|||
.set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
|
||||
.build();
|
||||
replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security));
|
||||
if (handshakeCompleteRunnable.isPresent()) {
|
||||
handshakeCompleteRunnable.get().run();
|
||||
}
|
||||
fireProtocolNegotiationEvent(ctx);
|
||||
}
|
||||
}
|
||||
|
|
@ -683,8 +698,8 @@ final class ProtocolNegotiators {
|
|||
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
|
||||
*/
|
||||
public static ProtocolNegotiator tls(SslContext sslContext,
|
||||
ObjectPool<? extends Executor> executorPool) {
|
||||
return new ClientTlsProtocolNegotiator(sslContext, executorPool);
|
||||
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
|
||||
return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -693,7 +708,7 @@ final class ProtocolNegotiators {
|
|||
* may happen immediately, even before the TLS Handshake is complete.
|
||||
*/
|
||||
public static ProtocolNegotiator tls(SslContext sslContext) {
|
||||
return tls(sslContext, null);
|
||||
return tls(sslContext, null, Optional.empty());
|
||||
}
|
||||
|
||||
public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) {
|
||||
|
|
|
|||
|
|
@ -105,6 +105,7 @@ import java.util.Collections;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
|
@ -766,7 +767,8 @@ public class NettyClientTransportTest {
|
|||
.trustManager(caCert)
|
||||
.keyManager(clientCert, clientKey)
|
||||
.build();
|
||||
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool);
|
||||
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool,
|
||||
Optional.empty());
|
||||
// after starting the client, the Executor in the client pool should be used
|
||||
assertEquals(true, clientExecutorPool.isInUse());
|
||||
final NettyClientTransport transport = newTransport(negotiator);
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ import java.util.ArrayDeque;
|
|||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Queue;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
|
@ -876,7 +877,7 @@ public class ProtocolNegotiatorsTest {
|
|||
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
|
||||
|
||||
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
|
||||
"authority", elg, noopLogger);
|
||||
"authority", elg, noopLogger, Optional.empty());
|
||||
pipeline.addLast(handler);
|
||||
pipeline.replace(SslHandler.class, null, goodSslHandler);
|
||||
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
|
||||
|
|
@ -914,7 +915,7 @@ public class ProtocolNegotiatorsTest {
|
|||
.applicationProtocolConfig(apn).build();
|
||||
|
||||
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
|
||||
"authority", elg, noopLogger);
|
||||
"authority", elg, noopLogger, Optional.empty());
|
||||
pipeline.addLast(handler);
|
||||
pipeline.replace(SslHandler.class, null, goodSslHandler);
|
||||
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
|
||||
|
|
@ -938,7 +939,7 @@ public class ProtocolNegotiatorsTest {
|
|||
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
|
||||
|
||||
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
|
||||
"authority", elg, noopLogger);
|
||||
"authority", elg, noopLogger, Optional.empty());
|
||||
pipeline.addLast(handler);
|
||||
|
||||
final AtomicReference<Throwable> error = new AtomicReference<>();
|
||||
|
|
@ -966,7 +967,7 @@ public class ProtocolNegotiatorsTest {
|
|||
@Test
|
||||
public void clientTlsHandler_closeDuringNegotiation() throws Exception {
|
||||
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
|
||||
"authority", null, noopLogger);
|
||||
"authority", null, noopLogger, Optional.empty());
|
||||
pipeline.addLast(new WriteBufferingAndExceptionHandler(handler));
|
||||
ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
|
||||
|
||||
|
|
@ -1228,7 +1229,8 @@ public class ProtocolNegotiatorsTest {
|
|||
serverSslContext = GrpcSslContexts.forServer(server1Chain, server1Key).build();
|
||||
}
|
||||
FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
|
||||
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, null);
|
||||
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext,
|
||||
null, Optional.empty());
|
||||
WriteBufferingAndExceptionHandler clientWbaeh =
|
||||
new WriteBufferingAndExceptionHandler(pn.newHandler(gh));
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ import io.grpc.netty.InternalProtocolNegotiator;
|
|||
import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel;
|
||||
import io.grpc.s2a.internal.handshaker.S2AIdentity;
|
||||
import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory;
|
||||
import io.grpc.s2a.internal.handshaker.S2AStub;
|
||||
import javax.annotation.concurrent.NotThreadSafe;
|
||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||
|
||||
|
|
@ -59,6 +60,7 @@ public final class S2AChannelCredentials {
|
|||
private final String s2aAddress;
|
||||
private final ChannelCredentials s2aChannelCredentials;
|
||||
private @Nullable S2AIdentity localIdentity = null;
|
||||
private @Nullable S2AStub stub = null;
|
||||
|
||||
Builder(String s2aAddress, ChannelCredentials s2aChannelCredentials) {
|
||||
this.s2aAddress = s2aAddress;
|
||||
|
|
@ -104,6 +106,16 @@ public final class S2AChannelCredentials {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the stub to use to communicate with S2A. This is only used for testing that the
|
||||
* stream to S2A gets closed.
|
||||
*/
|
||||
public Builder setStub(S2AStub stub) {
|
||||
checkNotNull(stub);
|
||||
this.stub = stub;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChannelCredentials build() {
|
||||
return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory());
|
||||
}
|
||||
|
|
@ -113,7 +125,7 @@ public final class S2AChannelCredentials {
|
|||
SharedResourcePool.forResource(
|
||||
S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials));
|
||||
checkNotNull(s2aChannelPool, "s2aChannelPool");
|
||||
return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool);
|
||||
return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool, stub);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -63,28 +63,35 @@ public final class S2AProtocolNegotiatorFactory {
|
|||
* @param localIdentity the identity of the client; if none is provided, the S2A will use the
|
||||
* client's default identity.
|
||||
* @param s2aChannelPool a pool of shared channels that can be used to connect to the S2A.
|
||||
* @param stub the stub to use to communicate with S2A. If none is provided the channelPool
|
||||
* will be used to create the stub. This is exposed for verifying the stream to S2A gets
|
||||
* closed in tests.
|
||||
* @return a factory for creating a client-side protocol negotiator.
|
||||
*/
|
||||
public static InternalProtocolNegotiator.ClientFactory createClientFactory(
|
||||
@Nullable S2AIdentity localIdentity, ObjectPool<Channel> s2aChannelPool) {
|
||||
@Nullable S2AIdentity localIdentity, ObjectPool<Channel> s2aChannelPool,
|
||||
@Nullable S2AStub stub) {
|
||||
checkNotNull(s2aChannelPool, "S2A channel pool should not be null.");
|
||||
return new S2AClientProtocolNegotiatorFactory(localIdentity, s2aChannelPool);
|
||||
return new S2AClientProtocolNegotiatorFactory(localIdentity, s2aChannelPool, stub);
|
||||
}
|
||||
|
||||
static final class S2AClientProtocolNegotiatorFactory
|
||||
implements InternalProtocolNegotiator.ClientFactory {
|
||||
private final @Nullable S2AIdentity localIdentity;
|
||||
private final ObjectPool<Channel> channelPool;
|
||||
private final @Nullable S2AStub stub;
|
||||
|
||||
S2AClientProtocolNegotiatorFactory(
|
||||
@Nullable S2AIdentity localIdentity, ObjectPool<Channel> channelPool) {
|
||||
@Nullable S2AIdentity localIdentity, ObjectPool<Channel> channelPool,
|
||||
@Nullable S2AStub stub) {
|
||||
this.localIdentity = localIdentity;
|
||||
this.channelPool = channelPool;
|
||||
this.stub = stub;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ProtocolNegotiator newNegotiator() {
|
||||
return S2AProtocolNegotiator.createForClient(channelPool, localIdentity);
|
||||
return S2AProtocolNegotiator.createForClient(channelPool, localIdentity, stub);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -98,18 +105,20 @@ public final class S2AProtocolNegotiatorFactory {
|
|||
static final class S2AProtocolNegotiator implements ProtocolNegotiator {
|
||||
|
||||
private final ObjectPool<Channel> channelPool;
|
||||
private final Channel channel;
|
||||
private @Nullable Channel channel = null;
|
||||
private final Optional<S2AIdentity> localIdentity;
|
||||
private final @Nullable S2AStub stub;
|
||||
private final ListeningExecutorService service =
|
||||
MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1));
|
||||
|
||||
static S2AProtocolNegotiator createForClient(
|
||||
ObjectPool<Channel> channelPool, @Nullable S2AIdentity localIdentity) {
|
||||
ObjectPool<Channel> channelPool, @Nullable S2AIdentity localIdentity,
|
||||
@Nullable S2AStub stub) {
|
||||
checkNotNull(channelPool, "Channel pool should not be null.");
|
||||
if (localIdentity == null) {
|
||||
return new S2AProtocolNegotiator(channelPool, Optional.empty());
|
||||
return new S2AProtocolNegotiator(channelPool, Optional.empty(), stub);
|
||||
} else {
|
||||
return new S2AProtocolNegotiator(channelPool, Optional.of(localIdentity));
|
||||
return new S2AProtocolNegotiator(channelPool, Optional.of(localIdentity), stub);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -122,10 +131,13 @@ public final class S2AProtocolNegotiatorFactory {
|
|||
}
|
||||
|
||||
private S2AProtocolNegotiator(ObjectPool<Channel> channelPool,
|
||||
Optional<S2AIdentity> localIdentity) {
|
||||
Optional<S2AIdentity> localIdentity, @Nullable S2AStub stub) {
|
||||
this.channelPool = channelPool;
|
||||
this.localIdentity = localIdentity;
|
||||
this.channel = channelPool.getObject();
|
||||
this.stub = stub;
|
||||
if (this.stub == null) {
|
||||
this.channel = channelPool.getObject();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -139,13 +151,15 @@ public final class S2AProtocolNegotiatorFactory {
|
|||
String hostname = getHostNameFromAuthority(grpcHandler.getAuthority());
|
||||
checkArgument(!isNullOrEmpty(hostname), "hostname should not be null or empty.");
|
||||
return new S2AProtocolNegotiationHandler(
|
||||
grpcHandler, channel, localIdentity, hostname, service);
|
||||
grpcHandler, channel, localIdentity, hostname, service, stub);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
service.shutdown();
|
||||
channelPool.returnObject(channel);
|
||||
if (channel != null) {
|
||||
channelPool.returnObject(channel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -180,18 +194,20 @@ public final class S2AProtocolNegotiatorFactory {
|
|||
}
|
||||
|
||||
private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler {
|
||||
private final Channel channel;
|
||||
private final @Nullable Channel channel;
|
||||
private final Optional<S2AIdentity> localIdentity;
|
||||
private final String hostname;
|
||||
private final GrpcHttp2ConnectionHandler grpcHandler;
|
||||
private final ListeningExecutorService service;
|
||||
private final @Nullable S2AStub stub;
|
||||
|
||||
private S2AProtocolNegotiationHandler(
|
||||
GrpcHttp2ConnectionHandler grpcHandler,
|
||||
Channel channel,
|
||||
Optional<S2AIdentity> localIdentity,
|
||||
String hostname,
|
||||
ListeningExecutorService service) {
|
||||
ListeningExecutorService service,
|
||||
@Nullable S2AStub stub) {
|
||||
super(
|
||||
// superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next'
|
||||
// handler but we don't have a next handler _yet_. So we "disable" superclass's behavior
|
||||
|
|
@ -209,6 +225,7 @@ public final class S2AProtocolNegotiatorFactory {
|
|||
this.hostname = hostname;
|
||||
checkNotNull(service, "service should not be null.");
|
||||
this.service = service;
|
||||
this.stub = stub;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -217,8 +234,13 @@ public final class S2AProtocolNegotiatorFactory {
|
|||
BufferReadsHandler bufferReads = new BufferReadsHandler();
|
||||
ctx.pipeline().addBefore(ctx.name(), /* name= */ null, bufferReads);
|
||||
|
||||
S2AServiceGrpc.S2AServiceStub stub = S2AServiceGrpc.newStub(channel);
|
||||
S2AStub s2aStub = S2AStub.newInstance(stub);
|
||||
S2AStub s2aStub;
|
||||
if (this.stub == null) {
|
||||
checkNotNull(channel, "Channel to S2A should not be null");
|
||||
s2aStub = S2AStub.newInstance(S2AServiceGrpc.newStub(channel));
|
||||
} else {
|
||||
s2aStub = this.stub;
|
||||
}
|
||||
|
||||
ListenableFuture<SslContext> sslContextFuture =
|
||||
service.submit(() -> SslContextFactory.createForClient(s2aStub, hostname, localIdentity));
|
||||
|
|
@ -230,11 +252,17 @@ public final class S2AProtocolNegotiatorFactory {
|
|||
ChannelHandler handler =
|
||||
InternalProtocolNegotiators.tls(
|
||||
sslContext,
|
||||
SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR))
|
||||
SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR),
|
||||
Optional.of(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
s2aStub.close();
|
||||
}
|
||||
}))
|
||||
.newHandler(grpcHandler);
|
||||
|
||||
// Remove the bufferReads handler and delegate the rest of the handshake to the TLS
|
||||
// handler.
|
||||
// Delegate the rest of the handshake to the TLS handler. and remove the
|
||||
// bufferReads handler.
|
||||
ctx.pipeline().addAfter(ctx.name(), /* name= */ null, handler);
|
||||
fireProtocolNegotiationEvent(ctx);
|
||||
ctx.pipeline().remove(bufferReads);
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ import javax.annotation.concurrent.NotThreadSafe;
|
|||
|
||||
/** Reads and writes messages to and from the S2A. */
|
||||
@NotThreadSafe
|
||||
class S2AStub implements AutoCloseable {
|
||||
public class S2AStub implements AutoCloseable {
|
||||
private static final Logger logger = Logger.getLogger(S2AStub.class.getName());
|
||||
private static final long HANDSHAKE_RPC_DEADLINE_SECS = 20;
|
||||
private final StreamObserver<SessionResp> reader = new Reader();
|
||||
|
|
@ -42,6 +42,7 @@ class S2AStub implements AutoCloseable {
|
|||
private StreamObserver<SessionReq> writer;
|
||||
private boolean doneReading = false;
|
||||
private boolean doneWriting = false;
|
||||
private boolean isClosed = false;
|
||||
|
||||
static S2AStub newInstance(S2AServiceGrpc.S2AServiceStub serviceStub) {
|
||||
checkNotNull(serviceStub);
|
||||
|
|
@ -136,6 +137,11 @@ class S2AStub implements AutoCloseable {
|
|||
if (writer != null) {
|
||||
writer.onCompleted();
|
||||
}
|
||||
isClosed = true;
|
||||
}
|
||||
|
||||
public boolean isClosed() {
|
||||
return isClosed;
|
||||
}
|
||||
|
||||
/** Create a new writer if the writer is null. */
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ package io.grpc.s2a.internal.handshaker;
|
|||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
|
||||
import io.grpc.Channel;
|
||||
import io.grpc.ChannelCredentials;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.InsecureChannelCredentials;
|
||||
|
|
@ -29,9 +30,12 @@ import io.grpc.ServerCredentials;
|
|||
import io.grpc.TlsChannelCredentials;
|
||||
import io.grpc.TlsServerCredentials;
|
||||
import io.grpc.benchmarks.Utils;
|
||||
import io.grpc.internal.ObjectPool;
|
||||
import io.grpc.internal.SharedResourcePool;
|
||||
import io.grpc.netty.GrpcSslContexts;
|
||||
import io.grpc.netty.NettyServerBuilder;
|
||||
import io.grpc.s2a.S2AChannelCredentials;
|
||||
import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel;
|
||||
import io.grpc.s2a.internal.handshaker.FakeS2AServer;
|
||||
import io.grpc.stub.StreamObserver;
|
||||
import io.grpc.testing.protobuf.SimpleRequest;
|
||||
|
|
@ -141,6 +145,25 @@ public final class IntegrationTest {
|
|||
assertThat(doUnaryRpc(channel)).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void clientCommunicateUsingS2ACredentialsSucceeds_verifyStreamToS2AClosed()
|
||||
throws Exception {
|
||||
ObjectPool<Channel> s2aChannelPool =
|
||||
SharedResourcePool.forResource(
|
||||
S2AHandshakerServiceChannel.getChannelResource(s2aAddress,
|
||||
InsecureChannelCredentials.create()));
|
||||
Channel ch = s2aChannelPool.getObject();
|
||||
S2AStub stub = S2AStub.newInstance(S2AServiceGrpc.newStub(ch));
|
||||
ChannelCredentials credentials =
|
||||
S2AChannelCredentials.newBuilder(s2aAddress, InsecureChannelCredentials.create())
|
||||
.setLocalSpiffeId("test-spiffe-id").setStub(stub).build();
|
||||
ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build();
|
||||
|
||||
s2aChannelPool.returnObject(ch);
|
||||
assertThat(doUnaryRpc(channel)).isTrue();
|
||||
assertThat(stub.isClosed()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Exception {
|
||||
String privateKeyPath = "src/test/resources/client_key.pem";
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ public class S2AProtocolNegotiatorFactoryTest {
|
|||
@Test
|
||||
public void createProtocolNegotiatorFactory_getsDefaultPort_succeeds() throws Exception {
|
||||
InternalProtocolNegotiator.ClientFactory clientFactory =
|
||||
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool);
|
||||
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null);
|
||||
|
||||
assertThat(clientFactory.getDefaultPort()).isEqualTo(S2AProtocolNegotiatorFactory.DEFAULT_PORT);
|
||||
}
|
||||
|
|
@ -146,7 +146,7 @@ public class S2AProtocolNegotiatorFactoryTest {
|
|||
public void createProtocolNegotiatorFactory_buildsAnS2AProtocolNegotiatorOnClientSide_succeeds()
|
||||
throws Exception {
|
||||
InternalProtocolNegotiator.ClientFactory clientFactory =
|
||||
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool);
|
||||
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null);
|
||||
|
||||
ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator();
|
||||
|
||||
|
|
@ -158,7 +158,7 @@ public class S2AProtocolNegotiatorFactoryTest {
|
|||
public void closeProtocolNegotiator_verifyProtocolNegotiatorIsClosedOnClientSide()
|
||||
throws Exception {
|
||||
InternalProtocolNegotiator.ClientFactory clientFactory =
|
||||
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool);
|
||||
S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null);
|
||||
ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator();
|
||||
|
||||
clientNegotiator.close();
|
||||
|
|
@ -170,7 +170,7 @@ public class S2AProtocolNegotiatorFactoryTest {
|
|||
public void createChannelHandler_addHandlerToMockContext() throws Exception {
|
||||
ProtocolNegotiator clientNegotiator =
|
||||
S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.createForClient(
|
||||
channelPool, LOCAL_IDENTITY);
|
||||
channelPool, LOCAL_IDENTITY, null);
|
||||
|
||||
ChannelHandler channelHandler = clientNegotiator.newHandler(fakeConnectionHandler);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue