From 1cc76d8132b99eba004fec909f3afad43a691def Mon Sep 17 00:00:00 2001 From: Carl Mastrangelo Date: Tue, 24 May 2016 16:29:26 -0700 Subject: [PATCH] core,netty,okhttp: move user agent out of client call and into the transport --- .../inprocess/InProcessChannelBuilder.java | 3 +- .../AbstractManagedChannelImplBuilder.java | 6 +- .../java/io/grpc/internal/ClientCallImpl.java | 19 +--- .../grpc/internal/ClientTransportFactory.java | 7 +- .../io/grpc/internal/ManagedChannelImpl.java | 10 +- .../java/io/grpc/internal/TransportSet.java | 11 ++- .../io/grpc/internal/ClientCallImplTest.java | 18 ++-- .../grpc/internal/ManagedChannelImplTest.java | 46 +++++---- ...anagedChannelImplTransportManagerTest.java | 21 +++-- .../test/java/io/grpc/internal/TestUtils.java | 3 +- .../io/grpc/internal/TransportSetTest.java | 94 ++++++++++++------- .../io/grpc/netty/NettyChannelBuilder.java | 8 +- .../io/grpc/netty/NettyClientTransport.java | 12 +-- .../grpc/netty/NettyClientTransportTest.java | 27 +++--- .../io/grpc/netty/NettyTransportTest.java | 7 +- .../src/main/java/io/grpc/okhttp/Headers.java | 6 +- .../io/grpc/okhttp/OkHttpChannelBuilder.java | 7 +- .../io/grpc/okhttp/OkHttpClientStream.java | 9 +- .../io/grpc/okhttp/OkHttpClientTransport.java | 13 ++- .../grpc/okhttp/OkHttpClientStreamTest.java | 2 +- .../okhttp/OkHttpClientTransportTest.java | 34 ++++--- .../io/grpc/okhttp/OkHttpTransportTest.java | 4 +- .../AbstractClientTransportFactoryTest.java | 2 +- 23 files changed, 207 insertions(+), 162 deletions(-) diff --git a/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java b/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java index 7134a22196..5cd04f268b 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java @@ -94,7 +94,8 @@ public class InProcessChannelBuilder extends } @Override - public ManagedClientTransport newClientTransport(SocketAddress addr, String authority) { + public ManagedClientTransport newClientTransport( + SocketAddress addr, String authority, String userAgent) { if (closed) { throw new IllegalStateException("The transport factory is closed."); } diff --git a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java index ae10d4e63e..d14b09f1c7 100644 --- a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java @@ -163,7 +163,7 @@ public abstract class AbstractManagedChannelImplBuilder } @Override - public final T userAgent(String userAgent) { + public final T userAgent(@Nullable String userAgent) { this.userAgent = userAgent; return thisT(); } @@ -232,8 +232,8 @@ public abstract class AbstractManagedChannelImplBuilder @Override public ManagedClientTransport newClientTransport(SocketAddress serverAddress, - String authority) { - return factory.newClientTransport(serverAddress, authorityOverride); + String authority, @Nullable String userAgent) { + return factory.newClientTransport(serverAddress, authorityOverride, userAgent); } @Override diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index d9083c3619..6f86ff6956 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -90,7 +90,6 @@ final class ClientCallImpl extends ClientCall private boolean cancelCalled; private boolean halfCloseCalled; private final ClientTransportProvider clientTransportProvider; - private String userAgent; private ScheduledExecutorService deadlineCancellationExecutor; private DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance(); private CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance(); @@ -129,11 +128,6 @@ final class ClientCallImpl extends ClientCall ClientTransport get(CallOptions callOptions); } - ClientCallImpl setUserAgent(String userAgent) { - this.userAgent = userAgent; - return this; - } - ClientCallImpl setDecompressorRegistry(DecompressorRegistry decompressorRegistry) { this.decompressorRegistry = decompressorRegistry; return this; @@ -145,13 +139,10 @@ final class ClientCallImpl extends ClientCall } @VisibleForTesting - static void prepareHeaders(Metadata headers, CallOptions callOptions, String userAgent, - DecompressorRegistry decompressorRegistry, Compressor compressor) { - // Fill out the User-Agent header. + static void prepareHeaders(Metadata headers, DecompressorRegistry decompressorRegistry, + Compressor compressor) { + // Remove user agent. Agent are added in the transport. headers.removeAll(USER_AGENT_KEY); - if (userAgent != null) { - headers.put(USER_AGENT_KEY, userAgent); - } headers.removeAll(MESSAGE_ENCODING_KEY); if (compressor != Codec.Identity.NONE) { @@ -213,7 +204,7 @@ final class ClientCallImpl extends ClientCall compressor = Codec.Identity.NONE; } - prepareHeaders(headers, callOptions, userAgent, decompressorRegistry, compressor); + prepareHeaders(headers, decompressorRegistry, compressor); final boolean deadlineExceeded = effectiveDeadline != null && effectiveDeadline.isExpired(); if (!deadlineExceeded) { @@ -265,7 +256,7 @@ final class ClientCallImpl extends ClientCall logIfContextNarrowedTimeout(effectiveTimeout, effectiveDeadline, outerCallDeadline, callDeadline); } - + private static void logIfContextNarrowedTimeout(long effectiveTimeout, Deadline effectiveDeadline, @Nullable Deadline outerCallDeadline, @Nullable Deadline callDeadline) { diff --git a/core/src/main/java/io/grpc/internal/ClientTransportFactory.java b/core/src/main/java/io/grpc/internal/ClientTransportFactory.java index 73f62588a3..d61f0999af 100644 --- a/core/src/main/java/io/grpc/internal/ClientTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/ClientTransportFactory.java @@ -34,6 +34,8 @@ package io.grpc.internal; import java.io.Closeable; import java.net.SocketAddress; +import javax.annotation.Nullable; + /** Pre-configured factory for creating {@link ManagedClientTransport} instances. */ public interface ClientTransportFactory extends Closeable { /** @@ -42,13 +44,14 @@ public interface ClientTransportFactory extends Closeable { * @param serverAddress the address that the transport is connected to * @param authority the HTTP/2 authority of the server */ - ManagedClientTransport newClientTransport(SocketAddress serverAddress, String authority); + ManagedClientTransport newClientTransport(SocketAddress serverAddress, String authority, + @Nullable String userAgent); /** * Releases any resources. * *

After this method has been called, it's no longer valid to call - * {@link #newClientTransport(SocketAddress, String)}. No guarantees about thread-safety are made. + * {@link #newClientTransport}. No guarantees about thread-safety are made. */ @Override void close(); diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index e50c4c69b7..617bd7947b 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -92,7 +92,6 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI private final ClientTransportFactory transportFactory; private final Executor executor; private final boolean usingSharedExecutor; - private final String userAgent; private final Object lock = new Object(); private final DecompressorRegistry decompressorRegistry; @@ -110,6 +109,7 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI * any interceptors this will just be {@link RealChannel}. */ private final Channel interceptorChannel; + @Nullable private final String userAgent; private final NameResolver nameResolver; private final LoadBalancer loadBalancer; @@ -159,11 +159,11 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI this.nameResolver = getNameResolver(target, nameResolverFactory, nameResolverParams); this.loadBalancer = loadBalancerFactory.newLoadBalancer(nameResolver.getServiceAuthority(), tm); this.transportFactory = transportFactory; - this.userAgent = userAgent; this.interceptorChannel = ClientInterceptors.intercept(new RealChannel(), interceptors); scheduledExecutor = SharedResourceHolder.get(TIMER_SERVICE); this.decompressorRegistry = decompressorRegistry; this.compressorRegistry = compressorRegistry; + this.userAgent = userAgent; this.nameResolver.start(new NameResolver.Listener() { @Override @@ -344,7 +344,6 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI callOptions, transportProvider, scheduledExecutor) - .setUserAgent(userAgent) .setDecompressorRegistry(decompressorRegistry) .setCompressorRegistry(compressorRegistry); } @@ -394,8 +393,9 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI } ts = transports.get(addressGroup); if (ts == null) { - ts = new TransportSet(addressGroup, authority(), loadBalancer, backoffPolicyProvider, - transportFactory, scheduledExecutor, executor, new TransportSet.Callback() { + ts = new TransportSet(addressGroup, authority(), userAgent, loadBalancer, + backoffPolicyProvider, transportFactory, scheduledExecutor, executor, + new TransportSet.Callback() { @Override public void onTerminated() { synchronized (lock) { diff --git a/core/src/main/java/io/grpc/internal/TransportSet.java b/core/src/main/java/io/grpc/internal/TransportSet.java index 7732837409..0d7355a5ac 100644 --- a/core/src/main/java/io/grpc/internal/TransportSet.java +++ b/core/src/main/java/io/grpc/internal/TransportSet.java @@ -67,6 +67,7 @@ final class TransportSet implements WithLogId { private final Object lock = new Object(); private final EquivalentAddressGroup addressGroup; private final String authority; + private final String userAgent; private final BackoffPolicy.Provider backoffPolicyProvider; private final Callback callback; private final ClientTransportFactory transportFactory; @@ -122,21 +123,22 @@ final class TransportSet implements WithLogId { @Nullable private volatile ManagedClientTransport activeTransport; - TransportSet(EquivalentAddressGroup addressGroup, String authority, + TransportSet(EquivalentAddressGroup addressGroup, String authority, String userAgent, LoadBalancer loadBalancer, BackoffPolicy.Provider backoffPolicyProvider, ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor, Executor appExecutor, Callback callback) { - this(addressGroup, authority, loadBalancer, backoffPolicyProvider, transportFactory, + this(addressGroup, authority, userAgent, loadBalancer, backoffPolicyProvider, transportFactory, scheduledExecutor, appExecutor, callback, Stopwatch.createUnstarted()); } @VisibleForTesting - TransportSet(EquivalentAddressGroup addressGroup, String authority, + TransportSet(EquivalentAddressGroup addressGroup, String authority, String userAgent, LoadBalancer loadBalancer, BackoffPolicy.Provider backoffPolicyProvider, ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor, Executor appExecutor, Callback callback, Stopwatch connectingTimer) { this.addressGroup = Preconditions.checkNotNull(addressGroup, "addressGroup"); this.authority = authority; + this.userAgent = userAgent; this.loadBalancer = loadBalancer; this.backoffPolicyProvider = backoffPolicyProvider; this.transportFactory = transportFactory; @@ -186,7 +188,8 @@ final class TransportSet implements WithLogId { nextAddressIndex = 0; } - ManagedClientTransport transport = transportFactory.newClientTransport(address, authority); + ManagedClientTransport transport = + transportFactory.newClientTransport(address, authority, userAgent); if (log.isLoggable(Level.FINE)) { log.log(Level.FINE, "[{0}] Created {1} for {2}", new Object[] {getLogId(), transport.getLogId(), address}); diff --git a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java index 33d1f3e6a7..745f2436b7 100644 --- a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java @@ -31,6 +31,7 @@ package io.grpc.internal; +import static com.google.common.truth.Truth.assertThat; import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -195,19 +196,18 @@ public class ClientCallImplTest { } @Test - public void prepareHeaders_userAgentAdded() { + public void prepareHeaders_userAgentRemove() { Metadata m = new Metadata(); - ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", decompressorRegistry, - Codec.Identity.NONE); + m.put(GrpcUtil.USER_AGENT_KEY, "batmobile"); + ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE); - assertEquals(m.get(GrpcUtil.USER_AGENT_KEY), "user agent"); + assertThat(m.get(GrpcUtil.USER_AGENT_KEY)).isNull(); } @Test public void prepareHeaders_ignoreIdentityEncoding() { Metadata m = new Metadata(); - ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", decompressorRegistry, - Codec.Identity.NONE); + ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); } @@ -250,8 +250,7 @@ public class ClientCallImplTest { } }, false); // not advertised - ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", customRegistry, - Codec.Identity.NONE); + ClientCallImpl.prepareHeaders(m, customRegistry, Codec.Identity.NONE); Iterable acceptedEncodings = ACCEPT_ENCODING_SPLITER.split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); @@ -267,8 +266,7 @@ public class ClientCallImplTest { m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip"); - ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, null, - DecompressorRegistry.newEmptyInstance(), Codec.Identity.NONE); + ClientCallImpl.prepareHeaders(m, DecompressorRegistry.newEmptyInstance(), Codec.Identity.NONE); assertNull(m.get(GrpcUtil.USER_AGENT_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 6be7f36528..bf96b2ce2f 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -114,6 +114,7 @@ public class ManagedChannelImplTest { private final ExecutorService executor = Executors.newSingleThreadExecutor(); private final String serviceName = "fake.example.com"; private final String authority = serviceName; + private final String userAgent = "userAgent"; private final String target = "fake://" + serviceName; private URI expectedUri; private final SocketAddress socketAddress = new SocketAddress() {}; @@ -146,14 +147,15 @@ public class ManagedChannelImplTest { return new ManagedChannelImpl(target, new FakeBackoffPolicyProvider(), nameResolverFactory, NAME_RESOLVER_PARAMS, loadBalancerFactory, mockTransportFactory, DecompressorRegistry.getDefaultInstance(), - CompressorRegistry.getDefaultInstance(), executor, null, interceptors); + CompressorRegistry.getDefaultInstance(), executor, userAgent, interceptors); } @Before public void setUp() throws Exception { MockitoAnnotations.initMocks(this); expectedUri = new URI(target); - when(mockTransportFactory.newClientTransport(any(SocketAddress.class), any(String.class))) + when(mockTransportFactory.newClientTransport( + any(SocketAddress.class), any(String.class), any(String.class))) .thenReturn(mockTransport); } @@ -195,12 +197,13 @@ public class ManagedChannelImplTest { // Create transport and call ClientStream mockStream = mock(ClientStream.class); Metadata headers = new Metadata(); - when(mockTransportFactory.newClientTransport(any(SocketAddress.class), any(String.class))) + when(mockTransportFactory.newClientTransport( + any(SocketAddress.class), any(String.class), any(String.class))) .thenReturn(mockTransport); when(mockTransport.newStream(same(method), same(headers))).thenReturn(mockStream); call.start(mockCallListener, headers); verify(mockTransportFactory, timeout(1000)) - .newClientTransport(same(socketAddress), eq(authority)); + .newClientTransport(same(socketAddress), eq(authority), eq(userAgent)); verify(mockTransport, timeout(1000)).start(transportListenerCaptor.capture()); ManagedClientTransport.Listener transportListener = transportListenerCaptor.getValue(); transportListener.transportReady(); @@ -442,7 +445,7 @@ public class ManagedChannelImplTest { nameResolverFactory.allResolved(); verify(mockTransportFactory, never()) - .newClientTransport(any(SocketAddress.class), any(String.class)); + .newClientTransport(any(SocketAddress.class), any(String.class), any(String.class)); } /** @@ -467,9 +470,11 @@ public class ManagedChannelImplTest { final ManagedClientTransport badTransport = mock(ManagedClientTransport.class); when(goodTransport.newStream(any(MethodDescriptor.class), any(Metadata.class))) .thenReturn(mock(ClientStream.class)); - when(mockTransportFactory.newClientTransport(same(goodAddress), any(String.class))) + when(mockTransportFactory.newClientTransport( + same(goodAddress), any(String.class), any(String.class))) .thenReturn(goodTransport); - when(mockTransportFactory.newClientTransport(same(badAddress), any(String.class))) + when(mockTransportFactory.newClientTransport( + same(badAddress), any(String.class), any(String.class))) .thenReturn(badTransport); FakeNameResolverFactory nameResolverFactory = @@ -483,16 +488,17 @@ public class ManagedChannelImplTest { ArgumentCaptor badTransportListenerCaptor = ArgumentCaptor.forClass(ManagedClientTransport.Listener.class); verify(badTransport, timeout(1000)).start(badTransportListenerCaptor.capture()); - verify(mockTransportFactory).newClientTransport(same(badAddress), any(String.class)); + verify(mockTransportFactory) + .newClientTransport(same(badAddress), any(String.class), any(String.class)); verify(mockTransportFactory, times(0)) - .newClientTransport(same(goodAddress), any(String.class)); + .newClientTransport(same(goodAddress), any(String.class), any(String.class)); badTransportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE); // The channel then try the second address (goodAddress) ArgumentCaptor goodTransportListenerCaptor = ArgumentCaptor.forClass(ManagedClientTransport.Listener.class); verify(mockTransportFactory, timeout(1000)) - .newClientTransport(same(goodAddress), any(String.class)); + .newClientTransport(same(goodAddress), any(String.class), any(String.class)); verify(goodTransport, timeout(1000)).start(goodTransportListenerCaptor.capture()); goodTransportListenerCaptor.getValue().transportReady(); verify(goodTransport, timeout(1000)).newStream(same(method), same(headers)); @@ -519,9 +525,9 @@ public class ManagedChannelImplTest { final ResolvedServerInfo server2 = new ResolvedServerInfo(addr2, Attributes.EMPTY); final ManagedClientTransport transport1 = mock(ManagedClientTransport.class); final ManagedClientTransport transport2 = mock(ManagedClientTransport.class); - when(mockTransportFactory.newClientTransport(same(addr1), any(String.class))) + when(mockTransportFactory.newClientTransport(same(addr1), any(String.class), any(String.class))) .thenReturn(transport1); - when(mockTransportFactory.newClientTransport(same(addr2), any(String.class))) + when(mockTransportFactory.newClientTransport(same(addr2), any(String.class), any(String.class))) .thenReturn(transport2); FakeNameResolverFactory nameResolverFactory = @@ -533,14 +539,16 @@ public class ManagedChannelImplTest { // Start a call. The channel will starts with the first address, which will fail to connect. call.start(mockCallListener, headers); verify(transport1, timeout(1000)).start(transportListenerCaptor.capture()); - verify(mockTransportFactory).newClientTransport(same(addr1), any(String.class)); + verify(mockTransportFactory) + .newClientTransport(same(addr1), any(String.class), any(String.class)); verify(mockTransportFactory, times(0)) - .newClientTransport(same(addr2), any(String.class)); + .newClientTransport(same(addr2), any(String.class), any(String.class)); transportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE); // The channel then try the second address, which will fail to connect too. verify(transport2, timeout(1000)).start(transportListenerCaptor.capture()); - verify(mockTransportFactory).newClientTransport(same(addr2), any(String.class)); + verify(mockTransportFactory) + .newClientTransport(same(addr2), any(String.class), any(String.class)); verify(transport2, timeout(1000)).start(transportListenerCaptor.capture()); transportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE); @@ -577,7 +585,7 @@ public class ManagedChannelImplTest { .thenReturn(mock(ClientStream.class)); when(transport2.newStream(any(MethodDescriptor.class), any(Metadata.class))) .thenReturn(mock(ClientStream.class)); - when(mockTransportFactory.newClientTransport(same(addr1), any(String.class))) + when(mockTransportFactory.newClientTransport(same(addr1), any(String.class), any(String.class))) .thenReturn(transport1, transport2); FakeNameResolverFactory nameResolverFactory = @@ -588,7 +596,8 @@ public class ManagedChannelImplTest { // First call will use the first address call.start(mockCallListener, headers); - verify(mockTransportFactory, timeout(1000)).newClientTransport(same(addr1), any(String.class)); + verify(mockTransportFactory, timeout(1000)) + .newClientTransport(same(addr1), any(String.class), any(String.class)); verify(transport1, timeout(1000)).start(transportListenerCaptor.capture()); transportListenerCaptor.getValue().transportReady(); verify(transport1, timeout(1000)).newStream(same(method), same(headers)); @@ -598,7 +607,8 @@ public class ManagedChannelImplTest { ClientCall call2 = channel.newCall(method, CallOptions.DEFAULT); call2.start(mockCallListener, headers); verify(transport2, timeout(1000)).start(transportListenerCaptor.capture()); - verify(mockTransportFactory, times(2)).newClientTransport(same(addr1), any(String.class)); + verify(mockTransportFactory, times(2)) + .newClientTransport(same(addr1), any(String.class), any(String.class)); transportListenerCaptor.getValue().transportReady(); verify(transport2, timeout(1000)).newStream(same(method), same(headers)); } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTransportManagerTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTransportManagerTest.java index d9cc8b6c84..274b50da58 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTransportManagerTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTransportManagerTest.java @@ -87,6 +87,7 @@ import java.util.concurrent.TimeUnit; public class ManagedChannelImplTransportManagerTest { private static final String authority = "fakeauthority"; + private static final String userAgent = "mosaic"; private final ExecutorService executor = Executors.newSingleThreadExecutor(); private final MethodDescriptor method = MethodDescriptor.create( @@ -127,7 +128,7 @@ public class ManagedChannelImplTransportManagerTest { channel = new ManagedChannelImpl("fake://target", mockBackoffPolicyProvider, mockNameResolverFactory, Attributes.EMPTY, mockLoadBalancerFactory, mockTransportFactory, DecompressorRegistry.getDefaultInstance(), - CompressorRegistry.getDefaultInstance(), executor, null, + CompressorRegistry.getDefaultInstance(), executor, userAgent, Collections.emptyList()); ArgumentCaptor> tmCaptor @@ -150,7 +151,7 @@ public class ManagedChannelImplTransportManagerTest { SocketAddress addr = mock(SocketAddress.class); EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(addr); ClientTransport t1 = tm.getTransport(addressGroup); - verify(mockTransportFactory, timeout(1000)).newClientTransport(addr, authority); + verify(mockTransportFactory, timeout(1000)).newClientTransport(addr, authority, userAgent); // The real transport MockClientTransportInfo transportInfo = transports.poll(1, TimeUnit.SECONDS); transportInfo.listener.transportReady(); @@ -175,7 +176,7 @@ public class ManagedChannelImplTransportManagerTest { // Pick the first transport ClientTransport t1 = tm.getTransport(addressGroup); assertNotNull(t1); - verify(mockTransportFactory, timeout(1000)).newClientTransport(addr1, authority); + verify(mockTransportFactory, timeout(1000)).newClientTransport(addr1, authority, userAgent); verify(mockBackoffPolicyProvider, times(backoffReset)).get(); // Fail the first transport, without setting it to ready MockClientTransportInfo transportInfo = transports.poll(1, TimeUnit.SECONDS); @@ -187,7 +188,7 @@ public class ManagedChannelImplTransportManagerTest { assertNotNull(t2); t2.newStream(method, new Metadata()); // Will keep the previous back-off policy, and not consult back-off policy - verify(mockTransportFactory, timeout(1000)).newClientTransport(addr2, authority); + verify(mockTransportFactory, timeout(1000)).newClientTransport(addr2, authority, userAgent); verify(mockBackoffPolicyProvider, times(backoffReset)).get(); transportInfo = transports.poll(1, TimeUnit.SECONDS); ClientTransport rt2 = transportInfo.transport; @@ -203,7 +204,8 @@ public class ManagedChannelImplTransportManagerTest { // Subsequent getTransport() will use the first address, since last attempt was successful. ClientTransport t3 = tm.getTransport(addressGroup); t3.newStream(method2, new Metadata()); - verify(mockTransportFactory, timeout(1000).times(2)).newClientTransport(addr1, authority); + verify(mockTransportFactory, timeout(1000).times(2)) + .newClientTransport(addr1, authority, userAgent); // Still no back-off policy creation, because an address succeeded. verify(mockBackoffPolicyProvider, times(backoffReset)).get(); transportInfo = transports.poll(1, TimeUnit.SECONDS); @@ -236,7 +238,7 @@ public class ManagedChannelImplTransportManagerTest { ClientTransport t1 = tm.getTransport(addressGroup); assertNotNull(t1); verify(mockTransportFactory, timeout(1000).times(++transportsAddr1)) - .newClientTransport(addr1, authority); + .newClientTransport(addr1, authority, userAgent); // Back-off policy was unset initially. verify(mockBackoffPolicyProvider, times(backoffReset)).get(); MockClientTransportInfo transportInfo = transports.poll(1, TimeUnit.SECONDS); @@ -250,7 +252,7 @@ public class ManagedChannelImplTransportManagerTest { ClientTransport t2 = tm.getTransport(addressGroup); assertNotNull(t2); verify(mockTransportFactory, timeout(1000).times(++transportsAddr1)) - .newClientTransport(addr1, authority); + .newClientTransport(addr1, authority, userAgent); // Back-off policy was not reset. verify(mockBackoffPolicyProvider, times(backoffReset)).get(); transports.poll(1, TimeUnit.SECONDS).listener.transportShutdown(Status.UNAVAILABLE); @@ -260,7 +262,7 @@ public class ManagedChannelImplTransportManagerTest { ClientTransport t3 = tm.getTransport(addressGroup); assertNotNull(t3); verify(mockTransportFactory, timeout(1000).times(++transportsAddr2)) - .newClientTransport(addr2, authority); + .newClientTransport(addr2, authority, userAgent); // Back-off policy was not reset. verify(mockBackoffPolicyProvider, times(backoffReset)).get(); transports.poll(1, TimeUnit.SECONDS).listener.transportShutdown(Status.UNAVAILABLE); @@ -272,7 +274,8 @@ public class ManagedChannelImplTransportManagerTest { // If backoff's DelayedTransport is still active, this is necessary. Otherwise it would be racy. t4.newStream(method, new Metadata()); verify(mockTransportFactory, timeout(1000).times(++transportsAddr1)) - .newClientTransport(addr1, authority); + + .newClientTransport(addr1, authority, userAgent); // Back-off policy was reset and consulted. verify(mockBackoffPolicyProvider, times(++backoffReset)).get(); verify(mockBackoffPolicy, times(++backoffConsulted)).nextBackoffMillis(); diff --git a/core/src/test/java/io/grpc/internal/TestUtils.java b/core/src/test/java/io/grpc/internal/TestUtils.java index ccf8c3c4c6..a23474f75b 100644 --- a/core/src/test/java/io/grpc/internal/TestUtils.java +++ b/core/src/test/java/io/grpc/internal/TestUtils.java @@ -97,7 +97,8 @@ final class TestUtils { }).when(mockTransport).start(any(ManagedClientTransport.Listener.class)); return mockTransport; } - }).when(mockTransportFactory).newClientTransport(any(SocketAddress.class), any(String.class)); + }).when(mockTransportFactory) + .newClientTransport(any(SocketAddress.class), any(String.class), any(String.class)); return captor; } diff --git a/core/src/test/java/io/grpc/internal/TransportSetTest.java b/core/src/test/java/io/grpc/internal/TransportSetTest.java index 4201636caf..93a3a2a741 100644 --- a/core/src/test/java/io/grpc/internal/TransportSetTest.java +++ b/core/src/test/java/io/grpc/internal/TransportSetTest.java @@ -38,9 +38,9 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.same; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -78,6 +78,7 @@ import java.util.concurrent.BlockingQueue; public class TransportSetTest { private static final String authority = "fakeauthority"; + private static final String userAgent = "mosaic"; private FakeClock fakeClock; private FakeClock fakeExecutor; @@ -131,7 +132,9 @@ public class TransportSetTest { // First attempt transportSet.obtainActiveTransport().newStream(method, new Metadata()); - verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(++transportsCreated)) + .newClientTransport(addr, authority, userAgent); + // Fail this one transports.poll().listener.transportShutdown(Status.UNAVAILABLE); verify(mockTransportSetCallback, times(++onAllAddressesFailed)).onAllAddressesFailed(); @@ -143,9 +146,11 @@ public class TransportSetTest { transportSet.obtainActiveTransport().newStream(method, new Metadata()); // Transport creation doesn't happen until time is due fakeClock.forwardMillis(9); - verify(mockTransportFactory, times(transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(transportsCreated)) + .newClientTransport(addr, authority, userAgent); fakeClock.forwardMillis(1); - verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(++transportsCreated)) + .newClientTransport(addr, authority, userAgent); // Fail this one too transports.poll().listener.transportShutdown(Status.UNAVAILABLE); verify(mockTransportSetCallback, times(++onAllAddressesFailed)).onAllAddressesFailed(); @@ -157,9 +162,11 @@ public class TransportSetTest { transportSet.obtainActiveTransport().newStream(method, new Metadata()); // Transport creation doesn't happen until time is due fakeClock.forwardMillis(99); - verify(mockTransportFactory, times(transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(transportsCreated)) + .newClientTransport(addr, authority, userAgent); fakeClock.forwardMillis(1); - verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(++transportsCreated)) + .newClientTransport(addr, authority, userAgent); // Let this one succeed transports.peek().listener.transportReady(); fakeClock.runDueTasks(); @@ -172,7 +179,8 @@ public class TransportSetTest { // Back-off is reset, and the next attempt will happen immediately transportSet.obtainActiveTransport().newStream(method, new Metadata()); verify(mockBackoffPolicyProvider, times(backoffReset)).get(); - verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(++transportsCreated)) + .newClientTransport(addr, authority, userAgent); // Final checks for consultations on back-off policies verify(mockBackoffPolicy1, times(backoff1Consulted)).nextBackoffMillis(); @@ -199,7 +207,8 @@ public class TransportSetTest { DelayedClientTransport delayedTransport1 = (DelayedClientTransport) transportSet.obtainActiveTransport(); delayedTransport1.newStream(method, new Metadata()); - verify(mockTransportFactory, times(++transportsAddr1)).newClientTransport(addr1, authority); + verify(mockTransportFactory, times(++transportsAddr1)) + .newClientTransport(addr1, authority, userAgent); // Let this one fail without success transports.poll().listener.transportShutdown(Status.UNAVAILABLE); assertNull(delayedTransport1.getTransportSupplier()); @@ -211,7 +220,8 @@ public class TransportSetTest { assertSame(delayedTransport1, delayedTransport2); delayedTransport2.newStream(method, new Metadata()); verify(mockBackoffPolicyProvider, times(backoffReset)).get(); - verify(mockTransportFactory, times(++transportsAddr2)).newClientTransport(addr2, authority); + verify(mockTransportFactory, times(++transportsAddr2)) + .newClientTransport(addr2, authority, userAgent); // Fail this one too transports.poll().listener.transportShutdown(Status.UNAVAILABLE); // All addresses have failed. Delayed transport will see an error. @@ -227,9 +237,11 @@ public class TransportSetTest { assertNotSame(delayedTransport2, delayedTransport3); delayedTransport3.newStream(method, new Metadata()); fakeClock.forwardMillis(9); - verify(mockTransportFactory, times(transportsAddr1)).newClientTransport(addr1, authority); + verify(mockTransportFactory, times(transportsAddr1)) + .newClientTransport(addr1, authority, userAgent); fakeClock.forwardMillis(1); - verify(mockTransportFactory, times(++transportsAddr1)).newClientTransport(addr1, authority); + verify(mockTransportFactory, times(++transportsAddr1)) + .newClientTransport(addr1, authority, userAgent); // Fail this one too transports.poll().listener.transportShutdown(Status.UNAVAILABLE); assertNull(delayedTransport3.getTransportSupplier()); @@ -241,7 +253,8 @@ public class TransportSetTest { assertSame(delayedTransport3, delayedTransport4); delayedTransport4.newStream(method, new Metadata()); verify(mockBackoffPolicyProvider, times(backoffReset)).get(); - verify(mockTransportFactory, times(++transportsAddr2)).newClientTransport(addr2, authority); + verify(mockTransportFactory, times(++transportsAddr2)) + .newClientTransport(addr2, authority, userAgent); // Fail this one too transports.poll().listener.transportShutdown(Status.UNAVAILABLE); // All addresses have failed again. Delayed transport will see an error @@ -257,9 +270,11 @@ public class TransportSetTest { assertNotSame(delayedTransport4, delayedTransport5); delayedTransport5.newStream(method, new Metadata()); fakeClock.forwardMillis(99); - verify(mockTransportFactory, times(transportsAddr1)).newClientTransport(addr1, authority); + verify(mockTransportFactory, times(transportsAddr1)) + .newClientTransport(addr1, authority, userAgent); fakeClock.forwardMillis(1); - verify(mockTransportFactory, times(++transportsAddr1)).newClientTransport(addr1, authority); + verify(mockTransportFactory, times(++transportsAddr1)) + .newClientTransport(addr1, authority, userAgent); // Let it through transports.peek().listener.transportReady(); // Delayed transport will see the connected transport. @@ -277,7 +292,8 @@ public class TransportSetTest { assertNotSame(delayedTransport5, delayedTransport6); delayedTransport6.newStream(method, new Metadata()); verify(mockBackoffPolicyProvider, times(backoffReset)).get(); - verify(mockTransportFactory, times(++transportsAddr1)).newClientTransport(addr1, authority); + verify(mockTransportFactory, times(++transportsAddr1)) + .newClientTransport(addr1, authority, userAgent); // Fail the transport transports.poll().listener.transportShutdown(Status.UNAVAILABLE); assertNull(delayedTransport6.getTransportSupplier()); @@ -289,7 +305,8 @@ public class TransportSetTest { assertSame(delayedTransport6, delayedTransport7); delayedTransport7.newStream(method, new Metadata()); verify(mockBackoffPolicyProvider, times(backoffReset)).get(); - verify(mockTransportFactory, times(++transportsAddr2)).newClientTransport(addr2, authority); + verify(mockTransportFactory, times(++transportsAddr2)) + .newClientTransport(addr2, authority, userAgent); // Fail this one too transports.poll().listener.transportShutdown(Status.UNAVAILABLE); // All addresses have failed. Delayed transport will see an error. @@ -305,9 +322,11 @@ public class TransportSetTest { assertNotSame(delayedTransport7, delayedTransport8); delayedTransport8.newStream(method, new Metadata()); fakeClock.forwardMillis(9); - verify(mockTransportFactory, times(transportsAddr1)).newClientTransport(addr1, authority); + verify(mockTransportFactory, times(transportsAddr1)) + .newClientTransport(addr1, authority, userAgent); fakeClock.forwardMillis(1); - verify(mockTransportFactory, times(++transportsAddr1)).newClientTransport(addr1, authority); + verify(mockTransportFactory, times(++transportsAddr1)) + .newClientTransport(addr1, authority, userAgent); // Final checks on invocations on back-off policies verify(mockBackoffPolicy1, times(backoff1Consulted)).nextBackoffMillis(); @@ -326,31 +345,37 @@ public class TransportSetTest { int transportsCreated = 0; // Won't connect until requested - verify(mockTransportFactory, times(transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(transportsCreated)) + .newClientTransport(addr, authority, userAgent); // First attempt transportSet.obtainActiveTransport().newStream(method, new Metadata()); - verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(++transportsCreated)) + .newClientTransport(addr, authority, userAgent); // Fail this one transports.poll().listener.transportShutdown(Status.UNAVAILABLE); // Won't reconnect until requested, even if back-off time has expired fakeClock.forwardMillis(10); - verify(mockTransportFactory, times(transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(transportsCreated)) + .newClientTransport(addr, authority, userAgent); // Once requested, will reconnect transportSet.obtainActiveTransport().newStream(method, new Metadata()); - verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(++transportsCreated)) + .newClientTransport(addr, authority, userAgent); // Fail this one, too transports.poll().listener.transportShutdown(Status.UNAVAILABLE); // Request immediately, but will wait for back-off before reconnecting transportSet.obtainActiveTransport().newStream(method, new Metadata()); - verify(mockTransportFactory, times(transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(transportsCreated)) + .newClientTransport(addr, authority, userAgent); fakeClock.forwardMillis(100); - verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(++transportsCreated)) + .newClientTransport(addr, authority, userAgent); fakeExecutor.runDueTasks(); // Drain new 'real' stream creation; not important to this test. } @@ -364,7 +389,8 @@ public class TransportSetTest { // Trigger TRANSIENT_FAILURE transportSet.obtainActiveTransport().newStream(method, new Metadata()); - verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(++transportsCreated)) + .newClientTransport(addr, authority, userAgent); transports.poll().listener.transportShutdown(Status.UNAVAILABLE); // Won't reconnect without any active streams @@ -372,11 +398,13 @@ public class TransportSetTest { assertTrue(transientFailureTransport instanceof DelayedClientTransport); transientFailureTransport.newStream(method, new Metadata()).cancel(Status.CANCELLED); fakeClock.forwardMillis(10); - verify(mockTransportFactory, times(transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(transportsCreated)) + .newClientTransport(addr, authority, userAgent); // Lose race (long delay between obtainActiveTransport and newStream); will now reconnect transientFailureTransport.newStream(method, new Metadata()); - verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(++transportsCreated)) + .newClientTransport(addr, authority, userAgent); fakeExecutor.runDueTasks(); // Drain new 'real' stream creation; not important to this test. } @@ -388,7 +416,7 @@ public class TransportSetTest { // First transport is created immediately ClientTransport pick = transportSet.obtainActiveTransport(); - verify(mockTransportFactory).newClientTransport(addr, authority); + verify(mockTransportFactory).newClientTransport(addr, authority, userAgent); assertNotNull(pick); // Fail this one MockClientTransportInfo transportInfo = transports.poll(); @@ -408,11 +436,11 @@ public class TransportSetTest { pick = transportSet.obtainActiveTransport(); assertNotNull(pick); assertTrue(pick instanceof FailingClientTransport); - verify(mockTransportFactory).newClientTransport(addr, authority); + verify(mockTransportFactory).newClientTransport(addr, authority, userAgent); // Reconnect will eventually happen, even though TransportSet has been shut down fakeClock.forwardMillis(10); - verify(mockTransportFactory, times(2)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(2)).newClientTransport(addr, authority, userAgent); // The pending stream will be started on this newly started transport after it's ready. // The transport is shut down by TransportSet right after the stream is created. transportInfo = transports.poll(); @@ -443,7 +471,7 @@ public class TransportSetTest { // First transport is created immediately ClientTransport pick = transportSet.obtainActiveTransport(); - verify(mockTransportFactory).newClientTransport(addr, authority); + verify(mockTransportFactory).newClientTransport(addr, authority, userAgent); assertNotNull(pick); // Fail this one MockClientTransportInfo transportInfo = transports.poll(); @@ -478,7 +506,7 @@ public class TransportSetTest { transportSet.shutdown(); ClientTransport pick = transportSet.obtainActiveTransport(); assertNotNull(pick); - verify(mockTransportFactory, times(0)).newClientTransport(addr, authority); + verify(mockTransportFactory, times(0)).newClientTransport(addr, authority, userAgent); } @Test @@ -490,7 +518,7 @@ public class TransportSetTest { private void createTransportSet(SocketAddress ... addrs) { addressGroup = new EquivalentAddressGroup(Arrays.asList(addrs)); - transportSet = new TransportSet(addressGroup, authority, mockLoadBalancer, + transportSet = new TransportSet(addressGroup, authority, userAgent, mockLoadBalancer, mockBackoffPolicyProvider, mockTransportFactory, fakeClock.scheduledExecutorService, fakeExecutor.scheduledExecutorService, mockTransportSetCallback, Stopwatch.createUnstarted(fakeClock.ticker)); diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index c1239e4fea..12b8fecfc9 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -311,23 +311,23 @@ public class NettyChannelBuilder extends AbstractManagedChannelImplBuilder channelType; private final EventLoopGroup group; private final ProtocolNegotiator negotiator; private final AsciiString authority; + private final AsciiString userAgent; private final int flowControlWindow; private final int maxMessageSize; private final int maxHeaderListSize; @@ -83,7 +83,7 @@ class NettyClientTransport implements ManagedClientTransport { NettyClientTransport(SocketAddress address, Class channelType, EventLoopGroup group, ProtocolNegotiator negotiator, int flowControlWindow, int maxMessageSize, int maxHeaderListSize, - String authority) { + String authority, @Nullable String userAgent) { this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator"); this.address = Preconditions.checkNotNull(address, "address"); this.group = Preconditions.checkNotNull(group, "group"); @@ -92,6 +92,7 @@ class NettyClientTransport implements ManagedClientTransport { this.maxMessageSize = maxMessageSize; this.maxHeaderListSize = maxHeaderListSize; this.authority = new AsciiString(authority); + this.userAgent = new AsciiString(GrpcUtil.getGrpcUserAgent("netty", userAgent)); } @Override @@ -114,9 +115,6 @@ class NettyClientTransport implements ManagedClientTransport { public ClientStream newStream(MethodDescriptor method, Metadata headers) { Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(headers, "headers"); - AsciiString userAgent = headers.containsKey(GrpcUtil.USER_AGENT_KEY) - ? new AsciiString(GrpcUtil.getGrpcUserAgent("netty", headers.get(GrpcUtil.USER_AGENT_KEY))) - : DEFAULT_AGENT; return new NettyClientStream(method, headers, channel, handler, maxMessageSize, authority, negotiationHandler.scheme(), userAgent) { @Override diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 06c93e9171..28ad4a143c 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -132,7 +132,7 @@ public class NettyClientTransportTest { } @Test - public void headersShouldAddDefaultUserAgent() throws Exception { + public void addDefaultUserAgent() throws Exception { startServer(); NettyClientTransport transport = newTransport(newNegotiator()); transport.start(clientTransportListener); @@ -148,21 +148,18 @@ public class NettyClientTransportTest { } @Test - public void headersShouldOverrideDefaultUserAgent() throws Exception { + public void overrideDefaultUserAgent() throws Exception { startServer(); - NettyClientTransport transport = newTransport(newNegotiator()); + NettyClientTransport transport = newTransport(newNegotiator(), + DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent"); transport.start(clientTransportListener); - // Send a single RPC and wait for the response. - String userAgent = "testUserAgent"; - Metadata sentHeaders = new Metadata(); - sentHeaders.put(USER_AGENT_KEY, userAgent); - new Rpc(transport, sentHeaders).halfClose().waitForResponse(); + new Rpc(transport, new Metadata()).halfClose().waitForResponse(); // Verify that the received headers contained the User-Agent. assertEquals(1, serverListener.streamListeners.size()); Metadata receivedHeaders = serverListener.streamListeners.get(0).headers; - assertEquals(GrpcUtil.getGrpcUserAgent("netty", userAgent), + assertEquals(GrpcUtil.getGrpcUserAgent("netty", "testUserAgent"), receivedHeaders.get(USER_AGENT_KEY)); } @@ -171,7 +168,7 @@ public class NettyClientTransportTest { startServer(); // Allow the response payloads of up to 1 byte. NettyClientTransport transport = newTransport(newNegotiator(), - 1, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); + 1, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null); transport.start(clientTransportListener); try { @@ -248,7 +245,8 @@ public class NettyClientTransportTest { public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception { startServer(); - NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, 1); + NettyClientTransport transport = + newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, 1, null); transport.start(clientTransportListener); try { @@ -298,13 +296,14 @@ public class NettyClientTransportTest { private NettyClientTransport newTransport(ProtocolNegotiator negotiator) { return newTransport(negotiator, - DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); + DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null /* user agent */); } private NettyClientTransport newTransport(ProtocolNegotiator negotiator, - int maxMsgSize, int maxHeaderListSize) { + int maxMsgSize, int maxHeaderListSize, String userAgent) { NettyClientTransport transport = new NettyClientTransport(address, NioSocketChannel.class, - group, negotiator, DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize, authority); + group, negotiator, DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize, authority, + userAgent); transports.add(transport); return transport; } diff --git a/netty/src/test/java/io/grpc/netty/NettyTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyTransportTest.java index b50cab77a7..ab5c652742 100644 --- a/netty/src/test/java/io/grpc/netty/NettyTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyTransportTest.java @@ -75,12 +75,13 @@ public class NettyTransportTest extends AbstractTransportTest { @Override protected ManagedClientTransport newClientTransport() { return clientFactory.newClientTransport( - new InetSocketAddress("localhost", SERVER_PORT), "localhost:" + SERVER_PORT); + new InetSocketAddress("localhost", SERVER_PORT), + "localhost:" + SERVER_PORT, + null /* agent */); } - // TODO(ejona): Flaky @Test - @Ignore + @Ignore("flaky") @Override public void flowControlPushBack() {} } diff --git a/okhttp/src/main/java/io/grpc/okhttp/Headers.java b/okhttp/src/main/java/io/grpc/okhttp/Headers.java index 573c4982e6..544647ca06 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/Headers.java +++ b/okhttp/src/main/java/io/grpc/okhttp/Headers.java @@ -46,6 +46,8 @@ import okio.ByteString; import java.util.ArrayList; import java.util.List; +import javax.annotation.Nullable; + /** * Constants for request/response headers. */ @@ -63,7 +65,7 @@ public class Headers { * application thread context. */ public static List

createRequestHeaders(Metadata headers, String defaultPath, - String authority) { + String authority, @Nullable String applicationUserAgent) { Preconditions.checkNotNull(headers, "headers"); Preconditions.checkNotNull(defaultPath, "defaultPath"); Preconditions.checkNotNull(authority, "authority"); @@ -79,7 +81,7 @@ public class Headers { String path = defaultPath; okhttpHeaders.add(new Header(Header.TARGET_PATH, path)); - String userAgent = GrpcUtil.getGrpcUserAgent("okhttp", headers.get(USER_AGENT_KEY)); + String userAgent = GrpcUtil.getGrpcUserAgent("okhttp", applicationUserAgent); okhttpHeaders.add(new Header(GrpcUtil.USER_AGENT_KEY.name(), userAgent)); // All non-pseudo headers must come after pseudo headers. diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index f51f2b2dfc..696cc24b50 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -260,13 +260,14 @@ public class OkHttpChannelBuilder extends } @Override - public ManagedClientTransport newClientTransport(SocketAddress addr, String authority) { + public ManagedClientTransport newClientTransport( + SocketAddress addr, String authority, @Nullable String userAgent) { if (closed) { throw new IllegalStateException("The transport factory is closed."); } InetSocketAddress inetSocketAddr = (InetSocketAddress) addr; - return new OkHttpClientTransport(inetSocketAddr, authority, executor, socketFactory, - Utils.convertSpec(connectionSpec), maxMessageSize); + return new OkHttpClientTransport(inetSocketAddr, authority, userAgent, executor, + socketFactory, Utils.convertSpec(connectionSpec), maxMessageSize); } @Override diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java index 7e55c73c51..76af89ef9e 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java @@ -34,7 +34,6 @@ package io.grpc.okhttp; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import io.grpc.Metadata; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -73,6 +72,7 @@ class OkHttpClientStream extends Http2ClientStream { private final OutboundFlowController outboundFlow; private final OkHttpClientTransport transport; private final Object lock; + private final String userAgent; private String authority; private Object outboundFlowState; private volatile Integer id; @@ -95,7 +95,8 @@ class OkHttpClientStream extends Http2ClientStream { OutboundFlowController outboundFlow, Object lock, int maxMessageSize, - String authority) { + String authority, + @Nullable String userAgent) { super(new OkHttpWritableBufferAllocator(), maxMessageSize); this.method = method; this.headers = headers; @@ -104,6 +105,7 @@ class OkHttpClientStream extends Http2ClientStream { this.outboundFlow = outboundFlow; this.lock = lock; this.authority = authority; + this.userAgent = userAgent; } /** @@ -136,7 +138,8 @@ class OkHttpClientStream extends Http2ClientStream { public void start(ClientStreamListener listener) { super.start(listener); String defaultPath = "/" + method.getFullMethodName(); - List
requestHeaders = Headers.createRequestHeaders(headers, defaultPath, authority); + List
requestHeaders = + Headers.createRequestHeaders(headers, defaultPath, authority, userAgent); headers = null; synchronized (lock) { this.requestHeaders = requestHeaders; diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index 2b1c9b8df0..465ff6e04e 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -125,6 +125,7 @@ class OkHttpClientTransport implements ManagedClientTransport { private final InetSocketAddress address; private final String defaultAuthority; + private final String userAgent; private final Random random = new Random(); private final Ticker ticker; private Listener listener; @@ -168,8 +169,8 @@ class OkHttpClientTransport implements ManagedClientTransport { Runnable connectingCallback; SettableFuture connectedFuture; - OkHttpClientTransport(InetSocketAddress address, String authority, Executor executor, - @Nullable SSLSocketFactory sslSocketFactory, ConnectionSpec connectionSpec, + OkHttpClientTransport(InetSocketAddress address, String authority, @Nullable String userAgent, + Executor executor, @Nullable SSLSocketFactory sslSocketFactory, ConnectionSpec connectionSpec, int maxMessageSize) { this.address = Preconditions.checkNotNull(address, "address"); this.defaultAuthority = authority; @@ -182,19 +183,21 @@ class OkHttpClientTransport implements ManagedClientTransport { this.sslSocketFactory = sslSocketFactory; this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec"); this.ticker = Ticker.systemTicker(); + this.userAgent = userAgent; } /** * Create a transport connected to a fake peer for test. */ @VisibleForTesting - OkHttpClientTransport(Executor executor, FrameReader frameReader, FrameWriter testFrameWriter, - int nextStreamId, Socket socket, Ticker ticker, + OkHttpClientTransport(String userAgent, Executor executor, FrameReader frameReader, + FrameWriter testFrameWriter, int nextStreamId, Socket socket, Ticker ticker, @Nullable Runnable connectingCallback, SettableFuture connectedFuture, int maxMessageSize) { address = null; this.maxMessageSize = maxMessageSize; defaultAuthority = "notarealauthority:80"; + this.userAgent = userAgent; this.executor = Preconditions.checkNotNull(executor); serializingExecutor = new SerializingExecutor(executor); this.testFrameReader = Preconditions.checkNotNull(frameReader); @@ -247,7 +250,7 @@ class OkHttpClientTransport implements ManagedClientTransport { Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(headers, "headers"); return new OkHttpClientStream(method, headers, frameWriter, OkHttpClientTransport.this, - outboundFlow, lock, maxMessageSize, defaultAuthority); + outboundFlow, lock, maxMessageSize, defaultAuthority, userAgent); } @GuardedBy("lock") diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java index 5c34758f5f..5ca6fe5d86 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java @@ -75,7 +75,7 @@ public class OkHttpClientStreamTest { methodDescriptor = MethodDescriptor.create( MethodType.UNARY, "/testService/test", marshaller, marshaller); stream = new OkHttpClientStream(methodDescriptor, new Metadata(), frameWriter, transport, - flowController, lock, MAX_MESSAGE_SIZE, "localhost"); + flowController, lock, MAX_MESSAGE_SIZE, "localhost", "userAgent"); } @Test diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index ed429e676d..d9f353ffa0 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -159,20 +159,20 @@ public class OkHttpClientTransportTest { } private void initTransport() throws Exception { - startTransport(3, null, true, DEFAULT_MAX_MESSAGE_SIZE); + startTransport(3, null, true, DEFAULT_MAX_MESSAGE_SIZE, null); } private void initTransport(int startId) throws Exception { - startTransport(startId, null, true, DEFAULT_MAX_MESSAGE_SIZE); + startTransport(startId, null, true, DEFAULT_MAX_MESSAGE_SIZE, null); } private void initTransportAndDelayConnected() throws Exception { delayConnectedCallback = new DelayConnectedCallback(); - startTransport(3, delayConnectedCallback, false, DEFAULT_MAX_MESSAGE_SIZE); + startTransport(3, delayConnectedCallback, false, DEFAULT_MAX_MESSAGE_SIZE, null); } private void startTransport(int startId, @Nullable Runnable connectingCallback, - boolean waitingForConnected, int maxMessageSize) throws Exception { + boolean waitingForConnected, int maxMessageSize, String userAgent) throws Exception { connectedFuture = SettableFuture.create(); Ticker ticker = new Ticker() { @Override @@ -180,10 +180,9 @@ public class OkHttpClientTransportTest { return nanoTime; } }; - clientTransport = new OkHttpClientTransport( - executor, frameReader, frameWriter, startId, - new MockSocket(frameReader), ticker, connectingCallback, connectedFuture, - maxMessageSize); + clientTransport = new OkHttpClientTransport(userAgent, executor, frameReader, + frameWriter, startId, new MockSocket(frameReader), ticker, connectingCallback, + connectedFuture, maxMessageSize); clientTransport.start(transportListener); if (waitingForConnected) { connectedFuture.get(TIME_OUT_MS, TimeUnit.MILLISECONDS); @@ -194,7 +193,7 @@ public class OkHttpClientTransportTest { public void testToString() throws Exception { InetSocketAddress address = InetSocketAddress.createUnresolved("hostname", 31415); clientTransport = new OkHttpClientTransport( - address, "hostname", executor, null, + address, "hostname", null /* agent */, executor, null, Utils.convertSpec(OkHttpChannelBuilder.DEFAULT_CONNECTION_SPEC), DEFAULT_MAX_MESSAGE_SIZE); String s = clientTransport.toString(); assertTrue("Unexpected: " + s, s.contains("OkHttpClientTransport")); @@ -204,7 +203,7 @@ public class OkHttpClientTransportTest { @Test public void maxMessageSizeShouldBeEnforced() throws Exception { // Allow the response payloads of up to 1 byte. - startTransport(3, null, true, 1); + startTransport(3, null, true, 1, null); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = clientTransport.newStream(method, new Metadata()); @@ -405,7 +404,7 @@ public class OkHttpClientTransportTest { } @Test - public void headersShouldAddDefaultUserAgent() throws Exception { + public void addDefaultUserAgent() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = clientTransport.newStream(method, new Metadata()); @@ -423,19 +422,16 @@ public class OkHttpClientTransportTest { } @Test - public void headersShouldOverrideDefaultUserAgent() throws Exception { - initTransport(); + public void overrideDefaultUserAgent() throws Exception { + startTransport(3, null, true, DEFAULT_MAX_MESSAGE_SIZE, "fakeUserAgent"); MockStreamListener listener = new MockStreamListener(); - String userAgent = "fakeUserAgent"; - Metadata metadata = new Metadata(); - metadata.put(GrpcUtil.USER_AGENT_KEY, userAgent); - OkHttpClientStream stream = clientTransport.newStream(method, metadata); + OkHttpClientStream stream = clientTransport.newStream(method, new Metadata()); stream.start(listener); List
expectedHeaders = Arrays.asList(SCHEME_HEADER, METHOD_HEADER, new Header(Header.TARGET_AUTHORITY, "notarealauthority:80"), new Header(Header.TARGET_PATH, "/fakemethod"), new Header(GrpcUtil.USER_AGENT_KEY.name(), - GrpcUtil.getGrpcUserAgent("okhttp", userAgent)), + GrpcUtil.getGrpcUserAgent("okhttp", "fakeUserAgent")), CONTENT_TYPE_HEADER, TE_HEADER); verify(frameWriter, timeout(TIME_OUT_MS)) .synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders)); @@ -1311,6 +1307,7 @@ public class OkHttpClientTransportTest { clientTransport = new OkHttpClientTransport( new InetSocketAddress("host", 1234), "invalid_authority", + "userAgent", executor, null, ConnectionSpec.CLEARTEXT, @@ -1328,6 +1325,7 @@ public class OkHttpClientTransportTest { clientTransport = new OkHttpClientTransport( new InetSocketAddress("localhost", 0), "authority", + "userAgent", executor, null, ConnectionSpec.CLEARTEXT, diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpTransportTest.java index 65d482ab1d..6c4ce99bd8 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpTransportTest.java @@ -73,7 +73,9 @@ public class OkHttpTransportTest extends AbstractTransportTest { @Override protected ManagedClientTransport newClientTransport() { return clientFactory.newClientTransport( - new InetSocketAddress("127.0.0.1", SERVER_PORT), "127.0.0.1:" + SERVER_PORT); + new InetSocketAddress("127.0.0.1", SERVER_PORT), + "127.0.0.1:" + SERVER_PORT, + null /* agent */); } // TODO(ejona): Flaky/Broken diff --git a/testing/src/main/java/io/grpc/internal/testing/AbstractClientTransportFactoryTest.java b/testing/src/main/java/io/grpc/internal/testing/AbstractClientTransportFactoryTest.java index a6b52b6c7a..893523f173 100644 --- a/testing/src/main/java/io/grpc/internal/testing/AbstractClientTransportFactoryTest.java +++ b/testing/src/main/java/io/grpc/internal/testing/AbstractClientTransportFactoryTest.java @@ -58,6 +58,6 @@ public abstract class AbstractClientTransportFactoryTest { ClientTransportFactory transportFactory = newClientTransportFactory(); transportFactory.close(); transportFactory.newClientTransport( - new InetSocketAddress("localhost", port), "localhost:" + port); + new InetSocketAddress("localhost", port), "localhost:" + port, "agent"); } }