diff --git a/middleware/http/opa/middleware.go b/middleware/http/opa/middleware.go index 113457450..5a93fee36 100644 --- a/middleware/http/opa/middleware.go +++ b/middleware/http/opa/middleware.go @@ -5,17 +5,23 @@ import ( "context" "encoding/json" "errors" + "fmt" + "math" + "strconv" "strings" "github.com/dapr/components-contrib/middleware" "github.com/dapr/kit/logger" + "github.com/open-policy-agent/opa/rego" "github.com/valyala/fasthttp" ) +type Status int + type middlewareMetadata struct { Rego string `json:"rego"` - DefaultStatus int `json:"defaultStatus,omitempty"` + DefaultStatus Status `json:"defaultStatus,omitempty"` IncludedHeaders string `json:"includedHeaders,omitempty"` } @@ -43,6 +49,41 @@ var ( errOpaInvalidResultType = errors.New("got an invalid type back from repo policy. Only a boolean or map is valid") ) +func (s *Status) UnmarshalJSON(b []byte) error { + if len(b) == 0 { + return nil + } + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + switch value := v.(type) { + case float64: + if value != math.Trunc(value) { + return fmt.Errorf("invalid float value %f parse to status(int)", value) + } + *s = Status(value) + case string: + intVal, err := strconv.Atoi(value) + if err != nil { + return err + } + *s = Status(intVal) + default: + return fmt.Errorf("invalid value %v parse to status(int)", value) + } + if !s.Valid() { + return fmt.Errorf("invalid status value %d expected in range [100-599]", *s) + } + + return nil +} + +// Check status is in the correct range for RFC 2616 status codes [100-599] +func (s *Status) Valid() bool { + return s != nil && *s >= 100 && *s < 600 +} + // GetHandler returns the HTTP handler provided by the middleware func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) { meta, err := m.getNativeMetadata(metadata) @@ -129,7 +170,7 @@ func (m *Middleware) evalRequest(ctx *fasthttp.RequestCtx, meta *middlewareMetad func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, result interface{}) bool { if allowed, ok := result.(bool); ok { if !allowed { - ctx.Error(fasthttp.StatusMessage(meta.DefaultStatus), meta.DefaultStatus) + ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus)) } return allowed @@ -151,7 +192,7 @@ func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middleware regoResult := RegoResult{ // By default, a non-allowed request with return a 403 response. - StatusCode: meta.DefaultStatus, + StatusCode: int(meta.DefaultStatus), AdditionalHeaders: make(map[string]string), } @@ -179,7 +220,7 @@ func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middleware } func (m *Middleware) opaError(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, err error) { - ctx.Error(fasthttp.StatusMessage(meta.DefaultStatus), meta.DefaultStatus) + ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus)) ctx.Response.Header.Set(opaErrorHeaderKey, "true") m.logger.Warnf("Error procesing rego policy: %v", err) } diff --git a/middleware/http/opa/middleware_test.go b/middleware/http/opa/middleware_test.go index d984ffec2..e81f137e3 100644 --- a/middleware/http/opa/middleware_test.go +++ b/middleware/http/opa/middleware_test.go @@ -1,6 +1,7 @@ package opa import ( + "encoding/json" "testing" "github.com/dapr/components-contrib/middleware" @@ -193,6 +194,31 @@ func TestOpaPolicy(t *testing.T) { }, shouldRegoError: true, }, + "status config": { + meta: middleware.Metadata{ + Properties: map[string]string{ + "rego": ` + package http + allow = false`, + "defaultStatus": "500", + }, + }, + status: 500, + }, + "rego priority over defaultStatus metadata": { + meta: middleware.Metadata{ + Properties: map[string]string{ + "rego": ` + package http + allow = { + "allow": false, + "status_code": 301 + }`, + "defaultStatus": "500", + }, + }, + status: 301, + }, } for name, test := range tests { @@ -232,3 +258,63 @@ func TestOpaPolicy(t *testing.T) { }) } } + +func TestStatus_UnmarshalJSON(t *testing.T) { + type testObj struct { + Value Status `json:"value,omitempty"` + } + tests := map[string]struct { + jsonBytes []byte + expectValue Status + expectError bool + }{ + "int value": { + jsonBytes: []byte(`{"value":100}`), + expectValue: Status(100), + expectError: false, + }, + "string value": { + jsonBytes: []byte(`{"value":"100"}`), + expectValue: Status(100), + expectError: false, + }, + "empty value": { + jsonBytes: []byte(`{}`), + expectValue: Status(0), + expectError: false, + }, + "invalid status code value": { + jsonBytes: []byte(`{"value":600}`), + expectError: true, + }, + "invalid float value": { + jsonBytes: []byte(`{"value":2.9}`), + expectError: true, + }, + "invalid value null": { + jsonBytes: []byte(`{"value":null}`), + expectError: true, + }, + "invalid value []": { + jsonBytes: []byte(`{"value":[]}`), + expectError: true, + }, + "invalid value {}": { + jsonBytes: []byte(`{"value":{}}`), + expectError: true, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + var obj testObj + err := json.Unmarshal(test.jsonBytes, &obj) + if test.expectError { + assert.NotEmpty(t, err) + + return + } + assert.Nil(t, err) + assert.Equal(t, obj.Value, test.expectValue) + }) + } +}