okhttp: Add missing server support for TLS ClientAuth (#9711)

This commit is contained in:
Eric Anderson 2022-11-22 17:09:03 -08:00 committed by GitHub
parent b593871801
commit c80b587579
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 351 additions and 2 deletions

View File

@ -21,6 +21,7 @@ dependencies {
testImplementation project(':grpc-core').sourceSets.test.output, testImplementation project(':grpc-core').sourceSets.test.output,
project(':grpc-api').sourceSets.test.output, project(':grpc-api').sourceSets.test.output,
project(':grpc-testing'), project(':grpc-testing'),
project(':grpc-testing-proto'),
libraries.netty.codec.http2, libraries.netty.codec.http2,
libraries.okhttp libraries.okhttp
signature libraries.signature.java signature libraries.signature.java

View File

@ -40,7 +40,10 @@ import io.grpc.internal.ServerImplBuilder;
import io.grpc.internal.SharedResourcePool; import io.grpc.internal.SharedResourcePool;
import io.grpc.internal.TransportTracer; import io.grpc.internal.TransportTracer;
import io.grpc.okhttp.internal.Platform; import io.grpc.okhttp.internal.Platform;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.util.EnumSet; import java.util.EnumSet;
@ -54,6 +57,8 @@ import java.util.logging.Logger;
import javax.net.ServerSocketFactory; import javax.net.ServerSocketFactory;
import javax.net.ssl.KeyManager; import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManager;
/** /**
@ -422,9 +427,26 @@ public final class OkHttpServerBuilder extends ForwardingServerBuilder<OkHttpSer
} catch (GeneralSecurityException gse) { } catch (GeneralSecurityException gse) {
throw new RuntimeException("TLS Provider failure", gse); throw new RuntimeException("TLS Provider failure", gse);
} }
SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory();
switch (tlsCreds.getClientAuth()) {
case OPTIONAL:
sslSocketFactory = new ClientCertRequestingSocketFactory(sslSocketFactory, false);
break;
case REQUIRE:
sslSocketFactory = new ClientCertRequestingSocketFactory(sslSocketFactory, true);
break;
case NONE:
// NOOP; this is the SSLContext default
break;
default:
return HandshakerSocketFactoryResult.error(
"Unknown TlsServerCredentials.ClientAuth value: " + tlsCreds.getClientAuth());
}
return HandshakerSocketFactoryResult.factory(new TlsServerHandshakerSocketFactory( return HandshakerSocketFactoryResult.factory(new TlsServerHandshakerSocketFactory(
new SslSocketFactoryServerCredentials.ServerCredentials( new SslSocketFactoryServerCredentials.ServerCredentials(sslSocketFactory)));
sslContext.getSocketFactory())));
} else if (creds instanceof InsecureServerCredentials) { } else if (creds instanceof InsecureServerCredentials) {
return HandshakerSocketFactoryResult.factory(new PlaintextHandshakerSocketFactory()); return HandshakerSocketFactoryResult.factory(new PlaintextHandshakerSocketFactory());
@ -473,4 +495,59 @@ public final class OkHttpServerBuilder extends ForwardingServerBuilder<OkHttpSer
Preconditions.checkNotNull(factory, "factory"), null); Preconditions.checkNotNull(factory, "factory"), null);
} }
} }
static final class ClientCertRequestingSocketFactory extends SSLSocketFactory {
private final SSLSocketFactory socketFactory;
private final boolean required;
public ClientCertRequestingSocketFactory(SSLSocketFactory socketFactory, boolean required) {
this.socketFactory = Preconditions.checkNotNull(socketFactory, "socketFactory");
this.required = required;
}
private Socket apply(Socket s) throws IOException {
if (!(s instanceof SSLSocket)) {
throw new IOException(
"SocketFactory " + socketFactory + " did not produce an SSLSocket: " + s.getClass());
}
SSLSocket sslSocket = (SSLSocket) s;
if (required) {
sslSocket.setNeedClientAuth(true);
} else {
sslSocket.setWantClientAuth(true);
}
return sslSocket;
}
@Override public Socket createSocket(Socket s, String host, int port, boolean autoClose)
throws IOException {
return apply(socketFactory.createSocket(s, host, port, autoClose));
}
@Override public Socket createSocket(String host, int port) throws IOException {
return apply(socketFactory.createSocket(host, port));
}
@Override public Socket createSocket(
String host, int port, InetAddress localHost, int localPort) throws IOException {
return apply(socketFactory.createSocket(host, port, localHost, localPort));
}
@Override public Socket createSocket(InetAddress host, int port) throws IOException {
return apply(socketFactory.createSocket(host, port));
}
@Override public Socket createSocket(
InetAddress host, int port, InetAddress localAddress, int localPort) throws IOException {
return apply(socketFactory.createSocket(host, port, localAddress, localPort));
}
@Override public String[] getDefaultCipherSuites() {
return socketFactory.getDefaultCipherSuites();
}
@Override public String[] getSupportedCipherSuites() {
return socketFactory.getSupportedCipherSuites();
}
}
} }

View File

@ -0,0 +1,271 @@
/*
* Copyright 2015 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import com.google.common.base.Throwables;
import io.grpc.ChannelCredentials;
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Server;
import io.grpc.ServerCredentials;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.TlsChannelCredentials;
import io.grpc.TlsServerCredentials;
import io.grpc.internal.testing.TestUtils;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.TlsTesting;
import io.grpc.testing.protobuf.SimpleRequest;
import io.grpc.testing.protobuf.SimpleResponse;
import io.grpc.testing.protobuf.SimpleServiceGrpc;
import java.io.IOException;
import java.io.InputStream;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import org.junit.Assume;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Verify OkHttp's TLS integration. */
@RunWith(JUnit4.class)
public class TlsTest {
@Rule
public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule();
@Before
public void checkForAlpnApi() throws Exception {
// This checks for the "Java 9 ALPN API" which was backported to Java 8u252. The Kokoro Windows
// CI is on too old of a JDK for us to assume this is available.
SSLContext context = SSLContext.getInstance("TLS");
context.init(null, null, null);
SSLEngine engine = context.createSSLEngine();
try {
SSLEngine.class.getMethod("getApplicationProtocol").invoke(engine);
} catch (NoSuchMethodException | UnsupportedOperationException ex) {
Assume.assumeNoException(ex);
}
}
@Test
public void mtls_succeeds() throws Exception {
ServerCredentials serverCreds;
try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
InputStream caCert = TlsTesting.loadCert("ca.pem")) {
serverCreds = TlsServerCredentials.newBuilder()
.keyManager(serverCert, serverPrivateKey)
.trustManager(caCert)
.clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
.build();
}
ChannelCredentials channelCreds;
try (InputStream clientCertChain = TlsTesting.loadCert("client.pem");
InputStream clientPrivateKey = TlsTesting.loadCert("client.key");
InputStream caCert = TlsTesting.loadCert("ca.pem")) {
channelCreds = TlsChannelCredentials.newBuilder()
.keyManager(clientCertChain, clientPrivateKey)
.trustManager(caCert)
.build();
}
Server server = grpcCleanupRule.register(server(serverCreds));
ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance());
}
@Test
public void untrustedClient_fails() throws Exception {
ServerCredentials serverCreds;
try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
InputStream caCert = TlsTesting.loadCert("ca.pem")) {
serverCreds = TlsServerCredentials.newBuilder()
.keyManager(serverCert, serverPrivateKey)
.trustManager(caCert)
.clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
.build();
}
ChannelCredentials channelCreds;
try (InputStream clientCertChain = TlsTesting.loadCert("badclient.pem");
InputStream clientPrivateKey = TlsTesting.loadCert("badclient.key");
InputStream caCert = TlsTesting.loadCert("ca.pem")) {
channelCreds = TlsChannelCredentials.newBuilder()
.keyManager(clientCertChain, clientPrivateKey)
.trustManager(caCert)
.build();
}
Server server = grpcCleanupRule.register(server(serverCreds));
ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
assertRpcFails(channel);
}
@Test
public void missingOptionalClientCert_succeeds() throws Exception {
ServerCredentials serverCreds;
try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
InputStream caCert = TlsTesting.loadCert("ca.pem")) {
serverCreds = TlsServerCredentials.newBuilder()
.keyManager(serverCert, serverPrivateKey)
.trustManager(caCert)
.clientAuth(TlsServerCredentials.ClientAuth.OPTIONAL)
.build();
}
ChannelCredentials channelCreds;
try (InputStream caCert = TlsTesting.loadCert("ca.pem")) {
channelCreds = TlsChannelCredentials.newBuilder()
.trustManager(caCert)
.build();
}
Server server = grpcCleanupRule.register(server(serverCreds));
ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance());
}
@Test
public void missingRequiredClientCert_fails() throws Exception {
ServerCredentials serverCreds;
try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
InputStream caCert = TlsTesting.loadCert("ca.pem")) {
serverCreds = TlsServerCredentials.newBuilder()
.keyManager(serverCert, serverPrivateKey)
.trustManager(caCert)
.clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
.build();
}
ChannelCredentials channelCreds;
try (InputStream caCert = TlsTesting.loadCert("ca.pem")) {
channelCreds = TlsChannelCredentials.newBuilder()
.trustManager(caCert)
.build();
}
Server server = grpcCleanupRule.register(server(serverCreds));
ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
assertRpcFails(channel);
}
@Test
public void untrustedServer_fails() throws Exception {
ServerCredentials serverCreds;
try (InputStream serverCert = TlsTesting.loadCert("badserver.pem");
InputStream serverPrivateKey = TlsTesting.loadCert("badserver.key");
InputStream caCert = TlsTesting.loadCert("ca.pem")) {
serverCreds = TlsServerCredentials.newBuilder()
.keyManager(serverCert, serverPrivateKey)
.trustManager(caCert)
.build();
}
ChannelCredentials channelCreds;
try (InputStream clientCertChain = TlsTesting.loadCert("client.pem");
InputStream clientPrivateKey = TlsTesting.loadCert("client.key");
InputStream caCert = TlsTesting.loadCert("ca.pem")) {
channelCreds = TlsChannelCredentials.newBuilder()
.keyManager(clientCertChain, clientPrivateKey)
.trustManager(caCert)
.build();
}
Server server = grpcCleanupRule.register(server(serverCreds));
ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
assertRpcFails(channel);
}
@Test
public void unmatchedServerSubjectAlternativeNames_fails() throws Exception {
ServerCredentials serverCreds;
try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
InputStream caCert = TlsTesting.loadCert("ca.pem")) {
serverCreds = TlsServerCredentials.newBuilder()
.keyManager(serverCert, serverPrivateKey)
.trustManager(caCert)
.build();
}
ChannelCredentials channelCreds;
try (InputStream clientCertChain = TlsTesting.loadCert("client.pem");
InputStream clientPrivateKey = TlsTesting.loadCert("client.key");
InputStream caCert = TlsTesting.loadCert("ca.pem")) {
channelCreds = TlsChannelCredentials.newBuilder()
.keyManager(clientCertChain, clientPrivateKey)
.trustManager(caCert)
.build();
}
Server server = grpcCleanupRule.register(server(serverCreds));
ManagedChannel channel = grpcCleanupRule.register(clientChannelBuilder(server, channelCreds)
.overrideAuthority("notgonnamatch.example.com")
.build());
assertRpcFails(channel);
}
private static Server server(ServerCredentials creds) throws IOException {
return OkHttpServerBuilder.forPort(0, creds)
.directExecutor()
.addService(new SimpleServiceImpl())
.build()
.start();
}
private static ManagedChannelBuilder<?> clientChannelBuilder(
Server server, ChannelCredentials creds) {
return OkHttpChannelBuilder.forAddress("localhost", server.getPort(), creds)
.directExecutor()
.overrideAuthority(TestUtils.TEST_SERVER_HOST);
}
private static ManagedChannel clientChannel(Server server, ChannelCredentials creds) {
return clientChannelBuilder(server, creds).build();
}
private static void assertRpcFails(ManagedChannel channel) {
SimpleServiceGrpc.SimpleServiceBlockingStub stub = SimpleServiceGrpc.newBlockingStub(channel);
try {
stub.unaryRpc(SimpleRequest.getDefaultInstance());
assertWithMessage("TLS handshake should have failed, but didn't; received RPC response")
.fail();
} catch (StatusRuntimeException e) {
assertWithMessage(Throwables.getStackTraceAsString(e))
.that(e.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE);
}
// We really want to see TRANSIENT_FAILURE here, but if the test runs slowly the 1s backoff
// may be exceeded by the time the failure happens (since it counts from the start of the
// attempt). Even so, CONNECTING is a strong indicator that the handshake failed; otherwise we'd
// expect READY or IDLE.
assertThat(channel.getState(false))
.isAnyOf(ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.CONNECTING);
}
private static final class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase {
@Override
public void unaryRpc(SimpleRequest req, StreamObserver<SimpleResponse> respOb) {
respOb.onNext(SimpleResponse.getDefaultInstance());
respOb.onCompleted();
}
}
}