fix http opa middleware status code parse (#1136)

* fix http opa middleware status code parse

* fix for review

* use json unmarshal to handle edge cases.

* fix review

Co-authored-by: Simon Leet <31784195+CodeMonkeyLeet@users.noreply.github.com>
Co-authored-by: Artur Souza <artursouza.ms@outlook.com>
This commit is contained in:
Taction 2021-09-18 01:58:01 +08:00 committed by GitHub
parent 189d2d6717
commit 3a109e2c74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 131 additions and 4 deletions

View File

@ -5,17 +5,23 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"math"
"strconv"
"strings"
"github.com/dapr/components-contrib/middleware"
"github.com/dapr/kit/logger"
"github.com/open-policy-agent/opa/rego"
"github.com/valyala/fasthttp"
)
type Status int
type middlewareMetadata struct {
Rego string `json:"rego"`
DefaultStatus int `json:"defaultStatus,omitempty"`
DefaultStatus Status `json:"defaultStatus,omitempty"`
IncludedHeaders string `json:"includedHeaders,omitempty"`
}
@ -43,6 +49,41 @@ var (
errOpaInvalidResultType = errors.New("got an invalid type back from repo policy. Only a boolean or map is valid")
)
func (s *Status) UnmarshalJSON(b []byte) error {
if len(b) == 0 {
return nil
}
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case float64:
if value != math.Trunc(value) {
return fmt.Errorf("invalid float value %f parse to status(int)", value)
}
*s = Status(value)
case string:
intVal, err := strconv.Atoi(value)
if err != nil {
return err
}
*s = Status(intVal)
default:
return fmt.Errorf("invalid value %v parse to status(int)", value)
}
if !s.Valid() {
return fmt.Errorf("invalid status value %d expected in range [100-599]", *s)
}
return nil
}
// Check status is in the correct range for RFC 2616 status codes [100-599]
func (s *Status) Valid() bool {
return s != nil && *s >= 100 && *s < 600
}
// GetHandler returns the HTTP handler provided by the middleware
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.RequestHandler) fasthttp.RequestHandler, error) {
meta, err := m.getNativeMetadata(metadata)
@ -129,7 +170,7 @@ func (m *Middleware) evalRequest(ctx *fasthttp.RequestCtx, meta *middlewareMetad
func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, result interface{}) bool {
if allowed, ok := result.(bool); ok {
if !allowed {
ctx.Error(fasthttp.StatusMessage(meta.DefaultStatus), meta.DefaultStatus)
ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus))
}
return allowed
@ -151,7 +192,7 @@ func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middleware
regoResult := RegoResult{
// By default, a non-allowed request with return a 403 response.
StatusCode: meta.DefaultStatus,
StatusCode: int(meta.DefaultStatus),
AdditionalHeaders: make(map[string]string),
}
@ -179,7 +220,7 @@ func (m *Middleware) handleRegoResult(ctx *fasthttp.RequestCtx, meta *middleware
}
func (m *Middleware) opaError(ctx *fasthttp.RequestCtx, meta *middlewareMetadata, err error) {
ctx.Error(fasthttp.StatusMessage(meta.DefaultStatus), meta.DefaultStatus)
ctx.Error(fasthttp.StatusMessage(int(meta.DefaultStatus)), int(meta.DefaultStatus))
ctx.Response.Header.Set(opaErrorHeaderKey, "true")
m.logger.Warnf("Error procesing rego policy: %v", err)
}

View File

@ -1,6 +1,7 @@
package opa
import (
"encoding/json"
"testing"
"github.com/dapr/components-contrib/middleware"
@ -193,6 +194,31 @@ func TestOpaPolicy(t *testing.T) {
},
shouldRegoError: true,
},
"status config": {
meta: middleware.Metadata{
Properties: map[string]string{
"rego": `
package http
allow = false`,
"defaultStatus": "500",
},
},
status: 500,
},
"rego priority over defaultStatus metadata": {
meta: middleware.Metadata{
Properties: map[string]string{
"rego": `
package http
allow = {
"allow": false,
"status_code": 301
}`,
"defaultStatus": "500",
},
},
status: 301,
},
}
for name, test := range tests {
@ -232,3 +258,63 @@ func TestOpaPolicy(t *testing.T) {
})
}
}
func TestStatus_UnmarshalJSON(t *testing.T) {
type testObj struct {
Value Status `json:"value,omitempty"`
}
tests := map[string]struct {
jsonBytes []byte
expectValue Status
expectError bool
}{
"int value": {
jsonBytes: []byte(`{"value":100}`),
expectValue: Status(100),
expectError: false,
},
"string value": {
jsonBytes: []byte(`{"value":"100"}`),
expectValue: Status(100),
expectError: false,
},
"empty value": {
jsonBytes: []byte(`{}`),
expectValue: Status(0),
expectError: false,
},
"invalid status code value": {
jsonBytes: []byte(`{"value":600}`),
expectError: true,
},
"invalid float value": {
jsonBytes: []byte(`{"value":2.9}`),
expectError: true,
},
"invalid value null": {
jsonBytes: []byte(`{"value":null}`),
expectError: true,
},
"invalid value []": {
jsonBytes: []byte(`{"value":[]}`),
expectError: true,
},
"invalid value {}": {
jsonBytes: []byte(`{"value":{}}`),
expectError: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
var obj testObj
err := json.Unmarshal(test.jsonBytes, &obj)
if test.expectError {
assert.NotEmpty(t, err)
return
}
assert.Nil(t, err)
assert.Equal(t, obj.Value, test.expectValue)
})
}
}