From 814e36b54136fca39d86549c6c501ec49319bd0d Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Tue, 1 Dec 2020 17:30:03 -0800 Subject: [PATCH] alts: Limit number of concurrent handshakes to 32 --- .../alts/internal/TsiHandshakeHandler.java | 78 ++++++++++++++++++- 1 file changed, 75 insertions(+), 3 deletions(-) diff --git a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java index a4123a7a53..1b4737f3cd 100644 --- a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java +++ b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java @@ -31,12 +31,16 @@ import io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.ProtocolNegotiationEvent; import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; import io.netty.handler.codec.ByteToMessageDecoder; import java.security.GeneralSecurityException; +import java.util.LinkedList; import java.util.List; +import java.util.Queue; import javax.annotation.Nullable; /** @@ -78,12 +82,17 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder { } private static final int HANDSHAKE_FRAME_SIZE = 1024; + // Avoid performing too many handshakes in parallel, as it may cause queuing in the handshake + // server and cause unbounded blocking on the event loop (b/168808426). This is a workaround until + // there is an async TSI handshaking API to avoid the blocking. + private static final AsyncSemaphore semaphore = new AsyncSemaphore(32); private final NettyTsiHandshaker handshaker; private final HandshakeValidator handshakeValidator; private final ChannelHandler next; private ProtocolNegotiationEvent pne; + private boolean semaphoreAcquired; /** * Constructs a TsiHandshakeHandler. @@ -137,13 +146,37 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder { } @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + public void userEventTriggered(final ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof ProtocolNegotiationEvent) { checkState(pne == null, "negotiation already started"); pne = (ProtocolNegotiationEvent) evt; InternalProtocolNegotiators.negotiationLogger(ctx) .log(ChannelLogLevel.INFO, "TsiHandshake started"); - sendHandshake(ctx); + ChannelFuture acquire = semaphore.acquire(ctx); + if (acquire.isSuccess()) { + semaphoreAcquired = true; + sendHandshake(ctx); + } else { + acquire.addListener(new ChannelFutureListener() { + @Override public void operationComplete(ChannelFuture future) { + if (!future.isSuccess()) { + ctx.fireExceptionCaught(future.cause()); + return; + } + if (ctx.isRemoved()) { + semaphore.release(); + return; + } + semaphoreAcquired = true; + try { + sendHandshake(ctx); + } catch (Exception ex) { + ctx.fireExceptionCaught(ex); + } + ctx.flush(); + } + }); + } } else { super.userEventTriggered(ctx, evt); } @@ -188,6 +221,45 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder { @Override protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + if (semaphoreAcquired) { + semaphore.release(); + semaphoreAcquired = false; + } handshaker.close(); } -} \ No newline at end of file + + private static class AsyncSemaphore { + private final Object lock = new Object(); + @SuppressWarnings("JdkObsolete") // LinkedList avoids high watermark memory issues + private final Queue queue = new LinkedList<>(); + private int permits; + + public AsyncSemaphore(int permits) { + this.permits = permits; + } + + public ChannelFuture acquire(ChannelHandlerContext ctx) { + synchronized (lock) { + if (permits > 0) { + permits--; + return ctx.newSucceededFuture(); + } + ChannelPromise promise = ctx.newPromise(); + queue.add(promise); + return promise; + } + } + + public void release() { + ChannelPromise next; + synchronized (lock) { + next = queue.poll(); + if (next == null) { + permits++; + return; + } + } + next.setSuccess(); + } + } +}