Updated more middlewares
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
8557183752
commit
e3d2ada01c
|
|
@ -82,7 +82,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
|
||||||
session := sessions.Start(w, r)
|
session := sessions.Start(w, r)
|
||||||
|
|
||||||
if session.GetString(meta.AuthHeaderName) != "" {
|
if session.GetString(meta.AuthHeaderName) != "" {
|
||||||
w.Header().Set(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName))
|
w.Header().Add(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName))
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -135,7 +135,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
|
||||||
|
|
||||||
authHeader := token.Type() + " " + token.AccessToken
|
authHeader := token.Type() + " " + token.AccessToken
|
||||||
session.Set(meta.AuthHeaderName, authHeader)
|
session.Set(meta.AuthHeaderName, authHeader)
|
||||||
w.Header().Set(meta.AuthHeaderName, authHeader)
|
w.Header().Add(meta.AuthHeaderName, authHeader)
|
||||||
httputils.RespondWithRedirect(w, http.StatusFound, redirectURL.String())
|
httputils.RespondWithRedirect(w, http.StatusFound, redirectURL.String())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -16,18 +16,18 @@ package oauth2clientcredentials
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/json"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/clientcredentials"
|
"golang.org/x/oauth2/clientcredentials"
|
||||||
|
|
||||||
|
mdutils "github.com/dapr/components-contrib/metadata"
|
||||||
"github.com/dapr/components-contrib/middleware"
|
"github.com/dapr/components-contrib/middleware"
|
||||||
"github.com/dapr/kit/logger"
|
"github.com/dapr/kit/logger"
|
||||||
)
|
)
|
||||||
|
|
@ -40,8 +40,7 @@ type oAuth2ClientCredentialsMiddlewareMetadata struct {
|
||||||
TokenURL string `json:"tokenURL"`
|
TokenURL string `json:"tokenURL"`
|
||||||
HeaderName string `json:"headerName"`
|
HeaderName string `json:"headerName"`
|
||||||
EndpointParamsQuery string `json:"endpointParamsQuery,omitempty"`
|
EndpointParamsQuery string `json:"endpointParamsQuery,omitempty"`
|
||||||
AuthStyleString string `json:"authStyle"`
|
AuthStyle int `json:"authStyle"`
|
||||||
AuthStyle int `json:"-"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests.
|
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests.
|
||||||
|
|
@ -69,53 +68,47 @@ type Middleware struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHandler retruns the HTTP handler provided by the middleware.
|
// GetHandler retruns the HTTP handler provided by the middleware.
|
||||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) {
|
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||||
meta, err := m.getNativeMetadata(metadata)
|
meta, err := m.getNativeMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.log.Errorf("getNativeMetadata error, %s", err)
|
m.log.Errorf("getNativeMetadata error: %s", err)
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler {
|
endpointParams, err := url.ParseQuery(meta.EndpointParamsQuery)
|
||||||
return func(ctx *fasthttp.RequestCtx) {
|
if err != nil {
|
||||||
var headerValue string
|
m.log.Errorf("Error parsing endpoint parameters: %s", err)
|
||||||
// Check if valid Token is in the cache
|
endpointParams, _ = url.ParseQuery("")
|
||||||
cacheKey := m.getCacheKey(meta)
|
}
|
||||||
cachedToken, found := m.tokenCache.Get(cacheKey)
|
|
||||||
|
|
||||||
|
conf := &clientcredentials.Config{
|
||||||
|
ClientID: meta.ClientID,
|
||||||
|
ClientSecret: meta.ClientSecret,
|
||||||
|
Scopes: strings.Split(meta.Scopes, ","),
|
||||||
|
TokenURL: meta.TokenURL,
|
||||||
|
EndpointParams: endpointParams,
|
||||||
|
AuthStyle: oauth2.AuthStyle(meta.AuthStyle),
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := m.getCacheKey(meta)
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var headerValue string
|
||||||
|
|
||||||
|
// Check if valid token is in the cache
|
||||||
|
cachedToken, found := m.tokenCache.Get(cacheKey)
|
||||||
if !found {
|
if !found {
|
||||||
m.log.Debugf("Cached token not found, try get one")
|
m.log.Debugf("Cached token not found, try get one")
|
||||||
|
|
||||||
endpointParams, err := url.ParseQuery(meta.EndpointParamsQuery)
|
token, err := m.tokenProvider.GetToken(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.log.Errorf("Error parsing endpoint parameters, %s", err)
|
m.log.Errorf("Error acquiring token: %s", err)
|
||||||
endpointParams, _ = url.ParseQuery("")
|
|
||||||
}
|
|
||||||
|
|
||||||
conf := &clientcredentials.Config{
|
|
||||||
ClientID: meta.ClientID,
|
|
||||||
ClientSecret: meta.ClientSecret,
|
|
||||||
Scopes: strings.Split(meta.Scopes, ","),
|
|
||||||
TokenURL: meta.TokenURL,
|
|
||||||
EndpointParams: endpointParams,
|
|
||||||
AuthStyle: oauth2.AuthStyle(meta.AuthStyle),
|
|
||||||
}
|
|
||||||
|
|
||||||
token, tokenError := m.tokenProvider.GetToken(conf)
|
|
||||||
if tokenError != nil {
|
|
||||||
m.log.Errorf("Error acquiring token, %s", tokenError)
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenExpirationDuration := token.Expiry.Sub(time.Now().In(time.UTC))
|
tokenExpirationDuration := token.Expiry.Sub(time.Now())
|
||||||
m.log.Debugf("Duration in seconds %s, Expiry Time %s", tokenExpirationDuration, token.Expiry)
|
m.log.Debugf("Token expires at %s (%s from now)", token.Expiry, tokenExpirationDuration)
|
||||||
if err != nil {
|
|
||||||
m.log.Errorf("Error parsing duration string, %s", fmt.Sprintf("%ss", token.Expiry))
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
headerValue = token.Type() + " " + token.AccessToken
|
headerValue = token.Type() + " " + token.AccessToken
|
||||||
m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration)
|
m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration)
|
||||||
|
|
@ -124,46 +117,37 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.R
|
||||||
headerValue = cachedToken.(string)
|
headerValue = cachedToken.(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Request.Header.Add(meta.HeaderName, headerValue)
|
w.Header().Add(meta.HeaderName, headerValue)
|
||||||
h(ctx)
|
next.ServeHTTP(w, r)
|
||||||
}
|
})
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2ClientCredentialsMiddlewareMetadata, error) {
|
func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2ClientCredentialsMiddlewareMetadata, error) {
|
||||||
b, err := json.Marshal(metadata.Properties)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var middlewareMetadata oAuth2ClientCredentialsMiddlewareMetadata
|
var middlewareMetadata oAuth2ClientCredentialsMiddlewareMetadata
|
||||||
err = json.Unmarshal(b, &middlewareMetadata)
|
err := mdutils.DecodeMetadata(metadata.Properties, &middlewareMetadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("metadata errors: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do input validation checks
|
// Do input validation checks
|
||||||
errorString := ""
|
errorString := ""
|
||||||
|
|
||||||
// Check if values are present
|
// Check if values are present
|
||||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.HeaderName, "headerName")
|
m.checkMetadataValueExists(&errorString, &middlewareMetadata.HeaderName, "headerName")
|
||||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientID, "clientID")
|
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientID, "clientID")
|
||||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientSecret, "clientSecret")
|
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientSecret, "clientSecret")
|
||||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.Scopes, "scopes")
|
m.checkMetadataValueExists(&errorString, &middlewareMetadata.Scopes, "scopes")
|
||||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.TokenURL, "tokenURL")
|
m.checkMetadataValueExists(&errorString, &middlewareMetadata.TokenURL, "tokenURL")
|
||||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.AuthStyleString, "authStyle")
|
|
||||||
|
|
||||||
// Converting AuthStyle to int and do a value check
|
// Value-check AuthStyle
|
||||||
authStyle, err := strconv.Atoi(middlewareMetadata.AuthStyleString)
|
if middlewareMetadata.AuthStyle < 0 || middlewareMetadata.AuthStyle > 2 {
|
||||||
if err != nil {
|
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", middlewareMetadata.AuthStyle)
|
||||||
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%s'. ", middlewareMetadata.AuthStyleString)
|
|
||||||
} else if authStyle < 0 || authStyle > 2 {
|
|
||||||
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", authStyle)
|
|
||||||
} else {
|
|
||||||
middlewareMetadata.AuthStyle = authStyle
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return errors if any found
|
// Return errors if any found
|
||||||
if errorString != "" {
|
if errorString != "" {
|
||||||
return nil, fmt.Errorf("%s", errorString)
|
return nil, fmt.Errorf("metadata errors: %s", errorString)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &middlewareMetadata, nil
|
return &middlewareMetadata, nil
|
||||||
|
|
@ -177,11 +161,8 @@ func (m *Middleware) checkMetadataValueExists(errorString *string, metadataValue
|
||||||
|
|
||||||
func (m *Middleware) getCacheKey(meta *oAuth2ClientCredentialsMiddlewareMetadata) string {
|
func (m *Middleware) getCacheKey(meta *oAuth2ClientCredentialsMiddlewareMetadata) string {
|
||||||
// we will hash the key components ClientID + Scopes is a unique composite key/identifier for a token
|
// we will hash the key components ClientID + Scopes is a unique composite key/identifier for a token
|
||||||
hashedKey := sha256.New()
|
hashedKey := sha256.Sum224([]byte(meta.ClientID + meta.Scopes))
|
||||||
key := strings.Join([]string{meta.ClientID, meta.Scopes}, "")
|
return hex.EncodeToString(hashedKey[:])
|
||||||
hashedKey.Write([]byte(key))
|
|
||||||
|
|
||||||
return fmt.Sprintf("%x", hashedKey.Sum(nil))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTokenProvider will enable to change the tokenProvider used after instanciation (needed for mocking).
|
// SetTokenProvider will enable to change the tokenProvider used after instanciation (needed for mocking).
|
||||||
|
|
|
||||||
|
|
@ -14,13 +14,14 @@ limitations under the License.
|
||||||
package oauth2clientcredentials
|
package oauth2clientcredentials
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
fh "github.com/valyala/fasthttp"
|
|
||||||
oauth2 "golang.org/x/oauth2"
|
oauth2 "golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/dapr/components-contrib/middleware"
|
"github.com/dapr/components-contrib/middleware"
|
||||||
|
|
@ -28,7 +29,11 @@ import (
|
||||||
"github.com/dapr/kit/logger"
|
"github.com/dapr/kit/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func mockedRequestHandler(ctx *fh.RequestCtx) {}
|
// mockedRequestHandler acts like an upstream service returns success status code 200 and a fixed response body.
|
||||||
|
func mockedRequestHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("from mock"))
|
||||||
|
}
|
||||||
|
|
||||||
// TestOAuth2ClientCredentialsMetadata will check
|
// TestOAuth2ClientCredentialsMetadata will check
|
||||||
// - if the metadata checks are correct in place.
|
// - if the metadata checks are correct in place.
|
||||||
|
|
@ -41,7 +46,7 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
|
||||||
|
|
||||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||||
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||||
assert.EqualError(t, err, "Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. Parameter 'authStyle' needs to be set. Parameter 'authStyle' can only have the values 0,1,2. Received: ''. ")
|
assert.EqualError(t, err, "metadata errors: Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. ")
|
||||||
|
|
||||||
// Invalid authStyle (non int)
|
// Invalid authStyle (non int)
|
||||||
metadata.Properties = map[string]string{
|
metadata.Properties = map[string]string{
|
||||||
|
|
@ -53,17 +58,17 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
|
||||||
"authStyle": "asdf", // This is the value to test
|
"authStyle": "asdf", // This is the value to test
|
||||||
}
|
}
|
||||||
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||||
assert.EqualError(t, err2, "Parameter 'authStyle' can only have the values 0,1,2. Received: 'asdf'. ")
|
assert.EqualError(t, err2, "metadata errors: 1 error(s) decoding:\n\n* cannot parse 'AuthStyle' as int: strconv.ParseInt: parsing \"asdf\": invalid syntax")
|
||||||
|
|
||||||
// Invalid authStyle (int > 2)
|
// Invalid authStyle (int > 2)
|
||||||
metadata.Properties["authStyle"] = "3"
|
metadata.Properties["authStyle"] = "3"
|
||||||
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||||
assert.EqualError(t, err3, "Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
|
assert.EqualError(t, err3, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
|
||||||
|
|
||||||
// Invalid authStyle (int < 0)
|
// Invalid authStyle (int < 0)
|
||||||
metadata.Properties["authStyle"] = "-1"
|
metadata.Properties["authStyle"] = "-1"
|
||||||
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||||
assert.EqualError(t, err4, "Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
|
assert.EqualError(t, err4, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestOAuth2ClientCredentialsToken will check
|
// TestOAuth2ClientCredentialsToken will check
|
||||||
|
|
@ -108,10 +113,12 @@ func TestOAuth2ClientCredentialsToken(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// First handler call should return abc Token
|
// First handler call should return abc Token
|
||||||
var requestContext1 fh.RequestCtx
|
r := httptest.NewRequest("GET", "http://dapr.io", nil)
|
||||||
handler(mockedRequestHandler)(&requestContext1)
|
w := httptest.NewRecorder()
|
||||||
|
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||||
|
|
||||||
// Assertion
|
// Assertion
|
||||||
assert.Equal(t, "Bearer abcd", string(requestContext1.Request.Header.Peek("someHeader")))
|
assert.Equal(t, "Bearer abcd", w.Header().Get("someHeader"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestOAuth2ClientCredentialsCache will check
|
// TestOAuth2ClientCredentialsCache will check
|
||||||
|
|
@ -166,23 +173,29 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// First handler call should return abc Token
|
// First handler call should return abc Token
|
||||||
var requestContext1 fh.RequestCtx
|
r := httptest.NewRequest("GET", "http://dapr.io", nil)
|
||||||
handler(mockedRequestHandler)(&requestContext1)
|
w := httptest.NewRecorder()
|
||||||
|
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||||
|
|
||||||
// Assertion
|
// Assertion
|
||||||
assert.Equal(t, "Bearer abc", string(requestContext1.Request.Header.Peek("someHeader")))
|
assert.Equal(t, "Bearer abc", w.Header().Get("someHeader"))
|
||||||
|
|
||||||
// Second handler call should still return 'cached' abc Token
|
// Second handler call should still return 'cached' abc Token
|
||||||
var requestContext2 fh.RequestCtx
|
r = httptest.NewRequest("GET", "http://dapr.io", nil)
|
||||||
handler(mockedRequestHandler)(&requestContext2)
|
w = httptest.NewRecorder()
|
||||||
|
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||||
|
|
||||||
// Assertion
|
// Assertion
|
||||||
assert.Equal(t, "Bearer abc", string(requestContext2.Request.Header.Peek("someHeader")))
|
assert.Equal(t, "Bearer abc", w.Header().Get("someHeader"))
|
||||||
|
|
||||||
// Wait at a second to invalidate cache entry for abc
|
// Wait at a second to invalidate cache entry for abc
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
// Third call should return def Token
|
// Third call should return def Token
|
||||||
var requestContext3 fh.RequestCtx
|
r = httptest.NewRequest("GET", "http://dapr.io", nil)
|
||||||
handler(mockedRequestHandler)(&requestContext3)
|
w = httptest.NewRecorder()
|
||||||
|
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||||
|
|
||||||
// Assertion
|
// Assertion
|
||||||
assert.Equal(t, "MAC def", string(requestContext3.Request.Header.Peek("someHeader")))
|
assert.Equal(t, "MAC def", w.Header().Get("someHeader"))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,13 +19,18 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"math"
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"net/textproto"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/open-policy-agent/opa/rego"
|
"github.com/open-policy-agent/opa/rego"
|
||||||
"github.com/valyala/fasthttp"
|
"k8s.io/utils/strings/slices"
|
||||||
|
|
||||||
|
"github.com/dapr/components-contrib/internal/httputils"
|
||||||
|
"github.com/dapr/components-contrib/internal/utils"
|
||||||
"github.com/dapr/components-contrib/middleware"
|
"github.com/dapr/components-contrib/middleware"
|
||||||
"github.com/dapr/kit/logger"
|
"github.com/dapr/kit/logger"
|
||||||
)
|
)
|
||||||
|
|
@ -33,9 +38,11 @@ import (
|
||||||
type Status int
|
type Status int
|
||||||
|
|
||||||
type middlewareMetadata struct {
|
type middlewareMetadata struct {
|
||||||
Rego string `json:"rego"`
|
Rego string `json:"rego"`
|
||||||
DefaultStatus Status `json:"defaultStatus,omitempty"`
|
DefaultStatus Status `json:"defaultStatus,omitempty"`
|
||||||
IncludedHeaders string `json:"includedHeaders,omitempty"`
|
IncludedHeaders string `json:"includedHeaders,omitempty"`
|
||||||
|
SkipBody string `json:"skipBody,omitempty"`
|
||||||
|
includedHeadersParsed []string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMiddleware returns a new Open Policy Agent middleware.
|
// NewMiddleware returns a new Open Policy Agent middleware.
|
||||||
|
|
@ -98,110 +105,96 @@ func (s *Status) Valid() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHandler returns the HTTP handler provided by the middleware.
|
// GetHandler returns the HTTP handler provided by the middleware.
|
||||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) {
|
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||||
meta, err := m.getNativeMetadata(metadata)
|
meta, err := m.getNativeMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
query, err := rego.New(
|
query, err := rego.New(
|
||||||
rego.Query("result = data.http.allow"),
|
rego.Query("result = data.http.allow"),
|
||||||
rego.Module("inline.rego", meta.Rego),
|
rego.Module("inline.rego", meta.Rego),
|
||||||
).PrepareForEval(ctx)
|
).PrepareForEval(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler {
|
return func(next http.Handler) http.Handler {
|
||||||
return func(ctx *fasthttp.RequestCtx) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if allow := m.evalRequest(ctx, meta, &query); !allow {
|
if allow := m.evalRequest(w, r, meta, &query); !allow {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h(ctx)
|
next.ServeHTTP(w, r)
|
||||||
}
|
})
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) evalRequest(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, query *rego.PreparedEvalQuery) bool {
|
func (m *Middleware) evalRequest(w http.ResponseWriter, r *http.Request, meta *middlewareMetadata, query *rego.PreparedEvalQuery) bool {
|
||||||
headers := map[string]string{}
|
headers := map[string]string{}
|
||||||
allowedHeaders := strings.Split(meta.IncludedHeaders, ",")
|
|
||||||
ctx.Request.Header.VisitAll(func(key, value []byte) {
|
|
||||||
for _, allowedHeader := range allowedHeaders {
|
|
||||||
scrubbedHeader := strings.ReplaceAll(allowedHeader, " ", "")
|
|
||||||
buf := []byte("")
|
|
||||||
result := fasthttp.AppendNormalizedHeaderKeyBytes(buf, []byte(scrubbedHeader))
|
|
||||||
normalizedHeader := result[0:]
|
|
||||||
if bytes.Equal(key, normalizedHeader) {
|
|
||||||
headers[string(key)] = string(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
queryArgs := map[string][]string{}
|
for key, value := range r.Header {
|
||||||
ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
if slices.Contains(meta.includedHeadersParsed, key) {
|
||||||
if val, ok := queryArgs[string(key)]; ok {
|
headers[key] = value[0]
|
||||||
queryArgs[string(key)] = append(val, string(value))
|
|
||||||
} else {
|
|
||||||
queryArgs[string(key)] = []string{string(value)}
|
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
|
||||||
path := string(ctx.Path())
|
var body string
|
||||||
pathParts := strings.Split(strings.Trim(path, "/"), "/")
|
if !utils.IsTruthy(meta.SkipBody) {
|
||||||
|
buf, _ := io.ReadAll(r.Body)
|
||||||
|
body = string(buf)
|
||||||
|
|
||||||
|
// Put the body back in the request
|
||||||
|
r.Body = io.NopCloser(bytes.NewBuffer(buf))
|
||||||
|
}
|
||||||
|
|
||||||
|
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
|
||||||
input := map[string]interface{}{
|
input := map[string]interface{}{
|
||||||
"request": map[string]interface{}{
|
"request": map[string]interface{}{
|
||||||
"method": string(ctx.Method()),
|
"method": r.Method,
|
||||||
"path": path,
|
"path": r.URL.Path,
|
||||||
"path_parts": pathParts,
|
"path_parts": pathParts,
|
||||||
"raw_query": string(ctx.QueryArgs().QueryString()),
|
"raw_query": r.URL.RawQuery,
|
||||||
"query": queryArgs,
|
"query": map[string][]string(r.URL.Query()),
|
||||||
"headers": headers,
|
"headers": headers,
|
||||||
"scheme": string(ctx.Request.URI().Scheme()),
|
"scheme": r.URL.Scheme,
|
||||||
"body": string(ctx.Request.Body()),
|
"body": body,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := query.Eval(context.TODO(), rego.EvalInput(input))
|
results, err := query.Eval(r.Context(), rego.EvalInput(input))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.opaError(ctx, meta, err)
|
m.opaError(w, meta, err)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(results) == 0 {
|
if len(results) == 0 {
|
||||||
m.opaError(ctx, meta, errOpaNoResult)
|
m.opaError(w, meta, errOpaNoResult)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.handleRegoResult(ctx, meta, results[0].Bindings["result"])
|
return m.handleRegoResult(w, meta, results[0].Bindings["result"])
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleRegoResult takes the in process request and open policy agent evaluation result
|
// handleRegoResult takes the in process request and open policy agent evaluation result
|
||||||
// and maps it the appropriate response or headers.
|
// and maps it the appropriate response or headers.
|
||||||
// It returns true if the request should continue, or false if a response should be immediately returned.
|
// It returns true if the request should continue, or false if a response should be immediately returned.
|
||||||
func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, result interface{}) bool {
|
func (m *Middleware) handleRegoResult(w http.ResponseWriter, meta *middlewareMetadata, result any) bool {
|
||||||
if allowed, ok := result.(bool); ok {
|
if allowed, ok := result.(bool); ok {
|
||||||
if !allowed {
|
if !allowed {
|
||||||
ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus))
|
httputils.RespondWithError(w, int(meta.DefaultStatus))
|
||||||
}
|
}
|
||||||
|
|
||||||
return allowed
|
return allowed
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := result.(map[string]interface{}); !ok {
|
if _, ok := result.(map[string]any); !ok {
|
||||||
m.opaError(ctx, meta, errOpaInvalidResultType)
|
m.opaError(w, meta, errOpaInvalidResultType)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Is it expensive to marshal back and forth? Should we just manually pull out properties?
|
// Is it expensive to marshal back and forth? Should we just manually pull out properties?
|
||||||
marshaled, err := json.Marshal(result)
|
marshaled, err := json.Marshal(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.opaError(ctx, meta, err)
|
m.opaError(w, meta, err)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -212,31 +205,26 @@ func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middleware
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = json.Unmarshal(marshaled, ®oResult); err != nil {
|
if err = json.Unmarshal(marshaled, ®oResult); err != nil {
|
||||||
m.opaError(ctx, meta, err)
|
m.opaError(w, meta, err)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the result isn't allowed, set the response status and
|
// Set the headers on the ongoing request (overriding as necessary)
|
||||||
// apply the additional headers to the response.
|
for key, value := range regoResult.AdditionalHeaders {
|
||||||
// Otherwise, set the headers on the ongoing request (overriding as necessary).
|
w.Header().Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the result isn't allowed, set the response status
|
||||||
if !regoResult.Allow {
|
if !regoResult.Allow {
|
||||||
ctx.Error(fasthttp.StatusMessage(regoResult.StatusCode), regoResult.StatusCode)
|
httputils.RespondWithError(w, regoResult.StatusCode)
|
||||||
for key, value := range regoResult.AdditionalHeaders {
|
|
||||||
ctx.Response.Header.Set(key, value)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for key, value := range regoResult.AdditionalHeaders {
|
|
||||||
ctx.Request.Header.Set(key, value)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return regoResult.Allow
|
return regoResult.Allow
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) opaError(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, err error) {
|
func (m *Middleware) opaError(w http.ResponseWriter, meta *middlewareMetadata, err error) {
|
||||||
ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus))
|
w.Header().Set(opaErrorHeaderKey, "true")
|
||||||
ctx.Response.Header.Set(opaErrorHeaderKey, "true")
|
httputils.RespondWithError(w, int(meta.DefaultStatus))
|
||||||
m.logger.Warnf("Error procesing rego policy: %v", err)
|
m.logger.Warnf("Error procesing rego policy: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -254,5 +242,16 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*middlewar
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
meta.includedHeadersParsed = strings.Split(meta.IncludedHeaders, ",")
|
||||||
|
n := 0
|
||||||
|
for i := range meta.includedHeadersParsed {
|
||||||
|
scrubbed := strings.ReplaceAll(meta.includedHeadersParsed[i], " ", "")
|
||||||
|
if scrubbed != "" {
|
||||||
|
meta.includedHeadersParsed[n] = textproto.CanonicalMIMEHeaderKey(scrubbed)
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
meta.includedHeadersParsed = meta.includedHeadersParsed[:n]
|
||||||
|
|
||||||
return &meta, nil
|
return &meta, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,13 @@ package opa
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
fh "github.com/valyala/fasthttp"
|
|
||||||
|
|
||||||
"github.com/dapr/components-contrib/metadata"
|
"github.com/dapr/components-contrib/metadata"
|
||||||
"github.com/dapr/components-contrib/middleware"
|
"github.com/dapr/components-contrib/middleware"
|
||||||
|
|
@ -27,17 +29,15 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// mockedRequestHandler acts like an upstream service returns success status code 200 and a fixed response body.
|
// mockedRequestHandler acts like an upstream service returns success status code 200 and a fixed response body.
|
||||||
func mockedRequestHandler(ctx *fh.RequestCtx) {
|
func mockedRequestHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx.Response.SetStatusCode(200)
|
w.WriteHeader(http.StatusOK)
|
||||||
ctx.Response.SetBody([]byte("from mock"))
|
w.Write([]byte("from mock"))
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestConfiguator func(*fh.RequestCtx)
|
|
||||||
|
|
||||||
func TestOpaPolicy(t *testing.T) {
|
func TestOpaPolicy(t *testing.T) {
|
||||||
tests := map[string]struct {
|
tests := map[string]struct {
|
||||||
meta middleware.Metadata
|
meta middleware.Metadata
|
||||||
req RequestConfiguator
|
req func() *http.Request
|
||||||
status int
|
status int
|
||||||
headers *[][]string
|
headers *[][]string
|
||||||
body []string
|
body []string
|
||||||
|
|
@ -124,9 +124,8 @@ func TestOpaPolicy(t *testing.T) {
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
req: func(ctx *fh.RequestCtx) {
|
req: func() *http.Request {
|
||||||
ctx.Request.SetHost("https://my.site")
|
return httptest.NewRequest("GET", "https://my.site/allowed", nil)
|
||||||
ctx.Request.URI().SetPath("/allowed")
|
|
||||||
},
|
},
|
||||||
status: 200,
|
status: 200,
|
||||||
},
|
},
|
||||||
|
|
@ -143,9 +142,8 @@ func TestOpaPolicy(t *testing.T) {
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
req: func(ctx *fh.RequestCtx) {
|
req: func() *http.Request {
|
||||||
ctx.Request.SetHost("https://my.site")
|
return httptest.NewRequest("GET", "https://my.site/forbidden", nil)
|
||||||
ctx.Request.URI().SetPath("/forbidden")
|
|
||||||
},
|
},
|
||||||
status: 403,
|
status: 403,
|
||||||
},
|
},
|
||||||
|
|
@ -162,9 +160,10 @@ func TestOpaPolicy(t *testing.T) {
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
req: func(ctx *fh.RequestCtx) {
|
req: func() *http.Request {
|
||||||
ctx.Request.SetHost("https://my.site")
|
r := httptest.NewRequest("GET", "https://my.site", nil)
|
||||||
ctx.Request.Header.Add("x-bad-header", "1")
|
r.Header.Add("x-bad-header", "1")
|
||||||
|
return r
|
||||||
},
|
},
|
||||||
status: 200,
|
status: 200,
|
||||||
},
|
},
|
||||||
|
|
@ -182,9 +181,10 @@ func TestOpaPolicy(t *testing.T) {
|
||||||
"includedHeaders": "x-bad-header",
|
"includedHeaders": "x-bad-header",
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
req: func(ctx *fh.RequestCtx) {
|
req: func() *http.Request {
|
||||||
ctx.Request.SetHost("https://my.site")
|
r := httptest.NewRequest("GET", "https://my.site", nil)
|
||||||
ctx.Request.Header.Add("x-bad-header", "1")
|
r.Header.Add("X-BAD-HEADER", "1")
|
||||||
|
return r
|
||||||
},
|
},
|
||||||
status: 403,
|
status: 403,
|
||||||
},
|
},
|
||||||
|
|
@ -245,27 +245,46 @@ func TestOpaPolicy(t *testing.T) {
|
||||||
"rego": `
|
"rego": `
|
||||||
package http
|
package http
|
||||||
default allow = false
|
default allow = false
|
||||||
|
allow = { "allow": true } {
|
||||||
allow = { "status_code": 200 } {
|
|
||||||
input.request.body == "allow"
|
input.request.body == "allow"
|
||||||
}
|
}
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
req: func(ctx *fh.RequestCtx) {
|
req: func() *http.Request {
|
||||||
ctx.SetContentType("text/plain; charset=utf8")
|
r := httptest.NewRequest("GET", "https://my.site", strings.NewReader("allow"))
|
||||||
ctx.Request.SetHost("https://my.site")
|
r.Header.Add("content-type", "text/plain; charset=utf8")
|
||||||
ctx.Request.SetBodyString("allow")
|
return r
|
||||||
},
|
},
|
||||||
status: 200,
|
status: 200,
|
||||||
},
|
},
|
||||||
|
"skip reading body": {
|
||||||
|
meta: middleware.Metadata{Base: metadata.Base{
|
||||||
|
Properties: map[string]string{
|
||||||
|
"skipBody": "true",
|
||||||
|
"rego": `
|
||||||
|
package http
|
||||||
|
default allow = false
|
||||||
|
allow = { "status_code": 403 } {
|
||||||
|
input.request.body == "allow"
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
req: func() *http.Request {
|
||||||
|
r := httptest.NewRequest("GET", "https://my.site", strings.NewReader("allow"))
|
||||||
|
r.Header.Add("content-type", "text/plain; charset=utf8")
|
||||||
|
return r
|
||||||
|
},
|
||||||
|
status: 403,
|
||||||
|
},
|
||||||
"allow when multiple headers included with space": {
|
"allow when multiple headers included with space": {
|
||||||
meta: middleware.Metadata{Base: metadata.Base{
|
meta: middleware.Metadata{Base: metadata.Base{
|
||||||
Properties: map[string]string{
|
Properties: map[string]string{
|
||||||
"rego": `
|
"rego": `
|
||||||
package http
|
package http
|
||||||
default allow = false
|
default allow = false
|
||||||
allow = { "status_code": 200 } {
|
allow = { "allow": true } {
|
||||||
input.request.headers["X-Jwt-Header"]
|
input.request.headers["X-Jwt-Header"]
|
||||||
input.request.headers["X-My-Custom-Header"]
|
input.request.headers["X-My-Custom-Header"]
|
||||||
}
|
}
|
||||||
|
|
@ -273,47 +292,76 @@ func TestOpaPolicy(t *testing.T) {
|
||||||
"includedHeaders": "x-my-custom-header, x-jwt-header",
|
"includedHeaders": "x-my-custom-header, x-jwt-header",
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
req: func(ctx *fh.RequestCtx) {
|
req: func() *http.Request {
|
||||||
ctx.Request.SetHost("https://my.site")
|
r := httptest.NewRequest("GET", "https://my.site", nil)
|
||||||
ctx.Request.Header.Add("x-jwt-header", "1")
|
r.Header.Add("x-jwt-header", "1")
|
||||||
ctx.Request.Header.Add("x-my-custom-header", "2")
|
r.Header.Add("x-my-custom-header", "2")
|
||||||
|
return r
|
||||||
},
|
},
|
||||||
status: 200,
|
status: 200,
|
||||||
},
|
},
|
||||||
|
"reject when multiple headers included with space": {
|
||||||
|
meta: middleware.Metadata{Base: metadata.Base{
|
||||||
|
Properties: map[string]string{
|
||||||
|
"rego": `
|
||||||
|
package http
|
||||||
|
default allow = false
|
||||||
|
allow = { "allow": true } {
|
||||||
|
input.request.headers["X-Jwt-Header"]
|
||||||
|
input.request.headers["X-My-Custom-Header"]
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
"includedHeaders": "x-my-custom-header, x-jwt-header",
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
req: func() *http.Request {
|
||||||
|
r := httptest.NewRequest("GET", "https://my.site", nil)
|
||||||
|
r.Header.Add("x-jwt-header", "1")
|
||||||
|
r.Header.Add("x-bad-header", "2")
|
||||||
|
return r
|
||||||
|
},
|
||||||
|
status: 403,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log := logger.NewLogger("opa.test")
|
||||||
|
|
||||||
for name, test := range tests {
|
for name, test := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
log := logger.NewLogger("opa.test")
|
|
||||||
opaMiddleware := NewMiddleware(log)
|
opaMiddleware := NewMiddleware(log)
|
||||||
handler, err := opaMiddleware.GetHandler(test.meta)
|
|
||||||
|
|
||||||
|
handler, err := opaMiddleware.GetHandler(test.meta)
|
||||||
if test.shouldHandlerError {
|
if test.shouldHandlerError {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var reqCtx fh.RequestCtx
|
var r *http.Request
|
||||||
if test.req != nil {
|
if test.req != nil {
|
||||||
test.req(&reqCtx)
|
r = test.req()
|
||||||
|
} else {
|
||||||
|
r = httptest.NewRequest("GET", "https://my.site", nil)
|
||||||
}
|
}
|
||||||
handler(mockedRequestHandler)(&reqCtx)
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||||
|
|
||||||
if test.shouldRegoError {
|
if test.shouldRegoError {
|
||||||
assert.Equal(t, 403, reqCtx.Response.StatusCode())
|
assert.Equal(t, 403, w.Code)
|
||||||
assert.Equal(t, "true", string(reqCtx.Response.Header.Peek(opaErrorHeaderKey)))
|
assert.Equal(t, "true", w.Header().Get(opaErrorHeaderKey))
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, test.status, reqCtx.Response.StatusCode())
|
assert.Equal(t, test.status, w.Code)
|
||||||
|
|
||||||
|
if test.status == 200 {
|
||||||
|
assert.Equal(t, "from mock", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
if test.headers != nil {
|
if test.headers != nil {
|
||||||
for _, header := range *test.headers {
|
for _, header := range *test.headers {
|
||||||
assert.Equal(t, header[1], string(reqCtx.Response.Header.Peek(header[0])))
|
assert.Equal(t, header[1], w.Header().Get(header[0]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -15,14 +15,12 @@ package ratelimit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/didip/tollbooth"
|
"github.com/didip/tollbooth"
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
"github.com/valyala/fasthttp/fasthttpadaptor"
|
|
||||||
|
|
||||||
"github.com/dapr/components-contrib/middleware"
|
"github.com/dapr/components-contrib/middleware"
|
||||||
"github.com/dapr/components-contrib/middleware/http/nethttpadaptor"
|
|
||||||
"github.com/dapr/kit/logger"
|
"github.com/dapr/kit/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -39,17 +37,16 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewRateLimitMiddleware returns a new ratelimit middleware.
|
// NewRateLimitMiddleware returns a new ratelimit middleware.
|
||||||
func NewRateLimitMiddleware(logger logger.Logger) middleware.Middleware {
|
func NewRateLimitMiddleware(_ logger.Logger) middleware.Middleware {
|
||||||
return &Middleware{logger: logger}
|
return &Middleware{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Middleware is an ratelimit middleware.
|
// Middleware is an ratelimit middleware.
|
||||||
type Middleware struct {
|
type Middleware struct {
|
||||||
logger logger.Logger
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHandler returns the HTTP handler provided by the middleware.
|
// GetHandler returns the HTTP handler provided by the middleware.
|
||||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) {
|
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||||
meta, err := m.getNativeMetadata(metadata)
|
meta, err := m.getNativeMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -57,13 +54,8 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.R
|
||||||
|
|
||||||
limiter := tollbooth.NewLimiter(meta.MaxRequestsPerSecond, nil)
|
limiter := tollbooth.NewLimiter(meta.MaxRequestsPerSecond, nil)
|
||||||
|
|
||||||
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler {
|
return func(next http.Handler) http.Handler {
|
||||||
limitHandler := tollbooth.LimitFuncHandler(limiter, nethttpadaptor.NewNetHTTPHandlerFunc(m.logger, h))
|
return tollbooth.LimitHandler(limiter, next)
|
||||||
wrappedHandler := fasthttpadaptor.NewFastHTTPHandlerFunc(limitHandler.ServeHTTP)
|
|
||||||
|
|
||||||
return func(ctx *fasthttp.RequestCtx) {
|
|
||||||
wrappedHandler(ctx)
|
|
||||||
}
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -74,7 +66,7 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*rateLimit
|
||||||
if val, ok := metadata.Properties[maxRequestsPerSecondKey]; ok {
|
if val, ok := metadata.Properties[maxRequestsPerSecondKey]; ok {
|
||||||
f, err := strconv.ParseFloat(val, 64)
|
f, err := strconv.ParseFloat(val, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error parsing ratelimit middleware property %s: %+v", maxRequestsPerSecondKey, err)
|
return nil, fmt.Errorf("error parsing ratelimit middleware property %s: %w", maxRequestsPerSecondKey, err)
|
||||||
}
|
}
|
||||||
if f <= 0 {
|
if f <= 0 {
|
||||||
return nil, fmt.Errorf("ratelimit middleware property %s must be a positive value", maxRequestsPerSecondKey)
|
return nil, fmt.Errorf("ratelimit middleware property %s must be a positive value", maxRequestsPerSecondKey)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue