diff --git a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java index e903bb716d..211e1d14e2 100644 --- a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java @@ -144,7 +144,7 @@ public final class GoogleDefaultChannelBuilder "%s must be a InetSocketAddress", serverAddress); final GoogleDefaultProtocolNegotiator negotiator = - new GoogleDefaultProtocolNegotiator(altsHandshakerFactory, sslContext, authority); + new GoogleDefaultProtocolNegotiator(altsHandshakerFactory, sslContext); return new TransportCreationParamsFilter() { @Override public SocketAddress getTargetServerAddress() { diff --git a/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java index aae9f34560..84288945f3 100644 --- a/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java @@ -28,10 +28,9 @@ public final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator private final ProtocolNegotiator altsProtocolNegotiator; private final ProtocolNegotiator tlsProtocolNegotiator; - public GoogleDefaultProtocolNegotiator( - TsiHandshakerFactory altsFactory, SslContext sslContext, String authority) { + public GoogleDefaultProtocolNegotiator(TsiHandshakerFactory altsFactory, SslContext sslContext) { altsProtocolNegotiator = AltsProtocolNegotiator.create(altsFactory); - tlsProtocolNegotiator = ProtocolNegotiators.tls(sslContext, authority); + tlsProtocolNegotiator = ProtocolNegotiators.tls(sslContext); } @VisibleForTesting diff --git a/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java b/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java index 7d5b1dbf27..2c127b16b1 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java +++ b/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java @@ -84,4 +84,13 @@ public abstract class GrpcHttp2ConnectionHandler extends Http2ConnectionHandler public Attributes getEagAttributes() { return Attributes.EMPTY; } + + /** + * Returns the authority of the server. Only available on the client-side. + * + * @throws UnsupportedOperationException if on server-side + */ + public String getAuthority() { + throw new UnsupportedOperationException(); + } } diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index a0d4525e09..633cb32324 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -331,8 +331,25 @@ public final class NettyChannelBuilder @CheckReturnValue @Internal protected ClientTransportFactory buildTransportFactory() { - return new NettyTransportFactory(dynamicParamsFactory, channelType, channelOptions, - negotiationType, sslContext, eventLoopGroup, flowControlWindow, maxInboundMessageSize(), + TransportCreationParamsFilterFactory transportCreationParamsFilterFactory = + dynamicParamsFactory; + if (transportCreationParamsFilterFactory == null) { + SslContext localSslContext = sslContext; + if (negotiationType == NegotiationType.TLS && localSslContext == null) { + try { + localSslContext = GrpcSslContexts.forClient().build(); + } catch (SSLException ex) { + throw new RuntimeException(ex); + } + } + ProtocolNegotiator negotiator = + createProtocolNegotiatorByType(negotiationType, localSslContext); + transportCreationParamsFilterFactory = + new DefaultNettyTransportCreationParamsFilterFactory(negotiator); + } + return new NettyTransportFactory( + transportCreationParamsFilterFactory, channelType, channelOptions, + eventLoopGroup, flowControlWindow, maxInboundMessageSize(), maxHeaderListSize, keepAliveTimeNanos, keepAliveTimeoutNanos, keepAliveWithoutCalls, transportTracerFactory.create()); } @@ -362,23 +379,7 @@ public final class NettyChannelBuilder @VisibleForTesting @CheckReturnValue - static ProtocolNegotiator createProtocolNegotiator( - String authority, - NegotiationType negotiationType, - SslContext sslContext, - ProxyParameters proxy) { - ProtocolNegotiator negotiator = - createProtocolNegotiatorByType(authority, negotiationType, sslContext); - if (proxy != null) { - negotiator = ProtocolNegotiators.httpProxy( - proxy.proxyAddress, proxy.username, proxy.password, negotiator); - } - return negotiator; - } - - @CheckReturnValue - private static ProtocolNegotiator createProtocolNegotiatorByType( - String authority, + static ProtocolNegotiator createProtocolNegotiatorByType( NegotiationType negotiationType, SslContext sslContext) { switch (negotiationType) { @@ -387,7 +388,7 @@ public final class NettyChannelBuilder case PLAINTEXT_UPGRADE: return ProtocolNegotiators.plaintextUpgrade(); case TLS: - return ProtocolNegotiators.tls(sslContext, authority); + return ProtocolNegotiators.tls(sslContext); default: throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType); } @@ -461,7 +462,6 @@ public final class NettyChannelBuilder private final TransportCreationParamsFilterFactory transportCreationParamsFilterFactory; private final Class channelType; private final Map, ?> channelOptions; - private final NegotiationType negotiationType; private final EventLoopGroup group; private final boolean usingSharedGroup; private final int flowControlWindow; @@ -476,27 +476,20 @@ public final class NettyChannelBuilder NettyTransportFactory(TransportCreationParamsFilterFactory transportCreationParamsFilterFactory, Class channelType, Map, ?> channelOptions, - NegotiationType negotiationType, SslContext sslContext, EventLoopGroup group, - int flowControlWindow, int maxMessageSize, int maxHeaderListSize, + EventLoopGroup group, int flowControlWindow, int maxMessageSize, int maxHeaderListSize, long keepAliveTimeNanos, long keepAliveTimeoutNanos, boolean keepAliveWithoutCalls, TransportTracer transportTracer) { - this.channelType = channelType; - this.negotiationType = negotiationType; - this.channelOptions = new HashMap, Object>(channelOptions); - this.transportTracer = transportTracer; - - if (transportCreationParamsFilterFactory == null) { - transportCreationParamsFilterFactory = - new DefaultNettyTransportCreationParamsFilterFactory(sslContext); - } this.transportCreationParamsFilterFactory = transportCreationParamsFilterFactory; - + this.channelType = channelType; + this.channelOptions = new HashMap, Object>(channelOptions); this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; this.maxHeaderListSize = maxHeaderListSize; this.keepAliveTimeNanos = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos); this.keepAliveTimeoutNanos = keepAliveTimeoutNanos; this.keepAliveWithoutCalls = keepAliveWithoutCalls; + this.transportTracer = transportTracer; + usingSharedGroup = group == null; if (usingSharedGroup) { // The group was unspecified, using the shared group. @@ -550,71 +543,69 @@ public final class NettyChannelBuilder SharedResourceHolder.release(Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP, group); } } + } - private final class DefaultNettyTransportCreationParamsFilterFactory - implements TransportCreationParamsFilterFactory { - private final SslContext sslContext; + private static final class DefaultNettyTransportCreationParamsFilterFactory + implements TransportCreationParamsFilterFactory { + final ProtocolNegotiator negotiator; - private DefaultNettyTransportCreationParamsFilterFactory(SslContext sslContext) { - if (negotiationType == NegotiationType.TLS && sslContext == null) { - try { - sslContext = GrpcSslContexts.forClient().build(); - } catch (SSLException ex) { - throw new RuntimeException(ex); - } - } - this.sslContext = sslContext; + DefaultNettyTransportCreationParamsFilterFactory(ProtocolNegotiator negotiator) { + this.negotiator = negotiator; + } + + @Override + public TransportCreationParamsFilter create( + SocketAddress targetServerAddress, + String authority, + String userAgent, + ProxyParameters proxyParams) { + ProtocolNegotiator localNegotiator = negotiator; + if (proxyParams != null) { + localNegotiator = ProtocolNegotiators.httpProxy( + proxyParams.proxyAddress, proxyParams.username, proxyParams.password, negotiator); } + return new DynamicNettyTransportParams( + targetServerAddress, authority, userAgent, localNegotiator); + } + } - @Override - public TransportCreationParamsFilter create( - SocketAddress targetServerAddress, - String authority, - String userAgent, - ProxyParameters proxyParams) { - return new DynamicNettyTransportParams( - targetServerAddress, authority, userAgent, proxyParams); - } + @CheckReturnValue + private static final class DynamicNettyTransportParams implements TransportCreationParamsFilter { - @CheckReturnValue - private final class DynamicNettyTransportParams implements TransportCreationParamsFilter { + private final SocketAddress targetServerAddress; + private final String authority; + @Nullable private final String userAgent; + private final ProtocolNegotiator protocolNegotiator; - private final SocketAddress targetServerAddress; - private final String authority; - @Nullable private final String userAgent; - private ProxyParameters proxyParams; + private DynamicNettyTransportParams( + SocketAddress targetServerAddress, + String authority, + String userAgent, + ProtocolNegotiator protocolNegotiator) { + this.targetServerAddress = targetServerAddress; + this.authority = authority; + this.userAgent = userAgent; + this.protocolNegotiator = protocolNegotiator; + } - private DynamicNettyTransportParams( - SocketAddress targetServerAddress, - String authority, - String userAgent, - ProxyParameters proxyParams) { - this.targetServerAddress = targetServerAddress; - this.authority = authority; - this.userAgent = userAgent; - this.proxyParams = proxyParams; - } + @Override + public SocketAddress getTargetServerAddress() { + return targetServerAddress; + } - @Override - public SocketAddress getTargetServerAddress() { - return targetServerAddress; - } + @Override + public String getAuthority() { + return authority; + } - @Override - public String getAuthority() { - return authority; - } + @Override + public String getUserAgent() { + return userAgent; + } - @Override - public String getUserAgent() { - return userAgent; - } - - @Override - public ProtocolNegotiator getProtocolNegotiator() { - return createProtocolNegotiator(authority, negotiationType, sslContext, proxyParams); - } - } + @Override + public ProtocolNegotiator getProtocolNegotiator() { + return protocolNegotiator; } } } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index 1b77855a53..e059af96d5 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -103,6 +103,7 @@ class NettyClientHandler extends AbstractNettyHandler { private final Supplier stopwatchFactory; private final TransportTracer transportTracer; private final Attributes eagAttributes; + private final String authority; private WriteQueue clientWriteQueue; private Http2Ping ping; private Attributes attributes = Attributes.EMPTY; @@ -116,7 +117,8 @@ class NettyClientHandler extends AbstractNettyHandler { Supplier stopwatchFactory, Runnable tooManyPingsRunnable, TransportTracer transportTracer, - Attributes eagAttributes) { + Attributes eagAttributes, + String authority) { Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive"); Http2HeadersDecoder headersDecoder = new GrpcHttp2ClientHeadersDecoder(maxHeaderListSize); Http2FrameReader frameReader = new DefaultHttp2FrameReader(headersDecoder); @@ -139,7 +141,8 @@ class NettyClientHandler extends AbstractNettyHandler { stopwatchFactory, tooManyPingsRunnable, transportTracer, - eagAttributes); + eagAttributes, + authority); } @VisibleForTesting @@ -154,7 +157,8 @@ class NettyClientHandler extends AbstractNettyHandler { Supplier stopwatchFactory, Runnable tooManyPingsRunnable, TransportTracer transportTracer, - Attributes eagAttributes) { + Attributes eagAttributes, + String authority) { Preconditions.checkNotNull(connection, "connection"); Preconditions.checkNotNull(frameReader, "frameReader"); Preconditions.checkNotNull(lifecycleManager, "lifecycleManager"); @@ -163,6 +167,7 @@ class NettyClientHandler extends AbstractNettyHandler { Preconditions.checkNotNull(stopwatchFactory, "stopwatchFactory"); Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable"); Preconditions.checkNotNull(eagAttributes, "eagAttributes"); + Preconditions.checkNotNull(authority, "authority"); Http2FrameLogger frameLogger = new Http2FrameLogger(LogLevel.DEBUG, NettyClientHandler.class); frameReader = new Http2InboundFrameLogger(frameReader, frameLogger); @@ -205,7 +210,8 @@ class NettyClientHandler extends AbstractNettyHandler { stopwatchFactory, tooManyPingsRunnable, transportTracer, - eagAttributes); + eagAttributes, + authority); } private NettyClientHandler( @@ -217,13 +223,15 @@ class NettyClientHandler extends AbstractNettyHandler { Supplier stopwatchFactory, final Runnable tooManyPingsRunnable, TransportTracer transportTracer, - Attributes eagAttributes) { + Attributes eagAttributes, + String authority) { super(/* channelUnused= */ null, decoder, encoder, settings); this.lifecycleManager = lifecycleManager; this.keepAliveManager = keepAliveManager; this.stopwatchFactory = stopwatchFactory; this.transportTracer = Preconditions.checkNotNull(transportTracer); this.eagAttributes = eagAttributes; + this.authority = authority; // Set the frame listener on the decoder. decoder().frameListener(new FrameListener()); @@ -429,6 +437,11 @@ class NettyClientHandler extends AbstractNettyHandler { return eagAttributes; } + @Override + public String getAuthority() { + return authority; + } + InternalChannelz.Security getSecurityInfo() { return securityInfo; } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java index 57c4efda27..6141db2fe7 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java @@ -68,6 +68,7 @@ class NettyClientTransport implements ConnectionClientTransport { private final Class channelType; private final EventLoopGroup group; private final ProtocolNegotiator negotiator; + private final String authorityString; private final AsciiString authority; private final AsciiString userAgent; private final int flowControlWindow; @@ -109,6 +110,7 @@ class NettyClientTransport implements ConnectionClientTransport { this.keepAliveTimeNanos = keepAliveTimeNanos; this.keepAliveTimeoutNanos = keepAliveTimeoutNanos; this.keepAliveWithoutCalls = keepAliveWithoutCalls; + this.authorityString = authority; this.authority = new AsciiString(authority); this.userAgent = new AsciiString(GrpcUtil.getGrpcUserAgent("netty", userAgent)); this.tooManyPingsRunnable = @@ -195,7 +197,8 @@ class NettyClientTransport implements ConnectionClientTransport { GrpcUtil.STOPWATCH_SUPPLIER, tooManyPingsRunnable, transportTracer, - eagAttributes); + eagAttributes, + authorityString); NettyHandlerSettings.setAutoWindow(handler); negotiationHandler = negotiator.newHandler(handler); diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 356cae0ea1..3a13207cb4 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -259,57 +259,49 @@ public final class ProtocolNegotiators { * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. */ - public static ProtocolNegotiator tls(SslContext sslContext, String authority) { - Preconditions.checkNotNull(sslContext, "sslContext"); - URI uri = GrpcUtil.authorityToUri(Preconditions.checkNotNull(authority, "authority")); - String host; - int port; - if (uri.getHost() != null) { - host = uri.getHost(); - port = uri.getPort(); - } else { - /* - * Implementation note: We pick -1 as the port here rather than deriving it from the original - * socket address. The SSL engine doens't use this port number when contacting the remote - * server, but rather it is used for other things like SSL Session caching. When an invalid - * authority is provided (like "bad_cert"), picking the original port and passing it in would - * mean that the port might used under the assumption that it was correct. By using -1 here, - * it forces the SSL implementation to treat it as invalid. - */ - host = authority; - port = -1; - } - - return new TlsNegotiator(sslContext, host, port); + public static ProtocolNegotiator tls(SslContext sslContext) { + return new TlsNegotiator(sslContext); } + @VisibleForTesting static final class TlsNegotiator implements ProtocolNegotiator { private final SslContext sslContext; - private final String host; - private final int port; - TlsNegotiator(SslContext sslContext, String host, int port) { + TlsNegotiator(SslContext sslContext) { this.sslContext = checkNotNull(sslContext, "sslContext"); - this.host = checkNotNull(host, "host"); - this.port = port; } @VisibleForTesting - String getHost() { - return host; - } - - @VisibleForTesting - int getPort() { - return port; + HostPort parseAuthority(String authority) { + URI uri = GrpcUtil.authorityToUri(Preconditions.checkNotNull(authority, "authority")); + String host; + int port; + if (uri.getHost() != null) { + host = uri.getHost(); + port = uri.getPort(); + } else { + /* + * Implementation note: We pick -1 as the port here rather than deriving it from the + * original socket address. The SSL engine doens't use this port number when contacting the + * remote server, but rather it is used for other things like SSL Session caching. When an + * invalid authority is provided (like "bad_cert"), picking the original port and passing it + * in would mean that the port might used under the assumption that it was correct. By + * using -1 here, it forces the SSL implementation to treat it as invalid. + */ + host = authority; + port = -1; + } + return new HostPort(host, port); } @Override public Handler newHandler(GrpcHttp2ConnectionHandler handler) { + final HostPort hostPort = parseAuthority(handler.getAuthority()); + ChannelHandler sslBootstrap = new ChannelHandlerAdapter() { @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { - SSLEngine sslEngine = sslContext.newEngine(ctx.alloc(), host, port); + SSLEngine sslEngine = sslContext.newEngine(ctx.alloc(), hostPort.host, hostPort.port); SSLParameters sslParams = sslEngine.getSSLParameters(); sslParams.setEndpointIdentificationAlgorithm("HTTPS"); sslEngine.setSSLParameters(sslParams); @@ -320,6 +312,18 @@ public final class ProtocolNegotiators { } } + /** A tuple of (host, port). */ + @VisibleForTesting + static final class HostPort { + final String host; + final int port; + + public HostPort(String host, int port) { + this.host = host; + this.port = port; + } + } + /** * Returns a {@link ProtocolNegotiator} used for upgrading to HTTP/2 from HTTP/1.x. */ diff --git a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java index ad7b3ec0f0..2d653a4ca9 100644 --- a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java @@ -139,65 +139,57 @@ public class NettyChannelBuilderTest { } @Test - public void createProtocolNegotiator_plaintext() { - ProtocolNegotiator negotiator = NettyChannelBuilder.createProtocolNegotiator( - "authority", + public void createProtocolNegotiatorByType_plaintext() { + ProtocolNegotiator negotiator = NettyChannelBuilder.createProtocolNegotiatorByType( NegotiationType.PLAINTEXT, - noSslContext, - noProxy); + noSslContext); // just check that the classes are the same, and that negotiator is not null. assertTrue(negotiator instanceof ProtocolNegotiators.PlaintextNegotiator); } @Test - public void createProtocolNegotiator_plaintextUpgrade() { - ProtocolNegotiator negotiator = NettyChannelBuilder.createProtocolNegotiator( - "authority", + public void createProtocolNegotiatorByType_plaintextUpgrade() { + ProtocolNegotiator negotiator = NettyChannelBuilder.createProtocolNegotiatorByType( NegotiationType.PLAINTEXT_UPGRADE, - noSslContext, - noProxy); + noSslContext); // just check that the classes are the same, and that negotiator is not null. assertTrue(negotiator instanceof ProtocolNegotiators.PlaintextUpgradeNegotiator); } @Test - public void createProtocolNegotiator_tlsWithNoContext() { + public void createProtocolNegotiatorByType_tlsWithNoContext() { thrown.expect(NullPointerException.class); - NettyChannelBuilder.createProtocolNegotiator( - "authority:1234", + NettyChannelBuilder.createProtocolNegotiatorByType( NegotiationType.TLS, - noSslContext, - noProxy); + noSslContext); } @Test - public void createProtocolNegotiator_tlsWithClientContext() throws SSLException { - ProtocolNegotiator negotiator = NettyChannelBuilder.createProtocolNegotiator( - "authority:1234", + public void createProtocolNegotiatorByType_tlsWithClientContext() throws SSLException { + ProtocolNegotiator negotiator = NettyChannelBuilder.createProtocolNegotiatorByType( NegotiationType.TLS, - GrpcSslContexts.forClient().build(), - noProxy); + GrpcSslContexts.forClient().build()); assertTrue(negotiator instanceof ProtocolNegotiators.TlsNegotiator); ProtocolNegotiators.TlsNegotiator n = (TlsNegotiator) negotiator; + ProtocolNegotiators.HostPort hostPort = n.parseAuthority("authority:1234"); - assertEquals("authority", n.getHost()); - assertEquals(1234, n.getPort()); + assertEquals("authority", hostPort.host); + assertEquals(1234, hostPort.port); } @Test - public void createProtocolNegotiator_tlsWithAuthorityFallback() throws SSLException { - ProtocolNegotiator negotiator = NettyChannelBuilder.createProtocolNegotiator( - "bad_authority", + public void createProtocolNegotiatorByType_tlsWithAuthorityFallback() throws SSLException { + ProtocolNegotiator negotiator = NettyChannelBuilder.createProtocolNegotiatorByType( NegotiationType.TLS, - GrpcSslContexts.forClient().build(), - noProxy); + GrpcSslContexts.forClient().build()); assertTrue(negotiator instanceof ProtocolNegotiators.TlsNegotiator); ProtocolNegotiators.TlsNegotiator n = (TlsNegotiator) negotiator; + ProtocolNegotiators.HostPort hostPort = n.parseAuthority("bad_authority"); - assertEquals("bad_authority", n.getHost()); - assertEquals(-1, n.getPort()); + assertEquals("bad_authority", hostPort.host); + assertEquals(-1, hostPort.port); } @Test diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index 8b6540bff1..1e02d27ec5 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -722,7 +722,8 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase