diff --git a/pkg/endpoints/filters/impersonation.go b/pkg/endpoints/filters/impersonation.go index ab767695f..1246ae863 100644 --- a/pkg/endpoints/filters/impersonation.go +++ b/pkg/endpoints/filters/impersonation.go @@ -117,10 +117,37 @@ func WithImpersonation(handler http.Handler, a authorizer.Authorizer, s runtime. } } - if !groupsSpecified && username != user.Anonymous { - // When impersonating a non-anonymous user, if no groups were specified - // include the system:authenticated group in the impersonated user info - groups = append(groups, user.AllAuthenticated) + if username != user.Anonymous { + // When impersonating a non-anonymous user, include the 'system:authenticated' group + // in the impersonated user info: + // - if no groups were specified + // - if a group has been specified other than 'system:authenticated' + // + // If 'system:unauthenticated' group has been specified we should not include + // the 'system:authenticated' group. + addAuthenticated := true + for _, group := range groups { + if group == user.AllAuthenticated || group == user.AllUnauthenticated { + addAuthenticated = false + break + } + } + + if addAuthenticated { + groups = append(groups, user.AllAuthenticated) + } + } else { + addUnauthenticated := true + for _, group := range groups { + if group == user.AllUnauthenticated { + addUnauthenticated = false + break + } + } + + if addUnauthenticated { + groups = append(groups, user.AllUnauthenticated) + } } newUser := &user.DefaultInfo{ diff --git a/pkg/endpoints/filters/impersonation_test.go b/pkg/endpoints/filters/impersonation_test.go index 1408d2b1c..376182b58 100644 --- a/pkg/endpoints/filters/impersonation_test.go +++ b/pkg/endpoints/filters/impersonation_test.go @@ -163,7 +163,7 @@ func TestImpersonationFilter(t *testing.T) { impersonationGroups: []string{"some-group"}, expectedUser: &user.DefaultInfo{ Name: "system:admin", - Groups: []string{"some-group"}, + Groups: []string{"some-group", "system:authenticated"}, Extra: map[string][]string{}, }, expectedCode: http.StatusOK, @@ -308,7 +308,7 @@ func TestImpersonationFilter(t *testing.T) { impersonationUser: "system:anonymous", expectedUser: &user.DefaultInfo{ Name: "system:anonymous", - Groups: []string{}, + Groups: []string{"system:unauthenticated"}, Extra: map[string][]string{}, }, expectedCode: http.StatusOK, @@ -341,6 +341,48 @@ func TestImpersonationFilter(t *testing.T) { }, expectedCode: http.StatusOK, }, + { + name: "specified-authenticated-group-prevents-double-adding-authenticated-group", + user: &user.DefaultInfo{ + Name: "dev", + Groups: []string{"wheel", "group-impersonater"}, + }, + impersonationUser: "system:admin", + impersonationGroups: []string{"some-group", "system:authenticated"}, + expectedUser: &user.DefaultInfo{ + Name: "system:admin", + Groups: []string{"some-group", "system:authenticated"}, + Extra: map[string][]string{}, + }, + expectedCode: http.StatusOK, + }, + { + name: "anonymous-user-should-include-unauthenticated-group", + user: &user.DefaultInfo{ + Name: "system:admin", + }, + impersonationUser: "system:anonymous", + expectedUser: &user.DefaultInfo{ + Name: "system:anonymous", + Groups: []string{"system:unauthenticated"}, + Extra: map[string][]string{}, + }, + expectedCode: http.StatusOK, + }, + { + name: "anonymous-user-prevents-double-adding-unauthenticated-group", + user: &user.DefaultInfo{ + Name: "system:admin", + }, + impersonationUser: "system:anonymous", + impersonationGroups: []string{"system:unauthenticated"}, + expectedUser: &user.DefaultInfo{ + Name: "system:anonymous", + Groups: []string{"system:unauthenticated"}, + Extra: map[string][]string{}, + }, + expectedCode: http.StatusOK, + }, } var ctx context.Context @@ -398,42 +440,44 @@ func TestImpersonationFilter(t *testing.T) { defer server.Close() for _, tc := range testCases { - func() { - lock.Lock() - defer lock.Unlock() - ctx = request.WithUser(request.NewContext(), tc.user) - }() + t.Run(tc.name, func(t *testing.T) { + func() { + lock.Lock() + defer lock.Unlock() + ctx = request.WithUser(request.NewContext(), tc.user) + }() - req, err := http.NewRequest("GET", server.URL, nil) - if err != nil { - t.Errorf("%s: unexpected error: %v", tc.name, err) - continue - } - if len(tc.impersonationUser) > 0 { - req.Header.Add(authenticationapi.ImpersonateUserHeader, tc.impersonationUser) - } - for _, group := range tc.impersonationGroups { - req.Header.Add(authenticationapi.ImpersonateGroupHeader, group) - } - for extraKey, values := range tc.impersonationUserExtras { - for _, value := range values { - req.Header.Add(authenticationapi.ImpersonateUserExtraHeaderPrefix+extraKey, value) + req, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Errorf("%s: unexpected error: %v", tc.name, err) + return + } + if len(tc.impersonationUser) > 0 { + req.Header.Add(authenticationapi.ImpersonateUserHeader, tc.impersonationUser) + } + for _, group := range tc.impersonationGroups { + req.Header.Add(authenticationapi.ImpersonateGroupHeader, group) + } + for extraKey, values := range tc.impersonationUserExtras { + for _, value := range values { + req.Header.Add(authenticationapi.ImpersonateUserExtraHeaderPrefix+extraKey, value) + } } - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Errorf("%s: unexpected error: %v", tc.name, err) - continue - } - if resp.StatusCode != tc.expectedCode { - t.Errorf("%s: expected %v, actual %v", tc.name, tc.expectedCode, resp.StatusCode) - continue - } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("%s: unexpected error: %v", tc.name, err) + return + } + if resp.StatusCode != tc.expectedCode { + t.Errorf("%s: expected %v, actual %v", tc.name, tc.expectedCode, resp.StatusCode) + return + } - if !reflect.DeepEqual(actualUser, tc.expectedUser) { - t.Errorf("%s: expected %#v, actual %#v", tc.name, tc.expectedUser, actualUser) - continue - } + if !reflect.DeepEqual(actualUser, tc.expectedUser) { + t.Errorf("%s: expected %#v, actual %#v", tc.name, tc.expectedUser, actualUser) + return + } + }) } }