Use generics for LoadBalancer to avoid ClientTransport exposure

TransportManager.makeTransport() was added to remove the only other
reference to ClientTransport outside of core and the transports.
This commit is contained in:
Eric Anderson 2016-01-23 10:07:13 -08:00
parent bf42913c23
commit 1488010cb1
14 changed files with 136 additions and 137 deletions

View File

@ -33,8 +33,6 @@ package io.grpc;
import com.google.common.util.concurrent.ListenableFuture;
import io.grpc.internal.ClientTransport;
import java.util.List;
import javax.annotation.Nullable;
@ -46,12 +44,14 @@ import javax.annotation.concurrent.ThreadSafe;
*
* <p>Note to implementations: all methods are expected to return quickly. Any work that may block
* should be done asynchronously.
*
* @param T the transport type to balance
*/
// TODO(zhangkun83): since it's also used for non-loadbalancing cases like pick-first,
// "RequestRouter" might be a better name.
@ExperimentalApi
@ThreadSafe
public abstract class LoadBalancer {
public abstract class LoadBalancer<T> {
/**
* Pick a transport that Channel will use for next RPC.
*
@ -61,8 +61,7 @@ public abstract class LoadBalancer {
*
* @param requestKey for affinity-based routing
*/
public abstract ListenableFuture<ClientTransport> pickTransport(
@Nullable RequestKey requestKey);
public abstract ListenableFuture<T> pickTransport(@Nullable RequestKey requestKey);
/**
* Shuts down this {@code LoadBalancer}.
@ -86,13 +85,12 @@ public abstract class LoadBalancer {
/**
* Called when a transport is fully connected and ready to accept traffic.
*/
public void transportReady(EquivalentAddressGroup addressGroup, ClientTransport transport) { }
public void transportReady(EquivalentAddressGroup addressGroup, T transport) { }
/**
* Called when a transport is shutting down.
*/
public void transportShutdown(
EquivalentAddressGroup addressGroup, ClientTransport transport, Status s) { }
public void transportShutdown(EquivalentAddressGroup addressGroup, T transport, Status s) { }
public abstract static class Factory {
/**
@ -102,6 +100,6 @@ public abstract class LoadBalancer {
* @param tm the interface where an {@code LoadBalancer} implementation gets connected
* transports from
*/
public abstract LoadBalancer newLoadBalancer(String serviceName, TransportManager tm);
public abstract <T> LoadBalancer<T> newLoadBalancer(String serviceName, TransportManager<T> tm);
}
}

View File

@ -36,7 +36,6 @@ import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.grpc.internal.BlankFutureProvider;
import io.grpc.internal.ClientTransport;
import java.net.SocketAddress;
import java.util.ArrayList;
@ -63,29 +62,28 @@ public final class SimpleLoadBalancerFactory extends LoadBalancer.Factory {
}
@Override
public LoadBalancer newLoadBalancer(String serviceName, TransportManager tm) {
return new SimpleLoadBalancer(tm);
public <T> LoadBalancer<T> newLoadBalancer(String serviceName, TransportManager<T> tm) {
return new SimpleLoadBalancer<T>(tm);
}
private static class SimpleLoadBalancer extends LoadBalancer {
private static class SimpleLoadBalancer<T> extends LoadBalancer<T> {
private final Object lock = new Object();
@GuardedBy("lock")
private EquivalentAddressGroup addresses;
@GuardedBy("lock")
private final BlankFutureProvider<ClientTransport> pendingPicks =
new BlankFutureProvider<ClientTransport>();
private final BlankFutureProvider<T> pendingPicks = new BlankFutureProvider<T>();
@GuardedBy("lock")
private StatusException nameResolutionError;
private final TransportManager tm;
private final TransportManager<T> tm;
private SimpleLoadBalancer(TransportManager tm) {
private SimpleLoadBalancer(TransportManager<T> tm) {
this.tm = tm;
}
@Override
public ListenableFuture<ClientTransport> pickTransport(@Nullable RequestKey requestKey) {
public ListenableFuture<T> pickTransport(@Nullable RequestKey requestKey) {
EquivalentAddressGroup addressesCopy;
synchronized (lock) {
addressesCopy = addresses;
@ -102,7 +100,7 @@ public final class SimpleLoadBalancerFactory extends LoadBalancer.Factory {
@Override
public void handleResolvedAddresses(
List<ResolvedServerInfo> updatedServers, Attributes config) {
BlankFutureProvider.FulfillmentBatch<ClientTransport> pendingPicksFulfillmentBatch;
BlankFutureProvider.FulfillmentBatch<T> pendingPicksFulfillmentBatch;
final EquivalentAddressGroup newAddresses;
synchronized (lock) {
ArrayList<SocketAddress> newAddressList =
@ -118,8 +116,8 @@ public final class SimpleLoadBalancerFactory extends LoadBalancer.Factory {
nameResolutionError = null;
pendingPicksFulfillmentBatch = pendingPicks.createFulfillmentBatch();
}
pendingPicksFulfillmentBatch.link(new Supplier<ListenableFuture<ClientTransport>>() {
@Override public ListenableFuture<ClientTransport> get() {
pendingPicksFulfillmentBatch.link(new Supplier<ListenableFuture<T>>() {
@Override public ListenableFuture<T> get() {
return tm.getTransport(newAddresses);
}
});
@ -127,7 +125,7 @@ public final class SimpleLoadBalancerFactory extends LoadBalancer.Factory {
@Override
public void handleNameResolutionError(Status error) {
BlankFutureProvider.FulfillmentBatch<ClientTransport> pendingPicksFulfillmentBatch;
BlankFutureProvider.FulfillmentBatch<T> pendingPicksFulfillmentBatch;
StatusException statusException =
error.augmentDescription("Name resolution failed").asException();
synchronized (lock) {

View File

@ -33,15 +33,13 @@ package io.grpc;
import com.google.common.util.concurrent.ListenableFuture;
import io.grpc.internal.ClientTransport;
import java.util.Collection;
/**
* Manages transport life-cycles and provide ready-to-use transports.
*/
@ExperimentalApi
public abstract class TransportManager {
public abstract class TransportManager<T> {
/**
* Advises this {@code TransportManager} to retain transports only to these servers, for warming
* up connections and discarding unused connections.
@ -59,6 +57,10 @@ public abstract class TransportManager {
// TODO(zhangkun83): GrpcLoadBalancer will use this to get transport to connect to LB servers,
// which would have a different authority than the primary servers. We need to figure out how to
// do it.
public abstract ListenableFuture<ClientTransport> getTransport(
EquivalentAddressGroup addressGroup);
public abstract ListenableFuture<T> getTransport(EquivalentAddressGroup addressGroup);
/**
* Returns a channel that uses {@code transport}; useful for issuing RPCs on a transport.
*/
public abstract Channel makeChannel(T transport);
}

View File

@ -125,7 +125,7 @@ public final class ManagedChannelImpl extends ManagedChannel {
private final Channel interceptorChannel;
private final NameResolver nameResolver;
private final LoadBalancer loadBalancer;
private final LoadBalancer<ClientTransport> loadBalancer;
/**
* Maps EquivalentAddressGroups to transports for that server.
@ -357,7 +357,7 @@ public final class ManagedChannelImpl extends ManagedChannel {
transportFactory.release();
}
private final TransportManager tm = new TransportManager() {
private final TransportManager<ClientTransport> tm = new TransportManager<ClientTransport>() {
@Override
public void updateRetainedTransports(Collection<EquivalentAddressGroup> addrs) {
// TODO(zhangkun83): warm-up new servers and discard removed servers.
@ -396,5 +396,11 @@ public final class ManagedChannelImpl extends ManagedChannel {
}
return ts.obtainActiveTransport();
}
@Override
public Channel makeChannel(ClientTransport transport) {
return new SingleTransportChannel(
transport, executor, scheduledExecutor, authority());
}
};
}

View File

@ -50,7 +50,7 @@ import java.util.concurrent.ScheduledExecutorService;
/**
* A {@link Channel} that wraps a {@link ClientTransport}.
*/
public final class SingleTransportChannel extends Channel {
final class SingleTransportChannel extends Channel {
private final ClientTransport transport;
private final Executor executor;

View File

@ -105,7 +105,7 @@ final class TransportSet {
@GuardedBy("lock")
private final Collection<ClientTransport> transports = new ArrayList<ClientTransport>();
private final LoadBalancer loadBalancer;
private final LoadBalancer<ClientTransport> loadBalancer;
@GuardedBy("lock")
private boolean shutdown;
@ -117,17 +117,19 @@ final class TransportSet {
@Nullable
private volatile UncancellableTransportFuture activeTransportFuture;
TransportSet(EquivalentAddressGroup addressGroup, String authority, LoadBalancer loadBalancer,
BackoffPolicy.Provider backoffPolicyProvider, ClientTransportFactory transportFactory,
ScheduledExecutorService scheduledExecutor, Callback callback) {
TransportSet(EquivalentAddressGroup addressGroup, String authority,
LoadBalancer<ClientTransport> loadBalancer, BackoffPolicy.Provider backoffPolicyProvider,
ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor,
Callback callback) {
this(addressGroup, authority, loadBalancer, backoffPolicyProvider, transportFactory,
scheduledExecutor, callback, Stopwatch.createUnstarted());
}
@VisibleForTesting
TransportSet(EquivalentAddressGroup addressGroup, String authority, LoadBalancer loadBalancer,
BackoffPolicy.Provider backoffPolicyProvider, ClientTransportFactory transportFactory,
ScheduledExecutorService scheduledExecutor, Callback callback, Stopwatch backoffWatch) {
TransportSet(EquivalentAddressGroup addressGroup, String authority,
LoadBalancer<ClientTransport> loadBalancer, BackoffPolicy.Provider backoffPolicyProvider,
ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor,
Callback callback, Stopwatch backoffWatch) {
this.addressGroup = Preconditions.checkNotNull(addressGroup, "addressGroup");
this.authority = authority;
this.loadBalancer = loadBalancer;

View File

@ -37,7 +37,6 @@ import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -47,8 +46,6 @@ import static org.mockito.Mockito.when;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.internal.ClientTransport;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -62,13 +59,13 @@ import java.util.ArrayList;
/** Unit test for {@link SimpleLoadBalancerFactory}. */
@RunWith(JUnit4.class)
public class SimpleLoadBalancerTest {
private LoadBalancer loadBalancer;
private LoadBalancer<Transport> loadBalancer;
private ArrayList<ResolvedServerInfo> servers;
private EquivalentAddressGroup addressGroup;
@Mock
private TransportManager mockTransportManager;
private TransportManager<Transport> mockTransportManager;
@Before
public void setUp() {
@ -87,12 +84,12 @@ public class SimpleLoadBalancerTest {
@Test
public void pickBeforeResolved() throws Exception {
ClientTransport mockTransport = mock(ClientTransport.class);
SettableFuture<ClientTransport> sourceFuture = SettableFuture.create();
Transport mockTransport = new Transport();
SettableFuture<Transport> sourceFuture = SettableFuture.create();
when(mockTransportManager.getTransport(eq(addressGroup)))
.thenReturn(sourceFuture);
ListenableFuture<ClientTransport> f1 = loadBalancer.pickTransport(null);
ListenableFuture<ClientTransport> f2 = loadBalancer.pickTransport(null);
ListenableFuture<Transport> f1 = loadBalancer.pickTransport(null);
ListenableFuture<Transport> f2 = loadBalancer.pickTransport(null);
assertNotNull(f1);
assertNotNull(f2);
assertNotSame(f1, f2);
@ -113,12 +110,12 @@ public class SimpleLoadBalancerTest {
@Test
public void pickAfterResolved() throws Exception {
ClientTransport mockTransport = mock(ClientTransport.class);
SettableFuture<ClientTransport> sourceFuture = SettableFuture.create();
Transport mockTransport = new Transport();
SettableFuture<Transport> sourceFuture = SettableFuture.create();
when(mockTransportManager.getTransport(eq(addressGroup)))
.thenReturn(sourceFuture);
loadBalancer.handleResolvedAddresses(servers, Attributes.EMPTY);
ListenableFuture<ClientTransport> f = loadBalancer.pickTransport(null);
ListenableFuture<Transport> f = loadBalancer.pickTransport(null);
assertSame(sourceFuture, f);
assertFalse(f.isDone());
sourceFuture.set(mockTransport);
@ -139,4 +136,5 @@ public class SimpleLoadBalancerTest {
}
}
private static class Transport {}
}

View File

@ -529,15 +529,15 @@ public class ManagedChannelImplTest {
private class SpyingLoadBalancerFactory extends LoadBalancer.Factory {
private final LoadBalancer.Factory delegate;
private final List<LoadBalancer> balancers = new ArrayList<LoadBalancer>();
private final List<LoadBalancer<?>> balancers = new ArrayList<LoadBalancer<?>>();
private SpyingLoadBalancerFactory(LoadBalancer.Factory delegate) {
this.delegate = delegate;
}
@Override
public LoadBalancer newLoadBalancer(String serviceName, TransportManager tm) {
LoadBalancer lb = spy(delegate.newLoadBalancer(serviceName, tm));
public <T> LoadBalancer<T> newLoadBalancer(String serviceName, TransportManager<T> tm) {
LoadBalancer<T> lb = spy(delegate.newLoadBalancer(serviceName, tm));
balancers.add(lb);
return lb;
}

View File

@ -62,6 +62,7 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Matchers;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
@ -114,15 +115,18 @@ public class ManagedChannelImplTransportManagerTest {
@Mock private BackoffPolicy.Provider mockBackoffPolicyProvider;
@Mock private BackoffPolicy mockBackoffPolicy;
private TransportManager tm;
private TransportManager<ClientTransport> tm;
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
when(mockBackoffPolicyProvider.get()).thenReturn(mockBackoffPolicy);
when(mockLoadBalancerFactory.newLoadBalancer(anyString(), any(TransportManager.class)))
.thenReturn(mock(LoadBalancer.class));
@SuppressWarnings("unchecked")
LoadBalancer<ClientTransport> loadBalancer = mock(LoadBalancer.class);
when(mockLoadBalancerFactory
.newLoadBalancer(anyString(), Matchers.<TransportManager<ClientTransport>>any()))
.thenReturn(loadBalancer);
channel = new ManagedChannelImpl("fake://target", mockBackoffPolicyProvider,
nameResolverFactory, Attributes.EMPTY, mockLoadBalancerFactory,
@ -130,7 +134,8 @@ public class ManagedChannelImplTransportManagerTest {
CompressorRegistry.getDefaultInstance(), executor, null,
Collections.<ClientInterceptor>emptyList());
ArgumentCaptor<TransportManager> tmCaptor = ArgumentCaptor.forClass(TransportManager.class);
ArgumentCaptor<TransportManager<ClientTransport>> tmCaptor
= ArgumentCaptor.forClass(null);
verify(mockLoadBalancerFactory).newLoadBalancer(anyString(), tmCaptor.capture());
tm = tmCaptor.getValue();
}
@ -259,5 +264,4 @@ public class ManagedChannelImplTransportManagerTest {
verify(mockBackoffPolicyProvider, times(backoffReset)).get();
verify(mockBackoffPolicy, times(++backoffConsulted)).nextBackoffMillis();
}
}

View File

@ -70,7 +70,7 @@ public class TransportSetTest {
private FakeClock fakeClock;
@Mock private LoadBalancer mockLoadBalancer;
@Mock private LoadBalancer<ClientTransport> mockLoadBalancer;
@Mock private BackoffPolicy mockBackoffPolicy1;
@Mock private BackoffPolicy mockBackoffPolicy2;
@Mock private BackoffPolicy mockBackoffPolicy3;

View File

@ -49,10 +49,8 @@ import io.grpc.StatusException;
import io.grpc.TransportManager;
import io.grpc.internal.BlankFutureProvider;
import io.grpc.internal.BlankFutureProvider.FulfillmentBatch;
import io.grpc.internal.ClientTransport;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourceHolder;
import io.grpc.internal.SingleTransportChannel;
import io.grpc.stub.StreamObserver;
import java.net.InetSocketAddress;
@ -62,7 +60,6 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.logging.Logger;
import javax.annotation.Nullable;
@ -71,17 +68,16 @@ import javax.annotation.concurrent.GuardedBy;
/**
* A {@link LoadBalancer} that uses the GRPCLB protocol.
*/
class GrpclbLoadBalancer extends LoadBalancer {
class GrpclbLoadBalancer<T> extends LoadBalancer<T> {
private static final Logger logger = Logger.getLogger(GrpclbLoadBalancer.class.getName());
private final Object lock = new Object();
private final String serviceName;
private final TransportManager tm;
private final TransportManager<T> tm;
// General states
@GuardedBy("lock")
private final BlankFutureProvider<ClientTransport> pendingPicks =
new BlankFutureProvider<ClientTransport>();
private final BlankFutureProvider<T> pendingPicks = new BlankFutureProvider<T>();
@GuardedBy("lock")
private Throwable lastError;
@ -92,9 +88,9 @@ class GrpclbLoadBalancer extends LoadBalancer {
@GuardedBy("lock")
private EquivalentAddressGroup lbAddresses;
@GuardedBy("lock")
private ClientTransport lbTransport;
private T lbTransport;
@GuardedBy("lock")
private ListenableFuture<ClientTransport> directTransport;
private ListenableFuture<T> directTransport;
@GuardedBy("lock")
private StreamObserver<LoadBalanceResponse> lbResponseObserver;
@GuardedBy("lock")
@ -105,16 +101,14 @@ class GrpclbLoadBalancer extends LoadBalancer {
private HashMap<SocketAddress, ResolvedServerInfo> servers;
@GuardedBy("lock")
@VisibleForTesting
private RoundRobinServerList roundRobinServerList;
private RoundRobinServerList<T> roundRobinServerList;
private ExecutorService executor;
private ScheduledExecutorService deadlineCancellationExecutor;
GrpclbLoadBalancer(String serviceName, TransportManager tm) {
GrpclbLoadBalancer(String serviceName, TransportManager<T> tm) {
this.serviceName = serviceName;
this.tm = tm;
executor = SharedResourceHolder.get(GrpcUtil.SHARED_CHANNEL_EXECUTOR);
deadlineCancellationExecutor = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE);
}
@VisibleForTesting
@ -125,15 +119,15 @@ class GrpclbLoadBalancer extends LoadBalancer {
}
@VisibleForTesting
RoundRobinServerList getRoundRobinServerList() {
RoundRobinServerList<T> getRoundRobinServerList() {
synchronized (lock) {
return roundRobinServerList;
}
}
@Override
public ListenableFuture<ClientTransport> pickTransport(@Nullable RequestKey requestKey) {
RoundRobinServerList serverListCopy;
public ListenableFuture<T> pickTransport(@Nullable RequestKey requestKey) {
RoundRobinServerList<T> serverListCopy;
synchronized (lock) {
Preconditions.checkState(!closed, "already closed");
if (directTransport != null) {
@ -178,11 +172,11 @@ class GrpclbLoadBalancer extends LoadBalancer {
Preconditions.checkNotNull(lbAddresses, "lbAddresses");
// TODO(zhangkun83): LB servers may use an authority different from the service's.
// getTransport() will need to add an argument for the authority.
ListenableFuture<ClientTransport> transportFuture = tm.getTransport(lbAddresses);
ListenableFuture<T> transportFuture = tm.getTransport(lbAddresses);
Futures.addCallback(
Preconditions.checkNotNull(transportFuture),
new FutureCallback<ClientTransport>() {
@Override public void onSuccess(ClientTransport transport) {
new FutureCallback<T>() {
@Override public void onSuccess(T transport) {
synchronized (lock) {
if (closed) {
return;
@ -221,9 +215,8 @@ class GrpclbLoadBalancer extends LoadBalancer {
@VisibleForTesting // to be mocked in tests
@GuardedBy("lock")
void sendLbRequest(ClientTransport transport, LoadBalanceRequest request) {
Channel channel = new SingleTransportChannel(transport, executor,
deadlineCancellationExecutor, serviceName);
void sendLbRequest(T transport, LoadBalanceRequest request) {
Channel channel = tm.makeChannel(transport);
LoadBalancerGrpc.LoadBalancerStub stub = LoadBalancerGrpc.newStub(channel);
lbRequestWriter = stub.balanceLoad(lbResponseObserver);
lbRequestWriter.onNext(request);
@ -245,14 +238,11 @@ class GrpclbLoadBalancer extends LoadBalancer {
lbRequestWriter.onCompleted();
}
executor = SharedResourceHolder.release(GrpcUtil.SHARED_CHANNEL_EXECUTOR, executor);
deadlineCancellationExecutor = SharedResourceHolder.release(
GrpcUtil.TIMER_SERVICE, deadlineCancellationExecutor);
}
}
@Override
public void transportShutdown(
EquivalentAddressGroup addressGroup, ClientTransport transport, Status status) {
public void transportShutdown(EquivalentAddressGroup addressGroup, T transport, Status status) {
handleError(status.augmentDescription("Transport to LB server closed"));
synchronized (lock) {
if (transport == lbTransport) {
@ -262,7 +252,7 @@ class GrpclbLoadBalancer extends LoadBalancer {
}
private void handleError(Status error) {
FulfillmentBatch<ClientTransport> pendingPicksFulfillmentBatch;
FulfillmentBatch<T> pendingPicksFulfillmentBatch;
StatusException statusException = error.asException();
synchronized (lock) {
lastError = statusException;
@ -291,7 +281,7 @@ class GrpclbLoadBalancer extends LoadBalancer {
logger.info("Got a LB response: " + response);
InitialLoadBalanceResponse initialResponse = response.getInitialResponse();
// TODO(zhangkun83): make use of initialResponse
RoundRobinServerList.Builder listBuilder = new RoundRobinServerList.Builder(tm);
RoundRobinServerList.Builder<T> listBuilder = new RoundRobinServerList.Builder<T>(tm);
ServerList serverList = response.getServerList();
HashMap<SocketAddress, ResolvedServerInfo> newServerMap =
new HashMap<SocketAddress, ResolvedServerInfo>();
@ -310,13 +300,13 @@ class GrpclbLoadBalancer extends LoadBalancer {
}
}
}
final RoundRobinServerList newRoundRobinServerList = listBuilder.build();
final RoundRobinServerList<T> newRoundRobinServerList = listBuilder.build();
if (newRoundRobinServerList.size() == 0) {
// initialResponse and serverList are under a oneof group. If initialResponse is set,
// serverList will be empty.
return;
}
FulfillmentBatch<ClientTransport> pendingPicksFulfillmentBatch;
FulfillmentBatch<T> pendingPicksFulfillmentBatch;
synchronized (lock) {
if (lbResponseObserver != this) {
// Make sure I am still the current stream.
@ -328,9 +318,9 @@ class GrpclbLoadBalancer extends LoadBalancer {
}
updateRetainedTransports();
pendingPicksFulfillmentBatch.link(
new Supplier<ListenableFuture<ClientTransport>>() {
new Supplier<ListenableFuture<T>>() {
@Override
public ListenableFuture<ClientTransport> get() {
public ListenableFuture<T> get() {
return newRoundRobinServerList.getTransportForNextServer();
}
});
@ -348,8 +338,8 @@ class GrpclbLoadBalancer extends LoadBalancer {
private void onStreamClosed(Status status) {
if (status.getCode() == Status.Code.UNIMPLEMENTED) {
FulfillmentBatch<ClientTransport> pendingPicksFulfillmentBatch;
final ListenableFuture<ClientTransport> transportFuture;
FulfillmentBatch<T> pendingPicksFulfillmentBatch;
final ListenableFuture<T> transportFuture;
// This LB transport doesn't seem to be an actual LB server, if the LB address comes
// directly from NameResolver, just use it to serve normal RPCs.
// TODO(zhangkun83): check if lbAddresses are from NameResolver after we start getting
@ -362,9 +352,9 @@ class GrpclbLoadBalancer extends LoadBalancer {
pendingPicksFulfillmentBatch = pendingPicks.createFulfillmentBatch();
}
pendingPicksFulfillmentBatch.link(
new Supplier<ListenableFuture<ClientTransport>>() {
new Supplier<ListenableFuture<T>>() {
@Override
public ListenableFuture<ClientTransport> get() {
public ListenableFuture<T> get() {
return transportFuture;
}
});

View File

@ -56,7 +56,7 @@ public class GrpclbLoadBalancerFactory extends LoadBalancer.Factory {
}
@Override
public LoadBalancer newLoadBalancer(String serviceName, TransportManager tm) {
return new GrpclbLoadBalancer(serviceName, tm);
public <T> LoadBalancer<T> newLoadBalancer(String serviceName, TransportManager<T> tm) {
return new GrpclbLoadBalancer<T>(serviceName, tm);
}
}

View File

@ -39,7 +39,6 @@ import com.google.common.util.concurrent.ListenableFuture;
import io.grpc.EquivalentAddressGroup;
import io.grpc.TransportManager;
import io.grpc.internal.ClientTransport;
import java.net.InetSocketAddress;
import java.util.Iterator;
@ -55,18 +54,18 @@ import javax.annotation.concurrent.ThreadSafe;
// TODO(zhangkun83): possibly move it to io.grpc.internal, as it can also be used by the round-robin
// LoadBalancer.
@ThreadSafe
class RoundRobinServerList {
private final TransportManager tm;
class RoundRobinServerList<T> {
private final TransportManager<T> tm;
private final List<EquivalentAddressGroup> list;
private final Iterator<EquivalentAddressGroup> cyclingIter;
private RoundRobinServerList(TransportManager tm, List<EquivalentAddressGroup> list) {
private RoundRobinServerList(TransportManager<T> tm, List<EquivalentAddressGroup> list) {
this.tm = tm;
this.list = list;
this.cyclingIter = Iterables.cycle(list).iterator();
}
ListenableFuture<ClientTransport> getTransportForNextServer() {
ListenableFuture<T> getTransportForNextServer() {
EquivalentAddressGroup currentServer;
synchronized (cyclingIter) {
// TODO(zhangkun83): receive transportShutdown and transportReady events, then skip addresses
@ -92,12 +91,12 @@ class RoundRobinServerList {
}
@NotThreadSafe
static class Builder {
static class Builder<T> {
private final ImmutableList.Builder<EquivalentAddressGroup> listBuilder =
ImmutableList.builder();
private final TransportManager tm;
private final TransportManager<T> tm;
Builder(TransportManager tm) {
Builder(TransportManager<T> tm) {
this.tm = tm;
}
@ -108,8 +107,8 @@ class RoundRobinServerList {
listBuilder.add(new EquivalentAddressGroup(addr));
}
RoundRobinServerList build() {
return new RoundRobinServerList(tm, listBuilder.build());
RoundRobinServerList<T> build() {
return new RoundRobinServerList<T>(tm, listBuilder.build());
}
}
}

View File

@ -55,7 +55,6 @@ import io.grpc.EquivalentAddressGroup;
import io.grpc.ResolvedServerInfo;
import io.grpc.Status;
import io.grpc.TransportManager;
import io.grpc.internal.ClientTransport;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -77,7 +76,8 @@ import java.util.concurrent.TimeUnit;
public class GrpclbLoadBalancerTest {
private static final String serviceName = "testlbservice";
private final TransportManager mockTransportManager = mock(TransportManager.class);
@SuppressWarnings("unchecked")
private final TransportManager<Transport> mockTransportManager = mock(TransportManager.class);
// The test subject
private TestGrpclbLoadBalancer loadBalancer = new TestGrpclbLoadBalancer();
@ -85,28 +85,28 @@ public class GrpclbLoadBalancerTest {
// Current addresses of the LB server
private EquivalentAddressGroup lbAddressGroup;
// The future of the currently requested transport for an LB server
private SettableFuture<ClientTransport> lbTransportFuture;
private SettableFuture<Transport> lbTransportFuture;
@Test
public void balancing() throws Exception {
List<ResolvedServerInfo> servers = createResolvedServerInfoList(4000, 4001);
// Set up mocks
List<ClientTransport> transports = new ArrayList<ClientTransport>(servers.size());
List<SettableFuture<ClientTransport>> transportFutures =
new ArrayList<SettableFuture<ClientTransport>>(servers.size());
List<Transport> transports = new ArrayList<Transport>(servers.size());
List<SettableFuture<Transport>> transportFutures =
new ArrayList<SettableFuture<Transport>>(servers.size());
for (ResolvedServerInfo server : servers) {
transports.add(
mock(ClientTransport.class, withSettings().name("Transport for " + server.toString())));
SettableFuture<ClientTransport> future = SettableFuture.create();
mock(Transport.class, withSettings().name("Transport for " + server.toString())));
SettableFuture<Transport> future = SettableFuture.create();
transportFutures.add(future);
when(mockTransportManager.getTransport(eq(new EquivalentAddressGroup(server.getAddress()))))
.thenReturn(future);
}
ListenableFuture<ClientTransport> pick0;
ListenableFuture<ClientTransport> pick1;
ListenableFuture<Transport> pick0;
ListenableFuture<Transport> pick1;
// Pick before name resolved
pick0 = loadBalancer.pickTransport(null);
@ -118,7 +118,7 @@ public class GrpclbLoadBalancerTest {
pick1 = loadBalancer.pickTransport(null);
// Make the transport for LB server ready
ClientTransport lbTransport = mock(ClientTransport.class);
Transport lbTransport = new Transport();
lbTransportFuture.set(lbTransport);
// An LB request is sent
SendLbRequestArgs sentLbRequest = loadBalancer.sentLbRequests.poll(1000, TimeUnit.SECONDS);
@ -173,7 +173,7 @@ public class GrpclbLoadBalancerTest {
simulateLbAddressResolved(30001);
// Make the transport for LB server ready
ClientTransport lbTransport = mock(ClientTransport.class);
Transport lbTransport = new Transport();
lbTransportFuture.set(lbTransport);
// An LB request is sent
@ -207,7 +207,7 @@ public class GrpclbLoadBalancerTest {
verify(mockTransportManager).getTransport(eq(lbAddressGroup));
// Make the transport for LB server ready
ClientTransport lbTransport = mock(ClientTransport.class);
Transport lbTransport = new Transport();
lbTransportFuture.set(lbTransport);
// An LB request is sent
@ -221,7 +221,7 @@ public class GrpclbLoadBalancerTest {
assertNotEquals(lbAddress1, lbAddress2);
verify(mockTransportManager).updateRetainedTransports(eq(Collections.singleton(lbAddress2)));
verify(mockTransportManager).getTransport(eq(lbAddressGroup));
lbTransport = mock(ClientTransport.class);
lbTransport = new Transport();
lbTransportFuture.set(lbTransport);
// Another LB request is sent
@ -243,7 +243,7 @@ public class GrpclbLoadBalancerTest {
simulateLbAddressResolved(30001);
// Make the transport for LB server ready
lbTransportFuture.set(mock(ClientTransport.class));
lbTransportFuture.set(new Transport());
// An LB request is sent
assertNotNull(loadBalancer.sentLbRequests.poll(1000, TimeUnit.SECONDS));
@ -271,10 +271,10 @@ public class GrpclbLoadBalancerTest {
simulateLbAddressResolved(30001);
// First pick, will be pending
ListenableFuture<ClientTransport> pick = loadBalancer.pickTransport(null);
ListenableFuture<Transport> pick = loadBalancer.pickTransport(null);
// Make the transport for LB server ready
ClientTransport lbTransport = mock(ClientTransport.class);
Transport lbTransport = new Transport();
lbTransportFuture.set(lbTransport);
// An LB request is sent
@ -304,10 +304,10 @@ public class GrpclbLoadBalancerTest {
simulateLbAddressResolved(30001);
// First pick, will be pending
ListenableFuture<ClientTransport> pick = loadBalancer.pickTransport(null);
ListenableFuture<Transport> pick = loadBalancer.pickTransport(null);
// Make the transport for LB server ready
ClientTransport lbTransport = mock(ClientTransport.class);
Transport lbTransport = new Transport();
lbTransportFuture.set(lbTransport);
// An LB request is sent
@ -356,7 +356,7 @@ public class GrpclbLoadBalancerTest {
simulateLbAddressResolved(30001);
// Make the transport for LB server ready
ClientTransport lbTransport = mock(ClientTransport.class);
Transport lbTransport = new Transport();
lbTransportFuture.set(lbTransport);
// An LB request is sent
@ -378,7 +378,7 @@ public class GrpclbLoadBalancerTest {
verify(mockTransportManager, times(2)).getTransport(eq(lbAddressGroup));
// Make the new transport ready
lbTransportFuture.set(mock(ClientTransport.class));
lbTransportFuture.set(new Transport());
// Another LB request is sent
assertNotNull(loadBalancer.sentLbRequests.poll(1000, TimeUnit.SECONDS));
@ -389,10 +389,10 @@ public class GrpclbLoadBalancerTest {
simulateLbAddressResolved(30001);
// First pick, will be pending
ListenableFuture<ClientTransport> pick = loadBalancer.pickTransport(null);
ListenableFuture<Transport> pick = loadBalancer.pickTransport(null);
// Make the transport for LB server ready
ClientTransport lbTransport = mock(ClientTransport.class);
Transport lbTransport = new Transport();
lbTransportFuture.set(lbTransport);
// An LB request is sent
@ -417,13 +417,13 @@ public class GrpclbLoadBalancerTest {
}
@Test public void nameResolutionFailed() throws Exception {
ListenableFuture<ClientTransport> pick0 = loadBalancer.pickTransport(null);
ListenableFuture<Transport> pick0 = loadBalancer.pickTransport(null);
assertFalse(pick0.isDone());
loadBalancer.handleNameResolutionError(Status.UNAVAILABLE);
assertTrue(pick0.isDone());
ListenableFuture<ClientTransport> pick1 = loadBalancer.pickTransport(null);
ListenableFuture<Transport> pick1 = loadBalancer.pickTransport(null);
assertTrue(pick1.isDone());
assertFutureFailedWithError(pick0, Status.Code.UNAVAILABLE, "Name resolution failed");
assertFutureFailedWithError(pick1, Status.Code.UNAVAILABLE, "Name resolution failed");
@ -434,7 +434,7 @@ public class GrpclbLoadBalancerTest {
simulateLbAddressResolved(30001);
// Make the transport for LB server ready
ClientTransport lbTransport = mock(ClientTransport.class);
Transport lbTransport = new Transport();
lbTransportFuture.set(lbTransport);
// An LB request is sent
@ -469,7 +469,7 @@ public class GrpclbLoadBalancerTest {
ResolvedServerInfo lbServerInfo = new ResolvedServerInfo(
new InetSocketAddress("127.0.0.1", lbPort), Attributes.EMPTY);
lbAddressGroup = buildAddressGroup(lbServerInfo);
ClientTransport lbTransport = mock(ClientTransport.class);
Transport lbTransport = new Transport();
lbTransportFuture = SettableFuture.create();
when(mockTransportManager.getTransport(eq(lbAddressGroup))).thenReturn(lbTransportFuture);
loadBalancer.handleResolvedAddresses(Collections.singletonList(lbServerInfo), Attributes.EMPTY);
@ -490,7 +490,7 @@ public class GrpclbLoadBalancerTest {
* A slightly modified {@link GrpclbLoadBalancerTest} that saves LB requests in a queue instead of
* sending them out.
*/
private class TestGrpclbLoadBalancer extends GrpclbLoadBalancer {
private class TestGrpclbLoadBalancer extends GrpclbLoadBalancer<Transport> {
final LinkedBlockingQueue<SendLbRequestArgs> sentLbRequests =
new LinkedBlockingQueue<SendLbRequestArgs>();
@ -498,16 +498,16 @@ public class GrpclbLoadBalancerTest {
super(serviceName, mockTransportManager);
}
@Override void sendLbRequest(ClientTransport transport, LoadBalanceRequest request) {
@Override void sendLbRequest(Transport transport, LoadBalanceRequest request) {
sentLbRequests.add(new SendLbRequestArgs(transport, request));
}
}
private static class SendLbRequestArgs {
final ClientTransport transport;
final Transport transport;
final LoadBalanceRequest request;
SendLbRequestArgs(ClientTransport transport, LoadBalanceRequest request) {
SendLbRequestArgs(Transport transport, LoadBalanceRequest request) {
this.transport = transport;
this.request = request;
}
@ -563,4 +563,6 @@ public class GrpclbLoadBalancerTest {
}
}
}
public static class Transport {}
}