xds: implement STS based OAuth 2.0 credentials exchange (#7232)

This commit is contained in:
sanjaypujare 2020-07-22 16:36:38 -07:00 committed by GitHub
parent e4215b422d
commit c60f5ff95b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 377 additions and 0 deletions

View File

@ -23,6 +23,7 @@ dependencies {
project(':grpc-stub'), project(':grpc-stub'),
project(':grpc-core'), project(':grpc-core'),
project(':grpc-services'), project(':grpc-services'),
project(':grpc-auth'),
project(path: ':grpc-alts', configuration: 'shadow'), project(path: ':grpc-alts', configuration: 'shadow'),
libraries.gson, libraries.gson,
libraries.re2j libraries.re2j

View File

@ -0,0 +1,169 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sts;
import com.google.api.client.http.GenericUrl;
import com.google.api.client.http.HttpContent;
import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpRequestFactory;
import com.google.api.client.http.HttpResponse;
import com.google.api.client.http.HttpStatusCodes;
import com.google.api.client.http.HttpTransport;
import com.google.api.client.http.javanet.NetHttpTransport;
import com.google.api.client.http.json.JsonHttpContent;
import com.google.api.client.json.JsonObjectParser;
import com.google.api.client.json.jackson2.JacksonFactory;
import com.google.api.client.util.GenericData;
import com.google.auth.http.HttpTransportFactory;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Files;
import java.io.File;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
// TODO(sanjaypujare): replace with the official implementation from google-auth once ready
/** Implementation of OAuth2 Token Exchange as per https://tools.ietf.org/html/rfc8693. */
public final class StsCredentials extends GoogleCredentials {
private static final long serialVersionUID = 6647041424685484932L;
private static final HttpTransportFactory defaultHttpTransportFactory =
new DefaultHttpTransportFactory();
private static final String CLOUD_PLATFORM_SCOPE =
"https://www.googleapis.com/auth/cloud-platform";
private final String sourceCredentialsFileLocation;
private final String identityTokenEndpoint;
private final String audience;
private transient HttpTransportFactory transportFactory;
private StsCredentials(
String identityTokenEndpoint,
String audience,
String sourceCredentialsFileLocation,
HttpTransportFactory transportFactory) {
this.identityTokenEndpoint = identityTokenEndpoint;
this.audience = audience;
this.sourceCredentialsFileLocation = sourceCredentialsFileLocation;
this.transportFactory = transportFactory;
}
/**
* Creates an StsCredentials.
*
* @param identityTokenEndpoint URL of the token exchange service to use.
* @param audience Audience to use in the STS request.
* @param sourceCredentialsFileLocation file-system location that contains the
* source creds e.g. JWT contents.
*/
public static StsCredentials create(
String identityTokenEndpoint, String audience, String sourceCredentialsFileLocation) {
return create(
identityTokenEndpoint,
audience,
sourceCredentialsFileLocation,
getFromServiceLoader(HttpTransportFactory.class, defaultHttpTransportFactory));
}
@VisibleForTesting
static StsCredentials create(
String identityTokenEndpoint,
String audience,
String sourceCredentialsFileLocation,
HttpTransportFactory transportFactory) {
return new StsCredentials(
identityTokenEndpoint, audience, sourceCredentialsFileLocation, transportFactory);
}
@Override
public AccessToken refreshAccessToken() throws IOException {
AccessToken tok = getSourceAccessTokenFromFileLocation();
HttpTransport httpTransport = this.transportFactory.create();
JsonObjectParser parser = new JsonObjectParser(JacksonFactory.getDefaultInstance());
HttpRequestFactory requestFactory = httpTransport.createRequestFactory();
GenericUrl url = new GenericUrl(identityTokenEndpoint);
Map<String, String> params = new HashMap<>();
params.put("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange");
params.put("subject_token_type", "urn:ietf:params:oauth:token-type:jwt");
params.put("requested_token_type", "urn:ietf:params:oauth:token-type:access_token");
params.put("subject_token", tok.getTokenValue());
params.put("scope", CLOUD_PLATFORM_SCOPE);
params.put("audience", audience);
HttpContent content = new JsonHttpContent(parser.getJsonFactory(), params);
HttpRequest request = requestFactory.buildPostRequest(url, content);
request.setParser(parser);
HttpResponse response = null;
try {
response = request.execute();
} catch (IOException e) {
throw new IOException("Error requesting access token", e);
}
if (response.getStatusCode() != HttpStatusCodes.STATUS_CODE_OK) {
throw new IOException("Error getting access token: " + getStatusString(response));
}
GenericData responseData = null;
try {
responseData = response.parseAs(GenericData.class);
} finally {
response.disconnect();
}
String access_token = (String) responseData.get("access_token");
Date expiryTime = null; // just in case expired_in value is not present
if (responseData.containsKey("expires_in")) {
expiryTime = new Date(System.currentTimeMillis()
+ ((BigDecimal) responseData.get("expires_in")).longValue() * 1000L);
}
return new AccessToken(access_token, expiryTime);
}
private AccessToken getSourceAccessTokenFromFileLocation() throws IOException {
return new AccessToken(
Files.asCharSource(new File(sourceCredentialsFileLocation), StandardCharsets.UTF_8).read(),
null);
}
private static String getStatusString(HttpResponse response) {
return response.getStatusCode() + " : " + response.getStatusMessage();
}
@Override
public Builder toBuilder() {
throw new UnsupportedOperationException("toBuilder not supported");
}
private static class DefaultHttpTransportFactory implements HttpTransportFactory {
private static final HttpTransport netHttpTransport = new NetHttpTransport();
@Override
public HttpTransport create() {
return netHttpTransport;
}
}
}

View File

@ -0,0 +1,207 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sts;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.api.client.http.HttpStatusCodes;
import com.google.api.client.http.json.JsonHttpContent;
import com.google.api.client.testing.http.MockHttpTransport;
import com.google.api.client.testing.http.MockLowLevelHttpRequest;
import com.google.api.client.testing.http.MockLowLevelHttpResponse;
import com.google.auth.http.HttpTransportFactory;
import com.google.auth.oauth2.AccessToken;
import com.google.common.collect.Range;
import com.google.common.io.Files;
import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.CallCredentials;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.SecurityLevel;
import io.grpc.auth.MoreCallCredentials;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
/** Tests for {@link StsCredentials}. */
@RunWith(JUnit4.class)
public class StsCredentialsTest {
public static final String AUDIENCE_VALUE =
"identitynamespace:my-trust-domain:my-identity-provider";
public static final String STS_URL = "https://securetoken.googleapis.com/v1/identitybindingtoken";
private static final String TOKEN_FILE_NAME = "istio-token.txt";
static final Metadata.Key<String> KEY_FOR_AUTHORIZATION =
Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER);
@Rule
public TemporaryFolder tempFolder = new TemporaryFolder();
private String currentFileContents;
private File tempTokenFile;
@Before
public void setUp() throws IOException {
tempTokenFile = tempFolder.newFile(TOKEN_FILE_NAME);
currentFileContents = "test-token-content-time-" + System.currentTimeMillis();
Files.write(currentFileContents.getBytes(StandardCharsets.UTF_8), tempTokenFile);
}
@SuppressWarnings("unchecked")
@Test
public void testStsRequestResponse() throws IOException {
MockHttpTransport.Builder builder = new MockHttpTransport.Builder();
MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
response.setContent(MOCK_RESPONSE);
builder.setLowLevelHttpResponse(response);
MockHttpTransport httpTransport = builder.build();
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
when(httpTransportFactory.create()).thenReturn(httpTransport);
StsCredentials stsCredentials =
StsCredentials.create(
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
AccessToken token = stsCredentials.refreshAccessToken();
assertThat(token).isNotNull();
assertThat(token.getTokenValue()).isEqualTo(ACCESS_TOKEN);
long diff = token.getExpirationTime().getTime() - System.currentTimeMillis();
assertThat(diff).isIn(Range.closed(3550000L, 3650000L));
MockLowLevelHttpRequest request = httpTransport.getLowLevelHttpRequest();
assertThat(request.getUrl()).isEqualTo(STS_URL);
assertThat(request.getContentType()).isEqualTo("application/json; charset=UTF-8");
assertThat(request.getStreamingContent()).isInstanceOf(JsonHttpContent.class);
Map<String, Object> map =
(Map<String, Object>) ((JsonHttpContent) request.getStreamingContent()).getData();
assertThat(map.get("subject_token_type")).isEqualTo("urn:ietf:params:oauth:token-type:jwt");
assertThat(map.get("grant_type")).isEqualTo("urn:ietf:params:oauth:grant-type:token-exchange");
assertThat(map.get("subject_token")).isEqualTo(currentFileContents);
assertThat(map.get("requested_token_type"))
.isEqualTo("urn:ietf:params:oauth:token-type:access_token");
assertThat(map.get("scope")).isEqualTo("https://www.googleapis.com/auth/cloud-platform");
assertThat(map.get("audience")).isEqualTo(AUDIENCE_VALUE);
}
@Test
public void stsCredentialsInCallCreds() throws IOException {
MockHttpTransport.Builder builder = new MockHttpTransport.Builder();
MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
response.setContent(MOCK_RESPONSE);
builder.setLowLevelHttpResponse(response);
MockHttpTransport httpTransport = builder.build();
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
when(httpTransportFactory.create()).thenReturn(httpTransport);
StsCredentials stsCredentials =
StsCredentials.create(
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
CallCredentials callCreds = MoreCallCredentials.from(stsCredentials);
CallCredentials.RequestInfo requestInfo = mock(CallCredentials.RequestInfo.class);
when(requestInfo.getSecurityLevel()).thenReturn(SecurityLevel.PRIVACY_AND_INTEGRITY);
when(requestInfo.getAuthority()).thenReturn("auth");
MethodDescriptor.Marshaller<?> requestMarshaller = mock(MethodDescriptor.Marshaller.class);
MethodDescriptor.Marshaller<?> responseMarshaller = mock(MethodDescriptor.Marshaller.class);
MethodDescriptor.Builder<?, ?> mBuilder =
MethodDescriptor.newBuilder(requestMarshaller, responseMarshaller);
mBuilder.setType(MethodDescriptor.MethodType.UNARY);
mBuilder.setFullMethodName("service/method");
MethodDescriptor<?, ?> methodDescriptor = mBuilder.build();
doReturn(methodDescriptor).when(requestInfo).getMethodDescriptor();
CallCredentials.MetadataApplier applier = mock(CallCredentials.MetadataApplier.class);
callCreds.applyRequestMetadata(requestInfo, MoreExecutors.directExecutor(), applier);
ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(null);
verify(applier).apply(metadataCaptor.capture());
Metadata metadata = metadataCaptor.getValue();
assertThat(metadata).isNotNull();
String authValue = metadata.get(KEY_FOR_AUTHORIZATION);
assertThat(authValue).isEqualTo("Bearer " + ACCESS_TOKEN);
}
@Test
public void testStsRequest_exception() throws IOException {
MockHttpTransport.Builder builder = new MockHttpTransport.Builder();
MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
response.setStatusCode(HttpStatusCodes.STATUS_CODE_BAD_REQUEST);
response.setContent(MOCK_RESPONSE);
builder.setLowLevelHttpResponse(response);
MockHttpTransport httpTransport = builder.build();
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
when(httpTransportFactory.create()).thenReturn(httpTransport);
StsCredentials stsCredentials =
StsCredentials.create(
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
try {
stsCredentials.refreshAccessToken();
fail("exception expected");
} catch (IOException ioe) {
assertThat(ioe.getMessage()).isEqualTo("Error requesting access token");
}
}
@Test
public void testStsRequest_nonSuccessCode() throws IOException {
MockHttpTransport.Builder builder = new MockHttpTransport.Builder();
MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
response.setStatusCode(HttpStatusCodes.STATUS_CODE_NO_CONTENT);
response.setContent(MOCK_RESPONSE);
builder.setLowLevelHttpResponse(response);
MockHttpTransport httpTransport = builder.build();
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
when(httpTransportFactory.create()).thenReturn(httpTransport);
StsCredentials stsCredentials =
StsCredentials.create(
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
try {
stsCredentials.refreshAccessToken();
fail("exception expected");
} catch (IOException ioe) {
assertThat(ioe.getMessage()).isEqualTo("Error getting access token: 204 : null");
}
}
@Test
public void toBuilder_unsupportedException() {
HttpTransportFactory httpTransportFactory = mock(HttpTransportFactory.class);
StsCredentials stsCredentials =
StsCredentials.create(
STS_URL, AUDIENCE_VALUE, tempTokenFile.getAbsolutePath(), httpTransportFactory);
try {
stsCredentials.toBuilder();
fail("exception expected");
} catch (UnsupportedOperationException uoe) {
assertThat(uoe.getMessage()).isEqualTo("toBuilder not supported");
}
}
private static final String ACCESS_TOKEN = "eyJhbGciOiJSU";
private static final String MOCK_RESPONSE =
"{\"access_token\": \""
+ ACCESS_TOKEN
+ "\",\n"
+ " \"issued_token_type\": \"urn:ietf:params:oauth:token-type:access_token\",\n"
+ " \"token_type\": \"Bearer\",\n"
+ " \"expires_in\": 3600\n"
+ "}";
}