okhttp: Consume mTLS and Trust/KeyManager Credentials API

This commit is contained in:
Eric Anderson 2021-02-09 16:20:43 -08:00 committed by Eric Anderson
parent 0eab1c9176
commit 2e0e238fb2
4 changed files with 212 additions and 48 deletions

View File

@ -22,16 +22,9 @@ import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials; import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder; import io.grpc.ManagedChannelBuilder;
import io.grpc.okhttp.SslSocketFactoryChannelCredentials; import io.grpc.TlsChannelCredentials;
import io.grpc.okhttp.OkHttpChannelBuilder;
import java.io.InputStream; import java.io.InputStream;
import java.security.KeyStore;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;
/** /**
* A helper class to create a OkHttp based channel. * A helper class to create a OkHttp based channel.
@ -45,11 +38,15 @@ class TesterOkHttpChannelBuilder {
@Nullable InputStream testCa) { @Nullable InputStream testCa) {
ChannelCredentials credentials; ChannelCredentials credentials;
if (useTls) { if (useTls) {
if (testCa == null) {
credentials = TlsChannelCredentials.create();
} else {
try { try {
credentials = SslSocketFactoryChannelCredentials.create(getSslSocketFactory(testCa)); credentials = TlsChannelCredentials.newBuilder().trustManager(testCa).build();
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
}
} else { } else {
credentials = InsecureChannelCredentials.create(); credentials = InsecureChannelCredentials.create();
} }
@ -57,36 +54,13 @@ class TesterOkHttpChannelBuilder {
ManagedChannelBuilder<?> channelBuilder = Grpc.newChannelBuilderForAddress( ManagedChannelBuilder<?> channelBuilder = Grpc.newChannelBuilderForAddress(
host, port, credentials) host, port, credentials)
.maxInboundMessageSize(16 * 1024 * 1024); .maxInboundMessageSize(16 * 1024 * 1024);
if (!(channelBuilder instanceof OkHttpChannelBuilder)) {
throw new RuntimeException("Did not receive an OkHttpChannelBuilder");
}
if (serverHostOverride != null) { if (serverHostOverride != null) {
// Force the hostname to match the cert the server uses. // Force the hostname to match the cert the server uses.
channelBuilder.overrideAuthority(serverHostOverride); channelBuilder.overrideAuthority(serverHostOverride);
} }
return channelBuilder.build(); return channelBuilder.build();
} }
private static SSLSocketFactory getSslSocketFactory(@Nullable InputStream testCa)
throws Exception {
if (testCa == null) {
return (SSLSocketFactory) SSLSocketFactory.getDefault();
} }
SSLContext context = SSLContext.getInstance("TLS");
context.init(null, getTrustManagers(testCa) , null);
return context.getSocketFactory();
}
private static TrustManager[] getTrustManagers(InputStream testCa) throws Exception {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
ks.load(null);
CertificateFactory cf = CertificateFactory.getInstance("X.509");
X509Certificate cert = (X509Certificate) cf.generateCertificate(testCa);
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
// Set up trust manager factory to use our key store.
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(ks);
return trustManagerFactory.getTrustManagers();
}
}

View File

@ -22,9 +22,11 @@ import static org.junit.Assert.assertTrue;
import com.google.common.base.Throwables; import com.google.common.base.Throwables;
import com.squareup.okhttp.ConnectionSpec; import com.squareup.okhttp.ConnectionSpec;
import io.grpc.ChannelCredentials;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ServerBuilder; import io.grpc.ServerBuilder;
import io.grpc.ServerCredentials; import io.grpc.ServerCredentials;
import io.grpc.TlsChannelCredentials;
import io.grpc.TlsServerCredentials; import io.grpc.TlsServerCredentials;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.testing.StreamRecorder; import io.grpc.internal.testing.StreamRecorder;
@ -80,6 +82,25 @@ public class Http2OkHttpTest extends AbstractInteropTest {
@Override @Override
protected OkHttpChannelBuilder createChannelBuilder() { protected OkHttpChannelBuilder createChannelBuilder() {
int port = ((InetSocketAddress) getListenAddress()).getPort();
ChannelCredentials channelCreds;
try {
channelCreds = TlsChannelCredentials.newBuilder()
.trustManager(TestUtils.loadCert("ca.pem"))
.build();
} catch (IOException ex) {
throw new RuntimeException(ex);
}
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("localhost", port, channelCreds)
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
TestUtils.TEST_SERVER_HOST, port));
// Disable the default census stats interceptor, use testing interceptor instead.
InternalOkHttpChannelBuilder.setStatsEnabled(builder, false);
return builder.intercept(createCensusStatsClientInterceptor());
}
private OkHttpChannelBuilder createChannelBuilderPreCredentialsApi() {
int port = ((InetSocketAddress) getListenAddress()).getPort(); int port = ((InetSocketAddress) getListenAddress()).getPort();
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("localhost", port) OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("localhost", port)
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
@ -125,7 +146,7 @@ public class Http2OkHttpTest extends AbstractInteropTest {
@Test @Test
public void wrongHostNameFailHostnameVerification() throws Exception { public void wrongHostNameFailHostnameVerification() throws Exception {
int port = ((InetSocketAddress) getListenAddress()).getPort(); int port = ((InetSocketAddress) getListenAddress()).getPort();
ManagedChannel channel = createChannelBuilder() ManagedChannel channel = createChannelBuilderPreCredentialsApi()
.overrideAuthority(GrpcUtil.authorityFromHostAndPort( .overrideAuthority(GrpcUtil.authorityFromHostAndPort(
BAD_HOSTNAME, port)) BAD_HOSTNAME, port))
.build(); .build();
@ -148,7 +169,7 @@ public class Http2OkHttpTest extends AbstractInteropTest {
@Test @Test
public void hostnameVerifierWithBadHostname() throws Exception { public void hostnameVerifierWithBadHostname() throws Exception {
int port = ((InetSocketAddress) getListenAddress()).getPort(); int port = ((InetSocketAddress) getListenAddress()).getPort();
ManagedChannel channel = createChannelBuilder() ManagedChannel channel = createChannelBuilderPreCredentialsApi()
.overrideAuthority(GrpcUtil.authorityFromHostAndPort( .overrideAuthority(GrpcUtil.authorityFromHostAndPort(
BAD_HOSTNAME, port)) BAD_HOSTNAME, port))
.hostnameVerifier(new HostnameVerifier() { .hostnameVerifier(new HostnameVerifier() {
@ -169,7 +190,7 @@ public class Http2OkHttpTest extends AbstractInteropTest {
@Test @Test
public void hostnameVerifierWithCorrectHostname() throws Exception { public void hostnameVerifierWithCorrectHostname() throws Exception {
int port = ((InetSocketAddress) getListenAddress()).getPort(); int port = ((InetSocketAddress) getListenAddress()).getPort();
ManagedChannel channel = createChannelBuilder() ManagedChannel channel = createChannelBuilderPreCredentialsApi()
.overrideAuthority(GrpcUtil.authorityFromHostAndPort( .overrideAuthority(GrpcUtil.authorityFromHostAndPort(
TestUtils.TEST_SERVER_HOST, port)) TestUtils.TEST_SERVER_HOST, port))
.hostnameVerifier(new HostnameVerifier() { .hostnameVerifier(new HostnameVerifier() {

View File

@ -49,9 +49,14 @@ import io.grpc.okhttp.internal.CipherSuite;
import io.grpc.okhttp.internal.ConnectionSpec; import io.grpc.okhttp.internal.ConnectionSpec;
import io.grpc.okhttp.internal.Platform; import io.grpc.okhttp.internal.Platform;
import io.grpc.okhttp.internal.TlsVersion; import io.grpc.okhttp.internal.TlsVersion;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.Set; import java.util.Set;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
@ -59,19 +64,26 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.CheckReturnValue; import javax.annotation.CheckReturnValue;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.net.SocketFactory; import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;
/** Convenience class for building channels with the OkHttp transport. */ /** Convenience class for building channels with the OkHttp transport. */
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785") @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785")
public final class OkHttpChannelBuilder extends public final class OkHttpChannelBuilder extends
AbstractManagedChannelImplBuilder<OkHttpChannelBuilder> { AbstractManagedChannelImplBuilder<OkHttpChannelBuilder> {
private static final Logger log = Logger.getLogger(OkHttpChannelBuilder.class.getName());
public static final int DEFAULT_FLOW_CONTROL_WINDOW = 65535; public static final int DEFAULT_FLOW_CONTROL_WINDOW = 65535;
private final ManagedChannelImplBuilder managedChannelImplBuilder; private final ManagedChannelImplBuilder managedChannelImplBuilder;
private TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory(); private TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory();
@ -526,7 +538,8 @@ public final class OkHttpChannelBuilder extends
} }
private static final EnumSet<TlsChannelCredentials.Feature> understoodTlsFeatures = private static final EnumSet<TlsChannelCredentials.Feature> understoodTlsFeatures =
EnumSet.noneOf(TlsChannelCredentials.Feature.class); EnumSet.of(
TlsChannelCredentials.Feature.MTLS, TlsChannelCredentials.Feature.CUSTOM_MANAGERS);
static SslSocketFactoryResult sslSocketFactoryFrom(ChannelCredentials creds) { static SslSocketFactoryResult sslSocketFactoryFrom(ChannelCredentials creds) {
if (creds instanceof TlsChannelCredentials) { if (creds instanceof TlsChannelCredentials) {
@ -537,14 +550,32 @@ public final class OkHttpChannelBuilder extends
return SslSocketFactoryResult.error( return SslSocketFactoryResult.error(
"TLS features not understood: " + incomprehensible); "TLS features not understood: " + incomprehensible);
} }
SSLSocketFactory sslSocketFactory; KeyManager[] km = null;
if (tlsCreds.getKeyManagers() != null) {
km = tlsCreds.getKeyManagers().toArray(new KeyManager[0]);
} else if (tlsCreds.getPrivateKey() != null) {
return SslSocketFactoryResult.error("byte[]-based private key unsupported. Use KeyManager");
} // else don't have a client cert
TrustManager[] tm = null;
if (tlsCreds.getTrustManagers() != null) {
tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]);
} else if (tlsCreds.getRootCertificates() != null) {
try { try {
SSLContext sslContext = SSLContext.getInstance("Default", Platform.get().getProvider()); tm = createTrustManager(tlsCreds.getRootCertificates());
sslSocketFactory = sslContext.getSocketFactory(); } catch (GeneralSecurityException gse) {
log.log(Level.FINE, "Exception loading root certificates from credential", gse);
return SslSocketFactoryResult.error(
"Unable to load root certificates: " + gse.getMessage());
}
} // else use system default
SSLContext sslContext;
try {
sslContext = SSLContext.getInstance("TLS", Platform.get().getProvider());
sslContext.init(km, tm, null);
} catch (GeneralSecurityException gse) { } catch (GeneralSecurityException gse) {
throw new RuntimeException("TLS Provider failure", gse); throw new RuntimeException("TLS Provider failure", gse);
} }
return SslSocketFactoryResult.factory(sslSocketFactory); return SslSocketFactoryResult.factory(sslContext.getSocketFactory());
} else if (creds instanceof InsecureChannelCredentials) { } else if (creds instanceof InsecureChannelCredentials) {
return SslSocketFactoryResult.plaintext(); return SslSocketFactoryResult.plaintext();
@ -578,6 +609,30 @@ public final class OkHttpChannelBuilder extends
} }
} }
static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
try {
ks.load(null, null);
} catch (IOException ex) {
// Shouldn't really happen, as we're not loading any data.
throw new GeneralSecurityException(ex);
}
CertificateFactory cf = CertificateFactory.getInstance("X.509");
ByteArrayInputStream in = new ByteArrayInputStream(rootCerts);
try {
X509Certificate cert = (X509Certificate) cf.generateCertificate(in);
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
} finally {
GrpcUtil.closeQuietly(in);
}
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(ks);
return trustManagerFactory.getTrustManagers();
}
static final class SslSocketFactoryResult { static final class SslSocketFactoryResult {
/** {@code null} implies plaintext if {@code error == null}. */ /** {@code null} implies plaintext if {@code error == null}. */
public final SSLSocketFactory factory; public final SSLSocketFactory factory;

View File

@ -24,6 +24,7 @@ import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import com.google.common.util.concurrent.SettableFuture;
import com.squareup.okhttp.ConnectionSpec; import com.squareup.okhttp.ConnectionSpec;
import io.grpc.CallCredentials; import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials; import io.grpc.ChannelCredentials;
@ -38,14 +39,23 @@ import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
import io.grpc.internal.FakeClock; import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourceHolder;
import io.grpc.internal.testing.TestUtils;
import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.GrpcCleanupRule;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.Socket; import java.net.Socket;
import java.security.KeyStore;
import java.security.cert.Certificate;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import javax.net.SocketFactory; import javax.net.SocketFactory;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
@ -155,6 +165,110 @@ public class OkHttpChannelBuilderTest {
assertThat(result.factory).isNull(); assertThat(result.factory).isNull();
} }
@Test
public void sslSocketFactoryFrom_tls_customRoots() throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate(TestUtils.TEST_SERVER_HOST);
KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null);
keyStore.setKeyEntry("mykey", cert.key(), new char[0], new Certificate[] {cert.cert()});
KeyManagerFactory keyManagerFactory =
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
keyManagerFactory.init(keyStore, new char[0]);
SSLContext serverContext = SSLContext.getInstance("TLS");
serverContext.init(keyManagerFactory.getKeyManagers(), null, null);
final SSLServerSocket serverListenSocket =
(SSLServerSocket) serverContext.getServerSocketFactory().createServerSocket(0);
final SettableFuture<SSLSocket> serverSocket = SettableFuture.create();
new Thread(new Runnable() {
@Override public void run() {
try {
SSLSocket socket = (SSLSocket) serverListenSocket.accept();
socket.getSession(); // Force handshake
serverSocket.set(socket);
serverListenSocket.close();
} catch (Throwable t) {
serverSocket.setException(t);
}
}
}).start();
ChannelCredentials creds = TlsChannelCredentials.newBuilder()
.trustManager(cert.certificate())
.build();
OkHttpChannelBuilder.SslSocketFactoryResult result =
OkHttpChannelBuilder.sslSocketFactoryFrom(creds);
SSLSocket socket =
(SSLSocket) result.factory.createSocket("localhost", serverListenSocket.getLocalPort());
socket.getSession(); // Force handshake
socket.close();
serverSocket.get().close();
}
@Test
public void sslSocketFactoryFrom_tls_mtls() throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate(TestUtils.TEST_SERVER_HOST);
KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null);
keyStore.setKeyEntry("mykey", cert.key(), new char[0], new Certificate[] {cert.cert()});
KeyManagerFactory keyManagerFactory =
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
keyManagerFactory.init(keyStore, new char[0]);
KeyStore certStore = KeyStore.getInstance(KeyStore.getDefaultType());
certStore.load(null);
certStore.setCertificateEntry("mycert", cert.cert());
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(certStore);
SSLContext serverContext = SSLContext.getInstance("TLS");
serverContext.init(
keyManagerFactory.getKeyManagers(), trustManagerFactory.getTrustManagers(), null);
final SSLServerSocket serverListenSocket =
(SSLServerSocket) serverContext.getServerSocketFactory().createServerSocket(0);
serverListenSocket.setNeedClientAuth(true);
final SettableFuture<SSLSocket> serverSocket = SettableFuture.create();
new Thread(new Runnable() {
@Override public void run() {
try {
SSLSocket socket = (SSLSocket) serverListenSocket.accept();
socket.getSession(); // Force handshake
serverSocket.set(socket);
serverListenSocket.close();
} catch (Throwable t) {
serverSocket.setException(t);
}
}
}).start();
ChannelCredentials creds = TlsChannelCredentials.newBuilder()
.keyManager(keyManagerFactory.getKeyManagers())
.trustManager(trustManagerFactory.getTrustManagers())
.build();
OkHttpChannelBuilder.SslSocketFactoryResult result =
OkHttpChannelBuilder.sslSocketFactoryFrom(creds);
SSLSocket socket =
(SSLSocket) result.factory.createSocket("localhost", serverListenSocket.getLocalPort());
socket.getSession(); // Force handshake
assertThat(((X500Principal) serverSocket.get().getSession().getPeerPrincipal()).getName())
.isEqualTo("CN=" + TestUtils.TEST_SERVER_HOST);
socket.close();
serverSocket.get().close();
}
@Test
public void sslSocketFactoryFrom_tls_mtls_byteKeyUnsupported() throws Exception {
ChannelCredentials creds = TlsChannelCredentials.newBuilder()
.keyManager(TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"))
.build();
OkHttpChannelBuilder.SslSocketFactoryResult result =
OkHttpChannelBuilder.sslSocketFactoryFrom(creds);
assertThat(result.error).contains("unsupported");
assertThat(result.callCredentials).isNull();
assertThat(result.factory).isNull();
}
@Test @Test
public void sslSocketFactoryFrom_insecure() { public void sslSocketFactoryFrom_insecure() {
OkHttpChannelBuilder.SslSocketFactoryResult result = OkHttpChannelBuilder.SslSocketFactoryResult result =