Updated OAuth2 middleware

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2022-09-29 00:26:47 +00:00
parent fb76c277af
commit 038bcf4666
3 changed files with 84 additions and 45 deletions

View File

@ -13,7 +13,21 @@ func RespondWithError(w http.ResponseWriter, statusCode int) {
statusCode = http.StatusInternalServerError
statusText = http.StatusText(statusCode)
}
RespondWithErrorAndMessage(w, statusCode, statusText)
}
// RespondWithErrorAndMessage responds to a http.ResponseWriter with an error status code.
// The message is included in the body as response.
// This method should be invoked before calling w.WriteHeader, and callers should abort the request after calling this method.
func RespondWithErrorAndMessage(w http.ResponseWriter, statusCode int, message string) {
w.Header().Set("content-type", "text/plain; charset=utf-8")
w.WriteHeader(statusCode)
w.Write([]byte(statusText))
w.Write([]byte(message))
}
// RespondWithRedirect responds to a http.ResponseWriter with a redirect.
// This method should be invoked before calling w.WriteHeader, and callers should abort the request after calling this method.
func RespondWithRedirect(w http.ResponseWriter, statusCode int, location string) {
w.Header().Set("location", location)
w.WriteHeader(statusCode)
}

View File

@ -32,13 +32,12 @@ type bearerMiddlewareMetadata struct {
}
// NewBearerMiddleware returns a new oAuth2 middleware.
func NewBearerMiddleware(logger logger.Logger) middleware.Middleware {
return &Middleware{logger: logger}
func NewBearerMiddleware(_ logger.Logger) middleware.Middleware {
return &Middleware{}
}
// Middleware is an oAuth2 authentication middleware.
type Middleware struct {
logger logger.Logger
}
const (

View File

@ -14,17 +14,19 @@ limitations under the License.
package oauth2
import (
"context"
"net/http"
"net/url"
"strings"
"github.com/fasthttp-contrib/sessions"
"github.com/google/uuid"
"github.com/valyala/fasthttp"
"golang.org/x/oauth2"
"github.com/dapr/components-contrib/internal/httputils"
"github.com/dapr/components-contrib/internal/utils"
mdutils "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/middleware"
"github.com/dapr/kit/logger"
)
// Metadata is the oAuth middleware config.
@ -40,19 +42,20 @@ type oAuth2MiddlewareMetadata struct {
}
// NewOAuth2Middleware returns a new oAuth2 middleware.
func NewOAuth2Middleware() middleware.Middleware {
return &Middleware{}
func NewOAuth2Middleware(log logger.Logger) middleware.Middleware {
return &Middleware{logger: log}
}
// Middleware is an oAuth2 authentication middleware.
type Middleware struct{}
type Middleware struct {
logger logger.Logger
}
const (
stateParam = "state"
savedState = "auth-state"
redirectPath = "redirect-url"
codeParam = "code"
https = "https://"
)
// GetHandler retruns the HTTP handler provided by the middleware.
@ -62,55 +65,78 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
return nil, err
}
forceHTTPS := utils.IsTruthy(meta.ForceHTTPS)
conf := &oauth2.Config{
ClientID: meta.ClientID,
ClientSecret: meta.ClientSecret,
Scopes: strings.Split(meta.Scopes, ","),
RedirectURL: meta.RedirectURL,
Endpoint: oauth2.Endpoint{
AuthURL: meta.AuthURL,
TokenURL: meta.TokenURL,
},
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conf := &oauth2.Config{
ClientID: meta.ClientID,
ClientSecret: meta.ClientSecret,
Scopes: strings.Split(meta.Scopes, ","),
RedirectURL: meta.RedirectURL,
Endpoint: oauth2.Endpoint{
AuthURL: meta.AuthURL,
TokenURL: meta.TokenURL,
},
}
session := sessions.Start(w, r)
if session.GetString(meta.AuthHeaderName) != "" {
w.Header().Set(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName))
next.ServeHTTP(w, r)
return
}
state := string(ctx.FormValue(stateParam))
//nolint:nestif
state := r.URL.Query().Get(stateParam)
if state == "" {
id, _ := uuid.NewUUID()
session.Set(savedState, id.String())
session.Set(redirectPath, string(ctx.RequestURI()))
url := conf.AuthCodeURL(id.String(), oauth2.AccessTypeOffline)
ctx.Redirect(url, 302)
id, err := uuid.NewRandom()
if err != nil {
httputils.RespondWithError(w, http.StatusInternalServerError)
m.logger.Errorf("Failed to generate UUID: %v", err)
return
}
idStr := id.String()
session.Set(savedState, idStr)
session.Set(redirectPath, r.URL)
url := conf.AuthCodeURL(idStr, oauth2.AccessTypeOffline)
httputils.RespondWithRedirect(w, http.StatusFound, url)
} else {
authState := session.GetString(savedState)
redirectURL := session.GetString(redirectPath)
if strings.EqualFold(meta.ForceHTTPS, "true") {
redirectURL = https + string(ctx.Request.Host()) + redirectURL
redirectURL, ok := session.Get(redirectPath).(*url.URL)
if !ok {
httputils.RespondWithError(w, http.StatusInternalServerError)
m.logger.Errorf("Value saved in state key '%s' is not a *url.URL")
return
}
if forceHTTPS {
redirectURL.Scheme = "https"
}
if state != authState {
ctx.Error("invalid state", fasthttp.StatusBadRequest)
} else {
code := string(ctx.FormValue(codeParam))
if code == "" {
ctx.Error("code not found", fasthttp.StatusBadRequest)
} else {
token, err := conf.Exchange(context.Background(), code)
if err != nil {
ctx.Error(err.Error(), fasthttp.StatusInternalServerError)
}
session.Set(meta.AuthHeaderName, token.Type()+" "+token.AccessToken)
ctx.Request.Header.Add(meta.AuthHeaderName, token.Type()+" "+token.AccessToken)
ctx.Redirect(redirectURL, 302)
}
httputils.RespondWithErrorAndMessage(w, http.StatusBadRequest, "invalid state")
return
}
code := r.URL.Query().Get(codeParam)
if code == "" {
httputils.RespondWithErrorAndMessage(w, http.StatusBadRequest, "code not found")
return
}
token, err := conf.Exchange(r.Context(), code)
if err != nil {
httputils.RespondWithError(w, http.StatusInternalServerError)
m.logger.Errorf("Failed to exchange token: %v", err)
return
}
authHeader := token.Type() + " " + token.AccessToken
session.Set(meta.AuthHeaderName, authHeader)
w.Header().Set(meta.AuthHeaderName, authHeader)
httputils.RespondWithRedirect(w, http.StatusFound, redirectURL.String())
}
})
}, nil