interop-testing: add fake altsHandshakerService for test (#7847)

This commit is contained in:
yifeizhuang 2021-02-09 16:56:26 -08:00 committed by GitHub
parent 514101d90c
commit 7f7821c616
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 308 additions and 2 deletions

View File

@ -170,6 +170,7 @@ public abstract class AbstractInteropTest {
private ScheduledExecutorService testServiceExecutor; private ScheduledExecutorService testServiceExecutor;
private Server server; private Server server;
private Server handshakerServer;
private final LinkedBlockingQueue<ServerStreamTracerInfo> serverStreamTracers = private final LinkedBlockingQueue<ServerStreamTracerInfo> serverStreamTracers =
new LinkedBlockingQueue<>(); new LinkedBlockingQueue<>();
@ -223,6 +224,7 @@ public abstract class AbstractInteropTest {
protected static final Empty EMPTY = Empty.getDefaultInstance(); protected static final Empty EMPTY = Empty.getDefaultInstance();
private void startServer() { private void startServer() {
maybeStartHandshakerServer();
ServerBuilder<?> builder = getServerBuilder(); ServerBuilder<?> builder = getServerBuilder();
if (builder == null) { if (builder == null) {
server = null; server = null;
@ -251,6 +253,17 @@ public abstract class AbstractInteropTest {
} }
} }
private void maybeStartHandshakerServer() {
ServerBuilder<?> handshakerServerBuilder = getHandshakerServerBuilder();
if (handshakerServerBuilder != null) {
try {
handshakerServer = handshakerServerBuilder.build().start();
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
}
private void stopServer() { private void stopServer() {
if (server != null) { if (server != null) {
server.shutdownNow(); server.shutdownNow();
@ -258,6 +271,9 @@ public abstract class AbstractInteropTest {
if (testServiceExecutor != null) { if (testServiceExecutor != null) {
testServiceExecutor.shutdown(); testServiceExecutor.shutdown();
} }
if (handshakerServer != null) {
handshakerServer.shutdownNow();
}
} }
@VisibleForTesting @VisibleForTesting
@ -348,6 +364,11 @@ public abstract class AbstractInteropTest {
return null; return null;
} }
@Nullable
protected ServerBuilder<?> getHandshakerServerBuilder() {
return null;
}
protected final ClientInterceptor createCensusStatsClientInterceptor() { protected final ClientInterceptor createCensusStatsClientInterceptor() {
return return
InternalCensusStatsAccessor InternalCensusStatsAccessor

View File

@ -0,0 +1,146 @@
/*
* Copyright 2021 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.testing.integration;
import static com.google.common.base.Preconditions.checkState;
import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.CLIENT_START;
import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.NEXT;
import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.SERVER_START;
import com.google.protobuf.ByteString;
import io.grpc.alts.internal.HandshakerReq;
import io.grpc.alts.internal.HandshakerResp;
import io.grpc.alts.internal.HandshakerResult;
import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceImplBase;
import io.grpc.alts.internal.Identity;
import io.grpc.alts.internal.RpcProtocolVersions;
import io.grpc.alts.internal.RpcProtocolVersions.Version;
import io.grpc.stub.StreamObserver;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* A fake HandshakeService for ALTS integration testing in non-gcp environments.
* */
public class AltsHandshakerTestService extends HandshakerServiceImplBase {
private static final Logger log = Logger.getLogger(AltsHandshakerTestService.class.getName());
private final Random random = new Random();
private static final int FIXED_LENGTH_OUTPUT = 16;
private final ByteString fakeOutput = data(FIXED_LENGTH_OUTPUT);
private final ByteString secret = data(128);
private State expectState = State.CLIENT_INIT;
@Override
public StreamObserver<HandshakerReq> doHandshake(
final StreamObserver<HandshakerResp> responseObserver) {
return new StreamObserver<HandshakerReq>() {
@Override
public void onNext(HandshakerReq value) {
log.log(Level.FINE, "request received: " + value);
synchronized (this) {
switch (expectState) {
case CLIENT_INIT:
checkState(CLIENT_START.equals(value.getReqOneofCase()));
HandshakerResp initClient = HandshakerResp.newBuilder()
.setOutFrames(fakeOutput)
.build();
log.log(Level.FINE, "init client response " + initClient);
responseObserver.onNext(initClient);
expectState = State.SERVER_INIT;
break;
case SERVER_INIT:
checkState(SERVER_START.equals(value.getReqOneofCase()));
HandshakerResp initServer = HandshakerResp.newBuilder()
.setBytesConsumed(FIXED_LENGTH_OUTPUT)
.setOutFrames(fakeOutput)
.build();
log.log(Level.FINE, "init server response" + initServer);
responseObserver.onNext(initServer);
expectState = State.CLIENT_FINISH;
break;
case CLIENT_FINISH:
checkState(NEXT.equals(value.getReqOneofCase()));
HandshakerResp resp = HandshakerResp.newBuilder()
.setResult(getResult())
.setBytesConsumed(FIXED_LENGTH_OUTPUT)
.setOutFrames(fakeOutput)
.build();
log.log(Level.FINE, "client finished response " + resp);
responseObserver.onNext(resp);
expectState = State.SERVER_FINISH;
break;
case SERVER_FINISH:
resp = HandshakerResp.newBuilder()
.setResult(getResult())
.setBytesConsumed(FIXED_LENGTH_OUTPUT)
.build();
log.log(Level.FINE, "server finished response " + resp);
responseObserver.onNext(resp);
expectState = State.CLIENT_INIT;
break;
default:
throw new RuntimeException("unknown state");
}
}
}
@Override
public void onError(Throwable t) {
log.log(Level.INFO, "onError " + t);
}
@Override
public void onCompleted() {
responseObserver.onCompleted();
}
};
}
private HandshakerResult getResult() {
return HandshakerResult.newBuilder().setApplicationProtocol("grpc")
.setRecordProtocol("ALTSRP_GCM_AES128_REKEY")
.setKeyData(secret)
.setMaxFrameSize(131072)
.setPeerIdentity(Identity.newBuilder()
.setServiceAccount("123456789-compute@developer.gserviceaccount.com")
.build())
.setPeerRpcVersions(RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder()
.setMajor(2).setMinor(1)
.build())
.setMinRpcVersion(Version.newBuilder()
.setMajor(2).setMinor(1)
.build())
.build())
.build();
}
private ByteString data(int len) {
byte[] k = new byte[len];
random.nextBytes(k);
return ByteString.copyFrom(k);
}
private enum State {
CLIENT_INIT,
SERVER_INIT,
CLIENT_FINISH,
SERVER_FINISH
}
}

View File

@ -21,8 +21,10 @@ import com.google.common.io.Files;
import io.grpc.ChannelCredentials; import io.grpc.ChannelCredentials;
import io.grpc.Grpc; import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials; import io.grpc.InsecureChannelCredentials;
import io.grpc.InsecureServerCredentials;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder; import io.grpc.ManagedChannelBuilder;
import io.grpc.ServerBuilder;
import io.grpc.TlsChannelCredentials; import io.grpc.TlsChannelCredentials;
import io.grpc.alts.AltsChannelCredentials; import io.grpc.alts.AltsChannelCredentials;
import io.grpc.alts.ComputeEngineChannelCredentials; import io.grpc.alts.ComputeEngineChannelCredentials;
@ -42,6 +44,7 @@ import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
/** /**
* Application that starts a client for the {@link TestServiceGrpc.TestServiceImplBase} and runs * Application that starts a client for the {@link TestServiceGrpc.TestServiceImplBase} and runs
@ -83,6 +86,7 @@ public class TestServiceClient {
private String serviceAccountKeyFile; private String serviceAccountKeyFile;
private String oauthScope; private String oauthScope;
private boolean fullStreamDecompression; private boolean fullStreamDecompression;
private int localHandshakerPort = -1;
private Tester tester = new Tester(); private Tester tester = new Tester();
@ -141,6 +145,8 @@ public class TestServiceClient {
oauthScope = value; oauthScope = value;
} else if ("full_stream_decompression".equals(key)) { } else if ("full_stream_decompression".equals(key)) {
fullStreamDecompression = Boolean.parseBoolean(value); fullStreamDecompression = Boolean.parseBoolean(value);
} else if ("local_handshaker_port".equals(key)) {
localHandshakerPort = Integer.parseInt(value);
} else { } else {
System.err.println("Unknown argument: " + key); System.err.println("Unknown argument: " + key);
usage = true; usage = true;
@ -165,6 +171,9 @@ public class TestServiceClient {
+ "\n --use_tls=true|false Whether to use TLS. Default " + c.useTls + "\n --use_tls=true|false Whether to use TLS. Default " + c.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS." + "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + c.useAlts + "\n Default " + c.useAlts
+ "\n --local_handshaker_port=PORT"
+ "\n Use local ALTS handshaker service on the specified "
+ "\n port for testing. Only effective when --use_alts=true."
+ "\n --use_upgrade=true|false Whether to use the h2c Upgrade mechanism." + "\n --use_upgrade=true|false Whether to use the h2c Upgrade mechanism."
+ "\n Enabling h2c Upgrade will disable TLS." + "\n Enabling h2c Upgrade will disable TLS."
+ "\n Default " + c.useH2cUpgrade + "\n Default " + c.useH2cUpgrade
@ -398,7 +407,13 @@ public class TestServiceClient {
} else if (useAlts) { } else if (useAlts) {
useGeneric = true; // Retain old behavior; avoids erroring if incompatible useGeneric = true; // Retain old behavior; avoids erroring if incompatible
channelCredentials = AltsChannelCredentials.create(); if (localHandshakerPort > -1) {
channelCredentials = AltsChannelCredentials.newBuilder()
.enableUntrustedAltsForTesting()
.setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build();
} else {
channelCredentials = AltsChannelCredentials.create();
}
} else if (useTls) { } else if (useTls) {
if (!useTestCa) { if (!useTestCa) {
@ -475,6 +490,18 @@ public class TestServiceClient {
// TODO(zhangkun83): remove this override once the said issue is fixed. // TODO(zhangkun83): remove this override once the said issue is fixed.
return false; return false;
} }
@Override
@Nullable
protected ServerBuilder<?> getHandshakerServerBuilder() {
if (localHandshakerPort > -1) {
return Grpc.newServerBuilderForPort(localHandshakerPort,
InsecureServerCredentials.create())
.addService(new AltsHandshakerTestService());
} else {
return null;
}
}
} }
private static String validTestCasesHelpText() { private static String validTestCasesHelpText() {

View File

@ -70,6 +70,7 @@ public class TestServiceServer {
private ScheduledExecutorService executor; private ScheduledExecutorService executor;
private Server server; private Server server;
private int localHandshakerPort = -1;
@VisibleForTesting @VisibleForTesting
void parseArgs(String[] args) { void parseArgs(String[] args) {
@ -98,6 +99,8 @@ public class TestServiceServer {
useTls = Boolean.parseBoolean(value); useTls = Boolean.parseBoolean(value);
} else if ("use_alts".equals(key)) { } else if ("use_alts".equals(key)) {
useAlts = Boolean.parseBoolean(value); useAlts = Boolean.parseBoolean(value);
} else if ("local_handshaker_port".equals(key)) {
localHandshakerPort = Integer.parseInt(value);
} else if ("grpc_version".equals(key)) { } else if ("grpc_version".equals(key)) {
if (!"2".equals(value)) { if (!"2".equals(value)) {
System.err.println("Only grpc version 2 is supported"); System.err.println("Only grpc version 2 is supported");
@ -122,6 +125,9 @@ public class TestServiceServer {
+ "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls + "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS." + "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + s.useAlts + "\n Default " + s.useAlts
+ "\n --local_handshaker_port=PORT"
+ "\n Use local ALTS handshaker service on the specified port "
+ "\n for testing. Only effective when --use_alts=true."
); );
System.exit(1); System.exit(1);
} }
@ -132,7 +138,13 @@ public class TestServiceServer {
executor = Executors.newSingleThreadScheduledExecutor(); executor = Executors.newSingleThreadScheduledExecutor();
ServerCredentials serverCreds; ServerCredentials serverCreds;
if (useAlts) { if (useAlts) {
serverCreds = AltsServerCredentials.create(); if (localHandshakerPort > -1) {
serverCreds = AltsServerCredentials.newBuilder()
.enableUntrustedAltsForTesting()
.setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build();
} else {
serverCreds = AltsServerCredentials.create();
}
} else if (useTls) { } else if (useTls) {
serverCreds = TlsServerCredentials.create( serverCreds = TlsServerCredentials.create(
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")); TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"));

View File

@ -0,0 +1,100 @@
/*
* Copyright 2021 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.testing.integration;
import static org.junit.Assert.assertEquals;
import com.google.protobuf.ByteString;
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.ServerCredentials;
import io.grpc.alts.AltsChannelCredentials;
import io.grpc.alts.AltsServerCredentials;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.integration.Messages.Payload;
import io.grpc.testing.integration.Messages.SimpleRequest;
import io.grpc.testing.integration.Messages.SimpleResponse;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class AltsHandshakerTest {
@Rule
public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private Server handshakerServer;
private Server testServer;
private ManagedChannel channel;
private void startAltsServer() throws Exception {
ServerCredentials serverCredentials = AltsServerCredentials.newBuilder()
.enableUntrustedAltsForTesting()
.setHandshakerAddressForTesting("localhost:" + handshakerServer.getPort())
.build();
testServer = grpcCleanup.register(
Grpc.newServerBuilderForPort(0, serverCredentials)
.addService(new TestServiceGrpc.TestServiceImplBase() {
@Override
public void unaryCall(SimpleRequest request, StreamObserver<SimpleResponse> so) {
so.onNext(SimpleResponse.getDefaultInstance());
so.onCompleted();
}
})
.build())
.start();
}
@Before
public void setup() throws Exception {
// create new EventLoopGroups to avoid deadlock at server side handshake negotiation, e.g.
// happens when handshakerServer and testServer child channels are on the same eventloop.
handshakerServer = grpcCleanup.register(NettyServerBuilder.forPort(0)
.bossEventLoopGroup(
new NioEventLoopGroup(0, new DefaultThreadFactory("test-alts-boss")))
.workerEventLoopGroup(
new NioEventLoopGroup(0, new DefaultThreadFactory("test-alts-worker")))
.channelType(NioServerSocketChannel.class)
.addService(new AltsHandshakerTestService())
.build()).start();
startAltsServer();
ChannelCredentials channelCredentials = AltsChannelCredentials.newBuilder()
.enableUntrustedAltsForTesting()
.setHandshakerAddressForTesting("localhost:" + handshakerServer.getPort()).build();
channel = grpcCleanup.register(
Grpc.newChannelBuilderForAddress("localhost", testServer.getPort(), channelCredentials)
.build());
}
@Test
public void testAlts() {
TestServiceGrpc.TestServiceBlockingStub blockingStub = TestServiceGrpc.newBlockingStub(channel);
final SimpleRequest request = SimpleRequest.newBuilder()
.setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[10])))
.build();
assertEquals(SimpleResponse.getDefaultInstance(), blockingStub.unaryCall(request));
}
}