diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java index 9634954607..77e7479504 100644 --- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java @@ -30,7 +30,6 @@ import io.grpc.Codec; import io.grpc.Compressor; import io.grpc.CompressorRegistry; import io.grpc.Context; -import io.grpc.Decompressor; import io.grpc.DecompressorRegistry; import io.grpc.InternalDecompressorRegistry; import io.grpc.Metadata; @@ -65,17 +64,6 @@ final class ServerCallImpl extends ServerCall { this.messageAcceptEncoding = inboundHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY); this.decompressorRegistry = decompressorRegistry; this.compressorRegistry = compressorRegistry; - - if (inboundHeaders.containsKey(MESSAGE_ENCODING_KEY)) { - String encoding = inboundHeaders.get(MESSAGE_ENCODING_KEY); - Decompressor decompressor = decompressorRegistry.lookupDecompressor(encoding); - if (decompressor == null) { - throw Status.UNIMPLEMENTED - .withDescription(String.format("Can't find decompressor for %s", encoding)) - .asRuntimeException(); - } - stream.setDecompressor(decompressor); - } } @Override diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java index 99bf32025d..b5411fa730 100644 --- a/core/src/main/java/io/grpc/internal/ServerImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerImpl.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.grpc.Contexts.statusFromCancelled; import static io.grpc.Status.DEADLINE_EXCEEDED; +import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; import static java.util.concurrent.TimeUnit.NANOSECONDS; @@ -28,6 +29,7 @@ import com.google.common.base.Preconditions; import io.grpc.Attributes; import io.grpc.CompressorRegistry; import io.grpc.Context; +import io.grpc.Decompressor; import io.grpc.DecompressorRegistry; import io.grpc.HandlerRegistry; import io.grpc.Metadata; @@ -362,6 +364,18 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { @Override public void streamCreated( final ServerStream stream, final String methodName, final Metadata headers) { + if (headers.containsKey(MESSAGE_ENCODING_KEY)) { + String encoding = headers.get(MESSAGE_ENCODING_KEY); + Decompressor decompressor = decompressorRegistry.lookupDecompressor(encoding); + if (decompressor == null) { + stream.close( + Status.UNIMPLEMENTED.withDescription( + String.format("Can't find decompressor for %s", encoding)), + new Metadata()); + return; + } + stream.setDecompressor(decompressor); + } final StatsTraceContext statsTraceCtx = Preconditions.checkNotNull( stream.statsTraceContext(), "statsTraceCtx not present from stream"); diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index 6bce3a8484..09f90a4f71 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -369,6 +370,28 @@ public class ServerImplTest { assertEquals(Status.Code.UNIMPLEMENTED, statusCaptor.getValue().getCode()); } + @Test + public void decompressorNotFound() throws Exception { + String decompressorName = "NON_EXISTENT_DECOMPRESSOR"; + createAndStartServer(NO_FILTERS); + ServerTransportListener transportListener + = transportServer.registerNewServerTransport(new SimpleServerTransport()); + Metadata requestHeaders = new Metadata(); + requestHeaders.put(MESSAGE_ENCODING_KEY, decompressorName); + StatsTraceContext statsTraceCtx = + StatsTraceContext.newServerContext( + streamTracerFactories, "Waiter/nonexist", requestHeaders); + when(stream.statsTraceContext()).thenReturn(statsTraceCtx); + + transportListener.streamCreated(stream, "Waiter/nonexist", requestHeaders); + + verify(stream).close(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getValue(); + assertEquals(Status.Code.UNIMPLEMENTED, status.getCode()); + assertEquals("Can't find decompressor for " + decompressorName, status.getDescription()); + verifyNoMoreInteractions(stream); + } + @Test public void basicExchangeSuccessful() throws Exception { createAndStartServer(NO_FILTERS);