diff --git a/core/src/main/java/io/grpc/ClientInterceptors.java b/core/src/main/java/io/grpc/ClientInterceptors.java index 9ad38fedbd..635d1374bd 100644 --- a/core/src/main/java/io/grpc/ClientInterceptors.java +++ b/core/src/main/java/io/grpc/ClientInterceptors.java @@ -32,13 +32,11 @@ package io.grpc; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; import java.util.Arrays; -import java.util.Iterator; import java.util.List; /** @@ -51,7 +49,8 @@ public class ClientInterceptors { /** * Create a new {@link Channel} that will call {@code interceptors} before starting a call on the - * given channel. + * given channel. The last interceptor will have its {@link ClientInterceptor#interceptCall} + * called first. * * @param channel the underlying channel to intercept. * @param interceptors array of interceptors to bind to {@code channel}. @@ -63,7 +62,8 @@ public class ClientInterceptors { /** * Create a new {@link Channel} that will call {@code interceptors} before starting a call on the - * given channel. + * given channel. The last interceptor will have its {@link ClientInterceptor#interceptCall} + * called first. * * @param channel the underlying channel to intercept. * @param interceptors a list of interceptors to bind to {@code channel}. @@ -71,51 +71,25 @@ public class ClientInterceptors { */ public static Channel intercept(Channel channel, List interceptors) { Preconditions.checkNotNull(channel); - if (interceptors.isEmpty()) { - return channel; + for (ClientInterceptor interceptor : interceptors) { + channel = new InterceptorChannel(channel, interceptor); } - return new InterceptorChannel(channel, interceptors); + return channel; } private static class InterceptorChannel extends Channel { private final Channel channel; - private final Iterable interceptors; + private final ClientInterceptor interceptor; - private InterceptorChannel(Channel channel, List interceptors) { + private InterceptorChannel(Channel channel, ClientInterceptor interceptor) { this.channel = channel; - this.interceptors = ImmutableList.copyOf(interceptors); + this.interceptor = Preconditions.checkNotNull(interceptor, "interceptor"); } @Override public ClientCall newCall( MethodDescriptor method, CallOptions callOptions) { - return new ProcessInterceptorChannel(channel, interceptors).newCall(method, callOptions); - } - } - - private static class ProcessInterceptorChannel extends Channel { - private final Channel channel; - private Iterator interceptors; - - private ProcessInterceptorChannel(Channel channel, Iterable interceptors) { - this.channel = channel; - this.interceptors = interceptors.iterator(); - } - - @Override - public ClientCall newCall( - MethodDescriptor method, CallOptions callOptions) { - if (interceptors != null && interceptors.hasNext()) { - return interceptors.next().interceptCall(method, callOptions, this); - } else { - Preconditions.checkState(interceptors != null, - "The channel has already been called. " - + "Some interceptor must have called on \"next\" twice."); - interceptors = null; - return channel.newCall( - Preconditions.checkNotNull(method, "method"), - Preconditions.checkNotNull(callOptions, "callOptions")); - } + return interceptor.interceptCall(method, callOptions, channel); } } diff --git a/core/src/main/java/io/grpc/ServerInterceptors.java b/core/src/main/java/io/grpc/ServerInterceptors.java index a8588fc99a..0965b96a30 100644 --- a/core/src/main/java/io/grpc/ServerInterceptors.java +++ b/core/src/main/java/io/grpc/ServerInterceptors.java @@ -32,13 +32,11 @@ package io.grpc; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener; import java.util.Arrays; -import java.util.Iterator; import java.util.List; /** @@ -50,7 +48,8 @@ public class ServerInterceptors { /** * Create a new {@code ServerServiceDefinition} whose {@link ServerCallHandler}s will call - * {@code interceptors} before calling the pre-existing {@code ServerCallHandler}. + * {@code interceptors} before calling the pre-existing {@code ServerCallHandler}. The last + * interceptor will have its {@link ServerInterceptor#interceptCall} called first. * * @param serviceDef the service definition for which to intercept all its methods. * @param interceptors array of interceptors to apply to the service. @@ -63,7 +62,8 @@ public class ServerInterceptors { /** * Create a new {@code ServerServiceDefinition} whose {@link ServerCallHandler}s will call - * {@code interceptors} before calling the pre-existing {@code ServerCallHandler}. + * {@code interceptors} before calling the pre-existing {@code ServerCallHandler}. The last + * interceptor will have its {@link ServerInterceptor#interceptCall} called first. * * @param serviceDef the service definition for which to intercept all its methods. * @param interceptors list of interceptors to apply to the service. @@ -72,14 +72,13 @@ public class ServerInterceptors { public static ServerServiceDefinition intercept(ServerServiceDefinition serviceDef, List interceptors) { Preconditions.checkNotNull(serviceDef); - List immutableInterceptors = ImmutableList.copyOf(interceptors); - if (immutableInterceptors.isEmpty()) { + if (interceptors.isEmpty()) { return serviceDef; } ServerServiceDefinition.Builder serviceDefBuilder = ServerServiceDefinition.builder(serviceDef.getName()); for (ServerMethodDefinition method : serviceDef.getMethods()) { - wrapAndAddMethod(serviceDefBuilder, method, immutableInterceptors); + wrapAndAddMethod(serviceDefBuilder, method, interceptors); } return serviceDefBuilder.build(); } @@ -87,62 +86,32 @@ public class ServerInterceptors { private static void wrapAndAddMethod( ServerServiceDefinition.Builder serviceDefBuilder, ServerMethodDefinition method, List interceptors) { - ServerCallHandler callHandler - = InterceptCallHandler.create(interceptors, method.getServerCallHandler()); + ServerCallHandler callHandler = method.getServerCallHandler(); + for (ServerInterceptor interceptor : interceptors) { + callHandler = InterceptCallHandler.create(interceptor, callHandler); + } serviceDefBuilder.addMethod(method.withServerCallHandler(callHandler)); } private static class InterceptCallHandler implements ServerCallHandler { public static InterceptCallHandler create( - List interceptors, ServerCallHandler callHandler) { - return new InterceptCallHandler(interceptors, callHandler); + ServerInterceptor interceptor, ServerCallHandler callHandler) { + return new InterceptCallHandler(interceptor, callHandler); } - private final List interceptors; + private final ServerInterceptor interceptor; private final ServerCallHandler callHandler; - private InterceptCallHandler(List interceptors, + private InterceptCallHandler(ServerInterceptor interceptor, ServerCallHandler callHandler) { - this.interceptors = interceptors; + this.interceptor = Preconditions.checkNotNull(interceptor, "interceptor"); this.callHandler = callHandler; } @Override public ServerCall.Listener startCall(String method, ServerCall call, Metadata.Headers headers) { - return ProcessInterceptorsCallHandler.create(interceptors.iterator(), callHandler) - .startCall(method, call, headers); - } - } - - private static class ProcessInterceptorsCallHandler - implements ServerCallHandler { - public static ProcessInterceptorsCallHandler create( - Iterator interceptors, ServerCallHandler callHandler) { - return new ProcessInterceptorsCallHandler(interceptors, callHandler); - } - - private Iterator interceptors; - private final ServerCallHandler callHandler; - - private ProcessInterceptorsCallHandler(Iterator interceptors, - ServerCallHandler callHandler) { - this.interceptors = interceptors; - this.callHandler = callHandler; - } - - @Override - public ServerCall.Listener startCall(String method, ServerCall call, - Metadata.Headers headers) { - if (interceptors != null && interceptors.hasNext()) { - return interceptors.next().interceptCall(method, call, headers, this); - } else { - Preconditions.checkState(interceptors != null, - "The call handler has already been called. " - + "Some interceptor must have called on \"next\" twice."); - interceptors = null; - return callHandler.startCall(method, call, headers); - } + return interceptor.interceptCall(method, call, headers, callHandler); } } diff --git a/core/src/test/java/io/grpc/ClientInterceptorsTest.java b/core/src/test/java/io/grpc/ClientInterceptorsTest.java index 84c4813634..abd0abdcef 100644 --- a/core/src/test/java/io/grpc/ClientInterceptorsTest.java +++ b/core/src/test/java/io/grpc/ClientInterceptorsTest.java @@ -139,7 +139,7 @@ public class ClientInterceptorsTest { verifyNoMoreInteractions(channel, interceptor); } - @Test(expected = IllegalStateException.class) + @Test public void callNextTwice() { ClientInterceptor interceptor = new ClientInterceptor() { @Override @@ -147,12 +147,15 @@ public class ClientInterceptorsTest { MethodDescriptor method, CallOptions callOptions, Channel next) { - next.newCall(method, callOptions); + // Calling next twice is permitted, although should only rarely be useful. + assertSame(call, next.newCall(method, callOptions)); return next.newCall(method, callOptions); } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor); - intercepted.newCall(method, CallOptions.DEFAULT); + assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT)); + verify(channel, times(2)).newCall(same(method), same(CallOptions.DEFAULT)); + verifyNoMoreInteractions(channel); } @Test @@ -189,7 +192,7 @@ public class ClientInterceptorsTest { }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor1, interceptor2); assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT)); - assertEquals(Arrays.asList("i1", "i2", "channel"), order); + assertEquals(Arrays.asList("i2", "i1", "channel"), order); } @Test diff --git a/core/src/test/java/io/grpc/ServerInterceptorsTest.java b/core/src/test/java/io/grpc/ServerInterceptorsTest.java index 95d6eec4a1..80b27e9b95 100644 --- a/core/src/test/java/io/grpc/ServerInterceptorsTest.java +++ b/core/src/test/java/io/grpc/ServerInterceptorsTest.java @@ -158,19 +158,23 @@ public class ServerInterceptorsTest { verifyNoMoreInteractions(handler2); } - @Test(expected = IllegalStateException.class) + @Test public void callNextTwice() { ServerInterceptor interceptor = new ServerInterceptor() { @Override public ServerCall.Listener interceptCall(String method, ServerCall call, Headers headers, ServerCallHandler next) { - next.startCall(method, call, headers); + // Calling next twice is permitted, although should only rarely be useful. + assertSame(listener, next.startCall(method, call, headers)); return next.startCall(method, call, headers); } }; ServerServiceDefinition intercepted = ServerInterceptors.intercept(serviceDefinition, interceptor); - getSoleMethod(intercepted).getServerCallHandler().startCall(methodName, call, headers); + assertSame(listener, + getSoleMethod(intercepted).getServerCallHandler().startCall(methodName, call, headers)); + verify(handler, times(2)).startCall(same(methodName), same(call), same(headers)); + verifyNoMoreInteractions(handler); } @Test @@ -207,7 +211,7 @@ public class ServerInterceptorsTest { serviceDefinition, Arrays.asList(interceptor1, interceptor2)); assertSame(listener, getSoleMethod(intercepted).getServerCallHandler().startCall(methodName, call, headers)); - assertEquals(Arrays.asList("i1", "i2", "handler"), order); + assertEquals(Arrays.asList("i2", "i1", "handler"), order); } @Test