mirror of https://github.com/grpc/grpc-java.git
alts: if ALTS is not running on GCP, fails call (#4807)
alts: if ALTS is not running on GCP, fails call rather than RuntimeException
This commit is contained in:
parent
fd73209e0c
commit
87513d8e83
|
|
@ -20,13 +20,15 @@ import static com.google.common.base.Preconditions.checkArgument;
|
|||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.grpc.CallOptions;
|
||||
import io.grpc.Channel;
|
||||
import io.grpc.ClientCall;
|
||||
import io.grpc.ConnectivityState;
|
||||
import io.grpc.ClientInterceptor;
|
||||
import io.grpc.ExperimentalApi;
|
||||
import io.grpc.ForwardingChannelBuilder;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.ManagedChannelBuilder;
|
||||
import io.grpc.MethodDescriptor;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.alts.internal.AltsClientOptions;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator;
|
||||
import io.grpc.alts.internal.AltsTsiHandshaker;
|
||||
|
|
@ -43,6 +45,8 @@ import io.grpc.netty.NettyChannelBuilder;
|
|||
import java.net.InetSocketAddress;
|
||||
import java.net.SocketAddress;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
|
|
@ -52,6 +56,7 @@ import javax.annotation.Nullable;
|
|||
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/4151")
|
||||
public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChannelBuilder> {
|
||||
|
||||
private static final Logger logger = Logger.getLogger(AltsChannelBuilder.class.getName());
|
||||
private final NettyChannelBuilder delegate;
|
||||
private final AltsClientOptions.Builder handshakerOptionsBuilder =
|
||||
new AltsClientOptions.Builder();
|
||||
|
|
@ -115,13 +120,24 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
|
|||
|
||||
@Override
|
||||
public ManagedChannel build() {
|
||||
CheckGcpEnvironment.check(enableUntrustedAlts);
|
||||
if (!CheckGcpEnvironment.isOnGcp()) {
|
||||
if (enableUntrustedAlts) {
|
||||
logger.log(
|
||||
Level.WARNING,
|
||||
"Untrusted ALTS mode is enabled and we cannot guarantee the trustworthiness of the "
|
||||
+ "ALTS handshaker service");
|
||||
} else {
|
||||
Status status =
|
||||
Status.INTERNAL.withDescription("ALTS is only allowed to run on Google Cloud Platform");
|
||||
delegate().intercept(new FailingClientInterceptor(status));
|
||||
}
|
||||
}
|
||||
|
||||
TcpfFactory tcpfFactory = new TcpfFactory();
|
||||
InternalNettyChannelBuilder.setDynamicTransportParamsFactory(delegate(), tcpfFactory);
|
||||
|
||||
tcpfFactoryForTest = tcpfFactory;
|
||||
|
||||
return new AltsChannel(delegate().build());
|
||||
return delegate().build();
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
|
|
@ -140,6 +156,7 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
|
|||
}
|
||||
|
||||
private final class TcpfFactory implements TransportCreationParamsFilterFactory {
|
||||
|
||||
final AltsClientOptions handshakerOptions = handshakerOptionsBuilder.build();
|
||||
|
||||
private final TsiHandshakerFactory altsHandshakerFactory =
|
||||
|
|
@ -189,69 +206,19 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
|
|||
}
|
||||
}
|
||||
|
||||
static final class AltsChannel extends ManagedChannel {
|
||||
private final ManagedChannel delegate;
|
||||
/** An implementation of {@link ClientInterceptor} that fails each call. */
|
||||
static final class FailingClientInterceptor implements ClientInterceptor {
|
||||
|
||||
AltsChannel(ManagedChannel delegate) {
|
||||
this.delegate = delegate;
|
||||
private final Status status;
|
||||
|
||||
public FailingClientInterceptor(Status status) {
|
||||
this.status = status;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ConnectivityState getState(boolean requestConnection) {
|
||||
return delegate.getState(requestConnection);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyWhenStateChanged(ConnectivityState source, Runnable callback) {
|
||||
delegate.notifyWhenStateChanged(source, callback);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AltsChannel shutdown() {
|
||||
delegate.shutdown();
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isShutdown() {
|
||||
return delegate.isShutdown();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isTerminated() {
|
||||
return delegate.isTerminated();
|
||||
}
|
||||
|
||||
@Override
|
||||
public AltsChannel shutdownNow() {
|
||||
delegate.shutdownNow();
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
|
||||
return delegate.awaitTermination(timeout, unit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
|
||||
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
|
||||
return delegate.newCall(methodDescriptor, callOptions);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String authority() {
|
||||
return delegate.authority();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void resetConnectBackoff() {
|
||||
delegate.resetConnectBackoff();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void enterIdle() {
|
||||
delegate.enterIdle();
|
||||
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
|
||||
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
|
||||
return new FailingClientCall<>(status);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,18 +16,22 @@
|
|||
|
||||
package io.grpc.alts;
|
||||
|
||||
import com.google.common.base.MoreObjects;
|
||||
import io.grpc.BindableService;
|
||||
import io.grpc.CompressorRegistry;
|
||||
import io.grpc.DecompressorRegistry;
|
||||
import io.grpc.ExperimentalApi;
|
||||
import io.grpc.HandlerRegistry;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.ServerBuilder;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCall.Listener;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.ServerServiceDefinition;
|
||||
import io.grpc.ServerStreamTracer.Factory;
|
||||
import io.grpc.ServerTransportFilter;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.alts.internal.AltsHandshakerOptions;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator;
|
||||
import io.grpc.alts.internal.AltsTsiHandshaker;
|
||||
|
|
@ -37,11 +41,11 @@ import io.grpc.alts.internal.TsiHandshaker;
|
|||
import io.grpc.alts.internal.TsiHandshakerFactory;
|
||||
import io.grpc.netty.NettyServerBuilder;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* gRPC secure server builder used for ALTS. This class adds on the necessary ALTS support to create
|
||||
|
|
@ -50,6 +54,7 @@ import java.util.concurrent.TimeUnit;
|
|||
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/4151")
|
||||
public final class AltsServerBuilder extends ServerBuilder<AltsServerBuilder> {
|
||||
|
||||
private static final Logger logger = Logger.getLogger(AltsServerBuilder.class.getName());
|
||||
private final NettyServerBuilder delegate;
|
||||
private boolean enableUntrustedAlts;
|
||||
|
||||
|
|
@ -170,7 +175,19 @@ public final class AltsServerBuilder extends ServerBuilder<AltsServerBuilder> {
|
|||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public Server build() {
|
||||
CheckGcpEnvironment.check(enableUntrustedAlts);
|
||||
if (!CheckGcpEnvironment.isOnGcp()) {
|
||||
if (enableUntrustedAlts) {
|
||||
logger.log(
|
||||
Level.WARNING,
|
||||
"Untrusted ALTS mode is enabled and we cannot guarantee the trustworthiness of the "
|
||||
+ "ALTS handshaker service");
|
||||
} else {
|
||||
Status status =
|
||||
Status.INTERNAL.withDescription("ALTS is only allowed to run on Google Cloud Platform");
|
||||
delegate.intercept(new FailingServerInterceptor(status));
|
||||
}
|
||||
}
|
||||
|
||||
delegate.protocolNegotiator(
|
||||
AltsProtocolNegotiator.create(
|
||||
new TsiHandshakerFactory() {
|
||||
|
|
@ -182,77 +199,25 @@ public final class AltsServerBuilder extends ServerBuilder<AltsServerBuilder> {
|
|||
new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions()));
|
||||
}
|
||||
}));
|
||||
return new AltsServer(delegate.build());
|
||||
return delegate.build();
|
||||
}
|
||||
|
||||
static final class AltsServer extends io.grpc.Server {
|
||||
private final Server delegate;
|
||||
/** An implementation of {@link ServerInterceptor} that fails each call. */
|
||||
static final class FailingServerInterceptor implements ServerInterceptor {
|
||||
|
||||
AltsServer(Server delegate) {
|
||||
this.delegate = delegate;
|
||||
private final Status status;
|
||||
|
||||
public FailingServerInterceptor(Status status) {
|
||||
this.status = status;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ServerServiceDefinition> getImmutableServices() {
|
||||
return delegate.getImmutableServices();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ServerServiceDefinition> getMutableServices() {
|
||||
return delegate.getMutableServices();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getPort() {
|
||||
return delegate.getPort();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ServerServiceDefinition> getServices() {
|
||||
return delegate.getServices();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Server start() throws IOException {
|
||||
delegate.start();
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Server shutdown() {
|
||||
delegate.shutdown();
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Server shutdownNow() {
|
||||
delegate.shutdownNow();
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isShutdown() {
|
||||
return delegate.isShutdown();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isTerminated() {
|
||||
return delegate.isTerminated();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
|
||||
return delegate.awaitTermination(timeout, unit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void awaitTermination() throws InterruptedException {
|
||||
delegate.awaitTermination();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return MoreObjects.toStringHelper(this).add("delegate", delegate).toString();
|
||||
public <ReqT, RespT> Listener<ReqT> interceptCall(
|
||||
ServerCall<ReqT, RespT> serverCall,
|
||||
Metadata metadata,
|
||||
ServerCallHandler<ReqT, RespT> nextHandler) {
|
||||
serverCall.close(status, new Metadata());
|
||||
return new Listener<ReqT>() {};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import org.apache.commons.lang3.SystemUtils;
|
|||
|
||||
/** Class for checking if the system is running on Google Cloud Platform (GCP). */
|
||||
final class CheckGcpEnvironment {
|
||||
|
||||
private static final Logger logger = Logger.getLogger(CheckGcpEnvironment.class.getName());
|
||||
private static final String DMI_PRODUCT_NAME = "/sys/class/dmi/id/product_name";
|
||||
private static final String WINDOWS_COMMAND = "powershell.exe";
|
||||
|
|
@ -38,18 +39,7 @@ final class CheckGcpEnvironment {
|
|||
// Construct me not!
|
||||
private CheckGcpEnvironment() {}
|
||||
|
||||
public static void check(boolean enableUntrustedAlts) {
|
||||
if (enableUntrustedAlts) {
|
||||
logger.log(
|
||||
Level.WARNING,
|
||||
"Untrusted ALTS mode is enabled and we cannot guarantee the trustworthiness of the ALTS "
|
||||
+ "handshaker service.");
|
||||
} else if (!isOnGcp()) {
|
||||
throw new RuntimeException("ALTS is only allowed to run on Google Cloud Platform.");
|
||||
}
|
||||
}
|
||||
|
||||
private static synchronized boolean isOnGcp() {
|
||||
static synchronized boolean isOnGcp() {
|
||||
if (cachedResult == null) {
|
||||
cachedResult = isRunningOnGcp();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Copyright 2018 The gRPC Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import io.grpc.ClientCall;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.Status;
|
||||
|
||||
/** An implementation of {@link ClientCall} that fails when started. */
|
||||
final class FailingClientCall<ReqT, RespT> extends ClientCall<ReqT, RespT> {
|
||||
|
||||
private final Status error;
|
||||
|
||||
public FailingClientCall(Status error) {
|
||||
this.error = error;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start(ClientCall.Listener<RespT> listener, Metadata headers) {
|
||||
listener.onClose(error, new Metadata());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void request(int numMessages) {}
|
||||
|
||||
@Override
|
||||
public void cancel(String message, Throwable cause) {}
|
||||
|
||||
@Override
|
||||
public void halfClose() {}
|
||||
|
||||
@Override
|
||||
public void sendMessage(ReqT message) {}
|
||||
}
|
||||
|
|
@ -28,7 +28,6 @@ import io.grpc.ClientInterceptor;
|
|||
import io.grpc.ForwardingChannelBuilder;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.ManagedChannelBuilder;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.MethodDescriptor;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.alts.internal.AltsClientOptions;
|
||||
|
|
@ -91,7 +90,7 @@ public final class GoogleDefaultChannelBuilder
|
|||
credentials = MoreCallCredentials.from(GoogleCredentials.getApplicationDefault());
|
||||
} catch (IOException e) {
|
||||
status =
|
||||
Status.FAILED_PRECONDITION
|
||||
Status.UNAUTHENTICATED
|
||||
.withDescription("Failed to get Google default credentials")
|
||||
.withCause(e);
|
||||
}
|
||||
|
|
@ -188,31 +187,4 @@ public final class GoogleDefaultChannelBuilder
|
|||
return next.newCall(method, callOptions.withCallCredentials(credentials));
|
||||
}
|
||||
}
|
||||
|
||||
/** An implementation of {@link ClientCall} that fails when started. */
|
||||
static final class FailingClientCall<ReqT, RespT> extends ClientCall<ReqT, RespT> {
|
||||
|
||||
private final Status error;
|
||||
|
||||
public FailingClientCall(Status error) {
|
||||
this.error = error;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start(ClientCall.Listener<RespT> listener, Metadata headers) {
|
||||
listener.onClose(error, new Metadata());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void request(int numMessages) {}
|
||||
|
||||
@Override
|
||||
public void cancel(String message, Throwable cause) {}
|
||||
|
||||
@Override
|
||||
public void halfClose() {}
|
||||
|
||||
@Override
|
||||
public void sendMessage(ReqT message) {}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,19 +17,12 @@
|
|||
package io.grpc.alts;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
import com.google.common.base.Defaults;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.alts.AltsChannelBuilder.AltsChannel;
|
||||
import io.grpc.alts.internal.AltsClientOptions;
|
||||
import io.grpc.alts.internal.AltsProtocolNegotiator;
|
||||
import io.grpc.alts.internal.TransportSecurityCommon.RpcProtocolVersions;
|
||||
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilterFactory;
|
||||
import io.grpc.netty.ProtocolNegotiator;
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.net.InetSocketAddress;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
|
@ -49,8 +42,7 @@ public final class AltsChannelBuilderTest {
|
|||
assertThat(tcpfFactory).isNull();
|
||||
assertThat(altsClientOptions).isNull();
|
||||
|
||||
ManagedChannel channel = builder.build();
|
||||
assertThat(channel).isInstanceOf(AltsChannel.class);
|
||||
builder.build();
|
||||
|
||||
tcpfFactory = builder.getTcpfFactoryForTest();
|
||||
altsClientOptions = builder.getAltsClientOptionsForTest();
|
||||
|
|
@ -72,24 +64,4 @@ public final class AltsChannelBuilderTest {
|
|||
.build();
|
||||
assertThat(altsClientOptions.getRpcProtocolVersions()).isEqualTo(expectedVersions);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void allAltsChannelMethodsForward() throws Exception {
|
||||
ManagedChannel mockDelegate = mock(ManagedChannel.class);
|
||||
AltsChannel altsChannel = new AltsChannel(mockDelegate);
|
||||
|
||||
for (Method method : ManagedChannel.class.getDeclaredMethods()) {
|
||||
if (Modifier.isStatic(method.getModifiers()) || Modifier.isPrivate(method.getModifiers())) {
|
||||
continue;
|
||||
}
|
||||
Class<?>[] argTypes = method.getParameterTypes();
|
||||
Object[] args = new Object[argTypes.length];
|
||||
for (int i = 0; i < argTypes.length; i++) {
|
||||
args[i] = Defaults.defaultValue(argTypes[i]);
|
||||
}
|
||||
|
||||
method.invoke(altsChannel, args);
|
||||
method.invoke(verify(mockDelegate), args);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue