diff --git a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java index 20f6fa1b6f..ac9b3301ee 100644 --- a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java +++ b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java @@ -18,14 +18,17 @@ package io.grpc.netty.shaded; import static com.google.common.truth.Truth.assertThat; +import io.grpc.ChannelCredentials; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.internal.testing.TestUtils; import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; +import io.grpc.netty.shaded.io.grpc.netty.NettySslContextChannelCredentials; import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider; import io.grpc.stub.StreamObserver; @@ -67,7 +70,7 @@ public final class ShadingTest { @Test public void serviceLoaderFindsNetty() throws Exception { assertThat(ServerBuilder.forPort(0)).isInstanceOf(NettyServerBuilder.class); - assertThat(ManagedChannelBuilder.forAddress("localhost", 1234)) + assertThat(Grpc.newChannelBuilder("localhost:1234", InsecureChannelCredentials.create())) .isInstanceOf(NettyChannelBuilder.class); } @@ -76,9 +79,8 @@ public final class ShadingTest { server = ServerBuilder.forPort(0) .addService(new SimpleServiceImpl()) .build().start(); - channel = ManagedChannelBuilder - .forAddress("localhost", server.getPort()) - .usePlaintext() + channel = Grpc.newChannelBuilder( + "localhost:" + server.getPort(), InsecureChannelCredentials.create()) .build(); SimpleServiceBlockingStub stub = SimpleServiceGrpc.newBlockingStub(channel); assertThat(SimpleResponse.getDefaultInstance()) @@ -91,11 +93,10 @@ public final class ShadingTest { .useTransportSecurity(TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")) .addService(new SimpleServiceImpl()) .build().start(); - channel = NettyChannelBuilder - .forAddress("localhost", server.getPort()) - .sslContext( - GrpcSslContexts.configure(SslContextBuilder.forClient(), SslProvider.OPENSSL) - .trustManager(TestUtils.loadCert("ca.pem")).build()) + ChannelCredentials creds = NettySslContextChannelCredentials.create( + GrpcSslContexts.configure(SslContextBuilder.forClient(), SslProvider.OPENSSL) + .trustManager(TestUtils.loadCert("ca.pem")).build()); + channel = Grpc.newChannelBuilder("localhost:" + server.getPort(), creds) .overrideAuthority("foo.test.google.fr") .build(); SimpleServiceBlockingStub stub = SimpleServiceGrpc.newBlockingStub(channel); diff --git a/netty/src/main/java/io/grpc/netty/InsecureFromHttp1ChannelCredentials.java b/netty/src/main/java/io/grpc/netty/InsecureFromHttp1ChannelCredentials.java new file mode 100644 index 0000000000..658f0f1183 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/InsecureFromHttp1ChannelCredentials.java @@ -0,0 +1,31 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import io.grpc.ChannelCredentials; +import io.grpc.ExperimentalApi; + +/** An insecure credential that upgrades from HTTP/1 to HTTP/2. */ +@ExperimentalApi("There is no plan to make this API stable, given transport API instability") +public final class InsecureFromHttp1ChannelCredentials { + private InsecureFromHttp1ChannelCredentials() {} + + /** Creates an insecure credential that will upgrade from HTTP/1 to HTTP/2. */ + public static ChannelCredentials create() { + return NettyChannelCredentials.create(ProtocolNegotiators.plaintextUpgradeClientFactory()); + } +} diff --git a/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java index ff4d4074ec..363e0c8ef5 100644 --- a/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java @@ -18,6 +18,7 @@ package io.grpc.netty; import io.grpc.Internal; import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourcePool; import io.netty.channel.socket.nio.NioSocketChannel; @@ -37,10 +38,7 @@ public final class InternalNettyChannelBuilder { } /** A class that provides a Netty handler to control protocol negotiation. */ - public interface ProtocolNegotiatorFactory - extends NettyChannelBuilder.ProtocolNegotiatorFactory { - - @Override + public interface ProtocolNegotiatorFactory { InternalProtocolNegotiator.ProtocolNegotiator buildProtocolNegotiator(); } @@ -49,7 +47,24 @@ public final class InternalNettyChannelBuilder { * and {@code SslContext}. */ public static void setProtocolNegotiatorFactory( - NettyChannelBuilder builder, ProtocolNegotiatorFactory protocolNegotiator) { + NettyChannelBuilder builder, final ProtocolNegotiatorFactory protocolNegotiator) { + builder.protocolNegotiatorFactory(new ProtocolNegotiator.ClientFactory() { + @Override public ProtocolNegotiator newNegotiator() { + return protocolNegotiator.buildProtocolNegotiator(); + } + + @Override public int getDefaultPort() { + return GrpcUtil.DEFAULT_PORT_SSL; + } + }); + } + + /** + * Sets the {@link ProtocolNegotiatorFactory} to be used. Overrides any specified negotiation type + * and {@code SslContext}. + */ + public static void setProtocolNegotiatorFactory( + NettyChannelBuilder builder, InternalProtocolNegotiator.ClientFactory protocolNegotiator) { builder.protocolNegotiatorFactory(protocolNegotiator); } diff --git a/netty/src/main/java/io/grpc/netty/InternalNettyChannelCredentials.java b/netty/src/main/java/io/grpc/netty/InternalNettyChannelCredentials.java new file mode 100644 index 0000000000..81051a1833 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/InternalNettyChannelCredentials.java @@ -0,0 +1,34 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import io.grpc.ChannelCredentials; +import io.grpc.Internal; + +/** + * Internal {@link NettyChannelCredentials} accessor. This is intended for usage internal to the + * gRPC team. If you *really* think you need to use this, contact the gRPC team first. + */ +@Internal +public final class InternalNettyChannelCredentials { + private InternalNettyChannelCredentials() {} + + /** Creates a {@link ChannelCredentials} that will use the provided {@code negotiator}. */ + public static ChannelCredentials create(InternalProtocolNegotiator.ClientFactory negotiator) { + return NettyChannelCredentials.create(negotiator); + } +} diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiator.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiator.java index a6a8335b2e..0efa85eea7 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiator.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiator.java @@ -16,12 +16,19 @@ package io.grpc.netty; +import io.grpc.Internal; + /** * Internal accessor for {@link ProtocolNegotiator}. */ +@Internal public final class InternalProtocolNegotiator { private InternalProtocolNegotiator() {} public interface ProtocolNegotiator extends io.grpc.netty.ProtocolNegotiator {} + + public interface ClientFactory extends io.grpc.netty.ProtocolNegotiator.ClientFactory { + @Override ProtocolNegotiator newNegotiator(); + } } diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 9757cbb3f9..009629620f 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -25,6 +25,8 @@ import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED; import com.google.common.annotations.VisibleForTesting; import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.grpc.Attributes; +import io.grpc.CallCredentials; +import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; import io.grpc.EquivalentAddressGroup; import io.grpc.ExperimentalApi; @@ -91,10 +93,8 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder, Object> channelOptions = new HashMap<>(); - private NegotiationType negotiationType = NegotiationType.TLS; private ChannelFactory channelFactory = DEFAULT_CHANNEL_FACTORY; private ObjectPool eventLoopGroupPool = DEFAULT_EVENT_LOOP_GROUP_POOL; - private SslContext sslContext; private boolean autoFlowControl = DEFAULT_AUTO_FLOW_CONTROL; private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW; private int maxInboundMessageSize = GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; @@ -102,7 +102,9 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilderDefault: TLS */ public NettyChannelBuilder negotiationType(NegotiationType type) { - negotiationType = type; + checkState(!freezeProtocolNegotiatorFactory, + "Cannot change security when using ChannelCredentials"); + if (!(protocolNegotiatorFactory instanceof DefaultProtocolNegotiator)) { + // Do nothing for compatibility + return this; + } + ((DefaultProtocolNegotiator) protocolNegotiatorFactory).negotiationType = type; return this; } @@ -276,12 +314,18 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder understoodTlsFeatures = + EnumSet.noneOf(TlsChannelCredentials.Feature.class); + private ProtocolNegotiators() { } + public static FromChannelCredentialsResult from(ChannelCredentials creds) { + if (creds instanceof TlsChannelCredentials) { + TlsChannelCredentials tlsCreds = (TlsChannelCredentials) creds; + Set incomprehensible = + tlsCreds.incomprehensible(understoodTlsFeatures); + if (!incomprehensible.isEmpty()) { + return FromChannelCredentialsResult.error( + "TLS features not understood: " + incomprehensible); + } + return FromChannelCredentialsResult.negotiator(tlsClientFactory(null)); + + } else if (creds instanceof InsecureChannelCredentials) { + return FromChannelCredentialsResult.negotiator(plaintextClientFactory()); + + } else if (creds instanceof CompositeChannelCredentials) { + CompositeChannelCredentials compCreds = (CompositeChannelCredentials) creds; + return from(compCreds.getChannelCredentials()) + .withCallCredentials(compCreds.getCallCredentials()); + + } else if (creds instanceof NettyChannelCredentials) { + NettyChannelCredentials nettyCreds = (NettyChannelCredentials) creds; + return FromChannelCredentialsResult.negotiator(nettyCreds.getNegotiator()); + + } else if (creds instanceof ChoiceChannelCredentials) { + ChoiceChannelCredentials choiceCreds = (ChoiceChannelCredentials) creds; + StringBuilder error = new StringBuilder(); + for (ChannelCredentials innerCreds : choiceCreds.getCredentialsList()) { + FromChannelCredentialsResult result = from(innerCreds); + if (result.error == null) { + return result; + } + error.append(", "); + error.append(result.error); + } + return FromChannelCredentialsResult.error(error.substring(2)); + + } else { + return FromChannelCredentialsResult.error( + "Unsupported credential type: " + creds.getClass().getName()); + } + } + + public static final class FromChannelCredentialsResult { + public final ProtocolNegotiator.ClientFactory negotiator; + public final CallCredentials callCredentials; + public final String error; + + private FromChannelCredentialsResult(ProtocolNegotiator.ClientFactory negotiator, + CallCredentials creds, String error) { + this.negotiator = negotiator; + this.callCredentials = creds; + this.error = error; + } + + public static FromChannelCredentialsResult error(String error) { + return new FromChannelCredentialsResult( + null, null, Preconditions.checkNotNull(error, "error")); + } + + public static FromChannelCredentialsResult negotiator( + ProtocolNegotiator.ClientFactory factory) { + return new FromChannelCredentialsResult( + Preconditions.checkNotNull(factory, "factory"), null, null); + } + + public FromChannelCredentialsResult withCallCredentials(CallCredentials callCreds) { + Preconditions.checkNotNull(callCreds, "callCreds"); + if (error != null) { + return this; + } + if (this.callCredentials != null) { + callCreds = new CompositeCallCredentials(this.callCredentials, callCreds); + } + return new FromChannelCredentialsResult(negotiator, callCreds, null); + } + } + static ChannelLogger negotiationLogger(ChannelHandlerContext ctx) { return negotiationLogger(ctx.channel()); } @@ -446,6 +536,36 @@ final class ProtocolNegotiators { return tls(sslContext, null); } + public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) { + return new TlsProtocolNegotiatorClientFactory(sslContext); + } + + @VisibleForTesting + static final class TlsProtocolNegotiatorClientFactory + implements ProtocolNegotiator.ClientFactory { + private final SslContext sslContext; + + public TlsProtocolNegotiatorClientFactory(SslContext sslContext) { + this.sslContext = sslContext; + } + + @Override public ProtocolNegotiator newNegotiator() { + SslContext sslContext = this.sslContext; + if (sslContext == null) { + try { + sslContext = GrpcSslContexts.forClient().build(); + } catch (SSLException ex) { + throw new RuntimeException(ex); + } + } + return tls(sslContext); + } + + @Override public int getDefaultPort() { + return GrpcUtil.DEFAULT_PORT_SSL; + } + } + /** A tuple of (host, port). */ @VisibleForTesting static final class HostPort { @@ -465,6 +585,21 @@ final class ProtocolNegotiators { return new PlaintextUpgradeProtocolNegotiator(); } + public static ProtocolNegotiator.ClientFactory plaintextUpgradeClientFactory() { + return new PlaintextUpgradeProtocolNegotiatorClientFactory(); + } + + private static final class PlaintextUpgradeProtocolNegotiatorClientFactory + implements ProtocolNegotiator.ClientFactory { + @Override public ProtocolNegotiator newNegotiator() { + return plaintextUpgrade(); + } + + @Override public int getDefaultPort() { + return GrpcUtil.DEFAULT_PORT_PLAINTEXT; + } + } + static final class PlaintextUpgradeProtocolNegotiator implements ProtocolNegotiator { @Override @@ -548,6 +683,22 @@ final class ProtocolNegotiators { return new PlaintextProtocolNegotiator(); } + public static ProtocolNegotiator.ClientFactory plaintextClientFactory() { + return new PlaintextProtocolNegotiatorClientFactory(); + } + + @VisibleForTesting + static final class PlaintextProtocolNegotiatorClientFactory + implements ProtocolNegotiator.ClientFactory { + @Override public ProtocolNegotiator newNegotiator() { + return plaintext(); + } + + @Override public int getDefaultPort() { + return GrpcUtil.DEFAULT_PORT_PLAINTEXT; + } + } + private static RuntimeException unavailableException(String msg) { return Status.UNAVAILABLE.withDescription(msg).asRuntimeException(); } diff --git a/netty/src/test/java/io/grpc/netty/NettyChannelProviderTest.java b/netty/src/test/java/io/grpc/netty/NettyChannelProviderTest.java index 40bfeb2802..86c1389f00 100644 --- a/netty/src/test/java/io/grpc/netty/NettyChannelProviderTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyChannelProviderTest.java @@ -16,6 +16,7 @@ package io.grpc.netty; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; @@ -23,7 +24,9 @@ import static org.junit.Assert.fail; import io.grpc.InternalServiceProviders; import io.grpc.ManagedChannelProvider; +import io.grpc.ManagedChannelProvider.NewChannelBuilderResult; import io.grpc.ManagedChannelRegistryAccessor; +import io.grpc.TlsChannelCredentials; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -65,4 +68,23 @@ public class NettyChannelProviderTest { public void builderIsANettyBuilder() { assertSame(NettyChannelBuilder.class, provider.builderForAddress("localhost", 443).getClass()); } + + @Test + public void builderForTarget() { + assertThat(provider.builderForTarget("localhost:443")).isInstanceOf(NettyChannelBuilder.class); + } + + @Test + public void newChannelBuilder_success() { + NewChannelBuilderResult result = + provider.newChannelBuilder("localhost:443", TlsChannelCredentials.create()); + assertThat(result.getChannelBuilder()).isInstanceOf(NettyChannelBuilder.class); + } + + @Test + public void newChannelBuilder_fail() { + NewChannelBuilderResult result = provider.newChannelBuilder("localhost:443", + TlsChannelCredentials.newBuilder().requireFakeFeature().build()); + assertThat(result.getError()).contains("FAKE"); + } } diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index d83a5fabaf..2e87a089e8 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -30,11 +30,17 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import io.grpc.Attributes; +import io.grpc.CallCredentials; +import io.grpc.ChannelCredentials; +import io.grpc.ChoiceChannelCredentials; +import io.grpc.CompositeChannelCredentials; import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; import io.grpc.InternalChannelz.Security; import io.grpc.SecurityLevel; import io.grpc.Status; import io.grpc.StatusRuntimeException; +import io.grpc.TlsChannelCredentials; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.testing.TestUtils; import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler; @@ -161,6 +167,105 @@ public class ProtocolNegotiatorsTest { group.shutdownGracefully(); } + @Test + public void from_unknown() { + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(new ChannelCredentials() {}); + assertThat(result.error).isNotNull(); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator).isNull(); + } + + @Test + public void from_tls() { + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(TlsChannelCredentials.create()); + assertThat(result.error).isNull(); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); + } + + @Test + public void from_unspportedTls() { + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(TlsChannelCredentials.newBuilder().requireFakeFeature().build()); + assertThat(result.error).contains("FAKE"); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator).isNull(); + } + + @Test + public void from_insecure() { + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(InsecureChannelCredentials.create()); + assertThat(result.error).isNull(); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory.class); + } + + @Test + public void from_composite() { + CallCredentials callCredentials = mock(CallCredentials.class); + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(CompositeChannelCredentials.create( + TlsChannelCredentials.create(), callCredentials)); + assertThat(result.error).isNull(); + assertThat(result.callCredentials).isSameInstanceAs(callCredentials); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); + + result = ProtocolNegotiators.from(CompositeChannelCredentials.create( + InsecureChannelCredentials.create(), callCredentials)); + assertThat(result.error).isNull(); + assertThat(result.callCredentials).isSameInstanceAs(callCredentials); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory.class); + } + + @Test + public void from_netty() { + ProtocolNegotiator.ClientFactory factory = mock(ProtocolNegotiator.ClientFactory.class); + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(NettyChannelCredentials.create(factory)); + assertThat(result.error).isNull(); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator).isSameInstanceAs(factory); + } + + @Test + public void from_choice() { + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(ChoiceChannelCredentials.create( + new ChannelCredentials() {}, + TlsChannelCredentials.create(), + InsecureChannelCredentials.create())); + assertThat(result.error).isNull(); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); + + result = ProtocolNegotiators.from(ChoiceChannelCredentials.create( + InsecureChannelCredentials.create(), + new ChannelCredentials() {}, + TlsChannelCredentials.create())); + assertThat(result.error).isNull(); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory.class); + } + + @Test + public void from_choice_unknown() { + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(ChoiceChannelCredentials.create( + new ChannelCredentials() {})); + assertThat(result.error).isNotNull(); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator).isNull(); + } + @Test public void waitUntilActiveHandler_handlerAdded() throws Exception { final CountDownLatch latch = new CountDownLatch(1);