From e3d2ada01ca2a8476e29f1601dd9e6fe23a72062 Mon Sep 17 00:00:00 2001 From: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Date: Thu, 29 Sep 2022 18:25:52 +0000 Subject: [PATCH] Updated more middlewares Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> --- middleware/http/oauth2/oauth2_middleware.go | 4 +- .../oauth2clientcredentials_middleware.go | 107 ++++++------- ...oauth2clientcredentials_middleware_test.go | 49 +++--- middleware/http/opa/middleware.go | 141 +++++++++--------- middleware/http/opa/middleware_test.go | 132 ++++++++++------ .../http/ratelimit/ratelimit_middleware.go | 22 +-- 6 files changed, 244 insertions(+), 211 deletions(-) diff --git a/middleware/http/oauth2/oauth2_middleware.go b/middleware/http/oauth2/oauth2_middleware.go index d04b932a4..76ec22a21 100644 --- a/middleware/http/oauth2/oauth2_middleware.go +++ b/middleware/http/oauth2/oauth2_middleware.go @@ -82,7 +82,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha session := sessions.Start(w, r) if session.GetString(meta.AuthHeaderName) != "" { - w.Header().Set(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName)) + w.Header().Add(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName)) next.ServeHTTP(w, r) return } @@ -135,7 +135,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha authHeader := token.Type() + " " + token.AccessToken session.Set(meta.AuthHeaderName, authHeader) - w.Header().Set(meta.AuthHeaderName, authHeader) + w.Header().Add(meta.AuthHeaderName, authHeader) httputils.RespondWithRedirect(w, http.StatusFound, redirectURL.String()) } }) diff --git a/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware.go b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware.go index f375973b9..56259b68a 100644 --- a/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware.go +++ b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware.go @@ -16,18 +16,18 @@ package oauth2clientcredentials import ( "context" "crypto/sha256" - "encoding/json" + "encoding/hex" "fmt" + "net/http" "net/url" - "strconv" "strings" "time" "github.com/patrickmn/go-cache" - "github.com/valyala/fasthttp" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" + mdutils "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/middleware" "github.com/dapr/kit/logger" ) @@ -40,8 +40,7 @@ type oAuth2ClientCredentialsMiddlewareMetadata struct { TokenURL string `json:"tokenURL"` HeaderName string `json:"headerName"` EndpointParamsQuery string `json:"endpointParamsQuery,omitempty"` - AuthStyleString string `json:"authStyle"` - AuthStyle int `json:"-"` + AuthStyle int `json:"authStyle"` } // TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests. @@ -69,53 +68,47 @@ type Middleware struct { } // GetHandler retruns the HTTP handler provided by the middleware. -func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) { +func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { meta, err := m.getNativeMetadata(metadata) if err != nil { - m.log.Errorf("getNativeMetadata error, %s", err) - + m.log.Errorf("getNativeMetadata error: %s", err) return nil, err } - return func(h fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - var headerValue string - // Check if valid Token is in the cache - cacheKey := m.getCacheKey(meta) - cachedToken, found := m.tokenCache.Get(cacheKey) + endpointParams, err := url.ParseQuery(meta.EndpointParamsQuery) + if err != nil { + m.log.Errorf("Error parsing endpoint parameters: %s", err) + endpointParams, _ = url.ParseQuery("") + } + conf := &clientcredentials.Config{ + ClientID: meta.ClientID, + ClientSecret: meta.ClientSecret, + Scopes: strings.Split(meta.Scopes, ","), + TokenURL: meta.TokenURL, + EndpointParams: endpointParams, + AuthStyle: oauth2.AuthStyle(meta.AuthStyle), + } + + cacheKey := m.getCacheKey(meta) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var headerValue string + + // Check if valid token is in the cache + cachedToken, found := m.tokenCache.Get(cacheKey) if !found { m.log.Debugf("Cached token not found, try get one") - endpointParams, err := url.ParseQuery(meta.EndpointParamsQuery) + token, err := m.tokenProvider.GetToken(conf) if err != nil { - m.log.Errorf("Error parsing endpoint parameters, %s", err) - endpointParams, _ = url.ParseQuery("") - } - - conf := &clientcredentials.Config{ - ClientID: meta.ClientID, - ClientSecret: meta.ClientSecret, - Scopes: strings.Split(meta.Scopes, ","), - TokenURL: meta.TokenURL, - EndpointParams: endpointParams, - AuthStyle: oauth2.AuthStyle(meta.AuthStyle), - } - - token, tokenError := m.tokenProvider.GetToken(conf) - if tokenError != nil { - m.log.Errorf("Error acquiring token, %s", tokenError) - + m.log.Errorf("Error acquiring token: %s", err) return } - tokenExpirationDuration := token.Expiry.Sub(time.Now().In(time.UTC)) - m.log.Debugf("Duration in seconds %s, Expiry Time %s", tokenExpirationDuration, token.Expiry) - if err != nil { - m.log.Errorf("Error parsing duration string, %s", fmt.Sprintf("%ss", token.Expiry)) - - return - } + tokenExpirationDuration := token.Expiry.Sub(time.Now()) + m.log.Debugf("Token expires at %s (%s from now)", token.Expiry, tokenExpirationDuration) headerValue = token.Type() + " " + token.AccessToken m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration) @@ -124,46 +117,37 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.R headerValue = cachedToken.(string) } - ctx.Request.Header.Add(meta.HeaderName, headerValue) - h(ctx) - } + w.Header().Add(meta.HeaderName, headerValue) + next.ServeHTTP(w, r) + }) }, nil } func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2ClientCredentialsMiddlewareMetadata, error) { - b, err := json.Marshal(metadata.Properties) - if err != nil { - return nil, err - } var middlewareMetadata oAuth2ClientCredentialsMiddlewareMetadata - err = json.Unmarshal(b, &middlewareMetadata) + err := mdutils.DecodeMetadata(metadata.Properties, &middlewareMetadata) if err != nil { - return nil, err + return nil, fmt.Errorf("metadata errors: %w", err) } // Do input validation checks errorString := "" + // Check if values are present m.checkMetadataValueExists(&errorString, &middlewareMetadata.HeaderName, "headerName") m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientID, "clientID") m.checkMetadataValueExists(&errorString, &middlewareMetadata.ClientSecret, "clientSecret") m.checkMetadataValueExists(&errorString, &middlewareMetadata.Scopes, "scopes") m.checkMetadataValueExists(&errorString, &middlewareMetadata.TokenURL, "tokenURL") - m.checkMetadataValueExists(&errorString, &middlewareMetadata.AuthStyleString, "authStyle") - // Converting AuthStyle to int and do a value check - authStyle, err := strconv.Atoi(middlewareMetadata.AuthStyleString) - if err != nil { - errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%s'. ", middlewareMetadata.AuthStyleString) - } else if authStyle < 0 || authStyle > 2 { - errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", authStyle) - } else { - middlewareMetadata.AuthStyle = authStyle + // Value-check AuthStyle + if middlewareMetadata.AuthStyle < 0 || middlewareMetadata.AuthStyle > 2 { + errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", middlewareMetadata.AuthStyle) } // Return errors if any found if errorString != "" { - return nil, fmt.Errorf("%s", errorString) + return nil, fmt.Errorf("metadata errors: %s", errorString) } return &middlewareMetadata, nil @@ -177,11 +161,8 @@ func (m *Middleware) checkMetadataValueExists(errorString *string, metadataValue func (m *Middleware) getCacheKey(meta *oAuth2ClientCredentialsMiddlewareMetadata) string { // we will hash the key components ClientID + Scopes is a unique composite key/identifier for a token - hashedKey := sha256.New() - key := strings.Join([]string{meta.ClientID, meta.Scopes}, "") - hashedKey.Write([]byte(key)) - - return fmt.Sprintf("%x", hashedKey.Sum(nil)) + hashedKey := sha256.Sum224([]byte(meta.ClientID + meta.Scopes)) + return hex.EncodeToString(hashedKey[:]) } // SetTokenProvider will enable to change the tokenProvider used after instanciation (needed for mocking). diff --git a/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_test.go b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_test.go index 8ee383086..9960b533a 100644 --- a/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_test.go +++ b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_test.go @@ -14,13 +14,14 @@ limitations under the License. package oauth2clientcredentials import ( + "net/http" + "net/http/httptest" "testing" "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - fh "github.com/valyala/fasthttp" oauth2 "golang.org/x/oauth2" "github.com/dapr/components-contrib/middleware" @@ -28,7 +29,11 @@ import ( "github.com/dapr/kit/logger" ) -func mockedRequestHandler(ctx *fh.RequestCtx) {} +// mockedRequestHandler acts like an upstream service returns success status code 200 and a fixed response body. +func mockedRequestHandler(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("from mock")) +} // TestOAuth2ClientCredentialsMetadata will check // - if the metadata checks are correct in place. @@ -41,7 +46,7 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) { log := logger.NewLogger("oauth2clientcredentials.test") _, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) - assert.EqualError(t, err, "Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. Parameter 'authStyle' needs to be set. Parameter 'authStyle' can only have the values 0,1,2. Received: ''. ") + assert.EqualError(t, err, "metadata errors: Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. ") // Invalid authStyle (non int) metadata.Properties = map[string]string{ @@ -53,17 +58,17 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) { "authStyle": "asdf", // This is the value to test } _, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) - assert.EqualError(t, err2, "Parameter 'authStyle' can only have the values 0,1,2. Received: 'asdf'. ") + assert.EqualError(t, err2, "metadata errors: 1 error(s) decoding:\n\n* cannot parse 'AuthStyle' as int: strconv.ParseInt: parsing \"asdf\": invalid syntax") // Invalid authStyle (int > 2) metadata.Properties["authStyle"] = "3" _, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) - assert.EqualError(t, err3, "Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ") + assert.EqualError(t, err3, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ") // Invalid authStyle (int < 0) metadata.Properties["authStyle"] = "-1" _, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) - assert.EqualError(t, err4, "Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ") + assert.EqualError(t, err4, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ") } // TestOAuth2ClientCredentialsToken will check @@ -108,10 +113,12 @@ func TestOAuth2ClientCredentialsToken(t *testing.T) { require.NoError(t, err) // First handler call should return abc Token - var requestContext1 fh.RequestCtx - handler(mockedRequestHandler)(&requestContext1) + r := httptest.NewRequest("GET", "http://dapr.io", nil) + w := httptest.NewRecorder() + handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r) + // Assertion - assert.Equal(t, "Bearer abcd", string(requestContext1.Request.Header.Peek("someHeader"))) + assert.Equal(t, "Bearer abcd", w.Header().Get("someHeader")) } // TestOAuth2ClientCredentialsCache will check @@ -166,23 +173,29 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) { require.NoError(t, err) // First handler call should return abc Token - var requestContext1 fh.RequestCtx - handler(mockedRequestHandler)(&requestContext1) + r := httptest.NewRequest("GET", "http://dapr.io", nil) + w := httptest.NewRecorder() + handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r) + // Assertion - assert.Equal(t, "Bearer abc", string(requestContext1.Request.Header.Peek("someHeader"))) + assert.Equal(t, "Bearer abc", w.Header().Get("someHeader")) // Second handler call should still return 'cached' abc Token - var requestContext2 fh.RequestCtx - handler(mockedRequestHandler)(&requestContext2) + r = httptest.NewRequest("GET", "http://dapr.io", nil) + w = httptest.NewRecorder() + handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r) + // Assertion - assert.Equal(t, "Bearer abc", string(requestContext2.Request.Header.Peek("someHeader"))) + assert.Equal(t, "Bearer abc", w.Header().Get("someHeader")) // Wait at a second to invalidate cache entry for abc time.Sleep(1 * time.Second) // Third call should return def Token - var requestContext3 fh.RequestCtx - handler(mockedRequestHandler)(&requestContext3) + r = httptest.NewRequest("GET", "http://dapr.io", nil) + w = httptest.NewRecorder() + handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r) + // Assertion - assert.Equal(t, "MAC def", string(requestContext3.Request.Header.Peek("someHeader"))) + assert.Equal(t, "MAC def", w.Header().Get("someHeader")) } diff --git a/middleware/http/opa/middleware.go b/middleware/http/opa/middleware.go index 2f2167dde..5ac9de601 100644 --- a/middleware/http/opa/middleware.go +++ b/middleware/http/opa/middleware.go @@ -19,13 +19,18 @@ import ( "encoding/json" "errors" "fmt" + "io" "math" + "net/http" + "net/textproto" "strconv" "strings" "github.com/open-policy-agent/opa/rego" - "github.com/valyala/fasthttp" + "k8s.io/utils/strings/slices" + "github.com/dapr/components-contrib/internal/httputils" + "github.com/dapr/components-contrib/internal/utils" "github.com/dapr/components-contrib/middleware" "github.com/dapr/kit/logger" ) @@ -33,9 +38,11 @@ import ( type Status int type middlewareMetadata struct { - Rego string `json:"rego"` - DefaultStatus Status `json:"defaultStatus,omitempty"` - IncludedHeaders string `json:"includedHeaders,omitempty"` + Rego string `json:"rego"` + DefaultStatus Status `json:"defaultStatus,omitempty"` + IncludedHeaders string `json:"includedHeaders,omitempty"` + SkipBody string `json:"skipBody,omitempty"` + includedHeadersParsed []string `json:"-"` } // NewMiddleware returns a new Open Policy Agent middleware. @@ -98,110 +105,96 @@ func (s *Status) Valid() bool { } // GetHandler returns the HTTP handler provided by the middleware. -func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) { +func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { meta, err := m.getNativeMetadata(metadata) if err != nil { return nil, err } - ctx := context.Background() - query, err := rego.New( rego.Query("result = data.http.allow"), rego.Module("inline.rego", meta.Rego), - ).PrepareForEval(ctx) + ).PrepareForEval(context.Background()) if err != nil { return nil, err } - return func(h fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - if allow := m.evalRequest(ctx, meta, &query); !allow { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if allow := m.evalRequest(w, r, meta, &query); !allow { return } - h(ctx) - } + next.ServeHTTP(w, r) + }) }, nil } -func (m *Middleware) evalRequest(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, query *rego.PreparedEvalQuery) bool { +func (m *Middleware) evalRequest(w http.ResponseWriter, r *http.Request, meta *middlewareMetadata, query *rego.PreparedEvalQuery) bool { headers := map[string]string{} - allowedHeaders := strings.Split(meta.IncludedHeaders, ",") - ctx.Request.Header.VisitAll(func(key, value []byte) { - for _, allowedHeader := range allowedHeaders { - scrubbedHeader := strings.ReplaceAll(allowedHeader, " ", "") - buf := []byte("") - result := fasthttp.AppendNormalizedHeaderKeyBytes(buf, []byte(scrubbedHeader)) - normalizedHeader := result[0:] - if bytes.Equal(key, normalizedHeader) { - headers[string(key)] = string(value) - } - } - }) - queryArgs := map[string][]string{} - ctx.QueryArgs().VisitAll(func(key, value []byte) { - if val, ok := queryArgs[string(key)]; ok { - queryArgs[string(key)] = append(val, string(value)) - } else { - queryArgs[string(key)] = []string{string(value)} + for key, value := range r.Header { + if slices.Contains(meta.includedHeadersParsed, key) { + headers[key] = value[0] } - }) + } - path := string(ctx.Path()) - pathParts := strings.Split(strings.Trim(path, "/"), "/") + var body string + if !utils.IsTruthy(meta.SkipBody) { + buf, _ := io.ReadAll(r.Body) + body = string(buf) + + // Put the body back in the request + r.Body = io.NopCloser(bytes.NewBuffer(buf)) + } + + pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/") input := map[string]interface{}{ "request": map[string]interface{}{ - "method": string(ctx.Method()), - "path": path, + "method": r.Method, + "path": r.URL.Path, "path_parts": pathParts, - "raw_query": string(ctx.QueryArgs().QueryString()), - "query": queryArgs, + "raw_query": r.URL.RawQuery, + "query": map[string][]string(r.URL.Query()), "headers": headers, - "scheme": string(ctx.Request.URI().Scheme()), - "body": string(ctx.Request.Body()), + "scheme": r.URL.Scheme, + "body": body, }, } - results, err := query.Eval(context.TODO(), rego.EvalInput(input)) + results, err := query.Eval(r.Context(), rego.EvalInput(input)) if err != nil { - m.opaError(ctx, meta, err) - + m.opaError(w, meta, err) return false } if len(results) == 0 { - m.opaError(ctx, meta, errOpaNoResult) - + m.opaError(w, meta, errOpaNoResult) return false } - return m.handleRegoResult(ctx, meta, results[0].Bindings["result"]) + return m.handleRegoResult(w, meta, results[0].Bindings["result"]) } // handleRegoResult takes the in process request and open policy agent evaluation result // and maps it the appropriate response or headers. // It returns true if the request should continue, or false if a response should be immediately returned. -func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, result interface{}) bool { +func (m *Middleware) handleRegoResult(w http.ResponseWriter, meta *middlewareMetadata, result any) bool { if allowed, ok := result.(bool); ok { if !allowed { - ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus)) + httputils.RespondWithError(w, int(meta.DefaultStatus)) } - return allowed } - if _, ok := result.(map[string]interface{}); !ok { - m.opaError(ctx, meta, errOpaInvalidResultType) - + if _, ok := result.(map[string]any); !ok { + m.opaError(w, meta, errOpaInvalidResultType) return false } // Is it expensive to marshal back and forth? Should we just manually pull out properties? marshaled, err := json.Marshal(result) if err != nil { - m.opaError(ctx, meta, err) - + m.opaError(w, meta, err) return false } @@ -212,31 +205,26 @@ func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middleware } if err = json.Unmarshal(marshaled, ®oResult); err != nil { - m.opaError(ctx, meta, err) - + m.opaError(w, meta, err) return false } - // If the result isn't allowed, set the response status and - // apply the additional headers to the response. - // Otherwise, set the headers on the ongoing request (overriding as necessary). + // Set the headers on the ongoing request (overriding as necessary) + for key, value := range regoResult.AdditionalHeaders { + w.Header().Set(key, value) + } + + // If the result isn't allowed, set the response status if !regoResult.Allow { - ctx.Error(fasthttp.StatusMessage(regoResult.StatusCode), regoResult.StatusCode) - for key, value := range regoResult.AdditionalHeaders { - ctx.Response.Header.Set(key, value) - } - } else { - for key, value := range regoResult.AdditionalHeaders { - ctx.Request.Header.Set(key, value) - } + httputils.RespondWithError(w, regoResult.StatusCode) } return regoResult.Allow } -func (m *Middleware) opaError(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, err error) { - ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus)) - ctx.Response.Header.Set(opaErrorHeaderKey, "true") +func (m *Middleware) opaError(w http.ResponseWriter, meta *middlewareMetadata, err error) { + w.Header().Set(opaErrorHeaderKey, "true") + httputils.RespondWithError(w, int(meta.DefaultStatus)) m.logger.Warnf("Error procesing rego policy: %v", err) } @@ -254,5 +242,16 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*middlewar return nil, err } + meta.includedHeadersParsed = strings.Split(meta.IncludedHeaders, ",") + n := 0 + for i := range meta.includedHeadersParsed { + scrubbed := strings.ReplaceAll(meta.includedHeadersParsed[i], " ", "") + if scrubbed != "" { + meta.includedHeadersParsed[n] = textproto.CanonicalMIMEHeaderKey(scrubbed) + n++ + } + } + meta.includedHeadersParsed = meta.includedHeadersParsed[:n] + return &meta, nil } diff --git a/middleware/http/opa/middleware_test.go b/middleware/http/opa/middleware_test.go index 507faaf2d..8ac427b97 100644 --- a/middleware/http/opa/middleware_test.go +++ b/middleware/http/opa/middleware_test.go @@ -15,11 +15,13 @@ package opa import ( "encoding/json" + "net/http" + "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - fh "github.com/valyala/fasthttp" "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/middleware" @@ -27,17 +29,15 @@ import ( ) // mockedRequestHandler acts like an upstream service returns success status code 200 and a fixed response body. -func mockedRequestHandler(ctx *fh.RequestCtx) { - ctx.Response.SetStatusCode(200) - ctx.Response.SetBody([]byte("from mock")) +func mockedRequestHandler(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("from mock")) } -type RequestConfiguator func(*fh.RequestCtx) - func TestOpaPolicy(t *testing.T) { tests := map[string]struct { meta middleware.Metadata - req RequestConfiguator + req func() *http.Request status int headers *[][]string body []string @@ -124,9 +124,8 @@ func TestOpaPolicy(t *testing.T) { `, }, }}, - req: func(ctx *fh.RequestCtx) { - ctx.Request.SetHost("https://my.site") - ctx.Request.URI().SetPath("/allowed") + req: func() *http.Request { + return httptest.NewRequest("GET", "https://my.site/allowed", nil) }, status: 200, }, @@ -143,9 +142,8 @@ func TestOpaPolicy(t *testing.T) { `, }, }}, - req: func(ctx *fh.RequestCtx) { - ctx.Request.SetHost("https://my.site") - ctx.Request.URI().SetPath("/forbidden") + req: func() *http.Request { + return httptest.NewRequest("GET", "https://my.site/forbidden", nil) }, status: 403, }, @@ -162,9 +160,10 @@ func TestOpaPolicy(t *testing.T) { `, }, }}, - req: func(ctx *fh.RequestCtx) { - ctx.Request.SetHost("https://my.site") - ctx.Request.Header.Add("x-bad-header", "1") + req: func() *http.Request { + r := httptest.NewRequest("GET", "https://my.site", nil) + r.Header.Add("x-bad-header", "1") + return r }, status: 200, }, @@ -182,9 +181,10 @@ func TestOpaPolicy(t *testing.T) { "includedHeaders": "x-bad-header", }, }}, - req: func(ctx *fh.RequestCtx) { - ctx.Request.SetHost("https://my.site") - ctx.Request.Header.Add("x-bad-header", "1") + req: func() *http.Request { + r := httptest.NewRequest("GET", "https://my.site", nil) + r.Header.Add("X-BAD-HEADER", "1") + return r }, status: 403, }, @@ -245,27 +245,46 @@ func TestOpaPolicy(t *testing.T) { "rego": ` package http default allow = false - - allow = { "status_code": 200 } { + allow = { "allow": true } { input.request.body == "allow" } `, }, }}, - req: func(ctx *fh.RequestCtx) { - ctx.SetContentType("text/plain; charset=utf8") - ctx.Request.SetHost("https://my.site") - ctx.Request.SetBodyString("allow") + req: func() *http.Request { + r := httptest.NewRequest("GET", "https://my.site", strings.NewReader("allow")) + r.Header.Add("content-type", "text/plain; charset=utf8") + return r }, status: 200, }, + "skip reading body": { + meta: middleware.Metadata{Base: metadata.Base{ + Properties: map[string]string{ + "skipBody": "true", + "rego": ` + package http + default allow = false + allow = { "status_code": 403 } { + input.request.body == "allow" + } + `, + }, + }}, + req: func() *http.Request { + r := httptest.NewRequest("GET", "https://my.site", strings.NewReader("allow")) + r.Header.Add("content-type", "text/plain; charset=utf8") + return r + }, + status: 403, + }, "allow when multiple headers included with space": { meta: middleware.Metadata{Base: metadata.Base{ Properties: map[string]string{ "rego": ` package http default allow = false - allow = { "status_code": 200 } { + allow = { "allow": true } { input.request.headers["X-Jwt-Header"] input.request.headers["X-My-Custom-Header"] } @@ -273,47 +292,76 @@ func TestOpaPolicy(t *testing.T) { "includedHeaders": "x-my-custom-header, x-jwt-header", }, }}, - req: func(ctx *fh.RequestCtx) { - ctx.Request.SetHost("https://my.site") - ctx.Request.Header.Add("x-jwt-header", "1") - ctx.Request.Header.Add("x-my-custom-header", "2") + req: func() *http.Request { + r := httptest.NewRequest("GET", "https://my.site", nil) + r.Header.Add("x-jwt-header", "1") + r.Header.Add("x-my-custom-header", "2") + return r }, status: 200, }, + "reject when multiple headers included with space": { + meta: middleware.Metadata{Base: metadata.Base{ + Properties: map[string]string{ + "rego": ` + package http + default allow = false + allow = { "allow": true } { + input.request.headers["X-Jwt-Header"] + input.request.headers["X-My-Custom-Header"] + } + `, + "includedHeaders": "x-my-custom-header, x-jwt-header", + }, + }}, + req: func() *http.Request { + r := httptest.NewRequest("GET", "https://my.site", nil) + r.Header.Add("x-jwt-header", "1") + r.Header.Add("x-bad-header", "2") + return r + }, + status: 403, + }, } + log := logger.NewLogger("opa.test") + for name, test := range tests { t.Run(name, func(t *testing.T) { - log := logger.NewLogger("opa.test") opaMiddleware := NewMiddleware(log) - handler, err := opaMiddleware.GetHandler(test.meta) + handler, err := opaMiddleware.GetHandler(test.meta) if test.shouldHandlerError { require.Error(t, err) - return } - require.NoError(t, err) - var reqCtx fh.RequestCtx + var r *http.Request if test.req != nil { - test.req(&reqCtx) + r = test.req() + } else { + r = httptest.NewRequest("GET", "https://my.site", nil) } - handler(mockedRequestHandler)(&reqCtx) + w := httptest.NewRecorder() + + handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r) if test.shouldRegoError { - assert.Equal(t, 403, reqCtx.Response.StatusCode()) - assert.Equal(t, "true", string(reqCtx.Response.Header.Peek(opaErrorHeaderKey))) - + assert.Equal(t, 403, w.Code) + assert.Equal(t, "true", w.Header().Get(opaErrorHeaderKey)) return } - assert.Equal(t, test.status, reqCtx.Response.StatusCode()) + assert.Equal(t, test.status, w.Code) + + if test.status == 200 { + assert.Equal(t, "from mock", w.Body.String()) + } if test.headers != nil { for _, header := range *test.headers { - assert.Equal(t, header[1], string(reqCtx.Response.Header.Peek(header[0]))) + assert.Equal(t, header[1], w.Header().Get(header[0])) } } }) diff --git a/middleware/http/ratelimit/ratelimit_middleware.go b/middleware/http/ratelimit/ratelimit_middleware.go index fa395a441..2ca439a78 100644 --- a/middleware/http/ratelimit/ratelimit_middleware.go +++ b/middleware/http/ratelimit/ratelimit_middleware.go @@ -15,14 +15,12 @@ package ratelimit import ( "fmt" + "net/http" "strconv" "github.com/didip/tollbooth" - "github.com/valyala/fasthttp" - "github.com/valyala/fasthttp/fasthttpadaptor" "github.com/dapr/components-contrib/middleware" - "github.com/dapr/components-contrib/middleware/http/nethttpadaptor" "github.com/dapr/kit/logger" ) @@ -39,17 +37,16 @@ const ( ) // NewRateLimitMiddleware returns a new ratelimit middleware. -func NewRateLimitMiddleware(logger logger.Logger) middleware.Middleware { - return &Middleware{logger: logger} +func NewRateLimitMiddleware(_ logger.Logger) middleware.Middleware { + return &Middleware{} } // Middleware is an ratelimit middleware. type Middleware struct { - logger logger.Logger } // GetHandler returns the HTTP handler provided by the middleware. -func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) { +func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { meta, err := m.getNativeMetadata(metadata) if err != nil { return nil, err @@ -57,13 +54,8 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.R limiter := tollbooth.NewLimiter(meta.MaxRequestsPerSecond, nil) - return func(h fasthttp.RequestHandler) fasthttp.RequestHandler { - limitHandler := tollbooth.LimitFuncHandler(limiter, nethttpadaptor.NewNetHTTPHandlerFunc(m.logger, h)) - wrappedHandler := fasthttpadaptor.NewFastHTTPHandlerFunc(limitHandler.ServeHTTP) - - return func(ctx *fasthttp.RequestCtx) { - wrappedHandler(ctx) - } + return func(next http.Handler) http.Handler { + return tollbooth.LimitHandler(limiter, next) }, nil } @@ -74,7 +66,7 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*rateLimit if val, ok := metadata.Properties[maxRequestsPerSecondKey]; ok { f, err := strconv.ParseFloat(val, 64) if err != nil { - return nil, fmt.Errorf("error parsing ratelimit middleware property %s: %+v", maxRequestsPerSecondKey, err) + return nil, fmt.Errorf("error parsing ratelimit middleware property %s: %w", maxRequestsPerSecondKey, err) } if f <= 0 { return nil, fmt.Errorf("ratelimit middleware property %s must be a positive value", maxRequestsPerSecondKey)