diff --git a/okhttp/BUILD.bazel b/okhttp/BUILD.bazel index d690086df8..e550634aca 100644 --- a/okhttp/BUILD.bazel +++ b/okhttp/BUILD.bazel @@ -11,6 +11,7 @@ java_library( deps = [ "//api", "//core:internal", + "//core:util", "@com_google_code_findbugs_jsr305//jar", "@com_google_errorprone_error_prone_annotations//jar", "@com_google_guava_guava//jar", diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index a752885766..bb2e66c965 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -51,13 +51,14 @@ import io.grpc.okhttp.internal.CipherSuite; import io.grpc.okhttp.internal.ConnectionSpec; import io.grpc.okhttp.internal.Platform; import io.grpc.okhttp.internal.TlsVersion; +import io.grpc.util.CertificateUtils; import java.io.ByteArrayInputStream; import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.GeneralSecurityException; import java.security.KeyStore; -import java.security.cert.CertificateFactory; +import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.EnumSet; import java.util.Set; @@ -73,6 +74,7 @@ import javax.annotation.Nullable; import javax.net.SocketFactory; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; @@ -597,7 +599,16 @@ public final class OkHttpChannelBuilder extends if (tlsCreds.getKeyManagers() != null) { km = tlsCreds.getKeyManagers().toArray(new KeyManager[0]); } else if (tlsCreds.getPrivateKey() != null) { - return SslSocketFactoryResult.error("byte[]-based private key unsupported. Use KeyManager"); + if (tlsCreds.getPrivateKeyPassword() != null) { + return SslSocketFactoryResult.error("byte[]-based private key with password unsupported. " + + "Use unencrypted file or KeyManager"); + } + try { + km = createKeyManager(tlsCreds.getCertificateChain(), tlsCreds.getPrivateKey()); + } catch (GeneralSecurityException gse) { + log.log(Level.FINE, "Exception loading private key from credential", gse); + return SslSocketFactoryResult.error("Unable to load private key: " + gse.getMessage()); + } } // else don't have a client cert TrustManager[] tm = null; if (tlsCreds.getTrustManagers() != null) { @@ -652,6 +663,39 @@ public final class OkHttpChannelBuilder extends } } + static KeyManager[] createKeyManager(byte[] certChain, byte[] privateKey) + throws GeneralSecurityException { + X509Certificate[] chain; + ByteArrayInputStream inCertChain = new ByteArrayInputStream(certChain); + try { + chain = CertificateUtils.getX509Certificates(inCertChain); + } finally { + GrpcUtil.closeQuietly(inCertChain); + } + PrivateKey key; + ByteArrayInputStream inPrivateKey = new ByteArrayInputStream(privateKey); + try { + key = CertificateUtils.getPrivateKey(inPrivateKey); + } catch (IOException uee) { + throw new GeneralSecurityException("Unable to decode private key", uee); + } finally { + GrpcUtil.closeQuietly(inPrivateKey); + } + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + ks.setKeyEntry("key", key, new char[0], chain); + + KeyManagerFactory keyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(ks, new char[0]); + return keyManagerFactory.getKeyManagers(); + } + static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException { KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); try { @@ -660,15 +704,17 @@ public final class OkHttpChannelBuilder extends // Shouldn't really happen, as we're not loading any data. throw new GeneralSecurityException(ex); } - CertificateFactory cf = CertificateFactory.getInstance("X.509"); + X509Certificate[] certs; ByteArrayInputStream in = new ByteArrayInputStream(rootCerts); try { - X509Certificate cert = (X509Certificate) cf.generateCertificate(in); - X500Principal principal = cert.getSubjectX500Principal(); - ks.setCertificateEntry(principal.getName("RFC2253"), cert); + certs = CertificateUtils.getX509Certificates(in); } finally { GrpcUtil.closeQuietly(in); } + for (X509Certificate cert : certs) { + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java index b72c506957..6026e5989c 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java @@ -37,14 +37,9 @@ import io.grpc.internal.ServerImplBuilder; import io.grpc.internal.SharedResourcePool; import io.grpc.internal.TransportTracer; import io.grpc.okhttp.internal.Platform; -import java.io.ByteArrayInputStream; -import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.GeneralSecurityException; -import java.security.KeyStore; -import java.security.cert.CertificateFactory; -import java.security.cert.X509Certificate; import java.util.EnumSet; import java.util.List; import java.util.Set; @@ -57,8 +52,6 @@ import javax.net.ServerSocketFactory; import javax.net.ssl.KeyManager; import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManager; -import javax.net.ssl.TrustManagerFactory; -import javax.security.auth.x500.X500Principal; /** * Build servers with the OkHttp transport. @@ -287,15 +280,25 @@ public final class OkHttpServerBuilder extends ForwardingServerBuilder serverSocket = SettableFuture.create(); + new Thread(new Runnable() { + @Override public void run() { + try { + SSLSocket socket = (SSLSocket) serverListenSocket.accept(); + socket.getSession(); // Force handshake + serverSocket.set(socket); + serverListenSocket.close(); + } catch (Throwable t) { + serverSocket.setException(t); + } + } + }).start(); + ChannelCredentials creds = TlsChannelCredentials.newBuilder() - .keyManager(TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")) + .keyManager(cert.certificate(), cert.privateKey()) + .trustManager(cert.certificate()) + .build(); + OkHttpChannelBuilder.SslSocketFactoryResult result = + OkHttpChannelBuilder.sslSocketFactoryFrom(creds); + SSLSocket socket = + (SSLSocket) result.factory.createSocket("localhost", serverListenSocket.getLocalPort()); + socket.getSession(); // Force handshake + assertThat(((X500Principal) serverSocket.get().getSession().getPeerPrincipal()).getName()) + .isEqualTo("CN=" + TestUtils.TEST_SERVER_HOST); + socket.close(); + serverSocket.get().close(); + } + + @Test + public void sslSocketFactoryFrom_tls_mtls_passwordUnsupported() throws Exception { + ChannelCredentials creds = TlsChannelCredentials.newBuilder() + .keyManager( + TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"), "password") .build(); OkHttpChannelBuilder.SslSocketFactoryResult result = OkHttpChannelBuilder.sslSocketFactoryFrom(creds);