diff --git a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java index 0fab5c23b0..7a2152133d 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java @@ -22,6 +22,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Compressor; import io.grpc.Decompressor; +import io.grpc.DecompressorRegistry; import io.grpc.Grpc; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -569,7 +570,7 @@ final class InProcessTransport implements ServerTransport, ConnectionClientTrans public void setCompressor(Compressor compressor) {} @Override - public void setDecompressor(Decompressor decompressor) {} + public void setDecompressorRegistry(DecompressorRegistry decompressorRegistry) {} @Override public void setMaxInboundMessageSize(int maxSize) {} diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index be2eefccb7..4adec0dac2 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -16,9 +16,14 @@ package io.grpc.internal; +import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; + import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import io.grpc.Codec; import io.grpc.Compressor; +import io.grpc.Decompressor; +import io.grpc.DecompressorRegistry; import io.grpc.Metadata; import io.grpc.Status; import java.io.InputStream; @@ -111,6 +116,11 @@ public abstract class AbstractClientStream extends AbstractStream transportState().setMaxInboundMessageSize(maxSize); } + @Override + public final void setDecompressorRegistry(DecompressorRegistry decompressorRegistry) { + transportState().setDecompressorRegistry(decompressorRegistry); + } + /** {@inheritDoc} */ @Override protected abstract TransportState transportState(); @@ -172,6 +182,7 @@ public abstract class AbstractClientStream extends AbstractStream private final StatsTraceContext statsTraceCtx; private boolean listenerClosed; private ClientStreamListener listener; + private DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance(); private Runnable deliveryStalledTask; @@ -186,6 +197,12 @@ public abstract class AbstractClientStream extends AbstractStream this.statsTraceCtx = Preconditions.checkNotNull(statsTraceCtx, "statsTraceCtx"); } + private void setDecompressorRegistry(DecompressorRegistry decompressorRegistry) { + Preconditions.checkState(this.listener == null, "Already called start"); + this.decompressorRegistry = + Preconditions.checkNotNull(decompressorRegistry, "decompressorRegistry"); + } + @VisibleForTesting public final void setListener(ClientStreamListener listener) { Preconditions.checkState(this.listener == null, "Already called setListener"); @@ -218,6 +235,19 @@ public abstract class AbstractClientStream extends AbstractStream protected void inboundHeadersReceived(Metadata headers) { Preconditions.checkState(!statusReported, "Received headers on closed stream"); statsTraceCtx.clientInboundHeaders(); + + Decompressor decompressor = Codec.Identity.NONE; + String encoding = headers.get(MESSAGE_ENCODING_KEY); + if (encoding != null) { + decompressor = decompressorRegistry.lookupDecompressor(encoding); + if (decompressor == null) { + deframeFailed(Status.INTERNAL.withDescription( + String.format("Can't find decompressor for %s", encoding)).asRuntimeException()); + return; + } + } + setDecompressor(decompressor); + listener().headersRead(headers); } diff --git a/core/src/main/java/io/grpc/internal/AbstractServerStream.java b/core/src/main/java/io/grpc/internal/AbstractServerStream.java index 84c8ebd0d3..14486e202c 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerStream.java @@ -18,6 +18,7 @@ package io.grpc.internal; import com.google.common.base.Preconditions; import io.grpc.Attributes; +import io.grpc.Decompressor; import io.grpc.InternalStatus; import io.grpc.Metadata; import io.grpc.Status; @@ -150,6 +151,11 @@ public abstract class AbstractServerStream extends AbstractStream return super.isReady(); } + @Override + public final void setDecompressor(Decompressor decompressor) { + transportState().setDecompressor(Preconditions.checkNotNull(decompressor, "decompressor")); + } + @Override public Attributes getAttributes() { return Attributes.EMPTY; } diff --git a/core/src/main/java/io/grpc/internal/AbstractStream.java b/core/src/main/java/io/grpc/internal/AbstractStream.java index c8f835d35e..cbddf523c6 100644 --- a/core/src/main/java/io/grpc/internal/AbstractStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractStream.java @@ -73,11 +73,6 @@ public abstract class AbstractStream implements Stream { framer().setCompressor(checkNotNull(compressor, "compressor")); } - @Override - public final void setDecompressor(Decompressor decompressor) { - transportState().setDecompressor(checkNotNull(decompressor, "decompressor")); - } - @Override public boolean isReady() { if (framer().isClosed()) { @@ -207,7 +202,7 @@ public abstract class AbstractStream implements Stream { return statsTraceCtx; } - private void setDecompressor(Decompressor decompressor) { + protected final void setDecompressor(Decompressor decompressor) { if (deframer.isClosed()) { return; } diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index fdf0a5580d..f423e2c0de 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -37,7 +37,6 @@ import io.grpc.Compressor; import io.grpc.CompressorRegistry; import io.grpc.Context; import io.grpc.Deadline; -import io.grpc.Decompressor; import io.grpc.DecompressorRegistry; import io.grpc.InternalDecompressorRegistry; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -222,6 +221,7 @@ final class ClientCallImpl extends ClientCall stream.setMaxOutboundMessageSize(callOptions.getMaxOutboundMessageSize()); } stream.setCompressor(compressor); + stream.setDecompressorRegistry(decompressorRegistry); stream.start(new ClientStreamListenerImpl(observer)); // Delay any sources of cancellation after start(), because most of the transports are broken if @@ -429,18 +429,6 @@ final class ClientCallImpl extends ClientCall @Override public void headersRead(final Metadata headers) { - Decompressor decompressor = Codec.Identity.NONE; - if (headers.containsKey(MESSAGE_ENCODING_KEY)) { - String encoding = headers.get(MESSAGE_ENCODING_KEY); - decompressor = decompressorRegistry.lookupDecompressor(encoding); - if (decompressor == null) { - stream.cancel(Status.INTERNAL.withDescription( - String.format("Can't find decompressor for %s", encoding))); - return; - } - } - stream.setDecompressor(decompressor); - class HeadersRead extends ContextRunnable { HeadersRead() { super(context); diff --git a/core/src/main/java/io/grpc/internal/ClientStream.java b/core/src/main/java/io/grpc/internal/ClientStream.java index 32ec612da3..1ff492cf66 100644 --- a/core/src/main/java/io/grpc/internal/ClientStream.java +++ b/core/src/main/java/io/grpc/internal/ClientStream.java @@ -17,6 +17,7 @@ package io.grpc.internal; import io.grpc.Attributes; +import io.grpc.DecompressorRegistry; import io.grpc.Status; /** @@ -51,6 +52,14 @@ public interface ClientStream extends Stream { */ void setAuthority(String authority); + /** + * Sets the registry to find a decompressor for the framer. May only be called before {@link + * #start}. If the transport does not support compression, this may do nothing. + * + * @param decompressorRegistry the registry of decompressors for decoding responses + */ + void setDecompressorRegistry(DecompressorRegistry decompressorRegistry); + /** * Starts stream. This method may only be called once. It is safe to do latent initialization of * the stream up until {@link #start} is called. diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index e52d400cda..b1073f1291 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -22,7 +22,7 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; import io.grpc.Compressor; -import io.grpc.Decompressor; +import io.grpc.DecompressorRegistry; import io.grpc.Metadata; import io.grpc.Status; import java.io.InputStream; @@ -303,16 +303,14 @@ class DelayedStream implements ClientStream { } @Override - public void setDecompressor(Decompressor decompressor) { - checkNotNull(decompressor, "decompressor"); - // This method being called only makes sense after setStream() has been called (but not - // necessarily returned), but there is not necessarily a happens-before relationship. This - // synchronized block creates one. - synchronized (this) { } - checkState(realStream != null, "How did we receive a reply before the request is sent?"); - // ClientStreamListenerImpl (in ClientCallImpl) requires setDecompressor to be set immediately, - // since messages may be processed immediately after this method returns. - realStream.setDecompressor(decompressor); + public void setDecompressorRegistry(final DecompressorRegistry decompressorRegistry) { + checkNotNull(decompressorRegistry, "decompressorRegistry"); + delayOrExecute(new Runnable() { + @Override + public void run() { + realStream.setDecompressorRegistry(decompressorRegistry); + } + }); } @Override diff --git a/core/src/main/java/io/grpc/internal/NoopClientStream.java b/core/src/main/java/io/grpc/internal/NoopClientStream.java index 8e85abe893..c051b50a01 100644 --- a/core/src/main/java/io/grpc/internal/NoopClientStream.java +++ b/core/src/main/java/io/grpc/internal/NoopClientStream.java @@ -18,7 +18,7 @@ package io.grpc.internal; import io.grpc.Attributes; import io.grpc.Compressor; -import io.grpc.Decompressor; +import io.grpc.DecompressorRegistry; import io.grpc.Status; import java.io.InputStream; @@ -68,7 +68,7 @@ public class NoopClientStream implements ClientStream { public void setCompressor(Compressor compressor) {} @Override - public void setDecompressor(Decompressor decompressor) {} + public void setDecompressorRegistry(DecompressorRegistry decompressorRegistry) {} @Override public void setMaxInboundMessageSize(int maxSize) {} diff --git a/core/src/main/java/io/grpc/internal/ServerStream.java b/core/src/main/java/io/grpc/internal/ServerStream.java index 5a79497607..7990b91b5a 100644 --- a/core/src/main/java/io/grpc/internal/ServerStream.java +++ b/core/src/main/java/io/grpc/internal/ServerStream.java @@ -17,6 +17,7 @@ package io.grpc.internal; import io.grpc.Attributes; +import io.grpc.Decompressor; import io.grpc.Metadata; import io.grpc.Status; import javax.annotation.Nullable; @@ -58,6 +59,14 @@ public interface ServerStream extends Stream { */ void cancel(Status status); + /** + * Sets the decompressor on the deframer. If the transport does not support compression, this may + * do nothing. + * + * @param decompressor the decompressor to use. + */ + void setDecompressor(Decompressor decompressor); + /** * Attributes describing stream. This is inherited from the transport attributes, and used * as the basis of {@link io.grpc.ServerCall#getAttributes}. diff --git a/core/src/main/java/io/grpc/internal/Stream.java b/core/src/main/java/io/grpc/internal/Stream.java index 5d85798b67..79667fabfc 100644 --- a/core/src/main/java/io/grpc/internal/Stream.java +++ b/core/src/main/java/io/grpc/internal/Stream.java @@ -17,7 +17,6 @@ package io.grpc.internal; import io.grpc.Compressor; -import io.grpc.Decompressor; import java.io.InputStream; /** @@ -72,13 +71,6 @@ public interface Stream { */ void setCompressor(Compressor compressor); - /** - * Sets the decompressor on the deframer. - * - * @param decompressor the decompressor to use. - */ - void setDecompressor(Decompressor decompressor); - /** * Enables per-message compression, if an encoding type has been negotiated. If no message * encoding has been negotiated, this is a no-op. By default per-message compression is enabled, diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index 1444dfdca9..0c7473363f 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -34,6 +34,7 @@ import static org.mockito.Mockito.when; import io.grpc.Attributes; import io.grpc.Attributes.Key; import io.grpc.Codec; +import io.grpc.DecompressorRegistry; import io.grpc.Metadata; import io.grpc.Status; import java.io.ByteArrayInputStream; @@ -87,16 +88,11 @@ public class DelayedStreamTest { stream.start(mock(ClientStreamListener.class)); } - @Test(expected = IllegalStateException.class) - public void setDecompressor_beforeSetStream() { - stream.start(listener); - stream.setDecompressor(Codec.Identity.NONE); - } - @Test public void setStream_sendsAllMessages() { stream.start(listener); stream.setCompressor(Codec.Identity.NONE); + stream.setDecompressorRegistry(DecompressorRegistry.getDefaultInstance()); stream.setMessageCompression(true); InputStream message = new ByteArrayInputStream(new byte[]{'a'}); @@ -105,10 +101,9 @@ public class DelayedStreamTest { stream.writeMessage(message); stream.setStream(realStream); - stream.setDecompressor(Codec.Identity.NONE); verify(realStream).setCompressor(Codec.Identity.NONE); - verify(realStream).setDecompressor(Codec.Identity.NONE); + verify(realStream).setDecompressorRegistry(DecompressorRegistry.getDefaultInstance()); verify(realStream).setMessageCompression(true); verify(realStream).setMessageCompression(false);