From 0987dc401ca55885b99e98a857705396ab47828e Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Fri, 17 Nov 2023 15:10:52 -0800 Subject: [PATCH] netty: Add option to limit RST_STREAM rate The behavior purposefully mirrors that of Netty's AbstractHttp2ConnectionHandlerBuilder.decoderEnforceMaxRstFramesPerWindow(). That API is not available to our code as we extend the Http2ConnectionHandler, but we want our API to be able to delegate to Netty's in the future if that ever becomes possible. --- .../main/java/io/grpc/netty/NettyServer.java | 7 +++ .../io/grpc/netty/NettyServerBuilder.java | 32 ++++++++++++- .../io/grpc/netty/NettyServerHandler.java | 40 +++++++++++++++++ .../io/grpc/netty/NettyServerTransport.java | 8 ++++ .../grpc/netty/NettyClientTransportTest.java | 3 +- .../io/grpc/netty/NettyServerHandlerTest.java | 45 +++++++++++++++++++ .../java/io/grpc/netty/NettyServerTest.java | 7 +++ 7 files changed, 140 insertions(+), 2 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/NettyServer.java b/netty/src/main/java/io/grpc/netty/NettyServer.java index fe7913870f..2960604e5b 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServer.java +++ b/netty/src/main/java/io/grpc/netty/NettyServer.java @@ -99,6 +99,8 @@ class NettyServer implements InternalServer, InternalWithLogId { private final long maxConnectionAgeGraceInNanos; private final boolean permitKeepAliveWithoutCalls; private final long permitKeepAliveTimeInNanos; + private final int maxRstCount; + private final long maxRstPeriodNanos; private final Attributes eagAttributes; private final ReferenceCounted sharedResourceReferenceCounter = new SharedResourceReferenceCounter(); @@ -127,6 +129,7 @@ class NettyServer implements InternalServer, InternalWithLogId { long maxConnectionIdleInNanos, long maxConnectionAgeInNanos, long maxConnectionAgeGraceInNanos, boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos, + int maxRstCount, long maxRstPeriodNanos, Attributes eagAttributes, InternalChannelz channelz) { this.addresses = checkNotNull(addresses, "addresses"); this.channelFactory = checkNotNull(channelFactory, "channelFactory"); @@ -156,6 +159,8 @@ class NettyServer implements InternalServer, InternalWithLogId { this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos; this.permitKeepAliveWithoutCalls = permitKeepAliveWithoutCalls; this.permitKeepAliveTimeInNanos = permitKeepAliveTimeInNanos; + this.maxRstCount = maxRstCount; + this.maxRstPeriodNanos = maxRstPeriodNanos; this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes"); this.channelz = Preconditions.checkNotNull(channelz); this.logId = InternalLogId.allocate(getClass(), addresses.isEmpty() ? "No address" : @@ -257,6 +262,8 @@ class NettyServer implements InternalServer, InternalWithLogId { maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, + maxRstCount, + maxRstPeriodNanos, eagAttributes); ServerTransportListener transportListener; // This is to order callbacks on the listener, not to guard access to channel. diff --git a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java index 9411a979ed..525f8953e0 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java @@ -75,6 +75,7 @@ public final class NettyServerBuilder extends ForwardingServerBuildergRPC clients send RST_STREAM when they cancel RPCs, so some RST_STREAMs are normal and + * setting this too low can cause errors for legimitate clients. + * + *

By default there is no limit. + * + * @param maxRstStream the positive limit of RST_STREAM frames per connection per period, or + * {@code Integer.MAX_VALUE} for unlimited + * @param secondsPerWindow the positive number of seconds per period + */ + @CanIgnoreReturnValue + public NettyServerBuilder maxRstFramesPerWindow(int maxRstStream, int secondsPerWindow) { + checkArgument(maxRstStream > 0, "maxRstStream must be positive"); + checkArgument(secondsPerWindow > 0, "secondsPerWindow must be positive"); + if (maxRstStream == Integer.MAX_VALUE) { + maxRstStream = MAX_RST_COUNT_DISABLED; + } + this.maxRstCount = maxRstStream; + this.maxRstPeriodNanos = TimeUnit.SECONDS.toNanos(secondsPerWindow); + return this; + } + /** Sets the EAG attributes available to protocol negotiators. Not for general use. */ void eagAttributes(Attributes eagAttributes) { this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes"); @@ -664,7 +694,7 @@ public final class NettyServerBuilder extends ForwardingServerBuilder streamTracerFactories; private final TransportTracer transportTracer; private final KeepAliveEnforcer keepAliveEnforcer; private final Attributes eagAttributes; + private final Ticker ticker; /** Incomplete attributes produced by negotiator. */ private Attributes negotiationAttributes; private InternalChannelz.Security securityInfo; @@ -146,6 +149,9 @@ class NettyServerHandler extends AbstractNettyHandler { private ScheduledFuture maxConnectionAgeMonitor; @CheckForNull private GracefulShutdown gracefulShutdown; + private int rstCount; + private long lastRstNanoTime; + static NettyServerHandler newHandler( ServerTransportListener transportListener, @@ -164,6 +170,8 @@ class NettyServerHandler extends AbstractNettyHandler { long maxConnectionAgeGraceInNanos, boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos, + int maxRstCount, + long maxRstPeriodNanos, Attributes eagAttributes) { Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive: %s", maxHeaderListSize); @@ -192,6 +200,8 @@ class NettyServerHandler extends AbstractNettyHandler { maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, + maxRstCount, + maxRstPeriodNanos, eagAttributes, Ticker.systemTicker()); } @@ -215,6 +225,8 @@ class NettyServerHandler extends AbstractNettyHandler { long maxConnectionAgeGraceInNanos, boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos, + int maxRstCount, + long maxRstPeriodNanos, Attributes eagAttributes, Ticker ticker) { Preconditions.checkArgument(maxStreams > 0, "maxStreams must be positive: %s", maxStreams); @@ -266,6 +278,8 @@ class NettyServerHandler extends AbstractNettyHandler { maxConnectionAgeInNanos, maxConnectionAgeGraceInNanos, keepAliveEnforcer, autoFlowControl, + maxRstCount, + maxRstPeriodNanos, eagAttributes, ticker); } @@ -286,6 +300,8 @@ class NettyServerHandler extends AbstractNettyHandler { long maxConnectionAgeGraceInNanos, final KeepAliveEnforcer keepAliveEnforcer, boolean autoFlowControl, + int maxRstCount, + long maxRstPeriodNanos, Attributes eagAttributes, Ticker ticker) { super(channelUnused, decoder, encoder, settings, new ServerChannelLogger(), @@ -328,8 +344,12 @@ class NettyServerHandler extends AbstractNettyHandler { this.maxConnectionAgeInNanos = maxConnectionAgeInNanos; this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos; this.keepAliveEnforcer = checkNotNull(keepAliveEnforcer, "keepAliveEnforcer"); + this.maxRstCount = maxRstCount; + this.maxRstPeriodNanos = maxRstPeriodNanos; this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes"); + this.ticker = checkNotNull(ticker, "ticker"); + this.lastRstNanoTime = ticker.read(); streamKey = encoder.connection().newKey(); this.transportListener = checkNotNull(transportListener, "transportListener"); this.streamTracerFactories = checkNotNull(streamTracerFactories, "streamTracerFactories"); @@ -527,6 +547,26 @@ class NettyServerHandler extends AbstractNettyHandler { } private void onRstStreamRead(int streamId, long errorCode) throws Http2Exception { + if (maxRstCount > 0) { + long now = ticker.read(); + if (now - lastRstNanoTime > maxRstPeriodNanos) { + lastRstNanoTime = now; + rstCount = 1; + } else { + rstCount++; + if (rstCount > maxRstCount) { + throw new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") { + @SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses + @Override + public Throwable fillInStackTrace() { + // Avoid the CPU cycles, since the resets may be a CPU consumption attack + return this; + } + }; + } + } + } + try { NettyServerStream.TransportState stream = serverStream(connection().stream(streamId)); if (stream != null) { diff --git a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java index 46ddeb27c9..9511927a09 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java @@ -77,6 +77,8 @@ class NettyServerTransport implements ServerTransport { private final long maxConnectionAgeGraceInNanos; private final boolean permitKeepAliveWithoutCalls; private final long permitKeepAliveTimeInNanos; + private final int maxRstCount; + private final long maxRstPeriodNanos; private final Attributes eagAttributes; private final List streamTracerFactories; private final TransportTracer transportTracer; @@ -99,6 +101,8 @@ class NettyServerTransport implements ServerTransport { long maxConnectionAgeGraceInNanos, boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos, + int maxRstCount, + long maxRstPeriodNanos, Attributes eagAttributes) { this.channel = Preconditions.checkNotNull(channel, "channel"); this.channelUnused = channelUnused; @@ -118,6 +122,8 @@ class NettyServerTransport implements ServerTransport { this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos; this.permitKeepAliveWithoutCalls = permitKeepAliveWithoutCalls; this.permitKeepAliveTimeInNanos = permitKeepAliveTimeInNanos; + this.maxRstCount = maxRstCount; + this.maxRstPeriodNanos = maxRstPeriodNanos; this.eagAttributes = Preconditions.checkNotNull(eagAttributes, "eagAttributes"); SocketAddress remote = channel.remoteAddress(); this.logId = InternalLogId.allocate(getClass(), remote != null ? remote.toString() : null); @@ -277,6 +283,8 @@ class NettyServerTransport implements ServerTransport { maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, + maxRstCount, + maxRstPeriodNanos, eagAttributes); } } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index eabbbda318..39e6718a24 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -27,6 +27,7 @@ import static io.grpc.internal.GrpcUtil.USER_AGENT_KEY; import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE; import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED; import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED; +import static io.grpc.netty.NettyServerBuilder.MAX_RST_COUNT_DISABLED; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -826,7 +827,7 @@ public class NettyClientTransportTest { DEFAULT_SERVER_KEEPALIVE_TIME_NANOS, DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS, MAX_CONNECTION_IDLE_NANOS_DISABLED, MAX_CONNECTION_AGE_NANOS_DISABLED, MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE, true, 0, - Attributes.EMPTY, + MAX_RST_COUNT_DISABLED, 0, Attributes.EMPTY, channelz); server.start(serverListener); address = TestUtils.testServerAddress((InetSocketAddress) server.getListenSocketAddress()); diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 368b0600f9..281ff3b17d 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -23,6 +23,7 @@ import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIME_NANOS; import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE; import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED; import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED; +import static io.grpc.netty.NettyServerBuilder.MAX_RST_COUNT_DISABLED; import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC; import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER; import static io.grpc.netty.Utils.HTTP_METHOD; @@ -33,6 +34,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; @@ -85,6 +87,7 @@ import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.codec.http2.Http2Stream; import io.netty.util.AsciiString; import java.io.InputStream; +import java.nio.channels.ClosedChannelException; import java.util.Arrays; import java.util.LinkedList; import java.util.List; @@ -143,6 +146,8 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase rapidReset(maxRstCount + 1)); + assertFalse(channel().isOpen()); + } + + private void rapidReset(int burstSize) throws Exception { + Http2Headers headers = new DefaultHttp2Headers() + .method(HTTP_METHOD) + .set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8)) + .set(TE_HEADER, TE_TRAILERS) + .path(new AsciiString("/foo/bar")); + int streamId = 1; + long rpcTimeNanos = maxRstPeriodNanos / 2 / burstSize; + for (int period = 0; period < 3; period++) { + for (int i = 0; i < burstSize; i++) { + channelRead(headersFrame(streamId, headers)); + channelRead(rstStreamFrame(streamId, (int) Http2Error.CANCEL.code())); + streamId += 2; + fakeClock().forwardNanos(rpcTimeNanos); + } + while (channel().readOutbound() != null) {} + fakeClock().forwardNanos(maxRstPeriodNanos - rpcTimeNanos * burstSize + 1); + } + } + private void createStream() throws Exception { Http2Headers headers = new DefaultHttp2Headers() .method(HTTP_METHOD) @@ -1296,6 +1339,8 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase serverShutdownCalled = SettableFuture.create(); @@ -203,6 +204,7 @@ public class NettyServerTest { 1, 1, // ignore 1, 1, // ignore true, 0, // ignore + 0, 0, // ignore Attributes.EMPTY, channelz); final SettableFuture shutdownCompleted = SettableFuture.create(); @@ -276,6 +278,7 @@ public class NettyServerTest { 1, 1, // ignore 1, 1, // ignore true, 0, // ignore + 0, 0, // ignore Attributes.EMPTY, channelz); final SettableFuture shutdownCompleted = SettableFuture.create(); @@ -337,6 +340,7 @@ public class NettyServerTest { 1, 1, // ignore 1, 1, // ignore true, 0, // ignore + 0, 0, // ignore Attributes.EMPTY, channelz); @@ -411,6 +415,7 @@ public class NettyServerTest { 1, 1, // ignore 1, 1, // ignore true, 0, // ignore + 0, 0, // ignore eagAttributes, channelz); ns.start(new ServerListener() { @@ -458,6 +463,7 @@ public class NettyServerTest { 1, 1, // ignore 1, 1, // ignore true, 0, // ignore + 0, 0, // ignore Attributes.EMPTY, channelz); final SettableFuture shutdownCompleted = SettableFuture.create(); @@ -600,6 +606,7 @@ public class NettyServerTest { 1, 1, // ignore 1, 1, // ignore true, 0, // ignore + 0, 0, // ignore Attributes.EMPTY, channelz); }