netty: add channelFactory to NettyChannelBuilder (#5312)

add channelFactory in NettyChannelBuilder & NettyClientTransport
This commit is contained in:
Jihun Cho 2019-02-04 10:51:23 -08:00 committed by GitHub
parent 3a39b81cf5
commit 71d067e8f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 24 deletions

View File

@ -17,13 +17,13 @@
package io.grpc.netty; package io.grpc.netty;
import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import static io.grpc.internal.GrpcUtil.DEFAULT_KEEPALIVE_TIMEOUT_NANOS; import static io.grpc.internal.GrpcUtil.DEFAULT_KEEPALIVE_TIMEOUT_NANOS;
import static io.grpc.internal.GrpcUtil.DEFAULT_KEEPALIVE_TIME_NANOS; import static io.grpc.internal.GrpcUtil.DEFAULT_KEEPALIVE_TIME_NANOS;
import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED; import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.CanIgnoreReturnValue;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.EquivalentAddressGroup; import io.grpc.EquivalentAddressGroup;
@ -40,8 +40,10 @@ import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourceHolder;
import io.grpc.internal.TransportTracer; import io.grpc.internal.TransportTracer;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFactory;
import io.netty.channel.ChannelOption; import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
import io.netty.channel.ReflectiveChannelFactory;
import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
@ -70,7 +72,8 @@ public final class NettyChannelBuilder
private NegotiationType negotiationType = NegotiationType.TLS; private NegotiationType negotiationType = NegotiationType.TLS;
private OverrideAuthorityChecker authorityChecker; private OverrideAuthorityChecker authorityChecker;
private Class<? extends Channel> channelType = NioSocketChannel.class; private ChannelFactory<? extends Channel> channelFactory =
new ReflectiveChannelFactory<>(NioSocketChannel.class);
@Nullable @Nullable
private EventLoopGroup eventLoopGroup; private EventLoopGroup eventLoopGroup;
@ -138,9 +141,23 @@ public final class NettyChannelBuilder
/** /**
* Specifies the channel type to use, by default we use {@link NioSocketChannel}. * Specifies the channel type to use, by default we use {@link NioSocketChannel}.
*
* <p>You either use this or {@link #channelFactory(io.netty.channel.ChannelFactory)} if your
* {@link Channel} implementation has no no-args constructor.
*/ */
public NettyChannelBuilder channelType(Class<? extends Channel> channelType) { public NettyChannelBuilder channelType(Class<? extends Channel> channelType) {
this.channelType = Preconditions.checkNotNull(channelType, "channelType"); checkNotNull(channelType, "channelType");
return channelFactory(new ReflectiveChannelFactory<>(channelType));
}
/**
* Specifies the {@link ChannelFactory} to create {@link Channel} instances. This method is
* usually only used if the specific {@code Channel} requires complex logic which requires
* additional information to create the {@code Channel}. Otherwise, recommend to use {@link
* #channelType(Class)}.
*/
public NettyChannelBuilder channelFactory(ChannelFactory<? extends Channel> channelFactory) {
this.channelFactory = checkNotNull(channelFactory, "channelFactory");
return this; return this;
} }
@ -390,7 +407,7 @@ public final class NettyChannelBuilder
negotiator = createProtocolNegotiatorByType(negotiationType, localSslContext); negotiator = createProtocolNegotiatorByType(negotiationType, localSslContext);
} }
return new NettyTransportFactory( return new NettyTransportFactory(
negotiator, channelType, channelOptions, negotiator, channelFactory, channelOptions,
eventLoopGroup, flowControlWindow, maxInboundMessageSize(), eventLoopGroup, flowControlWindow, maxInboundMessageSize(),
maxHeaderListSize, keepAliveTimeNanos, keepAliveTimeoutNanos, keepAliveWithoutCalls, maxHeaderListSize, keepAliveTimeNanos, keepAliveTimeoutNanos, keepAliveWithoutCalls,
transportTracerFactory.create(), localSocketPicker); transportTracerFactory.create(), localSocketPicker);
@ -453,7 +470,7 @@ public final class NettyChannelBuilder
void protocolNegotiatorFactory(ProtocolNegotiatorFactory protocolNegotiatorFactory) { void protocolNegotiatorFactory(ProtocolNegotiatorFactory protocolNegotiatorFactory) {
this.protocolNegotiatorFactory this.protocolNegotiatorFactory
= Preconditions.checkNotNull(protocolNegotiatorFactory, "protocolNegotiatorFactory"); = checkNotNull(protocolNegotiatorFactory, "protocolNegotiatorFactory");
} }
@Override @Override
@ -496,7 +513,7 @@ public final class NettyChannelBuilder
@CheckReturnValue @CheckReturnValue
private static final class NettyTransportFactory implements ClientTransportFactory { private static final class NettyTransportFactory implements ClientTransportFactory {
private final ProtocolNegotiator protocolNegotiator; private final ProtocolNegotiator protocolNegotiator;
private final Class<? extends Channel> channelType; private final ChannelFactory<? extends Channel> channelFactory;
private final Map<ChannelOption<?>, ?> channelOptions; private final Map<ChannelOption<?>, ?> channelOptions;
private final EventLoopGroup group; private final EventLoopGroup group;
private final boolean usingSharedGroup; private final boolean usingSharedGroup;
@ -512,12 +529,12 @@ public final class NettyChannelBuilder
private boolean closed; private boolean closed;
NettyTransportFactory(ProtocolNegotiator protocolNegotiator, NettyTransportFactory(ProtocolNegotiator protocolNegotiator,
Class<? extends Channel> channelType, Map<ChannelOption<?>, ?> channelOptions, ChannelFactory<? extends Channel> channelFactory, Map<ChannelOption<?>, ?> channelOptions,
EventLoopGroup group, int flowControlWindow, int maxMessageSize, int maxHeaderListSize, EventLoopGroup group, int flowControlWindow, int maxMessageSize, int maxHeaderListSize,
long keepAliveTimeNanos, long keepAliveTimeoutNanos, boolean keepAliveWithoutCalls, long keepAliveTimeNanos, long keepAliveTimeoutNanos, boolean keepAliveWithoutCalls,
TransportTracer transportTracer, LocalSocketPicker localSocketPicker) { TransportTracer transportTracer, LocalSocketPicker localSocketPicker) {
this.protocolNegotiator = protocolNegotiator; this.protocolNegotiator = protocolNegotiator;
this.channelType = channelType; this.channelFactory = channelFactory;
this.channelOptions = new HashMap<ChannelOption<?>, Object>(channelOptions); this.channelOptions = new HashMap<ChannelOption<?>, Object>(channelOptions);
this.flowControlWindow = flowControlWindow; this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
@ -562,7 +579,7 @@ public final class NettyChannelBuilder
}; };
NettyClientTransport transport = new NettyClientTransport( NettyClientTransport transport = new NettyClientTransport(
serverAddress, channelType, channelOptions, group, serverAddress, channelFactory, channelOptions, group,
localNegotiator, flowControlWindow, localNegotiator, flowControlWindow,
maxMessageSize, maxHeaderListSize, keepAliveTimeNanosState.get(), keepAliveTimeoutNanos, maxMessageSize, maxHeaderListSize, keepAliveTimeNanosState.get(), keepAliveTimeoutNanos,
keepAliveWithoutCalls, options.getAuthority(), options.getUserAgent(), keepAliveWithoutCalls, options.getAuthority(), options.getUserAgent(),

View File

@ -43,13 +43,13 @@ import io.grpc.internal.TransportTracer;
import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker; import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFactory;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelOption; import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoop; import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http2.StreamBufferingEncoder.Http2ChannelClosedException; import io.netty.handler.codec.http2.StreamBufferingEncoder.Http2ChannelClosedException;
import io.netty.util.AsciiString; import io.netty.util.AsciiString;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
@ -67,7 +67,7 @@ class NettyClientTransport implements ConnectionClientTransport {
private final InternalLogId logId; private final InternalLogId logId;
private final Map<ChannelOption<?>, ?> channelOptions; private final Map<ChannelOption<?>, ?> channelOptions;
private final SocketAddress remoteAddress; private final SocketAddress remoteAddress;
private final Class<? extends Channel> channelType; private final ChannelFactory<? extends Channel> channelFactory;
private final EventLoopGroup group; private final EventLoopGroup group;
private final ProtocolNegotiator negotiator; private final ProtocolNegotiator negotiator;
private final String authorityString; private final String authorityString;
@ -96,7 +96,7 @@ class NettyClientTransport implements ConnectionClientTransport {
private final LocalSocketPicker localSocketPicker; private final LocalSocketPicker localSocketPicker;
NettyClientTransport( NettyClientTransport(
SocketAddress address, Class<? extends Channel> channelType, SocketAddress address, ChannelFactory<? extends Channel> channelFactory,
Map<ChannelOption<?>, ?> channelOptions, EventLoopGroup group, Map<ChannelOption<?>, ?> channelOptions, EventLoopGroup group,
ProtocolNegotiator negotiator, int flowControlWindow, int maxMessageSize, ProtocolNegotiator negotiator, int flowControlWindow, int maxMessageSize,
int maxHeaderListSize, long keepAliveTimeNanos, long keepAliveTimeoutNanos, int maxHeaderListSize, long keepAliveTimeNanos, long keepAliveTimeoutNanos,
@ -106,7 +106,7 @@ class NettyClientTransport implements ConnectionClientTransport {
this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator"); this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator");
this.remoteAddress = Preconditions.checkNotNull(address, "address"); this.remoteAddress = Preconditions.checkNotNull(address, "address");
this.group = Preconditions.checkNotNull(group, "group"); this.group = Preconditions.checkNotNull(group, "group");
this.channelType = Preconditions.checkNotNull(channelType, "channelType"); this.channelFactory = channelFactory;
this.channelOptions = Preconditions.checkNotNull(channelOptions, "channelOptions"); this.channelOptions = Preconditions.checkNotNull(channelOptions, "channelOptions");
this.flowControlWindow = flowControlWindow; this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
@ -212,10 +212,9 @@ class NettyClientTransport implements ConnectionClientTransport {
Bootstrap b = new Bootstrap(); Bootstrap b = new Bootstrap();
b.group(eventLoop); b.group(eventLoop);
b.channel(channelType); b.channelFactory(channelFactory);
if (NioSocketChannel.class.isAssignableFrom(channelType)) { // For non-socket based channel, the option will be ignored.
b.option(SO_KEEPALIVE, true); b.option(SO_KEEPALIVE, true);
}
for (Map.Entry<ChannelOption<?>, ?> entry : channelOptions.entrySet()) { for (Map.Entry<ChannelOption<?>, ?> entry : channelOptions.entrySet()) {
// Every entry in the map is obtained from // Every entry in the map is obtained from
// NettyChannelBuilder#withOption(ChannelOption<T> option, T value) // NettyChannelBuilder#withOption(ChannelOption<T> option, T value)

View File

@ -62,8 +62,12 @@ import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.TransportTracer; import io.grpc.internal.TransportTracer;
import io.grpc.internal.testing.TestUtils; import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker; import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker;
import io.netty.channel.Channel;
import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelConfig;
import io.netty.channel.ChannelFactory;
import io.netty.channel.ChannelOption; import io.netty.channel.ChannelOption;
import io.netty.channel.ReflectiveChannelFactory;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannelConfig; import io.netty.channel.socket.SocketChannelConfig;
import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel;
@ -178,10 +182,11 @@ public class NettyClientTransportTest {
int soLinger = 123; int soLinger = 123;
channelOptions.put(ChannelOption.SO_LINGER, soLinger); channelOptions.put(ChannelOption.SO_LINGER, soLinger);
NettyClientTransport transport = new NettyClientTransport( NettyClientTransport transport = new NettyClientTransport(
address, NioSocketChannel.class, channelOptions, group, newNegotiator(), address, new ReflectiveChannelFactory<>(NioSocketChannel.class), channelOptions, group,
DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, newNegotiator(), DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE,
KEEPALIVE_TIME_NANOS_DISABLED, 1L, false, authority, null /* user agent */, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1L, false, authority,
tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker()); null /* user agent */, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY,
new SocketPicker());
transports.add(transport); transports.add(transport);
callMeMaybe(transport.start(clientTransportListener)); callMeMaybe(transport.start(clientTransportListener));
@ -418,7 +423,8 @@ public class NettyClientTransportTest {
address = TestUtils.testServerAddress(12345); address = TestUtils.testServerAddress(12345);
authority = GrpcUtil.authorityFromHostAndPort(address.getHostString(), address.getPort()); authority = GrpcUtil.authorityFromHostAndPort(address.getHostString(), address.getPort());
NettyClientTransport transport = new NettyClientTransport( NettyClientTransport transport = new NettyClientTransport(
address, CantConstructChannel.class, new HashMap<ChannelOption<?>, Object>(), group, address, new ReflectiveChannelFactory<>(CantConstructChannel.class),
new HashMap<ChannelOption<?>, Object>(), group,
newNegotiator(), DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, newNegotiator(), DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE,
GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1, false, authority, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1, false, authority,
null, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker()); null, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker());
@ -465,6 +471,32 @@ public class NettyClientTransportTest {
} }
} }
@Test
public void channelFactoryShouldSetSocketOptionKeepAlive() throws Exception {
startServer();
NettyClientTransport transport = newTransport(newNegotiator(),
DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true,
new ReflectiveChannelFactory<>(NioSocketChannel.class));
callMeMaybe(transport.start(clientTransportListener));
assertThat(transport.channel().config().getOption(ChannelOption.SO_KEEPALIVE))
.isTrue();
}
@Test
public void channelFactoryShouldNNotSetSocketOptionKeepAlive() throws Exception {
startServer();
NettyClientTransport transport = newTransport(newNegotiator(),
DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true,
new ReflectiveChannelFactory<>(LocalChannel.class));
callMeMaybe(transport.start(clientTransportListener));
assertThat(transport.channel().config().getOption(ChannelOption.SO_KEEPALIVE))
.isNull();
}
@Test @Test
public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception { public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception {
startServer(); startServer();
@ -594,14 +626,21 @@ public class NettyClientTransportTest {
private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int maxMsgSize, private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int maxMsgSize,
int maxHeaderListSize, String userAgent, boolean enableKeepAlive) { int maxHeaderListSize, String userAgent, boolean enableKeepAlive) {
return newTransport(negotiator, maxMsgSize, maxHeaderListSize, userAgent, enableKeepAlive,
new ReflectiveChannelFactory<>(NioSocketChannel.class));
}
private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int maxMsgSize,
int maxHeaderListSize, String userAgent, boolean enableKeepAlive,
ChannelFactory<? extends Channel> channelFactory) {
long keepAliveTimeNano = KEEPALIVE_TIME_NANOS_DISABLED; long keepAliveTimeNano = KEEPALIVE_TIME_NANOS_DISABLED;
long keepAliveTimeoutNano = TimeUnit.SECONDS.toNanos(1L); long keepAliveTimeoutNano = TimeUnit.SECONDS.toNanos(1L);
if (enableKeepAlive) { if (enableKeepAlive) {
keepAliveTimeNano = TimeUnit.SECONDS.toNanos(10L); keepAliveTimeNano = TimeUnit.SECONDS.toNanos(10L);
} }
NettyClientTransport transport = new NettyClientTransport( NettyClientTransport transport = new NettyClientTransport(
address, NioSocketChannel.class, new HashMap<ChannelOption<?>, Object>(), group, negotiator, address, channelFactory, new HashMap<ChannelOption<?>, Object>(), group,
DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize, negotiator, DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize,
keepAliveTimeNano, keepAliveTimeoutNano, keepAliveTimeNano, keepAliveTimeoutNano,
false, authority, userAgent, tooManyPingsRunnable, false, authority, userAgent, tooManyPingsRunnable,
new TransportTracer(), eagAttributes, new SocketPicker()); new TransportTracer(), eagAttributes, new SocketPicker());