netty: Add ChannelCredentials

This commit is contained in:
Eric Anderson 2020-07-31 16:54:38 -07:00 committed by Eric Anderson
parent 5733cd481a
commit 1ffde15471
13 changed files with 560 additions and 56 deletions

View File

@ -18,14 +18,17 @@ package io.grpc.netty.shaded;
import static com.google.common.truth.Truth.assertThat;
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.grpc.netty.shaded.io.grpc.netty.NettySslContextChannelCredentials;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
import io.grpc.stub.StreamObserver;
@ -67,7 +70,7 @@ public final class ShadingTest {
@Test
public void serviceLoaderFindsNetty() throws Exception {
assertThat(ServerBuilder.forPort(0)).isInstanceOf(NettyServerBuilder.class);
assertThat(ManagedChannelBuilder.forAddress("localhost", 1234))
assertThat(Grpc.newChannelBuilder("localhost:1234", InsecureChannelCredentials.create()))
.isInstanceOf(NettyChannelBuilder.class);
}
@ -76,9 +79,8 @@ public final class ShadingTest {
server = ServerBuilder.forPort(0)
.addService(new SimpleServiceImpl())
.build().start();
channel = ManagedChannelBuilder
.forAddress("localhost", server.getPort())
.usePlaintext()
channel = Grpc.newChannelBuilder(
"localhost:" + server.getPort(), InsecureChannelCredentials.create())
.build();
SimpleServiceBlockingStub stub = SimpleServiceGrpc.newBlockingStub(channel);
assertThat(SimpleResponse.getDefaultInstance())
@ -91,11 +93,10 @@ public final class ShadingTest {
.useTransportSecurity(TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"))
.addService(new SimpleServiceImpl())
.build().start();
channel = NettyChannelBuilder
.forAddress("localhost", server.getPort())
.sslContext(
GrpcSslContexts.configure(SslContextBuilder.forClient(), SslProvider.OPENSSL)
.trustManager(TestUtils.loadCert("ca.pem")).build())
ChannelCredentials creds = NettySslContextChannelCredentials.create(
GrpcSslContexts.configure(SslContextBuilder.forClient(), SslProvider.OPENSSL)
.trustManager(TestUtils.loadCert("ca.pem")).build());
channel = Grpc.newChannelBuilder("localhost:" + server.getPort(), creds)
.overrideAuthority("foo.test.google.fr")
.build();
SimpleServiceBlockingStub stub = SimpleServiceGrpc.newBlockingStub(channel);

View File

@ -0,0 +1,31 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.netty;
import io.grpc.ChannelCredentials;
import io.grpc.ExperimentalApi;
/** An insecure credential that upgrades from HTTP/1 to HTTP/2. */
@ExperimentalApi("There is no plan to make this API stable, given transport API instability")
public final class InsecureFromHttp1ChannelCredentials {
private InsecureFromHttp1ChannelCredentials() {}
/** Creates an insecure credential that will upgrade from HTTP/1 to HTTP/2. */
public static ChannelCredentials create() {
return NettyChannelCredentials.create(ProtocolNegotiators.plaintextUpgradeClientFactory());
}
}

View File

@ -18,6 +18,7 @@ package io.grpc.netty;
import io.grpc.Internal;
import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourcePool;
import io.netty.channel.socket.nio.NioSocketChannel;
@ -37,10 +38,7 @@ public final class InternalNettyChannelBuilder {
}
/** A class that provides a Netty handler to control protocol negotiation. */
public interface ProtocolNegotiatorFactory
extends NettyChannelBuilder.ProtocolNegotiatorFactory {
@Override
public interface ProtocolNegotiatorFactory {
InternalProtocolNegotiator.ProtocolNegotiator buildProtocolNegotiator();
}
@ -49,7 +47,24 @@ public final class InternalNettyChannelBuilder {
* and {@code SslContext}.
*/
public static void setProtocolNegotiatorFactory(
NettyChannelBuilder builder, ProtocolNegotiatorFactory protocolNegotiator) {
NettyChannelBuilder builder, final ProtocolNegotiatorFactory protocolNegotiator) {
builder.protocolNegotiatorFactory(new ProtocolNegotiator.ClientFactory() {
@Override public ProtocolNegotiator newNegotiator() {
return protocolNegotiator.buildProtocolNegotiator();
}
@Override public int getDefaultPort() {
return GrpcUtil.DEFAULT_PORT_SSL;
}
});
}
/**
* Sets the {@link ProtocolNegotiatorFactory} to be used. Overrides any specified negotiation type
* and {@code SslContext}.
*/
public static void setProtocolNegotiatorFactory(
NettyChannelBuilder builder, InternalProtocolNegotiator.ClientFactory protocolNegotiator) {
builder.protocolNegotiatorFactory(protocolNegotiator);
}

View File

@ -0,0 +1,34 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.netty;
import io.grpc.ChannelCredentials;
import io.grpc.Internal;
/**
* Internal {@link NettyChannelCredentials} accessor. This is intended for usage internal to the
* gRPC team. If you *really* think you need to use this, contact the gRPC team first.
*/
@Internal
public final class InternalNettyChannelCredentials {
private InternalNettyChannelCredentials() {}
/** Creates a {@link ChannelCredentials} that will use the provided {@code negotiator}. */
public static ChannelCredentials create(InternalProtocolNegotiator.ClientFactory negotiator) {
return NettyChannelCredentials.create(negotiator);
}
}

View File

@ -16,12 +16,19 @@
package io.grpc.netty;
import io.grpc.Internal;
/**
* Internal accessor for {@link ProtocolNegotiator}.
*/
@Internal
public final class InternalProtocolNegotiator {
private InternalProtocolNegotiator() {}
public interface ProtocolNegotiator extends io.grpc.netty.ProtocolNegotiator {}
public interface ClientFactory extends io.grpc.netty.ProtocolNegotiator.ClientFactory {
@Override ProtocolNegotiator newNegotiator();
}
}

View File

@ -25,6 +25,8 @@ import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED;
import com.google.common.annotations.VisibleForTesting;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import io.grpc.Attributes;
import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger;
import io.grpc.EquivalentAddressGroup;
import io.grpc.ExperimentalApi;
@ -91,10 +93,8 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder<NettyCha
private final ManagedChannelImplBuilder managedChannelImplBuilder;
private TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory();
private final Map<ChannelOption<?>, Object> channelOptions = new HashMap<>();
private NegotiationType negotiationType = NegotiationType.TLS;
private ChannelFactory<? extends Channel> channelFactory = DEFAULT_CHANNEL_FACTORY;
private ObjectPool<? extends EventLoopGroup> eventLoopGroupPool = DEFAULT_EVENT_LOOP_GROUP_POOL;
private SslContext sslContext;
private boolean autoFlowControl = DEFAULT_AUTO_FLOW_CONTROL;
private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW;
private int maxInboundMessageSize = GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
@ -102,7 +102,9 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder<NettyCha
private long keepAliveTimeNanos = KEEPALIVE_TIME_NANOS_DISABLED;
private long keepAliveTimeoutNanos = DEFAULT_KEEPALIVE_TIMEOUT_NANOS;
private boolean keepAliveWithoutCalls;
private ProtocolNegotiatorFactory protocolNegotiatorFactory;
private ProtocolNegotiator.ClientFactory protocolNegotiatorFactory
= new DefaultProtocolNegotiator();
private final boolean freezeProtocolNegotiatorFactory;
private LocalSocketPicker localSocketPicker;
/**
@ -128,7 +130,15 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder<NettyCha
*/
@CheckReturnValue
public static NettyChannelBuilder forAddress(String host, int port) {
return new NettyChannelBuilder(host, port);
return forTarget(GrpcUtil.authorityFromHostAndPort(host, port));
}
/**
* Creates a new builder with the given host and port.
*/
@CheckReturnValue
public static NettyChannelBuilder forAddress(String host, int port, ChannelCredentials creds) {
return forTarget(GrpcUtil.authorityFromHostAndPort(host, port), creds);
}
/**
@ -140,9 +150,18 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder<NettyCha
return new NettyChannelBuilder(target);
}
/**
* Creates a new builder with the given target string that will be resolved by
* {@link io.grpc.NameResolver}.
*/
@CheckReturnValue
NettyChannelBuilder(String host, int port) {
this(GrpcUtil.authorityFromHostAndPort(host, port));
public static NettyChannelBuilder forTarget(String target, ChannelCredentials creds) {
ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(creds);
if (result.error != null) {
throw new IllegalArgumentException(result.error);
}
return new NettyChannelBuilder(
target, result.negotiator, result.callCredentials);
}
private final class NettyChannelTransportFactoryBuilder implements ClientTransportFactoryBuilder {
@ -155,7 +174,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder<NettyCha
private final class NettyChannelDefaultPortProvider implements ChannelBuilderDefaultPortProvider {
@Override
public int getDefaultPort() {
return NettyChannelBuilder.this.getDefaultPort();
return protocolNegotiatorFactory.getDefaultPort();
}
}
@ -164,6 +183,18 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder<NettyCha
managedChannelImplBuilder = new ManagedChannelImplBuilder(target,
new NettyChannelTransportFactoryBuilder(),
new NettyChannelDefaultPortProvider());
this.freezeProtocolNegotiatorFactory = false;
}
@CheckReturnValue
NettyChannelBuilder(
String target, ProtocolNegotiator.ClientFactory negotiator,
@Nullable CallCredentials callCreds) {
managedChannelImplBuilder = new ManagedChannelImplBuilder(target, callCreds,
new NettyChannelTransportFactoryBuilder(),
new NettyChannelDefaultPortProvider());
this.protocolNegotiatorFactory = checkNotNull(negotiator, "negotiator");
this.freezeProtocolNegotiatorFactory = true;
}
@CheckReturnValue
@ -172,6 +203,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder<NettyCha
getAuthorityFromAddress(address),
new NettyChannelTransportFactoryBuilder(),
new NettyChannelDefaultPortProvider());
this.freezeProtocolNegotiatorFactory = false;
}
@Internal
@ -242,7 +274,13 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder<NettyCha
* <p>Default: <code>TLS</code>
*/
public NettyChannelBuilder negotiationType(NegotiationType type) {
negotiationType = type;
checkState(!freezeProtocolNegotiatorFactory,
"Cannot change security when using ChannelCredentials");
if (!(protocolNegotiatorFactory instanceof DefaultProtocolNegotiator)) {
// Do nothing for compatibility
return this;
}
((DefaultProtocolNegotiator) protocolNegotiatorFactory).negotiationType = type;
return this;
}
@ -276,12 +314,18 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder<NettyCha
* GrpcSslContexts}, but options could have been overridden.
*/
public NettyChannelBuilder sslContext(SslContext sslContext) {
checkState(!freezeProtocolNegotiatorFactory,
"Cannot change security when using ChannelCredentials");
if (sslContext != null) {
checkArgument(sslContext.isClient(),
"Server SSL context can not be used for client channel");
GrpcSslContexts.ensureAlpnAndH2Enabled(sslContext.applicationProtocolNegotiator());
}
this.sslContext = sslContext;
if (!(protocolNegotiatorFactory instanceof DefaultProtocolNegotiator)) {
// Do nothing for compatibility
return this;
}
((DefaultProtocolNegotiator) protocolNegotiatorFactory).sslContext = sslContext;
return this;
}
@ -452,22 +496,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder<NettyCha
ClientTransportFactory buildTransportFactory() {
assertEventLoopAndChannelType();
ProtocolNegotiator negotiator;
if (protocolNegotiatorFactory != null) {
negotiator = protocolNegotiatorFactory.buildProtocolNegotiator();
} else {
SslContext localSslContext = sslContext;
if (negotiationType == NegotiationType.TLS && localSslContext == null) {
try {
localSslContext = GrpcSslContexts.forClient().build();
} catch (SSLException ex) {
throw new RuntimeException(ex);
}
}
negotiator = createProtocolNegotiatorByType(negotiationType, localSslContext,
this.managedChannelImplBuilder.getOffloadExecutorPool());
}
ProtocolNegotiator negotiator = protocolNegotiatorFactory.newNegotiator();
return new NettyTransportFactory(
negotiator, channelFactory, channelOptions,
eventLoopGroupPool, autoFlowControl, flowControlWindow, maxInboundMessageSize,
@ -488,15 +517,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder<NettyCha
@CheckReturnValue
int getDefaultPort() {
switch (negotiationType) {
case PLAINTEXT:
case PLAINTEXT_UPGRADE:
return GrpcUtil.DEFAULT_PORT_PLAINTEXT;
case TLS:
return GrpcUtil.DEFAULT_PORT_SSL;
default:
throw new AssertionError(negotiationType + " not handled");
}
return protocolNegotiatorFactory.getDefaultPort();
}
@VisibleForTesting
@ -527,7 +548,9 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder<NettyCha
return this;
}
void protocolNegotiatorFactory(ProtocolNegotiatorFactory protocolNegotiatorFactory) {
void protocolNegotiatorFactory(ProtocolNegotiator.ClientFactory protocolNegotiatorFactory) {
checkState(!freezeProtocolNegotiatorFactory,
"Cannot change security when using ChannelCredentials");
this.protocolNegotiatorFactory
= checkNotNull(protocolNegotiatorFactory, "protocolNegotiatorFactory");
}
@ -558,12 +581,36 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder<NettyCha
return this;
}
interface ProtocolNegotiatorFactory {
/**
* Returns a ProtocolNegotatior instance configured for this Builder. This method is called
* during {@code ManagedChannelBuilder#build()}.
*/
ProtocolNegotiator buildProtocolNegotiator();
private final class DefaultProtocolNegotiator implements ProtocolNegotiator.ClientFactory {
private NegotiationType negotiationType = NegotiationType.TLS;
private SslContext sslContext;
@Override
public ProtocolNegotiator newNegotiator() {
SslContext localSslContext = sslContext;
if (negotiationType == NegotiationType.TLS && localSslContext == null) {
try {
localSslContext = GrpcSslContexts.forClient().build();
} catch (SSLException ex) {
throw new RuntimeException(ex);
}
}
return createProtocolNegotiatorByType(negotiationType, localSslContext,
managedChannelImplBuilder.getOffloadExecutorPool());
}
@Override
public int getDefaultPort() {
switch (negotiationType) {
case PLAINTEXT:
case PLAINTEXT_UPGRADE:
return GrpcUtil.DEFAULT_PORT_PLAINTEXT;
case TLS:
return GrpcUtil.DEFAULT_PORT_SSL;
default:
throw new AssertionError(negotiationType + " not handled");
}
}
}
/**

View File

@ -0,0 +1,37 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.netty;
import com.google.common.base.Preconditions;
import io.grpc.ChannelCredentials;
/** A credential with full control over the security handshake. */
final class NettyChannelCredentials extends ChannelCredentials {
public static ChannelCredentials create(ProtocolNegotiator.ClientFactory negotiator) {
return new NettyChannelCredentials(negotiator);
}
private final ProtocolNegotiator.ClientFactory negotiator;
private NettyChannelCredentials(ProtocolNegotiator.ClientFactory negotiator) {
this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator");
}
public ProtocolNegotiator.ClientFactory getNegotiator() {
return negotiator;
}
}

View File

@ -16,6 +16,7 @@
package io.grpc.netty;
import io.grpc.ChannelCredentials;
import io.grpc.Internal;
import io.grpc.ManagedChannelProvider;
@ -41,4 +42,14 @@ public final class NettyChannelProvider extends ManagedChannelProvider {
public NettyChannelBuilder builderForTarget(String target) {
return NettyChannelBuilder.forTarget(target);
}
@Override
public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentials creds) {
ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(creds);
if (result.error != null) {
return NewChannelBuilderResult.error(result.error);
}
return NewChannelBuilderResult.channelBuilder(new NettyChannelBuilder(
target, result.negotiator, result.callCredentials));
}
}

View File

@ -0,0 +1,35 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.netty;
import io.grpc.ChannelCredentials;
import io.grpc.ExperimentalApi;
import io.netty.handler.ssl.SslContext;
/** A credential that performs TLS with Netty's SslContext as configuration. */
@ExperimentalApi("There is no plan to make this API stable, given transport API instability")
public final class NettySslContextChannelCredentials {
private NettySslContextChannelCredentials() {}
/**
* Create a credential using Netty's SslContext as configuration. It must have been configured
* with {@link GrpcSslContexts}, but options could have been overridden.
*/
public static ChannelCredentials create(SslContext sslContext) {
return NettyChannelCredentials.create(ProtocolNegotiators.tlsClientFactory(sslContext));
}
}

View File

@ -44,4 +44,12 @@ interface ProtocolNegotiator {
* on client-side.
*/
void close();
interface ClientFactory {
/** Creates a new negotiator. */
ProtocolNegotiator newNegotiator();
/** Returns the implicit port to use if no port was specified explicitly by the user. */
int getDefaultPort();
}
}

View File

@ -23,13 +23,20 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.errorprone.annotations.ForOverride;
import io.grpc.Attributes;
import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger;
import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.ChoiceChannelCredentials;
import io.grpc.CompositeCallCredentials;
import io.grpc.CompositeChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.InternalChannelz.Security;
import io.grpc.InternalChannelz.Tls;
import io.grpc.SecurityLevel;
import io.grpc.Status;
import io.grpc.TlsChannelCredentials;
import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.ObjectPool;
@ -60,11 +67,14 @@ import java.net.SocketAddress;
import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSession;
@ -73,10 +83,90 @@ import javax.net.ssl.SSLSession;
*/
final class ProtocolNegotiators {
private static final Logger log = Logger.getLogger(ProtocolNegotiators.class.getName());
private static final EnumSet<TlsChannelCredentials.Feature> understoodTlsFeatures =
EnumSet.noneOf(TlsChannelCredentials.Feature.class);
private ProtocolNegotiators() {
}
public static FromChannelCredentialsResult from(ChannelCredentials creds) {
if (creds instanceof TlsChannelCredentials) {
TlsChannelCredentials tlsCreds = (TlsChannelCredentials) creds;
Set<TlsChannelCredentials.Feature> incomprehensible =
tlsCreds.incomprehensible(understoodTlsFeatures);
if (!incomprehensible.isEmpty()) {
return FromChannelCredentialsResult.error(
"TLS features not understood: " + incomprehensible);
}
return FromChannelCredentialsResult.negotiator(tlsClientFactory(null));
} else if (creds instanceof InsecureChannelCredentials) {
return FromChannelCredentialsResult.negotiator(plaintextClientFactory());
} else if (creds instanceof CompositeChannelCredentials) {
CompositeChannelCredentials compCreds = (CompositeChannelCredentials) creds;
return from(compCreds.getChannelCredentials())
.withCallCredentials(compCreds.getCallCredentials());
} else if (creds instanceof NettyChannelCredentials) {
NettyChannelCredentials nettyCreds = (NettyChannelCredentials) creds;
return FromChannelCredentialsResult.negotiator(nettyCreds.getNegotiator());
} else if (creds instanceof ChoiceChannelCredentials) {
ChoiceChannelCredentials choiceCreds = (ChoiceChannelCredentials) creds;
StringBuilder error = new StringBuilder();
for (ChannelCredentials innerCreds : choiceCreds.getCredentialsList()) {
FromChannelCredentialsResult result = from(innerCreds);
if (result.error == null) {
return result;
}
error.append(", ");
error.append(result.error);
}
return FromChannelCredentialsResult.error(error.substring(2));
} else {
return FromChannelCredentialsResult.error(
"Unsupported credential type: " + creds.getClass().getName());
}
}
public static final class FromChannelCredentialsResult {
public final ProtocolNegotiator.ClientFactory negotiator;
public final CallCredentials callCredentials;
public final String error;
private FromChannelCredentialsResult(ProtocolNegotiator.ClientFactory negotiator,
CallCredentials creds, String error) {
this.negotiator = negotiator;
this.callCredentials = creds;
this.error = error;
}
public static FromChannelCredentialsResult error(String error) {
return new FromChannelCredentialsResult(
null, null, Preconditions.checkNotNull(error, "error"));
}
public static FromChannelCredentialsResult negotiator(
ProtocolNegotiator.ClientFactory factory) {
return new FromChannelCredentialsResult(
Preconditions.checkNotNull(factory, "factory"), null, null);
}
public FromChannelCredentialsResult withCallCredentials(CallCredentials callCreds) {
Preconditions.checkNotNull(callCreds, "callCreds");
if (error != null) {
return this;
}
if (this.callCredentials != null) {
callCreds = new CompositeCallCredentials(this.callCredentials, callCreds);
}
return new FromChannelCredentialsResult(negotiator, callCreds, null);
}
}
static ChannelLogger negotiationLogger(ChannelHandlerContext ctx) {
return negotiationLogger(ctx.channel());
}
@ -446,6 +536,36 @@ final class ProtocolNegotiators {
return tls(sslContext, null);
}
public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) {
return new TlsProtocolNegotiatorClientFactory(sslContext);
}
@VisibleForTesting
static final class TlsProtocolNegotiatorClientFactory
implements ProtocolNegotiator.ClientFactory {
private final SslContext sslContext;
public TlsProtocolNegotiatorClientFactory(SslContext sslContext) {
this.sslContext = sslContext;
}
@Override public ProtocolNegotiator newNegotiator() {
SslContext sslContext = this.sslContext;
if (sslContext == null) {
try {
sslContext = GrpcSslContexts.forClient().build();
} catch (SSLException ex) {
throw new RuntimeException(ex);
}
}
return tls(sslContext);
}
@Override public int getDefaultPort() {
return GrpcUtil.DEFAULT_PORT_SSL;
}
}
/** A tuple of (host, port). */
@VisibleForTesting
static final class HostPort {
@ -465,6 +585,21 @@ final class ProtocolNegotiators {
return new PlaintextUpgradeProtocolNegotiator();
}
public static ProtocolNegotiator.ClientFactory plaintextUpgradeClientFactory() {
return new PlaintextUpgradeProtocolNegotiatorClientFactory();
}
private static final class PlaintextUpgradeProtocolNegotiatorClientFactory
implements ProtocolNegotiator.ClientFactory {
@Override public ProtocolNegotiator newNegotiator() {
return plaintextUpgrade();
}
@Override public int getDefaultPort() {
return GrpcUtil.DEFAULT_PORT_PLAINTEXT;
}
}
static final class PlaintextUpgradeProtocolNegotiator implements ProtocolNegotiator {
@Override
@ -548,6 +683,22 @@ final class ProtocolNegotiators {
return new PlaintextProtocolNegotiator();
}
public static ProtocolNegotiator.ClientFactory plaintextClientFactory() {
return new PlaintextProtocolNegotiatorClientFactory();
}
@VisibleForTesting
static final class PlaintextProtocolNegotiatorClientFactory
implements ProtocolNegotiator.ClientFactory {
@Override public ProtocolNegotiator newNegotiator() {
return plaintext();
}
@Override public int getDefaultPort() {
return GrpcUtil.DEFAULT_PORT_PLAINTEXT;
}
}
private static RuntimeException unavailableException(String msg) {
return Status.UNAVAILABLE.withDescription(msg).asRuntimeException();
}

View File

@ -16,6 +16,7 @@
package io.grpc.netty;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
@ -23,7 +24,9 @@ import static org.junit.Assert.fail;
import io.grpc.InternalServiceProviders;
import io.grpc.ManagedChannelProvider;
import io.grpc.ManagedChannelProvider.NewChannelBuilderResult;
import io.grpc.ManagedChannelRegistryAccessor;
import io.grpc.TlsChannelCredentials;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@ -65,4 +68,23 @@ public class NettyChannelProviderTest {
public void builderIsANettyBuilder() {
assertSame(NettyChannelBuilder.class, provider.builderForAddress("localhost", 443).getClass());
}
@Test
public void builderForTarget() {
assertThat(provider.builderForTarget("localhost:443")).isInstanceOf(NettyChannelBuilder.class);
}
@Test
public void newChannelBuilder_success() {
NewChannelBuilderResult result =
provider.newChannelBuilder("localhost:443", TlsChannelCredentials.create());
assertThat(result.getChannelBuilder()).isInstanceOf(NettyChannelBuilder.class);
}
@Test
public void newChannelBuilder_fail() {
NewChannelBuilderResult result = provider.newChannelBuilder("localhost:443",
TlsChannelCredentials.newBuilder().requireFakeFeature().build());
assertThat(result.getError()).contains("FAKE");
}
}

View File

@ -30,11 +30,17 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import io.grpc.Attributes;
import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
import io.grpc.ChoiceChannelCredentials;
import io.grpc.CompositeChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.InternalChannelz.Security;
import io.grpc.SecurityLevel;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.TlsChannelCredentials;
import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler;
@ -161,6 +167,105 @@ public class ProtocolNegotiatorsTest {
group.shutdownGracefully();
}
@Test
public void from_unknown() {
ProtocolNegotiators.FromChannelCredentialsResult result =
ProtocolNegotiators.from(new ChannelCredentials() {});
assertThat(result.error).isNotNull();
assertThat(result.callCredentials).isNull();
assertThat(result.negotiator).isNull();
}
@Test
public void from_tls() {
ProtocolNegotiators.FromChannelCredentialsResult result =
ProtocolNegotiators.from(TlsChannelCredentials.create());
assertThat(result.error).isNull();
assertThat(result.callCredentials).isNull();
assertThat(result.negotiator)
.isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class);
}
@Test
public void from_unspportedTls() {
ProtocolNegotiators.FromChannelCredentialsResult result =
ProtocolNegotiators.from(TlsChannelCredentials.newBuilder().requireFakeFeature().build());
assertThat(result.error).contains("FAKE");
assertThat(result.callCredentials).isNull();
assertThat(result.negotiator).isNull();
}
@Test
public void from_insecure() {
ProtocolNegotiators.FromChannelCredentialsResult result =
ProtocolNegotiators.from(InsecureChannelCredentials.create());
assertThat(result.error).isNull();
assertThat(result.callCredentials).isNull();
assertThat(result.negotiator)
.isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory.class);
}
@Test
public void from_composite() {
CallCredentials callCredentials = mock(CallCredentials.class);
ProtocolNegotiators.FromChannelCredentialsResult result =
ProtocolNegotiators.from(CompositeChannelCredentials.create(
TlsChannelCredentials.create(), callCredentials));
assertThat(result.error).isNull();
assertThat(result.callCredentials).isSameInstanceAs(callCredentials);
assertThat(result.negotiator)
.isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class);
result = ProtocolNegotiators.from(CompositeChannelCredentials.create(
InsecureChannelCredentials.create(), callCredentials));
assertThat(result.error).isNull();
assertThat(result.callCredentials).isSameInstanceAs(callCredentials);
assertThat(result.negotiator)
.isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory.class);
}
@Test
public void from_netty() {
ProtocolNegotiator.ClientFactory factory = mock(ProtocolNegotiator.ClientFactory.class);
ProtocolNegotiators.FromChannelCredentialsResult result =
ProtocolNegotiators.from(NettyChannelCredentials.create(factory));
assertThat(result.error).isNull();
assertThat(result.callCredentials).isNull();
assertThat(result.negotiator).isSameInstanceAs(factory);
}
@Test
public void from_choice() {
ProtocolNegotiators.FromChannelCredentialsResult result =
ProtocolNegotiators.from(ChoiceChannelCredentials.create(
new ChannelCredentials() {},
TlsChannelCredentials.create(),
InsecureChannelCredentials.create()));
assertThat(result.error).isNull();
assertThat(result.callCredentials).isNull();
assertThat(result.negotiator)
.isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class);
result = ProtocolNegotiators.from(ChoiceChannelCredentials.create(
InsecureChannelCredentials.create(),
new ChannelCredentials() {},
TlsChannelCredentials.create()));
assertThat(result.error).isNull();
assertThat(result.callCredentials).isNull();
assertThat(result.negotiator)
.isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory.class);
}
@Test
public void from_choice_unknown() {
ProtocolNegotiators.FromChannelCredentialsResult result =
ProtocolNegotiators.from(ChoiceChannelCredentials.create(
new ChannelCredentials() {}));
assertThat(result.error).isNotNull();
assertThat(result.callCredentials).isNull();
assertThat(result.negotiator).isNull();
}
@Test
public void waitUntilActiveHandler_handlerAdded() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);