okhttp: Consume mTLS and Trust/KeyManager Credentials API

This commit is contained in:
Eric Anderson 2021-02-09 16:20:43 -08:00 committed by Eric Anderson
parent 0eab1c9176
commit 2e0e238fb2
4 changed files with 212 additions and 48 deletions

View File

@ -22,16 +22,9 @@ import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.okhttp.SslSocketFactoryChannelCredentials;
import io.grpc.TlsChannelCredentials;
import io.grpc.okhttp.OkHttpChannelBuilder;
import java.io.InputStream;
import java.security.KeyStore;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;
/**
* A helper class to create a OkHttp based channel.
@ -45,11 +38,15 @@ class TesterOkHttpChannelBuilder {
@Nullable InputStream testCa) {
ChannelCredentials credentials;
if (useTls) {
if (testCa == null) {
credentials = TlsChannelCredentials.create();
} else {
try {
credentials = SslSocketFactoryChannelCredentials.create(getSslSocketFactory(testCa));
credentials = TlsChannelCredentials.newBuilder().trustManager(testCa).build();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
} else {
credentials = InsecureChannelCredentials.create();
}
@ -57,36 +54,13 @@ class TesterOkHttpChannelBuilder {
ManagedChannelBuilder<?> channelBuilder = Grpc.newChannelBuilderForAddress(
host, port, credentials)
.maxInboundMessageSize(16 * 1024 * 1024);
if (!(channelBuilder instanceof OkHttpChannelBuilder)) {
throw new RuntimeException("Did not receive an OkHttpChannelBuilder");
}
if (serverHostOverride != null) {
// Force the hostname to match the cert the server uses.
channelBuilder.overrideAuthority(serverHostOverride);
}
return channelBuilder.build();
}
private static SSLSocketFactory getSslSocketFactory(@Nullable InputStream testCa)
throws Exception {
if (testCa == null) {
return (SSLSocketFactory) SSLSocketFactory.getDefault();
}
SSLContext context = SSLContext.getInstance("TLS");
context.init(null, getTrustManagers(testCa) , null);
return context.getSocketFactory();
}
private static TrustManager[] getTrustManagers(InputStream testCa) throws Exception {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
ks.load(null);
CertificateFactory cf = CertificateFactory.getInstance("X.509");
X509Certificate cert = (X509Certificate) cf.generateCertificate(testCa);
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
// Set up trust manager factory to use our key store.
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(ks);
return trustManagerFactory.getTrustManagers();
}
}

View File

@ -22,9 +22,11 @@ import static org.junit.Assert.assertTrue;
import com.google.common.base.Throwables;
import com.squareup.okhttp.ConnectionSpec;
import io.grpc.ChannelCredentials;
import io.grpc.ManagedChannel;
import io.grpc.ServerBuilder;
import io.grpc.ServerCredentials;
import io.grpc.TlsChannelCredentials;
import io.grpc.TlsServerCredentials;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.testing.StreamRecorder;
@ -80,6 +82,25 @@ public class Http2OkHttpTest extends AbstractInteropTest {
@Override
protected OkHttpChannelBuilder createChannelBuilder() {
int port = ((InetSocketAddress) getListenAddress()).getPort();
ChannelCredentials channelCreds;
try {
channelCreds = TlsChannelCredentials.newBuilder()
.trustManager(TestUtils.loadCert("ca.pem"))
.build();
} catch (IOException ex) {
throw new RuntimeException(ex);
}
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("localhost", port, channelCreds)
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
TestUtils.TEST_SERVER_HOST, port));
// Disable the default census stats interceptor, use testing interceptor instead.
InternalOkHttpChannelBuilder.setStatsEnabled(builder, false);
return builder.intercept(createCensusStatsClientInterceptor());
}
private OkHttpChannelBuilder createChannelBuilderPreCredentialsApi() {
int port = ((InetSocketAddress) getListenAddress()).getPort();
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("localhost", port)
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
@ -125,7 +146,7 @@ public class Http2OkHttpTest extends AbstractInteropTest {
@Test
public void wrongHostNameFailHostnameVerification() throws Exception {
int port = ((InetSocketAddress) getListenAddress()).getPort();
ManagedChannel channel = createChannelBuilder()
ManagedChannel channel = createChannelBuilderPreCredentialsApi()
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
BAD_HOSTNAME, port))
.build();
@ -148,7 +169,7 @@ public class Http2OkHttpTest extends AbstractInteropTest {
@Test
public void hostnameVerifierWithBadHostname() throws Exception {
int port = ((InetSocketAddress) getListenAddress()).getPort();
ManagedChannel channel = createChannelBuilder()
ManagedChannel channel = createChannelBuilderPreCredentialsApi()
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
BAD_HOSTNAME, port))
.hostnameVerifier(new HostnameVerifier() {
@ -169,7 +190,7 @@ public class Http2OkHttpTest extends AbstractInteropTest {
@Test
public void hostnameVerifierWithCorrectHostname() throws Exception {
int port = ((InetSocketAddress) getListenAddress()).getPort();
ManagedChannel channel = createChannelBuilder()
ManagedChannel channel = createChannelBuilderPreCredentialsApi()
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
TestUtils.TEST_SERVER_HOST, port))
.hostnameVerifier(new HostnameVerifier() {

View File

@ -49,9 +49,14 @@ import io.grpc.okhttp.internal.CipherSuite;
import io.grpc.okhttp.internal.ConnectionSpec;
import io.grpc.okhttp.internal.Platform;
import io.grpc.okhttp.internal.TlsVersion;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.EnumSet;
import java.util.Set;
import java.util.concurrent.Executor;
@ -59,19 +64,26 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.CheckReturnValue;
import javax.annotation.Nullable;
import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;
/** Convenience class for building channels with the OkHttp transport. */
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785")
public final class OkHttpChannelBuilder extends
AbstractManagedChannelImplBuilder<OkHttpChannelBuilder> {
private static final Logger log = Logger.getLogger(OkHttpChannelBuilder.class.getName());
public static final int DEFAULT_FLOW_CONTROL_WINDOW = 65535;
private final ManagedChannelImplBuilder managedChannelImplBuilder;
private TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory();
@ -526,7 +538,8 @@ public final class OkHttpChannelBuilder extends
}
private static final EnumSet<TlsChannelCredentials.Feature> understoodTlsFeatures =
EnumSet.noneOf(TlsChannelCredentials.Feature.class);
EnumSet.of(
TlsChannelCredentials.Feature.MTLS, TlsChannelCredentials.Feature.CUSTOM_MANAGERS);
static SslSocketFactoryResult sslSocketFactoryFrom(ChannelCredentials creds) {
if (creds instanceof TlsChannelCredentials) {
@ -537,14 +550,32 @@ public final class OkHttpChannelBuilder extends
return SslSocketFactoryResult.error(
"TLS features not understood: " + incomprehensible);
}
SSLSocketFactory sslSocketFactory;
KeyManager[] km = null;
if (tlsCreds.getKeyManagers() != null) {
km = tlsCreds.getKeyManagers().toArray(new KeyManager[0]);
} else if (tlsCreds.getPrivateKey() != null) {
return SslSocketFactoryResult.error("byte[]-based private key unsupported. Use KeyManager");
} // else don't have a client cert
TrustManager[] tm = null;
if (tlsCreds.getTrustManagers() != null) {
tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]);
} else if (tlsCreds.getRootCertificates() != null) {
try {
SSLContext sslContext = SSLContext.getInstance("Default", Platform.get().getProvider());
sslSocketFactory = sslContext.getSocketFactory();
tm = createTrustManager(tlsCreds.getRootCertificates());
} catch (GeneralSecurityException gse) {
log.log(Level.FINE, "Exception loading root certificates from credential", gse);
return SslSocketFactoryResult.error(
"Unable to load root certificates: " + gse.getMessage());
}
} // else use system default
SSLContext sslContext;
try {
sslContext = SSLContext.getInstance("TLS", Platform.get().getProvider());
sslContext.init(km, tm, null);
} catch (GeneralSecurityException gse) {
throw new RuntimeException("TLS Provider failure", gse);
}
return SslSocketFactoryResult.factory(sslSocketFactory);
return SslSocketFactoryResult.factory(sslContext.getSocketFactory());
} else if (creds instanceof InsecureChannelCredentials) {
return SslSocketFactoryResult.plaintext();
@ -578,6 +609,30 @@ public final class OkHttpChannelBuilder extends
}
}
static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
try {
ks.load(null, null);
} catch (IOException ex) {
// Shouldn't really happen, as we're not loading any data.
throw new GeneralSecurityException(ex);
}
CertificateFactory cf = CertificateFactory.getInstance("X.509");
ByteArrayInputStream in = new ByteArrayInputStream(rootCerts);
try {
X509Certificate cert = (X509Certificate) cf.generateCertificate(in);
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
} finally {
GrpcUtil.closeQuietly(in);
}
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(ks);
return trustManagerFactory.getTrustManagers();
}
static final class SslSocketFactoryResult {
/** {@code null} implies plaintext if {@code error == null}. */
public final SSLSocketFactory factory;

View File

@ -24,6 +24,7 @@ import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.mock;
import com.google.common.util.concurrent.SettableFuture;
import com.squareup.okhttp.ConnectionSpec;
import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
@ -38,14 +39,23 @@ import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourceHolder;
import io.grpc.internal.testing.TestUtils;
import io.grpc.testing.GrpcCleanupRule;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.security.KeyStore;
import java.security.cert.Certificate;
import java.util.concurrent.ScheduledExecutorService;
import javax.net.SocketFactory;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@ -155,6 +165,110 @@ public class OkHttpChannelBuilderTest {
assertThat(result.factory).isNull();
}
@Test
public void sslSocketFactoryFrom_tls_customRoots() throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate(TestUtils.TEST_SERVER_HOST);
KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null);
keyStore.setKeyEntry("mykey", cert.key(), new char[0], new Certificate[] {cert.cert()});
KeyManagerFactory keyManagerFactory =
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
keyManagerFactory.init(keyStore, new char[0]);
SSLContext serverContext = SSLContext.getInstance("TLS");
serverContext.init(keyManagerFactory.getKeyManagers(), null, null);
final SSLServerSocket serverListenSocket =
(SSLServerSocket) serverContext.getServerSocketFactory().createServerSocket(0);
final SettableFuture<SSLSocket> serverSocket = SettableFuture.create();
new Thread(new Runnable() {
@Override public void run() {
try {
SSLSocket socket = (SSLSocket) serverListenSocket.accept();
socket.getSession(); // Force handshake
serverSocket.set(socket);
serverListenSocket.close();
} catch (Throwable t) {
serverSocket.setException(t);
}
}
}).start();
ChannelCredentials creds = TlsChannelCredentials.newBuilder()
.trustManager(cert.certificate())
.build();
OkHttpChannelBuilder.SslSocketFactoryResult result =
OkHttpChannelBuilder.sslSocketFactoryFrom(creds);
SSLSocket socket =
(SSLSocket) result.factory.createSocket("localhost", serverListenSocket.getLocalPort());
socket.getSession(); // Force handshake
socket.close();
serverSocket.get().close();
}
@Test
public void sslSocketFactoryFrom_tls_mtls() throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate(TestUtils.TEST_SERVER_HOST);
KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null);
keyStore.setKeyEntry("mykey", cert.key(), new char[0], new Certificate[] {cert.cert()});
KeyManagerFactory keyManagerFactory =
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
keyManagerFactory.init(keyStore, new char[0]);
KeyStore certStore = KeyStore.getInstance(KeyStore.getDefaultType());
certStore.load(null);
certStore.setCertificateEntry("mycert", cert.cert());
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(certStore);
SSLContext serverContext = SSLContext.getInstance("TLS");
serverContext.init(
keyManagerFactory.getKeyManagers(), trustManagerFactory.getTrustManagers(), null);
final SSLServerSocket serverListenSocket =
(SSLServerSocket) serverContext.getServerSocketFactory().createServerSocket(0);
serverListenSocket.setNeedClientAuth(true);
final SettableFuture<SSLSocket> serverSocket = SettableFuture.create();
new Thread(new Runnable() {
@Override public void run() {
try {
SSLSocket socket = (SSLSocket) serverListenSocket.accept();
socket.getSession(); // Force handshake
serverSocket.set(socket);
serverListenSocket.close();
} catch (Throwable t) {
serverSocket.setException(t);
}
}
}).start();
ChannelCredentials creds = TlsChannelCredentials.newBuilder()
.keyManager(keyManagerFactory.getKeyManagers())
.trustManager(trustManagerFactory.getTrustManagers())
.build();
OkHttpChannelBuilder.SslSocketFactoryResult result =
OkHttpChannelBuilder.sslSocketFactoryFrom(creds);
SSLSocket socket =
(SSLSocket) result.factory.createSocket("localhost", serverListenSocket.getLocalPort());
socket.getSession(); // Force handshake
assertThat(((X500Principal) serverSocket.get().getSession().getPeerPrincipal()).getName())
.isEqualTo("CN=" + TestUtils.TEST_SERVER_HOST);
socket.close();
serverSocket.get().close();
}
@Test
public void sslSocketFactoryFrom_tls_mtls_byteKeyUnsupported() throws Exception {
ChannelCredentials creds = TlsChannelCredentials.newBuilder()
.keyManager(TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"))
.build();
OkHttpChannelBuilder.SslSocketFactoryResult result =
OkHttpChannelBuilder.sslSocketFactoryFrom(creds);
assertThat(result.error).contains("unsupported");
assertThat(result.callCredentials).isNull();
assertThat(result.factory).isNull();
}
@Test
public void sslSocketFactoryFrom_insecure() {
OkHttpChannelBuilder.SslSocketFactoryResult result =