/* Copyright 2021 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package bearer_test import ( "context" "encoding/base64" "encoding/json" "errors" "fmt" "io" "net/http" "slices" "strconv" "sync/atomic" "testing" "time" "github.com/go-chi/chi/v5" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" // Import the embed package. _ "embed" "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/middleware" bearerMw "github.com/dapr/components-contrib/middleware/http/bearer" "github.com/dapr/components-contrib/tests/certification/embedded" "github.com/dapr/components-contrib/tests/certification/flow" "github.com/dapr/components-contrib/tests/certification/flow/app" "github.com/dapr/components-contrib/tests/certification/flow/sidecar" httpMiddlewareLoader "github.com/dapr/dapr/pkg/components/middleware/http" "github.com/dapr/dapr/pkg/config/protocol" httpMiddleware "github.com/dapr/dapr/pkg/middleware/http" dapr_testing "github.com/dapr/dapr/pkg/testing" "github.com/dapr/go-sdk/service/common" "github.com/dapr/kit/logger" ) const ( appID = "myapp" invokeMethod = "mymethod" tokenServicePort = 7470 // Defined in bearer.yaml tokenAudience = "26b9502f-1336-4479-ad5b-c8366edb7206" // Defined in bearer.yaml tokenIssuer = "http://localhost:7470" // Defined in bearer.yaml ) var ( //go:embed jwks.json jwksData string //go:embed private.json privateKeyData string // Logger log = logger.NewLogger("dapr.components") ) func init() { log.SetOutputLevel(logger.DebugLevel) } func TestHTTPMiddlewareBearer(t *testing.T) { var grpcPorts, httpPorts, appPorts [2]int client := http.Client{} for { ports, err := dapr_testing.GetFreePorts(6) require.NoError(t, err) // Ensure tokenServicePort isn't included if slices.Index(ports, tokenServicePort) > -1 { continue } grpcPorts = [2]int{ports[0], ports[3]} httpPorts = [2]int{ports[1], ports[4]} appPorts = [2]int{ports[2], ports[5]} break } // Load the private keys privateKeys, err := jwk.Parse([]byte(privateKeyData)) require.NoError(t, err) // Counters for requests coming to the web server requestsOpenIDConfiguration := atomic.Int32{} requestsJWKS := atomic.Int32{} // Step function that sets up a HTTP server that returns the JWKS setupJWKSServerStepFn := func() (string, flow.Runnable, flow.Runnable) { r := chi.NewRouter() openIDConfigurationHandler := func(w http.ResponseWriter, r *http.Request) { requestsOpenIDConfiguration.Add(1) log.Info("Received request for OpenID Configuration document") w.Header().Set("content-type", "application/json") w.WriteHeader(http.StatusOK) res := map[string]string{ "issuer": tokenIssuer, "jwks_uri": tokenIssuer + "/.well-known/jwks.json", } json.NewEncoder(w).Encode(res) } r.Get("/.well-known/openid-configuration", openIDConfigurationHandler) r.Get("/foo/.well-known/openid-configuration", openIDConfigurationHandler) r.Get("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) { requestsJWKS.Add(1) log.Info("Received request for JWKS") w.Header().Set("content-type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(jwksData)) }) srv := &http.Server{ Addr: fmt.Sprintf("127.0.0.1:%d", tokenServicePort), Handler: r, } start := flow.Runnable(func(ctx flow.Context) error { go func() { err := srv.ListenAndServe() if err != nil && !errors.Is(err, http.ErrServerClosed) { ctx.T.Fatalf("server error: %v", err) } }() time.Sleep(500 * time.Millisecond) return nil }) stop := flow.Runnable(func(ctx flow.Context) error { shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 1*time.Second) defer shutdownCancel() return srv.Shutdown(shutdownCtx) }) return "start JWKS server", start, stop } type sendRequestOpts struct { AuthorizationHeader string } sendRequest := func(parentCtx context.Context, port int, opts *sendRequestOpts) (int, error) { invokeUrl := fmt.Sprintf("http://localhost:%d/v1.0/invoke/%s/method/%s", port, appID, invokeMethod) reqCtx, reqCancel := context.WithTimeout(parentCtx, 5*time.Second) defer reqCancel() req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, invokeUrl, nil) if err != nil { return 0, fmt.Errorf("failed to create request: %w", err) } if opts != nil && opts.AuthorizationHeader != "" { req.Header.Set("authorization", opts.AuthorizationHeader) } res, err := client.Do(req) if err != nil { return 0, fmt.Errorf("request error: %w", err) } defer func() { // Drain before closing _, _ = io.Copy(io.Discard, res.Body) res.Body.Close() }() return res.StatusCode, nil } type runTestOpts struct { AuthFailure string // Options include: "empty" } // Run tests to check if the bearer token is validated correctly bearerTests := func(ctx flow.Context) error { now := time.Now() tests := []struct { name string buildTokenFn func(builder *jwt.Builder) signTokenFn func(builder *jwt.Builder) ([]byte, error) authHeaderFn func(token string) string signingKeyID int statusCode int }{ { name: "valid auth header", statusCode: http.StatusOK, }, { name: "lowercase bearer in token", statusCode: http.StatusOK, authHeaderFn: func(token string) string { return "bearer " + token }, }, { name: "empty authorization header", statusCode: http.StatusUnauthorized, authHeaderFn: func(token string) string { return "" }, }, { name: "empty bearer token 1", statusCode: http.StatusUnauthorized, authHeaderFn: func(token string) string { return "Bearer" }, }, { name: "empty bearer token 2", statusCode: http.StatusUnauthorized, authHeaderFn: func(token string) string { return "Bearer " }, }, { name: "malformed JWT", statusCode: http.StatusUnauthorized, authHeaderFn: func(token string) string { return "Bearer iMjZiOTUwMmYtMTMzN" }, }, { name: "expired token", statusCode: http.StatusUnauthorized, buildTokenFn: func(builder *jwt.Builder) { builder.IssuedAt(now.Add(-20 * time.Minute)) builder.Expiration(now.Add(-10 * time.Minute)) }, }, { name: "token but within allowed clock skew", statusCode: http.StatusOK, buildTokenFn: func(builder *jwt.Builder) { builder.IssuedAt(now.Add(-20 * time.Minute)) builder.Expiration(now.Add(-1 * time.Minute)) }, }, { name: "token not yet valid", statusCode: http.StatusUnauthorized, buildTokenFn: func(builder *jwt.Builder) { builder.NotBefore(now.Add(20 * time.Minute)) builder.IssuedAt(now.Add(20 * time.Minute)) }, }, { name: "token not yet valid but within allowed clock skew", statusCode: http.StatusOK, buildTokenFn: func(builder *jwt.Builder) { builder.NotBefore(now.Add(1 * time.Minute)) builder.IssuedAt(now.Add(1 * time.Minute)) }, }, { name: "invalid token audience", statusCode: http.StatusUnauthorized, buildTokenFn: func(builder *jwt.Builder) { builder.Audience([]string{"foo"}) }, }, { name: "empty token audience", statusCode: http.StatusUnauthorized, buildTokenFn: func(builder *jwt.Builder) { builder.Audience([]string{}) }, }, { name: "invalid token issuer", statusCode: http.StatusUnauthorized, buildTokenFn: func(builder *jwt.Builder) { builder.Issuer("foo") }, }, { name: "empty token issuer", statusCode: http.StatusUnauthorized, buildTokenFn: func(builder *jwt.Builder) { builder.Issuer("") }, }, { name: "reject tokens with alg 'none'", statusCode: http.StatusUnauthorized, signTokenFn: func(builder *jwt.Builder) ([]byte, error) { // {"alg":"none"} const joseHeader = `eyJhbGciOiJub25lIn0` token, err := builder.Build() if err != nil { return nil, err } claimSet, err := jwt.NewSerializer().Serialize(token) if err != nil { return nil, err } return []byte(joseHeader + "." + base64.RawURLEncoding.EncodeToString(claimSet) + "."), nil }, }, { name: "token signed with wrong key", statusCode: http.StatusUnauthorized, signingKeyID: 1, }, } ctx.T.Run("bearer token validation", func(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Generate a new JWT builder := jwt.NewBuilder(). Audience([]string{tokenAudience}). Issuer(tokenIssuer). IssuedAt(now). Expiration(now.Add(2 * time.Minute)) // If we have a tokenFn, invoke that if tt.buildTokenFn != nil { tt.buildTokenFn(builder) } // Build the token // If we have a signTokenFn, invoke that var signedToken []byte if tt.signTokenFn != nil { signedToken, err = tt.signTokenFn(builder) require.NoError(t, err) } else { token, err := builder.Build() require.NoError(t, err) useKey, _ := privateKeys.Key(tt.signingKeyID) signedToken, err = jwt.Sign(token, jwt.WithKey(jwa.PS256, useKey)) require.NoError(t, err) } // Set the auth header var authHeader string if tt.authHeaderFn != nil { authHeader = tt.authHeaderFn(string(signedToken)) } else { authHeader = "Bearer " + string(signedToken) } // Invoke both sidecars resStatus, err := sendRequest(ctx.Context, httpPorts[0], &sendRequestOpts{ AuthorizationHeader: authHeader, }) require.NoError(t, err) assert.Equal(t, tt.statusCode, resStatus) resStatus, err = sendRequest(ctx.Context, httpPorts[1], &sendRequestOpts{ AuthorizationHeader: authHeader, }) require.NoError(t, err) assert.Equal(t, tt.statusCode, resStatus) }) } }) return nil } // Run tests to validate component initialization initializationTests := func(ctx flow.Context) error { ctx.T.Run("component initialization", func(t *testing.T) { initMiddleware := func(md map[string]string) error { _, err := bearerMw. NewBearerMiddleware(log). GetHandler(context.Background(), middleware.Metadata{Base: metadata.Base{ Name: "test", Properties: md, }}) return err } t.Run("successful initialization", func(t *testing.T) { curRequestsOpenIDConfiguration := requestsOpenIDConfiguration.Load() curRequestsJWKS := requestsJWKS.Load() err := initMiddleware(map[string]string{ "issuer": tokenIssuer, "audience": tokenAudience, }) require.NoError(t, err) // Both endpoints should be requested assert.Equal(t, curRequestsOpenIDConfiguration+1, requestsOpenIDConfiguration.Load()) assert.Equal(t, curRequestsJWKS+1, requestsJWKS.Load()) }) t.Run("explicit JWKS URL", func(t *testing.T) { curRequestsOpenIDConfiguration := requestsOpenIDConfiguration.Load() curRequestsJWKS := requestsJWKS.Load() err := initMiddleware(map[string]string{ "issuer": tokenIssuer, "audience": tokenAudience, "jwksURL": tokenIssuer + "/.well-known/jwks.json", }) require.NoError(t, err) // Only the JWKS endpoint should be requested assert.Equal(t, curRequestsOpenIDConfiguration, requestsOpenIDConfiguration.Load()) assert.Equal(t, curRequestsJWKS+1, requestsJWKS.Load()) }) t.Run("cannot find OpenID configuration document", func(t *testing.T) { curRequestsOpenIDConfiguration := requestsOpenIDConfiguration.Load() curRequestsJWKS := requestsJWKS.Load() err := initMiddleware(map[string]string{ "issuer": tokenIssuer + "/notfound", "audience": tokenAudience, }) require.Error(t, err) require.ErrorContains(t, err, "invalid response status code: 404") // No endpoint should be requested assert.Equal(t, curRequestsOpenIDConfiguration, requestsOpenIDConfiguration.Load()) assert.Equal(t, curRequestsJWKS, requestsJWKS.Load()) }) t.Run("token issuer mismatch in OpenID configuration document", func(t *testing.T) { curRequestsOpenIDConfiguration := requestsOpenIDConfiguration.Load() curRequestsJWKS := requestsJWKS.Load() err := initMiddleware(map[string]string{ "issuer": tokenIssuer + "/foo", "audience": tokenAudience, }) require.Error(t, err) require.ErrorContains(t, err, "the issuer found in the OpenID Configuration document") // Only the OpenID Configuration endpoint should be requested assert.Equal(t, curRequestsOpenIDConfiguration+1, requestsOpenIDConfiguration.Load()) assert.Equal(t, curRequestsJWKS, requestsJWKS.Load()) }) t.Run("cannot find JWKS", func(t *testing.T) { curRequestsOpenIDConfiguration := requestsOpenIDConfiguration.Load() curRequestsJWKS := requestsJWKS.Load() err := initMiddleware(map[string]string{ "issuer": tokenIssuer, "audience": tokenAudience, "jwksURL": tokenIssuer + "/not-found/jwks.json", }) require.Error(t, err) require.ErrorContains(t, err, "failed to fetch JWKS") // No endpoint should be requested assert.Equal(t, curRequestsOpenIDConfiguration, requestsOpenIDConfiguration.Load()) assert.Equal(t, curRequestsJWKS, requestsJWKS.Load()) }) }) return nil } // Application setup code application := func(ctx flow.Context, s common.Service) error { s.AddServiceInvocationHandler(invokeMethod, func(ctx context.Context, in *common.InvocationEvent) (out *common.Content, err error) { return &common.Content{}, nil }) return nil } // Run tests flow.New(t, "Bearer HTTP middleware"). // Setup steps Step(setupJWKSServerStepFn()). // Start app and sidecar 1 Step(app.Run("Start application 1", fmt.Sprintf(":%d", appPorts[0]), application)). Step(sidecar.Run(appID, append(componentRuntimeOptions(), embedded.WithAppProtocol(protocol.HTTPProtocol, strconv.Itoa(appPorts[0])), embedded.WithDaprGRPCPort(strconv.Itoa(grpcPorts[0])), embedded.WithDaprHTTPPort(strconv.Itoa(httpPorts[0])), embedded.WithResourcesPath("./resources"), embedded.WithAPILoggingEnabled(false), embedded.WithProfilingEnabled(false), )..., )). // Start app and sidecar 2 Step(app.Run("Start application 2", fmt.Sprintf(":%d", appPorts[1]), application)). Step(sidecar.Run(appID, append(componentRuntimeOptions(), embedded.WithAppProtocol(protocol.HTTPProtocol, strconv.Itoa(appPorts[1])), embedded.WithDaprGRPCPort(strconv.Itoa(grpcPorts[1])), embedded.WithDaprHTTPPort(strconv.Itoa(httpPorts[1])), embedded.WithResourcesPath("./resources"), embedded.WithAPILoggingEnabled(false), embedded.WithProfilingEnabled(false), )..., )). // Tests Step("bearer token validation", bearerTests). Step("component initialization", initializationTests). // Run Run() } func componentRuntimeOptions() []embedded.Option { middlewareRegistry := httpMiddlewareLoader.NewRegistry() middlewareRegistry.Logger = log middlewareRegistry.RegisterComponent(func(log logger.Logger) httpMiddlewareLoader.FactoryMethod { return func(metadata middleware.Metadata) (httpMiddleware.Middleware, error) { return bearerMw.NewBearerMiddleware(log).GetHandler(context.Background(), metadata) } }, "bearer") return []embedded.Option{ embedded.WithHTTPMiddlewares(middlewareRegistry), } }