diff --git a/database/migrations/functions/users/register_password_reset_code.sql b/database/migrations/functions/users/register_password_reset_code.sql index 860619c7..347eb48b 100644 --- a/database/migrations/functions/users/register_password_reset_code.sql +++ b/database/migrations/functions/users/register_password_reset_code.sql @@ -1,11 +1,18 @@ -- register_password_reset_code registers a password reset code for the user -- identified by the email provided. create or replace function register_password_reset_code(p_email text) -returns uuid as $$ - insert into password_reset_code (user_id) - select user_id from "user" where email = p_email and email_verified = true +returns bytea as $$ +declare + v_code bytea := gen_random_bytes(32); +begin + insert into password_reset_code (password_reset_code_id, user_id) + select sha512(v_code), user_id from "user" where email = p_email and email_verified = true on conflict (user_id) do update set - password_reset_code_id = gen_random_uuid(), - created_at = current_timestamp - returning password_reset_code_id; -$$ language sql; + password_reset_code_id = sha512(v_code), + created_at = current_timestamp; + if not found then + raise 'invalid email'; + end if; + return v_code; +end +$$ language plpgsql; diff --git a/database/migrations/functions/users/reset_user_password.sql b/database/migrations/functions/users/reset_user_password.sql index a8acb17a..4fcc7d32 100644 --- a/database/migrations/functions/users/reset_user_password.sql +++ b/database/migrations/functions/users/reset_user_password.sql @@ -1,25 +1,25 @@ -- reset_user_password resets the password of the user associated to the code -- provided if it is still valid, returning the email of the user. -create or replace function reset_user_password(p_password_reset_code_id uuid, p_new_password text) +create or replace function reset_user_password(p_code bytea, p_new_password text) returns text as $$ declare v_user_id uuid; v_email text; begin -- Verify password reset code - perform verify_password_reset_code(p_password_reset_code_id); + perform verify_password_reset_code(p_code); -- Get id and email of the user associated with the code select u.user_id, u.email into v_user_id, v_email from "user" u join password_reset_code prc using (user_id) - where password_reset_code_id = p_password_reset_code_id; + where password_reset_code_id = sha512(p_code); -- Update user password update "user" set password = p_new_password where user_id = v_user_id; -- Delete password reset code - delete from password_reset_code where password_reset_code_id = p_password_reset_code_id; + delete from password_reset_code where user_id = v_user_id; -- Invalidate current user sessions delete from session where user_id = v_user_id; diff --git a/database/migrations/functions/users/verify_password_reset_code.sql b/database/migrations/functions/users/verify_password_reset_code.sql index e588c9dc..964237f4 100644 --- a/database/migrations/functions/users/verify_password_reset_code.sql +++ b/database/migrations/functions/users/verify_password_reset_code.sql @@ -1,10 +1,10 @@ -- verify_password_reset_code verifies is the password reset code provided is -- valid. The code must exist and not have expired. -create or replace function verify_password_reset_code(p_password_reset_code_id uuid) +create or replace function verify_password_reset_code(p_code bytea) returns void as $$ begin perform from password_reset_code - where password_reset_code_id = p_password_reset_code_id + where password_reset_code_id = sha512(p_code) and created_at + '15 minute'::interval > current_timestamp; if not found then raise 'invalid password reset code'; diff --git a/database/migrations/schema/009_hash_password_reset_code.sql b/database/migrations/schema/009_hash_password_reset_code.sql new file mode 100644 index 00000000..57425132 --- /dev/null +++ b/database/migrations/schema/009_hash_password_reset_code.sql @@ -0,0 +1,10 @@ +delete from password_reset_code; +alter table password_reset_code alter column password_reset_code_id drop default; +alter table password_reset_code alter column password_reset_code_id type bytea USING password_reset_code_id::text::bytea; +drop function if exists register_password_reset_code(text); + +---- create above / drop below ---- + +delete from password_reset_code; +alter table password_reset_code alter column password_reset_code_id type uuid; +alter table password_reset_code alter column password_reset_code_id set default gen_random_uuid(); diff --git a/database/tests/functions/users/register_password_reset_code.sql b/database/tests/functions/users/register_password_reset_code.sql index 0214f00a..38248e9b 100644 --- a/database/tests/functions/users/register_password_reset_code.sql +++ b/database/tests/functions/users/register_password_reset_code.sql @@ -12,7 +12,7 @@ values ('00000000-0000-0000-0000-000000000002', 'user2', 'user2@email.com', fals select register_password_reset_code('user1@email.com') as code1 \gset select is( password_reset_code_id, - :'code1', + sha512(:'code1'), 'Password reset code for user1 should be registered' ) from password_reset_code @@ -23,37 +23,31 @@ where alias = 'user1'; select register_password_reset_code('user1@email.com') as code2 \gset select is( password_reset_code_id, - :'code2', + sha512(:'code2'), 'Password reset code for user1 should have been updated' ) from password_reset_code join "user" using (user_id) where alias = 'user1'; select isnt( - :'code1'::uuid, - :'code2'::uuid, + :'code1'::bytea, + :'code2'::bytea, 'Password reset code must have changed' ); -- Try registering password reset code using non verified email -select register_password_reset_code('user2@email.com'); -select is_empty( - $$ - select password_reset_code_id - from password_reset_code prc join "user" u using (user_id) - where u.email = 'user2@email.com' - $$, +select throws_ok( + $$ select register_password_reset_code('user2@email.com') $$, + 'P0001', + 'invalid email', 'No password reset code should be registered for non verified email user2@email.com' ); -- Try registering password reset code using unregistered email -select register_password_reset_code('user3@email.com'); -select is_empty( - $$ - select password_reset_code_id - from password_reset_code prc join "user" u using (user_id) - where u.email = 'user3@email.com' - $$, +select throws_ok( + $$ select register_password_reset_code('user3@email.com') $$, + 'P0001', + 'invalid email', 'No password reset code should be registered for unregistered email user3@email.com' ); diff --git a/database/tests/functions/users/reset_user_password.sql b/database/tests/functions/users/reset_user_password.sql index 5e169e0a..9d0410fb 100644 --- a/database/tests/functions/users/reset_user_password.sql +++ b/database/tests/functions/users/reset_user_password.sql @@ -14,9 +14,9 @@ values (:'user1ID', 'user1', 'user1@email.com'); insert into "user" (user_id, alias, email) values (:'user2ID', 'user2', 'user2@email.com'); insert into password_reset_code (password_reset_code_id, user_id, created_at) -values (:'code1ID', :'user1ID', current_timestamp); +values (sha512(:'code1ID'), :'user1ID', current_timestamp); insert into password_reset_code (password_reset_code_id, user_id, created_at) -values (:'code2ID', :'user2ID', current_timestamp - '30 minute'::interval); +values (sha512(:'code2ID'), :'user2ID', current_timestamp - '30 minute'::interval); insert into session (user_id) values (:'user1ID'); -- Password reset should fail in the following cases diff --git a/database/tests/functions/users/verify_password_reset_code.sql b/database/tests/functions/users/verify_password_reset_code.sql index c2991488..e205dce9 100644 --- a/database/tests/functions/users/verify_password_reset_code.sql +++ b/database/tests/functions/users/verify_password_reset_code.sql @@ -1,9 +1,10 @@ -- Start transaction and plan tests begin; -select plan(2); +select plan(3); -- Declare some variables \set user1ID '00000000-0000-0000-0000-000000000001' +\set code1ID '00000000-0000-0000-0000-000000000001' \set user2ID '00000000-0000-0000-0000-000000000002' \set code2ID '00000000-0000-0000-0000-000000000002' @@ -13,8 +14,16 @@ values (:'user1ID', 'user1', 'user1@email.com'); insert into "user" (user_id, alias, email) values (:'user2ID', 'user2', 'user2@email.com'); insert into password_reset_code (password_reset_code_id, user_id, created_at) +values (sha512(:'code1ID'), :'user1ID', current_timestamp - '5 minute'::interval); +insert into password_reset_code (password_reset_code_id, user_id, created_at) values (:'code2ID', :'user2ID', current_timestamp - '30 minute'::interval); +-- Password reset should succeed +select lives_ok( + $$ select verify_password_reset_code('00000000-0000-0000-0000-000000000001') $$, + 'Verify password reset code succeeded' +); + -- Password reset should fail in the following cases select throws_ok( $$ select verify_password_reset_code('00000000-0000-0000-0000-000000000003') $$, diff --git a/internal/user/manager.go b/internal/user/manager.go index d0f04b1d..7cd90475 100644 --- a/internal/user/manager.go +++ b/internal/user/manager.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/sha512" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -31,11 +32,11 @@ const ( registerPasswordResetCodeDBQ = `select register_password_reset_code($1::text)` registerSessionDBQ = `select register_session($1::jsonb)` registerUserDBQ = `select register_user($1::jsonb)` - resetUserPasswordDBQ = `select reset_user_password($1::uuid, $2::text)` + resetUserPasswordDBQ = `select reset_user_password($1::bytea, $2::text)` updateUserPasswordDBQ = `select update_user_password($1::uuid, $2::text, $3::text)` updateUserProfileDBQ = `select update_user_profile($1::uuid, $2::jsonb)` verifyEmailDBQ = `select verify_email($1::uuid)` - verifyPasswordResetCodeDBQ = `select verify_password_reset_code($1::uuid)` + verifyPasswordResetCodeDBQ = `select verify_password_reset_code($1::bytea)` ) var ( @@ -282,7 +283,7 @@ func (m *Manager) RegisterPasswordResetCode(ctx context.Context, userEmail, base } // Register password reset code in database - var code *string + var code []byte err := m.db.QueryRow(ctx, registerPasswordResetCodeDBQ, userEmail).Scan(&code) if err != nil { return err @@ -290,8 +291,9 @@ func (m *Manager) RegisterPasswordResetCode(ctx context.Context, userEmail, base // Send password reset email if code != nil && m.es != nil { + codeB64 := base64.URLEncoding.EncodeToString(code) templateData := map[string]string{ - "link": fmt.Sprintf("%s/reset-password?code=%s", baseURL, *code), + "link": fmt.Sprintf("%s/reset-password?code=%s", baseURL, codeB64), } var emailBody bytes.Buffer if err := passwordResetTmpl.Execute(&emailBody, templateData); err != nil { @@ -392,9 +394,9 @@ func (m *Manager) RegisterUser(ctx context.Context, user *hub.User, baseURL stri } // ResetPassword resets the user password in the database. -func (m *Manager) ResetPassword(ctx context.Context, code, newPassword, baseURL string) error { +func (m *Manager) ResetPassword(ctx context.Context, codeB64, newPassword, baseURL string) error { // Validate input - if code == "" { + if codeB64 == "" { return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "code not provided") } if newPassword == "" { @@ -414,6 +416,10 @@ func (m *Manager) ResetPassword(ctx context.Context, code, newPassword, baseURL } // Reset user password in database + code, err := base64.URLEncoding.DecodeString(codeB64) + if err != nil { + return ErrInvalidPasswordResetCode + } var userEmail string err = m.db.QueryRow(ctx, resetUserPasswordDBQ, code, string(newHashed)).Scan(&userEmail) if err != nil { @@ -515,14 +521,18 @@ func (m *Manager) VerifyEmail(ctx context.Context, code string) (bool, error) { } // VerifyPasswordResetCode verifies if the provided code is valid. -func (m *Manager) VerifyPasswordResetCode(ctx context.Context, code string) error { +func (m *Manager) VerifyPasswordResetCode(ctx context.Context, codeB64 string) error { // Validate input - if code == "" { + if codeB64 == "" { return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "code not provided") } // Verify password reset code in database - _, err := m.db.Exec(ctx, verifyPasswordResetCodeDBQ, code) + code, err := base64.URLEncoding.DecodeString(codeB64) + if err != nil { + return ErrInvalidPasswordResetCode + } + _, err = m.db.Exec(ctx, verifyPasswordResetCodeDBQ, code) if err != nil && err.Error() == errInvalidPasswordResetCodeDB.Error() { return ErrInvalidPasswordResetCode } diff --git a/internal/user/manager_test.go b/internal/user/manager_test.go index b9c02f0c..36b399ce 100644 --- a/internal/user/manager_test.go +++ b/internal/user/manager_test.go @@ -3,6 +3,7 @@ package user import ( "context" "crypto/sha512" + "encoding/base64" "errors" "fmt" "testing" @@ -637,7 +638,6 @@ func TestRegisterPasswordResetCode(t *testing.T) { }) t.Run("successful password reset code registration in database", func(t *testing.T) { - code := "passwordResetCode" testCases := []struct { description string emailSenderResponse error @@ -656,7 +656,7 @@ func TestRegisterPasswordResetCode(t *testing.T) { t.Run(tc.description, func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, registerPasswordResetCodeDBQ, "email@email.com").Return(&code, nil) + db.On("QueryRow", ctx, registerPasswordResetCodeDBQ, "email@email.com").Return([]byte("code"), nil) es := &email.SenderMock{} es.On("SendEmail", mock.Anything).Return(tc.emailSenderResponse) m := NewManager(db, es) @@ -671,9 +671,8 @@ func TestRegisterPasswordResetCode(t *testing.T) { t.Run("database error registering password reset code", func(t *testing.T) { t.Parallel() - code := "" db := &tests.DBMock{} - db.On("QueryRow", ctx, registerPasswordResetCodeDBQ, "email@email.com").Return(&code, tests.ErrFakeDB) + db.On("QueryRow", ctx, registerPasswordResetCodeDBQ, "email@email.com").Return(nil, tests.ErrFakeDB) m := NewManager(db, nil) err := m.RegisterPasswordResetCode(ctx, "email@email.com", "http://baseurl.com") @@ -787,11 +786,13 @@ func TestRegisterUser(t *testing.T) { func TestResetPassword(t *testing.T) { ctx := context.Background() + code := []byte("code") + codeB64 := base64.URLEncoding.EncodeToString(code) t.Run("invalid input", func(t *testing.T) { testCases := []struct { errMsg string - code string + codeB64 string newPassword string baseURL string }{ @@ -820,7 +821,7 @@ func TestResetPassword(t *testing.T) { t.Parallel() es := &email.SenderMock{} m := NewManager(nil, es) - err := m.ResetPassword(ctx, tc.code, tc.newPassword, tc.baseURL) + err := m.ResetPassword(ctx, tc.codeB64, tc.newPassword, tc.baseURL) assert.True(t, errors.Is(err, hub.ErrInvalidInput)) assert.Contains(t, err.Error(), tc.errMsg) }) @@ -846,10 +847,10 @@ func TestResetPassword(t *testing.T) { t.Run(tc.dbErr.Error(), func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, resetUserPasswordDBQ, "code", mock.Anything).Return("", tc.dbErr) + db.On("QueryRow", ctx, resetUserPasswordDBQ, code, mock.Anything).Return("", tc.dbErr) m := NewManager(db, nil) - err := m.ResetPassword(ctx, "code", "newPassword", "http://baseurl.com") + err := m.ResetPassword(ctx, codeB64, "newPassword", "http://baseurl.com") assert.Equal(t, tc.expectedErr, err) db.AssertExpectations(t) }) @@ -875,12 +876,12 @@ func TestResetPassword(t *testing.T) { t.Run(tc.description, func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, resetUserPasswordDBQ, "code", mock.Anything).Return("email", nil) + db.On("QueryRow", ctx, resetUserPasswordDBQ, code, mock.Anything).Return("email", nil) es := &email.SenderMock{} es.On("SendEmail", mock.Anything).Return(tc.emailSenderResponse) m := NewManager(db, es) - err := m.ResetPassword(ctx, "code", "newPassword", "http://baseurl.com") + err := m.ResetPassword(ctx, codeB64, "newPassword", "http://baseurl.com") assert.Equal(t, tc.emailSenderResponse, err) db.AssertExpectations(t) es.AssertExpectations(t) @@ -1075,11 +1076,13 @@ func TestVerifyEmail(t *testing.T) { func TestVerifyPasswordResetCode(t *testing.T) { ctx := context.Background() + code := []byte("code") + codeB64 := base64.URLEncoding.EncodeToString(code) t.Run("invalid input", func(t *testing.T) { testCases := []struct { - errMsg string - code string + errMsg string + codeB64 string }{ { "code not provided", @@ -1091,7 +1094,7 @@ func TestVerifyPasswordResetCode(t *testing.T) { t.Run(tc.errMsg, func(t *testing.T) { t.Parallel() m := NewManager(nil, nil) - err := m.VerifyPasswordResetCode(ctx, tc.code) + err := m.VerifyPasswordResetCode(ctx, tc.codeB64) assert.True(t, errors.Is(err, hub.ErrInvalidInput)) assert.Contains(t, err.Error(), tc.errMsg) }) @@ -1117,10 +1120,10 @@ func TestVerifyPasswordResetCode(t *testing.T) { t.Run(tc.dbErr.Error(), func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("Exec", ctx, verifyPasswordResetCodeDBQ, "code").Return(tc.dbErr) + db.On("Exec", ctx, verifyPasswordResetCodeDBQ, code).Return(tc.dbErr) m := NewManager(db, nil) - err := m.VerifyPasswordResetCode(ctx, "code") + err := m.VerifyPasswordResetCode(ctx, codeB64) assert.Equal(t, tc.expectedErr, err) db.AssertExpectations(t) }) @@ -1130,10 +1133,10 @@ func TestVerifyPasswordResetCode(t *testing.T) { t.Run("password code verified successfully in database", func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("Exec", ctx, verifyPasswordResetCodeDBQ, "code").Return(nil) + db.On("Exec", ctx, verifyPasswordResetCodeDBQ, code).Return(nil) m := NewManager(db, nil) - err := m.VerifyPasswordResetCode(ctx, "code") + err := m.VerifyPasswordResetCode(ctx, codeB64) assert.Equal(t, nil, err) db.AssertExpectations(t) })