core: refactor flags in CensusStatsModule. (#5095)

There are currently three boolean flags, and there will be one more
soon.  Put them all in the top-level class instead of passing them as
arguments on lower levels.
This commit is contained in:
Kun Zhang 2018-11-28 14:20:40 -08:00 committed by GitHub
parent 81121fd8e4
commit 2961857451
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 60 additions and 90 deletions

View File

@ -432,12 +432,12 @@ public abstract class AbstractManagedChannelImplBuilder
temporarilyDisableRetry = true; temporarilyDisableRetry = true;
CensusStatsModule censusStats = this.censusStatsOverride; CensusStatsModule censusStats = this.censusStatsOverride;
if (censusStats == null) { if (censusStats == null) {
censusStats = new CensusStatsModule(GrpcUtil.STOPWATCH_SUPPLIER, true); censusStats = new CensusStatsModule(
GrpcUtil.STOPWATCH_SUPPLIER, true, recordStartedRpcs, recordFinishedRpcs);
} }
// First interceptor runs last (see ClientInterceptors.intercept()), so that no // First interceptor runs last (see ClientInterceptors.intercept()), so that no
// other interceptor can override the tracer factory we set in CallOptions. // other interceptor can override the tracer factory we set in CallOptions.
effectiveInterceptors.add( effectiveInterceptors.add(0, censusStats.getClientInterceptor());
0, censusStats.getClientInterceptor(recordStartedRpcs, recordFinishedRpcs));
} }
if (tracingEnabled) { if (tracingEnabled) {
temporarilyDisableRetry = true; temporarilyDisableRetry = true;

View File

@ -265,10 +265,10 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
if (statsEnabled) { if (statsEnabled) {
CensusStatsModule censusStats = this.censusStatsOverride; CensusStatsModule censusStats = this.censusStatsOverride;
if (censusStats == null) { if (censusStats == null) {
censusStats = new CensusStatsModule(GrpcUtil.STOPWATCH_SUPPLIER, true); censusStats = new CensusStatsModule(
GrpcUtil.STOPWATCH_SUPPLIER, true, recordStartedRpcs, recordFinishedRpcs);
} }
tracerFactories.add( tracerFactories.add(censusStats.getServerTracerFactory());
censusStats.getServerTracerFactory(recordStartedRpcs, recordFinishedRpcs));
} }
if (tracingEnabled) { if (tracingEnabled) {
CensusTracingModule censusTracing = CensusTracingModule censusTracing =

View File

@ -75,17 +75,20 @@ public final class CensusStatsModule {
@VisibleForTesting @VisibleForTesting
final Metadata.Key<TagContext> statsHeader; final Metadata.Key<TagContext> statsHeader;
private final boolean propagateTags; private final boolean propagateTags;
private final boolean recordStartedRpcs;
private final boolean recordFinishedRpcs;
/** /**
* Creates a {@link CensusStatsModule} with the default OpenCensus implementation. * Creates a {@link CensusStatsModule} with the default OpenCensus implementation.
*/ */
CensusStatsModule(Supplier<Stopwatch> stopwatchSupplier, boolean propagateTags) { CensusStatsModule(Supplier<Stopwatch> stopwatchSupplier,
boolean propagateTags, boolean recordStartedRpcs, boolean recordFinishedRpcs) {
this( this(
Tags.getTagger(), Tags.getTagger(),
Tags.getTagPropagationComponent().getBinarySerializer(), Tags.getTagPropagationComponent().getBinarySerializer(),
Stats.getStatsRecorder(), Stats.getStatsRecorder(),
stopwatchSupplier, stopwatchSupplier,
propagateTags); propagateTags, recordStartedRpcs, recordFinishedRpcs);
} }
/** /**
@ -95,12 +98,14 @@ public final class CensusStatsModule {
final Tagger tagger, final Tagger tagger,
final TagContextBinarySerializer tagCtxSerializer, final TagContextBinarySerializer tagCtxSerializer,
StatsRecorder statsRecorder, Supplier<Stopwatch> stopwatchSupplier, StatsRecorder statsRecorder, Supplier<Stopwatch> stopwatchSupplier,
boolean propagateTags) { boolean propagateTags, boolean recordStartedRpcs, boolean recordFinishedRpcs) {
this.tagger = checkNotNull(tagger, "tagger"); this.tagger = checkNotNull(tagger, "tagger");
this.statsRecorder = checkNotNull(statsRecorder, "statsRecorder"); this.statsRecorder = checkNotNull(statsRecorder, "statsRecorder");
checkNotNull(tagCtxSerializer, "tagCtxSerializer"); checkNotNull(tagCtxSerializer, "tagCtxSerializer");
this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier"); this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier");
this.propagateTags = propagateTags; this.propagateTags = propagateTags;
this.recordStartedRpcs = recordStartedRpcs;
this.recordFinishedRpcs = recordFinishedRpcs;
this.statsHeader = this.statsHeader =
Metadata.Key.of("grpc-tags-bin", new Metadata.BinaryMarshaller<TagContext>() { Metadata.Key.of("grpc-tags-bin", new Metadata.BinaryMarshaller<TagContext>() {
@Override @Override
@ -131,25 +136,22 @@ public final class CensusStatsModule {
*/ */
@VisibleForTesting @VisibleForTesting
ClientCallTracer newClientCallTracer( ClientCallTracer newClientCallTracer(
TagContext parentCtx, String fullMethodName, TagContext parentCtx, String fullMethodName) {
boolean recordStartedRpcs, boolean recordFinishedRpcs) { return new ClientCallTracer(this, parentCtx, fullMethodName);
return new ClientCallTracer(
this, parentCtx, fullMethodName, recordStartedRpcs, recordFinishedRpcs);
} }
/** /**
* Returns the server tracer factory. * Returns the server tracer factory.
*/ */
ServerStreamTracer.Factory getServerTracerFactory( ServerStreamTracer.Factory getServerTracerFactory() {
boolean recordStartedRpcs, boolean recordFinishedRpcs) { return new ServerTracerFactory();
return new ServerTracerFactory(recordStartedRpcs, recordFinishedRpcs);
} }
/** /**
* Returns the client interceptor that facilitates Census-based stats reporting. * Returns the client interceptor that facilitates Census-based stats reporting.
*/ */
ClientInterceptor getClientInterceptor(boolean recordStartedRpcs, boolean recordFinishedRpcs) { ClientInterceptor getClientInterceptor() {
return new StatsClientInterceptor(recordStartedRpcs, recordFinishedRpcs); return new StatsClientInterceptor();
} }
private static final class ClientTracer extends ClientStreamTracer { private static final class ClientTracer extends ClientStreamTracer {
@ -275,8 +277,6 @@ public final class CensusStatsModule {
} }
} }
@VisibleForTesting @VisibleForTesting
static final class ClientCallTracer extends ClientStreamTracer.Factory { static final class ClientCallTracer extends ClientStreamTracer.Factory {
@Nullable @Nullable
@ -314,24 +314,16 @@ public final class CensusStatsModule {
private volatile int callEnded; private volatile int callEnded;
private final TagContext parentCtx; private final TagContext parentCtx;
private final TagContext startCtx; private final TagContext startCtx;
private final boolean recordFinishedRpcs;
ClientCallTracer( ClientCallTracer(CensusStatsModule module, TagContext parentCtx, String fullMethodName) {
CensusStatsModule module, this.module = checkNotNull(module);
TagContext parentCtx,
String fullMethodName,
boolean recordStartedRpcs,
boolean recordFinishedRpcs) {
this.module = module;
this.parentCtx = checkNotNull(parentCtx); this.parentCtx = checkNotNull(parentCtx);
TagValue methodTag = TagValue.create(fullMethodName); TagValue methodTag = TagValue.create(fullMethodName);
this.startCtx = this.startCtx = module.tagger.toBuilder(parentCtx)
module.tagger.toBuilder(parentCtx)
.put(DeprecatedCensusConstants.RPC_METHOD, methodTag) .put(DeprecatedCensusConstants.RPC_METHOD, methodTag)
.build(); .build();
this.stopwatch = module.stopwatchSupplier.get().start(); this.stopwatch = module.stopwatchSupplier.get().start();
this.recordFinishedRpcs = recordFinishedRpcs; if (module.recordStartedRpcs) {
if (recordStartedRpcs) {
module.statsRecorder.newMeasureMap() module.statsRecorder.newMeasureMap()
.put(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT, 1) .put(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT, 1)
.record(startCtx); .record(startCtx);
@ -379,7 +371,7 @@ public final class CensusStatsModule {
} }
callEnded = 1; callEnded = 1;
} }
if (!recordFinishedRpcs) { if (!module.recordFinishedRpcs) {
return; return;
} }
stopwatch.stop(); stopwatch.stop();
@ -482,8 +474,6 @@ public final class CensusStatsModule {
private final TagContext parentCtx; private final TagContext parentCtx;
private volatile int streamClosed; private volatile int streamClosed;
private final Stopwatch stopwatch; private final Stopwatch stopwatch;
private final Tagger tagger;
private final boolean recordFinishedRpcs;
private volatile long outboundMessageCount; private volatile long outboundMessageCount;
private volatile long inboundMessageCount; private volatile long inboundMessageCount;
private volatile long outboundWireSize; private volatile long outboundWireSize;
@ -493,17 +483,11 @@ public final class CensusStatsModule {
ServerTracer( ServerTracer(
CensusStatsModule module, CensusStatsModule module,
TagContext parentCtx, TagContext parentCtx) {
Supplier<Stopwatch> stopwatchSupplier, this.module = checkNotNull(module, "module");
Tagger tagger,
boolean recordStartedRpcs,
boolean recordFinishedRpcs) {
this.module = module;
this.parentCtx = checkNotNull(parentCtx, "parentCtx"); this.parentCtx = checkNotNull(parentCtx, "parentCtx");
this.stopwatch = stopwatchSupplier.get().start(); this.stopwatch = module.stopwatchSupplier.get().start();
this.tagger = tagger; if (module.recordStartedRpcs) {
this.recordFinishedRpcs = recordFinishedRpcs;
if (recordStartedRpcs) {
module.statsRecorder.newMeasureMap() module.statsRecorder.newMeasureMap()
.put(DeprecatedCensusConstants.RPC_SERVER_STARTED_COUNT, 1) .put(DeprecatedCensusConstants.RPC_SERVER_STARTED_COUNT, 1)
.record(parentCtx); .record(parentCtx);
@ -588,7 +572,7 @@ public final class CensusStatsModule {
} }
streamClosed = 1; streamClosed = 1;
} }
if (!recordFinishedRpcs) { if (!module.recordFinishedRpcs) {
return; return;
} }
stopwatch.stop(); stopwatch.stop();
@ -624,7 +608,7 @@ public final class CensusStatsModule {
@Override @Override
public Context filterContext(Context context) { public Context filterContext(Context context) {
if (!tagger.empty().equals(parentCtx)) { if (!module.tagger.empty().equals(parentCtx)) {
return context.withValue(TAG_CONTEXT_KEY, parentCtx); return context.withValue(TAG_CONTEXT_KEY, parentCtx);
} }
return context; return context;
@ -633,14 +617,6 @@ public final class CensusStatsModule {
@VisibleForTesting @VisibleForTesting
final class ServerTracerFactory extends ServerStreamTracer.Factory { final class ServerTracerFactory extends ServerStreamTracer.Factory {
private final boolean recordStartedRpcs;
private final boolean recordFinishedRpcs;
ServerTracerFactory(boolean recordStartedRpcs, boolean recordFinishedRpcs) {
this.recordStartedRpcs = recordStartedRpcs;
this.recordFinishedRpcs = recordFinishedRpcs;
}
@Override @Override
public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) { public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) {
TagContext parentCtx = headers.get(statsHeader); TagContext parentCtx = headers.get(statsHeader);
@ -653,34 +629,19 @@ public final class CensusStatsModule {
.toBuilder(parentCtx) .toBuilder(parentCtx)
.put(DeprecatedCensusConstants.RPC_METHOD, methodTag) .put(DeprecatedCensusConstants.RPC_METHOD, methodTag)
.build(); .build();
return new ServerTracer( return new ServerTracer(CensusStatsModule.this, parentCtx);
CensusStatsModule.this,
parentCtx,
stopwatchSupplier,
tagger,
recordStartedRpcs,
recordFinishedRpcs);
} }
} }
@VisibleForTesting @VisibleForTesting
final class StatsClientInterceptor implements ClientInterceptor { final class StatsClientInterceptor implements ClientInterceptor {
private final boolean recordStartedRpcs;
private final boolean recordFinishedRpcs;
StatsClientInterceptor(boolean recordStartedRpcs, boolean recordFinishedRpcs) {
this.recordStartedRpcs = recordStartedRpcs;
this.recordFinishedRpcs = recordFinishedRpcs;
}
@Override @Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
// New RPCs on client-side inherit the tag context from the current Context. // New RPCs on client-side inherit the tag context from the current Context.
TagContext parentCtx = tagger.getCurrentTagContext(); TagContext parentCtx = tagger.getCurrentTagContext();
final ClientCallTracer tracerFactory = final ClientCallTracer tracerFactory =
newClientCallTracer(parentCtx, method.getFullMethodName(), newClientCallTracer(parentCtx, method.getFullMethodName());
recordStartedRpcs, recordFinishedRpcs);
ClientCall<ReqT, RespT> call = ClientCall<ReqT, RespT> call =
next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory));
return new SimpleForwardingClientCall<ReqT, RespT>(call) { return new SimpleForwardingClientCall<ReqT, RespT>(call) {

View File

@ -414,7 +414,7 @@ public class AbstractManagedChannelImplBuilderTest {
new FakeTagContextBinarySerializer(), new FakeTagContextBinarySerializer(),
new FakeStatsRecorder(), new FakeStatsRecorder(),
GrpcUtil.STOPWATCH_SUPPLIER, GrpcUtil.STOPWATCH_SUPPLIER,
true)); true, true, true));
} }
Builder(SocketAddress directServerAddress, String authority) { Builder(SocketAddress directServerAddress, String authority) {
@ -425,7 +425,7 @@ public class AbstractManagedChannelImplBuilderTest {
new FakeTagContextBinarySerializer(), new FakeTagContextBinarySerializer(),
new FakeStatsRecorder(), new FakeStatsRecorder(),
GrpcUtil.STOPWATCH_SUPPLIER, GrpcUtil.STOPWATCH_SUPPLIER,
true)); true, true, true));
} }
@Override @Override

View File

@ -91,7 +91,7 @@ public class AbstractServerImplBuilderTest {
new FakeTagContextBinarySerializer(), new FakeTagContextBinarySerializer(),
new FakeStatsRecorder(), new FakeStatsRecorder(),
GrpcUtil.STOPWATCH_SUPPLIER, GrpcUtil.STOPWATCH_SUPPLIER,
true)); true, true, true));
} }
@Override @Override

View File

@ -191,7 +191,8 @@ public class CensusModulesTest {
.thenReturn(fakeClientSpanContext); .thenReturn(fakeClientSpanContext);
censusStats = censusStats =
new CensusStatsModule( new CensusStatsModule(
tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(), true); tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(),
true, true, true);
censusTracing = new CensusTracingModule(tracer, mockTracingPropagationHandler); censusTracing = new CensusTracingModule(tracer, mockTracingPropagationHandler);
} }
@ -240,7 +241,7 @@ public class CensusModulesTest {
Channel interceptedChannel = Channel interceptedChannel =
ClientInterceptors.intercept( ClientInterceptors.intercept(
grpcServerRule.getChannel(), callOptionsCaptureInterceptor, grpcServerRule.getChannel(), callOptionsCaptureInterceptor,
censusStats.getClientInterceptor(true, true), censusTracing.getClientInterceptor()); censusStats.getClientInterceptor(), censusTracing.getClientInterceptor());
ClientCall<String, String> call; ClientCall<String, String> call;
if (nonDefaultContext) { if (nonDefaultContext) {
Context ctx = Context ctx =
@ -353,9 +354,13 @@ public class CensusModulesTest {
} }
private void subtestClientBasicStatsDefaultContext(boolean recordStarts, boolean recordFinishes) { private void subtestClientBasicStatsDefaultContext(boolean recordStarts, boolean recordFinishes) {
CensusStatsModule localCensusStats =
new CensusStatsModule(
tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(),
true, recordStarts, recordFinishes);
CensusStatsModule.ClientCallTracer callTracer = CensusStatsModule.ClientCallTracer callTracer =
censusStats.newClientCallTracer( localCensusStats.newClientCallTracer(
tagger.empty(), method.getFullMethodName(), recordStarts, recordFinishes); tagger.empty(), method.getFullMethodName());
Metadata headers = new Metadata(); Metadata headers = new Metadata();
ClientStreamTracer tracer = callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers); ClientStreamTracer tracer = callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
@ -490,8 +495,7 @@ public class CensusModulesTest {
@Test @Test
public void clientStreamNeverCreatedStillRecordStats() { public void clientStreamNeverCreatedStillRecordStats() {
CensusStatsModule.ClientCallTracer callTracer = CensusStatsModule.ClientCallTracer callTracer =
censusStats.newClientCallTracer( censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName());
tagger.empty(), method.getFullMethodName(), true, true);
fakeClock.forwardTime(3000, MILLISECONDS); fakeClock.forwardTime(3000, MILLISECONDS);
callTracer.callEnded(Status.DEADLINE_EXCEEDED.withDescription("3 seconds")); callTracer.callEnded(Status.DEADLINE_EXCEEDED.withDescription("3 seconds"));
@ -593,10 +597,10 @@ public class CensusModulesTest {
tagCtxSerializer, tagCtxSerializer,
statsRecorder, statsRecorder,
fakeClock.getStopwatchSupplier(), fakeClock.getStopwatchSupplier(),
propagate); propagate, recordStats, recordStats);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
CensusStatsModule.ClientCallTracer callTracer = CensusStatsModule.ClientCallTracer callTracer =
census.newClientCallTracer(clientCtx, method.getFullMethodName(), recordStats, recordStats); census.newClientCallTracer(clientCtx, method.getFullMethodName());
// This propagates clientCtx to headers if propagates==true // This propagates clientCtx to headers if propagates==true
callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers); callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
if (recordStats) { if (recordStats) {
@ -619,8 +623,7 @@ public class CensusModulesTest {
} }
ServerStreamTracer serverTracer = ServerStreamTracer serverTracer =
census.getServerTracerFactory(recordStats, recordStats).newServerStreamTracer( census.getServerTracerFactory().newServerStreamTracer(method.getFullMethodName(), headers);
method.getFullMethodName(), headers);
// Server tracer deserializes clientCtx from the headers, so that it records stats with the // Server tracer deserializes clientCtx from the headers, so that it records stats with the
// propagated tags. // propagated tags.
Context serverContext = serverTracer.filterContext(Context.ROOT); Context serverContext = serverTracer.filterContext(Context.ROOT);
@ -686,10 +689,12 @@ public class CensusModulesTest {
@Test @Test
public void statsHeadersNotPropagateDefaultContext() { public void statsHeadersNotPropagateDefaultContext() {
CensusStatsModule.ClientCallTracer callTracer = CensusStatsModule.ClientCallTracer callTracer =
censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName(), false, false); censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName());
Metadata headers = new Metadata(); Metadata headers = new Metadata();
callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers); callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
assertFalse(headers.containsKey(censusStats.statsHeader)); assertFalse(headers.containsKey(censusStats.statsHeader));
// Clear recorded stats to satisfy the assertions in wrapUp()
statsRecorder.rolloverRecords();
} }
@Test @Test
@ -828,8 +833,11 @@ public class CensusModulesTest {
} }
private void subtestServerBasicStatsNoHeaders(boolean recordStarts, boolean recordFinishes) { private void subtestServerBasicStatsNoHeaders(boolean recordStarts, boolean recordFinishes) {
ServerStreamTracer.Factory tracerFactory = CensusStatsModule localCensusStats =
censusStats.getServerTracerFactory(recordStarts, recordFinishes); new CensusStatsModule(
tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(),
true, recordStarts, recordFinishes);
ServerStreamTracer.Factory tracerFactory = localCensusStats.getServerTracerFactory();
ServerStreamTracer tracer = ServerStreamTracer tracer =
tracerFactory.newServerStreamTracer(method.getFullMethodName(), new Metadata()); tracerFactory.newServerStreamTracer(method.getFullMethodName(), new Metadata());

View File

@ -228,7 +228,7 @@ public abstract class AbstractInteropTest {
tagContextBinarySerializer, tagContextBinarySerializer,
serverStatsRecorder, serverStatsRecorder,
GrpcUtil.STOPWATCH_SUPPLIER, GrpcUtil.STOPWATCH_SUPPLIER,
true)); true, true, true));
try { try {
server = builder.build().start(); server = builder.build().start();
} catch (IOException ex) { } catch (IOException ex) {
@ -330,7 +330,8 @@ public abstract class AbstractInteropTest {
protected final CensusStatsModule createClientCensusStatsModule() { protected final CensusStatsModule createClientCensusStatsModule() {
return new CensusStatsModule( return new CensusStatsModule(
tagger, tagContextBinarySerializer, clientStatsRecorder, GrpcUtil.STOPWATCH_SUPPLIER, true); tagger, tagContextBinarySerializer, clientStatsRecorder, GrpcUtil.STOPWATCH_SUPPLIER,
true, true, true);
} }
/** /**