fix http opa middleware status code parse (#1136)
* fix http opa middleware status code parse * fix for review * use json unmarshal to handle edge cases. * fix review Co-authored-by: Simon Leet <31784195+CodeMonkeyLeet@users.noreply.github.com> Co-authored-by: Artur Souza <artursouza.ms@outlook.com>
This commit is contained in:
parent
189d2d6717
commit
3a109e2c74
|
@ -5,17 +5,23 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/dapr/components-contrib/middleware"
|
||||
"github.com/dapr/kit/logger"
|
||||
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type Status int
|
||||
|
||||
type middlewareMetadata struct {
|
||||
Rego string `json:"rego"`
|
||||
DefaultStatus int `json:"defaultStatus,omitempty"`
|
||||
DefaultStatus Status `json:"defaultStatus,omitempty"`
|
||||
IncludedHeaders string `json:"includedHeaders,omitempty"`
|
||||
}
|
||||
|
||||
|
@ -43,6 +49,41 @@ var (
|
|||
errOpaInvalidResultType = errors.New("got an invalid type back from repo policy. Only a boolean or map is valid")
|
||||
)
|
||||
|
||||
func (s *Status) UnmarshalJSON(b []byte) error {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
var v interface{}
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
switch value := v.(type) {
|
||||
case float64:
|
||||
if value != math.Trunc(value) {
|
||||
return fmt.Errorf("invalid float value %f parse to status(int)", value)
|
||||
}
|
||||
*s = Status(value)
|
||||
case string:
|
||||
intVal, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*s = Status(intVal)
|
||||
default:
|
||||
return fmt.Errorf("invalid value %v parse to status(int)", value)
|
||||
}
|
||||
if !s.Valid() {
|
||||
return fmt.Errorf("invalid status value %d expected in range [100-599]", *s)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check status is in the correct range for RFC 2616 status codes [100-599]
|
||||
func (s *Status) Valid() bool {
|
||||
return s != nil && *s >= 100 && *s < 600
|
||||
}
|
||||
|
||||
// GetHandler returns the HTTP handler provided by the middleware
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) {
|
||||
meta, err := m.getNativeMetadata(metadata)
|
||||
|
@ -129,7 +170,7 @@ func (m *Middleware) evalRequest(ctx *fasthttp.RequestCtx, meta *middlewareMetad
|
|||
func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, result interface{}) bool {
|
||||
if allowed, ok := result.(bool); ok {
|
||||
if !allowed {
|
||||
ctx.Error(fasthttp.StatusMessage(meta.DefaultStatus), meta.DefaultStatus)
|
||||
ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus))
|
||||
}
|
||||
|
||||
return allowed
|
||||
|
@ -151,7 +192,7 @@ func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middleware
|
|||
|
||||
regoResult := RegoResult{
|
||||
// By default, a non-allowed request with return a 403 response.
|
||||
StatusCode: meta.DefaultStatus,
|
||||
StatusCode: int(meta.DefaultStatus),
|
||||
AdditionalHeaders: make(map[string]string),
|
||||
}
|
||||
|
||||
|
@ -179,7 +220,7 @@ func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middleware
|
|||
}
|
||||
|
||||
func (m *Middleware) opaError(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, err error) {
|
||||
ctx.Error(fasthttp.StatusMessage(meta.DefaultStatus), meta.DefaultStatus)
|
||||
ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus))
|
||||
ctx.Response.Header.Set(opaErrorHeaderKey, "true")
|
||||
m.logger.Warnf("Error procesing rego policy: %v", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package opa
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/dapr/components-contrib/middleware"
|
||||
|
@ -193,6 +194,31 @@ func TestOpaPolicy(t *testing.T) {
|
|||
},
|
||||
shouldRegoError: true,
|
||||
},
|
||||
"status config": {
|
||||
meta: middleware.Metadata{
|
||||
Properties: map[string]string{
|
||||
"rego": `
|
||||
package http
|
||||
allow = false`,
|
||||
"defaultStatus": "500",
|
||||
},
|
||||
},
|
||||
status: 500,
|
||||
},
|
||||
"rego priority over defaultStatus metadata": {
|
||||
meta: middleware.Metadata{
|
||||
Properties: map[string]string{
|
||||
"rego": `
|
||||
package http
|
||||
allow = {
|
||||
"allow": false,
|
||||
"status_code": 301
|
||||
}`,
|
||||
"defaultStatus": "500",
|
||||
},
|
||||
},
|
||||
status: 301,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
|
@ -232,3 +258,63 @@ func TestOpaPolicy(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatus_UnmarshalJSON(t *testing.T) {
|
||||
type testObj struct {
|
||||
Value Status `json:"value,omitempty"`
|
||||
}
|
||||
tests := map[string]struct {
|
||||
jsonBytes []byte
|
||||
expectValue Status
|
||||
expectError bool
|
||||
}{
|
||||
"int value": {
|
||||
jsonBytes: []byte(`{"value":100}`),
|
||||
expectValue: Status(100),
|
||||
expectError: false,
|
||||
},
|
||||
"string value": {
|
||||
jsonBytes: []byte(`{"value":"100"}`),
|
||||
expectValue: Status(100),
|
||||
expectError: false,
|
||||
},
|
||||
"empty value": {
|
||||
jsonBytes: []byte(`{}`),
|
||||
expectValue: Status(0),
|
||||
expectError: false,
|
||||
},
|
||||
"invalid status code value": {
|
||||
jsonBytes: []byte(`{"value":600}`),
|
||||
expectError: true,
|
||||
},
|
||||
"invalid float value": {
|
||||
jsonBytes: []byte(`{"value":2.9}`),
|
||||
expectError: true,
|
||||
},
|
||||
"invalid value null": {
|
||||
jsonBytes: []byte(`{"value":null}`),
|
||||
expectError: true,
|
||||
},
|
||||
"invalid value []": {
|
||||
jsonBytes: []byte(`{"value":[]}`),
|
||||
expectError: true,
|
||||
},
|
||||
"invalid value {}": {
|
||||
jsonBytes: []byte(`{"value":{}}`),
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
var obj testObj
|
||||
err := json.Unmarshal(test.jsonBytes, &obj)
|
||||
if test.expectError {
|
||||
assert.NotEmpty(t, err)
|
||||
|
||||
return
|
||||
}
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, obj.Value, test.expectValue)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue