core: set server stream decompressor in transport thread

This commit is contained in:
Eric Gribkoff 2017-06-12 12:00:21 -07:00 committed by GitHub
parent 6038b0987e
commit 49b9216e83
3 changed files with 37 additions and 12 deletions

View File

@ -30,7 +30,6 @@ import io.grpc.Codec;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.Context;
import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import io.grpc.InternalDecompressorRegistry;
import io.grpc.Metadata;
@ -65,17 +64,6 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<ReqT, RespT> {
this.messageAcceptEncoding = inboundHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY);
this.decompressorRegistry = decompressorRegistry;
this.compressorRegistry = compressorRegistry;
if (inboundHeaders.containsKey(MESSAGE_ENCODING_KEY)) {
String encoding = inboundHeaders.get(MESSAGE_ENCODING_KEY);
Decompressor decompressor = decompressorRegistry.lookupDecompressor(encoding);
if (decompressor == null) {
throw Status.UNIMPLEMENTED
.withDescription(String.format("Can't find decompressor for %s", encoding))
.asRuntimeException();
}
stream.setDecompressor(decompressor);
}
}
@Override

View File

@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.grpc.Contexts.statusFromCancelled;
import static io.grpc.Status.DEADLINE_EXCEEDED;
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
@ -28,6 +29,7 @@ import com.google.common.base.Preconditions;
import io.grpc.Attributes;
import io.grpc.CompressorRegistry;
import io.grpc.Context;
import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import io.grpc.HandlerRegistry;
import io.grpc.Metadata;
@ -362,6 +364,18 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
@Override
public void streamCreated(
final ServerStream stream, final String methodName, final Metadata headers) {
if (headers.containsKey(MESSAGE_ENCODING_KEY)) {
String encoding = headers.get(MESSAGE_ENCODING_KEY);
Decompressor decompressor = decompressorRegistry.lookupDecompressor(encoding);
if (decompressor == null) {
stream.close(
Status.UNIMPLEMENTED.withDescription(
String.format("Can't find decompressor for %s", encoding)),
new Metadata());
return;
}
stream.setDecompressor(decompressor);
}
final StatsTraceContext statsTraceCtx = Preconditions.checkNotNull(
stream.statsTraceContext(), "statsTraceCtx not present from stream");

View File

@ -16,6 +16,7 @@
package io.grpc.internal;
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
@ -369,6 +370,28 @@ public class ServerImplTest {
assertEquals(Status.Code.UNIMPLEMENTED, statusCaptor.getValue().getCode());
}
@Test
public void decompressorNotFound() throws Exception {
String decompressorName = "NON_EXISTENT_DECOMPRESSOR";
createAndStartServer(NO_FILTERS);
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
Metadata requestHeaders = new Metadata();
requestHeaders.put(MESSAGE_ENCODING_KEY, decompressorName);
StatsTraceContext statsTraceCtx =
StatsTraceContext.newServerContext(
streamTracerFactories, "Waiter/nonexist", requestHeaders);
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
transportListener.streamCreated(stream, "Waiter/nonexist", requestHeaders);
verify(stream).close(statusCaptor.capture(), any(Metadata.class));
Status status = statusCaptor.getValue();
assertEquals(Status.Code.UNIMPLEMENTED, status.getCode());
assertEquals("Can't find decompressor for " + decompressorName, status.getDescription());
verifyNoMoreInteractions(stream);
}
@Test
public void basicExchangeSuccessful() throws Exception {
createAndStartServer(NO_FILTERS);