diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java
index 3065ec2088..af352dc2c3 100644
--- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java
+++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java
@@ -17,13 +17,13 @@
package io.grpc.netty;
import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static io.grpc.internal.GrpcUtil.DEFAULT_KEEPALIVE_TIMEOUT_NANOS;
import static io.grpc.internal.GrpcUtil.DEFAULT_KEEPALIVE_TIME_NANOS;
import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED;
import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import io.grpc.Attributes;
import io.grpc.EquivalentAddressGroup;
@@ -40,8 +40,10 @@ import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.SharedResourceHolder;
import io.grpc.internal.TransportTracer;
import io.netty.channel.Channel;
+import io.netty.channel.ChannelFactory;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
+import io.netty.channel.ReflectiveChannelFactory;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.SslContext;
import java.net.InetSocketAddress;
@@ -70,7 +72,8 @@ public final class NettyChannelBuilder
private NegotiationType negotiationType = NegotiationType.TLS;
private OverrideAuthorityChecker authorityChecker;
- private Class extends Channel> channelType = NioSocketChannel.class;
+ private ChannelFactory extends Channel> channelFactory =
+ new ReflectiveChannelFactory<>(NioSocketChannel.class);
@Nullable
private EventLoopGroup eventLoopGroup;
@@ -138,9 +141,23 @@ public final class NettyChannelBuilder
/**
* Specifies the channel type to use, by default we use {@link NioSocketChannel}.
+ *
+ *
You either use this or {@link #channelFactory(io.netty.channel.ChannelFactory)} if your
+ * {@link Channel} implementation has no no-args constructor.
*/
public NettyChannelBuilder channelType(Class extends Channel> channelType) {
- this.channelType = Preconditions.checkNotNull(channelType, "channelType");
+ checkNotNull(channelType, "channelType");
+ return channelFactory(new ReflectiveChannelFactory<>(channelType));
+ }
+
+ /**
+ * Specifies the {@link ChannelFactory} to create {@link Channel} instances. This method is
+ * usually only used if the specific {@code Channel} requires complex logic which requires
+ * additional information to create the {@code Channel}. Otherwise, recommend to use {@link
+ * #channelType(Class)}.
+ */
+ public NettyChannelBuilder channelFactory(ChannelFactory extends Channel> channelFactory) {
+ this.channelFactory = checkNotNull(channelFactory, "channelFactory");
return this;
}
@@ -390,7 +407,7 @@ public final class NettyChannelBuilder
negotiator = createProtocolNegotiatorByType(negotiationType, localSslContext);
}
return new NettyTransportFactory(
- negotiator, channelType, channelOptions,
+ negotiator, channelFactory, channelOptions,
eventLoopGroup, flowControlWindow, maxInboundMessageSize(),
maxHeaderListSize, keepAliveTimeNanos, keepAliveTimeoutNanos, keepAliveWithoutCalls,
transportTracerFactory.create(), localSocketPicker);
@@ -453,7 +470,7 @@ public final class NettyChannelBuilder
void protocolNegotiatorFactory(ProtocolNegotiatorFactory protocolNegotiatorFactory) {
this.protocolNegotiatorFactory
- = Preconditions.checkNotNull(protocolNegotiatorFactory, "protocolNegotiatorFactory");
+ = checkNotNull(protocolNegotiatorFactory, "protocolNegotiatorFactory");
}
@Override
@@ -496,7 +513,7 @@ public final class NettyChannelBuilder
@CheckReturnValue
private static final class NettyTransportFactory implements ClientTransportFactory {
private final ProtocolNegotiator protocolNegotiator;
- private final Class extends Channel> channelType;
+ private final ChannelFactory extends Channel> channelFactory;
private final Map, ?> channelOptions;
private final EventLoopGroup group;
private final boolean usingSharedGroup;
@@ -512,12 +529,12 @@ public final class NettyChannelBuilder
private boolean closed;
NettyTransportFactory(ProtocolNegotiator protocolNegotiator,
- Class extends Channel> channelType, Map, ?> channelOptions,
+ ChannelFactory extends Channel> channelFactory, Map, ?> channelOptions,
EventLoopGroup group, int flowControlWindow, int maxMessageSize, int maxHeaderListSize,
long keepAliveTimeNanos, long keepAliveTimeoutNanos, boolean keepAliveWithoutCalls,
TransportTracer transportTracer, LocalSocketPicker localSocketPicker) {
this.protocolNegotiator = protocolNegotiator;
- this.channelType = channelType;
+ this.channelFactory = channelFactory;
this.channelOptions = new HashMap, Object>(channelOptions);
this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize;
@@ -562,7 +579,7 @@ public final class NettyChannelBuilder
};
NettyClientTransport transport = new NettyClientTransport(
- serverAddress, channelType, channelOptions, group,
+ serverAddress, channelFactory, channelOptions, group,
localNegotiator, flowControlWindow,
maxMessageSize, maxHeaderListSize, keepAliveTimeNanosState.get(), keepAliveTimeoutNanos,
keepAliveWithoutCalls, options.getAuthority(), options.getUserAgent(),
diff --git a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java
index c2baedf95f..06e3f8492d 100644
--- a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java
+++ b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java
@@ -43,13 +43,13 @@ import io.grpc.internal.TransportTracer;
import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
+import io.netty.channel.ChannelFactory;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup;
-import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http2.StreamBufferingEncoder.Http2ChannelClosedException;
import io.netty.util.AsciiString;
import io.netty.util.concurrent.Future;
@@ -67,7 +67,7 @@ class NettyClientTransport implements ConnectionClientTransport {
private final InternalLogId logId;
private final Map, ?> channelOptions;
private final SocketAddress remoteAddress;
- private final Class extends Channel> channelType;
+ private final ChannelFactory extends Channel> channelFactory;
private final EventLoopGroup group;
private final ProtocolNegotiator negotiator;
private final String authorityString;
@@ -96,7 +96,7 @@ class NettyClientTransport implements ConnectionClientTransport {
private final LocalSocketPicker localSocketPicker;
NettyClientTransport(
- SocketAddress address, Class extends Channel> channelType,
+ SocketAddress address, ChannelFactory extends Channel> channelFactory,
Map, ?> channelOptions, EventLoopGroup group,
ProtocolNegotiator negotiator, int flowControlWindow, int maxMessageSize,
int maxHeaderListSize, long keepAliveTimeNanos, long keepAliveTimeoutNanos,
@@ -106,7 +106,7 @@ class NettyClientTransport implements ConnectionClientTransport {
this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator");
this.remoteAddress = Preconditions.checkNotNull(address, "address");
this.group = Preconditions.checkNotNull(group, "group");
- this.channelType = Preconditions.checkNotNull(channelType, "channelType");
+ this.channelFactory = channelFactory;
this.channelOptions = Preconditions.checkNotNull(channelOptions, "channelOptions");
this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize;
@@ -212,10 +212,9 @@ class NettyClientTransport implements ConnectionClientTransport {
Bootstrap b = new Bootstrap();
b.group(eventLoop);
- b.channel(channelType);
- if (NioSocketChannel.class.isAssignableFrom(channelType)) {
- b.option(SO_KEEPALIVE, true);
- }
+ b.channelFactory(channelFactory);
+ // For non-socket based channel, the option will be ignored.
+ b.option(SO_KEEPALIVE, true);
for (Map.Entry, ?> entry : channelOptions.entrySet()) {
// Every entry in the map is obtained from
// NettyChannelBuilder#withOption(ChannelOption option, T value)
diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java
index 3fb14aae26..a206e3954a 100644
--- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java
+++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java
@@ -62,8 +62,12 @@ import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.TransportTracer;
import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker;
+import io.netty.channel.Channel;
import io.netty.channel.ChannelConfig;
+import io.netty.channel.ChannelFactory;
import io.netty.channel.ChannelOption;
+import io.netty.channel.ReflectiveChannelFactory;
+import io.netty.channel.local.LocalChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannelConfig;
import io.netty.channel.socket.nio.NioServerSocketChannel;
@@ -178,10 +182,11 @@ public class NettyClientTransportTest {
int soLinger = 123;
channelOptions.put(ChannelOption.SO_LINGER, soLinger);
NettyClientTransport transport = new NettyClientTransport(
- address, NioSocketChannel.class, channelOptions, group, newNegotiator(),
- DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE,
- KEEPALIVE_TIME_NANOS_DISABLED, 1L, false, authority, null /* user agent */,
- tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker());
+ address, new ReflectiveChannelFactory<>(NioSocketChannel.class), channelOptions, group,
+ newNegotiator(), DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE,
+ GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1L, false, authority,
+ null /* user agent */, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY,
+ new SocketPicker());
transports.add(transport);
callMeMaybe(transport.start(clientTransportListener));
@@ -418,7 +423,8 @@ public class NettyClientTransportTest {
address = TestUtils.testServerAddress(12345);
authority = GrpcUtil.authorityFromHostAndPort(address.getHostString(), address.getPort());
NettyClientTransport transport = new NettyClientTransport(
- address, CantConstructChannel.class, new HashMap, Object>(), group,
+ address, new ReflectiveChannelFactory<>(CantConstructChannel.class),
+ new HashMap, Object>(), group,
newNegotiator(), DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE,
GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1, false, authority,
null, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker());
@@ -465,6 +471,32 @@ public class NettyClientTransportTest {
}
}
+ @Test
+ public void channelFactoryShouldSetSocketOptionKeepAlive() throws Exception {
+ startServer();
+ NettyClientTransport transport = newTransport(newNegotiator(),
+ DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true,
+ new ReflectiveChannelFactory<>(NioSocketChannel.class));
+
+ callMeMaybe(transport.start(clientTransportListener));
+
+ assertThat(transport.channel().config().getOption(ChannelOption.SO_KEEPALIVE))
+ .isTrue();
+ }
+
+ @Test
+ public void channelFactoryShouldNNotSetSocketOptionKeepAlive() throws Exception {
+ startServer();
+ NettyClientTransport transport = newTransport(newNegotiator(),
+ DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true,
+ new ReflectiveChannelFactory<>(LocalChannel.class));
+
+ callMeMaybe(transport.start(clientTransportListener));
+
+ assertThat(transport.channel().config().getOption(ChannelOption.SO_KEEPALIVE))
+ .isNull();
+ }
+
@Test
public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception {
startServer();
@@ -594,14 +626,21 @@ public class NettyClientTransportTest {
private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int maxMsgSize,
int maxHeaderListSize, String userAgent, boolean enableKeepAlive) {
+ return newTransport(negotiator, maxMsgSize, maxHeaderListSize, userAgent, enableKeepAlive,
+ new ReflectiveChannelFactory<>(NioSocketChannel.class));
+ }
+
+ private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int maxMsgSize,
+ int maxHeaderListSize, String userAgent, boolean enableKeepAlive,
+ ChannelFactory extends Channel> channelFactory) {
long keepAliveTimeNano = KEEPALIVE_TIME_NANOS_DISABLED;
long keepAliveTimeoutNano = TimeUnit.SECONDS.toNanos(1L);
if (enableKeepAlive) {
keepAliveTimeNano = TimeUnit.SECONDS.toNanos(10L);
}
NettyClientTransport transport = new NettyClientTransport(
- address, NioSocketChannel.class, new HashMap, Object>(), group, negotiator,
- DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize,
+ address, channelFactory, new HashMap, Object>(), group,
+ negotiator, DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize,
keepAliveTimeNano, keepAliveTimeoutNano,
false, authority, userAgent, tooManyPingsRunnable,
new TransportTracer(), eagAttributes, new SocketPicker());