diff --git a/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java index 4711af97be..922bca8cd4 100644 --- a/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java @@ -63,6 +63,19 @@ public final class InternalNettyChannelBuilder { builder.setDynamicParamsFactory(factory); } + /** A class that provides a Netty handler to control protocol negotiation. */ + public interface ProtocolNegotiatorFactory + extends NettyChannelBuilder.ProtocolNegotiatorFactory {} + + /** + * Sets the {@link ProtocolNegotiatorFactory} to be used. Overrides any specified negotiation type + * and {@code SslContext}. + */ + public static void setProtocolNegotiatorFactory( + NettyChannelBuilder builder, ProtocolNegotiatorFactory protocolNegotiator) { + builder.protocolNegotiatorFactory(protocolNegotiator); + } + public static void setStatsEnabled(NettyChannelBuilder builder, boolean value) { builder.setStatsEnabled(value); } diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 633cb32324..15a3ddd5da 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -81,6 +81,7 @@ public final class NettyChannelBuilder private long keepAliveTimeoutNanos = DEFAULT_KEEPALIVE_TIMEOUT_NANOS; private boolean keepAliveWithoutCalls; private TransportCreationParamsFilterFactory dynamicParamsFactory; + private ProtocolNegotiatorFactory protocolNegotiatorFactory; /** * Creates a new builder with the given server address. This factory method is primarily intended @@ -334,16 +335,20 @@ public final class NettyChannelBuilder TransportCreationParamsFilterFactory transportCreationParamsFilterFactory = dynamicParamsFactory; if (transportCreationParamsFilterFactory == null) { - SslContext localSslContext = sslContext; - if (negotiationType == NegotiationType.TLS && localSslContext == null) { - try { - localSslContext = GrpcSslContexts.forClient().build(); - } catch (SSLException ex) { - throw new RuntimeException(ex); + ProtocolNegotiator negotiator; + if (protocolNegotiatorFactory != null) { + negotiator = protocolNegotiatorFactory.buildProtocolNegotiator(); + } else { + SslContext localSslContext = sslContext; + if (negotiationType == NegotiationType.TLS && localSslContext == null) { + try { + localSslContext = GrpcSslContexts.forClient().build(); + } catch (SSLException ex) { + throw new RuntimeException(ex); + } } + negotiator = createProtocolNegotiatorByType(negotiationType, localSslContext); } - ProtocolNegotiator negotiator = - createProtocolNegotiatorByType(negotiationType, localSslContext); transportCreationParamsFilterFactory = new DefaultNettyTransportCreationParamsFilterFactory(negotiator); } @@ -413,6 +418,11 @@ public final class NettyChannelBuilder this.dynamicParamsFactory = checkNotNull(factory, "factory"); } + void protocolNegotiatorFactory(ProtocolNegotiatorFactory protocolNegotiatorFactory) { + this.protocolNegotiatorFactory + = Preconditions.checkNotNull(protocolNegotiatorFactory, "protocolNegotiatorFactory"); + } + @Override protected void setTracingEnabled(boolean value) { super.setTracingEnabled(value); @@ -454,6 +464,14 @@ public final class NettyChannelBuilder ProtocolNegotiator getProtocolNegotiator(); } + interface ProtocolNegotiatorFactory { + /** + * Returns a ProtocolNegotatior instance configured for this Builder. This method is called + * during {@code ManagedChannelBuilder#build()}. + */ + ProtocolNegotiator buildProtocolNegotiator(); + } + /** * Creates Netty transports. Exposed for internal use, as it should be private. */