s2a: fix flake in FakeS2AServerTest (#11673)

While here:
 * add an awaitTermination to after calling shutdown on server
 * don't use port picker

Fixes #11648
This commit is contained in:
Riya Mehta 2024-11-08 10:25:49 -08:00 committed by GitHub
parent 5081e60626
commit 546efd79f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 23 additions and 19 deletions

View File

@ -20,13 +20,13 @@ import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat;
import static java.util.concurrent.TimeUnit.SECONDS; import static java.util.concurrent.TimeUnit.SECONDS;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.SettableFuture;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Grpc; import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials; import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.Server; import io.grpc.Server;
import io.grpc.ServerBuilder; import io.grpc.ServerBuilder;
import io.grpc.benchmarks.Utils;
import io.grpc.s2a.internal.handshaker.ValidatePeerCertificateChainReq.VerificationMode; import io.grpc.s2a.internal.handshaker.ValidatePeerCertificateChainReq.VerificationMode;
import io.grpc.stub.StreamObserver; import io.grpc.stub.StreamObserver;
import java.io.IOException; import java.io.IOException;
@ -49,51 +49,52 @@ public final class FakeS2AServerTest {
private static final ImmutableList<ByteString> FAKE_CERT_DER_CHAIN = private static final ImmutableList<ByteString> FAKE_CERT_DER_CHAIN =
ImmutableList.of(ByteString.copyFrom("fake-der-chain".getBytes(StandardCharsets.US_ASCII))); ImmutableList.of(ByteString.copyFrom("fake-der-chain".getBytes(StandardCharsets.US_ASCII)));
private int port;
private String serverAddress; private String serverAddress;
private SessionResp response = null;
private Server fakeS2AServer; private Server fakeS2AServer;
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
port = Utils.pickUnusedPort(); fakeS2AServer = ServerBuilder.forPort(0).addService(new FakeS2AServer()).build();
fakeS2AServer = ServerBuilder.forPort(port).addService(new FakeS2AServer()).build();
fakeS2AServer.start(); fakeS2AServer.start();
serverAddress = String.format("localhost:%d", port); serverAddress = String.format("localhost:%d", fakeS2AServer.getPort());
} }
@After @After
public void tearDown() { public void tearDown() throws Exception {
fakeS2AServer.shutdown(); fakeS2AServer.shutdown();
fakeS2AServer.awaitTermination(10, SECONDS);
} }
@Test @Test
public void callS2AServerOnce_getTlsConfiguration_returnsValidResult() public void callS2AServerOnce_getTlsConfiguration_returnsValidResult()
throws InterruptedException, IOException { throws InterruptedException, IOException, java.util.concurrent.ExecutionException {
ExecutorService executor = Executors.newSingleThreadExecutor(); ExecutorService executor = Executors.newSingleThreadExecutor();
logger.info("Client connecting to: " + serverAddress); logger.info("Client connecting to: " + serverAddress);
ManagedChannel channel = ManagedChannel channel =
Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create())
.executor(executor) .executor(executor)
.build(); .build();
SettableFuture<SessionResp> respFuture = SettableFuture.create();
try { try {
S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel);
StreamObserver<SessionReq> requestObserver = StreamObserver<SessionReq> requestObserver =
asyncStub.setUpSession( asyncStub.setUpSession(
new StreamObserver<SessionResp>() { new StreamObserver<SessionResp>() {
SessionResp recvResp;
@Override @Override
public void onNext(SessionResp resp) { public void onNext(SessionResp resp) {
response = resp; recvResp = resp;
} }
@Override @Override
public void onError(Throwable t) { public void onError(Throwable t) {
throw new RuntimeException(t); respFuture.setException(t);
} }
@Override @Override
public void onCompleted() {} public void onCompleted() {
respFuture.set(recvResp);
}
}); });
try { try {
requestObserver.onNext( requestObserver.onNext(
@ -138,36 +139,39 @@ public final class FakeS2AServerTest {
.addCiphersuites( .addCiphersuites(
Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256)))
.build(); .build();
assertThat(response).ignoringRepeatedFieldOrder().isEqualTo(expected); assertThat(respFuture.get()).ignoringRepeatedFieldOrder().isEqualTo(expected);
} }
@Test @Test
public void callS2AServerOnce_validatePeerCertifiate_returnsValidResult() public void callS2AServerOnce_validatePeerCertifiate_returnsValidResult()
throws InterruptedException { throws InterruptedException, java.util.concurrent.ExecutionException {
ExecutorService executor = Executors.newSingleThreadExecutor(); ExecutorService executor = Executors.newSingleThreadExecutor();
logger.info("Client connecting to: " + serverAddress); logger.info("Client connecting to: " + serverAddress);
ManagedChannel channel = ManagedChannel channel =
Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create())
.executor(executor) .executor(executor)
.build(); .build();
SettableFuture<SessionResp> respFuture = SettableFuture.create();
try { try {
S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel);
StreamObserver<SessionReq> requestObserver = StreamObserver<SessionReq> requestObserver =
asyncStub.setUpSession( asyncStub.setUpSession(
new StreamObserver<SessionResp>() { new StreamObserver<SessionResp>() {
private SessionResp recvResp;
@Override @Override
public void onNext(SessionResp resp) { public void onNext(SessionResp resp) {
response = resp; recvResp = resp;
} }
@Override @Override
public void onError(Throwable t) { public void onError(Throwable t) {
throw new RuntimeException(t); respFuture.setException(t);
} }
@Override @Override
public void onCompleted() {} public void onCompleted() {
respFuture.set(recvResp);
}
}); });
try { try {
requestObserver.onNext( requestObserver.onNext(
@ -200,7 +204,7 @@ public final class FakeS2AServerTest {
ValidatePeerCertificateChainResp.newBuilder() ValidatePeerCertificateChainResp.newBuilder()
.setValidationResult(ValidatePeerCertificateChainResp.ValidationResult.SUCCESS)) .setValidationResult(ValidatePeerCertificateChainResp.ValidationResult.SUCCESS))
.build(); .build();
assertThat(response).ignoringRepeatedFieldOrder().isEqualTo(expected); assertThat(respFuture.get()).ignoringRepeatedFieldOrder().isEqualTo(expected);
} }
@Test @Test