Make the OkHTTP transport AppEngine friendly. AppEngine may support

conscrypt at some point which would allow ALPN to function
Clarify the SSLContext.getDefault is not used when constructing the
default SSLSocketFactory.
This commit is contained in:
Louis Ryan 2016-05-11 15:05:01 -07:00
parent d2cc576320
commit f52b4e52cd
8 changed files with 156 additions and 56 deletions

View File

@ -39,6 +39,7 @@ import com.google.common.base.Splitter;
import com.google.common.base.Stopwatch;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.Metadata;
@ -53,6 +54,7 @@ import java.util.Map.Entry;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
@ -62,6 +64,12 @@ import javax.annotation.Nullable;
*/
public final class GrpcUtil {
// Certain production AppEngine runtimes have constraints on threading and socket handling
// that need to be accommodated.
public static final boolean IS_RESTRICTED_APPENGINE =
"Production".equals(System.getProperty("com.google.appengine.runtime.environment"))
&& "1.7".equals(System.getProperty("java.specification.version"));
/**
* {@link io.grpc.Metadata.Key} for the timeout header.
*/
@ -374,10 +382,7 @@ public final class GrpcUtil {
private static final String name = "grpc-default-executor";
@Override
public ExecutorService create() {
return Executors.newCachedThreadPool(new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat(name + "-%d")
.build());
return Executors.newCachedThreadPool(getThreadFactory(name + "-%d", true));
}
@Override
@ -402,10 +407,8 @@ public final class GrpcUtil {
// ScheduledThreadPoolExecutor.
ScheduledExecutorService service = Executors.newScheduledThreadPool(
1,
new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("grpc-timer-%d")
.build());
getThreadFactory("grpc-timer-%d", true));
// If there are long timeouts that are cancelled, they will not actually be removed from
// the executors queue. This forces immediate removal upon cancellation to avoid a
// memory leak. Reflection is used because we cannot use methods added in Java 1.7. If
@ -431,6 +434,27 @@ public final class GrpcUtil {
}
};
/**
* Get a {@link ThreadFactory} suitable for use in the current environment.
* @param nameFormat to apply to threads created by the factory.
* @param daemon {@code true} if the threads the factory creates are daemon threads, {@code false}
* otherwise.
* @return a {@link ThreadFactory}.
*/
public static ThreadFactory getThreadFactory(String nameFormat, boolean daemon) {
ThreadFactory threadFactory = MoreExecutors.platformThreadFactory();
if (IS_RESTRICTED_APPENGINE) {
return threadFactory;
} else {
return new ThreadFactoryBuilder()
.setThreadFactory(threadFactory)
.setDaemon(daemon)
.setNameFormat(nameFormat)
.build();
}
}
/**
* The factory of default Stopwatches.
*/

View File

@ -32,7 +32,6 @@
package io.grpc.internal;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.util.IdentityHashMap;
import java.util.concurrent.Executors;
@ -65,10 +64,8 @@ public final class SharedResourceHolder {
new ScheduledExecutorFactory() {
@Override
public ScheduledExecutorService createScheduledExecutor() {
return Executors.newSingleThreadScheduledExecutor(new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("grpc-shared-destroyer-%d")
.build());
return Executors.newSingleThreadScheduledExecutor(
GrpcUtil.getThreadFactory("grpc-shared-destroyer-%d", true));
}
});

View File

@ -45,6 +45,7 @@ import io.grpc.internal.GrpcUtil;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.okhttp.OkHttpChannelBuilder;
import io.grpc.okhttp.internal.Platform;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.StreamRecorder;
import io.grpc.testing.TestUtils;
@ -60,6 +61,7 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import java.io.FileInputStream;
import java.io.IOException;
import javax.net.ssl.SSLPeerUnverifiedException;
@ -109,7 +111,8 @@ public class Http2OkHttpTest extends AbstractInteropTest {
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
TestUtils.TEST_SERVER_HOST, getPort()));
try {
builder.sslSocketFactory(TestUtils.newSslSocketFactoryForCa(TestUtils.loadCert("ca.pem")));
builder.sslSocketFactory(TestUtils.newSslSocketFactoryForCa(Platform.get().getProvider(),
new FileInputStream(TestUtils.loadCert("ca.pem"))));
} catch (Exception e) {
throw new RuntimeException(e);
}
@ -149,7 +152,8 @@ public class Http2OkHttpTest extends AbstractInteropTest {
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
"I.am.a.bad.hostname", getPort()));
ManagedChannel channel = builder.sslSocketFactory(
TestUtils.newSslSocketFactoryForCa(TestUtils.loadCert("ca.pem"))).build();
TestUtils.newSslSocketFactoryForCa(Platform.get().getProvider(),
new FileInputStream(TestUtils.loadCert("ca.pem")))).build();
TestServiceGrpc.TestServiceBlockingStub blockingStub =
TestServiceGrpc.newBlockingStub(channel);

View File

@ -38,7 +38,6 @@ import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.squareup.okhttp.CipherSuite;
import com.squareup.okhttp.ConnectionSpec;
@ -54,15 +53,18 @@ import io.grpc.internal.ConnectionClientTransport;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourceHolder;
import io.grpc.internal.SharedResourceHolder.Resource;
import io.grpc.okhttp.internal.Platform;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.GeneralSecurityException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
/** Convenience class for building channels with the OkHttp transport. */
@ -90,10 +92,7 @@ public class OkHttpChannelBuilder extends
new Resource<ExecutorService>() {
@Override
public ExecutorService create() {
return Executors.newCachedThreadPool(new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("grpc-okhttp-%d")
.build());
return Executors.newCachedThreadPool(GrpcUtil.getThreadFactory("grpc-okhttp-%d", true));
}
@Override
@ -147,6 +146,11 @@ public class OkHttpChannelBuilder extends
/**
* Sets the negotiation type for the HTTP/2 connection.
*
* <p>If TLS is enabled a default {@link SSLSocketFactory} is created using the best
* {@link java.security.Provider} available and is NOT based on
* {@link SSLSocketFactory#getDefault}. To more precisely control the TLS configuration call
* {@link #sslSocketFactory} to override the socket factory used.
*
* <p>Default: <code>TLS</code>
*/
public final OkHttpChannelBuilder negotiationType(NegotiationType type) {
@ -180,7 +184,8 @@ public class OkHttpChannelBuilder extends
}
/**
* Provides a SSLSocketFactory to replace the default SSLSocketFactory used for TLS.
* Override the default {@link SSLSocketFactory} and enable {@link NegotiationType#TLS}
* negotiation.
*
* <p>By default, when TLS is enabled, <code>SSLSocketFactory.getDefault()</code> will be used.
*
@ -264,8 +269,16 @@ public class OkHttpChannelBuilder extends
SSLSocketFactory createSocketFactory() {
switch (negotiationType) {
case TLS:
return sslSocketFactory == null
? (SSLSocketFactory) SSLSocketFactory.getDefault() : sslSocketFactory;
try {
if (sslSocketFactory == null) {
SSLContext sslContext = SSLContext.getInstance("TLS", Platform.get().getProvider());
sslContext.init(null, null, null);
sslSocketFactory = sslContext.getSocketFactory();
}
return sslSocketFactory;
} catch (GeneralSecurityException gse) {
throw new RuntimeException("TLS Provider failure", gse);
}
case PLAINTEXT:
return null;
default:

View File

@ -33,10 +33,14 @@ package io.grpc.okhttp;
import io.grpc.Internal;
import io.grpc.ManagedChannelProvider;
import io.grpc.internal.GrpcUtil;
/** Provider for {@link OkHttpChannelBuilder} instances. */
/**
* Provider for {@link OkHttpChannelBuilder} instances.
*/
@Internal
public final class OkHttpChannelProvider extends ManagedChannelProvider {
@Override
public boolean isAvailable() {
return true;
@ -44,7 +48,7 @@ public final class OkHttpChannelProvider extends ManagedChannelProvider {
@Override
public int priority() {
return isAndroid() ? 8 : 3;
return (GrpcUtil.IS_RESTRICTED_APPENGINE || isAndroid()) ? 8 : 3;
}
@Override

View File

@ -735,7 +735,9 @@ class OkHttpClientTransport implements ConnectionClientTransport {
@Override
public void run() {
String threadName = Thread.currentThread().getName();
Thread.currentThread().setName("OkHttpClientTransport");
if (!GrpcUtil.IS_RESTRICTED_APPENGINE) {
Thread.currentThread().setName("OkHttpClientTransport");
}
try {
// Read until the underlying socket closes.
while (frameReader.nextFrame(this)) {
@ -758,8 +760,10 @@ class OkHttpClientTransport implements ConnectionClientTransport {
log.log(Level.INFO, "Exception closing frame reader", ex);
}
listener.transportTerminated();
// Restore the original thread name.
Thread.currentThread().setName(threadName);
if (!GrpcUtil.IS_RESTRICTED_APPENGINE) {
// Restore the original thread name.
Thread.currentThread().setName(threadName);
}
}
}

View File

@ -28,10 +28,14 @@ import java.lang.reflect.Proxy;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketException;
import java.security.NoSuchAlgorithmException;
import java.security.Provider;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import okio.Buffer;
@ -62,6 +66,12 @@ public class Platform {
return PLATFORM;
}
private final Provider sslProvider;
public Platform(Provider sslProvider) {
this.sslProvider = sslProvider;
}
/** Prefix used on custom headers. */
public String getPrefix() {
return "OkHttp";
@ -77,6 +87,10 @@ public class Platform {
public void untagSocket(Socket socket) throws SocketException {
}
public Provider getProvider() {
return sslProvider;
}
/**
* Configure TLS extensions on {@code sslSocket} for {@code route}.
*
@ -106,46 +120,72 @@ public class Platform {
/** Attempt to match the host runtime to a capable Platform implementation. */
private static Platform findPlatform() {
// Attempt to find Android 2.3+ APIs.
// Find the conscrypt security provider if we can
Class<?> rawProviderClass = null;
try {
rawProviderClass = Class.forName("org.conscrypt.OpenSSLProvider");
} catch (ClassNotFoundException cnfe1) {
try {
Class.forName("com.android.org.conscrypt.OpenSSLSocketImpl");
} catch (ClassNotFoundException e) {
// Older platform before being unbundled.
Class.forName("org.apache.harmony.xnet.provider.jsse.OpenSSLSocketImpl");
rawProviderClass = Class.forName("com.android.org.conscrypt.OpenSSLProvider");
} catch (ClassNotFoundException cnfe2) {
try {
rawProviderClass = Class.forName("org.apache.harmony.xnet.provider.jsse.OpenSSLProvider");
} catch (ClassNotFoundException cnfe3) {
// Stick with what we have
}
}
}
boolean haveConscrypt = false;
Provider sslProvider = null;
try {
sslProvider = SSLContext.getDefault().getProvider();
} catch (NoSuchAlgorithmException nsae) {
// Ignore
}
if (rawProviderClass != null) {
if (sslProvider != null && sslProvider.getClass().equals(rawProviderClass)) {
haveConscrypt = true;
} else {
try {
Class<? extends Provider> providerClass = rawProviderClass.asSubclass(Provider.class);
sslProvider = providerClass.newInstance();
haveConscrypt = true;
} catch (InstantiationException iae) {
// Unable to use conscrypt, fall through to Jetty
logger.log(Level.WARNING,
"Unable to create conscrypt provider " + rawProviderClass.getName(), iae);
} catch (IllegalAccessException iaxe) {
// Unable to use conscrypt, fall through to Jetty
logger.log(Level.WARNING,
"Unable to create conscrypt provider " + rawProviderClass.getName(), iaxe);
}
}
}
if (haveConscrypt) {
// Attempt to find Android 2.3+ APIs.
OptionalMethod<Socket> setUseSessionTickets
= new OptionalMethod<Socket>(null, "setUseSessionTickets", boolean.class);
OptionalMethod<Socket> setHostname
= new OptionalMethod<Socket>(null, "setHostname", String.class);
Method trafficStatsTagSocket = null;
Method trafficStatsUntagSocket = null;
OptionalMethod<Socket> getAlpnSelectedProtocol = null;
OptionalMethod<Socket> setAlpnProtocols = null;
OptionalMethod<Socket> getAlpnSelectedProtocol =
new OptionalMethod<Socket>(byte[].class, "getAlpnSelectedProtocol");
OptionalMethod<Socket> setAlpnProtocols =
new OptionalMethod<Socket>(null, "setAlpnProtocols", byte[].class);
// Attempt to find Android 4.0+ APIs.
try {
Class<?> trafficStats = Class.forName("android.net.TrafficStats");
trafficStatsTagSocket = trafficStats.getMethod("tagSocket", Socket.class);
trafficStatsUntagSocket = trafficStats.getMethod("untagSocket", Socket.class);
// Attempt to find Android 5.0+ APIs.
try {
Class.forName("android.net.Network"); // Arbitrary class added in Android 5.0.
getAlpnSelectedProtocol =
new OptionalMethod<Socket>(byte[].class, "getAlpnSelectedProtocol");
setAlpnProtocols = new OptionalMethod<Socket>(null, "setAlpnProtocols", byte[].class);
} catch (ClassNotFoundException ignored) {
}
} catch (ClassNotFoundException ignored) {
} catch (NoSuchMethodException ignored) {
}
return new Android(setUseSessionTickets, setHostname, trafficStatsTagSocket,
trafficStatsUntagSocket, getAlpnSelectedProtocol, setAlpnProtocols);
} catch (ClassNotFoundException ignored) {
// This isn't an Android runtime.
trafficStatsUntagSocket, getAlpnSelectedProtocol, setAlpnProtocols, sslProvider);
}
// Find Jetty's ALPN extension for OpenJDK.
@ -159,16 +199,17 @@ public class Platform {
Method getMethod = negoClass.getMethod("get", SSLSocket.class);
Method removeMethod = negoClass.getMethod("remove", SSLSocket.class);
return new JdkWithJettyBootPlatform(
putMethod, getMethod, removeMethod, clientProviderClass, serverProviderClass);
putMethod, getMethod, removeMethod, clientProviderClass, serverProviderClass, sslProvider);
} catch (ClassNotFoundException ignored) {
} catch (NoSuchMethodException ignored) {
}
return new Platform();
return new Platform(sslProvider);
}
/** Android 2.3 or better. */
private static class Android extends Platform {
private final OptionalMethod<Socket> setUseSessionTickets;
private final OptionalMethod<Socket> setHostname;
@ -182,7 +223,9 @@ public class Platform {
public Android(OptionalMethod<Socket> setUseSessionTickets, OptionalMethod<Socket> setHostname,
Method trafficStatsTagSocket, Method trafficStatsUntagSocket,
OptionalMethod<Socket> getAlpnSelectedProtocol, OptionalMethod<Socket> setAlpnProtocols) {
OptionalMethod<Socket> getAlpnSelectedProtocol, OptionalMethod<Socket> setAlpnProtocols,
Provider provider) {
super(provider);
this.setUseSessionTickets = setUseSessionTickets;
this.setHostname = setHostname;
this.trafficStatsTagSocket = trafficStatsTagSocket;
@ -213,14 +256,13 @@ public class Platform {
}
// Enable ALPN.
if (setAlpnProtocols != null && setAlpnProtocols.isSupported(sslSocket)) {
if (setAlpnProtocols.isSupported(sslSocket)) {
Object[] parameters = { concatLengthPrefixed(protocols) };
setAlpnProtocols.invokeWithoutCheckedException(sslSocket, parameters);
}
}
@Override public String getSelectedProtocol(SSLSocket socket) {
if (getAlpnSelectedProtocol == null) return null;
if (!getAlpnSelectedProtocol.isSupported(socket)) return null;
byte[] alpnResult = (byte[]) getAlpnSelectedProtocol.invokeWithoutCheckedException(socket);
@ -263,7 +305,8 @@ public class Platform {
private final Class<?> serverProviderClass;
public JdkWithJettyBootPlatform(Method putMethod, Method getMethod, Method removeMethod,
Class<?> clientProviderClass, Class<?> serverProviderClass) {
Class<?> clientProviderClass, Class<?> serverProviderClass, Provider provider) {
super(provider);
this.putMethod = putMethod;
this.getMethod = getMethod;
this.removeMethod = removeMethod;

View File

@ -51,6 +51,8 @@ import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.Provider;
import java.security.Security;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
@ -246,7 +248,16 @@ public class TestUtils {
/**
* Creates an SSLSocketFactory which contains {@code certChainFile} as its only root certificate.
*/
public static SSLSocketFactory newSslSocketFactoryForCa(InputStream certChain) throws Exception {
public static SSLSocketFactory newSslSocketFactoryForCa(
InputStream certChain) throws Exception {
return newSslSocketFactoryForCa(Security.getProviders()[0], certChain);
}
/**
* Creates an SSLSocketFactory which contains {@code certChainFile} as its only root certificate.
*/
public static SSLSocketFactory newSslSocketFactoryForCa(Provider provider,
InputStream certChain) throws Exception {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
ks.load(null, null);
CertificateFactory cf = CertificateFactory.getInstance("X.509");
@ -259,7 +270,7 @@ public class TestUtils {
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(ks);
SSLContext context = SSLContext.getInstance("TLS");
SSLContext context = SSLContext.getInstance("TLS", provider);
context.init(null, trustManagerFactory.getTrustManagers(), null);
return context.getSocketFactory();
}