core: ServerBuilder.intercept(). (#3118)

This adds server-wide interceptors that applies to all call handlers.
Because ServerCallHandler is acquired per request, and can be dynamicly
provided by the fallback registry, the interceptors have to be installed
on a per-request basis. This adds a few object allocations per request,
which is acceptable.
This commit is contained in:
Kun Zhang 2017-06-21 08:55:16 -07:00 committed by GitHub
parent 3dce2ee84b
commit 859d211b6e
6 changed files with 182 additions and 49 deletions

View File

@ -0,0 +1,31 @@
/*
* Copyright 2017, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc;
/**
* Accessor to internal methods of {@link ServerInterceptors}.
*/
@Internal
public final class InternalServerInterceptors {
public static <ReqT, RespT> ServerCallHandler<ReqT, RespT> interceptCallHandler(
ServerInterceptor interceptor, ServerCallHandler<ReqT, RespT> callHandler) {
return ServerInterceptors.InterceptCallHandler.create(interceptor, callHandler);
}
private InternalServerInterceptors() {
}
}

View File

@ -88,6 +88,20 @@ public abstract class ServerBuilder<T extends ServerBuilder<T>> {
*/
public abstract T addService(BindableService bindableService);
/**
* Adds a {@link ServerInterceptor} that is run for all services on the server. Interceptors
* added through this method always run before per-service interceptors added through {@link
* ServerInterceptors}. Interceptors run in the reverse order in which they are added.
*
* @param interceptor the all-service interceptor
* @return this
* @since 1.5.0
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/3117")
public T intercept(ServerInterceptor interceptor) {
throw new UnsupportedOperationException();
}
/**
* Adds a {@link ServerTransportFilter}. The order of filters being added is the order they will
* be executed.

View File

@ -207,7 +207,7 @@ public final class ServerInterceptors {
serviceDefBuilder.addMethod(method.withServerCallHandler(callHandler));
}
private static class InterceptCallHandler<ReqT, RespT> implements ServerCallHandler<ReqT, RespT> {
static class InterceptCallHandler<ReqT, RespT> implements ServerCallHandler<ReqT, RespT> {
public static <ReqT, RespT> InterceptCallHandler<ReqT, RespT> create(
ServerInterceptor interceptor, ServerCallHandler<ReqT, RespT> callHandler) {
return new InterceptCallHandler<ReqT, RespT>(interceptor, callHandler);

View File

@ -33,6 +33,7 @@ import io.grpc.Internal;
import io.grpc.InternalNotifyOnServerBuild;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerInterceptor;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServerStreamTracer;
@ -70,6 +71,9 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
private final ArrayList<ServerTransportFilter> transportFilters =
new ArrayList<ServerTransportFilter>();
private final ArrayList<ServerInterceptor> interceptors =
new ArrayList<ServerInterceptor>();
private final List<InternalNotifyOnServerBuild> notifyOnBuildList =
new ArrayList<InternalNotifyOnServerBuild>();
@ -122,6 +126,12 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
return thisT();
}
@Override
public final T intercept(ServerInterceptor interceptor) {
interceptors.add(interceptor);
return thisT();
}
@Override
public final T addStreamTracerFactory(ServerStreamTracer.Factory factory) {
streamTracerFactories.add(checkNotNull(factory, "factory"));
@ -179,7 +189,7 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
firstNonNull(fallbackRegistry, EMPTY_FALLBACK_REGISTRY), transportServer,
Context.ROOT, firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()),
firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()),
transportFilters);
transportFilters, interceptors);
for (InternalNotifyOnServerBuild notifyTarget : notifyOnBuildList) {
notifyTarget.notifyOnBuild(server);
}

View File

@ -32,8 +32,11 @@ import io.grpc.Context;
import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import io.grpc.HandlerRegistry;
import io.grpc.InternalServerInterceptors;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServerTransportFilter;
@ -74,6 +77,9 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
private final InternalHandlerRegistry registry;
private final HandlerRegistry fallbackRegistry;
private final List<ServerTransportFilter> transportFilters;
// This is iterated on a per-call basis. Use an array instead of a Collection to avoid iterator
// creations.
private final ServerInterceptor[] interceptors;
@GuardedBy("lock") private boolean started;
@GuardedBy("lock") private boolean shutdown;
/** non-{@code null} if immediate shutdown has been requested. */
@ -109,7 +115,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
InternalHandlerRegistry registry, HandlerRegistry fallbackRegistry,
InternalServer transportServer, Context rootContext,
DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry,
List<ServerTransportFilter> transportFilters) {
List<ServerTransportFilter> transportFilters, List<ServerInterceptor> interceptors) {
this.executorPool = Preconditions.checkNotNull(executorPool, "executorPool");
this.timeoutServicePool = Preconditions.checkNotNull(timeoutServicePool, "timeoutServicePool");
this.registry = Preconditions.checkNotNull(registry, "registry");
@ -122,6 +128,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
this.compressorRegistry = compressorRegistry;
this.transportFilters = Collections.unmodifiableList(
new ArrayList<ServerTransportFilter>(transportFilters));
this.interceptors = interceptors.toArray(new ServerInterceptor[interceptors.size()]);
}
/**
@ -469,9 +476,12 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
ServerCallImpl<ReqT, RespT> call = new ServerCallImpl<ReqT, RespT>(
stream, methodDef.getMethodDescriptor(), headers, context,
decompressorRegistry, compressorRegistry);
ServerCallHandler<ReqT, RespT> callHandler = methodDef.getServerCallHandler();
statsTraceCtx.serverCallStarted(call);
ServerCall.Listener<ReqT> listener =
methodDef.getServerCallHandler().startCall(call, headers);
for (ServerInterceptor interceptor : interceptors) {
callHandler = InternalServerInterceptors.interceptCallHandler(interceptor, callHandler);
}
ServerCall.Listener<ReqT> listener = callHandler.startCall(call, headers);
if (listener == null) {
throw new NullPointerException(
"startCall() returned a null listener for method " + fullMethodName);

View File

@ -57,6 +57,7 @@ import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServerStreamTracer;
import io.grpc.ServerTransportFilter;
@ -70,6 +71,8 @@ import java.io.IOException;
import java.io.InputStream;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Executor;
@ -96,6 +99,13 @@ import org.mockito.MockitoAnnotations;
public class ServerImplTest {
private static final IntegerMarshaller INTEGER_MARSHALLER = IntegerMarshaller.INSTANCE;
private static final StringMarshaller STRING_MARSHALLER = StringMarshaller.INSTANCE;
private static final MethodDescriptor<String, Integer> METHOD =
MethodDescriptor.<String, Integer>newBuilder()
.setType(MethodDescriptor.MethodType.UNKNOWN)
.setFullMethodName("Waiter/serve")
.setRequestMarshaller(STRING_MARSHALLER)
.setResponseMarshaller(INTEGER_MARSHALLER)
.build();
private static final Context.Key<String> SERVER_ONLY = Context.key("serverOnly");
private static final Context.Key<String> SERVER_TRACER_ADDED_KEY = Context.key("tracer-added");
private static final Context.CancellableContext SERVER_CONTEXT =
@ -402,16 +412,10 @@ public class ServerImplTest {
final AtomicReference<ServerCall<String, Integer>> callReference
= new AtomicReference<ServerCall<String, Integer>>();
final AtomicReference<Context> callContextReference = new AtomicReference<Context>();
MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
.setType(MethodDescriptor.MethodType.UNKNOWN)
.setFullMethodName("Waiter/serve")
.setRequestMarshaller(STRING_MARSHALLER)
.setResponseMarshaller(INTEGER_MARSHALLER)
.build();
mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
new ServiceDescriptor("Waiter", method))
new ServiceDescriptor("Waiter", METHOD))
.addMethod(
method,
METHOD,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
@ -569,19 +573,96 @@ public class ServerImplTest {
assertEquals(2, terminationCallbackCalled.get());
}
@Test
public void interceptors() throws Exception {
final LinkedList<Context> capturedContexts = new LinkedList<Context>();
final Context.Key<String> key1 = Context.key("key1");
final Context.Key<String> key2 = Context.key("key2");
final Context.Key<String> key3 = Context.key("key3");
ServerInterceptor intercepter1 = new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
Context ctx = Context.current().withValue(key1, "value1");
Context origCtx = ctx.attach();
try {
capturedContexts.add(ctx);
return next.startCall(call, headers);
} finally {
ctx.detach(origCtx);
}
}
};
ServerInterceptor intercepter2 = new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
Context ctx = Context.current().withValue(key2, "value2");
Context origCtx = ctx.attach();
try {
capturedContexts.add(ctx);
return next.startCall(call, headers);
} finally {
ctx.detach(origCtx);
}
}
};
ServerCallHandler<String, Integer> callHandler = new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
ServerCall<String, Integer> call,
Metadata headers) {
capturedContexts.add(Context.current().withValue(key3, "value3"));
return callListener;
}
};
mutableFallbackRegistry.addService(
ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", METHOD))
.addMethod(METHOD, callHandler).build());
createServer(NO_FILTERS, Arrays.asList(intercepter2, intercepter1));
server.start();
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders);
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
assertEquals(1, executor.runDueTasks());
Context ctx1 = capturedContexts.poll();
assertEquals("value1", key1.get(ctx1));
assertNull(key2.get(ctx1));
assertNull(key3.get(ctx1));
Context ctx2 = capturedContexts.poll();
assertEquals("value1", key1.get(ctx2));
assertEquals("value2", key2.get(ctx2));
assertNull(key3.get(ctx2));
Context ctx3 = capturedContexts.poll();
assertEquals("value1", key1.get(ctx3));
assertEquals("value2", key2.get(ctx3));
assertEquals("value3", key3.get(ctx3));
assertTrue(capturedContexts.isEmpty());
}
@Test
public void exceptionInStartCallPropagatesToStream() throws Exception {
createAndStartServer(NO_FILTERS);
final Status status = Status.ABORTED.withDescription("Oh, no!");
MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
.setType(MethodDescriptor.MethodType.UNKNOWN)
.setFullMethodName("Waiter/serve")
.setRequestMarshaller(STRING_MARSHALLER)
.setResponseMarshaller(INTEGER_MARSHALLER)
.build();
mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
new ServiceDescriptor("Waiter", method))
.addMethod(method,
new ServiceDescriptor("Waiter", METHOD))
.addMethod(METHOD,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
@ -695,20 +776,14 @@ public class ServerImplTest {
@Test
public void testCallContextIsBoundInListenerCallbacks() throws Exception {
createAndStartServer(NO_FILTERS);
MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
.setType(MethodDescriptor.MethodType.UNKNOWN)
.setFullMethodName("Waiter/serve")
.setRequestMarshaller(STRING_MARSHALLER)
.setResponseMarshaller(INTEGER_MARSHALLER)
.build();
final AtomicBoolean onReadyCalled = new AtomicBoolean(false);
final AtomicBoolean onMessageCalled = new AtomicBoolean(false);
final AtomicBoolean onHalfCloseCalled = new AtomicBoolean(false);
final AtomicBoolean onCancelCalled = new AtomicBoolean(false);
mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
new ServiceDescriptor("Waiter", method))
new ServiceDescriptor("Waiter", METHOD))
.addMethod(
method,
METHOD,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
@ -809,16 +884,9 @@ public class ServerImplTest {
}
};
MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
.setType(MethodDescriptor.MethodType.UNKNOWN)
.setFullMethodName("Waiter/serve")
.setRequestMarshaller(STRING_MARSHALLER)
.setResponseMarshaller(INTEGER_MARSHALLER)
.build();
mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
new ServiceDescriptor("Waiter", method))
.addMethod(method,
new ServiceDescriptor("Waiter", METHOD))
.addMethod(METHOD,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
@ -928,15 +996,9 @@ public class ServerImplTest {
@Test
public void handlerRegistryPriorities() throws Exception {
fallbackRegistry = mock(HandlerRegistry.class);
MethodDescriptor<String, Integer> method1 = MethodDescriptor.<String, Integer>newBuilder()
.setType(MethodDescriptor.MethodType.UNKNOWN)
.setFullMethodName("Service1/Method1")
.setRequestMarshaller(STRING_MARSHALLER)
.setResponseMarshaller(INTEGER_MARSHALLER)
.build();
registry = new InternalHandlerRegistry.Builder()
.addService(ServerServiceDefinition.builder(new ServiceDescriptor("Service1", method1))
.addMethod(method1, callHandler).build())
.addService(ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", METHOD))
.addMethod(METHOD, callHandler).build())
.build();
transportServer = new SimpleServer();
createAndStartServer(NO_FILTERS);
@ -945,11 +1007,11 @@ public class ServerImplTest {
= transportServer.registerNewServerTransport(new SimpleServerTransport());
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
StatsTraceContext.newServerContext(streamTracerFactories, "Waitier/serve", requestHeaders);
StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders);
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
// This call will be handled by callHandler from the internal registry
transportListener.streamCreated(stream, "Service1/Method1", requestHeaders);
transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
assertEquals(1, executor.runDueTasks());
verify(callHandler).startCall(Matchers.<ServerCall<String, Integer>>anyObject(),
Matchers.<Metadata>anyObject());
@ -1109,9 +1171,15 @@ public class ServerImplTest {
}
private void createServer(List<ServerTransportFilter> filters) {
createServer(filters, Collections.<ServerInterceptor>emptyList());
}
private void createServer(
List<ServerTransportFilter> filters, List<ServerInterceptor> interceptors) {
assertNull(server);
server = new ServerImpl(executorPool, timerPool, registry, fallbackRegistry,
transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry, filters);
transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry, filters,
interceptors);
}
private void verifyExecutorsAcquired() {