Adding middleware component for non interactive oauth2 client credentials flow (#431)

* Initial commit with mock not working

Signed-off-by: Florian Wagner <flwagner@microsoft.com>

* changed structure for mocking, added first test

Signed-off-by: Florian Wagner <flwagner@microsoft.com>

* oauth2clientcredentials input checks and tests

Signed-off-by: Florian Wagner <flwagner@microsoft.com>

* rename metadata authHeaderName to HeaderName

Signed-off-by: Florian Wagner <flwagner@microsoft.com>

* Run 'go mod tidy'

Signed-off-by: Florian Wagner <flwagner@microsoft.com>

* use dapr logger from caller

Signed-off-by: Florian Wagner <flwagner@microsoft.com>

* Fix additional lint requirements by build pipeline

Signed-off-by: Florian Wagner <flwagner@microsoft.com>

Co-authored-by: Florian Wagner <flwagner@microsoft.com>
Co-authored-by: Yaron Schneider <yaronsc@microsoft.com>
This commit is contained in:
Florian Wagner 2020-08-17 03:05:24 +09:00 committed by GitHub
parent f5f807ca73
commit e60fa843f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 413 additions and 0 deletions

1
go.mod
View File

@ -54,6 +54,7 @@ require (
github.com/nats-io/nats.go v1.9.1
github.com/nats-io/stan.go v0.6.0
github.com/openzipkin/zipkin-go v0.1.6
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1
github.com/robfig/cron/v3 v3.0.1
github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da

View File

@ -0,0 +1,50 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/dapr/components-contrib/middleware/http/oauth2clientcredentials (interfaces: TokenProviderInterface)
// Package mock_oauth2clientcredentials is a generated GoMock package.
package mock_oauth2clientcredentials
import (
gomock "github.com/golang/mock/gomock"
oauth2 "golang.org/x/oauth2"
clientcredentials "golang.org/x/oauth2/clientcredentials"
reflect "reflect"
)
// MockTokenProviderInterface is a mock of TokenProviderInterface interface
type MockTokenProviderInterface struct {
ctrl *gomock.Controller
recorder *MockTokenProviderInterfaceMockRecorder
}
// MockTokenProviderInterfaceMockRecorder is the mock recorder for MockTokenProviderInterface
type MockTokenProviderInterfaceMockRecorder struct {
mock *MockTokenProviderInterface
}
// NewMockTokenProviderInterface creates a new mock instance
func NewMockTokenProviderInterface(ctrl *gomock.Controller) *MockTokenProviderInterface {
mock := &MockTokenProviderInterface{ctrl: ctrl}
mock.recorder = &MockTokenProviderInterfaceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockTokenProviderInterface) EXPECT() *MockTokenProviderInterfaceMockRecorder {
return m.recorder
}
// GetToken mocks base method
func (m *MockTokenProviderInterface) GetToken(arg0 *clientcredentials.Config) (*oauth2.Token, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetToken", arg0)
ret0, _ := ret[0].(*oauth2.Token)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetToken indicates an expected call of GetToken
func (mr *MockTokenProviderInterfaceMockRecorder) GetToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetToken", reflect.TypeOf((*MockTokenProviderInterface)(nil).GetToken), arg0)
}

View File

@ -0,0 +1,182 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// ------------------------------------------------------------
package oauth2clientcredentials
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"net/url"
"strconv"
"strings"
"time"
"github.com/dapr/components-contrib/middleware"
"github.com/dapr/dapr/pkg/logger"
"github.com/patrickmn/go-cache"
"github.com/valyala/fasthttp"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)
// Metadata is the oAuth clientcredentials middleware config
type oAuth2ClientCredentialsMiddlewareMetadata struct {
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
Scopes string `json:"scopes"`
TokenURL string `json:"tokenURL"`
HeaderName string `json:"headerName"`
EndpointParamsQuery string `json:"endpointParamsQuery,omitempty"`
AuthStyleString string `json:"authStyle"`
AuthStyle int `json:"-"`
}
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests
type TokenProviderInterface interface {
GetToken(conf *clientcredentials.Config) (*oauth2.Token, error)
}
// NewOAuth2ClientCredentialsMiddleware returns a new oAuth2 middleware
func NewOAuth2ClientCredentialsMiddleware(logger logger.Logger) *Middleware {
m := &Middleware{
log: logger,
tokenCache: cache.New(1*time.Hour, 10*time.Minute),
}
// Default: set Token Provider to real implementation (we will overwrite it for unit testing)
m.SetTokenProvider(m)
return m
}
// Middleware is an oAuth2 authentication middleware
type Middleware struct {
log logger.Logger
tokenCache *cache.Cache
tokenProvider TokenProviderInterface
}
// GetHandler retruns the HTTP handler provided by the middleware
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) {
meta, err := m.getNativeMetadata(metadata)
if err != nil {
m.log.Errorf("getNativeMetadata error, %s", err)
return nil, err
}
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
var headerValue string
// Check if valid Token is in the cache
var cacheKey = m.getCacheKey(meta)
cachedToken, found := m.tokenCache.Get(cacheKey)
if !found {
m.log.Debugf("Cached token not found, try get one")
var endpointParams, err = url.ParseQuery(meta.EndpointParamsQuery)
if err != nil {
m.log.Errorf("Error parsing endpoint parameters, %s", err)
endpointParams, _ = url.ParseQuery("")
}
conf := &clientcredentials.Config{
ClientID: meta.ClientID,
ClientSecret: meta.ClientSecret,
Scopes: strings.Split(meta.Scopes, ","),
TokenURL: meta.TokenURL,
EndpointParams: endpointParams,
AuthStyle: oauth2.AuthStyle(meta.AuthStyle),
}
token, tokenError := m.tokenProvider.GetToken(conf)
if tokenError != nil {
m.log.Errorf("Error acquiring token, %s", tokenError)
return
}
tokenExpirationDuration := token.Expiry.Sub(time.Now().In(time.UTC))
m.log.Debugf("Duration in seconds %s, Expiry Time %s", tokenExpirationDuration, token.Expiry)
if err != nil {
m.log.Errorf("Error parsing duration string, %s", fmt.Sprintf("%ss", token.Expiry))
return
}
headerValue = token.Type() + " " + token.AccessToken
m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration)
} else {
m.log.Debugf("Cached token found for key %s", cacheKey)
headerValue = cachedToken.(string)
}
ctx.Request.Header.Add(meta.HeaderName, headerValue)
h(ctx)
}
}, nil
}
func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2ClientCredentialsMiddlewareMetadata, error) {
b, err := json.Marshal(metadata.Properties)
if err != nil {
return nil, err
}
var middlewareMetadata oAuth2ClientCredentialsMiddlewareMetadata
err = json.Unmarshal(b, &middlewareMetadata)
if err != nil {
return nil, err
}
// Do input validation checks
errorString := ""
// Check if values are present
m.checkMetadataValueExists(&errorString, &middlewareMetadata.HeaderName, "headerName")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientID, "clientID")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientSecret, "clientSecret")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.Scopes, "scopes")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.TokenURL, "tokenURL")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.AuthStyleString, "authStyle")
// Converting AuthStyle to int and do a value check
authStyle, err := strconv.Atoi(middlewareMetadata.AuthStyleString)
if err != nil {
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%s'. ", middlewareMetadata.AuthStyleString)
} else if authStyle < 0 || authStyle > 2 {
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", authStyle)
} else {
middlewareMetadata.AuthStyle = authStyle
}
// Return errors if any found
if errorString != "" {
return nil, fmt.Errorf("%s", errorString)
}
return &middlewareMetadata, nil
}
func (m *Middleware) checkMetadataValueExists(errorString *string, metadataValue *string, metadataName string) {
if *metadataValue == "" {
*errorString += fmt.Sprintf("Parameter '%s' needs to be set. ", metadataName)
}
}
func (m *Middleware) getCacheKey(meta *oAuth2ClientCredentialsMiddlewareMetadata) string {
// we will hash the key components ClientID + Scopes is a unique composite key/identifier for a token
hashedKey := sha256.New()
key := strings.Join([]string{meta.ClientID, meta.Scopes}, "")
hashedKey.Write([]byte(key))
return fmt.Sprintf("%x", hashedKey.Sum(nil))
}
// SetTokenProvider will enable to change the tokenProvider used after instanciation (needed for mocking)
func (m *Middleware) SetTokenProvider(tokenProvider TokenProviderInterface) {
m.tokenProvider = tokenProvider
}
// GetToken returns a token from the current OAuth2 ClientCredentials Configuration
func (m *Middleware) GetToken(conf *clientcredentials.Config) (*oauth2.Token, error) {
tokenSource := conf.TokenSource(context.Background())
return tokenSource.Token()
}

View File

@ -0,0 +1,180 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// ------------------------------------------------------------
package oauth2clientcredentials
import (
"testing"
"time"
mock "github.com/dapr/components-contrib/middleware/http/oauth2clientcredentials/mocks"
"github.com/dapr/dapr/pkg/logger"
"github.com/dapr/components-contrib/middleware"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
fh "github.com/valyala/fasthttp"
oauth2 "golang.org/x/oauth2"
)
func mockedRequestHandler(ctx *fh.RequestCtx) {}
// TestOAuth2ClientCredentialsMetadata will check
// - if the metadata checks are correct in place
func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
// Specify components metadata
var metadata middleware.Metadata
// Missing all
metadata.Properties = map[string]string{}
log := logger.NewLogger("oauth2clientcredentials.test")
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
assert.EqualError(t, err, "Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. Parameter 'authStyle' needs to be set. Parameter 'authStyle' can only have the values 0,1,2. Received: ''. ")
// Invalid authStyle (non int)
metadata.Properties = map[string]string{
"clientID": "testId",
"clientSecret": "testSecret",
"scopes": "ascope",
"tokenURL": "https://localhost:9999",
"headerName": "someHeader",
"authStyle": "asdf", // This is the value to test
}
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
assert.EqualError(t, err2, "Parameter 'authStyle' can only have the values 0,1,2. Received: 'asdf'. ")
// Invalid authStyle (int > 2)
metadata.Properties["authStyle"] = "3"
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
assert.EqualError(t, err3, "Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
// Invalid authStyle (int < 0)
metadata.Properties["authStyle"] = "-1"
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
assert.EqualError(t, err4, "Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
}
// TestOAuth2ClientCredentialsToken will check
// - if the Token was added to the RequestHeader value specified
func TestOAuth2ClientCredentialsToken(t *testing.T) {
// Setup
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
// Mock mockTokenProvider
mockTokenProvider := mock.NewMockTokenProviderInterface(mockCtrl)
gomock.InOrder(
// First call returning abc and Bearer, expires within 1 second
mockTokenProvider.
EXPECT().
GetToken(gomock.Any()).
Return(&oauth2.Token{
AccessToken: "abcd",
TokenType: "Bearer",
Expiry: time.Now().In(time.UTC).Add(1 * time.Second),
}, nil).
Times(1),
)
// Specify components metadata
var metadata middleware.Metadata
metadata.Properties = map[string]string{
"clientID": "testId",
"clientSecret": "testSecret",
"scopes": "ascope",
"tokenURL": "https://localhost:9999",
"headerName": "someHeader",
"authStyle": "1",
}
// Initialize middleware component and inject mocked TokenProvider
log := logger.NewLogger("oauth2clientcredentials.test")
oauth2clientcredentialsMiddleware := NewOAuth2ClientCredentialsMiddleware(log)
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata)
require.NoError(t, err)
// First handler call should return abc Token
var requestContext1 fh.RequestCtx
handler(mockedRequestHandler)(&requestContext1)
// Assertion
assert.Equal(t, "Bearer abcd", string(requestContext1.Request.Header.Peek("someHeader")))
}
// TestOAuth2ClientCredentialsCache will check
// - if the Cache is working
func TestOAuth2ClientCredentialsCache(t *testing.T) {
// Setup
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
// Mock mockTokenProvider
mockTokenProvider := mock.NewMockTokenProviderInterface(mockCtrl)
gomock.InOrder(
// First call returning abc and Bearer, expires within 1 second
mockTokenProvider.
EXPECT().
GetToken(gomock.Any()).
Return(&oauth2.Token{
AccessToken: "abc",
TokenType: "Bearer",
Expiry: time.Now().In(time.UTC).Add(1 * time.Second),
}, nil).
Times(1),
// Second call returning def and MAC, expires within 1 second
mockTokenProvider.
EXPECT().
GetToken(gomock.Any()).
Return(&oauth2.Token{
AccessToken: "def",
TokenType: "MAC",
Expiry: time.Now().In(time.UTC).Add(1 * time.Second),
}, nil).
Times(1),
)
// Specify components metadata
var metadata middleware.Metadata
metadata.Properties = map[string]string{
"clientID": "testId",
"clientSecret": "testSecret",
"scopes": "ascope",
"tokenURL": "https://localhost:9999",
"headerName": "someHeader",
"authStyle": "1",
}
// Initialize middleware component and inject mocked TokenProvider
log := logger.NewLogger("oauth2clientcredentials.test")
oauth2clientcredentialsMiddleware := NewOAuth2ClientCredentialsMiddleware(log)
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata)
require.NoError(t, err)
// First handler call should return abc Token
var requestContext1 fh.RequestCtx
handler(mockedRequestHandler)(&requestContext1)
// Assertion
assert.Equal(t, "Bearer abc", string(requestContext1.Request.Header.Peek("someHeader")))
// Second handler call should still return 'cached' abc Token
var requestContext2 fh.RequestCtx
handler(mockedRequestHandler)(&requestContext2)
// Assertion
assert.Equal(t, "Bearer abc", string(requestContext2.Request.Header.Peek("someHeader")))
// Wait at a second to invalidate cache entry for abc
time.Sleep(1 * time.Second)
// Third call should return def Token
var requestContext3 fh.RequestCtx
handler(mockedRequestHandler)(&requestContext3)
// Assertion
assert.Equal(t, "MAC def", string(requestContext3.Request.Header.Peek("someHeader")))
}