mirror of https://github.com/grpc/grpc-java.git
okhttp: add socketFactory method to channel builder (#5378)
This commit is contained in:
parent
d44d015c44
commit
6c32eaf9d4
|
|
@ -49,6 +49,7 @@ import java.util.concurrent.Executors;
|
|||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.net.SocketFactory;
|
||||
import javax.net.ssl.HostnameVerifier;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
|
|
@ -120,6 +121,7 @@ public class OkHttpChannelBuilder extends
|
|||
private Executor transportExecutor;
|
||||
private ScheduledExecutorService scheduledExecutorService;
|
||||
|
||||
private SocketFactory socketFactory;
|
||||
private SSLSocketFactory sslSocketFactory;
|
||||
private HostnameVerifier hostnameVerifier;
|
||||
private ConnectionSpec connectionSpec = INTERNAL_DEFAULT_CONNECTION_SPEC;
|
||||
|
|
@ -156,6 +158,17 @@ public class OkHttpChannelBuilder extends
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Override the default {@link SocketFactory} used to create sockets. If the socket factory is not
|
||||
* set or set to null, a default one will be used.
|
||||
*
|
||||
* @since 1.20.0
|
||||
*/
|
||||
public final OkHttpChannelBuilder socketFactory(@Nullable SocketFactory socketFactory) {
|
||||
this.socketFactory = socketFactory;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the negotiation type for the HTTP/2 connection.
|
||||
*
|
||||
|
|
@ -397,10 +410,21 @@ public class OkHttpChannelBuilder extends
|
|||
@Internal
|
||||
protected final ClientTransportFactory buildTransportFactory() {
|
||||
boolean enableKeepAlive = keepAliveTimeNanos != KEEPALIVE_TIME_NANOS_DISABLED;
|
||||
return new OkHttpTransportFactory(transportExecutor, scheduledExecutorService,
|
||||
createSocketFactory(), hostnameVerifier, connectionSpec, maxInboundMessageSize(),
|
||||
enableKeepAlive, keepAliveTimeNanos, keepAliveTimeoutNanos, flowControlWindow,
|
||||
keepAliveWithoutCalls, maxInboundMetadataSize, transportTracerFactory);
|
||||
return new OkHttpTransportFactory(
|
||||
transportExecutor,
|
||||
scheduledExecutorService,
|
||||
socketFactory,
|
||||
createSslSocketFactory(),
|
||||
hostnameVerifier,
|
||||
connectionSpec,
|
||||
maxInboundMessageSize(),
|
||||
enableKeepAlive,
|
||||
keepAliveTimeNanos,
|
||||
keepAliveTimeoutNanos,
|
||||
flowControlWindow,
|
||||
keepAliveWithoutCalls,
|
||||
maxInboundMetadataSize,
|
||||
transportTracerFactory);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -417,7 +441,7 @@ public class OkHttpChannelBuilder extends
|
|||
|
||||
@VisibleForTesting
|
||||
@Nullable
|
||||
SSLSocketFactory createSocketFactory() {
|
||||
SSLSocketFactory createSslSocketFactory() {
|
||||
switch (negotiationType) {
|
||||
case TLS:
|
||||
try {
|
||||
|
|
@ -463,8 +487,8 @@ public class OkHttpChannelBuilder extends
|
|||
private final boolean usingSharedExecutor;
|
||||
private final boolean usingSharedScheduler;
|
||||
private final TransportTracer.Factory transportTracerFactory;
|
||||
@Nullable
|
||||
private final SSLSocketFactory socketFactory;
|
||||
private final SocketFactory socketFactory;
|
||||
@Nullable private final SSLSocketFactory sslSocketFactory;
|
||||
@Nullable
|
||||
private final HostnameVerifier hostnameVerifier;
|
||||
private final ConnectionSpec connectionSpec;
|
||||
|
|
@ -478,9 +502,11 @@ public class OkHttpChannelBuilder extends
|
|||
private final ScheduledExecutorService timeoutService;
|
||||
private boolean closed;
|
||||
|
||||
private OkHttpTransportFactory(Executor executor,
|
||||
private OkHttpTransportFactory(
|
||||
Executor executor,
|
||||
@Nullable ScheduledExecutorService timeoutService,
|
||||
@Nullable SSLSocketFactory socketFactory,
|
||||
@Nullable SocketFactory socketFactory,
|
||||
@Nullable SSLSocketFactory sslSocketFactory,
|
||||
@Nullable HostnameVerifier hostnameVerifier,
|
||||
ConnectionSpec connectionSpec,
|
||||
int maxMessageSize,
|
||||
|
|
@ -495,6 +521,7 @@ public class OkHttpChannelBuilder extends
|
|||
this.timeoutService = usingSharedScheduler
|
||||
? SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE) : timeoutService;
|
||||
this.socketFactory = socketFactory;
|
||||
this.sslSocketFactory = sslSocketFactory;
|
||||
this.hostnameVerifier = hostnameVerifier;
|
||||
this.connectionSpec = connectionSpec;
|
||||
this.maxMessageSize = maxMessageSize;
|
||||
|
|
@ -536,6 +563,7 @@ public class OkHttpChannelBuilder extends
|
|||
options.getUserAgent(),
|
||||
executor,
|
||||
socketFactory,
|
||||
sslSocketFactory,
|
||||
hostnameVerifier,
|
||||
connectionSpec,
|
||||
maxMessageSize,
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ import java.util.logging.Level;
|
|||
import java.util.logging.Logger;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.annotation.concurrent.GuardedBy;
|
||||
import javax.net.SocketFactory;
|
||||
import javax.net.ssl.HostnameVerifier;
|
||||
import javax.net.ssl.SSLSession;
|
||||
import javax.net.ssl.SSLSocket;
|
||||
|
|
@ -175,6 +176,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
|
|||
private boolean stopped;
|
||||
@GuardedBy("lock")
|
||||
private boolean hasStream;
|
||||
private final SocketFactory socketFactory;
|
||||
private SSLSocketFactory sslSocketFactory;
|
||||
private HostnameVerifier hostnameVerifier;
|
||||
private Socket socket;
|
||||
|
|
@ -219,12 +221,21 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
|
|||
Runnable connectingCallback;
|
||||
SettableFuture<Void> connectedFuture;
|
||||
|
||||
OkHttpClientTransport(InetSocketAddress address, String authority, @Nullable String userAgent,
|
||||
Executor executor, @Nullable SSLSocketFactory sslSocketFactory,
|
||||
@Nullable HostnameVerifier hostnameVerifier, ConnectionSpec connectionSpec,
|
||||
int maxMessageSize, int initialWindowSize,
|
||||
OkHttpClientTransport(
|
||||
InetSocketAddress address,
|
||||
String authority,
|
||||
@Nullable String userAgent,
|
||||
Executor executor,
|
||||
@Nullable SocketFactory socketFactory,
|
||||
@Nullable SSLSocketFactory sslSocketFactory,
|
||||
@Nullable HostnameVerifier hostnameVerifier,
|
||||
ConnectionSpec connectionSpec,
|
||||
int maxMessageSize,
|
||||
int initialWindowSize,
|
||||
@Nullable HttpConnectProxiedSocketAddress proxiedAddr,
|
||||
Runnable tooManyPingsRunnable, int maxInboundMetadataSize, TransportTracer transportTracer) {
|
||||
Runnable tooManyPingsRunnable,
|
||||
int maxInboundMetadataSize,
|
||||
TransportTracer transportTracer) {
|
||||
this.address = Preconditions.checkNotNull(address, "address");
|
||||
this.defaultAuthority = authority;
|
||||
this.maxMessageSize = maxMessageSize;
|
||||
|
|
@ -234,6 +245,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
|
|||
// Client initiated streams are odd, server initiated ones are even. Server should not need to
|
||||
// use it. We start clients at 3 to avoid conflicting with HTTP negotiation.
|
||||
nextStreamId = 3;
|
||||
this.socketFactory = socketFactory == null ? SocketFactory.getDefault() : socketFactory;
|
||||
this.sslSocketFactory = sslSocketFactory;
|
||||
this.hostnameVerifier = hostnameVerifier;
|
||||
this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec");
|
||||
|
|
@ -273,6 +285,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
|
|||
this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent);
|
||||
this.executor = Preconditions.checkNotNull(executor, "executor");
|
||||
serializingExecutor = new SerializingExecutor(executor);
|
||||
this.socketFactory = SocketFactory.getDefault();
|
||||
this.testFrameReader = Preconditions.checkNotNull(frameReader, "frameReader");
|
||||
this.testFrameWriter = Preconditions.checkNotNull(testFrameWriter, "testFrameWriter");
|
||||
this.socket = Preconditions.checkNotNull(socket, "socket");
|
||||
|
|
@ -506,7 +519,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
|
|||
SSLSession sslSession = null;
|
||||
try {
|
||||
if (proxiedAddr == null) {
|
||||
sock = new Socket(address.getAddress(), address.getPort());
|
||||
sock = socketFactory.createSocket(address.getAddress(), address.getPort());
|
||||
} else {
|
||||
if (proxiedAddr.getProxyAddress() instanceof InetSocketAddress) {
|
||||
sock = createHttpProxySocket(
|
||||
|
|
@ -584,9 +597,10 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
|
|||
Socket sock;
|
||||
// The proxy address may not be resolved
|
||||
if (proxyAddress.getAddress() != null) {
|
||||
sock = new Socket(proxyAddress.getAddress(), proxyAddress.getPort());
|
||||
sock = socketFactory.createSocket(proxyAddress.getAddress(), proxyAddress.getPort());
|
||||
} else {
|
||||
sock = new Socket(proxyAddress.getHostName(), proxyAddress.getPort());
|
||||
sock =
|
||||
socketFactory.createSocket(proxyAddress.getHostName(), proxyAddress.getPort());
|
||||
}
|
||||
sock.setTcpNoDelay(true);
|
||||
|
||||
|
|
@ -771,6 +785,11 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
|
|||
return clientFrameHandler;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
SocketFactory getSocketFactory() {
|
||||
return socketFactory;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
int getPendingStreamSize() {
|
||||
synchronized (lock) {
|
||||
|
|
|
|||
|
|
@ -27,8 +27,11 @@ import io.grpc.internal.ClientTransportFactory;
|
|||
import io.grpc.internal.FakeClock;
|
||||
import io.grpc.internal.GrpcUtil;
|
||||
import io.grpc.internal.SharedResourceHolder;
|
||||
import java.net.InetAddress;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.Socket;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import javax.net.SocketFactory;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
|
|
@ -125,10 +128,10 @@ public class OkHttpChannelBuilderTest {
|
|||
@Test
|
||||
public void usePlaintextCreatesNullSocketFactory() {
|
||||
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("host", 1234);
|
||||
assertNotNull(builder.createSocketFactory());
|
||||
assertNotNull(builder.createSslSocketFactory());
|
||||
|
||||
builder.usePlaintext();
|
||||
assertNull(builder.createSocketFactory());
|
||||
assertNull(builder.createSslSocketFactory());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
@ -159,5 +162,56 @@ public class OkHttpChannelBuilderTest {
|
|||
|
||||
clientTransportFactory.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void socketFactory_default() {
|
||||
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forTarget("foo");
|
||||
ClientTransportFactory transportFactory = builder.buildTransportFactory();
|
||||
OkHttpClientTransport transport =
|
||||
(OkHttpClientTransport)
|
||||
transportFactory.newClientTransport(
|
||||
new InetSocketAddress(5678), new ClientTransportFactory.ClientTransportOptions());
|
||||
|
||||
assertSame(SocketFactory.getDefault(), transport.getSocketFactory());
|
||||
|
||||
transportFactory.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void socketFactory_custom() {
|
||||
SocketFactory socketFactory =
|
||||
new SocketFactory() {
|
||||
@Override
|
||||
public Socket createSocket(String s, int i) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Socket createSocket(String s, int i, InetAddress inetAddress, int i1) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Socket createSocket(InetAddress inetAddress, int i) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Socket createSocket(
|
||||
InetAddress inetAddress, int i, InetAddress inetAddress1, int i1) {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
OkHttpChannelBuilder builder =
|
||||
OkHttpChannelBuilder.forTarget("foo").socketFactory(socketFactory);
|
||||
ClientTransportFactory transportFactory = builder.buildTransportFactory();
|
||||
OkHttpClientTransport transport =
|
||||
(OkHttpClientTransport)
|
||||
transportFactory.newClientTransport(
|
||||
new InetSocketAddress(5678), new ClientTransportFactory.ClientTransportOptions());
|
||||
|
||||
assertSame(socketFactory, transport.getSocketFactory());
|
||||
|
||||
transportFactory.close();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ import java.io.ByteArrayInputStream;
|
|||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.net.InetAddress;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.ServerSocket;
|
||||
import java.net.Socket;
|
||||
|
|
@ -103,6 +104,7 @@ import java.util.concurrent.LinkedBlockingQueue;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.net.SocketFactory;
|
||||
import javax.net.ssl.HostnameVerifier;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
import okio.Buffer;
|
||||
|
|
@ -146,6 +148,7 @@ public class OkHttpClientTransportTest {
|
|||
@Mock
|
||||
private ManagedClientTransport.Listener transportListener;
|
||||
|
||||
private final SocketFactory socketFactory = null;
|
||||
private final SSLSocketFactory sslSocketFactory = null;
|
||||
private final HostnameVerifier hostnameVerifier = null;
|
||||
private final TransportTracer transportTracer = new TransportTracer();
|
||||
|
|
@ -242,6 +245,7 @@ public class OkHttpClientTransportTest {
|
|||
"hostname",
|
||||
/*agent=*/ null,
|
||||
executor,
|
||||
socketFactory,
|
||||
sslSocketFactory,
|
||||
hostnameVerifier,
|
||||
OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC,
|
||||
|
|
@ -1531,6 +1535,7 @@ public class OkHttpClientTransportTest {
|
|||
"invalid_authority",
|
||||
"userAgent",
|
||||
executor,
|
||||
socketFactory,
|
||||
sslSocketFactory,
|
||||
hostnameVerifier,
|
||||
ConnectionSpec.CLEARTEXT,
|
||||
|
|
@ -1555,6 +1560,7 @@ public class OkHttpClientTransportTest {
|
|||
"authority",
|
||||
"userAgent",
|
||||
executor,
|
||||
socketFactory,
|
||||
sslSocketFactory,
|
||||
hostnameVerifier,
|
||||
ConnectionSpec.CLEARTEXT,
|
||||
|
|
@ -1579,6 +1585,37 @@ public class OkHttpClientTransportTest {
|
|||
assertEquals(Status.UNAVAILABLE.getCode(), streamListener.status.getCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void customSocketFactory() throws Exception {
|
||||
RuntimeException exception = new RuntimeException("thrown by socket factory");
|
||||
SocketFactory socketFactory = new RuntimeExceptionThrowingSocketFactory(exception);
|
||||
|
||||
clientTransport =
|
||||
new OkHttpClientTransport(
|
||||
new InetSocketAddress("localhost", 0),
|
||||
"authority",
|
||||
"userAgent",
|
||||
executor,
|
||||
socketFactory,
|
||||
sslSocketFactory,
|
||||
hostnameVerifier,
|
||||
ConnectionSpec.CLEARTEXT,
|
||||
DEFAULT_MAX_MESSAGE_SIZE,
|
||||
INITIAL_WINDOW_SIZE,
|
||||
NO_PROXY,
|
||||
tooManyPingsRunnable,
|
||||
DEFAULT_MAX_INBOUND_METADATA_SIZE,
|
||||
new TransportTracer());
|
||||
|
||||
ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class);
|
||||
clientTransport.start(listener);
|
||||
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
|
||||
verify(listener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture());
|
||||
Status status = captor.getValue();
|
||||
assertEquals(Status.UNAVAILABLE.getCode(), status.getCode());
|
||||
assertSame(exception, status.getCause());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void proxy_200() throws Exception {
|
||||
ServerSocket serverSocket = new ServerSocket(0);
|
||||
|
|
@ -1588,6 +1625,7 @@ public class OkHttpClientTransportTest {
|
|||
"authority",
|
||||
"userAgent",
|
||||
executor,
|
||||
socketFactory,
|
||||
sslSocketFactory,
|
||||
hostnameVerifier,
|
||||
ConnectionSpec.CLEARTEXT,
|
||||
|
|
@ -1642,6 +1680,7 @@ public class OkHttpClientTransportTest {
|
|||
"authority",
|
||||
"userAgent",
|
||||
executor,
|
||||
socketFactory,
|
||||
sslSocketFactory,
|
||||
hostnameVerifier,
|
||||
ConnectionSpec.CLEARTEXT,
|
||||
|
|
@ -1695,6 +1734,7 @@ public class OkHttpClientTransportTest {
|
|||
"authority",
|
||||
"userAgent",
|
||||
executor,
|
||||
socketFactory,
|
||||
sslSocketFactory,
|
||||
hostnameVerifier,
|
||||
ConnectionSpec.CLEARTEXT,
|
||||
|
|
@ -2216,4 +2256,32 @@ public class OkHttpClientTransportTest {
|
|||
@Override
|
||||
public void windowUpdate(int streamId, long windowSizeIncrement) throws IOException {}
|
||||
}
|
||||
|
||||
private static class RuntimeExceptionThrowingSocketFactory extends SocketFactory {
|
||||
RuntimeException exception;
|
||||
|
||||
private RuntimeExceptionThrowingSocketFactory(RuntimeException exception) {
|
||||
this.exception = exception;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Socket createSocket(String s, int i) {
|
||||
throw exception;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Socket createSocket(String s, int i, InetAddress inetAddress, int i1) {
|
||||
throw exception;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Socket createSocket(InetAddress inetAddress, int i) {
|
||||
throw exception;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Socket createSocket(InetAddress inetAddress, int i, InetAddress inetAddress1, int i1) {
|
||||
throw exception;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue