diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 3065ec2088..af352dc2c3 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -17,13 +17,13 @@ package io.grpc.netty; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static io.grpc.internal.GrpcUtil.DEFAULT_KEEPALIVE_TIMEOUT_NANOS; import static io.grpc.internal.GrpcUtil.DEFAULT_KEEPALIVE_TIME_NANOS; import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; @@ -40,8 +40,10 @@ import io.grpc.internal.KeepAliveManager; import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.TransportTracer; import io.netty.channel.Channel; +import io.netty.channel.ChannelFactory; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; +import io.netty.channel.ReflectiveChannelFactory; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.ssl.SslContext; import java.net.InetSocketAddress; @@ -70,7 +72,8 @@ public final class NettyChannelBuilder private NegotiationType negotiationType = NegotiationType.TLS; private OverrideAuthorityChecker authorityChecker; - private Class channelType = NioSocketChannel.class; + private ChannelFactory channelFactory = + new ReflectiveChannelFactory<>(NioSocketChannel.class); @Nullable private EventLoopGroup eventLoopGroup; @@ -138,9 +141,23 @@ public final class NettyChannelBuilder /** * Specifies the channel type to use, by default we use {@link NioSocketChannel}. + * + *

You either use this or {@link #channelFactory(io.netty.channel.ChannelFactory)} if your + * {@link Channel} implementation has no no-args constructor. */ public NettyChannelBuilder channelType(Class channelType) { - this.channelType = Preconditions.checkNotNull(channelType, "channelType"); + checkNotNull(channelType, "channelType"); + return channelFactory(new ReflectiveChannelFactory<>(channelType)); + } + + /** + * Specifies the {@link ChannelFactory} to create {@link Channel} instances. This method is + * usually only used if the specific {@code Channel} requires complex logic which requires + * additional information to create the {@code Channel}. Otherwise, recommend to use {@link + * #channelType(Class)}. + */ + public NettyChannelBuilder channelFactory(ChannelFactory channelFactory) { + this.channelFactory = checkNotNull(channelFactory, "channelFactory"); return this; } @@ -390,7 +407,7 @@ public final class NettyChannelBuilder negotiator = createProtocolNegotiatorByType(negotiationType, localSslContext); } return new NettyTransportFactory( - negotiator, channelType, channelOptions, + negotiator, channelFactory, channelOptions, eventLoopGroup, flowControlWindow, maxInboundMessageSize(), maxHeaderListSize, keepAliveTimeNanos, keepAliveTimeoutNanos, keepAliveWithoutCalls, transportTracerFactory.create(), localSocketPicker); @@ -453,7 +470,7 @@ public final class NettyChannelBuilder void protocolNegotiatorFactory(ProtocolNegotiatorFactory protocolNegotiatorFactory) { this.protocolNegotiatorFactory - = Preconditions.checkNotNull(protocolNegotiatorFactory, "protocolNegotiatorFactory"); + = checkNotNull(protocolNegotiatorFactory, "protocolNegotiatorFactory"); } @Override @@ -496,7 +513,7 @@ public final class NettyChannelBuilder @CheckReturnValue private static final class NettyTransportFactory implements ClientTransportFactory { private final ProtocolNegotiator protocolNegotiator; - private final Class channelType; + private final ChannelFactory channelFactory; private final Map, ?> channelOptions; private final EventLoopGroup group; private final boolean usingSharedGroup; @@ -512,12 +529,12 @@ public final class NettyChannelBuilder private boolean closed; NettyTransportFactory(ProtocolNegotiator protocolNegotiator, - Class channelType, Map, ?> channelOptions, + ChannelFactory channelFactory, Map, ?> channelOptions, EventLoopGroup group, int flowControlWindow, int maxMessageSize, int maxHeaderListSize, long keepAliveTimeNanos, long keepAliveTimeoutNanos, boolean keepAliveWithoutCalls, TransportTracer transportTracer, LocalSocketPicker localSocketPicker) { this.protocolNegotiator = protocolNegotiator; - this.channelType = channelType; + this.channelFactory = channelFactory; this.channelOptions = new HashMap, Object>(channelOptions); this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; @@ -562,7 +579,7 @@ public final class NettyChannelBuilder }; NettyClientTransport transport = new NettyClientTransport( - serverAddress, channelType, channelOptions, group, + serverAddress, channelFactory, channelOptions, group, localNegotiator, flowControlWindow, maxMessageSize, maxHeaderListSize, keepAliveTimeNanosState.get(), keepAliveTimeoutNanos, keepAliveWithoutCalls, options.getAuthority(), options.getUserAgent(), diff --git a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java index c2baedf95f..06e3f8492d 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java @@ -43,13 +43,13 @@ import io.grpc.internal.TransportTracer; import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; +import io.netty.channel.ChannelFactory; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoop; import io.netty.channel.EventLoopGroup; -import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.http2.StreamBufferingEncoder.Http2ChannelClosedException; import io.netty.util.AsciiString; import io.netty.util.concurrent.Future; @@ -67,7 +67,7 @@ class NettyClientTransport implements ConnectionClientTransport { private final InternalLogId logId; private final Map, ?> channelOptions; private final SocketAddress remoteAddress; - private final Class channelType; + private final ChannelFactory channelFactory; private final EventLoopGroup group; private final ProtocolNegotiator negotiator; private final String authorityString; @@ -96,7 +96,7 @@ class NettyClientTransport implements ConnectionClientTransport { private final LocalSocketPicker localSocketPicker; NettyClientTransport( - SocketAddress address, Class channelType, + SocketAddress address, ChannelFactory channelFactory, Map, ?> channelOptions, EventLoopGroup group, ProtocolNegotiator negotiator, int flowControlWindow, int maxMessageSize, int maxHeaderListSize, long keepAliveTimeNanos, long keepAliveTimeoutNanos, @@ -106,7 +106,7 @@ class NettyClientTransport implements ConnectionClientTransport { this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator"); this.remoteAddress = Preconditions.checkNotNull(address, "address"); this.group = Preconditions.checkNotNull(group, "group"); - this.channelType = Preconditions.checkNotNull(channelType, "channelType"); + this.channelFactory = channelFactory; this.channelOptions = Preconditions.checkNotNull(channelOptions, "channelOptions"); this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; @@ -212,10 +212,9 @@ class NettyClientTransport implements ConnectionClientTransport { Bootstrap b = new Bootstrap(); b.group(eventLoop); - b.channel(channelType); - if (NioSocketChannel.class.isAssignableFrom(channelType)) { - b.option(SO_KEEPALIVE, true); - } + b.channelFactory(channelFactory); + // For non-socket based channel, the option will be ignored. + b.option(SO_KEEPALIVE, true); for (Map.Entry, ?> entry : channelOptions.entrySet()) { // Every entry in the map is obtained from // NettyChannelBuilder#withOption(ChannelOption option, T value) diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 3fb14aae26..a206e3954a 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -62,8 +62,12 @@ import io.grpc.internal.ServerTransportListener; import io.grpc.internal.TransportTracer; import io.grpc.internal.testing.TestUtils; import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker; +import io.netty.channel.Channel; import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFactory; import io.netty.channel.ChannelOption; +import io.netty.channel.ReflectiveChannelFactory; +import io.netty.channel.local.LocalChannel; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannelConfig; import io.netty.channel.socket.nio.NioServerSocketChannel; @@ -178,10 +182,11 @@ public class NettyClientTransportTest { int soLinger = 123; channelOptions.put(ChannelOption.SO_LINGER, soLinger); NettyClientTransport transport = new NettyClientTransport( - address, NioSocketChannel.class, channelOptions, group, newNegotiator(), - DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, - KEEPALIVE_TIME_NANOS_DISABLED, 1L, false, authority, null /* user agent */, - tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker()); + address, new ReflectiveChannelFactory<>(NioSocketChannel.class), channelOptions, group, + newNegotiator(), DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, + GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1L, false, authority, + null /* user agent */, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, + new SocketPicker()); transports.add(transport); callMeMaybe(transport.start(clientTransportListener)); @@ -418,7 +423,8 @@ public class NettyClientTransportTest { address = TestUtils.testServerAddress(12345); authority = GrpcUtil.authorityFromHostAndPort(address.getHostString(), address.getPort()); NettyClientTransport transport = new NettyClientTransport( - address, CantConstructChannel.class, new HashMap, Object>(), group, + address, new ReflectiveChannelFactory<>(CantConstructChannel.class), + new HashMap, Object>(), group, newNegotiator(), DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1, false, authority, null, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker()); @@ -465,6 +471,32 @@ public class NettyClientTransportTest { } } + @Test + public void channelFactoryShouldSetSocketOptionKeepAlive() throws Exception { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator(), + DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true, + new ReflectiveChannelFactory<>(NioSocketChannel.class)); + + callMeMaybe(transport.start(clientTransportListener)); + + assertThat(transport.channel().config().getOption(ChannelOption.SO_KEEPALIVE)) + .isTrue(); + } + + @Test + public void channelFactoryShouldNNotSetSocketOptionKeepAlive() throws Exception { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator(), + DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true, + new ReflectiveChannelFactory<>(LocalChannel.class)); + + callMeMaybe(transport.start(clientTransportListener)); + + assertThat(transport.channel().config().getOption(ChannelOption.SO_KEEPALIVE)) + .isNull(); + } + @Test public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception { startServer(); @@ -594,14 +626,21 @@ public class NettyClientTransportTest { private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int maxMsgSize, int maxHeaderListSize, String userAgent, boolean enableKeepAlive) { + return newTransport(negotiator, maxMsgSize, maxHeaderListSize, userAgent, enableKeepAlive, + new ReflectiveChannelFactory<>(NioSocketChannel.class)); + } + + private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int maxMsgSize, + int maxHeaderListSize, String userAgent, boolean enableKeepAlive, + ChannelFactory channelFactory) { long keepAliveTimeNano = KEEPALIVE_TIME_NANOS_DISABLED; long keepAliveTimeoutNano = TimeUnit.SECONDS.toNanos(1L); if (enableKeepAlive) { keepAliveTimeNano = TimeUnit.SECONDS.toNanos(10L); } NettyClientTransport transport = new NettyClientTransport( - address, NioSocketChannel.class, new HashMap, Object>(), group, negotiator, - DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize, + address, channelFactory, new HashMap, Object>(), group, + negotiator, DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize, keepAliveTimeNano, keepAliveTimeoutNano, false, authority, userAgent, tooManyPingsRunnable, new TransportTracer(), eagAttributes, new SocketPicker());