From 859d211b6ece5827fdde56c257e9da7e3f4d56c1 Mon Sep 17 00:00:00 2001 From: Kun Zhang Date: Wed, 21 Jun 2017 08:55:16 -0700 Subject: [PATCH] core: ServerBuilder.intercept(). (#3118) This adds server-wide interceptors that applies to all call handlers. Because ServerCallHandler is acquired per request, and can be dynamicly provided by the fallback registry, the interceptors have to be installed on a per-request basis. This adds a few object allocations per request, which is acceptable. --- .../io/grpc/InternalServerInterceptors.java | 31 ++++ core/src/main/java/io/grpc/ServerBuilder.java | 14 ++ .../main/java/io/grpc/ServerInterceptors.java | 2 +- .../internal/AbstractServerImplBuilder.java | 12 +- .../java/io/grpc/internal/ServerImpl.java | 16 +- .../java/io/grpc/internal/ServerImplTest.java | 156 +++++++++++++----- 6 files changed, 182 insertions(+), 49 deletions(-) create mode 100644 core/src/main/java/io/grpc/InternalServerInterceptors.java diff --git a/core/src/main/java/io/grpc/InternalServerInterceptors.java b/core/src/main/java/io/grpc/InternalServerInterceptors.java new file mode 100644 index 0000000000..e981aa6cfd --- /dev/null +++ b/core/src/main/java/io/grpc/InternalServerInterceptors.java @@ -0,0 +1,31 @@ +/* + * Copyright 2017, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +/** + * Accessor to internal methods of {@link ServerInterceptors}. + */ +@Internal +public final class InternalServerInterceptors { + public static ServerCallHandler interceptCallHandler( + ServerInterceptor interceptor, ServerCallHandler callHandler) { + return ServerInterceptors.InterceptCallHandler.create(interceptor, callHandler); + } + + private InternalServerInterceptors() { + } +} diff --git a/core/src/main/java/io/grpc/ServerBuilder.java b/core/src/main/java/io/grpc/ServerBuilder.java index 53c851d4f0..87008ac2b1 100644 --- a/core/src/main/java/io/grpc/ServerBuilder.java +++ b/core/src/main/java/io/grpc/ServerBuilder.java @@ -88,6 +88,20 @@ public abstract class ServerBuilder> { */ public abstract T addService(BindableService bindableService); + /** + * Adds a {@link ServerInterceptor} that is run for all services on the server. Interceptors + * added through this method always run before per-service interceptors added through {@link + * ServerInterceptors}. Interceptors run in the reverse order in which they are added. + * + * @param interceptor the all-service interceptor + * @return this + * @since 1.5.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/3117") + public T intercept(ServerInterceptor interceptor) { + throw new UnsupportedOperationException(); + } + /** * Adds a {@link ServerTransportFilter}. The order of filters being added is the order they will * be executed. diff --git a/core/src/main/java/io/grpc/ServerInterceptors.java b/core/src/main/java/io/grpc/ServerInterceptors.java index 1c3cd5239a..7917a7d7d6 100644 --- a/core/src/main/java/io/grpc/ServerInterceptors.java +++ b/core/src/main/java/io/grpc/ServerInterceptors.java @@ -207,7 +207,7 @@ public final class ServerInterceptors { serviceDefBuilder.addMethod(method.withServerCallHandler(callHandler)); } - private static class InterceptCallHandler implements ServerCallHandler { + static class InterceptCallHandler implements ServerCallHandler { public static InterceptCallHandler create( ServerInterceptor interceptor, ServerCallHandler callHandler) { return new InterceptCallHandler(interceptor, callHandler); diff --git a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java index 707f6f400f..b1bc50c8a2 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java @@ -33,6 +33,7 @@ import io.grpc.Internal; import io.grpc.InternalNotifyOnServerBuild; import io.grpc.Server; import io.grpc.ServerBuilder; +import io.grpc.ServerInterceptor; import io.grpc.ServerMethodDefinition; import io.grpc.ServerServiceDefinition; import io.grpc.ServerStreamTracer; @@ -70,6 +71,9 @@ public abstract class AbstractServerImplBuilder transportFilters = new ArrayList(); + private final ArrayList interceptors = + new ArrayList(); + private final List notifyOnBuildList = new ArrayList(); @@ -122,6 +126,12 @@ public abstract class AbstractServerImplBuilder transportFilters; + // This is iterated on a per-call basis. Use an array instead of a Collection to avoid iterator + // creations. + private final ServerInterceptor[] interceptors; @GuardedBy("lock") private boolean started; @GuardedBy("lock") private boolean shutdown; /** non-{@code null} if immediate shutdown has been requested. */ @@ -109,7 +115,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { InternalHandlerRegistry registry, HandlerRegistry fallbackRegistry, InternalServer transportServer, Context rootContext, DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry, - List transportFilters) { + List transportFilters, List interceptors) { this.executorPool = Preconditions.checkNotNull(executorPool, "executorPool"); this.timeoutServicePool = Preconditions.checkNotNull(timeoutServicePool, "timeoutServicePool"); this.registry = Preconditions.checkNotNull(registry, "registry"); @@ -122,6 +128,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { this.compressorRegistry = compressorRegistry; this.transportFilters = Collections.unmodifiableList( new ArrayList(transportFilters)); + this.interceptors = interceptors.toArray(new ServerInterceptor[interceptors.size()]); } /** @@ -469,9 +476,12 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { ServerCallImpl call = new ServerCallImpl( stream, methodDef.getMethodDescriptor(), headers, context, decompressorRegistry, compressorRegistry); + ServerCallHandler callHandler = methodDef.getServerCallHandler(); statsTraceCtx.serverCallStarted(call); - ServerCall.Listener listener = - methodDef.getServerCallHandler().startCall(call, headers); + for (ServerInterceptor interceptor : interceptors) { + callHandler = InternalServerInterceptors.interceptCallHandler(interceptor, callHandler); + } + ServerCall.Listener listener = callHandler.startCall(call, headers); if (listener == null) { throw new NullPointerException( "startCall() returned a null listener for method " + fullMethodName); diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index b95c7edc54..ce285916ea 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -57,6 +57,7 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; import io.grpc.ServerServiceDefinition; import io.grpc.ServerStreamTracer; import io.grpc.ServerTransportFilter; @@ -70,6 +71,8 @@ import java.io.IOException; import java.io.InputStream; import java.net.SocketAddress; import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedList; import java.util.List; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Executor; @@ -96,6 +99,13 @@ import org.mockito.MockitoAnnotations; public class ServerImplTest { private static final IntegerMarshaller INTEGER_MARSHALLER = IntegerMarshaller.INSTANCE; private static final StringMarshaller STRING_MARSHALLER = StringMarshaller.INSTANCE; + private static final MethodDescriptor METHOD = + MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNKNOWN) + .setFullMethodName("Waiter/serve") + .setRequestMarshaller(STRING_MARSHALLER) + .setResponseMarshaller(INTEGER_MARSHALLER) + .build(); private static final Context.Key SERVER_ONLY = Context.key("serverOnly"); private static final Context.Key SERVER_TRACER_ADDED_KEY = Context.key("tracer-added"); private static final Context.CancellableContext SERVER_CONTEXT = @@ -402,16 +412,10 @@ public class ServerImplTest { final AtomicReference> callReference = new AtomicReference>(); final AtomicReference callContextReference = new AtomicReference(); - MethodDescriptor method = MethodDescriptor.newBuilder() - .setType(MethodDescriptor.MethodType.UNKNOWN) - .setFullMethodName("Waiter/serve") - .setRequestMarshaller(STRING_MARSHALLER) - .setResponseMarshaller(INTEGER_MARSHALLER) - .build(); mutableFallbackRegistry.addService(ServerServiceDefinition.builder( - new ServiceDescriptor("Waiter", method)) + new ServiceDescriptor("Waiter", METHOD)) .addMethod( - method, + METHOD, new ServerCallHandler() { @Override public ServerCall.Listener startCall( @@ -569,19 +573,96 @@ public class ServerImplTest { assertEquals(2, terminationCallbackCalled.get()); } + @Test + public void interceptors() throws Exception { + final LinkedList capturedContexts = new LinkedList(); + final Context.Key key1 = Context.key("key1"); + final Context.Key key2 = Context.key("key2"); + final Context.Key key3 = Context.key("key3"); + ServerInterceptor intercepter1 = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + Metadata headers, + ServerCallHandler next) { + Context ctx = Context.current().withValue(key1, "value1"); + Context origCtx = ctx.attach(); + try { + capturedContexts.add(ctx); + return next.startCall(call, headers); + } finally { + ctx.detach(origCtx); + } + } + }; + ServerInterceptor intercepter2 = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + Metadata headers, + ServerCallHandler next) { + Context ctx = Context.current().withValue(key2, "value2"); + Context origCtx = ctx.attach(); + try { + capturedContexts.add(ctx); + return next.startCall(call, headers); + } finally { + ctx.detach(origCtx); + } + } + }; + ServerCallHandler callHandler = new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, + Metadata headers) { + capturedContexts.add(Context.current().withValue(key3, "value3")); + return callListener; + } + }; + + mutableFallbackRegistry.addService( + ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", METHOD)) + .addMethod(METHOD, callHandler).build()); + createServer(NO_FILTERS, Arrays.asList(intercepter2, intercepter1)); + server.start(); + + ServerTransportListener transportListener + = transportServer.registerNewServerTransport(new SimpleServerTransport()); + + Metadata requestHeaders = new Metadata(); + StatsTraceContext statsTraceCtx = + StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders); + when(stream.statsTraceContext()).thenReturn(statsTraceCtx); + + transportListener.streamCreated(stream, "Waiter/serve", requestHeaders); + assertEquals(1, executor.runDueTasks()); + + Context ctx1 = capturedContexts.poll(); + assertEquals("value1", key1.get(ctx1)); + assertNull(key2.get(ctx1)); + assertNull(key3.get(ctx1)); + + Context ctx2 = capturedContexts.poll(); + assertEquals("value1", key1.get(ctx2)); + assertEquals("value2", key2.get(ctx2)); + assertNull(key3.get(ctx2)); + + Context ctx3 = capturedContexts.poll(); + assertEquals("value1", key1.get(ctx3)); + assertEquals("value2", key2.get(ctx3)); + assertEquals("value3", key3.get(ctx3)); + + assertTrue(capturedContexts.isEmpty()); + } + @Test public void exceptionInStartCallPropagatesToStream() throws Exception { createAndStartServer(NO_FILTERS); final Status status = Status.ABORTED.withDescription("Oh, no!"); - MethodDescriptor method = MethodDescriptor.newBuilder() - .setType(MethodDescriptor.MethodType.UNKNOWN) - .setFullMethodName("Waiter/serve") - .setRequestMarshaller(STRING_MARSHALLER) - .setResponseMarshaller(INTEGER_MARSHALLER) - .build(); mutableFallbackRegistry.addService(ServerServiceDefinition.builder( - new ServiceDescriptor("Waiter", method)) - .addMethod(method, + new ServiceDescriptor("Waiter", METHOD)) + .addMethod(METHOD, new ServerCallHandler() { @Override public ServerCall.Listener startCall( @@ -695,20 +776,14 @@ public class ServerImplTest { @Test public void testCallContextIsBoundInListenerCallbacks() throws Exception { createAndStartServer(NO_FILTERS); - MethodDescriptor method = MethodDescriptor.newBuilder() - .setType(MethodDescriptor.MethodType.UNKNOWN) - .setFullMethodName("Waiter/serve") - .setRequestMarshaller(STRING_MARSHALLER) - .setResponseMarshaller(INTEGER_MARSHALLER) - .build(); final AtomicBoolean onReadyCalled = new AtomicBoolean(false); final AtomicBoolean onMessageCalled = new AtomicBoolean(false); final AtomicBoolean onHalfCloseCalled = new AtomicBoolean(false); final AtomicBoolean onCancelCalled = new AtomicBoolean(false); mutableFallbackRegistry.addService(ServerServiceDefinition.builder( - new ServiceDescriptor("Waiter", method)) + new ServiceDescriptor("Waiter", METHOD)) .addMethod( - method, + METHOD, new ServerCallHandler() { @Override public ServerCall.Listener startCall( @@ -809,16 +884,9 @@ public class ServerImplTest { } }; - MethodDescriptor method = MethodDescriptor.newBuilder() - .setType(MethodDescriptor.MethodType.UNKNOWN) - .setFullMethodName("Waiter/serve") - .setRequestMarshaller(STRING_MARSHALLER) - .setResponseMarshaller(INTEGER_MARSHALLER) - .build(); - mutableFallbackRegistry.addService(ServerServiceDefinition.builder( - new ServiceDescriptor("Waiter", method)) - .addMethod(method, + new ServiceDescriptor("Waiter", METHOD)) + .addMethod(METHOD, new ServerCallHandler() { @Override public ServerCall.Listener startCall( @@ -928,15 +996,9 @@ public class ServerImplTest { @Test public void handlerRegistryPriorities() throws Exception { fallbackRegistry = mock(HandlerRegistry.class); - MethodDescriptor method1 = MethodDescriptor.newBuilder() - .setType(MethodDescriptor.MethodType.UNKNOWN) - .setFullMethodName("Service1/Method1") - .setRequestMarshaller(STRING_MARSHALLER) - .setResponseMarshaller(INTEGER_MARSHALLER) - .build(); registry = new InternalHandlerRegistry.Builder() - .addService(ServerServiceDefinition.builder(new ServiceDescriptor("Service1", method1)) - .addMethod(method1, callHandler).build()) + .addService(ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", METHOD)) + .addMethod(METHOD, callHandler).build()) .build(); transportServer = new SimpleServer(); createAndStartServer(NO_FILTERS); @@ -945,11 +1007,11 @@ public class ServerImplTest { = transportServer.registerNewServerTransport(new SimpleServerTransport()); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = - StatsTraceContext.newServerContext(streamTracerFactories, "Waitier/serve", requestHeaders); + StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders); when(stream.statsTraceContext()).thenReturn(statsTraceCtx); // This call will be handled by callHandler from the internal registry - transportListener.streamCreated(stream, "Service1/Method1", requestHeaders); + transportListener.streamCreated(stream, "Waiter/serve", requestHeaders); assertEquals(1, executor.runDueTasks()); verify(callHandler).startCall(Matchers.>anyObject(), Matchers.anyObject()); @@ -1109,9 +1171,15 @@ public class ServerImplTest { } private void createServer(List filters) { + createServer(filters, Collections.emptyList()); + } + + private void createServer( + List filters, List interceptors) { assertNull(server); server = new ServerImpl(executorPool, timerPool, registry, fallbackRegistry, - transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry, filters); + transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry, filters, + interceptors); } private void verifyExecutorsAcquired() {