okhttp: Add maxInboundMetadataSize

Fixes #4567
This commit is contained in:
Eric Anderson 2018-10-01 16:44:14 -07:00
parent 0eefa5263b
commit eaafb997e2
4 changed files with 72 additions and 11 deletions

View File

@ -161,6 +161,7 @@ public class OkHttpChannelBuilder extends
private long keepAliveTimeoutNanos = DEFAULT_KEEPALIVE_TIMEOUT_NANOS; private long keepAliveTimeoutNanos = DEFAULT_KEEPALIVE_TIMEOUT_NANOS;
private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW; private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW;
private boolean keepAliveWithoutCalls; private boolean keepAliveWithoutCalls;
private int maxInboundMetadataSize = Integer.MAX_VALUE;
protected OkHttpChannelBuilder(String host, int port) { protected OkHttpChannelBuilder(String host, int port) {
this(GrpcUtil.authorityFromHostAndPort(host, port)); this(GrpcUtil.authorityFromHostAndPort(host, port));
@ -405,6 +406,25 @@ public class OkHttpChannelBuilder extends
return this; return this;
} }
/**
* Sets the maximum size of metadata allowed to be received. {@code Integer.MAX_VALUE} disables
* the enforcement. Defaults to no limit ({@code Integer.MAX_VALUE}).
*
* <p>The implementation does not currently limit memory usage; this value is checked only after
* the metadata is decoded from the wire. It does prevent large metadata from being passed to the
* application.
*
* @param bytes the maximum size of received metadata
* @return this
* @throws IllegalArgumentException if bytes is non-positive
* @since 1.17.0
*/
public OkHttpChannelBuilder maxInboundMetadataSize(int bytes) {
Preconditions.checkArgument(bytes > 0, "maxInboundMetadataSize must be > 0");
this.maxInboundMetadataSize = bytes;
return this;
}
@Override @Override
@Internal @Internal
protected final ClientTransportFactory buildTransportFactory() { protected final ClientTransportFactory buildTransportFactory() {
@ -412,7 +432,7 @@ public class OkHttpChannelBuilder extends
return new OkHttpTransportFactory(transportExecutor, scheduledExecutorService, return new OkHttpTransportFactory(transportExecutor, scheduledExecutorService,
createSocketFactory(), hostnameVerifier, connectionSpec, maxInboundMessageSize(), createSocketFactory(), hostnameVerifier, connectionSpec, maxInboundMessageSize(),
enableKeepAlive, keepAliveTimeNanos, keepAliveTimeoutNanos, flowControlWindow, enableKeepAlive, keepAliveTimeNanos, keepAliveTimeoutNanos, flowControlWindow,
keepAliveWithoutCalls, transportTracerFactory); keepAliveWithoutCalls, maxInboundMetadataSize, transportTracerFactory);
} }
@Override @Override
@ -491,6 +511,7 @@ public class OkHttpChannelBuilder extends
private final long keepAliveTimeoutNanos; private final long keepAliveTimeoutNanos;
private final int flowControlWindow; private final int flowControlWindow;
private final boolean keepAliveWithoutCalls; private final boolean keepAliveWithoutCalls;
private final int maxInboundMetadataSize;
private final ScheduledExecutorService timeoutService; private final ScheduledExecutorService timeoutService;
private boolean closed; private boolean closed;
@ -505,6 +526,7 @@ public class OkHttpChannelBuilder extends
long keepAliveTimeoutNanos, long keepAliveTimeoutNanos,
int flowControlWindow, int flowControlWindow,
boolean keepAliveWithoutCalls, boolean keepAliveWithoutCalls,
int maxInboundMetadataSize,
TransportTracer.Factory transportTracerFactory) { TransportTracer.Factory transportTracerFactory) {
usingSharedScheduler = timeoutService == null; usingSharedScheduler = timeoutService == null;
this.timeoutService = usingSharedScheduler this.timeoutService = usingSharedScheduler
@ -518,6 +540,7 @@ public class OkHttpChannelBuilder extends
this.keepAliveTimeoutNanos = keepAliveTimeoutNanos; this.keepAliveTimeoutNanos = keepAliveTimeoutNanos;
this.flowControlWindow = flowControlWindow; this.flowControlWindow = flowControlWindow;
this.keepAliveWithoutCalls = keepAliveWithoutCalls; this.keepAliveWithoutCalls = keepAliveWithoutCalls;
this.maxInboundMetadataSize = maxInboundMetadataSize;
usingSharedExecutor = executor == null; usingSharedExecutor = executor == null;
this.transportTracerFactory = this.transportTracerFactory =
@ -556,6 +579,7 @@ public class OkHttpChannelBuilder extends
flowControlWindow, flowControlWindow,
options.getProxyParameters(), options.getProxyParameters(),
tooManyPingsRunnable, tooManyPingsRunnable,
maxInboundMetadataSize,
transportTracerFactory.create()); transportTracerFactory.create());
if (enableKeepAlive) { if (enableKeepAlive) {
transport.enableKeepAlive( transport.enableKeepAlive(

View File

@ -192,6 +192,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
private long keepAliveTimeoutNanos; private long keepAliveTimeoutNanos;
private boolean keepAliveWithoutCalls; private boolean keepAliveWithoutCalls;
private final Runnable tooManyPingsRunnable; private final Runnable tooManyPingsRunnable;
private final int maxInboundMetadataSize;
@GuardedBy("lock") @GuardedBy("lock")
private final TransportTracer transportTracer; private final TransportTracer transportTracer;
@GuardedBy("lock") @GuardedBy("lock")
@ -223,7 +224,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
Executor executor, @Nullable SSLSocketFactory sslSocketFactory, Executor executor, @Nullable SSLSocketFactory sslSocketFactory,
@Nullable HostnameVerifier hostnameVerifier, ConnectionSpec connectionSpec, @Nullable HostnameVerifier hostnameVerifier, ConnectionSpec connectionSpec,
int maxMessageSize, int initialWindowSize, @Nullable ProxyParameters proxy, int maxMessageSize, int initialWindowSize, @Nullable ProxyParameters proxy,
Runnable tooManyPingsRunnable, TransportTracer transportTracer) { Runnable tooManyPingsRunnable, int maxInboundMetadataSize, TransportTracer transportTracer) {
this.address = Preconditions.checkNotNull(address, "address"); this.address = Preconditions.checkNotNull(address, "address");
this.defaultAuthority = authority; this.defaultAuthority = authority;
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
@ -241,6 +242,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
this.proxy = proxy; this.proxy = proxy;
this.tooManyPingsRunnable = this.tooManyPingsRunnable =
Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable"); Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable");
this.maxInboundMetadataSize = maxInboundMetadataSize;
this.transportTracer = Preconditions.checkNotNull(transportTracer); this.transportTracer = Preconditions.checkNotNull(transportTracer);
initTransportTracer(); initTransportTracer();
} }
@ -281,6 +283,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
this.proxy = null; this.proxy = null;
this.tooManyPingsRunnable = this.tooManyPingsRunnable =
Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable"); Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable");
this.maxInboundMetadataSize = Integer.MAX_VALUE;
this.transportTracer = Preconditions.checkNotNull(transportTracer, "transportTracer"); this.transportTracer = Preconditions.checkNotNull(transportTracer, "transportTracer");
initTransportTracer(); initTransportTracer();
} }
@ -1067,6 +1070,18 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
List<Header> headerBlock, List<Header> headerBlock,
HeadersMode headersMode) { HeadersMode headersMode) {
boolean unknownStream = false; boolean unknownStream = false;
Status failedStatus = null;
if (maxInboundMetadataSize != Integer.MAX_VALUE) {
int metadataSize = headerBlockSize(headerBlock);
if (metadataSize > maxInboundMetadataSize) {
failedStatus = Status.RESOURCE_EXHAUSTED.withDescription(
String.format(
"Response %s metadata larger than %d: %d",
inFinished ? "trailer" : "header",
maxInboundMetadataSize,
metadataSize));
}
}
synchronized (lock) { synchronized (lock) {
OkHttpClientStream stream = streams.get(streamId); OkHttpClientStream stream = streams.get(streamId);
if (stream == null) { if (stream == null) {
@ -1076,7 +1091,14 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
unknownStream = true; unknownStream = true;
} }
} else { } else {
stream.transportState().transportHeadersReceived(headerBlock, inFinished); if (failedStatus == null) {
stream.transportState().transportHeadersReceived(headerBlock, inFinished);
} else {
if (!inFinished) {
frameWriter.rstStream(streamId, ErrorCode.CANCEL);
}
stream.transportState().transportReportStatus(failedStatus, false, new Metadata());
}
} }
} }
if (unknownStream) { if (unknownStream) {
@ -1085,6 +1107,17 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
} }
} }
private int headerBlockSize(List<Header> headerBlock) {
// Calculate as defined for SETTINGS_MAX_HEADER_LIST_SIZE in RFC 7540 §6.5.2.
long size = 0;
for (int i = 0; i < headerBlock.size(); i++) {
Header header = headerBlock.get(i);
size += 32 + header.name.size() + header.value.size();
}
size = Math.min(size, Integer.MAX_VALUE);
return (int) size;
}
@Override @Override
public void rstStream(int streamId, ErrorCode errorCode) { public void rstStream(int streamId, ErrorCode errorCode) {
Status status = toGrpcStatus(errorCode).augmentDescription("Rst Stream"); Status status = toGrpcStatus(errorCode).augmentDescription("Rst Stream");

View File

@ -135,6 +135,7 @@ public class OkHttpClientTransportTest {
private static final String NO_USER = null; private static final String NO_USER = null;
private static final String NO_PW = null; private static final String NO_PW = null;
private static final int DEFAULT_START_STREAM_ID = 3; private static final int DEFAULT_START_STREAM_ID = 3;
private static final int DEFAULT_MAX_INBOUND_METADATA_SIZE = Integer.MAX_VALUE;
@Rule public final Timeout globalTimeout = Timeout.seconds(10); @Rule public final Timeout globalTimeout = Timeout.seconds(10);
@ -245,6 +246,7 @@ public class OkHttpClientTransportTest {
INITIAL_WINDOW_SIZE, INITIAL_WINDOW_SIZE,
NO_PROXY, NO_PROXY,
tooManyPingsRunnable, tooManyPingsRunnable,
DEFAULT_MAX_INBOUND_METADATA_SIZE,
transportTracer); transportTracer);
String s = clientTransport.toString(); String s = clientTransport.toString();
assertTrue("Unexpected: " + s, s.contains("OkHttpClientTransport")); assertTrue("Unexpected: " + s, s.contains("OkHttpClientTransport"));
@ -1518,6 +1520,7 @@ public class OkHttpClientTransportTest {
INITIAL_WINDOW_SIZE, INITIAL_WINDOW_SIZE,
NO_PROXY, NO_PROXY,
tooManyPingsRunnable, tooManyPingsRunnable,
DEFAULT_MAX_INBOUND_METADATA_SIZE,
transportTracer); transportTracer);
String host = clientTransport.getOverridenHost(); String host = clientTransport.getOverridenHost();
@ -1541,6 +1544,7 @@ public class OkHttpClientTransportTest {
INITIAL_WINDOW_SIZE, INITIAL_WINDOW_SIZE,
NO_PROXY, NO_PROXY,
tooManyPingsRunnable, tooManyPingsRunnable,
DEFAULT_MAX_INBOUND_METADATA_SIZE,
new TransportTracer()); new TransportTracer());
ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class);
@ -1573,6 +1577,7 @@ public class OkHttpClientTransportTest {
new ProxyParameters( new ProxyParameters(
(InetSocketAddress) serverSocket.getLocalSocketAddress(), NO_USER, NO_PW), (InetSocketAddress) serverSocket.getLocalSocketAddress(), NO_USER, NO_PW),
tooManyPingsRunnable, tooManyPingsRunnable,
DEFAULT_MAX_INBOUND_METADATA_SIZE,
transportTracer); transportTracer);
clientTransport.start(transportListener); clientTransport.start(transportListener);
@ -1624,6 +1629,7 @@ public class OkHttpClientTransportTest {
new ProxyParameters( new ProxyParameters(
(InetSocketAddress) serverSocket.getLocalSocketAddress(), NO_USER, NO_PW), (InetSocketAddress) serverSocket.getLocalSocketAddress(), NO_USER, NO_PW),
tooManyPingsRunnable, tooManyPingsRunnable,
DEFAULT_MAX_INBOUND_METADATA_SIZE,
transportTracer); transportTracer);
clientTransport.start(transportListener); clientTransport.start(transportListener);
@ -1674,6 +1680,7 @@ public class OkHttpClientTransportTest {
new ProxyParameters( new ProxyParameters(
(InetSocketAddress) serverSocket.getLocalSocketAddress(), NO_USER, NO_PW), (InetSocketAddress) serverSocket.getLocalSocketAddress(), NO_USER, NO_PW),
tooManyPingsRunnable, tooManyPingsRunnable,
DEFAULT_MAX_INBOUND_METADATA_SIZE,
transportTracer); transportTracer);
clientTransport.start(transportListener); clientTransport.start(transportListener);

View File

@ -20,6 +20,7 @@ import io.grpc.ServerStreamTracer;
import io.grpc.internal.AccessProtectedHack; import io.grpc.internal.AccessProtectedHack;
import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.FakeClock; import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.InternalServer; import io.grpc.internal.InternalServer;
import io.grpc.internal.ManagedClientTransport; import io.grpc.internal.ManagedClientTransport;
import io.grpc.internal.testing.AbstractTransportTest; import io.grpc.internal.testing.AbstractTransportTest;
@ -41,6 +42,7 @@ public class OkHttpTransportTest extends AbstractTransportTest {
.forAddress("localhost", 0) .forAddress("localhost", 0)
.usePlaintext() .usePlaintext()
.setTransportTracerFactory(fakeClockTransportTracer) .setTransportTracerFactory(fakeClockTransportTracer)
.maxInboundMetadataSize(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE)
.buildTransportFactory(); .buildTransportFactory();
@After @After
@ -99,15 +101,10 @@ public class OkHttpTransportTest extends AbstractTransportTest {
return true; return true;
} }
// not yet implemented
@Override @Override
@org.junit.Test @org.junit.Test
@org.junit.Ignore @org.junit.Ignore
public void clientChecksInboundMetadataSize_header() {} public void clientChecksInboundMetadataSize_trailer() {
// Server-side is flaky due to https://github.com/netty/netty/pull/8332
// not yet implemented }
@Override
@org.junit.Test
@org.junit.Ignore
public void clientChecksInboundMetadataSize_trailer() {}
} }