mirror of https://github.com/grpc/grpc-java.git
okhttp: Consume mTLS and Trust/KeyManager Credentials API
This commit is contained in:
parent
0eab1c9176
commit
2e0e238fb2
|
|
@ -22,16 +22,9 @@ import io.grpc.Grpc;
|
|||
import io.grpc.InsecureChannelCredentials;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.ManagedChannelBuilder;
|
||||
import io.grpc.okhttp.SslSocketFactoryChannelCredentials;
|
||||
import io.grpc.TlsChannelCredentials;
|
||||
import io.grpc.okhttp.OkHttpChannelBuilder;
|
||||
import java.io.InputStream;
|
||||
import java.security.KeyStore;
|
||||
import java.security.cert.CertificateFactory;
|
||||
import java.security.cert.X509Certificate;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
import javax.net.ssl.TrustManager;
|
||||
import javax.net.ssl.TrustManagerFactory;
|
||||
import javax.security.auth.x500.X500Principal;
|
||||
|
||||
/**
|
||||
* A helper class to create a OkHttp based channel.
|
||||
|
|
@ -45,10 +38,14 @@ class TesterOkHttpChannelBuilder {
|
|||
@Nullable InputStream testCa) {
|
||||
ChannelCredentials credentials;
|
||||
if (useTls) {
|
||||
try {
|
||||
credentials = SslSocketFactoryChannelCredentials.create(getSslSocketFactory(testCa));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
if (testCa == null) {
|
||||
credentials = TlsChannelCredentials.create();
|
||||
} else {
|
||||
try {
|
||||
credentials = TlsChannelCredentials.newBuilder().trustManager(testCa).build();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
credentials = InsecureChannelCredentials.create();
|
||||
|
|
@ -57,36 +54,13 @@ class TesterOkHttpChannelBuilder {
|
|||
ManagedChannelBuilder<?> channelBuilder = Grpc.newChannelBuilderForAddress(
|
||||
host, port, credentials)
|
||||
.maxInboundMessageSize(16 * 1024 * 1024);
|
||||
if (!(channelBuilder instanceof OkHttpChannelBuilder)) {
|
||||
throw new RuntimeException("Did not receive an OkHttpChannelBuilder");
|
||||
}
|
||||
if (serverHostOverride != null) {
|
||||
// Force the hostname to match the cert the server uses.
|
||||
channelBuilder.overrideAuthority(serverHostOverride);
|
||||
}
|
||||
return channelBuilder.build();
|
||||
}
|
||||
|
||||
private static SSLSocketFactory getSslSocketFactory(@Nullable InputStream testCa)
|
||||
throws Exception {
|
||||
if (testCa == null) {
|
||||
return (SSLSocketFactory) SSLSocketFactory.getDefault();
|
||||
}
|
||||
|
||||
SSLContext context = SSLContext.getInstance("TLS");
|
||||
context.init(null, getTrustManagers(testCa) , null);
|
||||
return context.getSocketFactory();
|
||||
}
|
||||
|
||||
private static TrustManager[] getTrustManagers(InputStream testCa) throws Exception {
|
||||
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
|
||||
ks.load(null);
|
||||
CertificateFactory cf = CertificateFactory.getInstance("X.509");
|
||||
X509Certificate cert = (X509Certificate) cf.generateCertificate(testCa);
|
||||
X500Principal principal = cert.getSubjectX500Principal();
|
||||
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
|
||||
// Set up trust manager factory to use our key store.
|
||||
TrustManagerFactory trustManagerFactory =
|
||||
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
|
||||
trustManagerFactory.init(ks);
|
||||
return trustManagerFactory.getTrustManagers();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,9 +22,11 @@ import static org.junit.Assert.assertTrue;
|
|||
|
||||
import com.google.common.base.Throwables;
|
||||
import com.squareup.okhttp.ConnectionSpec;
|
||||
import io.grpc.ChannelCredentials;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.ServerBuilder;
|
||||
import io.grpc.ServerCredentials;
|
||||
import io.grpc.TlsChannelCredentials;
|
||||
import io.grpc.TlsServerCredentials;
|
||||
import io.grpc.internal.GrpcUtil;
|
||||
import io.grpc.internal.testing.StreamRecorder;
|
||||
|
|
@ -80,6 +82,25 @@ public class Http2OkHttpTest extends AbstractInteropTest {
|
|||
|
||||
@Override
|
||||
protected OkHttpChannelBuilder createChannelBuilder() {
|
||||
int port = ((InetSocketAddress) getListenAddress()).getPort();
|
||||
ChannelCredentials channelCreds;
|
||||
try {
|
||||
channelCreds = TlsChannelCredentials.newBuilder()
|
||||
.trustManager(TestUtils.loadCert("ca.pem"))
|
||||
.build();
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("localhost", port, channelCreds)
|
||||
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
|
||||
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
|
||||
TestUtils.TEST_SERVER_HOST, port));
|
||||
// Disable the default census stats interceptor, use testing interceptor instead.
|
||||
InternalOkHttpChannelBuilder.setStatsEnabled(builder, false);
|
||||
return builder.intercept(createCensusStatsClientInterceptor());
|
||||
}
|
||||
|
||||
private OkHttpChannelBuilder createChannelBuilderPreCredentialsApi() {
|
||||
int port = ((InetSocketAddress) getListenAddress()).getPort();
|
||||
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("localhost", port)
|
||||
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
|
||||
|
|
@ -125,7 +146,7 @@ public class Http2OkHttpTest extends AbstractInteropTest {
|
|||
@Test
|
||||
public void wrongHostNameFailHostnameVerification() throws Exception {
|
||||
int port = ((InetSocketAddress) getListenAddress()).getPort();
|
||||
ManagedChannel channel = createChannelBuilder()
|
||||
ManagedChannel channel = createChannelBuilderPreCredentialsApi()
|
||||
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
|
||||
BAD_HOSTNAME, port))
|
||||
.build();
|
||||
|
|
@ -148,7 +169,7 @@ public class Http2OkHttpTest extends AbstractInteropTest {
|
|||
@Test
|
||||
public void hostnameVerifierWithBadHostname() throws Exception {
|
||||
int port = ((InetSocketAddress) getListenAddress()).getPort();
|
||||
ManagedChannel channel = createChannelBuilder()
|
||||
ManagedChannel channel = createChannelBuilderPreCredentialsApi()
|
||||
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
|
||||
BAD_HOSTNAME, port))
|
||||
.hostnameVerifier(new HostnameVerifier() {
|
||||
|
|
@ -169,7 +190,7 @@ public class Http2OkHttpTest extends AbstractInteropTest {
|
|||
@Test
|
||||
public void hostnameVerifierWithCorrectHostname() throws Exception {
|
||||
int port = ((InetSocketAddress) getListenAddress()).getPort();
|
||||
ManagedChannel channel = createChannelBuilder()
|
||||
ManagedChannel channel = createChannelBuilderPreCredentialsApi()
|
||||
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
|
||||
TestUtils.TEST_SERVER_HOST, port))
|
||||
.hostnameVerifier(new HostnameVerifier() {
|
||||
|
|
|
|||
|
|
@ -49,9 +49,14 @@ import io.grpc.okhttp.internal.CipherSuite;
|
|||
import io.grpc.okhttp.internal.ConnectionSpec;
|
||||
import io.grpc.okhttp.internal.Platform;
|
||||
import io.grpc.okhttp.internal.TlsVersion;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.SocketAddress;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.security.KeyStore;
|
||||
import java.security.cert.CertificateFactory;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.EnumSet;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.Executor;
|
||||
|
|
@ -59,19 +64,26 @@ import java.util.concurrent.ExecutorService;
|
|||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
import javax.annotation.CheckReturnValue;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.net.SocketFactory;
|
||||
import javax.net.ssl.HostnameVerifier;
|
||||
import javax.net.ssl.KeyManager;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
import javax.net.ssl.TrustManager;
|
||||
import javax.net.ssl.TrustManagerFactory;
|
||||
import javax.security.auth.x500.X500Principal;
|
||||
|
||||
/** Convenience class for building channels with the OkHttp transport. */
|
||||
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785")
|
||||
public final class OkHttpChannelBuilder extends
|
||||
AbstractManagedChannelImplBuilder<OkHttpChannelBuilder> {
|
||||
|
||||
private static final Logger log = Logger.getLogger(OkHttpChannelBuilder.class.getName());
|
||||
public static final int DEFAULT_FLOW_CONTROL_WINDOW = 65535;
|
||||
|
||||
private final ManagedChannelImplBuilder managedChannelImplBuilder;
|
||||
private TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory();
|
||||
|
||||
|
|
@ -526,7 +538,8 @@ public final class OkHttpChannelBuilder extends
|
|||
}
|
||||
|
||||
private static final EnumSet<TlsChannelCredentials.Feature> understoodTlsFeatures =
|
||||
EnumSet.noneOf(TlsChannelCredentials.Feature.class);
|
||||
EnumSet.of(
|
||||
TlsChannelCredentials.Feature.MTLS, TlsChannelCredentials.Feature.CUSTOM_MANAGERS);
|
||||
|
||||
static SslSocketFactoryResult sslSocketFactoryFrom(ChannelCredentials creds) {
|
||||
if (creds instanceof TlsChannelCredentials) {
|
||||
|
|
@ -537,14 +550,32 @@ public final class OkHttpChannelBuilder extends
|
|||
return SslSocketFactoryResult.error(
|
||||
"TLS features not understood: " + incomprehensible);
|
||||
}
|
||||
SSLSocketFactory sslSocketFactory;
|
||||
KeyManager[] km = null;
|
||||
if (tlsCreds.getKeyManagers() != null) {
|
||||
km = tlsCreds.getKeyManagers().toArray(new KeyManager[0]);
|
||||
} else if (tlsCreds.getPrivateKey() != null) {
|
||||
return SslSocketFactoryResult.error("byte[]-based private key unsupported. Use KeyManager");
|
||||
} // else don't have a client cert
|
||||
TrustManager[] tm = null;
|
||||
if (tlsCreds.getTrustManagers() != null) {
|
||||
tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]);
|
||||
} else if (tlsCreds.getRootCertificates() != null) {
|
||||
try {
|
||||
tm = createTrustManager(tlsCreds.getRootCertificates());
|
||||
} catch (GeneralSecurityException gse) {
|
||||
log.log(Level.FINE, "Exception loading root certificates from credential", gse);
|
||||
return SslSocketFactoryResult.error(
|
||||
"Unable to load root certificates: " + gse.getMessage());
|
||||
}
|
||||
} // else use system default
|
||||
SSLContext sslContext;
|
||||
try {
|
||||
SSLContext sslContext = SSLContext.getInstance("Default", Platform.get().getProvider());
|
||||
sslSocketFactory = sslContext.getSocketFactory();
|
||||
sslContext = SSLContext.getInstance("TLS", Platform.get().getProvider());
|
||||
sslContext.init(km, tm, null);
|
||||
} catch (GeneralSecurityException gse) {
|
||||
throw new RuntimeException("TLS Provider failure", gse);
|
||||
}
|
||||
return SslSocketFactoryResult.factory(sslSocketFactory);
|
||||
return SslSocketFactoryResult.factory(sslContext.getSocketFactory());
|
||||
|
||||
} else if (creds instanceof InsecureChannelCredentials) {
|
||||
return SslSocketFactoryResult.plaintext();
|
||||
|
|
@ -578,6 +609,30 @@ public final class OkHttpChannelBuilder extends
|
|||
}
|
||||
}
|
||||
|
||||
static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException {
|
||||
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
|
||||
try {
|
||||
ks.load(null, null);
|
||||
} catch (IOException ex) {
|
||||
// Shouldn't really happen, as we're not loading any data.
|
||||
throw new GeneralSecurityException(ex);
|
||||
}
|
||||
CertificateFactory cf = CertificateFactory.getInstance("X.509");
|
||||
ByteArrayInputStream in = new ByteArrayInputStream(rootCerts);
|
||||
try {
|
||||
X509Certificate cert = (X509Certificate) cf.generateCertificate(in);
|
||||
X500Principal principal = cert.getSubjectX500Principal();
|
||||
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
|
||||
} finally {
|
||||
GrpcUtil.closeQuietly(in);
|
||||
}
|
||||
|
||||
TrustManagerFactory trustManagerFactory =
|
||||
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
|
||||
trustManagerFactory.init(ks);
|
||||
return trustManagerFactory.getTrustManagers();
|
||||
}
|
||||
|
||||
static final class SslSocketFactoryResult {
|
||||
/** {@code null} implies plaintext if {@code error == null}. */
|
||||
public final SSLSocketFactory factory;
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import static org.junit.Assert.assertNull;
|
|||
import static org.junit.Assert.assertSame;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
import com.google.common.util.concurrent.SettableFuture;
|
||||
import com.squareup.okhttp.ConnectionSpec;
|
||||
import io.grpc.CallCredentials;
|
||||
import io.grpc.ChannelCredentials;
|
||||
|
|
@ -38,14 +39,23 @@ import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
|
|||
import io.grpc.internal.FakeClock;
|
||||
import io.grpc.internal.GrpcUtil;
|
||||
import io.grpc.internal.SharedResourceHolder;
|
||||
import io.grpc.internal.testing.TestUtils;
|
||||
import io.grpc.testing.GrpcCleanupRule;
|
||||
import io.netty.handler.ssl.util.SelfSignedCertificate;
|
||||
import java.net.InetAddress;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.Socket;
|
||||
import java.security.KeyStore;
|
||||
import java.security.cert.Certificate;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import javax.net.SocketFactory;
|
||||
import javax.net.ssl.KeyManagerFactory;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLServerSocket;
|
||||
import javax.net.ssl.SSLSocket;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
import javax.net.ssl.TrustManagerFactory;
|
||||
import javax.security.auth.x500.X500Principal;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
|
|
@ -155,6 +165,110 @@ public class OkHttpChannelBuilderTest {
|
|||
assertThat(result.factory).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void sslSocketFactoryFrom_tls_customRoots() throws Exception {
|
||||
SelfSignedCertificate cert = new SelfSignedCertificate(TestUtils.TEST_SERVER_HOST);
|
||||
KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
|
||||
keyStore.load(null);
|
||||
keyStore.setKeyEntry("mykey", cert.key(), new char[0], new Certificate[] {cert.cert()});
|
||||
KeyManagerFactory keyManagerFactory =
|
||||
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
|
||||
keyManagerFactory.init(keyStore, new char[0]);
|
||||
|
||||
SSLContext serverContext = SSLContext.getInstance("TLS");
|
||||
serverContext.init(keyManagerFactory.getKeyManagers(), null, null);
|
||||
final SSLServerSocket serverListenSocket =
|
||||
(SSLServerSocket) serverContext.getServerSocketFactory().createServerSocket(0);
|
||||
final SettableFuture<SSLSocket> serverSocket = SettableFuture.create();
|
||||
new Thread(new Runnable() {
|
||||
@Override public void run() {
|
||||
try {
|
||||
SSLSocket socket = (SSLSocket) serverListenSocket.accept();
|
||||
socket.getSession(); // Force handshake
|
||||
serverSocket.set(socket);
|
||||
serverListenSocket.close();
|
||||
} catch (Throwable t) {
|
||||
serverSocket.setException(t);
|
||||
}
|
||||
}
|
||||
}).start();
|
||||
|
||||
ChannelCredentials creds = TlsChannelCredentials.newBuilder()
|
||||
.trustManager(cert.certificate())
|
||||
.build();
|
||||
OkHttpChannelBuilder.SslSocketFactoryResult result =
|
||||
OkHttpChannelBuilder.sslSocketFactoryFrom(creds);
|
||||
SSLSocket socket =
|
||||
(SSLSocket) result.factory.createSocket("localhost", serverListenSocket.getLocalPort());
|
||||
socket.getSession(); // Force handshake
|
||||
socket.close();
|
||||
serverSocket.get().close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void sslSocketFactoryFrom_tls_mtls() throws Exception {
|
||||
SelfSignedCertificate cert = new SelfSignedCertificate(TestUtils.TEST_SERVER_HOST);
|
||||
KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
|
||||
keyStore.load(null);
|
||||
keyStore.setKeyEntry("mykey", cert.key(), new char[0], new Certificate[] {cert.cert()});
|
||||
KeyManagerFactory keyManagerFactory =
|
||||
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
|
||||
keyManagerFactory.init(keyStore, new char[0]);
|
||||
|
||||
KeyStore certStore = KeyStore.getInstance(KeyStore.getDefaultType());
|
||||
certStore.load(null);
|
||||
certStore.setCertificateEntry("mycert", cert.cert());
|
||||
TrustManagerFactory trustManagerFactory =
|
||||
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
|
||||
trustManagerFactory.init(certStore);
|
||||
|
||||
SSLContext serverContext = SSLContext.getInstance("TLS");
|
||||
serverContext.init(
|
||||
keyManagerFactory.getKeyManagers(), trustManagerFactory.getTrustManagers(), null);
|
||||
final SSLServerSocket serverListenSocket =
|
||||
(SSLServerSocket) serverContext.getServerSocketFactory().createServerSocket(0);
|
||||
serverListenSocket.setNeedClientAuth(true);
|
||||
final SettableFuture<SSLSocket> serverSocket = SettableFuture.create();
|
||||
new Thread(new Runnable() {
|
||||
@Override public void run() {
|
||||
try {
|
||||
SSLSocket socket = (SSLSocket) serverListenSocket.accept();
|
||||
socket.getSession(); // Force handshake
|
||||
serverSocket.set(socket);
|
||||
serverListenSocket.close();
|
||||
} catch (Throwable t) {
|
||||
serverSocket.setException(t);
|
||||
}
|
||||
}
|
||||
}).start();
|
||||
|
||||
ChannelCredentials creds = TlsChannelCredentials.newBuilder()
|
||||
.keyManager(keyManagerFactory.getKeyManagers())
|
||||
.trustManager(trustManagerFactory.getTrustManagers())
|
||||
.build();
|
||||
OkHttpChannelBuilder.SslSocketFactoryResult result =
|
||||
OkHttpChannelBuilder.sslSocketFactoryFrom(creds);
|
||||
SSLSocket socket =
|
||||
(SSLSocket) result.factory.createSocket("localhost", serverListenSocket.getLocalPort());
|
||||
socket.getSession(); // Force handshake
|
||||
assertThat(((X500Principal) serverSocket.get().getSession().getPeerPrincipal()).getName())
|
||||
.isEqualTo("CN=" + TestUtils.TEST_SERVER_HOST);
|
||||
socket.close();
|
||||
serverSocket.get().close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void sslSocketFactoryFrom_tls_mtls_byteKeyUnsupported() throws Exception {
|
||||
ChannelCredentials creds = TlsChannelCredentials.newBuilder()
|
||||
.keyManager(TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"))
|
||||
.build();
|
||||
OkHttpChannelBuilder.SslSocketFactoryResult result =
|
||||
OkHttpChannelBuilder.sslSocketFactoryFrom(creds);
|
||||
assertThat(result.error).contains("unsupported");
|
||||
assertThat(result.callCredentials).isNull();
|
||||
assertThat(result.factory).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void sslSocketFactoryFrom_insecure() {
|
||||
OkHttpChannelBuilder.SslSocketFactoryResult result =
|
||||
|
|
|
|||
Loading…
Reference in New Issue