From 7fd5f261b472a564fc4a6da0ebd20928d5f8c761 Mon Sep 17 00:00:00 2001 From: Chengyuan Zhang Date: Fri, 24 May 2019 15:12:22 -0700 Subject: [PATCH] xds: implement lb policy backend metric api (#5639) * implemented utility methods to create ClientStreamTracer.Factory with OrcaReportListener installed for retrieving per-request ORCA data * added unit tests * use delegatesTo instead of spy * implemented OrcaReportingHelper delegating to some original Helper for load balancing policies accessing OOB metric reports * added unit tests for out-of-band ORCA metric accessing API in a separate test class * rebase to master, resolve the breaking change of StreamInfo class being final with builder * trashed hashCode/equal for OrcaReportingConfig * changed log level and channel trace event level to ERROR as required by design doc * added OrcaReportingHelperWrapper layer to allow updating report interval at any time * reverse the naming of parent/child helper, child helper is the outer-most helper in the wrapping structure * changed orca listener interface to use separate listener interfaces for per-request and out-of-band cases * added more comprehensive unit tests * added test case for per-request reporting that parent creates its own stream tracer * fixed bug of directly assign reporting config, which would cause it be mutated later * separate test cases for updating reporting config at different time * fixed lint style error * polish comments * minor polish in unit tests * refactor OrcaUtil class into OrcaOobUtil and OrcaPerRequestUtil and get rid of static methods for easier user testing * hide BackoffPolicyProvider and Stopwatch supplier in OrcaOobUtil's public API * add javadoc for getInstance() methods * ensure the same Subchannel instance created by the helper that has corresponding OrcaOobReportListener registered are passed to the listener callback * removed costNames foe OrcaReportingConfig * removed redundant checks * reformated the OrcaOobUtilTest class to put helper methods in the bottom * fixed impl with changes made on Subchannel (SubchannelStateListener now ties with Subchannel) * fixed comments * added usage examples in javadoc for OrcaUtils * add method comments for OrcaUtil's listener API threading * make fields in OrcaReportingConfig final * fixed OrcaOobUtilTest for calling setOrcaReportingConfig inside syncContext * added ExperimentalApi annotation for Orca utils --- .../main/java/io/grpc/xds/OrcaOobUtil.java | 629 ++++++++++++ .../java/io/grpc/xds/OrcaPerRequestUtil.java | 270 ++++++ .../java/io/grpc/xds/OrcaOobUtilTest.java | 892 ++++++++++++++++++ .../io/grpc/xds/OrcaPerRequestUtilTest.java | 170 ++++ 4 files changed, 1961 insertions(+) create mode 100644 xds/src/main/java/io/grpc/xds/OrcaOobUtil.java create mode 100644 xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java create mode 100644 xds/src/test/java/io/grpc/xds/OrcaOobUtilTest.java create mode 100644 xds/src/test/java/io/grpc/xds/OrcaPerRequestUtilTest.java diff --git a/xds/src/main/java/io/grpc/xds/OrcaOobUtil.java b/xds/src/main/java/io/grpc/xds/OrcaOobUtil.java new file mode 100644 index 0000000000..fa63f8a748 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/OrcaOobUtil.java @@ -0,0 +1,629 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static io.grpc.ConnectivityState.IDLE; +import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.SHUTDOWN; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; +import com.google.common.base.Objects; +import com.google.common.base.Stopwatch; +import com.google.common.base.Supplier; +import com.google.protobuf.util.Durations; +import io.envoyproxy.udpa.data.orca.v1.OrcaLoadReport; +import io.envoyproxy.udpa.service.orca.v1.OpenRcaServiceGrpc; +import io.envoyproxy.udpa.service.orca.v1.OrcaLoadReportRequest; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ChannelLogger; +import io.grpc.ChannelLogger.ChannelLogLevel; +import io.grpc.ClientCall; +import io.grpc.ConnectivityStateInfo; +import io.grpc.ExperimentalApi; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.Status.Code; +import io.grpc.SynchronizationContext; +import io.grpc.SynchronizationContext.ScheduledHandle; +import io.grpc.internal.BackoffPolicy; +import io.grpc.internal.ExponentialBackoffPolicy; +import io.grpc.internal.GrpcUtil; +import io.grpc.util.ForwardingLoadBalancerHelper; +import io.grpc.util.ForwardingSubchannel; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * Utility class that provides method for {@link LoadBalancer} to install listeners to receive + * out-of-band backend cost metrics in the format of Open Request Cost Aggregation (ORCA). + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/5790") +public abstract class OrcaOobUtil { + + private static final Logger logger = Logger.getLogger(OrcaPerRequestUtil.class.getName()); + private static final OrcaOobUtil DEFAULT_INSTANCE = + new OrcaOobUtil() { + + @Override + public OrcaReportingHelperWrapper newOrcaReportingHelperWrapper( + LoadBalancer.Helper delegate, + OrcaOobReportListener listener) { + return newOrcaReportingHelperWrapper( + delegate, + listener, + new ExponentialBackoffPolicy.Provider(), + GrpcUtil.STOPWATCH_SUPPLIER); + } + }; + + /** + * Gets an {@code OrcaOobUtil} instance that provides actual implementation of + * {@link #newOrcaReportingHelperWrapper}. + */ + public static OrcaOobUtil getInstance() { + return DEFAULT_INSTANCE; + } + + /** + * Creates a new {@link LoadBalancer.Helper} with provided {@link OrcaOobReportListener} installed + * to receive callback when an out-of-band ORCA report is received. + * + *

Example usages: + * + *

+ * + * @param delegate the delegate helper that provides essentials for establishing subchannels to + * backends. + * @param listener contains the callback to be invoked when an out-of-band ORCA report is + * received. + */ + public abstract OrcaReportingHelperWrapper newOrcaReportingHelperWrapper( + LoadBalancer.Helper delegate, + OrcaOobReportListener listener); + + @VisibleForTesting + static OrcaReportingHelperWrapper newOrcaReportingHelperWrapper( + LoadBalancer.Helper delegate, + OrcaOobReportListener listener, + BackoffPolicy.Provider backoffPolicyProvider, + Supplier stopwatchSupplier) { + final OrcaReportingHelper orcaHelper = + new OrcaReportingHelper(delegate, listener, backoffPolicyProvider, stopwatchSupplier); + + return new OrcaReportingHelperWrapper() { + @Override + public void setReportingConfig(OrcaReportingConfig config) { + orcaHelper.setReportingConfig(config); + } + + @Override + public Helper asHelper() { + return orcaHelper; + } + }; + } + + /** + * The listener interface for receiving out-of-band ORCA reports from backends. The class that is + * interested in processing backend cost metrics implements this interface, and the object created + * with that class is registered with a component, using methods in {@link OrcaPerRequestUtil}. + * When an ORCA report is received, that object's {@code onLoadReport} method is invoked. + */ + public interface OrcaOobReportListener { + + /** + * Invoked when an out-of-band ORCA report is received. + * + *

Note this callback will be invoked from the {@link SynchronizationContext} of the + * delegated helper, implementations should not block. + * + * @param report load report in the format of ORCA protocol. + */ + void onLoadReport(OrcaLoadReport report); + } + + /** + * Blueprint for the wrapper that wraps a {@link LoadBalancer.Helper} with the capability of + * allowing {@link LoadBalancer}s interested in receiving out-of-band ORCA reports to update the + * reporting configuration such as reporting interval. + */ + public abstract static class OrcaReportingHelperWrapper { + + /** + * Sets the configuration of receiving ORCA reports, such as the interval of receiving reports. + * + *

This method needs to be called from the SynchronizationContext returned by the wrapped + * helper's {@link Helper#getSynchronizationContext()}. + * + *

Each load balancing policy must call this method to configure the backend load reporting. + * Otherwise, it will not receive ORCA reports. + * + *

If multiple load balancing policies configure reporting with different intervals, reports + * come with the minimum of those intervals. + * + * @param config the configuration to be set. + */ + public abstract void setReportingConfig(OrcaReportingConfig config); + + /** + * Returns a wrapped {@link LoadBalancer.Helper}. Subchannels created through it will retrieve + * ORCA load reports if the server supports it. + */ + public abstract LoadBalancer.Helper asHelper(); + } + + /** + * An {@link OrcaReportingHelper} wraps a delegated {@link LoadBalancer.Helper} with additional + * functionality to manage RPCs for out-of-band ORCA reporting for each backend it establishes + * connection to. + */ + private static final class OrcaReportingHelper extends ForwardingLoadBalancerHelper + implements OrcaOobReportListener { + + private static final CreateSubchannelArgs.Key ORCA_REPORTING_STATE_KEY = + CreateSubchannelArgs.Key.create("internal-orca-reporting-state"); + private final LoadBalancer.Helper delegate; + private final OrcaOobReportListener listener; + private final SynchronizationContext syncContext; + private final BackoffPolicy.Provider backoffPolicyProvider; + private final Supplier stopwatchSupplier; + private final Set orcaStates = new HashSet<>(); + @Nullable private OrcaReportingConfig orcaConfig; + + OrcaReportingHelper( + LoadBalancer.Helper delegate, + OrcaOobReportListener listener, + BackoffPolicy.Provider backoffPolicyProvider, + Supplier stopwatchSupplier) { + this.delegate = checkNotNull(delegate, "delegate"); + this.listener = checkNotNull(listener, "listener"); + this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); + this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier"); + syncContext = checkNotNull(delegate.getSynchronizationContext(), "syncContext"); + } + + @Override + protected Helper delegate() { + return delegate; + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + syncContext.throwIfNotInThisSynchronizationContext(); + OrcaReportingState orcaState = args.getOption(ORCA_REPORTING_STATE_KEY); + boolean augmented = false; + if (orcaState == null) { + // Only the first load balancing policy requesting ORCA reports instantiates an + // OrcaReportingState. + orcaState = new OrcaReportingState(this, syncContext, + delegate().getScheduledExecutorService()); + args = args.toBuilder().addOption(ORCA_REPORTING_STATE_KEY, orcaState).build(); + augmented = true; + } + orcaStates.add(orcaState); + orcaState.listeners.add(this); + Subchannel subchannel = super.createSubchannel(args); + if (augmented) { + subchannel = new SubchannelImpl(subchannel, orcaState); + } + if (orcaConfig != null) { + orcaState.setReportingConfig(this, orcaConfig); + } + return subchannel; + } + + void setReportingConfig(final OrcaReportingConfig config) { + syncContext.throwIfNotInThisSynchronizationContext(); + orcaConfig = config; + for (OrcaReportingState state : orcaStates) { + state.setReportingConfig(OrcaReportingHelper.this, config); + } + } + + @Override + public void onLoadReport(OrcaLoadReport report) { + syncContext.throwIfNotInThisSynchronizationContext(); + if (orcaConfig != null) { + listener.onLoadReport(report); + } + } + + /** + * An {@link OrcaReportingState} is a client of ORCA service running on a single backend. + * + *

All methods are run from {@code syncContext}. + */ + private final class OrcaReportingState implements SubchannelStateListener { + + private final OrcaReportingHelper orcaHelper; + private final SynchronizationContext syncContext; + private final ScheduledExecutorService timeService; + private final List listeners = new ArrayList<>(); + private final Map configs = new HashMap<>(); + @Nullable private Subchannel subchannel; + @Nullable private ChannelLogger subchannelLogger; + @Nullable + private SubchannelStateListener stateListener; + @Nullable private BackoffPolicy backoffPolicy; + @Nullable private OrcaReportingStream orcaRpc; + @Nullable private ScheduledHandle retryTimer; + @Nullable private OrcaReportingConfig overallConfig; + private final Runnable retryTask = + new Runnable() { + @Override + public void run() { + startRpc(); + } + }; + private ConnectivityStateInfo state = ConnectivityStateInfo.forNonError(IDLE); + // True if server returned UNIMPLEMENTED. + private boolean disabled; + + OrcaReportingState( + OrcaReportingHelper orcaHelper, + SynchronizationContext syncContext, + ScheduledExecutorService timeService) { + this.orcaHelper = checkNotNull(orcaHelper, "orcaHelper"); + this.syncContext = checkNotNull(syncContext, "syncContext"); + this.timeService = checkNotNull(timeService, "timeService"); + } + + void init(Subchannel subchannel, SubchannelStateListener stateListener) { + checkState(this.subchannel == null, "init() already called"); + this.subchannel = checkNotNull(subchannel, "subchannel"); + this.subchannelLogger = checkNotNull(subchannel.getChannelLogger(), "subchannelLogger"); + this.stateListener = checkNotNull(stateListener, "stateListener"); + } + + void setReportingConfig(OrcaReportingHelper helper, OrcaReportingConfig config) { + boolean reconfigured = false; + configs.put(helper, config); + // Real reporting interval is the minimum of intervals requested by all participating + // helpers. + if (overallConfig == null) { + overallConfig = config.toBuilder().build(); + reconfigured = true; + } else { + long minInterval = Long.MAX_VALUE; + for (OrcaReportingConfig c : configs.values()) { + if (c.getReportIntervalNanos() < minInterval) { + minInterval = c.getReportIntervalNanos(); + } + } + if (overallConfig.getReportIntervalNanos() != minInterval) { + overallConfig = overallConfig.toBuilder() + .setReportInterval(minInterval, TimeUnit.NANOSECONDS).build(); + reconfigured = true; + } + } + if (reconfigured) { + stopRpc("ORCA reporting reconfigured"); + adjustOrcaReporting(); + } + } + + @Override + public void onSubchannelState(ConnectivityStateInfo newState) { + if (Objects.equal(state.getState(), READY) && !Objects.equal(newState.getState(), READY)) { + // A connection was lost. We will reset disabled flag because ORCA service + // may be available on the new connection. + disabled = false; + } + if (Objects.equal(newState.getState(), SHUTDOWN)) { + orcaHelper.orcaStates.remove(this); + } + state = newState; + adjustOrcaReporting(); + // Propagate subchannel state update to downstream listeners. + stateListener.onSubchannelState(newState); + } + + void adjustOrcaReporting() { + if (!disabled && overallConfig != null && Objects.equal(state.getState(), READY)) { + if (orcaRpc == null && !isRetryTimerPending()) { + startRpc(); + } + } else { + stopRpc("Client stops ORCA reporting"); + backoffPolicy = null; + } + } + + void startRpc() { + checkState(orcaRpc == null, "previous orca reporting RPC has not been cleaned up"); + checkState(subchannel != null, "init() not called"); + subchannelLogger.log( + ChannelLogLevel.DEBUG, "Starting ORCA reporting for {0}", subchannel.getAllAddresses()); + orcaRpc = new OrcaReportingStream(subchannel.asChannel(), stopwatchSupplier.get()); + orcaRpc.start(); + } + + void stopRpc(String msg) { + if (orcaRpc != null) { + orcaRpc.cancel(msg); + orcaRpc = null; + } + if (retryTimer != null) { + retryTimer.cancel(); + retryTimer = null; + } + } + + boolean isRetryTimerPending() { + return retryTimer != null && retryTimer.isPending(); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("disabled", disabled) + .add("orcaRpc", orcaRpc) + .add("reportingConfig", overallConfig) + .add("connectivityState", state) + .toString(); + } + + private class OrcaReportingStream extends ClientCall.Listener { + + private final ClientCall call; + private final Stopwatch stopwatch; + private boolean callHasResponded; + + OrcaReportingStream(Channel channel, Stopwatch stopwatch) { + call = + checkNotNull(channel, "channel") + .newCall(OpenRcaServiceGrpc.getStreamCoreMetricsMethod(), CallOptions.DEFAULT); + this.stopwatch = checkNotNull(stopwatch, "stopwatch"); + } + + void start() { + stopwatch.reset().start(); + call.start(this, new Metadata()); + call.sendMessage( + OrcaLoadReportRequest.newBuilder() + .setReportInterval(Durations.fromNanos(overallConfig.getReportIntervalNanos())) + .build()); + call.halfClose(); + call.request(1); + } + + @Override + public void onMessage(final OrcaLoadReport response) { + syncContext.execute( + new Runnable() { + @Override + public void run() { + if (orcaRpc == OrcaReportingStream.this) { + handleResponse(response); + } + } + }); + } + + @Override + public void onClose(final Status status, Metadata trailers) { + syncContext.execute( + new Runnable() { + @Override + public void run() { + if (orcaRpc == OrcaReportingStream.this) { + orcaRpc = null; + handleStreamClosed(status); + } + } + }); + } + + void handleResponse(OrcaLoadReport response) { + callHasResponded = true; + backoffPolicy = null; + subchannelLogger.log(ChannelLogLevel.DEBUG, "Received an ORCA report: {0}", response); + for (OrcaOobReportListener listener : listeners) { + listener.onLoadReport(response); + } + call.request(1); + } + + void handleStreamClosed(Status status) { + if (Objects.equal(status.getCode(), Code.UNIMPLEMENTED)) { + disabled = true; + logger.log( + Level.SEVERE, + "Backend {0} OpenRcaService is disabled. Server returned: {1}", + new Object[] {subchannel.getAllAddresses(), status}); + subchannelLogger.log(ChannelLogLevel.ERROR, "OpenRcaService disabled: {0}", status); + return; + } + long delayNanos = 0; + // Backoff only when no response has been received. + if (!callHasResponded) { + if (backoffPolicy == null) { + backoffPolicy = backoffPolicyProvider.get(); + } + delayNanos = backoffPolicy.nextBackoffNanos() - stopwatch.elapsed(TimeUnit.NANOSECONDS); + } + subchannelLogger.log( + ChannelLogLevel.DEBUG, + "ORCA reporting stream closed with {0}, backoff in {1} ns", + status, + delayNanos <= 0 ? 0 : delayNanos); + if (delayNanos <= 0) { + startRpc(); + } else { + checkState(!isRetryTimerPending(), "Retry double scheduled"); + retryTimer = + syncContext.schedule(retryTask, delayNanos, TimeUnit.NANOSECONDS, timeService); + } + } + + void cancel(String msg) { + call.cancel(msg, null); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("callStarted", call != null) + .add("callHasResponded", callHasResponded) + .toString(); + } + } + } + } + + @VisibleForTesting + static final class SubchannelImpl extends ForwardingSubchannel { + + private final Subchannel delegate; + private final OrcaReportingHelper.OrcaReportingState orcaState; + + SubchannelImpl(Subchannel delegate, OrcaReportingHelper.OrcaReportingState orcaState) { + this.delegate = checkNotNull(delegate, "delegate"); + this.orcaState = checkNotNull(orcaState, "orcaState"); + } + + @Override + protected Subchannel delegate() { + return delegate; + } + + @Override + public void start(SubchannelStateListener listener) { + orcaState.init(this, listener); + super.start(orcaState); + } + } + + /** Configuration for out-of-band ORCA reporting service RPC. */ + public static final class OrcaReportingConfig { + + private final long reportIntervalNanos; + + private OrcaReportingConfig(long reportIntervalNanos) { + this.reportIntervalNanos = reportIntervalNanos; + } + + /** Creates a new builder. */ + public static Builder newBuilder() { + return new Builder(); + } + + /** Returns the configured maximum interval of receiving out-of-band ORCA reports. */ + public long getReportIntervalNanos() { + return reportIntervalNanos; + } + + /** Returns a builder with the same initial values as this object. */ + public Builder toBuilder() { + return newBuilder().setReportInterval(reportIntervalNanos, TimeUnit.NANOSECONDS); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("reportIntervalNanos", reportIntervalNanos) + .toString(); + } + + public static final class Builder { + + private long reportIntervalNanos; + + Builder() {} + + /** + * Sets the maximum expected interval of receiving out-of-band ORCA report. The actual + * reporting interval might be smaller if there are other load balancing policies requesting + * for more frequent cost metric report. + * + * @param reportInterval the maximum expected interval of receiving periodical ORCA reports. + * @param unit time unit of {@code reportInterval} value. + */ + public Builder setReportInterval(long reportInterval, TimeUnit unit) { + reportIntervalNanos = unit.toNanos(reportInterval); + return this; + } + + /** Creates a new {@link OrcaReportingConfig} object. */ + public OrcaReportingConfig build() { + return new OrcaReportingConfig(reportIntervalNanos); + } + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java b/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java new file mode 100644 index 0000000000..6cac6bcfea --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java @@ -0,0 +1,270 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import io.envoyproxy.udpa.data.orca.v1.OrcaLoadReport; +import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; +import io.grpc.ExperimentalApi; +import io.grpc.LoadBalancer; +import io.grpc.Metadata; +import io.grpc.protobuf.ProtoUtils; +import io.grpc.util.ForwardingClientStreamTracer; +import java.util.ArrayList; +import java.util.List; + +/** + * Utility class that provides method for {@link LoadBalancer} to install listeners to receive + * per-request backend cost metrics in the format of Open Request Cost Aggregation (ORCA). + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/5790") +public abstract class OrcaPerRequestUtil { + private static final ClientStreamTracer NOOP_CLIENT_STREAM_TRACER = new ClientStreamTracer() {}; + private static final ClientStreamTracer.Factory NOOP_CLIENT_STREAM_TRACER_FACTORY = + new ClientStreamTracer.Factory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return NOOP_CLIENT_STREAM_TRACER; + } + }; + private static final OrcaPerRequestUtil DEFAULT_INSTANCE = + new OrcaPerRequestUtil() { + @Override + public ClientStreamTracer.Factory newOrcaClientStreamTracerFactory( + OrcaPerRequestReportListener listener) { + return newOrcaClientStreamTracerFactory(NOOP_CLIENT_STREAM_TRACER_FACTORY, listener); + } + + @Override + public ClientStreamTracer.Factory newOrcaClientStreamTracerFactory( + ClientStreamTracer.Factory delegate, OrcaPerRequestReportListener listener) { + return new OrcaReportingTracerFactory(delegate, listener); + } + }; + + /** + * Gets an {@code OrcaPerRequestUtil} instance that provides actual implementation of + * {@link #newOrcaClientStreamTracerFactory}. + */ + public static OrcaPerRequestUtil getInstance() { + return DEFAULT_INSTANCE; + } + + /** + * Creates a new {@link ClientStreamTracer.Factory} with provided {@link + * OrcaPerRequestReportListener} installed to receive callback when a per-request ORCA report is + * received. + * + *

Example usages for leaf level policy (e.g., WRR policy) + * + *

+   *   {@code
+   *   class WrrPicker extends SubchannelPicker {
+   *
+   *     public PickResult pickSubchannel(PickSubchannelArgs args) {
+   *       Subchannel subchannel = ...  // WRR picking logic
+   *       return PickResult.withSubchannel(
+   *           subchannel,
+   *           OrcaPerRequestReportUtil.getInstance().newOrcaClientStreamTracerFactory(listener));
+   *     }
+   *   }
+   *   }
+   * 
+ * + * @param listener contains the callback to be invoked when a per-request ORCA report is received. + */ + public abstract ClientStreamTracer.Factory newOrcaClientStreamTracerFactory( + OrcaPerRequestReportListener listener); + + /** + * Creates a new {@link ClientStreamTracer.Factory} with provided {@link + * OrcaPerRequestReportListener} installed to receive callback when a per-request ORCA report is + * received. + * + *

Example usages: + * + *

+ * + * @param delegate the delegate factory to produce other client stream tracing. + * @param listener contains the callback to be invoked when a per-request ORCA report is received. + */ + public abstract ClientStreamTracer.Factory newOrcaClientStreamTracerFactory( + ClientStreamTracer.Factory delegate, OrcaPerRequestReportListener listener); + + /** + * The listener interface for receiving per-request ORCA reports from backends. The class that is + * interested in processing backend cost metrics implements this interface, and the object created + * with that class is registered with a component, using methods in {@link OrcaPerRequestUtil}. + * When an ORCA report is received, that object's {@code onLoadReport} method is invoked. + */ + public interface OrcaPerRequestReportListener { + + /** + * Invoked when an per-request ORCA report is received. + * + *

Note this callback will be invoked from the network thread as the RPC finishes, + * implementations should not block. + * + * @param report load report in the format of ORCA format. + */ + void onLoadReport(OrcaLoadReport report); + } + + /** + * An {@link OrcaReportingTracerFactory} wraps a delegated {@link ClientStreamTracer.Factory} with + * additional functionality to produce {@link ClientStreamTracer} instances that extract + * per-request ORCA reports and push to registered listeners for calls they trace. + */ + @VisibleForTesting + static final class OrcaReportingTracerFactory extends ClientStreamTracer.Factory { + + @VisibleForTesting + static final Metadata.Key ORCA_ENDPOINT_LOAD_METRICS_KEY = + Metadata.Key.of( + "x-endpoint-load-metrics-bin", + ProtoUtils.metadataMarshaller(OrcaLoadReport.getDefaultInstance())); + + private static final CallOptions.Key ORCA_REPORT_BROKER_KEY = + CallOptions.Key.create("internal-orca-report-broker"); + private final ClientStreamTracer.Factory delegate; + private final OrcaPerRequestReportListener listener; + + OrcaReportingTracerFactory( + ClientStreamTracer.Factory delegate, OrcaPerRequestReportListener listener) { + this.delegate = checkNotNull(delegate, "delegate"); + this.listener = checkNotNull(listener, "listener"); + } + + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + OrcaReportBroker broker = info.getCallOptions().getOption(ORCA_REPORT_BROKER_KEY); + boolean augmented = false; + if (broker == null) { + broker = new OrcaReportBroker(); + info = + info.toBuilder() + .setCallOptions(info.getCallOptions().withOption(ORCA_REPORT_BROKER_KEY, broker)) + .build(); + augmented = true; + } + broker.addListener(listener); + ClientStreamTracer tracer = delegate.newClientStreamTracer(info, headers); + if (augmented) { + final ClientStreamTracer currTracer = tracer; + final OrcaReportBroker currBroker = broker; + // The actual tracer that performs ORCA report deserialization. + tracer = + new ForwardingClientStreamTracer() { + @Override + protected ClientStreamTracer delegate() { + return currTracer; + } + + @Override + public void inboundTrailers(Metadata trailers) { + OrcaLoadReport report = trailers.get(ORCA_ENDPOINT_LOAD_METRICS_KEY); + if (report != null) { + currBroker.onReport(report); + } + delegate().inboundTrailers(trailers); + } + }; + } + return tracer; + } + } + + /** + * A container class to hold registered {@link OrcaPerRequestReportListener}s and invoke all of + * them when an {@link OrcaLoadReport} is received. + */ + private static final class OrcaReportBroker { + + private final List listeners = new ArrayList<>(); + + void addListener(OrcaPerRequestReportListener listener) { + listeners.add(listener); + } + + void onReport(OrcaLoadReport report) { + for (OrcaPerRequestReportListener listener : listeners) { + listener.onLoadReport(report); + } + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/OrcaOobUtilTest.java b/xds/src/test/java/io/grpc/xds/OrcaOobUtilTest.java new file mode 100644 index 0000000000..f276a1f2ce --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/OrcaOobUtilTest.java @@ -0,0 +1,892 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.ConnectivityState.CONNECTING; +import static io.grpc.ConnectivityState.IDLE; +import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.SHUTDOWN; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +import com.google.common.util.concurrent.MoreExecutors; +import com.google.protobuf.util.Durations; +import io.envoyproxy.udpa.data.orca.v1.OrcaLoadReport; +import io.envoyproxy.udpa.service.orca.v1.OpenRcaServiceGrpc; +import io.envoyproxy.udpa.service.orca.v1.OrcaLoadReportRequest; +import io.grpc.Attributes; +import io.grpc.Channel; +import io.grpc.ChannelLogger; +import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; +import io.grpc.Context; +import io.grpc.Context.CancellationListener; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.ManagedChannel; +import io.grpc.NameResolver; +import io.grpc.Status; +import io.grpc.SynchronizationContext; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.internal.BackoffPolicy; +import io.grpc.internal.FakeClock; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.xds.OrcaOobUtil.OrcaOobReportListener; +import io.grpc.xds.OrcaOobUtil.OrcaReportingConfig; +import io.grpc.xds.OrcaOobUtil.OrcaReportingHelperWrapper; +import io.grpc.xds.OrcaOobUtil.SubchannelImpl; +import java.net.SocketAddress; +import java.text.MessageFormat; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Queue; +import java.util.Random; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** + * Unit tests for {@link OrcaOobUtil} class. + */ +@RunWith(JUnit4.class) +public class OrcaOobUtilTest { + + private static final int NUM_SUBCHANNELS = 2; + private static final Attributes.Key SUBCHANNEL_ATTR_KEY = + Attributes.Key.create("subchannel-attr-for-test"); + private static final OrcaReportingConfig SHORT_INTERVAL_CONFIG = + OrcaReportingConfig.newBuilder().setReportInterval(5L, TimeUnit.NANOSECONDS).build(); + private static final OrcaReportingConfig MEDIUM_INTERVAL_CONFIG = + OrcaReportingConfig.newBuilder().setReportInterval(543L, TimeUnit.MICROSECONDS).build(); + private static final OrcaReportingConfig LONG_INTERVAL_CONFIG = + OrcaReportingConfig.newBuilder().setReportInterval(1232L, TimeUnit.MILLISECONDS).build(); + @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + + @SuppressWarnings({"rawtypes", "unchecked"}) + private final List[] eagLists = new List[NUM_SUBCHANNELS]; + private final SubchannelStateListener[] mockStateListeners = + new SubchannelStateListener[NUM_SUBCHANNELS]; + private final ManagedChannel[] channels = new ManagedChannel[NUM_SUBCHANNELS]; + private final OpenRcaServiceImp[] orcaServiceImps = new OpenRcaServiceImp[NUM_SUBCHANNELS]; + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + + private final FakeClock fakeClock = new FakeClock(); + private final Helper origHelper = mock(Helper.class, delegatesTo(new FakeHelper())); + @Mock + private OrcaOobReportListener mockOrcaListener0; + @Mock + private OrcaOobReportListener mockOrcaListener1; + @Mock + private OrcaOobReportListener mockOrcaListener2; + @Mock private BackoffPolicy.Provider backoffPolicyProvider; + @Mock private BackoffPolicy backoffPolicy1; + @Mock private BackoffPolicy backoffPolicy2; + private FakeSubchannel[] subchannels = new FakeSubchannel[NUM_SUBCHANNELS]; + private OrcaReportingHelperWrapper orcaHelperWrapper; + private OrcaReportingHelperWrapper parentHelperWrapper; + private OrcaReportingHelperWrapper childHelperWrapper; + + private static FakeSubchannel unwrap(Subchannel s) { + return (FakeSubchannel) ((SubchannelImpl) s).delegate(); + } + + private static OrcaLoadReportRequest buildOrcaRequestFromConfig( + OrcaReportingConfig config) { + return OrcaLoadReportRequest.newBuilder() + .setReportInterval(Durations.fromNanos(config.getReportIntervalNanos())) + .build(); + } + + private static void assertLog(List logs, String expectedLog) { + assertThat(logs).containsExactly(expectedLog); + logs.clear(); + } + + @After + public void tearDown() { + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + if (subchannels[i] != null) { + subchannels[i].shutdown(); + } + } + } + + @Test + public void orcaReportingConfig_construct() { + int interval = new Random().nextInt(Integer.MAX_VALUE); + OrcaReportingConfig config = + OrcaReportingConfig.newBuilder() + .setReportInterval(interval, TimeUnit.MICROSECONDS) + .build(); + assertThat(config.getReportIntervalNanos()).isEqualTo(TimeUnit.MICROSECONDS.toNanos(interval)); + String str = config.toString(); + assertThat(str).contains("reportIntervalNanos="); + OrcaReportingConfig rebuildedConfig = config.toBuilder().build(); + assertThat(rebuildedConfig.getReportIntervalNanos()) + .isEqualTo(TimeUnit.MICROSECONDS.toNanos(interval)); + } + + @Before + @SuppressWarnings("unchecked") + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + orcaServiceImps[i] = new OpenRcaServiceImp(); + cleanupRule.register( + InProcessServerBuilder.forName("orca-reporting-test-" + i) + .addService(orcaServiceImps[i]) + .directExecutor() + .build() + .start()); + ManagedChannel channel = + cleanupRule.register( + InProcessChannelBuilder.forName("orca-reporting-test-" + i).directExecutor().build()); + channels[i] = channel; + EquivalentAddressGroup eag = + new EquivalentAddressGroup(new FakeSocketAddress("address-" + i)); + List eagList = Arrays.asList(eag); + eagLists[i] = eagList; + mockStateListeners[i] = mock(SubchannelStateListener.class); + } + + when(backoffPolicyProvider.get()).thenReturn(backoffPolicy1, backoffPolicy2); + when(backoffPolicy1.nextBackoffNanos()).thenReturn(11L, 21L); + when(backoffPolicy2.nextBackoffNanos()).thenReturn(12L, 22L); + + orcaHelperWrapper = + OrcaOobUtil.newOrcaReportingHelperWrapper( + origHelper, + mockOrcaListener0, + backoffPolicyProvider, + fakeClock.getStopwatchSupplier()); + parentHelperWrapper = + OrcaOobUtil.newOrcaReportingHelperWrapper( + origHelper, + mockOrcaListener1, + backoffPolicyProvider, + fakeClock.getStopwatchSupplier()); + childHelperWrapper = + OrcaOobUtil.newOrcaReportingHelperWrapper( + parentHelperWrapper.asHelper(), + mockOrcaListener2, + backoffPolicyProvider, + fakeClock.getStopwatchSupplier()); + } + + @Test + @SuppressWarnings("unchecked") + public void singlePolicyTypicalWorkflow() { + setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); + verify(origHelper, atLeast(0)).getSynchronizationContext(); + verifyNoMoreInteractions(origHelper); + + // Calling createSubchannel() on orcaHelper correctly passes augmented CreateSubchannelArgs + // to origHelper. + ArgumentCaptor createArgsCaptor = ArgumentCaptor.forClass(null); + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + String subchannelAttrValue = "eag attr " + i; + Attributes attrs = + Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, subchannelAttrValue).build(); + assertThat(unwrap(createSubchannel(orcaHelperWrapper.asHelper(), i, attrs))) + .isSameInstanceAs(subchannels[i]); + verify(origHelper, times(i + 1)).createSubchannel(createArgsCaptor.capture()); + assertThat(createArgsCaptor.getValue().getAddresses()).isEqualTo(eagLists[i]); + assertThat(createArgsCaptor.getValue().getAttributes().get(SUBCHANNEL_ATTR_KEY)) + .isEqualTo(subchannelAttrValue); + } + + // ORCA reporting does not start until underlying Subchannel is READY. + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + FakeSubchannel subchannel = subchannels[i]; + OpenRcaServiceImp orcaServiceImp = orcaServiceImps[i]; + SubchannelStateListener mockStateListener = mockStateListeners[i]; + InOrder inOrder = inOrder(mockStateListener); + deliverSubchannelState(i, ConnectivityStateInfo.forNonError(IDLE)); + deliverSubchannelState(i, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + deliverSubchannelState(i, ConnectivityStateInfo.forNonError(CONNECTING)); + + inOrder.verify(mockStateListener) + .onSubchannelState(eq(ConnectivityStateInfo.forNonError(IDLE))); + inOrder.verify(mockStateListener) + .onSubchannelState(eq(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE))); + inOrder.verify(mockStateListener) + .onSubchannelState(eq(ConnectivityStateInfo.forNonError(CONNECTING))); + verifyNoMoreInteractions(mockStateListener); + + assertThat(subchannel.logs).isEmpty(); + assertThat(orcaServiceImp.calls).isEmpty(); + verifyNoMoreInteractions(mockOrcaListener0); + deliverSubchannelState(i, ConnectivityStateInfo.forNonError(READY)); + verify(mockStateListener).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); + assertThat(orcaServiceImp.calls).hasSize(1); + ServerSideCall serverCall = orcaServiceImp.calls.peek(); + assertThat(serverCall.request).isEqualTo(buildOrcaRequestFromConfig(SHORT_INTERVAL_CONFIG)); + assertLog(subchannel.logs, + "DEBUG: Starting ORCA reporting for " + subchannel.getAllAddresses()); + + // Simulate an ORCA service response. Registered listener will receive an ORCA report for + // each backend. + OrcaLoadReport report = OrcaLoadReport.getDefaultInstance(); + serverCall.responseObserver.onNext(report); + assertLog(subchannel.logs, "DEBUG: Received an ORCA report: " + report); + verify(mockOrcaListener0, times(i + 1)).onLoadReport(eq(report)); + } + + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + FakeSubchannel subchannel = subchannels[i]; + SubchannelStateListener mockStateListener = mockStateListeners[i]; + + ServerSideCall serverCall = orcaServiceImps[i].calls.peek(); + assertThat(serverCall.cancelled).isFalse(); + verifyNoMoreInteractions(mockStateListener); + + // Shutting down the subchannel will cancel the ORCA reporting RPC. + subchannel.shutdown(); + verify(mockStateListener).onSubchannelState(eq(ConnectivityStateInfo.forNonError(SHUTDOWN))); + assertThat(serverCall.cancelled).isTrue(); + assertThat(subchannel.logs).isEmpty(); + verifyNoMoreInteractions(mockOrcaListener0); + } + + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + assertThat(orcaServiceImps[i].calls).hasSize(1); + } + + verifyZeroInteractions(backoffPolicyProvider); + } + + @Test + public void twoLevelPoliciesTypicalWorkflow() { + setOrcaReportConfig(childHelperWrapper, SHORT_INTERVAL_CONFIG); + setOrcaReportConfig(parentHelperWrapper, SHORT_INTERVAL_CONFIG); + verify(origHelper, atLeast(0)).getSynchronizationContext(); + verifyNoMoreInteractions(origHelper); + + // Calling createSubchannel() on child helper correctly passes augmented CreateSubchannelArgs + // to origHelper. + ArgumentCaptor createArgsCaptor = ArgumentCaptor.forClass(null); + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + String subchannelAttrValue = "eag attr " + i; + Attributes attrs = + Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, subchannelAttrValue).build(); + assertThat(unwrap(createSubchannel(childHelperWrapper.asHelper(), i, attrs))) + .isSameInstanceAs(subchannels[i]); + verify(origHelper, times(i + 1)).createSubchannel(createArgsCaptor.capture()); + assertThat(createArgsCaptor.getValue().getAddresses()).isEqualTo(eagLists[i]); + assertThat(createArgsCaptor.getValue().getAttributes().get(SUBCHANNEL_ATTR_KEY)) + .isEqualTo(subchannelAttrValue); + } + + // ORCA reporting does not start until underlying Subchannel is READY. + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + FakeSubchannel subchannel = subchannels[i]; + OpenRcaServiceImp orcaServiceImp = orcaServiceImps[i]; + SubchannelStateListener mockStateListener = mockStateListeners[i]; + InOrder inOrder = inOrder(mockStateListener); + deliverSubchannelState(i, ConnectivityStateInfo.forNonError(IDLE)); + deliverSubchannelState(i, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + deliverSubchannelState(i, ConnectivityStateInfo.forNonError(CONNECTING)); + + inOrder + .verify(mockStateListener).onSubchannelState(eq(ConnectivityStateInfo.forNonError(IDLE))); + inOrder + .verify(mockStateListener) + .onSubchannelState(eq(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE))); + inOrder + .verify(mockStateListener) + .onSubchannelState(eq(ConnectivityStateInfo.forNonError(CONNECTING))); + verifyNoMoreInteractions(mockStateListener); + + assertThat(subchannel.logs).isEmpty(); + assertThat(orcaServiceImp.calls).isEmpty(); + verifyNoMoreInteractions(mockOrcaListener1); + verifyNoMoreInteractions(mockOrcaListener2); + deliverSubchannelState(i, ConnectivityStateInfo.forNonError(READY)); + verify(mockStateListener).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); + assertThat(orcaServiceImp.calls).hasSize(1); + ServerSideCall serverCall = orcaServiceImp.calls.peek(); + assertThat(serverCall.request).isEqualTo(buildOrcaRequestFromConfig(SHORT_INTERVAL_CONFIG)); + assertLog(subchannel.logs, + "DEBUG: Starting ORCA reporting for " + subchannel.getAllAddresses()); + + // Simulate an ORCA service response. Registered listener will receive an ORCA report for + // each backend. + OrcaLoadReport report = OrcaLoadReport.getDefaultInstance(); + serverCall.responseObserver.onNext(report); + assertLog(subchannel.logs, "DEBUG: Received an ORCA report: " + report); + verify(mockOrcaListener1, times(i + 1)).onLoadReport(eq(report)); + verify(mockOrcaListener2, times(i + 1)).onLoadReport(eq(report)); + } + + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + FakeSubchannel subchannel = subchannels[i]; + SubchannelStateListener mockStateListener = mockStateListeners[i]; + + ServerSideCall serverCall = orcaServiceImps[i].calls.peek(); + assertThat(serverCall.cancelled).isFalse(); + verifyNoMoreInteractions(mockStateListener); + + // Shutting down the subchannel will cancel the ORCA reporting RPC. + subchannel.shutdown(); + verify(mockStateListener).onSubchannelState(eq(ConnectivityStateInfo.forNonError(SHUTDOWN))); + assertThat(serverCall.cancelled).isTrue(); + assertThat(subchannel.logs).isEmpty(); + verifyNoMoreInteractions(mockOrcaListener1, mockOrcaListener2); + } + + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + assertThat(orcaServiceImps[i].calls).hasSize(1); + } + + verifyZeroInteractions(backoffPolicyProvider); + } + + @Test + @SuppressWarnings("unchecked") + public void orcReportingDisabledWhenServiceNotImplemented() { + setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); + createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); + FakeSubchannel subchannel = subchannels[0]; + OpenRcaServiceImp orcaServiceImp = orcaServiceImps[0]; + SubchannelStateListener mockStateListener = mockStateListeners[0]; + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + verify(mockStateListener).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); + assertThat(orcaServiceImp.calls).hasSize(1); + + ServerSideCall serverCall = orcaServiceImp.calls.poll(); + assertThat(serverCall.request).isEqualTo(buildOrcaRequestFromConfig(SHORT_INTERVAL_CONFIG)); + subchannel.logs.clear(); + serverCall.responseObserver.onError(Status.UNIMPLEMENTED.asException()); + assertLog(subchannel.logs, + "ERROR: OpenRcaService disabled: " + Status.UNIMPLEMENTED); + verifyNoMoreInteractions(mockOrcaListener0); + + // Re-connecting on Subchannel will reset the "disabled" flag and restart ORCA reporting. + assertThat(orcaServiceImp.calls).hasSize(0); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(IDLE)); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + assertLog(subchannel.logs, + "DEBUG: Starting ORCA reporting for " + subchannel.getAllAddresses()); + assertThat(orcaServiceImp.calls).hasSize(1); + serverCall = orcaServiceImp.calls.poll(); + OrcaLoadReport report = OrcaLoadReport.getDefaultInstance(); + serverCall.responseObserver.onNext(report); + assertLog(subchannel.logs, "DEBUG: Received an ORCA report: " + report); + verify(mockOrcaListener0).onLoadReport(eq(report)); + + verifyZeroInteractions(backoffPolicyProvider); + } + + @Test + public void orcaReportingStreamClosedAndRetried() { + setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); + createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); + FakeSubchannel subchannel = subchannels[0]; + OpenRcaServiceImp orcaServiceImp = orcaServiceImps[0]; + SubchannelStateListener mockStateListener = mockStateListeners[0]; + InOrder inOrder = inOrder(mockStateListener, mockOrcaListener0, backoffPolicyProvider, + backoffPolicy1, backoffPolicy2); + + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + inOrder + .verify(mockStateListener).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); + assertLog(subchannel.logs, + "DEBUG: Starting ORCA reporting for " + subchannel.getAllAddresses()); + + // Server closes the ORCA reporting RPC without any response, will start backoff + // sequence 1 (11ns). + orcaServiceImp.calls.poll().responseObserver.onCompleted(); + assertLog(subchannel.logs, + "DEBUG: ORCA reporting stream closed with " + Status.OK + ", backoff in 11" + " ns"); + inOrder.verify(backoffPolicyProvider).get(); + inOrder.verify(backoffPolicy1).nextBackoffNanos(); + verifyRetryAfterNanos(inOrder, orcaServiceImp, 11); + assertLog(subchannel.logs, + "DEBUG: Starting ORCA reporting for " + subchannel.getAllAddresses()); + + // Server closes the ORCA reporting RPC with an error, will continue backoff sequence 1 (21ns). + orcaServiceImp.calls.poll().responseObserver.onError(Status.UNAVAILABLE.asException()); + assertLog(subchannel.logs, + "DEBUG: ORCA reporting stream closed with " + Status.UNAVAILABLE + ", backoff in 21" + + " ns"); + inOrder.verify(backoffPolicy1).nextBackoffNanos(); + verifyRetryAfterNanos(inOrder, orcaServiceImp, 21); + assertLog(subchannel.logs, + "DEBUG: Starting ORCA reporting for " + subchannel.getAllAddresses()); + + // Server responds normally. + OrcaLoadReport report = OrcaLoadReport.getDefaultInstance(); + orcaServiceImp.calls.peek().responseObserver.onNext(report); + assertLog(subchannel.logs, "DEBUG: Received an ORCA report: " + report); + inOrder.verify(mockOrcaListener0).onLoadReport(eq(report)); + + // Server closes the ORCA reporting RPC after a response, will restart immediately. + orcaServiceImp.calls.poll().responseObserver.onCompleted(); + assertThat(subchannel.logs).containsExactly( + "DEBUG: ORCA reporting stream closed with " + Status.OK + ", backoff in 0" + " ns", + "DEBUG: Starting ORCA reporting for " + subchannel.getAllAddresses()); + subchannel.logs.clear(); + + // Backoff policy is set to sequence 2 in previous retry. + // Server closes the ORCA reporting RPC with an error, will start backoff sequence 2 (12ns). + orcaServiceImp.calls.poll().responseObserver.onError(Status.UNAVAILABLE.asException()); + assertLog(subchannel.logs, + "DEBUG: ORCA reporting stream closed with " + Status.UNAVAILABLE + ", backoff in 12" + + " ns"); + inOrder.verify(backoffPolicyProvider).get(); + inOrder.verify(backoffPolicy2).nextBackoffNanos(); + verifyRetryAfterNanos(inOrder, orcaServiceImp, 12); + assertLog(subchannel.logs, + "DEBUG: Starting ORCA reporting for " + subchannel.getAllAddresses()); + + verifyNoMoreInteractions(mockStateListener, mockOrcaListener0, backoffPolicyProvider, + backoffPolicy1, backoffPolicy2); + } + + @Test + public void reportingNotStartedUntilConfigured() { + createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + verify(mockStateListeners[0]) + .onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); + + assertThat(orcaServiceImps[0].calls).isEmpty(); + assertThat(subchannels[0].logs).isEmpty(); + setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); + assertThat(orcaServiceImps[0].calls).hasSize(1); + assertLog(subchannels[0].logs, + "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); + assertThat(orcaServiceImps[0].calls.peek().request) + .isEqualTo(buildOrcaRequestFromConfig(SHORT_INTERVAL_CONFIG)); + } + + @Test + public void updateReportingIntervalBeforeCreatingSubchannel() { + setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); + createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + verify(mockStateListeners[0]).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); + + assertThat(orcaServiceImps[0].calls).hasSize(1); + assertLog(subchannels[0].logs, + "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); + assertThat(orcaServiceImps[0].calls.poll().request) + .isEqualTo(buildOrcaRequestFromConfig(SHORT_INTERVAL_CONFIG)); + } + + @Test + public void updateReportingIntervalBeforeSubchannelReady() { + createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); + setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + verify(mockStateListeners[0]).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); + + assertThat(orcaServiceImps[0].calls).hasSize(1); + assertLog(subchannels[0].logs, + "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); + assertThat(orcaServiceImps[0].calls.poll().request) + .isEqualTo(buildOrcaRequestFromConfig(SHORT_INTERVAL_CONFIG)); + } + + @Test + public void updateReportingIntervalWhenRpcActive() { + // Sets report interval before creating a Subchannel, reporting starts right after suchannel + // state becomes READY. + setOrcaReportConfig(orcaHelperWrapper, MEDIUM_INTERVAL_CONFIG); + createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + verify(mockStateListeners[0]).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); + + assertThat(orcaServiceImps[0].calls).hasSize(1); + assertLog(subchannels[0].logs, + "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); + assertThat(orcaServiceImps[0].calls.peek().request) + .isEqualTo(buildOrcaRequestFromConfig(MEDIUM_INTERVAL_CONFIG)); + + // Make reporting less frequent. + setOrcaReportConfig(orcaHelperWrapper, LONG_INTERVAL_CONFIG); + assertThat(orcaServiceImps[0].calls.poll().cancelled).isTrue(); + assertThat(orcaServiceImps[0].calls).hasSize(1); + assertLog(subchannels[0].logs, + "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); + assertThat(orcaServiceImps[0].calls.peek().request) + .isEqualTo(buildOrcaRequestFromConfig(LONG_INTERVAL_CONFIG)); + + // Configuring with the same report interval again does not restart ORCA RPC. + setOrcaReportConfig(orcaHelperWrapper, LONG_INTERVAL_CONFIG); + assertThat(orcaServiceImps[0].calls.peek().cancelled).isFalse(); + assertThat(subchannels[0].logs).isEmpty(); + + // Make reporting more frequent. + setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); + assertThat(orcaServiceImps[0].calls.poll().cancelled).isTrue(); + assertThat(orcaServiceImps[0].calls).hasSize(1); + assertLog(subchannels[0].logs, + "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); + assertThat(orcaServiceImps[0].calls.poll().request) + .isEqualTo(buildOrcaRequestFromConfig(SHORT_INTERVAL_CONFIG)); + } + + @Test + public void updateReportingIntervalWhenRpcPendingRetry() { + createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); + setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + verify(mockStateListeners[0]).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); + + assertThat(orcaServiceImps[0].calls).hasSize(1); + assertLog(subchannels[0].logs, + "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); + assertThat(orcaServiceImps[0].calls.peek().request) + .isEqualTo(buildOrcaRequestFromConfig(SHORT_INTERVAL_CONFIG)); + + // Server closes the RPC without response, client will retry with backoff. + assertThat(fakeClock.getPendingTasks()).isEmpty(); + orcaServiceImps[0].calls.poll().responseObserver.onCompleted(); + assertLog(subchannels[0].logs, + "DEBUG: ORCA reporting stream closed with " + Status.OK + ", backoff in 11" + + " ns"); + assertThat(fakeClock.getPendingTasks()).hasSize(1); + assertThat(orcaServiceImps[0].calls).isEmpty(); + + // Make reporting less frequent. + setOrcaReportConfig(orcaHelperWrapper, LONG_INTERVAL_CONFIG); + // Retry task will be canceled and restarts new RPC immediately. + assertThat(fakeClock.getPendingTasks()).isEmpty(); + assertThat(orcaServiceImps[0].calls).hasSize(1); + assertLog(subchannels[0].logs, + "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); + assertThat(orcaServiceImps[0].calls.peek().request) + .isEqualTo(buildOrcaRequestFromConfig(LONG_INTERVAL_CONFIG)); + } + + @Test + public void policiesReceiveSameReportIndependently() { + createSubchannel(childHelperWrapper.asHelper(), 0, Attributes.EMPTY); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + + // No helper sets ORCA reporting interval, so load reporting is not started. + verify(mockStateListeners[0]).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); + assertThat(orcaServiceImps[0].calls).isEmpty(); + assertThat(subchannels[0].logs).isEmpty(); + + // Parent helper requests ORCA reports with a certain interval, load reporting starts. + setOrcaReportConfig(parentHelperWrapper, SHORT_INTERVAL_CONFIG); + assertThat(orcaServiceImps[0].calls).hasSize(1); + assertLog(subchannels[0].logs, + "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); + + OrcaLoadReport report = OrcaLoadReport.getDefaultInstance(); + assertThat(orcaServiceImps[0].calls).hasSize(1); + orcaServiceImps[0].calls.peek().responseObserver.onNext(report); + assertLog(subchannels[0].logs, "DEBUG: Received an ORCA report: " + report); + // Only parent helper's listener receives the report. + ArgumentCaptor parentReportCaptor = ArgumentCaptor.forClass(null); + verify(mockOrcaListener1).onLoadReport(parentReportCaptor.capture()); + assertThat(parentReportCaptor.getValue()).isEqualTo(report); + verifyNoMoreInteractions(mockOrcaListener2); + + // Now child helper also wants to receive reports. + setOrcaReportConfig(childHelperWrapper, SHORT_INTERVAL_CONFIG); + orcaServiceImps[0].calls.peek().responseObserver.onNext(report); + assertLog(subchannels[0].logs, "DEBUG: Received an ORCA report: " + report); + // Both helper receives the same report instance. + ArgumentCaptor childReportCaptor = ArgumentCaptor.forClass(null); + verify(mockOrcaListener1, times(2)) + .onLoadReport(parentReportCaptor.capture()); + verify(mockOrcaListener2) + .onLoadReport(childReportCaptor.capture()); + assertThat(childReportCaptor.getValue()).isSameInstanceAs(parentReportCaptor.getValue()); + } + + @Test + public void reportWithMostFrequentIntervalRequested() { + setOrcaReportConfig(parentHelperWrapper, SHORT_INTERVAL_CONFIG); + setOrcaReportConfig(childHelperWrapper, LONG_INTERVAL_CONFIG); + createSubchannel(childHelperWrapper.asHelper(), 0, Attributes.EMPTY); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + verify(mockStateListeners[0]).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); + assertThat(orcaServiceImps[0].calls).hasSize(1); + assertLog(subchannels[0].logs, + "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); + + // The real report interval to be requested is the minimum of intervals requested by helpers. + assertThat(Durations.toNanos(orcaServiceImps[0].calls.peek().request.getReportInterval())) + .isEqualTo(SHORT_INTERVAL_CONFIG.getReportIntervalNanos()); + + // Child helper wants reporting to be more frequent than its current setting while it is still + // less frequent than parent helper. Nothing should happen on existing RPC. + setOrcaReportConfig(childHelperWrapper, MEDIUM_INTERVAL_CONFIG); + assertThat(orcaServiceImps[0].calls.peek().cancelled).isFalse(); + assertThat(subchannels[0].logs).isEmpty(); + + // Parent helper wants reporting to be less frequent. + setOrcaReportConfig(parentHelperWrapper, MEDIUM_INTERVAL_CONFIG); + assertThat(orcaServiceImps[0].calls.poll().cancelled).isTrue(); + assertThat(orcaServiceImps[0].calls).hasSize(1); + assertLog(subchannels[0].logs, + "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); + // ORCA reporting RPC restarts and the the real report interval is adjusted. + assertThat(Durations.toNanos(orcaServiceImps[0].calls.poll().request.getReportInterval())) + .isEqualTo(MEDIUM_INTERVAL_CONFIG.getReportIntervalNanos()); + } + + private void verifyRetryAfterNanos(InOrder inOrder, OpenRcaServiceImp orcaServiceImp, + long nanos) { + assertThat(fakeClock.getPendingTasks()).hasSize(1); + assertThat(orcaServiceImp.calls).isEmpty(); + fakeClock.forwardNanos(nanos - 1); + assertThat(orcaServiceImp.calls).isEmpty(); + inOrder.verifyNoMoreInteractions(); + fakeClock.forwardNanos(1); + assertThat(orcaServiceImp.calls).hasSize(1); + assertThat(fakeClock.getPendingTasks()).isEmpty(); + } + + private void deliverSubchannelState(final int index, final ConnectivityStateInfo newState) { + syncContext.execute( + new Runnable() { + @Override + public void run() { + subchannels[index].stateListener.onSubchannelState(newState); + } + }); + } + + private Subchannel createSubchannel(final Helper helper, final int index, + final Attributes attrs) { + final AtomicReference newSubchannel = new AtomicReference<>(); + syncContext.execute( + new Runnable() { + @Override + public void run() { + Subchannel s = + helper.createSubchannel( + CreateSubchannelArgs.newBuilder() + .setAddresses(eagLists[index]) + .setAttributes(attrs) + .build()); + s.start(mockStateListeners[index]); + newSubchannel.set(s); + } + }); + return newSubchannel.get(); + } + + private void setOrcaReportConfig( + final OrcaReportingHelperWrapper helperWrapper, final OrcaReportingConfig config) { + syncContext.execute(new Runnable() { + @Override + public void run() { + helperWrapper.setReportingConfig(config); + } + }); + } + + private static final class OpenRcaServiceImp extends OpenRcaServiceGrpc.OpenRcaServiceImplBase { + final Queue calls = new ArrayDeque<>(); + + @Override + public void streamCoreMetrics( + OrcaLoadReportRequest request, StreamObserver responseObserver) { + final ServerSideCall call = new ServerSideCall(request, responseObserver); + Context.current() + .addListener( + new CancellationListener() { + @Override + public void cancelled(Context ctx) { + call.cancelled = true; + } + }, + MoreExecutors.directExecutor()); + calls.add(call); + } + } + + private static final class ServerSideCall { + final OrcaLoadReportRequest request; + final StreamObserver responseObserver; + boolean cancelled; + + ServerSideCall(OrcaLoadReportRequest request, StreamObserver responseObserver) { + this.request = request; + this.responseObserver = responseObserver; + } + } + + private static final class FakeSocketAddress extends SocketAddress { + final String name; + + FakeSocketAddress(String name) { + this.name = name; + } + + @Override + public String toString() { + return name; + } + } + + private final class FakeSubchannel extends Subchannel { + final List eagList; + final Attributes attrs; + final Channel channel; + final List logs = new ArrayList<>(); + final int index; + SubchannelStateListener stateListener; + private final ChannelLogger logger = + new ChannelLogger() { + @Override + public void log(ChannelLogLevel level, String msg) { + logs.add(level + ": " + msg); + } + + @Override + public void log(ChannelLogLevel level, String template, Object... args) { + log(level, MessageFormat.format(template, args)); + } + }; + + FakeSubchannel(int index, CreateSubchannelArgs args, Channel channel) { + this.index = index; + this.eagList = args.getAddresses(); + this.attrs = args.getAttributes(); + this.channel = checkNotNull(channel); + } + + @Override + public void start(SubchannelStateListener listener) { + checkState(this.stateListener == null); + this.stateListener = listener; + } + + @Override + public void shutdown() { + deliverSubchannelState(index, ConnectivityStateInfo.forNonError(SHUTDOWN)); + } + + @Override + public void requestConnection() { + throw new AssertionError("Should not be called"); + } + + @Override + public List getAllAddresses() { + return eagList; + } + + @Override + public Attributes getAttributes() { + return attrs; + } + + @Override + public Channel asChannel() { + return channel; + } + + @Override + public ChannelLogger getChannelLogger() { + return logger; + } + } + + private final class FakeHelper extends Helper { + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + int index = -1; + for (int i = 0; i < NUM_SUBCHANNELS; i++) { + if (eagLists[i].equals(args.getAddresses())) { + index = i; + break; + } + } + checkState(index >= 0, "addrs " + args.getAddresses() + " not found"); + FakeSubchannel subchannel = new FakeSubchannel(index, args, channels[index]); + checkState(subchannels[index] == null, "subchannels[" + index + "] already created"); + subchannels[index] = subchannel; + return subchannel; + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + throw new AssertionError("Should not be called"); + } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } + + @Override + public ScheduledExecutorService getScheduledExecutorService() { + return fakeClock.getScheduledExecutorService(); + } + + @Deprecated + @Override + public NameResolver.Factory getNameResolverFactory() { + throw new AssertionError("Should not be called"); + } + + @Override + public String getAuthority() { + throw new AssertionError("Should not be called"); + } + + @Override + public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) { + throw new AssertionError("Should not be called"); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/OrcaPerRequestUtilTest.java b/xds/src/test/java/io/grpc/xds/OrcaPerRequestUtilTest.java new file mode 100644 index 0000000000..6bb3fa6a5a --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/OrcaPerRequestUtilTest.java @@ -0,0 +1,170 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +import io.envoyproxy.udpa.data.orca.v1.OrcaLoadReport; +import io.grpc.ClientStreamTracer; +import io.grpc.Metadata; +import io.grpc.xds.OrcaPerRequestUtil.OrcaPerRequestReportListener; +import io.grpc.xds.OrcaPerRequestUtil.OrcaReportingTracerFactory; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** + * Unit tests for {@link OrcaPerRequestUtil} class. + */ +@RunWith(JUnit4.class) +public class OrcaPerRequestUtilTest { + + private static final ClientStreamTracer.StreamInfo STREAM_INFO = + ClientStreamTracer.StreamInfo.newBuilder().build(); + + @Mock + private OrcaPerRequestReportListener orcaListener1; + @Mock + private OrcaPerRequestReportListener orcaListener2; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + /** + * Tests a single load balance policy's listener receive per-request ORCA reports upon call + * trailer arrives. + */ + @Test + public void singlePolicyTypicalWorkflow() { + // Use a mocked noop stream tracer factory as the original stream tracer factory. + ClientStreamTracer.Factory fakeDelegateFactory = mock(ClientStreamTracer.Factory.class); + ClientStreamTracer fakeTracer = mock(ClientStreamTracer.class); + doNothing().when(fakeTracer).inboundTrailers(any(Metadata.class)); + when(fakeDelegateFactory.newClientStreamTracer( + any(ClientStreamTracer.StreamInfo.class), any(Metadata.class))) + .thenReturn(fakeTracer); + + // The OrcaReportingTracerFactory will augment the StreamInfo passed to its + // newClientStreamTracer method. The augmented StreamInfo's CallOptions will contain + // a OrcaReportBroker, in which has the registered listener. + ClientStreamTracer.Factory factory = + OrcaPerRequestUtil.getInstance() + .newOrcaClientStreamTracerFactory(fakeDelegateFactory, orcaListener1); + ClientStreamTracer tracer = factory.newClientStreamTracer(STREAM_INFO, new Metadata()); + ArgumentCaptor streamInfoCaptor = ArgumentCaptor.forClass(null); + verify(fakeDelegateFactory) + .newClientStreamTracer(streamInfoCaptor.capture(), any(Metadata.class)); + ClientStreamTracer.StreamInfo capturedInfo = streamInfoCaptor.getValue(); + assertThat(capturedInfo).isNotEqualTo(STREAM_INFO); + + // When the trailer does not contain ORCA report, listener callback will not be invoked. + Metadata trailer = new Metadata(); + tracer.inboundTrailers(trailer); + verifyNoMoreInteractions(orcaListener1); + + // When the trailer contains an ORCA report, listener callback will be invoked. + trailer.put( + OrcaReportingTracerFactory.ORCA_ENDPOINT_LOAD_METRICS_KEY, + OrcaLoadReport.getDefaultInstance()); + tracer.inboundTrailers(trailer); + ArgumentCaptor reportCaptor = ArgumentCaptor.forClass(null); + verify(orcaListener1).onLoadReport(reportCaptor.capture()); + assertThat(reportCaptor.getValue()).isEqualTo(OrcaLoadReport.getDefaultInstance()); + } + + /** + * Tests parent-child load balance policies' listeners both receive per-request ORCA reports upon + * call trailer arrives and ORCA report deserialization happens only once. + */ + @Test + public void twoLevelPoliciesTypicalWorkflow() { + ClientStreamTracer.Factory parentFactory = + mock(ClientStreamTracer.Factory.class, + delegatesTo( + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(orcaListener1))); + + ClientStreamTracer.Factory childFactory = + OrcaPerRequestUtil.getInstance() + .newOrcaClientStreamTracerFactory(parentFactory, orcaListener2); + // Child factory will augment the StreamInfo and pass it to the parent factory. + ClientStreamTracer childTracer = + childFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + ArgumentCaptor streamInfoCaptor = ArgumentCaptor.forClass(null); + verify(parentFactory).newClientStreamTracer(streamInfoCaptor.capture(), any(Metadata.class)); + ClientStreamTracer.StreamInfo parentStreamInfo = streamInfoCaptor.getValue(); + assertThat(parentStreamInfo).isNotEqualTo(STREAM_INFO); + + // When the trailer does not contain ORCA report, no listener callback will be invoked. + Metadata trailer = new Metadata(); + childTracer.inboundTrailers(trailer); + verifyNoMoreInteractions(orcaListener1); + verifyNoMoreInteractions(orcaListener2); + + // When the trailer contains an ORCA report, callbacks for both listeners will be invoked. + // Both listener will receive the same ORCA report instance, which means deserialization + // happens only once. + trailer.put( + OrcaReportingTracerFactory.ORCA_ENDPOINT_LOAD_METRICS_KEY, + OrcaLoadReport.getDefaultInstance()); + childTracer.inboundTrailers(trailer); + ArgumentCaptor parentReportCap = ArgumentCaptor.forClass(null); + ArgumentCaptor childReportCap = ArgumentCaptor.forClass(null); + verify(orcaListener1).onLoadReport(parentReportCap.capture()); + verify(orcaListener2).onLoadReport(childReportCap.capture()); + assertThat(parentReportCap.getValue()).isEqualTo(OrcaLoadReport.getDefaultInstance()); + assertThat(childReportCap.getValue()).isSameInstanceAs(parentReportCap.getValue()); + } + + /** + * Tests the case when parent policy creates its own {@link ClientStreamTracer.Factory}, ORCA + * reports are only forwarded to the parent's listener. + */ + @Test + public void onlyParentPolicyReceivesReportsIfCreatesOwnTracer() { + ClientStreamTracer.Factory parentFactory = + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(orcaListener1); + ClientStreamTracer.Factory childFactory = + mock(ClientStreamTracer.Factory.class, + delegatesTo(OrcaPerRequestUtil.getInstance() + .newOrcaClientStreamTracerFactory(parentFactory, orcaListener2))); + ClientStreamTracer parentTracer = + parentFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + Metadata trailer = new Metadata(); + trailer.put( + OrcaReportingTracerFactory.ORCA_ENDPOINT_LOAD_METRICS_KEY, + OrcaLoadReport.getDefaultInstance()); + parentTracer.inboundTrailers(trailer); + verify(orcaListener1).onLoadReport(eq(OrcaLoadReport.getDefaultInstance())); + verifyZeroInteractions(childFactory); + verifyZeroInteractions(orcaListener2); + } +}