diff --git a/pkg/util/flag/map_string_string.go b/pkg/util/flag/map_string_string.go index 48468a8dd..00c550b04 100644 --- a/pkg/util/flag/map_string_string.go +++ b/pkg/util/flag/map_string_string.go @@ -23,11 +23,14 @@ import ( ) // MapStringString can be set from the command line with the format `--flag "string=string"`. -// Multiple comma-separated key-value pairs in a single invocation are supported. For example: `--flag "a=foo,b=bar"`. -// Multiple flag invocations are supported. For example: `--flag "a=foo" --flag "b=bar"`. +// Multiple flag invocations are supported. For example: `--flag "a=foo" --flag "b=bar"`. If this is desired +// to be the only type invocation `NoSplit` should be set to true. +// Multiple comma-separated key-value pairs in a single invocation are supported if `NoSplit` +// is set to false. For example: `--flag "a=foo,b=bar"`. type MapStringString struct { Map *map[string]string initialized bool + NoSplit bool } // NewMapStringString takes a pointer to a map[string]string and returns the @@ -36,6 +39,15 @@ func NewMapStringString(m *map[string]string) *MapStringString { return &MapStringString{Map: m} } +// NewMapStringString takes a pointer to a map[string]string and sets `NoSplit` +// value to `true` and returns the MapStringString flag parsing shim for that map +func NewMapStringStringNoSplit(m *map[string]string) *MapStringString { + return &MapStringString{ + Map: m, + NoSplit: true, + } +} + // String implements github.com/spf13/pflag.Value func (m *MapStringString) String() string { pairs := []string{} @@ -56,19 +68,34 @@ func (m *MapStringString) Set(value string) error { *m.Map = make(map[string]string) m.initialized = true } - for _, s := range strings.Split(value, ",") { - if len(s) == 0 { - continue + + // account for comma-separated key-value pairs in a single invocation + if !m.NoSplit { + for _, s := range strings.Split(value, ",") { + if len(s) == 0 { + continue + } + arr := strings.SplitN(s, "=", 2) + if len(arr) != 2 { + return fmt.Errorf("malformed pair, expect string=string") + } + k := strings.TrimSpace(arr[0]) + v := strings.TrimSpace(arr[1]) + (*m.Map)[k] = v } - arr := strings.SplitN(s, "=", 2) - if len(arr) != 2 { - return fmt.Errorf("malformed pair, expect string=string") - } - k := strings.TrimSpace(arr[0]) - v := strings.TrimSpace(arr[1]) - (*m.Map)[k] = v + return nil } + + // account for only one key-value pair in a single invocation + arr := strings.SplitN(value, "=", 2) + if len(arr) != 2 { + return fmt.Errorf("malformed pair, expect string=string") + } + k := strings.TrimSpace(arr[0]) + v := strings.TrimSpace(arr[1]) + (*m.Map)[k] = v return nil + } // Type implements github.com/spf13/pflag.Value diff --git a/pkg/util/flag/map_string_string_test.go b/pkg/util/flag/map_string_string_test.go index 8feb62524..aba984e7a 100644 --- a/pkg/util/flag/map_string_string_test.go +++ b/pkg/util/flag/map_string_string_test.go @@ -58,6 +58,7 @@ func TestSetMapStringString(t *testing.T) { &MapStringString{ initialized: true, Map: &map[string]string{}, + NoSplit: false, }, ""}, // make sure we still allocate for "initialized" maps where Map was initially set to a nil map {"allocates map if currently nil", []string{""}, @@ -65,6 +66,7 @@ func TestSetMapStringString(t *testing.T) { &MapStringString{ initialized: true, Map: &map[string]string{}, + NoSplit: false, }, ""}, // for most cases, we just reuse nilMap, which should be allocated by Set, and is reset before each test case {"empty", []string{""}, @@ -72,36 +74,56 @@ func TestSetMapStringString(t *testing.T) { &MapStringString{ initialized: true, Map: &map[string]string{}, + NoSplit: false, }, ""}, {"one key", []string{"one=foo"}, NewMapStringString(&nilMap), &MapStringString{ initialized: true, Map: &map[string]string{"one": "foo"}, + NoSplit: false, }, ""}, {"two keys", []string{"one=foo,two=bar"}, NewMapStringString(&nilMap), &MapStringString{ initialized: true, Map: &map[string]string{"one": "foo", "two": "bar"}, + NoSplit: false, + }, ""}, + {"one key, multi flag invocation only", []string{"one=foo,bar"}, + NewMapStringStringNoSplit(&nilMap), + &MapStringString{ + initialized: true, + Map: &map[string]string{"one": "foo,bar"}, + NoSplit: true, + }, ""}, + {"two keys, multi flag invocation only", []string{"one=foo,bar", "two=foo,bar"}, + NewMapStringStringNoSplit(&nilMap), + &MapStringString{ + initialized: true, + Map: &map[string]string{"one": "foo,bar", "two": "foo,bar"}, + NoSplit: true, }, ""}, {"two keys, multiple Set invocations", []string{"one=foo", "two=bar"}, NewMapStringString(&nilMap), &MapStringString{ initialized: true, Map: &map[string]string{"one": "foo", "two": "bar"}, + NoSplit: false, }, ""}, {"two keys with space", []string{"one=foo, two=bar"}, NewMapStringString(&nilMap), &MapStringString{ initialized: true, Map: &map[string]string{"one": "foo", "two": "bar"}, + NoSplit: false, }, ""}, {"empty key", []string{"=foo"}, NewMapStringString(&nilMap), &MapStringString{ initialized: true, Map: &map[string]string{"": "foo"}, + NoSplit: false, }, ""}, {"missing value", []string{"one"}, NewMapStringString(&nilMap), diff --git a/plugin/pkg/authenticator/token/oidc/oidc.go b/plugin/pkg/authenticator/token/oidc/oidc.go index 8a1030fd5..94c5d8b22 100644 --- a/plugin/pkg/authenticator/token/oidc/oidc.go +++ b/plugin/pkg/authenticator/token/oidc/oidc.go @@ -98,6 +98,10 @@ type Options struct { // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation SupportedSigningAlgs []string + // RequiredClaims, if specified, causes the OIDCAuthenticator to verify that all the + // required claims key value pairs are present in the ID Token. + RequiredClaims map[string]string + // now is used for testing. It defaults to time.Now. now func() time.Time } @@ -109,6 +113,7 @@ type Authenticator struct { usernamePrefix string groupsClaim string groupsPrefix string + requiredClaims map[string]string // Contains an *oidc.IDTokenVerifier. Do not access directly use the // idTokenVerifier method. @@ -218,6 +223,7 @@ func newAuthenticator(opts Options, initVerifier func(ctx context.Context, a *Au usernamePrefix: opts.UsernamePrefix, groupsClaim: opts.GroupsClaim, groupsPrefix: opts.GroupsPrefix, + requiredClaims: opts.RequiredClaims, cancel: cancel, } @@ -323,6 +329,23 @@ func (a *Authenticator) AuthenticateToken(token string) (user.Info, bool, error) info.Groups[i] = a.groupsPrefix + group } } + + // check to ensure all required claims are present in the ID token and have matching values. + for claim, value := range a.requiredClaims { + if !c.hasClaim(claim) { + return nil, false, fmt.Errorf("oidc: required claim %s not present in ID token", claim) + } + + // NOTE: Only string values are supported as valid required claim values. + var claimValue string + if err := c.unmarshalClaim(claim, &claimValue); err != nil { + return nil, false, fmt.Errorf("oidc: parse claim %s: %v", claim, err) + } + if claimValue != value { + return nil, false, fmt.Errorf("oidc: required claim %s value does not match. Got = %s, want = %s", claim, claimValue, value) + } + } + return info, true, nil } diff --git a/plugin/pkg/authenticator/token/oidc/oidc_test.go b/plugin/pkg/authenticator/token/oidc/oidc_test.go index 1e1248a67..b74c45ab1 100644 --- a/plugin/pkg/authenticator/token/oidc/oidc_test.go +++ b/plugin/pkg/authenticator/token/oidc/oidc_test.go @@ -428,6 +428,84 @@ func TestToken(t *testing.T) { }`, valid.Unix()), wantErr: true, }, + { + name: "required-claim", + options: Options{ + IssuerURL: "https://auth.example.com", + ClientID: "my-client", + UsernameClaim: "username", + GroupsClaim: "groups", + RequiredClaims: map[string]string{ + "hd": "example.com", + "sub": "test", + }, + now: func() time.Time { return now }, + }, + signingKey: loadRSAPrivKey(t, "testdata/rsa_1.pem", jose.RS256), + pubKeys: []*jose.JSONWebKey{ + loadRSAKey(t, "testdata/rsa_1.pem", jose.RS256), + }, + claims: fmt.Sprintf(`{ + "iss": "https://auth.example.com", + "aud": "my-client", + "username": "jane", + "hd": "example.com", + "sub": "test", + "exp": %d + }`, valid.Unix()), + want: &user.DefaultInfo{ + Name: "jane", + }, + }, + { + name: "no-required-claim", + options: Options{ + IssuerURL: "https://auth.example.com", + ClientID: "my-client", + UsernameClaim: "username", + GroupsClaim: "groups", + RequiredClaims: map[string]string{ + "hd": "example.com", + }, + now: func() time.Time { return now }, + }, + signingKey: loadRSAPrivKey(t, "testdata/rsa_1.pem", jose.RS256), + pubKeys: []*jose.JSONWebKey{ + loadRSAKey(t, "testdata/rsa_1.pem", jose.RS256), + }, + claims: fmt.Sprintf(`{ + "iss": "https://auth.example.com", + "aud": "my-client", + "username": "jane", + "exp": %d + }`, valid.Unix()), + wantErr: true, + }, + { + name: "invalid-required-claim", + options: Options{ + IssuerURL: "https://auth.example.com", + ClientID: "my-client", + UsernameClaim: "username", + GroupsClaim: "groups", + RequiredClaims: map[string]string{ + "hd": "example.com", + }, + now: func() time.Time { return now }, + }, + signingKey: loadRSAPrivKey(t, "testdata/rsa_1.pem", jose.RS256), + pubKeys: []*jose.JSONWebKey{ + loadRSAKey(t, "testdata/rsa_1.pem", jose.RS256), + }, + claims: fmt.Sprintf(`{ + "iss": "https://auth.example.com", + "aud": "my-client", + "username": "jane", + "hd": "example.org", + "exp": %d + }`, valid.Unix()), + wantErr: true, + }, { name: "invalid-signature", options: Options{