diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index 769309a5a1..758f99d535 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -170,6 +170,7 @@ public abstract class AbstractInteropTest { private ScheduledExecutorService testServiceExecutor; private Server server; + private Server handshakerServer; private final LinkedBlockingQueue serverStreamTracers = new LinkedBlockingQueue<>(); @@ -223,6 +224,7 @@ public abstract class AbstractInteropTest { protected static final Empty EMPTY = Empty.getDefaultInstance(); private void startServer() { + maybeStartHandshakerServer(); ServerBuilder builder = getServerBuilder(); if (builder == null) { server = null; @@ -251,6 +253,17 @@ public abstract class AbstractInteropTest { } } + private void maybeStartHandshakerServer() { + ServerBuilder handshakerServerBuilder = getHandshakerServerBuilder(); + if (handshakerServerBuilder != null) { + try { + handshakerServer = handshakerServerBuilder.build().start(); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } + } + private void stopServer() { if (server != null) { server.shutdownNow(); @@ -258,6 +271,9 @@ public abstract class AbstractInteropTest { if (testServiceExecutor != null) { testServiceExecutor.shutdown(); } + if (handshakerServer != null) { + handshakerServer.shutdownNow(); + } } @VisibleForTesting @@ -348,6 +364,11 @@ public abstract class AbstractInteropTest { return null; } + @Nullable + protected ServerBuilder getHandshakerServerBuilder() { + return null; + } + protected final ClientInterceptor createCensusStatsClientInterceptor() { return InternalCensusStatsAccessor diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AltsHandshakerTestService.java b/interop-testing/src/main/java/io/grpc/testing/integration/AltsHandshakerTestService.java new file mode 100644 index 0000000000..bf4a2fe366 --- /dev/null +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AltsHandshakerTestService.java @@ -0,0 +1,146 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static com.google.common.base.Preconditions.checkState; +import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.CLIENT_START; +import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.NEXT; +import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.SERVER_START; + +import com.google.protobuf.ByteString; +import io.grpc.alts.internal.HandshakerReq; +import io.grpc.alts.internal.HandshakerResp; +import io.grpc.alts.internal.HandshakerResult; +import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceImplBase; +import io.grpc.alts.internal.Identity; +import io.grpc.alts.internal.RpcProtocolVersions; +import io.grpc.alts.internal.RpcProtocolVersions.Version; +import io.grpc.stub.StreamObserver; +import java.util.Random; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A fake HandshakeService for ALTS integration testing in non-gcp environments. + * */ +public class AltsHandshakerTestService extends HandshakerServiceImplBase { + private static final Logger log = Logger.getLogger(AltsHandshakerTestService.class.getName()); + + private final Random random = new Random(); + private static final int FIXED_LENGTH_OUTPUT = 16; + private final ByteString fakeOutput = data(FIXED_LENGTH_OUTPUT); + private final ByteString secret = data(128); + private State expectState = State.CLIENT_INIT; + + @Override + public StreamObserver doHandshake( + final StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(HandshakerReq value) { + log.log(Level.FINE, "request received: " + value); + synchronized (this) { + switch (expectState) { + case CLIENT_INIT: + checkState(CLIENT_START.equals(value.getReqOneofCase())); + HandshakerResp initClient = HandshakerResp.newBuilder() + .setOutFrames(fakeOutput) + .build(); + log.log(Level.FINE, "init client response " + initClient); + responseObserver.onNext(initClient); + expectState = State.SERVER_INIT; + break; + case SERVER_INIT: + checkState(SERVER_START.equals(value.getReqOneofCase())); + HandshakerResp initServer = HandshakerResp.newBuilder() + .setBytesConsumed(FIXED_LENGTH_OUTPUT) + .setOutFrames(fakeOutput) + .build(); + log.log(Level.FINE, "init server response" + initServer); + responseObserver.onNext(initServer); + expectState = State.CLIENT_FINISH; + break; + case CLIENT_FINISH: + checkState(NEXT.equals(value.getReqOneofCase())); + HandshakerResp resp = HandshakerResp.newBuilder() + .setResult(getResult()) + .setBytesConsumed(FIXED_LENGTH_OUTPUT) + .setOutFrames(fakeOutput) + .build(); + log.log(Level.FINE, "client finished response " + resp); + responseObserver.onNext(resp); + expectState = State.SERVER_FINISH; + break; + case SERVER_FINISH: + resp = HandshakerResp.newBuilder() + .setResult(getResult()) + .setBytesConsumed(FIXED_LENGTH_OUTPUT) + .build(); + log.log(Level.FINE, "server finished response " + resp); + responseObserver.onNext(resp); + expectState = State.CLIENT_INIT; + break; + default: + throw new RuntimeException("unknown state"); + } + } + } + + @Override + public void onError(Throwable t) { + log.log(Level.INFO, "onError " + t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + + private HandshakerResult getResult() { + return HandshakerResult.newBuilder().setApplicationProtocol("grpc") + .setRecordProtocol("ALTSRP_GCM_AES128_REKEY") + .setKeyData(secret) + .setMaxFrameSize(131072) + .setPeerIdentity(Identity.newBuilder() + .setServiceAccount("123456789-compute@developer.gserviceaccount.com") + .build()) + .setPeerRpcVersions(RpcProtocolVersions.newBuilder() + .setMaxRpcVersion(Version.newBuilder() + .setMajor(2).setMinor(1) + .build()) + .setMinRpcVersion(Version.newBuilder() + .setMajor(2).setMinor(1) + .build()) + .build()) + .build(); + } + + private ByteString data(int len) { + byte[] k = new byte[len]; + random.nextBytes(k); + return ByteString.copyFrom(k); + } + + private enum State { + CLIENT_INIT, + SERVER_INIT, + CLIENT_FINISH, + SERVER_FINISH + } +} diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java index 82a379c212..3ca35dfb3a 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java @@ -21,8 +21,10 @@ import com.google.common.io.Files; import io.grpc.ChannelCredentials; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; +import io.grpc.ServerBuilder; import io.grpc.TlsChannelCredentials; import io.grpc.alts.AltsChannelCredentials; import io.grpc.alts.ComputeEngineChannelCredentials; @@ -42,6 +44,7 @@ import java.io.File; import java.io.FileInputStream; import java.nio.charset.Charset; import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; /** * Application that starts a client for the {@link TestServiceGrpc.TestServiceImplBase} and runs @@ -83,6 +86,7 @@ public class TestServiceClient { private String serviceAccountKeyFile; private String oauthScope; private boolean fullStreamDecompression; + private int localHandshakerPort = -1; private Tester tester = new Tester(); @@ -141,6 +145,8 @@ public class TestServiceClient { oauthScope = value; } else if ("full_stream_decompression".equals(key)) { fullStreamDecompression = Boolean.parseBoolean(value); + } else if ("local_handshaker_port".equals(key)) { + localHandshakerPort = Integer.parseInt(value); } else { System.err.println("Unknown argument: " + key); usage = true; @@ -165,6 +171,9 @@ public class TestServiceClient { + "\n --use_tls=true|false Whether to use TLS. Default " + c.useTls + "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS." + "\n Default " + c.useAlts + + "\n --local_handshaker_port=PORT" + + "\n Use local ALTS handshaker service on the specified " + + "\n port for testing. Only effective when --use_alts=true." + "\n --use_upgrade=true|false Whether to use the h2c Upgrade mechanism." + "\n Enabling h2c Upgrade will disable TLS." + "\n Default " + c.useH2cUpgrade @@ -398,7 +407,13 @@ public class TestServiceClient { } else if (useAlts) { useGeneric = true; // Retain old behavior; avoids erroring if incompatible - channelCredentials = AltsChannelCredentials.create(); + if (localHandshakerPort > -1) { + channelCredentials = AltsChannelCredentials.newBuilder() + .enableUntrustedAltsForTesting() + .setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build(); + } else { + channelCredentials = AltsChannelCredentials.create(); + } } else if (useTls) { if (!useTestCa) { @@ -475,6 +490,18 @@ public class TestServiceClient { // TODO(zhangkun83): remove this override once the said issue is fixed. return false; } + + @Override + @Nullable + protected ServerBuilder getHandshakerServerBuilder() { + if (localHandshakerPort > -1) { + return Grpc.newServerBuilderForPort(localHandshakerPort, + InsecureServerCredentials.create()) + .addService(new AltsHandshakerTestService()); + } else { + return null; + } + } } private static String validTestCasesHelpText() { diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java index 2a5c0ebe55..19946ec4a7 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java @@ -70,6 +70,7 @@ public class TestServiceServer { private ScheduledExecutorService executor; private Server server; + private int localHandshakerPort = -1; @VisibleForTesting void parseArgs(String[] args) { @@ -98,6 +99,8 @@ public class TestServiceServer { useTls = Boolean.parseBoolean(value); } else if ("use_alts".equals(key)) { useAlts = Boolean.parseBoolean(value); + } else if ("local_handshaker_port".equals(key)) { + localHandshakerPort = Integer.parseInt(value); } else if ("grpc_version".equals(key)) { if (!"2".equals(value)) { System.err.println("Only grpc version 2 is supported"); @@ -122,6 +125,9 @@ public class TestServiceServer { + "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls + "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS." + "\n Default " + s.useAlts + + "\n --local_handshaker_port=PORT" + + "\n Use local ALTS handshaker service on the specified port " + + "\n for testing. Only effective when --use_alts=true." ); System.exit(1); } @@ -132,7 +138,13 @@ public class TestServiceServer { executor = Executors.newSingleThreadScheduledExecutor(); ServerCredentials serverCreds; if (useAlts) { - serverCreds = AltsServerCredentials.create(); + if (localHandshakerPort > -1) { + serverCreds = AltsServerCredentials.newBuilder() + .enableUntrustedAltsForTesting() + .setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build(); + } else { + serverCreds = AltsServerCredentials.create(); + } } else if (useTls) { serverCreds = TlsServerCredentials.create( TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")); diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/AltsHandshakerTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/AltsHandshakerTest.java new file mode 100644 index 0000000000..c6c1d2b3e7 --- /dev/null +++ b/interop-testing/src/test/java/io/grpc/testing/integration/AltsHandshakerTest.java @@ -0,0 +1,100 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static org.junit.Assert.assertEquals; + +import com.google.protobuf.ByteString; +import io.grpc.ChannelCredentials; +import io.grpc.Grpc; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerCredentials; +import io.grpc.alts.AltsChannelCredentials; +import io.grpc.alts.AltsServerCredentials; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.integration.Messages.Payload; +import io.grpc.testing.integration.Messages.SimpleRequest; +import io.grpc.testing.integration.Messages.SimpleResponse; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.util.concurrent.DefaultThreadFactory; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class AltsHandshakerTest { + @Rule + public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private Server handshakerServer; + private Server testServer; + private ManagedChannel channel; + + private void startAltsServer() throws Exception { + ServerCredentials serverCredentials = AltsServerCredentials.newBuilder() + .enableUntrustedAltsForTesting() + .setHandshakerAddressForTesting("localhost:" + handshakerServer.getPort()) + .build(); + testServer = grpcCleanup.register( + Grpc.newServerBuilderForPort(0, serverCredentials) + .addService(new TestServiceGrpc.TestServiceImplBase() { + @Override + public void unaryCall(SimpleRequest request, StreamObserver so) { + so.onNext(SimpleResponse.getDefaultInstance()); + so.onCompleted(); + } + }) + .build()) + .start(); + } + + @Before + public void setup() throws Exception { + // create new EventLoopGroups to avoid deadlock at server side handshake negotiation, e.g. + // happens when handshakerServer and testServer child channels are on the same eventloop. + handshakerServer = grpcCleanup.register(NettyServerBuilder.forPort(0) + .bossEventLoopGroup( + new NioEventLoopGroup(0, new DefaultThreadFactory("test-alts-boss"))) + .workerEventLoopGroup( + new NioEventLoopGroup(0, new DefaultThreadFactory("test-alts-worker"))) + .channelType(NioServerSocketChannel.class) + .addService(new AltsHandshakerTestService()) + .build()).start(); + startAltsServer(); + + ChannelCredentials channelCredentials = AltsChannelCredentials.newBuilder() + .enableUntrustedAltsForTesting() + .setHandshakerAddressForTesting("localhost:" + handshakerServer.getPort()).build(); + channel = grpcCleanup.register( + Grpc.newChannelBuilderForAddress("localhost", testServer.getPort(), channelCredentials) + .build()); + } + + @Test + public void testAlts() { + TestServiceGrpc.TestServiceBlockingStub blockingStub = TestServiceGrpc.newBlockingStub(channel); + final SimpleRequest request = SimpleRequest.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[10]))) + .build(); + assertEquals(SimpleResponse.getDefaultInstance(), blockingStub.unaryCall(request)); + } +}