[release-1.10] Backport OAuth middlewares fixes into 1.10 [DO NOT SQUASH] (#2631)

This commit is contained in:
Alessandro (Ale) Segala 2023-03-04 01:20:51 +00:00 committed by GitHub
parent 1c48453fef
commit c264ba5df9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 73 additions and 8 deletions

View File

@ -43,7 +43,9 @@ type oAuth2MiddlewareMetadata struct {
// NewOAuth2Middleware returns a new oAuth2 middleware.
func NewOAuth2Middleware(log logger.Logger) middleware.Middleware {
return &Middleware{logger: log}
m := &Middleware{logger: log}
return m
}
// Middleware is an oAuth2 authentication middleware.
@ -82,11 +84,12 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
session := sessions.Start(w, r)
if session.GetString(meta.AuthHeaderName) != "" {
w.Header().Add(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName))
r.Header.Add(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName))
next.ServeHTTP(w, r)
return
}
// Redirect to the auth server
state := r.URL.Query().Get(stateParam)
if state == "" {
id, err := uuid.NewRandom()
@ -135,7 +138,6 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
authHeader := token.Type() + " " + token.AccessToken
session.Set(meta.AuthHeaderName, authHeader)
w.Header().Add(meta.AuthHeaderName, authHeader)
httputils.RespondWithRedirect(w, http.StatusFound, redirectURL.String())
}
})

View File

@ -0,0 +1,63 @@
/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package oauth2
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/fasthttp-contrib/sessions"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/middleware"
"github.com/dapr/kit/logger"
)
func TestOAuth2CreatesAuthorizationHeaderWhenInSessionState(t *testing.T) {
var metadata middleware.Metadata
metadata.Properties = map[string]string{
"clientID": "testId",
"clientSecret": "testSecret",
"scopes": "ascope",
"authURL": "https://idp:9999",
"tokenURL": "https://idp:9999",
"redirectUrl": "https://localhost:9999",
"authHeaderName": "someHeader",
}
log := logger.NewLogger("oauth2.test")
handler, err := NewOAuth2Middleware(log).GetHandler(metadata)
require.NoError(t, err)
// Create request and recorder
r := httptest.NewRequest(http.MethodGet, "http://dapr.io", nil)
w := httptest.NewRecorder()
session := sessions.Start(w, r)
session.Set("someHeader", "Bearer abcd")
// Copy the session cookie to the request
cookie := w.Header().Get("Set-Cookie")
r.Header.Add("Cookie", cookie)
handler(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("from mock"))
}),
).ServeHTTP(w, r)
assert.Equal(t, "Bearer abcd", r.Header.Get("someHeader"))
}

View File

@ -117,7 +117,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
headerValue = cachedToken.(string)
}
w.Header().Add(meta.HeaderName, headerValue)
r.Header.Add(meta.HeaderName, headerValue)
next.ServeHTTP(w, r)
})
}, nil

View File

@ -118,7 +118,7 @@ func TestOAuth2ClientCredentialsToken(t *testing.T) {
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
// Assertion
assert.Equal(t, "Bearer abcd", w.Header().Get("someHeader"))
assert.Equal(t, "Bearer abcd", r.Header.Get("someHeader"))
}
// TestOAuth2ClientCredentialsCache will check
@ -178,7 +178,7 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) {
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
// Assertion
assert.Equal(t, "Bearer abc", w.Header().Get("someHeader"))
assert.Equal(t, "Bearer abc", r.Header.Get("someHeader"))
// Second handler call should still return 'cached' abc Token
r = httptest.NewRequest(http.MethodGet, "http://dapr.io", nil)
@ -186,7 +186,7 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) {
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
// Assertion
assert.Equal(t, "Bearer abc", w.Header().Get("someHeader"))
assert.Equal(t, "Bearer abc", r.Header.Get("someHeader"))
// Wait at a second to invalidate cache entry for abc
time.Sleep(1 * time.Second)
@ -197,5 +197,5 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) {
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
// Assertion
assert.Equal(t, "MAC def", w.Header().Get("someHeader"))
assert.Equal(t, "MAC def", r.Header.Get("someHeader"))
}