hub/internal/handlers/user/handlers.go

968 lines
30 KiB
Go

package user
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math/big"
"net"
"net/http"
"strconv"
"strings"
"time"
"github.com/artifacthub/hub/internal/handlers/helpers"
"github.com/artifacthub/hub/internal/hub"
"github.com/artifacthub/hub/internal/user"
"github.com/coreos/go-oidc"
"github.com/go-chi/chi"
"github.com/google/go-github/github"
"github.com/gorilla/securecookie"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/satori/uuid"
"github.com/spf13/viper"
pwvalidator "github.com/wagslane/go-password-validator"
"golang.org/x/oauth2"
oagithub "golang.org/x/oauth2/github"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
"google.golang.org/api/people/v1"
)
const (
// APIKeyIDHeader represents the header used to provide an API key ID.
APIKeyIDHeader = "X-API-KEY-ID"
// APIKeySecretHeader represents the header used to provide an API key
// secret.
APIKeySecretHeader = "X-API-KEY-SECRET" // #nosec
// SessionApprovedHeader represents the header used to indicate a client if
// a session is approved or not. When a user has enabled TFA, sessions need
// be approved by providing a TFA passcode to the session validation
// endpoint. If TFA is not enabled, sessions will be approved on creation.
SessionApprovedHeader = "X-SESSION-APPROVED"
sessionCookieName = "sid"
oauthStateCookieName = "oas"
sessionDuration = 30 * 24 * time.Hour
oauthFailedURL = "/oauth-failed"
)
var (
// errInvalidAPIKey error indicates that the API key provided is not valid.
errInvalidAPIKey = errors.New("invalid api key")
// errInvalidSession error indicates that the session provided is not valid.
errInvalidSession = errors.New("invalid session")
)
// Handlers represents a group of http handlers in charge of handling
// users operations.
type Handlers struct {
userManager hub.UserManager
apiKeyManager hub.APIKeyManager
cfg *viper.Viper
sc *securecookie.SecureCookie
oauthConfig map[string]*oauth2.Config
oidcProvider *oidc.Provider
logger zerolog.Logger
}
// NewHandlers creates a new Handlers instance.
func NewHandlers(
ctx context.Context,
userManager hub.UserManager,
apiKeyManager hub.APIKeyManager,
cfg *viper.Viper,
) (*Handlers, error) {
// Setup secure cookie instance
sc := securecookie.New([]byte(cfg.GetString("server.cookie.hashKey")), nil)
sc.MaxAge(int(sessionDuration.Seconds()))
// Setup oauth providers configuration
oauthConfig := make(map[string]*oauth2.Config)
var oidcProvider *oidc.Provider
for provider := range cfg.GetStringMap("server.oauth") {
baseCfgKey := fmt.Sprintf("server.oauth.%s.", provider)
var endpoint oauth2.Endpoint
switch provider {
case "github":
endpoint = oagithub.Endpoint
case "google":
endpoint = google.Endpoint
case "oidc":
issuerURL := cfg.GetString(baseCfgKey + "issuerURL")
var err error
oidcProvider, err = oidc.NewProvider(ctx, issuerURL)
if err != nil {
return nil, fmt.Errorf("error setting up oidc provider: %w", err)
}
endpoint = oidcProvider.Endpoint()
default:
continue
}
oauthConfig[provider] = &oauth2.Config{
ClientID: cfg.GetString(baseCfgKey + "clientID"),
ClientSecret: cfg.GetString(baseCfgKey + "clientSecret"),
Endpoint: endpoint,
Scopes: cfg.GetStringSlice(baseCfgKey + "scopes"),
RedirectURL: cfg.GetString(baseCfgKey + "redirectURL"),
}
}
return &Handlers{
userManager: userManager,
apiKeyManager: apiKeyManager,
cfg: cfg,
sc: sc,
oauthConfig: oauthConfig,
oidcProvider: oidcProvider,
logger: log.With().Str("handlers", "user").Logger(),
}, nil
}
// ApproveSession is an http handler used to approve a session. When a user has
// enabled TFA, sessions created after users identify themselves with their
// credentials need to be approved to make them valid by providing a valid TFA
// passcode.
func (h *Handlers) ApproveSession(w http.ResponseWriter, r *http.Request) {
// Get passcode from input
var input map[string]string
err := json.NewDecoder(r.Body).Decode(&input)
if err != nil || input["passcode"] == "" {
h.logger.Error().Err(err).Str("method", "ApproveSession").Msg(hub.ErrInvalidInput.Error())
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
return
}
passcode := input["passcode"]
// Extract sessionID from cookie
var sessionID string
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
h.logger.Error().Err(err).Str("method", "ApproveSession").Msg("session cookie not found")
helpers.RenderErrorWithCodeJSON(w, errInvalidSession, http.StatusUnauthorized)
return
}
if err = h.sc.Decode(sessionCookieName, cookie.Value, &sessionID); err != nil {
h.logger.Error().Err(err).Str("method", "ApproveSession").Msg("sessionID decoding failed")
helpers.RenderErrorWithCodeJSON(w, errInvalidSession, http.StatusUnauthorized)
return
}
// Approve session using the passcode provided
if err := h.userManager.ApproveSession(r.Context(), sessionID, passcode); err != nil {
h.logger.Error().Err(err).Str("method", "ApproveSession").Send()
helpers.RenderErrorJSON(w, err)
return
}
w.WriteHeader(http.StatusNoContent)
}
// BasicAuth is a middleware that provides basic auth support.
func (h *Handlers) BasicAuth(next http.Handler) http.Handler {
validUser := []byte(h.cfg.GetString("server.basicAuth.username"))
validPass := []byte(h.cfg.GetString("server.basicAuth.password"))
realm := "Artifact Hub"
areCredentialsValid := func(user, pass []byte) bool {
if subtle.ConstantTimeCompare(user, validUser) != 1 {
return false
}
if subtle.ConstantTimeCompare(pass, validPass) != 1 {
return false
}
return true
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, pass, ok := r.BasicAuth()
if !ok || !areCredentialsValid([]byte(user), []byte(pass)) {
w.Header().Set("WWW-Authenticate", "Basic realm="+realm+`"`)
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "Unauthorized\n")
return
}
next.ServeHTTP(w, r)
})
}
// CheckPasswordStrength is an http handler that checks the strength of the
// password provided
func (h *Handlers) CheckPasswordStrength(w http.ResponseWriter, r *http.Request) {
var input map[string]string
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
h.logger.Error().Err(err).Str("method", "CheckPasswordStrength").Msg(hub.ErrInvalidInput.Error())
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
return
}
if err := pwvalidator.Validate(input["password"], user.PasswordMinEntropyBits); err != nil {
h.logger.Error().Err(err).Str("method", "CheckPasswordStrength").Msg(hub.ErrInvalidInput.Error())
helpers.RenderErrorWithCodeJSON(w, err, http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusNoContent)
}
// CheckAvailability is an http handler that checks the availability of a given
// value for the provided resource kind.
func (h *Handlers) CheckAvailability(w http.ResponseWriter, r *http.Request) {
resourceKind := chi.URLParam(r, "resourceKind")
value := r.FormValue("v")
w.Header().Set("Cache-Control", helpers.BuildCacheControlHeader(0))
available, err := h.userManager.CheckAvailability(r.Context(), resourceKind, value)
if err != nil {
h.logger.Error().Err(err).Str("method", "CheckAvailability").Send()
helpers.RenderErrorJSON(w, err)
return
}
if available {
helpers.RenderErrorWithCodeJSON(w, nil, http.StatusNotFound)
return
}
w.WriteHeader(http.StatusNoContent)
}
// DisableTFA is an http handler used to disable two-factor authentication.
func (h *Handlers) DisableTFA(w http.ResponseWriter, r *http.Request) {
var input map[string]string
err := json.NewDecoder(r.Body).Decode(&input)
if err != nil || input["passcode"] == "" {
h.logger.Error().Err(err).Str("method", "DisableTFA").Msg(hub.ErrInvalidInput.Error())
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
return
}
if err := h.userManager.DisableTFA(r.Context(), input["passcode"]); err != nil {
h.logger.Error().Err(err).Str("method", "DisableTFA").Send()
helpers.RenderErrorJSON(w, err)
return
}
w.WriteHeader(http.StatusNoContent)
}
// EnableTFA is an http handler used to enable two-factor authentication.
func (h *Handlers) EnableTFA(w http.ResponseWriter, r *http.Request) {
var input map[string]string
err := json.NewDecoder(r.Body).Decode(&input)
if err != nil || input["passcode"] == "" {
h.logger.Error().Err(err).Str("method", "EnableTFA").Msg(hub.ErrInvalidInput.Error())
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
return
}
if err := h.userManager.EnableTFA(r.Context(), input["passcode"]); err != nil {
h.logger.Error().Err(err).Str("method", "EnableTFA").Send()
helpers.RenderErrorJSON(w, err)
return
}
w.WriteHeader(http.StatusNoContent)
}
// GetProfile is an http handler used to get a logged in user profile.
func (h *Handlers) GetProfile(w http.ResponseWriter, r *http.Request) {
dataJSON, err := h.userManager.GetProfileJSON(r.Context())
if err != nil {
h.logger.Error().Err(err).Str("method", "GetProfile").Send()
helpers.RenderErrorJSON(w, err)
return
}
helpers.RenderJSON(w, dataJSON, 0, http.StatusOK)
}
// InjectUserID is a middleware that injects the id of the user doing the
// request into the request context when a valid session id is provided.
func (h *Handlers) InjectUserID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var userID string
// Inject userID in context if available and call next handler
defer func() {
if userID != "" {
ctx := context.WithValue(r.Context(), hub.UserIDKey, userID)
next.ServeHTTP(w, r.WithContext(ctx))
} else {
next.ServeHTTP(w, r)
}
}()
// Extract and validate cookie from request
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
return
}
var sessionID string
if err = h.sc.Decode(sessionCookieName, cookie.Value, &sessionID); err != nil {
return
}
// Check the session provided is valid
checkSessionOutput, err := h.userManager.CheckSession(r.Context(), sessionID, sessionDuration)
if err != nil {
return
}
if !checkSessionOutput.Valid {
return
}
userID = checkSessionOutput.UserID
})
}
// Login is an http handler used to log a user in.
func (h *Handlers) Login(w http.ResponseWriter, r *http.Request) {
// Extract credentials from request
var input map[string]string
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
h.logger.Error().Err(err).Str("method", "Add").Msg(hub.ErrInvalidInput.Error())
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
return
}
// Check if the credentials provided are valid
checkCredentialsOutput, err := h.userManager.CheckCredentials(r.Context(), input["email"], input["password"])
if err != nil {
h.logger.Error().Err(err).Str("method", "Login").Msg("checkCredentials failed")
helpers.RenderErrorJSON(w, err)
return
}
if !checkCredentialsOutput.Valid {
helpers.RenderErrorWithCodeJSON(w, nil, http.StatusUnauthorized)
return
}
// Register user session
ip, _, _ := net.SplitHostPort(r.RemoteAddr)
session, err := h.userManager.RegisterSession(r.Context(), &hub.Session{
UserID: checkCredentialsOutput.UserID,
IP: ip,
UserAgent: r.UserAgent(),
})
if err != nil {
h.logger.Error().Err(err).Str("method", "Login").Msg("registerSession failed")
helpers.RenderErrorJSON(w, err)
return
}
// Generate and set session cookie
encodedSessionID, err := h.sc.Encode(sessionCookieName, session.SessionID)
if err != nil {
h.logger.Error().Err(err).Str("method", "Login").Msg("sessionID encoding failed")
helpers.RenderErrorJSON(w, err)
return
}
cookie := &http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
Path: "/",
Expires: time.Now().Add(sessionDuration),
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
if h.cfg.GetBool("server.cookie.secure") {
cookie.Secure = true
}
http.SetCookie(w, cookie)
w.Header().Set(SessionApprovedHeader, strconv.FormatBool(session.Approved))
w.WriteHeader(http.StatusNoContent)
}
// Logout is an http handler used to log a user out.
func (h *Handlers) Logout(w http.ResponseWriter, r *http.Request) {
// Delete user session
cookie, err := r.Cookie(sessionCookieName)
if err == nil {
var sessionID string
err = h.sc.Decode(sessionCookieName, cookie.Value, &sessionID)
if err == nil {
err = h.userManager.DeleteSession(r.Context(), sessionID)
if err != nil {
h.logger.Error().Err(err).Str("method", "Logout").Msg("deleteSession failed")
}
}
}
// Request browser to delete session cookie
cookie = &http.Cookie{
Name: sessionCookieName,
Path: "/",
Expires: time.Now().Add(-24 * time.Hour),
}
http.SetCookie(w, cookie)
w.WriteHeader(http.StatusNoContent)
}
// OauthCallback is an http handler in charge of completing the oauth
// authentication process, registering the user if needed.
func (h *Handlers) OauthCallback(w http.ResponseWriter, r *http.Request) {
logger := h.logger.With().Str("method", "OauthCallback").Logger()
// Validate oauth code and state
code := r.FormValue("code")
if code == "" {
logger.Error().Msg("oauth code not provided")
http.Redirect(w, r, oauthFailedURL, http.StatusSeeOther)
return
}
state, err := NewOauthState(r.FormValue("state"))
if err != nil {
logger.Error().Msg("oauth state not provided")
http.Redirect(w, r, oauthFailedURL, http.StatusSeeOther)
return
}
stateCookie, err := r.Cookie(oauthStateCookieName)
if err != nil {
logger.Error().Err(err).Msg("state cookie not provided")
http.Redirect(w, r, oauthFailedURL, http.StatusSeeOther)
return
}
if state.Random != stateCookie.Value {
logger.Error().Err(err).Msg("invalid state cookie")
http.Redirect(w, r, oauthFailedURL, http.StatusSeeOther)
return
}
stateCookie = &http.Cookie{
Name: oauthStateCookieName,
Path: "/",
Expires: time.Now().Add(-24 * time.Hour),
}
http.SetCookie(w, stateCookie)
// Register user if needed, or return his id if already registered
provider := chi.URLParam(r, "provider")
providerConfig := h.oauthConfig[provider]
oauthToken, err := providerConfig.Exchange(r.Context(), code)
if err != nil {
logger.Error().Err(err).Msg("oauth code exchange failed")
http.Redirect(w, r, oauthFailedURL, http.StatusSeeOther)
return
}
userID, err := h.registerUserWithOauth(r.Context(), provider, providerConfig, oauthToken)
if err != nil {
logger.Error().Err(err).Msg("oauth code exchange failed")
http.Redirect(w, r, oauthFailedURL, http.StatusSeeOther)
return
}
// Register user session and set session cookie
ip, _, _ := net.SplitHostPort(r.RemoteAddr)
session, err := h.userManager.RegisterSession(r.Context(), &hub.Session{
UserID: userID,
IP: ip,
UserAgent: r.UserAgent(),
})
if err != nil {
logger.Error().Err(err).Msg("registerSession failed")
http.Redirect(w, r, oauthFailedURL, http.StatusSeeOther)
return
}
encodedSessionID, err := h.sc.Encode(sessionCookieName, session.SessionID)
if err != nil {
logger.Error().Err(err).Msg("sessionID encoding failed")
http.Redirect(w, r, oauthFailedURL, http.StatusSeeOther)
return
}
sessionCookie := &http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
Path: "/",
Expires: time.Now().Add(sessionDuration),
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
if h.cfg.GetBool("server.cookie.secure") {
sessionCookie.Secure = true
}
http.SetCookie(w, sessionCookie)
http.Redirect(w, r, state.RedirectURL, http.StatusSeeOther)
}
// OauthRedirect is an http handler that redirects the user to the oauth
// provider to proceed with the authorization.
func (h *Handlers) OauthRedirect(w http.ResponseWriter, r *http.Request) {
// Generate random value for oauth session and store it in browser. It'll
// be used later to validate the callback request is done by the same user.
random := uuid.NewV4().String()
cookie := &http.Cookie{
Name: oauthStateCookieName,
Value: random,
Path: "/",
HttpOnly: true,
}
if h.cfg.GetBool("server.cookie.secure") {
cookie.Secure = true
}
http.SetCookie(w, cookie)
// Prepare oauth state and redirect user to oauth provider
redirectURL := r.FormValue("redirect_url")
if redirectURL == "" {
redirectURL = r.Referer()
}
if redirectURL == "" {
redirectURL = "/"
}
providerConfig := h.oauthConfig[chi.URLParam(r, "provider")]
state := &OauthState{
Random: random,
RedirectURL: redirectURL,
}
authCodeURL := providerConfig.AuthCodeURL(state.String())
http.Redirect(w, r, authCodeURL, http.StatusSeeOther)
}
// RegisterPasswordResetCode is an http handler used to register a code to
// reset the password. The code will be emailed to the address provided.
func (h *Handlers) RegisterPasswordResetCode(w http.ResponseWriter, r *http.Request) {
var input map[string]string
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
h.logger.Error().Err(err).Str("method", "RegisterPasswordResetCode").Msg(hub.ErrInvalidInput.Error())
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
return
}
_ = h.userManager.RegisterPasswordResetCode(
r.Context(),
input["email"],
h.cfg.GetString("server.baseURL"),
)
w.WriteHeader(http.StatusCreated)
}
// RegisterUser is an http handler used to register a user in the hub database.
func (h *Handlers) RegisterUser(w http.ResponseWriter, r *http.Request) {
u := &hub.User{}
err := json.NewDecoder(r.Body).Decode(&u)
if err != nil {
h.logger.Error().Err(err).Str("method", "RegisterUser").Msg("invalid user")
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
return
}
u.EmailVerified = false
if u.Password == "" {
errMsg := "password not provided"
h.logger.Error().Err(err).Str("method", "RegisterUser").Msg(errMsg)
helpers.RenderErrorJSON(w, fmt.Errorf("%w: %s", hub.ErrInvalidInput, errMsg))
return
}
err = h.userManager.RegisterUser(r.Context(), u, h.cfg.GetString("server.baseURL"))
if err != nil {
h.logger.Error().Err(err).Str("method", "RegisterUser").Send()
helpers.RenderErrorJSON(w, err)
return
}
w.WriteHeader(http.StatusCreated)
}
// registerUserWithOauth is a helper function that registers a user using the
// details from his oauth provider if he's not already registered, returning
// the user id.
func (h *Handlers) registerUserWithOauth(
ctx context.Context,
provider string,
providerConfig *oauth2.Config,
oauthToken *oauth2.Token,
) (string, error) {
// Build user from profile from oauth provider
var u *hub.User
var err error
switch provider {
case "github":
u, err = h.newUserFromGithubProfile(ctx, oauthToken)
case "google":
u, err = h.newUserFromGoogleProfile(ctx, providerConfig, oauthToken)
case "oidc":
u, err = h.newUserFromOIDProfile(ctx, oauthToken)
default:
err = fmt.Errorf("invalid provider: %s", provider)
}
if err != nil {
return "", err
}
// Check user alias availability and append suffix to it if needed
available, err := h.userManager.CheckAvailability(ctx, "userAlias", u.Alias)
if err != nil {
return "", err
}
if !available {
randomSuffix, err := getRandomSuffix()
if err != nil {
return "", err
}
u.Alias += randomSuffix
}
// Check if user exists
userID, err := h.userManager.GetUserID(ctx, u.Email)
if err != nil && !errors.Is(err, user.ErrNotFound) {
return "", err
}
// Register user if needed
if userID == "" {
u.EmailVerified = true
err := h.userManager.RegisterUser(ctx, u, "")
if err != nil {
return "", err
}
userID, err = h.userManager.GetUserID(ctx, u.Email)
if err != nil {
return "", err
}
}
return userID, nil
}
// newUserFromGithubProfile builds a new hub.User instance from the user's
// Github profile.
func (h *Handlers) newUserFromGithubProfile(
ctx context.Context,
oauthToken *oauth2.Token,
) (*hub.User, error) {
// Get user profile and emails
httpClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(oauthToken))
githubClient := github.NewClient(httpClient)
profile, _, err := githubClient.Users.Get(ctx, "")
if err != nil {
return nil, err
}
emails, _, err := githubClient.Users.ListEmails(ctx, nil)
if err != nil {
return nil, err
}
// Get user's primary email and check if it has been verified
var email string
for _, entry := range emails {
if entry.GetPrimary() && entry.GetVerified() {
email = entry.GetEmail()
}
}
if email == "" {
return nil, errors.New("no valid email available for use")
}
return &hub.User{
Alias: profile.GetLogin(),
Email: email,
FirstName: profile.GetName(),
}, nil
}
// newUserFromGoogleProfile builds a new hub.User instance from the user's
// Google profile.
func (h *Handlers) newUserFromGoogleProfile(
ctx context.Context,
providerConfig *oauth2.Config,
oauthToken *oauth2.Token,
) (*hub.User, error) {
// Get user profile
opt := option.WithTokenSource(providerConfig.TokenSource(ctx, oauthToken))
peopleService, err := people.NewService(ctx, opt)
if err != nil {
return nil, err
}
profile, err := peopleService.People.
Get("people/me").
PersonFields("names,emailAddresses").
Do()
if err != nil {
return nil, err
}
// Get user's primary email and check if it has been verified
var email string
for _, entry := range profile.EmailAddresses {
if entry.Metadata.Primary && entry.Metadata.Verified {
email = entry.Value
}
}
if email == "" {
return nil, errors.New("no valid email available for use")
}
return &hub.User{
Alias: strings.Split(email, "@")[0],
Email: email,
FirstName: profile.Names[0].GivenName,
LastName: profile.Names[0].FamilyName,
}, nil
}
// newUserFromOIDProfile builds a new hub.User instance from the user's OpenID
// profile.
func (h *Handlers) newUserFromOIDProfile(
ctx context.Context,
oauthToken *oauth2.Token,
) (*hub.User, error) {
// Extract the id token from oauth token
rawIDToken, ok := oauthToken.Extra("id_token").(string)
if !ok {
return nil, errors.New("id token not available")
}
// Parse and verify id token payload
verifier := h.oidcProvider.Verifier(&oidc.Config{
ClientID: h.cfg.GetString("server.oauth.oidc.clientID"),
})
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
return nil, fmt.Errorf("invalid id token: %w", err)
}
// Extract claims
var claims struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
PreferredUsername string `json:"preferred_username"`
}
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("error extracting claims from id token: %w", err)
}
if claims.Email == "" || !claims.EmailVerified {
return nil, errors.New("no valid email available for use")
}
alias := claims.PreferredUsername
if alias == "" {
alias = strings.Split(claims.Email, "@")[0]
}
return &hub.User{
Alias: alias,
Email: claims.Email,
FirstName: claims.GivenName,
LastName: claims.FamilyName,
}, nil
}
// RequireLogin is a middleware that verifies if a user is logged in.
func (h *Handlers) RequireLogin(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var userID string
// Extract API key id and secret from header
apiKeyID := r.Header.Get(APIKeyIDHeader)
apiKeySecret := r.Header.Get(APIKeySecretHeader)
// Use API key based authentication if API key is provided
if apiKeyID != "" && apiKeySecret != "" {
// Check the API key provided is valid
checkAPIKeyOutput, err := h.apiKeyManager.Check(r.Context(), apiKeyID, apiKeySecret)
if err != nil {
h.logger.Error().Err(err).Str("method", "RequireLogin").Msg("checkAPIKey failed")
helpers.RenderErrorWithCodeJSON(w, nil, http.StatusInternalServerError)
return
}
if !checkAPIKeyOutput.Valid {
helpers.RenderErrorWithCodeJSON(w, errInvalidAPIKey, http.StatusUnauthorized)
return
}
userID = checkAPIKeyOutput.UserID
} else {
// Use cookie based authentication
cookie, err := r.Cookie(sessionCookieName)
if err == nil {
// Extract and validate cookie from request
var sessionID string
if err = h.sc.Decode(sessionCookieName, cookie.Value, &sessionID); err != nil {
h.logger.Error().Err(err).Str("method", "RequireLogin").Msg("sessionID decoding failed")
helpers.RenderErrorWithCodeJSON(w, errInvalidSession, http.StatusUnauthorized)
return
}
// Check the session provided is valid
checkSessionOutput, err := h.userManager.CheckSession(r.Context(), sessionID, sessionDuration)
if err != nil {
h.logger.Error().Err(err).Str("method", "RequireLogin").Msg("checkSession failed")
helpers.RenderErrorWithCodeJSON(w, nil, http.StatusInternalServerError)
return
}
if !checkSessionOutput.Valid {
helpers.RenderErrorWithCodeJSON(w, errInvalidSession, http.StatusUnauthorized)
return
}
userID = checkSessionOutput.UserID
}
}
// Return if no authentication method succeeded
if userID == "" {
helpers.RenderErrorWithCodeJSON(w, nil, http.StatusUnauthorized)
return
}
// Inject userID in context and call next handler
ctx := context.WithValue(r.Context(), hub.UserIDKey, userID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// ResetPassword is an http handler used to reset the user's password.
func (h *Handlers) ResetPassword(w http.ResponseWriter, r *http.Request) {
var input map[string]string
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
h.logger.Error().Err(err).Str("method", "ResetPassword").Msg(hub.ErrInvalidInput.Error())
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
return
}
err := h.userManager.ResetPassword(
r.Context(),
input["code"],
input["password"],
h.cfg.GetString("server.baseURL"),
)
if err != nil {
h.logger.Error().Err(err).Str("method", "ResetPassword").Send()
if errors.Is(err, user.ErrInvalidPasswordResetCode) {
helpers.RenderErrorWithCodeJSON(w, err, http.StatusBadRequest)
} else {
helpers.RenderErrorJSON(w, err)
}
return
}
w.WriteHeader(http.StatusNoContent)
}
// SetupTFA is an http handler used to setup two-factor authentication.
func (h *Handlers) SetupTFA(w http.ResponseWriter, r *http.Request) {
dataJSON, err := h.userManager.SetupTFA(r.Context())
if err != nil {
h.logger.Error().Err(err).Str("method", "SetupTFA").Send()
helpers.RenderErrorJSON(w, err)
return
}
helpers.RenderJSON(w, dataJSON, 0, http.StatusCreated)
}
// UpdatePassword is an http handler used to update the password in the hub
// database.
func (h *Handlers) UpdatePassword(w http.ResponseWriter, r *http.Request) {
var input map[string]string
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
h.logger.Error().Err(err).Str("method", "UpdatePassword").Msg(hub.ErrInvalidInput.Error())
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
return
}
err := h.userManager.UpdatePassword(r.Context(), input["old"], input["new"])
if err != nil {
h.logger.Error().Err(err).Str("method", "UpdatePassword").Send()
if errors.Is(err, user.ErrInvalidPassword) {
helpers.RenderErrorWithCodeJSON(w, nil, http.StatusUnauthorized)
} else {
helpers.RenderErrorJSON(w, err)
}
return
}
w.WriteHeader(http.StatusNoContent)
}
// UpdateProfile is an http handler used to update the user in the hub database.
func (h *Handlers) UpdateProfile(w http.ResponseWriter, r *http.Request) {
u := &hub.User{}
err := json.NewDecoder(r.Body).Decode(&u)
if err != nil {
h.logger.Error().Err(err).Str("method", "UpdateUserProfile").Msg("invalid user")
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
return
}
err = h.userManager.UpdateProfile(r.Context(), u)
if err != nil {
h.logger.Error().Err(err).Str("method", "UpdateUserProfile").Send()
helpers.RenderErrorJSON(w, err)
return
}
w.WriteHeader(http.StatusNoContent)
}
// VerifyEmail is an http handler used to verify a user's email address.
func (h *Handlers) VerifyEmail(w http.ResponseWriter, r *http.Request) {
var input map[string]string
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
h.logger.Error().Err(err).Str("method", "VerifyEmail").Msg(hub.ErrInvalidInput.Error())
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
return
}
verified, err := h.userManager.VerifyEmail(r.Context(), input["code"])
if err != nil {
h.logger.Error().Err(err).Str("method", "VerifyEmail").Send()
helpers.RenderErrorJSON(w, err)
return
}
if !verified {
helpers.RenderErrorWithCodeJSON(w, fmt.Errorf("email verification code has expired"), http.StatusGone)
return
}
w.WriteHeader(http.StatusNoContent)
}
// VerifyPasswordResetCode is an http handler used to verify a reset password
// code.
func (h *Handlers) VerifyPasswordResetCode(w http.ResponseWriter, r *http.Request) {
var input map[string]string
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
h.logger.Error().Err(err).Str("method", "VerifyPasswordResetCode").Msg(hub.ErrInvalidInput.Error())
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
return
}
err := h.userManager.VerifyPasswordResetCode(r.Context(), input["code"])
if err != nil {
h.logger.Error().Err(err).Str("method", "VerifyPasswordResetCode").Send()
if errors.Is(err, user.ErrInvalidPasswordResetCode) {
helpers.RenderErrorWithCodeJSON(w, err, http.StatusGone)
} else {
helpers.RenderErrorJSON(w, err)
}
return
}
w.WriteHeader(http.StatusOK)
}
// OauthState represents the state of an oauth authorization session, used to
// increase the security of the process and to restore the state of the
// application.
type OauthState struct {
Random string
RedirectURL string
}
// String returns an OauthState instance as a string.
func (s *OauthState) String() string {
dataJSON, _ := json.Marshal(s)
return base64.StdEncoding.EncodeToString(dataJSON)
}
// NewOauthState creates a new OauthState instance from the string
// representation provided.
func NewOauthState(s string) (*OauthState, error) {
dataJSON, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return nil, err
}
var state *OauthState
if err := json.Unmarshal(dataJSON, &state); err != nil {
return nil, err
}
return state, nil
}
// getRandomSuffix is a helper function that returns a random numerical suffix
// to be used in user aliases when the selected alias is already taken.
func getRandomSuffix() (string, error) {
nBig, err := rand.Int(rand.Reader, big.NewInt(1000))
if err != nil {
return "", err
}
return strconv.FormatInt(nBig.Int64(), 10), nil
}