diff --git a/build.gradle b/build.gradle index ee9377ea4f..c15fe3d861 100644 --- a/build.gradle +++ b/build.gradle @@ -163,6 +163,7 @@ subprojects { netty: 'io.netty:netty-codec-http2:[4.1.7.Final]', netty_epoll: 'io.netty:netty-transport-native-epoll:4.1.7.Final' + epoll_suffix, + netty_proxy_handler: 'io.netty:netty-handler-proxy:4.1.7.Final', netty_tcnative: 'io.netty:netty-tcnative-boringssl-static:1.1.33.Fork25', // Test dependencies. diff --git a/netty/build.gradle b/netty/build.gradle index 4677048e29..6c03251bd0 100644 --- a/netty/build.gradle +++ b/netty/build.gradle @@ -1,7 +1,8 @@ description = "gRPC: Netty" dependencies { compile project(':grpc-core'), - libraries.netty + libraries.netty, + libraries.netty_proxy_handler // Tests depend on base class defined by core module. testCompile project(':grpc-core').sourceSets.test.output, diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 3f35ed121a..aec7b0f7e8 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -291,6 +291,25 @@ public final class NettyChannelBuilder String authority, NegotiationType negotiationType, SslContext sslContext) { + ProtocolNegotiator negotiator = + createProtocolNegotiatorByType(authority, negotiationType, sslContext); + String proxy = System.getenv("GRPC_PROXY_EXP"); + if (proxy != null) { + String[] parts = proxy.split(":", 2); + int port = 80; + if (parts.length > 1) { + port = Integer.parseInt(parts[1]); + } + InetSocketAddress proxyAddress = new InetSocketAddress(parts[0], port); + negotiator = ProtocolNegotiators.httpProxy(proxyAddress, null, null, negotiator); + } + return negotiator; + } + + private static ProtocolNegotiator createProtocolNegotiatorByType( + String authority, + NegotiationType negotiationType, + SslContext sslContext) { switch (negotiationType) { case PLAINTEXT: return ProtocolNegotiators.plaintext(); diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 0c77024cf2..6311831bb5 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -45,6 +45,7 @@ import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; @@ -54,6 +55,9 @@ import io.netty.handler.codec.http.HttpClientUpgradeHandler; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http2.Http2ClientUpgradeCodec; +import io.netty.handler.proxy.HttpProxyHandler; +import io.netty.handler.proxy.ProxyConnectionEvent; +import io.netty.handler.proxy.ProxyHandler; import io.netty.handler.ssl.OpenSsl; import io.netty.handler.ssl.OpenSslEngine; import io.netty.handler.ssl.SslContext; @@ -61,6 +65,7 @@ import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandshakeCompletionEvent; import io.netty.util.AsciiString; import io.netty.util.ReferenceCountUtil; +import java.net.SocketAddress; import java.net.URI; import java.util.ArrayDeque; import java.util.Arrays; @@ -189,6 +194,73 @@ public final class ProtocolNegotiators { } } + /** + * Returns a {@link ProtocolNegotiator} that does HTTP CONNECT proxy negotiation. + */ + public static ProtocolNegotiator httpProxy(final SocketAddress proxyAddress, + final @Nullable String proxyUsername, final @Nullable String proxyPassword, + final ProtocolNegotiator negotiator) { + Preconditions.checkNotNull(proxyAddress, "proxyAddress"); + Preconditions.checkNotNull(negotiator, "negotiator"); + class ProxyNegotiator implements ProtocolNegotiator { + @Override + public Handler newHandler(GrpcHttp2ConnectionHandler http2Handler) { + HttpProxyHandler proxyHandler; + if (proxyUsername == null || proxyPassword == null) { + proxyHandler = new HttpProxyHandler(proxyAddress); + } else { + proxyHandler = new HttpProxyHandler(proxyAddress, proxyUsername, proxyPassword); + } + return new BufferUntilProxyTunnelledHandler( + proxyHandler, negotiator.newHandler(http2Handler)); + } + } + + return new ProxyNegotiator(); + } + + /** + * Buffers all writes until the HTTP CONNECT tunnel is established. + */ + static final class BufferUntilProxyTunnelledHandler extends AbstractBufferingHandler + implements ProtocolNegotiator.Handler { + private final ProtocolNegotiator.Handler originalHandler; + + public BufferUntilProxyTunnelledHandler( + ProxyHandler proxyHandler, ProtocolNegotiator.Handler handler) { + super(proxyHandler, handler); + this.originalHandler = handler; + } + + + @Override + public AsciiString scheme() { + return originalHandler.scheme(); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof ProxyConnectionEvent) { + writeBufferedAndRemove(ctx); + } + super.userEventTriggered(ctx, evt); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + fail(ctx, unavailableException("Connection broken while trying to CONNECT through proxy")); + super.channelInactive(ctx); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise future) throws Exception { + if (ctx.channel().isActive()) { // This may be a notification that the socket was closed + fail(ctx, unavailableException("Channel closed while trying to CONNECT through proxy")); + } + super.close(ctx, future); + } + } + /** * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} @@ -366,10 +438,22 @@ public final class ProtocolNegotiators { * lifetime and we only want to configure it once. */ if (handlers != null) { - ctx.pipeline().addFirst(handlers); + for (ChannelHandler handler : handlers) { + ctx.pipeline().addBefore(ctx.name(), null, handler); + } + ChannelHandler handler0 = handlers[0]; + ChannelHandlerContext handler0Ctx = ctx.pipeline().context(handlers[0]); handlers = null; + if (handler0Ctx != null) { // The handler may have removed itself immediately + if (handler0 instanceof ChannelInboundHandler) { + ((ChannelInboundHandler) handler0).channelRegistered(handler0Ctx); + } else { + handler0Ctx.fireChannelRegistered(); + } + } + } else { + super.channelRegistered(ctx); } - super.channelRegistered(ctx); } @Override @@ -424,7 +508,10 @@ public final class ProtocolNegotiators { @Override public void close(ChannelHandlerContext ctx, ChannelPromise future) throws Exception { - fail(ctx, unavailableException("Channel closed while performing protocol negotiation")); + if (ctx.channel().isActive()) { // This may be a notification that the socket was closed + fail(ctx, unavailableException("Channel closed while performing protocol negotiation")); + } + super.close(ctx, future); } protected final void fail(ChannelHandlerContext ctx, Throwable cause) { diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index b533ce5a9b..6d41f38108 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -31,25 +31,41 @@ package io.grpc.netty; +import static com.google.common.base.Charsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import io.grpc.netty.ProtocolNegotiators.ServerTlsHandler; import io.grpc.netty.ProtocolNegotiators.TlsNegotiator; import io.grpc.testing.TestUtils; +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandshakeCompletionEvent; import io.netty.handler.ssl.SupportedCipherSuiteFilter; import java.io.File; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.util.logging.Filter; import java.util.logging.Level; import java.util.logging.LogRecord; @@ -63,10 +79,17 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; @RunWith(JUnit4.class) public class ProtocolNegotiatorsTest { - @Rule public final ExpectedException thrown = ExpectedException.none(); + private static final Runnable NOOP_RUNNABLE = new Runnable() { + @Override public void run() {} + }; + + @Rule + public final ExpectedException thrown = ExpectedException.none(); private GrpcHttp2ConnectionHandler grpcHandler = mock(GrpcHttp2ConnectionHandler.class); @@ -81,7 +104,7 @@ public class ProtocolNegotiatorsTest { File serverCert = TestUtils.loadCert("server1.pem"); File key = TestUtils.loadCert("server1.key"); sslContext = GrpcSslContexts.forServer(serverCert, key) - .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build(); + .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build(); engine = SSLContext.getDefault().createSSLEngine(); } @@ -272,4 +295,92 @@ public class ProtocolNegotiatorsTest { assertEquals("bad_host:1234", negotiator.getHost()); assertEquals(-1, negotiator.getPort()); } + + @Test + public void httpProxy_nullAddressNpe() throws Exception { + thrown.expect(NullPointerException.class); + ProtocolNegotiators.httpProxy(null, "user", "pass", ProtocolNegotiators.plaintext()); + } + + @Test + public void httpProxy_nullNegotiatorNpe() throws Exception { + thrown.expect(NullPointerException.class); + ProtocolNegotiators.httpProxy( + InetSocketAddress.createUnresolved("localhost", 80), "user", "pass", null); + } + + @Test + public void httpProxy_nullUserPassNoException() throws Exception { + assertNotNull(ProtocolNegotiators.httpProxy( + InetSocketAddress.createUnresolved("localhost", 80), null, null, + ProtocolNegotiators.plaintext())); + } + + @Test(timeout = 5000) + public void httpProxy_completes() throws Exception { + DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); + // ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called + // the channel is already active. + LocalAddress proxy = new LocalAddress("httpProxy_completes"); + SocketAddress host = InetSocketAddress.createUnresolved("specialHost", 314); + + ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class); + Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class) + .childHandler(mockHandler) + .bind(proxy).sync().channel(); + + ProtocolNegotiator nego = + ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext()); + ChannelHandler handler = nego.newHandler(grpcHandler); + Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler) + .register().sync().channel(); + pipeline = channel.pipeline(); + // Wait for initialization to complete + channel.eventLoop().submit(NOOP_RUNNABLE).sync(); + // The grpcHandler must be in the pipeline, but we don't actually want it during our test + // because it will consume all events since it is a mock. We only use it because it is required + // to construct the Handler. + pipeline.remove(grpcHandler); + channel.connect(host).sync(); + serverChannel.close(); + ArgumentCaptor contextCaptor = + ArgumentCaptor.forClass(ChannelHandlerContext.class); + Mockito.verify(mockHandler).channelActive(contextCaptor.capture()); + ChannelHandlerContext serverContext = contextCaptor.getValue(); + + final String golden = "isThisThingOn?"; + ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel)); + + // Wait for sending initial request to complete + channel.eventLoop().submit(NOOP_RUNNABLE).sync(); + ArgumentCaptor objectCaptor = ArgumentCaptor.forClass(Object.class); + Mockito.verify(mockHandler) + .channelRead(any(ChannelHandlerContext.class), objectCaptor.capture()); + ByteBuf b = (ByteBuf) objectCaptor.getValue(); + String request = b.toString(UTF_8); + b.release(); + assertTrue("No trailing newline: " + request, request.endsWith("\r\n\r\n")); + assertTrue("No CONNECT: " + request, request.startsWith("CONNECT specialHost:314 ")); + assertTrue("No host header: " + request, request.contains("host: specialHost:314")); + + assertFalse(negotiationFuture.isDone()); + serverContext.writeAndFlush(bb("HTTP/1.1 200 OK\r\n\r\n", serverContext.channel())).sync(); + negotiationFuture.sync(); + + channel.eventLoop().submit(NOOP_RUNNABLE).sync(); + objectCaptor.getAllValues().clear(); + Mockito.verify(mockHandler, times(2)) + .channelRead(any(ChannelHandlerContext.class), objectCaptor.capture()); + b = (ByteBuf) objectCaptor.getAllValues().get(1); + // If we were using the real grpcHandler, this would have been the HTTP/2 preface + String preface = b.toString(UTF_8); + b.release(); + assertEquals(golden, preface); + + channel.close(); + } + + private static ByteBuf bb(String s, Channel c) { + return ByteBufUtil.writeUtf8(c.alloc(), s); + } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index 3fa57db4a9..7bf39d4bff 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -311,9 +311,20 @@ public class OkHttpChannelBuilder extends if (closed) { throw new IllegalStateException("The transport factory is closed."); } + InetSocketAddress proxyAddress = null; + String proxy = System.getenv("GRPC_PROXY_EXP"); + if (proxy != null) { + String[] parts = proxy.split(":", 2); + int port = 80; + if (parts.length > 1) { + port = Integer.parseInt(parts[1]); + } + proxyAddress = new InetSocketAddress(parts[0], port); + } InetSocketAddress inetSocketAddr = (InetSocketAddress) addr; OkHttpClientTransport transport = new OkHttpClientTransport(inetSocketAddr, authority, - userAgent, executor, socketFactory, Utils.convertSpec(connectionSpec), maxMessageSize); + userAgent, executor, socketFactory, Utils.convertSpec(connectionSpec), maxMessageSize, + proxyAddress, null, null); if (enableKeepAlive) { transport.enableKeepAlive(true, keepAliveDelayNanos, keepAliveTimeoutNanos); } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index b1135d7b0b..225216a497 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -39,6 +39,10 @@ import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; import com.google.common.base.Ticker; import com.google.common.util.concurrent.SettableFuture; +import com.squareup.okhttp.Credentials; +import com.squareup.okhttp.HttpUrl; +import com.squareup.okhttp.Request; +import com.squareup.okhttp.internal.http.StatusLine; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Metadata; @@ -46,6 +50,7 @@ import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.MethodType; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.StatusException; import io.grpc.internal.ConnectionClientTransport; import io.grpc.internal.GrpcUtil; import io.grpc.internal.Http2Ping; @@ -63,6 +68,7 @@ import io.grpc.okhttp.internal.framed.HeadersMode; import io.grpc.okhttp.internal.framed.Http2; import io.grpc.okhttp.internal.framed.Settings; import io.grpc.okhttp.internal.framed.Variant; +import java.io.EOFException; import java.io.IOException; import java.net.InetSocketAddress; import java.net.Socket; @@ -176,6 +182,12 @@ class OkHttpClientTransport implements ConnectionClientTransport { private boolean enableKeepAlive; private long keepAliveDelayNanos; private long keepAliveTimeoutNanos; + @Nullable + private final InetSocketAddress proxyAddress; + @Nullable + private final String proxyUsername; + @Nullable + private final String proxyPassword; // The following fields should only be used for test. Runnable connectingCallback; @@ -183,7 +195,8 @@ class OkHttpClientTransport implements ConnectionClientTransport { OkHttpClientTransport(InetSocketAddress address, String authority, @Nullable String userAgent, Executor executor, @Nullable SSLSocketFactory sslSocketFactory, ConnectionSpec connectionSpec, - int maxMessageSize) { + int maxMessageSize, @Nullable InetSocketAddress proxyAddress, @Nullable String proxyUsername, + @Nullable String proxyPassword) { this.address = Preconditions.checkNotNull(address, "address"); this.defaultAuthority = authority; this.maxMessageSize = maxMessageSize; @@ -196,6 +209,9 @@ class OkHttpClientTransport implements ConnectionClientTransport { this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec"); this.ticker = Ticker.systemTicker(); this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent); + this.proxyAddress = proxyAddress; + this.proxyUsername = proxyUsername; + this.proxyPassword = proxyPassword; } /** @@ -220,6 +236,9 @@ class OkHttpClientTransport implements ConnectionClientTransport { this.connectionSpec = null; this.connectingCallback = connectingCallback; this.connectedFuture = Preconditions.checkNotNull(connectedFuture, "connectedFuture"); + this.proxyAddress = null; + this.proxyUsername = null; + this.proxyPassword = null; } /** @@ -396,7 +415,12 @@ class OkHttpClientTransport implements ConnectionClientTransport { BufferedSink sink; Socket sock; try { - sock = new Socket(address.getAddress(), address.getPort()); + if (proxyAddress == null) { + sock = new Socket(address.getAddress(), address.getPort()); + } else { + sock = createHttpProxySocket(address, proxyAddress, proxyUsername, proxyPassword); + } + if (sslSocketFactory != null) { sock = OkHttpTlsUpgrader.upgrade( sslSocketFactory, sock, getOverridenHost(), getOverridenPort(), connectionSpec); @@ -404,6 +428,9 @@ class OkHttpClientTransport implements ConnectionClientTransport { sock.setTcpNoDelay(true); source = Okio.buffer(Okio.source(sock)); sink = Okio.buffer(Okio.sink(sock)); + } catch (StatusException e) { + startGoAway(0, ErrorCode.INTERNAL_ERROR, e.getStatus()); + return; } catch (Exception e) { onException(e); return; @@ -437,6 +464,93 @@ class OkHttpClientTransport implements ConnectionClientTransport { return null; } + private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddress proxyAddress, + String proxyUsername, String proxyPassword) throws IOException, StatusException { + try { + Socket sock = new Socket(proxyAddress.getAddress(), proxyAddress.getPort()); + sock.setTcpNoDelay(true); + + Source source = Okio.source(sock); + BufferedSink sink = Okio.buffer(Okio.sink(sock)); + + // Prepare headers and request method line + Request proxyRequest = createHttpProxyRequest(address, proxyUsername, proxyPassword); + HttpUrl url = proxyRequest.httpUrl(); + String requestLine = String.format("CONNECT %s:%d HTTP/1.1", url.host(), url.port()); + + // Write request to socket + sink.writeUtf8(requestLine).writeUtf8("\r\n"); + for (int i = 0, size = proxyRequest.headers().size(); i < size; i++) { + sink.writeUtf8(proxyRequest.headers().name(i)) + .writeUtf8(": ") + .writeUtf8(proxyRequest.headers().value(i)) + .writeUtf8("\r\n"); + } + sink.writeUtf8("\r\n"); + // Flush buffer (flushes socket and sends request) + sink.flush(); + + // Read status line, check if 2xx was returned + StatusLine statusLine = StatusLine.parse(readUtf8LineStrictUnbuffered(source)); + // Drain rest of headers + while (!readUtf8LineStrictUnbuffered(source).equals("")) {} + if (statusLine.code < 200 || statusLine.code >= 300) { + Buffer body = new Buffer(); + try { + sock.shutdownOutput(); + source.read(body, 1024); + } catch (IOException ex) { + body.writeUtf8("Unable to read body: " + ex.toString()); + } + try { + sock.close(); + } catch (IOException ignored) { + // ignored + } + String message = String.format( + "Response returned from proxy was not successful (expected 2xx, got %d %s). " + + "Response body:\n%s", + statusLine.code, statusLine.message, body.readUtf8()); + throw Status.UNAVAILABLE.withDescription(message).asException(); + } + return sock; + } catch (IOException e) { + throw Status.UNAVAILABLE.withDescription("Failed trying to connect with proxy").withCause(e) + .asException(); + } + } + + private Request createHttpProxyRequest(InetSocketAddress address, String proxyUsername, + String proxyPassword) { + HttpUrl tunnelUrl = new HttpUrl.Builder() + .scheme("https") + .host(address.getHostName()) + .port(address.getPort()) + .build(); + Request.Builder request = new Request.Builder() + .url(tunnelUrl) + .header("Host", tunnelUrl.host() + ":" + tunnelUrl.port()) + .header("User-Agent", userAgent); + + // If we have proxy credentials, set them right away + if (proxyUsername != null && proxyPassword != null) { + request.header("Proxy-Authorization", Credentials.basic(proxyUsername, proxyPassword)); + } + return request.build(); + } + + private static String readUtf8LineStrictUnbuffered(Source source) throws IOException { + Buffer buffer = new Buffer(); + while (true) { + if (source.read(buffer, 1) == -1) { + throw new EOFException("\\n not found: " + buffer.readByteString().hex()); + } + if (buffer.getByte(buffer.size() - 1) == '\n') { + return buffer.readUtf8LineStrict(); + } + } + } + @Override public String toString() { return getLogId() + "(" + address + ")"; diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index a4bd5c7704..6beff3d7b5 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -90,6 +90,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.net.InetSocketAddress; +import java.net.ServerSocket; import java.net.Socket; import java.util.ArrayList; import java.util.Arrays; @@ -191,7 +192,8 @@ public class OkHttpClientTransportTest { InetSocketAddress address = InetSocketAddress.createUnresolved("hostname", 31415); clientTransport = new OkHttpClientTransport( address, "hostname", null /* agent */, executor, null, - Utils.convertSpec(OkHttpChannelBuilder.DEFAULT_CONNECTION_SPEC), DEFAULT_MAX_MESSAGE_SIZE); + Utils.convertSpec(OkHttpChannelBuilder.DEFAULT_CONNECTION_SPEC), DEFAULT_MAX_MESSAGE_SIZE, + null, null, null); String s = clientTransport.toString(); assertTrue("Unexpected: " + s, s.contains("OkHttpClientTransport")); assertTrue("Unexpected: " + s, s.contains(address.toString())); @@ -1334,7 +1336,10 @@ public class OkHttpClientTransportTest { executor, null, ConnectionSpec.CLEARTEXT, - DEFAULT_MAX_MESSAGE_SIZE); + DEFAULT_MAX_MESSAGE_SIZE, + null, + null, + null); String host = clientTransport.getOverridenHost(); int port = clientTransport.getOverridenPort(); @@ -1352,7 +1357,10 @@ public class OkHttpClientTransportTest { executor, null, ConnectionSpec.CLEARTEXT, - DEFAULT_MAX_MESSAGE_SIZE); + DEFAULT_MAX_MESSAGE_SIZE, + null, + null, + null); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); clientTransport.start(listener); @@ -1368,6 +1376,131 @@ public class OkHttpClientTransportTest { assertEquals(Status.UNAVAILABLE.getCode(), streamListener.status.getCode()); } + @Test + public void proxy_200() throws Exception { + ServerSocket serverSocket = new ServerSocket(0); + clientTransport = new OkHttpClientTransport( + InetSocketAddress.createUnresolved("theservice", 80), + "authority", + "userAgent", + executor, + null, + ConnectionSpec.CLEARTEXT, + DEFAULT_MAX_MESSAGE_SIZE, + (InetSocketAddress) serverSocket.getLocalSocketAddress(), + null, + null); + clientTransport.start(transportListener); + + Socket sock = serverSocket.accept(); + serverSocket.close(); + + BufferedReader reader = new BufferedReader(new InputStreamReader(sock.getInputStream(), UTF_8)); + assertEquals("CONNECT theservice:80 HTTP/1.1", reader.readLine()); + assertEquals("Host: theservice:80", reader.readLine()); + while (!"".equals(reader.readLine())) {} + + sock.getOutputStream().write("HTTP/1.1 200 OK\r\nServer: test\r\n\r\n".getBytes(UTF_8)); + sock.getOutputStream().flush(); + + assertEquals("PRI * HTTP/2.0", reader.readLine()); + assertEquals("", reader.readLine()); + assertEquals("SM", reader.readLine()); + assertEquals("", reader.readLine()); + + // Empty SETTINGS + sock.getOutputStream().write(new byte[] {0, 0, 0, 0, 0x4, 0}); + // GOAWAY + sock.getOutputStream().write(new byte[] { + 0, 0, 0, 8, 0x7, 0, + 0, 0, 0, 0, // last stream id + 0, 0, 0, 0, // error code + }); + sock.getOutputStream().flush(); + + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class)); + while (sock.getInputStream().read() != -1) {} + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + sock.close(); + } + + @Test + public void proxy_500() throws Exception { + ServerSocket serverSocket = new ServerSocket(0); + clientTransport = new OkHttpClientTransport( + InetSocketAddress.createUnresolved("theservice", 80), + "authority", + "userAgent", + executor, + null, + ConnectionSpec.CLEARTEXT, + DEFAULT_MAX_MESSAGE_SIZE, + (InetSocketAddress) serverSocket.getLocalSocketAddress(), + null, + null); + clientTransport.start(transportListener); + + Socket sock = serverSocket.accept(); + serverSocket.close(); + + BufferedReader reader = new BufferedReader(new InputStreamReader(sock.getInputStream(), UTF_8)); + assertEquals("CONNECT theservice:80 HTTP/1.1", reader.readLine()); + assertEquals("Host: theservice:80", reader.readLine()); + while (!"".equals(reader.readLine())) {} + + final String errorText = "text describing error"; + sock.getOutputStream().write("HTTP/1.1 500 OH NO\r\n\r\n".getBytes(UTF_8)); + sock.getOutputStream().write(errorText.getBytes(UTF_8)); + sock.getOutputStream().flush(); + sock.shutdownOutput(); + + assertEquals(-1, sock.getInputStream().read()); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture()); + Status error = captor.getValue(); + assertTrue("Status didn't contain error code: " + captor.getValue(), + error.getDescription().contains("500")); + assertTrue("Status didn't contain error description: " + captor.getValue(), + error.getDescription().contains("OH NO")); + assertTrue("Status didn't contain error text: " + captor.getValue(), + error.getDescription().contains(errorText)); + assertEquals("Not UNAVAILABLE: " + captor.getValue(), + Status.UNAVAILABLE.getCode(), error.getCode()); + sock.close(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void proxy_immediateServerClose() throws Exception { + ServerSocket serverSocket = new ServerSocket(0); + clientTransport = new OkHttpClientTransport( + InetSocketAddress.createUnresolved("theservice", 80), + "authority", + "userAgent", + executor, + null, + ConnectionSpec.CLEARTEXT, + DEFAULT_MAX_MESSAGE_SIZE, + (InetSocketAddress) serverSocket.getLocalSocketAddress(), + null, + null); + clientTransport.start(transportListener); + + Socket sock = serverSocket.accept(); + serverSocket.close(); + sock.close(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture()); + Status error = captor.getValue(); + assertTrue("Status didn't contain proxy: " + captor.getValue(), + error.getDescription().contains("proxy")); + assertEquals("Not UNAVAILABLE: " + captor.getValue(), + Status.UNAVAILABLE.getCode(), error.getCode()); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + private int activeStreamCount() { return clientTransport.getActiveStreams().length; }