core: pass transport attributes to ClientStreamTracer.Factory.newClientStreamTracer() (#5380)

This will be a new override.  The old override is now deprecated.

In order to pass new information without adding new overrides, I shoved most information
into an object called StreamInfo.  The Metadata is left out to draw attention because
it's mutable.

Motivation: this is needed for correctly supporting pick_first in GRPCLB.  GRPCLB needs to
add a token to the headers, and the token varies among servers.  With round_robin, GRPCLB
create a Subchannel for each server, thus can attach the token when the Subchannel is picked.
To implement pick_first, all server addresses will be put in a single Subchannel, we will
need to add the header in newClientStreamTracer(), by looking up the server address from
the transport attributes and deciding which token to add.
This commit is contained in:
Kun Zhang 2019-02-21 11:13:51 -08:00 committed by GitHub
parent 6c32eaf9d4
commit 83b92cfc9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 135 additions and 37 deletions

View File

@ -16,6 +16,7 @@
package io.grpc.internal;
import io.grpc.Attributes;
import io.grpc.CallOptions;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
@ -49,7 +50,7 @@ public class StatsTraceContextBenchmark {
@BenchmarkMode(Mode.SampleTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public StatsTraceContext newClientContext() {
return StatsTraceContext.newClientContext(CallOptions.DEFAULT, emptyMetadata);
return StatsTraceContext.newClientContext(CallOptions.DEFAULT, Attributes.EMPTY, emptyMetadata);
}
/**

View File

@ -16,6 +16,7 @@
package io.grpc;
import io.grpc.Grpc;
import javax.annotation.concurrent.ThreadSafe;
/**
@ -57,9 +58,44 @@ public abstract class ClientStreamTracer extends StreamTracer {
* @param headers the mutable headers of the stream. It can be safely mutated within this
* method. It should not be saved because it is not safe for read or write after the
* method returns.
*
* @deprecated use {@link #newClientStreamTracer(StreamInfo, Metadata)} instead
*/
@Deprecated
public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) {
throw new UnsupportedOperationException("Not implemented");
}
/**
* Creates a {@link ClientStreamTracer} for a new client stream. This is called inside the
* transport when it's creating the stream.
*
* @param info information about the stream
* @param headers the mutable headers of the stream. It can be safely mutated within this
* method. Changes made to it will be sent by the stream. It should not be saved
* because it is not safe for read or write after the method returns.
*
* @since 1.20.0
*/
@SuppressWarnings("deprecation")
public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) {
return newClientStreamTracer(info.getCallOptions(), headers);
}
}
/**
* Information about a stream.
*/
public abstract static class StreamInfo {
/**
* Returns the attributes of the transport that this stream was created on.
*/
@Grpc.TransportAttr
public abstract Attributes getTransportAttrs();
/**
* Returns the effective CallOptions of the call.
*/
public abstract CallOptions getCallOptions();
}
}

View File

@ -167,7 +167,7 @@ final class InProcessTransport implements ServerTransport, ConnectionClientTrans
final MethodDescriptor<?, ?> method, final Metadata headers, final CallOptions callOptions) {
if (shutdownStatus != null) {
return failedClientStream(
StatsTraceContext.newClientContext(callOptions, headers), shutdownStatus);
StatsTraceContext.newClientContext(callOptions, attributes, headers), shutdownStatus);
}
headers.put(GrpcUtil.USER_AGENT_KEY, userAgent);
@ -186,7 +186,7 @@ final class InProcessTransport implements ServerTransport, ConnectionClientTrans
serverMaxInboundMetadataSize,
metadataSize));
return failedClientStream(
StatsTraceContext.newClientContext(callOptions, headers), status);
StatsTraceContext.newClientContext(callOptions, attributes, headers), status);
}
}
@ -625,7 +625,7 @@ final class InProcessTransport implements ServerTransport, ConnectionClientTrans
InProcessClientStream(CallOptions callOptions, Metadata headers) {
this.callOptions = callOptions;
statsTraceCtx = StatsTraceContext.newClientContext(callOptions, headers);
statsTraceCtx = StatsTraceContext.newClientContext(callOptions, attributes, headers);
}
private synchronized void setListener(ServerStreamListener listener) {

View File

@ -367,7 +367,8 @@ public final class CensusStatsModule {
}
@Override
public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) {
public ClientStreamTracer newClientStreamTracer(
ClientStreamTracer.StreamInfo info, Metadata headers) {
ClientTracer tracer = new ClientTracer(module, startCtx);
// TODO(zhangkun83): Once retry or hedging is implemented, a ClientCall may start more than
// one streams. We will need to update this file to support them.

View File

@ -242,7 +242,8 @@ final class CensusTracingModule {
}
@Override
public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) {
public ClientStreamTracer newClientStreamTracer(
ClientStreamTracer.StreamInfo info, Metadata headers) {
if (span != BlankSpan.INSTANCE) {
headers.discardAll(tracingHeader);
headers.put(tracingHeader, span.getContext());

View File

@ -22,7 +22,6 @@ import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Objects;
import io.grpc.Attributes;
import io.grpc.CallOptions;
import io.grpc.ClientStreamTracer;
import io.grpc.Compressor;
import io.grpc.Deadline;
@ -194,7 +193,8 @@ abstract class RetriableStream<ReqT> implements ClientStream {
final ClientStreamTracer bufferSizeTracer = new BufferSizeTracer(sub);
ClientStreamTracer.Factory tracerFactory = new ClientStreamTracer.Factory() {
@Override
public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) {
public ClientStreamTracer newClientStreamTracer(
ClientStreamTracer.StreamInfo info, Metadata headers) {
return bufferSizeTracer;
}
};

View File

@ -19,6 +19,7 @@ package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Attributes;
import io.grpc.CallOptions;
import io.grpc.ClientStreamTracer;
import io.grpc.Context;
@ -46,16 +47,28 @@ public final class StatsTraceContext {
/**
* Factory method for the client-side.
*/
public static StatsTraceContext newClientContext(CallOptions callOptions, Metadata headers) {
public static StatsTraceContext newClientContext(
final CallOptions callOptions, final Attributes transportAttrs, Metadata headers) {
List<ClientStreamTracer.Factory> factories = callOptions.getStreamTracerFactories();
if (factories.isEmpty()) {
return NOOP;
}
ClientStreamTracer.StreamInfo info = new ClientStreamTracer.StreamInfo() {
@Override
public Attributes getTransportAttrs() {
return transportAttrs;
}
@Override
public CallOptions getCallOptions() {
return callOptions;
}
};
// This array will be iterated multiple times per RPC. Use primitive array instead of Collection
// so that for-each doesn't create an Iterator every time.
StreamTracer[] tracers = new StreamTracer[factories.size()];
for (int i = 0; i < tracers.length; i++) {
tracers[i] = factories.get(i).newClientStreamTracer(callOptions, headers);
tracers[i] = factories.get(i).newClientStreamTracer(info, headers);
}
return new StatsTraceContext(tracers);
}

View File

@ -270,7 +270,8 @@ public class CallOptionsTest {
}
@Override
public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) {
public ClientStreamTracer newClientStreamTracer(
ClientStreamTracer.StreamInfo info, Metadata headers) {
return new ClientStreamTracer() {};
}

View File

@ -105,6 +105,18 @@ public class CensusModulesTest {
CallOptions.Key.createWithDefault("option1", "default");
private static final CallOptions CALL_OPTIONS =
CallOptions.DEFAULT.withOption(CUSTOM_OPTION, "customvalue");
private static final ClientStreamTracer.StreamInfo STREAM_INFO =
new ClientStreamTracer.StreamInfo() {
@Override
public Attributes getTransportAttrs() {
return Attributes.EMPTY;
}
@Override
public CallOptions getCallOptions() {
return CallOptions.DEFAULT;
}
};
private static class StringInputStream extends InputStream {
final String string;
@ -370,7 +382,7 @@ public class CensusModulesTest {
localCensusStats.newClientCallTracer(
tagger.empty(), method.getFullMethodName());
Metadata headers = new Metadata();
ClientStreamTracer tracer = callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
ClientStreamTracer tracer = callTracer.newClientStreamTracer(STREAM_INFO, headers);
if (recordStarts) {
StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord();
@ -494,8 +506,7 @@ public class CensusModulesTest {
CensusTracingModule.ClientCallTracer callTracer =
censusTracing.newClientCallTracer(null, method);
Metadata headers = new Metadata();
ClientStreamTracer clientStreamTracer =
callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
ClientStreamTracer clientStreamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers);
verify(tracer).spanBuilderWithExplicitParent(
eq("Sent.package1.service2.method3"), isNull(Span.class));
verify(spyClientSpan, never()).end(any(EndSpanOptions.class));
@ -655,7 +666,7 @@ public class CensusModulesTest {
CensusStatsModule.ClientCallTracer callTracer =
census.newClientCallTracer(clientCtx, method.getFullMethodName());
// This propagates clientCtx to headers if propagates==true
callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
callTracer.newClientStreamTracer(STREAM_INFO, headers);
if (recordStats) {
// Client upstart record
StatsTestUtils.MetricsRecord clientRecord = statsRecorder.pollRecord();
@ -744,7 +755,7 @@ public class CensusModulesTest {
CensusStatsModule.ClientCallTracer callTracer =
censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName());
Metadata headers = new Metadata();
callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
callTracer.newClientStreamTracer(STREAM_INFO, headers);
assertFalse(headers.containsKey(censusStats.statsHeader));
// Clear recorded stats to satisfy the assertions in wrapUp()
statsRecorder.rolloverRecords();
@ -775,7 +786,7 @@ public class CensusModulesTest {
CensusTracingModule.ClientCallTracer callTracer =
censusTracing.newClientCallTracer(fakeClientParentSpan, method);
Metadata headers = new Metadata();
callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
callTracer.newClientStreamTracer(STREAM_INFO, headers);
verify(mockTracingPropagationHandler).toByteArray(same(fakeClientSpanContext));
verifyNoMoreInteractions(mockTracingPropagationHandler);
@ -803,7 +814,7 @@ public class CensusModulesTest {
censusTracing.newClientCallTracer(fakeClientParentSpan, method);
Metadata headers = new Metadata();
callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
callTracer.newClientStreamTracer(STREAM_INFO, headers);
assertThat(headers.keys()).isNotEmpty();
}
@ -817,7 +828,7 @@ public class CensusModulesTest {
CensusTracingModule.ClientCallTracer callTracer =
censusTracing.newClientCallTracer(BlankSpan.INSTANCE, method);
callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
callTracer.newClientStreamTracer(STREAM_INFO, headers);
assertThat(headers.keys()).isEmpty();
}
@ -834,7 +845,7 @@ public class CensusModulesTest {
CensusTracingModule.ClientCallTracer callTracer =
censusTracing.newClientCallTracer(BlankSpan.INSTANCE, method);
callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
callTracer.newClientStreamTracer(STREAM_INFO, headers);
assertThat(headers.keys()).containsExactlyElementsIn(originalHeaderKeys);
}

View File

@ -41,6 +41,7 @@ import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Attributes;
import io.grpc.CallOptions;
import io.grpc.ClientStreamTracer;
import io.grpc.Codec;
@ -90,6 +91,18 @@ public class RetriableStreamTest {
private static final long MAX_BACKOFF_IN_SECONDS = 700;
private static final double BACKOFF_MULTIPLIER = 2D;
private static final double FAKE_RANDOM = .5D;
private static final ClientStreamTracer.StreamInfo STREAM_INFO =
new ClientStreamTracer.StreamInfo() {
@Override
public Attributes getTransportAttrs() {
return Attributes.EMPTY;
}
@Override
public CallOptions getCallOptions() {
return CallOptions.DEFAULT;
}
};
static {
RetriableStream.setRandom(
@ -168,7 +181,7 @@ public class RetriableStreamTest {
@Override
ClientStream newSubstream(ClientStreamTracer.Factory tracerFactory, Metadata metadata) {
bufferSizeTracer =
tracerFactory.newClientStreamTracer(CallOptions.DEFAULT, new Metadata());
tracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata());
int actualPreviousRpcAttemptsInHeader = metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS) == null
? 0 : Integer.valueOf(metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS));
return retriableStreamRecorder.newSubstream(actualPreviousRpcAttemptsInHeader);

View File

@ -117,7 +117,7 @@ class CronetClientTransport implements ConnectionClientTransport {
final String url = "https://" + authority + defaultPath;
final StatsTraceContext statsTraceCtx =
StatsTraceContext.newClientContext(callOptions, headers);
StatsTraceContext.newClientContext(callOptions, attrs, headers);
class StartCallback implements Runnable {
final CronetClientStream clientStream = new CronetClientStream(
url, userAgent, executor, headers, CronetClientTransport.this, this, lock, maxMessageSize,

View File

@ -19,7 +19,6 @@ package io.grpc.grpclb;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.protobuf.util.Timestamps;
import io.grpc.CallOptions;
import io.grpc.ClientStreamTracer;
import io.grpc.Metadata;
import io.grpc.Status;
@ -75,7 +74,8 @@ final class GrpclbClientLoadRecorder extends ClientStreamTracer.Factory {
}
@Override
public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) {
public ClientStreamTracer newClientStreamTracer(
ClientStreamTracer.StreamInfo info, Metadata headers) {
callsStartedUpdater.getAndIncrement(this);
return new StreamTracer();
}

View File

@ -172,6 +172,19 @@ public class GrpclbLoadBalancerTest {
throw new AssertionError(e);
}
});
private static final ClientStreamTracer.StreamInfo STREAM_INFO =
new ClientStreamTracer.StreamInfo() {
@Override
public Attributes getTransportAttrs() {
return Attributes.EMPTY;
}
@Override
public CallOptions getCallOptions() {
return CallOptions.DEFAULT;
}
};
private io.grpc.Server fakeLbServer;
@Captor
private ArgumentCaptor<SubchannelPicker> pickerCaptor;
@ -467,7 +480,7 @@ public class GrpclbLoadBalancerTest {
ClientStats.newBuilder().build());
ClientStreamTracer tracer1 =
pick1.getStreamTracerFactory().newClientStreamTracer(CallOptions.DEFAULT, new Metadata());
pick1.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, new Metadata());
PickResult pick2 = picker.pickSubchannel(args);
assertNull(pick2.getSubchannel());
@ -490,7 +503,7 @@ public class GrpclbLoadBalancerTest {
assertSame(subchannel2, pick3.getSubchannel());
assertSame(getLoadRecorder(), pick3.getStreamTracerFactory());
ClientStreamTracer tracer3 =
pick3.getStreamTracerFactory().newClientStreamTracer(CallOptions.DEFAULT, new Metadata());
pick3.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, new Metadata());
// pick3 has sent out headers
tracer3.outboundHeaders();
@ -527,7 +540,7 @@ public class GrpclbLoadBalancerTest {
assertSame(subchannel1, pick1.getSubchannel());
assertSame(getLoadRecorder(), pick5.getStreamTracerFactory());
ClientStreamTracer tracer5 =
pick5.getStreamTracerFactory().newClientStreamTracer(CallOptions.DEFAULT, new Metadata());
pick5.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, new Metadata());
// pick3 ended without receiving response headers
tracer3.streamClosed(Status.DEADLINE_EXCEEDED);
@ -602,7 +615,7 @@ public class GrpclbLoadBalancerTest {
PickResult pick1p = picker.pickSubchannel(args);
assertSame(subchannel1, pick1p.getSubchannel());
assertSame(getLoadRecorder(), pick1p.getStreamTracerFactory());
pick1p.getStreamTracerFactory().newClientStreamTracer(CallOptions.DEFAULT, new Metadata());
pick1p.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, new Metadata());
// The pick from the new stream will be included in the report
assertNextReport(

View File

@ -260,7 +260,8 @@ public abstract class AbstractInteropTest {
private final ClientStreamTracer.Factory clientStreamTracerFactory =
new ClientStreamTracer.Factory() {
@Override
public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) {
public ClientStreamTracer newClientStreamTracer(
ClientStreamTracer.StreamInfo info, Metadata headers) {
TestClientStreamTracer tracer = new TestClientStreamTracer();
clientStreamTracers.add(tracer);
return tracer;

View File

@ -161,7 +161,8 @@ class NettyClientTransport implements ConnectionClientTransport {
if (channel == null) {
return new FailingClientStream(statusExplainingWhyTheChannelIsNull);
}
StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext(callOptions, headers);
StatsTraceContext statsTraceCtx =
StatsTraceContext.newClientContext(callOptions, getAttributes(), headers);
return new NettyClientStream(
new NettyClientStream.TransportState(
handler,

View File

@ -372,7 +372,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
final Metadata headers, CallOptions callOptions) {
Preconditions.checkNotNull(method, "method");
Preconditions.checkNotNull(headers, "headers");
StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext(callOptions, headers);
StatsTraceContext statsTraceCtx =
StatsTraceContext.newClientContext(callOptions, attributes, headers);
// FIXME: it is likely wrong to pass the transportTracer here as it'll exit the lock's scope
synchronized (lock) { // to make @GuardedBy linter happy
return new OkHttpClientStream(

View File

@ -187,7 +187,7 @@ public abstract class AbstractTransportTest {
server = Iterables.getOnlyElement(newServer(Arrays.asList(serverStreamTracerFactory)));
OngoingStubbing<ClientStreamTracer> clientStubbing =
when(clientStreamTracerFactory
.newClientStreamTracer(any(CallOptions.class), any(Metadata.class)))
.newClientStreamTracer(any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)))
.thenReturn(clientStreamTracer1)
.thenReturn(clientStreamTracer2);
OngoingStubbing<ServerStreamTracer> serverStubbing =
@ -581,7 +581,7 @@ public abstract class AbstractTransportTest {
// Stream prevents termination
ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions);
inOrder.verify(clientStreamTracerFactory).newClientStreamTracer(
any(CallOptions.class), any(Metadata.class));
any(ClientStreamTracer.StreamInfo.class), any(Metadata.class));
ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase();
stream.start(clientStreamListener);
client.shutdown(Status.UNAVAILABLE);
@ -589,7 +589,7 @@ public abstract class AbstractTransportTest {
ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions);
inOrder.verify(clientStreamTracerFactory).newClientStreamTracer(
any(CallOptions.class), any(Metadata.class));
any(ClientStreamTracer.StreamInfo.class), any(Metadata.class));
ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase();
stream2.start(clientStreamListener2);
Status clientStreamStatus2 =
@ -632,7 +632,7 @@ public abstract class AbstractTransportTest {
assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS));
verify(mockClientTransportListener, never()).transportInUse(anyBoolean());
verify(clientStreamTracerFactory).newClientStreamTracer(
any(CallOptions.class), any(Metadata.class));
any(ClientStreamTracer.StreamInfo.class), any(Metadata.class));
assertNull(clientStreamTracer1.getInboundTrailers());
assertSame(shutdownReason, clientStreamTracer1.getStatus());
// Assert no interactions
@ -753,8 +753,13 @@ public abstract class AbstractTransportTest {
clientHeadersCopy.merge(clientHeaders);
ClientStream clientStream = client.newStream(methodDescriptor, clientHeaders, callOptions);
ArgumentCaptor<ClientStreamTracer.StreamInfo> streamInfoCaptor = ArgumentCaptor.forClass(null);
clientInOrder.verify(clientStreamTracerFactory).newClientStreamTracer(
same(callOptions), same(clientHeaders));
streamInfoCaptor.capture(), same(clientHeaders));
ClientStreamTracer.StreamInfo streamInfo = streamInfoCaptor.getValue();
assertThat(streamInfo.getTransportAttrs()).isSameAs(
((ConnectionClientTransport) client).getAttributes());
assertThat(streamInfo.getCallOptions()).isSameAs(callOptions);
ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase();
clientStream.start(clientStreamListener);
@ -1278,7 +1283,7 @@ public abstract class AbstractTransportTest {
assertNull(clientStreamStatus.getCause());
verify(clientStreamTracerFactory).newClientStreamTracer(
any(CallOptions.class), any(Metadata.class));
any(ClientStreamTracer.StreamInfo.class), any(Metadata.class));
assertTrue(clientStreamTracer1.getOutboundHeaders());
assertNull(clientStreamTracer1.getInboundTrailers());
assertSame(clientStreamStatus, clientStreamTracer1.getStatus());