diff --git a/plugin/pkg/authenticator/token/oidc/metrics.go b/plugin/pkg/authenticator/token/oidc/metrics.go index 66c6254ec..109cde254 100644 --- a/plugin/pkg/authenticator/token/oidc/metrics.go +++ b/plugin/pkg/authenticator/token/oidc/metrics.go @@ -20,13 +20,13 @@ import ( "context" "crypto/sha256" "fmt" - "k8s.io/utils/clock" "sync" "time" "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/component-base/metrics" "k8s.io/component-base/metrics/legacyregistry" + "k8s.io/utils/clock" ) const ( @@ -68,11 +68,11 @@ func getHash(data string) string { return "" } -func newInstrumentedAuthenticator(jwtIssuer string, delegate authenticator.Token) authenticator.Token { +func newInstrumentedAuthenticator(jwtIssuer string, delegate AuthenticatorTokenWithHealthCheck) AuthenticatorTokenWithHealthCheck { return newInstrumentedAuthenticatorWithClock(jwtIssuer, delegate, clock.RealClock{}) } -func newInstrumentedAuthenticatorWithClock(jwtIssuer string, delegate authenticator.Token, clock clock.PassiveClock) *instrumentedAuthenticator { +func newInstrumentedAuthenticatorWithClock(jwtIssuer string, delegate AuthenticatorTokenWithHealthCheck, clock clock.PassiveClock) *instrumentedAuthenticator { RegisterMetrics() return &instrumentedAuthenticator{ jwtIssuerHash: getHash(jwtIssuer), @@ -83,7 +83,7 @@ func newInstrumentedAuthenticatorWithClock(jwtIssuer string, delegate authentica type instrumentedAuthenticator struct { jwtIssuerHash string - delegate authenticator.Token + delegate AuthenticatorTokenWithHealthCheck clock clock.PassiveClock } @@ -104,3 +104,7 @@ func (a *instrumentedAuthenticator) AuthenticateToken(ctx context.Context, token } return response, ok, err } + +func (a *instrumentedAuthenticator) HealthCheck() error { + return a.delegate.HealthCheck() +} diff --git a/plugin/pkg/authenticator/token/oidc/metrics_test.go b/plugin/pkg/authenticator/token/oidc/metrics_test.go index 998b18466..cdcad1d05 100644 --- a/plugin/pkg/authenticator/token/oidc/metrics_test.go +++ b/plugin/pkg/authenticator/token/oidc/metrics_test.go @@ -35,7 +35,7 @@ const ( func TestRecordAuthenticationLatency(t *testing.T) { tests := []struct { name string - authenticator authenticator.Token + authenticator AuthenticatorTokenWithHealthCheck generateMetrics func() expectedValue string }{ @@ -117,6 +117,10 @@ func (a *dummyAuthenticator) AuthenticateToken(ctx context.Context, token string return a.response, a.ok, a.err } +func (a *dummyAuthenticator) HealthCheck() error { + panic("should not be called") +} + type dummyClock struct { } diff --git a/plugin/pkg/authenticator/token/oidc/oidc.go b/plugin/pkg/authenticator/token/oidc/oidc.go index 0ce213d9f..bdf50d89c 100644 --- a/plugin/pkg/authenticator/token/oidc/oidc.go +++ b/plugin/pkg/authenticator/token/oidc/oidc.go @@ -74,7 +74,9 @@ const ( type Options struct { // JWTAuthenticator is the authenticator that will be used to verify the JWT. JWTAuthenticator apiserver.JWTAuthenticator + // Optional KeySet to allow for synchronous initialization instead of fetching from the remote issuer. + // Mutually exclusive with JWTAuthenticator.Issuer.DiscoveryURL. KeySet oidc.KeySet // PEM encoded root certificate contents of the provider. Mutually exclusive with Client. @@ -135,7 +137,7 @@ func newAsyncIDTokenVerifier(ctx context.Context, c *oidc.Config, iss string, au sync := make(chan struct{}) // Polls indefinitely in an attempt to initialize the distributed claims // verifier, or until context canceled. - initFn := func() (done bool, err error) { + initFn := func(ctx context.Context) (done bool, err error) { klog.V(4).Infof("oidc authenticator: attempting init: iss=%v", iss) v, err := initVerifier(ctx, c, iss, audiences) if err != nil { @@ -150,13 +152,14 @@ func newAsyncIDTokenVerifier(ctx context.Context, c *oidc.Config, iss string, au } go func() { - if done, _ := initFn(); !done { - go wait.PollUntil(time.Second*10, initFn, ctx.Done()) - } + _ = wait.PollUntilContextCancel(ctx, 10*time.Second, true, initFn) }() if synchronizeTokenIDVerifierForTest { - <-sync + select { + case <-sync: + case <-ctx.Done(): + } } return t @@ -169,15 +172,13 @@ func (a *asyncIDTokenVerifier) verifier() *idTokenVerifier { return a.v } -type Authenticator struct { +type jwtAuthenticator struct { jwtAuthenticator apiserver.JWTAuthenticator // Contains an *oidc.IDTokenVerifier. Do not access directly use the // idTokenVerifier method. verifier atomic.Value - cancel context.CancelFunc - // resolver is used to resolve distributed claims. resolver *claimResolver @@ -187,6 +188,8 @@ type Authenticator struct { // requiredClaims contains the list of claims that must be present in the token. requiredClaims map[string]string + + healthCheck atomic.Pointer[errorHolder] } // idTokenVerifier is a wrapper around oidc.IDTokenVerifier. It uses the oidc.IDTokenVerifier @@ -196,21 +199,22 @@ type idTokenVerifier struct { audiences sets.Set[string] } -func (a *Authenticator) setVerifier(v *idTokenVerifier) { +func (a *jwtAuthenticator) setVerifier(v *idTokenVerifier) { a.verifier.Store(v) + if v != nil { + // this must be done after the verifier has been stored so that a nil error + // from HealthCheck always means that the authenticator is ready for use. + a.healthCheck.Store(&errorHolder{}) + } } -func (a *Authenticator) idTokenVerifier() (*idTokenVerifier, bool) { +func (a *jwtAuthenticator) idTokenVerifier() (*idTokenVerifier, bool) { if v := a.verifier.Load(); v != nil { return v.(*idTokenVerifier), true } return nil, false } -func (a *Authenticator) Close() { - a.cancel() -} - func AllValidSigningAlgorithms() []string { return sets.List(sets.KeySet(allowedSigningAlgs)) } @@ -228,7 +232,18 @@ var allowedSigningAlgs = map[string]bool{ oidc.PS512: true, } -func New(opts Options) (authenticator.Token, error) { +type AuthenticatorTokenWithHealthCheck interface { + authenticator.Token + HealthCheck() error +} + +// New returns an authenticator that is asynchronously initialized when opts.KeySet is not set. +// The input lifecycleCtx is used to: +// - terminate background goroutines that are needed for asynchronous initialization +// - as the base context for any requests that are made (i.e. for key fetching) +// Thus, once the lifecycleCtx is canceled, the authenticator must not be used. +// A caller may check if the authenticator is healthy by calling the HealthCheck method. +func New(lifecycleCtx context.Context, opts Options) (AuthenticatorTokenWithHealthCheck, error) { celMapper, fieldErr := apiservervalidation.CompileAndValidateJWTAuthenticator(opts.JWTAuthenticator, opts.DisallowedIssuers) if err := fieldErr.ToAggregate(); err != nil { return nil, err @@ -280,6 +295,10 @@ func New(opts Options) (authenticator.Token, error) { // the discovery URL. This is useful for self-hosted providers, for example, // providers that run on top of Kubernetes itself. if len(opts.JWTAuthenticator.Issuer.DiscoveryURL) > 0 { + if opts.KeySet != nil { + return nil, fmt.Errorf("oidc: KeySet and DiscoveryURL are mutually exclusive") + } + discoveryURL, err := url.Parse(opts.JWTAuthenticator.Issuer.DiscoveryURL) if err != nil { return nil, fmt.Errorf("oidc: invalid discovery URL: %w", err) @@ -297,8 +316,7 @@ func New(opts Options) (authenticator.Token, error) { client = &clientWithDiscoveryURL } - ctx, cancel := context.WithCancel(context.Background()) - ctx = oidc.ClientContext(ctx, client) + lifecycleCtx = oidc.ClientContext(lifecycleCtx, client) now := opts.now if now == nil { @@ -324,7 +342,7 @@ func New(opts Options) (authenticator.Token, error) { var resolver *claimResolver groupsClaim := opts.JWTAuthenticator.ClaimMappings.Groups.Claim if groupsClaim != "" { - resolver = newClaimResolver(groupsClaim, client, verifierConfig, audiences) + resolver = newClaimResolver(lifecycleCtx, groupsClaim, client, verifierConfig, audiences) } requiredClaims := make(map[string]string) @@ -334,38 +352,51 @@ func New(opts Options) (authenticator.Token, error) { } } - authenticator := &Authenticator{ + authn := &jwtAuthenticator{ jwtAuthenticator: opts.JWTAuthenticator, - cancel: cancel, resolver: resolver, celMapper: celMapper, requiredClaims: requiredClaims, } + authn.healthCheck.Store(&errorHolder{ + err: fmt.Errorf("oidc: authenticator for issuer %q is not initialized", authn.jwtAuthenticator.Issuer.URL), + }) issuerURL := opts.JWTAuthenticator.Issuer.URL if opts.KeySet != nil { // We already have a key set, synchronously initialize the verifier. - authenticator.setVerifier(&idTokenVerifier{ + authn.setVerifier(&idTokenVerifier{ oidc.NewVerifier(issuerURL, opts.KeySet, verifierConfig), audiences, }) } else { // Asynchronously attempt to initialize the authenticator. This enables // self-hosted providers, providers that run on top of Kubernetes itself. - go wait.PollImmediateUntil(10*time.Second, func() (done bool, err error) { - provider, err := oidc.NewProvider(ctx, issuerURL) - if err != nil { - klog.Errorf("oidc authenticator: initializing plugin: %v", err) - return false, nil - } + go func() { + // we ignore any errors from polling because they can only come from the context being canceled + _ = wait.PollUntilContextCancel(lifecycleCtx, 10*time.Second, true, func(_ context.Context) (done bool, err error) { + // this must always use lifecycleCtx because NewProvider uses that context for future key set fetching. + // this also means that there is no correct way to control the timeout of the discovery request made by NewProvider. + // the global timeout of the http.Client is still honored. + provider, err := oidc.NewProvider(lifecycleCtx, issuerURL) + if err != nil { + klog.Errorf("oidc authenticator: initializing plugin: %v", err) + authn.healthCheck.Store(&errorHolder{err: err}) + return false, nil + } - verifier := provider.Verifier(verifierConfig) - authenticator.setVerifier(&idTokenVerifier{verifier, audiences}) - return true, nil - }, ctx.Done()) + verifier := provider.Verifier(verifierConfig) + authn.setVerifier(&idTokenVerifier{verifier, audiences}) + return true, nil + }) + }() } - return newInstrumentedAuthenticator(issuerURL, authenticator), nil + return newInstrumentedAuthenticator(issuerURL, authn), nil +} + +type errorHolder struct { + err error } // discoveryURLRoundTripper is a http.RoundTripper that rewrites the @@ -448,6 +479,8 @@ type endpoint struct { // claimResolver expands distributed claims by calling respective claim source // endpoints. type claimResolver struct { + ctx context.Context + // claim is the distributed claim that may be resolved. claim string @@ -471,8 +504,10 @@ type claimResolver struct { } // newClaimResolver creates a new resolver for distributed claims. -func newClaimResolver(claim string, client *http.Client, config *oidc.Config, audiences sets.Set[string]) *claimResolver { +// the input ctx is retained and is used as the base context for background requests such as key fetching. +func newClaimResolver(ctx context.Context, claim string, client *http.Client, config *oidc.Config, audiences sets.Set[string]) *claimResolver { return &claimResolver{ + ctx: ctx, claim: claim, audiences: audiences, client: client, @@ -487,8 +522,7 @@ func (r *claimResolver) Verifier(iss string) (*idTokenVerifier, error) { av := r.verifierPerIssuer[iss] if av == nil { // This lazy init should normally be very quick. - // TODO: Make this context cancelable. - ctx := oidc.ClientContext(context.Background(), r.client) + ctx := oidc.ClientContext(r.ctx, r.client) av = newAsyncIDTokenVerifier(ctx, r.config, iss, r.audiences) r.verifierPerIssuer[iss] = av } @@ -638,7 +672,7 @@ func (v *idTokenVerifier) verifyAudience(t *oidc.IDToken) error { return fmt.Errorf("oidc: expected audience in %q got %q", sets.List(v.audiences), t.Audience) } -func (a *Authenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) { +func (a *jwtAuthenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) { if !hasCorrectIssuer(a.jwtAuthenticator.Issuer.URL, token) { return nil, false, nil } @@ -759,7 +793,15 @@ func (a *Authenticator) AuthenticateToken(ctx context.Context, token string) (*a return &authenticator.Response{User: info}, true, nil } -func (a *Authenticator) getUsername(ctx context.Context, c claims, claimsUnstructured *unstructured.Unstructured) (string, error) { +func (a *jwtAuthenticator) HealthCheck() error { + if holder := *a.healthCheck.Load(); holder.err != nil { + return fmt.Errorf("oidc: authenticator for issuer %q is not healthy: %w", a.jwtAuthenticator.Issuer.URL, holder.err) + } + + return nil +} + +func (a *jwtAuthenticator) getUsername(ctx context.Context, c claims, claimsUnstructured *unstructured.Unstructured) (string, error) { if a.celMapper.Username != nil { evalResult, err := a.celMapper.Username.EvalClaimMapping(ctx, claimsUnstructured) if err != nil { @@ -807,7 +849,7 @@ func (a *Authenticator) getUsername(ctx context.Context, c claims, claimsUnstruc return username, nil } -func (a *Authenticator) getGroups(ctx context.Context, c claims, claimsUnstructured *unstructured.Unstructured) ([]string, error) { +func (a *jwtAuthenticator) getGroups(ctx context.Context, c claims, claimsUnstructured *unstructured.Unstructured) ([]string, error) { groupsClaim := a.jwtAuthenticator.ClaimMappings.Groups.Claim if len(groupsClaim) > 0 { if _, ok := c[groupsClaim]; ok { @@ -847,7 +889,7 @@ func (a *Authenticator) getGroups(ctx context.Context, c claims, claimsUnstructu return groups, nil } -func (a *Authenticator) getUID(ctx context.Context, c claims, claimsUnstructured *unstructured.Unstructured) (string, error) { +func (a *jwtAuthenticator) getUID(ctx context.Context, c claims, claimsUnstructured *unstructured.Unstructured) (string, error) { uidClaim := a.jwtAuthenticator.ClaimMappings.UID.Claim if len(uidClaim) > 0 { var uid string @@ -872,7 +914,7 @@ func (a *Authenticator) getUID(ctx context.Context, c claims, claimsUnstructured return evalResult.EvalResult.Value().(string), nil } -func (a *Authenticator) getExtra(ctx context.Context, claimsUnstructured *unstructured.Unstructured) (map[string][]string, error) { +func (a *jwtAuthenticator) getExtra(ctx context.Context, claimsUnstructured *unstructured.Unstructured) (map[string][]string, error) { if a.celMapper.Extra == nil { return nil, nil } diff --git a/plugin/pkg/authenticator/token/oidc/oidc_test.go b/plugin/pkg/authenticator/token/oidc/oidc_test.go index 3b246df4d..d4db3faa3 100644 --- a/plugin/pkg/authenticator/token/oidc/oidc_test.go +++ b/plugin/pkg/authenticator/token/oidc/oidc_test.go @@ -145,6 +145,7 @@ type claimsTest struct { wantSkip bool wantErr string wantInitErr string + wantHealthErrPrefix string claimToResponseMap map[string]string openIDConfig string fetchKeysFromRemote bool @@ -283,8 +284,10 @@ func (c *claimsTest) run(t *testing.T) { expectInitErr := len(c.wantInitErr) > 0 + ctx := testContext(t) + // Initialize the authenticator. - a, err := New(c.options) + a, err := New(ctx, c.options) if err != nil { if !expectInitErr { t.Fatalf("initialize authenticator: %v", err) @@ -298,6 +301,25 @@ func (c *claimsTest) run(t *testing.T) { t.Fatalf("wanted initialization error %q but got none", c.wantInitErr) } + if len(c.wantHealthErrPrefix) > 0 { + if err := wait.PollUntilContextTimeout(ctx, time.Second, time.Minute, true, func(context.Context) (bool, error) { + healthErr := a.HealthCheck() + if healthErr == nil { + return false, fmt.Errorf("authenticator reported healthy when it should not") + } + + if strings.HasPrefix(healthErr.Error(), c.wantHealthErrPrefix) { + return true, nil + } + + t.Logf("saw health error prefix that did not match: want=%q got=%q", c.wantHealthErrPrefix, healthErr.Error()) + return false, nil + }); err != nil { + t.Fatalf("authenticator did not match wanted health error: %v", err) + } + return + } + claims := struct{}{} if err := json.Unmarshal([]byte(c.claims), &claims); err != nil { t.Fatalf("failed to unmarshal claims: %v", err) @@ -313,21 +335,9 @@ func (c *claimsTest) run(t *testing.T) { t.Fatalf("serialize token: %v", err) } - ia, ok := a.(*instrumentedAuthenticator) - if !ok { - t.Fatalf("expected authenticator to be instrumented") - } - authenticator, ok := ia.delegate.(*Authenticator) - if !ok { - t.Fatalf("expected delegate to be Authenticator") - } - ctx := testContext(t) - // wait for the authenticator to be initialized + // wait for the authenticator to be healthy err = wait.PollUntilContextCancel(ctx, time.Millisecond, true, func(context.Context) (bool, error) { - if v, _ := authenticator.idTokenVerifier(); v == nil { - return false, nil - } - return true, nil + return a.HealthCheck() == nil, nil }) if err != nil { t.Fatalf("failed to initialize the authenticator: %v", err) @@ -2060,6 +2070,51 @@ func TestToken(t *testing.T) { }, wantInitErr: "oidc: Client and CAContentProvider are mutually exclusive", }, + { + name: "keyset and discovery URL mutually exclusive", + options: Options{ + JWTAuthenticator: apiserver.JWTAuthenticator{ + Issuer: apiserver.Issuer{ + URL: "https://auth.example.com", + DiscoveryURL: "https://auth.example.com/foo", + Audiences: []string{"my-client"}, + }, + ClaimMappings: apiserver.ClaimMappings{ + Username: apiserver.PrefixedClaimOrExpression{ + Claim: "username", + Prefix: pointer.String("prefix:"), + }, + }, + }, + SupportedSigningAlgs: []string{"RS256"}, + now: func() time.Time { return now }, + KeySet: &staticKeySet{}, + }, + pubKeys: []*jose.JSONWebKey{ + loadRSAKey(t, "testdata/rsa_1.pem", jose.RS256), + }, + wantInitErr: "oidc: KeySet and DiscoveryURL are mutually exclusive", + }, + { + name: "health check failure", + options: Options{ + JWTAuthenticator: apiserver.JWTAuthenticator{ + Issuer: apiserver.Issuer{ + URL: "https://this-will-not-work.notatld", + Audiences: []string{"my-client"}, + }, + ClaimMappings: apiserver.ClaimMappings{ + Username: apiserver.PrefixedClaimOrExpression{ + Claim: "username", + Prefix: pointer.String("prefix:"), + }, + }, + }, + SupportedSigningAlgs: []string{"RS256"}, + }, + fetchKeysFromRemote: true, + wantHealthErrPrefix: `oidc: authenticator for issuer "https://this-will-not-work.notatld" is not healthy: Get "https://this-will-not-work.notatld/.well-known/openid-configuration": dial tcp: lookup this-will-not-work.notatld`, + }, { name: "accounts.google.com issuer", options: Options{ @@ -3306,7 +3361,7 @@ func TestToken(t *testing.T) { var successTestCount, failureTestCount int for _, test := range tests { t.Run(test.name, test.run) - if test.wantSkip || test.wantInitErr != "" { + if test.wantSkip || len(test.wantInitErr) > 0 || len(test.wantHealthErrPrefix) > 0 { continue } // check metrics for success and failure