alts: Remove usage of TransportCreationParamsFilterFactory

Provide a ProtocolNegotiatorFactory instead.
TransportCreationParamsFilterFactory is being removed.
This commit is contained in:
Eric Anderson 2018-09-24 17:32:04 -07:00 committed by GitHub
parent acf80d63b0
commit 693779dba7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 67 additions and 166 deletions

View File

@ -16,8 +16,6 @@
package io.grpc.alts; package io.grpc.alts;
import static com.google.common.base.Preconditions.checkArgument;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.Channel; import io.grpc.Channel;
@ -38,14 +36,9 @@ import io.grpc.alts.internal.TsiHandshaker;
import io.grpc.alts.internal.TsiHandshakerFactory; import io.grpc.alts.internal.TsiHandshakerFactory;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.ObjectPool; import io.grpc.internal.ObjectPool;
import io.grpc.internal.ProxyParameters;
import io.grpc.internal.SharedResourcePool; import io.grpc.internal.SharedResourcePool;
import io.grpc.netty.InternalNettyChannelBuilder; import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilter;
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilterFactory;
import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.NettyChannelBuilder;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@ -64,9 +57,11 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
new AltsClientOptions.Builder(); new AltsClientOptions.Builder();
private ObjectPool<ManagedChannel> handshakerChannelPool = private ObjectPool<ManagedChannel> handshakerChannelPool =
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL); SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL);
private TcpfFactory tcpfFactoryForTest;
private boolean enableUntrustedAlts; private boolean enableUntrustedAlts;
private AltsProtocolNegotiator negotiatorForTest;
private AltsClientOptions handshakerOptionsForTest;
/** "Overrides" the static method in {@link ManagedChannelBuilder}. */ /** "Overrides" the static method in {@link ManagedChannelBuilder}. */
public static final AltsChannelBuilder forTarget(String target) { public static final AltsChannelBuilder forTarget(String target) {
return new AltsChannelBuilder(target); return new AltsChannelBuilder(target);
@ -85,6 +80,8 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
.keepAliveWithoutCalls(true); .keepAliveWithoutCalls(true);
handshakerOptionsBuilder.setRpcProtocolVersions( handshakerOptionsBuilder.setRpcProtocolVersions(
RpcProtocolVersionsUtil.getRpcProtocolVersions()); RpcProtocolVersionsUtil.getRpcProtocolVersions());
InternalNettyChannelBuilder
.setProtocolNegotiatorFactory(delegate(), new ProtocolNegotiatorFactory());
} }
/** The server service account name for secure name checking. */ /** The server service account name for secure name checking. */
@ -140,84 +137,40 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
} }
} }
final AltsClientOptions handshakerOptions = handshakerOptionsBuilder.build();
TsiHandshakerFactory altsHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker() {
// Used the shared grpc channel to connecting to the ALTS handshaker service.
// TODO: Release the channel if it is not used.
// https://github.com/grpc/grpc-java/issues/4755.
return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(handshakerChannelPool.getObject()),
handshakerOptions);
}
};
AltsProtocolNegotiator negotiator = AltsProtocolNegotiator.create(altsHandshakerFactory);
TcpfFactory tcpfFactory = new TcpfFactory(handshakerOptions, negotiator);
InternalNettyChannelBuilder.setDynamicTransportParamsFactory(delegate(), tcpfFactory);
tcpfFactoryForTest = tcpfFactory;
return delegate().build(); return delegate().build();
} }
@VisibleForTesting @VisibleForTesting
@Nullable @Nullable
TransportCreationParamsFilterFactory getTcpfFactoryForTest() { AltsProtocolNegotiator getProtocolNegotiatorForTest() {
return tcpfFactoryForTest; return negotiatorForTest;
} }
@VisibleForTesting @VisibleForTesting
@Nullable @Nullable
AltsClientOptions getAltsClientOptionsForTest() { AltsClientOptions getAltsClientOptionsForTest() {
if (tcpfFactoryForTest == null) { return handshakerOptionsForTest;
return null;
}
return tcpfFactoryForTest.handshakerOptions;
} }
private static final class TcpfFactory implements TransportCreationParamsFilterFactory { private final class ProtocolNegotiatorFactory
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
final AltsClientOptions handshakerOptions;
private final AltsProtocolNegotiator negotiator;
public TcpfFactory(AltsClientOptions handshakerOptions, AltsProtocolNegotiator negotiator) {
this.handshakerOptions = handshakerOptions;
this.negotiator = negotiator;
}
@Override @Override
public TransportCreationParamsFilter create( public AltsProtocolNegotiator buildProtocolNegotiator() {
final SocketAddress serverAddress, final AltsClientOptions handshakerOptions = handshakerOptionsBuilder.build();
final String authority, TsiHandshakerFactory altsHandshakerFactory =
final String userAgent, new TsiHandshakerFactory() {
final ProxyParameters proxy) { @Override
checkArgument( public TsiHandshaker newHandshaker() {
serverAddress instanceof InetSocketAddress, // Used the shared grpc channel to connecting to the ALTS handshaker service.
"%s must be a InetSocketAddress", // TODO: Release the channel if it is not used.
serverAddress); // https://github.com/grpc/grpc-java/issues/4755.
return new TransportCreationParamsFilter() { return AltsTsiHandshaker.newClient(
@Override HandshakerServiceGrpc.newStub(handshakerChannelPool.getObject()),
public SocketAddress getTargetServerAddress() { handshakerOptions);
return serverAddress; }
} };
handshakerOptionsForTest = handshakerOptions;
@Override return negotiatorForTest = AltsProtocolNegotiator.create(altsHandshakerFactory);
public String getAuthority() {
return authority;
}
@Override
public String getUserAgent() {
return userAgent;
}
@Override
public AltsProtocolNegotiator getProtocolNegotiator() {
return negotiator;
}
};
} }
} }

View File

@ -16,8 +16,6 @@
package io.grpc.alts; package io.grpc.alts;
import static com.google.common.base.Preconditions.checkArgument;
import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.grpc.CallCredentials; import io.grpc.CallCredentials;
@ -39,17 +37,12 @@ import io.grpc.alts.internal.TsiHandshaker;
import io.grpc.alts.internal.TsiHandshakerFactory; import io.grpc.alts.internal.TsiHandshakerFactory;
import io.grpc.auth.MoreCallCredentials; import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.ProxyParameters;
import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourceHolder;
import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.InternalNettyChannelBuilder; import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilter;
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilterFactory;
import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.NettyChannelBuilder;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.net.ssl.SSLException; import javax.net.ssl.SSLException;
@ -61,37 +54,12 @@ public final class GoogleDefaultChannelBuilder
extends ForwardingChannelBuilder<GoogleDefaultChannelBuilder> { extends ForwardingChannelBuilder<GoogleDefaultChannelBuilder> {
private final NettyChannelBuilder delegate; private final NettyChannelBuilder delegate;
private final TcpfFactory tcpfFactory; private GoogleDefaultProtocolNegotiator negotiatorForTest;
private GoogleDefaultChannelBuilder(String target) { private GoogleDefaultChannelBuilder(String target) {
delegate = NettyChannelBuilder.forTarget(target); delegate = NettyChannelBuilder.forTarget(target);
InternalNettyChannelBuilder
final AltsClientOptions handshakerOptions = .setProtocolNegotiatorFactory(delegate(), new ProtocolNegotiatorFactory());
new AltsClientOptions.Builder()
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
.build();
TsiHandshakerFactory altsHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker() {
// Used the shared grpc channel to connecting to the ALTS handshaker service.
// TODO: Release the channel if it is not used.
// https://github.com/grpc/grpc-java/issues/4755.
ManagedChannel channel =
SharedResourceHolder.get(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL);
return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(channel), handshakerOptions);
}
};
SslContext sslContext;
try {
sslContext = GrpcSslContexts.forClient().build();
} catch (SSLException ex) {
throw new RuntimeException(ex);
}
tcpfFactory = new TcpfFactory(
new GoogleDefaultProtocolNegotiator(altsHandshakerFactory, sslContext));
InternalNettyChannelBuilder.setDynamicTransportParamsFactory(delegate(), tcpfFactory);
} }
/** "Overrides" the static method in {@link ManagedChannelBuilder}. */ /** "Overrides" the static method in {@link ManagedChannelBuilder}. */
@ -125,48 +93,39 @@ public final class GoogleDefaultChannelBuilder
} }
@VisibleForTesting @VisibleForTesting
TransportCreationParamsFilterFactory getTcpfFactoryForTest() { GoogleDefaultProtocolNegotiator getProtocolNegotiatorForTest() {
return tcpfFactory; return negotiatorForTest;
} }
private static final class TcpfFactory implements TransportCreationParamsFilterFactory { private final class ProtocolNegotiatorFactory
private final GoogleDefaultProtocolNegotiator negotiator; implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
private TcpfFactory(GoogleDefaultProtocolNegotiator negotiator) {
this.negotiator = negotiator;
}
@Override @Override
public TransportCreationParamsFilter create( public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() {
final SocketAddress serverAddress, final AltsClientOptions handshakerOptions =
final String authority, new AltsClientOptions.Builder()
final String userAgent, .setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
final ProxyParameters proxy) { .build();
checkArgument( TsiHandshakerFactory altsHandshakerFactory =
serverAddress instanceof InetSocketAddress, new TsiHandshakerFactory() {
"%s must be a InetSocketAddress", @Override
serverAddress); public TsiHandshaker newHandshaker() {
return new TransportCreationParamsFilter() { // Used the shared grpc channel to connecting to the ALTS handshaker service.
@Override // TODO: Release the channel if it is not used.
public SocketAddress getTargetServerAddress() { // https://github.com/grpc/grpc-java/issues/4755.
return serverAddress; ManagedChannel channel =
} SharedResourceHolder.get(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL);
return AltsTsiHandshaker.newClient(
@Override HandshakerServiceGrpc.newStub(channel), handshakerOptions);
public String getAuthority() { }
return authority; };
} SslContext sslContext;
try {
@Override sslContext = GrpcSslContexts.forClient().build();
public String getUserAgent() { } catch (SSLException ex) {
return userAgent; throw new RuntimeException(ex);
} }
return negotiatorForTest =
@Override new GoogleDefaultProtocolNegotiator(altsHandshakerFactory, sslContext);
public GoogleDefaultProtocolNegotiator getProtocolNegotiator() {
return negotiator;
}
};
} }
} }

View File

@ -35,7 +35,8 @@ public final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator
@VisibleForTesting @VisibleForTesting
GoogleDefaultProtocolNegotiator( GoogleDefaultProtocolNegotiator(
ProtocolNegotiator altsProtocolNegotiator, ProtocolNegotiator tlsProtocolNegotiator) { ProtocolNegotiator altsProtocolNegotiator,
ProtocolNegotiator tlsProtocolNegotiator) {
this.altsProtocolNegotiator = altsProtocolNegotiator; this.altsProtocolNegotiator = altsProtocolNegotiator;
this.tlsProtocolNegotiator = tlsProtocolNegotiator; this.tlsProtocolNegotiator = tlsProtocolNegotiator;
} }

View File

@ -21,9 +21,7 @@ import static com.google.common.truth.Truth.assertThat;
import io.grpc.alts.internal.AltsClientOptions; import io.grpc.alts.internal.AltsClientOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator; import io.grpc.alts.internal.AltsProtocolNegotiator;
import io.grpc.alts.internal.TransportSecurityCommon.RpcProtocolVersions; import io.grpc.alts.internal.TransportSecurityCommon.RpcProtocolVersions;
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilterFactory;
import io.grpc.netty.ProtocolNegotiator; import io.grpc.netty.ProtocolNegotiator;
import java.net.InetSocketAddress;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
@ -36,22 +34,18 @@ public final class AltsChannelBuilderTest {
AltsChannelBuilder builder = AltsChannelBuilder builder =
AltsChannelBuilder.forTarget("localhost:8080").enableUntrustedAltsForTesting(); AltsChannelBuilder.forTarget("localhost:8080").enableUntrustedAltsForTesting();
TransportCreationParamsFilterFactory tcpfFactory = builder.getTcpfFactoryForTest(); ProtocolNegotiator protocolNegotiator = builder.getProtocolNegotiatorForTest();
AltsClientOptions altsClientOptions = builder.getAltsClientOptionsForTest(); AltsClientOptions altsClientOptions = builder.getAltsClientOptionsForTest();
assertThat(tcpfFactory).isNull(); assertThat(protocolNegotiator).isNull();
assertThat(altsClientOptions).isNull(); assertThat(altsClientOptions).isNull();
builder.build(); builder.build();
tcpfFactory = builder.getTcpfFactoryForTest(); protocolNegotiator = builder.getProtocolNegotiatorForTest();
altsClientOptions = builder.getAltsClientOptionsForTest(); altsClientOptions = builder.getAltsClientOptionsForTest();
assertThat(tcpfFactory).isNotNull(); assertThat(protocolNegotiator).isNotNull();
ProtocolNegotiator protocolNegotiator =
tcpfFactory
.create(new InetSocketAddress(8080), "fakeAuthority", "fakeUserAgent", null)
.getProtocolNegotiator();
assertThat(protocolNegotiator).isInstanceOf(AltsProtocolNegotiator.class); assertThat(protocolNegotiator).isInstanceOf(AltsProtocolNegotiator.class);
assertThat(altsClientOptions).isNotNull(); assertThat(altsClientOptions).isNotNull();

View File

@ -19,9 +19,7 @@ package io.grpc.alts;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator; import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator;
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilterFactory;
import io.grpc.netty.ProtocolNegotiator; import io.grpc.netty.ProtocolNegotiator;
import java.net.InetSocketAddress;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
@ -32,13 +30,9 @@ public final class GoogleDefaultChannelBuilderTest {
@Test @Test
public void buildsNettyChannel() throws Exception { public void buildsNettyChannel() throws Exception {
GoogleDefaultChannelBuilder builder = GoogleDefaultChannelBuilder.forTarget("localhost:8080"); GoogleDefaultChannelBuilder builder = GoogleDefaultChannelBuilder.forTarget("localhost:8080");
builder.build();
TransportCreationParamsFilterFactory tcpfFactory = builder.getTcpfFactoryForTest(); ProtocolNegotiator protocolNegotiator = builder.getProtocolNegotiatorForTest();
assertThat(tcpfFactory).isNotNull();
ProtocolNegotiator protocolNegotiator =
tcpfFactory
.create(new InetSocketAddress(8080), "fakeAuthority", "fakeUserAgent", null)
.getProtocolNegotiator();
assertThat(protocolNegotiator).isInstanceOf(GoogleDefaultProtocolNegotiator.class); assertThat(protocolNegotiator).isInstanceOf(GoogleDefaultProtocolNegotiator.class);
} }
} }