Updated more middlewares

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2022-09-29 18:25:52 +00:00
parent 8557183752
commit e3d2ada01c
6 changed files with 244 additions and 211 deletions

View File

@ -82,7 +82,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
session := sessions.Start(w, r) session := sessions.Start(w, r)
if session.GetString(meta.AuthHeaderName) != "" { if session.GetString(meta.AuthHeaderName) != "" {
w.Header().Set(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName)) w.Header().Add(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName))
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
@ -135,7 +135,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
authHeader := token.Type() + " " + token.AccessToken authHeader := token.Type() + " " + token.AccessToken
session.Set(meta.AuthHeaderName, authHeader) session.Set(meta.AuthHeaderName, authHeader)
w.Header().Set(meta.AuthHeaderName, authHeader) w.Header().Add(meta.AuthHeaderName, authHeader)
httputils.RespondWithRedirect(w, http.StatusFound, redirectURL.String()) httputils.RespondWithRedirect(w, http.StatusFound, redirectURL.String())
} }
}) })

View File

@ -16,18 +16,18 @@ package oauth2clientcredentials
import ( import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/json" "encoding/hex"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"time" "time"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"github.com/valyala/fasthttp"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials" "golang.org/x/oauth2/clientcredentials"
mdutils "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/middleware" "github.com/dapr/components-contrib/middleware"
"github.com/dapr/kit/logger" "github.com/dapr/kit/logger"
) )
@ -40,8 +40,7 @@ type oAuth2ClientCredentialsMiddlewareMetadata struct {
TokenURL string `json:"tokenURL"` TokenURL string `json:"tokenURL"`
HeaderName string `json:"headerName"` HeaderName string `json:"headerName"`
EndpointParamsQuery string `json:"endpointParamsQuery,omitempty"` EndpointParamsQuery string `json:"endpointParamsQuery,omitempty"`
AuthStyleString string `json:"authStyle"` AuthStyle int `json:"authStyle"`
AuthStyle int `json:"-"`
} }
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests. // TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests.
@ -69,53 +68,47 @@ type Middleware struct {
} }
// GetHandler retruns the HTTP handler provided by the middleware. // GetHandler retruns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) { func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata) meta, err := m.getNativeMetadata(metadata)
if err != nil { if err != nil {
m.log.Errorf("getNativeMetadata error, %s", err) m.log.Errorf("getNativeMetadata error: %s", err)
return nil, err return nil, err
} }
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler { endpointParams, err := url.ParseQuery(meta.EndpointParamsQuery)
return func(ctx *fasthttp.RequestCtx) { if err != nil {
var headerValue string m.log.Errorf("Error parsing endpoint parameters: %s", err)
// Check if valid Token is in the cache endpointParams, _ = url.ParseQuery("")
cacheKey := m.getCacheKey(meta) }
cachedToken, found := m.tokenCache.Get(cacheKey)
conf := &clientcredentials.Config{
ClientID: meta.ClientID,
ClientSecret: meta.ClientSecret,
Scopes: strings.Split(meta.Scopes, ","),
TokenURL: meta.TokenURL,
EndpointParams: endpointParams,
AuthStyle: oauth2.AuthStyle(meta.AuthStyle),
}
cacheKey := m.getCacheKey(meta)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var headerValue string
// Check if valid token is in the cache
cachedToken, found := m.tokenCache.Get(cacheKey)
if !found { if !found {
m.log.Debugf("Cached token not found, try get one") m.log.Debugf("Cached token not found, try get one")
endpointParams, err := url.ParseQuery(meta.EndpointParamsQuery) token, err := m.tokenProvider.GetToken(conf)
if err != nil { if err != nil {
m.log.Errorf("Error parsing endpoint parameters, %s", err) m.log.Errorf("Error acquiring token: %s", err)
endpointParams, _ = url.ParseQuery("")
}
conf := &clientcredentials.Config{
ClientID: meta.ClientID,
ClientSecret: meta.ClientSecret,
Scopes: strings.Split(meta.Scopes, ","),
TokenURL: meta.TokenURL,
EndpointParams: endpointParams,
AuthStyle: oauth2.AuthStyle(meta.AuthStyle),
}
token, tokenError := m.tokenProvider.GetToken(conf)
if tokenError != nil {
m.log.Errorf("Error acquiring token, %s", tokenError)
return return
} }
tokenExpirationDuration := token.Expiry.Sub(time.Now().In(time.UTC)) tokenExpirationDuration := token.Expiry.Sub(time.Now())
m.log.Debugf("Duration in seconds %s, Expiry Time %s", tokenExpirationDuration, token.Expiry) m.log.Debugf("Token expires at %s (%s from now)", token.Expiry, tokenExpirationDuration)
if err != nil {
m.log.Errorf("Error parsing duration string, %s", fmt.Sprintf("%ss", token.Expiry))
return
}
headerValue = token.Type() + " " + token.AccessToken headerValue = token.Type() + " " + token.AccessToken
m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration) m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration)
@ -124,46 +117,37 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.R
headerValue = cachedToken.(string) headerValue = cachedToken.(string)
} }
ctx.Request.Header.Add(meta.HeaderName, headerValue) w.Header().Add(meta.HeaderName, headerValue)
h(ctx) next.ServeHTTP(w, r)
} })
}, nil }, nil
} }
func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2ClientCredentialsMiddlewareMetadata, error) { func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2ClientCredentialsMiddlewareMetadata, error) {
b, err := json.Marshal(metadata.Properties)
if err != nil {
return nil, err
}
var middlewareMetadata oAuth2ClientCredentialsMiddlewareMetadata var middlewareMetadata oAuth2ClientCredentialsMiddlewareMetadata
err = json.Unmarshal(b, &middlewareMetadata) err := mdutils.DecodeMetadata(metadata.Properties, &middlewareMetadata)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("metadata errors: %w", err)
} }
// Do input validation checks // Do input validation checks
errorString := "" errorString := ""
// Check if values are present // Check if values are present
m.checkMetadataValueExists(&errorString, &middlewareMetadata.HeaderName, "headerName") m.checkMetadataValueExists(&errorString, &middlewareMetadata.HeaderName, "headerName")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientID, "clientID") m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientID, "clientID")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientSecret, "clientSecret") m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientSecret, "clientSecret")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.Scopes, "scopes") m.checkMetadataValueExists(&errorString, &middlewareMetadata.Scopes, "scopes")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.TokenURL, "tokenURL") m.checkMetadataValueExists(&errorString, &middlewareMetadata.TokenURL, "tokenURL")
m.checkMetadataValueExists(&errorString, &middlewareMetadata.AuthStyleString, "authStyle")
// Converting AuthStyle to int and do a value check // Value-check AuthStyle
authStyle, err := strconv.Atoi(middlewareMetadata.AuthStyleString) if middlewareMetadata.AuthStyle < 0 || middlewareMetadata.AuthStyle > 2 {
if err != nil { errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", middlewareMetadata.AuthStyle)
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%s'. ", middlewareMetadata.AuthStyleString)
} else if authStyle < 0 || authStyle > 2 {
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", authStyle)
} else {
middlewareMetadata.AuthStyle = authStyle
} }
// Return errors if any found // Return errors if any found
if errorString != "" { if errorString != "" {
return nil, fmt.Errorf("%s", errorString) return nil, fmt.Errorf("metadata errors: %s", errorString)
} }
return &middlewareMetadata, nil return &middlewareMetadata, nil
@ -177,11 +161,8 @@ func (m *Middleware) checkMetadataValueExists(errorString *string, metadataValue
func (m *Middleware) getCacheKey(meta *oAuth2ClientCredentialsMiddlewareMetadata) string { func (m *Middleware) getCacheKey(meta *oAuth2ClientCredentialsMiddlewareMetadata) string {
// we will hash the key components ClientID + Scopes is a unique composite key/identifier for a token // we will hash the key components ClientID + Scopes is a unique composite key/identifier for a token
hashedKey := sha256.New() hashedKey := sha256.Sum224([]byte(meta.ClientID + meta.Scopes))
key := strings.Join([]string{meta.ClientID, meta.Scopes}, "") return hex.EncodeToString(hashedKey[:])
hashedKey.Write([]byte(key))
return fmt.Sprintf("%x", hashedKey.Sum(nil))
} }
// SetTokenProvider will enable to change the tokenProvider used after instanciation (needed for mocking). // SetTokenProvider will enable to change the tokenProvider used after instanciation (needed for mocking).

View File

@ -14,13 +14,14 @@ limitations under the License.
package oauth2clientcredentials package oauth2clientcredentials
import ( import (
"net/http"
"net/http/httptest"
"testing" "testing"
"time" "time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fh "github.com/valyala/fasthttp"
oauth2 "golang.org/x/oauth2" oauth2 "golang.org/x/oauth2"
"github.com/dapr/components-contrib/middleware" "github.com/dapr/components-contrib/middleware"
@ -28,7 +29,11 @@ import (
"github.com/dapr/kit/logger" "github.com/dapr/kit/logger"
) )
func mockedRequestHandler(ctx *fh.RequestCtx) {} // mockedRequestHandler acts like an upstream service returns success status code 200 and a fixed response body.
func mockedRequestHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("from mock"))
}
// TestOAuth2ClientCredentialsMetadata will check // TestOAuth2ClientCredentialsMetadata will check
// - if the metadata checks are correct in place. // - if the metadata checks are correct in place.
@ -41,7 +46,7 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
log := logger.NewLogger("oauth2clientcredentials.test") log := logger.NewLogger("oauth2clientcredentials.test")
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) _, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
assert.EqualError(t, err, "Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. Parameter 'authStyle' needs to be set. Parameter 'authStyle' can only have the values 0,1,2. Received: ''. ") assert.EqualError(t, err, "metadata errors: Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. ")
// Invalid authStyle (non int) // Invalid authStyle (non int)
metadata.Properties = map[string]string{ metadata.Properties = map[string]string{
@ -53,17 +58,17 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
"authStyle": "asdf", // This is the value to test "authStyle": "asdf", // This is the value to test
} }
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) _, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
assert.EqualError(t, err2, "Parameter 'authStyle' can only have the values 0,1,2. Received: 'asdf'. ") assert.EqualError(t, err2, "metadata errors: 1 error(s) decoding:\n\n* cannot parse 'AuthStyle' as int: strconv.ParseInt: parsing \"asdf\": invalid syntax")
// Invalid authStyle (int > 2) // Invalid authStyle (int > 2)
metadata.Properties["authStyle"] = "3" metadata.Properties["authStyle"] = "3"
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) _, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
assert.EqualError(t, err3, "Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ") assert.EqualError(t, err3, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
// Invalid authStyle (int < 0) // Invalid authStyle (int < 0)
metadata.Properties["authStyle"] = "-1" metadata.Properties["authStyle"] = "-1"
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) _, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
assert.EqualError(t, err4, "Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ") assert.EqualError(t, err4, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
} }
// TestOAuth2ClientCredentialsToken will check // TestOAuth2ClientCredentialsToken will check
@ -108,10 +113,12 @@ func TestOAuth2ClientCredentialsToken(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// First handler call should return abc Token // First handler call should return abc Token
var requestContext1 fh.RequestCtx r := httptest.NewRequest("GET", "http://dapr.io", nil)
handler(mockedRequestHandler)(&requestContext1) w := httptest.NewRecorder()
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
// Assertion // Assertion
assert.Equal(t, "Bearer abcd", string(requestContext1.Request.Header.Peek("someHeader"))) assert.Equal(t, "Bearer abcd", w.Header().Get("someHeader"))
} }
// TestOAuth2ClientCredentialsCache will check // TestOAuth2ClientCredentialsCache will check
@ -166,23 +173,29 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// First handler call should return abc Token // First handler call should return abc Token
var requestContext1 fh.RequestCtx r := httptest.NewRequest("GET", "http://dapr.io", nil)
handler(mockedRequestHandler)(&requestContext1) w := httptest.NewRecorder()
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
// Assertion // Assertion
assert.Equal(t, "Bearer abc", string(requestContext1.Request.Header.Peek("someHeader"))) assert.Equal(t, "Bearer abc", w.Header().Get("someHeader"))
// Second handler call should still return 'cached' abc Token // Second handler call should still return 'cached' abc Token
var requestContext2 fh.RequestCtx r = httptest.NewRequest("GET", "http://dapr.io", nil)
handler(mockedRequestHandler)(&requestContext2) w = httptest.NewRecorder()
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
// Assertion // Assertion
assert.Equal(t, "Bearer abc", string(requestContext2.Request.Header.Peek("someHeader"))) assert.Equal(t, "Bearer abc", w.Header().Get("someHeader"))
// Wait at a second to invalidate cache entry for abc // Wait at a second to invalidate cache entry for abc
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
// Third call should return def Token // Third call should return def Token
var requestContext3 fh.RequestCtx r = httptest.NewRequest("GET", "http://dapr.io", nil)
handler(mockedRequestHandler)(&requestContext3) w = httptest.NewRecorder()
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
// Assertion // Assertion
assert.Equal(t, "MAC def", string(requestContext3.Request.Header.Peek("someHeader"))) assert.Equal(t, "MAC def", w.Header().Get("someHeader"))
} }

View File

@ -19,13 +19,18 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"math" "math"
"net/http"
"net/textproto"
"strconv" "strconv"
"strings" "strings"
"github.com/open-policy-agent/opa/rego" "github.com/open-policy-agent/opa/rego"
"github.com/valyala/fasthttp" "k8s.io/utils/strings/slices"
"github.com/dapr/components-contrib/internal/httputils"
"github.com/dapr/components-contrib/internal/utils"
"github.com/dapr/components-contrib/middleware" "github.com/dapr/components-contrib/middleware"
"github.com/dapr/kit/logger" "github.com/dapr/kit/logger"
) )
@ -33,9 +38,11 @@ import (
type Status int type Status int
type middlewareMetadata struct { type middlewareMetadata struct {
Rego string `json:"rego"` Rego string `json:"rego"`
DefaultStatus Status `json:"defaultStatus,omitempty"` DefaultStatus Status `json:"defaultStatus,omitempty"`
IncludedHeaders string `json:"includedHeaders,omitempty"` IncludedHeaders string `json:"includedHeaders,omitempty"`
SkipBody string `json:"skipBody,omitempty"`
includedHeadersParsed []string `json:"-"`
} }
// NewMiddleware returns a new Open Policy Agent middleware. // NewMiddleware returns a new Open Policy Agent middleware.
@ -98,110 +105,96 @@ func (s *Status) Valid() bool {
} }
// GetHandler returns the HTTP handler provided by the middleware. // GetHandler returns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) { func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata) meta, err := m.getNativeMetadata(metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx := context.Background()
query, err := rego.New( query, err := rego.New(
rego.Query("result = data.http.allow"), rego.Query("result = data.http.allow"),
rego.Module("inline.rego", meta.Rego), rego.Module("inline.rego", meta.Rego),
).PrepareForEval(ctx) ).PrepareForEval(context.Background())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler { return func(next http.Handler) http.Handler {
return func(ctx *fasthttp.RequestCtx) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if allow := m.evalRequest(ctx, meta, &query); !allow { if allow := m.evalRequest(w, r, meta, &query); !allow {
return return
} }
h(ctx) next.ServeHTTP(w, r)
} })
}, nil }, nil
} }
func (m *Middleware) evalRequest(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, query *rego.PreparedEvalQuery) bool { func (m *Middleware) evalRequest(w http.ResponseWriter, r *http.Request, meta *middlewareMetadata, query *rego.PreparedEvalQuery) bool {
headers := map[string]string{} headers := map[string]string{}
allowedHeaders := strings.Split(meta.IncludedHeaders, ",")
ctx.Request.Header.VisitAll(func(key, value []byte) {
for _, allowedHeader := range allowedHeaders {
scrubbedHeader := strings.ReplaceAll(allowedHeader, " ", "")
buf := []byte("")
result := fasthttp.AppendNormalizedHeaderKeyBytes(buf, []byte(scrubbedHeader))
normalizedHeader := result[0:]
if bytes.Equal(key, normalizedHeader) {
headers[string(key)] = string(value)
}
}
})
queryArgs := map[string][]string{} for key, value := range r.Header {
ctx.QueryArgs().VisitAll(func(key, value []byte) { if slices.Contains(meta.includedHeadersParsed, key) {
if val, ok := queryArgs[string(key)]; ok { headers[key] = value[0]
queryArgs[string(key)] = append(val, string(value))
} else {
queryArgs[string(key)] = []string{string(value)}
} }
}) }
path := string(ctx.Path()) var body string
pathParts := strings.Split(strings.Trim(path, "/"), "/") if !utils.IsTruthy(meta.SkipBody) {
buf, _ := io.ReadAll(r.Body)
body = string(buf)
// Put the body back in the request
r.Body = io.NopCloser(bytes.NewBuffer(buf))
}
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
input := map[string]interface{}{ input := map[string]interface{}{
"request": map[string]interface{}{ "request": map[string]interface{}{
"method": string(ctx.Method()), "method": r.Method,
"path": path, "path": r.URL.Path,
"path_parts": pathParts, "path_parts": pathParts,
"raw_query": string(ctx.QueryArgs().QueryString()), "raw_query": r.URL.RawQuery,
"query": queryArgs, "query": map[string][]string(r.URL.Query()),
"headers": headers, "headers": headers,
"scheme": string(ctx.Request.URI().Scheme()), "scheme": r.URL.Scheme,
"body": string(ctx.Request.Body()), "body": body,
}, },
} }
results, err := query.Eval(context.TODO(), rego.EvalInput(input)) results, err := query.Eval(r.Context(), rego.EvalInput(input))
if err != nil { if err != nil {
m.opaError(ctx, meta, err) m.opaError(w, meta, err)
return false return false
} }
if len(results) == 0 { if len(results) == 0 {
m.opaError(ctx, meta, errOpaNoResult) m.opaError(w, meta, errOpaNoResult)
return false return false
} }
return m.handleRegoResult(ctx, meta, results[0].Bindings["result"]) return m.handleRegoResult(w, meta, results[0].Bindings["result"])
} }
// handleRegoResult takes the in process request and open policy agent evaluation result // handleRegoResult takes the in process request and open policy agent evaluation result
// and maps it the appropriate response or headers. // and maps it the appropriate response or headers.
// It returns true if the request should continue, or false if a response should be immediately returned. // It returns true if the request should continue, or false if a response should be immediately returned.
func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, result interface{}) bool { func (m *Middleware) handleRegoResult(w http.ResponseWriter, meta *middlewareMetadata, result any) bool {
if allowed, ok := result.(bool); ok { if allowed, ok := result.(bool); ok {
if !allowed { if !allowed {
ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus)) httputils.RespondWithError(w, int(meta.DefaultStatus))
} }
return allowed return allowed
} }
if _, ok := result.(map[string]interface{}); !ok { if _, ok := result.(map[string]any); !ok {
m.opaError(ctx, meta, errOpaInvalidResultType) m.opaError(w, meta, errOpaInvalidResultType)
return false return false
} }
// Is it expensive to marshal back and forth? Should we just manually pull out properties? // Is it expensive to marshal back and forth? Should we just manually pull out properties?
marshaled, err := json.Marshal(result) marshaled, err := json.Marshal(result)
if err != nil { if err != nil {
m.opaError(ctx, meta, err) m.opaError(w, meta, err)
return false return false
} }
@ -212,31 +205,26 @@ func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middleware
} }
if err = json.Unmarshal(marshaled, &regoResult); err != nil { if err = json.Unmarshal(marshaled, &regoResult); err != nil {
m.opaError(ctx, meta, err) m.opaError(w, meta, err)
return false return false
} }
// If the result isn't allowed, set the response status and // Set the headers on the ongoing request (overriding as necessary)
// apply the additional headers to the response. for key, value := range regoResult.AdditionalHeaders {
// Otherwise, set the headers on the ongoing request (overriding as necessary). w.Header().Set(key, value)
}
// If the result isn't allowed, set the response status
if !regoResult.Allow { if !regoResult.Allow {
ctx.Error(fasthttp.StatusMessage(regoResult.StatusCode), regoResult.StatusCode) httputils.RespondWithError(w, regoResult.StatusCode)
for key, value := range regoResult.AdditionalHeaders {
ctx.Response.Header.Set(key, value)
}
} else {
for key, value := range regoResult.AdditionalHeaders {
ctx.Request.Header.Set(key, value)
}
} }
return regoResult.Allow return regoResult.Allow
} }
func (m *Middleware) opaError(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, err error) { func (m *Middleware) opaError(w http.ResponseWriter, meta *middlewareMetadata, err error) {
ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus)) w.Header().Set(opaErrorHeaderKey, "true")
ctx.Response.Header.Set(opaErrorHeaderKey, "true") httputils.RespondWithError(w, int(meta.DefaultStatus))
m.logger.Warnf("Error procesing rego policy: %v", err) m.logger.Warnf("Error procesing rego policy: %v", err)
} }
@ -254,5 +242,16 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*middlewar
return nil, err return nil, err
} }
meta.includedHeadersParsed = strings.Split(meta.IncludedHeaders, ",")
n := 0
for i := range meta.includedHeadersParsed {
scrubbed := strings.ReplaceAll(meta.includedHeadersParsed[i], " ", "")
if scrubbed != "" {
meta.includedHeadersParsed[n] = textproto.CanonicalMIMEHeaderKey(scrubbed)
n++
}
}
meta.includedHeadersParsed = meta.includedHeadersParsed[:n]
return &meta, nil return &meta, nil
} }

View File

@ -15,11 +15,13 @@ package opa
import ( import (
"encoding/json" "encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fh "github.com/valyala/fasthttp"
"github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/middleware" "github.com/dapr/components-contrib/middleware"
@ -27,17 +29,15 @@ import (
) )
// mockedRequestHandler acts like an upstream service returns success status code 200 and a fixed response body. // mockedRequestHandler acts like an upstream service returns success status code 200 and a fixed response body.
func mockedRequestHandler(ctx *fh.RequestCtx) { func mockedRequestHandler(w http.ResponseWriter, r *http.Request) {
ctx.Response.SetStatusCode(200) w.WriteHeader(http.StatusOK)
ctx.Response.SetBody([]byte("from mock")) w.Write([]byte("from mock"))
} }
type RequestConfiguator func(*fh.RequestCtx)
func TestOpaPolicy(t *testing.T) { func TestOpaPolicy(t *testing.T) {
tests := map[string]struct { tests := map[string]struct {
meta middleware.Metadata meta middleware.Metadata
req RequestConfiguator req func() *http.Request
status int status int
headers *[][]string headers *[][]string
body []string body []string
@ -124,9 +124,8 @@ func TestOpaPolicy(t *testing.T) {
`, `,
}, },
}}, }},
req: func(ctx *fh.RequestCtx) { req: func() *http.Request {
ctx.Request.SetHost("https://my.site") return httptest.NewRequest("GET", "https://my.site/allowed", nil)
ctx.Request.URI().SetPath("/allowed")
}, },
status: 200, status: 200,
}, },
@ -143,9 +142,8 @@ func TestOpaPolicy(t *testing.T) {
`, `,
}, },
}}, }},
req: func(ctx *fh.RequestCtx) { req: func() *http.Request {
ctx.Request.SetHost("https://my.site") return httptest.NewRequest("GET", "https://my.site/forbidden", nil)
ctx.Request.URI().SetPath("/forbidden")
}, },
status: 403, status: 403,
}, },
@ -162,9 +160,10 @@ func TestOpaPolicy(t *testing.T) {
`, `,
}, },
}}, }},
req: func(ctx *fh.RequestCtx) { req: func() *http.Request {
ctx.Request.SetHost("https://my.site") r := httptest.NewRequest("GET", "https://my.site", nil)
ctx.Request.Header.Add("x-bad-header", "1") r.Header.Add("x-bad-header", "1")
return r
}, },
status: 200, status: 200,
}, },
@ -182,9 +181,10 @@ func TestOpaPolicy(t *testing.T) {
"includedHeaders": "x-bad-header", "includedHeaders": "x-bad-header",
}, },
}}, }},
req: func(ctx *fh.RequestCtx) { req: func() *http.Request {
ctx.Request.SetHost("https://my.site") r := httptest.NewRequest("GET", "https://my.site", nil)
ctx.Request.Header.Add("x-bad-header", "1") r.Header.Add("X-BAD-HEADER", "1")
return r
}, },
status: 403, status: 403,
}, },
@ -245,27 +245,46 @@ func TestOpaPolicy(t *testing.T) {
"rego": ` "rego": `
package http package http
default allow = false default allow = false
allow = { "allow": true } {
allow = { "status_code": 200 } {
input.request.body == "allow" input.request.body == "allow"
} }
`, `,
}, },
}}, }},
req: func(ctx *fh.RequestCtx) { req: func() *http.Request {
ctx.SetContentType("text/plain; charset=utf8") r := httptest.NewRequest("GET", "https://my.site", strings.NewReader("allow"))
ctx.Request.SetHost("https://my.site") r.Header.Add("content-type", "text/plain; charset=utf8")
ctx.Request.SetBodyString("allow") return r
}, },
status: 200, status: 200,
}, },
"skip reading body": {
meta: middleware.Metadata{Base: metadata.Base{
Properties: map[string]string{
"skipBody": "true",
"rego": `
package http
default allow = false
allow = { "status_code": 403 } {
input.request.body == "allow"
}
`,
},
}},
req: func() *http.Request {
r := httptest.NewRequest("GET", "https://my.site", strings.NewReader("allow"))
r.Header.Add("content-type", "text/plain; charset=utf8")
return r
},
status: 403,
},
"allow when multiple headers included with space": { "allow when multiple headers included with space": {
meta: middleware.Metadata{Base: metadata.Base{ meta: middleware.Metadata{Base: metadata.Base{
Properties: map[string]string{ Properties: map[string]string{
"rego": ` "rego": `
package http package http
default allow = false default allow = false
allow = { "status_code": 200 } { allow = { "allow": true } {
input.request.headers["X-Jwt-Header"] input.request.headers["X-Jwt-Header"]
input.request.headers["X-My-Custom-Header"] input.request.headers["X-My-Custom-Header"]
} }
@ -273,47 +292,76 @@ func TestOpaPolicy(t *testing.T) {
"includedHeaders": "x-my-custom-header, x-jwt-header", "includedHeaders": "x-my-custom-header, x-jwt-header",
}, },
}}, }},
req: func(ctx *fh.RequestCtx) { req: func() *http.Request {
ctx.Request.SetHost("https://my.site") r := httptest.NewRequest("GET", "https://my.site", nil)
ctx.Request.Header.Add("x-jwt-header", "1") r.Header.Add("x-jwt-header", "1")
ctx.Request.Header.Add("x-my-custom-header", "2") r.Header.Add("x-my-custom-header", "2")
return r
}, },
status: 200, status: 200,
}, },
"reject when multiple headers included with space": {
meta: middleware.Metadata{Base: metadata.Base{
Properties: map[string]string{
"rego": `
package http
default allow = false
allow = { "allow": true } {
input.request.headers["X-Jwt-Header"]
input.request.headers["X-My-Custom-Header"]
}
`,
"includedHeaders": "x-my-custom-header, x-jwt-header",
},
}},
req: func() *http.Request {
r := httptest.NewRequest("GET", "https://my.site", nil)
r.Header.Add("x-jwt-header", "1")
r.Header.Add("x-bad-header", "2")
return r
},
status: 403,
},
} }
log := logger.NewLogger("opa.test")
for name, test := range tests { for name, test := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
log := logger.NewLogger("opa.test")
opaMiddleware := NewMiddleware(log) opaMiddleware := NewMiddleware(log)
handler, err := opaMiddleware.GetHandler(test.meta)
handler, err := opaMiddleware.GetHandler(test.meta)
if test.shouldHandlerError { if test.shouldHandlerError {
require.Error(t, err) require.Error(t, err)
return return
} }
require.NoError(t, err) require.NoError(t, err)
var reqCtx fh.RequestCtx var r *http.Request
if test.req != nil { if test.req != nil {
test.req(&reqCtx) r = test.req()
} else {
r = httptest.NewRequest("GET", "https://my.site", nil)
} }
handler(mockedRequestHandler)(&reqCtx) w := httptest.NewRecorder()
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
if test.shouldRegoError { if test.shouldRegoError {
assert.Equal(t, 403, reqCtx.Response.StatusCode()) assert.Equal(t, 403, w.Code)
assert.Equal(t, "true", string(reqCtx.Response.Header.Peek(opaErrorHeaderKey))) assert.Equal(t, "true", w.Header().Get(opaErrorHeaderKey))
return return
} }
assert.Equal(t, test.status, reqCtx.Response.StatusCode()) assert.Equal(t, test.status, w.Code)
if test.status == 200 {
assert.Equal(t, "from mock", w.Body.String())
}
if test.headers != nil { if test.headers != nil {
for _, header := range *test.headers { for _, header := range *test.headers {
assert.Equal(t, header[1], string(reqCtx.Response.Header.Peek(header[0]))) assert.Equal(t, header[1], w.Header().Get(header[0]))
} }
} }
}) })

View File

@ -15,14 +15,12 @@ package ratelimit
import ( import (
"fmt" "fmt"
"net/http"
"strconv" "strconv"
"github.com/didip/tollbooth" "github.com/didip/tollbooth"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttpadaptor"
"github.com/dapr/components-contrib/middleware" "github.com/dapr/components-contrib/middleware"
"github.com/dapr/components-contrib/middleware/http/nethttpadaptor"
"github.com/dapr/kit/logger" "github.com/dapr/kit/logger"
) )
@ -39,17 +37,16 @@ const (
) )
// NewRateLimitMiddleware returns a new ratelimit middleware. // NewRateLimitMiddleware returns a new ratelimit middleware.
func NewRateLimitMiddleware(logger logger.Logger) middleware.Middleware { func NewRateLimitMiddleware(_ logger.Logger) middleware.Middleware {
return &Middleware{logger: logger} return &Middleware{}
} }
// Middleware is an ratelimit middleware. // Middleware is an ratelimit middleware.
type Middleware struct { type Middleware struct {
logger logger.Logger
} }
// GetHandler returns the HTTP handler provided by the middleware. // GetHandler returns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) { func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata) meta, err := m.getNativeMetadata(metadata)
if err != nil { if err != nil {
return nil, err return nil, err
@ -57,13 +54,8 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.R
limiter := tollbooth.NewLimiter(meta.MaxRequestsPerSecond, nil) limiter := tollbooth.NewLimiter(meta.MaxRequestsPerSecond, nil)
return func(h fasthttp.RequestHandler) fasthttp.RequestHandler { return func(next http.Handler) http.Handler {
limitHandler := tollbooth.LimitFuncHandler(limiter, nethttpadaptor.NewNetHTTPHandlerFunc(m.logger, h)) return tollbooth.LimitHandler(limiter, next)
wrappedHandler := fasthttpadaptor.NewFastHTTPHandlerFunc(limitHandler.ServeHTTP)
return func(ctx *fasthttp.RequestCtx) {
wrappedHandler(ctx)
}
}, nil }, nil
} }
@ -74,7 +66,7 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*rateLimit
if val, ok := metadata.Properties[maxRequestsPerSecondKey]; ok { if val, ok := metadata.Properties[maxRequestsPerSecondKey]; ok {
f, err := strconv.ParseFloat(val, 64) f, err := strconv.ParseFloat(val, 64)
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing ratelimit middleware property %s: %+v", maxRequestsPerSecondKey, err) return nil, fmt.Errorf("error parsing ratelimit middleware property %s: %w", maxRequestsPerSecondKey, err)
} }
if f <= 0 { if f <= 0 {
return nil, fmt.Errorf("ratelimit middleware property %s must be a positive value", maxRequestsPerSecondKey) return nil, fmt.Errorf("ratelimit middleware property %s must be a positive value", maxRequestsPerSecondKey)