Updated more middlewares

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2022-09-29 18:25:52 +00:00
parent 8557183752
commit e3d2ada01c
6 changed files with 244 additions and 211 deletions

View File

@ -82,7 +82,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
session := sessions.Start(w, r)
if session.GetString(meta.AuthHeaderName) != "" {
w.Header().Set(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName))
w.Header().Add(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName))
next.ServeHTTP(w, r)
return
}
@ -135,7 +135,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
authHeader := token.Type() + " " + token.AccessToken
session.Set(meta.AuthHeaderName, authHeader)
w.Header().Set(meta.AuthHeaderName, authHeader)
w.Header().Add(meta.AuthHeaderName, authHeader)
httputils.RespondWithRedirect(w, http.StatusFound, redirectURL.String())
}
})

View File

@ -16,18 +16,18 @@ package oauth2clientcredentials
import (
"context"
"crypto/sha256"
"encoding/json"
"encoding/hex"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/patrickmn/go-cache"
"github.com/valyala/fasthttp"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
mdutils "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/middleware"
"github.com/dapr/kit/logger"
)
@ -40,8 +40,7 @@ type oAuth2ClientCredentialsMiddlewareMetadata struct {
TokenURL string `json:"tokenURL"`
HeaderName string `json:"headerName"`
EndpointParamsQuery string `json:"endpointParamsQuery,omitempty"`
AuthStyleString string `json:"authStyle"`
AuthStyle int `json:"-"`
AuthStyle int `json:"authStyle"`
}
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests.
@ -69,53 +68,47 @@ type Middleware struct {
}
// GetHandler retruns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) {
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata)
if err != nil {
m.log.Errorf("getNativeMetadata error, %s", err)
m.log.Errorf("getNativeMetadata error: %s", err)
return nil, err
}
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
var headerValue string
// Check if valid Token is in the cache
cacheKey := m.getCacheKey(meta)
cachedToken, found := m.tokenCache.Get(cacheKey)
endpointParams, err := url.ParseQuery(meta.EndpointParamsQuery)
if err != nil {
m.log.Errorf("Error parsing endpoint parameters: %s", err)
endpointParams, _ = url.ParseQuery("")
}
conf := &clientcredentials.Config{
ClientID: meta.ClientID,
ClientSecret: meta.ClientSecret,
Scopes: strings.Split(meta.Scopes, ","),
TokenURL: meta.TokenURL,
EndpointParams: endpointParams,
AuthStyle: oauth2.AuthStyle(meta.AuthStyle),
}
cacheKey := m.getCacheKey(meta)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var headerValue string
// Check if valid token is in the cache
cachedToken, found := m.tokenCache.Get(cacheKey)
if !found {
m.log.Debugf("Cached token not found, try get one")
endpointParams, err := url.ParseQuery(meta.EndpointParamsQuery)
token, err := m.tokenProvider.GetToken(conf)
if err != nil {
m.log.Errorf("Error parsing endpoint parameters, %s", err)
endpointParams, _ = url.ParseQuery("")
}
conf := &clientcredentials.Config{
ClientID: meta.ClientID,
ClientSecret: meta.ClientSecret,
Scopes: strings.Split(meta.Scopes, ","),
TokenURL: meta.TokenURL,
EndpointParams: endpointParams,
AuthStyle: oauth2.AuthStyle(meta.AuthStyle),
}
token, tokenError := m.tokenProvider.GetToken(conf)
if tokenError != nil {
m.log.Errorf("Error acquiring token, %s", tokenError)
m.log.Errorf("Error acquiring token: %s", err)
return
}
tokenExpirationDuration := token.Expiry.Sub(time.Now().In(time.UTC))
m.log.Debugf("Duration in seconds %s, Expiry Time %s", tokenExpirationDuration, token.Expiry)
if err != nil {
m.log.Errorf("Error parsing duration string, %s", fmt.Sprintf("%ss", token.Expiry))
return
}
tokenExpirationDuration := token.Expiry.Sub(time.Now())
m.log.Debugf("Token expires at %s (%s from now)", token.Expiry, tokenExpirationDuration)
headerValue = token.Type() + " " + token.AccessToken
m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration)
@ -124,46 +117,37 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.R
headerValue = cachedToken.(string)
}
ctx.Request.Header.Add(meta.HeaderName, headerValue)
h(ctx)
}
w.Header().Add(meta.HeaderName, headerValue)
next.ServeHTTP(w, r)
})
}, nil
}
func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2ClientCredentialsMiddlewareMetadata, error) {
b, err := json.Marshal(metadata.Properties)
if err != nil {
return nil, err
}
var middlewareMetadata oAuth2ClientCredentialsMiddlewareMetadata
err = json.Unmarshal(b, &middlewareMetadata)
err := mdutils.DecodeMetadata(metadata.Properties, &middlewareMetadata)
if err != nil {
return nil, err
return nil, fmt.Errorf("metadata errors: %w", err)
}
// Do input validation checks
errorString := ""
// Check if values are present
m.checkMetadataValueExists(&errorString, &middlewareMetadata.HeaderName, "headerName")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientID, "clientID")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientSecret, "clientSecret")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.Scopes, "scopes")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.TokenURL, "tokenURL")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.AuthStyleString, "authStyle")
// Converting AuthStyle to int and do a value check
authStyle, err := strconv.Atoi(middlewareMetadata.AuthStyleString)
if err != nil {
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%s'. ", middlewareMetadata.AuthStyleString)
} else if authStyle < 0 || authStyle > 2 {
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", authStyle)
} else {
middlewareMetadata.AuthStyle = authStyle
// Value-check AuthStyle
if middlewareMetadata.AuthStyle < 0 || middlewareMetadata.AuthStyle > 2 {
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", middlewareMetadata.AuthStyle)
}
// Return errors if any found
if errorString != "" {
return nil, fmt.Errorf("%s", errorString)
return nil, fmt.Errorf("metadata errors: %s", errorString)
}
return &middlewareMetadata, nil
@ -177,11 +161,8 @@ func (m *Middleware) checkMetadataValueExists(errorString *string, metadataValue
func (m *Middleware) getCacheKey(meta *oAuth2ClientCredentialsMiddlewareMetadata) string {
// we will hash the key components ClientID + Scopes is a unique composite key/identifier for a token
hashedKey := sha256.New()
key := strings.Join([]string{meta.ClientID, meta.Scopes}, "")
hashedKey.Write([]byte(key))
return fmt.Sprintf("%x", hashedKey.Sum(nil))
hashedKey := sha256.Sum224([]byte(meta.ClientID + meta.Scopes))
return hex.EncodeToString(hashedKey[:])
}
// SetTokenProvider will enable to change the tokenProvider used after instanciation (needed for mocking).

View File

@ -14,13 +14,14 @@ limitations under the License.
package oauth2clientcredentials
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
fh "github.com/valyala/fasthttp"
oauth2 "golang.org/x/oauth2"
"github.com/dapr/components-contrib/middleware"
@ -28,7 +29,11 @@ import (
"github.com/dapr/kit/logger"
)
func mockedRequestHandler(ctx *fh.RequestCtx) {}
// mockedRequestHandler acts like an upstream service returns success status code 200 and a fixed response body.
func mockedRequestHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("from mock"))
}
// TestOAuth2ClientCredentialsMetadata will check
// - if the metadata checks are correct in place.
@ -41,7 +46,7 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
log := logger.NewLogger("oauth2clientcredentials.test")
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
assert.EqualError(t, err, "Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. Parameter 'authStyle' needs to be set. Parameter 'authStyle' can only have the values 0,1,2. Received: ''. ")
assert.EqualError(t, err, "metadata errors: Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. ")
// Invalid authStyle (non int)
metadata.Properties = map[string]string{
@ -53,17 +58,17 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
"authStyle": "asdf", // This is the value to test
}
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
assert.EqualError(t, err2, "Parameter 'authStyle' can only have the values 0,1,2. Received: 'asdf'. ")
assert.EqualError(t, err2, "metadata errors: 1 error(s) decoding:\n\n* cannot parse 'AuthStyle' as int: strconv.ParseInt: parsing \"asdf\": invalid syntax")
// Invalid authStyle (int > 2)
metadata.Properties["authStyle"] = "3"
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
assert.EqualError(t, err3, "Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
assert.EqualError(t, err3, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
// Invalid authStyle (int < 0)
metadata.Properties["authStyle"] = "-1"
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
assert.EqualError(t, err4, "Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
assert.EqualError(t, err4, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
}
// TestOAuth2ClientCredentialsToken will check
@ -108,10 +113,12 @@ func TestOAuth2ClientCredentialsToken(t *testing.T) {
require.NoError(t, err)
// First handler call should return abc Token
var requestContext1 fh.RequestCtx
handler(mockedRequestHandler)(&requestContext1)
r := httptest.NewRequest("GET", "http://dapr.io", nil)
w := httptest.NewRecorder()
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
// Assertion
assert.Equal(t, "Bearer abcd", string(requestContext1.Request.Header.Peek("someHeader")))
assert.Equal(t, "Bearer abcd", w.Header().Get("someHeader"))
}
// TestOAuth2ClientCredentialsCache will check
@ -166,23 +173,29 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) {
require.NoError(t, err)
// First handler call should return abc Token
var requestContext1 fh.RequestCtx
handler(mockedRequestHandler)(&requestContext1)
r := httptest.NewRequest("GET", "http://dapr.io", nil)
w := httptest.NewRecorder()
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
// Assertion
assert.Equal(t, "Bearer abc", string(requestContext1.Request.Header.Peek("someHeader")))
assert.Equal(t, "Bearer abc", w.Header().Get("someHeader"))
// Second handler call should still return 'cached' abc Token
var requestContext2 fh.RequestCtx
handler(mockedRequestHandler)(&requestContext2)
r = httptest.NewRequest("GET", "http://dapr.io", nil)
w = httptest.NewRecorder()
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
// Assertion
assert.Equal(t, "Bearer abc", string(requestContext2.Request.Header.Peek("someHeader")))
assert.Equal(t, "Bearer abc", w.Header().Get("someHeader"))
// Wait at a second to invalidate cache entry for abc
time.Sleep(1 * time.Second)
// Third call should return def Token
var requestContext3 fh.RequestCtx
handler(mockedRequestHandler)(&requestContext3)
r = httptest.NewRequest("GET", "http://dapr.io", nil)
w = httptest.NewRecorder()
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
// Assertion
assert.Equal(t, "MAC def", string(requestContext3.Request.Header.Peek("someHeader")))
assert.Equal(t, "MAC def", w.Header().Get("someHeader"))
}

View File

@ -19,13 +19,18 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"math"
"net/http"
"net/textproto"
"strconv"
"strings"
"github.com/open-policy-agent/opa/rego"
"github.com/valyala/fasthttp"
"k8s.io/utils/strings/slices"
"github.com/dapr/components-contrib/internal/httputils"
"github.com/dapr/components-contrib/internal/utils"
"github.com/dapr/components-contrib/middleware"
"github.com/dapr/kit/logger"
)
@ -33,9 +38,11 @@ import (
type Status int
type middlewareMetadata struct {
Rego string `json:"rego"`
DefaultStatus Status `json:"defaultStatus,omitempty"`
IncludedHeaders string `json:"includedHeaders,omitempty"`
Rego string `json:"rego"`
DefaultStatus Status `json:"defaultStatus,omitempty"`
IncludedHeaders string `json:"includedHeaders,omitempty"`
SkipBody string `json:"skipBody,omitempty"`
includedHeadersParsed []string `json:"-"`
}
// NewMiddleware returns a new Open Policy Agent middleware.
@ -98,110 +105,96 @@ func (s *Status) Valid() bool {
}
// GetHandler returns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) {
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata)
if err != nil {
return nil, err
}
ctx := context.Background()
query, err := rego.New(
rego.Query("result = data.http.allow"),
rego.Module("inline.rego", meta.Rego),
).PrepareForEval(ctx)
).PrepareForEval(context.Background())
if err != nil {
return nil, err
}
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
if allow := m.evalRequest(ctx, meta, &query); !allow {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if allow := m.evalRequest(w, r, meta, &query); !allow {
return
}
h(ctx)
}
next.ServeHTTP(w, r)
})
}, nil
}
func (m *Middleware) evalRequest(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, query *rego.PreparedEvalQuery) bool {
func (m *Middleware) evalRequest(w http.ResponseWriter, r *http.Request, meta *middlewareMetadata, query *rego.PreparedEvalQuery) bool {
headers := map[string]string{}
allowedHeaders := strings.Split(meta.IncludedHeaders, ",")
ctx.Request.Header.VisitAll(func(key, value []byte) {
for _, allowedHeader := range allowedHeaders {
scrubbedHeader := strings.ReplaceAll(allowedHeader, " ", "")
buf := []byte("")
result := fasthttp.AppendNormalizedHeaderKeyBytes(buf, []byte(scrubbedHeader))
normalizedHeader := result[0:]
if bytes.Equal(key, normalizedHeader) {
headers[string(key)] = string(value)
}
}
})
queryArgs := map[string][]string{}
ctx.QueryArgs().VisitAll(func(key, value []byte) {
if val, ok := queryArgs[string(key)]; ok {
queryArgs[string(key)] = append(val, string(value))
} else {
queryArgs[string(key)] = []string{string(value)}
for key, value := range r.Header {
if slices.Contains(meta.includedHeadersParsed, key) {
headers[key] = value[0]
}
})
}
path := string(ctx.Path())
pathParts := strings.Split(strings.Trim(path, "/"), "/")
var body string
if !utils.IsTruthy(meta.SkipBody) {
buf, _ := io.ReadAll(r.Body)
body = string(buf)
// Put the body back in the request
r.Body = io.NopCloser(bytes.NewBuffer(buf))
}
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
input := map[string]interface{}{
"request": map[string]interface{}{
"method": string(ctx.Method()),
"path": path,
"method": r.Method,
"path": r.URL.Path,
"path_parts": pathParts,
"raw_query": string(ctx.QueryArgs().QueryString()),
"query": queryArgs,
"raw_query": r.URL.RawQuery,
"query": map[string][]string(r.URL.Query()),
"headers": headers,
"scheme": string(ctx.Request.URI().Scheme()),
"body": string(ctx.Request.Body()),
"scheme": r.URL.Scheme,
"body": body,
},
}
results, err := query.Eval(context.TODO(), rego.EvalInput(input))
results, err := query.Eval(r.Context(), rego.EvalInput(input))
if err != nil {
m.opaError(ctx, meta, err)
m.opaError(w, meta, err)
return false
}
if len(results) == 0 {
m.opaError(ctx, meta, errOpaNoResult)
m.opaError(w, meta, errOpaNoResult)
return false
}
return m.handleRegoResult(ctx, meta, results[0].Bindings["result"])
return m.handleRegoResult(w, meta, results[0].Bindings["result"])
}
// handleRegoResult takes the in process request and open policy agent evaluation result
// and maps it the appropriate response or headers.
// It returns true if the request should continue, or false if a response should be immediately returned.
func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, result interface{}) bool {
func (m *Middleware) handleRegoResult(w http.ResponseWriter, meta *middlewareMetadata, result any) bool {
if allowed, ok := result.(bool); ok {
if !allowed {
ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus))
httputils.RespondWithError(w, int(meta.DefaultStatus))
}
return allowed
}
if _, ok := result.(map[string]interface{}); !ok {
m.opaError(ctx, meta, errOpaInvalidResultType)
if _, ok := result.(map[string]any); !ok {
m.opaError(w, meta, errOpaInvalidResultType)
return false
}
// Is it expensive to marshal back and forth? Should we just manually pull out properties?
marshaled, err := json.Marshal(result)
if err != nil {
m.opaError(ctx, meta, err)
m.opaError(w, meta, err)
return false
}
@ -212,31 +205,26 @@ func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middleware
}
if err = json.Unmarshal(marshaled, &regoResult); err != nil {
m.opaError(ctx, meta, err)
m.opaError(w, meta, err)
return false
}
// If the result isn't allowed, set the response status and
// apply the additional headers to the response.
// Otherwise, set the headers on the ongoing request (overriding as necessary).
// Set the headers on the ongoing request (overriding as necessary)
for key, value := range regoResult.AdditionalHeaders {
w.Header().Set(key, value)
}
// If the result isn't allowed, set the response status
if !regoResult.Allow {
ctx.Error(fasthttp.StatusMessage(regoResult.StatusCode), regoResult.StatusCode)
for key, value := range regoResult.AdditionalHeaders {
ctx.Response.Header.Set(key, value)
}
} else {
for key, value := range regoResult.AdditionalHeaders {
ctx.Request.Header.Set(key, value)
}
httputils.RespondWithError(w, regoResult.StatusCode)
}
return regoResult.Allow
}
func (m *Middleware) opaError(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, err error) {
ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus))
ctx.Response.Header.Set(opaErrorHeaderKey, "true")
func (m *Middleware) opaError(w http.ResponseWriter, meta *middlewareMetadata, err error) {
w.Header().Set(opaErrorHeaderKey, "true")
httputils.RespondWithError(w, int(meta.DefaultStatus))
m.logger.Warnf("Error procesing rego policy: %v", err)
}
@ -254,5 +242,16 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*middlewar
return nil, err
}
meta.includedHeadersParsed = strings.Split(meta.IncludedHeaders, ",")
n := 0
for i := range meta.includedHeadersParsed {
scrubbed := strings.ReplaceAll(meta.includedHeadersParsed[i], " ", "")
if scrubbed != "" {
meta.includedHeadersParsed[n] = textproto.CanonicalMIMEHeaderKey(scrubbed)
n++
}
}
meta.includedHeadersParsed = meta.includedHeadersParsed[:n]
return &meta, nil
}

View File

@ -15,11 +15,13 @@ package opa
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
fh "github.com/valyala/fasthttp"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/middleware"
@ -27,17 +29,15 @@ import (
)
// mockedRequestHandler acts like an upstream service returns success status code 200 and a fixed response body.
func mockedRequestHandler(ctx *fh.RequestCtx) {
ctx.Response.SetStatusCode(200)
ctx.Response.SetBody([]byte("from mock"))
func mockedRequestHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("from mock"))
}
type RequestConfiguator func(*fh.RequestCtx)
func TestOpaPolicy(t *testing.T) {
tests := map[string]struct {
meta middleware.Metadata
req RequestConfiguator
req func() *http.Request
status int
headers *[][]string
body []string
@ -124,9 +124,8 @@ func TestOpaPolicy(t *testing.T) {
`,
},
}},
req: func(ctx *fh.RequestCtx) {
ctx.Request.SetHost("https://my.site")
ctx.Request.URI().SetPath("/allowed")
req: func() *http.Request {
return httptest.NewRequest("GET", "https://my.site/allowed", nil)
},
status: 200,
},
@ -143,9 +142,8 @@ func TestOpaPolicy(t *testing.T) {
`,
},
}},
req: func(ctx *fh.RequestCtx) {
ctx.Request.SetHost("https://my.site")
ctx.Request.URI().SetPath("/forbidden")
req: func() *http.Request {
return httptest.NewRequest("GET", "https://my.site/forbidden", nil)
},
status: 403,
},
@ -162,9 +160,10 @@ func TestOpaPolicy(t *testing.T) {
`,
},
}},
req: func(ctx *fh.RequestCtx) {
ctx.Request.SetHost("https://my.site")
ctx.Request.Header.Add("x-bad-header", "1")
req: func() *http.Request {
r := httptest.NewRequest("GET", "https://my.site", nil)
r.Header.Add("x-bad-header", "1")
return r
},
status: 200,
},
@ -182,9 +181,10 @@ func TestOpaPolicy(t *testing.T) {
"includedHeaders": "x-bad-header",
},
}},
req: func(ctx *fh.RequestCtx) {
ctx.Request.SetHost("https://my.site")
ctx.Request.Header.Add("x-bad-header", "1")
req: func() *http.Request {
r := httptest.NewRequest("GET", "https://my.site", nil)
r.Header.Add("X-BAD-HEADER", "1")
return r
},
status: 403,
},
@ -245,27 +245,46 @@ func TestOpaPolicy(t *testing.T) {
"rego": `
package http
default allow = false
allow = { "status_code": 200 } {
allow = { "allow": true } {
input.request.body == "allow"
}
`,
},
}},
req: func(ctx *fh.RequestCtx) {
ctx.SetContentType("text/plain; charset=utf8")
ctx.Request.SetHost("https://my.site")
ctx.Request.SetBodyString("allow")
req: func() *http.Request {
r := httptest.NewRequest("GET", "https://my.site", strings.NewReader("allow"))
r.Header.Add("content-type", "text/plain; charset=utf8")
return r
},
status: 200,
},
"skip reading body": {
meta: middleware.Metadata{Base: metadata.Base{
Properties: map[string]string{
"skipBody": "true",
"rego": `
package http
default allow = false
allow = { "status_code": 403 } {
input.request.body == "allow"
}
`,
},
}},
req: func() *http.Request {
r := httptest.NewRequest("GET", "https://my.site", strings.NewReader("allow"))
r.Header.Add("content-type", "text/plain; charset=utf8")
return r
},
status: 403,
},
"allow when multiple headers included with space": {
meta: middleware.Metadata{Base: metadata.Base{
Properties: map[string]string{
"rego": `
package http
default allow = false
allow = { "status_code": 200 } {
allow = { "allow": true } {
input.request.headers["X-Jwt-Header"]
input.request.headers["X-My-Custom-Header"]
}
@ -273,47 +292,76 @@ func TestOpaPolicy(t *testing.T) {
"includedHeaders": "x-my-custom-header, x-jwt-header",
},
}},
req: func(ctx *fh.RequestCtx) {
ctx.Request.SetHost("https://my.site")
ctx.Request.Header.Add("x-jwt-header", "1")
ctx.Request.Header.Add("x-my-custom-header", "2")
req: func() *http.Request {
r := httptest.NewRequest("GET", "https://my.site", nil)
r.Header.Add("x-jwt-header", "1")
r.Header.Add("x-my-custom-header", "2")
return r
},
status: 200,
},
"reject when multiple headers included with space": {
meta: middleware.Metadata{Base: metadata.Base{
Properties: map[string]string{
"rego": `
package http
default allow = false
allow = { "allow": true } {
input.request.headers["X-Jwt-Header"]
input.request.headers["X-My-Custom-Header"]
}
`,
"includedHeaders": "x-my-custom-header, x-jwt-header",
},
}},
req: func() *http.Request {
r := httptest.NewRequest("GET", "https://my.site", nil)
r.Header.Add("x-jwt-header", "1")
r.Header.Add("x-bad-header", "2")
return r
},
status: 403,
},
}
log := logger.NewLogger("opa.test")
for name, test := range tests {
t.Run(name, func(t *testing.T) {
log := logger.NewLogger("opa.test")
opaMiddleware := NewMiddleware(log)
handler, err := opaMiddleware.GetHandler(test.meta)
handler, err := opaMiddleware.GetHandler(test.meta)
if test.shouldHandlerError {
require.Error(t, err)
return
}
require.NoError(t, err)
var reqCtx fh.RequestCtx
var r *http.Request
if test.req != nil {
test.req(&reqCtx)
r = test.req()
} else {
r = httptest.NewRequest("GET", "https://my.site", nil)
}
handler(mockedRequestHandler)(&reqCtx)
w := httptest.NewRecorder()
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
if test.shouldRegoError {
assert.Equal(t, 403, reqCtx.Response.StatusCode())
assert.Equal(t, "true", string(reqCtx.Response.Header.Peek(opaErrorHeaderKey)))
assert.Equal(t, 403, w.Code)
assert.Equal(t, "true", w.Header().Get(opaErrorHeaderKey))
return
}
assert.Equal(t, test.status, reqCtx.Response.StatusCode())
assert.Equal(t, test.status, w.Code)
if test.status == 200 {
assert.Equal(t, "from mock", w.Body.String())
}
if test.headers != nil {
for _, header := range *test.headers {
assert.Equal(t, header[1], string(reqCtx.Response.Header.Peek(header[0])))
assert.Equal(t, header[1], w.Header().Get(header[0]))
}
}
})

View File

@ -15,14 +15,12 @@ package ratelimit
import (
"fmt"
"net/http"
"strconv"
"github.com/didip/tollbooth"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttpadaptor"
"github.com/dapr/components-contrib/middleware"
"github.com/dapr/components-contrib/middleware/http/nethttpadaptor"
"github.com/dapr/kit/logger"
)
@ -39,17 +37,16 @@ const (
)
// NewRateLimitMiddleware returns a new ratelimit middleware.
func NewRateLimitMiddleware(logger logger.Logger) middleware.Middleware {
return &Middleware{logger: logger}
func NewRateLimitMiddleware(_ logger.Logger) middleware.Middleware {
return &Middleware{}
}
// Middleware is an ratelimit middleware.
type Middleware struct {
logger logger.Logger
}
// GetHandler returns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) {
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata)
if err != nil {
return nil, err
@ -57,13 +54,8 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.R
limiter := tollbooth.NewLimiter(meta.MaxRequestsPerSecond, nil)
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler {
limitHandler := tollbooth.LimitFuncHandler(limiter, nethttpadaptor.NewNetHTTPHandlerFunc(m.logger, h))
wrappedHandler := fasthttpadaptor.NewFastHTTPHandlerFunc(limitHandler.ServeHTTP)
return func(ctx *fasthttp.RequestCtx) {
wrappedHandler(ctx)
}
return func(next http.Handler) http.Handler {
return tollbooth.LimitHandler(limiter, next)
}, nil
}
@ -74,7 +66,7 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*rateLimit
if val, ok := metadata.Properties[maxRequestsPerSecondKey]; ok {
f, err := strconv.ParseFloat(val, 64)
if err != nil {
return nil, fmt.Errorf("error parsing ratelimit middleware property %s: %+v", maxRequestsPerSecondKey, err)
return nil, fmt.Errorf("error parsing ratelimit middleware property %s: %w", maxRequestsPerSecondKey, err)
}
if f <= 0 {
return nil, fmt.Errorf("ratelimit middleware property %s must be a positive value", maxRequestsPerSecondKey)