Adding middleware component for non interactive oauth2 client credentials flow (#431)
* Initial commit with mock not working Signed-off-by: Florian Wagner <flwagner@microsoft.com> * changed structure for mocking, added first test Signed-off-by: Florian Wagner <flwagner@microsoft.com> * oauth2clientcredentials input checks and tests Signed-off-by: Florian Wagner <flwagner@microsoft.com> * rename metadata authHeaderName to HeaderName Signed-off-by: Florian Wagner <flwagner@microsoft.com> * Run 'go mod tidy' Signed-off-by: Florian Wagner <flwagner@microsoft.com> * use dapr logger from caller Signed-off-by: Florian Wagner <flwagner@microsoft.com> * Fix additional lint requirements by build pipeline Signed-off-by: Florian Wagner <flwagner@microsoft.com> Co-authored-by: Florian Wagner <flwagner@microsoft.com> Co-authored-by: Yaron Schneider <yaronsc@microsoft.com>
This commit is contained in:
parent
f5f807ca73
commit
e60fa843f2
1
go.mod
1
go.mod
|
|
@ -54,6 +54,7 @@ require (
|
|||
github.com/nats-io/nats.go v1.9.1
|
||||
github.com/nats-io/stan.go v0.6.0
|
||||
github.com/openzipkin/zipkin-go v0.1.6
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/dapr/components-contrib/middleware/http/oauth2clientcredentials (interfaces: TokenProviderInterface)
|
||||
|
||||
// Package mock_oauth2clientcredentials is a generated GoMock package.
|
||||
package mock_oauth2clientcredentials
|
||||
|
||||
import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
oauth2 "golang.org/x/oauth2"
|
||||
clientcredentials "golang.org/x/oauth2/clientcredentials"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockTokenProviderInterface is a mock of TokenProviderInterface interface
|
||||
type MockTokenProviderInterface struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockTokenProviderInterfaceMockRecorder
|
||||
}
|
||||
|
||||
// MockTokenProviderInterfaceMockRecorder is the mock recorder for MockTokenProviderInterface
|
||||
type MockTokenProviderInterfaceMockRecorder struct {
|
||||
mock *MockTokenProviderInterface
|
||||
}
|
||||
|
||||
// NewMockTokenProviderInterface creates a new mock instance
|
||||
func NewMockTokenProviderInterface(ctrl *gomock.Controller) *MockTokenProviderInterface {
|
||||
mock := &MockTokenProviderInterface{ctrl: ctrl}
|
||||
mock.recorder = &MockTokenProviderInterfaceMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockTokenProviderInterface) EXPECT() *MockTokenProviderInterfaceMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetToken mocks base method
|
||||
func (m *MockTokenProviderInterface) GetToken(arg0 *clientcredentials.Config) (*oauth2.Token, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetToken", arg0)
|
||||
ret0, _ := ret[0].(*oauth2.Token)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetToken indicates an expected call of GetToken
|
||||
func (mr *MockTokenProviderInterfaceMockRecorder) GetToken(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetToken", reflect.TypeOf((*MockTokenProviderInterface)(nil).GetToken), arg0)
|
||||
}
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
// ------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
// ------------------------------------------------------------
|
||||
|
||||
package oauth2clientcredentials
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dapr/components-contrib/middleware"
|
||||
"github.com/dapr/dapr/pkg/logger"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/valyala/fasthttp"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/clientcredentials"
|
||||
)
|
||||
|
||||
// Metadata is the oAuth clientcredentials middleware config
|
||||
type oAuth2ClientCredentialsMiddlewareMetadata struct {
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
Scopes string `json:"scopes"`
|
||||
TokenURL string `json:"tokenURL"`
|
||||
HeaderName string `json:"headerName"`
|
||||
EndpointParamsQuery string `json:"endpointParamsQuery,omitempty"`
|
||||
AuthStyleString string `json:"authStyle"`
|
||||
AuthStyle int `json:"-"`
|
||||
}
|
||||
|
||||
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests
|
||||
type TokenProviderInterface interface {
|
||||
GetToken(conf *clientcredentials.Config) (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
// NewOAuth2ClientCredentialsMiddleware returns a new oAuth2 middleware
|
||||
func NewOAuth2ClientCredentialsMiddleware(logger logger.Logger) *Middleware {
|
||||
m := &Middleware{
|
||||
log: logger,
|
||||
tokenCache: cache.New(1*time.Hour, 10*time.Minute),
|
||||
}
|
||||
// Default: set Token Provider to real implementation (we will overwrite it for unit testing)
|
||||
m.SetTokenProvider(m)
|
||||
return m
|
||||
}
|
||||
|
||||
// Middleware is an oAuth2 authentication middleware
|
||||
type Middleware struct {
|
||||
log logger.Logger
|
||||
tokenCache *cache.Cache
|
||||
tokenProvider TokenProviderInterface
|
||||
}
|
||||
|
||||
// GetHandler retruns the HTTP handler provided by the middleware
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) {
|
||||
meta, err := m.getNativeMetadata(metadata)
|
||||
if err != nil {
|
||||
m.log.Errorf("getNativeMetadata error, %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
var headerValue string
|
||||
// Check if valid Token is in the cache
|
||||
var cacheKey = m.getCacheKey(meta)
|
||||
cachedToken, found := m.tokenCache.Get(cacheKey)
|
||||
|
||||
if !found {
|
||||
m.log.Debugf("Cached token not found, try get one")
|
||||
|
||||
var endpointParams, err = url.ParseQuery(meta.EndpointParamsQuery)
|
||||
if err != nil {
|
||||
m.log.Errorf("Error parsing endpoint parameters, %s", err)
|
||||
endpointParams, _ = url.ParseQuery("")
|
||||
}
|
||||
|
||||
conf := &clientcredentials.Config{
|
||||
ClientID: meta.ClientID,
|
||||
ClientSecret: meta.ClientSecret,
|
||||
Scopes: strings.Split(meta.Scopes, ","),
|
||||
TokenURL: meta.TokenURL,
|
||||
EndpointParams: endpointParams,
|
||||
AuthStyle: oauth2.AuthStyle(meta.AuthStyle),
|
||||
}
|
||||
|
||||
token, tokenError := m.tokenProvider.GetToken(conf)
|
||||
if tokenError != nil {
|
||||
m.log.Errorf("Error acquiring token, %s", tokenError)
|
||||
return
|
||||
}
|
||||
|
||||
tokenExpirationDuration := token.Expiry.Sub(time.Now().In(time.UTC))
|
||||
m.log.Debugf("Duration in seconds %s, Expiry Time %s", tokenExpirationDuration, token.Expiry)
|
||||
if err != nil {
|
||||
m.log.Errorf("Error parsing duration string, %s", fmt.Sprintf("%ss", token.Expiry))
|
||||
return
|
||||
}
|
||||
|
||||
headerValue = token.Type() + " " + token.AccessToken
|
||||
m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration)
|
||||
} else {
|
||||
m.log.Debugf("Cached token found for key %s", cacheKey)
|
||||
headerValue = cachedToken.(string)
|
||||
}
|
||||
|
||||
ctx.Request.Header.Add(meta.HeaderName, headerValue)
|
||||
h(ctx)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2ClientCredentialsMiddlewareMetadata, error) {
|
||||
b, err := json.Marshal(metadata.Properties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var middlewareMetadata oAuth2ClientCredentialsMiddlewareMetadata
|
||||
err = json.Unmarshal(b, &middlewareMetadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Do input validation checks
|
||||
errorString := ""
|
||||
// Check if values are present
|
||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.HeaderName, "headerName")
|
||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientID, "clientID")
|
||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientSecret, "clientSecret")
|
||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.Scopes, "scopes")
|
||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.TokenURL, "tokenURL")
|
||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.AuthStyleString, "authStyle")
|
||||
|
||||
// Converting AuthStyle to int and do a value check
|
||||
authStyle, err := strconv.Atoi(middlewareMetadata.AuthStyleString)
|
||||
if err != nil {
|
||||
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%s'. ", middlewareMetadata.AuthStyleString)
|
||||
} else if authStyle < 0 || authStyle > 2 {
|
||||
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", authStyle)
|
||||
} else {
|
||||
middlewareMetadata.AuthStyle = authStyle
|
||||
}
|
||||
|
||||
// Return errors if any found
|
||||
if errorString != "" {
|
||||
return nil, fmt.Errorf("%s", errorString)
|
||||
}
|
||||
|
||||
return &middlewareMetadata, nil
|
||||
}
|
||||
|
||||
func (m *Middleware) checkMetadataValueExists(errorString *string, metadataValue *string, metadataName string) {
|
||||
if *metadataValue == "" {
|
||||
*errorString += fmt.Sprintf("Parameter '%s' needs to be set. ", metadataName)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) getCacheKey(meta *oAuth2ClientCredentialsMiddlewareMetadata) string {
|
||||
// we will hash the key components ClientID + Scopes is a unique composite key/identifier for a token
|
||||
hashedKey := sha256.New()
|
||||
key := strings.Join([]string{meta.ClientID, meta.Scopes}, "")
|
||||
hashedKey.Write([]byte(key))
|
||||
return fmt.Sprintf("%x", hashedKey.Sum(nil))
|
||||
}
|
||||
|
||||
// SetTokenProvider will enable to change the tokenProvider used after instanciation (needed for mocking)
|
||||
func (m *Middleware) SetTokenProvider(tokenProvider TokenProviderInterface) {
|
||||
m.tokenProvider = tokenProvider
|
||||
}
|
||||
|
||||
// GetToken returns a token from the current OAuth2 ClientCredentials Configuration
|
||||
func (m *Middleware) GetToken(conf *clientcredentials.Config) (*oauth2.Token, error) {
|
||||
tokenSource := conf.TokenSource(context.Background())
|
||||
return tokenSource.Token()
|
||||
}
|
||||
|
|
@ -0,0 +1,180 @@
|
|||
// ------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
// ------------------------------------------------------------
|
||||
|
||||
package oauth2clientcredentials
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mock "github.com/dapr/components-contrib/middleware/http/oauth2clientcredentials/mocks"
|
||||
"github.com/dapr/dapr/pkg/logger"
|
||||
|
||||
"github.com/dapr/components-contrib/middleware"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
fh "github.com/valyala/fasthttp"
|
||||
oauth2 "golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func mockedRequestHandler(ctx *fh.RequestCtx) {}
|
||||
|
||||
// TestOAuth2ClientCredentialsMetadata will check
|
||||
// - if the metadata checks are correct in place
|
||||
func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
|
||||
// Specify components metadata
|
||||
var metadata middleware.Metadata
|
||||
|
||||
// Missing all
|
||||
metadata.Properties = map[string]string{}
|
||||
|
||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||
assert.EqualError(t, err, "Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. Parameter 'authStyle' needs to be set. Parameter 'authStyle' can only have the values 0,1,2. Received: ''. ")
|
||||
|
||||
// Invalid authStyle (non int)
|
||||
metadata.Properties = map[string]string{
|
||||
"clientID": "testId",
|
||||
"clientSecret": "testSecret",
|
||||
"scopes": "ascope",
|
||||
"tokenURL": "https://localhost:9999",
|
||||
"headerName": "someHeader",
|
||||
"authStyle": "asdf", // This is the value to test
|
||||
}
|
||||
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||
assert.EqualError(t, err2, "Parameter 'authStyle' can only have the values 0,1,2. Received: 'asdf'. ")
|
||||
|
||||
// Invalid authStyle (int > 2)
|
||||
metadata.Properties["authStyle"] = "3"
|
||||
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||
assert.EqualError(t, err3, "Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
|
||||
|
||||
// Invalid authStyle (int < 0)
|
||||
metadata.Properties["authStyle"] = "-1"
|
||||
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||
assert.EqualError(t, err4, "Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
|
||||
}
|
||||
|
||||
// TestOAuth2ClientCredentialsToken will check
|
||||
// - if the Token was added to the RequestHeader value specified
|
||||
func TestOAuth2ClientCredentialsToken(t *testing.T) {
|
||||
// Setup
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
|
||||
// Mock mockTokenProvider
|
||||
mockTokenProvider := mock.NewMockTokenProviderInterface(mockCtrl)
|
||||
|
||||
gomock.InOrder(
|
||||
// First call returning abc and Bearer, expires within 1 second
|
||||
mockTokenProvider.
|
||||
EXPECT().
|
||||
GetToken(gomock.Any()).
|
||||
Return(&oauth2.Token{
|
||||
AccessToken: "abcd",
|
||||
TokenType: "Bearer",
|
||||
Expiry: time.Now().In(time.UTC).Add(1 * time.Second),
|
||||
}, nil).
|
||||
Times(1),
|
||||
)
|
||||
|
||||
// Specify components metadata
|
||||
var metadata middleware.Metadata
|
||||
metadata.Properties = map[string]string{
|
||||
"clientID": "testId",
|
||||
"clientSecret": "testSecret",
|
||||
"scopes": "ascope",
|
||||
"tokenURL": "https://localhost:9999",
|
||||
"headerName": "someHeader",
|
||||
"authStyle": "1",
|
||||
}
|
||||
|
||||
// Initialize middleware component and inject mocked TokenProvider
|
||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||
oauth2clientcredentialsMiddleware := NewOAuth2ClientCredentialsMiddleware(log)
|
||||
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
|
||||
handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First handler call should return abc Token
|
||||
var requestContext1 fh.RequestCtx
|
||||
handler(mockedRequestHandler)(&requestContext1)
|
||||
// Assertion
|
||||
assert.Equal(t, "Bearer abcd", string(requestContext1.Request.Header.Peek("someHeader")))
|
||||
}
|
||||
|
||||
// TestOAuth2ClientCredentialsCache will check
|
||||
// - if the Cache is working
|
||||
func TestOAuth2ClientCredentialsCache(t *testing.T) {
|
||||
// Setup
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
|
||||
// Mock mockTokenProvider
|
||||
mockTokenProvider := mock.NewMockTokenProviderInterface(mockCtrl)
|
||||
|
||||
gomock.InOrder(
|
||||
// First call returning abc and Bearer, expires within 1 second
|
||||
mockTokenProvider.
|
||||
EXPECT().
|
||||
GetToken(gomock.Any()).
|
||||
Return(&oauth2.Token{
|
||||
AccessToken: "abc",
|
||||
TokenType: "Bearer",
|
||||
Expiry: time.Now().In(time.UTC).Add(1 * time.Second),
|
||||
}, nil).
|
||||
Times(1),
|
||||
// Second call returning def and MAC, expires within 1 second
|
||||
mockTokenProvider.
|
||||
EXPECT().
|
||||
GetToken(gomock.Any()).
|
||||
Return(&oauth2.Token{
|
||||
AccessToken: "def",
|
||||
TokenType: "MAC",
|
||||
Expiry: time.Now().In(time.UTC).Add(1 * time.Second),
|
||||
}, nil).
|
||||
Times(1),
|
||||
)
|
||||
|
||||
// Specify components metadata
|
||||
var metadata middleware.Metadata
|
||||
metadata.Properties = map[string]string{
|
||||
"clientID": "testId",
|
||||
"clientSecret": "testSecret",
|
||||
"scopes": "ascope",
|
||||
"tokenURL": "https://localhost:9999",
|
||||
"headerName": "someHeader",
|
||||
"authStyle": "1",
|
||||
}
|
||||
|
||||
// Initialize middleware component and inject mocked TokenProvider
|
||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||
oauth2clientcredentialsMiddleware := NewOAuth2ClientCredentialsMiddleware(log)
|
||||
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
|
||||
handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First handler call should return abc Token
|
||||
var requestContext1 fh.RequestCtx
|
||||
handler(mockedRequestHandler)(&requestContext1)
|
||||
// Assertion
|
||||
assert.Equal(t, "Bearer abc", string(requestContext1.Request.Header.Peek("someHeader")))
|
||||
|
||||
// Second handler call should still return 'cached' abc Token
|
||||
var requestContext2 fh.RequestCtx
|
||||
handler(mockedRequestHandler)(&requestContext2)
|
||||
// Assertion
|
||||
assert.Equal(t, "Bearer abc", string(requestContext2.Request.Header.Peek("someHeader")))
|
||||
|
||||
// Wait at a second to invalidate cache entry for abc
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Third call should return def Token
|
||||
var requestContext3 fh.RequestCtx
|
||||
handler(mockedRequestHandler)(&requestContext3)
|
||||
// Assertion
|
||||
assert.Equal(t, "MAC def", string(requestContext3.Request.Header.Peek("someHeader")))
|
||||
}
|
||||
Loading…
Reference in New Issue