hub/internal/handlers/user/handlers_test.go

1724 lines
46 KiB
Go

package user
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strconv"
"strings"
"testing"
"time"
"github.com/artifacthub/hub/internal/apikey"
"github.com/artifacthub/hub/internal/handlers/helpers"
"github.com/artifacthub/hub/internal/hub"
"github.com/artifacthub/hub/internal/tests"
"github.com/artifacthub/hub/internal/user"
"github.com/go-chi/chi"
"github.com/rs/zerolog"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
zerolog.SetGlobalLevel(zerolog.Disabled)
os.Exit(m.Run())
}
func TestApproveSession(t *testing.T) {
sessionID := "sessionID"
t.Run("invalid input", func(t *testing.T) {
testCases := []struct {
desc string
inputJSON string
}{
{
"invalid input",
`{"passcode": "123456" ...`,
},
{
"no passcode provided",
`{"passcode": ""}`,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("PUT", "/", strings.NewReader(tc.inputJSON))
hw := newHandlersWrapper()
hw.h.ApproveSession(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
}
})
t.Run("session cookie not provided", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"passcode": "123456"}`)
r, _ := http.NewRequest("PUT", "/", body)
hw := newHandlersWrapper()
hw.h.ApproveSession(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
t.Run("invalid session cookie", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"passcode": "123456"}`)
r, _ := http.NewRequest("PUT", "/", body)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: "invalidValue",
})
hw := newHandlersWrapper()
hw.h.ApproveSession(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
t.Run("error approving session", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"passcode": "123456"}`)
r, _ := http.NewRequest("PUT", "/", body)
hw := newHandlersWrapper()
hw.um.On("ApproveSession", r.Context(), sessionID, "123456").Return(tests.ErrFake)
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
})
hw.h.ApproveSession(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
})
t.Run("session approval succeeded", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"passcode": "123456"}`)
r, _ := http.NewRequest("PUT", "/", body)
hw := newHandlersWrapper()
hw.um.On("ApproveSession", r.Context(), sessionID, "123456").Return(nil)
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
})
hw.h.ApproveSession(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
})
}
func TestBasicAuth(t *testing.T) {
hw := newHandlersWrapper()
hw.cfg.Set("server.basicAuth.enabled", true)
hw.cfg.Set("server.basicAuth.username", "test")
hw.cfg.Set("server.basicAuth.password", "test")
t.Run("without basic auth credentials", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
hw.h.BasicAuth(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
t.Run("with basic auth credentials", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r.SetBasicAuth("test", "test")
hw.h.BasicAuth(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
})
}
func TestCheckPasswordStrength(t *testing.T) {
t.Run("invalid input", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{..`)
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.h.CheckPasswordStrength(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
t.Run("invalid passwords", func(t *testing.T) {
passwords := []string{
"invalid",
"123",
"weak12",
}
for _, pw := range passwords {
pw := pw
t.Run(pw, func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(fmt.Sprintf(`{"password": "%s"}`, pw))
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.h.CheckPasswordStrength(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
}
})
t.Run("valid passwords", func(t *testing.T) {
passwords := []string{
"12uuYbaT.",
"this password should be valid too",
"19s-8*s.Y",
"yet123-another-ONE",
}
for _, pw := range passwords {
pw := pw
t.Run(pw, func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(fmt.Sprintf(`{"password": "%s"}`, pw))
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.h.CheckPasswordStrength(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
})
}
})
}
func TestCheckAvailability(t *testing.T) {
t.Run("invalid input", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("HEAD", "/?v=value", nil)
rctx := &chi.Context{
URLParams: chi.RouteParams{
Keys: []string{"resourceKind"},
Values: []string{"invalid"},
},
}
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
hw := newHandlersWrapper()
hw.um.On("CheckAvailability", r.Context(), "invalid", "value").
Return(false, hub.ErrInvalidInput)
hw.h.CheckAvailability(w, r)
resp := w.Result()
defer resp.Body.Close()
h := resp.Header
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
assert.Equal(t, helpers.BuildCacheControlHeader(0), h.Get("Cache-Control"))
hw.um.AssertExpectations(t)
})
t.Run("valid input", func(t *testing.T) {
t.Run("check availability succeeded", func(t *testing.T) {
testCases := []struct {
resourceKind string
available bool
}{
{
"userAlias",
true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(fmt.Sprintf("resource kind: %s", tc.resourceKind), func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("HEAD", "/?v=value", nil)
rctx := &chi.Context{
URLParams: chi.RouteParams{
Keys: []string{"resourceKind"},
Values: []string{tc.resourceKind},
},
}
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
hw := newHandlersWrapper()
hw.um.On("CheckAvailability", r.Context(), tc.resourceKind, "value").Return(tc.available, nil)
hw.h.CheckAvailability(w, r)
resp := w.Result()
defer resp.Body.Close()
h := resp.Header
if tc.available {
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
} else {
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
}
assert.Equal(t, helpers.BuildCacheControlHeader(0), h.Get("Cache-Control"))
hw.um.AssertExpectations(t)
})
}
})
t.Run("check availability failed", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("HEAD", "/?v=value", nil)
rctx := &chi.Context{
URLParams: chi.RouteParams{
Keys: []string{"resourceKind"},
Values: []string{"userAlias"},
},
}
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
hw := newHandlersWrapper()
hw.um.On("CheckAvailability", r.Context(), "userAlias", "value").Return(false, tests.ErrFakeDB)
hw.h.CheckAvailability(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
hw.um.AssertExpectations(t)
})
})
}
func TestDisableTFA(t *testing.T) {
t.Run("invalid input", func(t *testing.T) {
testCases := []struct {
desc string
inputJSON string
}{
{
"invalid input",
`{"passcode": "123456" ...`,
},
{
"no passcode provided",
`{"passcode": ""}`,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("PUT", "/", strings.NewReader(tc.inputJSON))
hw := newHandlersWrapper()
hw.h.DisableTFA(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
}
})
t.Run("disable tfa failed", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"passcode": "123456"}`)
r, _ := http.NewRequest("PUT", "/", body)
hw := newHandlersWrapper()
hw.um.On("DisableTFA", r.Context(), "123456").Return(tests.ErrFake)
hw.h.DisableTFA(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("disable tfa succeeded", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"passcode": "123456"}`)
r, _ := http.NewRequest("PUT", "/", body)
hw := newHandlersWrapper()
hw.um.On("DisableTFA", r.Context(), "123456").Return(nil)
hw.h.DisableTFA(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
hw.um.AssertExpectations(t)
})
}
func TestEnableTFA(t *testing.T) {
t.Run("invalid input", func(t *testing.T) {
testCases := []struct {
desc string
inputJSON string
}{
{
"invalid input",
`{"passcode": "123456" ...`,
},
{
"no passcode provided",
`{"passcode": ""}`,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("PUT", "/", strings.NewReader(tc.inputJSON))
hw := newHandlersWrapper()
hw.h.EnableTFA(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
}
})
t.Run("enable tfa failed", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"passcode": "123456"}`)
r, _ := http.NewRequest("PUT", "/", body)
hw := newHandlersWrapper()
hw.um.On("EnableTFA", r.Context(), "123456").Return(tests.ErrFake)
hw.h.EnableTFA(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("enble tfa succeeded", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"passcode": "123456"}`)
r, _ := http.NewRequest("PUT", "/", body)
hw := newHandlersWrapper()
hw.um.On("EnableTFA", r.Context(), "123456").Return(nil)
hw.h.EnableTFA(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
hw.um.AssertExpectations(t)
})
}
func TestGetProfile(t *testing.T) {
t.Run("error getting profile", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
hw := newHandlersWrapper()
hw.um.On("GetProfileJSON", r.Context()).Return(nil, tests.ErrFakeDB)
hw.h.GetProfile(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("profile get succeeded", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
hw := newHandlersWrapper()
hw.um.On("GetProfileJSON", r.Context()).Return([]byte("dataJSON"), nil)
hw.h.GetProfile(w, r)
resp := w.Result()
defer resp.Body.Close()
h := resp.Header
data, _ := ioutil.ReadAll(resp.Body)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "application/json", h.Get("Content-Type"))
assert.Equal(t, helpers.BuildCacheControlHeader(0), h.Get("Cache-Control"))
assert.Equal(t, []byte("dataJSON"), data)
hw.um.AssertExpectations(t)
})
}
func TestInjectUserID(t *testing.T) {
sessionID := "sessionID"
checkUserID := func(expectedUserID interface{}) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if expectedUserID == nil {
assert.Nil(t, r.Context().Value(hub.UserIDKey))
} else {
assert.Equal(t, expectedUserID, r.Context().Value(hub.UserIDKey).(string))
}
}
}
t.Run("session cookie not provided", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
hw := newHandlersWrapper()
hw.h.InjectUserID(checkUserID(nil)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
})
t.Run("invalid session cookie provided", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: "invalidValue",
})
hw := newHandlersWrapper()
hw.h.InjectUserID(checkUserID(nil)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
})
t.Run("error checking session", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
hw := newHandlersWrapper()
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
})
hw.um.On("CheckSession", r.Context(), mock.Anything, mock.Anything).
Return(nil, tests.ErrFakeDB)
hw.h.InjectUserID(checkUserID(nil)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("invalid session provided", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
hw := newHandlersWrapper()
hw.um.On("CheckSession", r.Context(), mock.Anything, mock.Anything).
Return(&hub.CheckSessionOutput{UserID: "", Valid: false}, nil)
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
})
hw.h.InjectUserID(checkUserID(nil)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("inject user id succeeded", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
hw := newHandlersWrapper()
hw.um.On("CheckSession", r.Context(), mock.Anything, mock.Anything).
Return(&hub.CheckSessionOutput{UserID: "userID", Valid: true}, nil)
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
})
hw.h.InjectUserID(checkUserID("userID")).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
hw.um.AssertExpectations(t)
})
}
func TestLogin(t *testing.T) {
sessionID := "sessionID"
t.Run("invalid", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"email": "email" ...`)
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.h.Login(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
t.Run("credentials not provided", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{}`)
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.um.On("CheckCredentials", r.Context(), "", "").Return(nil, hub.ErrInvalidInput)
hw.h.Login(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
t.Run("error checking credentials", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"email": "email", "password": "pass"}`)
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.um.On("CheckCredentials", r.Context(), "email", "pass").Return(nil, tests.ErrFakeDB)
hw.h.Login(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("invalid credentials provided", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"email": "email", "password": "pass2"}`)
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.um.On("CheckCredentials", r.Context(), "email", "pass2").
Return(&hub.CheckCredentialsOutput{Valid: false, UserID: ""}, nil)
hw.h.Login(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("error registering session", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"email": "email", "password": "pass"}`)
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.um.On("CheckCredentials", r.Context(), "email", "pass").
Return(&hub.CheckCredentialsOutput{Valid: true, UserID: "userID"}, nil)
hw.um.On("RegisterSession", r.Context(), &hub.Session{UserID: "userID"}).
Return(nil, tests.ErrFakeDB)
hw.h.Login(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("login succeeded", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"email": "email", "password": "pass"}`)
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.um.On("CheckCredentials", r.Context(), "email", "pass").
Return(&hub.CheckCredentialsOutput{Valid: true, UserID: "userID"}, nil)
hw.um.On("RegisterSession", r.Context(), &hub.Session{UserID: "userID"}).
Return(&hub.Session{
SessionID: sessionID,
Approved: true,
}, nil)
hw.h.Login(w, r)
resp := w.Result()
defer resp.Body.Close()
h := resp.Header
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
require.Len(t, resp.Cookies(), 1)
cookie := resp.Cookies()[0]
assert.Equal(t, sessionCookieName, cookie.Name)
assert.Equal(t, "/", cookie.Path)
assert.True(t, cookie.HttpOnly)
assert.False(t, cookie.Secure)
var cookieSessionID string
err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &cookieSessionID)
require.NoError(t, err)
assert.Equal(t, sessionID, cookieSessionID)
assert.Equal(t, "true", h.Get(SessionApprovedHeader))
hw.um.AssertExpectations(t)
})
t.Run("login succeeded (tfa enabled)", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"email": "email", "password": "pass"}`)
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.um.On("CheckCredentials", r.Context(), "email", "pass").
Return(&hub.CheckCredentialsOutput{Valid: true, UserID: "userID"}, nil)
hw.um.On("RegisterSession", r.Context(), &hub.Session{UserID: "userID"}).
Return(&hub.Session{
SessionID: sessionID,
Approved: false,
}, nil)
hw.h.Login(w, r)
resp := w.Result()
defer resp.Body.Close()
h := resp.Header
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
require.Len(t, resp.Cookies(), 1)
cookie := resp.Cookies()[0]
assert.Equal(t, sessionCookieName, cookie.Name)
assert.Equal(t, "/", cookie.Path)
assert.True(t, cookie.HttpOnly)
assert.False(t, cookie.Secure)
var cookieSessionID string
err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &cookieSessionID)
require.NoError(t, err)
assert.Equal(t, sessionID, cookieSessionID)
assert.Equal(t, "false", h.Get(SessionApprovedHeader))
hw.um.AssertExpectations(t)
})
}
func TestLogout(t *testing.T) {
t.Run("invalid or no session cookie provided", func(t *testing.T) {
testCases := []struct {
description string
cookie *http.Cookie
}{
{
"invalid session cookie provided",
nil,
},
{
"no session cookie provided",
&http.Cookie{
Name: sessionCookieName,
Value: "invalidValue",
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.description, func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
if tc.cookie != nil {
r.AddCookie(tc.cookie)
}
hw := newHandlersWrapper()
hw.h.Logout(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
require.Len(t, resp.Cookies(), 1)
cookie := resp.Cookies()[0]
assert.Equal(t, sessionCookieName, cookie.Name)
assert.True(t, cookie.Expires.Before(time.Now().Add(-24*time.Hour)))
})
}
})
t.Run("valid session cookie provided", func(t *testing.T) {
testCases := []struct {
description string
err interface{}
}{
{
"session deleted successfully",
nil,
},
{
"error deleting session",
tests.ErrFakeDB,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.description, func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
hw := newHandlersWrapper()
hw.um.On("DeleteSession", r.Context(), "sessionID").Return(tc.err)
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, "sessionID")
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
})
hw.h.Logout(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
require.Len(t, resp.Cookies(), 1)
cookie := resp.Cookies()[0]
assert.Equal(t, sessionCookieName, cookie.Name)
assert.Equal(t, "/", cookie.Path)
assert.True(t, cookie.Expires.Before(time.Now().Add(-24*time.Hour)))
hw.um.AssertExpectations(t)
})
}
})
}
func TestOauthCallback(t *testing.T) {
t.Run("invalid oauth code or state", func(t *testing.T) {
state := &OauthState{
Random: "abcd",
RedirectURL: "/",
}
testCases := []struct {
description string
url string
cookie *http.Cookie
}{
{
"oauth code not provided",
"/",
nil,
},
{
"oauth state not provided",
"/?code=1234",
nil,
},
{
"state cookie not provided",
"/?code=1234&state=" + state.String(),
nil,
},
{
"invalid state cookie",
"/?code=1234&state=" + state.String(),
&http.Cookie{
Name: oauthStateCookieName,
Value: "something not expected",
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.description, func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", tc.url, nil)
if tc.cookie != nil {
r.AddCookie(tc.cookie)
}
hw := newHandlersWrapper()
hw.h.OauthCallback(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusSeeOther, resp.StatusCode)
redirectURL, err := resp.Location()
require.NoError(t, err)
assert.Equal(t, oauthFailedURL, redirectURL.String())
})
}
})
}
func TestOauthRedirect(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
rctx := &chi.Context{
URLParams: chi.RouteParams{
Keys: []string{"provider"},
Values: []string{"github"},
},
}
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
hw := newHandlersWrapper()
hw.h.OauthRedirect(w, r)
resp := w.Result()
defer resp.Body.Close()
require.Len(t, resp.Cookies(), 1)
assert.Equal(t, oauthStateCookieName, resp.Cookies()[0].Name)
assert.NotEmpty(t, resp.Cookies()[0].Value)
assert.Equal(t, "/", resp.Cookies()[0].Path)
assert.True(t, resp.Cookies()[0].HttpOnly)
assert.False(t, resp.Cookies()[0].Secure)
assert.Equal(t, http.StatusSeeOther, resp.StatusCode)
state := &OauthState{
Random: resp.Cookies()[0].Value,
RedirectURL: "/",
}
expectedRedirectURL := hw.h.oauthConfig["github"].AuthCodeURL(state.String())
redirectURL, err := resp.Location()
require.NoError(t, err)
assert.Equal(t, expectedRedirectURL, redirectURL.String())
}
func TestRegisterPasswordResetCode(t *testing.T) {
t.Run("invalid input", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`email`)
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.h.RegisterPasswordResetCode(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("register password reset code failed", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"email": "email"}`)
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.um.On("RegisterPasswordResetCode", r.Context(), "email", "baseURL").Return(tests.ErrFakeDB)
hw.h.RegisterPasswordResetCode(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusCreated, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("register password reset code succeeded", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"email": "email"}`)
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.um.On("RegisterPasswordResetCode", r.Context(), "email", "baseURL").Return(tests.ErrFakeDB)
hw.h.RegisterPasswordResetCode(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusCreated, resp.StatusCode)
hw.um.AssertExpectations(t)
})
}
func TestRegisterUser(t *testing.T) {
t.Run("no user provided", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("POST", "/", strings.NewReader(""))
hw := newHandlersWrapper()
hw.h.RegisterUser(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
t.Run("invalid user provided", func(t *testing.T) {
testCases := []struct {
description string
userJSON string
umErr error
}{
{
"invalid json",
"-",
nil,
},
{
"missing password",
`{"alias": "alias", "email": "email"}`,
nil,
},
{
"missing alias",
`{"email": "email", "password": "password"}`,
hub.ErrInvalidInput,
},
{
"missing email",
`{"alias": "alias", "password": "password"}`,
hub.ErrInvalidInput,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.description, func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("POST", "/", strings.NewReader(tc.userJSON))
hw := newHandlersWrapper()
if tc.umErr != nil {
hw.um.On("RegisterUser", r.Context(), mock.Anything, "baseURL").Return(tc.umErr)
}
hw.h.RegisterUser(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
hw.um.AssertExpectations(t)
})
}
})
t.Run("valid user provided", func(t *testing.T) {
userJSON := `
{
"alias": "alias",
"first_name": "first_name",
"last_name": "last_name",
"email": "email",
"password": "password"
}
`
u := &hub.User{}
_ = json.Unmarshal([]byte(userJSON), &u)
testCases := []struct {
description string
umErr error
expectedStatusCode int
}{
{
"registration succeeded",
nil,
http.StatusCreated,
},
{
"registration failed",
tests.ErrFakeDB,
http.StatusInternalServerError,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.description, func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("POST", "/", strings.NewReader(userJSON))
hw := newHandlersWrapper()
hw.um.On("RegisterUser", r.Context(), u, "baseURL").Return(tc.umErr)
hw.h.RegisterUser(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, tc.expectedStatusCode, resp.StatusCode)
hw.um.AssertExpectations(t)
})
}
})
}
func TestRequireLogin(t *testing.T) {
sessionID := "sessionID"
t.Run("api key based authentication", func(t *testing.T) {
apiKeyID := "keyID"
apiKeySecret := "secret"
t.Run("no api key id or secret provided", func(t *testing.T) {
testCases := []struct {
apiKeyID string
apiKeySecret string
}{
{
"",
"secret",
},
{
"key",
"",
},
}
for i, tc := range testCases {
tc := tc
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r.Header.Add(APIKeyIDHeader, tc.apiKeyID)
r.Header.Add(APIKeySecretHeader, tc.apiKeySecret)
hw := newHandlersWrapper()
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
h := resp.Header
data, _ := ioutil.ReadAll(resp.Body)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, "application/json", h.Get("Content-Type"))
assert.Equal(t, buildError(""), data)
})
}
})
t.Run("error checking api key", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r.Header.Add(APIKeyIDHeader, apiKeyID)
r.Header.Add(APIKeySecretHeader, apiKeySecret)
hw := newHandlersWrapper()
hw.am.On("Check", r.Context(), apiKeyID, apiKeySecret).Return(nil, tests.ErrFakeDB)
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
h := resp.Header
data, _ := ioutil.ReadAll(resp.Body)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
assert.Equal(t, "application/json", h.Get("Content-Type"))
assert.Equal(t, buildError(""), data)
hw.um.AssertExpectations(t)
})
t.Run("invalid api key provided", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r.Header.Add(APIKeyIDHeader, apiKeyID)
r.Header.Add(APIKeySecretHeader, apiKeySecret)
hw := newHandlersWrapper()
hw.am.On("Check", r.Context(), apiKeyID, apiKeySecret).
Return(&hub.CheckAPIKeyOutput{UserID: "", Valid: false}, nil)
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
h := resp.Header
data, _ := ioutil.ReadAll(resp.Body)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, "application/json", h.Get("Content-Type"))
assert.Equal(t, buildError(errInvalidAPIKey.Error()), data)
hw.um.AssertExpectations(t)
})
t.Run("api key based authentication succeeded", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r.Header.Add(APIKeyIDHeader, apiKeyID)
r.Header.Add(APIKeySecretHeader, apiKeySecret)
hw := newHandlersWrapper()
hw.am.On("Check", r.Context(), apiKeyID, apiKeySecret).
Return(&hub.CheckAPIKeyOutput{UserID: "userID", Valid: true}, nil)
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
hw.um.AssertExpectations(t)
})
})
t.Run("session cookie based authentication", func(t *testing.T) {
t.Run("invalid session cookie provided", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: "invalidValue",
})
hw := newHandlersWrapper()
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
h := resp.Header
data, _ := ioutil.ReadAll(resp.Body)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, "application/json", h.Get("Content-Type"))
assert.Equal(t, buildError(errInvalidSession.Error()), data)
})
t.Run("error checking session", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
hw := newHandlersWrapper()
hw.um.On("CheckSession", r.Context(), sessionID, sessionDuration).
Return(nil, tests.ErrFakeDB)
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
})
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
h := resp.Header
data, _ := ioutil.ReadAll(resp.Body)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
assert.Equal(t, "application/json", h.Get("Content-Type"))
assert.Equal(t, buildError(""), data)
hw.um.AssertExpectations(t)
})
t.Run("invalid session provided", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
hw := newHandlersWrapper()
hw.um.On("CheckSession", r.Context(), sessionID, sessionDuration).
Return(&hub.CheckSessionOutput{UserID: "", Valid: false}, nil)
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
})
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
h := resp.Header
data, _ := ioutil.ReadAll(resp.Body)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, "application/json", h.Get("Content-Type"))
assert.Equal(t, buildError(errInvalidSession.Error()), data)
hw.um.AssertExpectations(t)
})
t.Run("session cookie based authentication succeeded", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
hw := newHandlersWrapper()
hw.um.On("CheckSession", r.Context(), sessionID, sessionDuration).
Return(&hub.CheckSessionOutput{UserID: "userID", Valid: true}, nil)
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
})
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
hw.um.AssertExpectations(t)
})
})
t.Run("no authentication method used", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
hw := newHandlersWrapper()
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
h := resp.Header
data, _ := ioutil.ReadAll(resp.Body)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, "application/json", h.Get("Content-Type"))
assert.Equal(t, buildError(""), data)
})
}
func TestResetPassword(t *testing.T) {
t.Run("invalid input", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`code`)
r, _ := http.NewRequest("PUT", "/", body)
hw := newHandlersWrapper()
hw.h.ResetPassword(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("password reset failed (invalid code)", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"code": "code", "password": "password"}`)
r, _ := http.NewRequest("PUT", "/", body)
hw := newHandlersWrapper()
hw.um.On("ResetPassword", r.Context(), "code", "password", "baseURL").
Return(user.ErrInvalidPasswordResetCode)
hw.h.ResetPassword(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("password reset failed (db error)", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"code": "code", "password": "password"}`)
r, _ := http.NewRequest("PUT", "/", body)
hw := newHandlersWrapper()
hw.um.On("ResetPassword", r.Context(), "code", "password", "baseURL").Return(tests.ErrFakeDB)
hw.h.ResetPassword(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("password reset succeeded", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"code": "code", "password": "password"}`)
r, _ := http.NewRequest("PUT", "/", body)
r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
hw := newHandlersWrapper()
hw.um.On("ResetPassword", r.Context(), "code", "password", "baseURL").Return(nil)
hw.h.ResetPassword(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
hw.um.AssertExpectations(t)
})
}
func TestSetupTFA(t *testing.T) {
t.Run("tfa setup failed", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("POST", "/", nil)
hw := newHandlersWrapper()
hw.um.On("SetupTFA", r.Context()).Return(nil, tests.ErrFake)
hw.h.SetupTFA(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("tfa setup succeeded", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("POST", "/", nil)
hw := newHandlersWrapper()
hw.um.On("SetupTFA", r.Context()).Return([]byte("dataJSON"), nil)
hw.h.SetupTFA(w, r)
resp := w.Result()
defer resp.Body.Close()
h := resp.Header
data, _ := ioutil.ReadAll(resp.Body)
assert.Equal(t, http.StatusCreated, resp.StatusCode)
assert.Equal(t, "application/json", h.Get("Content-Type"))
assert.Equal(t, []byte("dataJSON"), data)
hw.um.AssertExpectations(t)
})
}
func TestUpdatePassword(t *testing.T) {
t.Run("no old password provided", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"new": "new"}`)
r, _ := http.NewRequest("PUT", "/", body)
hw := newHandlersWrapper()
hw.um.On("UpdatePassword", r.Context(), "", "new").Return(hub.ErrInvalidInput)
hw.h.UpdatePassword(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("no new password provided", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"old": "old"}`)
r, _ := http.NewRequest("PUT", "/", body)
hw := newHandlersWrapper()
hw.um.On("UpdatePassword", r.Context(), "old", "").Return(hub.ErrInvalidInput)
hw.h.UpdatePassword(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("invalid old password provided", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"old": "invalid", "new": "new"}`)
r, _ := http.NewRequest("PUT", "/", body)
r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
hw := newHandlersWrapper()
hw.um.On("UpdatePassword", r.Context(), "invalid", "new").
Return(user.ErrInvalidPassword)
hw.h.UpdatePassword(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("error updating password", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"old": "old", "new": "new"}`)
r, _ := http.NewRequest("PUT", "/", body)
r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
hw := newHandlersWrapper()
hw.um.On("UpdatePassword", r.Context(), "old", "new").
Return(tests.ErrFakeDB)
hw.h.UpdatePassword(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("password updated successfully", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`{"old": "old", "new": "new"}`)
r, _ := http.NewRequest("PUT", "/", body)
r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
hw := newHandlersWrapper()
hw.um.On("UpdatePassword", r.Context(), "old", "new").Return(nil)
hw.h.UpdatePassword(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
hw.um.AssertExpectations(t)
})
}
func TestUpdateProfile(t *testing.T) {
userJSON := `{"first_name": "firstname", "last_name": "lastname"}`
u := &hub.User{}
_ = json.Unmarshal([]byte(userJSON), &u)
t.Run("invalid input", func(t *testing.T) {
testCases := []struct {
desc string
userJSON string
umErr error
}{
{
"no user provided",
"",
nil,
},
{
"invalid user json",
"{invalid json",
nil,
},
{
"alias not provided",
"{}",
hub.ErrInvalidInput,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("PUT", "/", strings.NewReader(tc.userJSON))
hw := newHandlersWrapper()
if tc.umErr != nil {
hw.um.On("UpdateProfile", r.Context(), mock.Anything).Return(tc.umErr)
}
hw.h.UpdateProfile(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
hw.um.AssertExpectations(t)
})
}
})
t.Run("error updating profile", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("PUT", "/", strings.NewReader(userJSON))
r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
hw := newHandlersWrapper()
hw.um.On("UpdateProfile", r.Context(), u).Return(tests.ErrFakeDB)
hw.h.UpdateProfile(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("user profile updated successfully", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("PUT", "/", strings.NewReader(userJSON))
r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
hw := newHandlersWrapper()
hw.um.On("UpdateProfile", r.Context(), u).Return(nil)
hw.h.UpdateProfile(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
hw.um.AssertExpectations(t)
})
}
func TestVerifyEmail(t *testing.T) {
testCases := []struct {
description string
response []interface{}
expectedStatusCode int
}{
{
"code not provided",
[]interface{}{false, hub.ErrInvalidInput},
http.StatusBadRequest,
},
{
"code not verified",
[]interface{}{false, nil},
http.StatusGone,
},
{
"code verified",
[]interface{}{true, nil},
http.StatusNoContent,
},
{
"database error",
[]interface{}{false, tests.ErrFakeDB},
http.StatusInternalServerError,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.description, func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("POST", "/", strings.NewReader(`{"code": "1234"}`))
hw := newHandlersWrapper()
hw.um.On("VerifyEmail", r.Context(), "1234").Return(tc.response...)
hw.h.VerifyEmail(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, tc.expectedStatusCode, resp.StatusCode)
hw.um.AssertExpectations(t)
})
}
}
func TestVerifyPasswordResetCode(t *testing.T) {
t.Run("invalid input", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
body := strings.NewReader(`code`)
r, _ := http.NewRequest("POST", "/", body)
hw := newHandlersWrapper()
hw.h.VerifyPasswordResetCode(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
hw.um.AssertExpectations(t)
})
t.Run("valid input", func(t *testing.T) {
testCases := []struct {
description string
err error
expectedStatusCode int
}{
{
"valid code",
nil,
http.StatusOK,
},
{
"invalid code",
user.ErrInvalidPasswordResetCode,
http.StatusGone,
},
{
"database error",
tests.ErrFakeDB,
http.StatusInternalServerError,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.description, func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequest("POST", "/", strings.NewReader(`{"code": "code"}`))
hw := newHandlersWrapper()
hw.um.On("VerifyPasswordResetCode", r.Context(), "code").Return(tc.err)
hw.h.VerifyPasswordResetCode(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, tc.expectedStatusCode, resp.StatusCode)
hw.um.AssertExpectations(t)
})
}
})
}
func testsOK(w http.ResponseWriter, r *http.Request) {}
type handlersWrapper struct {
cfg *viper.Viper
um *user.ManagerMock
am *apikey.ManagerMock
h *Handlers
}
func newHandlersWrapper() *handlersWrapper {
cfg := viper.New()
cfg.Set("server.baseURL", "baseURL")
cfg.Set("server.oauth.github", map[string]string{})
um := &user.ManagerMock{}
am := &apikey.ManagerMock{}
h, _ := NewHandlers(context.Background(), um, am, cfg)
return &handlersWrapper{
cfg: cfg,
um: um,
am: am,
h: h,
}
}
func buildError(msg string) []byte {
data := map[string]interface{}{
"message": msg,
}
dataJSON, _ := json.Marshal(data)
return append(dataJSON, '\n')
}