540 lines
16 KiB
Go
540 lines
16 KiB
Go
/*
|
|
Copyright 2021 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package bearer_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"slices"
|
|
"strconv"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/lestrrat-go/jwx/v2/jwa"
|
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
|
"github.com/lestrrat-go/jwx/v2/jwt"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
// Import the embed package.
|
|
_ "embed"
|
|
|
|
"github.com/dapr/components-contrib/metadata"
|
|
"github.com/dapr/components-contrib/middleware"
|
|
bearerMw "github.com/dapr/components-contrib/middleware/http/bearer"
|
|
"github.com/dapr/components-contrib/tests/certification/embedded"
|
|
"github.com/dapr/components-contrib/tests/certification/flow"
|
|
"github.com/dapr/components-contrib/tests/certification/flow/app"
|
|
"github.com/dapr/components-contrib/tests/certification/flow/sidecar"
|
|
httpMiddlewareLoader "github.com/dapr/dapr/pkg/components/middleware/http"
|
|
"github.com/dapr/dapr/pkg/config/protocol"
|
|
httpMiddleware "github.com/dapr/dapr/pkg/middleware/http"
|
|
dapr_testing "github.com/dapr/dapr/pkg/testing"
|
|
"github.com/dapr/go-sdk/service/common"
|
|
"github.com/dapr/kit/logger"
|
|
)
|
|
|
|
const (
|
|
appID = "myapp"
|
|
invokeMethod = "mymethod"
|
|
tokenServicePort = 7470 // Defined in bearer.yaml
|
|
tokenAudience = "26b9502f-1336-4479-ad5b-c8366edb7206" // Defined in bearer.yaml
|
|
tokenIssuer = "http://localhost:7470" // Defined in bearer.yaml
|
|
)
|
|
|
|
var (
|
|
//go:embed jwks.json
|
|
jwksData string
|
|
//go:embed private.json
|
|
privateKeyData string
|
|
|
|
// Logger
|
|
log = logger.NewLogger("dapr.components")
|
|
)
|
|
|
|
func init() {
|
|
log.SetOutputLevel(logger.DebugLevel)
|
|
}
|
|
|
|
func TestHTTPMiddlewareBearer(t *testing.T) {
|
|
var grpcPorts, httpPorts, appPorts [2]int
|
|
client := http.Client{}
|
|
|
|
for {
|
|
ports, err := dapr_testing.GetFreePorts(6)
|
|
require.NoError(t, err)
|
|
|
|
// Ensure tokenServicePort isn't included
|
|
if slices.Index(ports, tokenServicePort) > -1 {
|
|
continue
|
|
}
|
|
|
|
grpcPorts = [2]int{ports[0], ports[3]}
|
|
httpPorts = [2]int{ports[1], ports[4]}
|
|
appPorts = [2]int{ports[2], ports[5]}
|
|
break
|
|
}
|
|
|
|
// Load the private keys
|
|
privateKeys, err := jwk.Parse([]byte(privateKeyData))
|
|
require.NoError(t, err)
|
|
|
|
// Counters for requests coming to the web server
|
|
requestsOpenIDConfiguration := atomic.Int32{}
|
|
requestsJWKS := atomic.Int32{}
|
|
|
|
// Step function that sets up a HTTP server that returns the JWKS
|
|
setupJWKSServerStepFn := func() (string, flow.Runnable, flow.Runnable) {
|
|
r := chi.NewRouter()
|
|
|
|
openIDConfigurationHandler := func(w http.ResponseWriter, r *http.Request) {
|
|
requestsOpenIDConfiguration.Add(1)
|
|
log.Info("Received request for OpenID Configuration document")
|
|
w.Header().Set("content-type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
res := map[string]string{
|
|
"issuer": tokenIssuer,
|
|
"jwks_uri": tokenIssuer + "/.well-known/jwks.json",
|
|
}
|
|
json.NewEncoder(w).Encode(res)
|
|
}
|
|
r.Get("/.well-known/openid-configuration", openIDConfigurationHandler)
|
|
r.Get("/foo/.well-known/openid-configuration", openIDConfigurationHandler)
|
|
|
|
r.Get("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) {
|
|
requestsJWKS.Add(1)
|
|
log.Info("Received request for JWKS")
|
|
w.Header().Set("content-type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(jwksData))
|
|
})
|
|
|
|
srv := &http.Server{
|
|
Addr: fmt.Sprintf("127.0.0.1:%d", tokenServicePort),
|
|
Handler: r,
|
|
}
|
|
|
|
start := flow.Runnable(func(ctx flow.Context) error {
|
|
go func() {
|
|
err := srv.ListenAndServe()
|
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
ctx.T.Fatalf("server error: %v", err)
|
|
}
|
|
}()
|
|
|
|
time.Sleep(500 * time.Millisecond)
|
|
return nil
|
|
})
|
|
|
|
stop := flow.Runnable(func(ctx flow.Context) error {
|
|
shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 1*time.Second)
|
|
defer shutdownCancel()
|
|
return srv.Shutdown(shutdownCtx)
|
|
})
|
|
|
|
return "start JWKS server", start, stop
|
|
}
|
|
|
|
type sendRequestOpts struct {
|
|
AuthorizationHeader string
|
|
}
|
|
|
|
sendRequest := func(parentCtx context.Context, port int, opts *sendRequestOpts) (int, error) {
|
|
invokeUrl := fmt.Sprintf("http://localhost:%d/v1.0/invoke/%s/method/%s", port, appID, invokeMethod)
|
|
|
|
reqCtx, reqCancel := context.WithTimeout(parentCtx, 5*time.Second)
|
|
defer reqCancel()
|
|
|
|
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, invokeUrl, nil)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
if opts != nil && opts.AuthorizationHeader != "" {
|
|
req.Header.Set("authorization", opts.AuthorizationHeader)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("request error: %w", err)
|
|
}
|
|
|
|
defer func() {
|
|
// Drain before closing
|
|
_, _ = io.Copy(io.Discard, res.Body)
|
|
res.Body.Close()
|
|
}()
|
|
|
|
return res.StatusCode, nil
|
|
}
|
|
|
|
type runTestOpts struct {
|
|
AuthFailure string // Options include: "empty"
|
|
}
|
|
|
|
// Run tests to check if the bearer token is validated correctly
|
|
bearerTests := func(ctx flow.Context) error {
|
|
now := time.Now()
|
|
|
|
tests := []struct {
|
|
name string
|
|
buildTokenFn func(builder *jwt.Builder)
|
|
signTokenFn func(builder *jwt.Builder) ([]byte, error)
|
|
authHeaderFn func(token string) string
|
|
signingKeyID int
|
|
statusCode int
|
|
}{
|
|
{
|
|
name: "valid auth header",
|
|
statusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "lowercase bearer in token",
|
|
statusCode: http.StatusOK,
|
|
authHeaderFn: func(token string) string {
|
|
return "bearer " + token
|
|
},
|
|
},
|
|
{
|
|
name: "empty authorization header",
|
|
statusCode: http.StatusUnauthorized,
|
|
authHeaderFn: func(token string) string {
|
|
return ""
|
|
},
|
|
},
|
|
{
|
|
name: "empty bearer token 1",
|
|
statusCode: http.StatusUnauthorized,
|
|
authHeaderFn: func(token string) string {
|
|
return "Bearer"
|
|
},
|
|
},
|
|
{
|
|
name: "empty bearer token 2",
|
|
statusCode: http.StatusUnauthorized,
|
|
authHeaderFn: func(token string) string {
|
|
return "Bearer "
|
|
},
|
|
},
|
|
{
|
|
name: "malformed JWT",
|
|
statusCode: http.StatusUnauthorized,
|
|
authHeaderFn: func(token string) string {
|
|
return "Bearer iMjZiOTUwMmYtMTMzN"
|
|
},
|
|
},
|
|
{
|
|
name: "expired token",
|
|
statusCode: http.StatusUnauthorized,
|
|
buildTokenFn: func(builder *jwt.Builder) {
|
|
builder.IssuedAt(now.Add(-20 * time.Minute))
|
|
builder.Expiration(now.Add(-10 * time.Minute))
|
|
},
|
|
},
|
|
{
|
|
name: "token but within allowed clock skew",
|
|
statusCode: http.StatusOK,
|
|
buildTokenFn: func(builder *jwt.Builder) {
|
|
builder.IssuedAt(now.Add(-20 * time.Minute))
|
|
builder.Expiration(now.Add(-1 * time.Minute))
|
|
},
|
|
},
|
|
{
|
|
name: "token not yet valid",
|
|
statusCode: http.StatusUnauthorized,
|
|
buildTokenFn: func(builder *jwt.Builder) {
|
|
builder.NotBefore(now.Add(20 * time.Minute))
|
|
builder.IssuedAt(now.Add(20 * time.Minute))
|
|
},
|
|
},
|
|
{
|
|
name: "token not yet valid but within allowed clock skew",
|
|
statusCode: http.StatusOK,
|
|
buildTokenFn: func(builder *jwt.Builder) {
|
|
builder.NotBefore(now.Add(1 * time.Minute))
|
|
builder.IssuedAt(now.Add(1 * time.Minute))
|
|
},
|
|
},
|
|
{
|
|
name: "invalid token audience",
|
|
statusCode: http.StatusUnauthorized,
|
|
buildTokenFn: func(builder *jwt.Builder) {
|
|
builder.Audience([]string{"foo"})
|
|
},
|
|
},
|
|
{
|
|
name: "empty token audience",
|
|
statusCode: http.StatusUnauthorized,
|
|
buildTokenFn: func(builder *jwt.Builder) {
|
|
builder.Audience([]string{})
|
|
},
|
|
},
|
|
{
|
|
name: "invalid token issuer",
|
|
statusCode: http.StatusUnauthorized,
|
|
buildTokenFn: func(builder *jwt.Builder) {
|
|
builder.Issuer("foo")
|
|
},
|
|
},
|
|
{
|
|
name: "empty token issuer",
|
|
statusCode: http.StatusUnauthorized,
|
|
buildTokenFn: func(builder *jwt.Builder) {
|
|
builder.Issuer("")
|
|
},
|
|
},
|
|
{
|
|
name: "reject tokens with alg 'none'",
|
|
statusCode: http.StatusUnauthorized,
|
|
signTokenFn: func(builder *jwt.Builder) ([]byte, error) {
|
|
// {"alg":"none"}
|
|
const joseHeader = `eyJhbGciOiJub25lIn0`
|
|
token, err := builder.Build()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
claimSet, err := jwt.NewSerializer().Serialize(token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return []byte(joseHeader + "." + base64.RawURLEncoding.EncodeToString(claimSet) + "."), nil
|
|
},
|
|
},
|
|
{
|
|
name: "token signed with wrong key",
|
|
statusCode: http.StatusUnauthorized,
|
|
signingKeyID: 1,
|
|
},
|
|
}
|
|
|
|
ctx.T.Run("bearer token validation", func(t *testing.T) {
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Generate a new JWT
|
|
builder := jwt.NewBuilder().
|
|
Audience([]string{tokenAudience}).
|
|
Issuer(tokenIssuer).
|
|
IssuedAt(now).
|
|
Expiration(now.Add(2 * time.Minute))
|
|
|
|
// If we have a tokenFn, invoke that
|
|
if tt.buildTokenFn != nil {
|
|
tt.buildTokenFn(builder)
|
|
}
|
|
|
|
// Build the token
|
|
// If we have a signTokenFn, invoke that
|
|
var signedToken []byte
|
|
if tt.signTokenFn != nil {
|
|
signedToken, err = tt.signTokenFn(builder)
|
|
require.NoError(t, err)
|
|
} else {
|
|
token, err := builder.Build()
|
|
require.NoError(t, err)
|
|
useKey, _ := privateKeys.Key(tt.signingKeyID)
|
|
signedToken, err = jwt.Sign(token, jwt.WithKey(jwa.PS256, useKey))
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Set the auth header
|
|
var authHeader string
|
|
if tt.authHeaderFn != nil {
|
|
authHeader = tt.authHeaderFn(string(signedToken))
|
|
} else {
|
|
authHeader = "Bearer " + string(signedToken)
|
|
}
|
|
|
|
// Invoke both sidecars
|
|
resStatus, err := sendRequest(ctx.Context, httpPorts[0], &sendRequestOpts{
|
|
AuthorizationHeader: authHeader,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, tt.statusCode, resStatus)
|
|
|
|
resStatus, err = sendRequest(ctx.Context, httpPorts[1], &sendRequestOpts{
|
|
AuthorizationHeader: authHeader,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, tt.statusCode, resStatus)
|
|
})
|
|
}
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
// Run tests to validate component initialization
|
|
initializationTests := func(ctx flow.Context) error {
|
|
ctx.T.Run("component initialization", func(t *testing.T) {
|
|
initMiddleware := func(md map[string]string) error {
|
|
_, err := bearerMw.
|
|
NewBearerMiddleware(log).
|
|
GetHandler(context.Background(), middleware.Metadata{Base: metadata.Base{
|
|
Name: "test",
|
|
Properties: md,
|
|
}})
|
|
return err
|
|
}
|
|
|
|
t.Run("successful initialization", func(t *testing.T) {
|
|
curRequestsOpenIDConfiguration := requestsOpenIDConfiguration.Load()
|
|
curRequestsJWKS := requestsJWKS.Load()
|
|
|
|
err := initMiddleware(map[string]string{
|
|
"issuer": tokenIssuer,
|
|
"audience": tokenAudience,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Both endpoints should be requested
|
|
assert.Equal(t, curRequestsOpenIDConfiguration+1, requestsOpenIDConfiguration.Load())
|
|
assert.Equal(t, curRequestsJWKS+1, requestsJWKS.Load())
|
|
})
|
|
|
|
t.Run("explicit JWKS URL", func(t *testing.T) {
|
|
curRequestsOpenIDConfiguration := requestsOpenIDConfiguration.Load()
|
|
curRequestsJWKS := requestsJWKS.Load()
|
|
|
|
err := initMiddleware(map[string]string{
|
|
"issuer": tokenIssuer,
|
|
"audience": tokenAudience,
|
|
"jwksURL": tokenIssuer + "/.well-known/jwks.json",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Only the JWKS endpoint should be requested
|
|
assert.Equal(t, curRequestsOpenIDConfiguration, requestsOpenIDConfiguration.Load())
|
|
assert.Equal(t, curRequestsJWKS+1, requestsJWKS.Load())
|
|
})
|
|
|
|
t.Run("cannot find OpenID configuration document", func(t *testing.T) {
|
|
curRequestsOpenIDConfiguration := requestsOpenIDConfiguration.Load()
|
|
curRequestsJWKS := requestsJWKS.Load()
|
|
|
|
err := initMiddleware(map[string]string{
|
|
"issuer": tokenIssuer + "/notfound",
|
|
"audience": tokenAudience,
|
|
})
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "invalid response status code: 404")
|
|
|
|
// No endpoint should be requested
|
|
assert.Equal(t, curRequestsOpenIDConfiguration, requestsOpenIDConfiguration.Load())
|
|
assert.Equal(t, curRequestsJWKS, requestsJWKS.Load())
|
|
})
|
|
|
|
t.Run("token issuer mismatch in OpenID configuration document", func(t *testing.T) {
|
|
curRequestsOpenIDConfiguration := requestsOpenIDConfiguration.Load()
|
|
curRequestsJWKS := requestsJWKS.Load()
|
|
|
|
err := initMiddleware(map[string]string{
|
|
"issuer": tokenIssuer + "/foo",
|
|
"audience": tokenAudience,
|
|
})
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "the issuer found in the OpenID Configuration document")
|
|
|
|
// Only the OpenID Configuration endpoint should be requested
|
|
assert.Equal(t, curRequestsOpenIDConfiguration+1, requestsOpenIDConfiguration.Load())
|
|
assert.Equal(t, curRequestsJWKS, requestsJWKS.Load())
|
|
})
|
|
|
|
t.Run("cannot find JWKS", func(t *testing.T) {
|
|
curRequestsOpenIDConfiguration := requestsOpenIDConfiguration.Load()
|
|
curRequestsJWKS := requestsJWKS.Load()
|
|
|
|
err := initMiddleware(map[string]string{
|
|
"issuer": tokenIssuer,
|
|
"audience": tokenAudience,
|
|
"jwksURL": tokenIssuer + "/not-found/jwks.json",
|
|
})
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "failed to fetch JWKS")
|
|
|
|
// No endpoint should be requested
|
|
assert.Equal(t, curRequestsOpenIDConfiguration, requestsOpenIDConfiguration.Load())
|
|
assert.Equal(t, curRequestsJWKS, requestsJWKS.Load())
|
|
})
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
// Application setup code
|
|
application := func(ctx flow.Context, s common.Service) error {
|
|
s.AddServiceInvocationHandler(invokeMethod, func(ctx context.Context, in *common.InvocationEvent) (out *common.Content, err error) {
|
|
return &common.Content{}, nil
|
|
})
|
|
return nil
|
|
}
|
|
|
|
// Run tests
|
|
flow.New(t, "Bearer HTTP middleware").
|
|
// Setup steps
|
|
Step(setupJWKSServerStepFn()).
|
|
// Start app and sidecar 1
|
|
Step(app.Run("Start application 1", fmt.Sprintf(":%d", appPorts[0]), application)).
|
|
Step(sidecar.Run(appID,
|
|
append(componentRuntimeOptions(),
|
|
embedded.WithAppProtocol(protocol.HTTPProtocol, strconv.Itoa(appPorts[0])),
|
|
embedded.WithDaprGRPCPort(strconv.Itoa(grpcPorts[0])),
|
|
embedded.WithDaprHTTPPort(strconv.Itoa(httpPorts[0])),
|
|
embedded.WithResourcesPath("./resources"),
|
|
embedded.WithAPILoggingEnabled(false),
|
|
embedded.WithProfilingEnabled(false),
|
|
)...,
|
|
)).
|
|
// Start app and sidecar 2
|
|
Step(app.Run("Start application 2", fmt.Sprintf(":%d", appPorts[1]), application)).
|
|
Step(sidecar.Run(appID,
|
|
append(componentRuntimeOptions(),
|
|
embedded.WithAppProtocol(protocol.HTTPProtocol, strconv.Itoa(appPorts[1])),
|
|
embedded.WithDaprGRPCPort(strconv.Itoa(grpcPorts[1])),
|
|
embedded.WithDaprHTTPPort(strconv.Itoa(httpPorts[1])),
|
|
embedded.WithResourcesPath("./resources"),
|
|
embedded.WithAPILoggingEnabled(false),
|
|
embedded.WithProfilingEnabled(false),
|
|
)...,
|
|
)).
|
|
// Tests
|
|
Step("bearer token validation", bearerTests).
|
|
Step("component initialization", initializationTests).
|
|
// Run
|
|
Run()
|
|
}
|
|
|
|
func componentRuntimeOptions() []embedded.Option {
|
|
middlewareRegistry := httpMiddlewareLoader.NewRegistry()
|
|
middlewareRegistry.Logger = log
|
|
middlewareRegistry.RegisterComponent(func(log logger.Logger) httpMiddlewareLoader.FactoryMethod {
|
|
return func(metadata middleware.Metadata) (httpMiddleware.Middleware, error) {
|
|
return bearerMw.NewBearerMiddleware(log).GetHandler(context.Background(), metadata)
|
|
}
|
|
}, "bearer")
|
|
|
|
return []embedded.Option{
|
|
embedded.WithHTTPMiddlewares(middlewareRegistry),
|
|
}
|
|
}
|