From cf4a38ed454f327130436e4cf156cb50b2d779ee Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Tue, 23 Jan 2018 10:35:59 -0800 Subject: [PATCH] core: retry part 2, buffer size limit Implement buffer size counting with ClientStreamTracer and buffer size limit following the spce https://github.com/grpc/proposal/blob/master/A6-client-retries.md#memory-management-buffering --- .../io/grpc/ForwardingChannelBuilder.java | 12 ++ .../java/io/grpc/ManagedChannelBuilder.java | 33 +++++ .../AbstractManagedChannelImplBuilder.java | 20 +++ .../io/grpc/internal/ManagedChannelImpl.java | 22 ++- .../io/grpc/internal/RetriableStream.java | 127 +++++++++++++++++- ...AbstractManagedChannelImplBuilderTest.java | 30 +++++ .../io/grpc/internal/RetriableStreamTest.java | 55 +++++++- 7 files changed, 287 insertions(+), 12 deletions(-) diff --git a/core/src/main/java/io/grpc/ForwardingChannelBuilder.java b/core/src/main/java/io/grpc/ForwardingChannelBuilder.java index da9d8a3b2c..72e15ae68b 100644 --- a/core/src/main/java/io/grpc/ForwardingChannelBuilder.java +++ b/core/src/main/java/io/grpc/ForwardingChannelBuilder.java @@ -163,6 +163,18 @@ public abstract class ForwardingChannelBuilder> throw new UnsupportedOperationException(); } + /** + * Sets the retry buffer size in bytes. If the buffer limit is exceeded, no RPC + * could retry at the moment, and in hedging case all hedges but one of the same RPC will cancel. + * The implementation may only estimate the buffer size being used rather than count the + * exact physical memory allocated. The method does not have any effect if retry is disabled by + * the client. + * + *

This method may not work as expected for the current release because retry is not fully + * implemented yet. + * + * @since 1.10.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/3982") + public T retryBufferSize(long bytes) { + throw new UnsupportedOperationException(); + } + + /** + * Sets the per RPC buffer limit in bytes used for retry. The RPC is not retriable if its buffer + * limit is exceeded. The implementation may only estimate the buffer size being used rather than + * count the exact physical memory allocated. It does not have any effect if retry is disabled by + * the client. + * + *

This method may not work as expected for the current release because retry is not fully + * implemented yet. + * + * @since 1.10.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/3982") + public T perRpcBufferLimit(long bytes) { + throw new UnsupportedOperationException(); + } + /** * Builds a channel using the given parameters. * diff --git a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java index 1c5a8cad66..90a6f2cc37 100644 --- a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java @@ -90,6 +90,9 @@ public abstract class AbstractManagedChannelImplBuilder private static final CompressorRegistry DEFAULT_COMPRESSOR_REGISTRY = CompressorRegistry.getDefaultInstance(); + private static final long DEFAULT_RETRY_BUFFER_SIZE_IN_BYTES = 1L << 24; // 16M + private static final long DEFAULT_PER_RPC_BUFFER_LIMIT_IN_BYTES = 1L << 20; // 1M + ObjectPool executorPool = DEFAULT_EXECUTOR_POOL; private final List interceptors = new ArrayList(); @@ -119,6 +122,9 @@ public abstract class AbstractManagedChannelImplBuilder long idleTimeoutMillis = IDLE_MODE_DEFAULT_TIMEOUT_MILLIS; + long retryBufferSize = DEFAULT_RETRY_BUFFER_SIZE_IN_BYTES; + long perRpcBufferLimit = DEFAULT_PER_RPC_BUFFER_LIMIT_IN_BYTES; + protected TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory(); private int maxInboundMessageSize = GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; @@ -273,6 +279,20 @@ public abstract class AbstractManagedChannelImplBuilder return thisT(); } + @Override + public final T retryBufferSize(long bytes) { + checkArgument(bytes > 0L, "retry buffer size must be positive"); + retryBufferSize = bytes; + return thisT(); + } + + @Override + public final T perRpcBufferLimit(long bytes) { + checkArgument(bytes > 0L, "per RPC buffer limit must be positive"); + perRpcBufferLimit = bytes; + return thisT(); + } + /** * Override the default stats implementation. */ diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 8631d6e4ab..1a9481da88 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -34,6 +34,7 @@ import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; +import io.grpc.ClientStreamTracer; import io.grpc.CompressorRegistry; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; @@ -53,6 +54,7 @@ import io.grpc.MethodDescriptor; import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.internal.ClientCallImpl.ClientTransportProvider; +import io.grpc.internal.RetriableStream.ChannelBufferMeter; import java.lang.ref.Reference; import java.lang.ref.ReferenceQueue; import java.lang.ref.SoftReference; @@ -200,6 +202,12 @@ public final class ManagedChannelImpl private final CallTracer.Factory callTracerFactory; private final CallTracer channelCallTracer; + // One instance per channel. + private final ChannelBufferMeter channelBufferUsed = new ChannelBufferMeter(); + + private final long perRpcBufferLimit; + private final long channelBufferLimit; + // Called from channelExecutor private final ManagedClientTransport.Listener delayedTransportListener = new ManagedClientTransport.Listener() { @@ -424,7 +432,8 @@ public final class ManagedChannelImpl final CallOptions callOptions, final Metadata headers, final Context context) { - return new RetriableStream(method) { + return new RetriableStream( + method, channelBufferUsed, perRpcBufferLimit, channelBufferLimit) { @Override Status prestart() { return uncommittedRetriableStreamsRegistry.add(this); @@ -436,12 +445,14 @@ public final class ManagedChannelImpl } @Override - ClientStream newStream() { + ClientStream newStream(ClientStreamTracer.Factory tracerFactory) { + // TODO(zdapeng): only add tracer when retry is enabled. + CallOptions newOptions = callOptions.withStreamTracerFactory(tracerFactory); ClientTransport transport = - get(new PickSubchannelArgsImpl(method, headers, callOptions)); + get(new PickSubchannelArgsImpl(method, headers, newOptions)); Context origContext = context.attach(); try { - return transport.newStream(method, headers, callOptions); + return transport.newStream(method, headers, newOptions); } finally { context.detach(origContext); } @@ -493,6 +504,9 @@ public final class ManagedChannelImpl this.userAgent = builder.userAgent; this.proxyDetector = proxyDetector; + this.channelBufferLimit = builder.retryBufferSize; + this.perRpcBufferLimit = builder.perRpcBufferLimit; + phantom = new ManagedChannelReference(this); this.callTracerFactory = callTracerFactory; channelCallTracer = callTracerFactory.create(); diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 036705c969..3fb7d310a8 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -20,6 +20,8 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Compressor; import io.grpc.DecompressorRegistry; import io.grpc.Metadata; @@ -32,6 +34,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; @@ -46,14 +49,27 @@ abstract class RetriableStream implements ClientStream { /** Must be held when updating state, accessing state.buffer, or certain substream attributes. */ private final Object lock = new Object(); - private volatile State state = - new State( - new ArrayList(), Collections.emptySet(), null, false, false); + private final ChannelBufferMeter channelBufferUsed; + private final long perRpcBufferLimit; + private final long channelBufferLimit; + + private volatile State state = new State( + new ArrayList(), Collections.emptySet(), null, false, false); + + // Used for recording the share of buffer used for the current call out of the channel buffer. + // This field would not be necessary if there is no channel buffer limit. + @GuardedBy("lock") + private long perRpcBufferUsed; private ClientStreamListener masterListener; - RetriableStream(MethodDescriptor method) { + RetriableStream( + MethodDescriptor method, ChannelBufferMeter channelBufferUsed, + long perRpcBufferLimit, long channelBufferLimit) { this.method = method; + this.channelBufferUsed = channelBufferUsed; + this.perRpcBufferLimit = perRpcBufferLimit; + this.channelBufferLimit = channelBufferLimit; } @Nullable // null if already committed @@ -67,6 +83,9 @@ abstract class RetriableStream implements ClientStream { state = state.committed(winningSubstream); + // subtract the share of this RPC from channelBufferUsed. + channelBufferUsed.addAndGet(-perRpcBufferUsed); + class CommitTask implements Runnable { @Override public void run() { @@ -109,15 +128,23 @@ abstract class RetriableStream implements ClientStream { private Substream createSubstream() { Substream sub = new Substream(); + // one tracer per substream + final ClientStreamTracer bufferSizeTracer = new BufferSizeTracer(sub); + ClientStreamTracer.Factory tracerFactory = new ClientStreamTracer.Factory() { + @Override + public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) { + return bufferSizeTracer; + } + }; // NOTICE: This set _must_ be done before stream.start() and it actually is. - sub.stream = newStream(); + sub.stream = newStream(tracerFactory); return sub; } /** * Creates a new physical ClientStream that represents a retry/hedging attempt. */ - abstract ClientStream newStream(); + abstract ClientStream newStream(ClientStreamTracer.Factory tracerFactory); private void drain(Substream substream) { int index = 0; @@ -452,6 +479,16 @@ abstract class RetriableStream implements ClientStream { state = state.substreamClosed(substream); } + // handle a race between buffer limit exceeded and closed, when setting + // substream.bufferLimitExceeded = true happens before state.substreamClosed(substream). + if (substream.bufferLimitExceeded) { + commitAndRun(substream); + if (state.winningSubstream == substream) { + masterListener.closed(status, trailers); + } + return; + } + if (state.winningSubstream == null && shouldRetry()) { // The check state.winningSubstream == null, checking if is not already committed, is racy, // but is still safe b/c the retry will also handle committed/cancellation @@ -603,5 +640,83 @@ abstract class RetriableStream implements ClientStream { // GuardedBy RetriableStream.lock boolean closed; + + // setting to true must be GuardedBy RetriableStream.lock + boolean bufferLimitExceeded; + } + + + /** + * Traces the buffer used by a substream. + */ + class BufferSizeTracer extends ClientStreamTracer { + // Each buffer size tracer is dedicated to one specific substream. + private final Substream substream; + + @GuardedBy("lock") + long bufferNeeded; + + BufferSizeTracer(Substream substream) { + this.substream = substream; + } + + /** + * A message is sent to the wire, so its reference would be released if no retry or + * hedging were involved. So at this point we have to hold the reference of the message longer + * for retry, and we need to increment {@code substream.bufferNeeded}. + */ + @Override + public void outboundWireSize(long bytes) { + if (state.winningSubstream != null) { + return; + } + + Runnable postCommitTask = null; + + // TODO(zdapeng): avoid using the same lock for both in-bound and out-bound. + synchronized (lock) { + if (state.winningSubstream != null || substream.closed) { + return; + } + bufferNeeded += bytes; + if (bufferNeeded <= perRpcBufferUsed) { + return; + } + + if (bufferNeeded > perRpcBufferLimit) { + substream.bufferLimitExceeded = true; + } else { + // Only update channelBufferUsed when perRpcBufferUsed is not exceeding perRpcBufferLimit. + long savedChannelBufferUsed = + channelBufferUsed.addAndGet(bufferNeeded - perRpcBufferUsed); + perRpcBufferUsed = bufferNeeded; + + if (savedChannelBufferUsed > channelBufferLimit) { + substream.bufferLimitExceeded = true; + } + } + + if (substream.bufferLimitExceeded) { + postCommitTask = commit(substream); + } + } + + if (postCommitTask != null) { + postCommitTask.run(); + } + } + } + + + /** + * Used to keep track of the total amount of memory used to buffer retryable or hedged RPCs for + * the Channel. There should be a single instance of it for each channel. + */ + static final class ChannelBufferMeter { + private final AtomicLong bufferUsed = new AtomicLong(); + + public long addAndGet(long newBytesUsed) { + return bufferUsed.addAndGet(newBytesUsed); + } } } diff --git a/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java b/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java index 7d53ae348f..3eddecb570 100644 --- a/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java @@ -330,6 +330,36 @@ public class AbstractManagedChannelImplBuilderTest { assertEquals(TimeUnit.SECONDS.toMillis(30), builder.getIdleTimeoutMillis()); } + @Test + public void retryBufferSize() { + Builder builder = new Builder("target"); + assertEquals(1L << 24, builder.retryBufferSize); + + builder.retryBufferSize(3456L); + assertEquals(3456L, builder.retryBufferSize); + } + + @Test + public void perRpcBufferLimit() { + Builder builder = new Builder("target"); + assertEquals(1L << 20, builder.perRpcBufferLimit); + + builder.perRpcBufferLimit(3456L); + assertEquals(3456L, builder.perRpcBufferLimit); + } + + @Test(expected = IllegalArgumentException.class) + public void retryBufferSizeInvalidArg() { + Builder builder = new Builder("target"); + builder.retryBufferSize(0L); + } + + @Test(expected = IllegalArgumentException.class) + public void perRpcBufferLimitInvalidArg() { + Builder builder = new Builder("target"); + builder.perRpcBufferLimit(0L); + } + static class Builder extends AbstractManagedChannelImplBuilder { Builder(String target) { super(target); diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index 04519e2e14..3cd26edfa3 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -31,6 +31,8 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Codec; import io.grpc.Compressor; import io.grpc.DecompressorRegistry; @@ -39,6 +41,7 @@ import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.MethodType; import io.grpc.Status; import io.grpc.StringMarshaller; +import io.grpc.internal.RetriableStream.ChannelBufferMeter; import io.grpc.internal.StreamListener.MessageProducer; import java.io.InputStream; import java.util.ArrayList; @@ -61,6 +64,8 @@ public class RetriableStreamTest { DecompressorRegistry.getDefaultInstance(); private static final int MAX_INBOUND_MESSAGE_SIZE = 1234; private static final int MAX_OUTNBOUND_MESSAGE_SIZE = 5678; + private static final long PER_RPC_BUFFER_LIMIT = 1000; + private static final long CHANNEL_BUFFER_LIMIT = 2000; private final RetriableStreamRecorder retriableStreamRecorder = mock(RetriableStreamRecorder.class); private final ClientStreamListener masterListener = mock(ClientStreamListener.class); @@ -71,15 +76,19 @@ public class RetriableStreamTest { .setRequestMarshaller(new StringMarshaller()) .setResponseMarshaller(new StringMarshaller()) .build(); + private final ChannelBufferMeter channelBufferUsed = new ChannelBufferMeter(); private final RetriableStream retriableStream = - new RetriableStream(method) { + new RetriableStream( + method, channelBufferUsed, PER_RPC_BUFFER_LIMIT, CHANNEL_BUFFER_LIMIT) { @Override void postCommit() { retriableStreamRecorder.postCommit(); } @Override - ClientStream newStream() { + ClientStream newStream(ClientStreamTracer.Factory tracerFactory) { + bufferSizeTracer = + tracerFactory.newClientStreamTracer(CallOptions.DEFAULT, new Metadata()); return retriableStreamRecorder.newSubstream(); } @@ -94,6 +103,8 @@ public class RetriableStreamTest { } }; + private ClientStreamTracer bufferSizeTracer; + @Test public void retry_everythingDrained() { ClientStream mockStream1 = mock(ClientStream.class); @@ -728,6 +739,46 @@ public class RetriableStreamTest { verify(mockStream3).request(1); } + // TODO(zdapeng): test buffer limit exceeded during backoff + @Test + public void perRpcBufferLimitExceeded() { + ClientStream mockStream1 = mock(ClientStream.class); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(); + + retriableStream.start(masterListener); + + bufferSizeTracer.outboundWireSize(PER_RPC_BUFFER_LIMIT); + + assertEquals(PER_RPC_BUFFER_LIMIT, channelBufferUsed.addAndGet(0)); + + verify(retriableStreamRecorder, never()).postCommit(); + bufferSizeTracer.outboundWireSize(2); + verify(retriableStreamRecorder).postCommit(); + + // verify channel buffer is adjusted + assertEquals(0, channelBufferUsed.addAndGet(0)); + } + + @Test + public void channelBufferLimitExceeded() { + ClientStream mockStream1 = mock(ClientStream.class); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(); + + retriableStream.start(masterListener); + + bufferSizeTracer.outboundWireSize(100); + + assertEquals(100, channelBufferUsed.addAndGet(0)); + + channelBufferUsed.addAndGet(CHANNEL_BUFFER_LIMIT - 200); + verify(retriableStreamRecorder, never()).postCommit(); + bufferSizeTracer.outboundWireSize(100 + 1); + verify(retriableStreamRecorder).postCommit(); + + // verify channel buffer is adjusted + assertEquals(CHANNEL_BUFFER_LIMIT - 200, channelBufferUsed.addAndGet(0)); + } + /** * Used to stub a retriable stream as well as to record methods of the retriable stream being * called.