core: call newStream() and applyRequestMetadata() under context.

`ClientTransport.newStream()` and
`CallCredentials.applyRequestMetadata()` is now called under the context
of the call.  This can be used to pass any call-specific information to
`CallCredentials`.
This commit is contained in:
Kun Zhang 2016-07-29 15:02:07 -07:00
parent aa33c59f0d
commit 1775ab3847
6 changed files with 145 additions and 8 deletions

View File

@ -60,11 +60,11 @@ public interface CallCredentials {
* Pass the credential data to the given {@link MetadataApplier}, which will propagate it to
* the request metadata.
*
* <p>It is called for each individual RPC, before the stream is about to be created on a
* transport. Implementations should not block in this method. If metadata is not immediately
* available, e.g., needs to be fetched from network, the implementation may give the {@code
* applier} to an asynchronous task which will eventually call the {@code applier}. The RPC
* proceeds only after the {@code applier} is called.
* <p>It is called for each individual RPC, within the {@link Context} of the call, before the
* stream is about to be created on a transport. Implementations should not block in this
* method. If metadata is not immediately available, e.g., needs to be fetched from network, the
* implementation may give the {@code applier} to an asynchronous task which will eventually call
* the {@code applier}. The RPC proceeds only after the {@code applier} is called.
*
* @param method The method descriptor of this RPC
* @param attrs Additional attributes from the transport, along with the keys defined in this

View File

@ -211,7 +211,12 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
updateTimeoutHeaders(effectiveDeadline, callOptions.getDeadline(),
context.getDeadline(), headers);
ClientTransport transport = clientTransportProvider.get(callOptions);
stream = transport.newStream(method, headers, callOptions);
Context origContext = context.attach();
try {
stream = transport.newStream(method, headers, callOptions);
} finally {
context.detach(origContext);
}
} else {
stream = new FailingClientStream(DEADLINE_EXCEEDED);
}

View File

@ -55,6 +55,8 @@ public interface ClientTransport {
* the error information. Any sent messages for this stream will be buffered until creation has
* completed (either successfully or unsuccessfully).
*
* <p>This method is called under the {@link io.grpc.Context} of the {@link io.grpc.ClientCall}.
*
* @param method the descriptor of the remote method to be called for this stream.
* @param headers to send at the beginning of the call
* @param callOptions runtime options of the call

View File

@ -37,6 +37,7 @@ import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import io.grpc.CallOptions;
import io.grpc.Context;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
@ -370,16 +371,25 @@ class DelayedClientTransport implements ManagedClientTransport {
private final MethodDescriptor<?, ?> method;
private final Metadata headers;
private final CallOptions callOptions;
private final Context context;
private PendingStream(MethodDescriptor<?, ?> method, Metadata headers,
CallOptions callOptions) {
this.method = method;
this.headers = headers;
this.callOptions = callOptions;
this.context = Context.current();
}
private void createRealStream(ClientTransport transport) {
setStream(transport.newStream(method, headers, callOptions));
ClientStream realStream;
Context origContext = context.attach();
try {
realStream = transport.newStream(method, headers, callOptions);
} finally {
context.detach(origContext);
}
setStream(realStream);
}
@Override

View File

@ -37,6 +37,7 @@ import static com.google.common.base.Preconditions.checkState;
import io.grpc.CallCredentials.MetadataApplier;
import io.grpc.CallOptions;
import io.grpc.Context;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
@ -49,6 +50,7 @@ final class MetadataApplierImpl implements MetadataApplier {
private final MethodDescriptor<?, ?> method;
private final Metadata origHeaders;
private final CallOptions callOptions;
private final Context ctx;
private final Object lock = new Object();
@ -69,6 +71,7 @@ final class MetadataApplierImpl implements MetadataApplier {
this.method = method;
this.origHeaders = origHeaders;
this.callOptions = callOptions;
this.ctx = Context.current();
}
@Override
@ -76,7 +79,14 @@ final class MetadataApplierImpl implements MetadataApplier {
checkState(!finalized, "apply() or fail() already called");
checkNotNull(headers, "headers");
origHeaders.merge(headers);
finalizeWith(transport.newStream(method, origHeaders, callOptions));
ClientStream realStream;
Context origCtx = ctx.attach();
try {
realStream = transport.newStream(method, origHeaders, callOptions);
} finally {
ctx.detach(origCtx);
}
finalizeWith(realStream);
}
@Override

View File

@ -34,6 +34,7 @@ package io.grpc.internal;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
@ -53,12 +54,15 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import io.grpc.Attributes;
import io.grpc.CallCredentials.MetadataApplier;
import io.grpc.CallCredentials;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.Context;
import io.grpc.DecompressorRegistry;
import io.grpc.DummyLoadBalancerFactory;
import io.grpc.IntegerMarshaller;
@ -67,6 +71,7 @@ import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.NameResolver;
import io.grpc.ResolvedServerInfo;
import io.grpc.SecurityLevel;
import io.grpc.Status;
import io.grpc.StringMarshaller;
import io.grpc.TransportManager;
@ -91,8 +96,10 @@ import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
@ -136,6 +143,8 @@ public class ManagedChannelImplTest {
private ClientCall.Listener<Integer> mockCallListener3;
@Mock
private SharedResourceHolder.Resource<ScheduledExecutorService> timerService;
@Mock
private CallCredentials creds;
private ArgumentCaptor<ManagedClientTransport.Listener> transportListenerCaptor =
ArgumentCaptor.forClass(ManagedClientTransport.Listener.class);
@ -813,6 +822,107 @@ public class ManagedChannelImplTest {
assertFalse(ManagedChannelImpl.URI_PATTERN.matcher(" a:/").matches()); // space not matched
}
/**
* Test that information such as the Call's context, MethodDescriptor, authority, executor are
* propagated to newStream() and applyRequestMetadata().
*/
@Test
public void informationPropagatedToNewStreamAndCallCredentials() {
createChannel(new FakeNameResolverFactory(true), NO_INTERCEPTOR);
Metadata headers = new Metadata();
CallOptions callOptions = CallOptions.DEFAULT.withCallCredentials(creds);
final Context.Key<String> testKey = Context.key("testing");
Context ctx = Context.current().withValue(testKey, "testValue");
final LinkedList<Context> credsApplyContexts = new LinkedList<Context>();
final LinkedList<Context> newStreamContexts = new LinkedList<Context>();
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
credsApplyContexts.add(Context.current());
return null;
}
}).when(creds).applyRequestMetadata(
any(MethodDescriptor.class), any(Attributes.class), any(Executor.class),
any(MetadataApplier.class));
final ConnectionClientTransport transport = mock(ConnectionClientTransport.class);
when(transport.getAttrs()).thenReturn(Attributes.EMPTY);
when(mockTransportFactory.newClientTransport(any(SocketAddress.class), any(String.class),
any(String.class))).thenReturn(transport);
doAnswer(new Answer<ClientStream>() {
@Override
public ClientStream answer(InvocationOnMock in) throws Throwable {
newStreamContexts.add(Context.current());
return mock(ClientStream.class);
}
}).when(transport).newStream(
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class));
// First call will be on delayed transport. Only newCall() is run within the expected context,
// so that we can verify that the context is explicitly attached before calling newStream() and
// applyRequestMetadata(), which happens after we detach the context from the thread.
Context origCtx = ctx.attach();
assertEquals("testValue", testKey.get());
ClientCall<String, Integer> call = channel.newCall(method, callOptions);
ctx.detach(origCtx);
assertNull(testKey.get());
call.start(mockCallListener, new Metadata());
ArgumentCaptor<ManagedClientTransport.Listener> transportListenerCaptor =
ArgumentCaptor.forClass(ManagedClientTransport.Listener.class);
verify(mockTransportFactory).newClientTransport(
same(socketAddress), eq(authority), eq(userAgent));
verify(transport).start(transportListenerCaptor.capture());
verify(creds, never()).applyRequestMetadata(
any(MethodDescriptor.class), any(Attributes.class), any(Executor.class),
any(MetadataApplier.class));
// applyRequestMetadata() is called after the transport becomes ready.
transportListenerCaptor.getValue().transportReady();
executor.runDueTasks();
ArgumentCaptor<Attributes> attrsCaptor = ArgumentCaptor.forClass(Attributes.class);
ArgumentCaptor<MetadataApplier> applierCaptor = ArgumentCaptor.forClass(MetadataApplier.class);
verify(creds).applyRequestMetadata(same(method), attrsCaptor.capture(),
same(executor.scheduledExecutorService), applierCaptor.capture());
assertEquals("testValue", testKey.get(credsApplyContexts.poll()));
assertEquals(authority, attrsCaptor.getValue().get(CallCredentials.ATTR_AUTHORITY));
assertEquals(SecurityLevel.NONE,
attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL));
verify(transport, never()).newStream(
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class));
// newStream() is called after apply() is called
applierCaptor.getValue().apply(new Metadata());
verify(transport).newStream(same(method), any(Metadata.class), same(callOptions));
assertEquals("testValue", testKey.get(newStreamContexts.poll()));
// The context should not live beyond the scope of newStream() and applyRequestMetadata()
assertNull(testKey.get());
// Second call will not be on delayed transport
origCtx = ctx.attach();
call = channel.newCall(method, callOptions);
ctx.detach(origCtx);
call.start(mockCallListener, new Metadata());
verify(creds, times(2)).applyRequestMetadata(same(method), attrsCaptor.capture(),
same(executor.scheduledExecutorService), applierCaptor.capture());
assertEquals("testValue", testKey.get(credsApplyContexts.poll()));
assertEquals(authority, attrsCaptor.getValue().get(CallCredentials.ATTR_AUTHORITY));
assertEquals(SecurityLevel.NONE,
attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL));
// This is from the first call
verify(transport).newStream(
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class));
// Still, newStream() is called after apply() is called
applierCaptor.getValue().apply(new Metadata());
verify(transport, times(2)).newStream(same(method), any(Metadata.class), same(callOptions));
assertEquals("testValue", testKey.get(newStreamContexts.poll()));
assertNull(testKey.get());
}
private static class FakeBackoffPolicyProvider implements BackoffPolicy.Provider {
@Override
public BackoffPolicy get() {