diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanisms.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanisms.java index 2d089183f9..88dfd62675 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanisms.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanisms.java @@ -24,19 +24,21 @@ import java.util.Optional; /** Retrieves the authentication mechanism for a given local identity. */ @Immutable final class GetAuthenticationMechanisms { - private static final Optional TOKEN_MANAGER = AccessTokenManager.create(); + static final Optional TOKEN_MANAGER = AccessTokenManager.create(); /** * Retrieves the authentication mechanism for a given local identity. * * @param localIdentity the identity for which to fetch a token. + * @param tokenManager the token manager to use for fetching tokens. * @return an {@link AuthenticationMechanism} for the given local identity. */ - static Optional getAuthMechanism(Optional localIdentity) { - if (!TOKEN_MANAGER.isPresent()) { + static Optional getAuthMechanism(Optional localIdentity, + Optional tokenManager) { + if (!tokenManager.isPresent()) { return Optional.empty(); } - AccessTokenManager manager = TOKEN_MANAGER.get(); + AccessTokenManager manager = tokenManager.get(); // If no identity is provided, fetch the default access token and DO NOT attach an identity // to the request. if (!localIdentity.isPresent()) { diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java index 3e5481daa9..153f4de691 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java @@ -105,7 +105,8 @@ final class SslContextFactory { reqBuilder.setLocalIdentity(localIdentity.get().getIdentity()); } Optional authMechanism = - GetAuthenticationMechanisms.getAuthMechanism(localIdentity); + GetAuthenticationMechanisms.getAuthMechanism(localIdentity, + GetAuthenticationMechanisms.TOKEN_MANAGER); if (authMechanism.isPresent()) { reqBuilder.addAuthenticationMechanisms(authMechanism.get()); } diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanismsTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanismsTest.java index d17d9ba99e..d69d84bf45 100644 --- a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanismsTest.java +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanismsTest.java @@ -18,9 +18,11 @@ package io.grpc.s2a.internal.handshaker; import com.google.common.truth.Expect; import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.grpc.s2a.internal.handshaker.tokenmanager.AccessTokenManager; import io.grpc.s2a.internal.handshaker.tokenmanager.SingleTokenFetcher; import java.util.Optional; import org.junit.AfterClass; +import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; @@ -33,6 +35,7 @@ public final class GetAuthenticationMechanismsTest { @Rule public final Expect expect = Expect.create(); private static final String TOKEN = "access_token"; private static String originalAccessToken; + private Optional tokenManager; @BeforeClass public static void setUpClass() { @@ -41,6 +44,11 @@ public final class GetAuthenticationMechanismsTest { SingleTokenFetcher.setAccessToken(TOKEN); } + @Before + public void setUp() { + tokenManager = AccessTokenManager.create(); + } + @AfterClass public static void tearDownClass() { SingleTokenFetcher.setAccessToken(originalAccessToken); @@ -49,7 +57,7 @@ public final class GetAuthenticationMechanismsTest { @Test public void getAuthMechanisms_emptyIdentity_success() { expect - .that(GetAuthenticationMechanisms.getAuthMechanism(Optional.empty())) + .that(GetAuthenticationMechanisms.getAuthMechanism(Optional.empty(), tokenManager)) .isEqualTo( Optional.of(AuthenticationMechanism.newBuilder().setToken("access_token").build())); } @@ -58,7 +66,7 @@ public final class GetAuthenticationMechanismsTest { public void getAuthMechanisms_nonEmptyIdentity_success() { S2AIdentity fakeIdentity = S2AIdentity.fromSpiffeId("fake-spiffe-id"); expect - .that(GetAuthenticationMechanisms.getAuthMechanism(Optional.of(fakeIdentity))) + .that(GetAuthenticationMechanisms.getAuthMechanism(Optional.of(fakeIdentity), tokenManager)) .isEqualTo( Optional.of( AuthenticationMechanism.newBuilder()