Merge branch 'main' into fix-kafka-consumer-shutdown
This commit is contained in:
commit
dc8e071e36
|
@ -18,6 +18,7 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/fasthttp-contrib/sessions"
|
||||
|
@ -42,6 +43,9 @@ type oAuth2MiddlewareMetadata struct {
|
|||
AuthHeaderName string `json:"authHeaderName" mapstructure:"authHeaderName"`
|
||||
RedirectURL string `json:"redirectURL" mapstructure:"redirectURL"`
|
||||
ForceHTTPS string `json:"forceHTTPS" mapstructure:"forceHTTPS"`
|
||||
PathFilter string `json:"pathFilter" mapstructure:"pathFilter"`
|
||||
|
||||
pathFilterRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
// NewOAuth2Middleware returns a new oAuth2 middleware.
|
||||
|
@ -84,6 +88,15 @@ func (m *Middleware) GetHandler(ctx context.Context, metadata middleware.Metadat
|
|||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if meta.pathFilterRegex != nil {
|
||||
matched := meta.pathFilterRegex.MatchString(r.URL.Path)
|
||||
if !matched {
|
||||
m.logger.Debugf("PathFilter %s didn't match %s! Skipping!", meta.PathFilter, r.URL.Path)
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
session := sessions.Start(w, r)
|
||||
|
||||
if session.GetString(meta.AuthHeaderName) != "" {
|
||||
|
@ -153,6 +166,15 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2Mid
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if middlewareMetadata.PathFilter != "" {
|
||||
rx, err := regexp.Compile(middlewareMetadata.PathFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
middlewareMetadata.pathFilterRegex = rx
|
||||
}
|
||||
|
||||
return &middlewareMetadata, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -61,3 +61,52 @@ func TestOAuth2CreatesAuthorizationHeaderWhenInSessionState(t *testing.T) {
|
|||
|
||||
assert.Equal(t, "Bearer abcd", r.Header.Get("someHeader"))
|
||||
}
|
||||
|
||||
func TestOAuth2CreatesAuthorizationHeaderGetNativeMetadata(t *testing.T) {
|
||||
var metadata middleware.Metadata
|
||||
metadata.Properties = map[string]string{
|
||||
"clientID": "testId",
|
||||
"clientSecret": "testSecret",
|
||||
"scopes": "ascope",
|
||||
"authURL": "https://idp:9999",
|
||||
"tokenURL": "https://idp:9999",
|
||||
"redirectUrl": "https://localhost:9999",
|
||||
"authHeaderName": "someHeader",
|
||||
}
|
||||
|
||||
log := logger.NewLogger("oauth2.test")
|
||||
oauth2Middleware, ok := NewOAuth2Middleware(log).(*Middleware)
|
||||
require.True(t, ok)
|
||||
|
||||
tc := []struct {
|
||||
name string
|
||||
pathFilter string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "empty pathFilter", pathFilter: "", wantErr: false},
|
||||
{name: "wildcard pathFilter", pathFilter: ".*", wantErr: false},
|
||||
{name: "api path pathFilter", pathFilter: "/api/v1/users", wantErr: false},
|
||||
{name: "debug endpoint pathFilter", pathFilter: "^/debug/?$", wantErr: false},
|
||||
{name: "user id pathFilter", pathFilter: "^/user/[0-9]+$", wantErr: false},
|
||||
{name: "invalid wildcard pathFilter", pathFilter: "*invalid", wantErr: true},
|
||||
{name: "unclosed parenthesis pathFilter", pathFilter: "invalid(", wantErr: true},
|
||||
{name: "unopened parenthesis pathFilter", pathFilter: "invalid)", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tc {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
metadata.Properties["pathFilter"] = tt.pathFilter
|
||||
nativeMetadata, err := oauth2Middleware.getNativeMetadata(metadata)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
if tt.pathFilter != "" {
|
||||
require.NotNil(t, nativeMetadata.pathFilterRegex)
|
||||
} else {
|
||||
require.Nil(t, nativeMetadata.pathFilterRegex)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -43,6 +44,9 @@ type oAuth2ClientCredentialsMiddlewareMetadata struct {
|
|||
HeaderName string `json:"headerName" mapstructure:"headerName"`
|
||||
EndpointParamsQuery string `json:"endpointParamsQuery,omitempty" mapstructure:"endpointParamsQuery"`
|
||||
AuthStyle int `json:"authStyle" mapstructure:"authStyle"`
|
||||
PathFilter string `json:"pathFilter" mapstructure:"pathFilter"`
|
||||
|
||||
pathFilterRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests.
|
||||
|
@ -69,7 +73,7 @@ type Middleware struct {
|
|||
tokenProvider TokenProviderInterface
|
||||
}
|
||||
|
||||
// GetHandler retruns the HTTP handler provided by the middleware.
|
||||
// GetHandler returns the HTTP handler provided by the middleware.
|
||||
func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
meta, err := m.getNativeMetadata(metadata)
|
||||
if err != nil {
|
||||
|
@ -98,27 +102,38 @@ func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata)
|
|||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var headerValue string
|
||||
|
||||
// Check if valid token is in the cache
|
||||
cachedToken, found := m.tokenCache.Get(cacheKey)
|
||||
if !found {
|
||||
m.log.Debugf("Cached token not found, try get one")
|
||||
|
||||
token, err := m.tokenProvider.GetToken(r.Context(), conf)
|
||||
if err != nil {
|
||||
m.log.Errorf("Error acquiring token: %s", err)
|
||||
if meta.pathFilterRegex != nil {
|
||||
matched := meta.pathFilterRegex.MatchString(r.URL.Path)
|
||||
if !matched {
|
||||
m.log.Debugf("PathFilter %s didn't match %s! Skipping!", meta.PathFilter, r.URL.Path)
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tokenExpirationDuration := time.Until(token.Expiry)
|
||||
m.log.Debugf("Token expires at %s (%s from now)", token.Expiry, tokenExpirationDuration)
|
||||
|
||||
headerValue = token.Type() + " " + token.AccessToken
|
||||
m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration)
|
||||
} else {
|
||||
// Check if valid token is in the cache
|
||||
cachedToken, found := m.tokenCache.Get(cacheKey)
|
||||
if found {
|
||||
m.log.Debugf("Cached token found for key %s", cacheKey)
|
||||
headerValue = cachedToken.(string)
|
||||
r.Header.Add(meta.HeaderName, headerValue)
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
m.log.Infof("Cached token not found, attempting to retrieve a new one")
|
||||
token, err := m.tokenProvider.GetToken(r.Context(), conf)
|
||||
if err != nil {
|
||||
m.log.Errorf("Error acquiring token: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
tokenExpirationDuration := time.Until(token.Expiry)
|
||||
m.log.Infof("Token expires at %s (%s from now)", token.Expiry, tokenExpirationDuration)
|
||||
|
||||
headerValue = token.Type() + " " + token.AccessToken
|
||||
m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration)
|
||||
|
||||
r.Header.Add(meta.HeaderName, headerValue)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
|
@ -142,6 +157,14 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2Cli
|
|||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.Scopes, "scopes")
|
||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.TokenURL, "tokenURL")
|
||||
|
||||
if middlewareMetadata.PathFilter != "" {
|
||||
rx, err := regexp.Compile(middlewareMetadata.PathFilter)
|
||||
if err != nil {
|
||||
errorString += "Parameter 'pathFilter' is not a valid regex: " + err.Error() + ". "
|
||||
}
|
||||
middlewareMetadata.pathFilterRegex = rx
|
||||
}
|
||||
|
||||
// Value-check AuthStyle
|
||||
if middlewareMetadata.AuthStyle < 0 || middlewareMetadata.AuthStyle > 2 {
|
||||
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", middlewareMetadata.AuthStyle)
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package oauth2clientcredentials
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/dapr/components-contrib/middleware"
|
||||
mock "github.com/dapr/components-contrib/middleware/http/oauth2clientcredentials/mocks"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
func BenchmarkTestOAuth2ClientCredentialsGetHandler(b *testing.B) {
|
||||
mockCtrl := gomock.NewController(b)
|
||||
defer mockCtrl.Finish()
|
||||
mockTokenProvider := mock.NewMockTokenProviderInterface(mockCtrl)
|
||||
gomock.InOrder(
|
||||
mockTokenProvider.
|
||||
EXPECT().
|
||||
GetToken(gomock.Any()).
|
||||
Return(&oauth2.Token{
|
||||
AccessToken: "abcd",
|
||||
TokenType: "Bearer",
|
||||
Expiry: time.Now().Add(1 * time.Minute),
|
||||
}, nil).
|
||||
Times(1),
|
||||
)
|
||||
|
||||
var metadata middleware.Metadata
|
||||
metadata.Properties = map[string]string{
|
||||
"clientID": "testId",
|
||||
"clientSecret": "testSecret",
|
||||
"scopes": "ascope",
|
||||
"tokenURL": "https://localhost:9999",
|
||||
"headerName": "authorization",
|
||||
"authStyle": "1",
|
||||
}
|
||||
|
||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||
oauth2clientcredentialsMiddleware, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
||||
require.True(b, ok)
|
||||
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
|
||||
handler, err := oauth2clientcredentialsMiddleware.GetHandler(b.Context(), metadata)
|
||||
require.NoError(b, err)
|
||||
|
||||
for i := range b.N {
|
||||
url := fmt.Sprintf("http://dapr.io/api/v1/users/%d", i)
|
||||
r := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTestOAuth2ClientCredentialsGetHandlerWithPathFilter(b *testing.B) {
|
||||
mockCtrl := gomock.NewController(b)
|
||||
defer mockCtrl.Finish()
|
||||
mockTokenProvider := mock.NewMockTokenProviderInterface(mockCtrl)
|
||||
gomock.InOrder(
|
||||
mockTokenProvider.
|
||||
EXPECT().
|
||||
GetToken(gomock.Any()).
|
||||
Return(&oauth2.Token{
|
||||
AccessToken: "abcd",
|
||||
TokenType: "Bearer",
|
||||
Expiry: time.Now().Add(1 * time.Minute),
|
||||
}, nil).
|
||||
Times(1),
|
||||
)
|
||||
|
||||
var metadata middleware.Metadata
|
||||
metadata.Properties = map[string]string{
|
||||
"clientID": "testId",
|
||||
"clientSecret": "testSecret",
|
||||
"scopes": "ascope",
|
||||
"tokenURL": "https://localhost:9999",
|
||||
"headerName": "authorization",
|
||||
"authStyle": "1",
|
||||
"pathFilter": "/api/v1/users/.*",
|
||||
}
|
||||
|
||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||
oauth2clientcredentialsMiddleware, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
||||
require.True(b, ok)
|
||||
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
|
||||
handler, err := oauth2clientcredentialsMiddleware.GetHandler(b.Context(), metadata)
|
||||
require.NoError(b, err)
|
||||
|
||||
for i := range b.N {
|
||||
url := fmt.Sprintf("http://dapr.io/api/v1/users/%d", i)
|
||||
r := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||
}
|
||||
}
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
oauth2 "golang.org/x/oauth2"
|
||||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/middleware"
|
||||
mock "github.com/dapr/components-contrib/middleware/http/oauth2clientcredentials/mocks"
|
||||
"github.com/dapr/kit/logger"
|
||||
|
@ -107,7 +108,8 @@ func TestOAuth2ClientCredentialsToken(t *testing.T) {
|
|||
|
||||
// Initialize middleware component and inject mocked TokenProvider
|
||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||
oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
||||
oauth2clientcredentialsMiddleware, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
||||
require.True(t, ok)
|
||||
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
|
||||
handler, err := oauth2clientcredentialsMiddleware.GetHandler(t.Context(), metadata)
|
||||
require.NoError(t, err)
|
||||
|
@ -167,7 +169,8 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) {
|
|||
|
||||
// Initialize middleware component and inject mocked TokenProvider
|
||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||
oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
||||
oauth2clientcredentialsMiddleware, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
||||
require.True(t, ok)
|
||||
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
|
||||
handler, err := oauth2clientcredentialsMiddleware.GetHandler(t.Context(), metadata)
|
||||
require.NoError(t, err)
|
||||
|
@ -199,3 +202,103 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) {
|
|||
// Assertion
|
||||
assert.Equal(t, "MAC def", r.Header.Get("someHeader"))
|
||||
}
|
||||
|
||||
func TestOAuth2ClientCredentialsPathFilterGetNativeMetadata(t *testing.T) {
|
||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||
m, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
||||
require.True(t, ok)
|
||||
|
||||
baseMiddlewareMetadata := middleware.Metadata{
|
||||
Base: metadata.Base{
|
||||
Properties: map[string]string{
|
||||
"clientID": "testId",
|
||||
"clientSecret": "testSecret",
|
||||
"scopes": "ascope",
|
||||
"tokenURL": "https://localhost:9999",
|
||||
"headerName": "someHeader",
|
||||
"authStyle": "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tc := []struct {
|
||||
name string
|
||||
pathFilter string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "empty pathFilter", pathFilter: "", wantErr: false},
|
||||
{name: "wildcard pathFilter", pathFilter: ".*", wantErr: false},
|
||||
{name: "api path pathFilter", pathFilter: "/api/v1/users", wantErr: false},
|
||||
{name: "debug endpoint pathFilter", pathFilter: "^/debug/?$", wantErr: false},
|
||||
{name: "user id pathFilter", pathFilter: "^/user/[0-9]+$", wantErr: false},
|
||||
{name: "invalid wildcard pathFilter", pathFilter: "*invalid", wantErr: true},
|
||||
{name: "unclosed parenthesis pathFilter", pathFilter: "invalid(", wantErr: true},
|
||||
{name: "unopened parenthesis pathFilter", pathFilter: "invalid)", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tc {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
baseMiddlewareMetadata.Properties["pathFilter"] = tt.pathFilter
|
||||
_, err := m.getNativeMetadata(baseMiddlewareMetadata)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2ClientCredentialsPathFilterGetHandler(t *testing.T) {
|
||||
// Setup
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
|
||||
// Mock mockTokenProvider
|
||||
mockTokenProvider := mock.NewMockTokenProviderInterface(mockCtrl)
|
||||
|
||||
gomock.InOrder(
|
||||
// First call returning abc and Bearer, expires within 1 second
|
||||
mockTokenProvider.
|
||||
EXPECT().
|
||||
GetToken(gomock.Any()).
|
||||
Return(&oauth2.Token{
|
||||
AccessToken: "abcd",
|
||||
TokenType: "Bearer",
|
||||
Expiry: time.Now().In(time.UTC).Add(1 * time.Second),
|
||||
}, nil).
|
||||
Times(1),
|
||||
)
|
||||
|
||||
var metadata middleware.Metadata
|
||||
metadata.Properties = map[string]string{
|
||||
"clientID": "testId",
|
||||
"clientSecret": "testSecret",
|
||||
"scopes": "ascope",
|
||||
"tokenURL": "https://localhost:9999",
|
||||
"headerName": "authorization",
|
||||
"authStyle": "1",
|
||||
"pathFilter": "/api/v1/users/.*",
|
||||
}
|
||||
|
||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||
oauth2clientcredentialsMiddleware, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
||||
require.True(t, ok)
|
||||
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
|
||||
handler, err := oauth2clientcredentialsMiddleware.GetHandler(t.Context(), metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
// pathFilter should match
|
||||
r := httptest.NewRequest(http.MethodGet, "http://dapr.io/api/v1/users/123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, "Bearer abcd", r.Header.Get("authorization"))
|
||||
|
||||
// pathFilter should not match
|
||||
r = httptest.NewRequest(http.MethodGet, "http://dapr.io/api/v1/tokens/123", nil)
|
||||
w = httptest.NewRecorder()
|
||||
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, "", r.Header.Get("authorization"))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue