diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index 0311976e94..2b4b1162ed 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -49,6 +49,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; +import javax.net.SocketFactory; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; @@ -120,6 +121,7 @@ public class OkHttpChannelBuilder extends private Executor transportExecutor; private ScheduledExecutorService scheduledExecutorService; + private SocketFactory socketFactory; private SSLSocketFactory sslSocketFactory; private HostnameVerifier hostnameVerifier; private ConnectionSpec connectionSpec = INTERNAL_DEFAULT_CONNECTION_SPEC; @@ -156,6 +158,17 @@ public class OkHttpChannelBuilder extends return this; } + /** + * Override the default {@link SocketFactory} used to create sockets. If the socket factory is not + * set or set to null, a default one will be used. + * + * @since 1.20.0 + */ + public final OkHttpChannelBuilder socketFactory(@Nullable SocketFactory socketFactory) { + this.socketFactory = socketFactory; + return this; + } + /** * Sets the negotiation type for the HTTP/2 connection. * @@ -397,10 +410,21 @@ public class OkHttpChannelBuilder extends @Internal protected final ClientTransportFactory buildTransportFactory() { boolean enableKeepAlive = keepAliveTimeNanos != KEEPALIVE_TIME_NANOS_DISABLED; - return new OkHttpTransportFactory(transportExecutor, scheduledExecutorService, - createSocketFactory(), hostnameVerifier, connectionSpec, maxInboundMessageSize(), - enableKeepAlive, keepAliveTimeNanos, keepAliveTimeoutNanos, flowControlWindow, - keepAliveWithoutCalls, maxInboundMetadataSize, transportTracerFactory); + return new OkHttpTransportFactory( + transportExecutor, + scheduledExecutorService, + socketFactory, + createSslSocketFactory(), + hostnameVerifier, + connectionSpec, + maxInboundMessageSize(), + enableKeepAlive, + keepAliveTimeNanos, + keepAliveTimeoutNanos, + flowControlWindow, + keepAliveWithoutCalls, + maxInboundMetadataSize, + transportTracerFactory); } @Override @@ -417,7 +441,7 @@ public class OkHttpChannelBuilder extends @VisibleForTesting @Nullable - SSLSocketFactory createSocketFactory() { + SSLSocketFactory createSslSocketFactory() { switch (negotiationType) { case TLS: try { @@ -463,8 +487,8 @@ public class OkHttpChannelBuilder extends private final boolean usingSharedExecutor; private final boolean usingSharedScheduler; private final TransportTracer.Factory transportTracerFactory; - @Nullable - private final SSLSocketFactory socketFactory; + private final SocketFactory socketFactory; + @Nullable private final SSLSocketFactory sslSocketFactory; @Nullable private final HostnameVerifier hostnameVerifier; private final ConnectionSpec connectionSpec; @@ -478,9 +502,11 @@ public class OkHttpChannelBuilder extends private final ScheduledExecutorService timeoutService; private boolean closed; - private OkHttpTransportFactory(Executor executor, + private OkHttpTransportFactory( + Executor executor, @Nullable ScheduledExecutorService timeoutService, - @Nullable SSLSocketFactory socketFactory, + @Nullable SocketFactory socketFactory, + @Nullable SSLSocketFactory sslSocketFactory, @Nullable HostnameVerifier hostnameVerifier, ConnectionSpec connectionSpec, int maxMessageSize, @@ -495,6 +521,7 @@ public class OkHttpChannelBuilder extends this.timeoutService = usingSharedScheduler ? SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE) : timeoutService; this.socketFactory = socketFactory; + this.sslSocketFactory = sslSocketFactory; this.hostnameVerifier = hostnameVerifier; this.connectionSpec = connectionSpec; this.maxMessageSize = maxMessageSize; @@ -536,6 +563,7 @@ public class OkHttpChannelBuilder extends options.getUserAgent(), executor, socketFactory, + sslSocketFactory, hostnameVerifier, connectionSpec, maxMessageSize, diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index d97dd40545..8fd88f2cc0 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -86,6 +86,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; +import javax.net.SocketFactory; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; @@ -175,6 +176,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep private boolean stopped; @GuardedBy("lock") private boolean hasStream; + private final SocketFactory socketFactory; private SSLSocketFactory sslSocketFactory; private HostnameVerifier hostnameVerifier; private Socket socket; @@ -219,12 +221,21 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep Runnable connectingCallback; SettableFuture connectedFuture; - OkHttpClientTransport(InetSocketAddress address, String authority, @Nullable String userAgent, - Executor executor, @Nullable SSLSocketFactory sslSocketFactory, - @Nullable HostnameVerifier hostnameVerifier, ConnectionSpec connectionSpec, - int maxMessageSize, int initialWindowSize, + OkHttpClientTransport( + InetSocketAddress address, + String authority, + @Nullable String userAgent, + Executor executor, + @Nullable SocketFactory socketFactory, + @Nullable SSLSocketFactory sslSocketFactory, + @Nullable HostnameVerifier hostnameVerifier, + ConnectionSpec connectionSpec, + int maxMessageSize, + int initialWindowSize, @Nullable HttpConnectProxiedSocketAddress proxiedAddr, - Runnable tooManyPingsRunnable, int maxInboundMetadataSize, TransportTracer transportTracer) { + Runnable tooManyPingsRunnable, + int maxInboundMetadataSize, + TransportTracer transportTracer) { this.address = Preconditions.checkNotNull(address, "address"); this.defaultAuthority = authority; this.maxMessageSize = maxMessageSize; @@ -234,6 +245,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep // Client initiated streams are odd, server initiated ones are even. Server should not need to // use it. We start clients at 3 to avoid conflicting with HTTP negotiation. nextStreamId = 3; + this.socketFactory = socketFactory == null ? SocketFactory.getDefault() : socketFactory; this.sslSocketFactory = sslSocketFactory; this.hostnameVerifier = hostnameVerifier; this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec"); @@ -273,6 +285,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent); this.executor = Preconditions.checkNotNull(executor, "executor"); serializingExecutor = new SerializingExecutor(executor); + this.socketFactory = SocketFactory.getDefault(); this.testFrameReader = Preconditions.checkNotNull(frameReader, "frameReader"); this.testFrameWriter = Preconditions.checkNotNull(testFrameWriter, "testFrameWriter"); this.socket = Preconditions.checkNotNull(socket, "socket"); @@ -506,7 +519,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep SSLSession sslSession = null; try { if (proxiedAddr == null) { - sock = new Socket(address.getAddress(), address.getPort()); + sock = socketFactory.createSocket(address.getAddress(), address.getPort()); } else { if (proxiedAddr.getProxyAddress() instanceof InetSocketAddress) { sock = createHttpProxySocket( @@ -584,9 +597,10 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep Socket sock; // The proxy address may not be resolved if (proxyAddress.getAddress() != null) { - sock = new Socket(proxyAddress.getAddress(), proxyAddress.getPort()); + sock = socketFactory.createSocket(proxyAddress.getAddress(), proxyAddress.getPort()); } else { - sock = new Socket(proxyAddress.getHostName(), proxyAddress.getPort()); + sock = + socketFactory.createSocket(proxyAddress.getHostName(), proxyAddress.getPort()); } sock.setTcpNoDelay(true); @@ -771,6 +785,11 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep return clientFrameHandler; } + @VisibleForTesting + SocketFactory getSocketFactory() { + return socketFactory; + } + @VisibleForTesting int getPendingStreamSize() { synchronized (lock) { diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java index 3f9104cf72..1e84a00414 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java @@ -27,8 +27,11 @@ import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourceHolder; +import java.net.InetAddress; import java.net.InetSocketAddress; +import java.net.Socket; import java.util.concurrent.ScheduledExecutorService; +import javax.net.SocketFactory; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -125,10 +128,10 @@ public class OkHttpChannelBuilderTest { @Test public void usePlaintextCreatesNullSocketFactory() { OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("host", 1234); - assertNotNull(builder.createSocketFactory()); + assertNotNull(builder.createSslSocketFactory()); builder.usePlaintext(); - assertNull(builder.createSocketFactory()); + assertNull(builder.createSslSocketFactory()); } @Test @@ -159,5 +162,56 @@ public class OkHttpChannelBuilderTest { clientTransportFactory.close(); } -} + @Test + public void socketFactory_default() { + OkHttpChannelBuilder builder = OkHttpChannelBuilder.forTarget("foo"); + ClientTransportFactory transportFactory = builder.buildTransportFactory(); + OkHttpClientTransport transport = + (OkHttpClientTransport) + transportFactory.newClientTransport( + new InetSocketAddress(5678), new ClientTransportFactory.ClientTransportOptions()); + + assertSame(SocketFactory.getDefault(), transport.getSocketFactory()); + + transportFactory.close(); + } + + @Test + public void socketFactory_custom() { + SocketFactory socketFactory = + new SocketFactory() { + @Override + public Socket createSocket(String s, int i) { + return null; + } + + @Override + public Socket createSocket(String s, int i, InetAddress inetAddress, int i1) { + return null; + } + + @Override + public Socket createSocket(InetAddress inetAddress, int i) { + return null; + } + + @Override + public Socket createSocket( + InetAddress inetAddress, int i, InetAddress inetAddress1, int i1) { + return null; + } + }; + OkHttpChannelBuilder builder = + OkHttpChannelBuilder.forTarget("foo").socketFactory(socketFactory); + ClientTransportFactory transportFactory = builder.buildTransportFactory(); + OkHttpClientTransport transport = + (OkHttpClientTransport) + transportFactory.newClientTransport( + new InetSocketAddress(5678), new ClientTransportFactory.ClientTransportOptions()); + + assertSame(socketFactory, transport.getSocketFactory()); + + transportFactory.close(); + } +} diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index c6115d1197..555fbceb9f 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -86,6 +86,7 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; @@ -103,6 +104,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.Nullable; +import javax.net.SocketFactory; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLSocketFactory; import okio.Buffer; @@ -146,6 +148,7 @@ public class OkHttpClientTransportTest { @Mock private ManagedClientTransport.Listener transportListener; + private final SocketFactory socketFactory = null; private final SSLSocketFactory sslSocketFactory = null; private final HostnameVerifier hostnameVerifier = null; private final TransportTracer transportTracer = new TransportTracer(); @@ -242,6 +245,7 @@ public class OkHttpClientTransportTest { "hostname", /*agent=*/ null, executor, + socketFactory, sslSocketFactory, hostnameVerifier, OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC, @@ -1531,6 +1535,7 @@ public class OkHttpClientTransportTest { "invalid_authority", "userAgent", executor, + socketFactory, sslSocketFactory, hostnameVerifier, ConnectionSpec.CLEARTEXT, @@ -1555,6 +1560,7 @@ public class OkHttpClientTransportTest { "authority", "userAgent", executor, + socketFactory, sslSocketFactory, hostnameVerifier, ConnectionSpec.CLEARTEXT, @@ -1579,6 +1585,37 @@ public class OkHttpClientTransportTest { assertEquals(Status.UNAVAILABLE.getCode(), streamListener.status.getCode()); } + @Test + public void customSocketFactory() throws Exception { + RuntimeException exception = new RuntimeException("thrown by socket factory"); + SocketFactory socketFactory = new RuntimeExceptionThrowingSocketFactory(exception); + + clientTransport = + new OkHttpClientTransport( + new InetSocketAddress("localhost", 0), + "authority", + "userAgent", + executor, + socketFactory, + sslSocketFactory, + hostnameVerifier, + ConnectionSpec.CLEARTEXT, + DEFAULT_MAX_MESSAGE_SIZE, + INITIAL_WINDOW_SIZE, + NO_PROXY, + tooManyPingsRunnable, + DEFAULT_MAX_INBOUND_METADATA_SIZE, + new TransportTracer()); + + ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); + clientTransport.start(listener); + ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); + verify(listener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture()); + Status status = captor.getValue(); + assertEquals(Status.UNAVAILABLE.getCode(), status.getCode()); + assertSame(exception, status.getCause()); + } + @Test public void proxy_200() throws Exception { ServerSocket serverSocket = new ServerSocket(0); @@ -1588,6 +1625,7 @@ public class OkHttpClientTransportTest { "authority", "userAgent", executor, + socketFactory, sslSocketFactory, hostnameVerifier, ConnectionSpec.CLEARTEXT, @@ -1642,6 +1680,7 @@ public class OkHttpClientTransportTest { "authority", "userAgent", executor, + socketFactory, sslSocketFactory, hostnameVerifier, ConnectionSpec.CLEARTEXT, @@ -1695,6 +1734,7 @@ public class OkHttpClientTransportTest { "authority", "userAgent", executor, + socketFactory, sslSocketFactory, hostnameVerifier, ConnectionSpec.CLEARTEXT, @@ -2216,4 +2256,32 @@ public class OkHttpClientTransportTest { @Override public void windowUpdate(int streamId, long windowSizeIncrement) throws IOException {} } + + private static class RuntimeExceptionThrowingSocketFactory extends SocketFactory { + RuntimeException exception; + + private RuntimeExceptionThrowingSocketFactory(RuntimeException exception) { + this.exception = exception; + } + + @Override + public Socket createSocket(String s, int i) { + throw exception; + } + + @Override + public Socket createSocket(String s, int i, InetAddress inetAddress, int i1) { + throw exception; + } + + @Override + public Socket createSocket(InetAddress inetAddress, int i) { + throw exception; + } + + @Override + public Socket createSocket(InetAddress inetAddress, int i, InetAddress inetAddress1, int i1) { + throw exception; + } + } }