diff --git a/middleware/http/oauth2/oauth2_middleware.go b/middleware/http/oauth2/oauth2_middleware.go index 3183dc381..c6a6eb209 100644 --- a/middleware/http/oauth2/oauth2_middleware.go +++ b/middleware/http/oauth2/oauth2_middleware.go @@ -73,17 +73,22 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.R TokenURL: meta.TokenURL, }, } + session := sessions.StartFasthttp(ctx) if session.GetString(meta.AuthHeaderName) != "" { ctx.Request.Header.Add(meta.AuthHeaderName, session.GetString(meta.AuthHeaderName)) h(ctx) - return } + state := string(ctx.FormValue(stateParam)) //nolint:nestif if state == "" { - id, _ := uuid.NewUUID() + id, err := uuid.NewRandom() + if err != nil { + ctx.Error(fasthttp.StatusMessage(fasthttp.StatusInternalServerError), fasthttp.StatusInternalServerError) + return + } session.Set(savedState, id.String()) session.Set(redirectPath, string(ctx.RequestURI())) url := conf.AuthCodeURL(id.String(), oauth2.AccessTypeOffline) @@ -100,15 +105,17 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(h fasthttp.R code := string(ctx.FormValue(codeParam)) if code == "" { ctx.Error("code not found", fasthttp.StatusBadRequest) - } else { - token, err := conf.Exchange(context.Background(), code) - if err != nil { - ctx.Error(err.Error(), fasthttp.StatusInternalServerError) - } - session.Set(meta.AuthHeaderName, token.Type()+" "+token.AccessToken) - ctx.Request.Header.Add(meta.AuthHeaderName, token.Type()+" "+token.AccessToken) - ctx.Redirect(redirectURL, 302) + return } + + token, err := conf.Exchange(context.Background(), code) + if err != nil { + ctx.Error(err.Error(), fasthttp.StatusInternalServerError) + return + } + session.Set(meta.AuthHeaderName, token.Type()+" "+token.AccessToken) + ctx.Request.Header.Add(meta.AuthHeaderName, token.Type()+" "+token.AccessToken) + ctx.Redirect(redirectURL, 302) } } }