mirror of https://github.com/artifacthub/hub.git
Some refactoring in user manager and handlers (#1291)
Signed-off-by: Sergio Castaño Arteaga <tegioz@icloud.com>
This commit is contained in:
parent
9146bf6d60
commit
9726b926ed
|
|
@ -1,9 +1,8 @@
|
|||
-- add_api_key adds the provided api key to the database.
|
||||
create or replace function add_api_key(p_api_key jsonb)
|
||||
returns setof json as $$
|
||||
returns uuid as $$
|
||||
declare
|
||||
v_api_key_id uuid;
|
||||
v_api_key_secret text := encode(gen_random_bytes(32), 'base64');
|
||||
begin
|
||||
insert into api_key (
|
||||
name,
|
||||
|
|
@ -11,13 +10,10 @@ begin
|
|||
user_id
|
||||
) values (
|
||||
p_api_key->>'name',
|
||||
encode(sha512(v_api_key_secret::bytea), 'hex'),
|
||||
p_api_key->>'secret',
|
||||
(p_api_key->>'user_id')::uuid
|
||||
) returning api_key_id into v_api_key_id;
|
||||
|
||||
return query select json_build_object(
|
||||
'api_key_id', v_api_key_id,
|
||||
'secret', v_api_key_secret
|
||||
);
|
||||
return v_api_key_id;
|
||||
end
|
||||
$$ language plpgsql;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
-- approvee_session approves the provided session in the database.
|
||||
create or replace function approve_session(p_session_id bytea, p_recovery_code text)
|
||||
create or replace function approve_session(p_session_id text, p_recovery_code text)
|
||||
returns void as $$
|
||||
begin
|
||||
-- Mark session as approved
|
||||
|
|
|
|||
|
|
@ -1,18 +1,15 @@
|
|||
-- register_password_reset_code registers a password reset code for the user
|
||||
-- identified by the email provided.
|
||||
create or replace function register_password_reset_code(p_email text)
|
||||
returns bytea as $$
|
||||
declare
|
||||
v_code bytea := gen_random_bytes(32);
|
||||
create or replace function register_password_reset_code(p_email text, p_code text)
|
||||
returns void as $$
|
||||
begin
|
||||
insert into password_reset_code (password_reset_code_id, user_id)
|
||||
select sha512(v_code), user_id from "user" where email = p_email and email_verified = true
|
||||
select p_code, user_id from "user" where email = p_email and email_verified = true
|
||||
on conflict (user_id) do update set
|
||||
password_reset_code_id = sha512(v_code),
|
||||
password_reset_code_id = p_code,
|
||||
created_at = current_timestamp;
|
||||
if not found then
|
||||
raise 'invalid email';
|
||||
end if;
|
||||
return v_code;
|
||||
end
|
||||
$$ language plpgsql;
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
-- register_session registers the provided session in the database.
|
||||
create or replace function register_session(p_session jsonb)
|
||||
returns table (session_id bytea, approved boolean) as $$
|
||||
returns boolean as $$
|
||||
declare
|
||||
v_session_id bytea := gen_random_bytes(32);
|
||||
v_approved boolean;
|
||||
begin
|
||||
-- Check if the session requires approval or not. When the user has enabled
|
||||
|
|
@ -22,13 +21,13 @@ begin
|
|||
user_agent,
|
||||
approved
|
||||
) values (
|
||||
sha512(v_session_id),
|
||||
p_session->>'session_id',
|
||||
(p_session->>'user_id')::uuid,
|
||||
nullif(p_session->>'ip', '')::inet,
|
||||
nullif(p_session->>'user_agent', ''),
|
||||
v_approved
|
||||
);
|
||||
|
||||
return query select v_session_id, v_approved;
|
||||
return v_approved;
|
||||
end
|
||||
$$ language plpgsql;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
-- reset_user_password resets the password of the user associated to the code
|
||||
-- provided if it is still valid, returning the email of the user.
|
||||
create or replace function reset_user_password(p_code bytea, p_new_password text)
|
||||
create or replace function reset_user_password(p_code text, p_new_password text)
|
||||
returns text as $$
|
||||
declare
|
||||
v_user_id uuid;
|
||||
|
|
@ -13,7 +13,7 @@ begin
|
|||
select u.user_id, u.email into v_user_id, v_email
|
||||
from "user" u
|
||||
join password_reset_code prc using (user_id)
|
||||
where password_reset_code_id = sha512(p_code);
|
||||
where password_reset_code_id = p_code;
|
||||
|
||||
-- Update user password
|
||||
update "user" set password = p_new_password where user_id = v_user_id;
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
-- verify_password_reset_code verifies is the password reset code provided is
|
||||
-- valid. The code must exist and not have expired.
|
||||
create or replace function verify_password_reset_code(p_code bytea)
|
||||
create or replace function verify_password_reset_code(p_code text)
|
||||
returns void as $$
|
||||
begin
|
||||
perform from password_reset_code
|
||||
where password_reset_code_id = sha512(p_code)
|
||||
where password_reset_code_id = p_code
|
||||
and created_at + '15 minute'::interval > current_timestamp;
|
||||
if not found then
|
||||
raise 'invalid password reset code';
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
delete from password_reset_code;
|
||||
alter table password_reset_code alter column password_reset_code_id drop default;
|
||||
alter table password_reset_code alter column password_reset_code_id type bytea USING password_reset_code_id::text::bytea;
|
||||
alter table password_reset_code alter column password_reset_code_id type bytea using password_reset_code_id::text::bytea;
|
||||
drop function if exists register_password_reset_code(text);
|
||||
|
||||
---- create above / drop below ----
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
delete from password_reset_code;
|
||||
alter table password_reset_code alter column password_reset_code_id type text;
|
||||
alter table session alter column session_id type text using substring(session_id::bytea::text from 3);
|
||||
drop function if exists add_api_key(jsonb);
|
||||
drop function if exists approve_session(bytea, text);
|
||||
drop function if exists register_password_reset_code(text);
|
||||
drop function if exists register_session(jsonb);
|
||||
drop function if exists reset_user_password(bytea, text);
|
||||
drop function if exists verify_password_reset_code(bytea);
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
delete from password_reset_code;
|
||||
alter table password_reset_code alter column password_reset_code_id type bytea;
|
||||
alter table session alter column session_id type bytea using session_id::text::bytea;
|
||||
|
|
@ -13,6 +13,7 @@ values (:'user1ID', 'user1', 'user1@email.com');
|
|||
select add_api_key('
|
||||
{
|
||||
"name": "apikey1",
|
||||
"secret": "hashed-secret",
|
||||
"user_id": "00000000-0000-0000-0000-000000000001"
|
||||
}
|
||||
'::jsonb);
|
||||
|
|
@ -22,12 +23,14 @@ select results_eq(
|
|||
$$
|
||||
select
|
||||
name,
|
||||
secret,
|
||||
user_id
|
||||
from api_key
|
||||
$$,
|
||||
$$
|
||||
values (
|
||||
'apikey1',
|
||||
'hashed-secret',
|
||||
'00000000-0000-0000-0000-000000000001'::uuid
|
||||
)
|
||||
$$,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
-- Start transaction and plan tests
|
||||
begin;
|
||||
select plan(5);
|
||||
select plan(4);
|
||||
|
||||
-- Seed user
|
||||
insert into "user" (user_id, alias, email, email_verified)
|
||||
|
|
@ -9,10 +9,10 @@ insert into "user" (user_id, alias, email, email_verified)
|
|||
values ('00000000-0000-0000-0000-000000000002', 'user2', 'user2@email.com', false);
|
||||
|
||||
-- Register password reset code
|
||||
select register_password_reset_code('user1@email.com') as code1 \gset
|
||||
select register_password_reset_code('user1@email.com', 'code1');
|
||||
select is(
|
||||
password_reset_code_id,
|
||||
sha512(:'code1'),
|
||||
'code1',
|
||||
'Password reset code for user1 should be registered'
|
||||
)
|
||||
from password_reset_code
|
||||
|
|
@ -20,24 +20,19 @@ join "user" using (user_id)
|
|||
where alias = 'user1';
|
||||
|
||||
-- Register another password reset code for the same user
|
||||
select register_password_reset_code('user1@email.com') as code2 \gset
|
||||
select register_password_reset_code('user1@email.com', 'code2');
|
||||
select is(
|
||||
password_reset_code_id,
|
||||
sha512(:'code2'),
|
||||
'code2',
|
||||
'Password reset code for user1 should have been updated'
|
||||
)
|
||||
from password_reset_code
|
||||
join "user" using (user_id)
|
||||
where alias = 'user1';
|
||||
select isnt(
|
||||
:'code1'::bytea,
|
||||
:'code2'::bytea,
|
||||
'Password reset code must have changed'
|
||||
);
|
||||
|
||||
-- Try registering password reset code using non verified email
|
||||
select throws_ok(
|
||||
$$ select register_password_reset_code('user2@email.com') $$,
|
||||
$$ select register_password_reset_code('user2@email.com', 'code') $$,
|
||||
'P0001',
|
||||
'invalid email',
|
||||
'No password reset code should be registered for non verified email user2@email.com'
|
||||
|
|
@ -45,7 +40,7 @@ select throws_ok(
|
|||
|
||||
-- Try registering password reset code using unregistered email
|
||||
select throws_ok(
|
||||
$$ select register_password_reset_code('user3@email.com') $$,
|
||||
$$ select register_password_reset_code('user3@email.com', 'code') $$,
|
||||
'P0001',
|
||||
'invalid email',
|
||||
'No password reset code should be registered for unregistered email user3@email.com'
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
-- Start transaction and plan tests
|
||||
begin;
|
||||
select plan(6);
|
||||
select plan(2);
|
||||
|
||||
-- Seed user
|
||||
insert into "user" (user_id, alias, email)
|
||||
|
|
@ -9,18 +9,20 @@ insert into "user" (user_id, alias, email, tfa_enabled)
|
|||
values ('00000000-0000-0000-0000-000000000002', 'user2', 'user2@email.com', true);
|
||||
|
||||
-- Register session for user with tfa disabled
|
||||
select session_id, approved from register_session('
|
||||
select register_session('
|
||||
{
|
||||
"session_id": "hashed-session-id-user1",
|
||||
"user_id": "00000000-0000-0000-0000-000000000001",
|
||||
"ip": "192.168.1.100",
|
||||
"user_agent": "Safari 13.0.5"
|
||||
}
|
||||
') \gset
|
||||
') as approved \gset
|
||||
|
||||
-- Check if session registration succeeded
|
||||
select results_eq(
|
||||
$$
|
||||
select
|
||||
session_id,
|
||||
user_id,
|
||||
ip,
|
||||
user_agent,
|
||||
|
|
@ -30,6 +32,7 @@ select results_eq(
|
|||
$$,
|
||||
$$
|
||||
values (
|
||||
'hashed-session-id-user1',
|
||||
'00000000-0000-0000-0000-000000000001'::uuid,
|
||||
'192.168.1.100'::inet,
|
||||
'Safari 13.0.5',
|
||||
|
|
@ -38,53 +41,32 @@ select results_eq(
|
|||
$$,
|
||||
'Session for user1 should exist'
|
||||
);
|
||||
select is(
|
||||
session_id,
|
||||
sha512(:'session_id'),
|
||||
'Returned session id for user1 should match value stored'
|
||||
)
|
||||
from session where user_id = '00000000-0000-0000-0000-000000000001';
|
||||
select is(
|
||||
true,
|
||||
:'approved',
|
||||
'Returned approved value for user1 should be true'
|
||||
)
|
||||
from session where user_id = '00000000-0000-0000-0000-000000000001';
|
||||
|
||||
-- Register session for user with tfa enabled
|
||||
select session_id, approved from register_session('
|
||||
select register_session('
|
||||
{
|
||||
"session_id": "hashed-session-id-user2",
|
||||
"user_id": "00000000-0000-0000-0000-000000000002"
|
||||
}
|
||||
') \gset
|
||||
') as approved \gset
|
||||
|
||||
-- Check if session registration succeeded
|
||||
select results_eq(
|
||||
$$
|
||||
select user_id, approved
|
||||
select
|
||||
session_id,
|
||||
approved
|
||||
from session
|
||||
where user_id = '00000000-0000-0000-0000-000000000002'
|
||||
$$,
|
||||
$$
|
||||
values (
|
||||
'00000000-0000-0000-0000-000000000002'::uuid,
|
||||
'hashed-session-id-user2',
|
||||
false
|
||||
)
|
||||
$$,
|
||||
'Session for user2 should exist'
|
||||
);
|
||||
select is(
|
||||
session_id,
|
||||
sha512(:'session_id'),
|
||||
'Returned session id for user2 should match value stored'
|
||||
)
|
||||
from session where user_id = '00000000-0000-0000-0000-000000000002';
|
||||
select is(
|
||||
false,
|
||||
:'approved',
|
||||
'Returned approved value for user2 should be false'
|
||||
)
|
||||
from session where user_id = '00000000-0000-0000-0000-000000000002';
|
||||
|
||||
-- Finish tests and rollback transaction
|
||||
select * from finish();
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ values (:'user1ID', 'user1', 'user1@email.com');
|
|||
insert into "user" (user_id, alias, email)
|
||||
values (:'user2ID', 'user2', 'user2@email.com');
|
||||
insert into password_reset_code (password_reset_code_id, user_id, created_at)
|
||||
values (sha512(:'code1ID'), :'user1ID', current_timestamp);
|
||||
values (:'code1ID', :'user1ID', current_timestamp);
|
||||
insert into password_reset_code (password_reset_code_id, user_id, created_at)
|
||||
values (sha512(:'code2ID'), :'user2ID', current_timestamp - '30 minute'::interval);
|
||||
values (:'code2ID', :'user2ID', current_timestamp - '30 minute'::interval);
|
||||
insert into session (session_id, user_id) values (gen_random_bytes(32), :'user1ID');
|
||||
|
||||
-- Password reset should fail in the following cases
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ values (:'user1ID', 'user1', 'user1@email.com');
|
|||
insert into "user" (user_id, alias, email)
|
||||
values (:'user2ID', 'user2', 'user2@email.com');
|
||||
insert into password_reset_code (password_reset_code_id, user_id, created_at)
|
||||
values (sha512(:'code1ID'), :'user1ID', current_timestamp - '5 minute'::interval);
|
||||
values (:'code1ID', :'user1ID', current_timestamp - '5 minute'::interval);
|
||||
insert into password_reset_code (password_reset_code_id, user_id, created_at)
|
||||
values (:'code2ID', :'user2ID', current_timestamp - '30 minute'::interval);
|
||||
|
||||
|
|
|
|||
|
|
@ -2,21 +2,29 @@ package apikey
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/artifacthub/hub/internal/hub"
|
||||
"github.com/artifacthub/hub/internal/util"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/satori/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
// Database queries
|
||||
addAPIKeyDBQ = `select add_api_key($1::jsonb)`
|
||||
deleteAPIKeyDBQ = `select delete_api_key($1::uuid, $2::uuid)`
|
||||
getAPIKeyDBQ = `select get_api_key($1::uuid, $2::uuid)`
|
||||
getUserAPIKeysDBQ = `select get_user_api_keys($1::uuid)`
|
||||
updateAPIKeyDBQ = `select update_api_key($1::jsonb)`
|
||||
addAPIKeyDBQ = `select add_api_key($1::jsonb)`
|
||||
deleteAPIKeyDBQ = `select delete_api_key($1::uuid, $2::uuid)`
|
||||
getAPIKeyDBQ = `select get_api_key($1::uuid, $2::uuid)`
|
||||
getAPIKeyUserIDDBQ = `select user_id, secret from api_key where api_key_id = $1`
|
||||
getUserAPIKeysDBQ = `select get_user_api_keys($1::uuid)`
|
||||
updateAPIKeyDBQ = `select update_api_key($1::jsonb)`
|
||||
)
|
||||
|
||||
// Manager provides an API to manage api keys.
|
||||
|
|
@ -32,7 +40,7 @@ func NewManager(db hub.DB) *Manager {
|
|||
}
|
||||
|
||||
// Add adds the provided api key to the database.
|
||||
func (m *Manager) Add(ctx context.Context, ak *hub.APIKey) ([]byte, error) {
|
||||
func (m *Manager) Add(ctx context.Context, ak *hub.APIKey) (*hub.APIKey, error) {
|
||||
ak.UserID = ctx.Value(hub.UserIDKey).(string)
|
||||
|
||||
// Validate input
|
||||
|
|
@ -40,9 +48,64 @@ func (m *Manager) Add(ctx context.Context, ak *hub.APIKey) ([]byte, error) {
|
|||
return nil, fmt.Errorf("%w: %s", hub.ErrInvalidInput, "name not provided")
|
||||
}
|
||||
|
||||
// Generate API key secret
|
||||
randomBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
apiKeySecret := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
apiKeySecretHashed := fmt.Sprintf("%x", sha512.Sum512([]byte(apiKeySecret)))
|
||||
|
||||
// Add api key to the database
|
||||
var apiKeyID string
|
||||
ak.Secret = apiKeySecretHashed
|
||||
akJSON, _ := json.Marshal(ak)
|
||||
return util.DBQueryJSON(ctx, m.db, addAPIKeyDBQ, akJSON)
|
||||
if err := m.db.QueryRow(ctx, addAPIKeyDBQ, akJSON).Scan(&apiKeyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &hub.APIKey{
|
||||
APIKeyID: apiKeyID,
|
||||
Secret: apiKeySecret,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Check checks if the api key provided is valid.
|
||||
func (m *Manager) Check(ctx context.Context, apiKeyID, apiKeySecret string) (*hub.CheckAPIKeyOutput, error) {
|
||||
// Validate input
|
||||
if apiKeyID == "" || apiKeySecret == "" {
|
||||
return nil, fmt.Errorf("%w: %s", hub.ErrInvalidInput, "api key id or secret not provided")
|
||||
}
|
||||
|
||||
// Get key's user id and secret from database
|
||||
var userID, apiKeySecretHashed string
|
||||
err := m.db.QueryRow(ctx, getAPIKeyUserIDDBQ, apiKeyID).Scan(&userID, &apiKeySecretHashed)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return &hub.CheckAPIKeyOutput{Valid: false}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the secret provided is valid
|
||||
switch {
|
||||
case strings.HasPrefix(apiKeySecretHashed, "$2a$"):
|
||||
// Bcrypt hash, will be deprecated soon
|
||||
err = bcrypt.CompareHashAndPassword([]byte(apiKeySecretHashed), []byte(apiKeySecret))
|
||||
if err != nil {
|
||||
return &hub.CheckAPIKeyOutput{Valid: false}, nil
|
||||
}
|
||||
default:
|
||||
// SHA512 hash
|
||||
if fmt.Sprintf("%x", sha512.Sum512([]byte(apiKeySecret))) != apiKeySecretHashed {
|
||||
return &hub.CheckAPIKeyOutput{Valid: false}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return &hub.CheckAPIKeyOutput{
|
||||
Valid: true,
|
||||
UserID: userID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Delete deletes the provided api key from the database.
|
||||
|
|
|
|||
|
|
@ -2,13 +2,18 @@ package apikey
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha512"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/artifacthub/hub/internal/hub"
|
||||
"github.com/artifacthub/hub/internal/tests"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const apiKeyID = "00000000-0000-0000-0000-000000000001"
|
||||
|
|
@ -60,14 +65,13 @@ func TestAdd(t *testing.T) {
|
|||
Name: "apikey1",
|
||||
UserID: "userID",
|
||||
}
|
||||
akJSON, _ := json.Marshal(ak)
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, addAPIKeyDBQ, akJSON).Return(nil, tests.ErrFakeDB)
|
||||
db.On("QueryRow", ctx, addAPIKeyDBQ, mock.Anything).Return(nil, tests.ErrFakeDB)
|
||||
m := NewManager(db)
|
||||
|
||||
keyInfoJSON, err := m.Add(ctx, ak)
|
||||
output, err := m.Add(ctx, ak)
|
||||
assert.Equal(t, tests.ErrFakeDB, err)
|
||||
assert.Nil(t, keyInfoJSON)
|
||||
assert.Nil(t, output)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
|
||||
|
|
@ -77,14 +81,100 @@ func TestAdd(t *testing.T) {
|
|||
Name: "apikey1",
|
||||
UserID: "userID",
|
||||
}
|
||||
akJSON, _ := json.Marshal(ak)
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, addAPIKeyDBQ, akJSON).Return([]byte("keyInfoJSON"), nil)
|
||||
db.On("QueryRow", ctx, addAPIKeyDBQ, mock.Anything).Return("apiKeyID", nil)
|
||||
m := NewManager(db)
|
||||
|
||||
keyInfoJSON, err := m.Add(ctx, ak)
|
||||
output, err := m.Add(ctx, ak)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("keyInfoJSON"), keyInfoJSON)
|
||||
assert.Equal(t, "apiKeyID", output.APIKeyID)
|
||||
assert.NotEmpty(t, output.Secret)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheck(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("invalid input", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
errMsg string
|
||||
apiKeyID string
|
||||
apiKeySecret string
|
||||
}{
|
||||
{
|
||||
"api key id or secret not provided",
|
||||
"",
|
||||
"secret",
|
||||
},
|
||||
{
|
||||
"api key id or secret not provided",
|
||||
"key",
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.errMsg, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewManager(nil)
|
||||
_, err := m.Check(ctx, tc.apiKeyID, tc.apiKeySecret)
|
||||
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
||||
assert.Contains(t, err.Error(), tc.errMsg)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("key info not found in database", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, getAPIKeyUserIDDBQ, "keyID").Return(nil, pgx.ErrNoRows)
|
||||
m := NewManager(db)
|
||||
|
||||
output, err := m.Check(ctx, "keyID", "secret")
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, output.Valid)
|
||||
assert.Empty(t, output.UserID)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("error getting key info from database", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, getAPIKeyUserIDDBQ, "keyID").Return(nil, tests.ErrFakeDB)
|
||||
m := NewManager(db)
|
||||
|
||||
output, err := m.Check(ctx, "keyID", "secret")
|
||||
assert.Equal(t, tests.ErrFakeDB, err)
|
||||
assert.Nil(t, output)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("valid key (secret hashed with bcrypt)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
secretHashed, _ := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.DefaultCost)
|
||||
db.On("QueryRow", ctx, getAPIKeyUserIDDBQ, "keyID").Return([]interface{}{"userID", string(secretHashed)}, nil)
|
||||
m := NewManager(db)
|
||||
|
||||
output, err := m.Check(ctx, "keyID", "secret")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, output.Valid)
|
||||
assert.Equal(t, "userID", output.UserID)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("valid key (secret hashed with sha512)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
secretHashed := fmt.Sprintf("%x", sha512.Sum512([]byte("secret")))
|
||||
db.On("QueryRow", ctx, getAPIKeyUserIDDBQ, "keyID").Return([]interface{}{"userID", secretHashed}, nil)
|
||||
m := NewManager(db)
|
||||
|
||||
output, err := m.Check(ctx, "keyID", "secret")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, output.Valid)
|
||||
assert.Equal(t, "userID", output.UserID)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,9 +13,16 @@ type ManagerMock struct {
|
|||
}
|
||||
|
||||
// Add implements the APIKeyManager interface.
|
||||
func (m *ManagerMock) Add(ctx context.Context, ak *hub.APIKey) ([]byte, error) {
|
||||
func (m *ManagerMock) Add(ctx context.Context, ak *hub.APIKey) (*hub.APIKey, error) {
|
||||
args := m.Called(ctx, ak)
|
||||
data, _ := args.Get(0).([]byte)
|
||||
data, _ := args.Get(0).(*hub.APIKey)
|
||||
return data, args.Error(1)
|
||||
}
|
||||
|
||||
// Check implements the UserManager interface.
|
||||
func (m *ManagerMock) Check(ctx context.Context, apiKeyID, apiKeySecret string) (*hub.CheckAPIKeyOutput, error) {
|
||||
args := m.Called(ctx, apiKeyID, apiKeySecret)
|
||||
data, _ := args.Get(0).(*hub.CheckAPIKeyOutput)
|
||||
return data, args.Error(1)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -28,19 +28,20 @@ func NewHandlers(apiKeyManager hub.APIKeyManager) *Handlers {
|
|||
|
||||
// Add is an http handler that adds the provided api key to the database.
|
||||
func (h *Handlers) Add(w http.ResponseWriter, r *http.Request) {
|
||||
ak := &hub.APIKey{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&ak); err != nil {
|
||||
akIN := &hub.APIKey{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&akIN); err != nil {
|
||||
h.logger.Error().Err(err).Str("method", "Add").Msg(hub.ErrInvalidInput.Error())
|
||||
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
|
||||
return
|
||||
}
|
||||
dataJSON, err := h.apiKeyManager.Add(r.Context(), ak)
|
||||
akOUT, err := h.apiKeyManager.Add(r.Context(), akIN)
|
||||
if err != nil {
|
||||
h.logger.Error().Err(err).Str("method", "Add").Send()
|
||||
helpers.RenderErrorJSON(w, err)
|
||||
return
|
||||
}
|
||||
helpers.RenderJSON(w, dataJSON, 0, http.StatusCreated)
|
||||
akOUTJSON, _ := json.Marshal(akOUT)
|
||||
helpers.RenderJSON(w, akOUTJSON, 0, http.StatusCreated)
|
||||
}
|
||||
|
||||
// Delete is an http handler that deletes the provided api key from the database.
|
||||
|
|
|
|||
|
|
@ -100,7 +100,11 @@ func TestAdd(t *testing.T) {
|
|||
r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
|
||||
|
||||
hw := newHandlersWrapper()
|
||||
hw.am.On("Add", r.Context(), ak).Return([]byte("keyInfoJSON"), nil)
|
||||
akOUT := &hub.APIKey{
|
||||
APIKeyID: "apiKeyID",
|
||||
Secret: "secret",
|
||||
}
|
||||
hw.am.On("Add", r.Context(), ak).Return(akOUT, nil)
|
||||
hw.h.Add(w, r)
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
|
@ -110,7 +114,8 @@ func TestAdd(t *testing.T) {
|
|||
assert.Equal(t, http.StatusCreated, resp.StatusCode)
|
||||
assert.Equal(t, "application/json", h.Get("Content-Type"))
|
||||
assert.Equal(t, helpers.BuildCacheControlHeader(0), h.Get("Cache-Control"))
|
||||
assert.Equal(t, []byte("keyInfoJSON"), data)
|
||||
outputAKJSON, _ := json.Marshal(akOUT)
|
||||
assert.Equal(t, outputAKJSON, data)
|
||||
hw.am.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ type Handlers struct {
|
|||
|
||||
// Setup creates a new Handlers instance.
|
||||
func Setup(ctx context.Context, cfg *viper.Viper, svc *Services) (*Handlers, error) {
|
||||
userHandlers, err := user.NewHandlers(ctx, svc.UserManager, cfg)
|
||||
userHandlers, err := user.NewHandlers(ctx, svc.UserManager, svc.APIKeyManager, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,16 +66,22 @@ var (
|
|||
// Handlers represents a group of http handlers in charge of handling
|
||||
// users operations.
|
||||
type Handlers struct {
|
||||
userManager hub.UserManager
|
||||
cfg *viper.Viper
|
||||
sc *securecookie.SecureCookie
|
||||
oauthConfig map[string]*oauth2.Config
|
||||
oidcProvider *oidc.Provider
|
||||
logger zerolog.Logger
|
||||
userManager hub.UserManager
|
||||
apiKeyManager hub.APIKeyManager
|
||||
cfg *viper.Viper
|
||||
sc *securecookie.SecureCookie
|
||||
oauthConfig map[string]*oauth2.Config
|
||||
oidcProvider *oidc.Provider
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
// NewHandlers creates a new Handlers instance.
|
||||
func NewHandlers(ctx context.Context, userManager hub.UserManager, cfg *viper.Viper) (*Handlers, error) {
|
||||
func NewHandlers(
|
||||
ctx context.Context,
|
||||
userManager hub.UserManager,
|
||||
apiKeyManager hub.APIKeyManager,
|
||||
cfg *viper.Viper,
|
||||
) (*Handlers, error) {
|
||||
// Setup secure cookie instance
|
||||
sc := securecookie.New([]byte(cfg.GetString("server.cookie.hashKey")), nil)
|
||||
sc.MaxAge(int(sessionDuration.Seconds()))
|
||||
|
|
@ -112,12 +118,13 @@ func NewHandlers(ctx context.Context, userManager hub.UserManager, cfg *viper.Vi
|
|||
}
|
||||
|
||||
return &Handlers{
|
||||
userManager: userManager,
|
||||
cfg: cfg,
|
||||
sc: sc,
|
||||
oauthConfig: oauthConfig,
|
||||
oidcProvider: oidcProvider,
|
||||
logger: log.With().Str("handlers", "user").Logger(),
|
||||
userManager: userManager,
|
||||
apiKeyManager: apiKeyManager,
|
||||
cfg: cfg,
|
||||
sc: sc,
|
||||
oauthConfig: oauthConfig,
|
||||
oidcProvider: oidcProvider,
|
||||
logger: log.With().Str("handlers", "user").Logger(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -137,7 +144,7 @@ func (h *Handlers) ApproveSession(w http.ResponseWriter, r *http.Request) {
|
|||
passcode := input["passcode"]
|
||||
|
||||
// Extract sessionID from cookie
|
||||
var sessionID []byte
|
||||
var sessionID string
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
h.logger.Error().Err(err).Str("method", "ApproveSession").Msg("session cookie not found")
|
||||
|
|
@ -290,7 +297,7 @@ func (h *Handlers) InjectUserID(next http.Handler) http.Handler {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
var sessionID []byte
|
||||
var sessionID string
|
||||
if err = h.sc.Decode(sessionCookieName, cookie.Value, &sessionID); err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -371,7 +378,7 @@ func (h *Handlers) Logout(w http.ResponseWriter, r *http.Request) {
|
|||
// Delete user session
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err == nil {
|
||||
var sessionID []byte
|
||||
var sessionID string
|
||||
err = h.sc.Decode(sessionCookieName, cookie.Value, &sessionID)
|
||||
if err == nil {
|
||||
err = h.userManager.DeleteSession(r.Context(), sessionID)
|
||||
|
|
@ -749,7 +756,7 @@ func (h *Handlers) RequireLogin(next http.Handler) http.Handler {
|
|||
// Use API key based authentication if API key is provided
|
||||
if apiKeyID != "" && apiKeySecret != "" {
|
||||
// Check the API key provided is valid
|
||||
checkAPIKeyOutput, err := h.userManager.CheckAPIKey(r.Context(), apiKeyID, apiKeySecret)
|
||||
checkAPIKeyOutput, err := h.apiKeyManager.Check(r.Context(), apiKeyID, apiKeySecret)
|
||||
if err != nil {
|
||||
h.logger.Error().Err(err).Str("method", "RequireLogin").Msg("checkAPIKey failed")
|
||||
helpers.RenderErrorWithCodeJSON(w, nil, http.StatusInternalServerError)
|
||||
|
|
@ -766,7 +773,7 @@ func (h *Handlers) RequireLogin(next http.Handler) http.Handler {
|
|||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err == nil {
|
||||
// Extract and validate cookie from request
|
||||
var sessionID []byte
|
||||
var sessionID string
|
||||
if err = h.sc.Decode(sessionCookieName, cookie.Value, &sessionID); err != nil {
|
||||
h.logger.Error().Err(err).Str("method", "RequireLogin").Msg("sessionID decoding failed")
|
||||
helpers.RenderErrorWithCodeJSON(w, errInvalidSession, http.StatusUnauthorized)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/artifacthub/hub/internal/apikey"
|
||||
"github.com/artifacthub/hub/internal/handlers/helpers"
|
||||
"github.com/artifacthub/hub/internal/hub"
|
||||
"github.com/artifacthub/hub/internal/tests"
|
||||
|
|
@ -31,7 +32,7 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
func TestApproveSession(t *testing.T) {
|
||||
sessionID := []byte("sessionID")
|
||||
sessionID := "sessionID"
|
||||
|
||||
t.Run("invalid input", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
|
|
@ -495,6 +496,8 @@ func TestGetProfile(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInjectUserID(t *testing.T) {
|
||||
sessionID := "sessionID"
|
||||
|
||||
checkUserID := func(expectedUserID interface{}) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if expectedUserID == nil {
|
||||
|
|
@ -541,7 +544,7 @@ func TestInjectUserID(t *testing.T) {
|
|||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
|
||||
hw := newHandlersWrapper()
|
||||
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, []byte("sessionID"))
|
||||
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: encodedSessionID,
|
||||
|
|
@ -565,7 +568,7 @@ func TestInjectUserID(t *testing.T) {
|
|||
hw.um.On("CheckSession", r.Context(), mock.Anything, mock.Anything).
|
||||
Return(&hub.CheckSessionOutput{UserID: "", Valid: false}, nil)
|
||||
|
||||
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, []byte("sessionID"))
|
||||
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: encodedSessionID,
|
||||
|
|
@ -586,7 +589,7 @@ func TestInjectUserID(t *testing.T) {
|
|||
hw := newHandlersWrapper()
|
||||
hw.um.On("CheckSession", r.Context(), mock.Anything, mock.Anything).
|
||||
Return(&hub.CheckSessionOutput{UserID: "userID", Valid: true}, nil)
|
||||
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, []byte("sessionID"))
|
||||
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: encodedSessionID,
|
||||
|
|
@ -601,6 +604,8 @@ func TestInjectUserID(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
sessionID := "sessionID"
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
w := httptest.NewRecorder()
|
||||
|
|
@ -693,7 +698,7 @@ func TestLogin(t *testing.T) {
|
|||
Return(&hub.CheckCredentialsOutput{Valid: true, UserID: "userID"}, nil)
|
||||
hw.um.On("RegisterSession", r.Context(), &hub.Session{UserID: "userID"}).
|
||||
Return(&hub.Session{
|
||||
SessionID: []byte("sessionID"),
|
||||
SessionID: sessionID,
|
||||
Approved: true,
|
||||
}, nil)
|
||||
hw.h.Login(w, r)
|
||||
|
|
@ -708,10 +713,10 @@ func TestLogin(t *testing.T) {
|
|||
assert.Equal(t, "/", cookie.Path)
|
||||
assert.True(t, cookie.HttpOnly)
|
||||
assert.False(t, cookie.Secure)
|
||||
var sessionID []byte
|
||||
err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &sessionID)
|
||||
var cookieSessionID string
|
||||
err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &cookieSessionID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("sessionID"), sessionID)
|
||||
assert.Equal(t, sessionID, cookieSessionID)
|
||||
assert.Equal(t, "true", h.Get(SessionApprovedHeader))
|
||||
hw.um.AssertExpectations(t)
|
||||
})
|
||||
|
|
@ -727,7 +732,7 @@ func TestLogin(t *testing.T) {
|
|||
Return(&hub.CheckCredentialsOutput{Valid: true, UserID: "userID"}, nil)
|
||||
hw.um.On("RegisterSession", r.Context(), &hub.Session{UserID: "userID"}).
|
||||
Return(&hub.Session{
|
||||
SessionID: []byte("sessionID"),
|
||||
SessionID: sessionID,
|
||||
Approved: false,
|
||||
}, nil)
|
||||
hw.h.Login(w, r)
|
||||
|
|
@ -742,10 +747,10 @@ func TestLogin(t *testing.T) {
|
|||
assert.Equal(t, "/", cookie.Path)
|
||||
assert.True(t, cookie.HttpOnly)
|
||||
assert.False(t, cookie.Secure)
|
||||
var sessionID []byte
|
||||
err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &sessionID)
|
||||
var cookieSessionID string
|
||||
err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &cookieSessionID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("sessionID"), sessionID)
|
||||
assert.Equal(t, sessionID, cookieSessionID)
|
||||
assert.Equal(t, "false", h.Get(SessionApprovedHeader))
|
||||
hw.um.AssertExpectations(t)
|
||||
})
|
||||
|
|
@ -815,8 +820,8 @@ func TestLogout(t *testing.T) {
|
|||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
|
||||
hw := newHandlersWrapper()
|
||||
hw.um.On("DeleteSession", r.Context(), []byte("sessionID")).Return(tc.err)
|
||||
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, []byte("sessionID"))
|
||||
hw.um.On("DeleteSession", r.Context(), "sessionID").Return(tc.err)
|
||||
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, "sessionID")
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: encodedSessionID,
|
||||
|
|
@ -1092,7 +1097,7 @@ func TestRegisterUser(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRequireLogin(t *testing.T) {
|
||||
sessionID := []byte("sessionID")
|
||||
sessionID := "sessionID"
|
||||
|
||||
t.Run("api key based authentication", func(t *testing.T) {
|
||||
apiKeyID := "keyID"
|
||||
|
|
@ -1143,7 +1148,7 @@ func TestRequireLogin(t *testing.T) {
|
|||
r.Header.Add(APIKeySecretHeader, apiKeySecret)
|
||||
|
||||
hw := newHandlersWrapper()
|
||||
hw.um.On("CheckAPIKey", r.Context(), apiKeyID, apiKeySecret).Return(nil, tests.ErrFakeDB)
|
||||
hw.am.On("Check", r.Context(), apiKeyID, apiKeySecret).Return(nil, tests.ErrFakeDB)
|
||||
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
|
@ -1164,7 +1169,7 @@ func TestRequireLogin(t *testing.T) {
|
|||
r.Header.Add(APIKeySecretHeader, apiKeySecret)
|
||||
|
||||
hw := newHandlersWrapper()
|
||||
hw.um.On("CheckAPIKey", r.Context(), apiKeyID, apiKeySecret).
|
||||
hw.am.On("Check", r.Context(), apiKeyID, apiKeySecret).
|
||||
Return(&hub.CheckAPIKeyOutput{UserID: "", Valid: false}, nil)
|
||||
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
|
||||
resp := w.Result()
|
||||
|
|
@ -1186,7 +1191,7 @@ func TestRequireLogin(t *testing.T) {
|
|||
r.Header.Add(APIKeySecretHeader, apiKeySecret)
|
||||
|
||||
hw := newHandlersWrapper()
|
||||
hw.um.On("CheckAPIKey", r.Context(), apiKeyID, apiKeySecret).
|
||||
hw.am.On("Check", r.Context(), apiKeyID, apiKeySecret).
|
||||
Return(&hub.CheckAPIKeyOutput{UserID: "userID", Valid: true}, nil)
|
||||
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
|
||||
resp := w.Result()
|
||||
|
|
@ -1689,6 +1694,7 @@ func testsOK(w http.ResponseWriter, r *http.Request) {}
|
|||
type handlersWrapper struct {
|
||||
cfg *viper.Viper
|
||||
um *user.ManagerMock
|
||||
am *apikey.ManagerMock
|
||||
h *Handlers
|
||||
}
|
||||
|
||||
|
|
@ -1697,11 +1703,13 @@ func newHandlersWrapper() *handlersWrapper {
|
|||
cfg.Set("server.baseURL", "baseURL")
|
||||
cfg.Set("server.oauth.github", map[string]string{})
|
||||
um := &user.ManagerMock{}
|
||||
h, _ := NewHandlers(context.Background(), um, cfg)
|
||||
am := &apikey.ManagerMock{}
|
||||
h, _ := NewHandlers(context.Background(), um, am, cfg)
|
||||
|
||||
return &handlersWrapper{
|
||||
cfg: cfg,
|
||||
um: um,
|
||||
am: am,
|
||||
h: h,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import "context"
|
|||
type APIKey struct {
|
||||
APIKeyID string `json:"api_key_id"`
|
||||
Name string `json:"name"`
|
||||
Secret string `json:"secret"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
|
@ -13,9 +14,16 @@ type APIKey struct {
|
|||
// APIKeyManager describes the methods an APIKeyManager implementation must
|
||||
// provide.
|
||||
type APIKeyManager interface {
|
||||
Add(ctx context.Context, ak *APIKey) ([]byte, error)
|
||||
Add(ctx context.Context, ak *APIKey) (*APIKey, error)
|
||||
Check(ctx context.Context, apiKeyID, apiKeySecret string) (*CheckAPIKeyOutput, error)
|
||||
Delete(ctx context.Context, apiKeyID string) error
|
||||
GetJSON(ctx context.Context, apiKeyID string) ([]byte, error)
|
||||
GetOwnedByUserJSON(ctx context.Context) ([]byte, error)
|
||||
Update(ctx context.Context, ak *APIKey) error
|
||||
}
|
||||
|
||||
// CheckAPIKeyOutput represents the output returned by the CheckApiKey method.
|
||||
type CheckAPIKeyOutput struct {
|
||||
Valid bool `json:"valid"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,12 +5,6 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// CheckAPIKeyOutput represents the output returned by the CheckApiKey method.
|
||||
type CheckAPIKeyOutput struct {
|
||||
Valid bool `json:"valid"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// CheckCredentialsOutput represents the output returned by the
|
||||
// CheckCredentials method.
|
||||
type CheckCredentialsOutput struct {
|
||||
|
|
@ -26,7 +20,7 @@ type CheckSessionOutput struct {
|
|||
|
||||
// Session represents some information about a user session.
|
||||
type Session struct {
|
||||
SessionID []byte `json:"session_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
IP string `json:"ip"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
|
|
@ -68,12 +62,11 @@ var UserIDKey = userIDKey{}
|
|||
|
||||
// UserManager describes the methods a UserManager implementation must provide.
|
||||
type UserManager interface {
|
||||
ApproveSession(ctx context.Context, sessionID []byte, passcode string) error
|
||||
CheckAPIKey(ctx context.Context, apiKeyID, apiKeySecret string) (*CheckAPIKeyOutput, error)
|
||||
ApproveSession(ctx context.Context, sessionID, passcode string) error
|
||||
CheckAvailability(ctx context.Context, resourceKind, value string) (bool, error)
|
||||
CheckCredentials(ctx context.Context, email, password string) (*CheckCredentialsOutput, error)
|
||||
CheckSession(ctx context.Context, sessionID []byte, duration time.Duration) (*CheckSessionOutput, error)
|
||||
DeleteSession(ctx context.Context, sessionID []byte) error
|
||||
CheckSession(ctx context.Context, sessionID string, duration time.Duration) (*CheckSessionOutput, error)
|
||||
DeleteSession(ctx context.Context, sessionID string) error
|
||||
DisableTFA(ctx context.Context, passcode string) error
|
||||
EnableTFA(ctx context.Context, passcode string) error
|
||||
GetProfile(ctx context.Context) (*User, error)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package user
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
|
|
@ -10,7 +11,6 @@ import (
|
|||
"fmt"
|
||||
"image/png"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/artifacthub/hub/internal/email"
|
||||
|
|
@ -26,13 +26,12 @@ import (
|
|||
|
||||
const (
|
||||
// Database queries
|
||||
approveSessionDBQ = `select approve_session($1::bytea, $2::text)`
|
||||
approveSessionDBQ = `select approve_session($1::text, $2::text)`
|
||||
checkUserAliasAvailDBQ = `select check_user_alias_availability($1::text)`
|
||||
checkUserCredsDBQ = `select user_id, password from "user" where email = $1 and password is not null and email_verified = true`
|
||||
deleteSessionDBQ = `delete from session where session_id = $1`
|
||||
disableTFADBQ = `update "user" set tfa_enabled = false, tfa_url = null, tfa_recovery_codes = null where user_id = $1 and tfa_enabled = true`
|
||||
enableTFADBQ = `update "user" set tfa_enabled = true where user_id = $1`
|
||||
getAPIKeyInfoDBQ = `select user_id, secret from api_key where api_key_id = $1`
|
||||
getSessionDBQ = `select user_id, floor(extract(epoch from created_at)), approved from session where session_id = $1`
|
||||
getTFAConfigDBQ = `select get_user_tfa_config($1::uuid)`
|
||||
getUserEmailDBQ = `select email from "user" where user_id = $1`
|
||||
|
|
@ -40,15 +39,15 @@ const (
|
|||
getUserIDFromSessionIDDBQ = `select user_id from session where session_id = $1`
|
||||
getUserPasswordDBQ = `select password from "user" where user_id = $1 and password is not null`
|
||||
getUserProfileDBQ = `select get_user_profile($1::uuid)`
|
||||
registerPasswordResetCodeDBQ = `select register_password_reset_code($1::text)`
|
||||
registerSessionDBQ = `select session_id, approved from register_session($1::jsonb)`
|
||||
registerPasswordResetCodeDBQ = `select register_password_reset_code($1::text, $2::text)`
|
||||
registerSessionDBQ = `select register_session($1::jsonb)`
|
||||
registerUserDBQ = `select register_user($1::jsonb)`
|
||||
resetUserPasswordDBQ = `select reset_user_password($1::bytea, $2::text)`
|
||||
resetUserPasswordDBQ = `select reset_user_password($1::text, $2::text)`
|
||||
updateTFAInfoDBQ = `update "user" set tfa_url = $2, tfa_recovery_codes = $3 where user_id = $1`
|
||||
updateUserPasswordDBQ = `select update_user_password($1::uuid, $2::text, $3::text)`
|
||||
updateUserProfileDBQ = `select update_user_profile($1::uuid, $2::jsonb)`
|
||||
verifyEmailDBQ = `select verify_email($1::uuid)`
|
||||
verifyPasswordResetCodeDBQ = `select verify_password_reset_code($1::bytea)`
|
||||
verifyPasswordResetCodeDBQ = `select verify_password_reset_code($1::text)`
|
||||
|
||||
numRecoveryCodes = 10
|
||||
)
|
||||
|
|
@ -94,7 +93,7 @@ func NewManager(db hub.DB, es hub.EmailSender) *Manager {
|
|||
}
|
||||
|
||||
// ApproveSession approves a given session using the TFA passcode provided.
|
||||
func (m *Manager) ApproveSession(ctx context.Context, sessionID []byte, passcode string) error {
|
||||
func (m *Manager) ApproveSession(ctx context.Context, sessionID, passcode string) error {
|
||||
// Validate input
|
||||
if len(sessionID) == 0 {
|
||||
return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "sessionID not provided")
|
||||
|
|
@ -105,7 +104,7 @@ func (m *Manager) ApproveSession(ctx context.Context, sessionID []byte, passcode
|
|||
|
||||
// Get id of the user the session belongs to
|
||||
var userID string
|
||||
err := m.db.QueryRow(ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Scan(&userID)
|
||||
err := m.db.QueryRow(ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Scan(&userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -131,48 +130,10 @@ func (m *Manager) ApproveSession(ctx context.Context, sessionID []byte, passcode
|
|||
if validRecoveryCodeProvided {
|
||||
recoveryCode = passcode
|
||||
}
|
||||
_, err = m.db.Exec(ctx, approveSessionDBQ, hashSessionID(sessionID), recoveryCode)
|
||||
_, err = m.db.Exec(ctx, approveSessionDBQ, hash(sessionID), recoveryCode)
|
||||
return err
|
||||
}
|
||||
|
||||
// CheckAPIKey checks if the api key provided is valid.
|
||||
func (m *Manager) CheckAPIKey(ctx context.Context, apiKeyID, apiKeySecret string) (*hub.CheckAPIKeyOutput, error) {
|
||||
// Validate input
|
||||
if apiKeyID == "" || apiKeySecret == "" {
|
||||
return nil, fmt.Errorf("%w: %s", hub.ErrInvalidInput, "api key id or secret not provided")
|
||||
}
|
||||
|
||||
// Get key's user id and secret from database
|
||||
var userID, apiKeySecretHashed string
|
||||
err := m.db.QueryRow(ctx, getAPIKeyInfoDBQ, apiKeyID).Scan(&userID, &apiKeySecretHashed)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return &hub.CheckAPIKeyOutput{Valid: false}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the secret provided is valid
|
||||
switch {
|
||||
case strings.HasPrefix(apiKeySecretHashed, "$2a$"):
|
||||
// Bcrypt hash, will be deprecated soon
|
||||
err = bcrypt.CompareHashAndPassword([]byte(apiKeySecretHashed), []byte(apiKeySecret))
|
||||
if err != nil {
|
||||
return &hub.CheckAPIKeyOutput{Valid: false}, nil
|
||||
}
|
||||
default:
|
||||
// SHA512 hash
|
||||
if fmt.Sprintf("%x", sha512.Sum512([]byte(apiKeySecret))) != apiKeySecretHashed {
|
||||
return &hub.CheckAPIKeyOutput{Valid: false}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return &hub.CheckAPIKeyOutput{
|
||||
Valid: true,
|
||||
UserID: userID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CheckAvailability checks the availability of a given value for the provided
|
||||
// resource kind.
|
||||
func (m *Manager) CheckAvailability(ctx context.Context, resourceKind, value string) (bool, error) {
|
||||
|
|
@ -247,7 +208,7 @@ func (m *Manager) CheckCredentials(
|
|||
// CheckSession checks if the user session provided is valid.
|
||||
func (m *Manager) CheckSession(
|
||||
ctx context.Context,
|
||||
sessionID []byte,
|
||||
sessionID string,
|
||||
duration time.Duration,
|
||||
) (*hub.CheckSessionOutput, error) {
|
||||
// Validate input
|
||||
|
|
@ -262,7 +223,7 @@ func (m *Manager) CheckSession(
|
|||
var userID string
|
||||
var createdAt int64
|
||||
var approved bool
|
||||
err := m.db.QueryRow(ctx, getSessionDBQ, hashSessionID(sessionID)).Scan(&userID, &createdAt, &approved)
|
||||
err := m.db.QueryRow(ctx, getSessionDBQ, hash(sessionID)).Scan(&userID, &createdAt, &approved)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return &hub.CheckSessionOutput{Valid: false}, nil
|
||||
|
|
@ -288,14 +249,14 @@ func (m *Manager) CheckSession(
|
|||
}
|
||||
|
||||
// DeleteSession deletes a user session from the database.
|
||||
func (m *Manager) DeleteSession(ctx context.Context, sessionID []byte) error {
|
||||
func (m *Manager) DeleteSession(ctx context.Context, sessionID string) error {
|
||||
// Validate input
|
||||
if len(sessionID) == 0 {
|
||||
return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "session id not provided")
|
||||
}
|
||||
|
||||
// Delete session from database
|
||||
_, err := m.db.Exec(ctx, deleteSessionDBQ, hashSessionID(sessionID))
|
||||
_, err := m.db.Exec(ctx, deleteSessionDBQ, hash(sessionID))
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -461,17 +422,20 @@ func (m *Manager) RegisterPasswordResetCode(ctx context.Context, userEmail, base
|
|||
}
|
||||
|
||||
// Register password reset code in database
|
||||
var code []byte
|
||||
err := m.db.QueryRow(ctx, registerPasswordResetCodeDBQ, userEmail).Scan(&code)
|
||||
randomBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
code := base64.URLEncoding.EncodeToString(randomBytes)
|
||||
_, err := m.db.Exec(ctx, registerPasswordResetCodeDBQ, userEmail, hash(code))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send password reset email
|
||||
if code != nil && m.es != nil {
|
||||
codeB64 := base64.URLEncoding.EncodeToString(code)
|
||||
if m.es != nil {
|
||||
templateData := map[string]string{
|
||||
"link": fmt.Sprintf("%s/reset-password?code=%s", baseURL, codeB64),
|
||||
"link": fmt.Sprintf("%s/reset-password?code=%s", baseURL, code),
|
||||
}
|
||||
var emailBody bytes.Buffer
|
||||
if err := passwordResetTmpl.Execute(&emailBody, templateData); err != nil {
|
||||
|
|
@ -500,14 +464,23 @@ func (m *Manager) RegisterSession(ctx context.Context, session *hub.Session) (*h
|
|||
return nil, fmt.Errorf("%w: %s", hub.ErrInvalidInput, "invalid user id")
|
||||
}
|
||||
|
||||
// Generate session id
|
||||
randomBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
sessionIDHashed := hash(sessionID)
|
||||
|
||||
// Register session in database
|
||||
session.SessionID = sessionIDHashed
|
||||
sessionJSON, _ := json.Marshal(session)
|
||||
var sessionID []byte
|
||||
var approved bool
|
||||
err := m.db.QueryRow(ctx, registerSessionDBQ, sessionJSON).Scan(&sessionID, &approved)
|
||||
err := m.db.QueryRow(ctx, registerSessionDBQ, sessionJSON).Scan(&approved)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &hub.Session{
|
||||
SessionID: sessionID,
|
||||
UserID: session.UserID,
|
||||
|
|
@ -585,9 +558,9 @@ func (m *Manager) RegisterUser(ctx context.Context, user *hub.User, baseURL stri
|
|||
}
|
||||
|
||||
// ResetPassword resets the user password in the database.
|
||||
func (m *Manager) ResetPassword(ctx context.Context, codeB64, newPassword, baseURL string) error {
|
||||
func (m *Manager) ResetPassword(ctx context.Context, code, newPassword, baseURL string) error {
|
||||
// Validate input
|
||||
if codeB64 == "" {
|
||||
if code == "" {
|
||||
return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "code not provided")
|
||||
}
|
||||
if newPassword == "" {
|
||||
|
|
@ -610,12 +583,8 @@ func (m *Manager) ResetPassword(ctx context.Context, codeB64, newPassword, baseU
|
|||
}
|
||||
|
||||
// Reset user password in database
|
||||
code, err := base64.URLEncoding.DecodeString(codeB64)
|
||||
if err != nil {
|
||||
return ErrInvalidPasswordResetCode
|
||||
}
|
||||
var userEmail string
|
||||
err = m.db.QueryRow(ctx, resetUserPasswordDBQ, code, string(newHashed)).Scan(&userEmail)
|
||||
err = m.db.QueryRow(ctx, resetUserPasswordDBQ, hash(code), string(newHashed)).Scan(&userEmail)
|
||||
if err != nil {
|
||||
if err.Error() == errInvalidPasswordResetCodeDB.Error() {
|
||||
return ErrInvalidPasswordResetCode
|
||||
|
|
@ -772,28 +741,23 @@ func (m *Manager) VerifyEmail(ctx context.Context, code string) (bool, error) {
|
|||
}
|
||||
|
||||
// VerifyPasswordResetCode verifies if the provided code is valid.
|
||||
func (m *Manager) VerifyPasswordResetCode(ctx context.Context, codeB64 string) error {
|
||||
func (m *Manager) VerifyPasswordResetCode(ctx context.Context, code string) error {
|
||||
// Validate input
|
||||
if codeB64 == "" {
|
||||
if code == "" {
|
||||
return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "code not provided")
|
||||
}
|
||||
|
||||
// Verify password reset code in database
|
||||
code, err := base64.URLEncoding.DecodeString(codeB64)
|
||||
if err != nil {
|
||||
return ErrInvalidPasswordResetCode
|
||||
}
|
||||
_, err = m.db.Exec(ctx, verifyPasswordResetCodeDBQ, code)
|
||||
_, err := m.db.Exec(ctx, verifyPasswordResetCodeDBQ, hash(code))
|
||||
if err != nil && err.Error() == errInvalidPasswordResetCodeDB.Error() {
|
||||
return ErrInvalidPasswordResetCode
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// hashSessionID is a helper function that creates a sha512 hash of the
|
||||
// sessionID provided.
|
||||
func hashSessionID(sessionID []byte) string {
|
||||
return fmt.Sprintf("\\x%x", sha512.Sum512(sessionID))
|
||||
// hash is a helper function that creates a sha512 hash of the text provided.
|
||||
func hash(text string) string {
|
||||
return fmt.Sprintf("%x", sha512.Sum512([]byte(text)))
|
||||
}
|
||||
|
||||
// isValidRecoveryCode checks if the code provided is a valid recovery code.
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ package user
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -24,8 +22,8 @@ import (
|
|||
|
||||
func TestApproveSession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
sessionID := []byte("sessionID")
|
||||
hashedSessionID := hashSessionID([]byte("sessionID"))
|
||||
sessionID := "sessionID"
|
||||
hashedSessionID := hash(sessionID)
|
||||
opts := totp.GenerateOpts{
|
||||
Issuer: "Artifact Hub",
|
||||
AccountName: "test@email.com",
|
||||
|
|
@ -42,22 +40,17 @@ func TestApproveSession(t *testing.T) {
|
|||
t.Run("invalid input", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
errMsg string
|
||||
sessionID []byte
|
||||
sessionID string
|
||||
passcode string
|
||||
}{
|
||||
{
|
||||
"sessionID not provided",
|
||||
nil,
|
||||
"123456",
|
||||
},
|
||||
{
|
||||
"sessionID not provided",
|
||||
[]byte(""),
|
||||
"",
|
||||
"123456",
|
||||
},
|
||||
{
|
||||
"passcode not provided",
|
||||
[]byte("sessionID"),
|
||||
"sessionID",
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
|
@ -87,7 +80,7 @@ func TestApproveSession(t *testing.T) {
|
|||
t.Run("error getting tfa config from database", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil)
|
||||
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Return("userID", nil)
|
||||
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(nil, tests.ErrFakeDB)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
|
|
@ -99,7 +92,7 @@ func TestApproveSession(t *testing.T) {
|
|||
t.Run("invalid passcode provided", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil)
|
||||
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Return("userID", nil)
|
||||
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
|
|
@ -111,7 +104,7 @@ func TestApproveSession(t *testing.T) {
|
|||
t.Run("session approved successfully", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil)
|
||||
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Return("userID", nil)
|
||||
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
||||
db.On("Exec", ctx, approveSessionDBQ, hashedSessionID, "").Return(nil)
|
||||
m := NewManager(db, nil)
|
||||
|
|
@ -125,7 +118,7 @@ func TestApproveSession(t *testing.T) {
|
|||
t.Run("session approved successfully (using valid recovery code)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil)
|
||||
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Return("userID", nil)
|
||||
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
|
||||
db.On("Exec", ctx, approveSessionDBQ, hashedSessionID, code1).Return(nil)
|
||||
m := NewManager(db, nil)
|
||||
|
|
@ -136,92 +129,6 @@ func TestApproveSession(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestCheckAPIKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("invalid input", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
errMsg string
|
||||
apiKeyID string
|
||||
apiKeySecret string
|
||||
}{
|
||||
{
|
||||
"api key id or secret not provided",
|
||||
"",
|
||||
"secret",
|
||||
},
|
||||
{
|
||||
"api key id or secret not provided",
|
||||
"key",
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.errMsg, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewManager(nil, nil)
|
||||
_, err := m.CheckAPIKey(ctx, tc.apiKeyID, tc.apiKeySecret)
|
||||
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
||||
assert.Contains(t, err.Error(), tc.errMsg)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("key info not found in database", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return(nil, pgx.ErrNoRows)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
output, err := m.CheckAPIKey(ctx, "keyID", "secret")
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, output.Valid)
|
||||
assert.Empty(t, output.UserID)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("error getting key info from database", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return(nil, tests.ErrFakeDB)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
output, err := m.CheckAPIKey(ctx, "keyID", "secret")
|
||||
assert.Equal(t, tests.ErrFakeDB, err)
|
||||
assert.Nil(t, output)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("valid key (secret hashed with bcrypt)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
secretHashed, _ := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.DefaultCost)
|
||||
db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return([]interface{}{"userID", string(secretHashed)}, nil)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
output, err := m.CheckAPIKey(ctx, "keyID", "secret")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, output.Valid)
|
||||
assert.Equal(t, "userID", output.UserID)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("valid key (secret hashed with sha512)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
secretHashed := fmt.Sprintf("%x", sha512.Sum512([]byte("secret")))
|
||||
db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return([]interface{}{"userID", secretHashed}, nil)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
output, err := m.CheckAPIKey(ctx, "keyID", "secret")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, output.Valid)
|
||||
assert.Equal(t, "userID", output.UserID)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckAvailability(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
|
@ -385,27 +292,23 @@ func TestCheckCredentials(t *testing.T) {
|
|||
|
||||
func TestCheckSession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
hashedSessionID := hashSessionID([]byte("sessionID"))
|
||||
sessionID := "sessionID"
|
||||
hashedSessionID := hash(sessionID)
|
||||
|
||||
t.Run("invalid input", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
errMsg string
|
||||
sessionID []byte
|
||||
sessionID string
|
||||
duration time.Duration
|
||||
}{
|
||||
{
|
||||
"session id not provided",
|
||||
nil,
|
||||
10,
|
||||
},
|
||||
{
|
||||
"session id not provided",
|
||||
[]byte(""),
|
||||
"",
|
||||
10,
|
||||
},
|
||||
{
|
||||
"duration not provided",
|
||||
[]byte("sessionID"),
|
||||
"sessionID",
|
||||
0,
|
||||
},
|
||||
}
|
||||
|
|
@ -427,7 +330,7 @@ func TestCheckSession(t *testing.T) {
|
|||
db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return(nil, pgx.ErrNoRows)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
|
||||
output, err := m.CheckSession(ctx, sessionID, 1*time.Hour)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, output.Valid)
|
||||
assert.Empty(t, output.UserID)
|
||||
|
|
@ -440,7 +343,7 @@ func TestCheckSession(t *testing.T) {
|
|||
db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return(nil, tests.ErrFakeDB)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
|
||||
output, err := m.CheckSession(ctx, sessionID, 1*time.Hour)
|
||||
assert.Equal(t, tests.ErrFakeDB, err)
|
||||
assert.Nil(t, output)
|
||||
db.AssertExpectations(t)
|
||||
|
|
@ -456,7 +359,7 @@ func TestCheckSession(t *testing.T) {
|
|||
}, nil)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
|
||||
output, err := m.CheckSession(ctx, sessionID, 1*time.Hour)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, output.Valid)
|
||||
assert.Empty(t, output.UserID)
|
||||
|
|
@ -473,7 +376,7 @@ func TestCheckSession(t *testing.T) {
|
|||
}, nil)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
|
||||
output, err := m.CheckSession(ctx, sessionID, 1*time.Hour)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, output.Valid)
|
||||
assert.Empty(t, output.UserID)
|
||||
|
|
@ -490,7 +393,7 @@ func TestCheckSession(t *testing.T) {
|
|||
}, nil)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
|
||||
output, err := m.CheckSession(ctx, sessionID, 1*time.Hour)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, output.Valid)
|
||||
assert.Equal(t, "userID", output.UserID)
|
||||
|
|
@ -500,22 +403,18 @@ func TestCheckSession(t *testing.T) {
|
|||
|
||||
func TestDeleteSession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
hashedSessionID := hashSessionID([]byte("sessionID"))
|
||||
sessionID := "sessionID"
|
||||
hashedSessionID := hash(sessionID)
|
||||
|
||||
t.Run("invalid input", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
errMsg string
|
||||
sessionID []byte
|
||||
sessionID string
|
||||
duration time.Duration
|
||||
}{
|
||||
{
|
||||
"session id not provided",
|
||||
nil,
|
||||
10,
|
||||
},
|
||||
{
|
||||
"session id not provided",
|
||||
[]byte(""),
|
||||
"",
|
||||
10,
|
||||
},
|
||||
}
|
||||
|
|
@ -553,7 +452,7 @@ func TestDeleteSession(t *testing.T) {
|
|||
db.On("Exec", ctx, deleteSessionDBQ, hashedSessionID).Return(tc.dbResponse)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
err := m.DeleteSession(ctx, []byte("sessionID"))
|
||||
err := m.DeleteSession(ctx, sessionID)
|
||||
assert.Equal(t, tc.dbResponse, err)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
|
|
@ -879,12 +778,7 @@ func TestGetUserID(t *testing.T) {
|
|||
|
||||
func TestRegisterSession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
s := &hub.Session{
|
||||
UserID: "00000000-0000-0000-0000-000000000001",
|
||||
IP: "192.168.1.100",
|
||||
UserAgent: "Safari 13.0.5",
|
||||
}
|
||||
userID := "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
t.Run("invalid input", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
|
|
@ -919,25 +813,29 @@ func TestRegisterSession(t *testing.T) {
|
|||
db.On("QueryRow", ctx, registerSessionDBQ, mock.Anything).Return(nil, tests.ErrFakeDB)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
sessionID, err := m.RegisterSession(ctx, s)
|
||||
sIN := &hub.Session{UserID: userID}
|
||||
sOUT, err := m.RegisterSession(ctx, sIN)
|
||||
assert.Equal(t, tests.ErrFakeDB, err)
|
||||
assert.Nil(t, sessionID)
|
||||
assert.Nil(t, sOUT)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("successful session registration", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, registerSessionDBQ, mock.Anything).Return([]interface{}{
|
||||
[]byte("sessionID"),
|
||||
true,
|
||||
}, nil)
|
||||
db.On("QueryRow", ctx, registerSessionDBQ, mock.Anything).Return(true, nil)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
session, err := m.RegisterSession(ctx, s)
|
||||
sIN := &hub.Session{
|
||||
UserID: userID,
|
||||
IP: "192.168.1.100",
|
||||
UserAgent: "Safari 13.0.5",
|
||||
}
|
||||
sOUT, err := m.RegisterSession(ctx, sIN)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("sessionID"), session.SessionID)
|
||||
assert.True(t, session.Approved)
|
||||
assert.NotEmpty(t, sOUT.SessionID)
|
||||
assert.Equal(t, userID, sOUT.UserID)
|
||||
assert.True(t, sOUT.Approved)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
|
@ -995,7 +893,7 @@ func TestRegisterPasswordResetCode(t *testing.T) {
|
|||
t.Run(tc.description, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, registerPasswordResetCodeDBQ, "email@email.com").Return([]byte("code"), nil)
|
||||
db.On("Exec", ctx, registerPasswordResetCodeDBQ, "email@email.com", mock.Anything).Return(nil)
|
||||
es := &email.SenderMock{}
|
||||
es.On("SendEmail", mock.Anything).Return(tc.emailSenderResponse)
|
||||
m := NewManager(db, es)
|
||||
|
|
@ -1011,7 +909,7 @@ func TestRegisterPasswordResetCode(t *testing.T) {
|
|||
t.Run("database error registering password reset code", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, registerPasswordResetCodeDBQ, "email@email.com").Return(nil, tests.ErrFakeDB)
|
||||
db.On("Exec", ctx, registerPasswordResetCodeDBQ, "email@email.com", mock.Anything).Return(tests.ErrFakeDB)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
err := m.RegisterPasswordResetCode(ctx, "email@email.com", "http://baseurl.com")
|
||||
|
|
@ -1131,15 +1029,15 @@ func TestRegisterUser(t *testing.T) {
|
|||
|
||||
func TestResetPassword(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
code := []byte("code")
|
||||
codeB64 := base64.URLEncoding.EncodeToString(code)
|
||||
code := "code"
|
||||
codeHashed := hash(code)
|
||||
newPassword := "a66bV.Xp2" // #nosec
|
||||
baseURL := "http://baseurl.com"
|
||||
|
||||
t.Run("invalid input", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
errMsg string
|
||||
codeB64 string
|
||||
code string
|
||||
newPassword string
|
||||
baseURL string
|
||||
}{
|
||||
|
|
@ -1151,19 +1049,19 @@ func TestResetPassword(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"new password not provided",
|
||||
"code",
|
||||
code,
|
||||
"",
|
||||
baseURL,
|
||||
},
|
||||
{
|
||||
"invalid base url",
|
||||
"code",
|
||||
code,
|
||||
newPassword,
|
||||
"invalid",
|
||||
},
|
||||
{
|
||||
"insecure password",
|
||||
"code",
|
||||
code,
|
||||
"password",
|
||||
baseURL,
|
||||
},
|
||||
|
|
@ -1174,7 +1072,7 @@ func TestResetPassword(t *testing.T) {
|
|||
t.Parallel()
|
||||
es := &email.SenderMock{}
|
||||
m := NewManager(nil, es)
|
||||
err := m.ResetPassword(ctx, tc.codeB64, tc.newPassword, tc.baseURL)
|
||||
err := m.ResetPassword(ctx, tc.code, tc.newPassword, tc.baseURL)
|
||||
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
||||
assert.Contains(t, err.Error(), tc.errMsg)
|
||||
})
|
||||
|
|
@ -1200,10 +1098,10 @@ func TestResetPassword(t *testing.T) {
|
|||
t.Run(tc.dbErr.Error(), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, resetUserPasswordDBQ, code, mock.Anything).Return("", tc.dbErr)
|
||||
db.On("QueryRow", ctx, resetUserPasswordDBQ, codeHashed, mock.Anything).Return("", tc.dbErr)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
err := m.ResetPassword(ctx, codeB64, newPassword, baseURL)
|
||||
err := m.ResetPassword(ctx, code, newPassword, baseURL)
|
||||
assert.Equal(t, tc.expectedErr, err)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
|
|
@ -1229,12 +1127,12 @@ func TestResetPassword(t *testing.T) {
|
|||
t.Run(tc.description, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("QueryRow", ctx, resetUserPasswordDBQ, code, mock.Anything).Return("email", nil)
|
||||
db.On("QueryRow", ctx, resetUserPasswordDBQ, codeHashed, mock.Anything).Return("email", nil)
|
||||
es := &email.SenderMock{}
|
||||
es.On("SendEmail", mock.Anything).Return(tc.emailSenderResponse)
|
||||
m := NewManager(db, es)
|
||||
|
||||
err := m.ResetPassword(ctx, codeB64, newPassword, baseURL)
|
||||
err := m.ResetPassword(ctx, code, newPassword, baseURL)
|
||||
assert.Equal(t, tc.emailSenderResponse, err)
|
||||
db.AssertExpectations(t)
|
||||
es.AssertExpectations(t)
|
||||
|
|
@ -1494,13 +1392,13 @@ func TestVerifyEmail(t *testing.T) {
|
|||
|
||||
func TestVerifyPasswordResetCode(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
code := []byte("code")
|
||||
codeB64 := base64.URLEncoding.EncodeToString(code)
|
||||
code := "code"
|
||||
codeHashed := hash(code)
|
||||
|
||||
t.Run("invalid input", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
errMsg string
|
||||
codeB64 string
|
||||
errMsg string
|
||||
code string
|
||||
}{
|
||||
{
|
||||
"code not provided",
|
||||
|
|
@ -1512,7 +1410,7 @@ func TestVerifyPasswordResetCode(t *testing.T) {
|
|||
t.Run(tc.errMsg, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewManager(nil, nil)
|
||||
err := m.VerifyPasswordResetCode(ctx, tc.codeB64)
|
||||
err := m.VerifyPasswordResetCode(ctx, tc.code)
|
||||
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
|
||||
assert.Contains(t, err.Error(), tc.errMsg)
|
||||
})
|
||||
|
|
@ -1538,10 +1436,10 @@ func TestVerifyPasswordResetCode(t *testing.T) {
|
|||
t.Run(tc.dbErr.Error(), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("Exec", ctx, verifyPasswordResetCodeDBQ, code).Return(tc.dbErr)
|
||||
db.On("Exec", ctx, verifyPasswordResetCodeDBQ, codeHashed).Return(tc.dbErr)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
err := m.VerifyPasswordResetCode(ctx, codeB64)
|
||||
err := m.VerifyPasswordResetCode(ctx, code)
|
||||
assert.Equal(t, tc.expectedErr, err)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
|
|
@ -1551,10 +1449,10 @@ func TestVerifyPasswordResetCode(t *testing.T) {
|
|||
t.Run("password code verified successfully in database", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := &tests.DBMock{}
|
||||
db.On("Exec", ctx, verifyPasswordResetCodeDBQ, code).Return(nil)
|
||||
db.On("Exec", ctx, verifyPasswordResetCodeDBQ, codeHashed).Return(nil)
|
||||
m := NewManager(db, nil)
|
||||
|
||||
err := m.VerifyPasswordResetCode(ctx, codeB64)
|
||||
err := m.VerifyPasswordResetCode(ctx, code)
|
||||
assert.Equal(t, nil, err)
|
||||
db.AssertExpectations(t)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -14,18 +14,11 @@ type ManagerMock struct {
|
|||
}
|
||||
|
||||
// ApproveSession implements the UserManager interface.
|
||||
func (m *ManagerMock) ApproveSession(ctx context.Context, sessionID []byte, passcode string) error {
|
||||
func (m *ManagerMock) ApproveSession(ctx context.Context, sessionID, passcode string) error {
|
||||
args := m.Called(ctx, sessionID, passcode)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// CheckAPIKey implements the UserManager interface.
|
||||
func (m *ManagerMock) CheckAPIKey(ctx context.Context, apiKeyID, apiKeySecret string) (*hub.CheckAPIKeyOutput, error) {
|
||||
args := m.Called(ctx, apiKeyID, apiKeySecret)
|
||||
data, _ := args.Get(0).(*hub.CheckAPIKeyOutput)
|
||||
return data, args.Error(1)
|
||||
}
|
||||
|
||||
// CheckAvailability implements the UserManager interface.
|
||||
func (m *ManagerMock) CheckAvailability(ctx context.Context, resourceKind, value string) (bool, error) {
|
||||
args := m.Called(ctx, resourceKind, value)
|
||||
|
|
@ -46,7 +39,7 @@ func (m *ManagerMock) CheckCredentials(
|
|||
// CheckSession implements the UserManager interface.
|
||||
func (m *ManagerMock) CheckSession(
|
||||
ctx context.Context,
|
||||
sessionID []byte,
|
||||
sessionID string,
|
||||
duration time.Duration,
|
||||
) (*hub.CheckSessionOutput, error) {
|
||||
args := m.Called(ctx, sessionID, duration)
|
||||
|
|
@ -55,7 +48,7 @@ func (m *ManagerMock) CheckSession(
|
|||
}
|
||||
|
||||
// DeleteSession implements the UserManager interface.
|
||||
func (m *ManagerMock) DeleteSession(ctx context.Context, sessionID []byte) error {
|
||||
func (m *ManagerMock) DeleteSession(ctx context.Context, sessionID string) error {
|
||||
args := m.Called(ctx, sessionID)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue