core: ServerImpl returns shared resources at termination (#2605)

Previously it does it at shutdown, which was wrong because executor may
still be used before the server is terminated.

Resolves #2034

Uses ObjectPool to make this change testable.  Cleans up test and makes
it mostly single-threaded, except for two deadlock tests that have to be
multi-threaded.
This commit is contained in:
Kun Zhang 2017-01-13 09:00:32 -08:00 committed by GitHub
parent ec7f00a272
commit c436d93f07
4 changed files with 170 additions and 137 deletions

View File

@ -165,7 +165,8 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
@Override @Override
public ServerImpl build() { public ServerImpl build() {
io.grpc.internal.InternalServer transportServer = buildTransportServer(); io.grpc.internal.InternalServer transportServer = buildTransportServer();
ServerImpl server = new ServerImpl(executor, registryBuilder.build(), ServerImpl server = new ServerImpl(getExecutorPool(),
SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE), registryBuilder.build(),
firstNonNull(fallbackRegistry, EMPTY_FALLBACK_REGISTRY), transportServer, firstNonNull(fallbackRegistry, EMPTY_FALLBACK_REGISTRY), transportServer,
Context.ROOT, firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()), Context.ROOT, firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()),
firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()), firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()),
@ -179,6 +180,24 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
return server; return server;
} }
private ObjectPool<? extends Executor> getExecutorPool() {
final Executor savedExecutor = executor;
if (savedExecutor == null) {
return SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR);
}
return new ObjectPool<Executor>() {
@Override
public Executor getObject() {
return savedExecutor;
}
@Override
public Executor returnObject(Object object) {
return null;
}
};
}
/** /**
* Children of AbstractServerBuilder should override this method to provide transport specific * Children of AbstractServerBuilder should override this method to provide transport specific
* information for the server. This method is mean for Transport implementors and should not be * information for the server. This method is mean for Transport implementors and should not be

View File

@ -31,6 +31,9 @@
package io.grpc.internal; package io.grpc.internal;
import javax.annotation.concurrent.ThreadSafe;
@ThreadSafe
public interface ObjectPool<T> { public interface ObjectPool<T> {
/** /**
* Get an object from the pool. * Get an object from the pool.

View File

@ -36,7 +36,6 @@ import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.grpc.Contexts.statusFromCancelled; import static io.grpc.Contexts.statusFromCancelled;
import static io.grpc.Status.DEADLINE_EXCEEDED; import static io.grpc.Status.DEADLINE_EXCEEDED;
import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY;
import static io.grpc.internal.GrpcUtil.TIMER_SERVICE;
import static java.util.concurrent.TimeUnit.NANOSECONDS; import static java.util.concurrent.TimeUnit.NANOSECONDS;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
@ -64,7 +63,6 @@ import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -88,10 +86,9 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
private static final ServerStreamListener NOOP_LISTENER = new NoopListener(); private static final ServerStreamListener NOOP_LISTENER = new NoopListener();
private final LogId logId = LogId.allocate(getClass().getName()); private final LogId logId = LogId.allocate(getClass().getName());
private final ObjectPool<? extends Executor> executorPool;
/** Executor for application processing. Safe to read after {@link #start()}. */ /** Executor for application processing. Safe to read after {@link #start()}. */
private Executor executor; private Executor executor;
/** Safe to read after {@link #start()}. */
private boolean usingSharedExecutor;
private final InternalHandlerRegistry registry; private final InternalHandlerRegistry registry;
private final HandlerRegistry fallbackRegistry; private final HandlerRegistry fallbackRegistry;
private final List<ServerTransportFilter> transportFilters; private final List<ServerTransportFilter> transportFilters;
@ -111,7 +108,8 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
@GuardedBy("lock") private final Collection<ServerTransport> transports = @GuardedBy("lock") private final Collection<ServerTransport> transports =
new HashSet<ServerTransport>(); new HashSet<ServerTransport>();
private final ScheduledExecutorService timeoutService = SharedResourceHolder.get(TIMER_SERVICE); private final ObjectPool<ScheduledExecutorService> timeoutServicePool;
private ScheduledExecutorService timeoutService;
private final Context rootContext; private final Context rootContext;
private final DecompressorRegistry decompressorRegistry; private final DecompressorRegistry decompressorRegistry;
@ -126,12 +124,15 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
* doesn't have the method * doesn't have the method
* @param executor to call methods on behalf of remote clients * @param executor to call methods on behalf of remote clients
*/ */
ServerImpl(Executor executor, InternalHandlerRegistry registry, HandlerRegistry fallbackRegistry, ServerImpl(ObjectPool<? extends Executor> executorPool,
ObjectPool<ScheduledExecutorService> timeoutServicePool,
InternalHandlerRegistry registry, HandlerRegistry fallbackRegistry,
InternalServer transportServer, Context rootContext, InternalServer transportServer, Context rootContext,
DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry, DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry,
List<ServerTransportFilter> transportFilters, StatsContextFactory statsFactory, List<ServerTransportFilter> transportFilters, StatsContextFactory statsFactory,
Supplier<Stopwatch> stopwatchSupplier) { Supplier<Stopwatch> stopwatchSupplier) {
this.executor = executor; this.executorPool = Preconditions.checkNotNull(executorPool, "executorPool");
this.timeoutServicePool = Preconditions.checkNotNull(timeoutServicePool, "timeoutServicePool");
this.registry = Preconditions.checkNotNull(registry, "registry"); this.registry = Preconditions.checkNotNull(registry, "registry");
this.fallbackRegistry = Preconditions.checkNotNull(fallbackRegistry, "fallbackRegistry"); this.fallbackRegistry = Preconditions.checkNotNull(fallbackRegistry, "fallbackRegistry");
this.transportServer = Preconditions.checkNotNull(transportServer, "transportServer"); this.transportServer = Preconditions.checkNotNull(transportServer, "transportServer");
@ -158,12 +159,10 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
synchronized (lock) { synchronized (lock) {
checkState(!started, "Already started"); checkState(!started, "Already started");
checkState(!shutdown, "Shutting down"); checkState(!shutdown, "Shutting down");
usingSharedExecutor = executor == null;
if (usingSharedExecutor) {
executor = SharedResourceHolder.get(GrpcUtil.SHARED_CHANNEL_EXECUTOR);
}
// Start and wait for any port to actually be bound. // Start and wait for any port to actually be bound.
transportServer.start(new ServerListenerImpl()); transportServer.start(new ServerListenerImpl());
timeoutService = Preconditions.checkNotNull(timeoutServicePool.getObject(), "timeoutService");
executor = Preconditions.checkNotNull(executorPool.getObject(), "executor");
started = true; started = true;
return this; return this;
} }
@ -214,10 +213,6 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
if (shutdownTransportServer) { if (shutdownTransportServer) {
transportServer.shutdown(); transportServer.shutdown();
} }
SharedResourceHolder.release(TIMER_SERVICE, timeoutService);
if (usingSharedExecutor) {
SharedResourceHolder.release(GrpcUtil.SHARED_CHANNEL_EXECUTOR, (ExecutorService) executor);
}
return this; return this;
} }
@ -307,6 +302,12 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
throw new AssertionError("Server already terminated"); throw new AssertionError("Server already terminated");
} }
terminated = true; terminated = true;
if (timeoutService != null) {
timeoutService = timeoutServicePool.returnObject(timeoutService);
}
if (executor != null) {
executor = executorPool.returnObject(executor);
}
// TODO(carl-mastrangelo): move this outside the synchronized block. // TODO(carl-mastrangelo): move this outside the synchronized block.
lock.notifyAll(); lock.notifyAll();
} }

View File

@ -46,7 +46,7 @@ import static org.mockito.Matchers.notNull;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.same;
import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -99,13 +99,11 @@ import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.util.concurrent.BrokenBarrierException; import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier; import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -133,13 +131,17 @@ public class ServerImplTest {
SERVER_CONTEXT.cancel(null); SERVER_CONTEXT.cancel(null);
} }
private ExecutorService executor = Executors.newSingleThreadExecutor(); private final FakeClock executor = new FakeClock();
private final FakeClock timer = new FakeClock();
@Mock
private ObjectPool<Executor> executorPool;
@Mock
private ObjectPool<ScheduledExecutorService> timerPool;
private InternalHandlerRegistry registry = new InternalHandlerRegistry.Builder().build(); private InternalHandlerRegistry registry = new InternalHandlerRegistry.Builder().build();
private MutableHandlerRegistry fallbackRegistry = new MutableHandlerRegistry(); private MutableHandlerRegistry mutableFallbackRegistry = new MutableHandlerRegistry();
private HandlerRegistry fallbackRegistry = mutableFallbackRegistry;
private SimpleServer transportServer = new SimpleServer(); private SimpleServer transportServer = new SimpleServer();
private ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, private ServerImpl server;
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, statsCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
@Captor @Captor
private ArgumentCaptor<Status> statusCaptor; private ArgumentCaptor<Status> statusCaptor;
@ -157,14 +159,14 @@ public class ServerImplTest {
@Before @Before
public void startUp() throws IOException { public void startUp() throws IOException {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
when(executorPool.getObject()).thenReturn(executor.getScheduledExecutorService());
server.start(); when(timerPool.getObject()).thenReturn(timer.getScheduledExecutorService());
} }
/** Tear down after test. */
@After @After
public void tearDown() { public void noPendingTasks() {
executor.shutdownNow(); assertEquals(0, executor.numPendingTasks());
assertEquals(0, timer.numPendingTasks());
} }
@Test @Test
@ -173,10 +175,7 @@ public class ServerImplTest {
@Override @Override
public void shutdown() {} public void shutdown() {}
}; };
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, createAndStartServer(NO_FILTERS);
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, statsCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start();
server.shutdown(); server.shutdown();
assertTrue(server.isShutdown()); assertTrue(server.isShutdown());
assertFalse(server.isTerminated()); assertFalse(server.isTerminated());
@ -185,27 +184,25 @@ public class ServerImplTest {
} }
@Test @Test
public void stopImmediate() { public void stopImmediate() throws IOException {
transportServer = new SimpleServer() { transportServer = new SimpleServer() {
@Override @Override
public void shutdown() { public void shutdown() {
throw new AssertionError("Should not be called, because wasn't started"); throw new AssertionError("Should not be called, because wasn't started");
} }
}; };
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, createServer(NO_FILTERS);
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, statsCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.shutdown(); server.shutdown();
assertTrue(server.isShutdown()); assertTrue(server.isShutdown());
assertTrue(server.isTerminated()); assertTrue(server.isTerminated());
verifyNoMoreInteractions(executorPool);
verifyNoMoreInteractions(timerPool);
} }
@Test @Test
public void startStopImmediateWithChildTransport() throws IOException { public void startStopImmediateWithChildTransport() throws IOException {
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, createAndStartServer(NO_FILTERS);
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, statsCtxFactory, verifyExecutorsAcquired();
GrpcUtil.STOPWATCH_SUPPLIER);
server.start();
class DelayedShutdownServerTransport extends SimpleServerTransport { class DelayedShutdownServerTransport extends SimpleServerTransport {
boolean shutdown; boolean shutdown;
@ -221,16 +218,17 @@ public class ServerImplTest {
assertTrue(server.isShutdown()); assertTrue(server.isShutdown());
assertFalse(server.isTerminated()); assertFalse(server.isTerminated());
assertTrue(serverTransport.shutdown); assertTrue(serverTransport.shutdown);
verifyExecutorsNotReturned();
serverTransport.listener.transportTerminated(); serverTransport.listener.transportTerminated();
assertTrue(server.isTerminated()); assertTrue(server.isTerminated());
verifyExecutorsReturned();
} }
@Test @Test
public void startShutdownNowImmediateWithChildTransport() throws IOException { public void startShutdownNowImmediateWithChildTransport() throws IOException {
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, createAndStartServer(NO_FILTERS);
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, statsCtxFactory, verifyExecutorsAcquired();
GrpcUtil.STOPWATCH_SUPPLIER);
server.start();
class DelayedShutdownServerTransport extends SimpleServerTransport { class DelayedShutdownServerTransport extends SimpleServerTransport {
boolean shutdown; boolean shutdown;
@ -249,16 +247,17 @@ public class ServerImplTest {
assertTrue(server.isShutdown()); assertTrue(server.isShutdown());
assertFalse(server.isTerminated()); assertFalse(server.isTerminated());
assertTrue(serverTransport.shutdown); assertTrue(serverTransport.shutdown);
verifyExecutorsNotReturned();
serverTransport.listener.transportTerminated(); serverTransport.listener.transportTerminated();
assertTrue(server.isTerminated()); assertTrue(server.isTerminated());
verifyExecutorsReturned();
} }
@Test @Test
public void shutdownNowAfterShutdown() throws IOException { public void shutdownNowAfterShutdown() throws IOException {
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, createAndStartServer(NO_FILTERS);
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, statsCtxFactory, verifyExecutorsAcquired();
GrpcUtil.STOPWATCH_SUPPLIER);
server.start();
class DelayedShutdownServerTransport extends SimpleServerTransport { class DelayedShutdownServerTransport extends SimpleServerTransport {
boolean shutdown; boolean shutdown;
@ -278,22 +277,23 @@ public class ServerImplTest {
server.shutdownNow(); server.shutdownNow();
assertFalse(server.isTerminated()); assertFalse(server.isTerminated());
assertTrue(serverTransport.shutdown); assertTrue(serverTransport.shutdown);
verifyExecutorsNotReturned();
serverTransport.listener.transportTerminated(); serverTransport.listener.transportTerminated();
assertTrue(server.isTerminated()); assertTrue(server.isTerminated());
verifyExecutorsReturned();
} }
@Test @Test
public void shutdownNowAfterSlowShutdown() throws IOException { public void shutdownNowAfterSlowShutdown() throws IOException {
SimpleServer transportServer = new SimpleServer() { transportServer = new SimpleServer() {
@Override @Override
public void shutdown() { public void shutdown() {
// Don't call super which calls listener.serverShutdown(). We'll call it manually. // Don't call super which calls listener.serverShutdown(). We'll call it manually.
} }
}; };
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, createAndStartServer(NO_FILTERS);
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, statsCtxFactory, verifyExecutorsAcquired();
GrpcUtil.STOPWATCH_SUPPLIER);
server.start();
class DelayedShutdownServerTransport extends SimpleServerTransport { class DelayedShutdownServerTransport extends SimpleServerTransport {
boolean shutdown; boolean shutdown;
@ -313,7 +313,10 @@ public class ServerImplTest {
transportServer.listener.serverShutdown(); transportServer.listener.serverShutdown();
assertTrue(server.isShutdown()); assertTrue(server.isShutdown());
assertFalse(server.isTerminated()); assertFalse(server.isTerminated());
verifyExecutorsNotReturned();
serverTransport.listener.transportTerminated(); serverTransport.listener.transportTerminated();
verifyExecutorsReturned();
assertTrue(server.isTerminated()); assertTrue(server.isTerminated());
} }
@ -327,19 +330,21 @@ public class ServerImplTest {
} }
} }
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer = new FailingStartupServer();
new FailingStartupServer(), SERVER_CONTEXT, decompressorRegistry, compressorRegistry, createServer(NO_FILTERS);
NO_FILTERS, statsCtxFactory, GrpcUtil.STOPWATCH_SUPPLIER);
try { try {
server.start(); server.start();
fail("expected exception"); fail("expected exception");
} catch (IOException e) { } catch (IOException e) {
assertSame(ex, e); assertSame(ex, e);
} }
verifyNoMoreInteractions(executorPool);
verifyNoMoreInteractions(timerPool);
} }
@Test @Test
public void methodNotFound() throws Exception { public void methodNotFound() throws Exception {
createAndStartServer(NO_FILTERS);
ServerTransportListener transportListener ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport()); = transportServer.registerNewServerTransport(new SimpleServerTransport());
Metadata requestHeaders = new Metadata(); Metadata requestHeaders = new Metadata();
@ -350,7 +355,7 @@ public class ServerImplTest {
verify(stream).setListener(isA(ServerStreamListener.class)); verify(stream).setListener(isA(ServerStreamListener.class));
verify(stream, atLeast(1)).statsTraceContext(); verify(stream, atLeast(1)).statsTraceContext();
executeBarrier(executor).await(); assertEquals(1, executor.runDueTasks());
verify(stream).close(statusCaptor.capture(), any(Metadata.class)); verify(stream).close(statusCaptor.capture(), any(Metadata.class));
Status status = statusCaptor.getValue(); Status status = statusCaptor.getValue();
assertEquals(Status.Code.UNIMPLEMENTED, status.getCode()); assertEquals(Status.Code.UNIMPLEMENTED, status.getCode());
@ -368,6 +373,7 @@ public class ServerImplTest {
@Test @Test
public void basicExchangeSuccessful() throws Exception { public void basicExchangeSuccessful() throws Exception {
createAndStartServer(NO_FILTERS);
final Metadata.Key<String> metadataKey final Metadata.Key<String> metadataKey
= Metadata.Key.of("inception", Metadata.ASCII_STRING_MARSHALLER); = Metadata.Key.of("inception", Metadata.ASCII_STRING_MARSHALLER);
final Metadata.Key<StatsContext> statsHeaderKey final Metadata.Key<StatsContext> statsHeaderKey
@ -376,7 +382,7 @@ public class ServerImplTest {
= new AtomicReference<ServerCall<String, Integer>>(); = new AtomicReference<ServerCall<String, Integer>>();
MethodDescriptor<String, Integer> method = MethodDescriptor.create( MethodDescriptor<String, Integer> method = MethodDescriptor.create(
MethodType.UNKNOWN, "Waiter/serve", STRING_MARSHALLER, INTEGER_MARSHALLER); MethodType.UNKNOWN, "Waiter/serve", STRING_MARSHALLER, INTEGER_MARSHALLER);
fallbackRegistry.addService(ServerServiceDefinition.builder( mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
new ServiceDescriptor("Waiter", method)) new ServiceDescriptor("Waiter", method))
.addMethod( .addMethod(
method, method,
@ -412,13 +418,14 @@ public class ServerImplTest {
assertNotNull(streamListener); assertNotNull(streamListener);
verify(stream, atLeast(1)).statsTraceContext(); verify(stream, atLeast(1)).statsTraceContext();
executeBarrier(executor).await(); assertEquals(1, executor.runDueTasks());
ServerCall<String, Integer> call = callReference.get(); ServerCall<String, Integer> call = callReference.get();
assertNotNull(call); assertNotNull(call);
String order = "Lots of pizza, please"; String order = "Lots of pizza, please";
streamListener.messageRead(STRING_MARSHALLER.stream(order)); streamListener.messageRead(STRING_MARSHALLER.stream(order));
verify(callListener, timeout(2000)).onMessage(order); assertEquals(1, executor.runDueTasks());
verify(callListener).onMessage(order);
Metadata responseHeaders = new Metadata(); Metadata responseHeaders = new Metadata();
responseHeaders.put(metadataKey, "response value"); responseHeaders.put(metadataKey, "response value");
@ -433,7 +440,7 @@ public class ServerImplTest {
assertEquals(314, INTEGER_MARSHALLER.parse(inputCaptor.getValue()).intValue()); assertEquals(314, INTEGER_MARSHALLER.parse(inputCaptor.getValue()).intValue());
streamListener.halfClosed(); // All full; no dessert. streamListener.halfClosed(); // All full; no dessert.
executeBarrier(executor).await(); assertEquals(1, executor.runDueTasks());
verify(callListener).onHalfClose(); verify(callListener).onHalfClose();
call.sendMessage(50); call.sendMessage(50);
@ -448,7 +455,7 @@ public class ServerImplTest {
verify(stream).close(status, trailers); verify(stream).close(status, trailers);
streamListener.closed(Status.OK); streamListener.closed(Status.OK);
executeBarrier(executor).await(); assertEquals(1, executor.runDueTasks());
verify(callListener).onComplete(); verify(callListener).onComplete();
verify(stream, atLeast(1)).statsTraceContext(); verify(stream, atLeast(1)).statsTraceContext();
@ -540,10 +547,7 @@ public class ServerImplTest {
.set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddr) .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddr)
.build(); .build();
ServerImpl server = new ServerImpl(MoreExecutors.directExecutor(), registry, fallbackRegistry, createAndStartServer(ImmutableList.of(filter1, filter2));
transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry,
ImmutableList.of(filter1, filter2), statsCtxFactory, GrpcUtil.STOPWATCH_SUPPLIER);
server.start();
ServerTransportListener transportListener ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport()); = transportServer.registerNewServerTransport(new SimpleServerTransport());
Attributes transportAttrs = transportListener.transportReady(Attributes.newBuilder() Attributes transportAttrs = transportListener.transportReady(Attributes.newBuilder()
@ -562,12 +566,12 @@ public class ServerImplTest {
@Test @Test
public void exceptionInStartCallPropagatesToStream() throws Exception { public void exceptionInStartCallPropagatesToStream() throws Exception {
CyclicBarrier barrier = executeBarrier(executor); createAndStartServer(NO_FILTERS);
final Status status = Status.ABORTED.withDescription("Oh, no!"); final Status status = Status.ABORTED.withDescription("Oh, no!");
MethodDescriptor<String, Integer> method = MethodDescriptor MethodDescriptor<String, Integer> method = MethodDescriptor
.create(MethodType.UNKNOWN, "Waiter/serve", .create(MethodType.UNKNOWN, "Waiter/serve",
STRING_MARSHALLER, INTEGER_MARSHALLER); STRING_MARSHALLER, INTEGER_MARSHALLER);
fallbackRegistry.addService(ServerServiceDefinition.builder( mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
new ServiceDescriptor("Waiter", method)) new ServiceDescriptor("Waiter", method))
.addMethod(method, .addMethod(method,
new ServerCallHandler<String, Integer>() { new ServerCallHandler<String, Integer>() {
@ -594,8 +598,7 @@ public class ServerImplTest {
verify(stream, atLeast(1)).statsTraceContext(); verify(stream, atLeast(1)).statsTraceContext();
verifyNoMoreInteractions(stream); verifyNoMoreInteractions(stream);
barrier.await(); assertEquals(1, executor.runDueTasks());
executeBarrier(executor).await();
verify(stream).close(same(status), notNull(Metadata.class)); verify(stream).close(same(status), notNull(Metadata.class));
verify(stream, atLeast(1)).statsTraceContext(); verify(stream, atLeast(1)).statsTraceContext();
verifyNoMoreInteractions(stream); verifyNoMoreInteractions(stream);
@ -622,10 +625,7 @@ public class ServerImplTest {
} }
transportServer = new MaybeDeadlockingServer(); transportServer = new MaybeDeadlockingServer();
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, createAndStartServer(NO_FILTERS);
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, statsCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start();
new Thread() { new Thread() {
@Override @Override
public void run() { public void run() {
@ -645,6 +645,7 @@ public class ServerImplTest {
@Test @Test
public void testNoDeadlockOnTransportShutdown() throws Exception { public void testNoDeadlockOnTransportShutdown() throws Exception {
createAndStartServer(NO_FILTERS);
final Object lock = new Object(); final Object lock = new Object();
final CyclicBarrier barrier = new CyclicBarrier(2); final CyclicBarrier barrier = new CyclicBarrier(2);
class MaybeDeadlockingServerTransport extends SimpleServerTransport { class MaybeDeadlockingServerTransport extends SimpleServerTransport {
@ -685,13 +686,14 @@ public class ServerImplTest {
@Test @Test
public void testCallContextIsBoundInListenerCallbacks() throws Exception { public void testCallContextIsBoundInListenerCallbacks() throws Exception {
createAndStartServer(NO_FILTERS);
MethodDescriptor<String, Integer> method = MethodDescriptor.create( MethodDescriptor<String, Integer> method = MethodDescriptor.create(
MethodType.UNKNOWN, "Waiter/serve", STRING_MARSHALLER, INTEGER_MARSHALLER); MethodType.UNKNOWN, "Waiter/serve", STRING_MARSHALLER, INTEGER_MARSHALLER);
final CountDownLatch onReadyCalled = new CountDownLatch(1); final AtomicBoolean onReadyCalled = new AtomicBoolean(false);
final CountDownLatch onMessageCalled = new CountDownLatch(1); final AtomicBoolean onMessageCalled = new AtomicBoolean(false);
final CountDownLatch onHalfCloseCalled = new CountDownLatch(1); final AtomicBoolean onHalfCloseCalled = new AtomicBoolean(false);
final CountDownLatch onCancelCalled = new CountDownLatch(1); final AtomicBoolean onCancelCalled = new AtomicBoolean(false);
fallbackRegistry.addService(ServerServiceDefinition.builder( mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
new ServiceDescriptor("Waiter", method)) new ServiceDescriptor("Waiter", method))
.addMethod( .addMethod(
method, method,
@ -710,25 +712,25 @@ public class ServerImplTest {
@Override @Override
public void onReady() { public void onReady() {
checkContext(); checkContext();
onReadyCalled.countDown(); onReadyCalled.set(true);
} }
@Override @Override
public void onMessage(String message) { public void onMessage(String message) {
checkContext(); checkContext();
onMessageCalled.countDown(); onMessageCalled.set(true);
} }
@Override @Override
public void onHalfClose() { public void onHalfClose() {
checkContext(); checkContext();
onHalfCloseCalled.countDown(); onHalfCloseCalled.set(true);
} }
@Override @Override
public void onCancel() { public void onCancel() {
checkContext(); checkContext();
onCancelCalled.countDown(); onCancelCalled.set(true);
} }
@Override @Override
@ -758,14 +760,20 @@ public class ServerImplTest {
assertNotNull(streamListener); assertNotNull(streamListener);
streamListener.onReady(); streamListener.onReady();
streamListener.messageRead(new ByteArrayInputStream(new byte[0])); assertEquals(1, executor.runDueTasks());
streamListener.halfClosed(); assertTrue(onReadyCalled.get());
streamListener.closed(Status.CANCELLED);
assertTrue(onReadyCalled.await(5, TimeUnit.SECONDS)); streamListener.messageRead(new ByteArrayInputStream(new byte[0]));
assertTrue(onMessageCalled.await(5, TimeUnit.SECONDS)); assertEquals(1, executor.runDueTasks());
assertTrue(onHalfCloseCalled.await(5, TimeUnit.SECONDS)); assertTrue(onMessageCalled.get());
assertTrue(onCancelCalled.await(5, TimeUnit.SECONDS));
streamListener.halfClosed();
assertEquals(1, executor.runDueTasks());
assertTrue(onHalfCloseCalled.get());
streamListener.closed(Status.CANCELLED);
assertEquals(1, executor.runDueTasks());
assertTrue(onCancelCalled.get());
// Close should never be called if asserts in listener pass. // Close should never be called if asserts in listener pass.
verify(stream, times(0)).close(isA(Status.class), isNotNull(Metadata.class)); verify(stream, times(0)).close(isA(Status.class), isNotNull(Metadata.class));
@ -773,14 +781,15 @@ public class ServerImplTest {
@Test @Test
public void testClientCancelTriggersContextCancellation() throws Exception { public void testClientCancelTriggersContextCancellation() throws Exception {
final CountDownLatch latch = new CountDownLatch(1); createAndStartServer(NO_FILTERS);
final AtomicBoolean contextCancelled = new AtomicBoolean(false);
callListener = new ServerCall.Listener<String>() { callListener = new ServerCall.Listener<String>() {
@Override @Override
public void onReady() { public void onReady() {
Context.current().addListener(new Context.CancellationListener() { Context.current().addListener(new Context.CancellationListener() {
@Override @Override
public void cancelled(Context context) { public void cancelled(Context context) {
latch.countDown(); contextCancelled.set(true);
} }
}, MoreExecutors.directExecutor()); }, MoreExecutors.directExecutor());
} }
@ -790,7 +799,7 @@ public class ServerImplTest {
= new AtomicReference<ServerCall<String, Integer>>(); = new AtomicReference<ServerCall<String, Integer>>();
MethodDescriptor<String, Integer> method = MethodDescriptor.create( MethodDescriptor<String, Integer> method = MethodDescriptor.create(
MethodType.UNKNOWN, "Waiter/serve", STRING_MARSHALLER, INTEGER_MARSHALLER); MethodType.UNKNOWN, "Waiter/serve", STRING_MARSHALLER, INTEGER_MARSHALLER);
fallbackRegistry.addService(ServerServiceDefinition.builder( mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
new ServiceDescriptor("Waiter", method)) new ServiceDescriptor("Waiter", method))
.addMethod(method, .addMethod(method,
new ServerCallHandler<String, Integer>() { new ServerCallHandler<String, Integer>() {
@ -817,7 +826,8 @@ public class ServerImplTest {
streamListener.onReady(); streamListener.onReady();
streamListener.closed(Status.CANCELLED); streamListener.closed(Status.CANCELLED);
assertTrue(latch.await(5, TimeUnit.SECONDS)); assertEquals(1, executor.runDueTasks());
assertTrue(contextCancelled.get());
} }
@Test @Test
@ -828,10 +838,7 @@ public class ServerImplTest {
return 65535; return 65535;
} }
}; };
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, createAndStartServer(NO_FILTERS);
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, statsCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start();
Truth.assertThat(server.getPort()).isEqualTo(65535); Truth.assertThat(server.getPort()).isEqualTo(65535);
} }
@ -839,9 +846,7 @@ public class ServerImplTest {
@Test @Test
public void getPortBeforeStartedFails() { public void getPortBeforeStartedFails() {
transportServer = new SimpleServer(); transportServer = new SimpleServer();
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, createServer(NO_FILTERS);
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, statsCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
thrown.expect(IllegalStateException.class); thrown.expect(IllegalStateException.class);
thrown.expectMessage("started"); thrown.expectMessage("started");
server.getPort(); server.getPort();
@ -850,10 +855,7 @@ public class ServerImplTest {
@Test @Test
public void getPortAfterTerminationFails() throws Exception { public void getPortAfterTerminationFails() throws Exception {
transportServer = new SimpleServer(); transportServer = new SimpleServer();
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, createAndStartServer(NO_FILTERS);
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, statsCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start();
server.shutdown(); server.shutdown();
server.awaitTermination(); server.awaitTermination();
thrown.expect(IllegalStateException.class); thrown.expect(IllegalStateException.class);
@ -863,7 +865,7 @@ public class ServerImplTest {
@Test @Test
public void handlerRegistryPriorities() throws Exception { public void handlerRegistryPriorities() throws Exception {
HandlerRegistry fallbackRegistry = mock(HandlerRegistry.class); fallbackRegistry = mock(HandlerRegistry.class);
MethodDescriptor<String, Integer> method1 = MethodDescriptor.create( MethodDescriptor<String, Integer> method1 = MethodDescriptor.create(
MethodType.UNKNOWN, "Service1/Method1", STRING_MARSHALLER, INTEGER_MARSHALLER); MethodType.UNKNOWN, "Service1/Method1", STRING_MARSHALLER, INTEGER_MARSHALLER);
registry = new InternalHandlerRegistry.Builder() registry = new InternalHandlerRegistry.Builder()
@ -871,10 +873,7 @@ public class ServerImplTest {
.addMethod(method1, callHandler).build()) .addMethod(method1, callHandler).build())
.build(); .build();
transportServer = new SimpleServer(); transportServer = new SimpleServer();
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, createAndStartServer(NO_FILTERS);
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, statsCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start();
ServerTransportListener transportListener ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport()); = transportServer.registerNewServerTransport(new SimpleServerTransport());
@ -886,37 +885,48 @@ public class ServerImplTest {
// This call will be handled by callHandler from the internal registry // This call will be handled by callHandler from the internal registry
transportListener.streamCreated(stream, "Service1/Method1", requestHeaders); transportListener.streamCreated(stream, "Service1/Method1", requestHeaders);
assertEquals(1, executor.runDueTasks());
verify(callHandler).startCall(Matchers.<ServerCall<String, Integer>>anyObject(),
Matchers.<Metadata>anyObject());
// This call will be handled by the fallbackRegistry because it's not registred in the internal // This call will be handled by the fallbackRegistry because it's not registred in the internal
// registry. // registry.
transportListener.streamCreated(stream, "Service1/Method2", requestHeaders); transportListener.streamCreated(stream, "Service1/Method2", requestHeaders);
assertEquals(1, executor.runDueTasks());
verify(fallbackRegistry).lookupMethod("Service1/Method2", null);
verify(callHandler, timeout(2000)).startCall(Matchers.<ServerCall<String, Integer>>anyObject(),
Matchers.<Metadata>anyObject());
verify(fallbackRegistry, timeout(2000)).lookupMethod("Service1/Method2", null);
verifyNoMoreInteractions(callHandler); verifyNoMoreInteractions(callHandler);
verifyNoMoreInteractions(fallbackRegistry); verifyNoMoreInteractions(fallbackRegistry);
} }
/** private void createAndStartServer(List<ServerTransportFilter> filters) throws IOException {
* Useful for plugging a single-threaded executor from processing tasks, or for waiting until a createServer(filters);
* single-threaded executor has processed queued tasks. server.start();
*/ }
private static CyclicBarrier executeBarrier(Executor executor) {
final CyclicBarrier barrier = new CyclicBarrier(2); private void createServer(List<ServerTransportFilter> filters) {
executor.execute(new Runnable() { assertNull(server);
@Override server = new ServerImpl(executorPool, timerPool, registry, fallbackRegistry,
public void run() { transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry, filters,
try { statsCtxFactory, GrpcUtil.STOPWATCH_SUPPLIER);
barrier.await(); }
} catch (InterruptedException ex) {
Thread.currentThread().interrupt(); private void verifyExecutorsAcquired() {
throw new RuntimeException(ex); verify(executorPool).getObject();
} catch (BrokenBarrierException ex) { verify(timerPool).getObject();
throw new RuntimeException(ex); verifyNoMoreInteractions(executorPool);
} verifyNoMoreInteractions(timerPool);
} }
});
return barrier; private void verifyExecutorsNotReturned() {
verify(executorPool, never()).returnObject(any(Executor.class));
verify(timerPool, never()).returnObject(any(ScheduledExecutorService.class));
}
private void verifyExecutorsReturned() {
verify(executorPool).returnObject(same(executor.getScheduledExecutorService()));
verify(timerPool).returnObject(same(timer.getScheduledExecutorService()));
verifyNoMoreInteractions(executorPool);
verifyNoMoreInteractions(timerPool);
} }
private static class SimpleServer implements io.grpc.internal.InternalServer { private static class SimpleServer implements io.grpc.internal.InternalServer {