Fix bug where server wouldn't declare the negotiated compression

This commit is contained in:
Carl Mastrangelo 2015-12-09 16:49:29 -08:00
parent d238d86a3c
commit 7ac44928be
13 changed files with 667 additions and 29 deletions

View File

@ -31,8 +31,11 @@
package io.grpc;
import static com.google.common.base.Preconditions.checkArgument;
import com.google.common.annotations.VisibleForTesting;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
@ -70,4 +73,15 @@ public final class CompressorRegistry {
public Compressor lookupCompressor(String compressorName) {
return compressors.get(compressorName);
}
/**
* Registers a compressor for both decompression and message encoding negotiation.
*
* @param c The compressor to register
*/
public void register(Compressor c) {
String encoding = c.getMessageEncoding();
checkArgument(!encoding.contains(","), "Comma is currently not allowed in message encoding");
compressors.put(encoding, c);
}
}

View File

@ -31,6 +31,7 @@
package io.grpc.inprocess;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata;
@ -326,11 +327,13 @@ class InProcessTransport implements ServerTransport, ClientTransport {
@Override
public void setMessageCompression(boolean enable) {
// noop
// noop
}
@Override
public void pickCompressor(Iterable<String> messageEncodings) {}
public Compressor pickCompressor(Iterable<String> messageEncodings) {
return null;
}
@Override
public void setCompressionRegistry(CompressorRegistry registry) {}
@ -447,7 +450,9 @@ class InProcessTransport implements ServerTransport, ClientTransport {
public void setMessageCompression(boolean enable) {}
@Override
public void pickCompressor(Iterable<String> messageEncodings) {}
public Compressor pickCompressor(Iterable<String> messageEncodings) {
return null;
}
@Override
public void setCompressionRegistry(CompressorRegistry registry) {}

View File

@ -32,10 +32,14 @@
package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER;
import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;
import com.google.common.base.Preconditions;
import io.grpc.Compressor;
import io.grpc.Metadata;
import io.grpc.Status;
@ -57,6 +61,7 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
private ServerStreamListener listener;
private boolean headersSent = false;
private String messageEncoding;
/**
* Whether the stream was closed gracefully by the application (vs. a transport-level failure).
*/
@ -95,6 +100,17 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
@Override
public final void writeHeaders(Metadata headers) {
Preconditions.checkNotNull(headers, "headers");
headers.removeAll(MESSAGE_ENCODING_KEY);
if (messageEncoding != null) {
headers.put(MESSAGE_ENCODING_KEY, messageEncoding);
}
headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY);
if (!decompressorRegistry().getAdvertisedMessageEncodings().isEmpty()) {
String acceptEncoding =
ACCEPT_ENCODING_JOINER.join(decompressorRegistry().getAdvertisedMessageEncodings());
headers.put(MESSAGE_ACCEPT_ENCODING_KEY, acceptEncoding);
}
outboundPhase(Phase.HEADERS);
headersSent = true;
internalSendHeaders(headers);
@ -148,9 +164,20 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
return;
}
}
// This checks to see if the client will accept any encoding. If so, a compressor is picked for
// the stream, and the decision is recorded. When the Server Call Handler writes the first
// headers, the negotiated encoding will be added in #writeHeaders(). It is safe to call
// pickCompressor multiple times before the headers have been written to the wire, though in
// practice this should never happen. There should only be one call to inboundHeadersReceived.
// Alternatively, compression could be negotiated after the server handler is invoked, but that
// would mean the inbound header would have to be stored until the first #writeHeaders call.
if (headers.containsKey(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)) {
pickCompressor(
Compressor c = pickCompressor(
ACCEPT_ENCODING_SPLITER.split(headers.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)));
if (c != null) {
messageEncoding = c.getMessageEncoding();
}
}
inboundPhase(Phase.MESSAGE);

View File

@ -330,16 +330,22 @@ public abstract class AbstractStream<IdT> implements Stream {
}
@Override
public final void pickCompressor(Iterable<String> messageEncodings) {
public final Compressor pickCompressor(Iterable<String> messageEncodings) {
for (String messageEncoding : messageEncodings) {
Compressor c = compressorRegistry.lookupCompressor(messageEncoding);
if (c != null) {
// TODO(carl-mastrangelo): check that headers haven't already been sent. I can't find where
// the client stream changes outbound phase correctly, so I am ignoring it.
framer.setCompressor(c);
break;
return c;
}
}
return null;
}
// TODO(carl-mastrangelo): this is a hack to get around registry passing. Remove it.
protected final DecompressorRegistry decompressorRegistry() {
return decompressorRegistry;
}
/**

View File

@ -34,6 +34,7 @@ package io.grpc.internal;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterables.addAll;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER;
import static io.grpc.internal.GrpcUtil.AUTHORITY_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY;
@ -181,7 +182,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY);
if (!decompressorRegistry.getAdvertisedMessageEncodings().isEmpty()) {
String acceptEncoding =
Joiner.on(',').join(decompressorRegistry.getAdvertisedMessageEncodings());
ACCEPT_ENCODING_JOINER.join(decompressorRegistry.getAdvertisedMessageEncodings());
headers.put(MESSAGE_ACCEPT_ENCODING_KEY, acceptEncoding);
}
}

View File

@ -215,13 +215,17 @@ class DelayedStream implements ClientStream {
}
@Override
public void pickCompressor(Iterable<String> messageEncodings) {
public Compressor pickCompressor(Iterable<String> messageEncodings) {
synchronized (this) {
compressionMessageEncodings = messageEncodings;
if (realStream != null) {
realStream.pickCompressor(messageEncodings);
return realStream.pickCompressor(messageEncodings);
}
}
// ClientCall never uses this. Since the stream doesn't exist yet, it can't say what
// stream it would pick. Eventually this will need a cleaner solution.
// TODO(carl-mastrangelo): Remove this.
return null;
}
@Override

View File

@ -36,6 +36,7 @@ import static io.grpc.Status.Code.CANCELLED;
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.base.Splitter;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
@ -159,6 +160,8 @@ public final class GrpcUtil {
public static final Splitter ACCEPT_ENCODING_SPLITER = Splitter.on(',').trimResults();
public static final Joiner ACCEPT_ENCODING_JOINER = Joiner.on(',');
/**
* Maps HTTP error response status codes to transport codes.
*/

View File

@ -31,6 +31,7 @@
package io.grpc.internal;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Status;
@ -72,7 +73,9 @@ public class NoopClientStream implements ClientStream {
}
@Override
public void pickCompressor(Iterable<String> messageEncodings) {}
public Compressor pickCompressor(Iterable<String> messageEncodings) {
return null;
}
@Override
public void setCompressionRegistry(CompressorRegistry registry) {}

View File

@ -31,11 +31,14 @@
package io.grpc.internal;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import java.io.InputStream;
import javax.annotation.Nullable;
/**
* A single stream of communication between two end-points within a transport.
*
@ -81,12 +84,15 @@ public interface Stream {
/**
* Picks a compressor for for this stream. If no message encodings are acceptable, compression is
* not used.
* not used. It is undefined if this this method is invoked multiple times.
*
*
* @param messageEncodings a group of message encoding names that the remote endpoint is known
* to support.
* @return The compressor chosen for the stream, or null if none selected.
*/
void pickCompressor(Iterable<String> messageEncodings);
@Nullable
Compressor pickCompressor(Iterable<String> messageEncodings);
/**
* Enables per-message compression, if an encoding type has been negotiated. If no message

View File

@ -31,14 +31,22 @@
package io.grpc.examples.experimental;
import com.google.common.util.concurrent.Uninterruptibles;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ClientCall.Listener;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloRequest;
import io.grpc.examples.helloworld.HelloResponse;
import io.grpc.internal.GrpcUtil;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
@ -53,14 +61,12 @@ public class CompressingHelloWorldClient {
Logger.getLogger(CompressingHelloWorldClient.class.getName());
private final ManagedChannel channel;
private final GreeterGrpc.GreeterBlockingStub blockingStub;
/** Construct client connecting to HelloWorld server at {@code host:port}. */
public CompressingHelloWorldClient(String host, int port) {
channel = ManagedChannelBuilder.forAddress(host, port)
.usePlaintext(true)
.build();
blockingStub = GreeterGrpc.newBlockingStub(channel);
}
public void shutdown() throws InterruptedException {
@ -68,16 +74,44 @@ public class CompressingHelloWorldClient {
}
/** Say hello to server. */
public void greet(String name) {
try {
logger.info("Will try to greet " + name + " ...");
HelloRequest request = HelloRequest.newBuilder().setName(name).build();
HelloResponse response = blockingStub.sayHello(request);
logger.info("Greeting: " + response.getMessage());
} catch (RuntimeException e) {
logger.log(Level.WARNING, "RPC failed", e);
return;
}
public void greet(final String name) {
final ClientCall<HelloRequest, HelloResponse> call =
channel.newCall(GreeterGrpc.METHOD_SAY_HELLO, CallOptions.DEFAULT);
final CountDownLatch latch = new CountDownLatch(1);
call.start(new Listener<HelloResponse>() {
@Override
public void onHeaders(Metadata headers) {
super.onHeaders(headers);
String encoding = headers.get(GrpcUtil.MESSAGE_ENCODING_KEY);
if (encoding == null) {
throw new RuntimeException("No compression selected!");
}
}
@Override
public void onMessage(HelloResponse message) {
super.onMessage(message);
logger.info("Greeting: " + message.getMessage());
latch.countDown();
}
@Override
public void onClose(Status status, Metadata trailers) {
latch.countDown();
if (!status.isOk()) {
throw status.asRuntimeException();
}
}
}, new Metadata());
call.setMessageCompression(true);
call.sendMessage(HelloRequest.newBuilder().setName(name).build());
call.request(1);
call.halfClose();
Uninterruptibles.awaitUninterruptibly(latch, 100, TimeUnit.SECONDS);
}
/**

View File

@ -49,6 +49,7 @@ import com.google.auth.oauth2.GoogleCredentials;
import com.google.auth.oauth2.OAuth2Credentials;
import com.google.auth.oauth2.ServiceAccountCredentials;
import com.google.auth.oauth2.ServiceAccountJwtAccessCredentials;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import com.google.protobuf.EmptyProtos.Empty;
@ -58,6 +59,7 @@ import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerInterceptor;
import io.grpc.ServerInterceptors;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
@ -107,13 +109,19 @@ public abstract class AbstractTransportTest {
private static Server server;
private static int OPERATION_TIMEOUT = 5000;
protected static void startStaticServer(ServerBuilder<?> builder) {
protected static void startStaticServer(
ServerBuilder<?> builder, ServerInterceptor ... interceptors) {
testServiceExecutor = Executors.newScheduledThreadPool(2);
List<ServerInterceptor> allInterceptors = ImmutableList.<ServerInterceptor>builder()
.add(TestUtils.recordRequestHeadersInterceptor(requestHeadersCapture))
.add(TestUtils.echoRequestHeadersInterceptor(Util.METADATA_KEY))
.add(interceptors)
.build();
builder.addService(ServerInterceptors.intercept(
TestServiceGrpc.bindService(new TestServiceImpl(testServiceExecutor)),
TestUtils.recordRequestHeadersInterceptor(requestHeadersCapture),
TestUtils.echoRequestHeadersInterceptor(Util.METADATA_KEY)));
allInterceptors));
try {
server = builder.build().start();
} catch (IOException ex) {
@ -584,7 +592,7 @@ public abstract class AbstractTransportTest {
Assert.assertEquals(contextValue, trailersCapture.get().get(METADATA_KEY));
}
@Test(timeout = 10000)
@Test(timeout = 100000000)
public void sendsTimeoutHeader() {
long configuredTimeoutMinutes = 100;
TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel)

View File

@ -0,0 +1,300 @@
/*
* Copyright 2015, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.testing.integration;
import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientCall.Listener;
import io.grpc.ClientInterceptor;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.ServerInterceptors;
import io.grpc.testing.TestUtils;
import io.grpc.testing.integration.Messages.SimpleRequest;
import io.grpc.testing.integration.TestServiceGrpc.TestServiceBlockingStub;
import io.grpc.testing.integration.TransportCompressionTest.Fzip;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
/**
* Tests for compression configurations.
*
* <p>Because of the asymmetry of clients and servers, clients will not know what decompression
* methods the server supports. In cases where the client is willing to encode, and the server
* is willing to decode, a second RPC is sent to show that the client has learned and will use
* the encoding.
*
* <p> In cases where compression is negotiated, but either the client or the server doesn't
* actually want to encode, a dummy codec is used to record usage. If compression is not enabled,
* the codec will see no data pass through. This is checked on each test to ensure the code is
* doing the right thing.
*/
@RunWith(Parameterized.class)
public class CompressionTest {
private static final ScheduledExecutorService executor = Executors.newScheduledThreadPool(2);
// Ensures that both the request and response messages are more than 0 bytes. The framer/deframer
// may not use the compressor if the message is empty.
private static final SimpleRequest REQUEST = SimpleRequest.newBuilder()
.setResponseSize(1)
.build();
private Fzip clientCodec = new Fzip();
private Fzip serverCodec = new Fzip();
private DecompressorRegistry clientDecompressors = DecompressorRegistry.newEmptyInstance();
private DecompressorRegistry serverDecompressors = DecompressorRegistry.newEmptyInstance();
private CompressorRegistry clientCompressors = CompressorRegistry.newEmptyInstance();
private CompressorRegistry serverCompressors = CompressorRegistry.newEmptyInstance();
/** The headers received by the server from the client */
private volatile Metadata serverResponseHeaders;
/** The headers received by the client from the server */
private volatile Metadata clientResponseHeaders;
// Params
private final boolean enableClientMessageCompression;
private final boolean enableServerMessageCompression;
private final boolean clientAcceptEncoding;
private final boolean clientEncoding;
private final boolean serverAcceptEncoding;
private final boolean serverEncoding;
private Server server;
private ManagedChannel channel;
private TestServiceBlockingStub stub;
public CompressionTest(
boolean enableClientMessageCompression,
boolean clientAcceptEncoding,
boolean clientEncoding,
boolean enableServerMessageCompression,
boolean serverAcceptEncoding,
boolean serverEncoding) {
this.enableClientMessageCompression = enableClientMessageCompression;
this.clientAcceptEncoding = clientAcceptEncoding;
this.clientEncoding = clientEncoding;
this.enableServerMessageCompression = enableServerMessageCompression;
this.serverAcceptEncoding = serverAcceptEncoding;
this.serverEncoding = serverEncoding;
}
@Before
public void setUp() throws Exception {
int serverPort = TestUtils.pickUnusedPort();
server = ServerBuilder.forPort(serverPort)
.addService(ServerInterceptors.intercept(
TestServiceGrpc.bindService(new TestServiceImpl(executor)),
new ServerCompressorInterceptor()))
.compressorRegistry(serverCompressors)
.decompressorRegistry(serverDecompressors)
.build()
.start();
channel = ManagedChannelBuilder.forAddress("localhost", serverPort)
.decompressorRegistry(clientDecompressors)
.compressorRegistry(clientCompressors)
.intercept(new ClientCompressorInterceptor())
.usePlaintext(true)
.build();
stub = TestServiceGrpc.newBlockingStub(channel);
}
@After
public void tearDown() {
channel.shutdownNow();
server.shutdownNow();
executor.shutdownNow();
}
@Parameters
public static Collection<Object[]> params() {
Boolean[] bools = new Boolean[]{false, true};
List<Object[]> combos = new ArrayList<Object[]>(64);
for (boolean enableClientMessageCompression : bools) {
for (boolean clientAcceptEncoding : bools) {
for (boolean clientEncoding : bools) {
for (boolean enableServerMessageCompression : bools) {
for (boolean serverAcceptEncoding : bools) {
for (boolean serverEncoding : bools) {
combos.add(new Object[] {
enableClientMessageCompression, clientAcceptEncoding, clientEncoding,
enableServerMessageCompression, serverAcceptEncoding, serverEncoding});
}
}
}
}
}
}
return combos;
}
@Test
public void compression() {
if (clientAcceptEncoding) {
clientDecompressors.register(clientCodec, true);
}
if (clientEncoding) {
clientCompressors.register(clientCodec);
}
if (serverAcceptEncoding) {
serverDecompressors.register(serverCodec, true);
}
if (serverEncoding) {
serverCompressors.register(serverCodec);
}
stub.unaryCall(REQUEST);
if (clientAcceptEncoding && serverEncoding) {
assertEquals("fzip", clientResponseHeaders.get(MESSAGE_ENCODING_KEY));
if (enableServerMessageCompression) {
assertTrue(clientCodec.anyRead);
assertTrue(serverCodec.anyWritten);
} else {
assertFalse(clientCodec.anyRead);
assertFalse(serverCodec.anyWritten);
}
} else {
assertNull(clientResponseHeaders.get(MESSAGE_ENCODING_KEY));
assertFalse(clientCodec.anyRead);
assertFalse(serverCodec.anyWritten);
}
if (serverAcceptEncoding) {
assertEquals("fzip", clientResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY));
} else {
assertNull(clientResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY));
}
// Must be null for the first call.
assertNull(serverResponseHeaders.get(MESSAGE_ENCODING_KEY));
assertFalse(clientCodec.anyWritten);
assertFalse(serverCodec.anyRead);
if (clientAcceptEncoding) {
assertEquals("fzip", serverResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY));
} else {
assertNull(serverResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY));
}
// Second call, once the client knows what the server supports.
stub.unaryCall(REQUEST);
if (clientEncoding && serverAcceptEncoding) {
assertEquals("fzip", serverResponseHeaders.get(MESSAGE_ENCODING_KEY));
if (enableClientMessageCompression) {
assertTrue(clientCodec.anyWritten);
assertTrue(serverCodec.anyRead);
} else {
assertFalse(clientCodec.anyWritten);
assertFalse(serverCodec.anyRead);
}
} else {
assertNull(serverResponseHeaders.get(MESSAGE_ENCODING_KEY));
assertFalse(clientCodec.anyWritten);
assertFalse(serverCodec.anyRead);
}
}
private class ServerCompressorInterceptor implements ServerInterceptor {
@Override
public <ReqT, RespT> io.grpc.ServerCall.Listener<ReqT> interceptCall(
MethodDescriptor<ReqT, RespT> method, ServerCall<RespT> call, Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
call.setMessageCompression(enableServerMessageCompression);
serverResponseHeaders = headers;
return next.startCall(method, call, headers);
}
}
private class ClientCompressorInterceptor implements ClientInterceptor {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
final ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
return new ClientCompressor<ReqT, RespT>(call);
}
}
private class ClientCompressor<ReqT, RespT> extends SimpleForwardingClientCall<ReqT, RespT> {
protected ClientCompressor(ClientCall<ReqT, RespT> delegate) {
super(delegate);
}
@Override
public void start(io.grpc.ClientCall.Listener<RespT> responseListener, Metadata headers) {
super.start(new ClientHeadersCapture<RespT>(responseListener), headers);
setMessageCompression(enableClientMessageCompression);
}
}
private class ClientHeadersCapture<RespT> extends SimpleForwardingClientCallListener<RespT> {
private ClientHeadersCapture(Listener<RespT> delegate) {
super(delegate);
}
@Override
public void onHeaders(Metadata headers) {
super.onHeaders(headers);
clientResponseHeaders = headers;
}
}
}

View File

@ -0,0 +1,227 @@
/*
* Copyright 2015, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.testing.integration;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import com.google.protobuf.ByteString;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.Codec;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.ForwardingClientCall;
import io.grpc.ForwardingClientCallListener;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerBuilder;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.internal.GrpcUtil;
import io.grpc.testing.TestUtils;
import io.grpc.testing.integration.Messages.Payload;
import io.grpc.testing.integration.Messages.PayloadType;
import io.grpc.testing.integration.Messages.SimpleRequest;
import io.grpc.testing.integration.Messages.SimpleResponse;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import java.io.FilterInputStream;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
/**
* Tests that compression is turned on.
*/
@RunWith(JUnit4.class)
public class TransportCompressionTest extends AbstractTransportTest {
private static int serverPort = TestUtils.pickUnusedPort();
private static final DecompressorRegistry decompressors = DecompressorRegistry.newEmptyInstance();
private static final CompressorRegistry compressors = CompressorRegistry.newEmptyInstance();
@Before
public void beforeTests() {
Fzip.INSTANCE.anyRead = false;
Fzip.INSTANCE.anyWritten = false;
}
/** Start server. */
@BeforeClass
public static void startServer() {
decompressors.register(Fzip.INSTANCE, true);
compressors.register(Fzip.INSTANCE);
startStaticServer(
ServerBuilder.forPort(serverPort)
.compressorRegistry(compressors)
.decompressorRegistry(decompressors),
new ServerInterceptor() {
@Override
public <ReqT, RespT> Listener<ReqT> interceptCall(MethodDescriptor<ReqT, RespT> method,
ServerCall<RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
Listener<ReqT> listener = next.startCall(method, call, headers);
// TODO(carl-mastrangelo): check that encoding was set.
call.setMessageCompression(true);
return listener;
}});
}
/** Stop server. */
@AfterClass
public static void stopServer() {
stopStaticServer();
}
@Test
public void compresses() {
final SimpleRequest request = SimpleRequest.newBuilder()
.setResponseSize(314159)
.setResponseType(PayloadType.COMPRESSABLE)
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFrom(new byte[271828])))
.build();
final SimpleResponse goldenResponse = SimpleResponse.newBuilder()
.setPayload(Payload.newBuilder()
.setType(PayloadType.COMPRESSABLE)
.setBody(ByteString.copyFrom(new byte[314159])))
.build();
assertEquals(goldenResponse, blockingStub.unaryCall(request));
// Assert that compression took place
assertTrue(Fzip.INSTANCE.anyRead);
assertTrue(Fzip.INSTANCE.anyWritten);
}
@Override
protected ManagedChannel createChannel() {
return ManagedChannelBuilder.forAddress("localhost", serverPort)
.decompressorRegistry(decompressors)
.compressorRegistry(compressors)
.intercept(new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
final ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
return new ForwardingClientCall<ReqT, RespT>() {
@Override
protected ClientCall<ReqT, RespT> delegate() {
return call;
}
@Override
public void start(
final ClientCall.Listener<RespT> responseListener, Metadata headers) {
ClientCall.Listener<RespT> listener = new ForwardingClientCallListener<RespT>() {
@Override
protected io.grpc.ClientCall.Listener<RespT> delegate() {
return responseListener;
}
@Override
public void onHeaders(Metadata headers) {
super.onHeaders(headers);
String encoding = headers.get(GrpcUtil.MESSAGE_ENCODING_KEY);
assertEquals(encoding, Fzip.INSTANCE.getMessageEncoding());
}
};
super.start(listener, headers);
setMessageCompression(true);
}
};
}
})
.usePlaintext(true)
.build();
}
static final class Fzip implements Codec {
static final Fzip INSTANCE = new Fzip();
boolean anyRead;
boolean anyWritten;
@Override
public String getMessageEncoding() {
return "fzip";
}
@Override
public OutputStream compress(OutputStream os) throws IOException {
return new FilterOutputStream(os) {
@Override
public void write(int b) throws IOException {
super.write(b);
anyWritten = true;
}
};
}
@Override
public InputStream decompress(InputStream is) throws IOException {
return new FilterInputStream(is) {
@Override
public int read() throws IOException {
int val = super.read();
anyRead = true;
return val;
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
int total = super.read(b, off, len);
anyRead = true;
return total;
}
};
}
}
}