credentials/sts: PerRPCCreds Implementation (#3696)

This commit is contained in:
Easwar Swaminathan 2020-07-09 12:15:45 -07:00 committed by GitHub
parent 9af290fac4
commit e8fb6c1752
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 1159 additions and 0 deletions

395
credentials/sts/sts.go Normal file
View File

@ -0,0 +1,395 @@
// +build go1.13
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package sts implements call credentials using STS (Security Token Service) as
// defined in https://tools.ietf.org/html/rfc8693.
//
// Experimental
//
// Notice: All APIs in this package are experimental and may be changed or
// removed in a later release.
package sts
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"sync"
"time"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
)
const (
// HTTP request timeout set on the http.Client used to make STS requests.
stsRequestTimeout = 5 * time.Second
// If lifetime left in a cached token is lesser than this value, we fetch a
// new one instead of returning the current one.
minCachedTokenLifetime = 300 * time.Second
tokenExchangeGrantType = "urn:ietf:params:oauth:grant-type:token-exchange"
defaultCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
)
// For overriding in tests.
var (
loadSystemCertPool = x509.SystemCertPool
makeHTTPDoer = makeHTTPClient
readSubjectTokenFrom = ioutil.ReadFile
readActorTokenFrom = ioutil.ReadFile
)
// Options configures the parameters used for an STS based token exchange.
type Options struct {
// TokenExchangeServiceURI is the address of the server which implements STS
// token exchange functionality.
TokenExchangeServiceURI string // Required.
// Resource is a URI that indicates the target service or resource where the
// client intends to use the requested security token.
Resource string // Optional.
// Audience is the logical name of the target service where the client
// intends to use the requested security token
Audience string // Optional.
// Scope is a list of space-delimited, case-sensitive strings, that allow
// the client to specify the desired scope of the requested security token
// in the context of the service or resource where the token will be used.
// If this field is left unspecified, a default value of
// https://www.googleapis.com/auth/cloud-platform will be used.
Scope string // Optional.
// RequestedTokenType is an identifier, as described in
// https://tools.ietf.org/html/rfc8693#section-3, that indicates the type of
// the requested security token.
RequestedTokenType string // Optional.
// SubjectTokenPath is a filesystem path which contains the security token
// that represents the identity of the party on behalf of whom the request
// is being made.
SubjectTokenPath string // Required.
// SubjectTokenType is an identifier, as described in
// https://tools.ietf.org/html/rfc8693#section-3, that indicates the type of
// the security token in the "subject_token_path" parameter.
SubjectTokenType string // Required.
// ActorTokenPath is a security token that represents the identity of the
// acting party.
ActorTokenPath string // Optional.
// ActorTokenType is an identifier, as described in
// https://tools.ietf.org/html/rfc8693#section-3, that indicates the type of
// the the security token in the "actor_token_path" parameter.
ActorTokenType string // Optional.
}
// NewCredentials returns a new PerRPCCredentials implementation, configured
// using opts, which performs token exchange using STS.
func NewCredentials(opts Options) (credentials.PerRPCCredentials, error) {
if err := validateOptions(opts); err != nil {
return nil, err
}
// Load the system roots to validate the certificate presented by the STS
// endpoint during the TLS handshake.
roots, err := loadSystemCertPool()
if err != nil {
return nil, err
}
return &callCreds{
opts: opts,
client: makeHTTPDoer(roots),
}, nil
}
// callCreds provides the implementation of call credentials based on an STS
// token exchange.
type callCreds struct {
opts Options
client httpDoer
// Cached accessToken to avoid an STS token exchange for every call to
// GetRequestMetadata.
mu sync.Mutex
tokenMetadata map[string]string
tokenExpiry time.Time
}
// GetRequestMetadata returns the cached accessToken, if available and valid, or
// fetches a new one by performing an STS token exchange.
func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) {
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer STS PerRPCCredentials: %v", err)
}
// Holding the lock for the whole duration of the STS request and response
// processing ensures that concurrent RPCs don't end up in multiple
// requests being made.
c.mu.Lock()
defer c.mu.Unlock()
if md := c.cachedMetadata(); md != nil {
return md, nil
}
req, err := constructRequest(ctx, c.opts)
if err != nil {
return nil, err
}
respBody, err := sendRequest(c.client, req)
if err != nil {
return nil, err
}
ti, err := tokenInfoFromResponse(respBody)
if err != nil {
return nil, err
}
c.tokenMetadata = map[string]string{"Authorization": fmt.Sprintf("%s %s", ti.tokenType, ti.token)}
c.tokenExpiry = ti.expiryTime
return c.tokenMetadata, nil
}
// RequireTransportSecurity indicates whether the credentials requires
// transport security.
func (c *callCreds) RequireTransportSecurity() bool {
return true
}
// httpDoer wraps the single method on the http.Client type that we use. This
// helps with overriding in unittests.
type httpDoer interface {
Do(req *http.Request) (*http.Response, error)
}
func makeHTTPClient(roots *x509.CertPool) httpDoer {
return &http.Client{
Timeout: stsRequestTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: roots,
},
},
}
}
// validateOptions performs the following validation checks on opts:
// - tokenExchangeServiceURI is not empty
// - tokenExchangeServiceURI is a valid URI with a http(s) scheme
// - subjectTokenPath and subjectTokenType are not empty.
func validateOptions(opts Options) error {
if opts.TokenExchangeServiceURI == "" {
return errors.New("empty token_exchange_service_uri in options")
}
u, err := url.Parse(opts.TokenExchangeServiceURI)
if err != nil {
return err
}
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("scheme is not supported: %s. Only http(s) is supported", u.Scheme)
}
if opts.SubjectTokenPath == "" {
return errors.New("required field SubjectTokenPath is not specified")
}
if opts.SubjectTokenType == "" {
return errors.New("required field SubjectTokenType is not specified")
}
return nil
}
// cachedMetadata returns the cached metadata provided it is not going to
// expire anytime soon.
//
// Caller must hold c.mu.
func (c *callCreds) cachedMetadata() map[string]string {
now := time.Now()
// If the cached token has not expired and the lifetime remaining on that
// token is greater than the minimum value we are willing to accept, go
// ahead and use it.
if c.tokenExpiry.After(now) && c.tokenExpiry.Sub(now) > minCachedTokenLifetime {
return c.tokenMetadata
}
return nil
}
// constructRequest creates the STS request body in JSON based on the provided
// options.
// - Contents of the subjectToken are read from the file specified in
// options. If we encounter an error here, we bail out.
// - Contents of the actorToken are read from the file specified in options.
// If we encounter an error here, we ignore this field because this is
// optional.
// - Most of the other fields in the request come directly from options.
//
// A new HTTP request is created by calling http.NewRequestWithContext() and
// passing the provided context, thereby enforcing any timeouts specified in
// the latter.
func constructRequest(ctx context.Context, opts Options) (*http.Request, error) {
subToken, err := readSubjectTokenFrom(opts.SubjectTokenPath)
if err != nil {
return nil, err
}
reqScope := opts.Scope
if reqScope == "" {
reqScope = defaultCloudPlatformScope
}
reqParams := &requestParameters{
GrantType: tokenExchangeGrantType,
Resource: opts.Resource,
Audience: opts.Audience,
Scope: reqScope,
RequestedTokenType: opts.RequestedTokenType,
SubjectToken: string(subToken),
SubjectTokenType: opts.SubjectTokenType,
}
if opts.ActorTokenPath != "" {
actorToken, err := readActorTokenFrom(opts.ActorTokenPath)
if err != nil {
return nil, err
}
reqParams.ActorToken = string(actorToken)
reqParams.ActorTokenType = opts.ActorTokenType
}
jsonBody, err := json.Marshal(reqParams)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", opts.TokenExchangeServiceURI, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create http request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
return req, nil
}
func sendRequest(client httpDoer, req *http.Request) ([]byte, error) {
// http.Client returns a non-nil error only if it encounters an error
// caused by client policy (such as CheckRedirect), or failure to speak
// HTTP (such as a network connectivity problem). A non-2xx status code
// doesn't cause an error.
resp, err := client.Do(req)
if err != nil {
return nil, err
}
// When the http.Client returns a non-nil error, it is the
// responsibility of the caller to read the response body till an EOF is
// encountered and to close it.
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, err
}
if resp.StatusCode == http.StatusOK {
return body, nil
}
grpclog.Warningf("http status %d, body: %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("http status %d, body: %s", resp.StatusCode, string(body))
}
func tokenInfoFromResponse(respBody []byte) (*tokenInfo, error) {
respData := &responseParameters{}
if err := json.Unmarshal(respBody, respData); err != nil {
return nil, fmt.Errorf("json.Unmarshal(%v): %v", respBody, err)
}
if respData.AccessToken == "" {
return nil, fmt.Errorf("empty accessToken in response (%v)", string(respBody))
}
return &tokenInfo{
tokenType: respData.TokenType,
token: respData.AccessToken,
expiryTime: time.Now().Add(time.Duration(respData.ExpiresIn) * time.Second),
}, nil
}
// requestParameters stores all STS request attributes defined in
// https://tools.ietf.org/html/rfc8693#section-2.1.
type requestParameters struct {
// REQUIRED. The value "urn:ietf:params:oauth:grant-type:token-exchange"
// indicates that a token exchange is being performed.
GrantType string `json:"grant_type"`
// OPTIONAL. Indicates the location of the target service or resource where
// the client intends to use the requested security token.
Resource string `json:"resource,omitempty"`
// OPTIONAL. The logical name of the target service where the client intends
// to use the requested security token.
Audience string `json:"audience,omitempty"`
// OPTIONAL. A list of space-delimited, case-sensitive strings, that allow
// the client to specify the desired scope of the requested security token
// in the context of the service or Resource where the token will be used.
Scope string `json:"scope,omitempty"`
// OPTIONAL. An identifier, for the type of the requested security token.
RequestedTokenType string `json:"requested_token_type,omitempty"`
// REQUIRED. A security token that represents the identity of the party on
// behalf of whom the request is being made.
SubjectToken string `json:"subject_token"`
// REQUIRED. An identifier, that indicates the type of the security token in
// the "subject_token" parameter.
SubjectTokenType string `json:"subject_token_type"`
// OPTIONAL. A security token that represents the identity of the acting
// party.
ActorToken string `json:"actor_token,omitempty"`
// An identifier, that indicates the type of the security token in the
// "actor_token" parameter.
ActorTokenType string `json:"actor_token_type,omitempty"`
}
// nesponseParameters stores all attributes sent as JSON in a successful STS
// response. These attributes are defined in
// https://tools.ietf.org/html/rfc8693#section-2.2.1.
type responseParameters struct {
// REQUIRED. The security token issued by the authorization server
// in response to the token exchange request.
AccessToken string `json:"access_token"`
// REQUIRED. An identifier, representation of the issued security token.
IssuedTokenType string `json:"issued_token_type"`
// REQUIRED. A case-insensitive value specifying the method of using the access
// token issued. It provides the client with information about how to utilize the
// access token to access protected resources.
TokenType string `json:"token_type"`
// RECOMMENDED. The validity lifetime, in seconds, of the token issued by the
// authorization server.
ExpiresIn int64 `json:"expires_in"`
// OPTIONAL, if the Scope of the issued security token is identical to the
// Scope requested by the client; otherwise, REQUIRED.
Scope string `json:"scope"`
// OPTIONAL. A refresh token will typically not be issued when the exchange is
// of one temporary credential (the subject_token) for a different temporary
// credential (the issued token) for use in some other context.
RefreshToken string `json:"refresh_token"`
}
// tokenInfo wraps the information received in a successful STS response.
type tokenInfo struct {
tokenType string
token string
expiryTime time.Time
}

764
credentials/sts/sts_test.go Normal file
View File

@ -0,0 +1,764 @@
// +build go1.13
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package sts
import (
"bytes"
"context"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
)
const (
requestedTokenType = "urn:ietf:params:oauth:token-type:access-token"
actorTokenPath = "/var/run/secrets/token.jwt"
actorTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
actorTokenContents = "actorToken.jwt.contents"
accessTokenContents = "access_token"
subjectTokenPath = "/var/run/secrets/token.jwt"
subjectTokenType = "urn:ietf:params:oauth:token-type:id_token"
subjectTokenContents = "subjectToken.jwt.contents"
serviceURI = "http://localhost"
exampleResource = "https://backend.example.com/api"
exampleAudience = "example-backend-service"
testScope = "https://www.googleapis.com/auth/monitoring"
)
var (
goodOptions = Options{
TokenExchangeServiceURI: serviceURI,
Audience: exampleAudience,
RequestedTokenType: requestedTokenType,
SubjectTokenPath: subjectTokenPath,
SubjectTokenType: subjectTokenType,
}
goodRequestParams = &requestParameters{
GrantType: tokenExchangeGrantType,
Audience: exampleAudience,
Scope: defaultCloudPlatformScope,
RequestedTokenType: requestedTokenType,
SubjectToken: subjectTokenContents,
SubjectTokenType: subjectTokenType,
}
goodMetadata = map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", accessTokenContents),
}
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
// A struct that implements AuthInfo interface and added to the context passed
// to GetRequestMetadata from tests.
type testAuthInfo struct {
credentials.CommonAuthInfo
}
func (ta testAuthInfo) AuthType() string {
return "testAuthInfo"
}
func createTestContext(ctx context.Context, s credentials.SecurityLevel) context.Context {
auth := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: s}}
ri := credentials.RequestInfo{
Method: "testInfo",
AuthInfo: auth,
}
return internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri)
}
// errReader implements the io.Reader interface and returns an error from the
// Read method.
type errReader struct{}
func (r errReader) Read(b []byte) (n int, err error) {
return 0, errors.New("read error")
}
// We need a function to construct the response instead of simply declaring it
// as a variable since the the response body will be consumed by the
// credentials, and therefore we will need a new one everytime.
func makeGoodResponse() *http.Response {
respJSON, _ := json.Marshal(responseParameters{
AccessToken: accessTokenContents,
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: 3600,
})
respBody := ioutil.NopCloser(bytes.NewReader(respJSON))
return &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Body: respBody,
}
}
// fakeHTTPDoer helps mock out the http.Client.Do calls made by the credentials
// code under test. It makes the http.Request made by the credentials available
// through a channel, and makes it possible to inject various responses.
type fakeHTTPDoer struct {
reqCh *testutils.Channel
respCh *testutils.Channel
err error
}
func (fc *fakeHTTPDoer) Do(req *http.Request) (*http.Response, error) {
fc.reqCh.Send(req)
val, err := fc.respCh.Receive()
if err != nil {
return nil, err
}
return val.(*http.Response), fc.err
}
// Overrides the http.Client with a fakeClient which sends a good response.
func overrideHTTPClientGood() (*fakeHTTPDoer, func()) {
fc := &fakeHTTPDoer{
reqCh: testutils.NewChannel(),
respCh: testutils.NewChannel(),
}
fc.respCh.Send(makeGoodResponse())
origMakeHTTPDoer := makeHTTPDoer
makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
return fc, func() { makeHTTPDoer = origMakeHTTPDoer }
}
// Overrides the http.Client with the provided fakeClient.
func overrideHTTPClient(fc *fakeHTTPDoer) func() {
origMakeHTTPDoer := makeHTTPDoer
makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
return func() { makeHTTPDoer = origMakeHTTPDoer }
}
// Overrides the subject token read to return a const which we can compare in
// our tests.
func overrideSubjectTokenGood() func() {
origReadSubjectTokenFrom := readSubjectTokenFrom
readSubjectTokenFrom = func(path string) ([]byte, error) {
return []byte(subjectTokenContents), nil
}
return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
}
// Overrides the subject token read to always return an error.
func overrideSubjectTokenError() func() {
origReadSubjectTokenFrom := readSubjectTokenFrom
readSubjectTokenFrom = func(path string) ([]byte, error) {
return nil, errors.New("error reading subject token")
}
return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
}
// Overrides the actor token read to return a const which we can compare in
// our tests.
func overrideActorTokenGood() func() {
origReadActorTokenFrom := readActorTokenFrom
readActorTokenFrom = func(path string) ([]byte, error) {
return []byte(actorTokenContents), nil
}
return func() { readActorTokenFrom = origReadActorTokenFrom }
}
// Overrides the actor token read to always return an error.
func overrideActorTokenError() func() {
origReadActorTokenFrom := readActorTokenFrom
readActorTokenFrom = func(path string) ([]byte, error) {
return nil, errors.New("error reading actor token")
}
return func() { readActorTokenFrom = origReadActorTokenFrom }
}
// compareRequest compares the http.Request received in the test with the
// expected requestParameters specified in wantReqParams.
func compareRequest(gotRequest *http.Request, wantReqParams *requestParameters) error {
jsonBody, err := json.Marshal(wantReqParams)
if err != nil {
return err
}
wantReq, err := http.NewRequest("POST", serviceURI, bytes.NewBuffer(jsonBody))
if err != nil {
return fmt.Errorf("failed to create http request: %v", err)
}
wantReq.Header.Set("Content-Type", "application/json")
wantR, err := httputil.DumpRequestOut(wantReq, true)
if err != nil {
return err
}
gotR, err := httputil.DumpRequestOut(gotRequest, true)
if err != nil {
return err
}
if diff := cmp.Diff(string(wantR), string(gotR)); diff != "" {
return fmt.Errorf("sts request diff (-want +got):\n%s", diff)
}
return nil
}
// receiveAndCompareRequest waits for a request to be sent out by the
// credentials implementation using the fakeHTTPClient and compares it to an
// expected goodRequest. This is expected to be called in a separate goroutine
// by the tests. So, any errors encountered are pushed to an error channel
// which is monitored by the test.
func receiveAndCompareRequest(reqCh *testutils.Channel, errCh chan error) {
val, err := reqCh.Receive()
if err != nil {
errCh <- err
return
}
req := val.(*http.Request)
if err := compareRequest(req, goodRequestParams); err != nil {
errCh <- err
return
}
errCh <- nil
}
// TestGetRequestMetadataSuccess verifies the successful case of sending an
// token exchange request and processing the response.
func (s) TestGetRequestMetadataSuccess(t *testing.T) {
defer overrideSubjectTokenGood()()
fc, cancel := overrideHTTPClientGood()
defer cancel()
creds, err := NewCredentials(goodOptions)
if err != nil {
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
}
errCh := make(chan error, 1)
go receiveAndCompareRequest(fc.reqCh, errCh)
gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
if err != nil {
t.Fatalf("creds.GetRequestMetadata() = %v", err)
}
if !cmp.Equal(gotMetadata, goodMetadata) {
t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
}
if err := <-errCh; err != nil {
t.Fatal(err)
}
// Make another call to get request metadata and this should return contents
// from the cache. This will fail if the credentials tries to send a fresh
// request here since we have not configured our fakeClient to return any
// response on retries.
gotMetadata, err = creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
if err != nil {
t.Fatalf("creds.GetRequestMetadata() = %v", err)
}
if !cmp.Equal(gotMetadata, goodMetadata) {
t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
}
}
// TestGetRequestMetadataBadSecurityLevel verifies the case where the
// securityLevel specified in the context passed to GetRequestMetadata is not
// sufficient.
func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) {
defer overrideSubjectTokenGood()()
creds, err := NewCredentials(goodOptions)
if err != nil {
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
}
gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.IntegrityOnly), "")
if err == nil {
t.Fatalf("creds.GetRequestMetadata() succeeded with metadata %v, expected to fail", gotMetadata)
}
}
// TestGetRequestMetadataCacheExpiry verifies the case where the cached access
// token has expired, and the credentials implementation will have to send a
// fresh token exchange request.
func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) {
const expiresInSecs = 1
defer overrideSubjectTokenGood()()
fc := &fakeHTTPDoer{
reqCh: testutils.NewChannel(),
respCh: testutils.NewChannel(),
}
defer overrideHTTPClient(fc)()
creds, err := NewCredentials(goodOptions)
if err != nil {
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
}
// The fakeClient is configured to return an access_token with a one second
// expiry. So, in the second iteration, the credentials will find the cache
// entry, but that would have expired, and therefore we expect it to send
// out a fresh request.
for i := 0; i < 2; i++ {
errCh := make(chan error, 1)
go receiveAndCompareRequest(fc.reqCh, errCh)
respJSON, _ := json.Marshal(responseParameters{
AccessToken: accessTokenContents,
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: expiresInSecs,
})
respBody := ioutil.NopCloser(bytes.NewReader(respJSON))
resp := &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Body: respBody,
}
fc.respCh.Send(resp)
gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
if err != nil {
t.Fatalf("creds.GetRequestMetadata() = %v", err)
}
if !cmp.Equal(gotMetadata, goodMetadata) {
t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
}
if err := <-errCh; err != nil {
t.Fatal(err)
}
time.Sleep(expiresInSecs * time.Second)
}
}
// TestGetRequestMetadataBadResponses verifies the scenario where the token
// exchange server returns bad responses.
func (s) TestGetRequestMetadataBadResponses(t *testing.T) {
tests := []struct {
name string
response *http.Response
}{
{
name: "bad JSON",
response: &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader("not JSON")),
},
},
{
name: "no access token",
response: &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader("{}")),
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
defer overrideSubjectTokenGood()()
fc := &fakeHTTPDoer{
reqCh: testutils.NewChannel(),
respCh: testutils.NewChannel(),
}
defer overrideHTTPClient(fc)()
creds, err := NewCredentials(goodOptions)
if err != nil {
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
}
errCh := make(chan error, 1)
go receiveAndCompareRequest(fc.reqCh, errCh)
fc.respCh.Send(test.response)
if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil {
t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
}
if err := <-errCh; err != nil {
t.Fatal(err)
}
})
}
}
// TestGetRequestMetadataBadSubjectTokenRead verifies the scenario where the
// attempt to read the subjectToken fails.
func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) {
defer overrideSubjectTokenError()()
fc, cancel := overrideHTTPClientGood()
defer cancel()
creds, err := NewCredentials(goodOptions)
if err != nil {
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
}
errCh := make(chan error, 1)
go func() {
if _, err := fc.reqCh.Receive(); err != testutils.ErrRecvTimeout {
errCh <- err
return
}
errCh <- nil
}()
if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil {
t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
}
if err := <-errCh; err != nil {
t.Fatal(err)
}
}
func (s) TestNewCredentials(t *testing.T) {
tests := []struct {
name string
opts Options
errSystemRoots bool
wantErr bool
}{
{
name: "invalid options - empty subjectTokenPath",
opts: Options{
TokenExchangeServiceURI: serviceURI,
},
wantErr: true,
},
{
name: "invalid system root certs",
opts: goodOptions,
errSystemRoots: true,
wantErr: true,
},
{
name: "good case",
opts: goodOptions,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.errSystemRoots {
oldSystemRoots := loadSystemCertPool
loadSystemCertPool = func() (*x509.CertPool, error) {
return nil, errors.New("failed to load system cert pool")
}
defer func() {
loadSystemCertPool = oldSystemRoots
}()
}
creds, err := NewCredentials(test.opts)
if (err != nil) != test.wantErr {
t.Fatalf("NewCredentials(%v) = %v, want %v", test.opts, err, test.wantErr)
}
if err == nil {
if !creds.RequireTransportSecurity() {
t.Errorf("creds.RequireTransportSecurity() returned false")
}
}
})
}
}
func (s) TestValidateOptions(t *testing.T) {
tests := []struct {
name string
opts Options
wantErrPrefix string
}{
{
name: "empty token exchange service URI",
opts: Options{},
wantErrPrefix: "empty token_exchange_service_uri in options",
},
{
name: "invalid URI",
opts: Options{
TokenExchangeServiceURI: "\tI'm a bad URI\n",
},
wantErrPrefix: "invalid control character in URL",
},
{
name: "unsupported scheme",
opts: Options{
TokenExchangeServiceURI: "unix:///path/to/socket",
},
wantErrPrefix: "scheme is not supported",
},
{
name: "empty subjectTokenPath",
opts: Options{
TokenExchangeServiceURI: serviceURI,
},
wantErrPrefix: "required field SubjectTokenPath is not specified",
},
{
name: "empty subjectTokenType",
opts: Options{
TokenExchangeServiceURI: serviceURI,
SubjectTokenPath: subjectTokenPath,
},
wantErrPrefix: "required field SubjectTokenType is not specified",
},
{
name: "good options",
opts: goodOptions,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
err := validateOptions(test.opts)
if (err != nil) != (test.wantErrPrefix != "") {
t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
}
if err != nil && !strings.Contains(err.Error(), test.wantErrPrefix) {
t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
}
})
}
}
func (s) TestConstructRequest(t *testing.T) {
tests := []struct {
name string
opts Options
subjectTokenReadErr bool
actorTokenReadErr bool
wantReqParams *requestParameters
wantErr bool
}{
{
name: "subject token read failure",
subjectTokenReadErr: true,
opts: goodOptions,
wantErr: true,
},
{
name: "actor token read failure",
actorTokenReadErr: true,
opts: Options{
TokenExchangeServiceURI: serviceURI,
Audience: exampleAudience,
RequestedTokenType: requestedTokenType,
SubjectTokenPath: subjectTokenPath,
SubjectTokenType: subjectTokenType,
ActorTokenPath: actorTokenPath,
ActorTokenType: actorTokenType,
},
wantErr: true,
},
{
name: "default cloud platform scope",
opts: goodOptions,
wantReqParams: goodRequestParams,
},
{
name: "all good",
opts: Options{
TokenExchangeServiceURI: serviceURI,
Resource: exampleResource,
Audience: exampleAudience,
Scope: testScope,
RequestedTokenType: requestedTokenType,
SubjectTokenPath: subjectTokenPath,
SubjectTokenType: subjectTokenType,
ActorTokenPath: actorTokenPath,
ActorTokenType: actorTokenType,
},
wantReqParams: &requestParameters{
GrantType: tokenExchangeGrantType,
Resource: exampleResource,
Audience: exampleAudience,
Scope: testScope,
RequestedTokenType: requestedTokenType,
SubjectToken: subjectTokenContents,
SubjectTokenType: subjectTokenType,
ActorToken: actorTokenContents,
ActorTokenType: actorTokenType,
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.subjectTokenReadErr {
defer overrideSubjectTokenError()()
} else {
defer overrideSubjectTokenGood()()
}
if test.actorTokenReadErr {
defer overrideActorTokenError()()
} else {
defer overrideActorTokenGood()()
}
gotRequest, err := constructRequest(context.Background(), test.opts)
if (err != nil) != test.wantErr {
t.Fatalf("constructRequest(%v) = %v, wantErr: %v", test.opts, err, test.wantErr)
}
if test.wantErr {
return
}
if err := compareRequest(gotRequest, test.wantReqParams); err != nil {
t.Fatal(err)
}
})
}
}
func (s) TestSendRequest(t *testing.T) {
defer overrideSubjectTokenGood()()
req, err := constructRequest(context.Background(), goodOptions)
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
resp *http.Response
respErr error
wantErr bool
}{
{
name: "client error",
respErr: errors.New("http.Client.Do failed"),
wantErr: true,
},
{
name: "bad response body",
resp: &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(errReader{}),
},
wantErr: true,
},
{
name: "nonOK status code",
resp: &http.Response{
Status: "400 BadRequest",
StatusCode: http.StatusBadRequest,
Body: ioutil.NopCloser(strings.NewReader("")),
},
wantErr: true,
},
{
name: "good case",
resp: makeGoodResponse(),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client := &fakeHTTPDoer{
reqCh: testutils.NewChannel(),
respCh: testutils.NewChannel(),
err: test.respErr,
}
client.respCh.Send(test.resp)
_, err := sendRequest(client, req)
if (err != nil) != test.wantErr {
t.Errorf("sendRequest(%v) = %v, wantErr: %v", req, err, test.wantErr)
}
})
}
}
func (s) TestTokenInfoFromResponse(t *testing.T) {
noAccessToken, _ := json.Marshal(responseParameters{
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: 3600,
})
goodResponse, _ := json.Marshal(responseParameters{
IssuedTokenType: requestedTokenType,
AccessToken: accessTokenContents,
TokenType: "Bearer",
ExpiresIn: 3600,
})
tests := []struct {
name string
respBody []byte
wantTokenInfo *tokenInfo
wantErr bool
}{
{
name: "bad JSON",
respBody: []byte("not JSON"),
wantErr: true,
},
{
name: "empty response",
respBody: []byte(""),
wantErr: true,
},
{
name: "non-empty response with no access token",
respBody: noAccessToken,
wantErr: true,
},
{
name: "good response",
respBody: goodResponse,
wantTokenInfo: &tokenInfo{
tokenType: "Bearer",
token: accessTokenContents,
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
gotTokenInfo, err := tokenInfoFromResponse(test.respBody)
if (err != nil) != test.wantErr {
t.Fatalf("tokenInfoFromResponse(%+v) = %v, wantErr: %v", test.respBody, err, test.wantErr)
}
if test.wantErr {
return
}
// Can't do a cmp.Equal on the whole struct since the expiryField
// is populated based on time.Now().
if gotTokenInfo.tokenType != test.wantTokenInfo.tokenType || gotTokenInfo.token != test.wantTokenInfo.token {
t.Errorf("tokenInfoFromResponse(%+v) = %+v, want: %+v", test.respBody, gotTokenInfo, test.wantTokenInfo)
}
})
}
}