core,netty,okhttp: move user agent out of client call and into the transport

This commit is contained in:
Carl Mastrangelo 2016-05-24 16:29:26 -07:00
parent c102dd4e4f
commit 1cc76d8132
23 changed files with 207 additions and 162 deletions

View File

@ -94,7 +94,8 @@ public class InProcessChannelBuilder extends
} }
@Override @Override
public ManagedClientTransport newClientTransport(SocketAddress addr, String authority) { public ManagedClientTransport newClientTransport(
SocketAddress addr, String authority, String userAgent) {
if (closed) { if (closed) {
throw new IllegalStateException("The transport factory is closed."); throw new IllegalStateException("The transport factory is closed.");
} }

View File

@ -163,7 +163,7 @@ public abstract class AbstractManagedChannelImplBuilder
} }
@Override @Override
public final T userAgent(String userAgent) { public final T userAgent(@Nullable String userAgent) {
this.userAgent = userAgent; this.userAgent = userAgent;
return thisT(); return thisT();
} }
@ -232,8 +232,8 @@ public abstract class AbstractManagedChannelImplBuilder
@Override @Override
public ManagedClientTransport newClientTransport(SocketAddress serverAddress, public ManagedClientTransport newClientTransport(SocketAddress serverAddress,
String authority) { String authority, @Nullable String userAgent) {
return factory.newClientTransport(serverAddress, authorityOverride); return factory.newClientTransport(serverAddress, authorityOverride, userAgent);
} }
@Override @Override

View File

@ -90,7 +90,6 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
private boolean cancelCalled; private boolean cancelCalled;
private boolean halfCloseCalled; private boolean halfCloseCalled;
private final ClientTransportProvider clientTransportProvider; private final ClientTransportProvider clientTransportProvider;
private String userAgent;
private ScheduledExecutorService deadlineCancellationExecutor; private ScheduledExecutorService deadlineCancellationExecutor;
private DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance(); private DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance();
private CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance(); private CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance();
@ -129,11 +128,6 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
ClientTransport get(CallOptions callOptions); ClientTransport get(CallOptions callOptions);
} }
ClientCallImpl<ReqT, RespT> setUserAgent(String userAgent) {
this.userAgent = userAgent;
return this;
}
ClientCallImpl<ReqT, RespT> setDecompressorRegistry(DecompressorRegistry decompressorRegistry) { ClientCallImpl<ReqT, RespT> setDecompressorRegistry(DecompressorRegistry decompressorRegistry) {
this.decompressorRegistry = decompressorRegistry; this.decompressorRegistry = decompressorRegistry;
return this; return this;
@ -145,13 +139,10 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
} }
@VisibleForTesting @VisibleForTesting
static void prepareHeaders(Metadata headers, CallOptions callOptions, String userAgent, static void prepareHeaders(Metadata headers, DecompressorRegistry decompressorRegistry,
DecompressorRegistry decompressorRegistry, Compressor compressor) { Compressor compressor) {
// Fill out the User-Agent header. // Remove user agent. Agent are added in the transport.
headers.removeAll(USER_AGENT_KEY); headers.removeAll(USER_AGENT_KEY);
if (userAgent != null) {
headers.put(USER_AGENT_KEY, userAgent);
}
headers.removeAll(MESSAGE_ENCODING_KEY); headers.removeAll(MESSAGE_ENCODING_KEY);
if (compressor != Codec.Identity.NONE) { if (compressor != Codec.Identity.NONE) {
@ -213,7 +204,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
compressor = Codec.Identity.NONE; compressor = Codec.Identity.NONE;
} }
prepareHeaders(headers, callOptions, userAgent, decompressorRegistry, compressor); prepareHeaders(headers, decompressorRegistry, compressor);
final boolean deadlineExceeded = effectiveDeadline != null && effectiveDeadline.isExpired(); final boolean deadlineExceeded = effectiveDeadline != null && effectiveDeadline.isExpired();
if (!deadlineExceeded) { if (!deadlineExceeded) {

View File

@ -34,6 +34,8 @@ package io.grpc.internal;
import java.io.Closeable; import java.io.Closeable;
import java.net.SocketAddress; import java.net.SocketAddress;
import javax.annotation.Nullable;
/** Pre-configured factory for creating {@link ManagedClientTransport} instances. */ /** Pre-configured factory for creating {@link ManagedClientTransport} instances. */
public interface ClientTransportFactory extends Closeable { public interface ClientTransportFactory extends Closeable {
/** /**
@ -42,13 +44,14 @@ public interface ClientTransportFactory extends Closeable {
* @param serverAddress the address that the transport is connected to * @param serverAddress the address that the transport is connected to
* @param authority the HTTP/2 authority of the server * @param authority the HTTP/2 authority of the server
*/ */
ManagedClientTransport newClientTransport(SocketAddress serverAddress, String authority); ManagedClientTransport newClientTransport(SocketAddress serverAddress, String authority,
@Nullable String userAgent);
/** /**
* Releases any resources. * Releases any resources.
* *
* <p>After this method has been called, it's no longer valid to call * <p>After this method has been called, it's no longer valid to call
* {@link #newClientTransport(SocketAddress, String)}. No guarantees about thread-safety are made. * {@link #newClientTransport}. No guarantees about thread-safety are made.
*/ */
@Override @Override
void close(); void close();

View File

@ -92,7 +92,6 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI
private final ClientTransportFactory transportFactory; private final ClientTransportFactory transportFactory;
private final Executor executor; private final Executor executor;
private final boolean usingSharedExecutor; private final boolean usingSharedExecutor;
private final String userAgent;
private final Object lock = new Object(); private final Object lock = new Object();
private final DecompressorRegistry decompressorRegistry; private final DecompressorRegistry decompressorRegistry;
@ -110,6 +109,7 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI
* any interceptors this will just be {@link RealChannel}. * any interceptors this will just be {@link RealChannel}.
*/ */
private final Channel interceptorChannel; private final Channel interceptorChannel;
@Nullable private final String userAgent;
private final NameResolver nameResolver; private final NameResolver nameResolver;
private final LoadBalancer<ClientTransport> loadBalancer; private final LoadBalancer<ClientTransport> loadBalancer;
@ -159,11 +159,11 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI
this.nameResolver = getNameResolver(target, nameResolverFactory, nameResolverParams); this.nameResolver = getNameResolver(target, nameResolverFactory, nameResolverParams);
this.loadBalancer = loadBalancerFactory.newLoadBalancer(nameResolver.getServiceAuthority(), tm); this.loadBalancer = loadBalancerFactory.newLoadBalancer(nameResolver.getServiceAuthority(), tm);
this.transportFactory = transportFactory; this.transportFactory = transportFactory;
this.userAgent = userAgent;
this.interceptorChannel = ClientInterceptors.intercept(new RealChannel(), interceptors); this.interceptorChannel = ClientInterceptors.intercept(new RealChannel(), interceptors);
scheduledExecutor = SharedResourceHolder.get(TIMER_SERVICE); scheduledExecutor = SharedResourceHolder.get(TIMER_SERVICE);
this.decompressorRegistry = decompressorRegistry; this.decompressorRegistry = decompressorRegistry;
this.compressorRegistry = compressorRegistry; this.compressorRegistry = compressorRegistry;
this.userAgent = userAgent;
this.nameResolver.start(new NameResolver.Listener() { this.nameResolver.start(new NameResolver.Listener() {
@Override @Override
@ -344,7 +344,6 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI
callOptions, callOptions,
transportProvider, transportProvider,
scheduledExecutor) scheduledExecutor)
.setUserAgent(userAgent)
.setDecompressorRegistry(decompressorRegistry) .setDecompressorRegistry(decompressorRegistry)
.setCompressorRegistry(compressorRegistry); .setCompressorRegistry(compressorRegistry);
} }
@ -394,8 +393,9 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI
} }
ts = transports.get(addressGroup); ts = transports.get(addressGroup);
if (ts == null) { if (ts == null) {
ts = new TransportSet(addressGroup, authority(), loadBalancer, backoffPolicyProvider, ts = new TransportSet(addressGroup, authority(), userAgent, loadBalancer,
transportFactory, scheduledExecutor, executor, new TransportSet.Callback() { backoffPolicyProvider, transportFactory, scheduledExecutor, executor,
new TransportSet.Callback() {
@Override @Override
public void onTerminated() { public void onTerminated() {
synchronized (lock) { synchronized (lock) {

View File

@ -67,6 +67,7 @@ final class TransportSet implements WithLogId {
private final Object lock = new Object(); private final Object lock = new Object();
private final EquivalentAddressGroup addressGroup; private final EquivalentAddressGroup addressGroup;
private final String authority; private final String authority;
private final String userAgent;
private final BackoffPolicy.Provider backoffPolicyProvider; private final BackoffPolicy.Provider backoffPolicyProvider;
private final Callback callback; private final Callback callback;
private final ClientTransportFactory transportFactory; private final ClientTransportFactory transportFactory;
@ -122,21 +123,22 @@ final class TransportSet implements WithLogId {
@Nullable @Nullable
private volatile ManagedClientTransport activeTransport; private volatile ManagedClientTransport activeTransport;
TransportSet(EquivalentAddressGroup addressGroup, String authority, TransportSet(EquivalentAddressGroup addressGroup, String authority, String userAgent,
LoadBalancer<ClientTransport> loadBalancer, BackoffPolicy.Provider backoffPolicyProvider, LoadBalancer<ClientTransport> loadBalancer, BackoffPolicy.Provider backoffPolicyProvider,
ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor, ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor,
Executor appExecutor, Callback callback) { Executor appExecutor, Callback callback) {
this(addressGroup, authority, loadBalancer, backoffPolicyProvider, transportFactory, this(addressGroup, authority, userAgent, loadBalancer, backoffPolicyProvider, transportFactory,
scheduledExecutor, appExecutor, callback, Stopwatch.createUnstarted()); scheduledExecutor, appExecutor, callback, Stopwatch.createUnstarted());
} }
@VisibleForTesting @VisibleForTesting
TransportSet(EquivalentAddressGroup addressGroup, String authority, TransportSet(EquivalentAddressGroup addressGroup, String authority, String userAgent,
LoadBalancer<ClientTransport> loadBalancer, BackoffPolicy.Provider backoffPolicyProvider, LoadBalancer<ClientTransport> loadBalancer, BackoffPolicy.Provider backoffPolicyProvider,
ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor, ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor,
Executor appExecutor, Callback callback, Stopwatch connectingTimer) { Executor appExecutor, Callback callback, Stopwatch connectingTimer) {
this.addressGroup = Preconditions.checkNotNull(addressGroup, "addressGroup"); this.addressGroup = Preconditions.checkNotNull(addressGroup, "addressGroup");
this.authority = authority; this.authority = authority;
this.userAgent = userAgent;
this.loadBalancer = loadBalancer; this.loadBalancer = loadBalancer;
this.backoffPolicyProvider = backoffPolicyProvider; this.backoffPolicyProvider = backoffPolicyProvider;
this.transportFactory = transportFactory; this.transportFactory = transportFactory;
@ -186,7 +188,8 @@ final class TransportSet implements WithLogId {
nextAddressIndex = 0; nextAddressIndex = 0;
} }
ManagedClientTransport transport = transportFactory.newClientTransport(address, authority); ManagedClientTransport transport =
transportFactory.newClientTransport(address, authority, userAgent);
if (log.isLoggable(Level.FINE)) { if (log.isLoggable(Level.FINE)) {
log.log(Level.FINE, "[{0}] Created {1} for {2}", log.log(Level.FINE, "[{0}] Created {1} for {2}",
new Object[] {getLogId(), transport.getLogId(), address}); new Object[] {getLogId(), transport.getLogId(), address});

View File

@ -31,6 +31,7 @@
package io.grpc.internal; package io.grpc.internal;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER; import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
@ -195,19 +196,18 @@ public class ClientCallImplTest {
} }
@Test @Test
public void prepareHeaders_userAgentAdded() { public void prepareHeaders_userAgentRemove() {
Metadata m = new Metadata(); Metadata m = new Metadata();
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", decompressorRegistry, m.put(GrpcUtil.USER_AGENT_KEY, "batmobile");
Codec.Identity.NONE); ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE);
assertEquals(m.get(GrpcUtil.USER_AGENT_KEY), "user agent"); assertThat(m.get(GrpcUtil.USER_AGENT_KEY)).isNull();
} }
@Test @Test
public void prepareHeaders_ignoreIdentityEncoding() { public void prepareHeaders_ignoreIdentityEncoding() {
Metadata m = new Metadata(); Metadata m = new Metadata();
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", decompressorRegistry, ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE);
Codec.Identity.NONE);
assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
} }
@ -250,8 +250,7 @@ public class ClientCallImplTest {
} }
}, false); // not advertised }, false); // not advertised
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", customRegistry, ClientCallImpl.prepareHeaders(m, customRegistry, Codec.Identity.NONE);
Codec.Identity.NONE);
Iterable<String> acceptedEncodings = Iterable<String> acceptedEncodings =
ACCEPT_ENCODING_SPLITER.split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); ACCEPT_ENCODING_SPLITER.split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
@ -267,8 +266,7 @@ public class ClientCallImplTest {
m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip");
m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip");
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, null, ClientCallImpl.prepareHeaders(m, DecompressorRegistry.newEmptyInstance(), Codec.Identity.NONE);
DecompressorRegistry.newEmptyInstance(), Codec.Identity.NONE);
assertNull(m.get(GrpcUtil.USER_AGENT_KEY)); assertNull(m.get(GrpcUtil.USER_AGENT_KEY));
assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));

View File

@ -114,6 +114,7 @@ public class ManagedChannelImplTest {
private final ExecutorService executor = Executors.newSingleThreadExecutor(); private final ExecutorService executor = Executors.newSingleThreadExecutor();
private final String serviceName = "fake.example.com"; private final String serviceName = "fake.example.com";
private final String authority = serviceName; private final String authority = serviceName;
private final String userAgent = "userAgent";
private final String target = "fake://" + serviceName; private final String target = "fake://" + serviceName;
private URI expectedUri; private URI expectedUri;
private final SocketAddress socketAddress = new SocketAddress() {}; private final SocketAddress socketAddress = new SocketAddress() {};
@ -146,14 +147,15 @@ public class ManagedChannelImplTest {
return new ManagedChannelImpl(target, new FakeBackoffPolicyProvider(), return new ManagedChannelImpl(target, new FakeBackoffPolicyProvider(),
nameResolverFactory, NAME_RESOLVER_PARAMS, loadBalancerFactory, nameResolverFactory, NAME_RESOLVER_PARAMS, loadBalancerFactory,
mockTransportFactory, DecompressorRegistry.getDefaultInstance(), mockTransportFactory, DecompressorRegistry.getDefaultInstance(),
CompressorRegistry.getDefaultInstance(), executor, null, interceptors); CompressorRegistry.getDefaultInstance(), executor, userAgent, interceptors);
} }
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
expectedUri = new URI(target); expectedUri = new URI(target);
when(mockTransportFactory.newClientTransport(any(SocketAddress.class), any(String.class))) when(mockTransportFactory.newClientTransport(
any(SocketAddress.class), any(String.class), any(String.class)))
.thenReturn(mockTransport); .thenReturn(mockTransport);
} }
@ -195,12 +197,13 @@ public class ManagedChannelImplTest {
// Create transport and call // Create transport and call
ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream = mock(ClientStream.class);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
when(mockTransportFactory.newClientTransport(any(SocketAddress.class), any(String.class))) when(mockTransportFactory.newClientTransport(
any(SocketAddress.class), any(String.class), any(String.class)))
.thenReturn(mockTransport); .thenReturn(mockTransport);
when(mockTransport.newStream(same(method), same(headers))).thenReturn(mockStream); when(mockTransport.newStream(same(method), same(headers))).thenReturn(mockStream);
call.start(mockCallListener, headers); call.start(mockCallListener, headers);
verify(mockTransportFactory, timeout(1000)) verify(mockTransportFactory, timeout(1000))
.newClientTransport(same(socketAddress), eq(authority)); .newClientTransport(same(socketAddress), eq(authority), eq(userAgent));
verify(mockTransport, timeout(1000)).start(transportListenerCaptor.capture()); verify(mockTransport, timeout(1000)).start(transportListenerCaptor.capture());
ManagedClientTransport.Listener transportListener = transportListenerCaptor.getValue(); ManagedClientTransport.Listener transportListener = transportListenerCaptor.getValue();
transportListener.transportReady(); transportListener.transportReady();
@ -442,7 +445,7 @@ public class ManagedChannelImplTest {
nameResolverFactory.allResolved(); nameResolverFactory.allResolved();
verify(mockTransportFactory, never()) verify(mockTransportFactory, never())
.newClientTransport(any(SocketAddress.class), any(String.class)); .newClientTransport(any(SocketAddress.class), any(String.class), any(String.class));
} }
/** /**
@ -467,9 +470,11 @@ public class ManagedChannelImplTest {
final ManagedClientTransport badTransport = mock(ManagedClientTransport.class); final ManagedClientTransport badTransport = mock(ManagedClientTransport.class);
when(goodTransport.newStream(any(MethodDescriptor.class), any(Metadata.class))) when(goodTransport.newStream(any(MethodDescriptor.class), any(Metadata.class)))
.thenReturn(mock(ClientStream.class)); .thenReturn(mock(ClientStream.class));
when(mockTransportFactory.newClientTransport(same(goodAddress), any(String.class))) when(mockTransportFactory.newClientTransport(
same(goodAddress), any(String.class), any(String.class)))
.thenReturn(goodTransport); .thenReturn(goodTransport);
when(mockTransportFactory.newClientTransport(same(badAddress), any(String.class))) when(mockTransportFactory.newClientTransport(
same(badAddress), any(String.class), any(String.class)))
.thenReturn(badTransport); .thenReturn(badTransport);
FakeNameResolverFactory nameResolverFactory = FakeNameResolverFactory nameResolverFactory =
@ -483,16 +488,17 @@ public class ManagedChannelImplTest {
ArgumentCaptor<ManagedClientTransport.Listener> badTransportListenerCaptor = ArgumentCaptor<ManagedClientTransport.Listener> badTransportListenerCaptor =
ArgumentCaptor.forClass(ManagedClientTransport.Listener.class); ArgumentCaptor.forClass(ManagedClientTransport.Listener.class);
verify(badTransport, timeout(1000)).start(badTransportListenerCaptor.capture()); verify(badTransport, timeout(1000)).start(badTransportListenerCaptor.capture());
verify(mockTransportFactory).newClientTransport(same(badAddress), any(String.class)); verify(mockTransportFactory)
.newClientTransport(same(badAddress), any(String.class), any(String.class));
verify(mockTransportFactory, times(0)) verify(mockTransportFactory, times(0))
.newClientTransport(same(goodAddress), any(String.class)); .newClientTransport(same(goodAddress), any(String.class), any(String.class));
badTransportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE); badTransportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE);
// The channel then try the second address (goodAddress) // The channel then try the second address (goodAddress)
ArgumentCaptor<ManagedClientTransport.Listener> goodTransportListenerCaptor = ArgumentCaptor<ManagedClientTransport.Listener> goodTransportListenerCaptor =
ArgumentCaptor.forClass(ManagedClientTransport.Listener.class); ArgumentCaptor.forClass(ManagedClientTransport.Listener.class);
verify(mockTransportFactory, timeout(1000)) verify(mockTransportFactory, timeout(1000))
.newClientTransport(same(goodAddress), any(String.class)); .newClientTransport(same(goodAddress), any(String.class), any(String.class));
verify(goodTransport, timeout(1000)).start(goodTransportListenerCaptor.capture()); verify(goodTransport, timeout(1000)).start(goodTransportListenerCaptor.capture());
goodTransportListenerCaptor.getValue().transportReady(); goodTransportListenerCaptor.getValue().transportReady();
verify(goodTransport, timeout(1000)).newStream(same(method), same(headers)); verify(goodTransport, timeout(1000)).newStream(same(method), same(headers));
@ -519,9 +525,9 @@ public class ManagedChannelImplTest {
final ResolvedServerInfo server2 = new ResolvedServerInfo(addr2, Attributes.EMPTY); final ResolvedServerInfo server2 = new ResolvedServerInfo(addr2, Attributes.EMPTY);
final ManagedClientTransport transport1 = mock(ManagedClientTransport.class); final ManagedClientTransport transport1 = mock(ManagedClientTransport.class);
final ManagedClientTransport transport2 = mock(ManagedClientTransport.class); final ManagedClientTransport transport2 = mock(ManagedClientTransport.class);
when(mockTransportFactory.newClientTransport(same(addr1), any(String.class))) when(mockTransportFactory.newClientTransport(same(addr1), any(String.class), any(String.class)))
.thenReturn(transport1); .thenReturn(transport1);
when(mockTransportFactory.newClientTransport(same(addr2), any(String.class))) when(mockTransportFactory.newClientTransport(same(addr2), any(String.class), any(String.class)))
.thenReturn(transport2); .thenReturn(transport2);
FakeNameResolverFactory nameResolverFactory = FakeNameResolverFactory nameResolverFactory =
@ -533,14 +539,16 @@ public class ManagedChannelImplTest {
// Start a call. The channel will starts with the first address, which will fail to connect. // Start a call. The channel will starts with the first address, which will fail to connect.
call.start(mockCallListener, headers); call.start(mockCallListener, headers);
verify(transport1, timeout(1000)).start(transportListenerCaptor.capture()); verify(transport1, timeout(1000)).start(transportListenerCaptor.capture());
verify(mockTransportFactory).newClientTransport(same(addr1), any(String.class)); verify(mockTransportFactory)
.newClientTransport(same(addr1), any(String.class), any(String.class));
verify(mockTransportFactory, times(0)) verify(mockTransportFactory, times(0))
.newClientTransport(same(addr2), any(String.class)); .newClientTransport(same(addr2), any(String.class), any(String.class));
transportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE); transportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE);
// The channel then try the second address, which will fail to connect too. // The channel then try the second address, which will fail to connect too.
verify(transport2, timeout(1000)).start(transportListenerCaptor.capture()); verify(transport2, timeout(1000)).start(transportListenerCaptor.capture());
verify(mockTransportFactory).newClientTransport(same(addr2), any(String.class)); verify(mockTransportFactory)
.newClientTransport(same(addr2), any(String.class), any(String.class));
verify(transport2, timeout(1000)).start(transportListenerCaptor.capture()); verify(transport2, timeout(1000)).start(transportListenerCaptor.capture());
transportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE); transportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE);
@ -577,7 +585,7 @@ public class ManagedChannelImplTest {
.thenReturn(mock(ClientStream.class)); .thenReturn(mock(ClientStream.class));
when(transport2.newStream(any(MethodDescriptor.class), any(Metadata.class))) when(transport2.newStream(any(MethodDescriptor.class), any(Metadata.class)))
.thenReturn(mock(ClientStream.class)); .thenReturn(mock(ClientStream.class));
when(mockTransportFactory.newClientTransport(same(addr1), any(String.class))) when(mockTransportFactory.newClientTransport(same(addr1), any(String.class), any(String.class)))
.thenReturn(transport1, transport2); .thenReturn(transport1, transport2);
FakeNameResolverFactory nameResolverFactory = FakeNameResolverFactory nameResolverFactory =
@ -588,7 +596,8 @@ public class ManagedChannelImplTest {
// First call will use the first address // First call will use the first address
call.start(mockCallListener, headers); call.start(mockCallListener, headers);
verify(mockTransportFactory, timeout(1000)).newClientTransport(same(addr1), any(String.class)); verify(mockTransportFactory, timeout(1000))
.newClientTransport(same(addr1), any(String.class), any(String.class));
verify(transport1, timeout(1000)).start(transportListenerCaptor.capture()); verify(transport1, timeout(1000)).start(transportListenerCaptor.capture());
transportListenerCaptor.getValue().transportReady(); transportListenerCaptor.getValue().transportReady();
verify(transport1, timeout(1000)).newStream(same(method), same(headers)); verify(transport1, timeout(1000)).newStream(same(method), same(headers));
@ -598,7 +607,8 @@ public class ManagedChannelImplTest {
ClientCall<String, Integer> call2 = channel.newCall(method, CallOptions.DEFAULT); ClientCall<String, Integer> call2 = channel.newCall(method, CallOptions.DEFAULT);
call2.start(mockCallListener, headers); call2.start(mockCallListener, headers);
verify(transport2, timeout(1000)).start(transportListenerCaptor.capture()); verify(transport2, timeout(1000)).start(transportListenerCaptor.capture());
verify(mockTransportFactory, times(2)).newClientTransport(same(addr1), any(String.class)); verify(mockTransportFactory, times(2))
.newClientTransport(same(addr1), any(String.class), any(String.class));
transportListenerCaptor.getValue().transportReady(); transportListenerCaptor.getValue().transportReady();
verify(transport2, timeout(1000)).newStream(same(method), same(headers)); verify(transport2, timeout(1000)).newStream(same(method), same(headers));
} }

View File

@ -87,6 +87,7 @@ import java.util.concurrent.TimeUnit;
public class ManagedChannelImplTransportManagerTest { public class ManagedChannelImplTransportManagerTest {
private static final String authority = "fakeauthority"; private static final String authority = "fakeauthority";
private static final String userAgent = "mosaic";
private final ExecutorService executor = Executors.newSingleThreadExecutor(); private final ExecutorService executor = Executors.newSingleThreadExecutor();
private final MethodDescriptor<String, String> method = MethodDescriptor.create( private final MethodDescriptor<String, String> method = MethodDescriptor.create(
@ -127,7 +128,7 @@ public class ManagedChannelImplTransportManagerTest {
channel = new ManagedChannelImpl("fake://target", mockBackoffPolicyProvider, channel = new ManagedChannelImpl("fake://target", mockBackoffPolicyProvider,
mockNameResolverFactory, Attributes.EMPTY, mockLoadBalancerFactory, mockNameResolverFactory, Attributes.EMPTY, mockLoadBalancerFactory,
mockTransportFactory, DecompressorRegistry.getDefaultInstance(), mockTransportFactory, DecompressorRegistry.getDefaultInstance(),
CompressorRegistry.getDefaultInstance(), executor, null, CompressorRegistry.getDefaultInstance(), executor, userAgent,
Collections.<ClientInterceptor>emptyList()); Collections.<ClientInterceptor>emptyList());
ArgumentCaptor<TransportManager<ClientTransport>> tmCaptor ArgumentCaptor<TransportManager<ClientTransport>> tmCaptor
@ -150,7 +151,7 @@ public class ManagedChannelImplTransportManagerTest {
SocketAddress addr = mock(SocketAddress.class); SocketAddress addr = mock(SocketAddress.class);
EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(addr); EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(addr);
ClientTransport t1 = tm.getTransport(addressGroup); ClientTransport t1 = tm.getTransport(addressGroup);
verify(mockTransportFactory, timeout(1000)).newClientTransport(addr, authority); verify(mockTransportFactory, timeout(1000)).newClientTransport(addr, authority, userAgent);
// The real transport // The real transport
MockClientTransportInfo transportInfo = transports.poll(1, TimeUnit.SECONDS); MockClientTransportInfo transportInfo = transports.poll(1, TimeUnit.SECONDS);
transportInfo.listener.transportReady(); transportInfo.listener.transportReady();
@ -175,7 +176,7 @@ public class ManagedChannelImplTransportManagerTest {
// Pick the first transport // Pick the first transport
ClientTransport t1 = tm.getTransport(addressGroup); ClientTransport t1 = tm.getTransport(addressGroup);
assertNotNull(t1); assertNotNull(t1);
verify(mockTransportFactory, timeout(1000)).newClientTransport(addr1, authority); verify(mockTransportFactory, timeout(1000)).newClientTransport(addr1, authority, userAgent);
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
// Fail the first transport, without setting it to ready // Fail the first transport, without setting it to ready
MockClientTransportInfo transportInfo = transports.poll(1, TimeUnit.SECONDS); MockClientTransportInfo transportInfo = transports.poll(1, TimeUnit.SECONDS);
@ -187,7 +188,7 @@ public class ManagedChannelImplTransportManagerTest {
assertNotNull(t2); assertNotNull(t2);
t2.newStream(method, new Metadata()); t2.newStream(method, new Metadata());
// Will keep the previous back-off policy, and not consult back-off policy // Will keep the previous back-off policy, and not consult back-off policy
verify(mockTransportFactory, timeout(1000)).newClientTransport(addr2, authority); verify(mockTransportFactory, timeout(1000)).newClientTransport(addr2, authority, userAgent);
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
transportInfo = transports.poll(1, TimeUnit.SECONDS); transportInfo = transports.poll(1, TimeUnit.SECONDS);
ClientTransport rt2 = transportInfo.transport; ClientTransport rt2 = transportInfo.transport;
@ -203,7 +204,8 @@ public class ManagedChannelImplTransportManagerTest {
// Subsequent getTransport() will use the first address, since last attempt was successful. // Subsequent getTransport() will use the first address, since last attempt was successful.
ClientTransport t3 = tm.getTransport(addressGroup); ClientTransport t3 = tm.getTransport(addressGroup);
t3.newStream(method2, new Metadata()); t3.newStream(method2, new Metadata());
verify(mockTransportFactory, timeout(1000).times(2)).newClientTransport(addr1, authority); verify(mockTransportFactory, timeout(1000).times(2))
.newClientTransport(addr1, authority, userAgent);
// Still no back-off policy creation, because an address succeeded. // Still no back-off policy creation, because an address succeeded.
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
transportInfo = transports.poll(1, TimeUnit.SECONDS); transportInfo = transports.poll(1, TimeUnit.SECONDS);
@ -236,7 +238,7 @@ public class ManagedChannelImplTransportManagerTest {
ClientTransport t1 = tm.getTransport(addressGroup); ClientTransport t1 = tm.getTransport(addressGroup);
assertNotNull(t1); assertNotNull(t1);
verify(mockTransportFactory, timeout(1000).times(++transportsAddr1)) verify(mockTransportFactory, timeout(1000).times(++transportsAddr1))
.newClientTransport(addr1, authority); .newClientTransport(addr1, authority, userAgent);
// Back-off policy was unset initially. // Back-off policy was unset initially.
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
MockClientTransportInfo transportInfo = transports.poll(1, TimeUnit.SECONDS); MockClientTransportInfo transportInfo = transports.poll(1, TimeUnit.SECONDS);
@ -250,7 +252,7 @@ public class ManagedChannelImplTransportManagerTest {
ClientTransport t2 = tm.getTransport(addressGroup); ClientTransport t2 = tm.getTransport(addressGroup);
assertNotNull(t2); assertNotNull(t2);
verify(mockTransportFactory, timeout(1000).times(++transportsAddr1)) verify(mockTransportFactory, timeout(1000).times(++transportsAddr1))
.newClientTransport(addr1, authority); .newClientTransport(addr1, authority, userAgent);
// Back-off policy was not reset. // Back-off policy was not reset.
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
transports.poll(1, TimeUnit.SECONDS).listener.transportShutdown(Status.UNAVAILABLE); transports.poll(1, TimeUnit.SECONDS).listener.transportShutdown(Status.UNAVAILABLE);
@ -260,7 +262,7 @@ public class ManagedChannelImplTransportManagerTest {
ClientTransport t3 = tm.getTransport(addressGroup); ClientTransport t3 = tm.getTransport(addressGroup);
assertNotNull(t3); assertNotNull(t3);
verify(mockTransportFactory, timeout(1000).times(++transportsAddr2)) verify(mockTransportFactory, timeout(1000).times(++transportsAddr2))
.newClientTransport(addr2, authority); .newClientTransport(addr2, authority, userAgent);
// Back-off policy was not reset. // Back-off policy was not reset.
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
transports.poll(1, TimeUnit.SECONDS).listener.transportShutdown(Status.UNAVAILABLE); transports.poll(1, TimeUnit.SECONDS).listener.transportShutdown(Status.UNAVAILABLE);
@ -272,7 +274,8 @@ public class ManagedChannelImplTransportManagerTest {
// If backoff's DelayedTransport is still active, this is necessary. Otherwise it would be racy. // If backoff's DelayedTransport is still active, this is necessary. Otherwise it would be racy.
t4.newStream(method, new Metadata()); t4.newStream(method, new Metadata());
verify(mockTransportFactory, timeout(1000).times(++transportsAddr1)) verify(mockTransportFactory, timeout(1000).times(++transportsAddr1))
.newClientTransport(addr1, authority);
.newClientTransport(addr1, authority, userAgent);
// Back-off policy was reset and consulted. // Back-off policy was reset and consulted.
verify(mockBackoffPolicyProvider, times(++backoffReset)).get(); verify(mockBackoffPolicyProvider, times(++backoffReset)).get();
verify(mockBackoffPolicy, times(++backoffConsulted)).nextBackoffMillis(); verify(mockBackoffPolicy, times(++backoffConsulted)).nextBackoffMillis();

View File

@ -97,7 +97,8 @@ final class TestUtils {
}).when(mockTransport).start(any(ManagedClientTransport.Listener.class)); }).when(mockTransport).start(any(ManagedClientTransport.Listener.class));
return mockTransport; return mockTransport;
} }
}).when(mockTransportFactory).newClientTransport(any(SocketAddress.class), any(String.class)); }).when(mockTransportFactory)
.newClientTransport(any(SocketAddress.class), any(String.class), any(String.class));
return captor; return captor;
} }

View File

@ -38,9 +38,9 @@ import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.same;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -78,6 +78,7 @@ import java.util.concurrent.BlockingQueue;
public class TransportSetTest { public class TransportSetTest {
private static final String authority = "fakeauthority"; private static final String authority = "fakeauthority";
private static final String userAgent = "mosaic";
private FakeClock fakeClock; private FakeClock fakeClock;
private FakeClock fakeExecutor; private FakeClock fakeExecutor;
@ -131,7 +132,9 @@ public class TransportSetTest {
// First attempt // First attempt
transportSet.obtainActiveTransport().newStream(method, new Metadata()); transportSet.obtainActiveTransport().newStream(method, new Metadata());
verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(++transportsCreated))
.newClientTransport(addr, authority, userAgent);
// Fail this one // Fail this one
transports.poll().listener.transportShutdown(Status.UNAVAILABLE); transports.poll().listener.transportShutdown(Status.UNAVAILABLE);
verify(mockTransportSetCallback, times(++onAllAddressesFailed)).onAllAddressesFailed(); verify(mockTransportSetCallback, times(++onAllAddressesFailed)).onAllAddressesFailed();
@ -143,9 +146,11 @@ public class TransportSetTest {
transportSet.obtainActiveTransport().newStream(method, new Metadata()); transportSet.obtainActiveTransport().newStream(method, new Metadata());
// Transport creation doesn't happen until time is due // Transport creation doesn't happen until time is due
fakeClock.forwardMillis(9); fakeClock.forwardMillis(9);
verify(mockTransportFactory, times(transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(transportsCreated))
.newClientTransport(addr, authority, userAgent);
fakeClock.forwardMillis(1); fakeClock.forwardMillis(1);
verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(++transportsCreated))
.newClientTransport(addr, authority, userAgent);
// Fail this one too // Fail this one too
transports.poll().listener.transportShutdown(Status.UNAVAILABLE); transports.poll().listener.transportShutdown(Status.UNAVAILABLE);
verify(mockTransportSetCallback, times(++onAllAddressesFailed)).onAllAddressesFailed(); verify(mockTransportSetCallback, times(++onAllAddressesFailed)).onAllAddressesFailed();
@ -157,9 +162,11 @@ public class TransportSetTest {
transportSet.obtainActiveTransport().newStream(method, new Metadata()); transportSet.obtainActiveTransport().newStream(method, new Metadata());
// Transport creation doesn't happen until time is due // Transport creation doesn't happen until time is due
fakeClock.forwardMillis(99); fakeClock.forwardMillis(99);
verify(mockTransportFactory, times(transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(transportsCreated))
.newClientTransport(addr, authority, userAgent);
fakeClock.forwardMillis(1); fakeClock.forwardMillis(1);
verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(++transportsCreated))
.newClientTransport(addr, authority, userAgent);
// Let this one succeed // Let this one succeed
transports.peek().listener.transportReady(); transports.peek().listener.transportReady();
fakeClock.runDueTasks(); fakeClock.runDueTasks();
@ -172,7 +179,8 @@ public class TransportSetTest {
// Back-off is reset, and the next attempt will happen immediately // Back-off is reset, and the next attempt will happen immediately
transportSet.obtainActiveTransport().newStream(method, new Metadata()); transportSet.obtainActiveTransport().newStream(method, new Metadata());
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(++transportsCreated))
.newClientTransport(addr, authority, userAgent);
// Final checks for consultations on back-off policies // Final checks for consultations on back-off policies
verify(mockBackoffPolicy1, times(backoff1Consulted)).nextBackoffMillis(); verify(mockBackoffPolicy1, times(backoff1Consulted)).nextBackoffMillis();
@ -199,7 +207,8 @@ public class TransportSetTest {
DelayedClientTransport delayedTransport1 = DelayedClientTransport delayedTransport1 =
(DelayedClientTransport) transportSet.obtainActiveTransport(); (DelayedClientTransport) transportSet.obtainActiveTransport();
delayedTransport1.newStream(method, new Metadata()); delayedTransport1.newStream(method, new Metadata());
verify(mockTransportFactory, times(++transportsAddr1)).newClientTransport(addr1, authority); verify(mockTransportFactory, times(++transportsAddr1))
.newClientTransport(addr1, authority, userAgent);
// Let this one fail without success // Let this one fail without success
transports.poll().listener.transportShutdown(Status.UNAVAILABLE); transports.poll().listener.transportShutdown(Status.UNAVAILABLE);
assertNull(delayedTransport1.getTransportSupplier()); assertNull(delayedTransport1.getTransportSupplier());
@ -211,7 +220,8 @@ public class TransportSetTest {
assertSame(delayedTransport1, delayedTransport2); assertSame(delayedTransport1, delayedTransport2);
delayedTransport2.newStream(method, new Metadata()); delayedTransport2.newStream(method, new Metadata());
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
verify(mockTransportFactory, times(++transportsAddr2)).newClientTransport(addr2, authority); verify(mockTransportFactory, times(++transportsAddr2))
.newClientTransport(addr2, authority, userAgent);
// Fail this one too // Fail this one too
transports.poll().listener.transportShutdown(Status.UNAVAILABLE); transports.poll().listener.transportShutdown(Status.UNAVAILABLE);
// All addresses have failed. Delayed transport will see an error. // All addresses have failed. Delayed transport will see an error.
@ -227,9 +237,11 @@ public class TransportSetTest {
assertNotSame(delayedTransport2, delayedTransport3); assertNotSame(delayedTransport2, delayedTransport3);
delayedTransport3.newStream(method, new Metadata()); delayedTransport3.newStream(method, new Metadata());
fakeClock.forwardMillis(9); fakeClock.forwardMillis(9);
verify(mockTransportFactory, times(transportsAddr1)).newClientTransport(addr1, authority); verify(mockTransportFactory, times(transportsAddr1))
.newClientTransport(addr1, authority, userAgent);
fakeClock.forwardMillis(1); fakeClock.forwardMillis(1);
verify(mockTransportFactory, times(++transportsAddr1)).newClientTransport(addr1, authority); verify(mockTransportFactory, times(++transportsAddr1))
.newClientTransport(addr1, authority, userAgent);
// Fail this one too // Fail this one too
transports.poll().listener.transportShutdown(Status.UNAVAILABLE); transports.poll().listener.transportShutdown(Status.UNAVAILABLE);
assertNull(delayedTransport3.getTransportSupplier()); assertNull(delayedTransport3.getTransportSupplier());
@ -241,7 +253,8 @@ public class TransportSetTest {
assertSame(delayedTransport3, delayedTransport4); assertSame(delayedTransport3, delayedTransport4);
delayedTransport4.newStream(method, new Metadata()); delayedTransport4.newStream(method, new Metadata());
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
verify(mockTransportFactory, times(++transportsAddr2)).newClientTransport(addr2, authority); verify(mockTransportFactory, times(++transportsAddr2))
.newClientTransport(addr2, authority, userAgent);
// Fail this one too // Fail this one too
transports.poll().listener.transportShutdown(Status.UNAVAILABLE); transports.poll().listener.transportShutdown(Status.UNAVAILABLE);
// All addresses have failed again. Delayed transport will see an error // All addresses have failed again. Delayed transport will see an error
@ -257,9 +270,11 @@ public class TransportSetTest {
assertNotSame(delayedTransport4, delayedTransport5); assertNotSame(delayedTransport4, delayedTransport5);
delayedTransport5.newStream(method, new Metadata()); delayedTransport5.newStream(method, new Metadata());
fakeClock.forwardMillis(99); fakeClock.forwardMillis(99);
verify(mockTransportFactory, times(transportsAddr1)).newClientTransport(addr1, authority); verify(mockTransportFactory, times(transportsAddr1))
.newClientTransport(addr1, authority, userAgent);
fakeClock.forwardMillis(1); fakeClock.forwardMillis(1);
verify(mockTransportFactory, times(++transportsAddr1)).newClientTransport(addr1, authority); verify(mockTransportFactory, times(++transportsAddr1))
.newClientTransport(addr1, authority, userAgent);
// Let it through // Let it through
transports.peek().listener.transportReady(); transports.peek().listener.transportReady();
// Delayed transport will see the connected transport. // Delayed transport will see the connected transport.
@ -277,7 +292,8 @@ public class TransportSetTest {
assertNotSame(delayedTransport5, delayedTransport6); assertNotSame(delayedTransport5, delayedTransport6);
delayedTransport6.newStream(method, new Metadata()); delayedTransport6.newStream(method, new Metadata());
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
verify(mockTransportFactory, times(++transportsAddr1)).newClientTransport(addr1, authority); verify(mockTransportFactory, times(++transportsAddr1))
.newClientTransport(addr1, authority, userAgent);
// Fail the transport // Fail the transport
transports.poll().listener.transportShutdown(Status.UNAVAILABLE); transports.poll().listener.transportShutdown(Status.UNAVAILABLE);
assertNull(delayedTransport6.getTransportSupplier()); assertNull(delayedTransport6.getTransportSupplier());
@ -289,7 +305,8 @@ public class TransportSetTest {
assertSame(delayedTransport6, delayedTransport7); assertSame(delayedTransport6, delayedTransport7);
delayedTransport7.newStream(method, new Metadata()); delayedTransport7.newStream(method, new Metadata());
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
verify(mockTransportFactory, times(++transportsAddr2)).newClientTransport(addr2, authority); verify(mockTransportFactory, times(++transportsAddr2))
.newClientTransport(addr2, authority, userAgent);
// Fail this one too // Fail this one too
transports.poll().listener.transportShutdown(Status.UNAVAILABLE); transports.poll().listener.transportShutdown(Status.UNAVAILABLE);
// All addresses have failed. Delayed transport will see an error. // All addresses have failed. Delayed transport will see an error.
@ -305,9 +322,11 @@ public class TransportSetTest {
assertNotSame(delayedTransport7, delayedTransport8); assertNotSame(delayedTransport7, delayedTransport8);
delayedTransport8.newStream(method, new Metadata()); delayedTransport8.newStream(method, new Metadata());
fakeClock.forwardMillis(9); fakeClock.forwardMillis(9);
verify(mockTransportFactory, times(transportsAddr1)).newClientTransport(addr1, authority); verify(mockTransportFactory, times(transportsAddr1))
.newClientTransport(addr1, authority, userAgent);
fakeClock.forwardMillis(1); fakeClock.forwardMillis(1);
verify(mockTransportFactory, times(++transportsAddr1)).newClientTransport(addr1, authority); verify(mockTransportFactory, times(++transportsAddr1))
.newClientTransport(addr1, authority, userAgent);
// Final checks on invocations on back-off policies // Final checks on invocations on back-off policies
verify(mockBackoffPolicy1, times(backoff1Consulted)).nextBackoffMillis(); verify(mockBackoffPolicy1, times(backoff1Consulted)).nextBackoffMillis();
@ -326,31 +345,37 @@ public class TransportSetTest {
int transportsCreated = 0; int transportsCreated = 0;
// Won't connect until requested // Won't connect until requested
verify(mockTransportFactory, times(transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(transportsCreated))
.newClientTransport(addr, authority, userAgent);
// First attempt // First attempt
transportSet.obtainActiveTransport().newStream(method, new Metadata()); transportSet.obtainActiveTransport().newStream(method, new Metadata());
verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(++transportsCreated))
.newClientTransport(addr, authority, userAgent);
// Fail this one // Fail this one
transports.poll().listener.transportShutdown(Status.UNAVAILABLE); transports.poll().listener.transportShutdown(Status.UNAVAILABLE);
// Won't reconnect until requested, even if back-off time has expired // Won't reconnect until requested, even if back-off time has expired
fakeClock.forwardMillis(10); fakeClock.forwardMillis(10);
verify(mockTransportFactory, times(transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(transportsCreated))
.newClientTransport(addr, authority, userAgent);
// Once requested, will reconnect // Once requested, will reconnect
transportSet.obtainActiveTransport().newStream(method, new Metadata()); transportSet.obtainActiveTransport().newStream(method, new Metadata());
verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(++transportsCreated))
.newClientTransport(addr, authority, userAgent);
// Fail this one, too // Fail this one, too
transports.poll().listener.transportShutdown(Status.UNAVAILABLE); transports.poll().listener.transportShutdown(Status.UNAVAILABLE);
// Request immediately, but will wait for back-off before reconnecting // Request immediately, but will wait for back-off before reconnecting
transportSet.obtainActiveTransport().newStream(method, new Metadata()); transportSet.obtainActiveTransport().newStream(method, new Metadata());
verify(mockTransportFactory, times(transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(transportsCreated))
.newClientTransport(addr, authority, userAgent);
fakeClock.forwardMillis(100); fakeClock.forwardMillis(100);
verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(++transportsCreated))
.newClientTransport(addr, authority, userAgent);
fakeExecutor.runDueTasks(); // Drain new 'real' stream creation; not important to this test. fakeExecutor.runDueTasks(); // Drain new 'real' stream creation; not important to this test.
} }
@ -364,7 +389,8 @@ public class TransportSetTest {
// Trigger TRANSIENT_FAILURE // Trigger TRANSIENT_FAILURE
transportSet.obtainActiveTransport().newStream(method, new Metadata()); transportSet.obtainActiveTransport().newStream(method, new Metadata());
verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(++transportsCreated))
.newClientTransport(addr, authority, userAgent);
transports.poll().listener.transportShutdown(Status.UNAVAILABLE); transports.poll().listener.transportShutdown(Status.UNAVAILABLE);
// Won't reconnect without any active streams // Won't reconnect without any active streams
@ -372,11 +398,13 @@ public class TransportSetTest {
assertTrue(transientFailureTransport instanceof DelayedClientTransport); assertTrue(transientFailureTransport instanceof DelayedClientTransport);
transientFailureTransport.newStream(method, new Metadata()).cancel(Status.CANCELLED); transientFailureTransport.newStream(method, new Metadata()).cancel(Status.CANCELLED);
fakeClock.forwardMillis(10); fakeClock.forwardMillis(10);
verify(mockTransportFactory, times(transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(transportsCreated))
.newClientTransport(addr, authority, userAgent);
// Lose race (long delay between obtainActiveTransport and newStream); will now reconnect // Lose race (long delay between obtainActiveTransport and newStream); will now reconnect
transientFailureTransport.newStream(method, new Metadata()); transientFailureTransport.newStream(method, new Metadata());
verify(mockTransportFactory, times(++transportsCreated)).newClientTransport(addr, authority); verify(mockTransportFactory, times(++transportsCreated))
.newClientTransport(addr, authority, userAgent);
fakeExecutor.runDueTasks(); // Drain new 'real' stream creation; not important to this test. fakeExecutor.runDueTasks(); // Drain new 'real' stream creation; not important to this test.
} }
@ -388,7 +416,7 @@ public class TransportSetTest {
// First transport is created immediately // First transport is created immediately
ClientTransport pick = transportSet.obtainActiveTransport(); ClientTransport pick = transportSet.obtainActiveTransport();
verify(mockTransportFactory).newClientTransport(addr, authority); verify(mockTransportFactory).newClientTransport(addr, authority, userAgent);
assertNotNull(pick); assertNotNull(pick);
// Fail this one // Fail this one
MockClientTransportInfo transportInfo = transports.poll(); MockClientTransportInfo transportInfo = transports.poll();
@ -408,11 +436,11 @@ public class TransportSetTest {
pick = transportSet.obtainActiveTransport(); pick = transportSet.obtainActiveTransport();
assertNotNull(pick); assertNotNull(pick);
assertTrue(pick instanceof FailingClientTransport); assertTrue(pick instanceof FailingClientTransport);
verify(mockTransportFactory).newClientTransport(addr, authority); verify(mockTransportFactory).newClientTransport(addr, authority, userAgent);
// Reconnect will eventually happen, even though TransportSet has been shut down // Reconnect will eventually happen, even though TransportSet has been shut down
fakeClock.forwardMillis(10); fakeClock.forwardMillis(10);
verify(mockTransportFactory, times(2)).newClientTransport(addr, authority); verify(mockTransportFactory, times(2)).newClientTransport(addr, authority, userAgent);
// The pending stream will be started on this newly started transport after it's ready. // The pending stream will be started on this newly started transport after it's ready.
// The transport is shut down by TransportSet right after the stream is created. // The transport is shut down by TransportSet right after the stream is created.
transportInfo = transports.poll(); transportInfo = transports.poll();
@ -443,7 +471,7 @@ public class TransportSetTest {
// First transport is created immediately // First transport is created immediately
ClientTransport pick = transportSet.obtainActiveTransport(); ClientTransport pick = transportSet.obtainActiveTransport();
verify(mockTransportFactory).newClientTransport(addr, authority); verify(mockTransportFactory).newClientTransport(addr, authority, userAgent);
assertNotNull(pick); assertNotNull(pick);
// Fail this one // Fail this one
MockClientTransportInfo transportInfo = transports.poll(); MockClientTransportInfo transportInfo = transports.poll();
@ -478,7 +506,7 @@ public class TransportSetTest {
transportSet.shutdown(); transportSet.shutdown();
ClientTransport pick = transportSet.obtainActiveTransport(); ClientTransport pick = transportSet.obtainActiveTransport();
assertNotNull(pick); assertNotNull(pick);
verify(mockTransportFactory, times(0)).newClientTransport(addr, authority); verify(mockTransportFactory, times(0)).newClientTransport(addr, authority, userAgent);
} }
@Test @Test
@ -490,7 +518,7 @@ public class TransportSetTest {
private void createTransportSet(SocketAddress ... addrs) { private void createTransportSet(SocketAddress ... addrs) {
addressGroup = new EquivalentAddressGroup(Arrays.asList(addrs)); addressGroup = new EquivalentAddressGroup(Arrays.asList(addrs));
transportSet = new TransportSet(addressGroup, authority, mockLoadBalancer, transportSet = new TransportSet(addressGroup, authority, userAgent, mockLoadBalancer,
mockBackoffPolicyProvider, mockTransportFactory, fakeClock.scheduledExecutorService, mockBackoffPolicyProvider, mockTransportFactory, fakeClock.scheduledExecutorService,
fakeExecutor.scheduledExecutorService, mockTransportSetCallback, fakeExecutor.scheduledExecutorService, mockTransportSetCallback,
Stopwatch.createUnstarted(fakeClock.ticker)); Stopwatch.createUnstarted(fakeClock.ticker));

View File

@ -311,23 +311,23 @@ public class NettyChannelBuilder extends AbstractManagedChannelImplBuilder<Netty
@Override @Override
public ManagedClientTransport newClientTransport( public ManagedClientTransport newClientTransport(
SocketAddress serverAddress, String authority) { SocketAddress serverAddress, String authority, @Nullable String userAgent) {
if (closed) { if (closed) {
throw new IllegalStateException("The transport factory is closed."); throw new IllegalStateException("The transport factory is closed.");
} }
ProtocolNegotiator negotiator = protocolNegotiator != null ? protocolNegotiator : ProtocolNegotiator negotiator = protocolNegotiator != null ? protocolNegotiator :
createProtocolNegotiator(authority, negotiationType, sslContext); createProtocolNegotiator(authority, negotiationType, sslContext);
return newClientTransport(serverAddress, authority, negotiator); return newClientTransport(serverAddress, authority, userAgent, negotiator);
} }
@Internal // This is strictly for internal use. Depend on this at your own peril. @Internal // This is strictly for internal use. Depend on this at your own peril.
public ManagedClientTransport newClientTransport(SocketAddress serverAddress, public ManagedClientTransport newClientTransport(SocketAddress serverAddress,
String authority, ProtocolNegotiator negotiator) { String authority, String userAgent, ProtocolNegotiator negotiator) {
if (closed) { if (closed) {
throw new IllegalStateException("The transport factory is closed."); throw new IllegalStateException("The transport factory is closed.");
} }
return new NettyClientTransport(serverAddress, channelType, group, negotiator, return new NettyClientTransport(serverAddress, channelType, group, negotiator,
flowControlWindow, maxMessageSize, maxHeaderListSize, authority); flowControlWindow, maxMessageSize, maxHeaderListSize, authority, userAgent);
} }
@Override @Override

View File

@ -56,18 +56,18 @@ import java.net.SocketAddress;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import javax.annotation.Nullable;
/** /**
* A Netty-based {@link ManagedClientTransport} implementation. * A Netty-based {@link ManagedClientTransport} implementation.
*/ */
class NettyClientTransport implements ManagedClientTransport { class NettyClientTransport implements ManagedClientTransport {
private static final AsciiString DEFAULT_AGENT =
new AsciiString(GrpcUtil.getGrpcUserAgent("netty", null));
private final SocketAddress address; private final SocketAddress address;
private final Class<? extends Channel> channelType; private final Class<? extends Channel> channelType;
private final EventLoopGroup group; private final EventLoopGroup group;
private final ProtocolNegotiator negotiator; private final ProtocolNegotiator negotiator;
private final AsciiString authority; private final AsciiString authority;
private final AsciiString userAgent;
private final int flowControlWindow; private final int flowControlWindow;
private final int maxMessageSize; private final int maxMessageSize;
private final int maxHeaderListSize; private final int maxHeaderListSize;
@ -83,7 +83,7 @@ class NettyClientTransport implements ManagedClientTransport {
NettyClientTransport(SocketAddress address, Class<? extends Channel> channelType, NettyClientTransport(SocketAddress address, Class<? extends Channel> channelType,
EventLoopGroup group, ProtocolNegotiator negotiator, EventLoopGroup group, ProtocolNegotiator negotiator,
int flowControlWindow, int maxMessageSize, int maxHeaderListSize, int flowControlWindow, int maxMessageSize, int maxHeaderListSize,
String authority) { String authority, @Nullable String userAgent) {
this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator"); this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator");
this.address = Preconditions.checkNotNull(address, "address"); this.address = Preconditions.checkNotNull(address, "address");
this.group = Preconditions.checkNotNull(group, "group"); this.group = Preconditions.checkNotNull(group, "group");
@ -92,6 +92,7 @@ class NettyClientTransport implements ManagedClientTransport {
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
this.maxHeaderListSize = maxHeaderListSize; this.maxHeaderListSize = maxHeaderListSize;
this.authority = new AsciiString(authority); this.authority = new AsciiString(authority);
this.userAgent = new AsciiString(GrpcUtil.getGrpcUserAgent("netty", userAgent));
} }
@Override @Override
@ -114,9 +115,6 @@ class NettyClientTransport implements ManagedClientTransport {
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) { public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) {
Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(method, "method");
Preconditions.checkNotNull(headers, "headers"); Preconditions.checkNotNull(headers, "headers");
AsciiString userAgent = headers.containsKey(GrpcUtil.USER_AGENT_KEY)
? new AsciiString(GrpcUtil.getGrpcUserAgent("netty", headers.get(GrpcUtil.USER_AGENT_KEY)))
: DEFAULT_AGENT;
return new NettyClientStream(method, headers, channel, handler, maxMessageSize, authority, return new NettyClientStream(method, headers, channel, handler, maxMessageSize, authority,
negotiationHandler.scheme(), userAgent) { negotiationHandler.scheme(), userAgent) {
@Override @Override

View File

@ -132,7 +132,7 @@ public class NettyClientTransportTest {
} }
@Test @Test
public void headersShouldAddDefaultUserAgent() throws Exception { public void addDefaultUserAgent() throws Exception {
startServer(); startServer();
NettyClientTransport transport = newTransport(newNegotiator()); NettyClientTransport transport = newTransport(newNegotiator());
transport.start(clientTransportListener); transport.start(clientTransportListener);
@ -148,21 +148,18 @@ public class NettyClientTransportTest {
} }
@Test @Test
public void headersShouldOverrideDefaultUserAgent() throws Exception { public void overrideDefaultUserAgent() throws Exception {
startServer(); startServer();
NettyClientTransport transport = newTransport(newNegotiator()); NettyClientTransport transport = newTransport(newNegotiator(),
DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent");
transport.start(clientTransportListener); transport.start(clientTransportListener);
// Send a single RPC and wait for the response. new Rpc(transport, new Metadata()).halfClose().waitForResponse();
String userAgent = "testUserAgent";
Metadata sentHeaders = new Metadata();
sentHeaders.put(USER_AGENT_KEY, userAgent);
new Rpc(transport, sentHeaders).halfClose().waitForResponse();
// Verify that the received headers contained the User-Agent. // Verify that the received headers contained the User-Agent.
assertEquals(1, serverListener.streamListeners.size()); assertEquals(1, serverListener.streamListeners.size());
Metadata receivedHeaders = serverListener.streamListeners.get(0).headers; Metadata receivedHeaders = serverListener.streamListeners.get(0).headers;
assertEquals(GrpcUtil.getGrpcUserAgent("netty", userAgent), assertEquals(GrpcUtil.getGrpcUserAgent("netty", "testUserAgent"),
receivedHeaders.get(USER_AGENT_KEY)); receivedHeaders.get(USER_AGENT_KEY));
} }
@ -171,7 +168,7 @@ public class NettyClientTransportTest {
startServer(); startServer();
// Allow the response payloads of up to 1 byte. // Allow the response payloads of up to 1 byte.
NettyClientTransport transport = newTransport(newNegotiator(), NettyClientTransport transport = newTransport(newNegotiator(),
1, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); 1, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null);
transport.start(clientTransportListener); transport.start(clientTransportListener);
try { try {
@ -248,7 +245,8 @@ public class NettyClientTransportTest {
public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception { public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception {
startServer(); startServer();
NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, 1); NettyClientTransport transport =
newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, 1, null);
transport.start(clientTransportListener); transport.start(clientTransportListener);
try { try {
@ -298,13 +296,14 @@ public class NettyClientTransportTest {
private NettyClientTransport newTransport(ProtocolNegotiator negotiator) { private NettyClientTransport newTransport(ProtocolNegotiator negotiator) {
return newTransport(negotiator, return newTransport(negotiator,
DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null /* user agent */);
} }
private NettyClientTransport newTransport(ProtocolNegotiator negotiator, private NettyClientTransport newTransport(ProtocolNegotiator negotiator,
int maxMsgSize, int maxHeaderListSize) { int maxMsgSize, int maxHeaderListSize, String userAgent) {
NettyClientTransport transport = new NettyClientTransport(address, NioSocketChannel.class, NettyClientTransport transport = new NettyClientTransport(address, NioSocketChannel.class,
group, negotiator, DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize, authority); group, negotiator, DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize, authority,
userAgent);
transports.add(transport); transports.add(transport);
return transport; return transport;
} }

View File

@ -75,12 +75,13 @@ public class NettyTransportTest extends AbstractTransportTest {
@Override @Override
protected ManagedClientTransport newClientTransport() { protected ManagedClientTransport newClientTransport() {
return clientFactory.newClientTransport( return clientFactory.newClientTransport(
new InetSocketAddress("localhost", SERVER_PORT), "localhost:" + SERVER_PORT); new InetSocketAddress("localhost", SERVER_PORT),
"localhost:" + SERVER_PORT,
null /* agent */);
} }
// TODO(ejona): Flaky
@Test @Test
@Ignore @Ignore("flaky")
@Override @Override
public void flowControlPushBack() {} public void flowControlPushBack() {}
} }

View File

@ -46,6 +46,8 @@ import okio.ByteString;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import javax.annotation.Nullable;
/** /**
* Constants for request/response headers. * Constants for request/response headers.
*/ */
@ -63,7 +65,7 @@ public class Headers {
* application thread context. * application thread context.
*/ */
public static List<Header> createRequestHeaders(Metadata headers, String defaultPath, public static List<Header> createRequestHeaders(Metadata headers, String defaultPath,
String authority) { String authority, @Nullable String applicationUserAgent) {
Preconditions.checkNotNull(headers, "headers"); Preconditions.checkNotNull(headers, "headers");
Preconditions.checkNotNull(defaultPath, "defaultPath"); Preconditions.checkNotNull(defaultPath, "defaultPath");
Preconditions.checkNotNull(authority, "authority"); Preconditions.checkNotNull(authority, "authority");
@ -79,7 +81,7 @@ public class Headers {
String path = defaultPath; String path = defaultPath;
okhttpHeaders.add(new Header(Header.TARGET_PATH, path)); okhttpHeaders.add(new Header(Header.TARGET_PATH, path));
String userAgent = GrpcUtil.getGrpcUserAgent("okhttp", headers.get(USER_AGENT_KEY)); String userAgent = GrpcUtil.getGrpcUserAgent("okhttp", applicationUserAgent);
okhttpHeaders.add(new Header(GrpcUtil.USER_AGENT_KEY.name(), userAgent)); okhttpHeaders.add(new Header(GrpcUtil.USER_AGENT_KEY.name(), userAgent));
// All non-pseudo headers must come after pseudo headers. // All non-pseudo headers must come after pseudo headers.

View File

@ -260,13 +260,14 @@ public class OkHttpChannelBuilder extends
} }
@Override @Override
public ManagedClientTransport newClientTransport(SocketAddress addr, String authority) { public ManagedClientTransport newClientTransport(
SocketAddress addr, String authority, @Nullable String userAgent) {
if (closed) { if (closed) {
throw new IllegalStateException("The transport factory is closed."); throw new IllegalStateException("The transport factory is closed.");
} }
InetSocketAddress inetSocketAddr = (InetSocketAddress) addr; InetSocketAddress inetSocketAddr = (InetSocketAddress) addr;
return new OkHttpClientTransport(inetSocketAddr, authority, executor, socketFactory, return new OkHttpClientTransport(inetSocketAddr, authority, userAgent, executor,
Utils.convertSpec(connectionSpec), maxMessageSize); socketFactory, Utils.convertSpec(connectionSpec), maxMessageSize);
} }
@Override @Override

View File

@ -34,7 +34,6 @@ package io.grpc.okhttp;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import io.grpc.Metadata;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
@ -73,6 +72,7 @@ class OkHttpClientStream extends Http2ClientStream {
private final OutboundFlowController outboundFlow; private final OutboundFlowController outboundFlow;
private final OkHttpClientTransport transport; private final OkHttpClientTransport transport;
private final Object lock; private final Object lock;
private final String userAgent;
private String authority; private String authority;
private Object outboundFlowState; private Object outboundFlowState;
private volatile Integer id; private volatile Integer id;
@ -95,7 +95,8 @@ class OkHttpClientStream extends Http2ClientStream {
OutboundFlowController outboundFlow, OutboundFlowController outboundFlow,
Object lock, Object lock,
int maxMessageSize, int maxMessageSize,
String authority) { String authority,
@Nullable String userAgent) {
super(new OkHttpWritableBufferAllocator(), maxMessageSize); super(new OkHttpWritableBufferAllocator(), maxMessageSize);
this.method = method; this.method = method;
this.headers = headers; this.headers = headers;
@ -104,6 +105,7 @@ class OkHttpClientStream extends Http2ClientStream {
this.outboundFlow = outboundFlow; this.outboundFlow = outboundFlow;
this.lock = lock; this.lock = lock;
this.authority = authority; this.authority = authority;
this.userAgent = userAgent;
} }
/** /**
@ -136,7 +138,8 @@ class OkHttpClientStream extends Http2ClientStream {
public void start(ClientStreamListener listener) { public void start(ClientStreamListener listener) {
super.start(listener); super.start(listener);
String defaultPath = "/" + method.getFullMethodName(); String defaultPath = "/" + method.getFullMethodName();
List<Header> requestHeaders = Headers.createRequestHeaders(headers, defaultPath, authority); List<Header> requestHeaders =
Headers.createRequestHeaders(headers, defaultPath, authority, userAgent);
headers = null; headers = null;
synchronized (lock) { synchronized (lock) {
this.requestHeaders = requestHeaders; this.requestHeaders = requestHeaders;

View File

@ -125,6 +125,7 @@ class OkHttpClientTransport implements ManagedClientTransport {
private final InetSocketAddress address; private final InetSocketAddress address;
private final String defaultAuthority; private final String defaultAuthority;
private final String userAgent;
private final Random random = new Random(); private final Random random = new Random();
private final Ticker ticker; private final Ticker ticker;
private Listener listener; private Listener listener;
@ -168,8 +169,8 @@ class OkHttpClientTransport implements ManagedClientTransport {
Runnable connectingCallback; Runnable connectingCallback;
SettableFuture<Void> connectedFuture; SettableFuture<Void> connectedFuture;
OkHttpClientTransport(InetSocketAddress address, String authority, Executor executor, OkHttpClientTransport(InetSocketAddress address, String authority, @Nullable String userAgent,
@Nullable SSLSocketFactory sslSocketFactory, ConnectionSpec connectionSpec, Executor executor, @Nullable SSLSocketFactory sslSocketFactory, ConnectionSpec connectionSpec,
int maxMessageSize) { int maxMessageSize) {
this.address = Preconditions.checkNotNull(address, "address"); this.address = Preconditions.checkNotNull(address, "address");
this.defaultAuthority = authority; this.defaultAuthority = authority;
@ -182,19 +183,21 @@ class OkHttpClientTransport implements ManagedClientTransport {
this.sslSocketFactory = sslSocketFactory; this.sslSocketFactory = sslSocketFactory;
this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec"); this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec");
this.ticker = Ticker.systemTicker(); this.ticker = Ticker.systemTicker();
this.userAgent = userAgent;
} }
/** /**
* Create a transport connected to a fake peer for test. * Create a transport connected to a fake peer for test.
*/ */
@VisibleForTesting @VisibleForTesting
OkHttpClientTransport(Executor executor, FrameReader frameReader, FrameWriter testFrameWriter, OkHttpClientTransport(String userAgent, Executor executor, FrameReader frameReader,
int nextStreamId, Socket socket, Ticker ticker, FrameWriter testFrameWriter, int nextStreamId, Socket socket, Ticker ticker,
@Nullable Runnable connectingCallback, SettableFuture<Void> connectedFuture, @Nullable Runnable connectingCallback, SettableFuture<Void> connectedFuture,
int maxMessageSize) { int maxMessageSize) {
address = null; address = null;
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
defaultAuthority = "notarealauthority:80"; defaultAuthority = "notarealauthority:80";
this.userAgent = userAgent;
this.executor = Preconditions.checkNotNull(executor); this.executor = Preconditions.checkNotNull(executor);
serializingExecutor = new SerializingExecutor(executor); serializingExecutor = new SerializingExecutor(executor);
this.testFrameReader = Preconditions.checkNotNull(frameReader); this.testFrameReader = Preconditions.checkNotNull(frameReader);
@ -247,7 +250,7 @@ class OkHttpClientTransport implements ManagedClientTransport {
Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(method, "method");
Preconditions.checkNotNull(headers, "headers"); Preconditions.checkNotNull(headers, "headers");
return new OkHttpClientStream(method, headers, frameWriter, OkHttpClientTransport.this, return new OkHttpClientStream(method, headers, frameWriter, OkHttpClientTransport.this,
outboundFlow, lock, maxMessageSize, defaultAuthority); outboundFlow, lock, maxMessageSize, defaultAuthority, userAgent);
} }
@GuardedBy("lock") @GuardedBy("lock")

View File

@ -75,7 +75,7 @@ public class OkHttpClientStreamTest {
methodDescriptor = MethodDescriptor.create( methodDescriptor = MethodDescriptor.create(
MethodType.UNARY, "/testService/test", marshaller, marshaller); MethodType.UNARY, "/testService/test", marshaller, marshaller);
stream = new OkHttpClientStream(methodDescriptor, new Metadata(), frameWriter, transport, stream = new OkHttpClientStream(methodDescriptor, new Metadata(), frameWriter, transport,
flowController, lock, MAX_MESSAGE_SIZE, "localhost"); flowController, lock, MAX_MESSAGE_SIZE, "localhost", "userAgent");
} }
@Test @Test

View File

@ -159,20 +159,20 @@ public class OkHttpClientTransportTest {
} }
private void initTransport() throws Exception { private void initTransport() throws Exception {
startTransport(3, null, true, DEFAULT_MAX_MESSAGE_SIZE); startTransport(3, null, true, DEFAULT_MAX_MESSAGE_SIZE, null);
} }
private void initTransport(int startId) throws Exception { private void initTransport(int startId) throws Exception {
startTransport(startId, null, true, DEFAULT_MAX_MESSAGE_SIZE); startTransport(startId, null, true, DEFAULT_MAX_MESSAGE_SIZE, null);
} }
private void initTransportAndDelayConnected() throws Exception { private void initTransportAndDelayConnected() throws Exception {
delayConnectedCallback = new DelayConnectedCallback(); delayConnectedCallback = new DelayConnectedCallback();
startTransport(3, delayConnectedCallback, false, DEFAULT_MAX_MESSAGE_SIZE); startTransport(3, delayConnectedCallback, false, DEFAULT_MAX_MESSAGE_SIZE, null);
} }
private void startTransport(int startId, @Nullable Runnable connectingCallback, private void startTransport(int startId, @Nullable Runnable connectingCallback,
boolean waitingForConnected, int maxMessageSize) throws Exception { boolean waitingForConnected, int maxMessageSize, String userAgent) throws Exception {
connectedFuture = SettableFuture.create(); connectedFuture = SettableFuture.create();
Ticker ticker = new Ticker() { Ticker ticker = new Ticker() {
@Override @Override
@ -180,10 +180,9 @@ public class OkHttpClientTransportTest {
return nanoTime; return nanoTime;
} }
}; };
clientTransport = new OkHttpClientTransport( clientTransport = new OkHttpClientTransport(userAgent, executor, frameReader,
executor, frameReader, frameWriter, startId, frameWriter, startId, new MockSocket(frameReader), ticker, connectingCallback,
new MockSocket(frameReader), ticker, connectingCallback, connectedFuture, connectedFuture, maxMessageSize);
maxMessageSize);
clientTransport.start(transportListener); clientTransport.start(transportListener);
if (waitingForConnected) { if (waitingForConnected) {
connectedFuture.get(TIME_OUT_MS, TimeUnit.MILLISECONDS); connectedFuture.get(TIME_OUT_MS, TimeUnit.MILLISECONDS);
@ -194,7 +193,7 @@ public class OkHttpClientTransportTest {
public void testToString() throws Exception { public void testToString() throws Exception {
InetSocketAddress address = InetSocketAddress.createUnresolved("hostname", 31415); InetSocketAddress address = InetSocketAddress.createUnresolved("hostname", 31415);
clientTransport = new OkHttpClientTransport( clientTransport = new OkHttpClientTransport(
address, "hostname", executor, null, address, "hostname", null /* agent */, executor, null,
Utils.convertSpec(OkHttpChannelBuilder.DEFAULT_CONNECTION_SPEC), DEFAULT_MAX_MESSAGE_SIZE); Utils.convertSpec(OkHttpChannelBuilder.DEFAULT_CONNECTION_SPEC), DEFAULT_MAX_MESSAGE_SIZE);
String s = clientTransport.toString(); String s = clientTransport.toString();
assertTrue("Unexpected: " + s, s.contains("OkHttpClientTransport")); assertTrue("Unexpected: " + s, s.contains("OkHttpClientTransport"));
@ -204,7 +203,7 @@ public class OkHttpClientTransportTest {
@Test @Test
public void maxMessageSizeShouldBeEnforced() throws Exception { public void maxMessageSizeShouldBeEnforced() throws Exception {
// Allow the response payloads of up to 1 byte. // Allow the response payloads of up to 1 byte.
startTransport(3, null, true, 1); startTransport(3, null, true, 1, null);
MockStreamListener listener = new MockStreamListener(); MockStreamListener listener = new MockStreamListener();
OkHttpClientStream stream = clientTransport.newStream(method, new Metadata()); OkHttpClientStream stream = clientTransport.newStream(method, new Metadata());
@ -405,7 +404,7 @@ public class OkHttpClientTransportTest {
} }
@Test @Test
public void headersShouldAddDefaultUserAgent() throws Exception { public void addDefaultUserAgent() throws Exception {
initTransport(); initTransport();
MockStreamListener listener = new MockStreamListener(); MockStreamListener listener = new MockStreamListener();
OkHttpClientStream stream = clientTransport.newStream(method, new Metadata()); OkHttpClientStream stream = clientTransport.newStream(method, new Metadata());
@ -423,19 +422,16 @@ public class OkHttpClientTransportTest {
} }
@Test @Test
public void headersShouldOverrideDefaultUserAgent() throws Exception { public void overrideDefaultUserAgent() throws Exception {
initTransport(); startTransport(3, null, true, DEFAULT_MAX_MESSAGE_SIZE, "fakeUserAgent");
MockStreamListener listener = new MockStreamListener(); MockStreamListener listener = new MockStreamListener();
String userAgent = "fakeUserAgent"; OkHttpClientStream stream = clientTransport.newStream(method, new Metadata());
Metadata metadata = new Metadata();
metadata.put(GrpcUtil.USER_AGENT_KEY, userAgent);
OkHttpClientStream stream = clientTransport.newStream(method, metadata);
stream.start(listener); stream.start(listener);
List<Header> expectedHeaders = Arrays.asList(SCHEME_HEADER, METHOD_HEADER, List<Header> expectedHeaders = Arrays.asList(SCHEME_HEADER, METHOD_HEADER,
new Header(Header.TARGET_AUTHORITY, "notarealauthority:80"), new Header(Header.TARGET_AUTHORITY, "notarealauthority:80"),
new Header(Header.TARGET_PATH, "/fakemethod"), new Header(Header.TARGET_PATH, "/fakemethod"),
new Header(GrpcUtil.USER_AGENT_KEY.name(), new Header(GrpcUtil.USER_AGENT_KEY.name(),
GrpcUtil.getGrpcUserAgent("okhttp", userAgent)), GrpcUtil.getGrpcUserAgent("okhttp", "fakeUserAgent")),
CONTENT_TYPE_HEADER, TE_HEADER); CONTENT_TYPE_HEADER, TE_HEADER);
verify(frameWriter, timeout(TIME_OUT_MS)) verify(frameWriter, timeout(TIME_OUT_MS))
.synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders)); .synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders));
@ -1311,6 +1307,7 @@ public class OkHttpClientTransportTest {
clientTransport = new OkHttpClientTransport( clientTransport = new OkHttpClientTransport(
new InetSocketAddress("host", 1234), new InetSocketAddress("host", 1234),
"invalid_authority", "invalid_authority",
"userAgent",
executor, executor,
null, null,
ConnectionSpec.CLEARTEXT, ConnectionSpec.CLEARTEXT,
@ -1328,6 +1325,7 @@ public class OkHttpClientTransportTest {
clientTransport = new OkHttpClientTransport( clientTransport = new OkHttpClientTransport(
new InetSocketAddress("localhost", 0), new InetSocketAddress("localhost", 0),
"authority", "authority",
"userAgent",
executor, executor,
null, null,
ConnectionSpec.CLEARTEXT, ConnectionSpec.CLEARTEXT,

View File

@ -73,7 +73,9 @@ public class OkHttpTransportTest extends AbstractTransportTest {
@Override @Override
protected ManagedClientTransport newClientTransport() { protected ManagedClientTransport newClientTransport() {
return clientFactory.newClientTransport( return clientFactory.newClientTransport(
new InetSocketAddress("127.0.0.1", SERVER_PORT), "127.0.0.1:" + SERVER_PORT); new InetSocketAddress("127.0.0.1", SERVER_PORT),
"127.0.0.1:" + SERVER_PORT,
null /* agent */);
} }
// TODO(ejona): Flaky/Broken // TODO(ejona): Flaky/Broken

View File

@ -58,6 +58,6 @@ public abstract class AbstractClientTransportFactoryTest {
ClientTransportFactory transportFactory = newClientTransportFactory(); ClientTransportFactory transportFactory = newClientTransportFactory();
transportFactory.close(); transportFactory.close();
transportFactory.newClientTransport( transportFactory.newClientTransport(
new InetSocketAddress("localhost", port), "localhost:" + port); new InetSocketAddress("localhost", port), "localhost:" + port, "agent");
} }
} }