diff --git a/pkg/authentication/serviceaccount/util.go b/pkg/authentication/serviceaccount/util.go index 949b60922..dd11efbde 100644 --- a/pkg/authentication/serviceaccount/util.go +++ b/pkg/authentication/serviceaccount/util.go @@ -163,15 +163,6 @@ func (sa *ServiceAccountInfo) UserInfo() user.Info { return info } -// CredentialIDForJTI converts a given JTI string into a credential identifier for use in a -// users 'extra' info. -func CredentialIDForJTI(jti string) string { - if len(jti) == 0 { - return "" - } - return "JTI=" + jti -} - // IsServiceAccountToken returns true if the secret is a valid api token for the service account func IsServiceAccountToken(secret *v1.Secret, sa *v1.ServiceAccount) bool { if secret.Type != v1.SecretTypeServiceAccountToken { diff --git a/pkg/authentication/token/jwt/jwt.go b/pkg/authentication/token/jwt/jwt.go new file mode 100644 index 000000000..17b384949 --- /dev/null +++ b/pkg/authentication/token/jwt/jwt.go @@ -0,0 +1,26 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package jwt + +// CredentialIDForJTI converts a given JTI string into a credential identifier for use in a +// users 'extra' info. +func CredentialIDForJTI(jti string) string { + if len(jti) == 0 { + return "" + } + return "JTI=" + jti +} diff --git a/plugin/pkg/authenticator/token/oidc/oidc.go b/plugin/pkg/authenticator/token/oidc/oidc.go index bdf50d89c..226ea22de 100644 --- a/plugin/pkg/authenticator/token/oidc/oidc.go +++ b/plugin/pkg/authenticator/token/oidc/oidc.go @@ -56,6 +56,7 @@ import ( apiservervalidation "k8s.io/apiserver/pkg/apis/apiserver/validation" "k8s.io/apiserver/pkg/authentication/authenticator" authenticationcel "k8s.io/apiserver/pkg/authentication/cel" + authenticationtokenjwt "k8s.io/apiserver/pkg/authentication/token/jwt" "k8s.io/apiserver/pkg/authentication/user" certutil "k8s.io/client-go/util/cert" "k8s.io/klog/v2" @@ -726,7 +727,7 @@ func (a *jwtAuthenticator) AuthenticateToken(ctx context.Context, token string) return nil, false, err } - extra, err := a.getExtra(ctx, claimsUnstructured) + extra, err := a.getExtra(ctx, c, claimsUnstructured) if err != nil { return nil, false, err } @@ -914,17 +915,21 @@ func (a *jwtAuthenticator) getUID(ctx context.Context, c claims, claimsUnstructu return evalResult.EvalResult.Value().(string), nil } -func (a *jwtAuthenticator) getExtra(ctx context.Context, claimsUnstructured *unstructured.Unstructured) (map[string][]string, error) { +func (a *jwtAuthenticator) getExtra(ctx context.Context, c claims, claimsUnstructured *unstructured.Unstructured) (map[string][]string, error) { + extra := make(map[string][]string) + + if credentialID := getCredentialID(c); len(credentialID) > 0 { + extra[user.CredentialIDKey] = []string{credentialID} + } + if a.celMapper.Extra == nil { - return nil, nil + return extra, nil } evalResult, err := a.celMapper.Extra.EvalClaimMappings(ctx, claimsUnstructured) if err != nil { return nil, err } - - extra := make(map[string][]string, len(evalResult)) for _, result := range evalResult { extraMapping, ok := result.ExpressionAccessor.(*authenticationcel.ExtraMappingExpression) if !ok { @@ -936,16 +941,25 @@ func (a *jwtAuthenticator) getExtra(ctx context.Context, claimsUnstructured *uns return nil, fmt.Errorf("oidc: error evaluating extra claim expression: %s: %w", extraMapping.Expression, err) } - if len(extraValues) == 0 { - continue + if len(extraValues) > 0 { + extra[extraMapping.Key] = extraValues } - - extra[extraMapping.Key] = extraValues } return extra, nil } +func getCredentialID(c claims) string { + if _, ok := c["jti"]; ok { + var jti string + if err := c.unmarshalClaim("jti", &jti); err == nil { + return authenticationtokenjwt.CredentialIDForJTI(jti) + } + } + + return "" +} + // getClaimJWT gets a distributed claim JWT from url, using the supplied access // token as bearer token. If the access token is "", the authorization header // will not be set. diff --git a/plugin/pkg/authenticator/token/oidc/oidc_test.go b/plugin/pkg/authenticator/token/oidc/oidc_test.go index 986d79b10..b8ff716c3 100644 --- a/plugin/pkg/authenticator/token/oidc/oidc_test.go +++ b/plugin/pkg/authenticator/token/oidc/oidc_test.go @@ -3356,6 +3356,152 @@ func TestToken(t *testing.T) { Name: "jane", }, }, + { + name: "credential id set in extra even when no extra claim mappings are defined", + options: Options{ + JWTAuthenticator: apiserver.JWTAuthenticator{ + Issuer: apiserver.Issuer{ + URL: "https://auth.example.com", + Audiences: []string{"my-client"}, + }, + ClaimMappings: apiserver.ClaimMappings{ + Username: apiserver.PrefixedClaimOrExpression{ + Expression: "claims.username", + }, + }, + }, + now: func() time.Time { return now }, + }, + signingKey: loadRSAPrivKey(t, "testdata/rsa_1.pem", jose.RS256), + pubKeys: []*jose.JSONWebKey{ + loadRSAKey(t, "testdata/rsa_1.pem", jose.RS256), + }, + claims: fmt.Sprintf(`{ + "iss": "https://auth.example.com", + "aud": "my-client", + "username": "jane", + "exp": %d, + "jti": "1234" + }`, valid.Unix()), + want: &user.DefaultInfo{ + Name: "jane", + Extra: map[string][]string{ + user.CredentialIDKey: {"JTI=1234"}, + }, + }, + }, + { + name: "credential id set in extra when extra claim mappings are defined", + options: Options{ + JWTAuthenticator: apiserver.JWTAuthenticator{ + Issuer: apiserver.Issuer{ + URL: "https://auth.example.com", + Audiences: []string{"my-client"}, + }, + ClaimMappings: apiserver.ClaimMappings{ + Username: apiserver.PrefixedClaimOrExpression{ + Expression: "claims.username", + }, + Extra: []apiserver.ExtraMapping{ + { + Key: "example.org/foo", + ValueExpression: "claims.foo", + }, + { + Key: "example.org/bar", + ValueExpression: "claims.bar", + }, + }, + }, + }, + now: func() time.Time { return now }, + }, + signingKey: loadRSAPrivKey(t, "testdata/rsa_1.pem", jose.RS256), + pubKeys: []*jose.JSONWebKey{ + loadRSAKey(t, "testdata/rsa_1.pem", jose.RS256), + }, + claims: fmt.Sprintf(`{ + "iss": "https://auth.example.com", + "aud": "my-client", + "username": "jane", + "exp": %d, + "jti": "1234", + "foo": "bar", + "bar": [ + "baz", + "qux" + ] + }`, valid.Unix()), + want: &user.DefaultInfo{ + Name: "jane", + Extra: map[string][]string{ + user.CredentialIDKey: {"JTI=1234"}, + "example.org/foo": {"bar"}, + "example.org/bar": {"baz", "qux"}, + }, + }, + }, + { + name: "non-string jti claim does not set credential id in extra or error", + options: Options{ + JWTAuthenticator: apiserver.JWTAuthenticator{ + Issuer: apiserver.Issuer{ + URL: "https://auth.example.com", + Audiences: []string{"my-client"}, + }, + ClaimMappings: apiserver.ClaimMappings{ + Username: apiserver.PrefixedClaimOrExpression{ + Expression: "claims.username", + }, + }, + }, + now: func() time.Time { return now }, + }, + signingKey: loadRSAPrivKey(t, "testdata/rsa_1.pem", jose.RS256), + pubKeys: []*jose.JSONWebKey{ + loadRSAKey(t, "testdata/rsa_1.pem", jose.RS256), + }, + claims: fmt.Sprintf(`{ + "iss": "https://auth.example.com", + "aud": "my-client", + "username": "jane", + "exp": %d, + "jti": 1234 + }`, valid.Unix()), + want: &user.DefaultInfo{ + Name: "jane", + }, + }, + { + name: "missing jti claim does not set credential id in extra or error", + options: Options{ + JWTAuthenticator: apiserver.JWTAuthenticator{ + Issuer: apiserver.Issuer{ + URL: "https://auth.example.com", + Audiences: []string{"my-client"}, + }, + ClaimMappings: apiserver.ClaimMappings{ + Username: apiserver.PrefixedClaimOrExpression{ + Expression: "claims.username", + }, + }, + }, + now: func() time.Time { return now }, + }, + signingKey: loadRSAPrivKey(t, "testdata/rsa_1.pem", jose.RS256), + pubKeys: []*jose.JSONWebKey{ + loadRSAKey(t, "testdata/rsa_1.pem", jose.RS256), + }, + claims: fmt.Sprintf(`{ + "iss": "https://auth.example.com", + "aud": "my-client", + "username": "jane", + "exp": %d + }`, valid.Unix()), + want: &user.DefaultInfo{ + Name: "jane", + }, + }, } var successTestCount, failureTestCount int