api,core: interceptor-based config selector (#7610)

Interceptor-based config selector will be needed for fault injection.

Add `interceptor` field to `InternalConfigSelector.Result`. Keep `callOptions` and `committedCallback` fields for the moment, because it needs a refactoring to migrate the existing xds config selector implementation to the new API.
This commit is contained in:
ZHANG Dapeng 2020-11-30 09:16:22 -08:00 committed by GitHub
parent 3811ef3d22
commit ac2327deb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 524 additions and 205 deletions

View File

@ -44,6 +44,10 @@ public abstract class InternalConfigSelector {
private final CallOptions callOptions;
@Nullable
private final Runnable committedCallback;
// TODO(zdapeng): delete callOptions and committedCallback fields, and migrate to use
// interceptor only.
@Nullable
public ClientInterceptor interceptor;
private Result(
Status status, Object config, CallOptions callOptions, Runnable committedCallback) {
@ -51,6 +55,16 @@ public abstract class InternalConfigSelector {
this.config = config;
this.callOptions = callOptions;
this.committedCallback = committedCallback;
this.interceptor = null;
}
private Result(
Status status, Object config, ClientInterceptor interceptor) {
this.status = checkNotNull(status, "status");
this.config = config;
this.callOptions = null;
this.committedCallback = null;
this.interceptor = interceptor;
}
/**
@ -92,6 +106,14 @@ public abstract class InternalConfigSelector {
return committedCallback;
}
/**
* Returns an interceptor that will be applies to calls.
*/
@Nullable
public ClientInterceptor getInterceptor() {
return interceptor;
}
public static Builder newBuilder() {
return new Builder();
}
@ -100,6 +122,7 @@ public abstract class InternalConfigSelector {
private Object config;
private CallOptions callOptions;
private Runnable committedCallback;
private ClientInterceptor interceptor;
private Builder() {}
@ -124,7 +147,7 @@ public abstract class InternalConfigSelector {
}
/**
* Sets the interceptor. This field is optional.
* Sets the committed callback. This field is optional.
*
* @return this
*/
@ -133,13 +156,27 @@ public abstract class InternalConfigSelector {
return this;
}
/**
* Sets the interceptor. This field is optional.
*
* @return this
*/
public Builder setInterceptor(ClientInterceptor interceptor) {
this.interceptor = checkNotNull(interceptor, "interceptor");
return this;
}
/**
* Build this {@link Result}.
*/
public Result build() {
checkState(config != null, "config is not set");
checkState(callOptions != null, "callOptions is not set");
return new Result(Status.OK, config, callOptions, committedCallback);
if (interceptor == null) {
checkState(callOptions != null, "callOptions is not set");
return new Result(Status.OK, config, callOptions, committedCallback);
} else {
return new Result(Status.OK, config, interceptor);
}
}
}
}

View File

@ -17,6 +17,7 @@
package io.grpc;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.mock;
import io.grpc.InternalConfigSelector.Result;
import io.grpc.Status.Code;
@ -55,6 +56,18 @@ public class InternalConfigSelectorTest {
assertThat(result.getCommittedCallback()).isSameInstanceAs(committedCallback);
}
@Test
public void resultBuilder_interceptorBased() {
Object config = "fake_config";
InternalConfigSelector.Result.Builder builder = InternalConfigSelector.Result.newBuilder();
ClientInterceptor interceptor = mock(ClientInterceptor.class);
InternalConfigSelector.Result result =
builder.setConfig(config).setInterceptor(interceptor).build();
assertThat(result.getStatus().isOk()).isTrue();
assertThat(result.getConfig()).isEqualTo(config);
assertThat(result.getInterceptor()).isSameInstanceAs(interceptor);
}
@Test
public void errorResult() {
Result result = Result.forError(Status.INTERNAL.withDescription("failed"));

View File

@ -40,10 +40,8 @@ import io.grpc.Context;
import io.grpc.Context.CancellationListener;
import io.grpc.Deadline;
import io.grpc.DecompressorRegistry;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.InternalConfigSelector;
import io.grpc.InternalDecompressorRegistry;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.MethodType;
@ -90,8 +88,6 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT> {
private final ContextCancellationListener cancellationListener =
new ContextCancellationListener();
private final ScheduledExecutorService deadlineCancellationExecutor;
@Nullable
private final InternalConfigSelector configSelector;
private boolean fullStreamDecompression;
private DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance();
private CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance();
@ -101,6 +97,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT> {
ClientStreamProvider clientStreamProvider,
ScheduledExecutorService deadlineCancellationExecutor,
CallTracer channelCallsTracer,
// TODO(zdapeng): remove this arg
@Nullable InternalConfigSelector configSelector) {
this.method = method;
// TODO(carl-mastrangelo): consider moving this construction to ManagedChannelImpl.
@ -123,7 +120,6 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT> {
this.callOptions = callOptions;
this.clientStreamProvider = clientStreamProvider;
this.deadlineCancellationExecutor = deadlineCancellationExecutor;
this.configSelector = configSelector;
PerfMark.event("ClientCall.<init>", tag);
}
@ -220,25 +216,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT> {
callExecutor.execute(new ClosedByContext());
return;
}
if (configSelector != null) {
PickSubchannelArgs args = new PickSubchannelArgsImpl(method, headers, callOptions);
InternalConfigSelector.Result result = configSelector.selectConfig(args);
Status status = result.getStatus();
if (!status.isOk()) {
executeCloseObserverInContext(observer, status);
return;
}
callOptions = result.getCallOptions();
Runnable committedCallback = result.getCommittedCallback();
if (committedCallback != null) {
observer = new CommittedCallbackListener(observer, committedCallback);
}
ManagedChannelServiceConfig config = (ManagedChannelServiceConfig) result.getConfig();
MethodInfo methodInfo = config.getMethodConfig(method);
applyMethodConfig(methodInfo);
}
applyMethodConfig();
final String compressorName = callOptions.getCompressor();
Compressor compressor;
if (compressorName != null) {
@ -325,38 +303,11 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT> {
}
}
private final class CommittedCallbackListener extends
SimpleForwardingClientCallListener<RespT> {
final Runnable committedCallback;
boolean committed;
CommittedCallbackListener(Listener<RespT> delegate, Runnable committedCallback) {
super(delegate);
this.committedCallback = committedCallback;
}
@Override
public void onHeaders(Metadata headers) {
committed = true;
committedCallback.run();
delegate().onHeaders(headers);
}
@Override
public void onClose(Status status, Metadata trailers) {
if (!committed) {
committed = true;
committedCallback.run();
}
delegate().onClose(status, trailers);
}
}
private void applyMethodConfig(MethodInfo info) {
private void applyMethodConfig() {
MethodInfo info = callOptions.getOption(MethodInfo.KEY);
if (info == null) {
return;
}
callOptions = callOptions.withOption(MethodInfo.KEY, info);
if (info.timeoutNanos != null) {
Deadline newDeadline = Deadline.after(info.timeoutNanos, TimeUnit.NANOSECONDS);
Deadline existingDeadline = callOptions.getDeadline();
@ -456,21 +407,6 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT> {
new DeadlineTimer(remainingNanos)), remainingNanos, TimeUnit.NANOSECONDS);
}
private void executeCloseObserverInContext(final Listener<RespT> observer, final Status status) {
class CloseInContext extends ContextRunnable {
CloseInContext() {
super(context);
}
@Override
public void runInContext() {
closeObserver(observer, status, new Metadata());
}
}
callExecutor.execute(new CloseInContext());
}
@Nullable
private Deadline effectiveDeadline() {
// Call options and context are immutable, so we don't need to cache the deadline.

View File

@ -0,0 +1,164 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.internal;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.Context;
import io.grpc.ForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.InternalConfigSelector;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.internal.ManagedChannelServiceConfig.MethodInfo;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
/**
* A client call for a given channel that applies a given config selector when it starts.
*/
final class ConfigSelectingClientCall<ReqT, RespT> extends ForwardingClientCall<ReqT, RespT> {
private final InternalConfigSelector configSelector;
private final Channel channel;
private final Executor callExecutor;
private final MethodDescriptor<ReqT, RespT> method;
private final Context context;
private CallOptions callOptions;
private ClientCall<ReqT, RespT> delegate;
ConfigSelectingClientCall(
InternalConfigSelector configSelector, Channel channel, Executor channelExecutor,
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions) {
this.configSelector = configSelector;
this.channel = channel;
this.method = method;
this.callOptions = callOptions;
this.callExecutor =
callOptions.getExecutor() == null ? channelExecutor : callOptions.getExecutor();
this.context = Context.current();
}
@Override
protected ClientCall<ReqT, RespT> delegate() {
return delegate;
}
@Override
public void start(Listener<RespT> observer, Metadata headers) {
PickSubchannelArgs args = new PickSubchannelArgsImpl(method, headers, callOptions);
InternalConfigSelector.Result result = configSelector.selectConfig(args);
Status status = result.getStatus();
if (!status.isOk()) {
executeCloseObserverInContext(observer, status);
return;
}
ClientInterceptor interceptor = result.getInterceptor();
ManagedChannelServiceConfig config = (ManagedChannelServiceConfig) result.getConfig();
MethodInfo methodInfo = config.getMethodConfig(method);
if (methodInfo != null) {
callOptions = callOptions.withOption(MethodInfo.KEY, methodInfo);
}
if (interceptor != null) {
delegate = interceptor.interceptCall(method, callOptions, channel);
} else if (result.getCallOptions() != null) {
// TODO(zdapeng): Delete this when migrating to use interceptor-based config selector only.
callOptions = result.getCallOptions();
if (methodInfo != null) {
callOptions = callOptions.withOption(MethodInfo.KEY, methodInfo);
}
Runnable callback = result.getCommittedCallback();
if (callback != null) {
delegate =
new CommittedCallbackInterceptor(callback).interceptCall(method, callOptions, channel);
} else {
delegate = channel.newCall(method, callOptions);
}
} else {
delegate = channel.newCall(method, callOptions);
}
delegate.start(observer, headers);
}
private void executeCloseObserverInContext(
final Listener<RespT> observer, final Status status) {
class CloseInContext extends ContextRunnable {
CloseInContext() {
super(context);
}
@Override
public void runInContext() {
observer.onClose(status, new Metadata());
}
}
callExecutor.execute(new CloseInContext());
}
@Override
public void cancel(@Nullable String message, @Nullable Throwable cause) {
if (delegate != null) {
delegate.cancel(message, cause);
}
}
// TODO(zdapeng): Delete this when migrating to use interceptor-based config selector only.
private final class CommittedCallbackInterceptor implements ClientInterceptor {
final Runnable callback;
CommittedCallbackInterceptor(Runnable callback) {
this.callback = callback;
}
@Override
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> interceptCall(
MethodDescriptor<RequestT, ResponseT> method, CallOptions callOptions, Channel next) {
return new SimpleForwardingClientCall<RequestT, ResponseT>(
next.newCall(method, callOptions)) {
@Override
public void start(Listener<ResponseT> listener, Metadata headers) {
listener = new SimpleForwardingClientCallListener<ResponseT>(listener) {
boolean committed;
@Override
public void onHeaders(Metadata headers) {
committed = true;
callback.run();
delegate().onHeaders(headers);
}
@Override
public void onClose(Status status, Metadata trailers) {
if (!committed) {
callback.run();
}
delegate().onClose(status, trailers);
}
};
delegate().start(listener, headers);
}
};
}
}
}

View File

@ -76,6 +76,7 @@ import io.grpc.internal.ClientCallImpl.ClientStreamProvider;
import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider;
import io.grpc.internal.ManagedChannelImplBuilder.UnsupportedClientTransportFactoryBuilder;
import io.grpc.internal.ManagedChannelServiceConfig.MethodInfo;
import io.grpc.internal.ManagedChannelServiceConfig.ServiceConfigConvertedSelector;
import io.grpc.internal.RetriableStream.ChannelBufferMeter;
import io.grpc.internal.RetriableStream.Throttle;
import java.net.URI;
@ -897,6 +898,29 @@ final class ManagedChannelImpl extends ManagedChannel implements
// same target, the new instance must have the same value.
private final String authority;
private final Channel clientCallImplChannel = new Channel() {
@Override
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
MethodDescriptor<RequestT, ResponseT> method, CallOptions callOptions) {
return new ClientCallImpl<>(
method,
getCallExecutor(callOptions),
callOptions,
transportProvider,
terminated ? null : transportFactory.getScheduledExecutorService(),
channelCallTracer,
null)
.setFullStreamDecompression(fullStreamDecompression)
.setDecompressorRegistry(decompressorRegistry)
.setCompressorRegistry(compressorRegistry);
}
@Override
public String authority() {
return authority;
}
};
private RealChannel(String authority) {
this.authority = checkNotNull(authority, "authority");
}
@ -1071,17 +1095,20 @@ final class ManagedChannelImpl extends ManagedChannel implements
private <ReqT, RespT> ClientCall<ReqT, RespT> newClientCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) {
return new ClientCallImpl<>(
method,
getCallExecutor(callOptions),
callOptions,
transportProvider,
terminated ? null : transportFactory.getScheduledExecutorService(),
channelCallTracer,
configSelector.get())
.setFullStreamDecompression(fullStreamDecompression)
.setDecompressorRegistry(decompressorRegistry)
.setCompressorRegistry(compressorRegistry);
InternalConfigSelector selector = configSelector.get();
if (selector == null) {
return clientCallImplChannel.newCall(method, callOptions);
}
if (selector instanceof ServiceConfigConvertedSelector) {
MethodInfo methodInfo =
((ServiceConfigConvertedSelector) selector).config.getMethodConfig(method);
if (methodInfo != null) {
callOptions = callOptions.withOption(MethodInfo.KEY, methodInfo);
}
return clientCallImplChannel.newCall(method, callOptions);
}
return new ConfigSelectingClientCall<>(
selector, clientCallImplChannel, executor, method, callOptions);
}
}

View File

@ -180,15 +180,7 @@ final class ManagedChannelServiceConfig {
if (serviceMap.isEmpty() && serviceMethodMap.isEmpty() && defaultMethodConfig == null) {
return null;
}
return new InternalConfigSelector() {
@Override
public Result selectConfig(PickSubchannelArgs args) {
return Result.newBuilder()
.setConfig(ManagedChannelServiceConfig.this)
.setCallOptions(args.getCallOptions())
.build();
}
};
return new ServiceConfigConvertedSelector(this);
}
@VisibleForTesting
@ -386,4 +378,22 @@ final class ManagedChannelServiceConfig {
ServiceConfigUtil.getNonFatalStatusCodesFromHedgingPolicy(hedgingPolicy));
}
}
static final class ServiceConfigConvertedSelector extends InternalConfigSelector {
final ManagedChannelServiceConfig config;
/** Converts the service config to config selector. */
private ServiceConfigConvertedSelector(ManagedChannelServiceConfig config) {
this.config = config;
}
@Override
public Result selectConfig(PickSubchannelArgs args) {
return Result.newBuilder()
.setConfig(config)
.setCallOptions(args.getCallOptions())
.build();
}
}
}

View File

@ -18,7 +18,6 @@ package io.grpc.internal;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITTER;
import static java.util.concurrent.TimeUnit.MINUTES;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@ -39,7 +38,6 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.MoreExecutors;
@ -54,7 +52,6 @@ import io.grpc.Deadline;
import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import io.grpc.InternalConfigSelector;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.MethodType;
@ -392,100 +389,14 @@ public class ClientCallImplTest {
}
@Test
public void configSelectorCallOptionsPropagatedToStream() {
ArgumentCaptor<CallOptions> callOptionsCaptor = ArgumentCaptor.forClass(null);
configSelector = new InternalConfigSelector() {
@Override
public Result selectConfig(PickSubchannelArgs args) {
return Result.newBuilder()
.setConfig(ManagedChannelServiceConfig.empty())
.setCallOptions(args.getCallOptions().withAuthority("dummy_value"))
.build();
}
};
ClientCallImpl<Void, Void> call = new ClientCallImpl<>(
method,
MoreExecutors.directExecutor(),
baseCallOptions,
clientStreamProvider,
deadlineCancellationExecutor,
channelCallTracer, configSelector)
.setDecompressorRegistry(decompressorRegistry);
call.start(callListener, new Metadata());
verify(clientStreamProvider).newStream(
same(method), callOptionsCaptor.capture(), any(Metadata.class), any(Context.class));
assertThat(callOptionsCaptor.getValue().getAuthority()).isEqualTo("dummy_value");
}
@Test
public void methodConfigPropagatedToStream() {
Map<String, ?> rawMethodConfig = ImmutableMap.of(
"retryPolicy",
ImmutableMap.of(
"maxAttempts", 3.0D,
"initialBackoff", "1s",
"maxBackoff", "10s",
"backoffMultiplier", 1.5D,
"retryableStatusCodes", ImmutableList.of("UNAVAILABLE")
));
final MethodInfo methodInfo = new MethodInfo(rawMethodConfig, true, 4, 4);
configSelector = new InternalConfigSelector() {
@Override
public Result selectConfig(PickSubchannelArgs args) {
ManagedChannelServiceConfig config = new ManagedChannelServiceConfig(
methodInfo,
ImmutableMap.<String, MethodInfo>of(),
ImmutableMap.<String, MethodInfo>of(),
null,
null,
null);
return Result.newBuilder()
.setConfig(config)
.setCallOptions(args.getCallOptions())
.build();
}
};
ClientCallImpl<Void, Void> call = new ClientCallImpl<>(
method,
MoreExecutors.directExecutor(),
baseCallOptions,
clientStreamProvider,
deadlineCancellationExecutor,
channelCallTracer, configSelector)
.setDecompressorRegistry(decompressorRegistry);
call.start(callListener, new Metadata());
ArgumentCaptor<CallOptions> callOptionsCaptor = ArgumentCaptor.forClass(null);
verify(clientStreamProvider).newStream(
same(method), callOptionsCaptor.capture(), any(Metadata.class), any(Context.class));
assertThat(callOptionsCaptor.getValue().getOption(MethodInfo.KEY)).isEqualTo(methodInfo);
}
@Test
public void configDeadlinePropagatedToStream() {
public void methodInfoDeadlinePropagatedToStream() {
ArgumentCaptor<CallOptions> callOptionsCaptor = ArgumentCaptor.forClass(null);
CallOptions callOptions = baseCallOptions.withDeadline(Deadline.after(2000, SECONDS));
// Case: config Deadline expires later than CallOptions Deadline
configSelector = new InternalConfigSelector() {
@Override
public Result selectConfig(PickSubchannelArgs args) {
Map<String, ?> rawMethodConfig = ImmutableMap.of(
"timeout",
"3000s");
MethodInfo methodInfo = new MethodInfo(rawMethodConfig, false, 0, 0);
ManagedChannelServiceConfig config = new ManagedChannelServiceConfig(
methodInfo,
ImmutableMap.<String, MethodInfo>of(),
ImmutableMap.<String, MethodInfo>of(),
null,
null,
null);
return Result.newBuilder()
.setConfig(config)
.setCallOptions(args.getCallOptions())
.build();
}
};
Map<String, ?> rawMethodConfig = ImmutableMap.of("timeout", "3000s");
MethodInfo methodInfo = new MethodInfo(rawMethodConfig, false, 0, 0);
callOptions = callOptions.withOption(MethodInfo.KEY, methodInfo);
ClientCallImpl<Void, Void> call = new ClientCallImpl<>(
method,
MoreExecutors.directExecutor(),
@ -501,26 +412,9 @@ public class ClientCallImplTest {
assertThat(actualDeadline).isLessThan(Deadline.after(2001, SECONDS));
// Case: config Deadline expires earlier than CallOptions Deadline
configSelector = new InternalConfigSelector() {
@Override
public Result selectConfig(PickSubchannelArgs args) {
Map<String, ?> rawMethodConfig = ImmutableMap.of(
"timeout",
"1000s");
MethodInfo methodInfo = new MethodInfo(rawMethodConfig, false, 0, 0);
ManagedChannelServiceConfig config = new ManagedChannelServiceConfig(
methodInfo,
ImmutableMap.<String, MethodInfo>of(),
ImmutableMap.<String, MethodInfo>of(),
null,
null,
null);
return Result.newBuilder()
.setConfig(config)
.setCallOptions(args.getCallOptions())
.build();
}
};
rawMethodConfig = ImmutableMap.of("timeout", "1000s");
methodInfo = new MethodInfo(rawMethodConfig, false, 0, 0);
callOptions = callOptions.withOption(MethodInfo.KEY, methodInfo);
call = new ClientCallImpl<>(
method,
MoreExecutors.directExecutor(),
@ -533,7 +427,7 @@ public class ClientCallImplTest {
verify(clientStreamProvider, times(2)).newStream(
same(method), callOptionsCaptor.capture(), any(Metadata.class), any(Context.class));
actualDeadline = callOptionsCaptor.getValue().getDeadline();
assertThat(actualDeadline).isLessThan(Deadline.after(1001, MINUTES));
assertThat(actualDeadline).isLessThan(Deadline.after(1001, SECONDS));
}
@Test

View File

@ -0,0 +1,163 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.internal;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.InternalConfigSelector;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.internal.ManagedChannelServiceConfig.MethodInfo;
import io.grpc.testing.TestMethodDescriptors;
import java.util.Map;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
/** Tests for {@link ConfigSelectingClientCall}. */
@RunWith(JUnit4.class)
public class ConfigSelectingClientCallTest {
@Rule
public MockitoRule mockitoRule = MockitoJUnit.rule();
private final MethodDescriptor<Void, Void> method = TestMethodDescriptors.voidMethod();
private TestChannel channel = new TestChannel();
// The underlying call directly created by the channel.
private TestCall<?, ?> call;
@Mock
private ClientCall.Listener<Void> callListener;
@Test
public void configSelectorInterceptsCall() {
Map<String, ?> rawMethodConfig = ImmutableMap.of(
"retryPolicy",
ImmutableMap.of(
"maxAttempts", 3.0D,
"initialBackoff", "1s",
"maxBackoff", "10s",
"backoffMultiplier", 1.5D,
"retryableStatusCodes", ImmutableList.of("UNAVAILABLE")
));
final MethodInfo methodInfo = new MethodInfo(rawMethodConfig, true, 4, 4);
final Metadata.Key<String> metadataKey =
Metadata.Key.of("test", Metadata.ASCII_STRING_MARSHALLER);
final CallOptions.Key<String> callOptionsKey = CallOptions.Key.create("test");
InternalConfigSelector configSelector = new InternalConfigSelector() {
@Override
public Result selectConfig(final PickSubchannelArgs args) {
ManagedChannelServiceConfig config = new ManagedChannelServiceConfig(
methodInfo,
ImmutableMap.<String, MethodInfo>of(),
ImmutableMap.<String, MethodInfo>of(),
null,
null,
null);
return Result.newBuilder()
.setConfig(config)
.setInterceptor(
// An interceptor that mutates CallOptions based on headers value.
new ClientInterceptor() {
String value = args.getHeaders().get(metadataKey);
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
callOptions = callOptions.withOption(callOptionsKey, value);
return next.newCall(method, callOptions);
}
})
.build();
}
};
ClientCall<Void, Void> configSelectingClientCall = new ConfigSelectingClientCall<>(
configSelector,
channel,
MoreExecutors.directExecutor(),
method,
CallOptions.DEFAULT.withAuthority("bar.authority"));
Metadata metadata = new Metadata();
metadata.put(metadataKey, "fooValue");
configSelectingClientCall.start(callListener, metadata);
assertThat(call.callOptions.getAuthority()).isEqualTo("bar.authority");
assertThat(call.callOptions.getOption(MethodInfo.KEY)).isEqualTo(methodInfo);
assertThat(call.callOptions.getOption(callOptionsKey)).isEqualTo("fooValue");
}
@Test
public void selectionErrorPropagatedToListener() {
InternalConfigSelector configSelector = new InternalConfigSelector() {
@Override
public Result selectConfig(PickSubchannelArgs args) {
return Result.forError(Status.FAILED_PRECONDITION);
}
};
ClientCall<Void, Void> configSelectingClientCall = new ConfigSelectingClientCall<>(
configSelector,
channel,
MoreExecutors.directExecutor(),
method,
CallOptions.DEFAULT);
configSelectingClientCall.start(callListener, new Metadata());
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(null);
verify(callListener).onClose(statusCaptor.capture(), any(Metadata.class));
assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.FAILED_PRECONDITION);
}
private final class TestChannel extends Channel {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
MethodDescriptor<ReqT, RespT> methodDescriptor, CallOptions callOptions) {
TestCall<ReqT, RespT> clientCall = new TestCall<>(callOptions);
call = clientCall;
return clientCall;
}
@Override
public String authority() {
return "foo.authority";
}
}
private static final class TestCall<ReqT, RespT> extends NoopClientCall<ReqT, RespT> {
// CallOptions actually received from the channel when the call is created.
final CallOptions callOptions;
TestCall(CallOptions callOptions) {
this.callOptions = callOptions;
}
}
}

View File

@ -606,6 +606,81 @@ public class ManagedChannelImplTest {
TimeUnit.SECONDS.toNanos(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS));
}
@Test
public void newCallWithConfigSelector_interceptorBased() {
FakeNameResolverFactory nameResolverFactory =
new FakeNameResolverFactory.Builder(expectedUri)
.setServers(ImmutableList.of(addressGroup)).build();
channelBuilder.nameResolverFactory(nameResolverFactory);
channel = new ManagedChannelImpl(
channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(),
balancerRpcExecutorPool, timer.getStopwatchSupplier(),
Collections.<ClientInterceptor>emptyList(), timer.getTimeProvider());
nameResolverFactory.nextConfigOrError.set(
ConfigOrError.fromConfig(ManagedChannelServiceConfig.empty()));
final Metadata.Key<String> metadataKey =
Metadata.Key.of("test", Metadata.ASCII_STRING_MARSHALLER);
final CallOptions.Key<String> callOptionsKey = CallOptions.Key.create("test");
InternalConfigSelector configSelector = new InternalConfigSelector() {
@Override
public Result selectConfig(final PickSubchannelArgs args) {
return Result.newBuilder()
.setConfig(ManagedChannelServiceConfig.empty())
.setInterceptor(
// An interceptor that mutates CallOptions based on headers value.
new ClientInterceptor() {
String value = args.getHeaders().get(metadataKey);
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
callOptions = callOptions.withOption(callOptionsKey, value);
return next.newCall(method, callOptions);
}
})
.build();
}
};
nameResolverFactory.nextAttributes.set(
Attributes.newBuilder().set(InternalConfigSelector.KEY, configSelector).build());
channel.getState(true);
Metadata headers = new Metadata();
headers.put(metadataKey, "fooValue");
ClientStream mockStream = mock(ClientStream.class);
ClientCall<String, Integer> call = channel.newCall(method, CallOptions.DEFAULT);
call.start(mockCallListener, headers);
ArgumentCaptor<Helper> helperCaptor = ArgumentCaptor.forClass(null);
verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture());
helper = helperCaptor.getValue();
// Make the transport available
Subchannel subchannel =
createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener);
requestConnectionSafely(helper, subchannel);
verify(mockTransportFactory)
.newClientTransport(
any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class));
MockClientTransportInfo transportInfo = transports.poll();
ConnectionClientTransport mockTransport = transportInfo.transport;
ManagedClientTransport.Listener transportListener = transportInfo.listener;
when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class)))
.thenReturn(mockStream);
transportListener.transportReady();
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel));
updateBalancingStateSafely(helper, READY, mockPicker);
executor.runDueTasks();
ArgumentCaptor<CallOptions> callOptionsCaptor = ArgumentCaptor.forClass(null);
verify(mockTransport).newStream(same(method), same(headers), callOptionsCaptor.capture());
assertThat(callOptionsCaptor.getValue().getOption(callOptionsKey)).isEqualTo("fooValue");
verify(mockStream).start(streamListenerCaptor.capture());
// Clean up as much as possible to allow the channel to terminate.
shutdownSafely(helper, subchannel);
timer.forwardNanos(
TimeUnit.SECONDS.toNanos(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS));
}
@Test
public void shutdownWithNoTransportsEverCreated() {
channelBuilder.nameResolverFactory(