Updated more middlewares
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
8557183752
commit
e3d2ada01c
|
|
@ -82,7 +82,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
|
|||
session := sessions.Start(w, r)
|
||||
|
||||
if session.GetString(meta.AuthHeaderName) != "" {
|
||||
w.Header().Set(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName))
|
||||
w.Header().Add(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName))
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
|
@ -135,7 +135,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
|
|||
|
||||
authHeader := token.Type() + " " + token.AccessToken
|
||||
session.Set(meta.AuthHeaderName, authHeader)
|
||||
w.Header().Set(meta.AuthHeaderName, authHeader)
|
||||
w.Header().Add(meta.AuthHeaderName, authHeader)
|
||||
httputils.RespondWithRedirect(w, http.StatusFound, redirectURL.String())
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -16,18 +16,18 @@ package oauth2clientcredentials
|
|||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/valyala/fasthttp"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/clientcredentials"
|
||||
|
||||
mdutils "github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/middleware"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
|
@ -40,8 +40,7 @@ type oAuth2ClientCredentialsMiddlewareMetadata struct {
|
|||
TokenURL string `json:"tokenURL"`
|
||||
HeaderName string `json:"headerName"`
|
||||
EndpointParamsQuery string `json:"endpointParamsQuery,omitempty"`
|
||||
AuthStyleString string `json:"authStyle"`
|
||||
AuthStyle int `json:"-"`
|
||||
AuthStyle int `json:"authStyle"`
|
||||
}
|
||||
|
||||
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests.
|
||||
|
|
@ -69,27 +68,16 @@ type Middleware struct {
|
|||
}
|
||||
|
||||
// GetHandler retruns the HTTP handler provided by the middleware.
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) {
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
meta, err := m.getNativeMetadata(metadata)
|
||||
if err != nil {
|
||||
m.log.Errorf("getNativeMetadata error, %s", err)
|
||||
|
||||
m.log.Errorf("getNativeMetadata error: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
var headerValue string
|
||||
// Check if valid Token is in the cache
|
||||
cacheKey := m.getCacheKey(meta)
|
||||
cachedToken, found := m.tokenCache.Get(cacheKey)
|
||||
|
||||
if !found {
|
||||
m.log.Debugf("Cached token not found, try get one")
|
||||
|
||||
endpointParams, err := url.ParseQuery(meta.EndpointParamsQuery)
|
||||
if err != nil {
|
||||
m.log.Errorf("Error parsing endpoint parameters, %s", err)
|
||||
m.log.Errorf("Error parsing endpoint parameters: %s", err)
|
||||
endpointParams, _ = url.ParseQuery("")
|
||||
}
|
||||
|
||||
|
|
@ -102,21 +90,26 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.R
|
|||
AuthStyle: oauth2.AuthStyle(meta.AuthStyle),
|
||||
}
|
||||
|
||||
token, tokenError := m.tokenProvider.GetToken(conf)
|
||||
if tokenError != nil {
|
||||
m.log.Errorf("Error acquiring token, %s", tokenError)
|
||||
cacheKey := m.getCacheKey(meta)
|
||||
|
||||
return
|
||||
}
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var headerValue string
|
||||
|
||||
tokenExpirationDuration := token.Expiry.Sub(time.Now().In(time.UTC))
|
||||
m.log.Debugf("Duration in seconds %s, Expiry Time %s", tokenExpirationDuration, token.Expiry)
|
||||
// Check if valid token is in the cache
|
||||
cachedToken, found := m.tokenCache.Get(cacheKey)
|
||||
if !found {
|
||||
m.log.Debugf("Cached token not found, try get one")
|
||||
|
||||
token, err := m.tokenProvider.GetToken(conf)
|
||||
if err != nil {
|
||||
m.log.Errorf("Error parsing duration string, %s", fmt.Sprintf("%ss", token.Expiry))
|
||||
|
||||
m.log.Errorf("Error acquiring token: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
tokenExpirationDuration := token.Expiry.Sub(time.Now())
|
||||
m.log.Debugf("Token expires at %s (%s from now)", token.Expiry, tokenExpirationDuration)
|
||||
|
||||
headerValue = token.Type() + " " + token.AccessToken
|
||||
m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration)
|
||||
} else {
|
||||
|
|
@ -124,46 +117,37 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.R
|
|||
headerValue = cachedToken.(string)
|
||||
}
|
||||
|
||||
ctx.Request.Header.Add(meta.HeaderName, headerValue)
|
||||
h(ctx)
|
||||
}
|
||||
w.Header().Add(meta.HeaderName, headerValue)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2ClientCredentialsMiddlewareMetadata, error) {
|
||||
b, err := json.Marshal(metadata.Properties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var middlewareMetadata oAuth2ClientCredentialsMiddlewareMetadata
|
||||
err = json.Unmarshal(b, &middlewareMetadata)
|
||||
err := mdutils.DecodeMetadata(metadata.Properties, &middlewareMetadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("metadata errors: %w", err)
|
||||
}
|
||||
|
||||
// Do input validation checks
|
||||
errorString := ""
|
||||
|
||||
// Check if values are present
|
||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.HeaderName, "headerName")
|
||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientID, "clientID")
|
||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientSecret, "clientSecret")
|
||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.Scopes, "scopes")
|
||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.TokenURL, "tokenURL")
|
||||
m.checkMetadataValueExists(&errorString, &middlewareMetadata.AuthStyleString, "authStyle")
|
||||
|
||||
// Converting AuthStyle to int and do a value check
|
||||
authStyle, err := strconv.Atoi(middlewareMetadata.AuthStyleString)
|
||||
if err != nil {
|
||||
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%s'. ", middlewareMetadata.AuthStyleString)
|
||||
} else if authStyle < 0 || authStyle > 2 {
|
||||
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", authStyle)
|
||||
} else {
|
||||
middlewareMetadata.AuthStyle = authStyle
|
||||
// Value-check AuthStyle
|
||||
if middlewareMetadata.AuthStyle < 0 || middlewareMetadata.AuthStyle > 2 {
|
||||
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", middlewareMetadata.AuthStyle)
|
||||
}
|
||||
|
||||
// Return errors if any found
|
||||
if errorString != "" {
|
||||
return nil, fmt.Errorf("%s", errorString)
|
||||
return nil, fmt.Errorf("metadata errors: %s", errorString)
|
||||
}
|
||||
|
||||
return &middlewareMetadata, nil
|
||||
|
|
@ -177,11 +161,8 @@ func (m *Middleware) checkMetadataValueExists(errorString *string, metadataValue
|
|||
|
||||
func (m *Middleware) getCacheKey(meta *oAuth2ClientCredentialsMiddlewareMetadata) string {
|
||||
// we will hash the key components ClientID + Scopes is a unique composite key/identifier for a token
|
||||
hashedKey := sha256.New()
|
||||
key := strings.Join([]string{meta.ClientID, meta.Scopes}, "")
|
||||
hashedKey.Write([]byte(key))
|
||||
|
||||
return fmt.Sprintf("%x", hashedKey.Sum(nil))
|
||||
hashedKey := sha256.Sum224([]byte(meta.ClientID + meta.Scopes))
|
||||
return hex.EncodeToString(hashedKey[:])
|
||||
}
|
||||
|
||||
// SetTokenProvider will enable to change the tokenProvider used after instanciation (needed for mocking).
|
||||
|
|
|
|||
|
|
@ -14,13 +14,14 @@ limitations under the License.
|
|||
package oauth2clientcredentials
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
fh "github.com/valyala/fasthttp"
|
||||
oauth2 "golang.org/x/oauth2"
|
||||
|
||||
"github.com/dapr/components-contrib/middleware"
|
||||
|
|
@ -28,7 +29,11 @@ import (
|
|||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
func mockedRequestHandler(ctx *fh.RequestCtx) {}
|
||||
// mockedRequestHandler acts like an upstream service returns success status code 200 and a fixed response body.
|
||||
func mockedRequestHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("from mock"))
|
||||
}
|
||||
|
||||
// TestOAuth2ClientCredentialsMetadata will check
|
||||
// - if the metadata checks are correct in place.
|
||||
|
|
@ -41,7 +46,7 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
|
|||
|
||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||
assert.EqualError(t, err, "Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. Parameter 'authStyle' needs to be set. Parameter 'authStyle' can only have the values 0,1,2. Received: ''. ")
|
||||
assert.EqualError(t, err, "metadata errors: Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. ")
|
||||
|
||||
// Invalid authStyle (non int)
|
||||
metadata.Properties = map[string]string{
|
||||
|
|
@ -53,17 +58,17 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
|
|||
"authStyle": "asdf", // This is the value to test
|
||||
}
|
||||
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||
assert.EqualError(t, err2, "Parameter 'authStyle' can only have the values 0,1,2. Received: 'asdf'. ")
|
||||
assert.EqualError(t, err2, "metadata errors: 1 error(s) decoding:\n\n* cannot parse 'AuthStyle' as int: strconv.ParseInt: parsing \"asdf\": invalid syntax")
|
||||
|
||||
// Invalid authStyle (int > 2)
|
||||
metadata.Properties["authStyle"] = "3"
|
||||
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||
assert.EqualError(t, err3, "Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
|
||||
assert.EqualError(t, err3, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
|
||||
|
||||
// Invalid authStyle (int < 0)
|
||||
metadata.Properties["authStyle"] = "-1"
|
||||
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||
assert.EqualError(t, err4, "Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
|
||||
assert.EqualError(t, err4, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
|
||||
}
|
||||
|
||||
// TestOAuth2ClientCredentialsToken will check
|
||||
|
|
@ -108,10 +113,12 @@ func TestOAuth2ClientCredentialsToken(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// First handler call should return abc Token
|
||||
var requestContext1 fh.RequestCtx
|
||||
handler(mockedRequestHandler)(&requestContext1)
|
||||
r := httptest.NewRequest("GET", "http://dapr.io", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||
|
||||
// Assertion
|
||||
assert.Equal(t, "Bearer abcd", string(requestContext1.Request.Header.Peek("someHeader")))
|
||||
assert.Equal(t, "Bearer abcd", w.Header().Get("someHeader"))
|
||||
}
|
||||
|
||||
// TestOAuth2ClientCredentialsCache will check
|
||||
|
|
@ -166,23 +173,29 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// First handler call should return abc Token
|
||||
var requestContext1 fh.RequestCtx
|
||||
handler(mockedRequestHandler)(&requestContext1)
|
||||
r := httptest.NewRequest("GET", "http://dapr.io", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||
|
||||
// Assertion
|
||||
assert.Equal(t, "Bearer abc", string(requestContext1.Request.Header.Peek("someHeader")))
|
||||
assert.Equal(t, "Bearer abc", w.Header().Get("someHeader"))
|
||||
|
||||
// Second handler call should still return 'cached' abc Token
|
||||
var requestContext2 fh.RequestCtx
|
||||
handler(mockedRequestHandler)(&requestContext2)
|
||||
r = httptest.NewRequest("GET", "http://dapr.io", nil)
|
||||
w = httptest.NewRecorder()
|
||||
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||
|
||||
// Assertion
|
||||
assert.Equal(t, "Bearer abc", string(requestContext2.Request.Header.Peek("someHeader")))
|
||||
assert.Equal(t, "Bearer abc", w.Header().Get("someHeader"))
|
||||
|
||||
// Wait at a second to invalidate cache entry for abc
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Third call should return def Token
|
||||
var requestContext3 fh.RequestCtx
|
||||
handler(mockedRequestHandler)(&requestContext3)
|
||||
r = httptest.NewRequest("GET", "http://dapr.io", nil)
|
||||
w = httptest.NewRecorder()
|
||||
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||
|
||||
// Assertion
|
||||
assert.Equal(t, "MAC def", string(requestContext3.Request.Header.Peek("someHeader")))
|
||||
assert.Equal(t, "MAC def", w.Header().Get("someHeader"))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,13 +19,18 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
"github.com/valyala/fasthttp"
|
||||
"k8s.io/utils/strings/slices"
|
||||
|
||||
"github.com/dapr/components-contrib/internal/httputils"
|
||||
"github.com/dapr/components-contrib/internal/utils"
|
||||
"github.com/dapr/components-contrib/middleware"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
|
@ -36,6 +41,8 @@ type middlewareMetadata struct {
|
|||
Rego string `json:"rego"`
|
||||
DefaultStatus Status `json:"defaultStatus,omitempty"`
|
||||
IncludedHeaders string `json:"includedHeaders,omitempty"`
|
||||
SkipBody string `json:"skipBody,omitempty"`
|
||||
includedHeadersParsed []string `json:"-"`
|
||||
}
|
||||
|
||||
// NewMiddleware returns a new Open Policy Agent middleware.
|
||||
|
|
@ -98,110 +105,96 @@ func (s *Status) Valid() bool {
|
|||
}
|
||||
|
||||
// GetHandler returns the HTTP handler provided by the middleware.
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) {
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
meta, err := m.getNativeMetadata(metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
query, err := rego.New(
|
||||
rego.Query("result = data.http.allow"),
|
||||
rego.Module("inline.rego", meta.Rego),
|
||||
).PrepareForEval(ctx)
|
||||
).PrepareForEval(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
if allow := m.evalRequest(ctx, meta, &query); !allow {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if allow := m.evalRequest(w, r, meta, &query); !allow {
|
||||
return
|
||||
}
|
||||
h(ctx)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Middleware) evalRequest(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, query *rego.PreparedEvalQuery) bool {
|
||||
func (m *Middleware) evalRequest(w http.ResponseWriter, r *http.Request, meta *middlewareMetadata, query *rego.PreparedEvalQuery) bool {
|
||||
headers := map[string]string{}
|
||||
allowedHeaders := strings.Split(meta.IncludedHeaders, ",")
|
||||
ctx.Request.Header.VisitAll(func(key, value []byte) {
|
||||
for _, allowedHeader := range allowedHeaders {
|
||||
scrubbedHeader := strings.ReplaceAll(allowedHeader, " ", "")
|
||||
buf := []byte("")
|
||||
result := fasthttp.AppendNormalizedHeaderKeyBytes(buf, []byte(scrubbedHeader))
|
||||
normalizedHeader := result[0:]
|
||||
if bytes.Equal(key, normalizedHeader) {
|
||||
headers[string(key)] = string(value)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
queryArgs := map[string][]string{}
|
||||
ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||
if val, ok := queryArgs[string(key)]; ok {
|
||||
queryArgs[string(key)] = append(val, string(value))
|
||||
} else {
|
||||
queryArgs[string(key)] = []string{string(value)}
|
||||
for key, value := range r.Header {
|
||||
if slices.Contains(meta.includedHeadersParsed, key) {
|
||||
headers[key] = value[0]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
path := string(ctx.Path())
|
||||
pathParts := strings.Split(strings.Trim(path, "/"), "/")
|
||||
var body string
|
||||
if !utils.IsTruthy(meta.SkipBody) {
|
||||
buf, _ := io.ReadAll(r.Body)
|
||||
body = string(buf)
|
||||
|
||||
// Put the body back in the request
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(buf))
|
||||
}
|
||||
|
||||
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
|
||||
input := map[string]interface{}{
|
||||
"request": map[string]interface{}{
|
||||
"method": string(ctx.Method()),
|
||||
"path": path,
|
||||
"method": r.Method,
|
||||
"path": r.URL.Path,
|
||||
"path_parts": pathParts,
|
||||
"raw_query": string(ctx.QueryArgs().QueryString()),
|
||||
"query": queryArgs,
|
||||
"raw_query": r.URL.RawQuery,
|
||||
"query": map[string][]string(r.URL.Query()),
|
||||
"headers": headers,
|
||||
"scheme": string(ctx.Request.URI().Scheme()),
|
||||
"body": string(ctx.Request.Body()),
|
||||
"scheme": r.URL.Scheme,
|
||||
"body": body,
|
||||
},
|
||||
}
|
||||
|
||||
results, err := query.Eval(context.TODO(), rego.EvalInput(input))
|
||||
results, err := query.Eval(r.Context(), rego.EvalInput(input))
|
||||
if err != nil {
|
||||
m.opaError(ctx, meta, err)
|
||||
|
||||
m.opaError(w, meta, err)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
m.opaError(ctx, meta, errOpaNoResult)
|
||||
|
||||
m.opaError(w, meta, errOpaNoResult)
|
||||
return false
|
||||
}
|
||||
|
||||
return m.handleRegoResult(ctx, meta, results[0].Bindings["result"])
|
||||
return m.handleRegoResult(w, meta, results[0].Bindings["result"])
|
||||
}
|
||||
|
||||
// handleRegoResult takes the in process request and open policy agent evaluation result
|
||||
// and maps it the appropriate response or headers.
|
||||
// It returns true if the request should continue, or false if a response should be immediately returned.
|
||||
func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, result interface{}) bool {
|
||||
func (m *Middleware) handleRegoResult(w http.ResponseWriter, meta *middlewareMetadata, result any) bool {
|
||||
if allowed, ok := result.(bool); ok {
|
||||
if !allowed {
|
||||
ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus))
|
||||
httputils.RespondWithError(w, int(meta.DefaultStatus))
|
||||
}
|
||||
|
||||
return allowed
|
||||
}
|
||||
|
||||
if _, ok := result.(map[string]interface{}); !ok {
|
||||
m.opaError(ctx, meta, errOpaInvalidResultType)
|
||||
|
||||
if _, ok := result.(map[string]any); !ok {
|
||||
m.opaError(w, meta, errOpaInvalidResultType)
|
||||
return false
|
||||
}
|
||||
|
||||
// Is it expensive to marshal back and forth? Should we just manually pull out properties?
|
||||
marshaled, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
m.opaError(ctx, meta, err)
|
||||
|
||||
m.opaError(w, meta, err)
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
@ -212,31 +205,26 @@ func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middleware
|
|||
}
|
||||
|
||||
if err = json.Unmarshal(marshaled, ®oResult); err != nil {
|
||||
m.opaError(ctx, meta, err)
|
||||
|
||||
m.opaError(w, meta, err)
|
||||
return false
|
||||
}
|
||||
|
||||
// If the result isn't allowed, set the response status and
|
||||
// apply the additional headers to the response.
|
||||
// Otherwise, set the headers on the ongoing request (overriding as necessary).
|
||||
// Set the headers on the ongoing request (overriding as necessary)
|
||||
for key, value := range regoResult.AdditionalHeaders {
|
||||
w.Header().Set(key, value)
|
||||
}
|
||||
|
||||
// If the result isn't allowed, set the response status
|
||||
if !regoResult.Allow {
|
||||
ctx.Error(fasthttp.StatusMessage(regoResult.StatusCode), regoResult.StatusCode)
|
||||
for key, value := range regoResult.AdditionalHeaders {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
} else {
|
||||
for key, value := range regoResult.AdditionalHeaders {
|
||||
ctx.Request.Header.Set(key, value)
|
||||
}
|
||||
httputils.RespondWithError(w, regoResult.StatusCode)
|
||||
}
|
||||
|
||||
return regoResult.Allow
|
||||
}
|
||||
|
||||
func (m *Middleware) opaError(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, err error) {
|
||||
ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus))
|
||||
ctx.Response.Header.Set(opaErrorHeaderKey, "true")
|
||||
func (m *Middleware) opaError(w http.ResponseWriter, meta *middlewareMetadata, err error) {
|
||||
w.Header().Set(opaErrorHeaderKey, "true")
|
||||
httputils.RespondWithError(w, int(meta.DefaultStatus))
|
||||
m.logger.Warnf("Error procesing rego policy: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -254,5 +242,16 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*middlewar
|
|||
return nil, err
|
||||
}
|
||||
|
||||
meta.includedHeadersParsed = strings.Split(meta.IncludedHeaders, ",")
|
||||
n := 0
|
||||
for i := range meta.includedHeadersParsed {
|
||||
scrubbed := strings.ReplaceAll(meta.includedHeadersParsed[i], " ", "")
|
||||
if scrubbed != "" {
|
||||
meta.includedHeadersParsed[n] = textproto.CanonicalMIMEHeaderKey(scrubbed)
|
||||
n++
|
||||
}
|
||||
}
|
||||
meta.includedHeadersParsed = meta.includedHeadersParsed[:n]
|
||||
|
||||
return &meta, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,11 +15,13 @@ package opa
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
fh "github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/middleware"
|
||||
|
|
@ -27,17 +29,15 @@ import (
|
|||
)
|
||||
|
||||
// mockedRequestHandler acts like an upstream service returns success status code 200 and a fixed response body.
|
||||
func mockedRequestHandler(ctx *fh.RequestCtx) {
|
||||
ctx.Response.SetStatusCode(200)
|
||||
ctx.Response.SetBody([]byte("from mock"))
|
||||
func mockedRequestHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("from mock"))
|
||||
}
|
||||
|
||||
type RequestConfiguator func(*fh.RequestCtx)
|
||||
|
||||
func TestOpaPolicy(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
meta middleware.Metadata
|
||||
req RequestConfiguator
|
||||
req func() *http.Request
|
||||
status int
|
||||
headers *[][]string
|
||||
body []string
|
||||
|
|
@ -124,9 +124,8 @@ func TestOpaPolicy(t *testing.T) {
|
|||
`,
|
||||
},
|
||||
}},
|
||||
req: func(ctx *fh.RequestCtx) {
|
||||
ctx.Request.SetHost("https://my.site")
|
||||
ctx.Request.URI().SetPath("/allowed")
|
||||
req: func() *http.Request {
|
||||
return httptest.NewRequest("GET", "https://my.site/allowed", nil)
|
||||
},
|
||||
status: 200,
|
||||
},
|
||||
|
|
@ -143,9 +142,8 @@ func TestOpaPolicy(t *testing.T) {
|
|||
`,
|
||||
},
|
||||
}},
|
||||
req: func(ctx *fh.RequestCtx) {
|
||||
ctx.Request.SetHost("https://my.site")
|
||||
ctx.Request.URI().SetPath("/forbidden")
|
||||
req: func() *http.Request {
|
||||
return httptest.NewRequest("GET", "https://my.site/forbidden", nil)
|
||||
},
|
||||
status: 403,
|
||||
},
|
||||
|
|
@ -162,9 +160,10 @@ func TestOpaPolicy(t *testing.T) {
|
|||
`,
|
||||
},
|
||||
}},
|
||||
req: func(ctx *fh.RequestCtx) {
|
||||
ctx.Request.SetHost("https://my.site")
|
||||
ctx.Request.Header.Add("x-bad-header", "1")
|
||||
req: func() *http.Request {
|
||||
r := httptest.NewRequest("GET", "https://my.site", nil)
|
||||
r.Header.Add("x-bad-header", "1")
|
||||
return r
|
||||
},
|
||||
status: 200,
|
||||
},
|
||||
|
|
@ -182,9 +181,10 @@ func TestOpaPolicy(t *testing.T) {
|
|||
"includedHeaders": "x-bad-header",
|
||||
},
|
||||
}},
|
||||
req: func(ctx *fh.RequestCtx) {
|
||||
ctx.Request.SetHost("https://my.site")
|
||||
ctx.Request.Header.Add("x-bad-header", "1")
|
||||
req: func() *http.Request {
|
||||
r := httptest.NewRequest("GET", "https://my.site", nil)
|
||||
r.Header.Add("X-BAD-HEADER", "1")
|
||||
return r
|
||||
},
|
||||
status: 403,
|
||||
},
|
||||
|
|
@ -245,27 +245,46 @@ func TestOpaPolicy(t *testing.T) {
|
|||
"rego": `
|
||||
package http
|
||||
default allow = false
|
||||
|
||||
allow = { "status_code": 200 } {
|
||||
allow = { "allow": true } {
|
||||
input.request.body == "allow"
|
||||
}
|
||||
`,
|
||||
},
|
||||
}},
|
||||
req: func(ctx *fh.RequestCtx) {
|
||||
ctx.SetContentType("text/plain; charset=utf8")
|
||||
ctx.Request.SetHost("https://my.site")
|
||||
ctx.Request.SetBodyString("allow")
|
||||
req: func() *http.Request {
|
||||
r := httptest.NewRequest("GET", "https://my.site", strings.NewReader("allow"))
|
||||
r.Header.Add("content-type", "text/plain; charset=utf8")
|
||||
return r
|
||||
},
|
||||
status: 200,
|
||||
},
|
||||
"skip reading body": {
|
||||
meta: middleware.Metadata{Base: metadata.Base{
|
||||
Properties: map[string]string{
|
||||
"skipBody": "true",
|
||||
"rego": `
|
||||
package http
|
||||
default allow = false
|
||||
allow = { "status_code": 403 } {
|
||||
input.request.body == "allow"
|
||||
}
|
||||
`,
|
||||
},
|
||||
}},
|
||||
req: func() *http.Request {
|
||||
r := httptest.NewRequest("GET", "https://my.site", strings.NewReader("allow"))
|
||||
r.Header.Add("content-type", "text/plain; charset=utf8")
|
||||
return r
|
||||
},
|
||||
status: 403,
|
||||
},
|
||||
"allow when multiple headers included with space": {
|
||||
meta: middleware.Metadata{Base: metadata.Base{
|
||||
Properties: map[string]string{
|
||||
"rego": `
|
||||
package http
|
||||
default allow = false
|
||||
allow = { "status_code": 200 } {
|
||||
allow = { "allow": true } {
|
||||
input.request.headers["X-Jwt-Header"]
|
||||
input.request.headers["X-My-Custom-Header"]
|
||||
}
|
||||
|
|
@ -273,47 +292,76 @@ func TestOpaPolicy(t *testing.T) {
|
|||
"includedHeaders": "x-my-custom-header, x-jwt-header",
|
||||
},
|
||||
}},
|
||||
req: func(ctx *fh.RequestCtx) {
|
||||
ctx.Request.SetHost("https://my.site")
|
||||
ctx.Request.Header.Add("x-jwt-header", "1")
|
||||
ctx.Request.Header.Add("x-my-custom-header", "2")
|
||||
req: func() *http.Request {
|
||||
r := httptest.NewRequest("GET", "https://my.site", nil)
|
||||
r.Header.Add("x-jwt-header", "1")
|
||||
r.Header.Add("x-my-custom-header", "2")
|
||||
return r
|
||||
},
|
||||
status: 200,
|
||||
},
|
||||
"reject when multiple headers included with space": {
|
||||
meta: middleware.Metadata{Base: metadata.Base{
|
||||
Properties: map[string]string{
|
||||
"rego": `
|
||||
package http
|
||||
default allow = false
|
||||
allow = { "allow": true } {
|
||||
input.request.headers["X-Jwt-Header"]
|
||||
input.request.headers["X-My-Custom-Header"]
|
||||
}
|
||||
`,
|
||||
"includedHeaders": "x-my-custom-header, x-jwt-header",
|
||||
},
|
||||
}},
|
||||
req: func() *http.Request {
|
||||
r := httptest.NewRequest("GET", "https://my.site", nil)
|
||||
r.Header.Add("x-jwt-header", "1")
|
||||
r.Header.Add("x-bad-header", "2")
|
||||
return r
|
||||
},
|
||||
status: 403,
|
||||
},
|
||||
}
|
||||
|
||||
log := logger.NewLogger("opa.test")
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
log := logger.NewLogger("opa.test")
|
||||
opaMiddleware := NewMiddleware(log)
|
||||
handler, err := opaMiddleware.GetHandler(test.meta)
|
||||
|
||||
handler, err := opaMiddleware.GetHandler(test.meta)
|
||||
if test.shouldHandlerError {
|
||||
require.Error(t, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
var reqCtx fh.RequestCtx
|
||||
var r *http.Request
|
||||
if test.req != nil {
|
||||
test.req(&reqCtx)
|
||||
r = test.req()
|
||||
} else {
|
||||
r = httptest.NewRequest("GET", "https://my.site", nil)
|
||||
}
|
||||
handler(mockedRequestHandler)(&reqCtx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
|
||||
|
||||
if test.shouldRegoError {
|
||||
assert.Equal(t, 403, reqCtx.Response.StatusCode())
|
||||
assert.Equal(t, "true", string(reqCtx.Response.Header.Peek(opaErrorHeaderKey)))
|
||||
|
||||
assert.Equal(t, 403, w.Code)
|
||||
assert.Equal(t, "true", w.Header().Get(opaErrorHeaderKey))
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, test.status, reqCtx.Response.StatusCode())
|
||||
assert.Equal(t, test.status, w.Code)
|
||||
|
||||
if test.status == 200 {
|
||||
assert.Equal(t, "from mock", w.Body.String())
|
||||
}
|
||||
|
||||
if test.headers != nil {
|
||||
for _, header := range *test.headers {
|
||||
assert.Equal(t, header[1], string(reqCtx.Response.Header.Peek(header[0])))
|
||||
assert.Equal(t, header[1], w.Header().Get(header[0]))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -15,14 +15,12 @@ package ratelimit
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/didip/tollbooth"
|
||||
"github.com/valyala/fasthttp"
|
||||
"github.com/valyala/fasthttp/fasthttpadaptor"
|
||||
|
||||
"github.com/dapr/components-contrib/middleware"
|
||||
"github.com/dapr/components-contrib/middleware/http/nethttpadaptor"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
|
|
@ -39,17 +37,16 @@ const (
|
|||
)
|
||||
|
||||
// NewRateLimitMiddleware returns a new ratelimit middleware.
|
||||
func NewRateLimitMiddleware(logger logger.Logger) middleware.Middleware {
|
||||
return &Middleware{logger: logger}
|
||||
func NewRateLimitMiddleware(_ logger.Logger) middleware.Middleware {
|
||||
return &Middleware{}
|
||||
}
|
||||
|
||||
// Middleware is an ratelimit middleware.
|
||||
type Middleware struct {
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
// GetHandler returns the HTTP handler provided by the middleware.
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) {
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
meta, err := m.getNativeMetadata(metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -57,13 +54,8 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.R
|
|||
|
||||
limiter := tollbooth.NewLimiter(meta.MaxRequestsPerSecond, nil)
|
||||
|
||||
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||
limitHandler := tollbooth.LimitFuncHandler(limiter, nethttpadaptor.NewNetHTTPHandlerFunc(m.logger, h))
|
||||
wrappedHandler := fasthttpadaptor.NewFastHTTPHandlerFunc(limitHandler.ServeHTTP)
|
||||
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
wrappedHandler(ctx)
|
||||
}
|
||||
return func(next http.Handler) http.Handler {
|
||||
return tollbooth.LimitHandler(limiter, next)
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -74,7 +66,7 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*rateLimit
|
|||
if val, ok := metadata.Properties[maxRequestsPerSecondKey]; ok {
|
||||
f, err := strconv.ParseFloat(val, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing ratelimit middleware property %s: %+v", maxRequestsPerSecondKey, err)
|
||||
return nil, fmt.Errorf("error parsing ratelimit middleware property %s: %w", maxRequestsPerSecondKey, err)
|
||||
}
|
||||
if f <= 0 {
|
||||
return nil, fmt.Errorf("ratelimit middleware property %s must be a positive value", maxRequestsPerSecondKey)
|
||||
|
|
|
|||
Loading…
Reference in New Issue