mirror of https://github.com/grpc/grpc-java.git
core: ServerBuilder.intercept(). (#3118)
This adds server-wide interceptors that applies to all call handlers. Because ServerCallHandler is acquired per request, and can be dynamicly provided by the fallback registry, the interceptors have to be installed on a per-request basis. This adds a few object allocations per request, which is acceptable.
This commit is contained in:
parent
3dce2ee84b
commit
859d211b6e
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Copyright 2017, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc;
|
||||
|
||||
/**
|
||||
* Accessor to internal methods of {@link ServerInterceptors}.
|
||||
*/
|
||||
@Internal
|
||||
public final class InternalServerInterceptors {
|
||||
public static <ReqT, RespT> ServerCallHandler<ReqT, RespT> interceptCallHandler(
|
||||
ServerInterceptor interceptor, ServerCallHandler<ReqT, RespT> callHandler) {
|
||||
return ServerInterceptors.InterceptCallHandler.create(interceptor, callHandler);
|
||||
}
|
||||
|
||||
private InternalServerInterceptors() {
|
||||
}
|
||||
}
|
||||
|
|
@ -88,6 +88,20 @@ public abstract class ServerBuilder<T extends ServerBuilder<T>> {
|
|||
*/
|
||||
public abstract T addService(BindableService bindableService);
|
||||
|
||||
/**
|
||||
* Adds a {@link ServerInterceptor} that is run for all services on the server. Interceptors
|
||||
* added through this method always run before per-service interceptors added through {@link
|
||||
* ServerInterceptors}. Interceptors run in the reverse order in which they are added.
|
||||
*
|
||||
* @param interceptor the all-service interceptor
|
||||
* @return this
|
||||
* @since 1.5.0
|
||||
*/
|
||||
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/3117")
|
||||
public T intercept(ServerInterceptor interceptor) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a {@link ServerTransportFilter}. The order of filters being added is the order they will
|
||||
* be executed.
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ public final class ServerInterceptors {
|
|||
serviceDefBuilder.addMethod(method.withServerCallHandler(callHandler));
|
||||
}
|
||||
|
||||
private static class InterceptCallHandler<ReqT, RespT> implements ServerCallHandler<ReqT, RespT> {
|
||||
static class InterceptCallHandler<ReqT, RespT> implements ServerCallHandler<ReqT, RespT> {
|
||||
public static <ReqT, RespT> InterceptCallHandler<ReqT, RespT> create(
|
||||
ServerInterceptor interceptor, ServerCallHandler<ReqT, RespT> callHandler) {
|
||||
return new InterceptCallHandler<ReqT, RespT>(interceptor, callHandler);
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ import io.grpc.Internal;
|
|||
import io.grpc.InternalNotifyOnServerBuild;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.ServerBuilder;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.ServerMethodDefinition;
|
||||
import io.grpc.ServerServiceDefinition;
|
||||
import io.grpc.ServerStreamTracer;
|
||||
|
|
@ -70,6 +71,9 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
|
|||
private final ArrayList<ServerTransportFilter> transportFilters =
|
||||
new ArrayList<ServerTransportFilter>();
|
||||
|
||||
private final ArrayList<ServerInterceptor> interceptors =
|
||||
new ArrayList<ServerInterceptor>();
|
||||
|
||||
private final List<InternalNotifyOnServerBuild> notifyOnBuildList =
|
||||
new ArrayList<InternalNotifyOnServerBuild>();
|
||||
|
||||
|
|
@ -122,6 +126,12 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
|
|||
return thisT();
|
||||
}
|
||||
|
||||
@Override
|
||||
public final T intercept(ServerInterceptor interceptor) {
|
||||
interceptors.add(interceptor);
|
||||
return thisT();
|
||||
}
|
||||
|
||||
@Override
|
||||
public final T addStreamTracerFactory(ServerStreamTracer.Factory factory) {
|
||||
streamTracerFactories.add(checkNotNull(factory, "factory"));
|
||||
|
|
@ -179,7 +189,7 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
|
|||
firstNonNull(fallbackRegistry, EMPTY_FALLBACK_REGISTRY), transportServer,
|
||||
Context.ROOT, firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()),
|
||||
firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()),
|
||||
transportFilters);
|
||||
transportFilters, interceptors);
|
||||
for (InternalNotifyOnServerBuild notifyTarget : notifyOnBuildList) {
|
||||
notifyTarget.notifyOnBuild(server);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,8 +32,11 @@ import io.grpc.Context;
|
|||
import io.grpc.Decompressor;
|
||||
import io.grpc.DecompressorRegistry;
|
||||
import io.grpc.HandlerRegistry;
|
||||
import io.grpc.InternalServerInterceptors;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.ServerMethodDefinition;
|
||||
import io.grpc.ServerServiceDefinition;
|
||||
import io.grpc.ServerTransportFilter;
|
||||
|
|
@ -74,6 +77,9 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
|
|||
private final InternalHandlerRegistry registry;
|
||||
private final HandlerRegistry fallbackRegistry;
|
||||
private final List<ServerTransportFilter> transportFilters;
|
||||
// This is iterated on a per-call basis. Use an array instead of a Collection to avoid iterator
|
||||
// creations.
|
||||
private final ServerInterceptor[] interceptors;
|
||||
@GuardedBy("lock") private boolean started;
|
||||
@GuardedBy("lock") private boolean shutdown;
|
||||
/** non-{@code null} if immediate shutdown has been requested. */
|
||||
|
|
@ -109,7 +115,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
|
|||
InternalHandlerRegistry registry, HandlerRegistry fallbackRegistry,
|
||||
InternalServer transportServer, Context rootContext,
|
||||
DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry,
|
||||
List<ServerTransportFilter> transportFilters) {
|
||||
List<ServerTransportFilter> transportFilters, List<ServerInterceptor> interceptors) {
|
||||
this.executorPool = Preconditions.checkNotNull(executorPool, "executorPool");
|
||||
this.timeoutServicePool = Preconditions.checkNotNull(timeoutServicePool, "timeoutServicePool");
|
||||
this.registry = Preconditions.checkNotNull(registry, "registry");
|
||||
|
|
@ -122,6 +128,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
|
|||
this.compressorRegistry = compressorRegistry;
|
||||
this.transportFilters = Collections.unmodifiableList(
|
||||
new ArrayList<ServerTransportFilter>(transportFilters));
|
||||
this.interceptors = interceptors.toArray(new ServerInterceptor[interceptors.size()]);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -469,9 +476,12 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
|
|||
ServerCallImpl<ReqT, RespT> call = new ServerCallImpl<ReqT, RespT>(
|
||||
stream, methodDef.getMethodDescriptor(), headers, context,
|
||||
decompressorRegistry, compressorRegistry);
|
||||
ServerCallHandler<ReqT, RespT> callHandler = methodDef.getServerCallHandler();
|
||||
statsTraceCtx.serverCallStarted(call);
|
||||
ServerCall.Listener<ReqT> listener =
|
||||
methodDef.getServerCallHandler().startCall(call, headers);
|
||||
for (ServerInterceptor interceptor : interceptors) {
|
||||
callHandler = InternalServerInterceptors.interceptCallHandler(interceptor, callHandler);
|
||||
}
|
||||
ServerCall.Listener<ReqT> listener = callHandler.startCall(call, headers);
|
||||
if (listener == null) {
|
||||
throw new NullPointerException(
|
||||
"startCall() returned a null listener for method " + fullMethodName);
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ import io.grpc.Metadata;
|
|||
import io.grpc.MethodDescriptor;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.ServerServiceDefinition;
|
||||
import io.grpc.ServerStreamTracer;
|
||||
import io.grpc.ServerTransportFilter;
|
||||
|
|
@ -70,6 +71,8 @@ import java.io.IOException;
|
|||
import java.io.InputStream;
|
||||
import java.net.SocketAddress;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CyclicBarrier;
|
||||
import java.util.concurrent.Executor;
|
||||
|
|
@ -96,6 +99,13 @@ import org.mockito.MockitoAnnotations;
|
|||
public class ServerImplTest {
|
||||
private static final IntegerMarshaller INTEGER_MARSHALLER = IntegerMarshaller.INSTANCE;
|
||||
private static final StringMarshaller STRING_MARSHALLER = StringMarshaller.INSTANCE;
|
||||
private static final MethodDescriptor<String, Integer> METHOD =
|
||||
MethodDescriptor.<String, Integer>newBuilder()
|
||||
.setType(MethodDescriptor.MethodType.UNKNOWN)
|
||||
.setFullMethodName("Waiter/serve")
|
||||
.setRequestMarshaller(STRING_MARSHALLER)
|
||||
.setResponseMarshaller(INTEGER_MARSHALLER)
|
||||
.build();
|
||||
private static final Context.Key<String> SERVER_ONLY = Context.key("serverOnly");
|
||||
private static final Context.Key<String> SERVER_TRACER_ADDED_KEY = Context.key("tracer-added");
|
||||
private static final Context.CancellableContext SERVER_CONTEXT =
|
||||
|
|
@ -402,16 +412,10 @@ public class ServerImplTest {
|
|||
final AtomicReference<ServerCall<String, Integer>> callReference
|
||||
= new AtomicReference<ServerCall<String, Integer>>();
|
||||
final AtomicReference<Context> callContextReference = new AtomicReference<Context>();
|
||||
MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
|
||||
.setType(MethodDescriptor.MethodType.UNKNOWN)
|
||||
.setFullMethodName("Waiter/serve")
|
||||
.setRequestMarshaller(STRING_MARSHALLER)
|
||||
.setResponseMarshaller(INTEGER_MARSHALLER)
|
||||
.build();
|
||||
mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
|
||||
new ServiceDescriptor("Waiter", method))
|
||||
new ServiceDescriptor("Waiter", METHOD))
|
||||
.addMethod(
|
||||
method,
|
||||
METHOD,
|
||||
new ServerCallHandler<String, Integer>() {
|
||||
@Override
|
||||
public ServerCall.Listener<String> startCall(
|
||||
|
|
@ -569,19 +573,96 @@ public class ServerImplTest {
|
|||
assertEquals(2, terminationCallbackCalled.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interceptors() throws Exception {
|
||||
final LinkedList<Context> capturedContexts = new LinkedList<Context>();
|
||||
final Context.Key<String> key1 = Context.key("key1");
|
||||
final Context.Key<String> key2 = Context.key("key2");
|
||||
final Context.Key<String> key3 = Context.key("key3");
|
||||
ServerInterceptor intercepter1 = new ServerInterceptor() {
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
|
||||
ServerCall<ReqT, RespT> call,
|
||||
Metadata headers,
|
||||
ServerCallHandler<ReqT, RespT> next) {
|
||||
Context ctx = Context.current().withValue(key1, "value1");
|
||||
Context origCtx = ctx.attach();
|
||||
try {
|
||||
capturedContexts.add(ctx);
|
||||
return next.startCall(call, headers);
|
||||
} finally {
|
||||
ctx.detach(origCtx);
|
||||
}
|
||||
}
|
||||
};
|
||||
ServerInterceptor intercepter2 = new ServerInterceptor() {
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
|
||||
ServerCall<ReqT, RespT> call,
|
||||
Metadata headers,
|
||||
ServerCallHandler<ReqT, RespT> next) {
|
||||
Context ctx = Context.current().withValue(key2, "value2");
|
||||
Context origCtx = ctx.attach();
|
||||
try {
|
||||
capturedContexts.add(ctx);
|
||||
return next.startCall(call, headers);
|
||||
} finally {
|
||||
ctx.detach(origCtx);
|
||||
}
|
||||
}
|
||||
};
|
||||
ServerCallHandler<String, Integer> callHandler = new ServerCallHandler<String, Integer>() {
|
||||
@Override
|
||||
public ServerCall.Listener<String> startCall(
|
||||
ServerCall<String, Integer> call,
|
||||
Metadata headers) {
|
||||
capturedContexts.add(Context.current().withValue(key3, "value3"));
|
||||
return callListener;
|
||||
}
|
||||
};
|
||||
|
||||
mutableFallbackRegistry.addService(
|
||||
ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", METHOD))
|
||||
.addMethod(METHOD, callHandler).build());
|
||||
createServer(NO_FILTERS, Arrays.asList(intercepter2, intercepter1));
|
||||
server.start();
|
||||
|
||||
ServerTransportListener transportListener
|
||||
= transportServer.registerNewServerTransport(new SimpleServerTransport());
|
||||
|
||||
Metadata requestHeaders = new Metadata();
|
||||
StatsTraceContext statsTraceCtx =
|
||||
StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders);
|
||||
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
|
||||
|
||||
transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
|
||||
assertEquals(1, executor.runDueTasks());
|
||||
|
||||
Context ctx1 = capturedContexts.poll();
|
||||
assertEquals("value1", key1.get(ctx1));
|
||||
assertNull(key2.get(ctx1));
|
||||
assertNull(key3.get(ctx1));
|
||||
|
||||
Context ctx2 = capturedContexts.poll();
|
||||
assertEquals("value1", key1.get(ctx2));
|
||||
assertEquals("value2", key2.get(ctx2));
|
||||
assertNull(key3.get(ctx2));
|
||||
|
||||
Context ctx3 = capturedContexts.poll();
|
||||
assertEquals("value1", key1.get(ctx3));
|
||||
assertEquals("value2", key2.get(ctx3));
|
||||
assertEquals("value3", key3.get(ctx3));
|
||||
|
||||
assertTrue(capturedContexts.isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void exceptionInStartCallPropagatesToStream() throws Exception {
|
||||
createAndStartServer(NO_FILTERS);
|
||||
final Status status = Status.ABORTED.withDescription("Oh, no!");
|
||||
MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
|
||||
.setType(MethodDescriptor.MethodType.UNKNOWN)
|
||||
.setFullMethodName("Waiter/serve")
|
||||
.setRequestMarshaller(STRING_MARSHALLER)
|
||||
.setResponseMarshaller(INTEGER_MARSHALLER)
|
||||
.build();
|
||||
mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
|
||||
new ServiceDescriptor("Waiter", method))
|
||||
.addMethod(method,
|
||||
new ServiceDescriptor("Waiter", METHOD))
|
||||
.addMethod(METHOD,
|
||||
new ServerCallHandler<String, Integer>() {
|
||||
@Override
|
||||
public ServerCall.Listener<String> startCall(
|
||||
|
|
@ -695,20 +776,14 @@ public class ServerImplTest {
|
|||
@Test
|
||||
public void testCallContextIsBoundInListenerCallbacks() throws Exception {
|
||||
createAndStartServer(NO_FILTERS);
|
||||
MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
|
||||
.setType(MethodDescriptor.MethodType.UNKNOWN)
|
||||
.setFullMethodName("Waiter/serve")
|
||||
.setRequestMarshaller(STRING_MARSHALLER)
|
||||
.setResponseMarshaller(INTEGER_MARSHALLER)
|
||||
.build();
|
||||
final AtomicBoolean onReadyCalled = new AtomicBoolean(false);
|
||||
final AtomicBoolean onMessageCalled = new AtomicBoolean(false);
|
||||
final AtomicBoolean onHalfCloseCalled = new AtomicBoolean(false);
|
||||
final AtomicBoolean onCancelCalled = new AtomicBoolean(false);
|
||||
mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
|
||||
new ServiceDescriptor("Waiter", method))
|
||||
new ServiceDescriptor("Waiter", METHOD))
|
||||
.addMethod(
|
||||
method,
|
||||
METHOD,
|
||||
new ServerCallHandler<String, Integer>() {
|
||||
@Override
|
||||
public ServerCall.Listener<String> startCall(
|
||||
|
|
@ -809,16 +884,9 @@ public class ServerImplTest {
|
|||
}
|
||||
};
|
||||
|
||||
MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
|
||||
.setType(MethodDescriptor.MethodType.UNKNOWN)
|
||||
.setFullMethodName("Waiter/serve")
|
||||
.setRequestMarshaller(STRING_MARSHALLER)
|
||||
.setResponseMarshaller(INTEGER_MARSHALLER)
|
||||
.build();
|
||||
|
||||
mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
|
||||
new ServiceDescriptor("Waiter", method))
|
||||
.addMethod(method,
|
||||
new ServiceDescriptor("Waiter", METHOD))
|
||||
.addMethod(METHOD,
|
||||
new ServerCallHandler<String, Integer>() {
|
||||
@Override
|
||||
public ServerCall.Listener<String> startCall(
|
||||
|
|
@ -928,15 +996,9 @@ public class ServerImplTest {
|
|||
@Test
|
||||
public void handlerRegistryPriorities() throws Exception {
|
||||
fallbackRegistry = mock(HandlerRegistry.class);
|
||||
MethodDescriptor<String, Integer> method1 = MethodDescriptor.<String, Integer>newBuilder()
|
||||
.setType(MethodDescriptor.MethodType.UNKNOWN)
|
||||
.setFullMethodName("Service1/Method1")
|
||||
.setRequestMarshaller(STRING_MARSHALLER)
|
||||
.setResponseMarshaller(INTEGER_MARSHALLER)
|
||||
.build();
|
||||
registry = new InternalHandlerRegistry.Builder()
|
||||
.addService(ServerServiceDefinition.builder(new ServiceDescriptor("Service1", method1))
|
||||
.addMethod(method1, callHandler).build())
|
||||
.addService(ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", METHOD))
|
||||
.addMethod(METHOD, callHandler).build())
|
||||
.build();
|
||||
transportServer = new SimpleServer();
|
||||
createAndStartServer(NO_FILTERS);
|
||||
|
|
@ -945,11 +1007,11 @@ public class ServerImplTest {
|
|||
= transportServer.registerNewServerTransport(new SimpleServerTransport());
|
||||
Metadata requestHeaders = new Metadata();
|
||||
StatsTraceContext statsTraceCtx =
|
||||
StatsTraceContext.newServerContext(streamTracerFactories, "Waitier/serve", requestHeaders);
|
||||
StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders);
|
||||
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
|
||||
|
||||
// This call will be handled by callHandler from the internal registry
|
||||
transportListener.streamCreated(stream, "Service1/Method1", requestHeaders);
|
||||
transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
|
||||
assertEquals(1, executor.runDueTasks());
|
||||
verify(callHandler).startCall(Matchers.<ServerCall<String, Integer>>anyObject(),
|
||||
Matchers.<Metadata>anyObject());
|
||||
|
|
@ -1109,9 +1171,15 @@ public class ServerImplTest {
|
|||
}
|
||||
|
||||
private void createServer(List<ServerTransportFilter> filters) {
|
||||
createServer(filters, Collections.<ServerInterceptor>emptyList());
|
||||
}
|
||||
|
||||
private void createServer(
|
||||
List<ServerTransportFilter> filters, List<ServerInterceptor> interceptors) {
|
||||
assertNull(server);
|
||||
server = new ServerImpl(executorPool, timerPool, registry, fallbackRegistry,
|
||||
transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry, filters);
|
||||
transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry, filters,
|
||||
interceptors);
|
||||
}
|
||||
|
||||
private void verifyExecutorsAcquired() {
|
||||
|
|
|
|||
Loading…
Reference in New Issue