core,netty,okhttp: move user agent removal closer to where it is set

This commit is contained in:
Carl Mastrangelo 2016-05-26 12:27:50 -07:00
parent 028d0844dd
commit 02eb24b3bd
6 changed files with 75 additions and 9 deletions

View File

@ -41,7 +41,6 @@ import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_JOINER;
import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY;
import static io.grpc.internal.GrpcUtil.USER_AGENT_KEY;
import static java.lang.Math.max; import static java.lang.Math.max;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
@ -141,9 +140,6 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
@VisibleForTesting @VisibleForTesting
static void prepareHeaders(Metadata headers, DecompressorRegistry decompressorRegistry, static void prepareHeaders(Metadata headers, DecompressorRegistry decompressorRegistry,
Compressor compressor) { Compressor compressor) {
// Remove user agent. Agent are added in the transport.
headers.removeAll(USER_AGENT_KEY);
headers.removeAll(MESSAGE_ENCODING_KEY); headers.removeAll(MESSAGE_ENCODING_KEY);
if (compressor != Codec.Identity.NONE) { if (compressor != Codec.Identity.NONE) {
headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding()); headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding());

View File

@ -196,12 +196,13 @@ public class ClientCallImplTest {
} }
@Test @Test
public void prepareHeaders_userAgentRemove() { public void prepareHeaders_userAgentIgnored() {
Metadata m = new Metadata(); Metadata m = new Metadata();
m.put(GrpcUtil.USER_AGENT_KEY, "batmobile"); m.put(GrpcUtil.USER_AGENT_KEY, "batmobile");
ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE); ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE);
assertThat(m.get(GrpcUtil.USER_AGENT_KEY)).isNull(); // User Agent is removed and set by the transport
assertThat(m.get(GrpcUtil.USER_AGENT_KEY)).isNotNull();
} }
@Test @Test
@ -262,13 +263,11 @@ public class ClientCallImplTest {
@Test @Test
public void prepareHeaders_removeReservedHeaders() { public void prepareHeaders_removeReservedHeaders() {
Metadata m = new Metadata(); Metadata m = new Metadata();
m.put(GrpcUtil.USER_AGENT_KEY, "user agent");
m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip");
m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip");
ClientCallImpl.prepareHeaders(m, DecompressorRegistry.newEmptyInstance(), Codec.Identity.NONE); ClientCallImpl.prepareHeaders(m, DecompressorRegistry.newEmptyInstance(), Codec.Identity.NONE);
assertNull(m.get(GrpcUtil.USER_AGENT_KEY));
assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
} }

View File

@ -39,6 +39,7 @@ import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.Http2ClientStream; import io.grpc.internal.Http2ClientStream;
import io.grpc.internal.WritableBuffer; import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
@ -94,6 +95,7 @@ abstract class NettyClientStream extends Http2ClientStream implements StreamIdHo
// Convert the headers into Netty HTTP/2 headers. // Convert the headers into Netty HTTP/2 headers.
AsciiString defaultPath = new AsciiString("/" + method.getFullMethodName()); AsciiString defaultPath = new AsciiString("/" + method.getFullMethodName());
headers.removeAll(GrpcUtil.USER_AGENT_KEY);
Http2Headers http2Headers Http2Headers http2Headers
= Utils.convertClientHeaders(headers, scheme, defaultPath, authority, userAgent); = Utils.convertClientHeaders(headers, scheme, defaultPath, authority, userAgent);
headers = null; headers = null;

View File

@ -31,6 +31,7 @@
package io.grpc.netty; package io.grpc.netty;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame; import static io.grpc.netty.NettyTestUtil.messageFrame;
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC; import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
@ -52,10 +53,13 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableListMultimap;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.Channel; import io.netty.channel.Channel;
@ -70,6 +74,7 @@ import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
@ -369,6 +374,25 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
assertTrue(stream.isReady()); assertTrue(stream.isReady());
} }
@Test
public void removeUserAgentFromApplicationHeaders() {
Metadata metadata = new Metadata();
metadata.put(GrpcUtil.USER_AGENT_KEY, "bad agent");
listener = mock(ClientStreamListener.class);
Mockito.reset(writeQueue);
when(writeQueue.enqueue(any(), any(boolean.class))).thenReturn(future);
stream = new NettyClientStreamImpl(methodDescriptor, new Metadata(), channel, handler,
DEFAULT_MAX_MESSAGE_SIZE, AsciiString.of("localhost"), AsciiString.of("http"),
AsciiString.of("good agent"));
stream.start(listener);
ArgumentCaptor<CreateStreamCommand> cmdCap = ArgumentCaptor.forClass(CreateStreamCommand.class);
verify(writeQueue).enqueue(cmdCap.capture(), eq(false));
assertThat(ImmutableListMultimap.copyOf(cmdCap.getValue().headers()))
.containsEntry(Utils.USER_AGENT, AsciiString.of("good agent"));
}
@Override @Override
protected NettyClientStream createStream() { protected NettyClientStream createStream() {
when(handler.getWriteQueue()).thenReturn(writeQueue); when(handler.getWriteQueue()).thenReturn(writeQueue);

View File

@ -38,6 +38,7 @@ import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.Http2ClientStream; import io.grpc.internal.Http2ClientStream;
import io.grpc.internal.WritableBuffer; import io.grpc.internal.WritableBuffer;
import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.ErrorCode;
@ -138,6 +139,7 @@ class OkHttpClientStream extends Http2ClientStream {
public void start(ClientStreamListener listener) { public void start(ClientStreamListener listener) {
super.start(listener); super.start(listener);
String defaultPath = "/" + method.getFullMethodName(); String defaultPath = "/" + method.getFullMethodName();
headers.removeAll(GrpcUtil.USER_AGENT_KEY);
List<Header> requestHeaders = List<Header> requestHeaders =
Headers.createRequestHeaders(headers, defaultPath, authority, userAgent); Headers.createRequestHeaders(headers, defaultPath, authority, userAgent);
headers = null; headers = null;

View File

@ -31,8 +31,10 @@
package io.grpc.okhttp; package io.grpc.okhttp;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -41,12 +43,16 @@ import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.MethodType; import io.grpc.MethodDescriptor.MethodType;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil;
import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.ErrorCode;
import io.grpc.okhttp.internal.framed.Header;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
@ -54,6 +60,7 @@ import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import java.io.InputStream; import java.io.InputStream;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
@ -64,6 +71,8 @@ public class OkHttpClientStreamTest {
@Mock private AsyncFrameWriter frameWriter; @Mock private AsyncFrameWriter frameWriter;
@Mock private OkHttpClientTransport transport; @Mock private OkHttpClientTransport transport;
@Mock private OutboundFlowController flowController; @Mock private OutboundFlowController flowController;
@Captor private ArgumentCaptor<List<Header>> headersCaptor;
private final Object lock = new Object(); private final Object lock = new Object();
private MethodDescriptor<?, ?> methodDescriptor; private MethodDescriptor<?, ?> methodDescriptor;
@ -126,8 +135,42 @@ public class OkHttpClientStreamTest {
verifyNoMoreInteractions(frameWriter); verifyNoMoreInteractions(frameWriter);
} }
@Test
public void start_userAgentRemoved() {
Metadata metaData = new Metadata();
metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application");
stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport,
flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application");
stream.start(new BaseClientStreamListener());
stream.start(3);
// TODO(carl-mastrangelo): extract this out into a testing/ directory and remove other defintions verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
assertThat(headersCaptor.getValue())
.contains(new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application"));
}
@Test
public void start_headerFieldOrder() {
Metadata metaData = new Metadata();
metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application");
stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport,
flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application");
stream.start(new BaseClientStreamListener());
stream.start(3);
verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
assertThat(headersCaptor.getValue()).containsExactly(
Headers.SCHEME_HEADER,
Headers.METHOD_HEADER,
new Header(Header.TARGET_AUTHORITY, "localhost"),
new Header(Header.TARGET_PATH, "/" + methodDescriptor.getFullMethodName()),
new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application"),
Headers.CONTENT_TYPE_HEADER,
Headers.TE_HEADER)
.inOrder();
}
// TODO(carl-mastrangelo): extract this out into a testing/ directory and remove other definitions
// of it. // of it.
private static class BaseClientStreamListener implements ClientStreamListener { private static class BaseClientStreamListener implements ClientStreamListener {
@Override @Override