diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 9a46543ea9..1227510575 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -145,6 +145,11 @@ public final class GrpcUtil { */ public static final int DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024; + /** + * The default maximum size (in bytes) for inbound header/trailer. + */ + public static final int DEFAULT_MAX_HEADER_LIST_SIZE = 8192; + /** * The set of valid status codes for client cancellation. */ diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index a520d86a09..cbfd240028 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -71,6 +71,7 @@ public class NettyChannelBuilder extends AbstractManagedChannelImplBuilder 0, "flowControlWindow must be positive"); + checkArgument(flowControlWindow > 0, "flowControlWindow must be positive"); this.flowControlWindow = flowControlWindow; return this; } /** * Sets the maximum message size allowed to be received on the channel. If not called, - * defaults to {@link io.grpc.internal.GrpcUtil#DEFAULT_MAX_MESSAGE_SIZE}. + * defaults to {@link GrpcUtil#DEFAULT_MAX_MESSAGE_SIZE}. */ public final NettyChannelBuilder maxMessageSize(int maxMessageSize) { checkArgument(maxMessageSize >= 0, "maxMessageSize must be >= 0"); @@ -175,6 +176,16 @@ public class NettyChannelBuilder extends AbstractManagedChannelImplBuilder 0, "maxHeaderListSize must be > 0"); + this.maxHeaderListSize = maxHeaderListSize; + return this; + } + /** * Equivalent to using {@link #negotiationType(NegotiationType)} with {@code PLAINTEXT} or * {@code PLAINTEXT_UPGRADE}. @@ -192,7 +203,7 @@ public class NettyChannelBuilder extends AbstractManagedChannelImplBuilder channelType, NegotiationType negotiationType, SslContext sslContext, EventLoopGroup group, int flowControlWindow, - int maxMessageSize) { + int maxMessageSize, + int maxHeaderListSize) { this.channelType = channelType; this.negotiationType = negotiationType; this.sslContext = sslContext; this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; + this.maxHeaderListSize = maxHeaderListSize; usingSharedGroup = group == null; if (usingSharedGroup) { // The group was unspecified, using the shared group. @@ -272,7 +286,7 @@ public class NettyChannelBuilder extends AbstractManagedChannelImplBuilder channelType, EventLoopGroup group, ProtocolNegotiator negotiator, - int flowControlWindow, int maxMessageSize, String authority) { + int flowControlWindow, int maxMessageSize, int maxHeaderListSize, + String authority) { Preconditions.checkNotNull(negotiator, "negotiator"); this.address = Preconditions.checkNotNull(address, "address"); this.group = Preconditions.checkNotNull(group, "group"); this.channelType = Preconditions.checkNotNull(channelType, "channelType"); this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; + this.maxHeaderListSize = maxHeaderListSize; this.authority = new AsciiString(authority); handler = newHandler(); @@ -244,7 +250,9 @@ class NettyClientTransport implements ClientTransport { private NettyClientHandler newHandler() { Http2Connection connection = new DefaultHttp2Connection(false); - Http2FrameReader frameReader = new DefaultHttp2FrameReader(); + Http2HeadersDecoder headersDecoder = + new DefaultHttp2HeadersDecoder(maxHeaderListSize, Http2CodecUtil.DEFAULT_HEADER_TABLE_SIZE); + Http2FrameReader frameReader = new DefaultHttp2FrameReader(headersDecoder); Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter(); Http2FrameLogger frameLogger = new Http2FrameLogger(LogLevel.DEBUG, getClass()); diff --git a/netty/src/main/java/io/grpc/netty/NettyServer.java b/netty/src/main/java/io/grpc/netty/NettyServer.java index f1ba4df279..b59264796d 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServer.java +++ b/netty/src/main/java/io/grpc/netty/NettyServer.java @@ -76,12 +76,13 @@ public class NettyServer implements Server { private Channel channel; private final int flowControlWindow; private final int maxMessageSize; + private final int maxHeaderListSize; private final ReferenceCounted eventLoopReferenceCounter = new EventLoopReferenceCounter(); NettyServer(SocketAddress address, Class channelType, @Nullable EventLoopGroup bossGroup, @Nullable EventLoopGroup workerGroup, @Nullable SslContext sslContext, int maxStreamsPerConnection, int flowControlWindow, - int maxMessageSize) { + int maxMessageSize, int maxHeaderListSize) { this.address = address; this.channelType = checkNotNull(channelType, "channelType"); this.bossGroup = bossGroup; @@ -92,6 +93,7 @@ public class NettyServer implements Server { this.maxStreamsPerConnection = maxStreamsPerConnection; this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; + this.maxHeaderListSize = maxHeaderListSize; } @Override @@ -119,7 +121,7 @@ public class NettyServer implements Server { }); NettyServerTransport transport = new NettyServerTransport(ch, sslContext, maxStreamsPerConnection, flowControlWindow, - maxMessageSize); + maxMessageSize, maxHeaderListSize); transport.start(listener.transportCreated(transport)); } }); diff --git a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java index cc1a3591a7..198916e68c 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java @@ -39,6 +39,7 @@ import com.google.common.base.Preconditions; import io.grpc.ExperimentalApi; import io.grpc.HandlerRegistry; import io.grpc.internal.AbstractServerImplBuilder; +import io.grpc.internal.GrpcUtil; import io.netty.channel.EventLoopGroup; import io.netty.channel.ServerChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; @@ -68,6 +69,7 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder 0, "max must be positive: %s", maxCalls); + checkArgument(maxCalls > 0, "max must be positive: %s", maxCalls); this.maxConcurrentCallsPerConnection = maxCalls; return this; } @@ -192,7 +194,7 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder 0, "flowControlWindow must be positive"); + checkArgument(flowControlWindow > 0, "flowControlWindow must be positive"); this.flowControlWindow = flowControlWindow; return this; } @@ -207,11 +209,21 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder 0, "maxHeaderListSize must be > 0"); + this.maxHeaderListSize = maxHeaderListSize; + return this; + } + @Override protected NettyServer buildTransportServer() { return new NettyServer(address, channelType, bossEventLoopGroup, workerEventLoopGroup, sslContext, maxConcurrentCallsPerConnection, flowControlWindow, - maxMessageSize); + maxMessageSize, maxHeaderListSize); } @Override diff --git a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java index 7fb37a6b36..407603ff4d 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java @@ -42,10 +42,13 @@ import io.netty.channel.ChannelHandler; import io.netty.handler.codec.http2.DefaultHttp2Connection; import io.netty.handler.codec.http2.DefaultHttp2FrameReader; import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; +import io.netty.handler.codec.http2.DefaultHttp2HeadersDecoder; +import io.netty.handler.codec.http2.Http2CodecUtil; import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2FrameLogger; import io.netty.handler.codec.http2.Http2FrameReader; import io.netty.handler.codec.http2.Http2FrameWriter; +import io.netty.handler.codec.http2.Http2HeadersDecoder; import io.netty.handler.codec.http2.Http2InboundFrameLogger; import io.netty.handler.codec.http2.Http2OutboundFrameLogger; import io.netty.handler.logging.LogLevel; @@ -70,14 +73,16 @@ class NettyServerTransport implements ServerTransport { private boolean terminated; private final int flowControlWindow; private final int maxMessageSize; + private final int maxHeaderListSize; NettyServerTransport(Channel channel, @Nullable SslContext sslContext, int maxStreams, - int flowControlWindow, int maxMessageSize) { + int flowControlWindow, int maxMessageSize, int maxHeaderListSize) { this.channel = Preconditions.checkNotNull(channel, "channel"); this.sslContext = sslContext; this.maxStreams = maxStreams; this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; + this.maxHeaderListSize = maxHeaderListSize; } public void start(ServerTransportListener listener) { @@ -133,8 +138,10 @@ class NettyServerTransport implements ServerTransport { private NettyServerHandler createHandler(ServerTransportListener transportListener) { Http2Connection connection = new DefaultHttp2Connection(true); Http2FrameLogger frameLogger = new Http2FrameLogger(LogLevel.DEBUG, getClass()); - Http2FrameReader frameReader = - new Http2InboundFrameLogger(new DefaultHttp2FrameReader(), frameLogger); + Http2HeadersDecoder headersDecoder = + new DefaultHttp2HeadersDecoder(maxHeaderListSize, Http2CodecUtil.DEFAULT_HEADER_TABLE_SIZE); + Http2FrameReader frameReader = new Http2InboundFrameLogger( + new DefaultHttp2FrameReader(headersDecoder), frameLogger); Http2FrameWriter frameWriter = new Http2OutboundFrameLogger(new DefaultHttp2FrameWriter(), frameLogger); diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 40de685847..0dac683d70 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -61,6 +61,7 @@ import io.grpc.testing.TestUtils; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http2.Http2Exception; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SupportedCipherSuiteFilter; @@ -154,14 +155,15 @@ public class NettyClientTransportTest { assertEquals(1, serverListener.streamListeners.size()); Metadata receivedHeaders = serverListener.streamListeners.get(0).headers; assertEquals(GrpcUtil.getGrpcUserAgent("netty", userAgent), - receivedHeaders.get(USER_AGENT_KEY)); + receivedHeaders.get(USER_AGENT_KEY)); } @Test public void maxMessageSizeShouldBeEnforced() throws Throwable { startServer(); // Allow the response payloads of up to 1 byte. - NettyClientTransport transport = newTransport(newNegotiator(), 1); + NettyClientTransport transport = newTransport(newNegotiator(), + 1, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); transport.start(clientTransportListener); try { @@ -205,7 +207,7 @@ public class NettyClientTransportTest { @Test public void bufferedStreamsShouldBeClosedWhenConnectionTerminates() throws Exception { // Only allow a single stream active at a time. - startServer(1); + startServer(1, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); NettyClientTransport transport = newTransport(newNegotiator()); transport.start(clientTransportListener); @@ -234,6 +236,51 @@ public class NettyClientTransportTest { } } + @Test + public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception { + startServer(); + + NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, 1); + transport.start(clientTransportListener); + + try { + // Send a single RPC and wait for the response. + new Rpc(transport, new Metadata()).halfClose().waitForResponse(); + fail("The stream should have been failed due to client received header exceeds header list" + + " size limit!"); + } catch (Exception e) { + Throwable rootCause = getRootCause(e); + assertTrue(rootCause instanceof Http2Exception); + assertEquals("Header size exceeded max allowed size (1)", rootCause.getMessage()); + } + } + + @Test + public void maxHeaderListSizeShouldBeEnforcedOnServer() throws Exception { + startServer(100, 1); + + NettyClientTransport transport = newTransport(newNegotiator()); + transport.start(clientTransportListener); + + try { + // Send a single RPC and wait for the response. + new Rpc(transport, new Metadata()).halfClose().waitForResponse(); + fail("The stream should have been failed due to server received header exceeds header list" + + " size limit!"); + } catch (Exception e) { + Throwable rootCause = getRootCause(e); + assertTrue(rootCause.getMessage(), + rootCause.getMessage().contains("Header size exceeded max allowed size (1)")); + } + } + + private Throwable getRootCause(Throwable t) { + if (t.getCause() == null) { + return t; + } + return getRootCause(t.getCause()); + } + private ProtocolNegotiator newNegotiator() throws IOException { File clientCert = TestUtils.loadCert("ca.pem"); SslContext clientContext = GrpcSslContexts.forClient().trustManager(clientCert) @@ -242,28 +289,30 @@ public class NettyClientTransportTest { } private NettyClientTransport newTransport(ProtocolNegotiator negotiator) { - return newTransport(negotiator, DEFAULT_MAX_MESSAGE_SIZE); + return newTransport(negotiator, + DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); } - private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int maxMsgSize) { + private NettyClientTransport newTransport(ProtocolNegotiator negotiator, + int maxMsgSize, int maxHeaderListSize) { NettyClientTransport transport = new NettyClientTransport(address, NioSocketChannel.class, - group, negotiator, DEFAULT_WINDOW_SIZE, maxMsgSize, authority); + group, negotiator, DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize, authority); transports.add(transport); return transport; } private void startServer() throws IOException { - startServer(100); + startServer(100, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); } - private void startServer(int maxStreamsPerConnection) throws IOException { + private void startServer(int maxStreamsPerConnection, int maxHeaderListSize) throws IOException { File serverCert = TestUtils.loadCert("server1.pem"); File key = TestUtils.loadCert("server1.key"); SslContext serverContext = GrpcSslContexts.forServer(serverCert, key) .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build(); server = new NettyServer(address, NioServerSocketChannel.class, group, group, serverContext, maxStreamsPerConnection, - DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE); + DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, maxHeaderListSize); server.start(serverListener); }