From 1775ab3847cfbb523f735b67d61af28cb20f1b44 Mon Sep 17 00:00:00 2001 From: Kun Zhang Date: Fri, 29 Jul 2016 15:02:07 -0700 Subject: [PATCH] core: call newStream() and applyRequestMetadata() under context. `ClientTransport.newStream()` and `CallCredentials.applyRequestMetadata()` is now called under the context of the call. This can be used to pass any call-specific information to `CallCredentials`. --- .../main/java/io/grpc/CallCredentials.java | 10 +- .../java/io/grpc/internal/ClientCallImpl.java | 7 +- .../io/grpc/internal/ClientTransport.java | 2 + .../grpc/internal/DelayedClientTransport.java | 12 +- .../io/grpc/internal/MetadataApplierImpl.java | 12 +- .../grpc/internal/ManagedChannelImplTest.java | 110 ++++++++++++++++++ 6 files changed, 145 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/io/grpc/CallCredentials.java b/core/src/main/java/io/grpc/CallCredentials.java index 5c24a64ec2..32defb4b2d 100644 --- a/core/src/main/java/io/grpc/CallCredentials.java +++ b/core/src/main/java/io/grpc/CallCredentials.java @@ -60,11 +60,11 @@ public interface CallCredentials { * Pass the credential data to the given {@link MetadataApplier}, which will propagate it to * the request metadata. * - *

It is called for each individual RPC, before the stream is about to be created on a - * transport. Implementations should not block in this method. If metadata is not immediately - * available, e.g., needs to be fetched from network, the implementation may give the {@code - * applier} to an asynchronous task which will eventually call the {@code applier}. The RPC - * proceeds only after the {@code applier} is called. + *

It is called for each individual RPC, within the {@link Context} of the call, before the + * stream is about to be created on a transport. Implementations should not block in this + * method. If metadata is not immediately available, e.g., needs to be fetched from network, the + * implementation may give the {@code applier} to an asynchronous task which will eventually call + * the {@code applier}. The RPC proceeds only after the {@code applier} is called. * * @param method The method descriptor of this RPC * @param attrs Additional attributes from the transport, along with the keys defined in this diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index 5b0d4c2de2..b6b2b53d2c 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -211,7 +211,12 @@ final class ClientCallImpl extends ClientCall updateTimeoutHeaders(effectiveDeadline, callOptions.getDeadline(), context.getDeadline(), headers); ClientTransport transport = clientTransportProvider.get(callOptions); - stream = transport.newStream(method, headers, callOptions); + Context origContext = context.attach(); + try { + stream = transport.newStream(method, headers, callOptions); + } finally { + context.detach(origContext); + } } else { stream = new FailingClientStream(DEADLINE_EXCEEDED); } diff --git a/core/src/main/java/io/grpc/internal/ClientTransport.java b/core/src/main/java/io/grpc/internal/ClientTransport.java index cd92b220d2..a6d360f83a 100644 --- a/core/src/main/java/io/grpc/internal/ClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ClientTransport.java @@ -55,6 +55,8 @@ public interface ClientTransport { * the error information. Any sent messages for this stream will be buffered until creation has * completed (either successfully or unsuccessfully). * + *

This method is called under the {@link io.grpc.Context} of the {@link io.grpc.ClientCall}. + * * @param method the descriptor of the remote method to be called for this stream. * @param headers to send at the beginning of the call * @param callOptions runtime options of the call diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index 7a0e157e66..a99395bc0e 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -37,6 +37,7 @@ import com.google.common.base.Supplier; import com.google.common.base.Suppliers; import io.grpc.CallOptions; +import io.grpc.Context; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -370,16 +371,25 @@ class DelayedClientTransport implements ManagedClientTransport { private final MethodDescriptor method; private final Metadata headers; private final CallOptions callOptions; + private final Context context; private PendingStream(MethodDescriptor method, Metadata headers, CallOptions callOptions) { this.method = method; this.headers = headers; this.callOptions = callOptions; + this.context = Context.current(); } private void createRealStream(ClientTransport transport) { - setStream(transport.newStream(method, headers, callOptions)); + ClientStream realStream; + Context origContext = context.attach(); + try { + realStream = transport.newStream(method, headers, callOptions); + } finally { + context.detach(origContext); + } + setStream(realStream); } @Override diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index 386606253c..b5d9bbcaaf 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -37,6 +37,7 @@ import static com.google.common.base.Preconditions.checkState; import io.grpc.CallCredentials.MetadataApplier; import io.grpc.CallOptions; +import io.grpc.Context; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -49,6 +50,7 @@ final class MetadataApplierImpl implements MetadataApplier { private final MethodDescriptor method; private final Metadata origHeaders; private final CallOptions callOptions; + private final Context ctx; private final Object lock = new Object(); @@ -69,6 +71,7 @@ final class MetadataApplierImpl implements MetadataApplier { this.method = method; this.origHeaders = origHeaders; this.callOptions = callOptions; + this.ctx = Context.current(); } @Override @@ -76,7 +79,14 @@ final class MetadataApplierImpl implements MetadataApplier { checkState(!finalized, "apply() or fail() already called"); checkNotNull(headers, "headers"); origHeaders.merge(headers); - finalizeWith(transport.newStream(method, origHeaders, callOptions)); + ClientStream realStream; + Context origCtx = ctx.attach(); + try { + realStream = transport.newStream(method, origHeaders, callOptions); + } finally { + ctx.detach(origCtx); + } + finalizeWith(realStream); } @Override diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 1a5e0ff0e0..f94380cea2 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -34,6 +34,7 @@ package io.grpc.internal; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; @@ -53,12 +54,15 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import io.grpc.Attributes; +import io.grpc.CallCredentials.MetadataApplier; +import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.Compressor; import io.grpc.CompressorRegistry; +import io.grpc.Context; import io.grpc.DecompressorRegistry; import io.grpc.DummyLoadBalancerFactory; import io.grpc.IntegerMarshaller; @@ -67,6 +71,7 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.NameResolver; import io.grpc.ResolvedServerInfo; +import io.grpc.SecurityLevel; import io.grpc.Status; import io.grpc.StringMarshaller; import io.grpc.TransportManager; @@ -91,8 +96,10 @@ import java.net.URI; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.LinkedList; import java.util.List; import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; @@ -136,6 +143,8 @@ public class ManagedChannelImplTest { private ClientCall.Listener mockCallListener3; @Mock private SharedResourceHolder.Resource timerService; + @Mock + private CallCredentials creds; private ArgumentCaptor transportListenerCaptor = ArgumentCaptor.forClass(ManagedClientTransport.Listener.class); @@ -813,6 +822,107 @@ public class ManagedChannelImplTest { assertFalse(ManagedChannelImpl.URI_PATTERN.matcher(" a:/").matches()); // space not matched } + /** + * Test that information such as the Call's context, MethodDescriptor, authority, executor are + * propagated to newStream() and applyRequestMetadata(). + */ + @Test + public void informationPropagatedToNewStreamAndCallCredentials() { + createChannel(new FakeNameResolverFactory(true), NO_INTERCEPTOR); + Metadata headers = new Metadata(); + CallOptions callOptions = CallOptions.DEFAULT.withCallCredentials(creds); + final Context.Key testKey = Context.key("testing"); + Context ctx = Context.current().withValue(testKey, "testValue"); + final LinkedList credsApplyContexts = new LinkedList(); + final LinkedList newStreamContexts = new LinkedList(); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) throws Throwable { + credsApplyContexts.add(Context.current()); + return null; + } + }).when(creds).applyRequestMetadata( + any(MethodDescriptor.class), any(Attributes.class), any(Executor.class), + any(MetadataApplier.class)); + + final ConnectionClientTransport transport = mock(ConnectionClientTransport.class); + when(transport.getAttrs()).thenReturn(Attributes.EMPTY); + when(mockTransportFactory.newClientTransport(any(SocketAddress.class), any(String.class), + any(String.class))).thenReturn(transport); + doAnswer(new Answer() { + @Override + public ClientStream answer(InvocationOnMock in) throws Throwable { + newStreamContexts.add(Context.current()); + return mock(ClientStream.class); + } + }).when(transport).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + + // First call will be on delayed transport. Only newCall() is run within the expected context, + // so that we can verify that the context is explicitly attached before calling newStream() and + // applyRequestMetadata(), which happens after we detach the context from the thread. + Context origCtx = ctx.attach(); + assertEquals("testValue", testKey.get()); + ClientCall call = channel.newCall(method, callOptions); + ctx.detach(origCtx); + assertNull(testKey.get()); + call.start(mockCallListener, new Metadata()); + + ArgumentCaptor transportListenerCaptor = + ArgumentCaptor.forClass(ManagedClientTransport.Listener.class); + verify(mockTransportFactory).newClientTransport( + same(socketAddress), eq(authority), eq(userAgent)); + verify(transport).start(transportListenerCaptor.capture()); + verify(creds, never()).applyRequestMetadata( + any(MethodDescriptor.class), any(Attributes.class), any(Executor.class), + any(MetadataApplier.class)); + + // applyRequestMetadata() is called after the transport becomes ready. + transportListenerCaptor.getValue().transportReady(); + executor.runDueTasks(); + ArgumentCaptor attrsCaptor = ArgumentCaptor.forClass(Attributes.class); + ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(MetadataApplier.class); + verify(creds).applyRequestMetadata(same(method), attrsCaptor.capture(), + same(executor.scheduledExecutorService), applierCaptor.capture()); + assertEquals("testValue", testKey.get(credsApplyContexts.poll())); + assertEquals(authority, attrsCaptor.getValue().get(CallCredentials.ATTR_AUTHORITY)); + assertEquals(SecurityLevel.NONE, + attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL)); + verify(transport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + + // newStream() is called after apply() is called + applierCaptor.getValue().apply(new Metadata()); + verify(transport).newStream(same(method), any(Metadata.class), same(callOptions)); + assertEquals("testValue", testKey.get(newStreamContexts.poll())); + // The context should not live beyond the scope of newStream() and applyRequestMetadata() + assertNull(testKey.get()); + + + // Second call will not be on delayed transport + origCtx = ctx.attach(); + call = channel.newCall(method, callOptions); + ctx.detach(origCtx); + call.start(mockCallListener, new Metadata()); + + verify(creds, times(2)).applyRequestMetadata(same(method), attrsCaptor.capture(), + same(executor.scheduledExecutorService), applierCaptor.capture()); + assertEquals("testValue", testKey.get(credsApplyContexts.poll())); + assertEquals(authority, attrsCaptor.getValue().get(CallCredentials.ATTR_AUTHORITY)); + assertEquals(SecurityLevel.NONE, + attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL)); + // This is from the first call + verify(transport).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + + // Still, newStream() is called after apply() is called + applierCaptor.getValue().apply(new Metadata()); + verify(transport, times(2)).newStream(same(method), any(Metadata.class), same(callOptions)); + assertEquals("testValue", testKey.get(newStreamContexts.poll())); + + assertNull(testKey.get()); + } + private static class FakeBackoffPolicyProvider implements BackoffPolicy.Provider { @Override public BackoffPolicy get() {