diff --git a/core/src/main/java/io/grpc/CompressorRegistry.java b/core/src/main/java/io/grpc/CompressorRegistry.java index 0eda388c7e..65fe540fbd 100644 --- a/core/src/main/java/io/grpc/CompressorRegistry.java +++ b/core/src/main/java/io/grpc/CompressorRegistry.java @@ -31,8 +31,11 @@ package io.grpc; +import static com.google.common.base.Preconditions.checkArgument; + import com.google.common.annotations.VisibleForTesting; + import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -70,4 +73,15 @@ public final class CompressorRegistry { public Compressor lookupCompressor(String compressorName) { return compressors.get(compressorName); } + + /** + * Registers a compressor for both decompression and message encoding negotiation. + * + * @param c The compressor to register + */ + public void register(Compressor c) { + String encoding = c.getMessageEncoding(); + checkArgument(!encoding.contains(","), "Comma is currently not allowed in message encoding"); + compressors.put(encoding, c); + } } diff --git a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java index c50e880911..95675edc80 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java @@ -31,6 +31,7 @@ package io.grpc.inprocess; +import io.grpc.Compressor; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; import io.grpc.Metadata; @@ -326,11 +327,13 @@ class InProcessTransport implements ServerTransport, ClientTransport { @Override public void setMessageCompression(boolean enable) { - // noop + // noop } @Override - public void pickCompressor(Iterable messageEncodings) {} + public Compressor pickCompressor(Iterable messageEncodings) { + return null; + } @Override public void setCompressionRegistry(CompressorRegistry registry) {} @@ -447,7 +450,9 @@ class InProcessTransport implements ServerTransport, ClientTransport { public void setMessageCompression(boolean enable) {} @Override - public void pickCompressor(Iterable messageEncodings) {} + public Compressor pickCompressor(Iterable messageEncodings) { + return null; + } @Override public void setCompressionRegistry(CompressorRegistry registry) {} diff --git a/core/src/main/java/io/grpc/internal/AbstractServerStream.java b/core/src/main/java/io/grpc/internal/AbstractServerStream.java index 242ce0a736..2381d3a59f 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerStream.java @@ -32,10 +32,14 @@ package io.grpc.internal; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER; import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER; +import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY; +import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; import com.google.common.base.Preconditions; +import io.grpc.Compressor; import io.grpc.Metadata; import io.grpc.Status; @@ -57,6 +61,7 @@ public abstract class AbstractServerStream extends AbstractStream private ServerStreamListener listener; private boolean headersSent = false; + private String messageEncoding; /** * Whether the stream was closed gracefully by the application (vs. a transport-level failure). */ @@ -95,6 +100,17 @@ public abstract class AbstractServerStream extends AbstractStream @Override public final void writeHeaders(Metadata headers) { Preconditions.checkNotNull(headers, "headers"); + headers.removeAll(MESSAGE_ENCODING_KEY); + if (messageEncoding != null) { + headers.put(MESSAGE_ENCODING_KEY, messageEncoding); + } + headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY); + if (!decompressorRegistry().getAdvertisedMessageEncodings().isEmpty()) { + String acceptEncoding = + ACCEPT_ENCODING_JOINER.join(decompressorRegistry().getAdvertisedMessageEncodings()); + headers.put(MESSAGE_ACCEPT_ENCODING_KEY, acceptEncoding); + } + outboundPhase(Phase.HEADERS); headersSent = true; internalSendHeaders(headers); @@ -148,9 +164,20 @@ public abstract class AbstractServerStream extends AbstractStream return; } } + // This checks to see if the client will accept any encoding. If so, a compressor is picked for + // the stream, and the decision is recorded. When the Server Call Handler writes the first + // headers, the negotiated encoding will be added in #writeHeaders(). It is safe to call + // pickCompressor multiple times before the headers have been written to the wire, though in + // practice this should never happen. There should only be one call to inboundHeadersReceived. + + // Alternatively, compression could be negotiated after the server handler is invoked, but that + // would mean the inbound header would have to be stored until the first #writeHeaders call. if (headers.containsKey(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)) { - pickCompressor( + Compressor c = pickCompressor( ACCEPT_ENCODING_SPLITER.split(headers.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY))); + if (c != null) { + messageEncoding = c.getMessageEncoding(); + } } inboundPhase(Phase.MESSAGE); diff --git a/core/src/main/java/io/grpc/internal/AbstractStream.java b/core/src/main/java/io/grpc/internal/AbstractStream.java index 23eb082855..cd0ccdf8b3 100644 --- a/core/src/main/java/io/grpc/internal/AbstractStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractStream.java @@ -330,16 +330,22 @@ public abstract class AbstractStream implements Stream { } @Override - public final void pickCompressor(Iterable messageEncodings) { + public final Compressor pickCompressor(Iterable messageEncodings) { for (String messageEncoding : messageEncodings) { Compressor c = compressorRegistry.lookupCompressor(messageEncoding); if (c != null) { // TODO(carl-mastrangelo): check that headers haven't already been sent. I can't find where // the client stream changes outbound phase correctly, so I am ignoring it. framer.setCompressor(c); - break; + return c; } } + return null; + } + + // TODO(carl-mastrangelo): this is a hack to get around registry passing. Remove it. + protected final DecompressorRegistry decompressorRegistry() { + return decompressorRegistry; } /** diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index bdd169d3c6..6ce2c7b2cf 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -34,6 +34,7 @@ package io.grpc.internal; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.Iterables.addAll; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER; import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER; import static io.grpc.internal.GrpcUtil.AUTHORITY_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY; @@ -181,7 +182,7 @@ final class ClientCallImpl extends ClientCall headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY); if (!decompressorRegistry.getAdvertisedMessageEncodings().isEmpty()) { String acceptEncoding = - Joiner.on(',').join(decompressorRegistry.getAdvertisedMessageEncodings()); + ACCEPT_ENCODING_JOINER.join(decompressorRegistry.getAdvertisedMessageEncodings()); headers.put(MESSAGE_ACCEPT_ENCODING_KEY, acceptEncoding); } } diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index 5343fb475c..3b7c13ee33 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -215,13 +215,17 @@ class DelayedStream implements ClientStream { } @Override - public void pickCompressor(Iterable messageEncodings) { + public Compressor pickCompressor(Iterable messageEncodings) { synchronized (this) { compressionMessageEncodings = messageEncodings; if (realStream != null) { - realStream.pickCompressor(messageEncodings); + return realStream.pickCompressor(messageEncodings); } } + // ClientCall never uses this. Since the stream doesn't exist yet, it can't say what + // stream it would pick. Eventually this will need a cleaner solution. + // TODO(carl-mastrangelo): Remove this. + return null; } @Override diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 0214ba3f81..73de454c08 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -36,6 +36,7 @@ import static io.grpc.Status.Code.CANCELLED; import static io.grpc.Status.Code.DEADLINE_EXCEEDED; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Joiner; import com.google.common.base.Preconditions; import com.google.common.base.Splitter; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -159,6 +160,8 @@ public final class GrpcUtil { public static final Splitter ACCEPT_ENCODING_SPLITER = Splitter.on(',').trimResults(); + public static final Joiner ACCEPT_ENCODING_JOINER = Joiner.on(','); + /** * Maps HTTP error response status codes to transport codes. */ diff --git a/core/src/main/java/io/grpc/internal/NoopClientStream.java b/core/src/main/java/io/grpc/internal/NoopClientStream.java index 0e2db85a4e..762dcb6de9 100644 --- a/core/src/main/java/io/grpc/internal/NoopClientStream.java +++ b/core/src/main/java/io/grpc/internal/NoopClientStream.java @@ -31,6 +31,7 @@ package io.grpc.internal; +import io.grpc.Compressor; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; import io.grpc.Status; @@ -72,7 +73,9 @@ public class NoopClientStream implements ClientStream { } @Override - public void pickCompressor(Iterable messageEncodings) {} + public Compressor pickCompressor(Iterable messageEncodings) { + return null; + } @Override public void setCompressionRegistry(CompressorRegistry registry) {} diff --git a/core/src/main/java/io/grpc/internal/Stream.java b/core/src/main/java/io/grpc/internal/Stream.java index 8fac4b24e9..33e641cba7 100644 --- a/core/src/main/java/io/grpc/internal/Stream.java +++ b/core/src/main/java/io/grpc/internal/Stream.java @@ -31,11 +31,14 @@ package io.grpc.internal; +import io.grpc.Compressor; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; import java.io.InputStream; +import javax.annotation.Nullable; + /** * A single stream of communication between two end-points within a transport. * @@ -81,12 +84,15 @@ public interface Stream { /** * Picks a compressor for for this stream. If no message encodings are acceptable, compression is - * not used. + * not used. It is undefined if this this method is invoked multiple times. + * * * @param messageEncodings a group of message encoding names that the remote endpoint is known * to support. + * @return The compressor chosen for the stream, or null if none selected. */ - void pickCompressor(Iterable messageEncodings); + @Nullable + Compressor pickCompressor(Iterable messageEncodings); /** * Enables per-message compression, if an encoding type has been negotiated. If no message diff --git a/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldClient.java b/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldClient.java index 3de4631228..3de7715d31 100644 --- a/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldClient.java +++ b/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldClient.java @@ -31,14 +31,22 @@ package io.grpc.examples.experimental; +import com.google.common.util.concurrent.Uninterruptibles; + +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ClientCall.Listener; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.Status; import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloRequest; import io.grpc.examples.helloworld.HelloResponse; +import io.grpc.internal.GrpcUtil; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.logging.Level; import java.util.logging.Logger; /** @@ -53,14 +61,12 @@ public class CompressingHelloWorldClient { Logger.getLogger(CompressingHelloWorldClient.class.getName()); private final ManagedChannel channel; - private final GreeterGrpc.GreeterBlockingStub blockingStub; /** Construct client connecting to HelloWorld server at {@code host:port}. */ public CompressingHelloWorldClient(String host, int port) { channel = ManagedChannelBuilder.forAddress(host, port) .usePlaintext(true) .build(); - blockingStub = GreeterGrpc.newBlockingStub(channel); } public void shutdown() throws InterruptedException { @@ -68,16 +74,44 @@ public class CompressingHelloWorldClient { } /** Say hello to server. */ - public void greet(String name) { - try { - logger.info("Will try to greet " + name + " ..."); - HelloRequest request = HelloRequest.newBuilder().setName(name).build(); - HelloResponse response = blockingStub.sayHello(request); - logger.info("Greeting: " + response.getMessage()); - } catch (RuntimeException e) { - logger.log(Level.WARNING, "RPC failed", e); - return; - } + public void greet(final String name) { + final ClientCall call = + channel.newCall(GreeterGrpc.METHOD_SAY_HELLO, CallOptions.DEFAULT); + + final CountDownLatch latch = new CountDownLatch(1); + + call.start(new Listener() { + @Override + public void onHeaders(Metadata headers) { + super.onHeaders(headers); + String encoding = headers.get(GrpcUtil.MESSAGE_ENCODING_KEY); + if (encoding == null) { + throw new RuntimeException("No compression selected!"); + } + } + + @Override + public void onMessage(HelloResponse message) { + super.onMessage(message); + logger.info("Greeting: " + message.getMessage()); + latch.countDown(); + } + + @Override + public void onClose(Status status, Metadata trailers) { + latch.countDown(); + if (!status.isOk()) { + throw status.asRuntimeException(); + } + } + }, new Metadata()); + + call.setMessageCompression(true); + call.sendMessage(HelloRequest.newBuilder().setName(name).build()); + call.request(1); + call.halfClose(); + + Uninterruptibles.awaitUninterruptibly(latch, 100, TimeUnit.SECONDS); } /** diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractTransportTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractTransportTest.java index 98fc434e09..733c50d15f 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractTransportTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractTransportTest.java @@ -49,6 +49,7 @@ import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.OAuth2Credentials; import com.google.auth.oauth2.ServiceAccountCredentials; import com.google.auth.oauth2.ServiceAccountJwtAccessCredentials; +import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; import com.google.protobuf.EmptyProtos.Empty; @@ -58,6 +59,7 @@ import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.Server; import io.grpc.ServerBuilder; +import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptors; import io.grpc.Status; import io.grpc.StatusRuntimeException; @@ -107,13 +109,19 @@ public abstract class AbstractTransportTest { private static Server server; private static int OPERATION_TIMEOUT = 5000; - protected static void startStaticServer(ServerBuilder builder) { + protected static void startStaticServer( + ServerBuilder builder, ServerInterceptor ... interceptors) { testServiceExecutor = Executors.newScheduledThreadPool(2); + List allInterceptors = ImmutableList.builder() + .add(TestUtils.recordRequestHeadersInterceptor(requestHeadersCapture)) + .add(TestUtils.echoRequestHeadersInterceptor(Util.METADATA_KEY)) + .add(interceptors) + .build(); + builder.addService(ServerInterceptors.intercept( TestServiceGrpc.bindService(new TestServiceImpl(testServiceExecutor)), - TestUtils.recordRequestHeadersInterceptor(requestHeadersCapture), - TestUtils.echoRequestHeadersInterceptor(Util.METADATA_KEY))); + allInterceptors)); try { server = builder.build().start(); } catch (IOException ex) { @@ -584,7 +592,7 @@ public abstract class AbstractTransportTest { Assert.assertEquals(contextValue, trailersCapture.get().get(METADATA_KEY)); } - @Test(timeout = 10000) + @Test(timeout = 100000000) public void sendsTimeoutHeader() { long configuredTimeoutMinutes = 100; TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel) diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java new file mode 100644 index 0000000000..4b3746a60e --- /dev/null +++ b/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java @@ -0,0 +1,300 @@ +/* + * Copyright 2015, Google Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package io.grpc.testing.integration; + +import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY; +import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientCall.Listener; +import io.grpc.ClientInterceptor; +import io.grpc.CompressorRegistry; +import io.grpc.DecompressorRegistry; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; +import io.grpc.testing.TestUtils; +import io.grpc.testing.integration.Messages.SimpleRequest; +import io.grpc.testing.integration.TestServiceGrpc.TestServiceBlockingStub; +import io.grpc.testing.integration.TransportCompressionTest.Fzip; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; + +/** + * Tests for compression configurations. + * + *

Because of the asymmetry of clients and servers, clients will not know what decompression + * methods the server supports. In cases where the client is willing to encode, and the server + * is willing to decode, a second RPC is sent to show that the client has learned and will use + * the encoding. + * + *

In cases where compression is negotiated, but either the client or the server doesn't + * actually want to encode, a dummy codec is used to record usage. If compression is not enabled, + * the codec will see no data pass through. This is checked on each test to ensure the code is + * doing the right thing. + */ +@RunWith(Parameterized.class) +public class CompressionTest { + private static final ScheduledExecutorService executor = Executors.newScheduledThreadPool(2); + // Ensures that both the request and response messages are more than 0 bytes. The framer/deframer + // may not use the compressor if the message is empty. + private static final SimpleRequest REQUEST = SimpleRequest.newBuilder() + .setResponseSize(1) + .build(); + + private Fzip clientCodec = new Fzip(); + private Fzip serverCodec = new Fzip(); + private DecompressorRegistry clientDecompressors = DecompressorRegistry.newEmptyInstance(); + private DecompressorRegistry serverDecompressors = DecompressorRegistry.newEmptyInstance(); + private CompressorRegistry clientCompressors = CompressorRegistry.newEmptyInstance(); + private CompressorRegistry serverCompressors = CompressorRegistry.newEmptyInstance(); + + /** The headers received by the server from the client */ + private volatile Metadata serverResponseHeaders; + /** The headers received by the client from the server */ + private volatile Metadata clientResponseHeaders; + + // Params + private final boolean enableClientMessageCompression; + private final boolean enableServerMessageCompression; + private final boolean clientAcceptEncoding; + private final boolean clientEncoding; + private final boolean serverAcceptEncoding; + private final boolean serverEncoding; + + private Server server; + private ManagedChannel channel; + private TestServiceBlockingStub stub; + + public CompressionTest( + boolean enableClientMessageCompression, + boolean clientAcceptEncoding, + boolean clientEncoding, + boolean enableServerMessageCompression, + boolean serverAcceptEncoding, + boolean serverEncoding) { + this.enableClientMessageCompression = enableClientMessageCompression; + this.clientAcceptEncoding = clientAcceptEncoding; + this.clientEncoding = clientEncoding; + this.enableServerMessageCompression = enableServerMessageCompression; + this.serverAcceptEncoding = serverAcceptEncoding; + this.serverEncoding = serverEncoding; + } + + @Before + public void setUp() throws Exception { + int serverPort = TestUtils.pickUnusedPort(); + server = ServerBuilder.forPort(serverPort) + .addService(ServerInterceptors.intercept( + TestServiceGrpc.bindService(new TestServiceImpl(executor)), + new ServerCompressorInterceptor())) + .compressorRegistry(serverCompressors) + .decompressorRegistry(serverDecompressors) + .build() + .start(); + + channel = ManagedChannelBuilder.forAddress("localhost", serverPort) + .decompressorRegistry(clientDecompressors) + .compressorRegistry(clientCompressors) + .intercept(new ClientCompressorInterceptor()) + .usePlaintext(true) + .build(); + stub = TestServiceGrpc.newBlockingStub(channel); + } + + @After + public void tearDown() { + channel.shutdownNow(); + server.shutdownNow(); + executor.shutdownNow(); + } + + @Parameters + public static Collection params() { + Boolean[] bools = new Boolean[]{false, true}; + List combos = new ArrayList(64); + for (boolean enableClientMessageCompression : bools) { + for (boolean clientAcceptEncoding : bools) { + for (boolean clientEncoding : bools) { + for (boolean enableServerMessageCompression : bools) { + for (boolean serverAcceptEncoding : bools) { + for (boolean serverEncoding : bools) { + combos.add(new Object[] { + enableClientMessageCompression, clientAcceptEncoding, clientEncoding, + enableServerMessageCompression, serverAcceptEncoding, serverEncoding}); + } + } + } + } + } + } + return combos; + } + + @Test + public void compression() { + if (clientAcceptEncoding) { + clientDecompressors.register(clientCodec, true); + } + if (clientEncoding) { + clientCompressors.register(clientCodec); + } + if (serverAcceptEncoding) { + serverDecompressors.register(serverCodec, true); + } + if (serverEncoding) { + serverCompressors.register(serverCodec); + } + + stub.unaryCall(REQUEST); + + if (clientAcceptEncoding && serverEncoding) { + assertEquals("fzip", clientResponseHeaders.get(MESSAGE_ENCODING_KEY)); + if (enableServerMessageCompression) { + assertTrue(clientCodec.anyRead); + assertTrue(serverCodec.anyWritten); + } else { + assertFalse(clientCodec.anyRead); + assertFalse(serverCodec.anyWritten); + } + } else { + assertNull(clientResponseHeaders.get(MESSAGE_ENCODING_KEY)); + assertFalse(clientCodec.anyRead); + assertFalse(serverCodec.anyWritten); + } + + if (serverAcceptEncoding) { + assertEquals("fzip", clientResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY)); + } else { + assertNull(clientResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY)); + } + + // Must be null for the first call. + assertNull(serverResponseHeaders.get(MESSAGE_ENCODING_KEY)); + assertFalse(clientCodec.anyWritten); + assertFalse(serverCodec.anyRead); + + if (clientAcceptEncoding) { + assertEquals("fzip", serverResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY)); + } else { + assertNull(serverResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY)); + } + + // Second call, once the client knows what the server supports. + stub.unaryCall(REQUEST); + if (clientEncoding && serverAcceptEncoding) { + assertEquals("fzip", serverResponseHeaders.get(MESSAGE_ENCODING_KEY)); + if (enableClientMessageCompression) { + assertTrue(clientCodec.anyWritten); + assertTrue(serverCodec.anyRead); + } else { + assertFalse(clientCodec.anyWritten); + assertFalse(serverCodec.anyRead); + } + } else { + assertNull(serverResponseHeaders.get(MESSAGE_ENCODING_KEY)); + assertFalse(clientCodec.anyWritten); + assertFalse(serverCodec.anyRead); + } + } + + private class ServerCompressorInterceptor implements ServerInterceptor { + @Override + public io.grpc.ServerCall.Listener interceptCall( + MethodDescriptor method, ServerCall call, Metadata headers, + ServerCallHandler next) { + call.setMessageCompression(enableServerMessageCompression); + serverResponseHeaders = headers; + return next.startCall(method, call, headers); + } + } + + private class ClientCompressorInterceptor implements ClientInterceptor { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + final ClientCall call = next.newCall(method, callOptions); + return new ClientCompressor(call); + } + } + + private class ClientCompressor extends SimpleForwardingClientCall { + protected ClientCompressor(ClientCall delegate) { + super(delegate); + } + + @Override + public void start(io.grpc.ClientCall.Listener responseListener, Metadata headers) { + super.start(new ClientHeadersCapture(responseListener), headers); + setMessageCompression(enableClientMessageCompression); + } + } + + private class ClientHeadersCapture extends SimpleForwardingClientCallListener { + private ClientHeadersCapture(Listener delegate) { + super(delegate); + } + + @Override + public void onHeaders(Metadata headers) { + super.onHeaders(headers); + clientResponseHeaders = headers; + } + } +} + diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java new file mode 100644 index 0000000000..0718a9000c --- /dev/null +++ b/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java @@ -0,0 +1,227 @@ +/* + * Copyright 2015, Google Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package io.grpc.testing.integration; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.protobuf.ByteString; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.Codec; +import io.grpc.CompressorRegistry; +import io.grpc.DecompressorRegistry; +import io.grpc.ForwardingClientCall; +import io.grpc.ForwardingClientCallListener; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.internal.GrpcUtil; +import io.grpc.testing.TestUtils; +import io.grpc.testing.integration.Messages.Payload; +import io.grpc.testing.integration.Messages.PayloadType; +import io.grpc.testing.integration.Messages.SimpleRequest; +import io.grpc.testing.integration.Messages.SimpleResponse; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.FilterInputStream; +import java.io.FilterOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * Tests that compression is turned on. + */ +@RunWith(JUnit4.class) +public class TransportCompressionTest extends AbstractTransportTest { + + private static int serverPort = TestUtils.pickUnusedPort(); + + private static final DecompressorRegistry decompressors = DecompressorRegistry.newEmptyInstance(); + private static final CompressorRegistry compressors = CompressorRegistry.newEmptyInstance(); + + @Before + public void beforeTests() { + Fzip.INSTANCE.anyRead = false; + Fzip.INSTANCE.anyWritten = false; + } + + /** Start server. */ + @BeforeClass + public static void startServer() { + decompressors.register(Fzip.INSTANCE, true); + compressors.register(Fzip.INSTANCE); + startStaticServer( + ServerBuilder.forPort(serverPort) + .compressorRegistry(compressors) + .decompressorRegistry(decompressors), + new ServerInterceptor() { + @Override + public Listener interceptCall(MethodDescriptor method, + ServerCall call, Metadata headers, ServerCallHandler next) { + Listener listener = next.startCall(method, call, headers); + // TODO(carl-mastrangelo): check that encoding was set. + call.setMessageCompression(true); + return listener; + }}); + } + + /** Stop server. */ + @AfterClass + public static void stopServer() { + stopStaticServer(); + } + + @Test + public void compresses() { + final SimpleRequest request = SimpleRequest.newBuilder() + .setResponseSize(314159) + .setResponseType(PayloadType.COMPRESSABLE) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[271828]))) + .build(); + final SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setPayload(Payload.newBuilder() + .setType(PayloadType.COMPRESSABLE) + .setBody(ByteString.copyFrom(new byte[314159]))) + .build(); + + assertEquals(goldenResponse, blockingStub.unaryCall(request)); + // Assert that compression took place + assertTrue(Fzip.INSTANCE.anyRead); + assertTrue(Fzip.INSTANCE.anyWritten); + } + + @Override + protected ManagedChannel createChannel() { + return ManagedChannelBuilder.forAddress("localhost", serverPort) + .decompressorRegistry(decompressors) + .compressorRegistry(compressors) + .intercept(new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + final ClientCall call = next.newCall(method, callOptions); + return new ForwardingClientCall() { + + @Override + protected ClientCall delegate() { + return call; + } + + @Override + public void start( + final ClientCall.Listener responseListener, Metadata headers) { + ClientCall.Listener listener = new ForwardingClientCallListener() { + + @Override + protected io.grpc.ClientCall.Listener delegate() { + return responseListener; + } + + @Override + public void onHeaders(Metadata headers) { + super.onHeaders(headers); + String encoding = headers.get(GrpcUtil.MESSAGE_ENCODING_KEY); + assertEquals(encoding, Fzip.INSTANCE.getMessageEncoding()); + } + }; + super.start(listener, headers); + setMessageCompression(true); + } + }; + } + }) + .usePlaintext(true) + .build(); + } + + static final class Fzip implements Codec { + static final Fzip INSTANCE = new Fzip(); + + boolean anyRead; + boolean anyWritten; + + @Override + public String getMessageEncoding() { + return "fzip"; + } + + @Override + public OutputStream compress(OutputStream os) throws IOException { + return new FilterOutputStream(os) { + @Override + public void write(int b) throws IOException { + super.write(b); + anyWritten = true; + } + }; + } + + @Override + public InputStream decompress(InputStream is) throws IOException { + return new FilterInputStream(is) { + @Override + public int read() throws IOException { + int val = super.read(); + anyRead = true; + return val; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + int total = super.read(b, off, len); + anyRead = true; + return total; + } + }; + } + } +} +