core,netty,okhttp: move user agent removal closer to where it is set

This commit is contained in:
Carl Mastrangelo 2016-05-26 12:27:50 -07:00
parent 028d0844dd
commit 02eb24b3bd
6 changed files with 75 additions and 9 deletions

View File

@ -41,7 +41,6 @@ import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER;
import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY;
import static io.grpc.internal.GrpcUtil.USER_AGENT_KEY;
import static java.lang.Math.max;
import com.google.common.annotations.VisibleForTesting;
@ -141,9 +140,6 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
@VisibleForTesting
static void prepareHeaders(Metadata headers, DecompressorRegistry decompressorRegistry,
Compressor compressor) {
// Remove user agent. Agent are added in the transport.
headers.removeAll(USER_AGENT_KEY);
headers.removeAll(MESSAGE_ENCODING_KEY);
if (compressor != Codec.Identity.NONE) {
headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding());

View File

@ -196,12 +196,13 @@ public class ClientCallImplTest {
}
@Test
public void prepareHeaders_userAgentRemove() {
public void prepareHeaders_userAgentIgnored() {
Metadata m = new Metadata();
m.put(GrpcUtil.USER_AGENT_KEY, "batmobile");
ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE);
assertThat(m.get(GrpcUtil.USER_AGENT_KEY)).isNull();
// User Agent is removed and set by the transport
assertThat(m.get(GrpcUtil.USER_AGENT_KEY)).isNotNull();
}
@Test
@ -262,13 +263,11 @@ public class ClientCallImplTest {
@Test
public void prepareHeaders_removeReservedHeaders() {
Metadata m = new Metadata();
m.put(GrpcUtil.USER_AGENT_KEY, "user agent");
m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip");
m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip");
ClientCallImpl.prepareHeaders(m, DecompressorRegistry.newEmptyInstance(), Codec.Identity.NONE);
assertNull(m.get(GrpcUtil.USER_AGENT_KEY));
assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
}

View File

@ -39,6 +39,7 @@ import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.Http2ClientStream;
import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf;
@ -94,6 +95,7 @@ abstract class NettyClientStream extends Http2ClientStream implements StreamIdHo
// Convert the headers into Netty HTTP/2 headers.
AsciiString defaultPath = new AsciiString("/" + method.getFullMethodName());
headers.removeAll(GrpcUtil.USER_AGENT_KEY);
Http2Headers http2Headers
= Utils.convertClientHeaders(headers, scheme, defaultPath, authority, userAgent);
headers = null;

View File

@ -31,6 +31,7 @@
package io.grpc.netty;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame;
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
@ -52,10 +53,13 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableListMultimap;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
@ -70,6 +74,7 @@ import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
@ -369,6 +374,25 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
assertTrue(stream.isReady());
}
@Test
public void removeUserAgentFromApplicationHeaders() {
Metadata metadata = new Metadata();
metadata.put(GrpcUtil.USER_AGENT_KEY, "bad agent");
listener = mock(ClientStreamListener.class);
Mockito.reset(writeQueue);
when(writeQueue.enqueue(any(), any(boolean.class))).thenReturn(future);
stream = new NettyClientStreamImpl(methodDescriptor, new Metadata(), channel, handler,
DEFAULT_MAX_MESSAGE_SIZE, AsciiString.of("localhost"), AsciiString.of("http"),
AsciiString.of("good agent"));
stream.start(listener);
ArgumentCaptor<CreateStreamCommand> cmdCap = ArgumentCaptor.forClass(CreateStreamCommand.class);
verify(writeQueue).enqueue(cmdCap.capture(), eq(false));
assertThat(ImmutableListMultimap.copyOf(cmdCap.getValue().headers()))
.containsEntry(Utils.USER_AGENT, AsciiString.of("good agent"));
}
@Override
protected NettyClientStream createStream() {
when(handler.getWriteQueue()).thenReturn(writeQueue);

View File

@ -38,6 +38,7 @@ import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.Http2ClientStream;
import io.grpc.internal.WritableBuffer;
import io.grpc.okhttp.internal.framed.ErrorCode;
@ -138,6 +139,7 @@ class OkHttpClientStream extends Http2ClientStream {
public void start(ClientStreamListener listener) {
super.start(listener);
String defaultPath = "/" + method.getFullMethodName();
headers.removeAll(GrpcUtil.USER_AGENT_KEY);
List<Header> requestHeaders =
Headers.createRequestHeaders(headers, defaultPath, authority, userAgent);
headers = null;

View File

@ -31,8 +31,10 @@
package io.grpc.okhttp;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -41,12 +43,16 @@ import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.MethodType;
import io.grpc.Status;
import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil;
import io.grpc.okhttp.internal.framed.ErrorCode;
import io.grpc.okhttp.internal.framed.Header;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
@ -54,6 +60,7 @@ import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.InputStream;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
@RunWith(JUnit4.class)
@ -64,6 +71,8 @@ public class OkHttpClientStreamTest {
@Mock private AsyncFrameWriter frameWriter;
@Mock private OkHttpClientTransport transport;
@Mock private OutboundFlowController flowController;
@Captor private ArgumentCaptor<List<Header>> headersCaptor;
private final Object lock = new Object();
private MethodDescriptor<?, ?> methodDescriptor;
@ -126,8 +135,42 @@ public class OkHttpClientStreamTest {
verifyNoMoreInteractions(frameWriter);
}
@Test
public void start_userAgentRemoved() {
Metadata metaData = new Metadata();
metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application");
stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport,
flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application");
stream.start(new BaseClientStreamListener());
stream.start(3);
// TODO(carl-mastrangelo): extract this out into a testing/ directory and remove other defintions
verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
assertThat(headersCaptor.getValue())
.contains(new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application"));
}
@Test
public void start_headerFieldOrder() {
Metadata metaData = new Metadata();
metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application");
stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport,
flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application");
stream.start(new BaseClientStreamListener());
stream.start(3);
verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
assertThat(headersCaptor.getValue()).containsExactly(
Headers.SCHEME_HEADER,
Headers.METHOD_HEADER,
new Header(Header.TARGET_AUTHORITY, "localhost"),
new Header(Header.TARGET_PATH, "/" + methodDescriptor.getFullMethodName()),
new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application"),
Headers.CONTENT_TYPE_HEADER,
Headers.TE_HEADER)
.inOrder();
}
// TODO(carl-mastrangelo): extract this out into a testing/ directory and remove other definitions
// of it.
private static class BaseClientStreamListener implements ClientStreamListener {
@Override