Refine delay jitter for exponential backoff (#7206)

This commit is contained in:
Yuriy Holinko 2025-03-25 18:20:12 +02:00 committed by GitHub
parent 9e0efd4267
commit 3c12e3af1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 84 additions and 67 deletions

View File

@ -213,9 +213,11 @@ public final class JdkHttpSender implements HttpSender {
do { do {
if (attempt > 0) { if (attempt > 0) {
// Compute and sleep for backoff // Compute and sleep for backoff
long upperBoundNanos = Math.min(nextBackoffNanos, retryPolicy.getMaxBackoff().toNanos()); long currentBackoffNanos =
long backoffNanos = ThreadLocalRandom.current().nextLong(upperBoundNanos); Math.min(nextBackoffNanos, retryPolicy.getMaxBackoff().toNanos());
nextBackoffNanos = (long) (nextBackoffNanos * retryPolicy.getBackoffMultiplier()); long backoffNanos =
(long) (ThreadLocalRandom.current().nextDouble(0.8d, 1.2d) * currentBackoffNanos);
nextBackoffNanos = (long) (currentBackoffNanos * retryPolicy.getBackoffMultiplier());
try { try {
TimeUnit.NANOSECONDS.sleep(backoffNanos); TimeUnit.NANOSECONDS.sleep(backoffNanos);
} catch (InterruptedException e) { } catch (InterruptedException e) {
@ -227,16 +229,11 @@ public final class JdkHttpSender implements HttpSender {
break; break;
} }
} }
httpResponse = null;
attempt++; exception = null;
requestBuilder.timeout(Duration.ofNanos(timeoutNanos - (System.nanoTime() - startTimeNanos))); requestBuilder.timeout(Duration.ofNanos(timeoutNanos - (System.nanoTime() - startTimeNanos)));
try { try {
httpResponse = sendRequest(requestBuilder, byteBufferPool); httpResponse = sendRequest(requestBuilder, byteBufferPool);
} catch (IOException e) {
exception = e;
}
if (httpResponse != null) {
boolean retryable = retryableStatusCodes.contains(httpResponse.statusCode()); boolean retryable = retryableStatusCodes.contains(httpResponse.statusCode());
if (logger.isLoggable(Level.FINER)) { if (logger.isLoggable(Level.FINER)) {
logger.log( logger.log(
@ -251,8 +248,8 @@ public final class JdkHttpSender implements HttpSender {
if (!retryable) { if (!retryable) {
return httpResponse; return httpResponse;
} }
} } catch (IOException e) {
if (exception != null) { exception = e;
boolean retryable = retryExceptionPredicate.test(exception); boolean retryable = retryExceptionPredicate.test(exception);
if (logger.isLoggable(Level.FINER)) { if (logger.isLoggable(Level.FINER)) {
logger.log( logger.log(
@ -268,7 +265,7 @@ public final class JdkHttpSender implements HttpSender {
throw exception; throw exception;
} }
} }
} while (attempt < retryPolicy.getMaxAttempts()); } while (++attempt < retryPolicy.getMaxAttempts());
if (httpResponse != null) { if (httpResponse != null) {
return httpResponse; return httpResponse;

View File

@ -18,6 +18,7 @@ import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
import okhttp3.Interceptor; import okhttp3.Interceptor;
@ -37,7 +38,7 @@ public final class RetryInterceptor implements Interceptor {
private final Function<Response, Boolean> isRetryable; private final Function<Response, Boolean> isRetryable;
private final Predicate<IOException> retryExceptionPredicate; private final Predicate<IOException> retryExceptionPredicate;
private final Sleeper sleeper; private final Sleeper sleeper;
private final BoundedLongGenerator randomLong; private final Supplier<Double> randomJitter;
/** Constructs a new retrier. */ /** Constructs a new retrier. */
public RetryInterceptor(RetryPolicy retryPolicy, Function<Response, Boolean> isRetryable) { public RetryInterceptor(RetryPolicy retryPolicy, Function<Response, Boolean> isRetryable) {
@ -48,7 +49,7 @@ public final class RetryInterceptor implements Interceptor {
? RetryInterceptor::isRetryableException ? RetryInterceptor::isRetryableException
: retryPolicy.getRetryExceptionPredicate(), : retryPolicy.getRetryExceptionPredicate(),
TimeUnit.NANOSECONDS::sleep, TimeUnit.NANOSECONDS::sleep,
bound -> ThreadLocalRandom.current().nextLong(bound)); () -> ThreadLocalRandom.current().nextDouble(0.8d, 1.2d));
} }
// Visible for testing // Visible for testing
@ -57,12 +58,12 @@ public final class RetryInterceptor implements Interceptor {
Function<Response, Boolean> isRetryable, Function<Response, Boolean> isRetryable,
Predicate<IOException> retryExceptionPredicate, Predicate<IOException> retryExceptionPredicate,
Sleeper sleeper, Sleeper sleeper,
BoundedLongGenerator randomLong) { Supplier<Double> randomJitter) {
this.retryPolicy = retryPolicy; this.retryPolicy = retryPolicy;
this.isRetryable = isRetryable; this.isRetryable = isRetryable;
this.retryExceptionPredicate = retryExceptionPredicate; this.retryExceptionPredicate = retryExceptionPredicate;
this.sleeper = sleeper; this.sleeper = sleeper;
this.randomLong = randomLong; this.randomJitter = randomJitter;
} }
@Override @Override
@ -75,9 +76,10 @@ public final class RetryInterceptor implements Interceptor {
if (attempt > 0) { if (attempt > 0) {
// Compute and sleep for backoff // Compute and sleep for backoff
// https://github.com/grpc/proposal/blob/master/A6-client-retries.md#exponential-backoff // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#exponential-backoff
long upperBoundNanos = Math.min(nextBackoffNanos, retryPolicy.getMaxBackoff().toNanos()); long currentBackoffNanos =
long backoffNanos = randomLong.get(upperBoundNanos); Math.min(nextBackoffNanos, retryPolicy.getMaxBackoff().toNanos());
nextBackoffNanos = (long) (nextBackoffNanos * retryPolicy.getBackoffMultiplier()); long backoffNanos = (long) (randomJitter.get() * currentBackoffNanos);
nextBackoffNanos = (long) (currentBackoffNanos * retryPolicy.getBackoffMultiplier());
try { try {
sleeper.sleep(backoffNanos); sleeper.sleep(backoffNanos);
} catch (InterruptedException e) { } catch (InterruptedException e) {
@ -88,14 +90,10 @@ public final class RetryInterceptor implements Interceptor {
if (response != null) { if (response != null) {
response.close(); response.close();
} }
exception = null;
} }
attempt++;
try { try {
response = chain.proceed(chain.request()); response = chain.proceed(chain.request());
} catch (IOException e) {
exception = e;
}
if (response != null) { if (response != null) {
boolean retryable = Boolean.TRUE.equals(isRetryable.apply(response)); boolean retryable = Boolean.TRUE.equals(isRetryable.apply(response));
if (logger.isLoggable(Level.FINER)) { if (logger.isLoggable(Level.FINER)) {
@ -111,8 +109,12 @@ public final class RetryInterceptor implements Interceptor {
if (!retryable) { if (!retryable) {
return response; return response;
} }
} else {
throw new NullPointerException("response cannot be null.");
} }
if (exception != null) { } catch (IOException e) {
exception = e;
response = null;
boolean retryable = retryExceptionPredicate.test(exception); boolean retryable = retryExceptionPredicate.test(exception);
if (logger.isLoggable(Level.FINER)) { if (logger.isLoggable(Level.FINER)) {
logger.log( logger.log(
@ -128,8 +130,7 @@ public final class RetryInterceptor implements Interceptor {
throw exception; throw exception;
} }
} }
} while (++attempt < retryPolicy.getMaxAttempts());
} while (attempt < retryPolicy.getMaxAttempts());
if (response != null) { if (response != null) {
return response; return response;
@ -172,11 +173,6 @@ public final class RetryInterceptor implements Interceptor {
return false; return false;
} }
// Visible for testing
interface BoundedLongGenerator {
long get(long bound);
}
// Visible for testing // Visible for testing
interface Sleeper { interface Sleeper {
void sleep(long delayNanos) throws InterruptedException; void sleep(long delayNanos) throws InterruptedException;

View File

@ -9,8 +9,9 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -32,9 +33,11 @@ import java.net.UnknownHostException;
import java.time.Duration; import java.time.Duration;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
import java.util.stream.Stream; import java.util.stream.Stream;
import okhttp3.Interceptor;
import okhttp3.OkHttpClient; import okhttp3.OkHttpClient;
import okhttp3.Request; import okhttp3.Request;
import okhttp3.Response; import okhttp3.Response;
@ -47,7 +50,9 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.stubbing.Answer;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
class RetryInterceptorTest { class RetryInterceptorTest {
@ -55,7 +60,7 @@ class RetryInterceptorTest {
@RegisterExtension static final MockWebServerExtension server = new MockWebServerExtension(); @RegisterExtension static final MockWebServerExtension server = new MockWebServerExtension();
@Mock private RetryInterceptor.Sleeper sleeper; @Mock private RetryInterceptor.Sleeper sleeper;
@Mock private RetryInterceptor.BoundedLongGenerator random; @Mock private Supplier<Double> random;
private Predicate<IOException> retryExceptionPredicate; private Predicate<IOException> retryExceptionPredicate;
private RetryInterceptor retrier; private RetryInterceptor retrier;
@ -91,6 +96,24 @@ class RetryInterceptorTest {
client = new OkHttpClient.Builder().addInterceptor(retrier).build(); client = new OkHttpClient.Builder().addInterceptor(retrier).build();
} }
@Test
void noRetryOnNullResponse() throws IOException {
Interceptor.Chain chain = mock(Interceptor.Chain.class);
when(chain.proceed(any())).thenReturn(null);
when(chain.request())
.thenReturn(new Request.Builder().url(server.httpUri().toString()).build());
assertThatThrownBy(
() -> {
retrier.intercept(chain);
})
.isInstanceOf(NullPointerException.class)
.hasMessage("response cannot be null.");
verifyNoInteractions(retryExceptionPredicate);
verifyNoInteractions(random);
verifyNoInteractions(sleeper);
}
@Test @Test
void noRetry() throws Exception { void noRetry() throws Exception {
server.enqueue(HttpResponse.of(HttpStatus.OK)); server.enqueue(HttpResponse.of(HttpStatus.OK));
@ -109,17 +132,8 @@ class RetryInterceptorTest {
@ValueSource(ints = {5, 6}) @ValueSource(ints = {5, 6})
void backsOff(int attempts) throws Exception { void backsOff(int attempts) throws Exception {
succeedOnAttempt(attempts); succeedOnAttempt(attempts);
when(random.get()).thenReturn(1.0d);
// Will backoff 4 times doNothing().when(sleeper).sleep(anyLong());
when(random.get((long) (TimeUnit.SECONDS.toNanos(1) * Math.pow(1.6, 0)))).thenReturn(100L);
when(random.get((long) (TimeUnit.SECONDS.toNanos(1) * Math.pow(1.6, 1)))).thenReturn(50L);
// Capped
when(random.get(TimeUnit.SECONDS.toNanos(2))).thenReturn(500L).thenReturn(510L);
doNothing().when(sleeper).sleep(100);
doNothing().when(sleeper).sleep(50);
doNothing().when(sleeper).sleep(500);
doNothing().when(sleeper).sleep(510);
try (Response response = sendRequest()) { try (Response response = sendRequest()) {
if (attempts <= 5) { if (attempts <= 5) {
@ -139,16 +153,26 @@ class RetryInterceptorTest {
succeedOnAttempt(5); succeedOnAttempt(5);
// Backs off twice, second is interrupted // Backs off twice, second is interrupted
when(random.get((long) (TimeUnit.SECONDS.toNanos(1) * Math.pow(1.6, 0)))).thenReturn(100L); when(random.get()).thenReturn(1.0d).thenReturn(1.0d);
when(random.get((long) (TimeUnit.SECONDS.toNanos(1) * Math.pow(1.6, 1)))).thenReturn(50L); doAnswer(
new Answer<Void>() {
int counter = 0;
doNothing().when(sleeper).sleep(100); @Override
doThrow(new InterruptedException()).when(sleeper).sleep(50); public Void answer(InvocationOnMock invocation) throws Throwable {
if (counter++ == 1) {
throw new InterruptedException();
}
return null;
}
})
.when(sleeper)
.sleep(anyLong());
try (Response response = sendRequest()) { try (Response response = sendRequest()) {
assertThat(response.isSuccessful()).isFalse(); assertThat(response.isSuccessful()).isFalse();
} }
verify(sleeper, times(2)).sleep(anyLong());
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
server.takeRequest(0, TimeUnit.NANOSECONDS); server.takeRequest(0, TimeUnit.NANOSECONDS);
} }
@ -157,7 +181,7 @@ class RetryInterceptorTest {
@Test @Test
void connectTimeout() throws Exception { void connectTimeout() throws Exception {
client = connectTimeoutClient(); client = connectTimeoutClient();
when(random.get(anyLong())).thenReturn(1L); when(random.get()).thenReturn(1.0d);
doNothing().when(sleeper).sleep(anyLong()); doNothing().when(sleeper).sleep(anyLong());
// Connecting to a non-routable IP address to trigger connection error // Connecting to a non-routable IP address to trigger connection error
@ -174,7 +198,7 @@ class RetryInterceptorTest {
@Test @Test
void connectException() throws Exception { void connectException() throws Exception {
client = connectTimeoutClient(); client = connectTimeoutClient();
when(random.get(anyLong())).thenReturn(1L); when(random.get()).thenReturn(1.0d);
doNothing().when(sleeper).sleep(anyLong()); doNothing().when(sleeper).sleep(anyLong());
// Connecting to localhost on an unused port address to trigger java.net.ConnectException // Connecting to localhost on an unused port address to trigger java.net.ConnectException