From 2e0e238fb2271ceb9645412b8de53b9ad6e3d3a2 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Tue, 9 Feb 2021 16:20:43 -0800 Subject: [PATCH] okhttp: Consume mTLS and Trust/KeyManager Credentials API --- .../TesterOkHttpChannelBuilder.java | 52 ++------ .../testing/integration/Http2OkHttpTest.java | 27 ++++- .../io/grpc/okhttp/OkHttpChannelBuilder.java | 67 +++++++++- .../grpc/okhttp/OkHttpChannelBuilderTest.java | 114 ++++++++++++++++++ 4 files changed, 212 insertions(+), 48 deletions(-) diff --git a/android-interop-testing/src/main/java/io/grpc/android/integrationtest/TesterOkHttpChannelBuilder.java b/android-interop-testing/src/main/java/io/grpc/android/integrationtest/TesterOkHttpChannelBuilder.java index 36a3b0ac9f..f4d8f5abcd 100644 --- a/android-interop-testing/src/main/java/io/grpc/android/integrationtest/TesterOkHttpChannelBuilder.java +++ b/android-interop-testing/src/main/java/io/grpc/android/integrationtest/TesterOkHttpChannelBuilder.java @@ -22,16 +22,9 @@ import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; -import io.grpc.okhttp.SslSocketFactoryChannelCredentials; +import io.grpc.TlsChannelCredentials; +import io.grpc.okhttp.OkHttpChannelBuilder; import java.io.InputStream; -import java.security.KeyStore; -import java.security.cert.CertificateFactory; -import java.security.cert.X509Certificate; -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLSocketFactory; -import javax.net.ssl.TrustManager; -import javax.net.ssl.TrustManagerFactory; -import javax.security.auth.x500.X500Principal; /** * A helper class to create a OkHttp based channel. @@ -45,10 +38,14 @@ class TesterOkHttpChannelBuilder { @Nullable InputStream testCa) { ChannelCredentials credentials; if (useTls) { - try { - credentials = SslSocketFactoryChannelCredentials.create(getSslSocketFactory(testCa)); - } catch (Exception e) { - throw new RuntimeException(e); + if (testCa == null) { + credentials = TlsChannelCredentials.create(); + } else { + try { + credentials = TlsChannelCredentials.newBuilder().trustManager(testCa).build(); + } catch (Exception e) { + throw new RuntimeException(e); + } } } else { credentials = InsecureChannelCredentials.create(); @@ -57,36 +54,13 @@ class TesterOkHttpChannelBuilder { ManagedChannelBuilder channelBuilder = Grpc.newChannelBuilderForAddress( host, port, credentials) .maxInboundMessageSize(16 * 1024 * 1024); + if (!(channelBuilder instanceof OkHttpChannelBuilder)) { + throw new RuntimeException("Did not receive an OkHttpChannelBuilder"); + } if (serverHostOverride != null) { // Force the hostname to match the cert the server uses. channelBuilder.overrideAuthority(serverHostOverride); } return channelBuilder.build(); } - - private static SSLSocketFactory getSslSocketFactory(@Nullable InputStream testCa) - throws Exception { - if (testCa == null) { - return (SSLSocketFactory) SSLSocketFactory.getDefault(); - } - - SSLContext context = SSLContext.getInstance("TLS"); - context.init(null, getTrustManagers(testCa) , null); - return context.getSocketFactory(); - } - - private static TrustManager[] getTrustManagers(InputStream testCa) throws Exception { - KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); - ks.load(null); - CertificateFactory cf = CertificateFactory.getInstance("X.509"); - X509Certificate cert = (X509Certificate) cf.generateCertificate(testCa); - X500Principal principal = cert.getSubjectX500Principal(); - ks.setCertificateEntry(principal.getName("RFC2253"), cert); - // Set up trust manager factory to use our key store. - TrustManagerFactory trustManagerFactory = - TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); - trustManagerFactory.init(ks); - return trustManagerFactory.getTrustManagers(); - } } - diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/Http2OkHttpTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/Http2OkHttpTest.java index 3bc03a9292..7cce836fa9 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/Http2OkHttpTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/Http2OkHttpTest.java @@ -22,9 +22,11 @@ import static org.junit.Assert.assertTrue; import com.google.common.base.Throwables; import com.squareup.okhttp.ConnectionSpec; +import io.grpc.ChannelCredentials; import io.grpc.ManagedChannel; import io.grpc.ServerBuilder; import io.grpc.ServerCredentials; +import io.grpc.TlsChannelCredentials; import io.grpc.TlsServerCredentials; import io.grpc.internal.GrpcUtil; import io.grpc.internal.testing.StreamRecorder; @@ -80,6 +82,25 @@ public class Http2OkHttpTest extends AbstractInteropTest { @Override protected OkHttpChannelBuilder createChannelBuilder() { + int port = ((InetSocketAddress) getListenAddress()).getPort(); + ChannelCredentials channelCreds; + try { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(TestUtils.loadCert("ca.pem")) + .build(); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("localhost", port, channelCreds) + .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) + .overrideAuthority(GrpcUtil.authorityFromHostAndPort( + TestUtils.TEST_SERVER_HOST, port)); + // Disable the default census stats interceptor, use testing interceptor instead. + InternalOkHttpChannelBuilder.setStatsEnabled(builder, false); + return builder.intercept(createCensusStatsClientInterceptor()); + } + + private OkHttpChannelBuilder createChannelBuilderPreCredentialsApi() { int port = ((InetSocketAddress) getListenAddress()).getPort(); OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("localhost", port) .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) @@ -125,7 +146,7 @@ public class Http2OkHttpTest extends AbstractInteropTest { @Test public void wrongHostNameFailHostnameVerification() throws Exception { int port = ((InetSocketAddress) getListenAddress()).getPort(); - ManagedChannel channel = createChannelBuilder() + ManagedChannel channel = createChannelBuilderPreCredentialsApi() .overrideAuthority(GrpcUtil.authorityFromHostAndPort( BAD_HOSTNAME, port)) .build(); @@ -148,7 +169,7 @@ public class Http2OkHttpTest extends AbstractInteropTest { @Test public void hostnameVerifierWithBadHostname() throws Exception { int port = ((InetSocketAddress) getListenAddress()).getPort(); - ManagedChannel channel = createChannelBuilder() + ManagedChannel channel = createChannelBuilderPreCredentialsApi() .overrideAuthority(GrpcUtil.authorityFromHostAndPort( BAD_HOSTNAME, port)) .hostnameVerifier(new HostnameVerifier() { @@ -169,7 +190,7 @@ public class Http2OkHttpTest extends AbstractInteropTest { @Test public void hostnameVerifierWithCorrectHostname() throws Exception { int port = ((InetSocketAddress) getListenAddress()).getPort(); - ManagedChannel channel = createChannelBuilder() + ManagedChannel channel = createChannelBuilderPreCredentialsApi() .overrideAuthority(GrpcUtil.authorityFromHostAndPort( TestUtils.TEST_SERVER_HOST, port)) .hostnameVerifier(new HostnameVerifier() { diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index d003f735d2..f7d0d97380 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -49,9 +49,14 @@ import io.grpc.okhttp.internal.CipherSuite; import io.grpc.okhttp.internal.ConnectionSpec; import io.grpc.okhttp.internal.Platform; import io.grpc.okhttp.internal.TlsVersion; +import java.io.ByteArrayInputStream; +import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; import java.util.EnumSet; import java.util.Set; import java.util.concurrent.Executor; @@ -59,19 +64,26 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.net.SocketFactory; import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.KeyManager; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.security.auth.x500.X500Principal; /** Convenience class for building channels with the OkHttp transport. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785") public final class OkHttpChannelBuilder extends AbstractManagedChannelImplBuilder { - + private static final Logger log = Logger.getLogger(OkHttpChannelBuilder.class.getName()); public static final int DEFAULT_FLOW_CONTROL_WINDOW = 65535; + private final ManagedChannelImplBuilder managedChannelImplBuilder; private TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory(); @@ -526,7 +538,8 @@ public final class OkHttpChannelBuilder extends } private static final EnumSet understoodTlsFeatures = - EnumSet.noneOf(TlsChannelCredentials.Feature.class); + EnumSet.of( + TlsChannelCredentials.Feature.MTLS, TlsChannelCredentials.Feature.CUSTOM_MANAGERS); static SslSocketFactoryResult sslSocketFactoryFrom(ChannelCredentials creds) { if (creds instanceof TlsChannelCredentials) { @@ -537,14 +550,32 @@ public final class OkHttpChannelBuilder extends return SslSocketFactoryResult.error( "TLS features not understood: " + incomprehensible); } - SSLSocketFactory sslSocketFactory; + KeyManager[] km = null; + if (tlsCreds.getKeyManagers() != null) { + km = tlsCreds.getKeyManagers().toArray(new KeyManager[0]); + } else if (tlsCreds.getPrivateKey() != null) { + return SslSocketFactoryResult.error("byte[]-based private key unsupported. Use KeyManager"); + } // else don't have a client cert + TrustManager[] tm = null; + if (tlsCreds.getTrustManagers() != null) { + tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]); + } else if (tlsCreds.getRootCertificates() != null) { + try { + tm = createTrustManager(tlsCreds.getRootCertificates()); + } catch (GeneralSecurityException gse) { + log.log(Level.FINE, "Exception loading root certificates from credential", gse); + return SslSocketFactoryResult.error( + "Unable to load root certificates: " + gse.getMessage()); + } + } // else use system default + SSLContext sslContext; try { - SSLContext sslContext = SSLContext.getInstance("Default", Platform.get().getProvider()); - sslSocketFactory = sslContext.getSocketFactory(); + sslContext = SSLContext.getInstance("TLS", Platform.get().getProvider()); + sslContext.init(km, tm, null); } catch (GeneralSecurityException gse) { throw new RuntimeException("TLS Provider failure", gse); } - return SslSocketFactoryResult.factory(sslSocketFactory); + return SslSocketFactoryResult.factory(sslContext.getSocketFactory()); } else if (creds instanceof InsecureChannelCredentials) { return SslSocketFactoryResult.plaintext(); @@ -578,6 +609,30 @@ public final class OkHttpChannelBuilder extends } } + static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + ByteArrayInputStream in = new ByteArrayInputStream(rootCerts); + try { + X509Certificate cert = (X509Certificate) cf.generateCertificate(in); + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } finally { + GrpcUtil.closeQuietly(in); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + return trustManagerFactory.getTrustManagers(); + } + static final class SslSocketFactoryResult { /** {@code null} implies plaintext if {@code error == null}. */ public final SSLSocketFactory factory; diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java index 23d140ec9b..0063fc82ca 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java @@ -24,6 +24,7 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.mockito.Mockito.mock; +import com.google.common.util.concurrent.SettableFuture; import com.squareup.okhttp.ConnectionSpec; import io.grpc.CallCredentials; import io.grpc.ChannelCredentials; @@ -38,14 +39,23 @@ import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourceHolder; +import io.grpc.internal.testing.TestUtils; import io.grpc.testing.GrpcCleanupRule; +import io.netty.handler.ssl.util.SelfSignedCertificate; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Socket; +import java.security.KeyStore; +import java.security.cert.Certificate; import java.util.concurrent.ScheduledExecutorService; import javax.net.SocketFactory; +import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManagerFactory; +import javax.security.auth.x500.X500Principal; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -155,6 +165,110 @@ public class OkHttpChannelBuilderTest { assertThat(result.factory).isNull(); } + @Test + public void sslSocketFactoryFrom_tls_customRoots() throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(TestUtils.TEST_SERVER_HOST); + KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(null); + keyStore.setKeyEntry("mykey", cert.key(), new char[0], new Certificate[] {cert.cert()}); + KeyManagerFactory keyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStore, new char[0]); + + SSLContext serverContext = SSLContext.getInstance("TLS"); + serverContext.init(keyManagerFactory.getKeyManagers(), null, null); + final SSLServerSocket serverListenSocket = + (SSLServerSocket) serverContext.getServerSocketFactory().createServerSocket(0); + final SettableFuture serverSocket = SettableFuture.create(); + new Thread(new Runnable() { + @Override public void run() { + try { + SSLSocket socket = (SSLSocket) serverListenSocket.accept(); + socket.getSession(); // Force handshake + serverSocket.set(socket); + serverListenSocket.close(); + } catch (Throwable t) { + serverSocket.setException(t); + } + } + }).start(); + + ChannelCredentials creds = TlsChannelCredentials.newBuilder() + .trustManager(cert.certificate()) + .build(); + OkHttpChannelBuilder.SslSocketFactoryResult result = + OkHttpChannelBuilder.sslSocketFactoryFrom(creds); + SSLSocket socket = + (SSLSocket) result.factory.createSocket("localhost", serverListenSocket.getLocalPort()); + socket.getSession(); // Force handshake + socket.close(); + serverSocket.get().close(); + } + + @Test + public void sslSocketFactoryFrom_tls_mtls() throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(TestUtils.TEST_SERVER_HOST); + KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(null); + keyStore.setKeyEntry("mykey", cert.key(), new char[0], new Certificate[] {cert.cert()}); + KeyManagerFactory keyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStore, new char[0]); + + KeyStore certStore = KeyStore.getInstance(KeyStore.getDefaultType()); + certStore.load(null); + certStore.setCertificateEntry("mycert", cert.cert()); + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(certStore); + + SSLContext serverContext = SSLContext.getInstance("TLS"); + serverContext.init( + keyManagerFactory.getKeyManagers(), trustManagerFactory.getTrustManagers(), null); + final SSLServerSocket serverListenSocket = + (SSLServerSocket) serverContext.getServerSocketFactory().createServerSocket(0); + serverListenSocket.setNeedClientAuth(true); + final SettableFuture serverSocket = SettableFuture.create(); + new Thread(new Runnable() { + @Override public void run() { + try { + SSLSocket socket = (SSLSocket) serverListenSocket.accept(); + socket.getSession(); // Force handshake + serverSocket.set(socket); + serverListenSocket.close(); + } catch (Throwable t) { + serverSocket.setException(t); + } + } + }).start(); + + ChannelCredentials creds = TlsChannelCredentials.newBuilder() + .keyManager(keyManagerFactory.getKeyManagers()) + .trustManager(trustManagerFactory.getTrustManagers()) + .build(); + OkHttpChannelBuilder.SslSocketFactoryResult result = + OkHttpChannelBuilder.sslSocketFactoryFrom(creds); + SSLSocket socket = + (SSLSocket) result.factory.createSocket("localhost", serverListenSocket.getLocalPort()); + socket.getSession(); // Force handshake + assertThat(((X500Principal) serverSocket.get().getSession().getPeerPrincipal()).getName()) + .isEqualTo("CN=" + TestUtils.TEST_SERVER_HOST); + socket.close(); + serverSocket.get().close(); + } + + @Test + public void sslSocketFactoryFrom_tls_mtls_byteKeyUnsupported() throws Exception { + ChannelCredentials creds = TlsChannelCredentials.newBuilder() + .keyManager(TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")) + .build(); + OkHttpChannelBuilder.SslSocketFactoryResult result = + OkHttpChannelBuilder.sslSocketFactoryFrom(creds); + assertThat(result.error).contains("unsupported"); + assertThat(result.callCredentials).isNull(); + assertThat(result.factory).isNull(); + } + @Test public void sslSocketFactoryFrom_insecure() { OkHttpChannelBuilder.SslSocketFactoryResult result =