Add a simple compression API

This commit is contained in:
Carl Mastrangelo 2016-01-27 16:49:41 -08:00
parent 6af2ddafe5
commit a3c79e87ae
31 changed files with 305 additions and 416 deletions

View File

@ -63,6 +63,9 @@ public final class CallOptions {
@Nullable @Nullable
private RequestKey requestKey; private RequestKey requestKey;
@Nullable
private String compressorName;
/** /**
* Override the HTTP/2 authority the channel claims to be connecting to. <em>This is not * Override the HTTP/2 authority the channel claims to be connecting to. <em>This is not
* generally safe.</em> Overriding allows advanced users to re-use a single Channel for multiple * generally safe.</em> Overriding allows advanced users to re-use a single Channel for multiple
@ -79,6 +82,17 @@ public final class CallOptions {
return newOptions; return newOptions;
} }
/**
* Sets the compression to use for the call. The compressor must be a valid name known in the
* {@link CompressorRegistry}.
*/
@ExperimentalApi
public CallOptions withCompression(@Nullable String compressorName) {
CallOptions newOptions = new CallOptions(this);
newOptions.compressorName = compressorName;
return newOptions;
}
/** /**
* Returns a new {@code CallOptions} with the given absolute deadline in nanoseconds in the clock * Returns a new {@code CallOptions} with the given absolute deadline in nanoseconds in the clock
* as per {@link System#nanoTime()}. * as per {@link System#nanoTime()}.
@ -131,6 +145,16 @@ public final class CallOptions {
return requestKey; return requestKey;
} }
/**
* Returns the compressor's name.
*/
@ExperimentalApi
@Nullable
public String getCompressor() {
return compressorName;
}
/** /**
* Override the HTTP/2 authority the channel claims to be connecting to. <em>This is not * Override the HTTP/2 authority the channel claims to be connecting to. <em>This is not
* generally safe.</em> Overriding allows advanced users to re-use a single Channel for multiple * generally safe.</em> Overriding allows advanced users to re-use a single Channel for multiple
@ -172,6 +196,7 @@ public final class CallOptions {
authority = other.authority; authority = other.authority;
requestKey = other.requestKey; requestKey = other.requestKey;
executor = other.executor; executor = other.executor;
compressorName = other.compressorName;
} }
@Override @Override

View File

@ -76,6 +76,12 @@ public abstract class ForwardingServerCall<RespT> extends ServerCall<RespT> {
delegate().setMessageCompression(enabled); delegate().setMessageCompression(enabled);
} }
@Override
@ExperimentalApi
public void setCompression(String compressor) {
delegate().setCompression(compressor);
}
/** /**
* A simplified version of {@link ForwardingServerCall} where subclasses can pass in a {@link * A simplified version of {@link ForwardingServerCall} where subclasses can pass in a {@link
* ServerCall} as the delegate. * ServerCall} as the delegate.

View File

@ -176,4 +176,18 @@ public abstract class ServerCall<RespT> {
public void setMessageCompression(boolean enabled) { public void setMessageCompression(boolean enabled) {
// noop // noop
} }
/**
* Sets the compression algorithm for this call. If the server does not support the compression
* algorithm, the call will fail. This method may only be called before {@link #sendHeaders}.
* The compressor to use will be looked up in the {@link CompressorRegistry}. Default gRPC
* servers support the "gzip" compressor.
*
* @param compressor the name of the compressor to use.
* @throws IllegalArgumentException if the compressor name can not be found.
*/
@ExperimentalApi
public void setCompression(String compressor) {
// noop
}
} }

View File

@ -34,8 +34,7 @@ package io.grpc.inprocess;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import io.grpc.Compressor; import io.grpc.Compressor;
import io.grpc.CompressorRegistry; import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
@ -230,9 +229,6 @@ class InProcessTransport implements ServerTransport, ClientTransport {
clientStreamListener = listener; clientStreamListener = listener;
} }
@Override
public void setDecompressionRegistry(DecompressorRegistry registry) {}
@Override @Override
public void request(int numMessages) { public void request(int numMessages) {
clientStream.serverRequested(numMessages); clientStream.serverRequested(numMessages);
@ -345,12 +341,10 @@ class InProcessTransport implements ServerTransport, ClientTransport {
} }
@Override @Override
public Compressor pickCompressor(Iterable<String> messageEncodings) { public void setCompressor(Compressor compressor) {}
return null;
}
@Override @Override
public void setCompressionRegistry(CompressorRegistry registry) {} public void setDecompressor(Decompressor decompressor) {}
} }
private class InProcessClientStream implements ClientStream { private class InProcessClientStream implements ClientStream {
@ -457,20 +451,9 @@ class InProcessTransport implements ServerTransport, ClientTransport {
} }
} }
@Override
public void setDecompressionRegistry(DecompressorRegistry registry) {}
@Override @Override
public void setMessageCompression(boolean enable) {} public void setMessageCompression(boolean enable) {}
@Override
public Compressor pickCompressor(Iterable<String> messageEncodings) {
return null;
}
@Override
public void setCompressionRegistry(CompressorRegistry registry) {}
@Override @Override
public void start(ClientStreamListener listener) { public void start(ClientStreamListener listener) {
serverStream.setListener(listener); serverStream.setListener(listener);
@ -483,6 +466,12 @@ class InProcessTransport implements ServerTransport, ClientTransport {
streams.add(InProcessTransport.InProcessStream.this); streams.add(InProcessTransport.InProcessStream.this);
} }
} }
@Override
public void setCompressor(Compressor compressor) {}
@Override
public void setDecompressor(Decompressor decompressor) {}
} }
} }
} }

View File

@ -122,21 +122,6 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
log.log(Level.INFO, "Received headers on closed stream {0} {1}", log.log(Level.INFO, "Received headers on closed stream {0} {1}",
new Object[]{id(), headers}); new Object[]{id(), headers});
} }
if (headers.containsKey(GrpcUtil.MESSAGE_ENCODING_KEY)) {
String messageEncoding = headers.get(GrpcUtil.MESSAGE_ENCODING_KEY);
try {
setDecompressor(messageEncoding);
} catch (IllegalArgumentException e) {
// Don't use INVALID_ARGUMENT since that is for servers to send clients.
Status status = Status.INTERNAL.withDescription("Unable to decompress message from server.")
.withCause(e);
// TODO(carl-mastrangelo): look back into tearing down this stream. sendCancel() can be
// buffered.
inboundTransportError(status, headers);
sendCancel(status);
return;
}
}
inboundPhase(Phase.MESSAGE); inboundPhase(Phase.MESSAGE);
listener.headersRead(headers); listener.headersRead(headers);

View File

@ -31,6 +31,8 @@
package io.grpc.internal; package io.grpc.internal;
import static com.google.common.base.MoreObjects.firstNonNull;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
@ -112,24 +114,18 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
return thisT(); return thisT();
} }
protected final DecompressorRegistry decompressorRegistry() {
return decompressorRegistry;
}
@Override @Override
public final T compressorRegistry(CompressorRegistry registry) { public final T compressorRegistry(CompressorRegistry registry) {
compressorRegistry = registry; compressorRegistry = registry;
return thisT(); return thisT();
} }
protected final CompressorRegistry compressorRegistry() {
return compressorRegistry;
}
@Override @Override
public ServerImpl build() { public ServerImpl build() {
io.grpc.internal.Server transportServer = buildTransportServer(); io.grpc.internal.Server transportServer = buildTransportServer();
return new ServerImpl(executor, registry, transportServer, Context.ROOT); return new ServerImpl(executor, registry, transportServer, Context.ROOT,
firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()),
firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()));
} }
/** /**

View File

@ -32,14 +32,9 @@
package io.grpc.internal; package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER;
import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.Compressor;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
@ -61,7 +56,6 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
private ServerStreamListener listener; private ServerStreamListener listener;
private boolean headersSent = false; private boolean headersSent = false;
private String messageEncoding;
/** /**
* Whether the stream was closed gracefully by the application (vs. a transport-level failure). * Whether the stream was closed gracefully by the application (vs. a transport-level failure).
*/ */
@ -100,16 +94,6 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
@Override @Override
public final void writeHeaders(Metadata headers) { public final void writeHeaders(Metadata headers) {
Preconditions.checkNotNull(headers, "headers"); Preconditions.checkNotNull(headers, "headers");
headers.removeAll(MESSAGE_ENCODING_KEY);
if (messageEncoding != null) {
headers.put(MESSAGE_ENCODING_KEY, messageEncoding);
}
headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY);
if (!decompressorRegistry().getAdvertisedMessageEncodings().isEmpty()) {
String acceptEncoding =
ACCEPT_ENCODING_JOINER.join(decompressorRegistry().getAdvertisedMessageEncodings());
headers.put(MESSAGE_ACCEPT_ENCODING_KEY, acceptEncoding);
}
outboundPhase(Phase.HEADERS); outboundPhase(Phase.HEADERS);
headersSent = true; headersSent = true;
@ -152,34 +136,6 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
* @param headers the parsed headers * @param headers the parsed headers
*/ */
protected void inboundHeadersReceived(Metadata headers) { protected void inboundHeadersReceived(Metadata headers) {
if (headers.containsKey(GrpcUtil.MESSAGE_ENCODING_KEY)) {
String messageEncoding = headers.get(GrpcUtil.MESSAGE_ENCODING_KEY);
try {
setDecompressor(messageEncoding);
} catch (IllegalArgumentException e) {
Status status = Status.INVALID_ARGUMENT
.withDescription("Unable to decompress encoding " + messageEncoding)
.withCause(e);
abortStream(status, true);
return;
}
}
// This checks to see if the client will accept any encoding. If so, a compressor is picked for
// the stream, and the decision is recorded. When the Server Call Handler writes the first
// headers, the negotiated encoding will be added in #writeHeaders(). It is safe to call
// pickCompressor multiple times before the headers have been written to the wire, though in
// practice this should never happen. There should only be one call to inboundHeadersReceived.
// Alternatively, compression could be negotiated after the server handler is invoked, but that
// would mean the inbound header would have to be stored until the first #writeHeaders call.
if (headers.containsKey(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)) {
Compressor c = pickCompressor(
ACCEPT_ENCODING_SPLITER.split(headers.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)));
if (c != null) {
messageEncoding = c.getMessageEncoding();
}
}
inboundPhase(Phase.MESSAGE); inboundPhase(Phase.MESSAGE);
} }

View File

@ -31,7 +31,6 @@
package io.grpc.internal; package io.grpc.internal;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
@ -40,9 +39,7 @@ import com.google.common.base.MoreObjects;
import io.grpc.Codec; import io.grpc.Codec;
import io.grpc.Compressor; import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.Decompressor; import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import java.io.InputStream; import java.io.InputStream;
@ -101,10 +98,6 @@ public abstract class AbstractStream<IdT> implements Stream {
private boolean allocated; private boolean allocated;
private final Object onReadyLock = new Object(); private final Object onReadyLock = new Object();
private volatile DecompressorRegistry decompressorRegistry =
DecompressorRegistry.getDefaultInstance();
private volatile CompressorRegistry compressorRegistry =
CompressorRegistry.getDefaultInstance();
@VisibleForTesting @VisibleForTesting
class FramerSink implements MessageFramer.Sink { class FramerSink implements MessageFramer.Sink {
@ -305,47 +298,14 @@ public abstract class AbstractStream<IdT> implements Stream {
} }
} }
/** @Override
* Looks up the decompressor by its message encoding name, and sets it for this stream. public final void setCompressor(Compressor compressor) {
* Decompressors are registered with {@link DecompressorRegistry#register}. framer.setCompressor(checkNotNull(compressor, "compressor"));
*
* @param messageEncoding the name of the encoding provided by the remote host
* @throws IllegalArgumentException if the provided message encoding cannot be found.
*/
protected final void setDecompressor(String messageEncoding) {
Decompressor d = decompressorRegistry.lookupDecompressor(messageEncoding);
checkArgument(d != null,
"Unable to find decompressor for message encoding %s", messageEncoding);
deframer.setDecompressor(d);
} }
@Override @Override
public final void setDecompressionRegistry(DecompressorRegistry registry) { public final void setDecompressor(Decompressor decompressor) {
decompressorRegistry = checkNotNull(registry); deframer.setDecompressor(checkNotNull(decompressor, "decompressor"));
}
@Override
public final void setCompressionRegistry(CompressorRegistry registry) {
compressorRegistry = checkNotNull(registry);
}
@Override
public final Compressor pickCompressor(Iterable<String> messageEncodings) {
for (String messageEncoding : messageEncodings) {
Compressor c = compressorRegistry.lookupCompressor(messageEncoding);
if (c != null) {
// TODO(carl-mastrangelo): check that headers haven't already been sent. I can't find where
// the client stream changes outbound phase correctly, so I am ignoring it.
framer.setCompressor(c);
return c;
}
}
return null;
}
// TODO(carl-mastrangelo): this is a hack to get around registry passing. Remove it.
protected final DecompressorRegistry decompressorRegistry() {
return decompressorRegistry;
} }
/** /**

View File

@ -33,10 +33,8 @@ package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterables.addAll;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER; import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER;
import static io.grpc.internal.GrpcUtil.AUTHORITY_KEY; import static io.grpc.internal.GrpcUtil.AUTHORITY_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;
@ -57,6 +55,7 @@ import io.grpc.Codec;
import io.grpc.Compressor; import io.grpc.Compressor;
import io.grpc.CompressorRegistry; import io.grpc.CompressorRegistry;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry; import io.grpc.DecompressorRegistry;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
@ -64,8 +63,6 @@ import io.grpc.MethodDescriptor.MethodType;
import io.grpc.Status; import io.grpc.Status;
import java.io.InputStream; import java.io.InputStream;
import java.util.Collections;
import java.util.Set;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
@ -91,7 +88,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
private final ClientTransportProvider clientTransportProvider; private final ClientTransportProvider clientTransportProvider;
private String userAgent; private String userAgent;
private ScheduledExecutorService deadlineCancellationExecutor; private ScheduledExecutorService deadlineCancellationExecutor;
private Set<String> knownMessageEncodingRegistry; private Compressor compressor;
private DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance(); private DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance();
private CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance(); private CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance();
@ -146,19 +143,9 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
return this; return this;
} }
/**
* Sets encodings known to be supported by the server. This set MUST be thread safe, and MAY be
* modified by any code as it learns about new supported encodings.
*/
ClientCallImpl<ReqT, RespT> setKnownMessageEncodingRegistry(Set<String> knownMessageEncodings) {
this.knownMessageEncodingRegistry = knownMessageEncodings;
return this;
}
@VisibleForTesting @VisibleForTesting
static void prepareHeaders(Metadata headers, CallOptions callOptions, String userAgent, static void prepareHeaders(Metadata headers, CallOptions callOptions, String userAgent,
Set<String> knownMessageEncodings, DecompressorRegistry decompressorRegistry, DecompressorRegistry decompressorRegistry, Compressor compressor) {
CompressorRegistry compressorRegistry) {
// Hack to propagate authority. This should be properly pass to the transport.newStream // Hack to propagate authority. This should be properly pass to the transport.newStream
// somehow. // somehow.
headers.removeAll(AUTHORITY_KEY); headers.removeAll(AUTHORITY_KEY);
@ -173,12 +160,8 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
} }
headers.removeAll(MESSAGE_ENCODING_KEY); headers.removeAll(MESSAGE_ENCODING_KEY);
for (String messageEncoding : knownMessageEncodings) { if (compressor != Codec.Identity.NONE) {
Compressor compressor = compressorRegistry.lookupCompressor(messageEncoding); headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding());
if (compressor != null && compressor != Codec.Identity.NONE) {
headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding());
break;
}
} }
headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY); headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY);
@ -207,8 +190,27 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
}); });
return; return;
} }
prepareHeaders(headers, callOptions, userAgent, final String compressorName = callOptions.getCompressor();
knownMessageEncodingRegistry, decompressorRegistry, compressorRegistry); if (compressorName != null) {
compressor = compressorRegistry.lookupCompressor(compressorName);
if (compressor == null) {
stream = NoopClientStream.INSTANCE;
callExecutor.execute(new ContextRunnable(context) {
@Override
public void runInContext() {
observer.onClose(
Status.INTERNAL.withDescription(
String.format("Unable to find compressor by name %s", compressorName)),
new Metadata());
}
});
return;
}
} else {
compressor = Codec.Identity.NONE;
}
prepareHeaders(headers, callOptions, userAgent, decompressorRegistry, compressor);
ListenableFuture<ClientTransport> transportFuture = clientTransportProvider.get(callOptions); ListenableFuture<ClientTransport> transportFuture = clientTransportProvider.get(callOptions);
@ -236,11 +238,8 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
transportFuture.isDone() ? directExecutor() : callExecutor); transportFuture.isDone() ? directExecutor() : callExecutor);
} }
stream.setDecompressionRegistry(decompressorRegistry); stream.setCompressor(compressor);
stream.setCompressionRegistry(compressorRegistry); if (compressor != Codec.Identity.NONE) {
if (headers.containsKey(MESSAGE_ENCODING_KEY)) {
stream.pickCompressor(Collections.singleton(headers.get(MESSAGE_ENCODING_KEY)));
// TODO(carl-mastrangelo): move this to ClientCall.
stream.setMessageCompression(true); stream.setMessageCompression(true);
} }
@ -387,13 +386,18 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
@Override @Override
public void headersRead(final Metadata headers) { public void headersRead(final Metadata headers) {
if (headers.containsKey(MESSAGE_ACCEPT_ENCODING_KEY)) { Decompressor decompressor = Codec.Identity.NONE;
// TODO(carl-mastrangelo): after the first time we contact the server, it almost certainly if (headers.containsKey(MESSAGE_ENCODING_KEY)) {
// won't change. It might be possible to recover performance by not adding to the known String encoding = headers.get(MESSAGE_ENCODING_KEY);
// encodings if it isn't empty. decompressor = decompressorRegistry.lookupDecompressor(encoding);
String serverAcceptEncodings = headers.get(MESSAGE_ACCEPT_ENCODING_KEY); if (decompressor == null) {
addAll(knownMessageEncodingRegistry, ACCEPT_ENCODING_SPLITER.split(serverAcceptEncodings)); stream.cancel(Status.INTERNAL.withDescription(
String.format("Can't find decompressor for %s", encoding)));
return;
}
} }
stream.setDecompressor(decompressor);
callExecutor.execute(new ContextRunnable(context) { callExecutor.execute(new ContextRunnable(context) {
@Override @Override
public final void runInContext() { public final void runInContext() {

View File

@ -35,8 +35,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import io.grpc.Compressor; import io.grpc.Compressor;
import io.grpc.CompressorRegistry; import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
@ -67,10 +66,6 @@ class DelayedStream implements ClientStream {
@GuardedBy("this") @GuardedBy("this")
private Status error; private Status error;
@GuardedBy("this")
private Iterable<String> compressionMessageEncodings;
@GuardedBy("this")
private DecompressorRegistry decompressionRegistry;
@GuardedBy("this") @GuardedBy("this")
private final List<PendingMessage> pendingMessages = new LinkedList<PendingMessage>(); private final List<PendingMessage> pendingMessages = new LinkedList<PendingMessage>();
private boolean messageCompressionEnabled; private boolean messageCompressionEnabled;
@ -81,7 +76,9 @@ class DelayedStream implements ClientStream {
@GuardedBy("this") @GuardedBy("this")
private boolean pendingFlush; private boolean pendingFlush;
@GuardedBy("this") @GuardedBy("this")
private CompressorRegistry compressionRegistry; private Compressor compressor;
@GuardedBy("this")
private Decompressor decompressor;
static final class PendingMessage { static final class PendingMessage {
final InputStream message; final InputStream message;
@ -118,15 +115,13 @@ class DelayedStream implements ClientStream {
checkState(listener != null, "listener"); checkState(listener != null, "listener");
realStream.start(listener); realStream.start(listener);
if (compressionMessageEncodings != null) { if (decompressor != null) {
realStream.pickCompressor(compressionMessageEncodings); realStream.setDecompressor(decompressor);
} }
if (this.decompressionRegistry != null) { if (compressor != null) {
realStream.setDecompressionRegistry(this.decompressionRegistry); realStream.setCompressor(compressor);
}
if (this.compressionRegistry != null) {
realStream.setCompressionRegistry(this.compressionRegistry);
} }
for (PendingMessage message : pendingMessages) { for (PendingMessage message : pendingMessages) {
realStream.setMessageCompression(message.shouldBeCompressed); realStream.setMessageCompression(message.shouldBeCompressed);
realStream.writeMessage(message.message); realStream.writeMessage(message.message);
@ -246,45 +241,29 @@ class DelayedStream implements ClientStream {
} }
@Override @Override
public Compressor pickCompressor(Iterable<String> messageEncodings) { public void setCompressor(Compressor compressor) {
if (startedRealStream == null) { if (startedRealStream == null) {
synchronized (this) { synchronized (this) {
if (startedRealStream == null) { if (startedRealStream == null) {
compressionMessageEncodings = messageEncodings; this.compressor = compressor;
// ClientCall never uses this. Since the stream doesn't exist yet, it can't say what
// stream it would pick. Eventually this will need a cleaner solution.
// TODO(carl-mastrangelo): Remove this.
return null;
}
}
}
return startedRealStream.pickCompressor(messageEncodings);
}
@Override
public void setCompressionRegistry(CompressorRegistry registry) {
if (startedRealStream == null) {
synchronized (this) {
if (startedRealStream == null) {
compressionRegistry = registry;
return; return;
} }
} }
} }
startedRealStream.setCompressionRegistry(registry); startedRealStream.setCompressor(compressor);
} }
@Override @Override
public void setDecompressionRegistry(DecompressorRegistry registry) { public void setDecompressor(Decompressor decompressor) {
if (startedRealStream == null) { if (startedRealStream == null) {
synchronized (this) { synchronized (this) {
if (startedRealStream == null) { if (startedRealStream == null) {
decompressionRegistry = registry; this.decompressor = decompressor;
return; return;
} }
} }
} }
startedRealStream.setDecompressionRegistry(registry); startedRealStream.setDecompressor(decompressor);
} }
@Override @Override

View File

@ -60,12 +60,9 @@ import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
@ -96,18 +93,6 @@ public final class ManagedChannelImpl extends ManagedChannel {
private final String userAgent; private final String userAgent;
private final Object lock = new Object(); private final Object lock = new Object();
/* Compression related */
/**
* When a client connects to a server, it does not know what encodings are supported. This set
* is the union of all accept-encoding headers the server has sent. It is used to pick an
* encoding when contacting the server again. One problem with the gRPC protocol is that if
* there is only one RPC made (perhaps streaming, or otherwise long lived) an encoding will not
* be selected. To combat this you can preflight a request to the server to fill in the mapping
* for the next one. A better solution is if you have prior knowledge that the server supports
* an encoding, and fill this structure before the request.
*/
private final Set<String> knownAcceptEncodingRegistry =
Collections.newSetFromMap(new ConcurrentHashMap<String, Boolean>());
private final DecompressorRegistry decompressorRegistry; private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry; private final CompressorRegistry compressorRegistry;
@ -335,8 +320,7 @@ public final class ManagedChannelImpl extends ManagedChannel {
scheduledExecutor) scheduledExecutor)
.setUserAgent(userAgent) .setUserAgent(userAgent)
.setDecompressorRegistry(decompressorRegistry) .setDecompressorRegistry(decompressorRegistry)
.setCompressorRegistry(compressorRegistry) .setCompressorRegistry(compressorRegistry);
.setKnownMessageEncodingRegistry(knownAcceptEncodingRegistry);
} }
@Override @Override

View File

@ -32,8 +32,7 @@
package io.grpc.internal; package io.grpc.internal;
import io.grpc.Compressor; import io.grpc.Compressor;
import io.grpc.CompressorRegistry; import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import io.grpc.Status; import io.grpc.Status;
import java.io.InputStream; import java.io.InputStream;
@ -67,19 +66,14 @@ public class NoopClientStream implements ClientStream {
@Override @Override
public void halfClose() {} public void halfClose() {}
@Override
public void setDecompressionRegistry(DecompressorRegistry registry) {}
@Override @Override
public void setMessageCompression(boolean enable) { public void setMessageCompression(boolean enable) {
// noop // noop
} }
@Override @Override
public Compressor pickCompressor(Iterable<String> messageEncodings) { public void setCompressor(Compressor compressor) {}
return null;
}
@Override @Override
public void setCompressionRegistry(CompressorRegistry registry) {} public void setDecompressor(Decompressor decompressor) {}
} }

View File

@ -31,13 +31,23 @@
package io.grpc.internal; package io.grpc.internal;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER;
import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITER;
import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables; import com.google.common.base.Throwables;
import io.grpc.Codec;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.MethodType; import io.grpc.MethodDescriptor.MethodType;
@ -46,22 +56,44 @@ import io.grpc.Status;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Future; import java.util.concurrent.Future;
final class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> { final class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> {
private final ServerStream stream; private final ServerStream stream;
private final MethodDescriptor<ReqT, RespT> method; private final MethodDescriptor<ReqT, RespT> method;
private final Context.CancellableContext context; private final Context.CancellableContext context;
private Metadata inboundHeaders;
private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry;
// state // state
private volatile boolean cancelled; private volatile boolean cancelled;
private boolean sendHeadersCalled; private boolean sendHeadersCalled;
private boolean closeCalled; private boolean closeCalled;
private Compressor compressor;
ServerCallImpl(ServerStream stream, MethodDescriptor<ReqT, RespT> method, ServerCallImpl(ServerStream stream, MethodDescriptor<ReqT, RespT> method,
Context.CancellableContext context) { Metadata inboundHeaders, Context.CancellableContext context,
DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry) {
this.stream = stream; this.stream = stream;
this.method = method; this.method = method;
this.context = context; this.context = context;
this.inboundHeaders = inboundHeaders;
this.decompressorRegistry = decompressorRegistry;
this.compressorRegistry = compressorRegistry;
if (inboundHeaders.containsKey(MESSAGE_ENCODING_KEY)) {
String encoding = inboundHeaders.get(MESSAGE_ENCODING_KEY);
Decompressor decompressor = decompressorRegistry.lookupDecompressor(encoding);
if (decompressor == null) {
throw Status.INTERNAL
.withDescription(String.format("Can't find decompressor for %s", encoding))
.asRuntimeException();
}
stream.setDecompressor(decompressor);
}
} }
@Override @Override
@ -73,6 +105,44 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> {
public void sendHeaders(Metadata headers) { public void sendHeaders(Metadata headers) {
checkState(!sendHeadersCalled, "sendHeaders has already been called"); checkState(!sendHeadersCalled, "sendHeaders has already been called");
checkState(!closeCalled, "call is closed"); checkState(!closeCalled, "call is closed");
headers.removeAll(MESSAGE_ENCODING_KEY);
if (compressor == null) {
compressor = Codec.Identity.NONE;
if (inboundHeaders.containsKey(MESSAGE_ACCEPT_ENCODING_KEY)) {
String acceptEncodings = inboundHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY);
for (String acceptEncoding : ACCEPT_ENCODING_SPLITER.split(acceptEncodings)) {
Compressor c = compressorRegistry.lookupCompressor(acceptEncoding);
if (c != null) {
compressor = c;
break;
}
}
}
} else {
if (inboundHeaders.containsKey(MESSAGE_ACCEPT_ENCODING_KEY)) {
String acceptEncodings = inboundHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY);
List<String> acceptedEncodingsList = ACCEPT_ENCODING_SPLITER.splitToList(acceptEncodings);
if (!acceptedEncodingsList.contains(compressor.getMessageEncoding())) {
// resort to using no compression.
compressor = Codec.Identity.NONE;
}
} else {
compressor = Codec.Identity.NONE;
}
}
inboundHeaders = null;
if (compressor != Codec.Identity.NONE) {
headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding());
}
stream.setCompressor(compressor);
headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY);
Set<String> acceptEncodings = decompressorRegistry.getAdvertisedMessageEncodings();
if (!acceptEncodings.isEmpty()) {
headers.put(MESSAGE_ACCEPT_ENCODING_KEY, ACCEPT_ENCODING_JOINER.join(acceptEncodings));
}
// Don't check if sendMessage has been called, since it requires that sendHeaders was already // Don't check if sendMessage has been called, since it requires that sendHeaders was already
// called. // called.
sendHeadersCalled = true; sendHeadersCalled = true;
@ -98,6 +168,15 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> {
stream.setMessageCompression(enable); stream.setMessageCompression(enable);
} }
@Override
public void setCompression(String compressorName) {
// Added here to give a better error message.
checkState(!sendHeadersCalled, "sendHeaders has been called");
compressor = compressorRegistry.lookupCompressor(compressorName);
checkArgument(compressor != null, "Unable to find compressor by name %s", compressorName);
}
@Override @Override
public boolean isReady() { public boolean isReady() {
return stream.isReady(); return stream.isReady();
@ -108,6 +187,7 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> {
try { try {
checkState(!closeCalled, "call already closed"); checkState(!closeCalled, "call already closed");
closeCalled = true; closeCalled = true;
inboundHeaders = null;
stream.close(status, trailers); stream.close(status, trailers);
} finally { } finally {
if (status.getCode() == Status.Code.OK) { if (status.getCode() == Status.Code.OK) {

View File

@ -40,7 +40,9 @@ import com.google.common.base.Throwables;
import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.CompressorRegistry;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.DecompressorRegistry;
import io.grpc.HandlerRegistry; import io.grpc.HandlerRegistry;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerCall; import io.grpc.ServerCall;
@ -104,20 +106,26 @@ public final class ServerImpl extends io.grpc.Server {
private final ScheduledExecutorService timeoutService = SharedResourceHolder.get(TIMER_SERVICE); private final ScheduledExecutorService timeoutService = SharedResourceHolder.get(TIMER_SERVICE);
private final Context rootContext; private final Context rootContext;
private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry;
/** /**
* Construct a server. * Construct a server.
* *
* @param executor to call methods on behalf of remote clients * @param executor to call methods on behalf of remote clients
* @param registry of methods to expose to remote clients. * @param registry of methods to expose to remote clients.
*/ */
ServerImpl(Executor executor, HandlerRegistry registry, ServerImpl(Executor executor, HandlerRegistry registry, io.grpc.internal.Server transportServer,
io.grpc.internal.Server transportServer, Context rootContext) { Context rootContext, DecompressorRegistry decompressorRegistry,
CompressorRegistry compressorRegistry) {
this.executor = executor; this.executor = executor;
this.registry = Preconditions.checkNotNull(registry, "registry"); this.registry = Preconditions.checkNotNull(registry, "registry");
this.transportServer = Preconditions.checkNotNull(transportServer, "transportServer"); this.transportServer = Preconditions.checkNotNull(transportServer, "transportServer");
// Fork from the passed in context so that it does not propagate cancellation, it only // Fork from the passed in context so that it does not propagate cancellation, it only
// inherits values. // inherits values.
this.rootContext = Preconditions.checkNotNull(rootContext).fork(); this.rootContext = Preconditions.checkNotNull(rootContext).fork();
this.decompressorRegistry = decompressorRegistry;
this.compressorRegistry = compressorRegistry;
} }
/** /**
@ -355,7 +363,8 @@ public final class ServerImpl extends io.grpc.Server {
Metadata headers, Context.CancellableContext context) { Metadata headers, Context.CancellableContext context) {
// TODO(ejona86): should we update fullMethodName to have the canonical path of the method? // TODO(ejona86): should we update fullMethodName to have the canonical path of the method?
ServerCallImpl<ReqT, RespT> call = new ServerCallImpl<ReqT, RespT>( ServerCallImpl<ReqT, RespT> call = new ServerCallImpl<ReqT, RespT>(
stream, methodDef.getMethodDescriptor(), context); stream, methodDef.getMethodDescriptor(), headers, context, decompressorRegistry,
compressorRegistry);
ServerCall.Listener<ReqT> listener = methodDef.getServerCallHandler() ServerCall.Listener<ReqT> listener = methodDef.getServerCallHandler()
.startCall(methodDef.getMethodDescriptor(), call, headers); .startCall(methodDef.getMethodDescriptor(), call, headers);
if (listener == null) { if (listener == null) {

View File

@ -83,7 +83,7 @@ final class SingleTransportChannel extends Channel {
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) { MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
return new ClientCallImpl<RequestT, ResponseT>(methodDescriptor, return new ClientCallImpl<RequestT, ResponseT>(methodDescriptor,
new SerializingExecutor(executor), callOptions, transportProvider, new SerializingExecutor(executor), callOptions, transportProvider,
deadlineCancellationExecutor).setKnownMessageEncodingRegistry(knownAcceptEncodingRegistry); deadlineCancellationExecutor);
} }
@Override @Override

View File

@ -32,13 +32,10 @@
package io.grpc.internal; package io.grpc.internal;
import io.grpc.Compressor; import io.grpc.Compressor;
import io.grpc.CompressorRegistry; import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import java.io.InputStream; import java.io.InputStream;
import javax.annotation.Nullable;
/** /**
* A single stream of communication between two end-points within a transport. * A single stream of communication between two end-points within a transport.
* *
@ -85,43 +82,22 @@ public interface Stream {
boolean isReady(); boolean isReady();
/** /**
* Picks a compressor for for this stream. If no message encodings are acceptable, compression is * Sets the compressor on the framer.
* not used. It is undefined if this this method is invoked multiple times. If the stream has
* a {@code start()} method, pickCompressor must be called prior to start.
* *
* * @param compressor the compressor to use
* @param messageEncodings a group of message encoding names that the remote endpoint is known
* to support.
* @return The compressor chosen for the stream, or null if none selected.
*/ */
@Nullable void setCompressor(Compressor compressor);
Compressor pickCompressor(Iterable<String> messageEncodings);
/**
* Sets the decompressor on the deframer.
*
* @param decompressor the decompressor to use.
*/
void setDecompressor(Decompressor decompressor);
/** /**
* Enables per-message compression, if an encoding type has been negotiated. If no message * Enables per-message compression, if an encoding type has been negotiated. If no message
* encoding has been negotiated, this is a no-op. * encoding has been negotiated, this is a no-op.
*/ */
void setMessageCompression(boolean enable); void setMessageCompression(boolean enable);
/**
* Sets the decompressor registry to use when resolving {@code #setDecompressor(String)}. If
* unset, the default DecompressorRegistry will be used. If the stream has a {@code start()}
* method, setDecompressionRegistry must be called prior to start.
*
* @see DecompressorRegistry#getDefaultInstance()
*
* @param registry the decompressors to use.
*/
void setDecompressionRegistry(DecompressorRegistry registry);
/**
* Sets the compressor registry to use when resolving {@link #pickCompressor}. If unset, the
* default CompressorRegistry will be used. If the stream has a {@code start()} method,
* setCompressionRegistry must be called prior to start.
*
* @see CompressorRegistry#getDefaultInstance()
*
* @param registry the compressors to use.
*/
void setCompressionRegistry(CompressorRegistry registry);
} }

View File

@ -239,23 +239,6 @@ public class AbstractClientStreamTest {
verify(mockListener).headersRead(headers); verify(mockListener).headersRead(headers);
} }
@Test
public void inboundHeadersReceived_notifiesListenerOnBadEncoding() {
AbstractClientStream<Integer> stream = new BaseAbstractClientStream<Integer>(allocator);
stream.start(mockListener);
Metadata headers = new Metadata();
headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, "bad");
Metadata.Key<String> randomKey = Metadata.Key.of("random", Metadata.ASCII_STRING_MARSHALLER);
headers.put(randomKey, "4");
stream.inboundHeadersReceived(headers);
ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
verify(mockListener).closed(statusCaptor.capture(), metadataCaptor.capture());
assertEquals(Code.INTERNAL, statusCaptor.getValue().getCode());
assertEquals("4", metadataCaptor.getValue().get(randomKey));
}
@Test @Test
public void rstStreamClosesStream() { public void rstStreamClosesStream() {
AbstractClientStream<Integer> stream = new BaseAbstractClientStream<Integer>(allocator); AbstractClientStream<Integer> stream = new BaseAbstractClientStream<Integer>(allocator);

View File

@ -55,7 +55,6 @@ import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.ClientCall; import io.grpc.ClientCall;
import io.grpc.Codec; import io.grpc.Codec;
import io.grpc.CompressorRegistry;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.Decompressor; import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry; import io.grpc.DecompressorRegistry;
@ -80,7 +79,6 @@ import org.mockito.MockitoAnnotations;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.HashSet;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
@ -103,8 +101,6 @@ public class ClientCallImplTest {
private final FakeClock fakeClock = new FakeClock(); private final FakeClock fakeClock = new FakeClock();
private final ScheduledExecutorService deadlineCancellationExecutor = private final ScheduledExecutorService deadlineCancellationExecutor =
fakeClock.scheduledExecutorService; fakeClock.scheduledExecutorService;
private final Set<String> knownMessageEncodings = new HashSet<String>();
private final CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance();
private final DecompressorRegistry decompressorRegistry = private final DecompressorRegistry decompressorRegistry =
DecompressorRegistry.getDefaultInstance(); DecompressorRegistry.getDefaultInstance();
private final MethodDescriptor<Void, Void> method = MethodDescriptor.create( private final MethodDescriptor<Void, Void> method = MethodDescriptor.create(
@ -159,9 +155,7 @@ public class ClientCallImplTest {
CallOptions.DEFAULT, CallOptions.DEFAULT,
provider, provider,
deadlineCancellationExecutor) deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry) .setDecompressorRegistry(decompressorRegistry);
.setCompressorRegistry(compressorRegistry)
.setKnownMessageEncodingRegistry(knownMessageEncodings);
call.start(callListener, new Metadata()); call.start(callListener, new Metadata());
@ -178,8 +172,8 @@ public class ClientCallImplTest {
public void prepareHeaders_authorityAdded() { public void prepareHeaders_authorityAdded() {
Metadata m = new Metadata(); Metadata m = new Metadata();
CallOptions callOptions = CallOptions.DEFAULT.withAuthority("auth"); CallOptions callOptions = CallOptions.DEFAULT.withAuthority("auth");
ClientCallImpl.prepareHeaders(m, callOptions, "user agent", knownMessageEncodings, ClientCallImpl.prepareHeaders(m, callOptions, "user agent", decompressorRegistry,
decompressorRegistry, compressorRegistry); Codec.Identity.NONE);
assertEquals(m.get(GrpcUtil.AUTHORITY_KEY), "auth"); assertEquals(m.get(GrpcUtil.AUTHORITY_KEY), "auth");
} }
@ -187,28 +181,17 @@ public class ClientCallImplTest {
@Test @Test
public void prepareHeaders_userAgentAdded() { public void prepareHeaders_userAgentAdded() {
Metadata m = new Metadata(); Metadata m = new Metadata();
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", knownMessageEncodings, ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", decompressorRegistry,
decompressorRegistry, compressorRegistry); Codec.Identity.NONE);
assertEquals(m.get(GrpcUtil.USER_AGENT_KEY), "user agent"); assertEquals(m.get(GrpcUtil.USER_AGENT_KEY), "user agent");
} }
@Test
public void prepareHeaders_messageEncodingAdded() {
Metadata m = new Metadata();
knownMessageEncodings.add(new Codec.Gzip().getMessageEncoding());
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", knownMessageEncodings,
decompressorRegistry, compressorRegistry);
assertEquals(m.get(GrpcUtil.MESSAGE_ENCODING_KEY), new Codec.Gzip().getMessageEncoding());
}
@Test @Test
public void prepareHeaders_ignoreIdentityEncoding() { public void prepareHeaders_ignoreIdentityEncoding() {
Metadata m = new Metadata(); Metadata m = new Metadata();
knownMessageEncodings.add(Codec.Identity.NONE.getMessageEncoding()); ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", decompressorRegistry,
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", knownMessageEncodings, Codec.Identity.NONE);
decompressorRegistry, compressorRegistry);
assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
} }
@ -251,8 +234,8 @@ public class ClientCallImplTest {
} }
}, false); // not advertised }, false); // not advertised
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", knownMessageEncodings, ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", customRegistry,
customRegistry, compressorRegistry); Codec.Identity.NONE);
Iterable<String> acceptedEncodings = Iterable<String> acceptedEncodings =
ACCEPT_ENCODING_SPLITER.split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); ACCEPT_ENCODING_SPLITER.split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
@ -269,8 +252,8 @@ public class ClientCallImplTest {
m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip");
m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip");
ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, null, knownMessageEncodings, ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, null,
DecompressorRegistry.newEmptyInstance(), compressorRegistry); DecompressorRegistry.newEmptyInstance(), Codec.Identity.NONE);
assertNull(m.get(GrpcUtil.AUTHORITY_KEY)); assertNull(m.get(GrpcUtil.AUTHORITY_KEY));
assertNull(m.get(GrpcUtil.USER_AGENT_KEY)); assertNull(m.get(GrpcUtil.USER_AGENT_KEY));
@ -294,9 +277,7 @@ public class ClientCallImplTest {
CallOptions.DEFAULT, CallOptions.DEFAULT,
provider, provider,
deadlineCancellationExecutor) deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry) .setDecompressorRegistry(decompressorRegistry);
.setCompressorRegistry(compressorRegistry)
.setKnownMessageEncodingRegistry(knownMessageEncodings);
Context.ROOT.attach(); Context.ROOT.attach();
@ -372,9 +353,7 @@ public class ClientCallImplTest {
CallOptions.DEFAULT, CallOptions.DEFAULT,
provider, provider,
deadlineCancellationExecutor) deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry) .setDecompressorRegistry(decompressorRegistry);
.setCompressorRegistry(compressorRegistry)
.setKnownMessageEncodingRegistry(knownMessageEncodings);
previous.attach(); previous.attach();
@ -454,9 +433,7 @@ public class ClientCallImplTest {
callOptions, callOptions,
provider, provider,
deadlineCancellationExecutor) deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry) .setDecompressorRegistry(decompressorRegistry);
.setCompressorRegistry(compressorRegistry)
.setKnownMessageEncodingRegistry(knownMessageEncodings);
call.start(callListener, new Metadata()); call.start(callListener, new Metadata());
assertFalse(future.isDone()); assertFalse(future.isDone());
fakeClock.forwardTime(1, TimeUnit.SECONDS); fakeClock.forwardTime(1, TimeUnit.SECONDS);

View File

@ -36,8 +36,7 @@ import static org.mockito.Matchers.isA;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import io.grpc.CompressorRegistry; import io.grpc.Codec;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
@ -77,10 +76,8 @@ public class DelayedStreamTest {
@Test @Test
public void setStream_sendsAllMessages() { public void setStream_sendsAllMessages() {
stream.start(listener); stream.start(listener);
DecompressorRegistry decompressors = DecompressorRegistry.newEmptyInstance(); stream.setCompressor(Codec.Identity.NONE);
CompressorRegistry compressors = CompressorRegistry.newEmptyInstance(); stream.setDecompressor(Codec.Identity.NONE);
stream.setDecompressionRegistry(decompressors);
stream.setCompressionRegistry(compressors);
stream.setMessageCompression(true); stream.setMessageCompression(true);
InputStream message = new ByteArrayInputStream(new byte[]{'a'}); InputStream message = new ByteArrayInputStream(new byte[]{'a'});
@ -90,8 +87,8 @@ public class DelayedStreamTest {
stream.setStream(realStream); stream.setStream(realStream);
verify(realStream).setDecompressionRegistry(decompressors); verify(realStream).setCompressor(Codec.Identity.NONE);
verify(realStream).setCompressionRegistry(compressors); verify(realStream).setDecompressor(Codec.Identity.NONE);
// Verify that the order was correct, even though they should be interleaved with the // Verify that the order was correct, even though they should be interleaved with the
// writeMessage calls // writeMessage calls

View File

@ -54,6 +54,7 @@ import io.grpc.CallOptions;
import io.grpc.Channel; import io.grpc.Channel;
import io.grpc.ClientCall; import io.grpc.ClientCall;
import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptor;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry; import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry; import io.grpc.DecompressorRegistry;
import io.grpc.IntegerMarshaller; import io.grpc.IntegerMarshaller;
@ -195,8 +196,7 @@ public class ManagedChannelImplTest {
ClientTransport.Listener transportListener = transportListenerCaptor.getValue(); ClientTransport.Listener transportListener = transportListenerCaptor.getValue();
verify(mockTransport, timeout(1000)).newStream(same(method), same(headers)); verify(mockTransport, timeout(1000)).newStream(same(method), same(headers));
verify(mockStream).start(streamListenerCaptor.capture()); verify(mockStream).start(streamListenerCaptor.capture());
verify(mockStream).setDecompressionRegistry(isA(DecompressorRegistry.class)); verify(mockStream).setCompressor(isA(Compressor.class));
verify(mockStream).setCompressionRegistry(isA(CompressorRegistry.class));
ClientStreamListener streamListener = streamListenerCaptor.getValue(); ClientStreamListener streamListener = streamListenerCaptor.getValue();
// Second call // Second call

View File

@ -44,7 +44,9 @@ import static org.mockito.Mockito.when;
import com.google.common.io.CharStreams; import com.google.common.io.CharStreams;
import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.Futures;
import io.grpc.CompressorRegistry;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.Marshaller;
@ -89,7 +91,8 @@ public class ServerCallImplTest {
public void setUp() { public void setUp() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
context = Context.ROOT.withCancellation(); context = Context.ROOT.withCancellation();
call = new ServerCallImpl<Long, Long>(stream, method, context); call = new ServerCallImpl<Long, Long>(stream, method, new Metadata(), context,
DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance());
} }
@Test @Test

View File

@ -49,7 +49,10 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.Context; import io.grpc.Context;
import io.grpc.DecompressorRegistry;
import io.grpc.IntegerMarshaller; import io.grpc.IntegerMarshaller;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
@ -91,6 +94,9 @@ public class ServerImplTest {
private static final Context.Key<String> SERVER_ONLY = Context.key("serverOnly"); private static final Context.Key<String> SERVER_ONLY = Context.key("serverOnly");
private static final Context.CancellableContext SERVER_CONTEXT = private static final Context.CancellableContext SERVER_CONTEXT =
Context.ROOT.withValue(SERVER_ONLY, "yes").withCancellation(); Context.ROOT.withValue(SERVER_ONLY, "yes").withCancellation();
private final CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance();
private final DecompressorRegistry decompressorRegistry =
DecompressorRegistry.getDefaultInstance();
static { static {
// Cancel the root context. Server will fork it so the per-call context should not // Cancel the root context. Server will fork it so the per-call context should not
@ -101,7 +107,8 @@ public class ServerImplTest {
private ExecutorService executor = Executors.newSingleThreadExecutor(); private ExecutorService executor = Executors.newSingleThreadExecutor();
private MutableHandlerRegistry registry = new MutableHandlerRegistryImpl(); private MutableHandlerRegistry registry = new MutableHandlerRegistryImpl();
private SimpleServer transportServer = new SimpleServer(); private SimpleServer transportServer = new SimpleServer();
private ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT); private ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT,
decompressorRegistry, compressorRegistry);
@Mock @Mock
private ServerStream stream; private ServerStream stream;
@ -129,7 +136,8 @@ public class ServerImplTest {
@Override @Override
public void shutdown() {} public void shutdown() {}
}; };
ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT); ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT,
decompressorRegistry, compressorRegistry);
server.start(); server.start();
server.shutdown(); server.shutdown();
assertTrue(server.isShutdown()); assertTrue(server.isShutdown());
@ -146,7 +154,8 @@ public class ServerImplTest {
throw new AssertionError("Should not be called, because wasn't started"); throw new AssertionError("Should not be called, because wasn't started");
} }
}; };
ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT); ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT,
decompressorRegistry, compressorRegistry);
server.shutdown(); server.shutdown();
assertTrue(server.isShutdown()); assertTrue(server.isShutdown());
assertTrue(server.isTerminated()); assertTrue(server.isTerminated());
@ -154,7 +163,8 @@ public class ServerImplTest {
@Test @Test
public void startStopImmediateWithChildTransport() throws IOException { public void startStopImmediateWithChildTransport() throws IOException {
ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT); ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT,
decompressorRegistry, compressorRegistry);
server.start(); server.start();
class DelayedShutdownServerTransport extends SimpleServerTransport { class DelayedShutdownServerTransport extends SimpleServerTransport {
boolean shutdown; boolean shutdown;
@ -186,7 +196,7 @@ public class ServerImplTest {
} }
ServerImpl server = new ServerImpl(executor, registry, new FailingStartupServer(), ServerImpl server = new ServerImpl(executor, registry, new FailingStartupServer(),
SERVER_CONTEXT); SERVER_CONTEXT, decompressorRegistry, compressorRegistry);
try { try {
server.start(); server.start();
fail("expected exception"); fail("expected exception");
@ -240,6 +250,7 @@ public class ServerImplTest {
responseHeaders.put(metadataKey, "response value"); responseHeaders.put(metadataKey, "response value");
call.sendHeaders(responseHeaders); call.sendHeaders(responseHeaders);
verify(stream).writeHeaders(responseHeaders); verify(stream).writeHeaders(responseHeaders);
verify(stream).setCompressor(isA(Compressor.class));
call.sendMessage(314); call.sendMessage(314);
ArgumentCaptor<InputStream> inputCaptor = ArgumentCaptor.forClass(InputStream.class); ArgumentCaptor<InputStream> inputCaptor = ArgumentCaptor.forClass(InputStream.class);
@ -322,7 +333,8 @@ public class ServerImplTest {
} }
transportServer = new MaybeDeadlockingServer(); transportServer = new MaybeDeadlockingServer();
ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT); ServerImpl server = new ServerImpl(executor, registry, transportServer, SERVER_CONTEXT,
decompressorRegistry, compressorRegistry);
server.start(); server.start();
new Thread() { new Thread() {
@Override @Override

View File

@ -230,11 +230,6 @@ public class CompressionTest {
assertNull(clientResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY)); assertNull(clientResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY));
} }
// Must be null for the first call.
assertNull(serverResponseHeaders.get(MESSAGE_ENCODING_KEY));
assertFalse(clientCodec.anyWritten);
assertFalse(serverCodec.anyRead);
if (clientAcceptEncoding) { if (clientAcceptEncoding) {
assertEquals("fzip", serverResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY)); assertEquals("fzip", serverResponseHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY));
} else { } else {
@ -242,7 +237,6 @@ public class CompressionTest {
} }
// Second call, once the client knows what the server supports. // Second call, once the client knows what the server supports.
stub.unaryCall(REQUEST);
if (clientEncoding && serverAcceptEncoding) { if (clientEncoding && serverAcceptEncoding) {
assertEquals("fzip", serverResponseHeaders.get(MESSAGE_ENCODING_KEY)); assertEquals("fzip", serverResponseHeaders.get(MESSAGE_ENCODING_KEY));
if (enableClientMessageCompression) { if (enableClientMessageCompression) {
@ -274,7 +268,11 @@ public class CompressionTest {
@Override @Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
final ClientCall<ReqT, RespT> call = next.newCall(method, callOptions); if (clientEncoding && serverAcceptEncoding) {
callOptions = callOptions.withCompression("fzip");
}
ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
return new ClientCompressor<ReqT, RespT>(call); return new ClientCompressor<ReqT, RespT>(call);
} }
} }

View File

@ -35,8 +35,6 @@ import static com.google.common.base.Preconditions.checkNotNull;
import static io.netty.channel.ChannelOption.SO_BACKLOG; import static io.netty.channel.ChannelOption.SO_BACKLOG;
import static io.netty.channel.ChannelOption.SO_KEEPALIVE; import static io.netty.channel.ChannelOption.SO_KEEPALIVE;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.internal.Server; import io.grpc.internal.Server;
import io.grpc.internal.ServerListener; import io.grpc.internal.ServerListener;
import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourceHolder;
@ -74,8 +72,6 @@ class NettyServer implements Server {
private EventLoopGroup workerGroup; private EventLoopGroup workerGroup;
private ServerListener listener; private ServerListener listener;
private Channel channel; private Channel channel;
private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry;
private final int flowControlWindow; private final int flowControlWindow;
private final int maxMessageSize; private final int maxMessageSize;
private final int maxHeaderListSize; private final int maxHeaderListSize;
@ -83,8 +79,7 @@ class NettyServer implements Server {
NettyServer(SocketAddress address, Class<? extends ServerChannel> channelType, NettyServer(SocketAddress address, Class<? extends ServerChannel> channelType,
@Nullable EventLoopGroup bossGroup, @Nullable EventLoopGroup workerGroup, @Nullable EventLoopGroup bossGroup, @Nullable EventLoopGroup workerGroup,
ProtocolNegotiator protocolNegotiator, DecompressorRegistry decompressorRegistry, ProtocolNegotiator protocolNegotiator, int maxStreamsPerConnection,
CompressorRegistry compressorRegistry, int maxStreamsPerConnection,
int flowControlWindow, int maxMessageSize, int maxHeaderListSize) { int flowControlWindow, int maxMessageSize, int maxHeaderListSize) {
this.address = address; this.address = address;
this.channelType = checkNotNull(channelType, "channelType"); this.channelType = checkNotNull(channelType, "channelType");
@ -97,8 +92,6 @@ class NettyServer implements Server {
this.flowControlWindow = flowControlWindow; this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
this.maxHeaderListSize = maxHeaderListSize; this.maxHeaderListSize = maxHeaderListSize;
this.decompressorRegistry = checkNotNull(decompressorRegistry, "decompressorRegistry");
this.compressorRegistry = checkNotNull(compressorRegistry, "compressorRegistry");
} }
@Override @Override
@ -125,10 +118,8 @@ class NettyServer implements Server {
eventLoopReferenceCounter.release(); eventLoopReferenceCounter.release();
} }
}); });
NettyServerTransport transport NettyServerTransport transport = new NettyServerTransport(ch, protocolNegotiator,
= new NettyServerTransport(ch, protocolNegotiator, decompressorRegistry, maxStreamsPerConnection, flowControlWindow, maxMessageSize, maxHeaderListSize);
compressorRegistry, maxStreamsPerConnection, flowControlWindow, maxMessageSize,
maxHeaderListSize);
transport.start(listener.transportCreated(transport)); transport.start(listener.transportCreated(transport));
} }
}); });

View File

@ -31,14 +31,11 @@
package io.grpc.netty; package io.grpc.netty;
import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.ExperimentalApi; import io.grpc.ExperimentalApi;
import io.grpc.HandlerRegistry; import io.grpc.HandlerRegistry;
import io.grpc.Internal; import io.grpc.Internal;
@ -244,12 +241,9 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder<NettySer
negotiator = sslContext != null ? ProtocolNegotiators.serverTls(sslContext) : negotiator = sslContext != null ? ProtocolNegotiators.serverTls(sslContext) :
ProtocolNegotiators.serverPlaintext(); ProtocolNegotiators.serverPlaintext();
} }
return new NettyServer(address, channelType, bossEventLoopGroup, return new NettyServer(address, channelType, bossEventLoopGroup, workerEventLoopGroup,
workerEventLoopGroup, negotiator, negotiator, maxConcurrentCallsPerConnection, flowControlWindow, maxMessageSize,
firstNonNull(decompressorRegistry(), DecompressorRegistry.getDefaultInstance()), maxHeaderListSize);
firstNonNull(compressorRegistry(), CompressorRegistry.getDefaultInstance()),
maxConcurrentCallsPerConnection, flowControlWindow,
maxMessageSize, maxHeaderListSize);
} }
@Override @Override

View File

@ -43,8 +43,6 @@ import static io.netty.handler.codec.http2.DefaultHttp2LocalFlowController.DEFAU
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
@ -97,8 +95,6 @@ class NettyServerHandler extends AbstractNettyHandler {
private static final Status GOAWAY_STATUS = Status.UNAVAILABLE; private static final Status GOAWAY_STATUS = Status.UNAVAILABLE;
private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry;
private final Http2Connection.PropertyKey streamKey; private final Http2Connection.PropertyKey streamKey;
private final ServerTransportListener transportListener; private final ServerTransportListener transportListener;
private final int maxMessageSize; private final int maxMessageSize;
@ -107,8 +103,6 @@ class NettyServerHandler extends AbstractNettyHandler {
private WriteQueue serverWriteQueue; private WriteQueue serverWriteQueue;
static NettyServerHandler newHandler(ServerTransportListener transportListener, static NettyServerHandler newHandler(ServerTransportListener transportListener,
DecompressorRegistry decompressorRegistry,
CompressorRegistry compressorRegistry,
int maxStreams, int maxStreams,
int flowControlWindow, int flowControlWindow,
int maxHeaderListSize, int maxHeaderListSize,
@ -121,15 +115,13 @@ class NettyServerHandler extends AbstractNettyHandler {
new DefaultHttp2FrameReader(headersDecoder), frameLogger); new DefaultHttp2FrameReader(headersDecoder), frameLogger);
Http2FrameWriter frameWriter = Http2FrameWriter frameWriter =
new Http2OutboundFrameLogger(new DefaultHttp2FrameWriter(), frameLogger); new Http2OutboundFrameLogger(new DefaultHttp2FrameWriter(), frameLogger);
return newHandler(frameReader, frameWriter, transportListener, decompressorRegistry, return newHandler(frameReader, frameWriter, transportListener, maxStreams, flowControlWindow,
compressorRegistry, maxStreams, flowControlWindow, maxMessageSize); maxMessageSize);
} }
@VisibleForTesting @VisibleForTesting
static NettyServerHandler newHandler(Http2FrameReader frameReader, Http2FrameWriter frameWriter, static NettyServerHandler newHandler(Http2FrameReader frameReader, Http2FrameWriter frameWriter,
ServerTransportListener transportListener, ServerTransportListener transportListener,
DecompressorRegistry decompressorRegistry,
CompressorRegistry compressorRegistry,
int maxStreams, int maxStreams,
int flowControlWindow, int flowControlWindow,
int maxMessageSize) { int maxMessageSize) {
@ -151,21 +143,16 @@ class NettyServerHandler extends AbstractNettyHandler {
settings.initialWindowSize(flowControlWindow); settings.initialWindowSize(flowControlWindow);
settings.maxConcurrentStreams(maxStreams); settings.maxConcurrentStreams(maxStreams);
return new NettyServerHandler(transportListener, decoder, encoder, settings, return new NettyServerHandler(transportListener, decoder, encoder, settings, maxMessageSize);
decompressorRegistry, compressorRegistry, maxMessageSize);
} }
private NettyServerHandler(ServerTransportListener transportListener, private NettyServerHandler(ServerTransportListener transportListener,
Http2ConnectionDecoder decoder, Http2ConnectionDecoder decoder,
Http2ConnectionEncoder encoder, Http2Settings settings, Http2ConnectionEncoder encoder, Http2Settings settings,
DecompressorRegistry decompressorRegistry,
CompressorRegistry compressorRegistry,
int maxMessageSize) { int maxMessageSize) {
super(decoder, encoder, settings); super(decoder, encoder, settings);
checkArgument(maxMessageSize >= 0, "maxMessageSize must be >= 0"); checkArgument(maxMessageSize >= 0, "maxMessageSize must be >= 0");
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
this.decompressorRegistry = checkNotNull(decompressorRegistry, "decompressorRegistry");
this.compressorRegistry = checkNotNull(compressorRegistry, "compressorRegistry");
streamKey = encoder.connection().newKey(); streamKey = encoder.connection().newKey();
this.transportListener = checkNotNull(transportListener, "transportListener"); this.transportListener = checkNotNull(transportListener, "transportListener");
@ -206,11 +193,6 @@ class NettyServerHandler extends AbstractNettyHandler {
NettyServerStream stream = new NettyServerStream(ctx.channel(), http2Stream, this, NettyServerStream stream = new NettyServerStream(ctx.channel(), http2Stream, this,
maxMessageSize); maxMessageSize);
// These must be called before inboundHeadersReceived, because the framers depend on knowing
// the compression algorithms available before negotiation.
stream.setDecompressionRegistry(decompressorRegistry);
stream.setCompressionRegistry(compressorRegistry);
Metadata metadata = Utils.convertHeaders(headers); Metadata metadata = Utils.convertHeaders(headers);
stream.inboundHeadersReceived(metadata); stream.inboundHeadersReceived(metadata);

View File

@ -31,12 +31,8 @@
package io.grpc.netty; package io.grpc.netty;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.internal.ServerTransport; import io.grpc.internal.ServerTransport;
import io.grpc.internal.ServerTransportListener; import io.grpc.internal.ServerTransportListener;
import io.netty.channel.Channel; import io.netty.channel.Channel;
@ -55,8 +51,6 @@ class NettyServerTransport implements ServerTransport {
private final Channel channel; private final Channel channel;
private final ProtocolNegotiator protocolNegotiator; private final ProtocolNegotiator protocolNegotiator;
private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry;
private final int maxStreams; private final int maxStreams;
private ServerTransportListener listener; private ServerTransportListener listener;
private boolean terminated; private boolean terminated;
@ -64,18 +58,14 @@ class NettyServerTransport implements ServerTransport {
private final int maxMessageSize; private final int maxMessageSize;
private final int maxHeaderListSize; private final int maxHeaderListSize;
NettyServerTransport(Channel channel, ProtocolNegotiator protocolNegotiator, NettyServerTransport(Channel channel, ProtocolNegotiator protocolNegotiator, int maxStreams,
DecompressorRegistry decompressorRegistry, int flowControlWindow, int maxMessageSize, int maxHeaderListSize) {
CompressorRegistry compressorRegistry, int maxStreams, int flowControlWindow,
int maxMessageSize, int maxHeaderListSize) {
this.channel = Preconditions.checkNotNull(channel, "channel"); this.channel = Preconditions.checkNotNull(channel, "channel");
this.protocolNegotiator = Preconditions.checkNotNull(protocolNegotiator, "protocolNegotiator"); this.protocolNegotiator = Preconditions.checkNotNull(protocolNegotiator, "protocolNegotiator");
this.maxStreams = maxStreams; this.maxStreams = maxStreams;
this.flowControlWindow = flowControlWindow; this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
this.maxHeaderListSize = maxHeaderListSize; this.maxHeaderListSize = maxHeaderListSize;
this.decompressorRegistry = checkNotNull(decompressorRegistry, "decompressorRegistry");
this.compressorRegistry = checkNotNull(compressorRegistry, "compressorRegistry");
} }
public void start(ServerTransportListener listener) { public void start(ServerTransportListener listener) {
@ -125,7 +115,7 @@ class NettyServerTransport implements ServerTransport {
* Creates the Netty handler to be used in the channel pipeline. * Creates the Netty handler to be used in the channel pipeline.
*/ */
private NettyServerHandler createHandler(ServerTransportListener transportListener) { private NettyServerHandler createHandler(ServerTransportListener transportListener) {
return NettyServerHandler.newHandler(transportListener, decompressorRegistry, return NettyServerHandler.newHandler(transportListener, maxStreams, flowControlWindow,
compressorRegistry, maxStreams, flowControlWindow, maxHeaderListSize, maxMessageSize); maxHeaderListSize, maxMessageSize);
} }
} }

View File

@ -43,8 +43,6 @@ import static org.junit.Assert.fail;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.Marshaller;
@ -321,10 +319,8 @@ public class NettyClientTransportTest {
SslContext serverContext = GrpcSslContexts.forServer(serverCert, key) SslContext serverContext = GrpcSslContexts.forServer(serverCert, key)
.ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build(); .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build();
ProtocolNegotiator negotiator = ProtocolNegotiators.serverTls(serverContext); ProtocolNegotiator negotiator = ProtocolNegotiators.serverTls(serverContext);
server = new NettyServer(address, NioServerSocketChannel.class, server = new NettyServer(address, NioServerSocketChannel.class, group, group, negotiator,
group, group, negotiator, DecompressorRegistry.getDefaultInstance(), maxStreamsPerConnection, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, maxHeaderListSize);
CompressorRegistry.getDefaultInstance(), maxStreamsPerConnection,
DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, maxHeaderListSize);
server.start(serverListener); server.start(serverListener);
} }

View File

@ -54,8 +54,6 @@ import static org.mockito.Mockito.when;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.Status.Code; import io.grpc.Status.Code;
@ -347,7 +345,6 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
@Override @Override
protected NettyServerHandler newHandler() { protected NettyServerHandler newHandler() {
return NettyServerHandler.newHandler(frameReader(), frameWriter(), transportListener, return NettyServerHandler.newHandler(frameReader(), frameWriter(), transportListener,
DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance(),
maxConcurrentStreams, flowControlWindow, DEFAULT_MAX_MESSAGE_SIZE); maxConcurrentStreams, flowControlWindow, DEFAULT_MAX_MESSAGE_SIZE);
} }

View File

@ -50,9 +50,7 @@ import static org.mockito.Mockito.when;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerStreamListener;
import io.netty.buffer.EmptyByteBuf; import io.netty.buffer.EmptyByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
@ -96,7 +94,6 @@ public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream
stream.writeHeaders(new Metadata()); stream.writeHeaders(new Metadata());
Http2Headers headers = new DefaultHttp2Headers() Http2Headers headers = new DefaultHttp2Headers()
.status(Utils.STATUS_OK) .status(Utils.STATUS_OK)
.set(GrpcUtil.MESSAGE_ACCEPT_ENCODING, AsciiString.of("gzip"))
.set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_GRPC); .set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_GRPC);
verify(writeQueue).enqueue(new SendResponseHeadersCommand(STREAM_ID, headers, false), true); verify(writeQueue).enqueue(new SendResponseHeadersCommand(STREAM_ID, headers, false), true);
byte[] msg = smallMessage(); byte[] msg = smallMessage();

View File

@ -35,6 +35,7 @@ import io.grpc.CallOptions;
import io.grpc.Channel; import io.grpc.Channel;
import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptor;
import io.grpc.ClientInterceptors; import io.grpc.ClientInterceptors;
import io.grpc.ExperimentalApi;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -116,6 +117,20 @@ public abstract class AbstractStub<S extends AbstractStub<S>> {
return build(channel, callOptions.withDeadlineAfter(duration, unit)); return build(channel, callOptions.withDeadlineAfter(duration, unit));
} }
/**
* Set's the compressor name to use for the call. It is the responsibility of the application
* to make sure the server supports decoding the compressor picked by the client. To be clear,
* this is the compressor used by the stub to compress messages to the server. To get
* compressed responses from the server, set the appropriate {@link io.grpc.DecompressorRegistry}
* on the {@link io.grpc.ManagedChannelBuilder}.
*
* @param compressorName the name (e.g. "gzip") of the compressor to use.
*/
@ExperimentalApi
public final S withCompression(String compressorName) {
return build(channel, callOptions.withCompression(compressorName));
}
/** /**
* Returns a new stub that uses the given channel. * Returns a new stub that uses the given channel.
*/ */