Add a simple compression API

This commit is contained in:
Carl Mastrangelo 2016-01-27 16:49:41 -08:00
parent 6af2ddafe5
commit a3c79e87ae
31 changed files with 305 additions and 416 deletions

View File

@ -63,6 +63,9 @@ public final class CallOptions {
@Nullable
private RequestKey requestKey;
@Nullable
private String compressorName;
/**
* Override the HTTP/2 authority the channel claims to be connecting to. <em>This is not
* generally safe.</em> Overriding allows advanced users to re-use a single Channel for multiple
@ -79,6 +82,17 @@ public final class CallOptions {
return newOptions;
}
/**
* Sets the compression to use for the call. The compressor must be a valid name known in the
* {@link CompressorRegistry}.
*/
@ExperimentalApi
public CallOptions withCompression(@Nullable String compressorName) {
CallOptions newOptions = new CallOptions(this);
newOptions.compressorName = compressorName;
return newOptions;
}
/**
* Returns a new {@code CallOptions} with the given absolute deadline in nanoseconds in the clock
* as per {@link System#nanoTime()}.
@ -131,6 +145,16 @@ public final class CallOptions {
return requestKey;
}
/**
* Returns the compressor's name.
*/
@ExperimentalApi
@Nullable
public String getCompressor() {
return compressorName;
}
/**
* Override the HTTP/2 authority the channel claims to be connecting to. <em>This is not
* generally safe.</em> Overriding allows advanced users to re-use a single Channel for multiple
@ -172,6 +196,7 @@ public final class CallOptions {
authority = other.authority;
requestKey = other.requestKey;
executor = other.executor;
compressorName = other.compressorName;
}
@Override

View File

@ -76,6 +76,12 @@ public abstract class ForwardingServerCall<RespT> extends ServerCall<RespT> {
delegate().setMessageCompression(enabled);
}
@Override
@ExperimentalApi
public void setCompression(String compressor) {
delegate().setCompression(compressor);
}
/**
* A simplified version of {@link ForwardingServerCall} where subclasses can pass in a {@link
* ServerCall} as the delegate.

View File

@ -176,4 +176,18 @@ public abstract class ServerCall<RespT> {
public void setMessageCompression(boolean enabled) {
// noop
}
/**
* Sets the compression algorithm for this call. If the server does not support the compression
* algorithm, the call will fail. This method may only be called before {@link #sendHeaders}.
* The compressor to use will be looked up in the {@link CompressorRegistry}. Default gRPC
* servers support the "gzip" compressor.
*
* @param compressor the name of the compressor to use.
* @throws IllegalArgumentException if the compressor name can not be found.
*/
@ExperimentalApi
public void setCompression(String compressor) {
// noop
}
}

View File

@ -34,8 +34,7 @@ package io.grpc.inprocess;
import static com.google.common.base.Preconditions.checkNotNull;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Decompressor;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
@ -230,9 +229,6 @@ class InProcessTransport implements ServerTransport, ClientTransport {
clientStreamListener = listener;
}
@Override
public void setDecompressionRegistry(DecompressorRegistry registry) {}
@Override
public void request(int numMessages) {
clientStream.serverRequested(numMessages);
@ -345,12 +341,10 @@ class InProcessTransport implements ServerTransport, ClientTransport {
}
@Override
public Compressor pickCompressor(Iterable<String> messageEncodings) {
return null;
}
public void setCompressor(Compressor compressor) {}
@Override
public void setCompressionRegistry(CompressorRegistry registry) {}
public void setDecompressor(Decompressor decompressor) {}
}
private class InProcessClientStream implements ClientStream {
@ -457,20 +451,9 @@ class InProcessTransport implements ServerTransport, ClientTransport {
}
}
@Override
public void setDecompressionRegistry(DecompressorRegistry registry) {}
@Override
public void setMessageCompression(boolean enable) {}
@Override
public Compressor pickCompressor(Iterable<String> messageEncodings) {
return null;
}
@Override
public void setCompressionRegistry(CompressorRegistry registry) {}
@Override
public void start(ClientStreamListener listener) {
serverStream.setListener(listener);
@ -483,6 +466,12 @@ class InProcessTransport implements ServerTransport, ClientTransport {
streams.add(InProcessTransport.InProcessStream.this);
}
}
@Override
public void setCompressor(Compressor compressor) {}
@Override
public void setDecompressor(Decompressor decompressor) {}
}
}
}

View File

@ -122,21 +122,6 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
log.log(Level.INFO, "Received headers on closed stream {0} {1}",
new Object[]{id(), headers});
}
if (headers.containsKey(GrpcUtil.MESSAGE_ENCODING_KEY)) {
String messageEncoding = headers.get(GrpcUtil.MESSAGE_ENCODING_KEY);
try {
setDecompressor(messageEncoding);
} catch (IllegalArgumentException e) {
// Don't use INVALID_ARGUMENT since that is for servers to send clients.
Status status = Status.INTERNAL.withDescription("Unable to decompress message from server.")
.withCause(e);
// TODO(carl-mastrangelo): look back into tearing down this stream. sendCancel() can be
// buffered.
inboundTransportError(status, headers);
sendCancel(status);
return;
}
}
inboundPhase(Phase.MESSAGE);
listener.headersRead(headers);

View File

@ -31,6 +31,8 @@
package io.grpc.internal;
import static com.google.common.base.MoreObjects.firstNonNull;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.MoreExecutors;
@ -112,24 +114,18 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
return thisT();
}
protected final DecompressorRegistry decompressorRegistry() {
return decompressorRegistry;
}
@Override
public final T compressorRegistry(CompressorRegistry registry) {
compressorRegistry = registry;
return thisT();
}
protected final CompressorRegistry compressorRegistry() {
return compressorRegistry;
}
@Override
public ServerImpl build() {
io.grpc.internal.Server transportServer = buildTransportServer();
return new ServerImpl(executor, registry, transportServer, Context.ROOT);
return new ServerImpl(executor, registry, transportServer, Context.ROOT,
firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()),
firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()));
}
/**

View File

@ -32,14 +32,9 @@
package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER;
import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;
import com.google.common.base.Preconditions;
import io.grpc.Compressor;
import io.grpc.Metadata;
import io.grpc.Status;
@ -61,7 +56,6 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
private ServerStreamListener listener;
private boolean headersSent = false;
private String messageEncoding;
/**
* Whether the stream was closed gracefully by the application (vs. a transport-level failure).
*/
@ -100,16 +94,6 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
@Override
public final void writeHeaders(Metadata headers) {
Preconditions.checkNotNull(headers, "headers");
headers.removeAll(MESSAGE_ENCODING_KEY);
if (messageEncoding != null) {
headers.put(MESSAGE_ENCODING_KEY, messageEncoding);
}
headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY);
if (!decompressorRegistry().getAdvertisedMessageEncodings().isEmpty()) {
String acceptEncoding =
ACCEPT_ENCODING_JOINER.join(decompressorRegistry().getAdvertisedMessageEncodings());
headers.put(MESSAGE_ACCEPT_ENCODING_KEY, acceptEncoding);
}
outboundPhase(Phase.HEADERS);
headersSent = true;
@ -152,34 +136,6 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
* @param headers the parsed headers
*/
protected void inboundHeadersReceived(Metadata headers) {
if (headers.containsKey(GrpcUtil.MESSAGE_ENCODING_KEY)) {
String messageEncoding = headers.get(GrpcUtil.MESSAGE_ENCODING_KEY);
try {
setDecompressor(messageEncoding);
} catch (IllegalArgumentException e) {
Status status = Status.INVALID_ARGUMENT
.withDescription("Unable to decompress encoding " + messageEncoding)
.withCause(e);
abortStream(status, true);
return;
}
}
// This checks to see if the client will accept any encoding. If so, a compressor is picked for
// the stream, and the decision is recorded. When the Server Call Handler writes the first
// headers, the negotiated encoding will be added in #writeHeaders(). It is safe to call
// pickCompressor multiple times before the headers have been written to the wire, though in
// practice this should never happen. There should only be one call to inboundHeadersReceived.
// Alternatively, compression could be negotiated after the server handler is invoked, but that
// would mean the inbound header would have to be stored until the first #writeHeaders call.
if (headers.containsKey(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)) {
Compressor c = pickCompressor(
ACCEPT_ENCODING_SPLITER.split(headers.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)));
if (c != null) {
messageEncoding = c.getMessageEncoding();
}
}
inboundPhase(Phase.MESSAGE);
}

View File

@ -31,7 +31,6 @@
package io.grpc.internal;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
@ -40,9 +39,7 @@ import com.google.common.base.MoreObjects;
import io.grpc.Codec;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import java.io.InputStream;
@ -101,10 +98,6 @@ public abstract class AbstractStream<IdT> implements Stream {
private boolean allocated;
private final Object onReadyLock = new Object();
private volatile DecompressorRegistry decompressorRegistry =
DecompressorRegistry.getDefaultInstance();
private volatile CompressorRegistry compressorRegistry =
CompressorRegistry.getDefaultInstance();
@VisibleForTesting
class FramerSink implements MessageFramer.Sink {
@ -305,47 +298,14 @@ public abstract class AbstractStream<IdT> implements Stream {
}
}
/**
* Looks up the decompressor by its message encoding name, and sets it for this stream.
* Decompressors are registered with {@link DecompressorRegistry#register}.
*
* @param messageEncoding the name of the encoding provided by the remote host
* @throws IllegalArgumentException if the provided message encoding cannot be found.
*/
protected final void setDecompressor(String messageEncoding) {
Decompressor d = decompressorRegistry.lookupDecompressor(messageEncoding);
checkArgument(d != null,
"Unable to find decompressor for message encoding %s", messageEncoding);
deframer.setDecompressor(d);
@Override
public final void setCompressor(Compressor compressor) {
framer.setCompressor(checkNotNull(compressor, "compressor"));
}
@Override
public final void setDecompressionRegistry(DecompressorRegistry registry) {
decompressorRegistry = checkNotNull(registry);
}
@Override
public final void setCompressionRegistry(CompressorRegistry registry) {
compressorRegistry = checkNotNull(registry);
}
@Override
public final Compressor pickCompressor(Iterable<String> messageEncodings) {
for (String messageEncoding : messageEncodings) {
Compressor c = compressorRegistry.lookupCompressor(messageEncoding);
if (c != null) {
// TODO(carl-mastrangelo): check that headers haven't already been sent. I can't find where
// the client stream changes outbound phase correctly, so I am ignoring it.
framer.setCompressor(c);
return c;
}
}
return null;
}
// TODO(carl-mastrangelo): this is a hack to get around registry passing. Remove it.
protected final DecompressorRegistry decompressorRegistry() {
return decompressorRegistry;
public final void setDecompressor(Decompressor decompressor) {
deframer.setDecompressor(checkNotNull(decompressor, "decompressor"));
}
/**

View File

@ -33,10 +33,8 @@ package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterables.addAll;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER;
import static io.grpc.internal.GrpcUtil.AUTHORITY_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;
@ -57,6 +55,7 @@ import io.grpc.Codec;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.Context;
import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
@ -64,8 +63,6 @@ import io.grpc.MethodDescriptor.MethodType;
import io.grpc.Status;
import java.io.InputStream;
import java.util.Collections;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
@ -91,7 +88,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
private final ClientTransportProvider clientTransportProvider;
private String userAgent;
private ScheduledExecutorService deadlineCancellationExecutor;
private Set<String> knownMessageEncodingRegistry;
private Compressor compressor;
private DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance();
private CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance();
@ -146,19 +143,9 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
return this;
}
/**
* Sets encodings known to be supported by the server. This set MUST be thread safe, and MAY be
* modified by any code as it learns about new supported encodings.
*/
ClientCallImpl<ReqT, RespT> setKnownMessageEncodingRegistry(Set<String> knownMessageEncodings) {
this.knownMessageEncodingRegistry = knownMessageEncodings;
return this;
}
@VisibleForTesting
static void prepareHeaders(Metadata headers, CallOptions callOptions, String userAgent,
Set<String> knownMessageEncodings, DecompressorRegistry decompressorRegistry,
CompressorRegistry compressorRegistry) {
DecompressorRegistry decompressorRegistry, Compressor compressor) {
// Hack to propagate authority. This should be properly pass to the transport.newStream
// somehow.
headers.removeAll(AUTHORITY_KEY);
@ -173,12 +160,8 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
}
headers.removeAll(MESSAGE_ENCODING_KEY);
for (String messageEncoding : knownMessageEncodings) {
Compressor compressor = compressorRegistry.lookupCompressor(messageEncoding);
if (compressor != null && compressor != Codec.Identity.NONE) {
headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding());
break;
}
if (compressor != Codec.Identity.NONE) {
headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding());
}
headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY);
@ -207,8 +190,27 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
});
return;
}
prepareHeaders(headers, callOptions, userAgent,
knownMessageEncodingRegistry, decompressorRegistry, compressorRegistry);
final String compressorName = callOptions.getCompressor();
if (compressorName != null) {
compressor = compressorRegistry.lookupCompressor(compressorName);
if (compressor == null) {
stream = NoopClientStream.INSTANCE;
callExecutor.execute(new ContextRunnable(context) {
@Override
public void runInContext() {
observer.onClose(
Status.INTERNAL.withDescription(
String.format("Unable to find compressor by name %s", compressorName)),
new Metadata());
}
});
return;
}
} else {
compressor = Codec.Identity.NONE;
}
prepareHeaders(headers, callOptions, userAgent, decompressorRegistry, compressor);
ListenableFuture<ClientTransport> transportFuture = clientTransportProvider.get(callOptions);
@ -236,11 +238,8 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
transportFuture.isDone() ? directExecutor() : callExecutor);
}
stream.setDecompressionRegistry(decompressorRegistry);
stream.setCompressionRegistry(compressorRegistry);
if (headers.containsKey(MESSAGE_ENCODING_KEY)) {
stream.pickCompressor(Collections.singleton(headers.get(MESSAGE_ENCODING_KEY)));
// TODO(carl-mastrangelo): move this to ClientCall.
stream.setCompressor(compressor);
if (compressor != Codec.Identity.NONE) {
stream.setMessageCompression(true);
}
@ -387,13 +386,18 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
@Override
public void headersRead(final Metadata headers) {
if (headers.containsKey(MESSAGE_ACCEPT_ENCODING_KEY)) {
// TODO(carl-mastrangelo): after the first time we contact the server, it almost certainly
// won't change. It might be possible to recover performance by not adding to the known
// encodings if it isn't empty.
String serverAcceptEncodings = headers.get(MESSAGE_ACCEPT_ENCODING_KEY);
addAll(knownMessageEncodingRegistry, ACCEPT_ENCODING_SPLITER.split(serverAcceptEncodings));
Decompressor decompressor = Codec.Identity.NONE;
if (headers.containsKey(MESSAGE_ENCODING_KEY)) {
String encoding = headers.get(MESSAGE_ENCODING_KEY);
decompressor = decompressorRegistry.lookupDecompressor(encoding);
if (decompressor == null) {
stream.cancel(Status.INTERNAL.withDescription(
String.format("Can't find decompressor for %s", encoding)));
return;
}
}
stream.setDecompressor(decompressor);
callExecutor.execute(new ContextRunnable(context) {
@Override
public final void runInContext() {

View File

@ -35,8 +35,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Decompressor;
import io.grpc.Metadata;
import io.grpc.Status;
@ -67,10 +66,6 @@ class DelayedStream implements ClientStream {
@GuardedBy("this")
private Status error;
@GuardedBy("this")
private Iterable<String> compressionMessageEncodings;
@GuardedBy("this")
private DecompressorRegistry decompressionRegistry;
@GuardedBy("this")
private final List<PendingMessage> pendingMessages = new LinkedList<PendingMessage>();
private boolean messageCompressionEnabled;
@ -81,7 +76,9 @@ class DelayedStream implements ClientStream {
@GuardedBy("this")
private boolean pendingFlush;
@GuardedBy("this")
private CompressorRegistry compressionRegistry;
private Compressor compressor;
@GuardedBy("this")
private Decompressor decompressor;
static final class PendingMessage {
final InputStream message;
@ -118,15 +115,13 @@ class DelayedStream implements ClientStream {
checkState(listener != null, "listener");
realStream.start(listener);
if (compressionMessageEncodings != null) {
realStream.pickCompressor(compressionMessageEncodings);
if (decompressor != null) {
realStream.setDecompressor(decompressor);
}
if (this.decompressionRegistry != null) {
realStream.setDecompressionRegistry(this.decompressionRegistry);
}
if (this.compressionRegistry != null) {
realStream.setCompressionRegistry(this.compressionRegistry);
if (compressor != null) {
realStream.setCompressor(compressor);
}
for (PendingMessage message : pendingMessages) {
realStream.setMessageCompression(message.shouldBeCompressed);
realStream.writeMessage(message.message);
@ -246,45 +241,29 @@ class DelayedStream implements ClientStream {
}
@Override
public Compressor pickCompressor(Iterable<String> messageEncodings) {
public void setCompressor(Compressor compressor) {
if (startedRealStream == null) {
synchronized (this) {
if (startedRealStream == null) {
compressionMessageEncodings = messageEncodings;
// ClientCall never uses this. Since the stream doesn't exist yet, it can't say what
// stream it would pick. Eventually this will need a cleaner solution.
// TODO(carl-mastrangelo): Remove this.
return null;
}
}
}
return startedRealStream.pickCompressor(messageEncodings);
}
@Override
public void setCompressionRegistry(CompressorRegistry registry) {
if (startedRealStream == null) {
synchronized (this) {
if (startedRealStream == null) {
compressionRegistry = registry;
this.compressor = compressor;
return;
}
}
}
startedRealStream.setCompressionRegistry(registry);
startedRealStream.setCompressor(compressor);
}
@Override
public void setDecompressionRegistry(DecompressorRegistry registry) {
public void setDecompressor(Decompressor decompressor) {
if (startedRealStream == null) {
synchronized (this) {
if (startedRealStream == null) {
decompressionRegistry = registry;
this.decompressor = decompressor;
return;
}
}
}
startedRealStream.setDecompressionRegistry(registry);
startedRealStream.setDecompressor(decompressor);
}
@Override

View File

@ -60,12 +60,9 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
@ -96,18 +93,6 @@ public final class ManagedChannelImpl extends ManagedChannel {
private final String userAgent;
private final Object lock = new Object();
/* Compression related */
/**
* When a client connects to a server, it does not know what encodings are supported. This set
* is the union of all accept-encoding headers the server has sent. It is used to pick an
* encoding when contacting the server again. One problem with the gRPC protocol is that if
* there is only one RPC made (perhaps streaming, or otherwise long lived) an encoding will not
* be selected. To combat this you can preflight a request to the server to fill in the mapping
* for the next one. A better solution is if you have prior knowledge that the server supports
* an encoding, and fill this structure before the request.
*/
private final Set<String> knownAcceptEncodingRegistry =
Collections.newSetFromMap(new ConcurrentHashMap<String, Boolean>());
private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry;
@ -335,8 +320,7 @@ public final class ManagedChannelImpl extends ManagedChannel {
scheduledExecutor)
.setUserAgent(userAgent)
.setDecompressorRegistry(decompressorRegistry)
.setCompressorRegistry(compressorRegistry)
.setKnownMessageEncodingRegistry(knownAcceptEncodingRegistry);
.setCompressorRegistry(compressorRegistry);
}
@Override

View File

@ -32,8 +32,7 @@
package io.grpc.internal;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Decompressor;
import io.grpc.Status;
import java.io.InputStream;
@ -67,19 +66,14 @@ public class NoopClientStream implements ClientStream {
@Override
public void halfClose() {}
@Override
public void setDecompressionRegistry(DecompressorRegistry registry) {}
@Override
public void setMessageCompression(boolean enable) {
// noop
}
@Override
public Compressor pickCompressor(Iterable<String> messageEncodings) {
return null;
}
public void setCompressor(Compressor compressor) {}
@Override
public void setCompressionRegistry(CompressorRegistry registry) {}
public void setDecompressor(Decompressor decompressor) {}
}

View File

@ -31,13 +31,23 @@
package io.grpc.internal;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER;
import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import io.grpc.Codec;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.Context;
import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.MethodType;
@ -46,22 +56,44 @@ import io.grpc.Status;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Future;
final class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> {
private final ServerStream stream;
private final MethodDescriptor<ReqT, RespT> method;
private final Context.CancellableContext context;
private Metadata inboundHeaders;
private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry;
// state
private volatile boolean cancelled;
private boolean sendHeadersCalled;
private boolean closeCalled;
private Compressor compressor;
ServerCallImpl(ServerStream stream, MethodDescriptor<ReqT, RespT> method,
Context.CancellableContext context) {
Metadata inboundHeaders, Context.CancellableContext context,
DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry) {
this.stream = stream;
this.method = method;
this.context = context;
this.inboundHeaders = inboundHeaders;
this.decompressorRegistry = decompressorRegistry;
this.compressorRegistry = compressorRegistry;
if (inboundHeaders.containsKey(MESSAGE_ENCODING_KEY)) {
String encoding = inboundHeaders.get(MESSAGE_ENCODING_KEY);
Decompressor decompressor = decompressorRegistry.lookupDecompressor(encoding);
if (decompressor == null) {
throw Status.INTERNAL
.withDescription(String.format("Can't find decompressor for %s", encoding))
.asRuntimeException();
}
stream.setDecompressor(decompressor);
}
}
@Override
@ -73,6 +105,44 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> {
public void sendHeaders(Metadata headers) {
checkState(!sendHeadersCalled, "sendHeaders has already been called");
checkState(!closeCalled, "call is closed");
headers.removeAll(MESSAGE_ENCODING_KEY);
if (compressor == null) {
compressor = Codec.Identity.NONE;
if (inboundHeaders.containsKey(MESSAGE_ACCEPT_ENCODING_KEY)) {
String acceptEncodings = inboundHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY);
for (String acceptEncoding : ACCEPT_ENCODING_SPLITER.split(acceptEncodings)) {
Compressor c = compressorRegistry.lookupCompressor(acceptEncoding);
if (c != null) {
compressor = c;
break;
}
}
}
} else {
if (inboundHeaders.containsKey(MESSAGE_ACCEPT_ENCODING_KEY)) {
String acceptEncodings = inboundHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY);
List<String> acceptedEncodingsList = ACCEPT_ENCODING_SPLITER.splitToList(acceptEncodings);
if (!acceptedEncodingsList.contains(compressor.getMessageEncoding())) {
// resort to using no compression.
compressor = Codec.Identity.NONE;
}
} else {
compressor = Codec.Identity.NONE;
}
}
inboundHeaders = null;
if (compressor != Codec.Identity.NONE) {
headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding());
}
stream.setCompressor(compressor);
headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY);
Set<String> acceptEncodings = decompressorRegistry.getAdvertisedMessageEncodings();
if (!acceptEncodings.isEmpty()) {
headers.put(MESSAGE_ACCEPT_ENCODING_KEY, ACCEPT_ENCODING_JOINER.join(acceptEncodings));
}
// Don't check if sendMessage has been called, since it requires that sendHeaders was already
// called.
sendHeadersCalled = true;
@ -98,6 +168,15 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> {
stream.setMessageCompression(enable);
}
@Override
public void setCompression(String compressorName) {
// Added here to give a better error message.
checkState(!sendHeadersCalled, "sendHeaders has been called");
compressor = compressorRegistry.lookupCompressor(compressorName);
checkArgument(compressor != null, "Unable to find compressor by name %s", compressorName);
}
@Override
public boolean isReady() {
return stream.isReady();
@ -108,6 +187,7 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> {
try {
checkState(!closeCalled, "call already closed");
closeCalled = true;
inboundHeaders = null;
stream.close(status, trailers);
} finally {
if (status.getCode() == Status.Code.OK) {

View File

@ -40,7 +40,9 @@ import com.google.common.base.Throwables;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.CompressorRegistry;
import io.grpc.Context;
import io.grpc.DecompressorRegistry;
import io.grpc.HandlerRegistry;
import io.grpc.Metadata;
import io.grpc.ServerCall;
@ -104,20 +106,26 @@ public final class ServerImpl extends io.grpc.Server {
private final ScheduledExecutorService timeoutService = SharedResourceHolder.get(TIMER_SERVICE);
private final Context rootContext;
private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry;
/**
* Construct a server.
*
* @param executor to call methods on behalf of remote clients
* @param registry of methods to expose to remote clients.
*/
ServerImpl(Executor executor, HandlerRegistry registry,
io.grpc.internal.Server transportServer, Context rootContext) {
ServerImpl(Executor executor, HandlerRegistry registry, io.grpc.internal.Server transportServer,
Context rootContext, DecompressorRegistry decompressorRegistry,
CompressorRegistry compressorRegistry) {
this.executor = executor;
this.registry = Preconditions.checkNotNull(registry, "registry");
this.transportServer = Preconditions.checkNotNull(transportServer, "transportServer");
// Fork from the passed in context so that it does not propagate cancellation, it only
// inherits values.
this.rootContext = Preconditions.checkNotNull(rootContext).fork();
this.decompressorRegistry = decompressorRegistry;
this.compressorRegistry = compressorRegistry;
}
/**
@ -355,7 +363,8 @@ public final class ServerImpl extends io.grpc.Server {
Metadata headers, Context.CancellableContext context) {
// TODO(ejona86): should we update fullMethodName to have the canonical path of the method?
ServerCallImpl<ReqT, RespT> call = new ServerCallImpl<ReqT, RespT>(
stream, methodDef.getMethodDescriptor(), context);
stream, methodDef.getMethodDescriptor(), headers, context, decompressorRegistry,
compressorRegistry);
ServerCall.Listener<ReqT> listener = methodDef.getServerCallHandler()
.startCall(methodDef.getMethodDescriptor(), call, headers);
if (listener == null) {

View File

@ -83,7 +83,7 @@ final class SingleTransportChannel extends Channel {
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
return new ClientCallImpl<RequestT, ResponseT>(methodDescriptor,
new SerializingExecutor(executor), callOptions, transportProvider,
deadlineCancellationExecutor).setKnownMessageEncodingRegistry(knownAcceptEncodingRegistry);
deadlineCancellationExecutor);
}
@Override

View File

@ -32,13 +32,10 @@
package io.grpc.internal;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Decompressor;
import java.io.InputStream;
import javax.annotation.Nullable;
/**
* A single stream of communication between two end-points within a transport.
*
@ -85,43 +82,22 @@ public interface Stream {
boolean isReady();
/**
* Picks a compressor for for this stream. If no message encodings are acceptable, compression is
* not used. It is undefined if this this method is invoked multiple times. If the stream has
* a {@code start()} method, pickCompressor must be called prior to start.
* Sets the compressor on the framer.
*
*
* @param messageEncodings a group of message encoding names that the remote endpoint is known
* to support.
* @return The compressor chosen for the stream, or null if none selected.
* @param compressor the compressor to use
*/
@Nullable
Compressor pickCompressor(Iterable<String> messageEncodings);
void setCompressor(Compressor compressor);
/**
* Sets the decompressor on the deframer.
*
* @param decompressor the decompressor to use.
*/
void setDecompressor(Decompressor decompressor);
/**
* Enables per-message compression, if an encoding type has been negotiated. If no message
* encoding has been negotiated, this is a no-op.
*/
void setMessageCompression(boolean enable);
/**
* Sets the decompressor registry to use when resolving {@code #setDecompressor(String)}. If
* unset, the default DecompressorRegistry will be used. If the stream has a {@code start()}
* method, setDecompressionRegistry must be called prior to start.
*
* @see DecompressorRegistry#getDefaultInstance()
*
* @param registry the decompressors to use.
*/
void setDecompressionRegistry(DecompressorRegistry registry);
/**
* Sets the compressor registry to use when resolving {@link #pickCompressor}. If unset, the
* default CompressorRegistry will be used. If the stream has a {@code start()} method,
* setCompressionRegistry must be called prior to start.
*
* @see CompressorRegistry#getDefaultInstance()
*
* @param registry the compressors to use.
*/
void setCompressionRegistry(CompressorRegistry registry);
}

View File

@ -239,23 +239,6 @@ public class AbstractClientStreamTest {
verify(mockListener).headersRead(headers);
}
@Test
public void inboundHeadersReceived_notifiesListenerOnBadEncoding() {
AbstractClientStream<Integer> stream = new BaseAbstractClientStream<Integer>(allocator);
stream.start(mockListener);
Metadata headers = new Metadata();
headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, "bad");
Metadata.Key<String> randomKey = Metadata.Key.of("random", Metadata.ASCII_STRING_MARSHALLER);
headers.put(randomKey, "4");
stream.inboundHeadersReceived(headers);
ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
verify(mockListener).closed(statusCaptor.capture(), metadataCaptor.capture());
assertEquals(Code.INTERNAL, statusCaptor.getValue().getCode());
assertEquals("4", metadataCaptor.getValue().get(randomKey));
}
@Test
public void rstStreamClosesStream() {
AbstractClientStream<Integer> stream = new BaseAbstractClientStream<Integer>(allocator);

View File

@ -55,7 +55,6 @@ import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.Codec;
import io.grpc.CompressorRegistry;
import io.grpc.Context;
import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
@ -80,7 +79,6 @@ import org.mockito.MockitoAnnotations;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
@ -103,8 +101,6 @@ public class ClientCallImplTest {
private final FakeClock fakeClock = new FakeClock();
private final ScheduledExecutorService deadlineCancellationExecutor =
fakeClock.scheduledExecutorService;
private final Set<String> knownMessageEncodings = new HashSet<String>();
private final CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance();
private final DecompressorRegistry decompressorRegistry =
DecompressorRegistry.getDefaultInstance();
private final MethodDescriptor<Void, Void> method = MethodDescriptor.create(
@ -159,9 +155,7 @@ public class ClientCallImplTest {
CallOptions.DEFAULT,
provider,
deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry)
.setCompressorRegistry(compressorRegistry)
.setKnownMessageEncodingRegistry(knownMessageEncodings);
.setDecompressorRegistry(decompressorRegistry);
call.start(callListener, new Metadata());
@ -178,8 +172,8 @@ public class ClientCallImplTest {
public void prepareHeaders_authorityAdded() {
Metadata m = new Metadata();
CallOptions callOptions = CallOptions.DEFAULT.withAuthority("auth");
ClientCallImpl.prepareHeaders(m, callOptions, "user agent", knownMessageEncodings,
decompressorRegistry, compressorRegistry);
ClientCallImpl.prepareHeaders(m, callOptions, "user agent", decompressorRegistry,
Codec.Identity.NONE);
assertEquals(m.get(GrpcUtil.AUTHORITY_KEY), "auth");
}
@ -187,28 +181,17 @@ public class ClientCallImplTest {
@Test
public void prepareHeaders_userAgentAdded() {
Metadata m = new Metadata();
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", knownMessageEncodings,
decompressorRegistry, compressorRegistry);
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", decompressorRegistry,
Codec.Identity.NONE);
assertEquals(m.get(GrpcUtil.USER_AGENT_KEY), "user agent");
}
@Test
public void prepareHeaders_messageEncodingAdded() {
Metadata m = new Metadata();
knownMessageEncodings.add(new Codec.Gzip().getMessageEncoding());
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", knownMessageEncodings,
decompressorRegistry, compressorRegistry);
assertEquals(m.get(GrpcUtil.MESSAGE_ENCODING_KEY), new Codec.Gzip().getMessageEncoding());
}
@Test
public void prepareHeaders_ignoreIdentityEncoding() {
Metadata m = new Metadata();
knownMessageEncodings.add(Codec.Identity.NONE.getMessageEncoding());
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", knownMessageEncodings,
decompressorRegistry, compressorRegistry);
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", decompressorRegistry,
Codec.Identity.NONE);
assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
}
@ -251,8 +234,8 @@ public class ClientCallImplTest {
}
}, false); // not advertised
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", knownMessageEncodings,
customRegistry, compressorRegistry);
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", customRegistry,
Codec.Identity.NONE);
Iterable<String> acceptedEncodings =
ACCEPT_ENCODING_SPLITER.split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
@ -269,8 +252,8 @@ public class ClientCallImplTest {
m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip");
m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip");
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, null, knownMessageEncodings,
DecompressorRegistry.newEmptyInstance(), compressorRegistry);
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, null,
DecompressorRegistry.newEmptyInstance(), Codec.Identity.NONE);
assertNull(m.get(GrpcUtil.AUTHORITY_KEY));
assertNull(m.get(GrpcUtil.USER_AGENT_KEY));
@ -294,9 +277,7 @@ public class ClientCallImplTest {
CallOptions.DEFAULT,
provider,
deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry)
.setCompressorRegistry(compressorRegistry)
.setKnownMessageEncodingRegistry(knownMessageEncodings);
.setDecompressorRegistry(decompressorRegistry);
Context.ROOT.attach();
@ -372,9 +353,7 @@ public class ClientCallImplTest {
CallOptions.DEFAULT,
provider,
deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry)
.setCompressorRegistry(compressorRegistry)
.setKnownMessageEncodingRegistry(knownMessageEncodings);
.setDecompressorRegistry(decompressorRegistry);
previous.attach();
@ -454,9 +433,7 @@ public class ClientCallImplTest {
callOptions,
provider,
deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry)
.setCompressorRegistry(compressorRegistry)
.setKnownMessageEncodingRegistry(knownMessageEncodings);
.setDecompressorRegistry(decompressorRegistry);
call.start(callListener, new Metadata());
assertFalse(future.isDone());
fakeClock.forwardTime(1, TimeUnit.SECONDS);

View File

@ -36,8 +36,7 @@ import static org.mockito.Matchers.isA;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Codec;
import io.grpc.Metadata;
import io.grpc.Status;
@ -77,10 +76,8 @@ public class DelayedStreamTest {
@Test
public void setStream_sendsAllMessages() {
stream.start(listener);
DecompressorRegistry decompressors = DecompressorRegistry.newEmptyInstance();
CompressorRegistry compressors = CompressorRegistry.newEmptyInstance();
stream.setDecompressionRegistry(decompressors);
stream.setCompressionRegistry(compressors);
stream.setCompressor(Codec.Identity.NONE);
stream.setDecompressor(Codec.Identity.NONE);
stream.setMessageCompression(true);
InputStream message = new ByteArrayInputStream(new byte[]{'a'});
@ -90,8 +87,8 @@ public class DelayedStreamTest {
stream.setStream(realStream);
verify(realStream).setDecompressionRegistry(decompressors);
verify(realStream).setCompressionRegistry(compressors);
verify(realStream).setCompressor(Codec.Identity.NONE);
verify(realStream).setDecompressor(Codec.Identity.NONE);
// Verify that the order was correct, even though they should be interleaved with the
// writeMessage calls

View File

@ -54,6 +54,7 @@ import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.IntegerMarshaller;
@ -195,8 +196,7 @@ public class ManagedChannelImplTest {
ClientTransport.Listener transportListener = transportListenerCaptor.getValue();
verify(mockTransport, timeout(1000)).newStream(same(method), same(headers));
verify(mockStream).start(streamListenerCaptor.capture());
verify(mockStream).setDecompressionRegistry(isA(DecompressorRegistry.class));
verify(mockStream).setCompressionRegistry(isA(CompressorRegistry.class));
verify(mockStream).setCompressor(isA(Compressor.class));
ClientStreamListener streamListener = streamListenerCaptor.getValue();
// Second call

View File

@ -44,7 +44,9 @@ import static org.mockito.Mockito.when;
import com.google.common.io.CharStreams;
import com.google.common.util.concurrent.Futures;
import io.grpc.CompressorRegistry;
import io.grpc.Context;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller;
@ -89,7 +91,8 @@ public class ServerCallImplTest {
public void setUp() {
MockitoAnnotations.initMocks(this);
context = Context.ROOT.withCancellation();
call = new ServerCallImpl<Long, Long>(stream, method, context);
call = new ServerCallImpl<Long, Long>(stream, method, new Metadata(), context,
DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance());
}
@Test

View File

@ -49,7 +49,10 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.Context;
import io.grpc.DecompressorRegistry;
import io.grpc.IntegerMarshaller;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
@ -91,6 +94,9 @@ public class ServerImplTest {
private static final Context.Key<String> SERVER_ONLY = Context.key("serverOnly");
private static final Context.CancellableContext SERVER_CONTEXT =
Context.ROOT.withValue(SERVER_ONLY, "yes").withCancellation();
private final CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance();
private final DecompressorRegistry decompressorRegistry =
DecompressorRegistry.getDefaultInstance();
static {
// Cancel the root context. Server will fork it so the per-call context should not
@ -101,7 +107,8 @@ public class ServerImplTest {
private ExecutorService executor = Executors.newSingleThreadExecutor();
private MutableHandlerRegistry registry = new MutableHandlerRegistryImpl();
private SimpleServer transportServer = new SimpleServer();
private ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT);
private ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT,
decompressorRegistry, compressorRegistry);
@Mock
private ServerStream stream;
@ -129,7 +136,8 @@ public class ServerImplTest {
@Override
public void shutdown() {}
};
ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT);
ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT,
decompressorRegistry, compressorRegistry);
server.start();
server.shutdown();
assertTrue(server.isShutdown());
@ -146,7 +154,8 @@ public class ServerImplTest {
throw new AssertionError("Should not be called, because wasn't started");
}
};
ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT);
ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT,
decompressorRegistry, compressorRegistry);
server.shutdown();
assertTrue(server.isShutdown());
assertTrue(server.isTerminated());
@ -154,7 +163,8 @@ public class ServerImplTest {
@Test
public void startStopImmediateWithChildTransport() throws IOException {
ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT);
ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT,
decompressorRegistry, compressorRegistry);
server.start();
class DelayedShutdownServerTransport extends SimpleServerTransport {
boolean shutdown;
@ -186,7 +196,7 @@ public class ServerImplTest {
}
ServerImpl server = new ServerImpl(executor, registry, new FailingStartupServer(),
SERVER_CONTEXT);
SERVER_CONTEXT, decompressorRegistry, compressorRegistry);
try {
server.start();
fail("expected exception");
@ -240,6 +250,7 @@ public class ServerImplTest {
responseHeaders.put(metadataKey, "response value");
call.sendHeaders(responseHeaders);
verify(stream).writeHeaders(responseHeaders);
verify(stream).setCompressor(isA(Compressor.class));
call.sendMessage(314);
ArgumentCaptor<InputStream> inputCaptor = ArgumentCaptor.forClass(InputStream.class);
@ -322,7 +333,8 @@ public class ServerImplTest {
}
transportServer = new MaybeDeadlockingServer();
ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT);
ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT,
decompressorRegistry, compressorRegistry);
server.start();
new Thread() {
@Override

View File

@ -230,11 +230,6 @@ public class CompressionTest {
assertNull(clientResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY));
}
// Must be null for the first call.
assertNull(serverResponseHeaders.get(MESSAGE_ENCODING_KEY));
assertFalse(clientCodec.anyWritten);
assertFalse(serverCodec.anyRead);
if (clientAcceptEncoding) {
assertEquals("fzip", serverResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY));
} else {
@ -242,7 +237,6 @@ public class CompressionTest {
}
// Second call, once the client knows what the server supports.
stub.unaryCall(REQUEST);
if (clientEncoding && serverAcceptEncoding) {
assertEquals("fzip", serverResponseHeaders.get(MESSAGE_ENCODING_KEY));
if (enableClientMessageCompression) {
@ -274,7 +268,11 @@ public class CompressionTest {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
final ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
if (clientEncoding && serverAcceptEncoding) {
callOptions = callOptions.withCompression("fzip");
}
ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
return new ClientCompressor<ReqT, RespT>(call);
}
}

View File

@ -35,8 +35,6 @@ import static com.google.common.base.Preconditions.checkNotNull;
import static io.netty.channel.ChannelOption.SO_BACKLOG;
import static io.netty.channel.ChannelOption.SO_KEEPALIVE;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.internal.Server;
import io.grpc.internal.ServerListener;
import io.grpc.internal.SharedResourceHolder;
@ -74,8 +72,6 @@ class NettyServer implements Server {
private EventLoopGroup workerGroup;
private ServerListener listener;
private Channel channel;
private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry;
private final int flowControlWindow;
private final int maxMessageSize;
private final int maxHeaderListSize;
@ -83,8 +79,7 @@ class NettyServer implements Server {
NettyServer(SocketAddress address, Class<? extends ServerChannel> channelType,
@Nullable EventLoopGroup bossGroup, @Nullable EventLoopGroup workerGroup,
ProtocolNegotiator protocolNegotiator, DecompressorRegistry decompressorRegistry,
CompressorRegistry compressorRegistry, int maxStreamsPerConnection,
ProtocolNegotiator protocolNegotiator, int maxStreamsPerConnection,
int flowControlWindow, int maxMessageSize, int maxHeaderListSize) {
this.address = address;
this.channelType = checkNotNull(channelType, "channelType");
@ -97,8 +92,6 @@ class NettyServer implements Server {
this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize;
this.maxHeaderListSize = maxHeaderListSize;
this.decompressorRegistry = checkNotNull(decompressorRegistry, "decompressorRegistry");
this.compressorRegistry = checkNotNull(compressorRegistry, "compressorRegistry");
}
@Override
@ -125,10 +118,8 @@ class NettyServer implements Server {
eventLoopReferenceCounter.release();
}
});
NettyServerTransport transport
= new NettyServerTransport(ch, protocolNegotiator, decompressorRegistry,
compressorRegistry, maxStreamsPerConnection, flowControlWindow, maxMessageSize,
maxHeaderListSize);
NettyServerTransport transport = new NettyServerTransport(ch, protocolNegotiator,
maxStreamsPerConnection, flowControlWindow, maxMessageSize, maxHeaderListSize);
transport.start(listener.transportCreated(transport));
}
});

View File

@ -31,14 +31,11 @@
package io.grpc.netty;
import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkArgument;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import com.google.common.base.Preconditions;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.ExperimentalApi;
import io.grpc.HandlerRegistry;
import io.grpc.Internal;
@ -244,12 +241,9 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder<NettySer
negotiator = sslContext != null ? ProtocolNegotiators.serverTls(sslContext) :
ProtocolNegotiators.serverPlaintext();
}
return new NettyServer(address, channelType, bossEventLoopGroup,
workerEventLoopGroup, negotiator,
firstNonNull(decompressorRegistry(), DecompressorRegistry.getDefaultInstance()),
firstNonNull(compressorRegistry(), CompressorRegistry.getDefaultInstance()),
maxConcurrentCallsPerConnection, flowControlWindow,
maxMessageSize, maxHeaderListSize);
return new NettyServer(address, channelType, bossEventLoopGroup, workerEventLoopGroup,
negotiator, maxConcurrentCallsPerConnection, flowControlWindow, maxMessageSize,
maxHeaderListSize);
}
@Override

View File

@ -43,8 +43,6 @@ import static io.netty.handler.codec.http2.DefaultHttp2LocalFlowController.DEFAU
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
@ -97,8 +95,6 @@ class NettyServerHandler extends AbstractNettyHandler {
private static final Status GOAWAY_STATUS = Status.UNAVAILABLE;
private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry;
private final Http2Connection.PropertyKey streamKey;
private final ServerTransportListener transportListener;
private final int maxMessageSize;
@ -107,8 +103,6 @@ class NettyServerHandler extends AbstractNettyHandler {
private WriteQueue serverWriteQueue;
static NettyServerHandler newHandler(ServerTransportListener transportListener,
DecompressorRegistry decompressorRegistry,
CompressorRegistry compressorRegistry,
int maxStreams,
int flowControlWindow,
int maxHeaderListSize,
@ -121,15 +115,13 @@ class NettyServerHandler extends AbstractNettyHandler {
new DefaultHttp2FrameReader(headersDecoder), frameLogger);
Http2FrameWriter frameWriter =
new Http2OutboundFrameLogger(new DefaultHttp2FrameWriter(), frameLogger);
return newHandler(frameReader, frameWriter, transportListener, decompressorRegistry,
compressorRegistry, maxStreams, flowControlWindow, maxMessageSize);
return newHandler(frameReader, frameWriter, transportListener, maxStreams, flowControlWindow,
maxMessageSize);
}
@VisibleForTesting
static NettyServerHandler newHandler(Http2FrameReader frameReader, Http2FrameWriter frameWriter,
ServerTransportListener transportListener,
DecompressorRegistry decompressorRegistry,
CompressorRegistry compressorRegistry,
int maxStreams,
int flowControlWindow,
int maxMessageSize) {
@ -151,21 +143,16 @@ class NettyServerHandler extends AbstractNettyHandler {
settings.initialWindowSize(flowControlWindow);
settings.maxConcurrentStreams(maxStreams);
return new NettyServerHandler(transportListener, decoder, encoder, settings,
decompressorRegistry, compressorRegistry, maxMessageSize);
return new NettyServerHandler(transportListener, decoder, encoder, settings, maxMessageSize);
}
private NettyServerHandler(ServerTransportListener transportListener,
Http2ConnectionDecoder decoder,
Http2ConnectionEncoder encoder, Http2Settings settings,
DecompressorRegistry decompressorRegistry,
CompressorRegistry compressorRegistry,
int maxMessageSize) {
super(decoder, encoder, settings);
checkArgument(maxMessageSize >= 0, "maxMessageSize must be >= 0");
this.maxMessageSize = maxMessageSize;
this.decompressorRegistry = checkNotNull(decompressorRegistry, "decompressorRegistry");
this.compressorRegistry = checkNotNull(compressorRegistry, "compressorRegistry");
streamKey = encoder.connection().newKey();
this.transportListener = checkNotNull(transportListener, "transportListener");
@ -206,11 +193,6 @@ class NettyServerHandler extends AbstractNettyHandler {
NettyServerStream stream = new NettyServerStream(ctx.channel(), http2Stream, this,
maxMessageSize);
// These must be called before inboundHeadersReceived, because the framers depend on knowing
// the compression algorithms available before negotiation.
stream.setDecompressionRegistry(decompressorRegistry);
stream.setCompressionRegistry(compressorRegistry);
Metadata metadata = Utils.convertHeaders(headers);
stream.inboundHeadersReceived(metadata);

View File

@ -31,12 +31,8 @@
package io.grpc.netty;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.base.Preconditions;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.internal.ServerTransport;
import io.grpc.internal.ServerTransportListener;
import io.netty.channel.Channel;
@ -55,8 +51,6 @@ class NettyServerTransport implements ServerTransport {
private final Channel channel;
private final ProtocolNegotiator protocolNegotiator;
private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry;
private final int maxStreams;
private ServerTransportListener listener;
private boolean terminated;
@ -64,18 +58,14 @@ class NettyServerTransport implements ServerTransport {
private final int maxMessageSize;
private final int maxHeaderListSize;
NettyServerTransport(Channel channel, ProtocolNegotiator protocolNegotiator,
DecompressorRegistry decompressorRegistry,
CompressorRegistry compressorRegistry, int maxStreams, int flowControlWindow,
int maxMessageSize, int maxHeaderListSize) {
NettyServerTransport(Channel channel, ProtocolNegotiator protocolNegotiator, int maxStreams,
int flowControlWindow, int maxMessageSize, int maxHeaderListSize) {
this.channel = Preconditions.checkNotNull(channel, "channel");
this.protocolNegotiator = Preconditions.checkNotNull(protocolNegotiator, "protocolNegotiator");
this.maxStreams = maxStreams;
this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize;
this.maxHeaderListSize = maxHeaderListSize;
this.decompressorRegistry = checkNotNull(decompressorRegistry, "decompressorRegistry");
this.compressorRegistry = checkNotNull(compressorRegistry, "compressorRegistry");
}
public void start(ServerTransportListener listener) {
@ -125,7 +115,7 @@ class NettyServerTransport implements ServerTransport {
* Creates the Netty handler to be used in the channel pipeline.
*/
private NettyServerHandler createHandler(ServerTransportListener transportListener) {
return NettyServerHandler.newHandler(transportListener, decompressorRegistry,
compressorRegistry, maxStreams, flowControlWindow, maxHeaderListSize, maxMessageSize);
return NettyServerHandler.newHandler(transportListener, maxStreams, flowControlWindow,
maxHeaderListSize, maxMessageSize);
}
}

View File

@ -43,8 +43,6 @@ import static org.junit.Assert.fail;
import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller;
@ -321,10 +319,8 @@ public class NettyClientTransportTest {
SslContext serverContext = GrpcSslContexts.forServer(serverCert, key)
.ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build();
ProtocolNegotiator negotiator = ProtocolNegotiators.serverTls(serverContext);
server = new NettyServer(address, NioServerSocketChannel.class,
group, group, negotiator, DecompressorRegistry.getDefaultInstance(),
CompressorRegistry.getDefaultInstance(), maxStreamsPerConnection,
DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, maxHeaderListSize);
server = new NettyServer(address, NioServerSocketChannel.class, group, group, negotiator,
maxStreamsPerConnection, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, maxHeaderListSize);
server.start(serverListener);
}

View File

@ -54,8 +54,6 @@ import static org.mockito.Mockito.when;
import com.google.common.io.ByteStreams;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.Status.Code;
@ -347,7 +345,6 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
@Override
protected NettyServerHandler newHandler() {
return NettyServerHandler.newHandler(frameReader(), frameWriter(), transportListener,
DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance(),
maxConcurrentStreams, flowControlWindow, DEFAULT_MAX_MESSAGE_SIZE);
}

View File

@ -50,9 +50,7 @@ import static org.mockito.Mockito.when;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.ServerStreamListener;
import io.netty.buffer.EmptyByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.ChannelPromise;
@ -96,7 +94,6 @@ public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream
stream.writeHeaders(new Metadata());
Http2Headers headers = new DefaultHttp2Headers()
.status(Utils.STATUS_OK)
.set(GrpcUtil.MESSAGE_ACCEPT_ENCODING, AsciiString.of("gzip"))
.set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_GRPC);
verify(writeQueue).enqueue(new SendResponseHeadersCommand(STREAM_ID, headers, false), true);
byte[] msg = smallMessage();

View File

@ -35,6 +35,7 @@ import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientInterceptor;
import io.grpc.ClientInterceptors;
import io.grpc.ExperimentalApi;
import java.util.concurrent.TimeUnit;
@ -116,6 +117,20 @@ public abstract class AbstractStub<S extends AbstractStub<S>> {
return build(channel, callOptions.withDeadlineAfter(duration, unit));
}
/**
* Set's the compressor name to use for the call. It is the responsibility of the application
* to make sure the server supports decoding the compressor picked by the client. To be clear,
* this is the compressor used by the stub to compress messages to the server. To get
* compressed responses from the server, set the appropriate {@link io.grpc.DecompressorRegistry}
* on the {@link io.grpc.ManagedChannelBuilder}.
*
* @param compressorName the name (e.g. "gzip") of the compressor to use.
*/
@ExperimentalApi
public final S withCompression(String compressorName) {
return build(channel, callOptions.withCompression(compressorName));
}
/**
* Returns a new stub that uses the given channel.
*/