From a3c79e87ae85205bbdb2ed4de03813760ce755ee Mon Sep 17 00:00:00 2001 From: Carl Mastrangelo Date: Wed, 27 Jan 2016 16:49:41 -0800 Subject: [PATCH] Add a simple compression API --- core/src/main/java/io/grpc/CallOptions.java | 25 ++++++ .../java/io/grpc/ForwardingServerCall.java | 6 ++ core/src/main/java/io/grpc/ServerCall.java | 14 ++++ .../io/grpc/inprocess/InProcessTransport.java | 29 ++----- .../grpc/internal/AbstractClientStream.java | 15 ---- .../internal/AbstractServerImplBuilder.java | 14 ++-- .../grpc/internal/AbstractServerStream.java | 44 ---------- .../java/io/grpc/internal/AbstractStream.java | 50 ++--------- .../java/io/grpc/internal/ClientCallImpl.java | 74 +++++++++-------- .../java/io/grpc/internal/DelayedStream.java | 51 ++++-------- .../io/grpc/internal/ManagedChannelImpl.java | 18 +--- .../io/grpc/internal/NoopClientStream.java | 12 +-- .../java/io/grpc/internal/ServerCallImpl.java | 82 ++++++++++++++++++- .../java/io/grpc/internal/ServerImpl.java | 15 +++- .../grpc/internal/SingleTransportChannel.java | 2 +- .../main/java/io/grpc/internal/Stream.java | 46 +++-------- .../internal/AbstractClientStreamTest.java | 17 ---- .../io/grpc/internal/ClientCallImplTest.java | 51 ++++-------- .../io/grpc/internal/DelayedStreamTest.java | 13 ++- .../grpc/internal/ManagedChannelImplTest.java | 4 +- .../io/grpc/internal/ServerCallImplTest.java | 5 +- .../java/io/grpc/internal/ServerImplTest.java | 24 ++++-- .../testing/integration/CompressionTest.java | 12 ++- .../main/java/io/grpc/netty/NettyServer.java | 15 +--- .../io/grpc/netty/NettyServerBuilder.java | 12 +-- .../io/grpc/netty/NettyServerHandler.java | 24 +----- .../io/grpc/netty/NettyServerTransport.java | 18 +--- .../grpc/netty/NettyClientTransportTest.java | 8 +- .../io/grpc/netty/NettyServerHandlerTest.java | 3 - .../io/grpc/netty/NettyServerStreamTest.java | 3 - .../main/java/io/grpc/stub/AbstractStub.java | 15 ++++ 31 files changed, 305 insertions(+), 416 deletions(-) diff --git a/core/src/main/java/io/grpc/CallOptions.java b/core/src/main/java/io/grpc/CallOptions.java index 7beb177c23..3775fc80a0 100644 --- a/core/src/main/java/io/grpc/CallOptions.java +++ b/core/src/main/java/io/grpc/CallOptions.java @@ -63,6 +63,9 @@ public final class CallOptions { @Nullable private RequestKey requestKey; + @Nullable + private String compressorName; + /** * Override the HTTP/2 authority the channel claims to be connecting to. This is not * generally safe. Overriding allows advanced users to re-use a single Channel for multiple @@ -79,6 +82,17 @@ public final class CallOptions { return newOptions; } + /** + * Sets the compression to use for the call. The compressor must be a valid name known in the + * {@link CompressorRegistry}. + */ + @ExperimentalApi + public CallOptions withCompression(@Nullable String compressorName) { + CallOptions newOptions = new CallOptions(this); + newOptions.compressorName = compressorName; + return newOptions; + } + /** * Returns a new {@code CallOptions} with the given absolute deadline in nanoseconds in the clock * as per {@link System#nanoTime()}. @@ -131,6 +145,16 @@ public final class CallOptions { return requestKey; } + + /** + * Returns the compressor's name. + */ + @ExperimentalApi + @Nullable + public String getCompressor() { + return compressorName; + } + /** * Override the HTTP/2 authority the channel claims to be connecting to. This is not * generally safe. Overriding allows advanced users to re-use a single Channel for multiple @@ -172,6 +196,7 @@ public final class CallOptions { authority = other.authority; requestKey = other.requestKey; executor = other.executor; + compressorName = other.compressorName; } @Override diff --git a/core/src/main/java/io/grpc/ForwardingServerCall.java b/core/src/main/java/io/grpc/ForwardingServerCall.java index 42dd3b3dd0..3e0e434fe3 100644 --- a/core/src/main/java/io/grpc/ForwardingServerCall.java +++ b/core/src/main/java/io/grpc/ForwardingServerCall.java @@ -76,6 +76,12 @@ public abstract class ForwardingServerCall extends ServerCall { delegate().setMessageCompression(enabled); } + @Override + @ExperimentalApi + public void setCompression(String compressor) { + delegate().setCompression(compressor); + } + /** * A simplified version of {@link ForwardingServerCall} where subclasses can pass in a {@link * ServerCall} as the delegate. diff --git a/core/src/main/java/io/grpc/ServerCall.java b/core/src/main/java/io/grpc/ServerCall.java index 5ec52d0174..6e2b194003 100644 --- a/core/src/main/java/io/grpc/ServerCall.java +++ b/core/src/main/java/io/grpc/ServerCall.java @@ -176,4 +176,18 @@ public abstract class ServerCall { public void setMessageCompression(boolean enabled) { // noop } + + /** + * Sets the compression algorithm for this call. If the server does not support the compression + * algorithm, the call will fail. This method may only be called before {@link #sendHeaders}. + * The compressor to use will be looked up in the {@link CompressorRegistry}. Default gRPC + * servers support the "gzip" compressor. + * + * @param compressor the name of the compressor to use. + * @throws IllegalArgumentException if the compressor name can not be found. + */ + @ExperimentalApi + public void setCompression(String compressor) { + // noop + } } diff --git a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java index 938cf2b016..3498173d61 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java @@ -34,8 +34,7 @@ package io.grpc.inprocess; import static com.google.common.base.Preconditions.checkNotNull; import io.grpc.Compressor; -import io.grpc.CompressorRegistry; -import io.grpc.DecompressorRegistry; +import io.grpc.Decompressor; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -230,9 +229,6 @@ class InProcessTransport implements ServerTransport, ClientTransport { clientStreamListener = listener; } - @Override - public void setDecompressionRegistry(DecompressorRegistry registry) {} - @Override public void request(int numMessages) { clientStream.serverRequested(numMessages); @@ -345,12 +341,10 @@ class InProcessTransport implements ServerTransport, ClientTransport { } @Override - public Compressor pickCompressor(Iterable messageEncodings) { - return null; - } + public void setCompressor(Compressor compressor) {} @Override - public void setCompressionRegistry(CompressorRegistry registry) {} + public void setDecompressor(Decompressor decompressor) {} } private class InProcessClientStream implements ClientStream { @@ -457,20 +451,9 @@ class InProcessTransport implements ServerTransport, ClientTransport { } } - @Override - public void setDecompressionRegistry(DecompressorRegistry registry) {} - @Override public void setMessageCompression(boolean enable) {} - @Override - public Compressor pickCompressor(Iterable messageEncodings) { - return null; - } - - @Override - public void setCompressionRegistry(CompressorRegistry registry) {} - @Override public void start(ClientStreamListener listener) { serverStream.setListener(listener); @@ -483,6 +466,12 @@ class InProcessTransport implements ServerTransport, ClientTransport { streams.add(InProcessTransport.InProcessStream.this); } } + + @Override + public void setCompressor(Compressor compressor) {} + + @Override + public void setDecompressor(Decompressor decompressor) {} } } } diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index 4a93c1b1d1..563cf6c40f 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -122,21 +122,6 @@ public abstract class AbstractClientStream extends AbstractStream log.log(Level.INFO, "Received headers on closed stream {0} {1}", new Object[]{id(), headers}); } - if (headers.containsKey(GrpcUtil.MESSAGE_ENCODING_KEY)) { - String messageEncoding = headers.get(GrpcUtil.MESSAGE_ENCODING_KEY); - try { - setDecompressor(messageEncoding); - } catch (IllegalArgumentException e) { - // Don't use INVALID_ARGUMENT since that is for servers to send clients. - Status status = Status.INTERNAL.withDescription("Unable to decompress message from server.") - .withCause(e); - // TODO(carl-mastrangelo): look back into tearing down this stream. sendCancel() can be - // buffered. - inboundTransportError(status, headers); - sendCancel(status); - return; - } - } inboundPhase(Phase.MESSAGE); listener.headersRead(headers); diff --git a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java index 95df85ec42..132813a8fb 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java @@ -31,6 +31,8 @@ package io.grpc.internal; +import static com.google.common.base.MoreObjects.firstNonNull; + import com.google.common.base.Preconditions; import com.google.common.util.concurrent.MoreExecutors; @@ -112,24 +114,18 @@ public abstract class AbstractServerImplBuilder extends AbstractStream private ServerStreamListener listener; private boolean headersSent = false; - private String messageEncoding; /** * Whether the stream was closed gracefully by the application (vs. a transport-level failure). */ @@ -100,16 +94,6 @@ public abstract class AbstractServerStream extends AbstractStream @Override public final void writeHeaders(Metadata headers) { Preconditions.checkNotNull(headers, "headers"); - headers.removeAll(MESSAGE_ENCODING_KEY); - if (messageEncoding != null) { - headers.put(MESSAGE_ENCODING_KEY, messageEncoding); - } - headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY); - if (!decompressorRegistry().getAdvertisedMessageEncodings().isEmpty()) { - String acceptEncoding = - ACCEPT_ENCODING_JOINER.join(decompressorRegistry().getAdvertisedMessageEncodings()); - headers.put(MESSAGE_ACCEPT_ENCODING_KEY, acceptEncoding); - } outboundPhase(Phase.HEADERS); headersSent = true; @@ -152,34 +136,6 @@ public abstract class AbstractServerStream extends AbstractStream * @param headers the parsed headers */ protected void inboundHeadersReceived(Metadata headers) { - if (headers.containsKey(GrpcUtil.MESSAGE_ENCODING_KEY)) { - String messageEncoding = headers.get(GrpcUtil.MESSAGE_ENCODING_KEY); - try { - setDecompressor(messageEncoding); - } catch (IllegalArgumentException e) { - Status status = Status.INVALID_ARGUMENT - .withDescription("Unable to decompress encoding " + messageEncoding) - .withCause(e); - abortStream(status, true); - return; - } - } - // This checks to see if the client will accept any encoding. If so, a compressor is picked for - // the stream, and the decision is recorded. When the Server Call Handler writes the first - // headers, the negotiated encoding will be added in #writeHeaders(). It is safe to call - // pickCompressor multiple times before the headers have been written to the wire, though in - // practice this should never happen. There should only be one call to inboundHeadersReceived. - - // Alternatively, compression could be negotiated after the server handler is invoked, but that - // would mean the inbound header would have to be stored until the first #writeHeaders call. - if (headers.containsKey(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)) { - Compressor c = pickCompressor( - ACCEPT_ENCODING_SPLITER.split(headers.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY))); - if (c != null) { - messageEncoding = c.getMessageEncoding(); - } - } - inboundPhase(Phase.MESSAGE); } diff --git a/core/src/main/java/io/grpc/internal/AbstractStream.java b/core/src/main/java/io/grpc/internal/AbstractStream.java index cd0ccdf8b3..c4d4b49651 100644 --- a/core/src/main/java/io/grpc/internal/AbstractStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractStream.java @@ -31,7 +31,6 @@ package io.grpc.internal; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; @@ -40,9 +39,7 @@ import com.google.common.base.MoreObjects; import io.grpc.Codec; import io.grpc.Compressor; -import io.grpc.CompressorRegistry; import io.grpc.Decompressor; -import io.grpc.DecompressorRegistry; import java.io.InputStream; @@ -101,10 +98,6 @@ public abstract class AbstractStream implements Stream { private boolean allocated; private final Object onReadyLock = new Object(); - private volatile DecompressorRegistry decompressorRegistry = - DecompressorRegistry.getDefaultInstance(); - private volatile CompressorRegistry compressorRegistry = - CompressorRegistry.getDefaultInstance(); @VisibleForTesting class FramerSink implements MessageFramer.Sink { @@ -305,47 +298,14 @@ public abstract class AbstractStream implements Stream { } } - /** - * Looks up the decompressor by its message encoding name, and sets it for this stream. - * Decompressors are registered with {@link DecompressorRegistry#register}. - * - * @param messageEncoding the name of the encoding provided by the remote host - * @throws IllegalArgumentException if the provided message encoding cannot be found. - */ - protected final void setDecompressor(String messageEncoding) { - Decompressor d = decompressorRegistry.lookupDecompressor(messageEncoding); - checkArgument(d != null, - "Unable to find decompressor for message encoding %s", messageEncoding); - deframer.setDecompressor(d); + @Override + public final void setCompressor(Compressor compressor) { + framer.setCompressor(checkNotNull(compressor, "compressor")); } @Override - public final void setDecompressionRegistry(DecompressorRegistry registry) { - decompressorRegistry = checkNotNull(registry); - } - - @Override - public final void setCompressionRegistry(CompressorRegistry registry) { - compressorRegistry = checkNotNull(registry); - } - - @Override - public final Compressor pickCompressor(Iterable messageEncodings) { - for (String messageEncoding : messageEncodings) { - Compressor c = compressorRegistry.lookupCompressor(messageEncoding); - if (c != null) { - // TODO(carl-mastrangelo): check that headers haven't already been sent. I can't find where - // the client stream changes outbound phase correctly, so I am ignoring it. - framer.setCompressor(c); - return c; - } - } - return null; - } - - // TODO(carl-mastrangelo): this is a hack to get around registry passing. Remove it. - protected final DecompressorRegistry decompressorRegistry() { - return decompressorRegistry; + public final void setDecompressor(Decompressor decompressor) { + deframer.setDecompressor(checkNotNull(decompressor, "decompressor")); } /** diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index 2698476cc9..02ccffc2a7 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -33,10 +33,8 @@ package io.grpc.internal; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.Iterables.addAll; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER; -import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER; import static io.grpc.internal.GrpcUtil.AUTHORITY_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; @@ -57,6 +55,7 @@ import io.grpc.Codec; import io.grpc.Compressor; import io.grpc.CompressorRegistry; import io.grpc.Context; +import io.grpc.Decompressor; import io.grpc.DecompressorRegistry; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -64,8 +63,6 @@ import io.grpc.MethodDescriptor.MethodType; import io.grpc.Status; import java.io.InputStream; -import java.util.Collections; -import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; @@ -91,7 +88,7 @@ final class ClientCallImpl extends ClientCall private final ClientTransportProvider clientTransportProvider; private String userAgent; private ScheduledExecutorService deadlineCancellationExecutor; - private Set knownMessageEncodingRegistry; + private Compressor compressor; private DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance(); private CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance(); @@ -146,19 +143,9 @@ final class ClientCallImpl extends ClientCall return this; } - /** - * Sets encodings known to be supported by the server. This set MUST be thread safe, and MAY be - * modified by any code as it learns about new supported encodings. - */ - ClientCallImpl setKnownMessageEncodingRegistry(Set knownMessageEncodings) { - this.knownMessageEncodingRegistry = knownMessageEncodings; - return this; - } - @VisibleForTesting static void prepareHeaders(Metadata headers, CallOptions callOptions, String userAgent, - Set knownMessageEncodings, DecompressorRegistry decompressorRegistry, - CompressorRegistry compressorRegistry) { + DecompressorRegistry decompressorRegistry, Compressor compressor) { // Hack to propagate authority. This should be properly pass to the transport.newStream // somehow. headers.removeAll(AUTHORITY_KEY); @@ -173,12 +160,8 @@ final class ClientCallImpl extends ClientCall } headers.removeAll(MESSAGE_ENCODING_KEY); - for (String messageEncoding : knownMessageEncodings) { - Compressor compressor = compressorRegistry.lookupCompressor(messageEncoding); - if (compressor != null && compressor != Codec.Identity.NONE) { - headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding()); - break; - } + if (compressor != Codec.Identity.NONE) { + headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding()); } headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY); @@ -207,8 +190,27 @@ final class ClientCallImpl extends ClientCall }); return; } - prepareHeaders(headers, callOptions, userAgent, - knownMessageEncodingRegistry, decompressorRegistry, compressorRegistry); + final String compressorName = callOptions.getCompressor(); + if (compressorName != null) { + compressor = compressorRegistry.lookupCompressor(compressorName); + if (compressor == null) { + stream = NoopClientStream.INSTANCE; + callExecutor.execute(new ContextRunnable(context) { + @Override + public void runInContext() { + observer.onClose( + Status.INTERNAL.withDescription( + String.format("Unable to find compressor by name %s", compressorName)), + new Metadata()); + } + }); + return; + } + } else { + compressor = Codec.Identity.NONE; + } + + prepareHeaders(headers, callOptions, userAgent, decompressorRegistry, compressor); ListenableFuture transportFuture = clientTransportProvider.get(callOptions); @@ -236,11 +238,8 @@ final class ClientCallImpl extends ClientCall transportFuture.isDone() ? directExecutor() : callExecutor); } - stream.setDecompressionRegistry(decompressorRegistry); - stream.setCompressionRegistry(compressorRegistry); - if (headers.containsKey(MESSAGE_ENCODING_KEY)) { - stream.pickCompressor(Collections.singleton(headers.get(MESSAGE_ENCODING_KEY))); - // TODO(carl-mastrangelo): move this to ClientCall. + stream.setCompressor(compressor); + if (compressor != Codec.Identity.NONE) { stream.setMessageCompression(true); } @@ -387,13 +386,18 @@ final class ClientCallImpl extends ClientCall @Override public void headersRead(final Metadata headers) { - if (headers.containsKey(MESSAGE_ACCEPT_ENCODING_KEY)) { - // TODO(carl-mastrangelo): after the first time we contact the server, it almost certainly - // won't change. It might be possible to recover performance by not adding to the known - // encodings if it isn't empty. - String serverAcceptEncodings = headers.get(MESSAGE_ACCEPT_ENCODING_KEY); - addAll(knownMessageEncodingRegistry, ACCEPT_ENCODING_SPLITER.split(serverAcceptEncodings)); + Decompressor decompressor = Codec.Identity.NONE; + if (headers.containsKey(MESSAGE_ENCODING_KEY)) { + String encoding = headers.get(MESSAGE_ENCODING_KEY); + decompressor = decompressorRegistry.lookupDecompressor(encoding); + if (decompressor == null) { + stream.cancel(Status.INTERNAL.withDescription( + String.format("Can't find decompressor for %s", encoding))); + return; + } } + stream.setDecompressor(decompressor); + callExecutor.execute(new ContextRunnable(context) { @Override public final void runInContext() { diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index d218dbdc35..a56f53249a 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -35,8 +35,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import io.grpc.Compressor; -import io.grpc.CompressorRegistry; -import io.grpc.DecompressorRegistry; +import io.grpc.Decompressor; import io.grpc.Metadata; import io.grpc.Status; @@ -67,10 +66,6 @@ class DelayedStream implements ClientStream { @GuardedBy("this") private Status error; - @GuardedBy("this") - private Iterable compressionMessageEncodings; - @GuardedBy("this") - private DecompressorRegistry decompressionRegistry; @GuardedBy("this") private final List pendingMessages = new LinkedList(); private boolean messageCompressionEnabled; @@ -81,7 +76,9 @@ class DelayedStream implements ClientStream { @GuardedBy("this") private boolean pendingFlush; @GuardedBy("this") - private CompressorRegistry compressionRegistry; + private Compressor compressor; + @GuardedBy("this") + private Decompressor decompressor; static final class PendingMessage { final InputStream message; @@ -118,15 +115,13 @@ class DelayedStream implements ClientStream { checkState(listener != null, "listener"); realStream.start(listener); - if (compressionMessageEncodings != null) { - realStream.pickCompressor(compressionMessageEncodings); + if (decompressor != null) { + realStream.setDecompressor(decompressor); } - if (this.decompressionRegistry != null) { - realStream.setDecompressionRegistry(this.decompressionRegistry); - } - if (this.compressionRegistry != null) { - realStream.setCompressionRegistry(this.compressionRegistry); + if (compressor != null) { + realStream.setCompressor(compressor); } + for (PendingMessage message : pendingMessages) { realStream.setMessageCompression(message.shouldBeCompressed); realStream.writeMessage(message.message); @@ -246,45 +241,29 @@ class DelayedStream implements ClientStream { } @Override - public Compressor pickCompressor(Iterable messageEncodings) { + public void setCompressor(Compressor compressor) { if (startedRealStream == null) { synchronized (this) { if (startedRealStream == null) { - compressionMessageEncodings = messageEncodings; - // ClientCall never uses this. Since the stream doesn't exist yet, it can't say what - // stream it would pick. Eventually this will need a cleaner solution. - // TODO(carl-mastrangelo): Remove this. - return null; - } - } - } - return startedRealStream.pickCompressor(messageEncodings); - } - - @Override - public void setCompressionRegistry(CompressorRegistry registry) { - if (startedRealStream == null) { - synchronized (this) { - if (startedRealStream == null) { - compressionRegistry = registry; + this.compressor = compressor; return; } } } - startedRealStream.setCompressionRegistry(registry); + startedRealStream.setCompressor(compressor); } @Override - public void setDecompressionRegistry(DecompressorRegistry registry) { + public void setDecompressor(Decompressor decompressor) { if (startedRealStream == null) { synchronized (this) { if (startedRealStream == null) { - decompressionRegistry = registry; + this.decompressor = decompressor; return; } } } - startedRealStream.setDecompressionRegistry(registry); + startedRealStream.setDecompressor(decompressor); } @Override diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 8835f27647..137eedce9e 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -60,12 +60,9 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; @@ -96,18 +93,6 @@ public final class ManagedChannelImpl extends ManagedChannel { private final String userAgent; private final Object lock = new Object(); - /* Compression related */ - /** - * When a client connects to a server, it does not know what encodings are supported. This set - * is the union of all accept-encoding headers the server has sent. It is used to pick an - * encoding when contacting the server again. One problem with the gRPC protocol is that if - * there is only one RPC made (perhaps streaming, or otherwise long lived) an encoding will not - * be selected. To combat this you can preflight a request to the server to fill in the mapping - * for the next one. A better solution is if you have prior knowledge that the server supports - * an encoding, and fill this structure before the request. - */ - private final Set knownAcceptEncodingRegistry = - Collections.newSetFromMap(new ConcurrentHashMap()); private final DecompressorRegistry decompressorRegistry; private final CompressorRegistry compressorRegistry; @@ -335,8 +320,7 @@ public final class ManagedChannelImpl extends ManagedChannel { scheduledExecutor) .setUserAgent(userAgent) .setDecompressorRegistry(decompressorRegistry) - .setCompressorRegistry(compressorRegistry) - .setKnownMessageEncodingRegistry(knownAcceptEncodingRegistry); + .setCompressorRegistry(compressorRegistry); } @Override diff --git a/core/src/main/java/io/grpc/internal/NoopClientStream.java b/core/src/main/java/io/grpc/internal/NoopClientStream.java index 25c3d81660..af25b19e6a 100644 --- a/core/src/main/java/io/grpc/internal/NoopClientStream.java +++ b/core/src/main/java/io/grpc/internal/NoopClientStream.java @@ -32,8 +32,7 @@ package io.grpc.internal; import io.grpc.Compressor; -import io.grpc.CompressorRegistry; -import io.grpc.DecompressorRegistry; +import io.grpc.Decompressor; import io.grpc.Status; import java.io.InputStream; @@ -67,19 +66,14 @@ public class NoopClientStream implements ClientStream { @Override public void halfClose() {} - @Override - public void setDecompressionRegistry(DecompressorRegistry registry) {} - @Override public void setMessageCompression(boolean enable) { // noop } @Override - public Compressor pickCompressor(Iterable messageEncodings) { - return null; - } + public void setCompressor(Compressor compressor) {} @Override - public void setCompressionRegistry(CompressorRegistry registry) {} + public void setDecompressor(Decompressor decompressor) {} } diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java index 2c9322c8ed..4fff4b31bf 100644 --- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java @@ -31,13 +31,23 @@ package io.grpc.internal; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER; +import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER; +import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY; +import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Throwables; +import io.grpc.Codec; +import io.grpc.Compressor; +import io.grpc.CompressorRegistry; import io.grpc.Context; +import io.grpc.Decompressor; +import io.grpc.DecompressorRegistry; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.MethodType; @@ -46,22 +56,44 @@ import io.grpc.Status; import java.io.IOException; import java.io.InputStream; +import java.util.List; +import java.util.Set; import java.util.concurrent.Future; final class ServerCallImpl extends ServerCall { private final ServerStream stream; private final MethodDescriptor method; private final Context.CancellableContext context; + private Metadata inboundHeaders; + private final DecompressorRegistry decompressorRegistry; + private final CompressorRegistry compressorRegistry; + // state private volatile boolean cancelled; private boolean sendHeadersCalled; private boolean closeCalled; + private Compressor compressor; ServerCallImpl(ServerStream stream, MethodDescriptor method, - Context.CancellableContext context) { + Metadata inboundHeaders, Context.CancellableContext context, + DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry) { this.stream = stream; this.method = method; this.context = context; + this.inboundHeaders = inboundHeaders; + this.decompressorRegistry = decompressorRegistry; + this.compressorRegistry = compressorRegistry; + + if (inboundHeaders.containsKey(MESSAGE_ENCODING_KEY)) { + String encoding = inboundHeaders.get(MESSAGE_ENCODING_KEY); + Decompressor decompressor = decompressorRegistry.lookupDecompressor(encoding); + if (decompressor == null) { + throw Status.INTERNAL + .withDescription(String.format("Can't find decompressor for %s", encoding)) + .asRuntimeException(); + } + stream.setDecompressor(decompressor); + } } @Override @@ -73,6 +105,44 @@ final class ServerCallImpl extends ServerCall { public void sendHeaders(Metadata headers) { checkState(!sendHeadersCalled, "sendHeaders has already been called"); checkState(!closeCalled, "call is closed"); + + headers.removeAll(MESSAGE_ENCODING_KEY); + if (compressor == null) { + compressor = Codec.Identity.NONE; + if (inboundHeaders.containsKey(MESSAGE_ACCEPT_ENCODING_KEY)) { + String acceptEncodings = inboundHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY); + for (String acceptEncoding : ACCEPT_ENCODING_SPLITER.split(acceptEncodings)) { + Compressor c = compressorRegistry.lookupCompressor(acceptEncoding); + if (c != null) { + compressor = c; + break; + } + } + } + } else { + if (inboundHeaders.containsKey(MESSAGE_ACCEPT_ENCODING_KEY)) { + String acceptEncodings = inboundHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY); + List acceptedEncodingsList = ACCEPT_ENCODING_SPLITER.splitToList(acceptEncodings); + if (!acceptedEncodingsList.contains(compressor.getMessageEncoding())) { + // resort to using no compression. + compressor = Codec.Identity.NONE; + } + } else { + compressor = Codec.Identity.NONE; + } + } + inboundHeaders = null; + if (compressor != Codec.Identity.NONE) { + headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding()); + } + stream.setCompressor(compressor); + + headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY); + Set acceptEncodings = decompressorRegistry.getAdvertisedMessageEncodings(); + if (!acceptEncodings.isEmpty()) { + headers.put(MESSAGE_ACCEPT_ENCODING_KEY, ACCEPT_ENCODING_JOINER.join(acceptEncodings)); + } + // Don't check if sendMessage has been called, since it requires that sendHeaders was already // called. sendHeadersCalled = true; @@ -98,6 +168,15 @@ final class ServerCallImpl extends ServerCall { stream.setMessageCompression(enable); } + @Override + public void setCompression(String compressorName) { + // Added here to give a better error message. + checkState(!sendHeadersCalled, "sendHeaders has been called"); + + compressor = compressorRegistry.lookupCompressor(compressorName); + checkArgument(compressor != null, "Unable to find compressor by name %s", compressorName); + } + @Override public boolean isReady() { return stream.isReady(); @@ -108,6 +187,7 @@ final class ServerCallImpl extends ServerCall { try { checkState(!closeCalled, "call already closed"); closeCalled = true; + inboundHeaders = null; stream.close(status, trailers); } finally { if (status.getCode() == Status.Code.OK) { diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java index 6fc1c4dc6d..38240cfade 100644 --- a/core/src/main/java/io/grpc/internal/ServerImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerImpl.java @@ -40,7 +40,9 @@ import com.google.common.base.Throwables; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.CompressorRegistry; import io.grpc.Context; +import io.grpc.DecompressorRegistry; import io.grpc.HandlerRegistry; import io.grpc.Metadata; import io.grpc.ServerCall; @@ -104,20 +106,26 @@ public final class ServerImpl extends io.grpc.Server { private final ScheduledExecutorService timeoutService = SharedResourceHolder.get(TIMER_SERVICE); private final Context rootContext; + private final DecompressorRegistry decompressorRegistry; + private final CompressorRegistry compressorRegistry; + /** * Construct a server. * * @param executor to call methods on behalf of remote clients * @param registry of methods to expose to remote clients. */ - ServerImpl(Executor executor, HandlerRegistry registry, - io.grpc.internal.Server transportServer, Context rootContext) { + ServerImpl(Executor executor, HandlerRegistry registry, io.grpc.internal.Server transportServer, + Context rootContext, DecompressorRegistry decompressorRegistry, + CompressorRegistry compressorRegistry) { this.executor = executor; this.registry = Preconditions.checkNotNull(registry, "registry"); this.transportServer = Preconditions.checkNotNull(transportServer, "transportServer"); // Fork from the passed in context so that it does not propagate cancellation, it only // inherits values. this.rootContext = Preconditions.checkNotNull(rootContext).fork(); + this.decompressorRegistry = decompressorRegistry; + this.compressorRegistry = compressorRegistry; } /** @@ -355,7 +363,8 @@ public final class ServerImpl extends io.grpc.Server { Metadata headers, Context.CancellableContext context) { // TODO(ejona86): should we update fullMethodName to have the canonical path of the method? ServerCallImpl call = new ServerCallImpl( - stream, methodDef.getMethodDescriptor(), context); + stream, methodDef.getMethodDescriptor(), headers, context, decompressorRegistry, + compressorRegistry); ServerCall.Listener listener = methodDef.getServerCallHandler() .startCall(methodDef.getMethodDescriptor(), call, headers); if (listener == null) { diff --git a/core/src/main/java/io/grpc/internal/SingleTransportChannel.java b/core/src/main/java/io/grpc/internal/SingleTransportChannel.java index b5163c416c..a1919f58ab 100644 --- a/core/src/main/java/io/grpc/internal/SingleTransportChannel.java +++ b/core/src/main/java/io/grpc/internal/SingleTransportChannel.java @@ -83,7 +83,7 @@ final class SingleTransportChannel extends Channel { MethodDescriptor methodDescriptor, CallOptions callOptions) { return new ClientCallImpl(methodDescriptor, new SerializingExecutor(executor), callOptions, transportProvider, - deadlineCancellationExecutor).setKnownMessageEncodingRegistry(knownAcceptEncodingRegistry); + deadlineCancellationExecutor); } @Override diff --git a/core/src/main/java/io/grpc/internal/Stream.java b/core/src/main/java/io/grpc/internal/Stream.java index b52d9b5efc..35f4c77c46 100644 --- a/core/src/main/java/io/grpc/internal/Stream.java +++ b/core/src/main/java/io/grpc/internal/Stream.java @@ -32,13 +32,10 @@ package io.grpc.internal; import io.grpc.Compressor; -import io.grpc.CompressorRegistry; -import io.grpc.DecompressorRegistry; +import io.grpc.Decompressor; import java.io.InputStream; -import javax.annotation.Nullable; - /** * A single stream of communication between two end-points within a transport. * @@ -85,43 +82,22 @@ public interface Stream { boolean isReady(); /** - * Picks a compressor for for this stream. If no message encodings are acceptable, compression is - * not used. It is undefined if this this method is invoked multiple times. If the stream has - * a {@code start()} method, pickCompressor must be called prior to start. + * Sets the compressor on the framer. * - * - * @param messageEncodings a group of message encoding names that the remote endpoint is known - * to support. - * @return The compressor chosen for the stream, or null if none selected. + * @param compressor the compressor to use */ - @Nullable - Compressor pickCompressor(Iterable messageEncodings); + void setCompressor(Compressor compressor); + + /** + * Sets the decompressor on the deframer. + * + * @param decompressor the decompressor to use. + */ + void setDecompressor(Decompressor decompressor); /** * Enables per-message compression, if an encoding type has been negotiated. If no message * encoding has been negotiated, this is a no-op. */ void setMessageCompression(boolean enable); - - /** - * Sets the decompressor registry to use when resolving {@code #setDecompressor(String)}. If - * unset, the default DecompressorRegistry will be used. If the stream has a {@code start()} - * method, setDecompressionRegistry must be called prior to start. - * - * @see DecompressorRegistry#getDefaultInstance() - * - * @param registry the decompressors to use. - */ - void setDecompressionRegistry(DecompressorRegistry registry); - - /** - * Sets the compressor registry to use when resolving {@link #pickCompressor}. If unset, the - * default CompressorRegistry will be used. If the stream has a {@code start()} method, - * setCompressionRegistry must be called prior to start. - * - * @see CompressorRegistry#getDefaultInstance() - * - * @param registry the compressors to use. - */ - void setCompressionRegistry(CompressorRegistry registry); } diff --git a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java index ed6209102b..76d9b070d0 100644 --- a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java @@ -239,23 +239,6 @@ public class AbstractClientStreamTest { verify(mockListener).headersRead(headers); } - @Test - public void inboundHeadersReceived_notifiesListenerOnBadEncoding() { - AbstractClientStream stream = new BaseAbstractClientStream(allocator); - stream.start(mockListener); - Metadata headers = new Metadata(); - headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, "bad"); - Metadata.Key randomKey = Metadata.Key.of("random", Metadata.ASCII_STRING_MARSHALLER); - headers.put(randomKey, "4"); - - stream.inboundHeadersReceived(headers); - - ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); - verify(mockListener).closed(statusCaptor.capture(), metadataCaptor.capture()); - assertEquals(Code.INTERNAL, statusCaptor.getValue().getCode()); - assertEquals("4", metadataCaptor.getValue().get(randomKey)); - } - @Test public void rstStreamClosesStream() { AbstractClientStream stream = new BaseAbstractClientStream(allocator); diff --git a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java index 66e4763235..d85b89665c 100644 --- a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java @@ -55,7 +55,6 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.CallOptions; import io.grpc.ClientCall; import io.grpc.Codec; -import io.grpc.CompressorRegistry; import io.grpc.Context; import io.grpc.Decompressor; import io.grpc.DecompressorRegistry; @@ -80,7 +79,6 @@ import org.mockito.MockitoAnnotations; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; -import java.util.HashSet; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; @@ -103,8 +101,6 @@ public class ClientCallImplTest { private final FakeClock fakeClock = new FakeClock(); private final ScheduledExecutorService deadlineCancellationExecutor = fakeClock.scheduledExecutorService; - private final Set knownMessageEncodings = new HashSet(); - private final CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance(); private final DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance(); private final MethodDescriptor method = MethodDescriptor.create( @@ -159,9 +155,7 @@ public class ClientCallImplTest { CallOptions.DEFAULT, provider, deadlineCancellationExecutor) - .setDecompressorRegistry(decompressorRegistry) - .setCompressorRegistry(compressorRegistry) - .setKnownMessageEncodingRegistry(knownMessageEncodings); + .setDecompressorRegistry(decompressorRegistry); call.start(callListener, new Metadata()); @@ -178,8 +172,8 @@ public class ClientCallImplTest { public void prepareHeaders_authorityAdded() { Metadata m = new Metadata(); CallOptions callOptions = CallOptions.DEFAULT.withAuthority("auth"); - ClientCallImpl.prepareHeaders(m, callOptions, "user agent", knownMessageEncodings, - decompressorRegistry, compressorRegistry); + ClientCallImpl.prepareHeaders(m, callOptions, "user agent", decompressorRegistry, + Codec.Identity.NONE); assertEquals(m.get(GrpcUtil.AUTHORITY_KEY), "auth"); } @@ -187,28 +181,17 @@ public class ClientCallImplTest { @Test public void prepareHeaders_userAgentAdded() { Metadata m = new Metadata(); - ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", knownMessageEncodings, - decompressorRegistry, compressorRegistry); + ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", decompressorRegistry, + Codec.Identity.NONE); assertEquals(m.get(GrpcUtil.USER_AGENT_KEY), "user agent"); } - @Test - public void prepareHeaders_messageEncodingAdded() { - Metadata m = new Metadata(); - knownMessageEncodings.add(new Codec.Gzip().getMessageEncoding()); - ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", knownMessageEncodings, - decompressorRegistry, compressorRegistry); - - assertEquals(m.get(GrpcUtil.MESSAGE_ENCODING_KEY), new Codec.Gzip().getMessageEncoding()); - } - @Test public void prepareHeaders_ignoreIdentityEncoding() { Metadata m = new Metadata(); - knownMessageEncodings.add(Codec.Identity.NONE.getMessageEncoding()); - ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", knownMessageEncodings, - decompressorRegistry, compressorRegistry); + ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", decompressorRegistry, + Codec.Identity.NONE); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); } @@ -251,8 +234,8 @@ public class ClientCallImplTest { } }, false); // not advertised - ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", knownMessageEncodings, - customRegistry, compressorRegistry); + ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", customRegistry, + Codec.Identity.NONE); Iterable acceptedEncodings = ACCEPT_ENCODING_SPLITER.split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); @@ -269,8 +252,8 @@ public class ClientCallImplTest { m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip"); - ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, null, knownMessageEncodings, - DecompressorRegistry.newEmptyInstance(), compressorRegistry); + ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, null, + DecompressorRegistry.newEmptyInstance(), Codec.Identity.NONE); assertNull(m.get(GrpcUtil.AUTHORITY_KEY)); assertNull(m.get(GrpcUtil.USER_AGENT_KEY)); @@ -294,9 +277,7 @@ public class ClientCallImplTest { CallOptions.DEFAULT, provider, deadlineCancellationExecutor) - .setDecompressorRegistry(decompressorRegistry) - .setCompressorRegistry(compressorRegistry) - .setKnownMessageEncodingRegistry(knownMessageEncodings); + .setDecompressorRegistry(decompressorRegistry); Context.ROOT.attach(); @@ -372,9 +353,7 @@ public class ClientCallImplTest { CallOptions.DEFAULT, provider, deadlineCancellationExecutor) - .setDecompressorRegistry(decompressorRegistry) - .setCompressorRegistry(compressorRegistry) - .setKnownMessageEncodingRegistry(knownMessageEncodings); + .setDecompressorRegistry(decompressorRegistry); previous.attach(); @@ -454,9 +433,7 @@ public class ClientCallImplTest { callOptions, provider, deadlineCancellationExecutor) - .setDecompressorRegistry(decompressorRegistry) - .setCompressorRegistry(compressorRegistry) - .setKnownMessageEncodingRegistry(knownMessageEncodings); + .setDecompressorRegistry(decompressorRegistry); call.start(callListener, new Metadata()); assertFalse(future.isDone()); fakeClock.forwardTime(1, TimeUnit.SECONDS); diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index 641eb05005..3b08447900 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -36,8 +36,7 @@ import static org.mockito.Matchers.isA; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import io.grpc.CompressorRegistry; -import io.grpc.DecompressorRegistry; +import io.grpc.Codec; import io.grpc.Metadata; import io.grpc.Status; @@ -77,10 +76,8 @@ public class DelayedStreamTest { @Test public void setStream_sendsAllMessages() { stream.start(listener); - DecompressorRegistry decompressors = DecompressorRegistry.newEmptyInstance(); - CompressorRegistry compressors = CompressorRegistry.newEmptyInstance(); - stream.setDecompressionRegistry(decompressors); - stream.setCompressionRegistry(compressors); + stream.setCompressor(Codec.Identity.NONE); + stream.setDecompressor(Codec.Identity.NONE); stream.setMessageCompression(true); InputStream message = new ByteArrayInputStream(new byte[]{'a'}); @@ -90,8 +87,8 @@ public class DelayedStreamTest { stream.setStream(realStream); - verify(realStream).setDecompressionRegistry(decompressors); - verify(realStream).setCompressionRegistry(compressors); + verify(realStream).setCompressor(Codec.Identity.NONE); + verify(realStream).setDecompressor(Codec.Identity.NONE); // Verify that the order was correct, even though they should be interleaved with the // writeMessage calls diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 78276acd99..cfb4b545f9 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -54,6 +54,7 @@ import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; +import io.grpc.Compressor; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; import io.grpc.IntegerMarshaller; @@ -195,8 +196,7 @@ public class ManagedChannelImplTest { ClientTransport.Listener transportListener = transportListenerCaptor.getValue(); verify(mockTransport, timeout(1000)).newStream(same(method), same(headers)); verify(mockStream).start(streamListenerCaptor.capture()); - verify(mockStream).setDecompressionRegistry(isA(DecompressorRegistry.class)); - verify(mockStream).setCompressionRegistry(isA(CompressorRegistry.class)); + verify(mockStream).setCompressor(isA(Compressor.class)); ClientStreamListener streamListener = streamListenerCaptor.getValue(); // Second call diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java index 2742ee28c3..2560c5099e 100644 --- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java @@ -44,7 +44,9 @@ import static org.mockito.Mockito.when; import com.google.common.io.CharStreams; import com.google.common.util.concurrent.Futures; +import io.grpc.CompressorRegistry; import io.grpc.Context; +import io.grpc.DecompressorRegistry; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.Marshaller; @@ -89,7 +91,8 @@ public class ServerCallImplTest { public void setUp() { MockitoAnnotations.initMocks(this); context = Context.ROOT.withCancellation(); - call = new ServerCallImpl(stream, method, context); + call = new ServerCallImpl(stream, method, new Metadata(), context, + DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance()); } @Test diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index 3c323777c4..c2a437f765 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -49,7 +49,10 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.Compressor; +import io.grpc.CompressorRegistry; import io.grpc.Context; +import io.grpc.DecompressorRegistry; import io.grpc.IntegerMarshaller; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -91,6 +94,9 @@ public class ServerImplTest { private static final Context.Key SERVER_ONLY = Context.key("serverOnly"); private static final Context.CancellableContext SERVER_CONTEXT = Context.ROOT.withValue(SERVER_ONLY, "yes").withCancellation(); + private final CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance(); + private final DecompressorRegistry decompressorRegistry = + DecompressorRegistry.getDefaultInstance(); static { // Cancel the root context. Server will fork it so the per-call context should not @@ -101,7 +107,8 @@ public class ServerImplTest { private ExecutorService executor = Executors.newSingleThreadExecutor(); private MutableHandlerRegistry registry = new MutableHandlerRegistryImpl(); private SimpleServer transportServer = new SimpleServer(); - private ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT); + private ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT, + decompressorRegistry, compressorRegistry); @Mock private ServerStream stream; @@ -129,7 +136,8 @@ public class ServerImplTest { @Override public void shutdown() {} }; - ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT); + ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT, + decompressorRegistry, compressorRegistry); server.start(); server.shutdown(); assertTrue(server.isShutdown()); @@ -146,7 +154,8 @@ public class ServerImplTest { throw new AssertionError("Should not be called, because wasn't started"); } }; - ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT); + ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT, + decompressorRegistry, compressorRegistry); server.shutdown(); assertTrue(server.isShutdown()); assertTrue(server.isTerminated()); @@ -154,7 +163,8 @@ public class ServerImplTest { @Test public void startStopImmediateWithChildTransport() throws IOException { - ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT); + ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT, + decompressorRegistry, compressorRegistry); server.start(); class DelayedShutdownServerTransport extends SimpleServerTransport { boolean shutdown; @@ -186,7 +196,7 @@ public class ServerImplTest { } ServerImpl server = new ServerImpl(executor, registry, new FailingStartupServer(), - SERVER_CONTEXT); + SERVER_CONTEXT, decompressorRegistry, compressorRegistry); try { server.start(); fail("expected exception"); @@ -240,6 +250,7 @@ public class ServerImplTest { responseHeaders.put(metadataKey, "response value"); call.sendHeaders(responseHeaders); verify(stream).writeHeaders(responseHeaders); + verify(stream).setCompressor(isA(Compressor.class)); call.sendMessage(314); ArgumentCaptor inputCaptor = ArgumentCaptor.forClass(InputStream.class); @@ -322,7 +333,8 @@ public class ServerImplTest { } transportServer = new MaybeDeadlockingServer(); - ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT); + ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT, + decompressorRegistry, compressorRegistry); server.start(); new Thread() { @Override diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java index d7b70528a1..161bb58dba 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java @@ -230,11 +230,6 @@ public class CompressionTest { assertNull(clientResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY)); } - // Must be null for the first call. - assertNull(serverResponseHeaders.get(MESSAGE_ENCODING_KEY)); - assertFalse(clientCodec.anyWritten); - assertFalse(serverCodec.anyRead); - if (clientAcceptEncoding) { assertEquals("fzip", serverResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY)); } else { @@ -242,7 +237,6 @@ public class CompressionTest { } // Second call, once the client knows what the server supports. - stub.unaryCall(REQUEST); if (clientEncoding && serverAcceptEncoding) { assertEquals("fzip", serverResponseHeaders.get(MESSAGE_ENCODING_KEY)); if (enableClientMessageCompression) { @@ -274,7 +268,11 @@ public class CompressionTest { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { - final ClientCall call = next.newCall(method, callOptions); + if (clientEncoding && serverAcceptEncoding) { + callOptions = callOptions.withCompression("fzip"); + } + ClientCall call = next.newCall(method, callOptions); + return new ClientCompressor(call); } } diff --git a/netty/src/main/java/io/grpc/netty/NettyServer.java b/netty/src/main/java/io/grpc/netty/NettyServer.java index f422dc213b..0ce9cf83eb 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServer.java +++ b/netty/src/main/java/io/grpc/netty/NettyServer.java @@ -35,8 +35,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.netty.channel.ChannelOption.SO_BACKLOG; import static io.netty.channel.ChannelOption.SO_KEEPALIVE; -import io.grpc.CompressorRegistry; -import io.grpc.DecompressorRegistry; import io.grpc.internal.Server; import io.grpc.internal.ServerListener; import io.grpc.internal.SharedResourceHolder; @@ -74,8 +72,6 @@ class NettyServer implements Server { private EventLoopGroup workerGroup; private ServerListener listener; private Channel channel; - private final DecompressorRegistry decompressorRegistry; - private final CompressorRegistry compressorRegistry; private final int flowControlWindow; private final int maxMessageSize; private final int maxHeaderListSize; @@ -83,8 +79,7 @@ class NettyServer implements Server { NettyServer(SocketAddress address, Class channelType, @Nullable EventLoopGroup bossGroup, @Nullable EventLoopGroup workerGroup, - ProtocolNegotiator protocolNegotiator, DecompressorRegistry decompressorRegistry, - CompressorRegistry compressorRegistry, int maxStreamsPerConnection, + ProtocolNegotiator protocolNegotiator, int maxStreamsPerConnection, int flowControlWindow, int maxMessageSize, int maxHeaderListSize) { this.address = address; this.channelType = checkNotNull(channelType, "channelType"); @@ -97,8 +92,6 @@ class NettyServer implements Server { this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; this.maxHeaderListSize = maxHeaderListSize; - this.decompressorRegistry = checkNotNull(decompressorRegistry, "decompressorRegistry"); - this.compressorRegistry = checkNotNull(compressorRegistry, "compressorRegistry"); } @Override @@ -125,10 +118,8 @@ class NettyServer implements Server { eventLoopReferenceCounter.release(); } }); - NettyServerTransport transport - = new NettyServerTransport(ch, protocolNegotiator, decompressorRegistry, - compressorRegistry, maxStreamsPerConnection, flowControlWindow, maxMessageSize, - maxHeaderListSize); + NettyServerTransport transport = new NettyServerTransport(ch, protocolNegotiator, + maxStreamsPerConnection, flowControlWindow, maxMessageSize, maxHeaderListSize); transport.start(listener.transportCreated(transport)); } }); diff --git a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java index 449f115a00..0e160c7d0c 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java @@ -31,14 +31,11 @@ package io.grpc.netty; -import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkArgument; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import com.google.common.base.Preconditions; -import io.grpc.CompressorRegistry; -import io.grpc.DecompressorRegistry; import io.grpc.ExperimentalApi; import io.grpc.HandlerRegistry; import io.grpc.Internal; @@ -244,12 +241,9 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder= 0, "maxMessageSize must be >= 0"); this.maxMessageSize = maxMessageSize; - this.decompressorRegistry = checkNotNull(decompressorRegistry, "decompressorRegistry"); - this.compressorRegistry = checkNotNull(compressorRegistry, "compressorRegistry"); streamKey = encoder.connection().newKey(); this.transportListener = checkNotNull(transportListener, "transportListener"); @@ -206,11 +193,6 @@ class NettyServerHandler extends AbstractNettyHandler { NettyServerStream stream = new NettyServerStream(ctx.channel(), http2Stream, this, maxMessageSize); - // These must be called before inboundHeadersReceived, because the framers depend on knowing - // the compression algorithms available before negotiation. - stream.setDecompressionRegistry(decompressorRegistry); - stream.setCompressionRegistry(compressorRegistry); - Metadata metadata = Utils.convertHeaders(headers); stream.inboundHeadersReceived(metadata); diff --git a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java index 7275c3c5b4..275670350e 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java @@ -31,12 +31,8 @@ package io.grpc.netty; -import static com.google.common.base.Preconditions.checkNotNull; - import com.google.common.base.Preconditions; -import io.grpc.CompressorRegistry; -import io.grpc.DecompressorRegistry; import io.grpc.internal.ServerTransport; import io.grpc.internal.ServerTransportListener; import io.netty.channel.Channel; @@ -55,8 +51,6 @@ class NettyServerTransport implements ServerTransport { private final Channel channel; private final ProtocolNegotiator protocolNegotiator; - private final DecompressorRegistry decompressorRegistry; - private final CompressorRegistry compressorRegistry; private final int maxStreams; private ServerTransportListener listener; private boolean terminated; @@ -64,18 +58,14 @@ class NettyServerTransport implements ServerTransport { private final int maxMessageSize; private final int maxHeaderListSize; - NettyServerTransport(Channel channel, ProtocolNegotiator protocolNegotiator, - DecompressorRegistry decompressorRegistry, - CompressorRegistry compressorRegistry, int maxStreams, int flowControlWindow, - int maxMessageSize, int maxHeaderListSize) { + NettyServerTransport(Channel channel, ProtocolNegotiator protocolNegotiator, int maxStreams, + int flowControlWindow, int maxMessageSize, int maxHeaderListSize) { this.channel = Preconditions.checkNotNull(channel, "channel"); this.protocolNegotiator = Preconditions.checkNotNull(protocolNegotiator, "protocolNegotiator"); this.maxStreams = maxStreams; this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; this.maxHeaderListSize = maxHeaderListSize; - this.decompressorRegistry = checkNotNull(decompressorRegistry, "decompressorRegistry"); - this.compressorRegistry = checkNotNull(compressorRegistry, "compressorRegistry"); } public void start(ServerTransportListener listener) { @@ -125,7 +115,7 @@ class NettyServerTransport implements ServerTransport { * Creates the Netty handler to be used in the channel pipeline. */ private NettyServerHandler createHandler(ServerTransportListener transportListener) { - return NettyServerHandler.newHandler(transportListener, decompressorRegistry, - compressorRegistry, maxStreams, flowControlWindow, maxHeaderListSize, maxMessageSize); + return NettyServerHandler.newHandler(transportListener, maxStreams, flowControlWindow, + maxHeaderListSize, maxMessageSize); } } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 03e21b5099..6898eaaf71 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -43,8 +43,6 @@ import static org.junit.Assert.fail; import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.SettableFuture; -import io.grpc.CompressorRegistry; -import io.grpc.DecompressorRegistry; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.Marshaller; @@ -321,10 +319,8 @@ public class NettyClientTransportTest { SslContext serverContext = GrpcSslContexts.forServer(serverCert, key) .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build(); ProtocolNegotiator negotiator = ProtocolNegotiators.serverTls(serverContext); - server = new NettyServer(address, NioServerSocketChannel.class, - group, group, negotiator, DecompressorRegistry.getDefaultInstance(), - CompressorRegistry.getDefaultInstance(), maxStreamsPerConnection, - DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, maxHeaderListSize); + server = new NettyServer(address, NioServerSocketChannel.class, group, group, negotiator, + maxStreamsPerConnection, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, maxHeaderListSize); server.start(serverListener); } diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 9e4c0bbfd8..52ec3fd402 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -54,8 +54,6 @@ import static org.mockito.Mockito.when; import com.google.common.io.ByteStreams; -import io.grpc.CompressorRegistry; -import io.grpc.DecompressorRegistry; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.Status.Code; @@ -347,7 +345,6 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase> { return build(channel, callOptions.withDeadlineAfter(duration, unit)); } + /** + * Set's the compressor name to use for the call. It is the responsibility of the application + * to make sure the server supports decoding the compressor picked by the client. To be clear, + * this is the compressor used by the stub to compress messages to the server. To get + * compressed responses from the server, set the appropriate {@link io.grpc.DecompressorRegistry} + * on the {@link io.grpc.ManagedChannelBuilder}. + * + * @param compressorName the name (e.g. "gzip") of the compressor to use. + */ + @ExperimentalApi + public final S withCompression(String compressorName) { + return build(channel, callOptions.withCompression(compressorName)); + } + /** * Returns a new stub that uses the given channel. */