Move multiple-port ServerImpl to NettyServer (#7674)

Change InternalServer to handle multiple addresses and implemented in NettyServer.
It makes ServerImpl to have a single transport server, and this single transport server (NettyServer) will bind to all listening addresses during bootstrap. (#7674)
This commit is contained in:
Yifei Zhuang 2021-01-05 13:24:16 -08:00 committed by GitHub
parent ccef406f89
commit 53da588dd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 542 additions and 181 deletions

View File

@ -82,11 +82,21 @@ final class InProcessServer implements InternalServer {
return new InProcessSocketAddress(name); return new InProcessSocketAddress(name);
} }
@Override
public List<? extends SocketAddress> getListenSocketAddresses() {
return Collections.singletonList(getListenSocketAddress());
}
@Override @Override
public InternalInstrumented<SocketStats> getListenSocketStats() { public InternalInstrumented<SocketStats> getListenSocketStats() {
return null; return null;
} }
@Override
public List<InternalInstrumented<SocketStats>> getListenSocketStatsList() {
return null;
}
@Override @Override
public void shutdown() { public void shutdown() {
if (!registry.remove(name, this)) { if (!registry.remove(name, this)) {

View File

@ -33,7 +33,6 @@ import io.grpc.internal.ServerImplBuilder;
import io.grpc.internal.ServerImplBuilder.ClientTransportServersBuilder; import io.grpc.internal.ServerImplBuilder.ClientTransportServersBuilder;
import io.grpc.internal.SharedResourcePool; import io.grpc.internal.SharedResourcePool;
import java.io.File; import java.io.File;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
@ -109,7 +108,7 @@ public final class InProcessServerBuilder extends
final class InProcessClientTransportServersBuilder implements ClientTransportServersBuilder { final class InProcessClientTransportServersBuilder implements ClientTransportServersBuilder {
@Override @Override
public List<? extends InternalServer> buildClientTransportServers( public InternalServer buildClientTransportServers(
List<? extends ServerStreamTracer.Factory> streamTracerFactories) { List<? extends ServerStreamTracer.Factory> streamTracerFactories) {
return buildTransportServers(streamTracerFactories); return buildTransportServers(streamTracerFactories);
} }
@ -187,9 +186,9 @@ public final class InProcessServerBuilder extends
return this; return this;
} }
List<InProcessServer> buildTransportServers( InProcessServer buildTransportServers(
List<? extends ServerStreamTracer.Factory> streamTracerFactories) { List<? extends ServerStreamTracer.Factory> streamTracerFactories) {
return Collections.singletonList(new InProcessServer(this, streamTracerFactories)); return new InProcessServer(this, streamTracerFactories);
} }
@Override @Override

View File

@ -20,12 +20,13 @@ import io.grpc.InternalChannelz.SocketStats;
import io.grpc.InternalInstrumented; import io.grpc.InternalInstrumented;
import java.io.IOException; import java.io.IOException;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.util.List;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.annotation.concurrent.ThreadSafe; import javax.annotation.concurrent.ThreadSafe;
/** /**
* An object that accepts new incoming connections. This would commonly encapsulate a bound socket * An object that accepts new incoming connections on one or more listening socket addresses.
* that {@code accept()}s new connections. * This would commonly encapsulate a bound socket that {@code accept()}s new connections.
*/ */
@ThreadSafe @ThreadSafe
public interface InternalServer { public interface InternalServer {
@ -49,13 +50,25 @@ public interface InternalServer {
void shutdown(); void shutdown();
/** /**
* Returns the listening socket address. May change after {@link start(ServerListener)} is * Returns the first listening socket address. May change after {@link start(ServerListener)} is
* called. * called.
*/ */
SocketAddress getListenSocketAddress(); SocketAddress getListenSocketAddress();
/** /**
* Returns the listen socket stats of this server. May return {@code null}. * Returns the first listen socket stats of this server. May return {@code null}.
*/ */
@Nullable InternalInstrumented<SocketStats> getListenSocketStats(); @Nullable InternalInstrumented<SocketStats> getListenSocketStats();
/**
* Returns a list of listening socket addresses. May change after {@link start(ServerListener)}
* is called.
*/
List<? extends SocketAddress> getListenSocketAddresses();
/**
* Returns a list of listen socket stats of this server. May return {@code null}.
*/
@Nullable List<InternalInstrumented<SocketStats>> getListenSocketStatsList();
} }

View File

@ -110,12 +110,11 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
@GuardedBy("lock") private boolean serverShutdownCallbackInvoked; @GuardedBy("lock") private boolean serverShutdownCallbackInvoked;
@GuardedBy("lock") private boolean terminated; @GuardedBy("lock") private boolean terminated;
/** Service encapsulating something similar to an accept() socket. */ /** Service encapsulating something similar to an accept() socket. */
private final List<? extends InternalServer> transportServers; private final InternalServer transportServer;
private final Object lock = new Object(); private final Object lock = new Object();
@GuardedBy("lock") private boolean transportServersTerminated; @GuardedBy("lock") private boolean transportServersTerminated;
/** {@code transportServer} and services encapsulating something similar to a TCP connection. */ /** {@code transportServer} and services encapsulating something similar to a TCP connection. */
@GuardedBy("lock") private final Set<ServerTransport> transports = new HashSet<>(); @GuardedBy("lock") private final Set<ServerTransport> transports = new HashSet<>();
@GuardedBy("lock") private int activeTransportServers;
private final Context rootContext; private final Context rootContext;
@ -131,20 +130,18 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
* Construct a server. * Construct a server.
* *
* @param builder builder with configuration for server * @param builder builder with configuration for server
* @param transportServers transport servers that will create new incoming transports * @param transportServer transport servers that will create new incoming transports
* @param rootContext context that callbacks for new RPCs should be derived from * @param rootContext context that callbacks for new RPCs should be derived from
*/ */
ServerImpl( ServerImpl(
ServerImplBuilder builder, ServerImplBuilder builder,
List<? extends InternalServer> transportServers, InternalServer transportServer,
Context rootContext) { Context rootContext) {
this.executorPool = Preconditions.checkNotNull(builder.executorPool, "executorPool"); this.executorPool = Preconditions.checkNotNull(builder.executorPool, "executorPool");
this.registry = Preconditions.checkNotNull(builder.registryBuilder.build(), "registryBuilder"); this.registry = Preconditions.checkNotNull(builder.registryBuilder.build(), "registryBuilder");
this.fallbackRegistry = this.fallbackRegistry =
Preconditions.checkNotNull(builder.fallbackRegistry, "fallbackRegistry"); Preconditions.checkNotNull(builder.fallbackRegistry, "fallbackRegistry");
Preconditions.checkNotNull(transportServers, "transportServers"); this.transportServer = Preconditions.checkNotNull(transportServer, "transportServer");
Preconditions.checkArgument(!transportServers.isEmpty(), "no servers provided");
this.transportServers = new ArrayList<>(transportServers);
this.logId = this.logId =
InternalLogId.allocate("Server", String.valueOf(getListenSocketsIgnoringLifecycle())); InternalLogId.allocate("Server", String.valueOf(getListenSocketsIgnoringLifecycle()));
// Fork from the passed in context so that it does not propagate cancellation, it only // Fork from the passed in context so that it does not propagate cancellation, it only
@ -179,10 +176,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
// Start and wait for any ports to actually be bound. // Start and wait for any ports to actually be bound.
ServerListenerImpl listener = new ServerListenerImpl(); ServerListenerImpl listener = new ServerListenerImpl();
for (InternalServer ts : transportServers) { transportServer.start(listener);
ts.start(listener);
activeTransportServers++;
}
executor = Preconditions.checkNotNull(executorPool.getObject(), "executor"); executor = Preconditions.checkNotNull(executorPool.getObject(), "executor");
started = true; started = true;
return this; return this;
@ -195,8 +189,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
synchronized (lock) { synchronized (lock) {
checkState(started, "Not started"); checkState(started, "Not started");
checkState(!terminated, "Already terminated"); checkState(!terminated, "Already terminated");
for (InternalServer ts : transportServers) { for (SocketAddress addr: transportServer.getListenSocketAddresses()) {
SocketAddress addr = ts.getListenSocketAddress();
if (addr instanceof InetSocketAddress) { if (addr instanceof InetSocketAddress) {
return ((InetSocketAddress) addr).getPort(); return ((InetSocketAddress) addr).getPort();
} }
@ -216,11 +209,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
private List<SocketAddress> getListenSocketsIgnoringLifecycle() { private List<SocketAddress> getListenSocketsIgnoringLifecycle() {
synchronized (lock) { synchronized (lock) {
List<SocketAddress> addrs = new ArrayList<>(transportServers.size()); return Collections.unmodifiableList(transportServer.getListenSocketAddresses());
for (InternalServer ts : transportServers) {
addrs.add(ts.getListenSocketAddress());
}
return Collections.unmodifiableList(addrs);
} }
} }
@ -268,9 +257,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
} }
} }
if (shutdownTransportServers) { if (shutdownTransportServers) {
for (InternalServer ts : transportServers) { transportServer.shutdown();
ts.shutdown();
}
} }
return this; return this;
} }
@ -388,8 +375,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
ArrayList<ServerTransport> copiedTransports; ArrayList<ServerTransport> copiedTransports;
Status shutdownNowStatusCopy; Status shutdownNowStatusCopy;
synchronized (lock) { synchronized (lock) {
activeTransportServers--; if (serverShutdownCallbackInvoked) {
if (activeTransportServers != 0) {
return; return;
} }
@ -662,12 +648,9 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
@Override @Override
public ListenableFuture<ServerStats> getStats() { public ListenableFuture<ServerStats> getStats() {
ServerStats.Builder builder = new ServerStats.Builder(); ServerStats.Builder builder = new ServerStats.Builder();
for (InternalServer ts : transportServers) { List<InternalInstrumented<SocketStats>> stats = transportServer.getListenSocketStatsList();
// TODO(carl-mastrangelo): remove the list and just add directly. if (stats != null ) {
InternalInstrumented<SocketStats> stats = ts.getListenSocketStats(); builder.addListenSockets(stats);
if (stats != null ) {
builder.addListenSockets(Collections.singletonList(stats));
}
} }
serverCallTracer.updateBuilder(builder); serverCallTracer.updateBuilder(builder);
SettableFuture<ServerStats> ret = SettableFuture.create(); SettableFuture<ServerStats> ret = SettableFuture.create();
@ -679,7 +662,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
public String toString() { public String toString() {
return MoreObjects.toStringHelper(this) return MoreObjects.toStringHelper(this)
.add("logId", logId.getId()) .add("logId", logId.getId())
.add("transportServers", transportServers) .add("transportServer", transportServer)
.toString(); .toString();
} }

View File

@ -97,7 +97,7 @@ public final class ServerImplBuilder extends ServerBuilder<ServerImplBuilder> {
* is meant for Transport implementors and should not be used by normal users. * is meant for Transport implementors and should not be used by normal users.
*/ */
public interface ClientTransportServersBuilder { public interface ClientTransportServersBuilder {
List<? extends InternalServer> buildClientTransportServers( InternalServer buildClientTransportServers(
List<? extends ServerStreamTracer.Factory> streamTracerFactories); List<? extends ServerStreamTracer.Factory> streamTracerFactories);
} }

View File

@ -22,7 +22,6 @@ import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import com.google.common.collect.Iterables;
import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer;
import io.grpc.internal.FakeClock; import io.grpc.internal.FakeClock;
import io.grpc.internal.ObjectPool; import io.grpc.internal.ObjectPool;
@ -55,8 +54,8 @@ public class InProcessServerBuilderTest {
@Test @Test
public void scheduledExecutorService_default() { public void scheduledExecutorService_default() {
InProcessServerBuilder builder = InProcessServerBuilder.forName("foo"); InProcessServerBuilder builder = InProcessServerBuilder.forName("foo");
InProcessServer server = Iterables.getOnlyElement( InProcessServer server =
builder.buildTransportServers(new ArrayList<ServerStreamTracer.Factory>())); builder.buildTransportServers(new ArrayList<ServerStreamTracer.Factory>());
ObjectPool<ScheduledExecutorService> scheduledExecutorServicePool = ObjectPool<ScheduledExecutorService> scheduledExecutorServicePool =
server.getScheduledExecutorServicePool(); server.getScheduledExecutorServicePool();
@ -80,8 +79,8 @@ public class InProcessServerBuilderTest {
InProcessServerBuilder builder1 = builder.scheduledExecutorService(scheduledExecutorService); InProcessServerBuilder builder1 = builder.scheduledExecutorService(scheduledExecutorService);
assertSame(builder, builder1); assertSame(builder, builder1);
InProcessServer server = Iterables.getOnlyElement( InProcessServer server =
builder1.buildTransportServers(new ArrayList<ServerStreamTracer.Factory>())); builder1.buildTransportServers(new ArrayList<ServerStreamTracer.Factory>());
ObjectPool<ScheduledExecutorService> scheduledExecutorServicePool = ObjectPool<ScheduledExecutorService> scheduledExecutorServicePool =
server.getScheduledExecutorServicePool(); server.getScheduledExecutorServicePool();

View File

@ -19,7 +19,6 @@ package io.grpc.inprocess;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import com.google.common.collect.ImmutableList;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.Metadata; import io.grpc.Metadata;
@ -55,16 +54,16 @@ public class InProcessTransportTest extends AbstractTransportTest {
public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule();
@Override @Override
protected List<? extends InternalServer> newServer( protected InternalServer newServer(
List<ServerStreamTracer.Factory> streamTracerFactories) { List<ServerStreamTracer.Factory> streamTracerFactories) {
InProcessServerBuilder builder = InProcessServerBuilder InProcessServerBuilder builder = InProcessServerBuilder
.forName(TRANSPORT_NAME) .forName(TRANSPORT_NAME)
.maxInboundMetadataSize(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); .maxInboundMetadataSize(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE);
return ImmutableList.of(new InProcessServer(builder, streamTracerFactories)); return new InProcessServer(builder, streamTracerFactories);
} }
@Override @Override
protected List<? extends InternalServer> newServer( protected InternalServer newServer(
int port, List<ServerStreamTracer.Factory> streamTracerFactories) { int port, List<ServerStreamTracer.Factory> streamTracerFactories) {
return newServer(streamTracerFactories); return newServer(streamTracerFactories);
} }

View File

@ -16,7 +16,6 @@
package io.grpc.inprocess; package io.grpc.inprocess;
import com.google.common.collect.ImmutableList;
import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalChannelz.SocketStats;
import io.grpc.InternalInstrumented; import io.grpc.InternalInstrumented;
import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer;
@ -31,6 +30,7 @@ import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.SharedResourcePool; import io.grpc.internal.SharedResourcePool;
import java.io.IOException; import java.io.IOException;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -52,13 +52,13 @@ public final class StandaloneInProcessTransportTest extends AbstractTransportTes
private TestServer currentServer; private TestServer currentServer;
@Override @Override
protected List<? extends InternalServer> newServer( protected InternalServer newServer(
List<ServerStreamTracer.Factory> streamTracerFactories) { List<ServerStreamTracer.Factory> streamTracerFactories) {
return ImmutableList.of(new TestServer(streamTracerFactories)); return new TestServer(streamTracerFactories);
} }
@Override @Override
protected List<? extends InternalServer> newServer( protected InternalServer newServer(
int port, List<ServerStreamTracer.Factory> streamTracerFactories) { int port, List<ServerStreamTracer.Factory> streamTracerFactories) {
return newServer(streamTracerFactories); return newServer(streamTracerFactories);
} }
@ -126,11 +126,22 @@ public final class StandaloneInProcessTransportTest extends AbstractTransportTes
return new SocketAddress() {}; return new SocketAddress() {};
} }
@Override
public List<SocketAddress> getListenSocketAddresses() {
return Collections.singletonList(getListenSocketAddress());
}
@Override @Override
@Nullable @Nullable
public InternalInstrumented<SocketStats> getListenSocketStats() { public InternalInstrumented<SocketStats> getListenSocketStats() {
return null; return null;
} }
@Override
@Nullable
public List<InternalInstrumented<SocketStats>> getListenSocketStatsList() {
return null;
}
} }
/** Wraps the server listener to ensure we don't accept new transports after shutdown. */ /** Wraps the server listener to ensure we don't accept new transports after shutdown. */

View File

@ -40,7 +40,6 @@ import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import com.google.common.base.Objects; import com.google.common.base.Objects;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
@ -118,13 +117,13 @@ public abstract class AbstractTransportTest {
* Returns a new server that when started will be able to be connected to from the client. Each * Returns a new server that when started will be able to be connected to from the client. Each
* returned instance should be new and yet be accessible by new client transports. * returned instance should be new and yet be accessible by new client transports.
*/ */
protected abstract List<? extends InternalServer> newServer( protected abstract InternalServer newServer(
List<ServerStreamTracer.Factory> streamTracerFactories); List<ServerStreamTracer.Factory> streamTracerFactories);
/** /**
* Builds a new server that is listening on the same port as the given server instance does. * Builds a new server that is listening on the same port as the given server instance does.
*/ */
protected abstract List<? extends InternalServer> newServer( protected abstract InternalServer newServer(
int port, List<ServerStreamTracer.Factory> streamTracerFactories); int port, List<ServerStreamTracer.Factory> streamTracerFactories);
/** /**
@ -230,7 +229,7 @@ public abstract class AbstractTransportTest {
@Before @Before
public void setUp() { public void setUp() {
server = Iterables.getOnlyElement(newServer(Arrays.asList(serverStreamTracerFactory))); server = newServer(Arrays.asList(serverStreamTracerFactory));
callOptions = CallOptions.DEFAULT.withStreamTracerFactory(clientStreamTracerFactory); callOptions = CallOptions.DEFAULT.withStreamTracerFactory(clientStreamTracerFactory);
} }
@ -401,8 +400,7 @@ public abstract class AbstractTransportTest {
if (addr instanceof InetSocketAddress) { if (addr instanceof InetSocketAddress) {
port = ((InetSocketAddress) addr).getPort(); port = ((InetSocketAddress) addr).getPort();
} }
InternalServer server2 = InternalServer server2 = newServer(port, Arrays.asList(serverStreamTracerFactory));
Iterables.getOnlyElement(newServer(port, Arrays.asList(serverStreamTracerFactory)));
thrown.expect(IOException.class); thrown.expect(IOException.class);
server2.start(new MockServerListener()); server2.start(new MockServerListener());
} }
@ -421,7 +419,7 @@ public abstract class AbstractTransportTest {
assumeTrue("transport is not using InetSocketAddress", port != -1); assumeTrue("transport is not using InetSocketAddress", port != -1);
server.shutdown(); server.shutdown();
server = Iterables.getOnlyElement(newServer(port, Arrays.asList(serverStreamTracerFactory))); server = newServer(port, Arrays.asList(serverStreamTracerFactory));
boolean success; boolean success;
Thread.currentThread().interrupt(); Thread.currentThread().interrupt();
try { try {
@ -473,7 +471,7 @@ public abstract class AbstractTransportTest {
// resources. There may be cases this is impossible in the future, but for now it is a useful // resources. There may be cases this is impossible in the future, but for now it is a useful
// property. // property.
serverListener = new MockServerListener(); serverListener = new MockServerListener();
server = Iterables.getOnlyElement(newServer(port, Arrays.asList(serverStreamTracerFactory))); server = newServer(port, Arrays.asList(serverStreamTracerFactory));
server.start(serverListener); server.start(serverListener);
// Try to "flush" out any listener notifications on client and server. This also ensures that // Try to "flush" out any listener notifications on client and server. This also ensures that

View File

@ -46,7 +46,7 @@ public class ServerImplBuilderTest {
builder = new ServerImplBuilder( builder = new ServerImplBuilder(
new ClientTransportServersBuilder() { new ClientTransportServersBuilder() {
@Override @Override
public List<? extends InternalServer> buildClientTransportServers( public InternalServer buildClientTransportServers(
List<? extends ServerStreamTracer.Factory> streamTracerFactories) { List<? extends ServerStreamTracer.Factory> streamTracerFactories) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }

View File

@ -44,7 +44,6 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.SettableFuture;
@ -90,7 +89,6 @@ import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier; import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
@ -205,7 +203,7 @@ public class ServerImplTest {
builder = new ServerImplBuilder( builder = new ServerImplBuilder(
new ClientTransportServersBuilder() { new ClientTransportServersBuilder() {
@Override @Override
public List<? extends InternalServer> buildClientTransportServers( public InternalServer buildClientTransportServers(
List<? extends ServerStreamTracer.Factory> streamTracerFactories) { List<? extends ServerStreamTracer.Factory> streamTracerFactories) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@ -226,39 +224,19 @@ public class ServerImplTest {
} }
@Test @Test
public void multiport() throws Exception { public void getListenSockets() throws Exception {
final CountDownLatch starts = new CountDownLatch(2); int port = 800;
final CountDownLatch shutdowns = new CountDownLatch(2); final List<InetSocketAddress> addresses =
Collections.singletonList(new InetSocketAddress(800));
final class Serv extends SimpleServer { transportServer = new SimpleServer() {
@Override @Override
public void start(ServerListener listener) throws IOException { public List<InetSocketAddress> getListenSocketAddresses() {
super.start(listener); return addresses;
starts.countDown();
} }
};
@Override createAndStartServer();
public void shutdown() { assertEquals(port, server.getPort());
super.shutdown(); assertThat(server.getListenSockets()).isEqualTo(addresses);
shutdowns.countDown();
}
}
SimpleServer transportServer1 = new Serv();
SimpleServer transportServer2 = new Serv();
assertNull(server);
builder.fallbackHandlerRegistry(fallbackRegistry);
builder.executorPool = executorPool;
server = new ServerImpl(
builder, ImmutableList.of(transportServer1, transportServer2), SERVER_CONTEXT);
server.start();
assertTrue(starts.await(1, TimeUnit.SECONDS));
assertEquals(2, shutdowns.getCount());
server.shutdown();
assertTrue(shutdowns.await(1, TimeUnit.SECONDS));
assertTrue(server.awaitTermination(1, TimeUnit.SECONDS));
} }
@Test @Test
@ -1131,15 +1109,22 @@ public class ServerImplTest {
@Test @Test
public void getPort() throws Exception { public void getPort() throws Exception {
final InetSocketAddress addr = new InetSocketAddress(65535); final InetSocketAddress addr = new InetSocketAddress(65535);
final List<InetSocketAddress> addrs = Collections.singletonList(addr);
transportServer = new SimpleServer() { transportServer = new SimpleServer() {
@Override @Override
public SocketAddress getListenSocketAddress() { public InetSocketAddress getListenSocketAddress() {
return addr; return addr;
} }
@Override
public List<InetSocketAddress> getListenSocketAddresses() {
return addrs;
}
}; };
createAndStartServer(); createAndStartServer();
assertThat(server.getPort()).isEqualTo(addr.getPort()); assertThat(server.getPort()).isEqualTo(addr.getPort());
assertThat(server.getListenSockets()).isEqualTo(addrs);
} }
@Test @Test
@ -1431,7 +1416,7 @@ public class ServerImplTest {
builder.fallbackHandlerRegistry(fallbackRegistry); builder.fallbackHandlerRegistry(fallbackRegistry);
builder.executorPool = executorPool; builder.executorPool = executorPool;
server = new ServerImpl(builder, Collections.singletonList(transportServer), SERVER_CONTEXT); server = new ServerImpl(builder, transportServer, SERVER_CONTEXT);
} }
private void verifyExecutorsAcquired() { private void verifyExecutorsAcquired() {
@ -1469,11 +1454,21 @@ public class ServerImplTest {
return new InetSocketAddress(12345); return new InetSocketAddress(12345);
} }
@Override
public List<InetSocketAddress> getListenSocketAddresses() {
return Collections.singletonList(new InetSocketAddress(12345));
}
@Override @Override
public InternalInstrumented<SocketStats> getListenSocketStats() { public InternalInstrumented<SocketStats> getListenSocketStats() {
return null; return null;
} }
@Override
public List<InternalInstrumented<SocketStats>> getListenSocketStatsList() {
return null;
}
@Override @Override
public void shutdown() { public void shutdown() {
listener.serverShutdown(); listener.serverShutdown();

View File

@ -30,7 +30,7 @@ import java.util.List;
*/ */
@Internal @Internal
public final class InternalNettyServerBuilder { public final class InternalNettyServerBuilder {
public static List<NettyServer> buildTransportServers(NettyServerBuilder builder, public static NettyServer buildTransportServers(NettyServerBuilder builder,
List<? extends ServerStreamTracer.Factory> streamTracerFactories) { List<? extends ServerStreamTracer.Factory> streamTracerFactories) {
return builder.buildTransportServers(streamTracerFactories); return builder.buildTransportServers(streamTracerFactories);
} }

View File

@ -45,17 +45,26 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption; import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel; import io.netty.channel.ServerChannel;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.ChannelGroupFuture;
import io.netty.channel.group.ChannelGroupFutureListener;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.util.AbstractReferenceCounted; import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCounted; import io.netty.util.ReferenceCounted;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.GenericFutureListener;
import java.io.IOException; import java.io.IOException;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.Callable;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@ -66,7 +75,7 @@ class NettyServer implements InternalServer, InternalWithLogId {
private static final Logger log = Logger.getLogger(InternalServer.class.getName()); private static final Logger log = Logger.getLogger(InternalServer.class.getName());
private final InternalLogId logId; private final InternalLogId logId;
private final SocketAddress address; private final List<? extends SocketAddress> addresses;
private final ChannelFactory<? extends ServerChannel> channelFactory; private final ChannelFactory<? extends ServerChannel> channelFactory;
private final Map<ChannelOption<?>, ?> channelOptions; private final Map<ChannelOption<?>, ?> channelOptions;
private final Map<ChannelOption<?>, ?> childChannelOptions; private final Map<ChannelOption<?>, ?> childChannelOptions;
@ -78,7 +87,7 @@ class NettyServer implements InternalServer, InternalWithLogId {
private EventLoopGroup bossGroup; private EventLoopGroup bossGroup;
private EventLoopGroup workerGroup; private EventLoopGroup workerGroup;
private ServerListener listener; private ServerListener listener;
private Channel channel; private final ChannelGroup channelGroup;
private final boolean autoFlowControl; private final boolean autoFlowControl;
private final int flowControlWindow; private final int flowControlWindow;
private final int maxMessageSize; private final int maxMessageSize;
@ -96,11 +105,14 @@ class NettyServer implements InternalServer, InternalWithLogId {
private final List<? extends ServerStreamTracer.Factory> streamTracerFactories; private final List<? extends ServerStreamTracer.Factory> streamTracerFactories;
private final TransportTracer.Factory transportTracerFactory; private final TransportTracer.Factory transportTracerFactory;
private final InternalChannelz channelz; private final InternalChannelz channelz;
// Only modified in event loop but safe to read any time. private volatile List<InternalInstrumented<SocketStats>> listenSocketStatsList =
private volatile InternalInstrumented<SocketStats> listenSocketStats; Collections.emptyList();
private volatile boolean terminated;
private final EventLoop bossExecutor;
NettyServer( NettyServer(
SocketAddress address, ChannelFactory<? extends ServerChannel> channelFactory, List<? extends SocketAddress> addresses,
ChannelFactory<? extends ServerChannel> channelFactory,
Map<ChannelOption<?>, ?> channelOptions, Map<ChannelOption<?>, ?> channelOptions,
Map<ChannelOption<?>, ?> childChannelOptions, Map<ChannelOption<?>, ?> childChannelOptions,
ObjectPool<? extends EventLoopGroup> bossGroupPool, ObjectPool<? extends EventLoopGroup> bossGroupPool,
@ -116,7 +128,7 @@ class NettyServer implements InternalServer, InternalWithLogId {
long maxConnectionAgeInNanos, long maxConnectionAgeGraceInNanos, long maxConnectionAgeInNanos, long maxConnectionAgeGraceInNanos,
boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos, boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos,
Attributes eagAttributes, InternalChannelz channelz) { Attributes eagAttributes, InternalChannelz channelz) {
this.address = address; this.addresses = checkNotNull(addresses, "addresses");
this.channelFactory = checkNotNull(channelFactory, "channelFactory"); this.channelFactory = checkNotNull(channelFactory, "channelFactory");
checkNotNull(channelOptions, "channelOptions"); checkNotNull(channelOptions, "channelOptions");
this.channelOptions = new HashMap<ChannelOption<?>, Object>(channelOptions); this.channelOptions = new HashMap<ChannelOption<?>, Object>(channelOptions);
@ -126,6 +138,8 @@ class NettyServer implements InternalServer, InternalWithLogId {
this.workerGroupPool = checkNotNull(workerGroupPool, "workerGroupPool"); this.workerGroupPool = checkNotNull(workerGroupPool, "workerGroupPool");
this.forceHeapBuffer = forceHeapBuffer; this.forceHeapBuffer = forceHeapBuffer;
this.bossGroup = bossGroupPool.getObject(); this.bossGroup = bossGroupPool.getObject();
this.bossExecutor = bossGroup.next();
this.channelGroup = new DefaultChannelGroup(this.bossExecutor);
this.workerGroup = workerGroupPool.getObject(); this.workerGroup = workerGroupPool.getObject();
this.protocolNegotiator = checkNotNull(protocolNegotiator, "protocolNegotiator"); this.protocolNegotiator = checkNotNull(protocolNegotiator, "protocolNegotiator");
this.streamTracerFactories = checkNotNull(streamTracerFactories, "streamTracerFactories"); this.streamTracerFactories = checkNotNull(streamTracerFactories, "streamTracerFactories");
@ -144,32 +158,53 @@ class NettyServer implements InternalServer, InternalWithLogId {
this.permitKeepAliveTimeInNanos = permitKeepAliveTimeInNanos; this.permitKeepAliveTimeInNanos = permitKeepAliveTimeInNanos;
this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes"); this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes");
this.channelz = Preconditions.checkNotNull(channelz); this.channelz = Preconditions.checkNotNull(channelz);
this.logId = this.logId = InternalLogId.allocate(getClass(), addresses.isEmpty() ? "No address" :
InternalLogId.allocate(getClass(), address != null ? address.toString() : "No address"); String.valueOf(addresses));
} }
@Override @Override
public SocketAddress getListenSocketAddress() { public SocketAddress getListenSocketAddress() {
if (channel == null) { Iterator<Channel> it = channelGroup.iterator();
if (it.hasNext()) {
return it.next().localAddress();
} else {
// server is not listening/bound yet, just return the original port. // server is not listening/bound yet, just return the original port.
return address; return addresses.isEmpty() ? null : addresses.get(0);
} }
return channel.localAddress(); }
@Override
public List<SocketAddress> getListenSocketAddresses() {
List<SocketAddress> listenSocketAddresses = new ArrayList<>();
for (Channel c: channelGroup) {
listenSocketAddresses.add(c.localAddress());
}
// server is not listening/bound yet, just return the original ports.
if (listenSocketAddresses.isEmpty()) {
listenSocketAddresses.addAll(addresses);
}
return listenSocketAddresses;
} }
@Override @Override
public InternalInstrumented<SocketStats> getListenSocketStats() { public InternalInstrumented<SocketStats> getListenSocketStats() {
return listenSocketStats; List<InternalInstrumented<SocketStats>> savedListenSocketStatsList = listenSocketStatsList;
return savedListenSocketStatsList.isEmpty() ? null : savedListenSocketStatsList.get(0);
}
@Override
public List<InternalInstrumented<SocketStats>> getListenSocketStatsList() {
return listenSocketStatsList;
} }
@Override @Override
public void start(ServerListener serverListener) throws IOException { public void start(ServerListener serverListener) throws IOException {
listener = checkNotNull(serverListener, "serverListener"); listener = checkNotNull(serverListener, "serverListener");
ServerBootstrap b = new ServerBootstrap(); final ServerBootstrap b = new ServerBootstrap();
b.option(ALLOCATOR, Utils.getByteBufAllocator(forceHeapBuffer)); b.option(ALLOCATOR, Utils.getByteBufAllocator(forceHeapBuffer));
b.childOption(ALLOCATOR, Utils.getByteBufAllocator(forceHeapBuffer)); b.childOption(ALLOCATOR, Utils.getByteBufAllocator(forceHeapBuffer));
b.group(bossGroup, workerGroup); b.group(bossExecutor, workerGroup);
b.channelFactory(channelFactory); b.channelFactory(channelFactory);
// For non-socket based channel, the option will be ignored. // For non-socket based channel, the option will be ignored.
b.childOption(SO_KEEPALIVE, true); b.childOption(SO_KEEPALIVE, true);
@ -226,8 +261,8 @@ class NettyServer implements InternalServer, InternalWithLogId {
ServerTransportListener transportListener; ServerTransportListener transportListener;
// This is to order callbacks on the listener, not to guard access to channel. // This is to order callbacks on the listener, not to guard access to channel.
synchronized (NettyServer.this) { synchronized (NettyServer.this) {
if (channel != null && !channel.isOpen()) { if (terminated) {
// Server already shutdown. // Server already terminated.
ch.close(); ch.close();
return; return;
} }
@ -258,51 +293,77 @@ class NettyServer implements InternalServer, InternalWithLogId {
ch.closeFuture().addListener(loopReleaser); ch.closeFuture().addListener(loopReleaser);
} }
}); });
// Bind and start to accept incoming connections. Future<Map<ChannelFuture, SocketAddress>> bindCallFuture =
ChannelFuture future = b.bind(address); bossExecutor.submit(
// We'd love to observe interruption, but if interrupted we will need to close the channel, new Callable<Map<ChannelFuture, SocketAddress>>() {
// which itself would need an await() to guarantee the port is not used when the method returns. @Override
// See #6850 public Map<ChannelFuture, SocketAddress> call() {
future.awaitUninterruptibly(); Map<ChannelFuture, SocketAddress> bindFutures = new HashMap<>();
if (!future.isSuccess()) { for (SocketAddress address: addresses) {
throw new IOException(String.format("Failed to bind to address %s", address), future.cause()); ChannelFuture future = b.bind(address);
channelGroup.add(future.channel());
bindFutures.put(future, address);
}
return bindFutures;
}
}
);
Map<ChannelFuture, SocketAddress> channelFutures =
bindCallFuture.awaitUninterruptibly().getNow();
if (!bindCallFuture.isSuccess()) {
channelGroup.close().awaitUninterruptibly();
throw new IOException(String.format("Failed to bind to addresses %s",
addresses), bindCallFuture.cause());
} }
channel = future.channel(); final List<InternalInstrumented<SocketStats>> socketStats = new ArrayList<>();
channel.eventLoop().execute(new Runnable() { for (Map.Entry<ChannelFuture, SocketAddress> entry: channelFutures.entrySet()) {
@Override // We'd love to observe interruption, but if interrupted we will need to close the channel,
public void run() { // which itself would need an await() to guarantee the port is not used when the method
listenSocketStats = new ListenSocket(channel); // returns. See #6850
channelz.addListenSocket(listenSocketStats); final ChannelFuture future = entry.getKey();
if (!future.awaitUninterruptibly().isSuccess()) {
channelGroup.close().awaitUninterruptibly();
throw new IOException(String.format("Failed to bind to address %s",
entry.getValue()), future.cause());
} }
}); final InternalInstrumented<SocketStats> listenSocketStats =
new ListenSocket(future.channel());
channelz.addListenSocket(listenSocketStats);
socketStats.add(listenSocketStats);
future.channel().closeFuture().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
channelz.removeListenSocket(listenSocketStats);
}
});
}
listenSocketStatsList = Collections.unmodifiableList(socketStats);
} }
@Override @Override
public void shutdown() { public void shutdown() {
if (channel == null || !channel.isOpen()) { if (terminated) {
// Already closed.
return; return;
} }
channel.close().addListener(new ChannelFutureListener() { ChannelGroupFuture groupFuture = channelGroup.close()
@Override .addListener(new ChannelGroupFutureListener() {
public void operationComplete(ChannelFuture future) throws Exception { @Override
if (!future.isSuccess()) { public void operationComplete(ChannelGroupFuture future) throws Exception {
log.log(Level.WARNING, "Error shutting down server", future.cause()); if (!future.isSuccess()) {
} log.log(Level.WARNING, "Error closing server channel group", future.cause());
InternalInstrumented<SocketStats> stats = listenSocketStats; }
listenSocketStats = null; sharedResourceReferenceCounter.release();
if (stats != null) { protocolNegotiator.close();
channelz.removeListenSocket(stats); listenSocketStatsList = Collections.emptyList();
} synchronized (NettyServer.this) {
sharedResourceReferenceCounter.release(); listener.serverShutdown();
protocolNegotiator.close(); terminated = true;
synchronized (NettyServer.this) { }
listener.serverShutdown(); }
} });
}
});
try { try {
channel.closeFuture().await(); groupFuture.await();
} catch (InterruptedException e) { } catch (InterruptedException e) {
log.log(Level.FINE, "Interrupted while shutting down", e); log.log(Level.FINE, "Interrupted while shutting down", e);
Thread.currentThread().interrupt(); Thread.currentThread().interrupt();
@ -318,7 +379,7 @@ class NettyServer implements InternalServer, InternalWithLogId {
public String toString() { public String toString() {
return MoreObjects.toStringHelper(this) return MoreObjects.toStringHelper(this)
.add("logId", logId.getId()) .add("logId", logId.getId())
.add("address", address) .add("addresses", addresses)
.toString(); .toString();
} }

View File

@ -54,7 +54,6 @@ import java.io.InputStream;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -165,7 +164,7 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder<NettySer
private final class NettyClientTransportServersBuilder implements ClientTransportServersBuilder { private final class NettyClientTransportServersBuilder implements ClientTransportServersBuilder {
@Override @Override
public List<? extends InternalServer> buildClientTransportServers( public InternalServer buildClientTransportServers(
List<? extends ServerStreamTracer.Factory> streamTracerFactories) { List<? extends ServerStreamTracer.Factory> streamTracerFactories) {
return buildTransportServers(streamTracerFactories); return buildTransportServers(streamTracerFactories);
} }
@ -623,27 +622,22 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder<NettySer
} }
@CheckReturnValue @CheckReturnValue
List<NettyServer> buildTransportServers( NettyServer buildTransportServers(
List<? extends ServerStreamTracer.Factory> streamTracerFactories) { List<? extends ServerStreamTracer.Factory> streamTracerFactories) {
assertEventLoopsAndChannelType(); assertEventLoopsAndChannelType();
ProtocolNegotiator negotiator = protocolNegotiatorFactory.newNegotiator( ProtocolNegotiator negotiator = protocolNegotiatorFactory.newNegotiator(
this.serverImplBuilder.getExecutorPool()); this.serverImplBuilder.getExecutorPool());
List<NettyServer> transportServers = new ArrayList<>(listenAddresses.size()); return new NettyServer(
for (SocketAddress listenAddress : listenAddresses) { listenAddresses, channelFactory, channelOptions, childChannelOptions,
NettyServer transportServer = new NettyServer( bossEventLoopGroupPool, workerEventLoopGroupPool, forceHeapBuffer, negotiator,
listenAddress, channelFactory, channelOptions, childChannelOptions, streamTracerFactories, transportTracerFactory, maxConcurrentCallsPerConnection,
bossEventLoopGroupPool, workerEventLoopGroupPool, forceHeapBuffer, negotiator, autoFlowControl, flowControlWindow, maxMessageSize, maxHeaderListSize,
streamTracerFactories, transportTracerFactory, maxConcurrentCallsPerConnection, keepAliveTimeInNanos, keepAliveTimeoutInNanos,
autoFlowControl, flowControlWindow, maxMessageSize, maxHeaderListSize, maxConnectionIdleInNanos, maxConnectionAgeInNanos,
keepAliveTimeInNanos, keepAliveTimeoutInNanos, maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos,
maxConnectionIdleInNanos, maxConnectionAgeInNanos, eagAttributes, this.serverImplBuilder.getChannelz());
maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos,
eagAttributes, this.serverImplBuilder.getChannelz());
transportServers.add(transportServer);
}
return Collections.unmodifiableList(transportServers);
} }
@VisibleForTesting @VisibleForTesting

View File

@ -773,7 +773,7 @@ public class NettyClientTransportTest {
private void startServer(int maxStreamsPerConnection, int maxHeaderListSize) throws IOException { private void startServer(int maxStreamsPerConnection, int maxHeaderListSize) throws IOException {
server = new NettyServer( server = new NettyServer(
TestUtils.testServerAddress(new InetSocketAddress(0)), TestUtils.testServerAddresses(new InetSocketAddress(0)),
new ReflectiveChannelFactory<>(NioServerSocketChannel.class), new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(), new HashMap<ChannelOption<?>, Object>(),
new HashMap<ChannelOption<?>, Object>(), new HashMap<ChannelOption<?>, Object>(),

View File

@ -26,7 +26,6 @@ import io.netty.channel.EventLoopGroup;
import io.netty.channel.local.LocalServerChannel; import io.netty.channel.local.LocalServerChannel;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
@ -46,12 +45,12 @@ public class NettyServerBuilderTest {
private NettyServerBuilder builder = NettyServerBuilder.forPort(8080); private NettyServerBuilder builder = NettyServerBuilder.forPort(8080);
@Test @Test
public void createMultipleServers() { public void addMultipleListenAddresses() {
builder.addListenAddress(new InetSocketAddress(8081)); builder.addListenAddress(new InetSocketAddress(8081));
List<NettyServer> servers = NettyServer server =
builder.buildTransportServers(ImmutableList.<ServerStreamTracer.Factory>of()); builder.buildTransportServers(ImmutableList.<ServerStreamTracer.Factory>of());
Truth.assertThat(servers).hasSize(2); Truth.assertThat(server.getListenSocketAddresses()).hasSize(2);
} }
@Test @Test

View File

@ -19,9 +19,18 @@ package io.grpc.netty;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static io.grpc.InternalChannelz.id; import static io.grpc.InternalChannelz.id;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.SettableFuture;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.InternalChannelz; import io.grpc.InternalChannelz;
@ -36,31 +45,63 @@ import io.grpc.internal.ServerTransport;
import io.grpc.internal.ServerTransportListener; import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.TransportTracer; import io.grpc.internal.TransportTracer;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFactory;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOption; import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ReflectiveChannelFactory; import io.netty.channel.ReflectiveChannelFactory;
import io.netty.channel.WriteBufferWaterMark; import io.netty.channel.WriteBufferWaterMark;
import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.AsciiString; import io.netty.util.AsciiString;
import io.netty.util.concurrent.Future;
import java.io.IOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.Socket; import java.net.Socket;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.junit.After; import org.junit.After;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class NettyServerTest { public class NettyServerTest {
private final InternalChannelz channelz = new InternalChannelz(); private final InternalChannelz channelz = new InternalChannelz();
private final NioEventLoopGroup eventLoop = new NioEventLoopGroup(1); private final NioEventLoopGroup eventLoop = new NioEventLoopGroup(1);
private final ChannelFactory<NioServerSocketChannel> channelFactory =
new ReflectiveChannelFactory<>(NioServerSocketChannel.class);
@Mock
EventLoopGroup mockEventLoopGroup;
@Mock
EventLoop mockEventLoop;
@Mock
Future<Map<ChannelFuture, SocketAddress>> bindFuture;
@Before
public void setup() throws Exception {
MockitoAnnotations.initMocks(this);
when(mockEventLoopGroup.next()).thenReturn(mockEventLoop);
when(mockEventLoop
.submit(ArgumentMatchers.<Callable<Map<ChannelFuture, SocketAddress>>>any()))
.thenReturn(bindFuture);
}
@After @After
public void tearDown() throws Exception { public void tearDown() throws Exception {
@ -90,7 +131,7 @@ public class NettyServerTest {
NoHandlerProtocolNegotiator protocolNegotiator = new NoHandlerProtocolNegotiator(); NoHandlerProtocolNegotiator protocolNegotiator = new NoHandlerProtocolNegotiator();
NettyServer ns = new NettyServer( NettyServer ns = new NettyServer(
addr, Arrays.asList(addr),
new ReflectiveChannelFactory<>(NioServerSocketChannel.class), new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(), new HashMap<ChannelOption<?>, Object>(),
new HashMap<ChannelOption<?>, Object>(), new HashMap<ChannelOption<?>, Object>(),
@ -134,11 +175,147 @@ public class NettyServerTest {
assertThat(protocolNegotiator.closed).isTrue(); assertThat(protocolNegotiator.closed).isTrue();
} }
@Test
public void multiPortStartStopGet() throws Exception {
InetSocketAddress addr1 = new InetSocketAddress(0);
InetSocketAddress addr2 = new InetSocketAddress(0);
NettyServer ns = new NettyServer(
Arrays.asList(addr1, addr2),
new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(),
new HashMap<ChannelOption<?>, Object>(),
new FixedObjectPool<>(eventLoop),
new FixedObjectPool<>(eventLoop),
false,
ProtocolNegotiators.plaintext(),
Collections.<ServerStreamTracer.Factory>emptyList(),
TransportTracer.getDefaultFactory(),
1, // ignore
false, // ignore
1, // ignore
1, // ignore
1, // ignore
1, // ignore
1, 1, // ignore
1, 1, // ignore
true, 0, // ignore
Attributes.EMPTY,
channelz);
final SettableFuture<Void> shutdownCompleted = SettableFuture.create();
ns.start(new ServerListener() {
@Override
public ServerTransportListener transportCreated(ServerTransport transport) {
return new NoopServerTransportListener();
}
@Override
public void serverShutdown() {
shutdownCompleted.set(null);
}
});
// SocketStats won't be available until the event loop task of adding SocketStats created by
// ns.start() complete. So submit a noop task and await until it's drained.
eventLoop.submit(new Runnable() {
@Override
public void run() {}
}).await(5, TimeUnit.SECONDS);
assertEquals(2, ns.getListenSocketAddresses().size());
for (SocketAddress address: ns.getListenSocketAddresses()) {
assertThat(((InetSocketAddress) address).getPort()).isGreaterThan(0);
}
List<InternalInstrumented<SocketStats>> stats = ns.getListenSocketStatsList();
assertEquals(2, ns.getListenSocketStatsList().size());
for (InternalInstrumented<SocketStats> listenSocket : stats) {
assertSame(listenSocket, channelz.getSocket(id(listenSocket)));
// very basic sanity check of the contents
SocketStats socketStats = listenSocket.getStats().get();
assertThat(ns.getListenSocketAddresses()).contains(socketStats.local);
assertNull(socketStats.remote);
}
// Cleanup
ns.shutdown();
shutdownCompleted.get();
// listen socket is removed
for (InternalInstrumented<SocketStats> listenSocket : stats) {
assertNull(channelz.getSocket(id(listenSocket)));
}
}
@Test(timeout = 60000)
public void multiPortConnections() throws Exception {
InetSocketAddress addr1 = new InetSocketAddress(0);
InetSocketAddress addr2 = new InetSocketAddress(0);
final CountDownLatch allPortsConnectedCountDown = new CountDownLatch(2);
NettyServer ns = new NettyServer(
Arrays.asList(addr1, addr2),
new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(),
new HashMap<ChannelOption<?>, Object>(),
new FixedObjectPool<>(eventLoop),
new FixedObjectPool<>(eventLoop),
false,
ProtocolNegotiators.plaintext(),
Collections.<ServerStreamTracer.Factory>emptyList(),
TransportTracer.getDefaultFactory(),
1, // ignore
false, // ignore
1, // ignore
1, // ignore
1, // ignore
1, // ignore
1, 1, // ignore
1, 1, // ignore
true, 0, // ignore
Attributes.EMPTY,
channelz);
final SettableFuture<Void> shutdownCompleted = SettableFuture.create();
ns.start(new ServerListener() {
@Override
public ServerTransportListener transportCreated(ServerTransport transport) {
allPortsConnectedCountDown.countDown();
return new NoopServerTransportListener();
}
@Override
public void serverShutdown() {
shutdownCompleted.set(null);
}
});
// SocketStats won't be available until the event loop task of adding SocketStats created by
// ns.start() complete. So submit a noop task and await until it's drained.
eventLoop.submit(new Runnable() {
@Override
public void run() {}
}).await(5, TimeUnit.SECONDS);
List<SocketAddress> serverSockets = ns.getListenSocketAddresses();
assertEquals(2, serverSockets.size());
for (int i = 0; i < 2; i++) {
Socket socket = new Socket();
socket.connect(serverSockets.get(i), /* timeout= */ 8000);
socket.close();
}
allPortsConnectedCountDown.await();
// Cleanup
ns.shutdown();
shutdownCompleted.get();
}
@Test @Test
public void getPort_notStarted() { public void getPort_notStarted() {
InetSocketAddress addr = new InetSocketAddress(0); InetSocketAddress addr = new InetSocketAddress(0);
List<InetSocketAddress> addresses = Collections.singletonList(addr);
NettyServer ns = new NettyServer( NettyServer ns = new NettyServer(
addr, addresses,
new ReflectiveChannelFactory<>(NioServerSocketChannel.class), new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(), new HashMap<ChannelOption<?>, Object>(),
new HashMap<ChannelOption<?>, Object>(), new HashMap<ChannelOption<?>, Object>(),
@ -161,6 +338,7 @@ public class NettyServerTest {
channelz); channelz);
assertThat(ns.getListenSocketAddress()).isEqualTo(addr); assertThat(ns.getListenSocketAddress()).isEqualTo(addr);
assertThat(ns.getListenSocketAddresses()).isEqualTo(addresses);
} }
@Test(timeout = 60000) @Test(timeout = 60000)
@ -211,7 +389,7 @@ public class NettyServerTest {
TestProtocolNegotiator protocolNegotiator = new TestProtocolNegotiator(); TestProtocolNegotiator protocolNegotiator = new TestProtocolNegotiator();
InetSocketAddress addr = new InetSocketAddress(0); InetSocketAddress addr = new InetSocketAddress(0);
NettyServer ns = new NettyServer( NettyServer ns = new NettyServer(
addr, Arrays.asList(addr),
new ReflectiveChannelFactory<>(NioServerSocketChannel.class), new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(), new HashMap<ChannelOption<?>, Object>(),
childChannelOptions, childChannelOptions,
@ -258,7 +436,7 @@ public class NettyServerTest {
public void channelzListenSocket() throws Exception { public void channelzListenSocket() throws Exception {
InetSocketAddress addr = new InetSocketAddress(0); InetSocketAddress addr = new InetSocketAddress(0);
NettyServer ns = new NettyServer( NettyServer ns = new NettyServer(
addr, Arrays.asList(addr),
new ReflectiveChannelFactory<>(NioServerSocketChannel.class), new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(), new HashMap<ChannelOption<?>, Object>(),
new HashMap<ChannelOption<?>, Object>(), new HashMap<ChannelOption<?>, Object>(),
@ -320,6 +498,110 @@ public class NettyServerTest {
assertNull(channelz.getSocket(id(listenSocket))); assertNull(channelz.getSocket(id(listenSocket)));
} }
@Test
@SuppressWarnings("unchecked")
public void testBindScheduleFailure() throws Exception {
when(bindFuture.awaitUninterruptibly()).thenReturn(bindFuture);
when(bindFuture.isSuccess()).thenReturn(false);
when(bindFuture.getNow()).thenReturn(null);
Throwable mockCause = mock(Throwable.class);
when(bindFuture.cause()).thenReturn(mockCause);
Future<Void> mockFuture = (Future<Void>) mock(Future.class);
doReturn(mockFuture).when(mockEventLoopGroup).submit(any(Runnable.class));
SocketAddress addr = new InetSocketAddress(0);
verifyServerNotStart(Collections.singletonList(addr), mockEventLoopGroup,
IOException.class, "Failed to bind to addresses " + Arrays.asList(addr));
}
@Test
@SuppressWarnings("unchecked")
public void testBindFailure() throws Exception {
when(bindFuture.awaitUninterruptibly()).thenReturn(bindFuture);
ChannelFuture future = mock(ChannelFuture.class);
when(future.awaitUninterruptibly()).thenReturn(future);
when(future.isSuccess()).thenReturn(false);
Channel channel = channelFactory.newChannel();
eventLoop.register(channel);
when(future.channel()).thenReturn(channel);
Throwable mockCause = mock(Throwable.class);
when(future.cause()).thenReturn(mockCause);
SocketAddress addr = new InetSocketAddress(0);
Map<ChannelFuture, SocketAddress> map = ImmutableMap.of(future, addr);
when(bindFuture.getNow()).thenReturn(map);
when(bindFuture.isSuccess()).thenReturn(true);
Future<Void> mockFuture = (Future<Void>) mock(Future.class);
doReturn(mockFuture).when(mockEventLoopGroup).submit(any(Runnable.class));
verifyServerNotStart(Collections.singletonList(addr), mockEventLoopGroup,
IOException.class, "Failed to bind to address " + addr);
}
@Test
public void testBindPartialFailure() throws Exception {
SocketAddress add1 = new InetSocketAddress(0);
SocketAddress add2 = new InetSocketAddress(2);
SocketAddress add3 = new InetSocketAddress(2);
verifyServerNotStart(ImmutableList.of(add1, add2, add3), eventLoop,
IOException.class, "Failed to bind to address " + add3);
}
private void verifyServerNotStart(List<SocketAddress> addr, EventLoopGroup ev,
Class<?> expectedException, String expectedMessage)
throws Exception {
NettyServer ns = getServer(addr, ev);
try {
ns.start(new ServerListener() {
@Override
public ServerTransportListener transportCreated(ServerTransport transport) {
return new NoopServerTransportListener();
}
@Override
public void serverShutdown() {
}
});
} catch (Exception ex) {
assertTrue(expectedException.isInstance(ex));
assertThat(ex.getMessage()).isEqualTo(expectedMessage);
assertFalse(addr.isEmpty());
// Listener tasks are executed on the event loop, so await until noop task is drained.
ev.submit(new Runnable() {
@Override
public void run() {}
}).await(5, TimeUnit.SECONDS);
assertThat(ns.getListenSocketAddress()).isEqualTo(addr.get(0));
assertThat(ns.getListenSocketAddresses()).isEqualTo(addr);
assertTrue(ns.getListenSocketStatsList().isEmpty());
assertNull(ns.getListenSocketStats());
return;
}
fail();
}
private NettyServer getServer(List<SocketAddress> addr, EventLoopGroup ev) {
return new NettyServer(
addr,
new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(),
new HashMap<ChannelOption<?>, Object>(),
new FixedObjectPool<>(ev),
new FixedObjectPool<>(ev),
false,
ProtocolNegotiators.plaintext(),
Collections.<ServerStreamTracer.Factory>emptyList(),
TransportTracer.getDefaultFactory(),
1, // ignore
false, // ignore
1, // ignore
1, // ignore
1, // ignore
1, // ignore
1, 1, // ignore
1, 1, // ignore
true, 0, // ignore
Attributes.EMPTY,
channelz);
}
private static class NoopServerTransportListener implements ServerTransportListener { private static class NoopServerTransportListener implements ServerTransportListener {
@Override public void streamCreated(ServerStream stream, String method, Metadata headers) {} @Override public void streamCreated(ServerStream stream, String method, Metadata headers) {}

View File

@ -63,7 +63,7 @@ public class NettyTransportTest extends AbstractTransportTest {
} }
@Override @Override
protected List<? extends InternalServer> newServer( protected InternalServer newServer(
List<ServerStreamTracer.Factory> streamTracerFactories) { List<ServerStreamTracer.Factory> streamTracerFactories) {
return NettyServerBuilder return NettyServerBuilder
.forAddress(new InetSocketAddress("localhost", 0)) .forAddress(new InetSocketAddress("localhost", 0))
@ -73,7 +73,7 @@ public class NettyTransportTest extends AbstractTransportTest {
} }
@Override @Override
protected List<? extends InternalServer> newServer( protected InternalServer newServer(
int port, List<ServerStreamTracer.Factory> streamTracerFactories) { int port, List<ServerStreamTracer.Factory> streamTracerFactories) {
return NettyServerBuilder return NettyServerBuilder
.forAddress(new InetSocketAddress("localhost", port)) .forAddress(new InetSocketAddress("localhost", port))

View File

@ -51,7 +51,7 @@ public class OkHttpTransportTest extends AbstractTransportTest {
} }
@Override @Override
protected List<? extends InternalServer> newServer( protected InternalServer newServer(
List<ServerStreamTracer.Factory> streamTracerFactories) { List<ServerStreamTracer.Factory> streamTracerFactories) {
NettyServerBuilder builder = NettyServerBuilder NettyServerBuilder builder = NettyServerBuilder
.forPort(0) .forPort(0)
@ -61,7 +61,7 @@ public class OkHttpTransportTest extends AbstractTransportTest {
} }
@Override @Override
protected List<? extends InternalServer> newServer( protected InternalServer newServer(
int port, List<ServerStreamTracer.Factory> streamTracerFactories) { int port, List<ServerStreamTracer.Factory> streamTracerFactories) {
NettyServerBuilder builder = NettyServerBuilder NettyServerBuilder builder = NettyServerBuilder
.forAddress(new InetSocketAddress(port)) .forAddress(new InetSocketAddress(port))

View File

@ -77,6 +77,24 @@ public class TestUtils {
} }
} }
/**
* Creates a new list of {@link InetSocketAddress} on localhost that overrides the host with
* {@link #TEST_SERVER_HOST}.
*/
public static List<InetSocketAddress> testServerAddresses(InetSocketAddress... originalSockAddr) {
try {
InetAddress inetAddress = InetAddress.getByName("localhost");
inetAddress = InetAddress.getByAddress(TEST_SERVER_HOST, inetAddress.getAddress());
List<InetSocketAddress> addresses = new ArrayList<>();
for (InetSocketAddress orig: originalSockAddr) {
addresses.add(new InetSocketAddress(inetAddress, orig.getPort()));
}
return addresses;
} catch (UnknownHostException e) {
throw new RuntimeException(e);
}
}
/** /**
* Returns the ciphers preferred to use during tests. They may be chosen because they are widely * Returns the ciphers preferred to use during tests. They may be chosen because they are widely
* available or because they are fast. There is no requirement that they provide confidentiality * available or because they are fast. There is no requirement that they provide confidentiality