From 9726b926edfa76d5853d4479ff1f5c4c354164ce Mon Sep 17 00:00:00 2001 From: "Sergio C. Arteaga" Date: Thu, 6 May 2021 22:40:37 +0200 Subject: [PATCH] Some refactoring in user manager and handlers (#1291) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sergio CastaƱo Arteaga --- .../functions/api_keys/add_api_key.sql | 10 +- .../functions/users/approve_session.sql | 2 +- .../users/register_password_reset_code.sql | 11 +- .../functions/users/register_session.sql | 7 +- .../functions/users/reset_user_password.sql | 4 +- .../users/verify_password_reset_code.sql | 4 +- .../schema/009_hash_password_reset_code.sql | 2 +- .../migrations/schema/015_refactoring.sql | 15 ++ .../tests/functions/api_keys/add_api_key.sql | 3 + .../users/register_password_reset_code.sql | 19 +- .../functions/users/register_session.sql | 44 +--- .../functions/users/reset_user_password.sql | 4 +- .../users/verify_password_reset_code.sql | 2 +- internal/apikey/manager.go | 77 +++++- internal/apikey/manager_test.go | 106 ++++++++- internal/apikey/mock.go | 11 +- internal/handlers/apikey/handlers.go | 9 +- internal/handlers/apikey/handlers_test.go | 9 +- internal/handlers/handlers.go | 2 +- internal/handlers/user/handlers.go | 43 ++-- internal/handlers/user/handlers_test.go | 46 ++-- internal/hub/apikey.go | 10 +- internal/hub/user.go | 15 +- internal/user/manager.go | 118 ++++------ internal/user/manager_test.go | 222 +++++------------- internal/user/mock.go | 13 +- 26 files changed, 416 insertions(+), 392 deletions(-) create mode 100644 database/migrations/schema/015_refactoring.sql diff --git a/database/migrations/functions/api_keys/add_api_key.sql b/database/migrations/functions/api_keys/add_api_key.sql index 58fe61a1..b5c983e6 100644 --- a/database/migrations/functions/api_keys/add_api_key.sql +++ b/database/migrations/functions/api_keys/add_api_key.sql @@ -1,9 +1,8 @@ -- add_api_key adds the provided api key to the database. create or replace function add_api_key(p_api_key jsonb) -returns setof json as $$ +returns uuid as $$ declare v_api_key_id uuid; - v_api_key_secret text := encode(gen_random_bytes(32), 'base64'); begin insert into api_key ( name, @@ -11,13 +10,10 @@ begin user_id ) values ( p_api_key->>'name', - encode(sha512(v_api_key_secret::bytea), 'hex'), + p_api_key->>'secret', (p_api_key->>'user_id')::uuid ) returning api_key_id into v_api_key_id; - return query select json_build_object( - 'api_key_id', v_api_key_id, - 'secret', v_api_key_secret - ); + return v_api_key_id; end $$ language plpgsql; diff --git a/database/migrations/functions/users/approve_session.sql b/database/migrations/functions/users/approve_session.sql index 7c5cc72a..d70223bd 100644 --- a/database/migrations/functions/users/approve_session.sql +++ b/database/migrations/functions/users/approve_session.sql @@ -1,5 +1,5 @@ -- approvee_session approves the provided session in the database. -create or replace function approve_session(p_session_id bytea, p_recovery_code text) +create or replace function approve_session(p_session_id text, p_recovery_code text) returns void as $$ begin -- Mark session as approved diff --git a/database/migrations/functions/users/register_password_reset_code.sql b/database/migrations/functions/users/register_password_reset_code.sql index 347eb48b..15ca943c 100644 --- a/database/migrations/functions/users/register_password_reset_code.sql +++ b/database/migrations/functions/users/register_password_reset_code.sql @@ -1,18 +1,15 @@ -- register_password_reset_code registers a password reset code for the user -- identified by the email provided. -create or replace function register_password_reset_code(p_email text) -returns bytea as $$ -declare - v_code bytea := gen_random_bytes(32); +create or replace function register_password_reset_code(p_email text, p_code text) +returns void as $$ begin insert into password_reset_code (password_reset_code_id, user_id) - select sha512(v_code), user_id from "user" where email = p_email and email_verified = true + select p_code, user_id from "user" where email = p_email and email_verified = true on conflict (user_id) do update set - password_reset_code_id = sha512(v_code), + password_reset_code_id = p_code, created_at = current_timestamp; if not found then raise 'invalid email'; end if; - return v_code; end $$ language plpgsql; diff --git a/database/migrations/functions/users/register_session.sql b/database/migrations/functions/users/register_session.sql index 371de28e..cd30d025 100644 --- a/database/migrations/functions/users/register_session.sql +++ b/database/migrations/functions/users/register_session.sql @@ -1,8 +1,7 @@ -- register_session registers the provided session in the database. create or replace function register_session(p_session jsonb) -returns table (session_id bytea, approved boolean) as $$ +returns boolean as $$ declare - v_session_id bytea := gen_random_bytes(32); v_approved boolean; begin -- Check if the session requires approval or not. When the user has enabled @@ -22,13 +21,13 @@ begin user_agent, approved ) values ( - sha512(v_session_id), + p_session->>'session_id', (p_session->>'user_id')::uuid, nullif(p_session->>'ip', '')::inet, nullif(p_session->>'user_agent', ''), v_approved ); - return query select v_session_id, v_approved; + return v_approved; end $$ language plpgsql; diff --git a/database/migrations/functions/users/reset_user_password.sql b/database/migrations/functions/users/reset_user_password.sql index 4fcc7d32..7c4aa139 100644 --- a/database/migrations/functions/users/reset_user_password.sql +++ b/database/migrations/functions/users/reset_user_password.sql @@ -1,6 +1,6 @@ -- reset_user_password resets the password of the user associated to the code -- provided if it is still valid, returning the email of the user. -create or replace function reset_user_password(p_code bytea, p_new_password text) +create or replace function reset_user_password(p_code text, p_new_password text) returns text as $$ declare v_user_id uuid; @@ -13,7 +13,7 @@ begin select u.user_id, u.email into v_user_id, v_email from "user" u join password_reset_code prc using (user_id) - where password_reset_code_id = sha512(p_code); + where password_reset_code_id = p_code; -- Update user password update "user" set password = p_new_password where user_id = v_user_id; diff --git a/database/migrations/functions/users/verify_password_reset_code.sql b/database/migrations/functions/users/verify_password_reset_code.sql index 964237f4..ff4ab807 100644 --- a/database/migrations/functions/users/verify_password_reset_code.sql +++ b/database/migrations/functions/users/verify_password_reset_code.sql @@ -1,10 +1,10 @@ -- verify_password_reset_code verifies is the password reset code provided is -- valid. The code must exist and not have expired. -create or replace function verify_password_reset_code(p_code bytea) +create or replace function verify_password_reset_code(p_code text) returns void as $$ begin perform from password_reset_code - where password_reset_code_id = sha512(p_code) + where password_reset_code_id = p_code and created_at + '15 minute'::interval > current_timestamp; if not found then raise 'invalid password reset code'; diff --git a/database/migrations/schema/009_hash_password_reset_code.sql b/database/migrations/schema/009_hash_password_reset_code.sql index 57425132..376982e2 100644 --- a/database/migrations/schema/009_hash_password_reset_code.sql +++ b/database/migrations/schema/009_hash_password_reset_code.sql @@ -1,6 +1,6 @@ delete from password_reset_code; alter table password_reset_code alter column password_reset_code_id drop default; -alter table password_reset_code alter column password_reset_code_id type bytea USING password_reset_code_id::text::bytea; +alter table password_reset_code alter column password_reset_code_id type bytea using password_reset_code_id::text::bytea; drop function if exists register_password_reset_code(text); ---- create above / drop below ---- diff --git a/database/migrations/schema/015_refactoring.sql b/database/migrations/schema/015_refactoring.sql new file mode 100644 index 00000000..b90a0dcf --- /dev/null +++ b/database/migrations/schema/015_refactoring.sql @@ -0,0 +1,15 @@ +delete from password_reset_code; +alter table password_reset_code alter column password_reset_code_id type text; +alter table session alter column session_id type text using substring(session_id::bytea::text from 3); +drop function if exists add_api_key(jsonb); +drop function if exists approve_session(bytea, text); +drop function if exists register_password_reset_code(text); +drop function if exists register_session(jsonb); +drop function if exists reset_user_password(bytea, text); +drop function if exists verify_password_reset_code(bytea); + +---- create above / drop below ---- + +delete from password_reset_code; +alter table password_reset_code alter column password_reset_code_id type bytea; +alter table session alter column session_id type bytea using session_id::text::bytea; diff --git a/database/tests/functions/api_keys/add_api_key.sql b/database/tests/functions/api_keys/add_api_key.sql index aeea90df..4e6ff76e 100644 --- a/database/tests/functions/api_keys/add_api_key.sql +++ b/database/tests/functions/api_keys/add_api_key.sql @@ -13,6 +13,7 @@ values (:'user1ID', 'user1', 'user1@email.com'); select add_api_key(' { "name": "apikey1", + "secret": "hashed-secret", "user_id": "00000000-0000-0000-0000-000000000001" } '::jsonb); @@ -22,12 +23,14 @@ select results_eq( $$ select name, + secret, user_id from api_key $$, $$ values ( 'apikey1', + 'hashed-secret', '00000000-0000-0000-0000-000000000001'::uuid ) $$, diff --git a/database/tests/functions/users/register_password_reset_code.sql b/database/tests/functions/users/register_password_reset_code.sql index 38248e9b..c2d96fde 100644 --- a/database/tests/functions/users/register_password_reset_code.sql +++ b/database/tests/functions/users/register_password_reset_code.sql @@ -1,6 +1,6 @@ -- Start transaction and plan tests begin; -select plan(5); +select plan(4); -- Seed user insert into "user" (user_id, alias, email, email_verified) @@ -9,10 +9,10 @@ insert into "user" (user_id, alias, email, email_verified) values ('00000000-0000-0000-0000-000000000002', 'user2', 'user2@email.com', false); -- Register password reset code -select register_password_reset_code('user1@email.com') as code1 \gset +select register_password_reset_code('user1@email.com', 'code1'); select is( password_reset_code_id, - sha512(:'code1'), + 'code1', 'Password reset code for user1 should be registered' ) from password_reset_code @@ -20,24 +20,19 @@ join "user" using (user_id) where alias = 'user1'; -- Register another password reset code for the same user -select register_password_reset_code('user1@email.com') as code2 \gset +select register_password_reset_code('user1@email.com', 'code2'); select is( password_reset_code_id, - sha512(:'code2'), + 'code2', 'Password reset code for user1 should have been updated' ) from password_reset_code join "user" using (user_id) where alias = 'user1'; -select isnt( - :'code1'::bytea, - :'code2'::bytea, - 'Password reset code must have changed' -); -- Try registering password reset code using non verified email select throws_ok( - $$ select register_password_reset_code('user2@email.com') $$, + $$ select register_password_reset_code('user2@email.com', 'code') $$, 'P0001', 'invalid email', 'No password reset code should be registered for non verified email user2@email.com' @@ -45,7 +40,7 @@ select throws_ok( -- Try registering password reset code using unregistered email select throws_ok( - $$ select register_password_reset_code('user3@email.com') $$, + $$ select register_password_reset_code('user3@email.com', 'code') $$, 'P0001', 'invalid email', 'No password reset code should be registered for unregistered email user3@email.com' diff --git a/database/tests/functions/users/register_session.sql b/database/tests/functions/users/register_session.sql index 38cd672a..85fb241f 100644 --- a/database/tests/functions/users/register_session.sql +++ b/database/tests/functions/users/register_session.sql @@ -1,6 +1,6 @@ -- Start transaction and plan tests begin; -select plan(6); +select plan(2); -- Seed user insert into "user" (user_id, alias, email) @@ -9,18 +9,20 @@ insert into "user" (user_id, alias, email, tfa_enabled) values ('00000000-0000-0000-0000-000000000002', 'user2', 'user2@email.com', true); -- Register session for user with tfa disabled -select session_id, approved from register_session(' +select register_session(' { + "session_id": "hashed-session-id-user1", "user_id": "00000000-0000-0000-0000-000000000001", "ip": "192.168.1.100", "user_agent": "Safari 13.0.5" } -') \gset +') as approved \gset -- Check if session registration succeeded select results_eq( $$ select + session_id, user_id, ip, user_agent, @@ -30,6 +32,7 @@ select results_eq( $$, $$ values ( + 'hashed-session-id-user1', '00000000-0000-0000-0000-000000000001'::uuid, '192.168.1.100'::inet, 'Safari 13.0.5', @@ -38,53 +41,32 @@ select results_eq( $$, 'Session for user1 should exist' ); -select is( - session_id, - sha512(:'session_id'), - 'Returned session id for user1 should match value stored' -) -from session where user_id = '00000000-0000-0000-0000-000000000001'; -select is( - true, - :'approved', - 'Returned approved value for user1 should be true' -) -from session where user_id = '00000000-0000-0000-0000-000000000001'; -- Register session for user with tfa enabled -select session_id, approved from register_session(' +select register_session(' { + "session_id": "hashed-session-id-user2", "user_id": "00000000-0000-0000-0000-000000000002" } -') \gset +') as approved \gset -- Check if session registration succeeded select results_eq( $$ - select user_id, approved + select + session_id, + approved from session where user_id = '00000000-0000-0000-0000-000000000002' $$, $$ values ( - '00000000-0000-0000-0000-000000000002'::uuid, + 'hashed-session-id-user2', false ) $$, 'Session for user2 should exist' ); -select is( - session_id, - sha512(:'session_id'), - 'Returned session id for user2 should match value stored' -) -from session where user_id = '00000000-0000-0000-0000-000000000002'; -select is( - false, - :'approved', - 'Returned approved value for user2 should be false' -) -from session where user_id = '00000000-0000-0000-0000-000000000002'; -- Finish tests and rollback transaction select * from finish(); diff --git a/database/tests/functions/users/reset_user_password.sql b/database/tests/functions/users/reset_user_password.sql index 29cdc253..0fd12be7 100644 --- a/database/tests/functions/users/reset_user_password.sql +++ b/database/tests/functions/users/reset_user_password.sql @@ -14,9 +14,9 @@ values (:'user1ID', 'user1', 'user1@email.com'); insert into "user" (user_id, alias, email) values (:'user2ID', 'user2', 'user2@email.com'); insert into password_reset_code (password_reset_code_id, user_id, created_at) -values (sha512(:'code1ID'), :'user1ID', current_timestamp); +values (:'code1ID', :'user1ID', current_timestamp); insert into password_reset_code (password_reset_code_id, user_id, created_at) -values (sha512(:'code2ID'), :'user2ID', current_timestamp - '30 minute'::interval); +values (:'code2ID', :'user2ID', current_timestamp - '30 minute'::interval); insert into session (session_id, user_id) values (gen_random_bytes(32), :'user1ID'); -- Password reset should fail in the following cases diff --git a/database/tests/functions/users/verify_password_reset_code.sql b/database/tests/functions/users/verify_password_reset_code.sql index e205dce9..0ccc62ff 100644 --- a/database/tests/functions/users/verify_password_reset_code.sql +++ b/database/tests/functions/users/verify_password_reset_code.sql @@ -14,7 +14,7 @@ values (:'user1ID', 'user1', 'user1@email.com'); insert into "user" (user_id, alias, email) values (:'user2ID', 'user2', 'user2@email.com'); insert into password_reset_code (password_reset_code_id, user_id, created_at) -values (sha512(:'code1ID'), :'user1ID', current_timestamp - '5 minute'::interval); +values (:'code1ID', :'user1ID', current_timestamp - '5 minute'::interval); insert into password_reset_code (password_reset_code_id, user_id, created_at) values (:'code2ID', :'user2ID', current_timestamp - '30 minute'::interval); diff --git a/internal/apikey/manager.go b/internal/apikey/manager.go index bd6b290c..a1af8e23 100644 --- a/internal/apikey/manager.go +++ b/internal/apikey/manager.go @@ -2,21 +2,29 @@ package apikey import ( "context" + "crypto/rand" + "crypto/sha512" + "encoding/base64" "encoding/json" + "errors" "fmt" + "strings" "github.com/artifacthub/hub/internal/hub" "github.com/artifacthub/hub/internal/util" + "github.com/jackc/pgx/v4" "github.com/satori/uuid" + "golang.org/x/crypto/bcrypt" ) const ( // Database queries - addAPIKeyDBQ = `select add_api_key($1::jsonb)` - deleteAPIKeyDBQ = `select delete_api_key($1::uuid, $2::uuid)` - getAPIKeyDBQ = `select get_api_key($1::uuid, $2::uuid)` - getUserAPIKeysDBQ = `select get_user_api_keys($1::uuid)` - updateAPIKeyDBQ = `select update_api_key($1::jsonb)` + addAPIKeyDBQ = `select add_api_key($1::jsonb)` + deleteAPIKeyDBQ = `select delete_api_key($1::uuid, $2::uuid)` + getAPIKeyDBQ = `select get_api_key($1::uuid, $2::uuid)` + getAPIKeyUserIDDBQ = `select user_id, secret from api_key where api_key_id = $1` + getUserAPIKeysDBQ = `select get_user_api_keys($1::uuid)` + updateAPIKeyDBQ = `select update_api_key($1::jsonb)` ) // Manager provides an API to manage api keys. @@ -32,7 +40,7 @@ func NewManager(db hub.DB) *Manager { } // Add adds the provided api key to the database. -func (m *Manager) Add(ctx context.Context, ak *hub.APIKey) ([]byte, error) { +func (m *Manager) Add(ctx context.Context, ak *hub.APIKey) (*hub.APIKey, error) { ak.UserID = ctx.Value(hub.UserIDKey).(string) // Validate input @@ -40,9 +48,64 @@ func (m *Manager) Add(ctx context.Context, ak *hub.APIKey) ([]byte, error) { return nil, fmt.Errorf("%w: %s", hub.ErrInvalidInput, "name not provided") } + // Generate API key secret + randomBytes := make([]byte, 32) + if _, err := rand.Read(randomBytes); err != nil { + return nil, err + } + apiKeySecret := base64.StdEncoding.EncodeToString(randomBytes) + apiKeySecretHashed := fmt.Sprintf("%x", sha512.Sum512([]byte(apiKeySecret))) + // Add api key to the database + var apiKeyID string + ak.Secret = apiKeySecretHashed akJSON, _ := json.Marshal(ak) - return util.DBQueryJSON(ctx, m.db, addAPIKeyDBQ, akJSON) + if err := m.db.QueryRow(ctx, addAPIKeyDBQ, akJSON).Scan(&apiKeyID); err != nil { + return nil, err + } + + return &hub.APIKey{ + APIKeyID: apiKeyID, + Secret: apiKeySecret, + }, nil +} + +// Check checks if the api key provided is valid. +func (m *Manager) Check(ctx context.Context, apiKeyID, apiKeySecret string) (*hub.CheckAPIKeyOutput, error) { + // Validate input + if apiKeyID == "" || apiKeySecret == "" { + return nil, fmt.Errorf("%w: %s", hub.ErrInvalidInput, "api key id or secret not provided") + } + + // Get key's user id and secret from database + var userID, apiKeySecretHashed string + err := m.db.QueryRow(ctx, getAPIKeyUserIDDBQ, apiKeyID).Scan(&userID, &apiKeySecretHashed) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return &hub.CheckAPIKeyOutput{Valid: false}, nil + } + return nil, err + } + + // Check if the secret provided is valid + switch { + case strings.HasPrefix(apiKeySecretHashed, "$2a$"): + // Bcrypt hash, will be deprecated soon + err = bcrypt.CompareHashAndPassword([]byte(apiKeySecretHashed), []byte(apiKeySecret)) + if err != nil { + return &hub.CheckAPIKeyOutput{Valid: false}, nil + } + default: + // SHA512 hash + if fmt.Sprintf("%x", sha512.Sum512([]byte(apiKeySecret))) != apiKeySecretHashed { + return &hub.CheckAPIKeyOutput{Valid: false}, nil + } + } + + return &hub.CheckAPIKeyOutput{ + Valid: true, + UserID: userID, + }, nil } // Delete deletes the provided api key from the database. diff --git a/internal/apikey/manager_test.go b/internal/apikey/manager_test.go index 30a4aa48..96cd89bf 100644 --- a/internal/apikey/manager_test.go +++ b/internal/apikey/manager_test.go @@ -2,13 +2,18 @@ package apikey import ( "context" + "crypto/sha512" "encoding/json" "errors" + "fmt" "testing" "github.com/artifacthub/hub/internal/hub" "github.com/artifacthub/hub/internal/tests" + "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/crypto/bcrypt" ) const apiKeyID = "00000000-0000-0000-0000-000000000001" @@ -60,14 +65,13 @@ func TestAdd(t *testing.T) { Name: "apikey1", UserID: "userID", } - akJSON, _ := json.Marshal(ak) db := &tests.DBMock{} - db.On("QueryRow", ctx, addAPIKeyDBQ, akJSON).Return(nil, tests.ErrFakeDB) + db.On("QueryRow", ctx, addAPIKeyDBQ, mock.Anything).Return(nil, tests.ErrFakeDB) m := NewManager(db) - keyInfoJSON, err := m.Add(ctx, ak) + output, err := m.Add(ctx, ak) assert.Equal(t, tests.ErrFakeDB, err) - assert.Nil(t, keyInfoJSON) + assert.Nil(t, output) db.AssertExpectations(t) }) @@ -77,14 +81,100 @@ func TestAdd(t *testing.T) { Name: "apikey1", UserID: "userID", } - akJSON, _ := json.Marshal(ak) db := &tests.DBMock{} - db.On("QueryRow", ctx, addAPIKeyDBQ, akJSON).Return([]byte("keyInfoJSON"), nil) + db.On("QueryRow", ctx, addAPIKeyDBQ, mock.Anything).Return("apiKeyID", nil) m := NewManager(db) - keyInfoJSON, err := m.Add(ctx, ak) + output, err := m.Add(ctx, ak) assert.NoError(t, err) - assert.Equal(t, []byte("keyInfoJSON"), keyInfoJSON) + assert.Equal(t, "apiKeyID", output.APIKeyID) + assert.NotEmpty(t, output.Secret) + db.AssertExpectations(t) + }) +} + +func TestCheck(t *testing.T) { + ctx := context.Background() + + t.Run("invalid input", func(t *testing.T) { + testCases := []struct { + errMsg string + apiKeyID string + apiKeySecret string + }{ + { + "api key id or secret not provided", + "", + "secret", + }, + { + "api key id or secret not provided", + "key", + "", + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.errMsg, func(t *testing.T) { + t.Parallel() + m := NewManager(nil) + _, err := m.Check(ctx, tc.apiKeyID, tc.apiKeySecret) + assert.True(t, errors.Is(err, hub.ErrInvalidInput)) + assert.Contains(t, err.Error(), tc.errMsg) + }) + } + }) + + t.Run("key info not found in database", func(t *testing.T) { + t.Parallel() + db := &tests.DBMock{} + db.On("QueryRow", ctx, getAPIKeyUserIDDBQ, "keyID").Return(nil, pgx.ErrNoRows) + m := NewManager(db) + + output, err := m.Check(ctx, "keyID", "secret") + assert.NoError(t, err) + assert.False(t, output.Valid) + assert.Empty(t, output.UserID) + db.AssertExpectations(t) + }) + + t.Run("error getting key info from database", func(t *testing.T) { + t.Parallel() + db := &tests.DBMock{} + db.On("QueryRow", ctx, getAPIKeyUserIDDBQ, "keyID").Return(nil, tests.ErrFakeDB) + m := NewManager(db) + + output, err := m.Check(ctx, "keyID", "secret") + assert.Equal(t, tests.ErrFakeDB, err) + assert.Nil(t, output) + db.AssertExpectations(t) + }) + + t.Run("valid key (secret hashed with bcrypt)", func(t *testing.T) { + t.Parallel() + db := &tests.DBMock{} + secretHashed, _ := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.DefaultCost) + db.On("QueryRow", ctx, getAPIKeyUserIDDBQ, "keyID").Return([]interface{}{"userID", string(secretHashed)}, nil) + m := NewManager(db) + + output, err := m.Check(ctx, "keyID", "secret") + assert.NoError(t, err) + assert.True(t, output.Valid) + assert.Equal(t, "userID", output.UserID) + db.AssertExpectations(t) + }) + + t.Run("valid key (secret hashed with sha512)", func(t *testing.T) { + t.Parallel() + db := &tests.DBMock{} + secretHashed := fmt.Sprintf("%x", sha512.Sum512([]byte("secret"))) + db.On("QueryRow", ctx, getAPIKeyUserIDDBQ, "keyID").Return([]interface{}{"userID", secretHashed}, nil) + m := NewManager(db) + + output, err := m.Check(ctx, "keyID", "secret") + assert.NoError(t, err) + assert.True(t, output.Valid) + assert.Equal(t, "userID", output.UserID) db.AssertExpectations(t) }) } diff --git a/internal/apikey/mock.go b/internal/apikey/mock.go index a85a1985..ccc017cf 100644 --- a/internal/apikey/mock.go +++ b/internal/apikey/mock.go @@ -13,9 +13,16 @@ type ManagerMock struct { } // Add implements the APIKeyManager interface. -func (m *ManagerMock) Add(ctx context.Context, ak *hub.APIKey) ([]byte, error) { +func (m *ManagerMock) Add(ctx context.Context, ak *hub.APIKey) (*hub.APIKey, error) { args := m.Called(ctx, ak) - data, _ := args.Get(0).([]byte) + data, _ := args.Get(0).(*hub.APIKey) + return data, args.Error(1) +} + +// Check implements the UserManager interface. +func (m *ManagerMock) Check(ctx context.Context, apiKeyID, apiKeySecret string) (*hub.CheckAPIKeyOutput, error) { + args := m.Called(ctx, apiKeyID, apiKeySecret) + data, _ := args.Get(0).(*hub.CheckAPIKeyOutput) return data, args.Error(1) } diff --git a/internal/handlers/apikey/handlers.go b/internal/handlers/apikey/handlers.go index 2cebf937..a9df1933 100644 --- a/internal/handlers/apikey/handlers.go +++ b/internal/handlers/apikey/handlers.go @@ -28,19 +28,20 @@ func NewHandlers(apiKeyManager hub.APIKeyManager) *Handlers { // Add is an http handler that adds the provided api key to the database. func (h *Handlers) Add(w http.ResponseWriter, r *http.Request) { - ak := &hub.APIKey{} - if err := json.NewDecoder(r.Body).Decode(&ak); err != nil { + akIN := &hub.APIKey{} + if err := json.NewDecoder(r.Body).Decode(&akIN); err != nil { h.logger.Error().Err(err).Str("method", "Add").Msg(hub.ErrInvalidInput.Error()) helpers.RenderErrorJSON(w, hub.ErrInvalidInput) return } - dataJSON, err := h.apiKeyManager.Add(r.Context(), ak) + akOUT, err := h.apiKeyManager.Add(r.Context(), akIN) if err != nil { h.logger.Error().Err(err).Str("method", "Add").Send() helpers.RenderErrorJSON(w, err) return } - helpers.RenderJSON(w, dataJSON, 0, http.StatusCreated) + akOUTJSON, _ := json.Marshal(akOUT) + helpers.RenderJSON(w, akOUTJSON, 0, http.StatusCreated) } // Delete is an http handler that deletes the provided api key from the database. diff --git a/internal/handlers/apikey/handlers_test.go b/internal/handlers/apikey/handlers_test.go index c5fe0170..a177acfa 100644 --- a/internal/handlers/apikey/handlers_test.go +++ b/internal/handlers/apikey/handlers_test.go @@ -100,7 +100,11 @@ func TestAdd(t *testing.T) { r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID")) hw := newHandlersWrapper() - hw.am.On("Add", r.Context(), ak).Return([]byte("keyInfoJSON"), nil) + akOUT := &hub.APIKey{ + APIKeyID: "apiKeyID", + Secret: "secret", + } + hw.am.On("Add", r.Context(), ak).Return(akOUT, nil) hw.h.Add(w, r) resp := w.Result() defer resp.Body.Close() @@ -110,7 +114,8 @@ func TestAdd(t *testing.T) { assert.Equal(t, http.StatusCreated, resp.StatusCode) assert.Equal(t, "application/json", h.Get("Content-Type")) assert.Equal(t, helpers.BuildCacheControlHeader(0), h.Get("Cache-Control")) - assert.Equal(t, []byte("keyInfoJSON"), data) + outputAKJSON, _ := json.Marshal(akOUT) + assert.Equal(t, outputAKJSON, data) hw.am.AssertExpectations(t) }) } diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 6da80619..728abb8f 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -78,7 +78,7 @@ type Handlers struct { // Setup creates a new Handlers instance. func Setup(ctx context.Context, cfg *viper.Viper, svc *Services) (*Handlers, error) { - userHandlers, err := user.NewHandlers(ctx, svc.UserManager, cfg) + userHandlers, err := user.NewHandlers(ctx, svc.UserManager, svc.APIKeyManager, cfg) if err != nil { return nil, err } diff --git a/internal/handlers/user/handlers.go b/internal/handlers/user/handlers.go index c9359710..b6958d3c 100644 --- a/internal/handlers/user/handlers.go +++ b/internal/handlers/user/handlers.go @@ -66,16 +66,22 @@ var ( // Handlers represents a group of http handlers in charge of handling // users operations. type Handlers struct { - userManager hub.UserManager - cfg *viper.Viper - sc *securecookie.SecureCookie - oauthConfig map[string]*oauth2.Config - oidcProvider *oidc.Provider - logger zerolog.Logger + userManager hub.UserManager + apiKeyManager hub.APIKeyManager + cfg *viper.Viper + sc *securecookie.SecureCookie + oauthConfig map[string]*oauth2.Config + oidcProvider *oidc.Provider + logger zerolog.Logger } // NewHandlers creates a new Handlers instance. -func NewHandlers(ctx context.Context, userManager hub.UserManager, cfg *viper.Viper) (*Handlers, error) { +func NewHandlers( + ctx context.Context, + userManager hub.UserManager, + apiKeyManager hub.APIKeyManager, + cfg *viper.Viper, +) (*Handlers, error) { // Setup secure cookie instance sc := securecookie.New([]byte(cfg.GetString("server.cookie.hashKey")), nil) sc.MaxAge(int(sessionDuration.Seconds())) @@ -112,12 +118,13 @@ func NewHandlers(ctx context.Context, userManager hub.UserManager, cfg *viper.Vi } return &Handlers{ - userManager: userManager, - cfg: cfg, - sc: sc, - oauthConfig: oauthConfig, - oidcProvider: oidcProvider, - logger: log.With().Str("handlers", "user").Logger(), + userManager: userManager, + apiKeyManager: apiKeyManager, + cfg: cfg, + sc: sc, + oauthConfig: oauthConfig, + oidcProvider: oidcProvider, + logger: log.With().Str("handlers", "user").Logger(), }, nil } @@ -137,7 +144,7 @@ func (h *Handlers) ApproveSession(w http.ResponseWriter, r *http.Request) { passcode := input["passcode"] // Extract sessionID from cookie - var sessionID []byte + var sessionID string cookie, err := r.Cookie(sessionCookieName) if err != nil { h.logger.Error().Err(err).Str("method", "ApproveSession").Msg("session cookie not found") @@ -290,7 +297,7 @@ func (h *Handlers) InjectUserID(next http.Handler) http.Handler { if err != nil { return } - var sessionID []byte + var sessionID string if err = h.sc.Decode(sessionCookieName, cookie.Value, &sessionID); err != nil { return } @@ -371,7 +378,7 @@ func (h *Handlers) Logout(w http.ResponseWriter, r *http.Request) { // Delete user session cookie, err := r.Cookie(sessionCookieName) if err == nil { - var sessionID []byte + var sessionID string err = h.sc.Decode(sessionCookieName, cookie.Value, &sessionID) if err == nil { err = h.userManager.DeleteSession(r.Context(), sessionID) @@ -749,7 +756,7 @@ func (h *Handlers) RequireLogin(next http.Handler) http.Handler { // Use API key based authentication if API key is provided if apiKeyID != "" && apiKeySecret != "" { // Check the API key provided is valid - checkAPIKeyOutput, err := h.userManager.CheckAPIKey(r.Context(), apiKeyID, apiKeySecret) + checkAPIKeyOutput, err := h.apiKeyManager.Check(r.Context(), apiKeyID, apiKeySecret) if err != nil { h.logger.Error().Err(err).Str("method", "RequireLogin").Msg("checkAPIKey failed") helpers.RenderErrorWithCodeJSON(w, nil, http.StatusInternalServerError) @@ -766,7 +773,7 @@ func (h *Handlers) RequireLogin(next http.Handler) http.Handler { cookie, err := r.Cookie(sessionCookieName) if err == nil { // Extract and validate cookie from request - var sessionID []byte + var sessionID string if err = h.sc.Decode(sessionCookieName, cookie.Value, &sessionID); err != nil { h.logger.Error().Err(err).Str("method", "RequireLogin").Msg("sessionID decoding failed") helpers.RenderErrorWithCodeJSON(w, errInvalidSession, http.StatusUnauthorized) diff --git a/internal/handlers/user/handlers_test.go b/internal/handlers/user/handlers_test.go index c5cf2df1..d236540b 100644 --- a/internal/handlers/user/handlers_test.go +++ b/internal/handlers/user/handlers_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/artifacthub/hub/internal/apikey" "github.com/artifacthub/hub/internal/handlers/helpers" "github.com/artifacthub/hub/internal/hub" "github.com/artifacthub/hub/internal/tests" @@ -31,7 +32,7 @@ func TestMain(m *testing.M) { } func TestApproveSession(t *testing.T) { - sessionID := []byte("sessionID") + sessionID := "sessionID" t.Run("invalid input", func(t *testing.T) { testCases := []struct { @@ -495,6 +496,8 @@ func TestGetProfile(t *testing.T) { } func TestInjectUserID(t *testing.T) { + sessionID := "sessionID" + checkUserID := func(expectedUserID interface{}) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if expectedUserID == nil { @@ -541,7 +544,7 @@ func TestInjectUserID(t *testing.T) { r, _ := http.NewRequest("GET", "/", nil) hw := newHandlersWrapper() - encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, []byte("sessionID")) + encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID) r.AddCookie(&http.Cookie{ Name: sessionCookieName, Value: encodedSessionID, @@ -565,7 +568,7 @@ func TestInjectUserID(t *testing.T) { hw.um.On("CheckSession", r.Context(), mock.Anything, mock.Anything). Return(&hub.CheckSessionOutput{UserID: "", Valid: false}, nil) - encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, []byte("sessionID")) + encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID) r.AddCookie(&http.Cookie{ Name: sessionCookieName, Value: encodedSessionID, @@ -586,7 +589,7 @@ func TestInjectUserID(t *testing.T) { hw := newHandlersWrapper() hw.um.On("CheckSession", r.Context(), mock.Anything, mock.Anything). Return(&hub.CheckSessionOutput{UserID: "userID", Valid: true}, nil) - encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, []byte("sessionID")) + encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID) r.AddCookie(&http.Cookie{ Name: sessionCookieName, Value: encodedSessionID, @@ -601,6 +604,8 @@ func TestInjectUserID(t *testing.T) { } func TestLogin(t *testing.T) { + sessionID := "sessionID" + t.Run("invalid", func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() @@ -693,7 +698,7 @@ func TestLogin(t *testing.T) { Return(&hub.CheckCredentialsOutput{Valid: true, UserID: "userID"}, nil) hw.um.On("RegisterSession", r.Context(), &hub.Session{UserID: "userID"}). Return(&hub.Session{ - SessionID: []byte("sessionID"), + SessionID: sessionID, Approved: true, }, nil) hw.h.Login(w, r) @@ -708,10 +713,10 @@ func TestLogin(t *testing.T) { assert.Equal(t, "/", cookie.Path) assert.True(t, cookie.HttpOnly) assert.False(t, cookie.Secure) - var sessionID []byte - err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &sessionID) + var cookieSessionID string + err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &cookieSessionID) require.NoError(t, err) - assert.Equal(t, []byte("sessionID"), sessionID) + assert.Equal(t, sessionID, cookieSessionID) assert.Equal(t, "true", h.Get(SessionApprovedHeader)) hw.um.AssertExpectations(t) }) @@ -727,7 +732,7 @@ func TestLogin(t *testing.T) { Return(&hub.CheckCredentialsOutput{Valid: true, UserID: "userID"}, nil) hw.um.On("RegisterSession", r.Context(), &hub.Session{UserID: "userID"}). Return(&hub.Session{ - SessionID: []byte("sessionID"), + SessionID: sessionID, Approved: false, }, nil) hw.h.Login(w, r) @@ -742,10 +747,10 @@ func TestLogin(t *testing.T) { assert.Equal(t, "/", cookie.Path) assert.True(t, cookie.HttpOnly) assert.False(t, cookie.Secure) - var sessionID []byte - err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &sessionID) + var cookieSessionID string + err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &cookieSessionID) require.NoError(t, err) - assert.Equal(t, []byte("sessionID"), sessionID) + assert.Equal(t, sessionID, cookieSessionID) assert.Equal(t, "false", h.Get(SessionApprovedHeader)) hw.um.AssertExpectations(t) }) @@ -815,8 +820,8 @@ func TestLogout(t *testing.T) { r, _ := http.NewRequest("GET", "/", nil) hw := newHandlersWrapper() - hw.um.On("DeleteSession", r.Context(), []byte("sessionID")).Return(tc.err) - encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, []byte("sessionID")) + hw.um.On("DeleteSession", r.Context(), "sessionID").Return(tc.err) + encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, "sessionID") r.AddCookie(&http.Cookie{ Name: sessionCookieName, Value: encodedSessionID, @@ -1092,7 +1097,7 @@ func TestRegisterUser(t *testing.T) { } func TestRequireLogin(t *testing.T) { - sessionID := []byte("sessionID") + sessionID := "sessionID" t.Run("api key based authentication", func(t *testing.T) { apiKeyID := "keyID" @@ -1143,7 +1148,7 @@ func TestRequireLogin(t *testing.T) { r.Header.Add(APIKeySecretHeader, apiKeySecret) hw := newHandlersWrapper() - hw.um.On("CheckAPIKey", r.Context(), apiKeyID, apiKeySecret).Return(nil, tests.ErrFakeDB) + hw.am.On("Check", r.Context(), apiKeyID, apiKeySecret).Return(nil, tests.ErrFakeDB) hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r) resp := w.Result() defer resp.Body.Close() @@ -1164,7 +1169,7 @@ func TestRequireLogin(t *testing.T) { r.Header.Add(APIKeySecretHeader, apiKeySecret) hw := newHandlersWrapper() - hw.um.On("CheckAPIKey", r.Context(), apiKeyID, apiKeySecret). + hw.am.On("Check", r.Context(), apiKeyID, apiKeySecret). Return(&hub.CheckAPIKeyOutput{UserID: "", Valid: false}, nil) hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r) resp := w.Result() @@ -1186,7 +1191,7 @@ func TestRequireLogin(t *testing.T) { r.Header.Add(APIKeySecretHeader, apiKeySecret) hw := newHandlersWrapper() - hw.um.On("CheckAPIKey", r.Context(), apiKeyID, apiKeySecret). + hw.am.On("Check", r.Context(), apiKeyID, apiKeySecret). Return(&hub.CheckAPIKeyOutput{UserID: "userID", Valid: true}, nil) hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r) resp := w.Result() @@ -1689,6 +1694,7 @@ func testsOK(w http.ResponseWriter, r *http.Request) {} type handlersWrapper struct { cfg *viper.Viper um *user.ManagerMock + am *apikey.ManagerMock h *Handlers } @@ -1697,11 +1703,13 @@ func newHandlersWrapper() *handlersWrapper { cfg.Set("server.baseURL", "baseURL") cfg.Set("server.oauth.github", map[string]string{}) um := &user.ManagerMock{} - h, _ := NewHandlers(context.Background(), um, cfg) + am := &apikey.ManagerMock{} + h, _ := NewHandlers(context.Background(), um, am, cfg) return &handlersWrapper{ cfg: cfg, um: um, + am: am, h: h, } } diff --git a/internal/hub/apikey.go b/internal/hub/apikey.go index 592e5a4a..75b59703 100644 --- a/internal/hub/apikey.go +++ b/internal/hub/apikey.go @@ -6,6 +6,7 @@ import "context" type APIKey struct { APIKeyID string `json:"api_key_id"` Name string `json:"name"` + Secret string `json:"secret"` CreatedAt int64 `json:"created_at"` UserID string `json:"user_id"` } @@ -13,9 +14,16 @@ type APIKey struct { // APIKeyManager describes the methods an APIKeyManager implementation must // provide. type APIKeyManager interface { - Add(ctx context.Context, ak *APIKey) ([]byte, error) + Add(ctx context.Context, ak *APIKey) (*APIKey, error) + Check(ctx context.Context, apiKeyID, apiKeySecret string) (*CheckAPIKeyOutput, error) Delete(ctx context.Context, apiKeyID string) error GetJSON(ctx context.Context, apiKeyID string) ([]byte, error) GetOwnedByUserJSON(ctx context.Context) ([]byte, error) Update(ctx context.Context, ak *APIKey) error } + +// CheckAPIKeyOutput represents the output returned by the CheckApiKey method. +type CheckAPIKeyOutput struct { + Valid bool `json:"valid"` + UserID string `json:"user_id"` +} diff --git a/internal/hub/user.go b/internal/hub/user.go index 8fe030bc..17e1e500 100644 --- a/internal/hub/user.go +++ b/internal/hub/user.go @@ -5,12 +5,6 @@ import ( "time" ) -// CheckAPIKeyOutput represents the output returned by the CheckApiKey method. -type CheckAPIKeyOutput struct { - Valid bool `json:"valid"` - UserID string `json:"user_id"` -} - // CheckCredentialsOutput represents the output returned by the // CheckCredentials method. type CheckCredentialsOutput struct { @@ -26,7 +20,7 @@ type CheckSessionOutput struct { // Session represents some information about a user session. type Session struct { - SessionID []byte `json:"session_id"` + SessionID string `json:"session_id"` UserID string `json:"user_id"` IP string `json:"ip"` UserAgent string `json:"user_agent"` @@ -68,12 +62,11 @@ var UserIDKey = userIDKey{} // UserManager describes the methods a UserManager implementation must provide. type UserManager interface { - ApproveSession(ctx context.Context, sessionID []byte, passcode string) error - CheckAPIKey(ctx context.Context, apiKeyID, apiKeySecret string) (*CheckAPIKeyOutput, error) + ApproveSession(ctx context.Context, sessionID, passcode string) error CheckAvailability(ctx context.Context, resourceKind, value string) (bool, error) CheckCredentials(ctx context.Context, email, password string) (*CheckCredentialsOutput, error) - CheckSession(ctx context.Context, sessionID []byte, duration time.Duration) (*CheckSessionOutput, error) - DeleteSession(ctx context.Context, sessionID []byte) error + CheckSession(ctx context.Context, sessionID string, duration time.Duration) (*CheckSessionOutput, error) + DeleteSession(ctx context.Context, sessionID string) error DisableTFA(ctx context.Context, passcode string) error EnableTFA(ctx context.Context, passcode string) error GetProfile(ctx context.Context) (*User, error) diff --git a/internal/user/manager.go b/internal/user/manager.go index 2d775303..a8d8811e 100644 --- a/internal/user/manager.go +++ b/internal/user/manager.go @@ -3,6 +3,7 @@ package user import ( "bytes" "context" + "crypto/rand" "crypto/sha512" "encoding/base64" "encoding/json" @@ -10,7 +11,6 @@ import ( "fmt" "image/png" "net/url" - "strings" "time" "github.com/artifacthub/hub/internal/email" @@ -26,13 +26,12 @@ import ( const ( // Database queries - approveSessionDBQ = `select approve_session($1::bytea, $2::text)` + approveSessionDBQ = `select approve_session($1::text, $2::text)` checkUserAliasAvailDBQ = `select check_user_alias_availability($1::text)` checkUserCredsDBQ = `select user_id, password from "user" where email = $1 and password is not null and email_verified = true` deleteSessionDBQ = `delete from session where session_id = $1` disableTFADBQ = `update "user" set tfa_enabled = false, tfa_url = null, tfa_recovery_codes = null where user_id = $1 and tfa_enabled = true` enableTFADBQ = `update "user" set tfa_enabled = true where user_id = $1` - getAPIKeyInfoDBQ = `select user_id, secret from api_key where api_key_id = $1` getSessionDBQ = `select user_id, floor(extract(epoch from created_at)), approved from session where session_id = $1` getTFAConfigDBQ = `select get_user_tfa_config($1::uuid)` getUserEmailDBQ = `select email from "user" where user_id = $1` @@ -40,15 +39,15 @@ const ( getUserIDFromSessionIDDBQ = `select user_id from session where session_id = $1` getUserPasswordDBQ = `select password from "user" where user_id = $1 and password is not null` getUserProfileDBQ = `select get_user_profile($1::uuid)` - registerPasswordResetCodeDBQ = `select register_password_reset_code($1::text)` - registerSessionDBQ = `select session_id, approved from register_session($1::jsonb)` + registerPasswordResetCodeDBQ = `select register_password_reset_code($1::text, $2::text)` + registerSessionDBQ = `select register_session($1::jsonb)` registerUserDBQ = `select register_user($1::jsonb)` - resetUserPasswordDBQ = `select reset_user_password($1::bytea, $2::text)` + resetUserPasswordDBQ = `select reset_user_password($1::text, $2::text)` updateTFAInfoDBQ = `update "user" set tfa_url = $2, tfa_recovery_codes = $3 where user_id = $1` updateUserPasswordDBQ = `select update_user_password($1::uuid, $2::text, $3::text)` updateUserProfileDBQ = `select update_user_profile($1::uuid, $2::jsonb)` verifyEmailDBQ = `select verify_email($1::uuid)` - verifyPasswordResetCodeDBQ = `select verify_password_reset_code($1::bytea)` + verifyPasswordResetCodeDBQ = `select verify_password_reset_code($1::text)` numRecoveryCodes = 10 ) @@ -94,7 +93,7 @@ func NewManager(db hub.DB, es hub.EmailSender) *Manager { } // ApproveSession approves a given session using the TFA passcode provided. -func (m *Manager) ApproveSession(ctx context.Context, sessionID []byte, passcode string) error { +func (m *Manager) ApproveSession(ctx context.Context, sessionID, passcode string) error { // Validate input if len(sessionID) == 0 { return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "sessionID not provided") @@ -105,7 +104,7 @@ func (m *Manager) ApproveSession(ctx context.Context, sessionID []byte, passcode // Get id of the user the session belongs to var userID string - err := m.db.QueryRow(ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Scan(&userID) + err := m.db.QueryRow(ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Scan(&userID) if err != nil { return err } @@ -131,48 +130,10 @@ func (m *Manager) ApproveSession(ctx context.Context, sessionID []byte, passcode if validRecoveryCodeProvided { recoveryCode = passcode } - _, err = m.db.Exec(ctx, approveSessionDBQ, hashSessionID(sessionID), recoveryCode) + _, err = m.db.Exec(ctx, approveSessionDBQ, hash(sessionID), recoveryCode) return err } -// CheckAPIKey checks if the api key provided is valid. -func (m *Manager) CheckAPIKey(ctx context.Context, apiKeyID, apiKeySecret string) (*hub.CheckAPIKeyOutput, error) { - // Validate input - if apiKeyID == "" || apiKeySecret == "" { - return nil, fmt.Errorf("%w: %s", hub.ErrInvalidInput, "api key id or secret not provided") - } - - // Get key's user id and secret from database - var userID, apiKeySecretHashed string - err := m.db.QueryRow(ctx, getAPIKeyInfoDBQ, apiKeyID).Scan(&userID, &apiKeySecretHashed) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return &hub.CheckAPIKeyOutput{Valid: false}, nil - } - return nil, err - } - - // Check if the secret provided is valid - switch { - case strings.HasPrefix(apiKeySecretHashed, "$2a$"): - // Bcrypt hash, will be deprecated soon - err = bcrypt.CompareHashAndPassword([]byte(apiKeySecretHashed), []byte(apiKeySecret)) - if err != nil { - return &hub.CheckAPIKeyOutput{Valid: false}, nil - } - default: - // SHA512 hash - if fmt.Sprintf("%x", sha512.Sum512([]byte(apiKeySecret))) != apiKeySecretHashed { - return &hub.CheckAPIKeyOutput{Valid: false}, nil - } - } - - return &hub.CheckAPIKeyOutput{ - Valid: true, - UserID: userID, - }, nil -} - // CheckAvailability checks the availability of a given value for the provided // resource kind. func (m *Manager) CheckAvailability(ctx context.Context, resourceKind, value string) (bool, error) { @@ -247,7 +208,7 @@ func (m *Manager) CheckCredentials( // CheckSession checks if the user session provided is valid. func (m *Manager) CheckSession( ctx context.Context, - sessionID []byte, + sessionID string, duration time.Duration, ) (*hub.CheckSessionOutput, error) { // Validate input @@ -262,7 +223,7 @@ func (m *Manager) CheckSession( var userID string var createdAt int64 var approved bool - err := m.db.QueryRow(ctx, getSessionDBQ, hashSessionID(sessionID)).Scan(&userID, &createdAt, &approved) + err := m.db.QueryRow(ctx, getSessionDBQ, hash(sessionID)).Scan(&userID, &createdAt, &approved) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return &hub.CheckSessionOutput{Valid: false}, nil @@ -288,14 +249,14 @@ func (m *Manager) CheckSession( } // DeleteSession deletes a user session from the database. -func (m *Manager) DeleteSession(ctx context.Context, sessionID []byte) error { +func (m *Manager) DeleteSession(ctx context.Context, sessionID string) error { // Validate input if len(sessionID) == 0 { return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "session id not provided") } // Delete session from database - _, err := m.db.Exec(ctx, deleteSessionDBQ, hashSessionID(sessionID)) + _, err := m.db.Exec(ctx, deleteSessionDBQ, hash(sessionID)) return err } @@ -461,17 +422,20 @@ func (m *Manager) RegisterPasswordResetCode(ctx context.Context, userEmail, base } // Register password reset code in database - var code []byte - err := m.db.QueryRow(ctx, registerPasswordResetCodeDBQ, userEmail).Scan(&code) + randomBytes := make([]byte, 32) + if _, err := rand.Read(randomBytes); err != nil { + return err + } + code := base64.URLEncoding.EncodeToString(randomBytes) + _, err := m.db.Exec(ctx, registerPasswordResetCodeDBQ, userEmail, hash(code)) if err != nil { return err } // Send password reset email - if code != nil && m.es != nil { - codeB64 := base64.URLEncoding.EncodeToString(code) + if m.es != nil { templateData := map[string]string{ - "link": fmt.Sprintf("%s/reset-password?code=%s", baseURL, codeB64), + "link": fmt.Sprintf("%s/reset-password?code=%s", baseURL, code), } var emailBody bytes.Buffer if err := passwordResetTmpl.Execute(&emailBody, templateData); err != nil { @@ -500,14 +464,23 @@ func (m *Manager) RegisterSession(ctx context.Context, session *hub.Session) (*h return nil, fmt.Errorf("%w: %s", hub.ErrInvalidInput, "invalid user id") } + // Generate session id + randomBytes := make([]byte, 32) + if _, err := rand.Read(randomBytes); err != nil { + return nil, err + } + sessionID := base64.StdEncoding.EncodeToString(randomBytes) + sessionIDHashed := hash(sessionID) + // Register session in database + session.SessionID = sessionIDHashed sessionJSON, _ := json.Marshal(session) - var sessionID []byte var approved bool - err := m.db.QueryRow(ctx, registerSessionDBQ, sessionJSON).Scan(&sessionID, &approved) + err := m.db.QueryRow(ctx, registerSessionDBQ, sessionJSON).Scan(&approved) if err != nil { return nil, err } + return &hub.Session{ SessionID: sessionID, UserID: session.UserID, @@ -585,9 +558,9 @@ func (m *Manager) RegisterUser(ctx context.Context, user *hub.User, baseURL stri } // ResetPassword resets the user password in the database. -func (m *Manager) ResetPassword(ctx context.Context, codeB64, newPassword, baseURL string) error { +func (m *Manager) ResetPassword(ctx context.Context, code, newPassword, baseURL string) error { // Validate input - if codeB64 == "" { + if code == "" { return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "code not provided") } if newPassword == "" { @@ -610,12 +583,8 @@ func (m *Manager) ResetPassword(ctx context.Context, codeB64, newPassword, baseU } // Reset user password in database - code, err := base64.URLEncoding.DecodeString(codeB64) - if err != nil { - return ErrInvalidPasswordResetCode - } var userEmail string - err = m.db.QueryRow(ctx, resetUserPasswordDBQ, code, string(newHashed)).Scan(&userEmail) + err = m.db.QueryRow(ctx, resetUserPasswordDBQ, hash(code), string(newHashed)).Scan(&userEmail) if err != nil { if err.Error() == errInvalidPasswordResetCodeDB.Error() { return ErrInvalidPasswordResetCode @@ -772,28 +741,23 @@ func (m *Manager) VerifyEmail(ctx context.Context, code string) (bool, error) { } // VerifyPasswordResetCode verifies if the provided code is valid. -func (m *Manager) VerifyPasswordResetCode(ctx context.Context, codeB64 string) error { +func (m *Manager) VerifyPasswordResetCode(ctx context.Context, code string) error { // Validate input - if codeB64 == "" { + if code == "" { return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "code not provided") } // Verify password reset code in database - code, err := base64.URLEncoding.DecodeString(codeB64) - if err != nil { - return ErrInvalidPasswordResetCode - } - _, err = m.db.Exec(ctx, verifyPasswordResetCodeDBQ, code) + _, err := m.db.Exec(ctx, verifyPasswordResetCodeDBQ, hash(code)) if err != nil && err.Error() == errInvalidPasswordResetCodeDB.Error() { return ErrInvalidPasswordResetCode } return err } -// hashSessionID is a helper function that creates a sha512 hash of the -// sessionID provided. -func hashSessionID(sessionID []byte) string { - return fmt.Sprintf("\\x%x", sha512.Sum512(sessionID)) +// hash is a helper function that creates a sha512 hash of the text provided. +func hash(text string) string { + return fmt.Sprintf("%x", sha512.Sum512([]byte(text))) } // isValidRecoveryCode checks if the code provided is a valid recovery code. diff --git a/internal/user/manager_test.go b/internal/user/manager_test.go index 869e2463..2aa617c4 100644 --- a/internal/user/manager_test.go +++ b/internal/user/manager_test.go @@ -2,8 +2,6 @@ package user import ( "context" - "crypto/sha512" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -24,8 +22,8 @@ import ( func TestApproveSession(t *testing.T) { ctx := context.Background() - sessionID := []byte("sessionID") - hashedSessionID := hashSessionID([]byte("sessionID")) + sessionID := "sessionID" + hashedSessionID := hash(sessionID) opts := totp.GenerateOpts{ Issuer: "Artifact Hub", AccountName: "test@email.com", @@ -42,22 +40,17 @@ func TestApproveSession(t *testing.T) { t.Run("invalid input", func(t *testing.T) { testCases := []struct { errMsg string - sessionID []byte + sessionID string passcode string }{ { "sessionID not provided", - nil, - "123456", - }, - { - "sessionID not provided", - []byte(""), + "", "123456", }, { "passcode not provided", - []byte("sessionID"), + "sessionID", "", }, } @@ -87,7 +80,7 @@ func TestApproveSession(t *testing.T) { t.Run("error getting tfa config from database", func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil) + db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Return("userID", nil) db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(nil, tests.ErrFakeDB) m := NewManager(db, nil) @@ -99,7 +92,7 @@ func TestApproveSession(t *testing.T) { t.Run("invalid passcode provided", func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil) + db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Return("userID", nil) db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil) m := NewManager(db, nil) @@ -111,7 +104,7 @@ func TestApproveSession(t *testing.T) { t.Run("session approved successfully", func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil) + db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Return("userID", nil) db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil) db.On("Exec", ctx, approveSessionDBQ, hashedSessionID, "").Return(nil) m := NewManager(db, nil) @@ -125,7 +118,7 @@ func TestApproveSession(t *testing.T) { t.Run("session approved successfully (using valid recovery code)", func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil) + db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Return("userID", nil) db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil) db.On("Exec", ctx, approveSessionDBQ, hashedSessionID, code1).Return(nil) m := NewManager(db, nil) @@ -136,92 +129,6 @@ func TestApproveSession(t *testing.T) { }) } -func TestCheckAPIKey(t *testing.T) { - ctx := context.Background() - - t.Run("invalid input", func(t *testing.T) { - testCases := []struct { - errMsg string - apiKeyID string - apiKeySecret string - }{ - { - "api key id or secret not provided", - "", - "secret", - }, - { - "api key id or secret not provided", - "key", - "", - }, - } - for _, tc := range testCases { - tc := tc - t.Run(tc.errMsg, func(t *testing.T) { - t.Parallel() - m := NewManager(nil, nil) - _, err := m.CheckAPIKey(ctx, tc.apiKeyID, tc.apiKeySecret) - assert.True(t, errors.Is(err, hub.ErrInvalidInput)) - assert.Contains(t, err.Error(), tc.errMsg) - }) - } - }) - - t.Run("key info not found in database", func(t *testing.T) { - t.Parallel() - db := &tests.DBMock{} - db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return(nil, pgx.ErrNoRows) - m := NewManager(db, nil) - - output, err := m.CheckAPIKey(ctx, "keyID", "secret") - assert.NoError(t, err) - assert.False(t, output.Valid) - assert.Empty(t, output.UserID) - db.AssertExpectations(t) - }) - - t.Run("error getting key info from database", func(t *testing.T) { - t.Parallel() - db := &tests.DBMock{} - db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return(nil, tests.ErrFakeDB) - m := NewManager(db, nil) - - output, err := m.CheckAPIKey(ctx, "keyID", "secret") - assert.Equal(t, tests.ErrFakeDB, err) - assert.Nil(t, output) - db.AssertExpectations(t) - }) - - t.Run("valid key (secret hashed with bcrypt)", func(t *testing.T) { - t.Parallel() - db := &tests.DBMock{} - secretHashed, _ := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.DefaultCost) - db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return([]interface{}{"userID", string(secretHashed)}, nil) - m := NewManager(db, nil) - - output, err := m.CheckAPIKey(ctx, "keyID", "secret") - assert.NoError(t, err) - assert.True(t, output.Valid) - assert.Equal(t, "userID", output.UserID) - db.AssertExpectations(t) - }) - - t.Run("valid key (secret hashed with sha512)", func(t *testing.T) { - t.Parallel() - db := &tests.DBMock{} - secretHashed := fmt.Sprintf("%x", sha512.Sum512([]byte("secret"))) - db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return([]interface{}{"userID", secretHashed}, nil) - m := NewManager(db, nil) - - output, err := m.CheckAPIKey(ctx, "keyID", "secret") - assert.NoError(t, err) - assert.True(t, output.Valid) - assert.Equal(t, "userID", output.UserID) - db.AssertExpectations(t) - }) -} - func TestCheckAvailability(t *testing.T) { ctx := context.Background() @@ -385,27 +292,23 @@ func TestCheckCredentials(t *testing.T) { func TestCheckSession(t *testing.T) { ctx := context.Background() - hashedSessionID := hashSessionID([]byte("sessionID")) + sessionID := "sessionID" + hashedSessionID := hash(sessionID) t.Run("invalid input", func(t *testing.T) { testCases := []struct { errMsg string - sessionID []byte + sessionID string duration time.Duration }{ { "session id not provided", - nil, - 10, - }, - { - "session id not provided", - []byte(""), + "", 10, }, { "duration not provided", - []byte("sessionID"), + "sessionID", 0, }, } @@ -427,7 +330,7 @@ func TestCheckSession(t *testing.T) { db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return(nil, pgx.ErrNoRows) m := NewManager(db, nil) - output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour) + output, err := m.CheckSession(ctx, sessionID, 1*time.Hour) assert.NoError(t, err) assert.False(t, output.Valid) assert.Empty(t, output.UserID) @@ -440,7 +343,7 @@ func TestCheckSession(t *testing.T) { db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return(nil, tests.ErrFakeDB) m := NewManager(db, nil) - output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour) + output, err := m.CheckSession(ctx, sessionID, 1*time.Hour) assert.Equal(t, tests.ErrFakeDB, err) assert.Nil(t, output) db.AssertExpectations(t) @@ -456,7 +359,7 @@ func TestCheckSession(t *testing.T) { }, nil) m := NewManager(db, nil) - output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour) + output, err := m.CheckSession(ctx, sessionID, 1*time.Hour) assert.NoError(t, err) assert.False(t, output.Valid) assert.Empty(t, output.UserID) @@ -473,7 +376,7 @@ func TestCheckSession(t *testing.T) { }, nil) m := NewManager(db, nil) - output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour) + output, err := m.CheckSession(ctx, sessionID, 1*time.Hour) assert.NoError(t, err) assert.False(t, output.Valid) assert.Empty(t, output.UserID) @@ -490,7 +393,7 @@ func TestCheckSession(t *testing.T) { }, nil) m := NewManager(db, nil) - output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour) + output, err := m.CheckSession(ctx, sessionID, 1*time.Hour) assert.NoError(t, err) assert.True(t, output.Valid) assert.Equal(t, "userID", output.UserID) @@ -500,22 +403,18 @@ func TestCheckSession(t *testing.T) { func TestDeleteSession(t *testing.T) { ctx := context.Background() - hashedSessionID := hashSessionID([]byte("sessionID")) + sessionID := "sessionID" + hashedSessionID := hash(sessionID) t.Run("invalid input", func(t *testing.T) { testCases := []struct { errMsg string - sessionID []byte + sessionID string duration time.Duration }{ { "session id not provided", - nil, - 10, - }, - { - "session id not provided", - []byte(""), + "", 10, }, } @@ -553,7 +452,7 @@ func TestDeleteSession(t *testing.T) { db.On("Exec", ctx, deleteSessionDBQ, hashedSessionID).Return(tc.dbResponse) m := NewManager(db, nil) - err := m.DeleteSession(ctx, []byte("sessionID")) + err := m.DeleteSession(ctx, sessionID) assert.Equal(t, tc.dbResponse, err) db.AssertExpectations(t) }) @@ -879,12 +778,7 @@ func TestGetUserID(t *testing.T) { func TestRegisterSession(t *testing.T) { ctx := context.Background() - - s := &hub.Session{ - UserID: "00000000-0000-0000-0000-000000000001", - IP: "192.168.1.100", - UserAgent: "Safari 13.0.5", - } + userID := "00000000-0000-0000-0000-000000000001" t.Run("invalid input", func(t *testing.T) { testCases := []struct { @@ -919,25 +813,29 @@ func TestRegisterSession(t *testing.T) { db.On("QueryRow", ctx, registerSessionDBQ, mock.Anything).Return(nil, tests.ErrFakeDB) m := NewManager(db, nil) - sessionID, err := m.RegisterSession(ctx, s) + sIN := &hub.Session{UserID: userID} + sOUT, err := m.RegisterSession(ctx, sIN) assert.Equal(t, tests.ErrFakeDB, err) - assert.Nil(t, sessionID) + assert.Nil(t, sOUT) db.AssertExpectations(t) }) t.Run("successful session registration", func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, registerSessionDBQ, mock.Anything).Return([]interface{}{ - []byte("sessionID"), - true, - }, nil) + db.On("QueryRow", ctx, registerSessionDBQ, mock.Anything).Return(true, nil) m := NewManager(db, nil) - session, err := m.RegisterSession(ctx, s) + sIN := &hub.Session{ + UserID: userID, + IP: "192.168.1.100", + UserAgent: "Safari 13.0.5", + } + sOUT, err := m.RegisterSession(ctx, sIN) assert.NoError(t, err) - assert.Equal(t, []byte("sessionID"), session.SessionID) - assert.True(t, session.Approved) + assert.NotEmpty(t, sOUT.SessionID) + assert.Equal(t, userID, sOUT.UserID) + assert.True(t, sOUT.Approved) db.AssertExpectations(t) }) } @@ -995,7 +893,7 @@ func TestRegisterPasswordResetCode(t *testing.T) { t.Run(tc.description, func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, registerPasswordResetCodeDBQ, "email@email.com").Return([]byte("code"), nil) + db.On("Exec", ctx, registerPasswordResetCodeDBQ, "email@email.com", mock.Anything).Return(nil) es := &email.SenderMock{} es.On("SendEmail", mock.Anything).Return(tc.emailSenderResponse) m := NewManager(db, es) @@ -1011,7 +909,7 @@ func TestRegisterPasswordResetCode(t *testing.T) { t.Run("database error registering password reset code", func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, registerPasswordResetCodeDBQ, "email@email.com").Return(nil, tests.ErrFakeDB) + db.On("Exec", ctx, registerPasswordResetCodeDBQ, "email@email.com", mock.Anything).Return(tests.ErrFakeDB) m := NewManager(db, nil) err := m.RegisterPasswordResetCode(ctx, "email@email.com", "http://baseurl.com") @@ -1131,15 +1029,15 @@ func TestRegisterUser(t *testing.T) { func TestResetPassword(t *testing.T) { ctx := context.Background() - code := []byte("code") - codeB64 := base64.URLEncoding.EncodeToString(code) + code := "code" + codeHashed := hash(code) newPassword := "a66bV.Xp2" // #nosec baseURL := "http://baseurl.com" t.Run("invalid input", func(t *testing.T) { testCases := []struct { errMsg string - codeB64 string + code string newPassword string baseURL string }{ @@ -1151,19 +1049,19 @@ func TestResetPassword(t *testing.T) { }, { "new password not provided", - "code", + code, "", baseURL, }, { "invalid base url", - "code", + code, newPassword, "invalid", }, { "insecure password", - "code", + code, "password", baseURL, }, @@ -1174,7 +1072,7 @@ func TestResetPassword(t *testing.T) { t.Parallel() es := &email.SenderMock{} m := NewManager(nil, es) - err := m.ResetPassword(ctx, tc.codeB64, tc.newPassword, tc.baseURL) + err := m.ResetPassword(ctx, tc.code, tc.newPassword, tc.baseURL) assert.True(t, errors.Is(err, hub.ErrInvalidInput)) assert.Contains(t, err.Error(), tc.errMsg) }) @@ -1200,10 +1098,10 @@ func TestResetPassword(t *testing.T) { t.Run(tc.dbErr.Error(), func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, resetUserPasswordDBQ, code, mock.Anything).Return("", tc.dbErr) + db.On("QueryRow", ctx, resetUserPasswordDBQ, codeHashed, mock.Anything).Return("", tc.dbErr) m := NewManager(db, nil) - err := m.ResetPassword(ctx, codeB64, newPassword, baseURL) + err := m.ResetPassword(ctx, code, newPassword, baseURL) assert.Equal(t, tc.expectedErr, err) db.AssertExpectations(t) }) @@ -1229,12 +1127,12 @@ func TestResetPassword(t *testing.T) { t.Run(tc.description, func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, resetUserPasswordDBQ, code, mock.Anything).Return("email", nil) + db.On("QueryRow", ctx, resetUserPasswordDBQ, codeHashed, mock.Anything).Return("email", nil) es := &email.SenderMock{} es.On("SendEmail", mock.Anything).Return(tc.emailSenderResponse) m := NewManager(db, es) - err := m.ResetPassword(ctx, codeB64, newPassword, baseURL) + err := m.ResetPassword(ctx, code, newPassword, baseURL) assert.Equal(t, tc.emailSenderResponse, err) db.AssertExpectations(t) es.AssertExpectations(t) @@ -1494,13 +1392,13 @@ func TestVerifyEmail(t *testing.T) { func TestVerifyPasswordResetCode(t *testing.T) { ctx := context.Background() - code := []byte("code") - codeB64 := base64.URLEncoding.EncodeToString(code) + code := "code" + codeHashed := hash(code) t.Run("invalid input", func(t *testing.T) { testCases := []struct { - errMsg string - codeB64 string + errMsg string + code string }{ { "code not provided", @@ -1512,7 +1410,7 @@ func TestVerifyPasswordResetCode(t *testing.T) { t.Run(tc.errMsg, func(t *testing.T) { t.Parallel() m := NewManager(nil, nil) - err := m.VerifyPasswordResetCode(ctx, tc.codeB64) + err := m.VerifyPasswordResetCode(ctx, tc.code) assert.True(t, errors.Is(err, hub.ErrInvalidInput)) assert.Contains(t, err.Error(), tc.errMsg) }) @@ -1538,10 +1436,10 @@ func TestVerifyPasswordResetCode(t *testing.T) { t.Run(tc.dbErr.Error(), func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("Exec", ctx, verifyPasswordResetCodeDBQ, code).Return(tc.dbErr) + db.On("Exec", ctx, verifyPasswordResetCodeDBQ, codeHashed).Return(tc.dbErr) m := NewManager(db, nil) - err := m.VerifyPasswordResetCode(ctx, codeB64) + err := m.VerifyPasswordResetCode(ctx, code) assert.Equal(t, tc.expectedErr, err) db.AssertExpectations(t) }) @@ -1551,10 +1449,10 @@ func TestVerifyPasswordResetCode(t *testing.T) { t.Run("password code verified successfully in database", func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("Exec", ctx, verifyPasswordResetCodeDBQ, code).Return(nil) + db.On("Exec", ctx, verifyPasswordResetCodeDBQ, codeHashed).Return(nil) m := NewManager(db, nil) - err := m.VerifyPasswordResetCode(ctx, codeB64) + err := m.VerifyPasswordResetCode(ctx, code) assert.Equal(t, nil, err) db.AssertExpectations(t) }) diff --git a/internal/user/mock.go b/internal/user/mock.go index b9fc9ff5..c293923c 100644 --- a/internal/user/mock.go +++ b/internal/user/mock.go @@ -14,18 +14,11 @@ type ManagerMock struct { } // ApproveSession implements the UserManager interface. -func (m *ManagerMock) ApproveSession(ctx context.Context, sessionID []byte, passcode string) error { +func (m *ManagerMock) ApproveSession(ctx context.Context, sessionID, passcode string) error { args := m.Called(ctx, sessionID, passcode) return args.Error(0) } -// CheckAPIKey implements the UserManager interface. -func (m *ManagerMock) CheckAPIKey(ctx context.Context, apiKeyID, apiKeySecret string) (*hub.CheckAPIKeyOutput, error) { - args := m.Called(ctx, apiKeyID, apiKeySecret) - data, _ := args.Get(0).(*hub.CheckAPIKeyOutput) - return data, args.Error(1) -} - // CheckAvailability implements the UserManager interface. func (m *ManagerMock) CheckAvailability(ctx context.Context, resourceKind, value string) (bool, error) { args := m.Called(ctx, resourceKind, value) @@ -46,7 +39,7 @@ func (m *ManagerMock) CheckCredentials( // CheckSession implements the UserManager interface. func (m *ManagerMock) CheckSession( ctx context.Context, - sessionID []byte, + sessionID string, duration time.Duration, ) (*hub.CheckSessionOutput, error) { args := m.Called(ctx, sessionID, duration) @@ -55,7 +48,7 @@ func (m *ManagerMock) CheckSession( } // DeleteSession implements the UserManager interface. -func (m *ManagerMock) DeleteSession(ctx context.Context, sessionID []byte) error { +func (m *ManagerMock) DeleteSession(ctx context.Context, sessionID string) error { args := m.Called(ctx, sessionID) return args.Error(0) }