alts: plumb authority to ALTS protocol negotiator (#4880)

alts: plumb authority to ALTS protocol negotiator
This commit is contained in:
Jiangtao Li 2018-09-27 19:27:45 -07:00 committed by GitHub
parent da87ffb329
commit 8e72351a65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 83 additions and 84 deletions

View File

@ -17,6 +17,7 @@
package io.grpc.alts;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
@ -53,14 +54,13 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
private static final Logger logger = Logger.getLogger(AltsChannelBuilder.class.getName());
private final NettyChannelBuilder delegate;
private final AltsClientOptions.Builder handshakerOptionsBuilder =
new AltsClientOptions.Builder();
private final ImmutableList.Builder<String> targetServiceAccountsBuilder =
ImmutableList.builder();
private ObjectPool<ManagedChannel> handshakerChannelPool =
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL);
private boolean enableUntrustedAlts;
private AltsProtocolNegotiator negotiatorForTest;
private AltsClientOptions handshakerOptionsForTest;
/** "Overrides" the static method in {@link ManagedChannelBuilder}. */
public static final AltsChannelBuilder forTarget(String target) {
@ -78,16 +78,8 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
.keepAliveTime(20, TimeUnit.SECONDS)
.keepAliveTimeout(10, TimeUnit.SECONDS)
.keepAliveWithoutCalls(true);
handshakerOptionsBuilder.setRpcProtocolVersions(
RpcProtocolVersionsUtil.getRpcProtocolVersions());
InternalNettyChannelBuilder
.setProtocolNegotiatorFactory(delegate(), new ProtocolNegotiatorFactory());
}
/** The server service account name for secure name checking. */
public AltsChannelBuilder withSecureNamingTarget(String targetName) {
handshakerOptionsBuilder.setTargetName(targetName);
return this;
InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
delegate(), new ProtocolNegotiatorFactory());
}
/**
@ -95,7 +87,7 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
* service account in the handshaker result. Otherwise, the handshake fails.
*/
public AltsChannelBuilder addTargetServiceAccount(String targetServiceAccount) {
handshakerOptionsBuilder.addTargetServiceAccount(targetServiceAccount);
targetServiceAccountsBuilder.add(targetServiceAccount);
return this;
}
@ -146,31 +138,32 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
return negotiatorForTest;
}
@VisibleForTesting
@Nullable
AltsClientOptions getAltsClientOptionsForTest() {
return handshakerOptionsForTest;
}
private final class ProtocolNegotiatorFactory
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
@Override
public AltsProtocolNegotiator buildProtocolNegotiator() {
final AltsClientOptions handshakerOptions = handshakerOptionsBuilder.build();
final ImmutableList<String> targetServiceAccounts = targetServiceAccountsBuilder.build();
TsiHandshakerFactory altsHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker() {
public TsiHandshaker newHandshaker(String authority) {
// Used the shared grpc channel to connecting to the ALTS handshaker service.
// TODO: Release the channel if it is not used.
// https://github.com/grpc/grpc-java/issues/4755.
AltsClientOptions handshakerOptions =
new AltsClientOptions.Builder()
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
.setTargetServiceAccounts(targetServiceAccounts)
.setTargetName(authority)
.build();
return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(handshakerChannelPool.getObject()),
handshakerOptions);
}
};
handshakerOptionsForTest = handshakerOptions;
return negotiatorForTest = AltsProtocolNegotiator.create(altsHandshakerFactory);
return negotiatorForTest =
AltsProtocolNegotiator.createClientNegotiator(altsHandshakerFactory);
}
}

View File

@ -197,10 +197,10 @@ public final class AltsServerBuilder extends ServerBuilder<AltsServerBuilder> {
}
delegate.protocolNegotiator(
AltsProtocolNegotiator.create(
AltsProtocolNegotiator.createServerNegotiator(
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker() {
public TsiHandshaker newHandshaker(String authority) {
// Used the shared grpc channel to connecting to the ALTS handshaker service.
// TODO: Release the channel if it is not used.
// https://github.com/grpc/grpc-java/issues/4755.

View File

@ -58,8 +58,8 @@ public final class GoogleDefaultChannelBuilder
private GoogleDefaultChannelBuilder(String target) {
delegate = NettyChannelBuilder.forTarget(target);
InternalNettyChannelBuilder
.setProtocolNegotiatorFactory(delegate(), new ProtocolNegotiatorFactory());
InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
delegate(), new ProtocolNegotiatorFactory());
}
/** "Overrides" the static method in {@link ManagedChannelBuilder}. */
@ -101,19 +101,20 @@ public final class GoogleDefaultChannelBuilder
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
@Override
public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() {
final AltsClientOptions handshakerOptions =
new AltsClientOptions.Builder()
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
.build();
TsiHandshakerFactory altsHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker() {
public TsiHandshaker newHandshaker(String authority) {
// Used the shared grpc channel to connecting to the ALTS handshaker service.
// TODO: Release the channel if it is not used.
// https://github.com/grpc/grpc-java/issues/4755.
ManagedChannel channel =
SharedResourceHolder.get(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL);
AltsClientOptions handshakerOptions =
new AltsClientOptions.Builder()
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
.setTargetName(authority)
.build();
return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(channel), handshakerOptions);
}

View File

@ -16,42 +16,40 @@
package io.grpc.alts.internal;
import com.google.common.collect.ImmutableList;
import io.grpc.alts.internal.TransportSecurityCommon.RpcProtocolVersions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nullable;
/** Handshaker options for creating ALTS client channel. */
public final class AltsClientOptions extends AltsHandshakerOptions {
// targetName is the server service account name for secure name checking. This field is not yet
// supported.
// targetName is the server service account name for secure name checking.
@Nullable private final String targetName;
// targetServiceAccounts contains a list of expected target service accounts. One of these service
// accounts should match peer service account in the handshaker result. Otherwise, the handshake
// fails.
private final List<String> targetServiceAccounts;
private final ImmutableList<String> targetServiceAccounts;
private AltsClientOptions(Builder builder) {
super(builder.rpcProtocolVersions);
targetName = builder.targetName;
targetServiceAccounts =
Collections.unmodifiableList(new ArrayList<>(builder.targetServiceAccounts));
targetServiceAccounts = builder.targetServiceAccounts;
}
public String getTargetName() {
return targetName;
}
public List<String> getTargetServiceAccounts() {
public ImmutableList<String> getTargetServiceAccounts() {
return targetServiceAccounts;
}
/** Builder for AltsClientOptions. */
public static final class Builder {
@Nullable private String targetName;
@Nullable private RpcProtocolVersions rpcProtocolVersions;
private ArrayList<String> targetServiceAccounts = new ArrayList<>();
private ImmutableList<String> targetServiceAccounts = ImmutableList.of();
public Builder setTargetName(String targetName) {
this.targetName = targetName;
@ -63,8 +61,8 @@ public final class AltsClientOptions extends AltsHandshakerOptions {
return this;
}
public Builder addTargetServiceAccount(String targetServiceAccount) {
targetServiceAccounts.add(targetServiceAccount);
public Builder setTargetServiceAccounts(ImmutableList<String> targetServiceAccounts) {
this.targetServiceAccounts = targetServiceAccounts;
return this;
}

View File

@ -36,8 +36,8 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.util.AsciiString;
/**
* A client-side GRPC {@link ProtocolNegotiator} for ALTS. This class creates a Netty handler that
* provides ALTS security on the wire, similar to Netty's {@code SslHandler}.
* A GRPC {@link ProtocolNegotiator} for ALTS. This class creates a Netty handler that provides ALTS
* security on the wire, similar to Netty's {@code SslHandler}.
*/
public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
@ -54,14 +54,31 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
return ALTS_CONTEXT_KEY;
}
/** Creates a negotiator used for ALTS. */
public static AltsProtocolNegotiator create(final TsiHandshakerFactory handshakerFactory) {
/** Creates a negotiator used for ALTS client. */
public static AltsProtocolNegotiator createClientNegotiator(
final TsiHandshakerFactory handshakerFactory) {
return new AltsProtocolNegotiator() {
@Override
public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
return new BufferUntilAltsNegotiatedHandler(
grpcHandler,
new TsiHandshakeHandler(new NettyTsiHandshaker(handshakerFactory.newHandshaker())),
new TsiHandshakeHandler(
new NettyTsiHandshaker(
handshakerFactory.newHandshaker(grpcHandler.getAuthority()))),
new TsiFrameHandler());
}
};
}
/** Creates a negotiator used for ALTS server. */
public static AltsProtocolNegotiator createServerNegotiator(
final TsiHandshakerFactory handshakerFactory) {
return new AltsProtocolNegotiator() {
@Override
public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
return new BufferUntilAltsNegotiatedHandler(
grpcHandler,
new TsiHandshakeHandler(new NettyTsiHandshaker(handshakerFactory.newHandshaker(null))),
new TsiFrameHandler());
}
};

View File

@ -29,14 +29,13 @@ public final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator
private final ProtocolNegotiator tlsProtocolNegotiator;
public GoogleDefaultProtocolNegotiator(TsiHandshakerFactory altsFactory, SslContext sslContext) {
altsProtocolNegotiator = AltsProtocolNegotiator.create(altsFactory);
altsProtocolNegotiator = AltsProtocolNegotiator.createClientNegotiator(altsFactory);
tlsProtocolNegotiator = ProtocolNegotiators.tls(sslContext);
}
@VisibleForTesting
GoogleDefaultProtocolNegotiator(
ProtocolNegotiator altsProtocolNegotiator,
ProtocolNegotiator tlsProtocolNegotiator) {
ProtocolNegotiator altsProtocolNegotiator, ProtocolNegotiator tlsProtocolNegotiator) {
this.altsProtocolNegotiator = altsProtocolNegotiator;
this.tlsProtocolNegotiator = tlsProtocolNegotiator;
}

View File

@ -16,9 +16,11 @@
package io.grpc.alts.internal;
import javax.annotation.Nullable;
/** Factory that manufactures instances of {@link TsiHandshaker}. */
public interface TsiHandshakerFactory {
/** Creates a new handshaker. */
TsiHandshaker newHandshaker();
TsiHandshaker newHandshaker(@Nullable String authority);
}

View File

@ -18,9 +18,7 @@ package io.grpc.alts;
import static com.google.common.truth.Truth.assertThat;
import io.grpc.alts.internal.AltsClientOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator;
import io.grpc.alts.internal.TransportSecurityCommon.RpcProtocolVersions;
import io.grpc.netty.ProtocolNegotiator;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -35,27 +33,12 @@ public final class AltsChannelBuilderTest {
AltsChannelBuilder.forTarget("localhost:8080").enableUntrustedAltsForTesting();
ProtocolNegotiator protocolNegotiator = builder.getProtocolNegotiatorForTest();
AltsClientOptions altsClientOptions = builder.getAltsClientOptionsForTest();
assertThat(protocolNegotiator).isNull();
assertThat(altsClientOptions).isNull();
builder.build();
protocolNegotiator = builder.getProtocolNegotiatorForTest();
altsClientOptions = builder.getAltsClientOptionsForTest();
assertThat(protocolNegotiator).isNotNull();
assertThat(protocolNegotiator).isInstanceOf(AltsProtocolNegotiator.class);
assertThat(altsClientOptions).isNotNull();
RpcProtocolVersions expectedVersions =
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
.build();
assertThat(altsClientOptions.getRpcProtocolVersions()).isEqualTo(expectedVersions);
}
}

View File

@ -18,6 +18,7 @@ package io.grpc.alts.internal;
import static com.google.common.truth.Truth.assertThat;
import com.google.common.collect.ImmutableList;
import io.grpc.alts.internal.TransportSecurityCommon.RpcProtocolVersions;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -39,12 +40,12 @@ public final class AltsClientOptionsTest {
.setMinRpcVersion(
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
.build();
ImmutableList<String> serviceAccounts = ImmutableList.of(serviceAccount1, serviceAccount2);
AltsClientOptions options =
new AltsClientOptions.Builder()
.setTargetName(targetName)
.addTargetServiceAccount(serviceAccount1)
.addTargetServiceAccount(serviceAccount2)
.setTargetServiceAccounts(serviceAccounts)
.setRpcProtocolVersions(rpcVersions)
.build();

View File

@ -27,6 +27,7 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import io.grpc.alts.internal.Handshaker.HandshakeProtocol;
import io.grpc.alts.internal.Handshaker.HandshakerReq;
@ -61,7 +62,7 @@ public class AltsHandshakerClientTest {
clientOptions =
new AltsClientOptions.Builder()
.setTargetName(TEST_TARGET_NAME)
.addTargetServiceAccount(TEST_TARGET_SERVICE_ACCOUNT)
.setTargetServiceAccounts(ImmutableList.of(TEST_TARGET_SERVICE_ACCOUNT))
.build();
handshaker = new AltsHandshakerClient(mockStub, clientOptions);
}
@ -249,7 +250,7 @@ public class AltsHandshakerClientTest {
clientOptions =
new AltsClientOptions.Builder()
.setTargetName(TEST_TARGET_NAME)
.addTargetServiceAccount(TEST_TARGET_SERVICE_ACCOUNT)
.setTargetServiceAccounts(ImmutableList.of(TEST_TARGET_SERVICE_ACCOUNT))
.setRpcProtocolVersions(rpcVersions)
.build();
handshaker = new AltsHandshakerClient(mockStub, clientOptions);

View File

@ -76,6 +76,7 @@ import org.junit.runners.JUnit4;
/** Tests for {@link AltsProtocolNegotiator}. */
@RunWith(JUnit4.class)
public class AltsProtocolNegotiatorTest {
private final CapturingGrpcHttp2ConnectionHandler grpcHandler = capturingGrpcHandler();
private final List<ReferenceCounted> references = new ArrayList<>();
@ -133,8 +134,8 @@ public class AltsProtocolNegotiatorTest {
TsiHandshakerFactory handshakerFactory =
new DelegatingTsiHandshakerFactory(FakeTsiHandshaker.clientHandshakerFactory()) {
@Override
public TsiHandshaker newHandshaker() {
return new DelegatingTsiHandshaker(super.newHandshaker()) {
public TsiHandshaker newHandshaker(String authority) {
return new DelegatingTsiHandshaker(super.newHandshaker(authority)) {
@Override
public TsiPeer extractPeer() throws GeneralSecurityException {
return mockedTsiPeer;
@ -147,7 +148,8 @@ public class AltsProtocolNegotiatorTest {
};
}
};
handler = AltsProtocolNegotiator.create(handshakerFactory).newHandler(grpcHandler);
handler =
AltsProtocolNegotiator.createServerNegotiator(handshakerFactory).newHandler(grpcHandler);
channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler);
}
@ -394,6 +396,7 @@ public class AltsProtocolNegotiatorTest {
}
private final class CapturingGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
private Attributes attrs;
private CapturingGrpcHttp2ConnectionHandler(
@ -421,8 +424,8 @@ public class AltsProtocolNegotiatorTest {
}
@Override
public TsiHandshaker newHandshaker() {
return delegate.newHandshaker();
public TsiHandshaker newHandshaker(String authority) {
return delegate.newHandshaker(authority);
}
}
@ -477,6 +480,7 @@ public class AltsProtocolNegotiatorTest {
}
private static class InterceptingProtector implements TsiFrameProtector {
private final TsiFrameProtector delegate;
final AtomicInteger flushes = new AtomicInteger();

View File

@ -37,7 +37,7 @@ public class FakeTsiHandshaker implements TsiHandshaker {
private static final TsiHandshakerFactory clientHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker() {
public TsiHandshaker newHandshaker(String authority) {
return new FakeTsiHandshaker(true);
}
};
@ -45,7 +45,7 @@ public class FakeTsiHandshaker implements TsiHandshaker {
private static final TsiHandshakerFactory serverHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker() {
public TsiHandshaker newHandshaker(String authority) {
return new FakeTsiHandshaker(false);
}
};
@ -83,11 +83,11 @@ public class FakeTsiHandshaker implements TsiHandshaker {
}
public static TsiHandshaker newFakeHandshakerClient() {
return clientHandshakerFactory.newHandshaker();
return clientHandshakerFactory.newHandshaker(null);
}
public static TsiHandshaker newFakeHandshakerServer() {
return serverHandshakerFactory.newHandshaker();
return serverHandshakerFactory.newHandshaker(null);
}
protected FakeTsiHandshaker(boolean isClient) {