From de0df9740a4c6fb2a54ae30701c5744501d53d1a Mon Sep 17 00:00:00 2001 From: Xiao Hang Date: Tue, 25 Apr 2017 15:55:22 -0700 Subject: [PATCH] okhttp: Support GET in okhttp transport --- .../io/grpc/okhttp/OkHttpClientStream.java | 6 ++- .../grpc/okhttp/OkHttpClientStreamTest.java | 37 +++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java index 47323feffc..282e0729f6 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java @@ -34,6 +34,7 @@ package io.grpc.okhttp; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.common.io.BaseEncoding; import io.grpc.Attributes; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -83,7 +84,7 @@ class OkHttpClientStream extends AbstractClientStream2 { String authority, String userAgent, StatsTraceContext statsTraceCtx) { - super(new OkHttpWritableBufferAllocator(), statsTraceCtx, headers, false); + super(new OkHttpWritableBufferAllocator(), statsTraceCtx, headers, method.isSafe()); this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx"); this.method = method; this.authority = authority; @@ -127,6 +128,9 @@ class OkHttpClientStream extends AbstractClientStream2 { @Override public void writeHeaders(Metadata metadata, byte[] payload) { String defaultPath = "/" + method.getFullMethodName(); + if (payload != null) { + defaultPath += "?" + BaseEncoding.base64().encode(payload); + } metadata.discardAll(GrpcUtil.USER_AGENT_KEY); synchronized (state.lock) { state.streamReady(metadata, defaultPath); diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java index a95776aa5b..4398b00125 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java @@ -35,9 +35,12 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.eq; +import static org.mockito.Matchers.isA; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import com.google.common.io.BaseEncoding; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.MethodType; @@ -47,7 +50,9 @@ import io.grpc.internal.GrpcUtil; import io.grpc.internal.StatsTraceContext; import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.Header; +import java.io.ByteArrayInputStream; import java.io.InputStream; +import java.nio.charset.Charset; import java.util.List; import java.util.concurrent.atomic.AtomicReference; import org.junit.Before; @@ -176,6 +181,38 @@ public class OkHttpClientStreamTest { .inOrder(); } + @Test + public void getUnaryRequest() { + MethodDescriptor getMethod = MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("/service/method") + .setIdempotent(true) + .setSafe(true) + .setRequestMarshaller(marshaller) + .setResponseMarshaller(marshaller) + .build(); + stream = new OkHttpClientStream(getMethod, new Metadata(), frameWriter, transport, + flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application", + StatsTraceContext.NOOP); + stream.start(new BaseClientStreamListener()); + + // GET streams send headers after halfClose is called. + verify(frameWriter, times(0)).synStream( + eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); + verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class)); + + byte[] msg = "request".getBytes(Charset.forName("UTF-8")); + stream.writeMessage(new ByteArrayInputStream(msg)); + stream.halfClose(); + verify(transport).streamReadyToStart(eq(stream)); + stream.transportState().start(3); + + verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); + assertThat(headersCaptor.getValue()).contains( + new Header(Header.TARGET_PATH, "/" + getMethod.getFullMethodName() + "?" + + BaseEncoding.base64().encode(msg))); + } + // TODO(carl-mastrangelo): extract this out into a testing/ directory and remove other definitions // of it. private static class BaseClientStreamListener implements ClientStreamListener {