mirror of https://github.com/artifacthub/hub.git
1562 lines
40 KiB
Go
1562 lines
40 KiB
Go
package user
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha512"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/artifacthub/hub/internal/email"
|
|
"github.com/artifacthub/hub/internal/hub"
|
|
"github.com/artifacthub/hub/internal/tests"
|
|
"github.com/jackc/pgx/v4"
|
|
"github.com/pquerna/otp/totp"
|
|
"github.com/satori/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
func TestApproveSession(t *testing.T) {
|
|
ctx := context.Background()
|
|
sessionID := []byte("sessionID")
|
|
hashedSessionID := hashSessionID([]byte("sessionID"))
|
|
opts := totp.GenerateOpts{
|
|
Issuer: "Artifact Hub",
|
|
AccountName: "test@email.com",
|
|
}
|
|
key, _ := totp.Generate(opts)
|
|
code1 := "code1"
|
|
tfaConfig := &hub.TFAConfig{
|
|
Enabled: true,
|
|
URL: key.URL(),
|
|
RecoveryCodes: []string{code1},
|
|
}
|
|
tfaConfigJSON, _ := json.Marshal(tfaConfig)
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
errMsg string
|
|
sessionID []byte
|
|
passcode string
|
|
}{
|
|
{
|
|
"sessionID not provided",
|
|
nil,
|
|
"123456",
|
|
},
|
|
{
|
|
"sessionID not provided",
|
|
[]byte(""),
|
|
"123456",
|
|
},
|
|
{
|
|
"passcode not provided",
|
|
[]byte("sessionID"),
|
|
"",
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.errMsg, func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
err := m.ApproveSession(ctx, tc.sessionID, tc.passcode)
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
assert.Contains(t, err.Error(), tc.errMsg)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("error getting user id from session", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashedSessionID).Return("", tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.ApproveSession(ctx, sessionID, "123456")
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("error getting tfa config from database", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil)
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(nil, tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.ApproveSession(ctx, sessionID, "123456")
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("invalid passcode provided", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil)
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.ApproveSession(ctx, sessionID, "123456")
|
|
assert.Equal(t, errInvalidTFAPasscode, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("session approved successfully", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil)
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
|
db.On("Exec", ctx, approveSessionDBQ, hashedSessionID, "").Return(nil)
|
|
m := NewManager(db, nil)
|
|
|
|
passcode, _ := totp.GenerateCode(key.Secret(), time.Now())
|
|
err := m.ApproveSession(ctx, sessionID, passcode)
|
|
assert.Nil(t, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("session approved successfully (using valid recovery code)", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil)
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
|
db.On("Exec", ctx, approveSessionDBQ, hashedSessionID, code1).Return(nil)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.ApproveSession(ctx, sessionID, code1)
|
|
assert.Nil(t, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestCheckAPIKey(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
errMsg string
|
|
apiKeyID string
|
|
apiKeySecret string
|
|
}{
|
|
{
|
|
"api key id or secret not provided",
|
|
"",
|
|
"secret",
|
|
},
|
|
{
|
|
"api key id or secret not provided",
|
|
"key",
|
|
"",
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.errMsg, func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
_, err := m.CheckAPIKey(ctx, tc.apiKeyID, tc.apiKeySecret)
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
assert.Contains(t, err.Error(), tc.errMsg)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("key info not found in database", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return(nil, pgx.ErrNoRows)
|
|
m := NewManager(db, nil)
|
|
|
|
output, err := m.CheckAPIKey(ctx, "keyID", "secret")
|
|
assert.NoError(t, err)
|
|
assert.False(t, output.Valid)
|
|
assert.Empty(t, output.UserID)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("error getting key info from database", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return(nil, tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
output, err := m.CheckAPIKey(ctx, "keyID", "secret")
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
assert.Nil(t, output)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("valid key (secret hashed with bcrypt)", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
secretHashed, _ := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.DefaultCost)
|
|
db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return([]interface{}{"userID", string(secretHashed)}, nil)
|
|
m := NewManager(db, nil)
|
|
|
|
output, err := m.CheckAPIKey(ctx, "keyID", "secret")
|
|
assert.NoError(t, err)
|
|
assert.True(t, output.Valid)
|
|
assert.Equal(t, "userID", output.UserID)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("valid key (secret hashed with sha512)", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
secretHashed := fmt.Sprintf("%x", sha512.Sum512([]byte("secret")))
|
|
db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return([]interface{}{"userID", secretHashed}, nil)
|
|
m := NewManager(db, nil)
|
|
|
|
output, err := m.CheckAPIKey(ctx, "keyID", "secret")
|
|
assert.NoError(t, err)
|
|
assert.True(t, output.Valid)
|
|
assert.Equal(t, "userID", output.UserID)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestCheckAvailability(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
errMsg string
|
|
resourceKind string
|
|
value string
|
|
}{
|
|
{
|
|
"invalid resource kind",
|
|
"invalid",
|
|
"value",
|
|
},
|
|
{
|
|
"invalid value",
|
|
"userAlias",
|
|
"",
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.errMsg, func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
_, err := m.CheckAvailability(ctx, tc.resourceKind, tc.value)
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
assert.Contains(t, err.Error(), tc.errMsg)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("database query succeeded", func(t *testing.T) {
|
|
testCases := []struct {
|
|
resourceKind string
|
|
dbQuery string
|
|
available bool
|
|
}{
|
|
{
|
|
"userAlias",
|
|
checkUserAliasAvailDBQ,
|
|
true,
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(fmt.Sprintf("resource kind: %s", tc.resourceKind), func(t *testing.T) {
|
|
t.Parallel()
|
|
tc.dbQuery = fmt.Sprintf("select not exists (%s)", tc.dbQuery)
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, tc.dbQuery, "value").Return(tc.available, nil)
|
|
m := NewManager(db, nil)
|
|
|
|
available, err := m.CheckAvailability(ctx, tc.resourceKind, "value")
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tc.available, available)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("database error", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
dbQuery := fmt.Sprintf(`select not exists (%s)`, checkUserAliasAvailDBQ)
|
|
db.On("QueryRow", ctx, dbQuery, "value").Return(false, tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
available, err := m.CheckAvailability(ctx, "userAlias", "value")
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
assert.False(t, available)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestCheckCredentials(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
errMsg string
|
|
email string
|
|
password string
|
|
}{
|
|
{
|
|
"email not provided",
|
|
"",
|
|
"password",
|
|
},
|
|
{
|
|
"password not provided",
|
|
"email",
|
|
"",
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.errMsg, func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
_, err := m.CheckCredentials(ctx, tc.email, tc.password)
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
assert.Contains(t, err.Error(), tc.errMsg)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("credentials provided not found in database", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, checkUserCredsDBQ, "email").Return(nil, pgx.ErrNoRows)
|
|
m := NewManager(db, nil)
|
|
|
|
output, err := m.CheckCredentials(ctx, "email", "pass")
|
|
assert.NoError(t, err)
|
|
assert.False(t, output.Valid)
|
|
assert.Empty(t, output.UserID)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("error getting credentials from database", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, checkUserCredsDBQ, "email").Return(nil, tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
output, err := m.CheckCredentials(ctx, "email", "pass")
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
assert.Nil(t, output)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("invalid credentials provided", func(t *testing.T) {
|
|
t.Parallel()
|
|
pw, _ := bcrypt.GenerateFromPassword([]byte("pass"), bcrypt.DefaultCost)
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, checkUserCredsDBQ, "email").Return([]interface{}{"userID", string(pw)}, nil)
|
|
m := NewManager(db, nil)
|
|
|
|
output, err := m.CheckCredentials(ctx, "email", "pass2")
|
|
assert.NoError(t, err)
|
|
assert.False(t, output.Valid)
|
|
assert.Empty(t, output.UserID)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("valid credentials provided", func(t *testing.T) {
|
|
t.Parallel()
|
|
pw, _ := bcrypt.GenerateFromPassword([]byte("pass"), bcrypt.DefaultCost)
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, checkUserCredsDBQ, "email").Return([]interface{}{"userID", string(pw)}, nil)
|
|
m := NewManager(db, nil)
|
|
|
|
output, err := m.CheckCredentials(ctx, "email", "pass")
|
|
assert.NoError(t, err)
|
|
assert.True(t, output.Valid)
|
|
assert.Equal(t, "userID", output.UserID)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestCheckSession(t *testing.T) {
|
|
ctx := context.Background()
|
|
hashedSessionID := hashSessionID([]byte("sessionID"))
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
errMsg string
|
|
sessionID []byte
|
|
duration time.Duration
|
|
}{
|
|
{
|
|
"session id not provided",
|
|
nil,
|
|
10,
|
|
},
|
|
{
|
|
"session id not provided",
|
|
[]byte(""),
|
|
10,
|
|
},
|
|
{
|
|
"duration not provided",
|
|
[]byte("sessionID"),
|
|
0,
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.errMsg, func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
_, err := m.CheckSession(ctx, tc.sessionID, tc.duration)
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
assert.Contains(t, err.Error(), tc.errMsg)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("session not found in database", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return(nil, pgx.ErrNoRows)
|
|
m := NewManager(db, nil)
|
|
|
|
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
|
|
assert.NoError(t, err)
|
|
assert.False(t, output.Valid)
|
|
assert.Empty(t, output.UserID)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("error getting session from database", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return(nil, tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
assert.Nil(t, output)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("session has expired", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return([]interface{}{
|
|
"userID",
|
|
int64(1),
|
|
true,
|
|
}, nil)
|
|
m := NewManager(db, nil)
|
|
|
|
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
|
|
assert.NoError(t, err)
|
|
assert.False(t, output.Valid)
|
|
assert.Empty(t, output.UserID)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("session is not approved", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return([]interface{}{
|
|
"userID",
|
|
time.Now().Unix(),
|
|
false,
|
|
}, nil)
|
|
m := NewManager(db, nil)
|
|
|
|
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
|
|
assert.NoError(t, err)
|
|
assert.False(t, output.Valid)
|
|
assert.Empty(t, output.UserID)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("valid session", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return([]interface{}{
|
|
"userID",
|
|
time.Now().Unix(),
|
|
true,
|
|
}, nil)
|
|
m := NewManager(db, nil)
|
|
|
|
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
|
|
assert.NoError(t, err)
|
|
assert.True(t, output.Valid)
|
|
assert.Equal(t, "userID", output.UserID)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestDeleteSession(t *testing.T) {
|
|
ctx := context.Background()
|
|
hashedSessionID := hashSessionID([]byte("sessionID"))
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
errMsg string
|
|
sessionID []byte
|
|
duration time.Duration
|
|
}{
|
|
{
|
|
"session id not provided",
|
|
nil,
|
|
10,
|
|
},
|
|
{
|
|
"session id not provided",
|
|
[]byte(""),
|
|
10,
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.errMsg, func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
err := m.DeleteSession(ctx, tc.sessionID)
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
assert.Contains(t, err.Error(), tc.errMsg)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("valid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
description string
|
|
dbResponse interface{}
|
|
}{
|
|
{
|
|
"session deleted successfully",
|
|
nil,
|
|
},
|
|
{
|
|
"error deleting session from database",
|
|
tests.ErrFakeDB,
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.description, func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("Exec", ctx, deleteSessionDBQ, hashedSessionID).Return(tc.dbResponse)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.DeleteSession(ctx, []byte("sessionID"))
|
|
assert.Equal(t, tc.dbResponse, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDisableTFA(t *testing.T) {
|
|
ctx := context.WithValue(context.Background(), hub.UserIDKey, "userID")
|
|
opts := totp.GenerateOpts{
|
|
Issuer: "Artifact Hub",
|
|
AccountName: "test@email.com",
|
|
}
|
|
key, _ := totp.Generate(opts)
|
|
code1 := "code1"
|
|
tfaConfig := &hub.TFAConfig{
|
|
Enabled: true,
|
|
URL: key.URL(),
|
|
RecoveryCodes: []string{code1},
|
|
}
|
|
tfaConfigJSON, _ := json.Marshal(tfaConfig)
|
|
|
|
t.Run("user id not found in ctx", func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
assert.Panics(t, func() {
|
|
_ = m.DisableTFA(ctx, "123456")
|
|
})
|
|
})
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
err := m.DisableTFA(ctx, "")
|
|
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
})
|
|
|
|
t.Run("error getting tfa config from database", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(nil, tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.DisableTFA(ctx, "123456")
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("invalid passcode provided", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.DisableTFA(ctx, "123456")
|
|
assert.Equal(t, errInvalidTFAPasscode, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("error setting tfa as disabled in the database", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
|
db.On("Exec", ctx, disableTFADBQ, "userID").Return(tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
passcode, _ := totp.GenerateCode(key.Secret(), time.Now())
|
|
err := m.DisableTFA(ctx, passcode)
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("tfa disabled successfully", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
|
db.On("Exec", ctx, disableTFADBQ, "userID").Return(nil)
|
|
m := NewManager(db, nil)
|
|
|
|
passcode, _ := totp.GenerateCode(key.Secret(), time.Now())
|
|
err := m.DisableTFA(ctx, passcode)
|
|
assert.Nil(t, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("tfa disabled successfully (using valid recovery code)", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
|
db.On("Exec", ctx, disableTFADBQ, "userID").Return(nil)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.DisableTFA(ctx, code1)
|
|
assert.Nil(t, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestEnableTFA(t *testing.T) {
|
|
ctx := context.WithValue(context.Background(), hub.UserIDKey, "userID")
|
|
opts := totp.GenerateOpts{
|
|
Issuer: "Artifact Hub",
|
|
AccountName: "test@email.com",
|
|
}
|
|
key, _ := totp.Generate(opts)
|
|
tfaConfig := &hub.TFAConfig{
|
|
Enabled: true,
|
|
URL: key.URL(),
|
|
}
|
|
tfaConfigJSON, _ := json.Marshal(tfaConfig)
|
|
|
|
t.Run("user id not found in ctx", func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
assert.Panics(t, func() {
|
|
_ = m.EnableTFA(ctx, "123456")
|
|
})
|
|
})
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
err := m.EnableTFA(ctx, "")
|
|
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
})
|
|
|
|
t.Run("error getting tfa config from database", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(nil, tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.EnableTFA(ctx, "123456")
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("invalid passcode provided", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.EnableTFA(ctx, "123456")
|
|
assert.Equal(t, errInvalidTFAPasscode, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("error setting tfa as enabled in the database", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
|
db.On("Exec", ctx, enableTFADBQ, "userID").Return(tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
passcode, _ := totp.GenerateCode(key.Secret(), time.Now())
|
|
err := m.EnableTFA(ctx, passcode)
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("error sending tfa enabled email nofication", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
|
db.On("Exec", ctx, enableTFADBQ, "userID").Return(nil)
|
|
db.On("QueryRow", ctx, getUserEmailDBQ, "userID").Return("email", nil)
|
|
es := &email.SenderMock{}
|
|
es.On("SendEmail", mock.Anything).Return(email.ErrFakeSenderFailure)
|
|
m := NewManager(db, es)
|
|
|
|
passcode, _ := totp.GenerateCode(key.Secret(), time.Now())
|
|
err := m.EnableTFA(ctx, passcode)
|
|
assert.Equal(t, email.ErrFakeSenderFailure, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("tfa enabled successfully", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
|
db.On("Exec", ctx, enableTFADBQ, "userID").Return(nil)
|
|
db.On("QueryRow", ctx, getUserEmailDBQ, "userID").Return("email", nil)
|
|
es := &email.SenderMock{}
|
|
es.On("SendEmail", mock.Anything).Return(nil)
|
|
m := NewManager(db, es)
|
|
|
|
passcode, _ := totp.GenerateCode(key.Secret(), time.Now())
|
|
err := m.EnableTFA(ctx, passcode)
|
|
assert.Nil(t, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestGetProfile(t *testing.T) {
|
|
ctx := context.WithValue(context.Background(), hub.UserIDKey, "userID")
|
|
|
|
t.Run("user id not found in ctx", func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
assert.Panics(t, func() {
|
|
_, _ = m.GetProfile(context.Background())
|
|
})
|
|
})
|
|
|
|
t.Run("database error", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserProfileDBQ, "userID").Return(nil, tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
profile, err := m.GetProfile(ctx)
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
assert.Nil(t, profile)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("database query succeeded", func(t *testing.T) {
|
|
t.Parallel()
|
|
expectedProfile := &hub.User{
|
|
Alias: "alias",
|
|
FirstName: "first_name",
|
|
LastName: "last_name",
|
|
Email: "email",
|
|
ProfileImageID: "profile_image_id",
|
|
PasswordSet: true,
|
|
TFAEnabled: true,
|
|
}
|
|
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserProfileDBQ, "userID").Return([]byte(`
|
|
{
|
|
"alias": "alias",
|
|
"first_name": "first_name",
|
|
"last_name": "last_name",
|
|
"email": "email",
|
|
"profile_image_id": "profile_image_id",
|
|
"password_set": true,
|
|
"tfa_enabled": true
|
|
}
|
|
`), nil)
|
|
m := NewManager(db, nil)
|
|
|
|
profile, err := m.GetProfile(ctx)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expectedProfile, profile)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestGetProfileJSON(t *testing.T) {
|
|
ctx := context.WithValue(context.Background(), hub.UserIDKey, "userID")
|
|
|
|
t.Run("user id not found in ctx", func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
assert.Panics(t, func() {
|
|
_, _ = m.GetProfileJSON(context.Background())
|
|
})
|
|
})
|
|
|
|
t.Run("database query succeeded", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserProfileDBQ, "userID").Return([]byte("dataJSON"), nil)
|
|
m := NewManager(db, nil)
|
|
|
|
data, err := m.GetProfileJSON(ctx)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, []byte("dataJSON"), data)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("database error", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserProfileDBQ, "userID").Return(nil, tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
data, err := m.GetProfileJSON(ctx)
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
assert.Nil(t, data)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestGetUserID(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
_, err := m.GetUserID(ctx, "")
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
})
|
|
|
|
t.Run("database query succeeded", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserIDFromEmailDBQ, "email").Return("userID", nil)
|
|
m := NewManager(db, nil)
|
|
|
|
userID, err := m.GetUserID(ctx, "email")
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "userID", userID)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("database error", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserIDFromEmailDBQ, "email").Return("", tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
userID, err := m.GetUserID(ctx, "email")
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
assert.Empty(t, userID)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestRegisterSession(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
s := &hub.Session{
|
|
UserID: "00000000-0000-0000-0000-000000000001",
|
|
IP: "192.168.1.100",
|
|
UserAgent: "Safari 13.0.5",
|
|
}
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
errMsg string
|
|
userID string
|
|
}{
|
|
{
|
|
"user id not provided",
|
|
"",
|
|
},
|
|
{
|
|
"invalid user id",
|
|
"invalid",
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.errMsg, func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
s := &hub.Session{UserID: tc.userID}
|
|
_, err := m.RegisterSession(ctx, s)
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
assert.Contains(t, err.Error(), tc.errMsg)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("database error", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, registerSessionDBQ, mock.Anything).Return(nil, tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
sessionID, err := m.RegisterSession(ctx, s)
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
assert.Nil(t, sessionID)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("successful session registration", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, registerSessionDBQ, mock.Anything).Return([]interface{}{
|
|
[]byte("sessionID"),
|
|
true,
|
|
}, nil)
|
|
m := NewManager(db, nil)
|
|
|
|
session, err := m.RegisterSession(ctx, s)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, []byte("sessionID"), session.SessionID)
|
|
assert.True(t, session.Approved)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestRegisterPasswordResetCode(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
errMsg string
|
|
userEmail string
|
|
baseURL string
|
|
}{
|
|
{
|
|
"email not provided",
|
|
"",
|
|
"http://baseurl.com",
|
|
},
|
|
{
|
|
"invalid base url",
|
|
"email@email.com",
|
|
"invalid",
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.errMsg, func(t *testing.T) {
|
|
t.Parallel()
|
|
es := &email.SenderMock{}
|
|
m := NewManager(nil, es)
|
|
|
|
err := m.RegisterPasswordResetCode(ctx, tc.userEmail, tc.baseURL)
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
assert.Contains(t, err.Error(), tc.errMsg)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("successful password reset code registration in database", func(t *testing.T) {
|
|
testCases := []struct {
|
|
description string
|
|
emailSenderResponse error
|
|
}{
|
|
{
|
|
"password reset code sent successfully",
|
|
nil,
|
|
},
|
|
{
|
|
"error sending password reset code",
|
|
email.ErrFakeSenderFailure,
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.description, func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, registerPasswordResetCodeDBQ, "email@email.com").Return([]byte("code"), nil)
|
|
es := &email.SenderMock{}
|
|
es.On("SendEmail", mock.Anything).Return(tc.emailSenderResponse)
|
|
m := NewManager(db, es)
|
|
|
|
err := m.RegisterPasswordResetCode(ctx, "email@email.com", "http://baseurl.com")
|
|
assert.Equal(t, tc.emailSenderResponse, err)
|
|
db.AssertExpectations(t)
|
|
es.AssertExpectations(t)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("database error registering password reset code", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, registerPasswordResetCodeDBQ, "email@email.com").Return(nil, tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.RegisterPasswordResetCode(ctx, "email@email.com", "http://baseurl.com")
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestRegisterUser(t *testing.T) {
|
|
ctx := context.Background()
|
|
password := "a66bV.Xp2" // #nosec
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
errMsg string
|
|
user *hub.User
|
|
baseURL string
|
|
}{
|
|
{
|
|
"alias not provided",
|
|
&hub.User{},
|
|
"http://baseurl.com",
|
|
},
|
|
{
|
|
"email not provided",
|
|
&hub.User{Alias: "user1"},
|
|
"http://baseurl.com",
|
|
},
|
|
{
|
|
"invalid base url",
|
|
&hub.User{Alias: "user1", Email: "email"},
|
|
"invalid",
|
|
},
|
|
{
|
|
"invalid profile image id",
|
|
&hub.User{Alias: "user1", Email: "email", ProfileImageID: "invalid"},
|
|
"http://baseurl.com",
|
|
},
|
|
{
|
|
"insecure password",
|
|
&hub.User{Alias: "user1", Email: "email", Password: "hello"},
|
|
"http://baseurl.com",
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.errMsg, func(t *testing.T) {
|
|
t.Parallel()
|
|
es := &email.SenderMock{}
|
|
m := NewManager(nil, es)
|
|
|
|
err := m.RegisterUser(ctx, tc.user, tc.baseURL)
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
assert.Contains(t, err.Error(), tc.errMsg)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("successful user registration in database", func(t *testing.T) {
|
|
code := "emailVerificationCode"
|
|
testCases := []struct {
|
|
description string
|
|
emailSenderResponse error
|
|
}{
|
|
{
|
|
"email verification code sent successfully",
|
|
nil,
|
|
},
|
|
{
|
|
"error sending email verification code",
|
|
email.ErrFakeSenderFailure,
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.description, func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, registerUserDBQ, mock.Anything).Return(&code, nil)
|
|
es := &email.SenderMock{}
|
|
es.On("SendEmail", mock.Anything).Return(tc.emailSenderResponse)
|
|
m := NewManager(db, es)
|
|
|
|
u := &hub.User{
|
|
Alias: "alias",
|
|
FirstName: "first_name",
|
|
LastName: "last_name",
|
|
Email: "email@email.com",
|
|
Password: password,
|
|
ProfileImageID: "00000000-0000-0000-0000-000000000001",
|
|
}
|
|
err := m.RegisterUser(ctx, u, "http://baseurl.com")
|
|
assert.Equal(t, tc.emailSenderResponse, err)
|
|
db.AssertExpectations(t)
|
|
es.AssertExpectations(t)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("database error registering user", func(t *testing.T) {
|
|
t.Parallel()
|
|
code := ""
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, registerUserDBQ, mock.Anything).Return(&code, tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
u := &hub.User{
|
|
Alias: "alias",
|
|
Email: "email@email.com",
|
|
Password: password,
|
|
}
|
|
err := m.RegisterUser(ctx, u, "http://baseurl.com")
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestResetPassword(t *testing.T) {
|
|
ctx := context.Background()
|
|
code := []byte("code")
|
|
codeB64 := base64.URLEncoding.EncodeToString(code)
|
|
newPassword := "a66bV.Xp2" // #nosec
|
|
baseURL := "http://baseurl.com"
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
errMsg string
|
|
codeB64 string
|
|
newPassword string
|
|
baseURL string
|
|
}{
|
|
{
|
|
"code not provided",
|
|
"",
|
|
newPassword,
|
|
baseURL,
|
|
},
|
|
{
|
|
"new password not provided",
|
|
"code",
|
|
"",
|
|
baseURL,
|
|
},
|
|
{
|
|
"invalid base url",
|
|
"code",
|
|
newPassword,
|
|
"invalid",
|
|
},
|
|
{
|
|
"insecure password",
|
|
"code",
|
|
"password",
|
|
baseURL,
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.errMsg, func(t *testing.T) {
|
|
t.Parallel()
|
|
es := &email.SenderMock{}
|
|
m := NewManager(nil, es)
|
|
err := m.ResetPassword(ctx, tc.codeB64, tc.newPassword, tc.baseURL)
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
assert.Contains(t, err.Error(), tc.errMsg)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("database error resetting password", func(t *testing.T) {
|
|
testCases := []struct {
|
|
dbErr error
|
|
expectedErr error
|
|
}{
|
|
{
|
|
tests.ErrFake,
|
|
tests.ErrFake,
|
|
},
|
|
{
|
|
errInvalidPasswordResetCodeDB,
|
|
ErrInvalidPasswordResetCode,
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.dbErr.Error(), func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, resetUserPasswordDBQ, code, mock.Anything).Return("", tc.dbErr)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.ResetPassword(ctx, codeB64, newPassword, baseURL)
|
|
assert.Equal(t, tc.expectedErr, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("successful password reset in database", func(t *testing.T) {
|
|
testCases := []struct {
|
|
description string
|
|
emailSenderResponse error
|
|
}{
|
|
{
|
|
"password reset success email sent successfully",
|
|
nil,
|
|
},
|
|
{
|
|
"error sending password reset success email",
|
|
email.ErrFakeSenderFailure,
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.description, func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, resetUserPasswordDBQ, code, mock.Anything).Return("email", nil)
|
|
es := &email.SenderMock{}
|
|
es.On("SendEmail", mock.Anything).Return(tc.emailSenderResponse)
|
|
m := NewManager(db, es)
|
|
|
|
err := m.ResetPassword(ctx, codeB64, newPassword, baseURL)
|
|
assert.Equal(t, tc.emailSenderResponse, err)
|
|
db.AssertExpectations(t)
|
|
es.AssertExpectations(t)
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestSetupTFA(t *testing.T) {
|
|
ctx := context.WithValue(context.Background(), hub.UserIDKey, "userID")
|
|
|
|
t.Run("user id not found in ctx", func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
assert.Panics(t, func() {
|
|
_, _ = m.SetupTFA(context.Background())
|
|
})
|
|
})
|
|
|
|
t.Run("error getting requesting user email", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserEmailDBQ, "userID").Return("", tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
dataJSON, err := m.SetupTFA(ctx)
|
|
assert.Nil(t, dataJSON)
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("error storing tfa info in database", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserEmailDBQ, "userID").Return("email", nil)
|
|
db.On("Exec", ctx, updateTFAInfoDBQ, "userID", mock.Anything, mock.Anything).Return(tests.ErrFake)
|
|
m := NewManager(db, nil)
|
|
|
|
dataJSON, err := m.SetupTFA(ctx)
|
|
assert.Nil(t, dataJSON)
|
|
assert.Equal(t, tests.ErrFake, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("setup tfa succeeded", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserEmailDBQ, "userID").Return("email", nil)
|
|
db.On("Exec", ctx, updateTFAInfoDBQ, "userID", mock.Anything, mock.Anything).Return(nil)
|
|
m := NewManager(db, nil)
|
|
|
|
dataJSON, err := m.SetupTFA(ctx)
|
|
assert.NotNil(t, dataJSON)
|
|
assert.Nil(t, err)
|
|
var output *hub.SetupTFAOutput
|
|
err = json.Unmarshal(dataJSON, &output)
|
|
require.NoError(t, err)
|
|
assert.NotEmpty(t, output.QRCode)
|
|
assert.NotEmpty(t, output.Secret)
|
|
for _, code := range output.RecoveryCodes {
|
|
_, err := uuid.FromString(code)
|
|
assert.NoError(t, err, code)
|
|
}
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestUpdatePassword(t *testing.T) {
|
|
ctx := context.WithValue(context.Background(), hub.UserIDKey, "userID")
|
|
oldHashed, _ := bcrypt.GenerateFromPassword([]byte("old"), bcrypt.DefaultCost)
|
|
new := "a66bV.Xp2"
|
|
|
|
t.Run("user id not found in ctx", func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
assert.Panics(t, func() {
|
|
_ = m.UpdatePassword(context.Background(), "old", "new")
|
|
})
|
|
})
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
errMsg string
|
|
old string
|
|
new string
|
|
}{
|
|
{
|
|
"old password not provided",
|
|
"",
|
|
"new",
|
|
},
|
|
{
|
|
"new password not provided",
|
|
"old",
|
|
"",
|
|
},
|
|
{
|
|
"insecure password",
|
|
"old",
|
|
"new",
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.errMsg, func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
err := m.UpdatePassword(ctx, tc.old, tc.new)
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
assert.Contains(t, err.Error(), tc.errMsg)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("database error getting user password", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserPasswordDBQ, "userID").Return("", tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.UpdatePassword(ctx, "old", new)
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("invalid user password provided", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserPasswordDBQ, "userID").Return(string(oldHashed), nil)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.UpdatePassword(ctx, "old2", new)
|
|
assert.Error(t, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("database error updating password", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserPasswordDBQ, "userID").Return(string(oldHashed), nil)
|
|
db.On("Exec", ctx, updateUserPasswordDBQ, "userID", mock.Anything, mock.Anything).
|
|
Return(tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.UpdatePassword(ctx, "old", new)
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("successful password update", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, getUserPasswordDBQ, "userID").Return(string(oldHashed), nil)
|
|
db.On("Exec", ctx, updateUserPasswordDBQ, "userID", mock.Anything, mock.Anything).Return(nil)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.UpdatePassword(ctx, "old", new)
|
|
assert.NoError(t, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestUpdateProfile(t *testing.T) {
|
|
ctx := context.WithValue(context.Background(), hub.UserIDKey, "userID")
|
|
|
|
t.Run("user id not found in ctx", func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
assert.Panics(t, func() {
|
|
_ = m.UpdateProfile(context.Background(), &hub.User{})
|
|
})
|
|
})
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
errMsg string
|
|
user *hub.User
|
|
}{
|
|
{
|
|
"alias not provided",
|
|
&hub.User{},
|
|
},
|
|
{
|
|
"invalid profile image id",
|
|
&hub.User{Alias: "user1", Email: "email", ProfileImageID: "invalid"},
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.errMsg, func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
err := m.UpdateProfile(ctx, tc.user)
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
assert.Contains(t, err.Error(), tc.errMsg)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("database query succeeded", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("Exec", ctx, updateUserProfileDBQ, "userID", mock.Anything).Return(nil)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.UpdateProfile(ctx, &hub.User{Alias: "user1"})
|
|
assert.NoError(t, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("database error", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("Exec", ctx, updateUserProfileDBQ, "userID", mock.Anything).Return(tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.UpdateProfile(ctx, &hub.User{Alias: "user1"})
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestVerifyEmail(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
_, err := m.VerifyEmail(ctx, "")
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
})
|
|
|
|
t.Run("successful email verification", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, verifyEmailDBQ, "emailVerificationCode").Return(true, nil)
|
|
m := NewManager(db, nil)
|
|
|
|
verified, err := m.VerifyEmail(ctx, "emailVerificationCode")
|
|
assert.NoError(t, err)
|
|
assert.True(t, verified)
|
|
db.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("database error verifying email", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("QueryRow", ctx, verifyEmailDBQ, "emailVerificationCode").Return(false, tests.ErrFakeDB)
|
|
m := NewManager(db, nil)
|
|
|
|
verified, err := m.VerifyEmail(ctx, "emailVerificationCode")
|
|
assert.Equal(t, tests.ErrFakeDB, err)
|
|
assert.False(t, verified)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestVerifyPasswordResetCode(t *testing.T) {
|
|
ctx := context.Background()
|
|
code := []byte("code")
|
|
codeB64 := base64.URLEncoding.EncodeToString(code)
|
|
|
|
t.Run("invalid input", func(t *testing.T) {
|
|
testCases := []struct {
|
|
errMsg string
|
|
codeB64 string
|
|
}{
|
|
{
|
|
"code not provided",
|
|
"",
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.errMsg, func(t *testing.T) {
|
|
t.Parallel()
|
|
m := NewManager(nil, nil)
|
|
err := m.VerifyPasswordResetCode(ctx, tc.codeB64)
|
|
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
|
assert.Contains(t, err.Error(), tc.errMsg)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("database error verifying password reset code", func(t *testing.T) {
|
|
testCases := []struct {
|
|
dbErr error
|
|
expectedErr error
|
|
}{
|
|
{
|
|
tests.ErrFake,
|
|
tests.ErrFake,
|
|
},
|
|
{
|
|
errInvalidPasswordResetCodeDB,
|
|
ErrInvalidPasswordResetCode,
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.dbErr.Error(), func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("Exec", ctx, verifyPasswordResetCodeDBQ, code).Return(tc.dbErr)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.VerifyPasswordResetCode(ctx, codeB64)
|
|
assert.Equal(t, tc.expectedErr, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("password code verified successfully in database", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := &tests.DBMock{}
|
|
db.On("Exec", ctx, verifyPasswordResetCodeDBQ, code).Return(nil)
|
|
m := NewManager(db, nil)
|
|
|
|
err := m.VerifyPasswordResetCode(ctx, codeB64)
|
|
assert.Equal(t, nil, err)
|
|
db.AssertExpectations(t)
|
|
})
|
|
}
|