From eae3312c9a3a5a1acd743bcaac50a1351e036bf9 Mon Sep 17 00:00:00 2001 From: Nelson Parente Date: Fri, 25 Jul 2025 16:31:34 +0100 Subject: [PATCH] Add Path Filter Support to OAuth2 and Client Credentials Middlewares (#3906) Signed-off-by: nelson.parente --- middleware/http/oauth2/oauth2_middleware.go | 22 ++++ .../http/oauth2/oauth2_middleware_test.go | 49 ++++++++ .../oauth2clientcredentials_middleware.go | 53 +++++--- ...ntcredentials_middleware_benchmark_test.go | 113 ++++++++++++++++++ ...oauth2clientcredentials_middleware_test.go | 107 ++++++++++++++++- 5 files changed, 327 insertions(+), 17 deletions(-) create mode 100644 middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_benchmark_test.go diff --git a/middleware/http/oauth2/oauth2_middleware.go b/middleware/http/oauth2/oauth2_middleware.go index 4c7b4b662..2211caa58 100644 --- a/middleware/http/oauth2/oauth2_middleware.go +++ b/middleware/http/oauth2/oauth2_middleware.go @@ -18,6 +18,7 @@ import ( "net/http" "net/url" "reflect" + "regexp" "strings" "github.com/fasthttp-contrib/sessions" @@ -42,6 +43,9 @@ type oAuth2MiddlewareMetadata struct { AuthHeaderName string `json:"authHeaderName" mapstructure:"authHeaderName"` RedirectURL string `json:"redirectURL" mapstructure:"redirectURL"` ForceHTTPS string `json:"forceHTTPS" mapstructure:"forceHTTPS"` + PathFilter string `json:"pathFilter" mapstructure:"pathFilter"` + + pathFilterRegex *regexp.Regexp } // NewOAuth2Middleware returns a new oAuth2 middleware. @@ -84,6 +88,15 @@ func (m *Middleware) GetHandler(ctx context.Context, metadata middleware.Metadat return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if meta.pathFilterRegex != nil { + matched := meta.pathFilterRegex.MatchString(r.URL.Path) + if !matched { + m.logger.Debugf("PathFilter %s didn't match %s! Skipping!", meta.PathFilter, r.URL.Path) + next.ServeHTTP(w, r) + return + } + } + session := sessions.Start(w, r) if session.GetString(meta.AuthHeaderName) != "" { @@ -153,6 +166,15 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2Mid if err != nil { return nil, err } + + if middlewareMetadata.PathFilter != "" { + rx, err := regexp.Compile(middlewareMetadata.PathFilter) + if err != nil { + return nil, err + } + middlewareMetadata.pathFilterRegex = rx + } + return &middlewareMetadata, nil } diff --git a/middleware/http/oauth2/oauth2_middleware_test.go b/middleware/http/oauth2/oauth2_middleware_test.go index 4f12a5825..e8f67e418 100644 --- a/middleware/http/oauth2/oauth2_middleware_test.go +++ b/middleware/http/oauth2/oauth2_middleware_test.go @@ -61,3 +61,52 @@ func TestOAuth2CreatesAuthorizationHeaderWhenInSessionState(t *testing.T) { assert.Equal(t, "Bearer abcd", r.Header.Get("someHeader")) } + +func TestOAuth2CreatesAuthorizationHeaderGetNativeMetadata(t *testing.T) { + var metadata middleware.Metadata + metadata.Properties = map[string]string{ + "clientID": "testId", + "clientSecret": "testSecret", + "scopes": "ascope", + "authURL": "https://idp:9999", + "tokenURL": "https://idp:9999", + "redirectUrl": "https://localhost:9999", + "authHeaderName": "someHeader", + } + + log := logger.NewLogger("oauth2.test") + oauth2Middleware, ok := NewOAuth2Middleware(log).(*Middleware) + require.True(t, ok) + + tc := []struct { + name string + pathFilter string + wantErr bool + }{ + {name: "empty pathFilter", pathFilter: "", wantErr: false}, + {name: "wildcard pathFilter", pathFilter: ".*", wantErr: false}, + {name: "api path pathFilter", pathFilter: "/api/v1/users", wantErr: false}, + {name: "debug endpoint pathFilter", pathFilter: "^/debug/?$", wantErr: false}, + {name: "user id pathFilter", pathFilter: "^/user/[0-9]+$", wantErr: false}, + {name: "invalid wildcard pathFilter", pathFilter: "*invalid", wantErr: true}, + {name: "unclosed parenthesis pathFilter", pathFilter: "invalid(", wantErr: true}, + {name: "unopened parenthesis pathFilter", pathFilter: "invalid)", wantErr: true}, + } + + for _, tt := range tc { + t.Run(tt.name, func(t *testing.T) { + metadata.Properties["pathFilter"] = tt.pathFilter + nativeMetadata, err := oauth2Middleware.getNativeMetadata(metadata) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + if tt.pathFilter != "" { + require.NotNil(t, nativeMetadata.pathFilterRegex) + } else { + require.Nil(t, nativeMetadata.pathFilterRegex) + } + } + }) + } +} diff --git a/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware.go b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware.go index 484376efe..12f955d64 100644 --- a/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware.go +++ b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware.go @@ -21,6 +21,7 @@ import ( "net/http" "net/url" "reflect" + "regexp" "strings" "time" @@ -43,6 +44,9 @@ type oAuth2ClientCredentialsMiddlewareMetadata struct { HeaderName string `json:"headerName" mapstructure:"headerName"` EndpointParamsQuery string `json:"endpointParamsQuery,omitempty" mapstructure:"endpointParamsQuery"` AuthStyle int `json:"authStyle" mapstructure:"authStyle"` + PathFilter string `json:"pathFilter" mapstructure:"pathFilter"` + + pathFilterRegex *regexp.Regexp } // TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests. @@ -69,7 +73,7 @@ type Middleware struct { tokenProvider TokenProviderInterface } -// GetHandler retruns the HTTP handler provided by the middleware. +// GetHandler returns the HTTP handler provided by the middleware. func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { meta, err := m.getNativeMetadata(metadata) if err != nil { @@ -98,27 +102,38 @@ func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var headerValue string - // Check if valid token is in the cache - cachedToken, found := m.tokenCache.Get(cacheKey) - if !found { - m.log.Debugf("Cached token not found, try get one") - - token, err := m.tokenProvider.GetToken(r.Context(), conf) - if err != nil { - m.log.Errorf("Error acquiring token: %s", err) + if meta.pathFilterRegex != nil { + matched := meta.pathFilterRegex.MatchString(r.URL.Path) + if !matched { + m.log.Debugf("PathFilter %s didn't match %s! Skipping!", meta.PathFilter, r.URL.Path) + next.ServeHTTP(w, r) return } + } - tokenExpirationDuration := time.Until(token.Expiry) - m.log.Debugf("Token expires at %s (%s from now)", token.Expiry, tokenExpirationDuration) - - headerValue = token.Type() + " " + token.AccessToken - m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration) - } else { + // Check if valid token is in the cache + cachedToken, found := m.tokenCache.Get(cacheKey) + if found { m.log.Debugf("Cached token found for key %s", cacheKey) headerValue = cachedToken.(string) + r.Header.Add(meta.HeaderName, headerValue) + next.ServeHTTP(w, r) + return } + m.log.Infof("Cached token not found, attempting to retrieve a new one") + token, err := m.tokenProvider.GetToken(r.Context(), conf) + if err != nil { + m.log.Errorf("Error acquiring token: %s", err) + return + } + + tokenExpirationDuration := time.Until(token.Expiry) + m.log.Infof("Token expires at %s (%s from now)", token.Expiry, tokenExpirationDuration) + + headerValue = token.Type() + " " + token.AccessToken + m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration) + r.Header.Add(meta.HeaderName, headerValue) next.ServeHTTP(w, r) }) @@ -142,6 +157,14 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2Cli m.checkMetadataValueExists(&errorString, &middlewareMetadata.Scopes, "scopes") m.checkMetadataValueExists(&errorString, &middlewareMetadata.TokenURL, "tokenURL") + if middlewareMetadata.PathFilter != "" { + rx, err := regexp.Compile(middlewareMetadata.PathFilter) + if err != nil { + errorString += "Parameter 'pathFilter' is not a valid regex: " + err.Error() + ". " + } + middlewareMetadata.pathFilterRegex = rx + } + // Value-check AuthStyle if middlewareMetadata.AuthStyle < 0 || middlewareMetadata.AuthStyle > 2 { errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", middlewareMetadata.AuthStyle) diff --git a/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_benchmark_test.go b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_benchmark_test.go new file mode 100644 index 000000000..032cbc11d --- /dev/null +++ b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_benchmark_test.go @@ -0,0 +1,113 @@ +/* +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package oauth2clientcredentials + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/dapr/components-contrib/middleware" + mock "github.com/dapr/components-contrib/middleware/http/oauth2clientcredentials/mocks" + "github.com/dapr/kit/logger" +) + +func BenchmarkTestOAuth2ClientCredentialsGetHandler(b *testing.B) { + mockCtrl := gomock.NewController(b) + defer mockCtrl.Finish() + mockTokenProvider := mock.NewMockTokenProviderInterface(mockCtrl) + gomock.InOrder( + mockTokenProvider. + EXPECT(). + GetToken(gomock.Any()). + Return(&oauth2.Token{ + AccessToken: "abcd", + TokenType: "Bearer", + Expiry: time.Now().Add(1 * time.Minute), + }, nil). + Times(1), + ) + + var metadata middleware.Metadata + metadata.Properties = map[string]string{ + "clientID": "testId", + "clientSecret": "testSecret", + "scopes": "ascope", + "tokenURL": "https://localhost:9999", + "headerName": "authorization", + "authStyle": "1", + } + + log := logger.NewLogger("oauth2clientcredentials.test") + oauth2clientcredentialsMiddleware, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware) + require.True(b, ok) + oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider) + handler, err := oauth2clientcredentialsMiddleware.GetHandler(b.Context(), metadata) + require.NoError(b, err) + + for i := range b.N { + url := fmt.Sprintf("http://dapr.io/api/v1/users/%d", i) + r := httptest.NewRequest(http.MethodGet, url, nil) + w := httptest.NewRecorder() + handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r) + } +} + +func BenchmarkTestOAuth2ClientCredentialsGetHandlerWithPathFilter(b *testing.B) { + mockCtrl := gomock.NewController(b) + defer mockCtrl.Finish() + mockTokenProvider := mock.NewMockTokenProviderInterface(mockCtrl) + gomock.InOrder( + mockTokenProvider. + EXPECT(). + GetToken(gomock.Any()). + Return(&oauth2.Token{ + AccessToken: "abcd", + TokenType: "Bearer", + Expiry: time.Now().Add(1 * time.Minute), + }, nil). + Times(1), + ) + + var metadata middleware.Metadata + metadata.Properties = map[string]string{ + "clientID": "testId", + "clientSecret": "testSecret", + "scopes": "ascope", + "tokenURL": "https://localhost:9999", + "headerName": "authorization", + "authStyle": "1", + "pathFilter": "/api/v1/users/.*", + } + + log := logger.NewLogger("oauth2clientcredentials.test") + oauth2clientcredentialsMiddleware, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware) + require.True(b, ok) + oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider) + handler, err := oauth2clientcredentialsMiddleware.GetHandler(b.Context(), metadata) + require.NoError(b, err) + + for i := range b.N { + url := fmt.Sprintf("http://dapr.io/api/v1/users/%d", i) + r := httptest.NewRequest(http.MethodGet, url, nil) + w := httptest.NewRecorder() + handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r) + } +} diff --git a/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_test.go b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_test.go index 48d27754e..7fdbc86c6 100644 --- a/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_test.go +++ b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" oauth2 "golang.org/x/oauth2" + "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/middleware" mock "github.com/dapr/components-contrib/middleware/http/oauth2clientcredentials/mocks" "github.com/dapr/kit/logger" @@ -107,7 +108,8 @@ func TestOAuth2ClientCredentialsToken(t *testing.T) { // Initialize middleware component and inject mocked TokenProvider log := logger.NewLogger("oauth2clientcredentials.test") - oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware) + oauth2clientcredentialsMiddleware, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware) + require.True(t, ok) oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider) handler, err := oauth2clientcredentialsMiddleware.GetHandler(t.Context(), metadata) require.NoError(t, err) @@ -167,7 +169,8 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) { // Initialize middleware component and inject mocked TokenProvider log := logger.NewLogger("oauth2clientcredentials.test") - oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware) + oauth2clientcredentialsMiddleware, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware) + require.True(t, ok) oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider) handler, err := oauth2clientcredentialsMiddleware.GetHandler(t.Context(), metadata) require.NoError(t, err) @@ -199,3 +202,103 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) { // Assertion assert.Equal(t, "MAC def", r.Header.Get("someHeader")) } + +func TestOAuth2ClientCredentialsPathFilterGetNativeMetadata(t *testing.T) { + log := logger.NewLogger("oauth2clientcredentials.test") + m, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware) + require.True(t, ok) + + baseMiddlewareMetadata := middleware.Metadata{ + Base: metadata.Base{ + Properties: map[string]string{ + "clientID": "testId", + "clientSecret": "testSecret", + "scopes": "ascope", + "tokenURL": "https://localhost:9999", + "headerName": "someHeader", + "authStyle": "1", + }, + }, + } + + tc := []struct { + name string + pathFilter string + wantErr bool + }{ + {name: "empty pathFilter", pathFilter: "", wantErr: false}, + {name: "wildcard pathFilter", pathFilter: ".*", wantErr: false}, + {name: "api path pathFilter", pathFilter: "/api/v1/users", wantErr: false}, + {name: "debug endpoint pathFilter", pathFilter: "^/debug/?$", wantErr: false}, + {name: "user id pathFilter", pathFilter: "^/user/[0-9]+$", wantErr: false}, + {name: "invalid wildcard pathFilter", pathFilter: "*invalid", wantErr: true}, + {name: "unclosed parenthesis pathFilter", pathFilter: "invalid(", wantErr: true}, + {name: "unopened parenthesis pathFilter", pathFilter: "invalid)", wantErr: true}, + } + + for _, tt := range tc { + t.Run(tt.name, func(t *testing.T) { + baseMiddlewareMetadata.Properties["pathFilter"] = tt.pathFilter + _, err := m.getNativeMetadata(baseMiddlewareMetadata) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestOAuth2ClientCredentialsPathFilterGetHandler(t *testing.T) { + // Setup + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + // Mock mockTokenProvider + mockTokenProvider := mock.NewMockTokenProviderInterface(mockCtrl) + + gomock.InOrder( + // First call returning abc and Bearer, expires within 1 second + mockTokenProvider. + EXPECT(). + GetToken(gomock.Any()). + Return(&oauth2.Token{ + AccessToken: "abcd", + TokenType: "Bearer", + Expiry: time.Now().In(time.UTC).Add(1 * time.Second), + }, nil). + Times(1), + ) + + var metadata middleware.Metadata + metadata.Properties = map[string]string{ + "clientID": "testId", + "clientSecret": "testSecret", + "scopes": "ascope", + "tokenURL": "https://localhost:9999", + "headerName": "authorization", + "authStyle": "1", + "pathFilter": "/api/v1/users/.*", + } + + log := logger.NewLogger("oauth2clientcredentials.test") + oauth2clientcredentialsMiddleware, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware) + require.True(t, ok) + oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider) + handler, err := oauth2clientcredentialsMiddleware.GetHandler(t.Context(), metadata) + require.NoError(t, err) + + // pathFilter should match + r := httptest.NewRequest(http.MethodGet, "http://dapr.io/api/v1/users/123", nil) + w := httptest.NewRecorder() + handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r) + + assert.Equal(t, "Bearer abcd", r.Header.Get("authorization")) + + // pathFilter should not match + r = httptest.NewRequest(http.MethodGet, "http://dapr.io/api/v1/tokens/123", nil) + w = httptest.NewRecorder() + handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r) + + assert.Equal(t, "", r.Header.Get("authorization")) +}