alts: if ALTS is not running on GCP, fails call (#4807)

alts: if ALTS is not running on GCP, fails call rather than RuntimeException
This commit is contained in:
Jiangtao Li 2018-08-29 08:59:16 -07:00 committed by GitHub
parent fd73209e0c
commit 87513d8e83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 116 additions and 202 deletions

View File

@ -20,13 +20,15 @@ import static com.google.common.base.Preconditions.checkArgument;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ConnectivityState;
import io.grpc.ClientInterceptor;
import io.grpc.ExperimentalApi;
import io.grpc.ForwardingChannelBuilder;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.alts.internal.AltsClientOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator;
import io.grpc.alts.internal.AltsTsiHandshaker;
@ -43,6 +45,8 @@ import io.grpc.netty.NettyChannelBuilder;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
/**
@ -52,6 +56,7 @@ import javax.annotation.Nullable;
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/4151")
public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChannelBuilder> {
private static final Logger logger = Logger.getLogger(AltsChannelBuilder.class.getName());
private final NettyChannelBuilder delegate;
private final AltsClientOptions.Builder handshakerOptionsBuilder =
new AltsClientOptions.Builder();
@ -115,13 +120,24 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
@Override
public ManagedChannel build() {
CheckGcpEnvironment.check(enableUntrustedAlts);
if (!CheckGcpEnvironment.isOnGcp()) {
if (enableUntrustedAlts) {
logger.log(
Level.WARNING,
"Untrusted ALTS mode is enabled and we cannot guarantee the trustworthiness of the "
+ "ALTS handshaker service");
} else {
Status status =
Status.INTERNAL.withDescription("ALTS is only allowed to run on Google Cloud Platform");
delegate().intercept(new FailingClientInterceptor(status));
}
}
TcpfFactory tcpfFactory = new TcpfFactory();
InternalNettyChannelBuilder.setDynamicTransportParamsFactory(delegate(), tcpfFactory);
tcpfFactoryForTest = tcpfFactory;
return new AltsChannel(delegate().build());
return delegate().build();
}
@VisibleForTesting
@ -140,6 +156,7 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
}
private final class TcpfFactory implements TransportCreationParamsFilterFactory {
final AltsClientOptions handshakerOptions = handshakerOptionsBuilder.build();
private final TsiHandshakerFactory altsHandshakerFactory =
@ -189,69 +206,19 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
}
}
static final class AltsChannel extends ManagedChannel {
private final ManagedChannel delegate;
/** An implementation of {@link ClientInterceptor} that fails each call. */
static final class FailingClientInterceptor implements ClientInterceptor {
AltsChannel(ManagedChannel delegate) {
this.delegate = delegate;
private final Status status;
public FailingClientInterceptor(Status status) {
this.status = status;
}
@Override
public ConnectivityState getState(boolean requestConnection) {
return delegate.getState(requestConnection);
}
@Override
public void notifyWhenStateChanged(ConnectivityState source, Runnable callback) {
delegate.notifyWhenStateChanged(source, callback);
}
@Override
public AltsChannel shutdown() {
delegate.shutdown();
return this;
}
@Override
public boolean isShutdown() {
return delegate.isShutdown();
}
@Override
public boolean isTerminated() {
return delegate.isTerminated();
}
@Override
public AltsChannel shutdownNow() {
delegate.shutdownNow();
return this;
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return delegate.awaitTermination(timeout, unit);
}
@Override
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
return delegate.newCall(methodDescriptor, callOptions);
}
@Override
public String authority() {
return delegate.authority();
}
@Override
public void resetConnectBackoff() {
delegate.resetConnectBackoff();
}
@Override
public void enterIdle() {
delegate.enterIdle();
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return new FailingClientCall<>(status);
}
}
}

View File

@ -16,18 +16,22 @@
package io.grpc.alts;
import com.google.common.base.MoreObjects;
import io.grpc.BindableService;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.ExperimentalApi;
import io.grpc.HandlerRegistry;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServerStreamTracer.Factory;
import io.grpc.ServerTransportFilter;
import io.grpc.Status;
import io.grpc.alts.internal.AltsHandshakerOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator;
import io.grpc.alts.internal.AltsTsiHandshaker;
@ -37,11 +41,11 @@ import io.grpc.alts.internal.TsiHandshaker;
import io.grpc.alts.internal.TsiHandshakerFactory;
import io.grpc.netty.NettyServerBuilder;
import java.io.File;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* gRPC secure server builder used for ALTS. This class adds on the necessary ALTS support to create
@ -50,6 +54,7 @@ import java.util.concurrent.TimeUnit;
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/4151")
public final class AltsServerBuilder extends ServerBuilder<AltsServerBuilder> {
private static final Logger logger = Logger.getLogger(AltsServerBuilder.class.getName());
private final NettyServerBuilder delegate;
private boolean enableUntrustedAlts;
@ -170,7 +175,19 @@ public final class AltsServerBuilder extends ServerBuilder<AltsServerBuilder> {
/** {@inheritDoc} */
@Override
public Server build() {
CheckGcpEnvironment.check(enableUntrustedAlts);
if (!CheckGcpEnvironment.isOnGcp()) {
if (enableUntrustedAlts) {
logger.log(
Level.WARNING,
"Untrusted ALTS mode is enabled and we cannot guarantee the trustworthiness of the "
+ "ALTS handshaker service");
} else {
Status status =
Status.INTERNAL.withDescription("ALTS is only allowed to run on Google Cloud Platform");
delegate.intercept(new FailingServerInterceptor(status));
}
}
delegate.protocolNegotiator(
AltsProtocolNegotiator.create(
new TsiHandshakerFactory() {
@ -182,77 +199,25 @@ public final class AltsServerBuilder extends ServerBuilder<AltsServerBuilder> {
new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions()));
}
}));
return new AltsServer(delegate.build());
return delegate.build();
}
static final class AltsServer extends io.grpc.Server {
private final Server delegate;
/** An implementation of {@link ServerInterceptor} that fails each call. */
static final class FailingServerInterceptor implements ServerInterceptor {
AltsServer(Server delegate) {
this.delegate = delegate;
private final Status status;
public FailingServerInterceptor(Status status) {
this.status = status;
}
@Override
public List<ServerServiceDefinition> getImmutableServices() {
return delegate.getImmutableServices();
}
@Override
public List<ServerServiceDefinition> getMutableServices() {
return delegate.getMutableServices();
}
@Override
public int getPort() {
return delegate.getPort();
}
@Override
public List<ServerServiceDefinition> getServices() {
return delegate.getServices();
}
@Override
public Server start() throws IOException {
delegate.start();
return this;
}
@Override
public Server shutdown() {
delegate.shutdown();
return this;
}
@Override
public Server shutdownNow() {
delegate.shutdownNow();
return this;
}
@Override
public boolean isShutdown() {
return delegate.isShutdown();
}
@Override
public boolean isTerminated() {
return delegate.isTerminated();
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return delegate.awaitTermination(timeout, unit);
}
@Override
public void awaitTermination() throws InterruptedException {
delegate.awaitTermination();
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this).add("delegate", delegate).toString();
public <ReqT, RespT> Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> serverCall,
Metadata metadata,
ServerCallHandler<ReqT, RespT> nextHandler) {
serverCall.close(status, new Metadata());
return new Listener<ReqT>() {};
}
}
}

View File

@ -30,6 +30,7 @@ import org.apache.commons.lang3.SystemUtils;
/** Class for checking if the system is running on Google Cloud Platform (GCP). */
final class CheckGcpEnvironment {
private static final Logger logger = Logger.getLogger(CheckGcpEnvironment.class.getName());
private static final String DMI_PRODUCT_NAME = "/sys/class/dmi/id/product_name";
private static final String WINDOWS_COMMAND = "powershell.exe";
@ -38,18 +39,7 @@ final class CheckGcpEnvironment {
// Construct me not!
private CheckGcpEnvironment() {}
public static void check(boolean enableUntrustedAlts) {
if (enableUntrustedAlts) {
logger.log(
Level.WARNING,
"Untrusted ALTS mode is enabled and we cannot guarantee the trustworthiness of the ALTS "
+ "handshaker service.");
} else if (!isOnGcp()) {
throw new RuntimeException("ALTS is only allowed to run on Google Cloud Platform.");
}
}
private static synchronized boolean isOnGcp() {
static synchronized boolean isOnGcp() {
if (cachedResult == null) {
cachedResult = isRunningOnGcp();
}

View File

@ -0,0 +1,48 @@
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import io.grpc.ClientCall;
import io.grpc.Metadata;
import io.grpc.Status;
/** An implementation of {@link ClientCall} that fails when started. */
final class FailingClientCall<ReqT, RespT> extends ClientCall<ReqT, RespT> {
private final Status error;
public FailingClientCall(Status error) {
this.error = error;
}
@Override
public void start(ClientCall.Listener<RespT> listener, Metadata headers) {
listener.onClose(error, new Metadata());
}
@Override
public void request(int numMessages) {}
@Override
public void cancel(String message, Throwable cause) {}
@Override
public void halfClose() {}
@Override
public void sendMessage(ReqT message) {}
}

View File

@ -28,7 +28,6 @@ import io.grpc.ClientInterceptor;
import io.grpc.ForwardingChannelBuilder;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.alts.internal.AltsClientOptions;
@ -91,7 +90,7 @@ public final class GoogleDefaultChannelBuilder
credentials = MoreCallCredentials.from(GoogleCredentials.getApplicationDefault());
} catch (IOException e) {
status =
Status.FAILED_PRECONDITION
Status.UNAUTHENTICATED
.withDescription("Failed to get Google default credentials")
.withCause(e);
}
@ -188,31 +187,4 @@ public final class GoogleDefaultChannelBuilder
return next.newCall(method, callOptions.withCallCredentials(credentials));
}
}
/** An implementation of {@link ClientCall} that fails when started. */
static final class FailingClientCall<ReqT, RespT> extends ClientCall<ReqT, RespT> {
private final Status error;
public FailingClientCall(Status error) {
this.error = error;
}
@Override
public void start(ClientCall.Listener<RespT> listener, Metadata headers) {
listener.onClose(error, new Metadata());
}
@Override
public void request(int numMessages) {}
@Override
public void cancel(String message, Throwable cause) {}
@Override
public void halfClose() {}
@Override
public void sendMessage(ReqT message) {}
}
}

View File

@ -17,19 +17,12 @@
package io.grpc.alts;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import com.google.common.base.Defaults;
import io.grpc.ManagedChannel;
import io.grpc.alts.AltsChannelBuilder.AltsChannel;
import io.grpc.alts.internal.AltsClientOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator;
import io.grpc.alts.internal.TransportSecurityCommon.RpcProtocolVersions;
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilterFactory;
import io.grpc.netty.ProtocolNegotiator;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.net.InetSocketAddress;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -49,8 +42,7 @@ public final class AltsChannelBuilderTest {
assertThat(tcpfFactory).isNull();
assertThat(altsClientOptions).isNull();
ManagedChannel channel = builder.build();
assertThat(channel).isInstanceOf(AltsChannel.class);
builder.build();
tcpfFactory = builder.getTcpfFactoryForTest();
altsClientOptions = builder.getAltsClientOptionsForTest();
@ -72,24 +64,4 @@ public final class AltsChannelBuilderTest {
.build();
assertThat(altsClientOptions.getRpcProtocolVersions()).isEqualTo(expectedVersions);
}
@Test
public void allAltsChannelMethodsForward() throws Exception {
ManagedChannel mockDelegate = mock(ManagedChannel.class);
AltsChannel altsChannel = new AltsChannel(mockDelegate);
for (Method method : ManagedChannel.class.getDeclaredMethods()) {
if (Modifier.isStatic(method.getModifiers()) || Modifier.isPrivate(method.getModifiers())) {
continue;
}
Class<?>[] argTypes = method.getParameterTypes();
Object[] args = new Object[argTypes.length];
for (int i = 0; i < argTypes.length; i++) {
args[i] = Defaults.defaultValue(argTypes[i]);
}
method.invoke(altsChannel, args);
method.invoke(verify(mockDelegate), args);
}
}
}