diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index d4053cd291..ca7a33e474 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -55,6 +55,7 @@ import io.grpc.internal.Channelz.ChannelStats; import io.grpc.internal.ClientCallImpl.ClientTransportProvider; import io.grpc.internal.RetriableStream.ChannelBufferMeter; import io.grpc.internal.RetriableStream.RetryPolicy; +import io.grpc.internal.RetriableStream.Throttle; import java.lang.ref.Reference; import java.lang.ref.ReferenceQueue; import java.lang.ref.SoftReference; @@ -205,6 +206,8 @@ public final class ManagedChannelImpl extends ManagedChannel implements Instrume // One instance per channel. private final ChannelBufferMeter channelBufferUsed = new ChannelBufferMeter(); + private Throttle throttle; + private final int maxRetryAttempts; private final int maxHedgedAttempts; private final long perRpcBufferLimit; @@ -447,7 +450,7 @@ public final class ManagedChannelImpl extends ManagedChannel implements Instrume return new RetriableStream( method, headers, channelBufferUsed, perRpcBufferLimit, channelBufferLimit, getCallExecutor(callOptions), transportFactory.getScheduledExecutorService(), - retryPolicy) { + retryPolicy, throttle) { @Override Status prestart() { return uncommittedRetriableStreamsRegistry.add(this); @@ -1072,6 +1075,7 @@ public final class ManagedChannelImpl extends ManagedChannel implements Instrume try { if (retryEnabled) { retryPolicies = getRetryPolicies(config); + throttle = getThrottle(config); } } catch (RuntimeException re) { logger.log( @@ -1126,6 +1130,12 @@ public final class ManagedChannelImpl extends ManagedChannel implements Instrume }; } + // TODO(zdapeng): implement it once the Gson dependency issue is resolved. + @Nullable + private static Throttle getThrottle(Attributes config) { + return null; + } + private final class SubchannelImpl extends AbstractSubchannel { // Set right after SubchannelImpl is created. InternalSubchannel subchannel; diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index cfbb41653d..c09be6f7c8 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -42,6 +42,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; @@ -65,7 +66,6 @@ abstract class RetriableStream implements ClientStream { private final ScheduledExecutorService scheduledExecutorService; // Must not modify it. private final Metadata headers; - // TODO(zdapeng): add and use its business logic private final RetryPolicy retryPolicy; /** Must be held when updating state, accessing state.buffer, or certain substream attributes. */ @@ -74,6 +74,8 @@ abstract class RetriableStream implements ClientStream { private final ChannelBufferMeter channelBufferUsed; private final long perRpcBufferLimit; private final long channelBufferLimit; + @Nullable + private final Throttle throttle; private volatile State state = new State( new ArrayList(), Collections.emptySet(), null, false, false); @@ -91,7 +93,7 @@ abstract class RetriableStream implements ClientStream { MethodDescriptor method, Metadata headers, ChannelBufferMeter channelBufferUsed, long perRpcBufferLimit, long channelBufferLimit, Executor callExecutor, ScheduledExecutorService scheduledExecutorService, - RetryPolicy retryPolicy) { + RetryPolicy retryPolicy, @Nullable Throttle throttle) { this.method = method; this.channelBufferUsed = channelBufferUsed; this.perRpcBufferLimit = perRpcBufferLimit; @@ -101,6 +103,7 @@ abstract class RetriableStream implements ClientStream { this.headers = headers; this.retryPolicy = checkNotNull(retryPolicy, "retryPolicy"); nextBackoffIntervalInSeconds = retryPolicy.initialBackoffInSeconds; + this.throttle = throttle; } @Nullable // null if already committed @@ -519,6 +522,9 @@ abstract class RetriableStream implements ClientStream { commitAndRun(substream); if (state.winningSubstream == substream) { masterListener.headersRead(headers); + if (throttle != null) { + throttle.onSuccess(); + } } } @@ -582,35 +588,43 @@ abstract class RetriableStream implements ClientStream { private RetryPlan makeRetryDecision(RetryPolicy retryPolicy, Status status, Metadata trailer) { boolean shouldRetry = false; long backoffInMillis = 0L; + boolean isRetryableStatusCode = retryPolicy.retryableStatusCodes.contains(status.getCode()); - if (retryPolicy.maxAttempts > substream.previousAttempts + 1) { - String pushbackStr = trailer.get(GRPC_RETRY_PUSHBACK_MS); - if (pushbackStr == null) { - if (retryPolicy.retryableStatusCodes.contains(status.getCode())) { + String pushbackStr = trailer.get(GRPC_RETRY_PUSHBACK_MS); + Integer pushback = null; + if (pushbackStr != null) { + try { + pushback = Integer.valueOf(pushbackStr); + } catch (NumberFormatException e) { + pushback = -1; + } + } + + boolean isThrottled = false; + if (throttle != null) { + if (isRetryableStatusCode || (pushback != null && pushback < 0)) { + isThrottled = !throttle.onQualifiedFailureThenCheckIsAboveThreshold(); + } + } + + if (retryPolicy.maxAttempts > substream.previousAttempts + 1 && !isThrottled) { + if (pushback == null) { + if (isRetryableStatusCode) { shouldRetry = true; backoffInMillis = (long) (nextBackoffIntervalInSeconds * 1000D * random.nextDouble()); nextBackoffIntervalInSeconds = Math.min( nextBackoffIntervalInSeconds * retryPolicy.backoffMultiplier, retryPolicy.maxBackoffInSeconds); } // else no retry - } else { - int pushback; - try { - pushback = Integer.parseInt(pushbackStr); - } catch (NumberFormatException e) { - pushback = -1; - } - if (pushback >= 0) { - shouldRetry = true; - backoffInMillis = pushback; - nextBackoffIntervalInSeconds = retryPolicy.initialBackoffInSeconds; - } // else no retry - } - } + } else if (pushback >= 0) { + shouldRetry = true; + backoffInMillis = pushback; + nextBackoffIntervalInSeconds = retryPolicy.initialBackoffInSeconds; + } // else no retry + } // else no retry // TODO(zdapeng): transparent retry // TODO(zdapeng): hedging - // TODO(zdapeng): throttling return new RetryPlan(shouldRetry, backoffInMillis); } @@ -831,11 +845,87 @@ abstract class RetriableStream implements ClientStream { static final class ChannelBufferMeter { private final AtomicLong bufferUsed = new AtomicLong(); - public long addAndGet(long newBytesUsed) { + @VisibleForTesting + long addAndGet(long newBytesUsed) { return bufferUsed.addAndGet(newBytesUsed); } } + /** + * Used for retry throttling. + */ + static final class Throttle { + + private static final int THREE_DECIMAL_PLACES_SCALE_UP = 1000; + + /** + * 1000 times the maxTokens field of the retryThrottling policy in service config. + * The number of tokens starts at maxTokens. The token_count will always be between 0 and + * maxTokens. + */ + final int maxTokens; + + /** + * Half of {@code maxTokens}. + */ + final int threshold; + + /** + * 1000 times the tokenRatio field of the retryThrottling policy in service config. + */ + final int tokenRatio; + + final AtomicInteger tokenCount = new AtomicInteger(); + + Throttle(float maxTokens, float tokenRatio) { + // tokenRatio is up to 3 decimal places + this.tokenRatio = (int) (tokenRatio * THREE_DECIMAL_PLACES_SCALE_UP); + this.maxTokens = (int) (maxTokens * THREE_DECIMAL_PLACES_SCALE_UP); + this.threshold = this.maxTokens / 2; + tokenCount.set(this.maxTokens); + } + + @VisibleForTesting + boolean isAboveThreshold() { + return tokenCount.get() > threshold; + } + + /** + * Counts down the token on qualified failure and checks if it is above the threshold + * atomically. Qualified failure is a failure with a retryable or non-fatal status code or with + * a not-to-retry pushback. + */ + @VisibleForTesting + boolean onQualifiedFailureThenCheckIsAboveThreshold() { + while (true) { + int currentCount = tokenCount.get(); + if (currentCount == 0) { + return false; + } + int decremented = currentCount - (1 * THREE_DECIMAL_PLACES_SCALE_UP); + boolean updated = tokenCount.compareAndSet(currentCount, Math.max(decremented, 0)); + if (updated) { + return decremented > threshold; + } + } + } + + @VisibleForTesting + void onSuccess() { + while (true) { + int currentCount = tokenCount.get(); + if (currentCount == maxTokens) { + break; + } + int incremented = currentCount + tokenRatio; + boolean updated = tokenCount.compareAndSet(currentCount, Math.min(incremented, maxTokens)); + if (updated) { + break; + } + } + } + } + @Immutable static final class RetryPolicy { private final int maxAttempts; diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index 9b25de4eb7..62042a4750 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -49,14 +49,18 @@ import io.grpc.Status.Code; import io.grpc.StringMarshaller; import io.grpc.internal.RetriableStream.ChannelBufferMeter; import io.grpc.internal.RetriableStream.RetryPolicy; +import io.grpc.internal.RetriableStream.Throttle; import io.grpc.internal.StreamListener.MessageProducer; import java.io.InputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nullable; import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; @@ -114,32 +118,50 @@ public class RetriableStreamTest { .build(); private final ChannelBufferMeter channelBufferUsed = new ChannelBufferMeter(); private final FakeClock fakeClock = new FakeClock(); + + private final class RecordedRetriableStream extends RetriableStream { + RecordedRetriableStream(MethodDescriptor method, Metadata headers, + ChannelBufferMeter channelBufferUsed, long perRpcBufferLimit, long channelBufferLimit, + Executor callExecutor, + ScheduledExecutorService scheduledExecutorService, + RetryPolicy retryPolicy, + @Nullable Throttle throttle) { + super(method, headers, channelBufferUsed, perRpcBufferLimit, channelBufferLimit, callExecutor, + scheduledExecutorService, retryPolicy, throttle); + } + + @Override + void postCommit() { + retriableStreamRecorder.postCommit(); + } + + @Override + ClientStream newSubstream(ClientStreamTracer.Factory tracerFactory, Metadata metadata) { + bufferSizeTracer = + tracerFactory.newClientStreamTracer(CallOptions.DEFAULT, new Metadata()); + int actualPreviousRpcAttemptsInHeader = metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS) == null + ? 0 : Integer.valueOf(metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS)); + return retriableStreamRecorder.newSubstream(actualPreviousRpcAttemptsInHeader); + } + + @Override + Status prestart() { + return retriableStreamRecorder.prestart(); + } + } + private final RetriableStream retriableStream = - new RetriableStream( - method, new Metadata(),channelBufferUsed, PER_RPC_BUFFER_LIMIT, CHANNEL_BUFFER_LIMIT, - MoreExecutors.directExecutor(), fakeClock.getScheduledExecutorService(), RETRY_POLICY) { - @Override - void postCommit() { - retriableStreamRecorder.postCommit(); - } - - @Override - ClientStream newSubstream(ClientStreamTracer.Factory tracerFactory, Metadata metadata) { - bufferSizeTracer = - tracerFactory.newClientStreamTracer(CallOptions.DEFAULT, new Metadata()); - int actualPreviousRpcAttemptsInHeader = metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS) == null - ? 0 : Integer.valueOf(metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS)); - return retriableStreamRecorder.newSubstream(actualPreviousRpcAttemptsInHeader); - } - - @Override - Status prestart() { - return retriableStreamRecorder.prestart(); - } - }; + newThrottledRetriableStream(null /* throttle */); private ClientStreamTracer bufferSizeTracer; + private RetriableStream newThrottledRetriableStream(Throttle throttle) { + return new RecordedRetriableStream( + method, new Metadata(), channelBufferUsed, PER_RPC_BUFFER_LIMIT, CHANNEL_BUFFER_LIMIT, + MoreExecutors.directExecutor(), fakeClock.getScheduledExecutorService(), RETRY_POLICY, + throttle); + } + @After public void tearDown() { assertEquals(0, fakeClock.numPendingTasks()); @@ -1143,6 +1165,204 @@ public class RetriableStreamTest { verify(retriableStreamRecorder).postCommit(); } + @Test + public void throttle() { + Throttle throttle = new Throttle(4f, 0.8f); + assertTrue(throttle.isAboveThreshold()); + assertTrue(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // token = 3 + assertTrue(throttle.isAboveThreshold()); + assertFalse(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // token = 2 + assertFalse(throttle.isAboveThreshold()); + assertFalse(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // token = 1 + assertFalse(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // token = 0 + assertFalse(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // token = 0 + assertFalse(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // token = 0 + assertFalse(throttle.isAboveThreshold()); + + throttle.onSuccess(); // token = 0.8 + assertFalse(throttle.isAboveThreshold()); + throttle.onSuccess(); // token = 1.6 + assertFalse(throttle.isAboveThreshold()); + throttle.onSuccess(); // token = 3.2 + assertTrue(throttle.isAboveThreshold()); + throttle.onSuccess(); // token = 4 + assertTrue(throttle.isAboveThreshold()); + throttle.onSuccess(); // token = 4 + assertTrue(throttle.isAboveThreshold()); + throttle.onSuccess(); // token = 4 + assertTrue(throttle.isAboveThreshold()); + + assertTrue(throttle.isAboveThreshold()); + assertTrue(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // token = 3 + assertTrue(throttle.isAboveThreshold()); + assertFalse(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // token = 2 + assertFalse(throttle.isAboveThreshold()); + } + + @Test + public void throttledStream_FailWithRetriableStatusCode_WithoutPushback() { + Throttle throttle = new Throttle(4f, 0.8f); + RetriableStream retriableStream = newThrottledRetriableStream(throttle); + + ClientStream mockStream = mock(ClientStream.class); + doReturn(mockStream).when(retriableStreamRecorder).newSubstream(anyInt()); + retriableStream.start(masterListener); + ArgumentCaptor sublistenerCaptor = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream).start(sublistenerCaptor.capture()); + + // mimic some other call in the channel triggers a throttle countdown + assertTrue(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // count = 3 + + sublistenerCaptor.getValue().closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), new Metadata()); + verify(retriableStreamRecorder).postCommit(); + assertFalse(throttle.isAboveThreshold()); // count = 2 + } + + @Test + public void throttledStream_FailWithNonRetriableStatusCode_WithoutPushback() { + Throttle throttle = new Throttle(4f, 0.8f); + RetriableStream retriableStream = newThrottledRetriableStream(throttle); + + ClientStream mockStream = mock(ClientStream.class); + doReturn(mockStream).when(retriableStreamRecorder).newSubstream(anyInt()); + retriableStream.start(masterListener); + ArgumentCaptor sublistenerCaptor = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream).start(sublistenerCaptor.capture()); + + // mimic some other call in the channel triggers a throttle countdown + assertTrue(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // count = 3 + + sublistenerCaptor.getValue().closed(Status.fromCode(NON_RETRIABLE_STATUS_CODE), new Metadata()); + verify(retriableStreamRecorder).postCommit(); + assertTrue(throttle.isAboveThreshold()); // count = 3 + + assertFalse(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // count = 2 + } + + @Test + public void throttledStream_FailWithRetriableStatusCode_WithRetriablePushback() { + Throttle throttle = new Throttle(4f, 0.8f); + RetriableStream retriableStream = newThrottledRetriableStream(throttle); + + ClientStream mockStream = mock(ClientStream.class); + doReturn(mockStream).when(retriableStreamRecorder).newSubstream(anyInt()); + retriableStream.start(masterListener); + ArgumentCaptor sublistenerCaptor = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream).start(sublistenerCaptor.capture()); + + // mimic some other call in the channel triggers a throttle countdown + assertTrue(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // count = 3 + + int pushbackInMillis = 123; + Metadata headers = new Metadata(); + headers.put(RetriableStream.GRPC_RETRY_PUSHBACK_MS, "" + pushbackInMillis); + sublistenerCaptor.getValue().closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), headers); + verify(retriableStreamRecorder).postCommit(); + assertFalse(throttle.isAboveThreshold()); // count = 2 + } + + @Test + public void throttledStream_FailWithNonRetriableStatusCode_WithRetriablePushback() { + Throttle throttle = new Throttle(4f, 0.8f); + RetriableStream retriableStream = newThrottledRetriableStream(throttle); + + ClientStream mockStream = mock(ClientStream.class); + doReturn(mockStream).when(retriableStreamRecorder).newSubstream(anyInt()); + retriableStream.start(masterListener); + ArgumentCaptor sublistenerCaptor = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream).start(sublistenerCaptor.capture()); + + // mimic some other call in the channel triggers a throttle countdown + assertTrue(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // count = 3 + + int pushbackInMillis = 123; + Metadata headers = new Metadata(); + headers.put(RetriableStream.GRPC_RETRY_PUSHBACK_MS, "" + pushbackInMillis); + sublistenerCaptor.getValue().closed(Status.fromCode(NON_RETRIABLE_STATUS_CODE), headers); + verify(retriableStreamRecorder, never()).postCommit(); + assertTrue(throttle.isAboveThreshold()); // count = 3 + + // drain pending retry + fakeClock.forwardTime(pushbackInMillis, TimeUnit.MILLISECONDS); + + assertTrue(throttle.isAboveThreshold()); // count = 3 + assertFalse(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // count = 2 + } + + @Test + public void throttledStream_FailWithRetriableStatusCode_WithNonRetriablePushback() { + Throttle throttle = new Throttle(4f, 0.8f); + RetriableStream retriableStream = newThrottledRetriableStream(throttle); + + ClientStream mockStream = mock(ClientStream.class); + doReturn(mockStream).when(retriableStreamRecorder).newSubstream(anyInt()); + retriableStream.start(masterListener); + ArgumentCaptor sublistenerCaptor = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream).start(sublistenerCaptor.capture()); + + // mimic some other call in the channel triggers a throttle countdown + assertTrue(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // count = 3 + + Metadata headers = new Metadata(); + headers.put(RetriableStream.GRPC_RETRY_PUSHBACK_MS, ""); + sublistenerCaptor.getValue().closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), headers); + verify(retriableStreamRecorder).postCommit(); + assertFalse(throttle.isAboveThreshold()); // count = 2 + } + + @Test + public void throttledStream_FailWithNonRetriableStatusCode_WithNonRetriablePushback() { + Throttle throttle = new Throttle(4f, 0.8f); + RetriableStream retriableStream = newThrottledRetriableStream(throttle); + + ClientStream mockStream = mock(ClientStream.class); + doReturn(mockStream).when(retriableStreamRecorder).newSubstream(anyInt()); + retriableStream.start(masterListener); + ArgumentCaptor sublistenerCaptor = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream).start(sublistenerCaptor.capture()); + + // mimic some other call in the channel triggers a throttle countdown + assertTrue(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // count = 3 + + Metadata headers = new Metadata(); + headers.put(RetriableStream.GRPC_RETRY_PUSHBACK_MS, ""); + sublistenerCaptor.getValue().closed(Status.fromCode(NON_RETRIABLE_STATUS_CODE), headers); + verify(retriableStreamRecorder).postCommit(); + assertFalse(throttle.isAboveThreshold()); // count = 2 + } + + @Test + public void throttleStream_Succeed() { + Throttle throttle = new Throttle(4f, 0.8f); + RetriableStream retriableStream = newThrottledRetriableStream(throttle); + + ClientStream mockStream = mock(ClientStream.class); + doReturn(mockStream).when(retriableStreamRecorder).newSubstream(anyInt()); + retriableStream.start(masterListener); + ArgumentCaptor sublistenerCaptor = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream).start(sublistenerCaptor.capture()); + + // mimic some other calls in the channel trigger throttle countdowns + assertTrue(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // count = 3 + assertFalse(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // count = 2 + assertFalse(throttle.onQualifiedFailureThenCheckIsAboveThreshold()); // count = 1 + + sublistenerCaptor.getValue().headersRead(new Metadata()); + verify(retriableStreamRecorder).postCommit(); + assertFalse(throttle.isAboveThreshold()); // count = 1.8 + + // mimic some other call in the channel triggers a success + throttle.onSuccess(); + assertTrue(throttle.isAboveThreshold()); // count = 2.6 + } + /** * Used to stub a retriable stream as well as to record methods of the retriable stream being * called.