netty: Per-rpc authority verification against peer cert subject names (#11724)

Per-rpc verification of authority specified via call options or set by the LB API against peer cert's subject names.
This commit is contained in:
Kannan J 2025-02-24 14:58:11 +00:00 committed by GitHub
parent 57124d6b29
commit cdab410b81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1228 additions and 83 deletions

View File

@ -0,0 +1,24 @@
/*
* Copyright 2025 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.internal;
import io.grpc.Status;
/** Verifier for the outgoing authority pseudo-header against peer cert. */
public interface AuthorityVerifier {
Status verifyAuthority(String authority);
}

View File

@ -0,0 +1,66 @@
/*
* Copyright 2024 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.internal;
import java.io.IOException;
import java.io.InputStream;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.Collection;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;
/**
* Contains certificate/key PEM file utility method(s) for internal usage.
*/
public final class CertificateUtils {
/**
* Creates X509TrustManagers using the provided CA certs.
*/
public static TrustManager[] createTrustManager(InputStream rootCerts)
throws GeneralSecurityException {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
try {
ks.load(null, null);
} catch (IOException ex) {
// Shouldn't really happen, as we're not loading any data.
throw new GeneralSecurityException(ex);
}
X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts);
for (X509Certificate cert : certs) {
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
}
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(ks);
return trustManagerFactory.getTrustManagers();
}
private static X509Certificate[] getX509Certificates(InputStream inputStream)
throws CertificateException {
CertificateFactory factory = CertificateFactory.getInstance("X.509");
Collection<? extends Certificate> certs = factory.generateCertificates(inputStream);
return certs.toArray(new X509Certificate[0]);
}
}

View File

@ -42,5 +42,8 @@ public final class GrpcAttributes {
public static final Attributes.Key<Attributes> ATTR_CLIENT_EAG_ATTRS =
Attributes.Key.create("io.grpc.internal.GrpcAttributes.clientEagAttrs");
public static final Attributes.Key<AuthorityVerifier> ATTR_AUTHORITY_VERIFIER =
Attributes.Key.create("io.grpc.internal.GrpcAttributes.authorityVerifier");
private GrpcAttributes() {}
}

View File

@ -0,0 +1,132 @@
/*
* Copyright 2024 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.internal;
import java.security.Principal;
import java.security.cert.Certificate;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSessionContext;
/** A no-op ssl session, to facilitate overriding only the required methods in specific
* implementations.
*/
public class NoopSslSession implements SSLSession {
@Override
public byte[] getId() {
return new byte[0];
}
@Override
public SSLSessionContext getSessionContext() {
return null;
}
@Override
@SuppressWarnings("deprecation")
public javax.security.cert.X509Certificate[] getPeerCertificateChain() {
throw new UnsupportedOperationException("This method is deprecated and marked for removal. "
+ "Use the getPeerCertificates() method instead.");
}
@Override
public long getCreationTime() {
return 0;
}
@Override
public long getLastAccessedTime() {
return 0;
}
@Override
public void invalidate() {
}
@Override
public boolean isValid() {
return false;
}
@Override
public void putValue(String s, Object o) {
}
@Override
public Object getValue(String s) {
return null;
}
@Override
public void removeValue(String s) {
}
@Override
public String[] getValueNames() {
return new String[0];
}
@Override
public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException {
return new Certificate[0];
}
@Override
public Certificate[] getLocalCertificates() {
return new Certificate[0];
}
@Override
public Principal getPeerPrincipal() throws SSLPeerUnverifiedException {
return null;
}
@Override
public Principal getLocalPrincipal() {
return null;
}
@Override
public String getCipherSuite() {
return null;
}
@Override
public String getProtocol() {
return null;
}
@Override
public String getPeerHost() {
return null;
}
@Override
public int getPeerPort() {
return 0;
}
@Override
public int getPacketBufferSize() {
return 0;
}
@Override
public int getApplicationBufferSize() {
return 0;
}
}

View File

@ -66,6 +66,16 @@ final class GrpcHttp2OutboundHeaders extends AbstractHttp2Headers {
this.preHeaders = preHeaders;
}
@Override
public CharSequence authority() {
for (int i = 0; i < preHeaders.length / 2; i++) {
if (preHeaders[i * 2].equals(Http2Headers.PseudoHeaderName.AUTHORITY.value())) {
return preHeaders[i * 2 + 1];
}
}
return null;
}
@Override
@SuppressWarnings("ReferenceEquality") // STATUS.value() never changes.
public CharSequence status() {

View File

@ -44,7 +44,7 @@ public final class InternalProtocolNegotiators {
ObjectPool<? extends Executor> executorPool,
Optional<Runnable> handshakeCompleteRunnable) {
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext,
executorPool, handshakeCompleteRunnable);
executorPool, handshakeCompleteRunnable, null);
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
@Override
@ -170,7 +170,7 @@ public final class InternalProtocolNegotiators {
ChannelHandler next, SslContext sslContext, String authority,
ChannelLogger negotiationLogger) {
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger,
Optional.absent());
Optional.absent(), null, null);
}
public static class ProtocolNegotiationHandler

View File

@ -652,7 +652,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh
case PLAINTEXT_UPGRADE:
return ProtocolNegotiators.plaintextUpgrade();
case TLS:
return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent());
return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null);
default:
throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType);
}

View File

@ -28,6 +28,7 @@ import com.google.common.base.Ticker;
import io.grpc.Attributes;
import io.grpc.ChannelLogger;
import io.grpc.InternalChannelz;
import io.grpc.InternalStatus;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.StatusException;
@ -83,6 +84,8 @@ import io.perfmark.PerfMark;
import io.perfmark.Tag;
import io.perfmark.TaskCloseable;
import java.nio.channels.ClosedChannelException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.logging.Level;
import java.util.logging.Logger;
@ -94,6 +97,8 @@ import javax.annotation.Nullable;
*/
class NettyClientHandler extends AbstractNettyHandler {
private static final Logger logger = Logger.getLogger(NettyClientHandler.class.getName());
static boolean enablePerRpcAuthorityCheck =
GrpcUtil.getFlag("GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK", false);
/**
* A message that simply passes through the channel without any real processing. It is useful to
@ -128,6 +133,13 @@ class NettyClientHandler extends AbstractNettyHandler {
lifecycleManager.notifyInUse(false);
}
};
private final Map<String, Status> peerVerificationResults =
new LinkedHashMap<String, Status>() {
@Override
protected boolean removeEldestEntry(Map.Entry<String, Status> eldest) {
return size() > 100;
}
};
private WriteQueue clientWriteQueue;
private Http2Ping ping;
@ -591,6 +603,56 @@ class NettyClientHandler extends AbstractNettyHandler {
return;
}
CharSequence authorityHeader = command.headers().authority();
if (authorityHeader == null) {
Status authorityVerificationStatus = Status.UNAVAILABLE.withDescription(
"Missing authority header");
command.stream().setNonExistent();
command.stream().transportReportStatus(
Status.UNAVAILABLE, RpcProgress.PROCESSED, true, new Metadata());
promise.setFailure(InternalStatus.asRuntimeExceptionWithoutStacktrace(
authorityVerificationStatus, null));
return;
}
// No need to verify authority for the rpc outgoing header if it is same as the authority
// for the transport
if (!authority.contentEquals(authorityHeader)) {
Status authorityVerificationStatus = peerVerificationResults.get(
authorityHeader.toString());
if (authorityVerificationStatus == null) {
if (attributes.get(GrpcAttributes.ATTR_AUTHORITY_VERIFIER) == null) {
authorityVerificationStatus = Status.UNAVAILABLE.withDescription(
"Authority verifier not found to verify authority");
command.stream().setNonExistent();
command.stream().transportReportStatus(
authorityVerificationStatus, RpcProgress.PROCESSED, true, new Metadata());
promise.setFailure(InternalStatus.asRuntimeExceptionWithoutStacktrace(
authorityVerificationStatus, null));
return;
}
authorityVerificationStatus = attributes.get(GrpcAttributes.ATTR_AUTHORITY_VERIFIER)
.verifyAuthority(authorityHeader.toString());
peerVerificationResults.put(authorityHeader.toString(), authorityVerificationStatus);
if (!authorityVerificationStatus.isOk() && !enablePerRpcAuthorityCheck) {
logger.log(Level.WARNING, String.format("%s.%s",
authorityVerificationStatus.getDescription(),
enablePerRpcAuthorityCheck
? "" : " This will be an error in the future."),
InternalStatus.asRuntimeExceptionWithoutStacktrace(
authorityVerificationStatus, null));
}
}
if (!authorityVerificationStatus.isOk()) {
if (enablePerRpcAuthorityCheck) {
command.stream().setNonExistent();
command.stream().transportReportStatus(
authorityVerificationStatus, RpcProgress.PROCESSED, true, new Metadata());
promise.setFailure(InternalStatus.asRuntimeExceptionWithoutStacktrace(
authorityVerificationStatus, null));
return;
}
}
}
// Get the stream ID for the new stream.
int streamId;
try {

View File

@ -106,6 +106,7 @@ class NettyClientTransport implements ConnectionClientTransport {
private final boolean useGetForSafeMethods;
private final Ticker ticker;
NettyClientTransport(
SocketAddress address,
ChannelFactory<? extends Channel> channelFactory,

View File

@ -34,6 +34,6 @@ public final class NettySslContextChannelCredentials {
Preconditions.checkArgument(sslContext.isClient(),
"Server SSL context can not be used for client channel");
GrpcSslContexts.ensureAlpnAndH2Enabled(sslContext.applicationProtocolNegotiator());
return NettyChannelCredentials.create(ProtocolNegotiators.tlsClientFactory(sslContext));
return NettyChannelCredentials.create(ProtocolNegotiators.tlsClientFactory(sslContext, null));
}
}

View File

@ -0,0 +1,151 @@
/*
* Copyright 2024 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.netty;
import java.nio.ByteBuffer;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
/**
* A no-op implementation of SslEngine, to facilitate overriding only the required methods in
* specific implementations.
*/
class NoopSslEngine extends SSLEngine {
@Override
public SSLEngineResult wrap(ByteBuffer[] srcs, int offset, int length, ByteBuffer dst)
throws SSLException {
return null;
}
@Override
public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts, int offset, int length)
throws SSLException {
return null;
}
@Override
public Runnable getDelegatedTask() {
return null;
}
@Override
public void closeInbound() throws SSLException {
}
@Override
public boolean isInboundDone() {
return false;
}
@Override
public void closeOutbound() {
}
@Override
public boolean isOutboundDone() {
return false;
}
@Override
public String[] getSupportedCipherSuites() {
return new String[0];
}
@Override
public String[] getEnabledCipherSuites() {
return new String[0];
}
@Override
public void setEnabledCipherSuites(String[] suites) {
}
@Override
public String[] getSupportedProtocols() {
return new String[0];
}
@Override
public String[] getEnabledProtocols() {
return new String[0];
}
@Override
public void setEnabledProtocols(String[] protocols) {
}
@Override
public SSLSession getSession() {
return null;
}
@Override
public void beginHandshake() throws SSLException {
}
@Override
public SSLEngineResult.HandshakeStatus getHandshakeStatus() {
return null;
}
@Override
public void setUseClientMode(boolean mode) {
}
@Override
public boolean getUseClientMode() {
return false;
}
@Override
public void setNeedClientAuth(boolean need) {
}
@Override
public boolean getNeedClientAuth() {
return false;
}
@Override
public void setWantClientAuth(boolean want) {
}
@Override
public boolean getWantClientAuth() {
return false;
}
@Override
public void setEnableSessionCreation(boolean flag) {
}
@Override
public boolean getEnableSessionCreation() {
return false;
}
}

View File

@ -63,4 +63,5 @@ interface ProtocolNegotiator {
*/
ProtocolNegotiator newNegotiator(ObjectPool<? extends Executor> offloadExecutorPool);
}
}

View File

@ -16,7 +16,6 @@
package io.grpc.netty;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
@ -42,8 +41,10 @@ import io.grpc.ServerCredentials;
import io.grpc.Status;
import io.grpc.TlsChannelCredentials;
import io.grpc.TlsServerCredentials;
import io.grpc.internal.CertificateUtils;
import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.NoopSslSession;
import io.grpc.internal.ObjectPool;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
@ -71,8 +72,11 @@ import java.io.ByteArrayInputStream;
import java.net.SocketAddress;
import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.logging.Level;
@ -82,6 +86,9 @@ import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSession;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;
import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement;
/**
@ -95,7 +102,15 @@ final class ProtocolNegotiators {
private static final EnumSet<TlsServerCredentials.Feature> understoodServerTlsFeatures =
EnumSet.of(
TlsServerCredentials.Feature.MTLS, TlsServerCredentials.Feature.CUSTOM_MANAGERS);
private static Class<?> x509ExtendedTrustManagerClass;
static {
try {
x509ExtendedTrustManagerClass = Class.forName("javax.net.ssl.X509ExtendedTrustManager");
} catch (ClassNotFoundException e) {
// Will disallow per-rpc authority override via call option.
}
}
private ProtocolNegotiators() {
}
@ -118,14 +133,32 @@ final class ProtocolNegotiators {
new ByteArrayInputStream(tlsCreds.getPrivateKey()),
tlsCreds.getPrivateKeyPassword());
}
if (tlsCreds.getTrustManagers() != null) {
builder.trustManager(new FixedTrustManagerFactory(tlsCreds.getTrustManagers()));
} else if (tlsCreds.getRootCertificates() != null) {
builder.trustManager(new ByteArrayInputStream(tlsCreds.getRootCertificates()));
} // else use system default
try {
return FromChannelCredentialsResult.negotiator(tlsClientFactory(builder.build()));
} catch (SSLException ex) {
List<TrustManager> trustManagers;
if (tlsCreds.getTrustManagers() != null) {
trustManagers = tlsCreds.getTrustManagers();
} else if (tlsCreds.getRootCertificates() != null) {
trustManagers = Arrays.asList(CertificateUtils.createTrustManager(
new ByteArrayInputStream(tlsCreds.getRootCertificates())));
} else { // else use system default
TrustManagerFactory tmf = TrustManagerFactory.getInstance(
TrustManagerFactory.getDefaultAlgorithm());
tmf.init((KeyStore) null);
trustManagers = Arrays.asList(tmf.getTrustManagers());
}
builder.trustManager(new FixedTrustManagerFactory(trustManagers));
TrustManager x509ExtendedTrustManager = null;
if (x509ExtendedTrustManagerClass != null) {
for (TrustManager trustManager : trustManagers) {
if (x509ExtendedTrustManagerClass.isInstance(trustManager)) {
x509ExtendedTrustManager = trustManager;
break;
}
}
}
return FromChannelCredentialsResult.negotiator(tlsClientFactory(builder.build(),
(X509TrustManager) x509ExtendedTrustManager));
} catch (SSLException | GeneralSecurityException ex) {
log.log(Level.FINE, "Exception building SslContext", ex);
return FromChannelCredentialsResult.error(
"Unable to create SslContext: " + ex.getMessage());
@ -411,8 +444,8 @@ final class ProtocolNegotiators {
ServerTlsHandler(ChannelHandler next,
SslContext sslContext,
final ObjectPool<? extends Executor> executorPool) {
this.sslContext = checkNotNull(sslContext, "sslContext");
this.next = checkNotNull(next, "next");
this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext");
this.next = Preconditions.checkNotNull(next, "next");
if (executorPool != null) {
this.executor = executorPool.getObject();
}
@ -469,8 +502,8 @@ final class ProtocolNegotiators {
public static ProtocolNegotiator httpProxy(final SocketAddress proxyAddress,
final @Nullable String proxyUsername, final @Nullable String proxyPassword,
final ProtocolNegotiator negotiator) {
checkNotNull(negotiator, "negotiator");
checkNotNull(proxyAddress, "proxyAddress");
Preconditions.checkNotNull(negotiator, "negotiator");
Preconditions.checkNotNull(proxyAddress, "proxyAddress");
final AsciiString scheme = negotiator.scheme();
class ProxyNegotiator implements ProtocolNegotiator {
@Override
@ -516,7 +549,7 @@ final class ProtocolNegotiators {
ChannelHandler next,
ChannelLogger negotiationLogger) {
super(next, negotiationLogger);
this.address = checkNotNull(address, "address");
this.address = Preconditions.checkNotNull(address, "address");
this.userName = userName;
this.password = password;
}
@ -545,18 +578,21 @@ final class ProtocolNegotiators {
static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator {
public ClientTlsProtocolNegotiator(SslContext sslContext,
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
this.sslContext = checkNotNull(sslContext, "sslContext");
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable,
X509TrustManager x509ExtendedTrustManager) {
this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext");
this.executorPool = executorPool;
if (this.executorPool != null) {
this.executor = this.executorPool.getObject();
}
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
this.x509ExtendedTrustManager = x509ExtendedTrustManager;
}
private final SslContext sslContext;
private final ObjectPool<? extends Executor> executorPool;
private final Optional<Runnable> handshakeCompleteRunnable;
private final X509TrustManager x509ExtendedTrustManager;
private Executor executor;
@Override
@ -569,7 +605,8 @@ final class ProtocolNegotiators {
ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler);
ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger();
ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(),
this.executor, negotiationLogger, handshakeCompleteRunnable);
this.executor, negotiationLogger, handshakeCompleteRunnable, this,
x509ExtendedTrustManager);
return new WaitUntilActiveHandler(cth, negotiationLogger);
}
@ -579,6 +616,11 @@ final class ProtocolNegotiators {
this.executorPool.returnObject(this.executor);
}
}
@VisibleForTesting
boolean hasX509ExtendedTrustManager() {
return x509ExtendedTrustManager != null;
}
}
static final class ClientTlsHandler extends ProtocolNegotiationHandler {
@ -588,23 +630,28 @@ final class ProtocolNegotiators {
private final int port;
private Executor executor;
private final Optional<Runnable> handshakeCompleteRunnable;
private final X509TrustManager x509ExtendedTrustManager;
private SSLEngine sslEngine;
ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority,
Executor executor, ChannelLogger negotiationLogger,
Optional<Runnable> handshakeCompleteRunnable) {
Optional<Runnable> handshakeCompleteRunnable,
ClientTlsProtocolNegotiator clientTlsProtocolNegotiator,
X509TrustManager x509ExtendedTrustManager) {
super(next, negotiationLogger);
this.sslContext = checkNotNull(sslContext, "sslContext");
this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext");
HostPort hostPort = parseAuthority(authority);
this.host = hostPort.host;
this.port = hostPort.port;
this.executor = executor;
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
this.x509ExtendedTrustManager = x509ExtendedTrustManager;
}
@Override
@IgnoreJRERequirement
protected void handlerAdded0(ChannelHandlerContext ctx) {
SSLEngine sslEngine = sslContext.newEngine(ctx.alloc(), host, port);
sslEngine = sslContext.newEngine(ctx.alloc(), host, port);
SSLParameters sslParams = sslEngine.getSSLParameters();
sslParams.setEndpointIdentificationAlgorithm("HTTPS");
sslEngine.setSSLParameters(sslParams);
@ -661,6 +708,8 @@ final class ProtocolNegotiators {
Attributes attrs = existingPne.getAttributes().toBuilder()
.set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY)
.set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
.set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, new X509AuthorityVerifier(
sslEngine, x509ExtendedTrustManager))
.build();
replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security));
if (handshakeCompleteRunnable.isPresent()) {
@ -700,8 +749,10 @@ final class ProtocolNegotiators {
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
*/
public static ProtocolNegotiator tls(SslContext sslContext,
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable);
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable,
X509TrustManager x509ExtendedTrustManager) {
return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable,
x509ExtendedTrustManager);
}
/**
@ -709,25 +760,30 @@ final class ProtocolNegotiators {
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
* may happen immediately, even before the TLS Handshake is complete.
*/
public static ProtocolNegotiator tls(SslContext sslContext) {
return tls(sslContext, null, Optional.absent());
public static ProtocolNegotiator tls(SslContext sslContext,
X509TrustManager x509ExtendedTrustManager) {
return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager);
}
public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) {
return new TlsProtocolNegotiatorClientFactory(sslContext);
public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext,
X509TrustManager x509ExtendedTrustManager) {
return new TlsProtocolNegotiatorClientFactory(sslContext, x509ExtendedTrustManager);
}
@VisibleForTesting
static final class TlsProtocolNegotiatorClientFactory
implements ProtocolNegotiator.ClientFactory {
private final SslContext sslContext;
private final X509TrustManager x509ExtendedTrustManager;
public TlsProtocolNegotiatorClientFactory(SslContext sslContext) {
public TlsProtocolNegotiatorClientFactory(SslContext sslContext,
X509TrustManager x509ExtendedTrustManager) {
this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext");
this.x509ExtendedTrustManager = x509ExtendedTrustManager;
}
@Override public ProtocolNegotiator newNegotiator() {
return tls(sslContext);
return tls(sslContext, x509ExtendedTrustManager);
}
@Override public int getDefaultPort() {
@ -780,7 +836,9 @@ final class ProtocolNegotiators {
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
ChannelHandler upgradeHandler =
new Http2UpgradeAndGrpcHandler(grpcHandler.getAuthority(), grpcHandler);
return new WaitUntilActiveHandler(upgradeHandler, grpcHandler.getNegotiationLogger());
ChannelHandler plaintextHandler =
new PlaintextHandler(upgradeHandler, grpcHandler.getNegotiationLogger());
return new WaitUntilActiveHandler(plaintextHandler, grpcHandler.getNegotiationLogger());
}
@Override
@ -801,8 +859,8 @@ final class ProtocolNegotiators {
private ProtocolNegotiationEvent pne;
Http2UpgradeAndGrpcHandler(String authority, GrpcHttp2ConnectionHandler next) {
this.authority = checkNotNull(authority, "authority");
this.next = checkNotNull(next, "next");
this.authority = Preconditions.checkNotNull(authority, "authority");
this.next = Preconditions.checkNotNull(next, "next");
this.negotiationLogger = next.getNegotiationLogger();
}
@ -846,9 +904,9 @@ final class ProtocolNegotiators {
}
/**
* Returns a {@link ChannelHandler} that ensures that the {@code handler} is added to the
* pipeline writes to the {@link io.netty.channel.Channel} may happen immediately, even before it
* is active.
* Returns a {@link io.netty.channel.ChannelHandler} that ensures that the {@code handler} is
* added to the pipeline writes to the {@link io.netty.channel.Channel} may happen immediately,
* even before it is active.
*/
public static ProtocolNegotiator plaintext() {
return new PlaintextProtocolNegotiator();
@ -926,7 +984,7 @@ final class ProtocolNegotiators {
private final GrpcHttp2ConnectionHandler next;
public GrpcNegotiationHandler(GrpcHttp2ConnectionHandler next) {
this.next = checkNotNull(next, "next");
this.next = Preconditions.checkNotNull(next, "next");
}
@Override
@ -977,7 +1035,9 @@ final class ProtocolNegotiators {
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
ChannelHandler grpcNegotiationHandler = new GrpcNegotiationHandler(grpcHandler);
ChannelHandler activeHandler = new WaitUntilActiveHandler(grpcNegotiationHandler,
ChannelHandler plaintextHandler =
new PlaintextHandler(grpcNegotiationHandler, grpcHandler.getNegotiationLogger());
ChannelHandler activeHandler = new WaitUntilActiveHandler(plaintextHandler,
grpcHandler.getNegotiationLogger());
return activeHandler;
}
@ -991,6 +1051,22 @@ final class ProtocolNegotiators {
}
}
static final class PlaintextHandler extends ProtocolNegotiationHandler {
PlaintextHandler(ChannelHandler next, ChannelLogger negotiationLogger) {
super(next, negotiationLogger);
}
@Override
protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) {
ProtocolNegotiationEvent existingPne = getProtocolNegotiationEvent();
Attributes attrs = existingPne.getAttributes().toBuilder()
.set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, (authority) -> Status.OK)
.build();
replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs));
fireProtocolNegotiationEvent(ctx);
}
}
/**
* Waits for the channel to be active, and then installs the next Handler. Using this allows
* subsequent handlers to assume the channel is active and ready to send. Additionally, this a
@ -1048,15 +1124,15 @@ final class ProtocolNegotiators {
protected ProtocolNegotiationHandler(ChannelHandler next, String negotiatorName,
ChannelLogger negotiationLogger) {
this.next = checkNotNull(next, "next");
this.next = Preconditions.checkNotNull(next, "next");
this.negotiatorName = negotiatorName;
this.negotiationLogger = checkNotNull(negotiationLogger, "negotiationLogger");
this.negotiationLogger = Preconditions.checkNotNull(negotiationLogger, "negotiationLogger");
}
protected ProtocolNegotiationHandler(ChannelHandler next, ChannelLogger negotiationLogger) {
this.next = checkNotNull(next, "next");
this.next = Preconditions.checkNotNull(next, "next");
this.negotiatorName = getClass().getSimpleName().replace("Handler", "");
this.negotiationLogger = checkNotNull(negotiationLogger, "negotiationLogger");
this.negotiationLogger = Preconditions.checkNotNull(negotiationLogger, "negotiationLogger");
}
@Override
@ -1097,7 +1173,7 @@ final class ProtocolNegotiators {
protected final void replaceProtocolNegotiationEvent(ProtocolNegotiationEvent pne) {
checkState(this.pne != null, "previous protocol negotiation event hasn't triggered");
this.pne = checkNotNull(pne);
this.pne = Preconditions.checkNotNull(pne);
}
protected final void fireProtocolNegotiationEvent(ChannelHandlerContext ctx) {
@ -1107,4 +1183,42 @@ final class ProtocolNegotiators {
ctx.fireUserEventTriggered(pne);
}
}
static final class SslEngineWrapper extends NoopSslEngine {
private final SSLEngine sslEngine;
private final String peerHost;
SslEngineWrapper(SSLEngine sslEngine, String peerHost) {
this.sslEngine = sslEngine;
this.peerHost = peerHost;
}
@Override
public String getPeerHost() {
return peerHost;
}
@Override
public SSLSession getHandshakeSession() {
return new FakeSslSession(peerHost);
}
@Override
public SSLParameters getSSLParameters() {
return sslEngine.getSSLParameters();
}
}
static final class FakeSslSession extends NoopSslSession {
private final String peerHost;
FakeSslSession(String peerHost) {
this.peerHost = peerHost;
}
@Override
public String getPeerHost() {
return peerHost;
}
}
}

View File

@ -0,0 +1,108 @@
/*
* Copyright 2025 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.netty;
import static com.google.common.base.Preconditions.checkNotNull;
import io.grpc.Status;
import io.grpc.internal.AuthorityVerifier;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import javax.annotation.Nonnull;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.X509TrustManager;
final class X509AuthorityVerifier implements AuthorityVerifier {
private final SSLEngine sslEngine;
private final X509TrustManager x509ExtendedTrustManager;
private static final Method checkServerTrustedMethod;
static {
Method method = null;
try {
Class<?> x509ExtendedTrustManagerClass =
Class.forName("javax.net.ssl.X509ExtendedTrustManager");
method = x509ExtendedTrustManagerClass.getMethod("checkServerTrusted",
X509Certificate[].class, String.class, SSLEngine.class);
} catch (ClassNotFoundException e) {
// Per-rpc authority overriding via call options will be disallowed.
} catch (NoSuchMethodException e) {
// Should never happen since X509ExtendedTrustManager was introduced in Android API level 24
// along with checkServerTrusted.
}
checkServerTrustedMethod = method;
}
public X509AuthorityVerifier(SSLEngine sslEngine, X509TrustManager x509ExtendedTrustManager) {
this.sslEngine = checkNotNull(sslEngine, "sslEngine");
this.x509ExtendedTrustManager = x509ExtendedTrustManager;
}
@Override
public Status verifyAuthority(@Nonnull String authority) {
if (x509ExtendedTrustManager == null) {
return Status.UNAVAILABLE.withDescription(
"Can't allow authority override in rpc when X509ExtendedTrustManager"
+ " is not available");
}
Status peerVerificationStatus;
try {
// Because the authority pseudo-header can contain a port number:
// https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2.3
verifyAuthorityAllowedForPeerCert(removeAnyPortNumber(authority));
peerVerificationStatus = Status.OK;
} catch (SSLPeerUnverifiedException | CertificateException | InvocationTargetException
| IllegalAccessException | IllegalStateException e) {
peerVerificationStatus = Status.UNAVAILABLE.withDescription(
String.format("Peer hostname verification during rpc failed for authority '%s'",
authority)).withCause(e);
}
return peerVerificationStatus;
}
private String removeAnyPortNumber(String authority) {
int closingSquareBracketIndex = authority.lastIndexOf(']');
int portNumberSeperatorColonIndex = authority.lastIndexOf(':');
if (portNumberSeperatorColonIndex > closingSquareBracketIndex) {
return authority.substring(0, portNumberSeperatorColonIndex);
}
return authority;
}
private void verifyAuthorityAllowedForPeerCert(String authority)
throws SSLPeerUnverifiedException, CertificateException, InvocationTargetException,
IllegalAccessException {
SSLEngine sslEngineWrapper = new ProtocolNegotiators.SslEngineWrapper(sslEngine, authority);
// The typecasting of Certificate to X509Certificate should work because this method will only
// be called when using TLS and thus X509.
Certificate[] peerCertificates = sslEngine.getSession().getPeerCertificates();
X509Certificate[] x509PeerCertificates = new X509Certificate[peerCertificates.length];
for (int i = 0; i < peerCertificates.length; i++) {
x509PeerCertificates[i] = (X509Certificate) peerCertificates[i];
}
if (checkServerTrustedMethod == null) {
throw new IllegalStateException("checkServerTrustedMethod not found");
}
checkServerTrustedMethod.invoke(
x509ExtendedTrustManager, x509PeerCertificates, "RSA", sslEngineWrapper);
}
}

View File

@ -36,6 +36,7 @@ import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.same;
@ -64,6 +65,7 @@ import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.ClientStreamListener.RpcProgress;
import io.grpc.internal.ClientTransport;
import io.grpc.internal.ClientTransport.PingCallback;
import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.ManagedClientTransport;
@ -90,10 +92,12 @@ import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.codec.http2.Http2Stream;
import io.netty.util.AsciiString;
import java.io.InputStream;
import java.security.cert.CertificateException;
import java.text.MessageFormat;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Handler;
@ -189,7 +193,11 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
})
.when(streamListener)
.messagesAvailable(ArgumentMatchers.<StreamListener.MessageProducer>any());
doAnswer((attributes) -> Attributes.newBuilder().set(
GrpcAttributes.ATTR_AUTHORITY_VERIFIER,
(authority) -> Status.OK).build())
.when(listener)
.filterTransport(ArgumentMatchers.any(Attributes.class));
lifecycleManager = new ClientTransportLifecycleManager(listener);
// This mocks the keepalive manager only for there's in which we verify it. For other tests
// it'll be null which will be testing if we behave correctly when it's not present.
@ -919,6 +927,159 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
assertFalse(channel().isOpen());
}
@Test
public void missingAuthorityHeader_streamCreationShouldFail() throws Exception {
Http2Headers grpcHeadersWithoutAuthority = new DefaultHttp2Headers()
.scheme(HTTPS)
.path(as("/fakemethod"))
.method(HTTP_METHOD)
.add(as("auth"), as("sometoken"))
.add(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC)
.add(TE_HEADER, TE_TRAILERS);
ChannelFuture channelFuture = enqueue(newCreateStreamCommand(
grpcHeadersWithoutAuthority, streamTransportState));
try {
channelFuture.get();
fail("Expected stream creation failure");
} catch (ExecutionException e) {
assertThat(e.getCause().getMessage()).isEqualTo("UNAVAILABLE: Missing authority header");
}
}
@Test
public void missingAuthorityVerifierInAttributes_streamCreationShouldFail() throws Exception {
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
StreamListener.MessageProducer producer =
(StreamListener.MessageProducer) invocation.getArguments()[0];
InputStream message;
while ((message = producer.next()) != null) {
streamListenerMessageQueue.add(message);
}
return null;
}
})
.when(streamListener)
.messagesAvailable(ArgumentMatchers.<StreamListener.MessageProducer>any());
doAnswer((attributes) -> Attributes.EMPTY)
.when(listener)
.filterTransport(ArgumentMatchers.any(Attributes.class));
lifecycleManager = new ClientTransportLifecycleManager(listener);
// This mocks the keepalive manager only for there's in which we verify it. For other tests
// it'll be null which will be testing if we behave correctly when it's not present.
if (setKeepaliveManagerFor.contains(testNameRule.getMethodName())) {
mockKeepAliveManager = mock(KeepAliveManager.class);
}
initChannel(new GrpcHttp2ClientHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE));
streamTransportState = new TransportStateImpl(
handler(),
channel().eventLoop(),
DEFAULT_MAX_MESSAGE_SIZE,
transportTracer);
streamTransportState.setListener(streamListener);
grpcHeaders = new DefaultHttp2Headers()
.scheme(HTTPS)
.authority(as("www.fake.com"))
.path(as("/fakemethod"))
.method(HTTP_METHOD)
.add(as("auth"), as("sometoken"))
.add(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC)
.add(TE_HEADER, TE_TRAILERS);
// Simulate receipt of initial remote settings.
ByteBuf serializedSettings = serializeSettings(new Http2Settings());
channelRead(serializedSettings);
channel().releaseOutbound();
ChannelFuture channelFuture = createStream();
try {
channelFuture.get();
fail("Expected stream creation failure");
} catch (ExecutionException e) {
assertThat(e.getCause().getMessage()).isEqualTo(
"UNAVAILABLE: Authority verifier not found to verify authority");
}
}
@Test
public void authorityVerificationSuccess_streamCreationSucceeds() throws Exception {
NettyClientHandler.enablePerRpcAuthorityCheck = true;
try {
ChannelFuture channelFuture = createStream();
channelFuture.get();
} finally {
NettyClientHandler.enablePerRpcAuthorityCheck = false;
}
}
@Test
public void authorityVerificationFailure_streamCreationFails() throws Exception {
NettyClientHandler.enablePerRpcAuthorityCheck = true;
try {
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
StreamListener.MessageProducer producer =
(StreamListener.MessageProducer) invocation.getArguments()[0];
InputStream message;
while ((message = producer.next()) != null) {
streamListenerMessageQueue.add(message);
}
return null;
}
})
.when(streamListener)
.messagesAvailable(ArgumentMatchers.<StreamListener.MessageProducer>any());
doAnswer((attributes) -> Attributes.newBuilder().set(
GrpcAttributes.ATTR_AUTHORITY_VERIFIER,
(authority) -> Status.UNAVAILABLE.withCause(
new CertificateException("Peer verification failed"))).build())
.when(listener)
.filterTransport(ArgumentMatchers.any(Attributes.class));
lifecycleManager = new ClientTransportLifecycleManager(listener);
// This mocks the keepalive manager only for there's in which we verify it. For other tests
// it'll be null which will be testing if we behave correctly when it's not present.
if (setKeepaliveManagerFor.contains(testNameRule.getMethodName())) {
mockKeepAliveManager = mock(KeepAliveManager.class);
}
initChannel(new GrpcHttp2ClientHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE));
streamTransportState = new TransportStateImpl(
handler(),
channel().eventLoop(),
DEFAULT_MAX_MESSAGE_SIZE,
transportTracer);
streamTransportState.setListener(streamListener);
grpcHeaders = new DefaultHttp2Headers()
.scheme(HTTPS)
.authority(as("www.fake.com"))
.path(as("/fakemethod"))
.method(HTTP_METHOD)
.add(as("auth"), as("sometoken"))
.add(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC)
.add(TE_HEADER, TE_TRAILERS);
// Simulate receipt of initial remote settings.
ByteBuf serializedSettings = serializeSettings(new Http2Settings());
channelRead(serializedSettings);
channel().releaseOutbound();
ChannelFuture channelFuture = createStream();
try {
channelFuture.get();
fail("Expected stream creation failure");
} catch (ExecutionException e) {
assertThat(e.getMessage()).isEqualTo("io.grpc.InternalStatusRuntimeException: UNAVAILABLE");
}
} finally {
NettyClientHandler.enablePerRpcAuthorityCheck = false;
}
}
@Override
protected void makeStream() throws Exception {
createStream();

View File

@ -59,6 +59,7 @@ import io.grpc.ServerStreamTracer;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StatusException;
import io.grpc.TlsChannelCredentials;
import io.grpc.internal.ClientStream;
import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.ClientTransport;
@ -76,6 +77,7 @@ import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker;
import io.grpc.netty.NettyTestUtil.TrackingObjectPoolForTest;
import io.grpc.testing.TlsTesting;
import io.grpc.util.CertificateUtils;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelConfig;
@ -101,9 +103,14 @@ import io.netty.util.ReferenceCountUtil;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
@ -115,8 +122,15 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509ExtendedTrustManager;
import javax.net.ssl.X509TrustManager;
import javax.security.auth.x500.X500Principal;
import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
@ -131,6 +145,7 @@ import org.mockito.junit.MockitoRule;
* Tests for {@link NettyClientTransport}.
*/
@RunWith(JUnit4.class)
@IgnoreJRERequirement
public class NettyClientTransportTest {
@Rule public final MockitoRule mocks = MockitoJUnit.rule();
@ -203,7 +218,7 @@ public class NettyClientTransportTest {
}
@Test
public void setSoLingerChannelOption() throws IOException {
public void setSoLingerChannelOption() throws IOException, GeneralSecurityException {
startServer();
Map<ChannelOption<?>, Object> channelOptions = new HashMap<>();
// set SO_LINGER option
@ -354,7 +369,7 @@ public class NettyClientTransportTest {
.trustManager(caCert)
.keyManager(clientCert, clientKey)
.build();
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext);
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, null);
final NettyClientTransport transport = newTransport(negotiator);
callMeMaybe(transport.start(clientTransportListener));
verify(clientTransportListener, timeout(5000)).transportTerminated();
@ -821,7 +836,7 @@ public class NettyClientTransportTest {
.keyManager(clientCert, clientKey)
.build();
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool,
Optional.absent());
Optional.absent(), null);
// after starting the client, the Executor in the client pool should be used
assertEquals(true, clientExecutorPool.isInUse());
final NettyClientTransport transport = newTransport(negotiator);
@ -836,6 +851,179 @@ public class NettyClientTransportTest {
assertEquals(false, serverExecutorPool.isInUse());
}
/**
* This test tests the case of TlsCredentials passed to ProtocolNegotiators not having an instance
* of X509ExtendedTrustManager (this is not testable in ProtocolNegotiatorsTest without creating
* accessors for the internal state of negotiator whether it has a X509ExtendedTrustManager,
* hence the need to test it in this class instead). To establish a successful handshake we create
* a fake X509TrustManager not implementing X509ExtendedTrustManager but wraps the real
* X509ExtendedTrustManager.
*/
@Test
public void authorityOverrideInCallOptions_noX509ExtendedTrustManager_newStreamCreationFails()
throws IOException, InterruptedException, GeneralSecurityException, ExecutionException,
TimeoutException {
NettyClientHandler.enablePerRpcAuthorityCheck = true;
try {
startServer();
InputStream caCert = TlsTesting.loadCert("ca.pem");
X509TrustManager x509ExtendedTrustManager =
(X509TrustManager) getX509ExtendedTrustManager(caCert);
ProtocolNegotiators.FromChannelCredentialsResult result =
ProtocolNegotiators.from(TlsChannelCredentials.newBuilder()
.trustManager(new FakeTrustManager(x509ExtendedTrustManager)).build());
NettyClientTransport transport = newTransport(result.negotiator.newNegotiator());
SettableFuture<Void> connected = SettableFuture.create();
FakeClientTransportListener fakeClientTransportListener =
new FakeClientTransportListener(connected);
callMeMaybe(transport.start(fakeClientTransportListener));
connected.get(10, TimeUnit.SECONDS);
assertThat(fakeClientTransportListener.isConnected()).isTrue();
Rpc rpc = new Rpc(transport, new Metadata(), "foo.test.google.in");
try {
rpc.waitForClose();
fail("Expected exception in starting stream");
} catch (ExecutionException ex) {
Status status = ((StatusException) ex.getCause()).getStatus();
assertThat(status.getDescription()).isEqualTo("Can't allow authority override in rpc "
+ "when X509ExtendedTrustManager is not available");
assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE);
}
} finally {
NettyClientHandler.enablePerRpcAuthorityCheck = false;
}
}
@Test
public void authorityOverrideInCallOptions_doesntMatchServerPeerHost_newStreamCreationFails()
throws IOException, InterruptedException, GeneralSecurityException, ExecutionException,
TimeoutException {
NettyClientHandler.enablePerRpcAuthorityCheck = true;
try {
startServer();
NettyClientTransport transport = newTransport(newNegotiator());
SettableFuture<Void> connected = SettableFuture.create();
FakeClientTransportListener fakeClientTransportListener =
new FakeClientTransportListener(connected);
callMeMaybe(transport.start(fakeClientTransportListener));
connected.get(10, TimeUnit.SECONDS);
assertThat(fakeClientTransportListener.isConnected()).isTrue();
Rpc rpc = new Rpc(transport, new Metadata(), "foo.test.google.in");
try {
rpc.waitForClose();
fail("Expected exception in starting stream");
} catch (ExecutionException ex) {
Status status = ((StatusException) ex.getCause()).getStatus();
assertThat(status.getDescription()).isEqualTo("Peer hostname verification during rpc "
+ "failed for authority 'foo.test.google.in'");
assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE);
assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException())
.isInstanceOf(CertificateException.class);
assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException()
.getMessage()).isEqualTo(
"No subject alternative DNS name matching foo.test.google.in found.");
}
} finally {
NettyClientHandler.enablePerRpcAuthorityCheck = false;
}
}
@Test
public void authorityOverrideInCallOptions_matchesServerPeerHost_newStreamCreationSucceeds()
throws IOException, InterruptedException, GeneralSecurityException, ExecutionException,
TimeoutException {
NettyClientHandler.enablePerRpcAuthorityCheck = true;
try {
startServer();
NettyClientTransport transport = newTransport(newNegotiator());
SettableFuture<Void> connected = SettableFuture.create();
FakeClientTransportListener fakeClientTransportListener =
new FakeClientTransportListener(connected);
callMeMaybe(transport.start(fakeClientTransportListener));
connected.get(10, TimeUnit.SECONDS);
assertThat(fakeClientTransportListener.isConnected()).isTrue();
new Rpc(transport, new Metadata(), "foo.test.google.fr").waitForResponse();
} finally {
NettyClientHandler.enablePerRpcAuthorityCheck = false;;
}
}
// Without removing the port number part that {@link X509AuthorityVerifier} does, there will be a
// java.security.cert.CertificateException: Illegal given domain name: foo.test.google.fr:12345
@Test
public void authorityOverrideInCallOptions_portNumberInAuthority_isStrippedForPeerVerification()
throws IOException, InterruptedException, GeneralSecurityException, ExecutionException,
TimeoutException {
NettyClientHandler.enablePerRpcAuthorityCheck = true;
try {
startServer();
NettyClientTransport transport = newTransport(newNegotiator());
SettableFuture<Void> connected = SettableFuture.create();
FakeClientTransportListener fakeClientTransportListener =
new FakeClientTransportListener(connected);
callMeMaybe(transport.start(fakeClientTransportListener));
connected.get(10, TimeUnit.SECONDS);
assertThat(fakeClientTransportListener.isConnected()).isTrue();
new Rpc(transport, new Metadata(), "foo.test.google.fr:12345").waitForResponse();
} finally {
NettyClientHandler.enablePerRpcAuthorityCheck = false;;
}
}
@Test
public void authorityOverrideInCallOptions_portNumberAndIpv6_isStrippedForPeerVerification()
throws IOException, InterruptedException, GeneralSecurityException, ExecutionException,
TimeoutException {
NettyClientHandler.enablePerRpcAuthorityCheck = true;
try {
startServer();
NettyClientTransport transport = newTransport(newNegotiator());
SettableFuture<Void> connected = SettableFuture.create();
FakeClientTransportListener fakeClientTransportListener =
new FakeClientTransportListener(connected);
callMeMaybe(transport.start(fakeClientTransportListener));
connected.get(10, TimeUnit.SECONDS);
assertThat(fakeClientTransportListener.isConnected()).isTrue();
new Rpc(transport, new Metadata(), "[2001:db8:3333:4444:5555:6666:1.2.3.4]:12345")
.waitForResponse();
} catch (ExecutionException ex) {
Status status = ((StatusException) ex.getCause()).getStatus();
assertThat(status.getDescription()).isEqualTo("Peer hostname verification during rpc "
+ "failed for authority '[2001:db8:3333:4444:5555:6666:1.2.3.4]:12345'");
assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE);
assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException())
.isInstanceOf(CertificateException.class);
// Port number is removed by {@link X509AuthorityVerifier}.
assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException()
.getMessage()).isEqualTo(
"No subject alternative names matching IP address 2001:db8:3333:4444:5555:6666:1.2.3.4 "
+ "found");
} finally {
NettyClientHandler.enablePerRpcAuthorityCheck = false;;
}
}
@Test
public void authorityOverrideInCallOptions_notMatches_flagDisabled_createsStream()
throws IOException, InterruptedException, GeneralSecurityException, ExecutionException,
TimeoutException {
startServer();
NettyClientTransport transport = newTransport(newNegotiator());
SettableFuture<Void> connected = SettableFuture.create();
FakeClientTransportListener fakeClientTransportListener =
new FakeClientTransportListener(connected);
callMeMaybe(transport.start(fakeClientTransportListener));
connected.get(10, TimeUnit.SECONDS);
assertThat(fakeClientTransportListener.isConnected()).isTrue();
new Rpc(transport, new Metadata(), "foo.test.google.in").waitForResponse();
}
private Throwable getRootCause(Throwable t) {
if (t.getCause() == null) {
return t;
@ -843,10 +1031,37 @@ public class NettyClientTransportTest {
return getRootCause(t.getCause());
}
private ProtocolNegotiator newNegotiator() throws IOException {
private ProtocolNegotiator newNegotiator() throws IOException, GeneralSecurityException {
InputStream caCert = TlsTesting.loadCert("ca.pem");
SslContext clientContext = GrpcSslContexts.forClient().trustManager(caCert).build();
return ProtocolNegotiators.tls(clientContext);
return ProtocolNegotiators.tls(clientContext,
(X509TrustManager) getX509ExtendedTrustManager(TlsTesting.loadCert("ca.pem")));
}
private static TrustManager getX509ExtendedTrustManager(InputStream rootCerts)
throws GeneralSecurityException {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
try {
ks.load(null, null);
} catch (IOException ex) {
// Shouldn't really happen, as we're not loading any data.
throw new GeneralSecurityException(ex);
}
X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts);
for (X509Certificate cert : certs) {
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
}
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(ks);
for (TrustManager trustManager : trustManagerFactory.getTrustManagers()) {
if (trustManager instanceof X509ExtendedTrustManager) {
return trustManager;
}
}
return null;
}
private NettyClientTransport newTransport(ProtocolNegotiator negotiator) {
@ -965,13 +1180,20 @@ public class NettyClientTransportTest {
final TestClientStreamListener listener = new TestClientStreamListener();
Rpc(NettyClientTransport transport) {
this(transport, new Metadata());
this(transport, new Metadata(), null);
}
Rpc(NettyClientTransport transport, Metadata headers) {
this(transport, headers, null);
}
Rpc(NettyClientTransport transport, Metadata headers, String authorityOverride) {
stream = transport.newStream(
METHOD, headers, CallOptions.DEFAULT,
new ClientStreamTracer[]{ new ClientStreamTracer() {} });
if (authorityOverride != null) {
stream.setAuthority(authorityOverride);
}
stream.start(listener);
stream.request(1);
stream.writeMessage(new ByteArrayInputStream(MESSAGE.getBytes(UTF_8)));
@ -1169,4 +1391,62 @@ public class NettyClientTransportTest {
@Override
public void log(ChannelLogLevel level, String messageFormat, Object... args) {}
}
static class FakeClientTransportListener implements ManagedClientTransport.Listener {
private final SettableFuture<Void> connected;
@GuardedBy("this")
private boolean isConnected = false;
public FakeClientTransportListener(SettableFuture<Void> connected) {
this.connected = connected;
}
@Override
public void transportShutdown(Status s) {}
@Override
public void transportTerminated() {}
@Override
public void transportReady() {
synchronized (this) {
isConnected = true;
}
connected.set(null);
}
synchronized boolean isConnected() {
return isConnected;
}
@Override
public void transportInUse(boolean inUse) {}
}
private static class FakeTrustManager implements X509TrustManager {
private final X509TrustManager delegate;
public FakeTrustManager(X509TrustManager x509ExtendedTrustManager) {
this.delegate = x509ExtendedTrustManager;
}
@Override
public void checkClientTrusted(X509Certificate[] x509Certificates, String s)
throws CertificateException {
delegate.checkClientTrusted(x509Certificates, s);
}
@Override
public void checkServerTrusted(X509Certificate[] x509Certificates, String s)
throws CertificateException {
delegate.checkServerTrusted(x509Certificates, s);
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return delegate.getAcceptedIssuers();
}
}
}

View File

@ -112,10 +112,14 @@ import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayDeque;
import java.util.Arrays;
@ -222,13 +226,52 @@ public class ProtocolNegotiatorsTest {
}
@Test
public void fromClient_tls() {
public void fromClient_tls_trustManager()
throws KeyStoreException, CertificateException, IOException, NoSuchAlgorithmException {
KeyStore certStore = KeyStore.getInstance(KeyStore.getDefaultType());
certStore.load(null);
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
try (InputStream ca = TlsTesting.loadCert("ca.pem")) {
for (X509Certificate cert : CertificateUtils.getX509Certificates(ca)) {
certStore.setCertificateEntry(cert.getSubjectX500Principal().getName("RFC2253"), cert);
}
}
trustManagerFactory.init(certStore);
ProtocolNegotiators.FromChannelCredentialsResult result =
ProtocolNegotiators.from(TlsChannelCredentials.newBuilder()
.trustManager(trustManagerFactory.getTrustManagers()).build());
assertThat(result.error).isNull();
assertThat(result.callCredentials).isNull();
assertThat(result.negotiator)
.isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class);
assertThat(((ClientTlsProtocolNegotiator) result.negotiator.newNegotiator())
.hasX509ExtendedTrustManager()).isTrue();
}
@Test
public void fromClient_tls_CaCertsInputStream() throws IOException {
ProtocolNegotiators.FromChannelCredentialsResult result =
ProtocolNegotiators.from(TlsChannelCredentials.newBuilder()
.trustManager(TlsTesting.loadCert("ca.pem")).build());
assertThat(result.error).isNull();
assertThat(result.callCredentials).isNull();
assertThat(result.negotiator)
.isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class);
assertThat(((ClientTlsProtocolNegotiator) result.negotiator.newNegotiator())
.hasX509ExtendedTrustManager()).isTrue();
}
@Test
public void fromClient_tls_systemDefault() {
ProtocolNegotiators.FromChannelCredentialsResult result =
ProtocolNegotiators.from(TlsChannelCredentials.create());
assertThat(result.error).isNull();
assertThat(result.callCredentials).isNull();
assertThat(result.negotiator)
.isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class);
assertThat(((ClientTlsProtocolNegotiator) result.negotiator.newNegotiator())
.hasX509ExtendedTrustManager()).isTrue();
}
@Test
@ -877,7 +920,8 @@ public class ProtocolNegotiatorsTest {
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", elg, noopLogger, Optional.absent());
"authority", elg, noopLogger, Optional.absent(),
getClientTlsProtocolNegotiator(), null);
pipeline.addLast(handler);
pipeline.replace(SslHandler.class, null, goodSslHandler);
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
@ -915,7 +959,8 @@ public class ProtocolNegotiatorsTest {
.applicationProtocolConfig(apn).build();
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", elg, noopLogger, Optional.absent());
"authority", elg, noopLogger, Optional.absent(),
getClientTlsProtocolNegotiator(), null);
pipeline.addLast(handler);
pipeline.replace(SslHandler.class, null, goodSslHandler);
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
@ -939,7 +984,8 @@ public class ProtocolNegotiatorsTest {
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", elg, noopLogger, Optional.absent());
"authority", elg, noopLogger, Optional.absent(),
getClientTlsProtocolNegotiator(), null);
pipeline.addLast(handler);
final AtomicReference<Throwable> error = new AtomicReference<>();
@ -967,7 +1013,8 @@ public class ProtocolNegotiatorsTest {
@Test
public void clientTlsHandler_closeDuringNegotiation() throws Exception {
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", null, noopLogger, Optional.absent());
"authority", null, noopLogger, Optional.absent(),
getClientTlsProtocolNegotiator(), null);
pipeline.addLast(new WriteBufferingAndExceptionHandler(handler));
ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
@ -979,6 +1026,12 @@ public class ProtocolNegotiatorsTest {
.isEqualTo(Status.Code.UNAVAILABLE);
}
private ClientTlsProtocolNegotiator getClientTlsProtocolNegotiator() throws SSLException {
return new ClientTlsProtocolNegotiator(GrpcSslContexts.forClient().trustManager(
TlsTesting.loadCert("ca.pem")).build(),
null, Optional.absent(), null);
}
@Test
public void engineLog() {
ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
@ -1007,7 +1060,7 @@ public class ProtocolNegotiatorsTest {
public void tls_failsOnNullSslContext() {
thrown.expect(NullPointerException.class);
Object unused = ProtocolNegotiators.tls(null);
Object unused = ProtocolNegotiators.tls(null, null);
}
@Test
@ -1230,7 +1283,7 @@ public class ProtocolNegotiatorsTest {
}
FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext,
null, Optional.absent());
null, Optional.absent(), null);
WriteBufferingAndExceptionHandler clientWbaeh =
new WriteBufferingAndExceptionHandler(pn.newHandler(gh));

View File

@ -81,8 +81,6 @@ import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;
/** Convenience class for building channels with the OkHttp transport. */
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785")
@ -705,32 +703,12 @@ public final class OkHttpChannelBuilder extends ForwardingChannelBuilder2<OkHttp
static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException {
InputStream rootCertsStream = new ByteArrayInputStream(rootCerts);
try {
return createTrustManager(rootCertsStream);
return io.grpc.internal.CertificateUtils.createTrustManager(rootCertsStream);
} finally {
GrpcUtil.closeQuietly(rootCertsStream);
}
}
static TrustManager[] createTrustManager(InputStream rootCerts) throws GeneralSecurityException {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
try {
ks.load(null, null);
} catch (IOException ex) {
// Shouldn't really happen, as we're not loading any data.
throw new GeneralSecurityException(ex);
}
X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts);
for (X509Certificate cert : certs) {
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
}
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(ks);
return trustManagerFactory.getTrustManagers();
}
static Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}

View File

@ -34,6 +34,7 @@ import io.grpc.CompositeChannelCredentials;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
import io.grpc.TlsChannelCredentials;
import io.grpc.internal.CertificateUtils;
import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
import io.grpc.internal.FakeClock;
@ -212,7 +213,7 @@ public class OkHttpChannelBuilderTest {
TrustManager[] trustManagers;
try (InputStream ca = TlsTesting.loadCert("ca.pem")) {
trustManagers = OkHttpChannelBuilder.createTrustManager(ca);
trustManagers = CertificateUtils.createTrustManager(ca);
}
SSLContext serverContext = SSLContext.getInstance("TLS");
@ -257,7 +258,7 @@ public class OkHttpChannelBuilderTest {
InputStream ca = TlsTesting.loadCert("ca.pem")) {
serverContext.init(
OkHttpChannelBuilder.createKeyManager(server1Chain, server1Key),
OkHttpChannelBuilder.createTrustManager(ca),
CertificateUtils.createTrustManager(ca),
null);
}
final SSLServerSocket serverListenSocket =