okhttp: add socketFactory method to channel builder (#5378)

This commit is contained in:
Eric Gribkoff 2019-02-20 19:17:20 -08:00 committed by GitHub
parent d44d015c44
commit 6c32eaf9d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 189 additions and 20 deletions

View File

@ -49,6 +49,7 @@ import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
@ -120,6 +121,7 @@ public class OkHttpChannelBuilder extends
private Executor transportExecutor;
private ScheduledExecutorService scheduledExecutorService;
private SocketFactory socketFactory;
private SSLSocketFactory sslSocketFactory;
private HostnameVerifier hostnameVerifier;
private ConnectionSpec connectionSpec = INTERNAL_DEFAULT_CONNECTION_SPEC;
@ -156,6 +158,17 @@ public class OkHttpChannelBuilder extends
return this;
}
/**
* Override the default {@link SocketFactory} used to create sockets. If the socket factory is not
* set or set to null, a default one will be used.
*
* @since 1.20.0
*/
public final OkHttpChannelBuilder socketFactory(@Nullable SocketFactory socketFactory) {
this.socketFactory = socketFactory;
return this;
}
/**
* Sets the negotiation type for the HTTP/2 connection.
*
@ -397,10 +410,21 @@ public class OkHttpChannelBuilder extends
@Internal
protected final ClientTransportFactory buildTransportFactory() {
boolean enableKeepAlive = keepAliveTimeNanos != KEEPALIVE_TIME_NANOS_DISABLED;
return new OkHttpTransportFactory(transportExecutor, scheduledExecutorService,
createSocketFactory(), hostnameVerifier, connectionSpec, maxInboundMessageSize(),
enableKeepAlive, keepAliveTimeNanos, keepAliveTimeoutNanos, flowControlWindow,
keepAliveWithoutCalls, maxInboundMetadataSize, transportTracerFactory);
return new OkHttpTransportFactory(
transportExecutor,
scheduledExecutorService,
socketFactory,
createSslSocketFactory(),
hostnameVerifier,
connectionSpec,
maxInboundMessageSize(),
enableKeepAlive,
keepAliveTimeNanos,
keepAliveTimeoutNanos,
flowControlWindow,
keepAliveWithoutCalls,
maxInboundMetadataSize,
transportTracerFactory);
}
@Override
@ -417,7 +441,7 @@ public class OkHttpChannelBuilder extends
@VisibleForTesting
@Nullable
SSLSocketFactory createSocketFactory() {
SSLSocketFactory createSslSocketFactory() {
switch (negotiationType) {
case TLS:
try {
@ -463,8 +487,8 @@ public class OkHttpChannelBuilder extends
private final boolean usingSharedExecutor;
private final boolean usingSharedScheduler;
private final TransportTracer.Factory transportTracerFactory;
@Nullable
private final SSLSocketFactory socketFactory;
private final SocketFactory socketFactory;
@Nullable private final SSLSocketFactory sslSocketFactory;
@Nullable
private final HostnameVerifier hostnameVerifier;
private final ConnectionSpec connectionSpec;
@ -478,9 +502,11 @@ public class OkHttpChannelBuilder extends
private final ScheduledExecutorService timeoutService;
private boolean closed;
private OkHttpTransportFactory(Executor executor,
private OkHttpTransportFactory(
Executor executor,
@Nullable ScheduledExecutorService timeoutService,
@Nullable SSLSocketFactory socketFactory,
@Nullable SocketFactory socketFactory,
@Nullable SSLSocketFactory sslSocketFactory,
@Nullable HostnameVerifier hostnameVerifier,
ConnectionSpec connectionSpec,
int maxMessageSize,
@ -495,6 +521,7 @@ public class OkHttpChannelBuilder extends
this.timeoutService = usingSharedScheduler
? SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE) : timeoutService;
this.socketFactory = socketFactory;
this.sslSocketFactory = sslSocketFactory;
this.hostnameVerifier = hostnameVerifier;
this.connectionSpec = connectionSpec;
this.maxMessageSize = maxMessageSize;
@ -536,6 +563,7 @@ public class OkHttpChannelBuilder extends
options.getUserAgent(),
executor,
socketFactory,
sslSocketFactory,
hostnameVerifier,
connectionSpec,
maxMessageSize,

View File

@ -86,6 +86,7 @@ import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
@ -175,6 +176,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
private boolean stopped;
@GuardedBy("lock")
private boolean hasStream;
private final SocketFactory socketFactory;
private SSLSocketFactory sslSocketFactory;
private HostnameVerifier hostnameVerifier;
private Socket socket;
@ -219,12 +221,21 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
Runnable connectingCallback;
SettableFuture<Void> connectedFuture;
OkHttpClientTransport(InetSocketAddress address, String authority, @Nullable String userAgent,
Executor executor, @Nullable SSLSocketFactory sslSocketFactory,
@Nullable HostnameVerifier hostnameVerifier, ConnectionSpec connectionSpec,
int maxMessageSize, int initialWindowSize,
OkHttpClientTransport(
InetSocketAddress address,
String authority,
@Nullable String userAgent,
Executor executor,
@Nullable SocketFactory socketFactory,
@Nullable SSLSocketFactory sslSocketFactory,
@Nullable HostnameVerifier hostnameVerifier,
ConnectionSpec connectionSpec,
int maxMessageSize,
int initialWindowSize,
@Nullable HttpConnectProxiedSocketAddress proxiedAddr,
Runnable tooManyPingsRunnable, int maxInboundMetadataSize, TransportTracer transportTracer) {
Runnable tooManyPingsRunnable,
int maxInboundMetadataSize,
TransportTracer transportTracer) {
this.address = Preconditions.checkNotNull(address, "address");
this.defaultAuthority = authority;
this.maxMessageSize = maxMessageSize;
@ -234,6 +245,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
// Client initiated streams are odd, server initiated ones are even. Server should not need to
// use it. We start clients at 3 to avoid conflicting with HTTP negotiation.
nextStreamId = 3;
this.socketFactory = socketFactory == null ? SocketFactory.getDefault() : socketFactory;
this.sslSocketFactory = sslSocketFactory;
this.hostnameVerifier = hostnameVerifier;
this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec");
@ -273,6 +285,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent);
this.executor = Preconditions.checkNotNull(executor, "executor");
serializingExecutor = new SerializingExecutor(executor);
this.socketFactory = SocketFactory.getDefault();
this.testFrameReader = Preconditions.checkNotNull(frameReader, "frameReader");
this.testFrameWriter = Preconditions.checkNotNull(testFrameWriter, "testFrameWriter");
this.socket = Preconditions.checkNotNull(socket, "socket");
@ -506,7 +519,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
SSLSession sslSession = null;
try {
if (proxiedAddr == null) {
sock = new Socket(address.getAddress(), address.getPort());
sock = socketFactory.createSocket(address.getAddress(), address.getPort());
} else {
if (proxiedAddr.getProxyAddress() instanceof InetSocketAddress) {
sock = createHttpProxySocket(
@ -584,9 +597,10 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
Socket sock;
// The proxy address may not be resolved
if (proxyAddress.getAddress() != null) {
sock = new Socket(proxyAddress.getAddress(), proxyAddress.getPort());
sock = socketFactory.createSocket(proxyAddress.getAddress(), proxyAddress.getPort());
} else {
sock = new Socket(proxyAddress.getHostName(), proxyAddress.getPort());
sock =
socketFactory.createSocket(proxyAddress.getHostName(), proxyAddress.getPort());
}
sock.setTcpNoDelay(true);
@ -771,6 +785,11 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
return clientFrameHandler;
}
@VisibleForTesting
SocketFactory getSocketFactory() {
return socketFactory;
}
@VisibleForTesting
int getPendingStreamSize() {
synchronized (lock) {

View File

@ -27,8 +27,11 @@ import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourceHolder;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.concurrent.ScheduledExecutorService;
import javax.net.SocketFactory;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@ -125,10 +128,10 @@ public class OkHttpChannelBuilderTest {
@Test
public void usePlaintextCreatesNullSocketFactory() {
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("host", 1234);
assertNotNull(builder.createSocketFactory());
assertNotNull(builder.createSslSocketFactory());
builder.usePlaintext();
assertNull(builder.createSocketFactory());
assertNull(builder.createSslSocketFactory());
}
@Test
@ -159,5 +162,56 @@ public class OkHttpChannelBuilderTest {
clientTransportFactory.close();
}
}
@Test
public void socketFactory_default() {
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forTarget("foo");
ClientTransportFactory transportFactory = builder.buildTransportFactory();
OkHttpClientTransport transport =
(OkHttpClientTransport)
transportFactory.newClientTransport(
new InetSocketAddress(5678), new ClientTransportFactory.ClientTransportOptions());
assertSame(SocketFactory.getDefault(), transport.getSocketFactory());
transportFactory.close();
}
@Test
public void socketFactory_custom() {
SocketFactory socketFactory =
new SocketFactory() {
@Override
public Socket createSocket(String s, int i) {
return null;
}
@Override
public Socket createSocket(String s, int i, InetAddress inetAddress, int i1) {
return null;
}
@Override
public Socket createSocket(InetAddress inetAddress, int i) {
return null;
}
@Override
public Socket createSocket(
InetAddress inetAddress, int i, InetAddress inetAddress1, int i1) {
return null;
}
};
OkHttpChannelBuilder builder =
OkHttpChannelBuilder.forTarget("foo").socketFactory(socketFactory);
ClientTransportFactory transportFactory = builder.buildTransportFactory();
OkHttpClientTransport transport =
(OkHttpClientTransport)
transportFactory.newClientTransport(
new InetSocketAddress(5678), new ClientTransportFactory.ClientTransportOptions());
assertSame(socketFactory, transport.getSocketFactory());
transportFactory.close();
}
}

View File

@ -86,6 +86,7 @@ import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
@ -103,6 +104,7 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nullable;
import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSocketFactory;
import okio.Buffer;
@ -146,6 +148,7 @@ public class OkHttpClientTransportTest {
@Mock
private ManagedClientTransport.Listener transportListener;
private final SocketFactory socketFactory = null;
private final SSLSocketFactory sslSocketFactory = null;
private final HostnameVerifier hostnameVerifier = null;
private final TransportTracer transportTracer = new TransportTracer();
@ -242,6 +245,7 @@ public class OkHttpClientTransportTest {
"hostname",
/*agent=*/ null,
executor,
socketFactory,
sslSocketFactory,
hostnameVerifier,
OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC,
@ -1531,6 +1535,7 @@ public class OkHttpClientTransportTest {
"invalid_authority",
"userAgent",
executor,
socketFactory,
sslSocketFactory,
hostnameVerifier,
ConnectionSpec.CLEARTEXT,
@ -1555,6 +1560,7 @@ public class OkHttpClientTransportTest {
"authority",
"userAgent",
executor,
socketFactory,
sslSocketFactory,
hostnameVerifier,
ConnectionSpec.CLEARTEXT,
@ -1579,6 +1585,37 @@ public class OkHttpClientTransportTest {
assertEquals(Status.UNAVAILABLE.getCode(), streamListener.status.getCode());
}
@Test
public void customSocketFactory() throws Exception {
RuntimeException exception = new RuntimeException("thrown by socket factory");
SocketFactory socketFactory = new RuntimeExceptionThrowingSocketFactory(exception);
clientTransport =
new OkHttpClientTransport(
new InetSocketAddress("localhost", 0),
"authority",
"userAgent",
executor,
socketFactory,
sslSocketFactory,
hostnameVerifier,
ConnectionSpec.CLEARTEXT,
DEFAULT_MAX_MESSAGE_SIZE,
INITIAL_WINDOW_SIZE,
NO_PROXY,
tooManyPingsRunnable,
DEFAULT_MAX_INBOUND_METADATA_SIZE,
new TransportTracer());
ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class);
clientTransport.start(listener);
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(listener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture());
Status status = captor.getValue();
assertEquals(Status.UNAVAILABLE.getCode(), status.getCode());
assertSame(exception, status.getCause());
}
@Test
public void proxy_200() throws Exception {
ServerSocket serverSocket = new ServerSocket(0);
@ -1588,6 +1625,7 @@ public class OkHttpClientTransportTest {
"authority",
"userAgent",
executor,
socketFactory,
sslSocketFactory,
hostnameVerifier,
ConnectionSpec.CLEARTEXT,
@ -1642,6 +1680,7 @@ public class OkHttpClientTransportTest {
"authority",
"userAgent",
executor,
socketFactory,
sslSocketFactory,
hostnameVerifier,
ConnectionSpec.CLEARTEXT,
@ -1695,6 +1734,7 @@ public class OkHttpClientTransportTest {
"authority",
"userAgent",
executor,
socketFactory,
sslSocketFactory,
hostnameVerifier,
ConnectionSpec.CLEARTEXT,
@ -2216,4 +2256,32 @@ public class OkHttpClientTransportTest {
@Override
public void windowUpdate(int streamId, long windowSizeIncrement) throws IOException {}
}
private static class RuntimeExceptionThrowingSocketFactory extends SocketFactory {
RuntimeException exception;
private RuntimeExceptionThrowingSocketFactory(RuntimeException exception) {
this.exception = exception;
}
@Override
public Socket createSocket(String s, int i) {
throw exception;
}
@Override
public Socket createSocket(String s, int i, InetAddress inetAddress, int i1) {
throw exception;
}
@Override
public Socket createSocket(InetAddress inetAddress, int i) {
throw exception;
}
@Override
public Socket createSocket(InetAddress inetAddress, int i, InetAddress inetAddress1, int i1) {
throw exception;
}
}
}