okhttp: add socketFactory method to channel builder (#5378)

This commit is contained in:
Eric Gribkoff 2019-02-20 19:17:20 -08:00 committed by GitHub
parent d44d015c44
commit 6c32eaf9d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 189 additions and 20 deletions

View File

@ -49,6 +49,7 @@ import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.SSLSocketFactory;
@ -120,6 +121,7 @@ public class OkHttpChannelBuilder extends
private Executor transportExecutor; private Executor transportExecutor;
private ScheduledExecutorService scheduledExecutorService; private ScheduledExecutorService scheduledExecutorService;
private SocketFactory socketFactory;
private SSLSocketFactory sslSocketFactory; private SSLSocketFactory sslSocketFactory;
private HostnameVerifier hostnameVerifier; private HostnameVerifier hostnameVerifier;
private ConnectionSpec connectionSpec = INTERNAL_DEFAULT_CONNECTION_SPEC; private ConnectionSpec connectionSpec = INTERNAL_DEFAULT_CONNECTION_SPEC;
@ -156,6 +158,17 @@ public class OkHttpChannelBuilder extends
return this; return this;
} }
/**
* Override the default {@link SocketFactory} used to create sockets. If the socket factory is not
* set or set to null, a default one will be used.
*
* @since 1.20.0
*/
public final OkHttpChannelBuilder socketFactory(@Nullable SocketFactory socketFactory) {
this.socketFactory = socketFactory;
return this;
}
/** /**
* Sets the negotiation type for the HTTP/2 connection. * Sets the negotiation type for the HTTP/2 connection.
* *
@ -397,10 +410,21 @@ public class OkHttpChannelBuilder extends
@Internal @Internal
protected final ClientTransportFactory buildTransportFactory() { protected final ClientTransportFactory buildTransportFactory() {
boolean enableKeepAlive = keepAliveTimeNanos != KEEPALIVE_TIME_NANOS_DISABLED; boolean enableKeepAlive = keepAliveTimeNanos != KEEPALIVE_TIME_NANOS_DISABLED;
return new OkHttpTransportFactory(transportExecutor, scheduledExecutorService, return new OkHttpTransportFactory(
createSocketFactory(), hostnameVerifier, connectionSpec, maxInboundMessageSize(), transportExecutor,
enableKeepAlive, keepAliveTimeNanos, keepAliveTimeoutNanos, flowControlWindow, scheduledExecutorService,
keepAliveWithoutCalls, maxInboundMetadataSize, transportTracerFactory); socketFactory,
createSslSocketFactory(),
hostnameVerifier,
connectionSpec,
maxInboundMessageSize(),
enableKeepAlive,
keepAliveTimeNanos,
keepAliveTimeoutNanos,
flowControlWindow,
keepAliveWithoutCalls,
maxInboundMetadataSize,
transportTracerFactory);
} }
@Override @Override
@ -417,7 +441,7 @@ public class OkHttpChannelBuilder extends
@VisibleForTesting @VisibleForTesting
@Nullable @Nullable
SSLSocketFactory createSocketFactory() { SSLSocketFactory createSslSocketFactory() {
switch (negotiationType) { switch (negotiationType) {
case TLS: case TLS:
try { try {
@ -463,8 +487,8 @@ public class OkHttpChannelBuilder extends
private final boolean usingSharedExecutor; private final boolean usingSharedExecutor;
private final boolean usingSharedScheduler; private final boolean usingSharedScheduler;
private final TransportTracer.Factory transportTracerFactory; private final TransportTracer.Factory transportTracerFactory;
@Nullable private final SocketFactory socketFactory;
private final SSLSocketFactory socketFactory; @Nullable private final SSLSocketFactory sslSocketFactory;
@Nullable @Nullable
private final HostnameVerifier hostnameVerifier; private final HostnameVerifier hostnameVerifier;
private final ConnectionSpec connectionSpec; private final ConnectionSpec connectionSpec;
@ -478,9 +502,11 @@ public class OkHttpChannelBuilder extends
private final ScheduledExecutorService timeoutService; private final ScheduledExecutorService timeoutService;
private boolean closed; private boolean closed;
private OkHttpTransportFactory(Executor executor, private OkHttpTransportFactory(
Executor executor,
@Nullable ScheduledExecutorService timeoutService, @Nullable ScheduledExecutorService timeoutService,
@Nullable SSLSocketFactory socketFactory, @Nullable SocketFactory socketFactory,
@Nullable SSLSocketFactory sslSocketFactory,
@Nullable HostnameVerifier hostnameVerifier, @Nullable HostnameVerifier hostnameVerifier,
ConnectionSpec connectionSpec, ConnectionSpec connectionSpec,
int maxMessageSize, int maxMessageSize,
@ -495,6 +521,7 @@ public class OkHttpChannelBuilder extends
this.timeoutService = usingSharedScheduler this.timeoutService = usingSharedScheduler
? SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE) : timeoutService; ? SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE) : timeoutService;
this.socketFactory = socketFactory; this.socketFactory = socketFactory;
this.sslSocketFactory = sslSocketFactory;
this.hostnameVerifier = hostnameVerifier; this.hostnameVerifier = hostnameVerifier;
this.connectionSpec = connectionSpec; this.connectionSpec = connectionSpec;
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
@ -536,6 +563,7 @@ public class OkHttpChannelBuilder extends
options.getUserAgent(), options.getUserAgent(),
executor, executor,
socketFactory, socketFactory,
sslSocketFactory,
hostnameVerifier, hostnameVerifier,
connectionSpec, connectionSpec,
maxMessageSize, maxMessageSize,

View File

@ -86,6 +86,7 @@ import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.GuardedBy;
import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocket;
@ -175,6 +176,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
private boolean stopped; private boolean stopped;
@GuardedBy("lock") @GuardedBy("lock")
private boolean hasStream; private boolean hasStream;
private final SocketFactory socketFactory;
private SSLSocketFactory sslSocketFactory; private SSLSocketFactory sslSocketFactory;
private HostnameVerifier hostnameVerifier; private HostnameVerifier hostnameVerifier;
private Socket socket; private Socket socket;
@ -219,12 +221,21 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
Runnable connectingCallback; Runnable connectingCallback;
SettableFuture<Void> connectedFuture; SettableFuture<Void> connectedFuture;
OkHttpClientTransport(InetSocketAddress address, String authority, @Nullable String userAgent, OkHttpClientTransport(
Executor executor, @Nullable SSLSocketFactory sslSocketFactory, InetSocketAddress address,
@Nullable HostnameVerifier hostnameVerifier, ConnectionSpec connectionSpec, String authority,
int maxMessageSize, int initialWindowSize, @Nullable String userAgent,
Executor executor,
@Nullable SocketFactory socketFactory,
@Nullable SSLSocketFactory sslSocketFactory,
@Nullable HostnameVerifier hostnameVerifier,
ConnectionSpec connectionSpec,
int maxMessageSize,
int initialWindowSize,
@Nullable HttpConnectProxiedSocketAddress proxiedAddr, @Nullable HttpConnectProxiedSocketAddress proxiedAddr,
Runnable tooManyPingsRunnable, int maxInboundMetadataSize, TransportTracer transportTracer) { Runnable tooManyPingsRunnable,
int maxInboundMetadataSize,
TransportTracer transportTracer) {
this.address = Preconditions.checkNotNull(address, "address"); this.address = Preconditions.checkNotNull(address, "address");
this.defaultAuthority = authority; this.defaultAuthority = authority;
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
@ -234,6 +245,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
// Client initiated streams are odd, server initiated ones are even. Server should not need to // Client initiated streams are odd, server initiated ones are even. Server should not need to
// use it. We start clients at 3 to avoid conflicting with HTTP negotiation. // use it. We start clients at 3 to avoid conflicting with HTTP negotiation.
nextStreamId = 3; nextStreamId = 3;
this.socketFactory = socketFactory == null ? SocketFactory.getDefault() : socketFactory;
this.sslSocketFactory = sslSocketFactory; this.sslSocketFactory = sslSocketFactory;
this.hostnameVerifier = hostnameVerifier; this.hostnameVerifier = hostnameVerifier;
this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec"); this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec");
@ -273,6 +285,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent); this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent);
this.executor = Preconditions.checkNotNull(executor, "executor"); this.executor = Preconditions.checkNotNull(executor, "executor");
serializingExecutor = new SerializingExecutor(executor); serializingExecutor = new SerializingExecutor(executor);
this.socketFactory = SocketFactory.getDefault();
this.testFrameReader = Preconditions.checkNotNull(frameReader, "frameReader"); this.testFrameReader = Preconditions.checkNotNull(frameReader, "frameReader");
this.testFrameWriter = Preconditions.checkNotNull(testFrameWriter, "testFrameWriter"); this.testFrameWriter = Preconditions.checkNotNull(testFrameWriter, "testFrameWriter");
this.socket = Preconditions.checkNotNull(socket, "socket"); this.socket = Preconditions.checkNotNull(socket, "socket");
@ -506,7 +519,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
SSLSession sslSession = null; SSLSession sslSession = null;
try { try {
if (proxiedAddr == null) { if (proxiedAddr == null) {
sock = new Socket(address.getAddress(), address.getPort()); sock = socketFactory.createSocket(address.getAddress(), address.getPort());
} else { } else {
if (proxiedAddr.getProxyAddress() instanceof InetSocketAddress) { if (proxiedAddr.getProxyAddress() instanceof InetSocketAddress) {
sock = createHttpProxySocket( sock = createHttpProxySocket(
@ -584,9 +597,10 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
Socket sock; Socket sock;
// The proxy address may not be resolved // The proxy address may not be resolved
if (proxyAddress.getAddress() != null) { if (proxyAddress.getAddress() != null) {
sock = new Socket(proxyAddress.getAddress(), proxyAddress.getPort()); sock = socketFactory.createSocket(proxyAddress.getAddress(), proxyAddress.getPort());
} else { } else {
sock = new Socket(proxyAddress.getHostName(), proxyAddress.getPort()); sock =
socketFactory.createSocket(proxyAddress.getHostName(), proxyAddress.getPort());
} }
sock.setTcpNoDelay(true); sock.setTcpNoDelay(true);
@ -771,6 +785,11 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
return clientFrameHandler; return clientFrameHandler;
} }
@VisibleForTesting
SocketFactory getSocketFactory() {
return socketFactory;
}
@VisibleForTesting @VisibleForTesting
int getPendingStreamSize() { int getPendingStreamSize() {
synchronized (lock) { synchronized (lock) {

View File

@ -27,8 +27,11 @@ import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.FakeClock; import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourceHolder;
import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import javax.net.SocketFactory;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
@ -125,10 +128,10 @@ public class OkHttpChannelBuilderTest {
@Test @Test
public void usePlaintextCreatesNullSocketFactory() { public void usePlaintextCreatesNullSocketFactory() {
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("host", 1234); OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("host", 1234);
assertNotNull(builder.createSocketFactory()); assertNotNull(builder.createSslSocketFactory());
builder.usePlaintext(); builder.usePlaintext();
assertNull(builder.createSocketFactory()); assertNull(builder.createSslSocketFactory());
} }
@Test @Test
@ -159,5 +162,56 @@ public class OkHttpChannelBuilderTest {
clientTransportFactory.close(); clientTransportFactory.close();
} }
}
@Test
public void socketFactory_default() {
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forTarget("foo");
ClientTransportFactory transportFactory = builder.buildTransportFactory();
OkHttpClientTransport transport =
(OkHttpClientTransport)
transportFactory.newClientTransport(
new InetSocketAddress(5678), new ClientTransportFactory.ClientTransportOptions());
assertSame(SocketFactory.getDefault(), transport.getSocketFactory());
transportFactory.close();
}
@Test
public void socketFactory_custom() {
SocketFactory socketFactory =
new SocketFactory() {
@Override
public Socket createSocket(String s, int i) {
return null;
}
@Override
public Socket createSocket(String s, int i, InetAddress inetAddress, int i1) {
return null;
}
@Override
public Socket createSocket(InetAddress inetAddress, int i) {
return null;
}
@Override
public Socket createSocket(
InetAddress inetAddress, int i, InetAddress inetAddress1, int i1) {
return null;
}
};
OkHttpChannelBuilder builder =
OkHttpChannelBuilder.forTarget("foo").socketFactory(socketFactory);
ClientTransportFactory transportFactory = builder.buildTransportFactory();
OkHttpClientTransport transport =
(OkHttpClientTransport)
transportFactory.newClientTransport(
new InetSocketAddress(5678), new ClientTransportFactory.ClientTransportOptions());
assertSame(socketFactory, transport.getSocketFactory());
transportFactory.close();
}
}

View File

@ -86,6 +86,7 @@ import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.net.Socket; import java.net.Socket;
@ -103,6 +104,7 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.SSLSocketFactory;
import okio.Buffer; import okio.Buffer;
@ -146,6 +148,7 @@ public class OkHttpClientTransportTest {
@Mock @Mock
private ManagedClientTransport.Listener transportListener; private ManagedClientTransport.Listener transportListener;
private final SocketFactory socketFactory = null;
private final SSLSocketFactory sslSocketFactory = null; private final SSLSocketFactory sslSocketFactory = null;
private final HostnameVerifier hostnameVerifier = null; private final HostnameVerifier hostnameVerifier = null;
private final TransportTracer transportTracer = new TransportTracer(); private final TransportTracer transportTracer = new TransportTracer();
@ -242,6 +245,7 @@ public class OkHttpClientTransportTest {
"hostname", "hostname",
/*agent=*/ null, /*agent=*/ null,
executor, executor,
socketFactory,
sslSocketFactory, sslSocketFactory,
hostnameVerifier, hostnameVerifier,
OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC, OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC,
@ -1531,6 +1535,7 @@ public class OkHttpClientTransportTest {
"invalid_authority", "invalid_authority",
"userAgent", "userAgent",
executor, executor,
socketFactory,
sslSocketFactory, sslSocketFactory,
hostnameVerifier, hostnameVerifier,
ConnectionSpec.CLEARTEXT, ConnectionSpec.CLEARTEXT,
@ -1555,6 +1560,7 @@ public class OkHttpClientTransportTest {
"authority", "authority",
"userAgent", "userAgent",
executor, executor,
socketFactory,
sslSocketFactory, sslSocketFactory,
hostnameVerifier, hostnameVerifier,
ConnectionSpec.CLEARTEXT, ConnectionSpec.CLEARTEXT,
@ -1579,6 +1585,37 @@ public class OkHttpClientTransportTest {
assertEquals(Status.UNAVAILABLE.getCode(), streamListener.status.getCode()); assertEquals(Status.UNAVAILABLE.getCode(), streamListener.status.getCode());
} }
@Test
public void customSocketFactory() throws Exception {
RuntimeException exception = new RuntimeException("thrown by socket factory");
SocketFactory socketFactory = new RuntimeExceptionThrowingSocketFactory(exception);
clientTransport =
new OkHttpClientTransport(
new InetSocketAddress("localhost", 0),
"authority",
"userAgent",
executor,
socketFactory,
sslSocketFactory,
hostnameVerifier,
ConnectionSpec.CLEARTEXT,
DEFAULT_MAX_MESSAGE_SIZE,
INITIAL_WINDOW_SIZE,
NO_PROXY,
tooManyPingsRunnable,
DEFAULT_MAX_INBOUND_METADATA_SIZE,
new TransportTracer());
ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class);
clientTransport.start(listener);
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(listener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture());
Status status = captor.getValue();
assertEquals(Status.UNAVAILABLE.getCode(), status.getCode());
assertSame(exception, status.getCause());
}
@Test @Test
public void proxy_200() throws Exception { public void proxy_200() throws Exception {
ServerSocket serverSocket = new ServerSocket(0); ServerSocket serverSocket = new ServerSocket(0);
@ -1588,6 +1625,7 @@ public class OkHttpClientTransportTest {
"authority", "authority",
"userAgent", "userAgent",
executor, executor,
socketFactory,
sslSocketFactory, sslSocketFactory,
hostnameVerifier, hostnameVerifier,
ConnectionSpec.CLEARTEXT, ConnectionSpec.CLEARTEXT,
@ -1642,6 +1680,7 @@ public class OkHttpClientTransportTest {
"authority", "authority",
"userAgent", "userAgent",
executor, executor,
socketFactory,
sslSocketFactory, sslSocketFactory,
hostnameVerifier, hostnameVerifier,
ConnectionSpec.CLEARTEXT, ConnectionSpec.CLEARTEXT,
@ -1695,6 +1734,7 @@ public class OkHttpClientTransportTest {
"authority", "authority",
"userAgent", "userAgent",
executor, executor,
socketFactory,
sslSocketFactory, sslSocketFactory,
hostnameVerifier, hostnameVerifier,
ConnectionSpec.CLEARTEXT, ConnectionSpec.CLEARTEXT,
@ -2216,4 +2256,32 @@ public class OkHttpClientTransportTest {
@Override @Override
public void windowUpdate(int streamId, long windowSizeIncrement) throws IOException {} public void windowUpdate(int streamId, long windowSizeIncrement) throws IOException {}
} }
private static class RuntimeExceptionThrowingSocketFactory extends SocketFactory {
RuntimeException exception;
private RuntimeExceptionThrowingSocketFactory(RuntimeException exception) {
this.exception = exception;
}
@Override
public Socket createSocket(String s, int i) {
throw exception;
}
@Override
public Socket createSocket(String s, int i, InetAddress inetAddress, int i1) {
throw exception;
}
@Override
public Socket createSocket(InetAddress inetAddress, int i) {
throw exception;
}
@Override
public Socket createSocket(InetAddress inetAddress, int i, InetAddress inetAddress1, int i1) {
throw exception;
}
}
} }