Updated OAuth2 middleware
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
fb76c277af
commit
038bcf4666
|
@ -13,7 +13,21 @@ func RespondWithError(w http.ResponseWriter, statusCode int) {
|
|||
statusCode = http.StatusInternalServerError
|
||||
statusText = http.StatusText(statusCode)
|
||||
}
|
||||
RespondWithErrorAndMessage(w, statusCode, statusText)
|
||||
}
|
||||
|
||||
// RespondWithErrorAndMessage responds to a http.ResponseWriter with an error status code.
|
||||
// The message is included in the body as response.
|
||||
// This method should be invoked before calling w.WriteHeader, and callers should abort the request after calling this method.
|
||||
func RespondWithErrorAndMessage(w http.ResponseWriter, statusCode int, message string) {
|
||||
w.Header().Set("content-type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(statusCode)
|
||||
w.Write([]byte(statusText))
|
||||
w.Write([]byte(message))
|
||||
}
|
||||
|
||||
// RespondWithRedirect responds to a http.ResponseWriter with a redirect.
|
||||
// This method should be invoked before calling w.WriteHeader, and callers should abort the request after calling this method.
|
||||
func RespondWithRedirect(w http.ResponseWriter, statusCode int, location string) {
|
||||
w.Header().Set("location", location)
|
||||
w.WriteHeader(statusCode)
|
||||
}
|
||||
|
|
|
@ -32,13 +32,12 @@ type bearerMiddlewareMetadata struct {
|
|||
}
|
||||
|
||||
// NewBearerMiddleware returns a new oAuth2 middleware.
|
||||
func NewBearerMiddleware(logger logger.Logger) middleware.Middleware {
|
||||
return &Middleware{logger: logger}
|
||||
func NewBearerMiddleware(_ logger.Logger) middleware.Middleware {
|
||||
return &Middleware{}
|
||||
}
|
||||
|
||||
// Middleware is an oAuth2 authentication middleware.
|
||||
type Middleware struct {
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
const (
|
||||
|
|
|
@ -14,17 +14,19 @@ limitations under the License.
|
|||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/fasthttp-contrib/sessions"
|
||||
"github.com/google/uuid"
|
||||
"github.com/valyala/fasthttp"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/dapr/components-contrib/internal/httputils"
|
||||
"github.com/dapr/components-contrib/internal/utils"
|
||||
mdutils "github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/middleware"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
// Metadata is the oAuth middleware config.
|
||||
|
@ -40,19 +42,20 @@ type oAuth2MiddlewareMetadata struct {
|
|||
}
|
||||
|
||||
// NewOAuth2Middleware returns a new oAuth2 middleware.
|
||||
func NewOAuth2Middleware() middleware.Middleware {
|
||||
return &Middleware{}
|
||||
func NewOAuth2Middleware(log logger.Logger) middleware.Middleware {
|
||||
return &Middleware{logger: log}
|
||||
}
|
||||
|
||||
// Middleware is an oAuth2 authentication middleware.
|
||||
type Middleware struct{}
|
||||
type Middleware struct {
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
const (
|
||||
stateParam = "state"
|
||||
savedState = "auth-state"
|
||||
redirectPath = "redirect-url"
|
||||
codeParam = "code"
|
||||
https = "https://"
|
||||
)
|
||||
|
||||
// GetHandler retruns the HTTP handler provided by the middleware.
|
||||
|
@ -62,55 +65,78 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
|
|||
return nil, err
|
||||
}
|
||||
|
||||
forceHTTPS := utils.IsTruthy(meta.ForceHTTPS)
|
||||
conf := &oauth2.Config{
|
||||
ClientID: meta.ClientID,
|
||||
ClientSecret: meta.ClientSecret,
|
||||
Scopes: strings.Split(meta.Scopes, ","),
|
||||
RedirectURL: meta.RedirectURL,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: meta.AuthURL,
|
||||
TokenURL: meta.TokenURL,
|
||||
},
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conf := &oauth2.Config{
|
||||
ClientID: meta.ClientID,
|
||||
ClientSecret: meta.ClientSecret,
|
||||
Scopes: strings.Split(meta.Scopes, ","),
|
||||
RedirectURL: meta.RedirectURL,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: meta.AuthURL,
|
||||
TokenURL: meta.TokenURL,
|
||||
},
|
||||
}
|
||||
session := sessions.Start(w, r)
|
||||
|
||||
if session.GetString(meta.AuthHeaderName) != "" {
|
||||
w.Header().Set(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName))
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
return
|
||||
}
|
||||
state := string(ctx.FormValue(stateParam))
|
||||
//nolint:nestif
|
||||
|
||||
state := r.URL.Query().Get(stateParam)
|
||||
if state == "" {
|
||||
id, _ := uuid.NewUUID()
|
||||
session.Set(savedState, id.String())
|
||||
session.Set(redirectPath, string(ctx.RequestURI()))
|
||||
url := conf.AuthCodeURL(id.String(), oauth2.AccessTypeOffline)
|
||||
ctx.Redirect(url, 302)
|
||||
id, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
httputils.RespondWithError(w, http.StatusInternalServerError)
|
||||
m.logger.Errorf("Failed to generate UUID: %v", err)
|
||||
return
|
||||
}
|
||||
idStr := id.String()
|
||||
|
||||
session.Set(savedState, idStr)
|
||||
session.Set(redirectPath, r.URL)
|
||||
|
||||
url := conf.AuthCodeURL(idStr, oauth2.AccessTypeOffline)
|
||||
httputils.RespondWithRedirect(w, http.StatusFound, url)
|
||||
} else {
|
||||
authState := session.GetString(savedState)
|
||||
redirectURL := session.GetString(redirectPath)
|
||||
if strings.EqualFold(meta.ForceHTTPS, "true") {
|
||||
redirectURL = https + string(ctx.Request.Host()) + redirectURL
|
||||
redirectURL, ok := session.Get(redirectPath).(*url.URL)
|
||||
if !ok {
|
||||
httputils.RespondWithError(w, http.StatusInternalServerError)
|
||||
m.logger.Errorf("Value saved in state key '%s' is not a *url.URL")
|
||||
return
|
||||
}
|
||||
|
||||
if forceHTTPS {
|
||||
redirectURL.Scheme = "https"
|
||||
}
|
||||
|
||||
if state != authState {
|
||||
ctx.Error("invalid state", fasthttp.StatusBadRequest)
|
||||
} else {
|
||||
code := string(ctx.FormValue(codeParam))
|
||||
if code == "" {
|
||||
ctx.Error("code not found", fasthttp.StatusBadRequest)
|
||||
} else {
|
||||
token, err := conf.Exchange(context.Background(), code)
|
||||
if err != nil {
|
||||
ctx.Error(err.Error(), fasthttp.StatusInternalServerError)
|
||||
}
|
||||
session.Set(meta.AuthHeaderName, token.Type()+" "+token.AccessToken)
|
||||
ctx.Request.Header.Add(meta.AuthHeaderName, token.Type()+" "+token.AccessToken)
|
||||
ctx.Redirect(redirectURL, 302)
|
||||
}
|
||||
httputils.RespondWithErrorAndMessage(w, http.StatusBadRequest, "invalid state")
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get(codeParam)
|
||||
if code == "" {
|
||||
httputils.RespondWithErrorAndMessage(w, http.StatusBadRequest, "code not found")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := conf.Exchange(r.Context(), code)
|
||||
if err != nil {
|
||||
httputils.RespondWithError(w, http.StatusInternalServerError)
|
||||
m.logger.Errorf("Failed to exchange token: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := token.Type() + " " + token.AccessToken
|
||||
session.Set(meta.AuthHeaderName, authHeader)
|
||||
w.Header().Set(meta.AuthHeaderName, authHeader)
|
||||
httputils.RespondWithRedirect(w, http.StatusFound, redirectURL.String())
|
||||
}
|
||||
})
|
||||
}, nil
|
||||
|
|
Loading…
Reference in New Issue