Some refactoring in user manager and handlers (#1291)

Signed-off-by: Sergio Castaño Arteaga <tegioz@icloud.com>
This commit is contained in:
Sergio C. Arteaga 2021-05-06 22:40:37 +02:00 committed by GitHub
parent 9146bf6d60
commit 9726b926ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 416 additions and 392 deletions

View File

@ -1,9 +1,8 @@
-- add_api_key adds the provided api key to the database.
create or replace function add_api_key(p_api_key jsonb)
returns setof json as $$
returns uuid as $$
declare
v_api_key_id uuid;
v_api_key_secret text := encode(gen_random_bytes(32), 'base64');
begin
insert into api_key (
name,
@ -11,13 +10,10 @@ begin
user_id
) values (
p_api_key->>'name',
encode(sha512(v_api_key_secret::bytea), 'hex'),
p_api_key->>'secret',
(p_api_key->>'user_id')::uuid
) returning api_key_id into v_api_key_id;
return query select json_build_object(
'api_key_id', v_api_key_id,
'secret', v_api_key_secret
);
return v_api_key_id;
end
$$ language plpgsql;

View File

@ -1,5 +1,5 @@
-- approvee_session approves the provided session in the database.
create or replace function approve_session(p_session_id bytea, p_recovery_code text)
create or replace function approve_session(p_session_id text, p_recovery_code text)
returns void as $$
begin
-- Mark session as approved

View File

@ -1,18 +1,15 @@
-- register_password_reset_code registers a password reset code for the user
-- identified by the email provided.
create or replace function register_password_reset_code(p_email text)
returns bytea as $$
declare
v_code bytea := gen_random_bytes(32);
create or replace function register_password_reset_code(p_email text, p_code text)
returns void as $$
begin
insert into password_reset_code (password_reset_code_id, user_id)
select sha512(v_code), user_id from "user" where email = p_email and email_verified = true
select p_code, user_id from "user" where email = p_email and email_verified = true
on conflict (user_id) do update set
password_reset_code_id = sha512(v_code),
password_reset_code_id = p_code,
created_at = current_timestamp;
if not found then
raise 'invalid email';
end if;
return v_code;
end
$$ language plpgsql;

View File

@ -1,8 +1,7 @@
-- register_session registers the provided session in the database.
create or replace function register_session(p_session jsonb)
returns table (session_id bytea, approved boolean) as $$
returns boolean as $$
declare
v_session_id bytea := gen_random_bytes(32);
v_approved boolean;
begin
-- Check if the session requires approval or not. When the user has enabled
@ -22,13 +21,13 @@ begin
user_agent,
approved
) values (
sha512(v_session_id),
p_session->>'session_id',
(p_session->>'user_id')::uuid,
nullif(p_session->>'ip', '')::inet,
nullif(p_session->>'user_agent', ''),
v_approved
);
return query select v_session_id, v_approved;
return v_approved;
end
$$ language plpgsql;

View File

@ -1,6 +1,6 @@
-- reset_user_password resets the password of the user associated to the code
-- provided if it is still valid, returning the email of the user.
create or replace function reset_user_password(p_code bytea, p_new_password text)
create or replace function reset_user_password(p_code text, p_new_password text)
returns text as $$
declare
v_user_id uuid;
@ -13,7 +13,7 @@ begin
select u.user_id, u.email into v_user_id, v_email
from "user" u
join password_reset_code prc using (user_id)
where password_reset_code_id = sha512(p_code);
where password_reset_code_id = p_code;
-- Update user password
update "user" set password = p_new_password where user_id = v_user_id;

View File

@ -1,10 +1,10 @@
-- verify_password_reset_code verifies is the password reset code provided is
-- valid. The code must exist and not have expired.
create or replace function verify_password_reset_code(p_code bytea)
create or replace function verify_password_reset_code(p_code text)
returns void as $$
begin
perform from password_reset_code
where password_reset_code_id = sha512(p_code)
where password_reset_code_id = p_code
and created_at + '15 minute'::interval > current_timestamp;
if not found then
raise 'invalid password reset code';

View File

@ -1,6 +1,6 @@
delete from password_reset_code;
alter table password_reset_code alter column password_reset_code_id drop default;
alter table password_reset_code alter column password_reset_code_id type bytea USING password_reset_code_id::text::bytea;
alter table password_reset_code alter column password_reset_code_id type bytea using password_reset_code_id::text::bytea;
drop function if exists register_password_reset_code(text);
---- create above / drop below ----

View File

@ -0,0 +1,15 @@
delete from password_reset_code;
alter table password_reset_code alter column password_reset_code_id type text;
alter table session alter column session_id type text using substring(session_id::bytea::text from 3);
drop function if exists add_api_key(jsonb);
drop function if exists approve_session(bytea, text);
drop function if exists register_password_reset_code(text);
drop function if exists register_session(jsonb);
drop function if exists reset_user_password(bytea, text);
drop function if exists verify_password_reset_code(bytea);
---- create above / drop below ----
delete from password_reset_code;
alter table password_reset_code alter column password_reset_code_id type bytea;
alter table session alter column session_id type bytea using session_id::text::bytea;

View File

@ -13,6 +13,7 @@ values (:'user1ID', 'user1', 'user1@email.com');
select add_api_key('
{
"name": "apikey1",
"secret": "hashed-secret",
"user_id": "00000000-0000-0000-0000-000000000001"
}
'::jsonb);
@ -22,12 +23,14 @@ select results_eq(
$$
select
name,
secret,
user_id
from api_key
$$,
$$
values (
'apikey1',
'hashed-secret',
'00000000-0000-0000-0000-000000000001'::uuid
)
$$,

View File

@ -1,6 +1,6 @@
-- Start transaction and plan tests
begin;
select plan(5);
select plan(4);
-- Seed user
insert into "user" (user_id, alias, email, email_verified)
@ -9,10 +9,10 @@ insert into "user" (user_id, alias, email, email_verified)
values ('00000000-0000-0000-0000-000000000002', 'user2', 'user2@email.com', false);
-- Register password reset code
select register_password_reset_code('user1@email.com') as code1 \gset
select register_password_reset_code('user1@email.com', 'code1');
select is(
password_reset_code_id,
sha512(:'code1'),
'code1',
'Password reset code for user1 should be registered'
)
from password_reset_code
@ -20,24 +20,19 @@ join "user" using (user_id)
where alias = 'user1';
-- Register another password reset code for the same user
select register_password_reset_code('user1@email.com') as code2 \gset
select register_password_reset_code('user1@email.com', 'code2');
select is(
password_reset_code_id,
sha512(:'code2'),
'code2',
'Password reset code for user1 should have been updated'
)
from password_reset_code
join "user" using (user_id)
where alias = 'user1';
select isnt(
:'code1'::bytea,
:'code2'::bytea,
'Password reset code must have changed'
);
-- Try registering password reset code using non verified email
select throws_ok(
$$ select register_password_reset_code('user2@email.com') $$,
$$ select register_password_reset_code('user2@email.com', 'code') $$,
'P0001',
'invalid email',
'No password reset code should be registered for non verified email user2@email.com'
@ -45,7 +40,7 @@ select throws_ok(
-- Try registering password reset code using unregistered email
select throws_ok(
$$ select register_password_reset_code('user3@email.com') $$,
$$ select register_password_reset_code('user3@email.com', 'code') $$,
'P0001',
'invalid email',
'No password reset code should be registered for unregistered email user3@email.com'

View File

@ -1,6 +1,6 @@
-- Start transaction and plan tests
begin;
select plan(6);
select plan(2);
-- Seed user
insert into "user" (user_id, alias, email)
@ -9,18 +9,20 @@ insert into "user" (user_id, alias, email, tfa_enabled)
values ('00000000-0000-0000-0000-000000000002', 'user2', 'user2@email.com', true);
-- Register session for user with tfa disabled
select session_id, approved from register_session('
select register_session('
{
"session_id": "hashed-session-id-user1",
"user_id": "00000000-0000-0000-0000-000000000001",
"ip": "192.168.1.100",
"user_agent": "Safari 13.0.5"
}
') \gset
') as approved \gset
-- Check if session registration succeeded
select results_eq(
$$
select
session_id,
user_id,
ip,
user_agent,
@ -30,6 +32,7 @@ select results_eq(
$$,
$$
values (
'hashed-session-id-user1',
'00000000-0000-0000-0000-000000000001'::uuid,
'192.168.1.100'::inet,
'Safari 13.0.5',
@ -38,53 +41,32 @@ select results_eq(
$$,
'Session for user1 should exist'
);
select is(
session_id,
sha512(:'session_id'),
'Returned session id for user1 should match value stored'
)
from session where user_id = '00000000-0000-0000-0000-000000000001';
select is(
true,
:'approved',
'Returned approved value for user1 should be true'
)
from session where user_id = '00000000-0000-0000-0000-000000000001';
-- Register session for user with tfa enabled
select session_id, approved from register_session('
select register_session('
{
"session_id": "hashed-session-id-user2",
"user_id": "00000000-0000-0000-0000-000000000002"
}
') \gset
') as approved \gset
-- Check if session registration succeeded
select results_eq(
$$
select user_id, approved
select
session_id,
approved
from session
where user_id = '00000000-0000-0000-0000-000000000002'
$$,
$$
values (
'00000000-0000-0000-0000-000000000002'::uuid,
'hashed-session-id-user2',
false
)
$$,
'Session for user2 should exist'
);
select is(
session_id,
sha512(:'session_id'),
'Returned session id for user2 should match value stored'
)
from session where user_id = '00000000-0000-0000-0000-000000000002';
select is(
false,
:'approved',
'Returned approved value for user2 should be false'
)
from session where user_id = '00000000-0000-0000-0000-000000000002';
-- Finish tests and rollback transaction
select * from finish();

View File

@ -14,9 +14,9 @@ values (:'user1ID', 'user1', 'user1@email.com');
insert into "user" (user_id, alias, email)
values (:'user2ID', 'user2', 'user2@email.com');
insert into password_reset_code (password_reset_code_id, user_id, created_at)
values (sha512(:'code1ID'), :'user1ID', current_timestamp);
values (:'code1ID', :'user1ID', current_timestamp);
insert into password_reset_code (password_reset_code_id, user_id, created_at)
values (sha512(:'code2ID'), :'user2ID', current_timestamp - '30 minute'::interval);
values (:'code2ID', :'user2ID', current_timestamp - '30 minute'::interval);
insert into session (session_id, user_id) values (gen_random_bytes(32), :'user1ID');
-- Password reset should fail in the following cases

View File

@ -14,7 +14,7 @@ values (:'user1ID', 'user1', 'user1@email.com');
insert into "user" (user_id, alias, email)
values (:'user2ID', 'user2', 'user2@email.com');
insert into password_reset_code (password_reset_code_id, user_id, created_at)
values (sha512(:'code1ID'), :'user1ID', current_timestamp - '5 minute'::interval);
values (:'code1ID', :'user1ID', current_timestamp - '5 minute'::interval);
insert into password_reset_code (password_reset_code_id, user_id, created_at)
values (:'code2ID', :'user2ID', current_timestamp - '30 minute'::interval);

View File

@ -2,21 +2,29 @@ package apikey
import (
"context"
"crypto/rand"
"crypto/sha512"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/artifacthub/hub/internal/hub"
"github.com/artifacthub/hub/internal/util"
"github.com/jackc/pgx/v4"
"github.com/satori/uuid"
"golang.org/x/crypto/bcrypt"
)
const (
// Database queries
addAPIKeyDBQ = `select add_api_key($1::jsonb)`
deleteAPIKeyDBQ = `select delete_api_key($1::uuid, $2::uuid)`
getAPIKeyDBQ = `select get_api_key($1::uuid, $2::uuid)`
getUserAPIKeysDBQ = `select get_user_api_keys($1::uuid)`
updateAPIKeyDBQ = `select update_api_key($1::jsonb)`
addAPIKeyDBQ = `select add_api_key($1::jsonb)`
deleteAPIKeyDBQ = `select delete_api_key($1::uuid, $2::uuid)`
getAPIKeyDBQ = `select get_api_key($1::uuid, $2::uuid)`
getAPIKeyUserIDDBQ = `select user_id, secret from api_key where api_key_id = $1`
getUserAPIKeysDBQ = `select get_user_api_keys($1::uuid)`
updateAPIKeyDBQ = `select update_api_key($1::jsonb)`
)
// Manager provides an API to manage api keys.
@ -32,7 +40,7 @@ func NewManager(db hub.DB) *Manager {
}
// Add adds the provided api key to the database.
func (m *Manager) Add(ctx context.Context, ak *hub.APIKey) ([]byte, error) {
func (m *Manager) Add(ctx context.Context, ak *hub.APIKey) (*hub.APIKey, error) {
ak.UserID = ctx.Value(hub.UserIDKey).(string)
// Validate input
@ -40,9 +48,64 @@ func (m *Manager) Add(ctx context.Context, ak *hub.APIKey) ([]byte, error) {
return nil, fmt.Errorf("%w: %s", hub.ErrInvalidInput, "name not provided")
}
// Generate API key secret
randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil {
return nil, err
}
apiKeySecret := base64.StdEncoding.EncodeToString(randomBytes)
apiKeySecretHashed := fmt.Sprintf("%x", sha512.Sum512([]byte(apiKeySecret)))
// Add api key to the database
var apiKeyID string
ak.Secret = apiKeySecretHashed
akJSON, _ := json.Marshal(ak)
return util.DBQueryJSON(ctx, m.db, addAPIKeyDBQ, akJSON)
if err := m.db.QueryRow(ctx, addAPIKeyDBQ, akJSON).Scan(&apiKeyID); err != nil {
return nil, err
}
return &hub.APIKey{
APIKeyID: apiKeyID,
Secret: apiKeySecret,
}, nil
}
// Check checks if the api key provided is valid.
func (m *Manager) Check(ctx context.Context, apiKeyID, apiKeySecret string) (*hub.CheckAPIKeyOutput, error) {
// Validate input
if apiKeyID == "" || apiKeySecret == "" {
return nil, fmt.Errorf("%w: %s", hub.ErrInvalidInput, "api key id or secret not provided")
}
// Get key's user id and secret from database
var userID, apiKeySecretHashed string
err := m.db.QueryRow(ctx, getAPIKeyUserIDDBQ, apiKeyID).Scan(&userID, &apiKeySecretHashed)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return &hub.CheckAPIKeyOutput{Valid: false}, nil
}
return nil, err
}
// Check if the secret provided is valid
switch {
case strings.HasPrefix(apiKeySecretHashed, "$2a$"):
// Bcrypt hash, will be deprecated soon
err = bcrypt.CompareHashAndPassword([]byte(apiKeySecretHashed), []byte(apiKeySecret))
if err != nil {
return &hub.CheckAPIKeyOutput{Valid: false}, nil
}
default:
// SHA512 hash
if fmt.Sprintf("%x", sha512.Sum512([]byte(apiKeySecret))) != apiKeySecretHashed {
return &hub.CheckAPIKeyOutput{Valid: false}, nil
}
}
return &hub.CheckAPIKeyOutput{
Valid: true,
UserID: userID,
}, nil
}
// Delete deletes the provided api key from the database.

View File

@ -2,13 +2,18 @@ package apikey
import (
"context"
"crypto/sha512"
"encoding/json"
"errors"
"fmt"
"testing"
"github.com/artifacthub/hub/internal/hub"
"github.com/artifacthub/hub/internal/tests"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/crypto/bcrypt"
)
const apiKeyID = "00000000-0000-0000-0000-000000000001"
@ -60,14 +65,13 @@ func TestAdd(t *testing.T) {
Name: "apikey1",
UserID: "userID",
}
akJSON, _ := json.Marshal(ak)
db := &tests.DBMock{}
db.On("QueryRow", ctx, addAPIKeyDBQ, akJSON).Return(nil, tests.ErrFakeDB)
db.On("QueryRow", ctx, addAPIKeyDBQ, mock.Anything).Return(nil, tests.ErrFakeDB)
m := NewManager(db)
keyInfoJSON, err := m.Add(ctx, ak)
output, err := m.Add(ctx, ak)
assert.Equal(t, tests.ErrFakeDB, err)
assert.Nil(t, keyInfoJSON)
assert.Nil(t, output)
db.AssertExpectations(t)
})
@ -77,14 +81,100 @@ func TestAdd(t *testing.T) {
Name: "apikey1",
UserID: "userID",
}
akJSON, _ := json.Marshal(ak)
db := &tests.DBMock{}
db.On("QueryRow", ctx, addAPIKeyDBQ, akJSON).Return([]byte("keyInfoJSON"), nil)
db.On("QueryRow", ctx, addAPIKeyDBQ, mock.Anything).Return("apiKeyID", nil)
m := NewManager(db)
keyInfoJSON, err := m.Add(ctx, ak)
output, err := m.Add(ctx, ak)
assert.NoError(t, err)
assert.Equal(t, []byte("keyInfoJSON"), keyInfoJSON)
assert.Equal(t, "apiKeyID", output.APIKeyID)
assert.NotEmpty(t, output.Secret)
db.AssertExpectations(t)
})
}
func TestCheck(t *testing.T) {
ctx := context.Background()
t.Run("invalid input", func(t *testing.T) {
testCases := []struct {
errMsg string
apiKeyID string
apiKeySecret string
}{
{
"api key id or secret not provided",
"",
"secret",
},
{
"api key id or secret not provided",
"key",
"",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.errMsg, func(t *testing.T) {
t.Parallel()
m := NewManager(nil)
_, err := m.Check(ctx, tc.apiKeyID, tc.apiKeySecret)
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
assert.Contains(t, err.Error(), tc.errMsg)
})
}
})
t.Run("key info not found in database", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("QueryRow", ctx, getAPIKeyUserIDDBQ, "keyID").Return(nil, pgx.ErrNoRows)
m := NewManager(db)
output, err := m.Check(ctx, "keyID", "secret")
assert.NoError(t, err)
assert.False(t, output.Valid)
assert.Empty(t, output.UserID)
db.AssertExpectations(t)
})
t.Run("error getting key info from database", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("QueryRow", ctx, getAPIKeyUserIDDBQ, "keyID").Return(nil, tests.ErrFakeDB)
m := NewManager(db)
output, err := m.Check(ctx, "keyID", "secret")
assert.Equal(t, tests.ErrFakeDB, err)
assert.Nil(t, output)
db.AssertExpectations(t)
})
t.Run("valid key (secret hashed with bcrypt)", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
secretHashed, _ := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.DefaultCost)
db.On("QueryRow", ctx, getAPIKeyUserIDDBQ, "keyID").Return([]interface{}{"userID", string(secretHashed)}, nil)
m := NewManager(db)
output, err := m.Check(ctx, "keyID", "secret")
assert.NoError(t, err)
assert.True(t, output.Valid)
assert.Equal(t, "userID", output.UserID)
db.AssertExpectations(t)
})
t.Run("valid key (secret hashed with sha512)", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
secretHashed := fmt.Sprintf("%x", sha512.Sum512([]byte("secret")))
db.On("QueryRow", ctx, getAPIKeyUserIDDBQ, "keyID").Return([]interface{}{"userID", secretHashed}, nil)
m := NewManager(db)
output, err := m.Check(ctx, "keyID", "secret")
assert.NoError(t, err)
assert.True(t, output.Valid)
assert.Equal(t, "userID", output.UserID)
db.AssertExpectations(t)
})
}

View File

@ -13,9 +13,16 @@ type ManagerMock struct {
}
// Add implements the APIKeyManager interface.
func (m *ManagerMock) Add(ctx context.Context, ak *hub.APIKey) ([]byte, error) {
func (m *ManagerMock) Add(ctx context.Context, ak *hub.APIKey) (*hub.APIKey, error) {
args := m.Called(ctx, ak)
data, _ := args.Get(0).([]byte)
data, _ := args.Get(0).(*hub.APIKey)
return data, args.Error(1)
}
// Check implements the UserManager interface.
func (m *ManagerMock) Check(ctx context.Context, apiKeyID, apiKeySecret string) (*hub.CheckAPIKeyOutput, error) {
args := m.Called(ctx, apiKeyID, apiKeySecret)
data, _ := args.Get(0).(*hub.CheckAPIKeyOutput)
return data, args.Error(1)
}

View File

@ -28,19 +28,20 @@ func NewHandlers(apiKeyManager hub.APIKeyManager) *Handlers {
// Add is an http handler that adds the provided api key to the database.
func (h *Handlers) Add(w http.ResponseWriter, r *http.Request) {
ak := &hub.APIKey{}
if err := json.NewDecoder(r.Body).Decode(&ak); err != nil {
akIN := &hub.APIKey{}
if err := json.NewDecoder(r.Body).Decode(&akIN); err != nil {
h.logger.Error().Err(err).Str("method", "Add").Msg(hub.ErrInvalidInput.Error())
helpers.RenderErrorJSON(w, hub.ErrInvalidInput)
return
}
dataJSON, err := h.apiKeyManager.Add(r.Context(), ak)
akOUT, err := h.apiKeyManager.Add(r.Context(), akIN)
if err != nil {
h.logger.Error().Err(err).Str("method", "Add").Send()
helpers.RenderErrorJSON(w, err)
return
}
helpers.RenderJSON(w, dataJSON, 0, http.StatusCreated)
akOUTJSON, _ := json.Marshal(akOUT)
helpers.RenderJSON(w, akOUTJSON, 0, http.StatusCreated)
}
// Delete is an http handler that deletes the provided api key from the database.

View File

@ -100,7 +100,11 @@ func TestAdd(t *testing.T) {
r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
hw := newHandlersWrapper()
hw.am.On("Add", r.Context(), ak).Return([]byte("keyInfoJSON"), nil)
akOUT := &hub.APIKey{
APIKeyID: "apiKeyID",
Secret: "secret",
}
hw.am.On("Add", r.Context(), ak).Return(akOUT, nil)
hw.h.Add(w, r)
resp := w.Result()
defer resp.Body.Close()
@ -110,7 +114,8 @@ func TestAdd(t *testing.T) {
assert.Equal(t, http.StatusCreated, resp.StatusCode)
assert.Equal(t, "application/json", h.Get("Content-Type"))
assert.Equal(t, helpers.BuildCacheControlHeader(0), h.Get("Cache-Control"))
assert.Equal(t, []byte("keyInfoJSON"), data)
outputAKJSON, _ := json.Marshal(akOUT)
assert.Equal(t, outputAKJSON, data)
hw.am.AssertExpectations(t)
})
}

View File

@ -78,7 +78,7 @@ type Handlers struct {
// Setup creates a new Handlers instance.
func Setup(ctx context.Context, cfg *viper.Viper, svc *Services) (*Handlers, error) {
userHandlers, err := user.NewHandlers(ctx, svc.UserManager, cfg)
userHandlers, err := user.NewHandlers(ctx, svc.UserManager, svc.APIKeyManager, cfg)
if err != nil {
return nil, err
}

View File

@ -66,16 +66,22 @@ var (
// Handlers represents a group of http handlers in charge of handling
// users operations.
type Handlers struct {
userManager hub.UserManager
cfg *viper.Viper
sc *securecookie.SecureCookie
oauthConfig map[string]*oauth2.Config
oidcProvider *oidc.Provider
logger zerolog.Logger
userManager hub.UserManager
apiKeyManager hub.APIKeyManager
cfg *viper.Viper
sc *securecookie.SecureCookie
oauthConfig map[string]*oauth2.Config
oidcProvider *oidc.Provider
logger zerolog.Logger
}
// NewHandlers creates a new Handlers instance.
func NewHandlers(ctx context.Context, userManager hub.UserManager, cfg *viper.Viper) (*Handlers, error) {
func NewHandlers(
ctx context.Context,
userManager hub.UserManager,
apiKeyManager hub.APIKeyManager,
cfg *viper.Viper,
) (*Handlers, error) {
// Setup secure cookie instance
sc := securecookie.New([]byte(cfg.GetString("server.cookie.hashKey")), nil)
sc.MaxAge(int(sessionDuration.Seconds()))
@ -112,12 +118,13 @@ func NewHandlers(ctx context.Context, userManager hub.UserManager, cfg *viper.Vi
}
return &Handlers{
userManager: userManager,
cfg: cfg,
sc: sc,
oauthConfig: oauthConfig,
oidcProvider: oidcProvider,
logger: log.With().Str("handlers", "user").Logger(),
userManager: userManager,
apiKeyManager: apiKeyManager,
cfg: cfg,
sc: sc,
oauthConfig: oauthConfig,
oidcProvider: oidcProvider,
logger: log.With().Str("handlers", "user").Logger(),
}, nil
}
@ -137,7 +144,7 @@ func (h *Handlers) ApproveSession(w http.ResponseWriter, r *http.Request) {
passcode := input["passcode"]
// Extract sessionID from cookie
var sessionID []byte
var sessionID string
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
h.logger.Error().Err(err).Str("method", "ApproveSession").Msg("session cookie not found")
@ -290,7 +297,7 @@ func (h *Handlers) InjectUserID(next http.Handler) http.Handler {
if err != nil {
return
}
var sessionID []byte
var sessionID string
if err = h.sc.Decode(sessionCookieName, cookie.Value, &sessionID); err != nil {
return
}
@ -371,7 +378,7 @@ func (h *Handlers) Logout(w http.ResponseWriter, r *http.Request) {
// Delete user session
cookie, err := r.Cookie(sessionCookieName)
if err == nil {
var sessionID []byte
var sessionID string
err = h.sc.Decode(sessionCookieName, cookie.Value, &sessionID)
if err == nil {
err = h.userManager.DeleteSession(r.Context(), sessionID)
@ -749,7 +756,7 @@ func (h *Handlers) RequireLogin(next http.Handler) http.Handler {
// Use API key based authentication if API key is provided
if apiKeyID != "" && apiKeySecret != "" {
// Check the API key provided is valid
checkAPIKeyOutput, err := h.userManager.CheckAPIKey(r.Context(), apiKeyID, apiKeySecret)
checkAPIKeyOutput, err := h.apiKeyManager.Check(r.Context(), apiKeyID, apiKeySecret)
if err != nil {
h.logger.Error().Err(err).Str("method", "RequireLogin").Msg("checkAPIKey failed")
helpers.RenderErrorWithCodeJSON(w, nil, http.StatusInternalServerError)
@ -766,7 +773,7 @@ func (h *Handlers) RequireLogin(next http.Handler) http.Handler {
cookie, err := r.Cookie(sessionCookieName)
if err == nil {
// Extract and validate cookie from request
var sessionID []byte
var sessionID string
if err = h.sc.Decode(sessionCookieName, cookie.Value, &sessionID); err != nil {
h.logger.Error().Err(err).Str("method", "RequireLogin").Msg("sessionID decoding failed")
helpers.RenderErrorWithCodeJSON(w, errInvalidSession, http.StatusUnauthorized)

View File

@ -13,6 +13,7 @@ import (
"testing"
"time"
"github.com/artifacthub/hub/internal/apikey"
"github.com/artifacthub/hub/internal/handlers/helpers"
"github.com/artifacthub/hub/internal/hub"
"github.com/artifacthub/hub/internal/tests"
@ -31,7 +32,7 @@ func TestMain(m *testing.M) {
}
func TestApproveSession(t *testing.T) {
sessionID := []byte("sessionID")
sessionID := "sessionID"
t.Run("invalid input", func(t *testing.T) {
testCases := []struct {
@ -495,6 +496,8 @@ func TestGetProfile(t *testing.T) {
}
func TestInjectUserID(t *testing.T) {
sessionID := "sessionID"
checkUserID := func(expectedUserID interface{}) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if expectedUserID == nil {
@ -541,7 +544,7 @@ func TestInjectUserID(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil)
hw := newHandlersWrapper()
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, []byte("sessionID"))
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
@ -565,7 +568,7 @@ func TestInjectUserID(t *testing.T) {
hw.um.On("CheckSession", r.Context(), mock.Anything, mock.Anything).
Return(&hub.CheckSessionOutput{UserID: "", Valid: false}, nil)
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, []byte("sessionID"))
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
@ -586,7 +589,7 @@ func TestInjectUserID(t *testing.T) {
hw := newHandlersWrapper()
hw.um.On("CheckSession", r.Context(), mock.Anything, mock.Anything).
Return(&hub.CheckSessionOutput{UserID: "userID", Valid: true}, nil)
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, []byte("sessionID"))
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, sessionID)
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
@ -601,6 +604,8 @@ func TestInjectUserID(t *testing.T) {
}
func TestLogin(t *testing.T) {
sessionID := "sessionID"
t.Run("invalid", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
@ -693,7 +698,7 @@ func TestLogin(t *testing.T) {
Return(&hub.CheckCredentialsOutput{Valid: true, UserID: "userID"}, nil)
hw.um.On("RegisterSession", r.Context(), &hub.Session{UserID: "userID"}).
Return(&hub.Session{
SessionID: []byte("sessionID"),
SessionID: sessionID,
Approved: true,
}, nil)
hw.h.Login(w, r)
@ -708,10 +713,10 @@ func TestLogin(t *testing.T) {
assert.Equal(t, "/", cookie.Path)
assert.True(t, cookie.HttpOnly)
assert.False(t, cookie.Secure)
var sessionID []byte
err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &sessionID)
var cookieSessionID string
err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &cookieSessionID)
require.NoError(t, err)
assert.Equal(t, []byte("sessionID"), sessionID)
assert.Equal(t, sessionID, cookieSessionID)
assert.Equal(t, "true", h.Get(SessionApprovedHeader))
hw.um.AssertExpectations(t)
})
@ -727,7 +732,7 @@ func TestLogin(t *testing.T) {
Return(&hub.CheckCredentialsOutput{Valid: true, UserID: "userID"}, nil)
hw.um.On("RegisterSession", r.Context(), &hub.Session{UserID: "userID"}).
Return(&hub.Session{
SessionID: []byte("sessionID"),
SessionID: sessionID,
Approved: false,
}, nil)
hw.h.Login(w, r)
@ -742,10 +747,10 @@ func TestLogin(t *testing.T) {
assert.Equal(t, "/", cookie.Path)
assert.True(t, cookie.HttpOnly)
assert.False(t, cookie.Secure)
var sessionID []byte
err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &sessionID)
var cookieSessionID string
err := hw.h.sc.Decode(sessionCookieName, cookie.Value, &cookieSessionID)
require.NoError(t, err)
assert.Equal(t, []byte("sessionID"), sessionID)
assert.Equal(t, sessionID, cookieSessionID)
assert.Equal(t, "false", h.Get(SessionApprovedHeader))
hw.um.AssertExpectations(t)
})
@ -815,8 +820,8 @@ func TestLogout(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil)
hw := newHandlersWrapper()
hw.um.On("DeleteSession", r.Context(), []byte("sessionID")).Return(tc.err)
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, []byte("sessionID"))
hw.um.On("DeleteSession", r.Context(), "sessionID").Return(tc.err)
encodedSessionID, _ := hw.h.sc.Encode(sessionCookieName, "sessionID")
r.AddCookie(&http.Cookie{
Name: sessionCookieName,
Value: encodedSessionID,
@ -1092,7 +1097,7 @@ func TestRegisterUser(t *testing.T) {
}
func TestRequireLogin(t *testing.T) {
sessionID := []byte("sessionID")
sessionID := "sessionID"
t.Run("api key based authentication", func(t *testing.T) {
apiKeyID := "keyID"
@ -1143,7 +1148,7 @@ func TestRequireLogin(t *testing.T) {
r.Header.Add(APIKeySecretHeader, apiKeySecret)
hw := newHandlersWrapper()
hw.um.On("CheckAPIKey", r.Context(), apiKeyID, apiKeySecret).Return(nil, tests.ErrFakeDB)
hw.am.On("Check", r.Context(), apiKeyID, apiKeySecret).Return(nil, tests.ErrFakeDB)
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
@ -1164,7 +1169,7 @@ func TestRequireLogin(t *testing.T) {
r.Header.Add(APIKeySecretHeader, apiKeySecret)
hw := newHandlersWrapper()
hw.um.On("CheckAPIKey", r.Context(), apiKeyID, apiKeySecret).
hw.am.On("Check", r.Context(), apiKeyID, apiKeySecret).
Return(&hub.CheckAPIKeyOutput{UserID: "", Valid: false}, nil)
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
@ -1186,7 +1191,7 @@ func TestRequireLogin(t *testing.T) {
r.Header.Add(APIKeySecretHeader, apiKeySecret)
hw := newHandlersWrapper()
hw.um.On("CheckAPIKey", r.Context(), apiKeyID, apiKeySecret).
hw.am.On("Check", r.Context(), apiKeyID, apiKeySecret).
Return(&hub.CheckAPIKeyOutput{UserID: "userID", Valid: true}, nil)
hw.h.RequireLogin(http.HandlerFunc(testsOK)).ServeHTTP(w, r)
resp := w.Result()
@ -1689,6 +1694,7 @@ func testsOK(w http.ResponseWriter, r *http.Request) {}
type handlersWrapper struct {
cfg *viper.Viper
um *user.ManagerMock
am *apikey.ManagerMock
h *Handlers
}
@ -1697,11 +1703,13 @@ func newHandlersWrapper() *handlersWrapper {
cfg.Set("server.baseURL", "baseURL")
cfg.Set("server.oauth.github", map[string]string{})
um := &user.ManagerMock{}
h, _ := NewHandlers(context.Background(), um, cfg)
am := &apikey.ManagerMock{}
h, _ := NewHandlers(context.Background(), um, am, cfg)
return &handlersWrapper{
cfg: cfg,
um: um,
am: am,
h: h,
}
}

View File

@ -6,6 +6,7 @@ import "context"
type APIKey struct {
APIKeyID string `json:"api_key_id"`
Name string `json:"name"`
Secret string `json:"secret"`
CreatedAt int64 `json:"created_at"`
UserID string `json:"user_id"`
}
@ -13,9 +14,16 @@ type APIKey struct {
// APIKeyManager describes the methods an APIKeyManager implementation must
// provide.
type APIKeyManager interface {
Add(ctx context.Context, ak *APIKey) ([]byte, error)
Add(ctx context.Context, ak *APIKey) (*APIKey, error)
Check(ctx context.Context, apiKeyID, apiKeySecret string) (*CheckAPIKeyOutput, error)
Delete(ctx context.Context, apiKeyID string) error
GetJSON(ctx context.Context, apiKeyID string) ([]byte, error)
GetOwnedByUserJSON(ctx context.Context) ([]byte, error)
Update(ctx context.Context, ak *APIKey) error
}
// CheckAPIKeyOutput represents the output returned by the CheckApiKey method.
type CheckAPIKeyOutput struct {
Valid bool `json:"valid"`
UserID string `json:"user_id"`
}

View File

@ -5,12 +5,6 @@ import (
"time"
)
// CheckAPIKeyOutput represents the output returned by the CheckApiKey method.
type CheckAPIKeyOutput struct {
Valid bool `json:"valid"`
UserID string `json:"user_id"`
}
// CheckCredentialsOutput represents the output returned by the
// CheckCredentials method.
type CheckCredentialsOutput struct {
@ -26,7 +20,7 @@ type CheckSessionOutput struct {
// Session represents some information about a user session.
type Session struct {
SessionID []byte `json:"session_id"`
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
IP string `json:"ip"`
UserAgent string `json:"user_agent"`
@ -68,12 +62,11 @@ var UserIDKey = userIDKey{}
// UserManager describes the methods a UserManager implementation must provide.
type UserManager interface {
ApproveSession(ctx context.Context, sessionID []byte, passcode string) error
CheckAPIKey(ctx context.Context, apiKeyID, apiKeySecret string) (*CheckAPIKeyOutput, error)
ApproveSession(ctx context.Context, sessionID, passcode string) error
CheckAvailability(ctx context.Context, resourceKind, value string) (bool, error)
CheckCredentials(ctx context.Context, email, password string) (*CheckCredentialsOutput, error)
CheckSession(ctx context.Context, sessionID []byte, duration time.Duration) (*CheckSessionOutput, error)
DeleteSession(ctx context.Context, sessionID []byte) error
CheckSession(ctx context.Context, sessionID string, duration time.Duration) (*CheckSessionOutput, error)
DeleteSession(ctx context.Context, sessionID string) error
DisableTFA(ctx context.Context, passcode string) error
EnableTFA(ctx context.Context, passcode string) error
GetProfile(ctx context.Context) (*User, error)

View File

@ -3,6 +3,7 @@ package user
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha512"
"encoding/base64"
"encoding/json"
@ -10,7 +11,6 @@ import (
"fmt"
"image/png"
"net/url"
"strings"
"time"
"github.com/artifacthub/hub/internal/email"
@ -26,13 +26,12 @@ import (
const (
// Database queries
approveSessionDBQ = `select approve_session($1::bytea, $2::text)`
approveSessionDBQ = `select approve_session($1::text, $2::text)`
checkUserAliasAvailDBQ = `select check_user_alias_availability($1::text)`
checkUserCredsDBQ = `select user_id, password from "user" where email = $1 and password is not null and email_verified = true`
deleteSessionDBQ = `delete from session where session_id = $1`
disableTFADBQ = `update "user" set tfa_enabled = false, tfa_url = null, tfa_recovery_codes = null where user_id = $1 and tfa_enabled = true`
enableTFADBQ = `update "user" set tfa_enabled = true where user_id = $1`
getAPIKeyInfoDBQ = `select user_id, secret from api_key where api_key_id = $1`
getSessionDBQ = `select user_id, floor(extract(epoch from created_at)), approved from session where session_id = $1`
getTFAConfigDBQ = `select get_user_tfa_config($1::uuid)`
getUserEmailDBQ = `select email from "user" where user_id = $1`
@ -40,15 +39,15 @@ const (
getUserIDFromSessionIDDBQ = `select user_id from session where session_id = $1`
getUserPasswordDBQ = `select password from "user" where user_id = $1 and password is not null`
getUserProfileDBQ = `select get_user_profile($1::uuid)`
registerPasswordResetCodeDBQ = `select register_password_reset_code($1::text)`
registerSessionDBQ = `select session_id, approved from register_session($1::jsonb)`
registerPasswordResetCodeDBQ = `select register_password_reset_code($1::text, $2::text)`
registerSessionDBQ = `select register_session($1::jsonb)`
registerUserDBQ = `select register_user($1::jsonb)`
resetUserPasswordDBQ = `select reset_user_password($1::bytea, $2::text)`
resetUserPasswordDBQ = `select reset_user_password($1::text, $2::text)`
updateTFAInfoDBQ = `update "user" set tfa_url = $2, tfa_recovery_codes = $3 where user_id = $1`
updateUserPasswordDBQ = `select update_user_password($1::uuid, $2::text, $3::text)`
updateUserProfileDBQ = `select update_user_profile($1::uuid, $2::jsonb)`
verifyEmailDBQ = `select verify_email($1::uuid)`
verifyPasswordResetCodeDBQ = `select verify_password_reset_code($1::bytea)`
verifyPasswordResetCodeDBQ = `select verify_password_reset_code($1::text)`
numRecoveryCodes = 10
)
@ -94,7 +93,7 @@ func NewManager(db hub.DB, es hub.EmailSender) *Manager {
}
// ApproveSession approves a given session using the TFA passcode provided.
func (m *Manager) ApproveSession(ctx context.Context, sessionID []byte, passcode string) error {
func (m *Manager) ApproveSession(ctx context.Context, sessionID, passcode string) error {
// Validate input
if len(sessionID) == 0 {
return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "sessionID not provided")
@ -105,7 +104,7 @@ func (m *Manager) ApproveSession(ctx context.Context, sessionID []byte, passcode
// Get id of the user the session belongs to
var userID string
err := m.db.QueryRow(ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Scan(&userID)
err := m.db.QueryRow(ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Scan(&userID)
if err != nil {
return err
}
@ -131,48 +130,10 @@ func (m *Manager) ApproveSession(ctx context.Context, sessionID []byte, passcode
if validRecoveryCodeProvided {
recoveryCode = passcode
}
_, err = m.db.Exec(ctx, approveSessionDBQ, hashSessionID(sessionID), recoveryCode)
_, err = m.db.Exec(ctx, approveSessionDBQ, hash(sessionID), recoveryCode)
return err
}
// CheckAPIKey checks if the api key provided is valid.
func (m *Manager) CheckAPIKey(ctx context.Context, apiKeyID, apiKeySecret string) (*hub.CheckAPIKeyOutput, error) {
// Validate input
if apiKeyID == "" || apiKeySecret == "" {
return nil, fmt.Errorf("%w: %s", hub.ErrInvalidInput, "api key id or secret not provided")
}
// Get key's user id and secret from database
var userID, apiKeySecretHashed string
err := m.db.QueryRow(ctx, getAPIKeyInfoDBQ, apiKeyID).Scan(&userID, &apiKeySecretHashed)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return &hub.CheckAPIKeyOutput{Valid: false}, nil
}
return nil, err
}
// Check if the secret provided is valid
switch {
case strings.HasPrefix(apiKeySecretHashed, "$2a$"):
// Bcrypt hash, will be deprecated soon
err = bcrypt.CompareHashAndPassword([]byte(apiKeySecretHashed), []byte(apiKeySecret))
if err != nil {
return &hub.CheckAPIKeyOutput{Valid: false}, nil
}
default:
// SHA512 hash
if fmt.Sprintf("%x", sha512.Sum512([]byte(apiKeySecret))) != apiKeySecretHashed {
return &hub.CheckAPIKeyOutput{Valid: false}, nil
}
}
return &hub.CheckAPIKeyOutput{
Valid: true,
UserID: userID,
}, nil
}
// CheckAvailability checks the availability of a given value for the provided
// resource kind.
func (m *Manager) CheckAvailability(ctx context.Context, resourceKind, value string) (bool, error) {
@ -247,7 +208,7 @@ func (m *Manager) CheckCredentials(
// CheckSession checks if the user session provided is valid.
func (m *Manager) CheckSession(
ctx context.Context,
sessionID []byte,
sessionID string,
duration time.Duration,
) (*hub.CheckSessionOutput, error) {
// Validate input
@ -262,7 +223,7 @@ func (m *Manager) CheckSession(
var userID string
var createdAt int64
var approved bool
err := m.db.QueryRow(ctx, getSessionDBQ, hashSessionID(sessionID)).Scan(&userID, &createdAt, &approved)
err := m.db.QueryRow(ctx, getSessionDBQ, hash(sessionID)).Scan(&userID, &createdAt, &approved)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return &hub.CheckSessionOutput{Valid: false}, nil
@ -288,14 +249,14 @@ func (m *Manager) CheckSession(
}
// DeleteSession deletes a user session from the database.
func (m *Manager) DeleteSession(ctx context.Context, sessionID []byte) error {
func (m *Manager) DeleteSession(ctx context.Context, sessionID string) error {
// Validate input
if len(sessionID) == 0 {
return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "session id not provided")
}
// Delete session from database
_, err := m.db.Exec(ctx, deleteSessionDBQ, hashSessionID(sessionID))
_, err := m.db.Exec(ctx, deleteSessionDBQ, hash(sessionID))
return err
}
@ -461,17 +422,20 @@ func (m *Manager) RegisterPasswordResetCode(ctx context.Context, userEmail, base
}
// Register password reset code in database
var code []byte
err := m.db.QueryRow(ctx, registerPasswordResetCodeDBQ, userEmail).Scan(&code)
randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil {
return err
}
code := base64.URLEncoding.EncodeToString(randomBytes)
_, err := m.db.Exec(ctx, registerPasswordResetCodeDBQ, userEmail, hash(code))
if err != nil {
return err
}
// Send password reset email
if code != nil && m.es != nil {
codeB64 := base64.URLEncoding.EncodeToString(code)
if m.es != nil {
templateData := map[string]string{
"link": fmt.Sprintf("%s/reset-password?code=%s", baseURL, codeB64),
"link": fmt.Sprintf("%s/reset-password?code=%s", baseURL, code),
}
var emailBody bytes.Buffer
if err := passwordResetTmpl.Execute(&emailBody, templateData); err != nil {
@ -500,14 +464,23 @@ func (m *Manager) RegisterSession(ctx context.Context, session *hub.Session) (*h
return nil, fmt.Errorf("%w: %s", hub.ErrInvalidInput, "invalid user id")
}
// Generate session id
randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil {
return nil, err
}
sessionID := base64.StdEncoding.EncodeToString(randomBytes)
sessionIDHashed := hash(sessionID)
// Register session in database
session.SessionID = sessionIDHashed
sessionJSON, _ := json.Marshal(session)
var sessionID []byte
var approved bool
err := m.db.QueryRow(ctx, registerSessionDBQ, sessionJSON).Scan(&sessionID, &approved)
err := m.db.QueryRow(ctx, registerSessionDBQ, sessionJSON).Scan(&approved)
if err != nil {
return nil, err
}
return &hub.Session{
SessionID: sessionID,
UserID: session.UserID,
@ -585,9 +558,9 @@ func (m *Manager) RegisterUser(ctx context.Context, user *hub.User, baseURL stri
}
// ResetPassword resets the user password in the database.
func (m *Manager) ResetPassword(ctx context.Context, codeB64, newPassword, baseURL string) error {
func (m *Manager) ResetPassword(ctx context.Context, code, newPassword, baseURL string) error {
// Validate input
if codeB64 == "" {
if code == "" {
return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "code not provided")
}
if newPassword == "" {
@ -610,12 +583,8 @@ func (m *Manager) ResetPassword(ctx context.Context, codeB64, newPassword, baseU
}
// Reset user password in database
code, err := base64.URLEncoding.DecodeString(codeB64)
if err != nil {
return ErrInvalidPasswordResetCode
}
var userEmail string
err = m.db.QueryRow(ctx, resetUserPasswordDBQ, code, string(newHashed)).Scan(&userEmail)
err = m.db.QueryRow(ctx, resetUserPasswordDBQ, hash(code), string(newHashed)).Scan(&userEmail)
if err != nil {
if err.Error() == errInvalidPasswordResetCodeDB.Error() {
return ErrInvalidPasswordResetCode
@ -772,28 +741,23 @@ func (m *Manager) VerifyEmail(ctx context.Context, code string) (bool, error) {
}
// VerifyPasswordResetCode verifies if the provided code is valid.
func (m *Manager) VerifyPasswordResetCode(ctx context.Context, codeB64 string) error {
func (m *Manager) VerifyPasswordResetCode(ctx context.Context, code string) error {
// Validate input
if codeB64 == "" {
if code == "" {
return fmt.Errorf("%w: %s", hub.ErrInvalidInput, "code not provided")
}
// Verify password reset code in database
code, err := base64.URLEncoding.DecodeString(codeB64)
if err != nil {
return ErrInvalidPasswordResetCode
}
_, err = m.db.Exec(ctx, verifyPasswordResetCodeDBQ, code)
_, err := m.db.Exec(ctx, verifyPasswordResetCodeDBQ, hash(code))
if err != nil && err.Error() == errInvalidPasswordResetCodeDB.Error() {
return ErrInvalidPasswordResetCode
}
return err
}
// hashSessionID is a helper function that creates a sha512 hash of the
// sessionID provided.
func hashSessionID(sessionID []byte) string {
return fmt.Sprintf("\\x%x", sha512.Sum512(sessionID))
// hash is a helper function that creates a sha512 hash of the text provided.
func hash(text string) string {
return fmt.Sprintf("%x", sha512.Sum512([]byte(text)))
}
// isValidRecoveryCode checks if the code provided is a valid recovery code.

View File

@ -2,8 +2,6 @@ package user
import (
"context"
"crypto/sha512"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@ -24,8 +22,8 @@ import (
func TestApproveSession(t *testing.T) {
ctx := context.Background()
sessionID := []byte("sessionID")
hashedSessionID := hashSessionID([]byte("sessionID"))
sessionID := "sessionID"
hashedSessionID := hash(sessionID)
opts := totp.GenerateOpts{
Issuer: "Artifact Hub",
AccountName: "test@email.com",
@ -42,22 +40,17 @@ func TestApproveSession(t *testing.T) {
t.Run("invalid input", func(t *testing.T) {
testCases := []struct {
errMsg string
sessionID []byte
sessionID string
passcode string
}{
{
"sessionID not provided",
nil,
"123456",
},
{
"sessionID not provided",
[]byte(""),
"",
"123456",
},
{
"passcode not provided",
[]byte("sessionID"),
"sessionID",
"",
},
}
@ -87,7 +80,7 @@ func TestApproveSession(t *testing.T) {
t.Run("error getting tfa config from database", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil)
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Return("userID", nil)
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(nil, tests.ErrFakeDB)
m := NewManager(db, nil)
@ -99,7 +92,7 @@ func TestApproveSession(t *testing.T) {
t.Run("invalid passcode provided", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil)
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Return("userID", nil)
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
m := NewManager(db, nil)
@ -111,7 +104,7 @@ func TestApproveSession(t *testing.T) {
t.Run("session approved successfully", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil)
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Return("userID", nil)
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
db.On("Exec", ctx, approveSessionDBQ, hashedSessionID, "").Return(nil)
m := NewManager(db, nil)
@ -125,7 +118,7 @@ func TestApproveSession(t *testing.T) {
t.Run("session approved successfully (using valid recovery code)", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hashSessionID(sessionID)).Return("userID", nil)
db.On("QueryRow", ctx, getUserIDFromSessionIDDBQ, hash(sessionID)).Return("userID", nil)
db.On("QueryRow", ctx, getTFAConfigDBQ, "userID").Return(tfaConfigJSON, nil)
db.On("Exec", ctx, approveSessionDBQ, hashedSessionID, code1).Return(nil)
m := NewManager(db, nil)
@ -136,92 +129,6 @@ func TestApproveSession(t *testing.T) {
})
}
func TestCheckAPIKey(t *testing.T) {
ctx := context.Background()
t.Run("invalid input", func(t *testing.T) {
testCases := []struct {
errMsg string
apiKeyID string
apiKeySecret string
}{
{
"api key id or secret not provided",
"",
"secret",
},
{
"api key id or secret not provided",
"key",
"",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.errMsg, func(t *testing.T) {
t.Parallel()
m := NewManager(nil, nil)
_, err := m.CheckAPIKey(ctx, tc.apiKeyID, tc.apiKeySecret)
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
assert.Contains(t, err.Error(), tc.errMsg)
})
}
})
t.Run("key info not found in database", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return(nil, pgx.ErrNoRows)
m := NewManager(db, nil)
output, err := m.CheckAPIKey(ctx, "keyID", "secret")
assert.NoError(t, err)
assert.False(t, output.Valid)
assert.Empty(t, output.UserID)
db.AssertExpectations(t)
})
t.Run("error getting key info from database", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return(nil, tests.ErrFakeDB)
m := NewManager(db, nil)
output, err := m.CheckAPIKey(ctx, "keyID", "secret")
assert.Equal(t, tests.ErrFakeDB, err)
assert.Nil(t, output)
db.AssertExpectations(t)
})
t.Run("valid key (secret hashed with bcrypt)", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
secretHashed, _ := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.DefaultCost)
db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return([]interface{}{"userID", string(secretHashed)}, nil)
m := NewManager(db, nil)
output, err := m.CheckAPIKey(ctx, "keyID", "secret")
assert.NoError(t, err)
assert.True(t, output.Valid)
assert.Equal(t, "userID", output.UserID)
db.AssertExpectations(t)
})
t.Run("valid key (secret hashed with sha512)", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
secretHashed := fmt.Sprintf("%x", sha512.Sum512([]byte("secret")))
db.On("QueryRow", ctx, getAPIKeyInfoDBQ, "keyID").Return([]interface{}{"userID", secretHashed}, nil)
m := NewManager(db, nil)
output, err := m.CheckAPIKey(ctx, "keyID", "secret")
assert.NoError(t, err)
assert.True(t, output.Valid)
assert.Equal(t, "userID", output.UserID)
db.AssertExpectations(t)
})
}
func TestCheckAvailability(t *testing.T) {
ctx := context.Background()
@ -385,27 +292,23 @@ func TestCheckCredentials(t *testing.T) {
func TestCheckSession(t *testing.T) {
ctx := context.Background()
hashedSessionID := hashSessionID([]byte("sessionID"))
sessionID := "sessionID"
hashedSessionID := hash(sessionID)
t.Run("invalid input", func(t *testing.T) {
testCases := []struct {
errMsg string
sessionID []byte
sessionID string
duration time.Duration
}{
{
"session id not provided",
nil,
10,
},
{
"session id not provided",
[]byte(""),
"",
10,
},
{
"duration not provided",
[]byte("sessionID"),
"sessionID",
0,
},
}
@ -427,7 +330,7 @@ func TestCheckSession(t *testing.T) {
db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return(nil, pgx.ErrNoRows)
m := NewManager(db, nil)
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
output, err := m.CheckSession(ctx, sessionID, 1*time.Hour)
assert.NoError(t, err)
assert.False(t, output.Valid)
assert.Empty(t, output.UserID)
@ -440,7 +343,7 @@ func TestCheckSession(t *testing.T) {
db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return(nil, tests.ErrFakeDB)
m := NewManager(db, nil)
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
output, err := m.CheckSession(ctx, sessionID, 1*time.Hour)
assert.Equal(t, tests.ErrFakeDB, err)
assert.Nil(t, output)
db.AssertExpectations(t)
@ -456,7 +359,7 @@ func TestCheckSession(t *testing.T) {
}, nil)
m := NewManager(db, nil)
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
output, err := m.CheckSession(ctx, sessionID, 1*time.Hour)
assert.NoError(t, err)
assert.False(t, output.Valid)
assert.Empty(t, output.UserID)
@ -473,7 +376,7 @@ func TestCheckSession(t *testing.T) {
}, nil)
m := NewManager(db, nil)
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
output, err := m.CheckSession(ctx, sessionID, 1*time.Hour)
assert.NoError(t, err)
assert.False(t, output.Valid)
assert.Empty(t, output.UserID)
@ -490,7 +393,7 @@ func TestCheckSession(t *testing.T) {
}, nil)
m := NewManager(db, nil)
output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour)
output, err := m.CheckSession(ctx, sessionID, 1*time.Hour)
assert.NoError(t, err)
assert.True(t, output.Valid)
assert.Equal(t, "userID", output.UserID)
@ -500,22 +403,18 @@ func TestCheckSession(t *testing.T) {
func TestDeleteSession(t *testing.T) {
ctx := context.Background()
hashedSessionID := hashSessionID([]byte("sessionID"))
sessionID := "sessionID"
hashedSessionID := hash(sessionID)
t.Run("invalid input", func(t *testing.T) {
testCases := []struct {
errMsg string
sessionID []byte
sessionID string
duration time.Duration
}{
{
"session id not provided",
nil,
10,
},
{
"session id not provided",
[]byte(""),
"",
10,
},
}
@ -553,7 +452,7 @@ func TestDeleteSession(t *testing.T) {
db.On("Exec", ctx, deleteSessionDBQ, hashedSessionID).Return(tc.dbResponse)
m := NewManager(db, nil)
err := m.DeleteSession(ctx, []byte("sessionID"))
err := m.DeleteSession(ctx, sessionID)
assert.Equal(t, tc.dbResponse, err)
db.AssertExpectations(t)
})
@ -879,12 +778,7 @@ func TestGetUserID(t *testing.T) {
func TestRegisterSession(t *testing.T) {
ctx := context.Background()
s := &hub.Session{
UserID: "00000000-0000-0000-0000-000000000001",
IP: "192.168.1.100",
UserAgent: "Safari 13.0.5",
}
userID := "00000000-0000-0000-0000-000000000001"
t.Run("invalid input", func(t *testing.T) {
testCases := []struct {
@ -919,25 +813,29 @@ func TestRegisterSession(t *testing.T) {
db.On("QueryRow", ctx, registerSessionDBQ, mock.Anything).Return(nil, tests.ErrFakeDB)
m := NewManager(db, nil)
sessionID, err := m.RegisterSession(ctx, s)
sIN := &hub.Session{UserID: userID}
sOUT, err := m.RegisterSession(ctx, sIN)
assert.Equal(t, tests.ErrFakeDB, err)
assert.Nil(t, sessionID)
assert.Nil(t, sOUT)
db.AssertExpectations(t)
})
t.Run("successful session registration", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("QueryRow", ctx, registerSessionDBQ, mock.Anything).Return([]interface{}{
[]byte("sessionID"),
true,
}, nil)
db.On("QueryRow", ctx, registerSessionDBQ, mock.Anything).Return(true, nil)
m := NewManager(db, nil)
session, err := m.RegisterSession(ctx, s)
sIN := &hub.Session{
UserID: userID,
IP: "192.168.1.100",
UserAgent: "Safari 13.0.5",
}
sOUT, err := m.RegisterSession(ctx, sIN)
assert.NoError(t, err)
assert.Equal(t, []byte("sessionID"), session.SessionID)
assert.True(t, session.Approved)
assert.NotEmpty(t, sOUT.SessionID)
assert.Equal(t, userID, sOUT.UserID)
assert.True(t, sOUT.Approved)
db.AssertExpectations(t)
})
}
@ -995,7 +893,7 @@ func TestRegisterPasswordResetCode(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("QueryRow", ctx, registerPasswordResetCodeDBQ, "email@email.com").Return([]byte("code"), nil)
db.On("Exec", ctx, registerPasswordResetCodeDBQ, "email@email.com", mock.Anything).Return(nil)
es := &email.SenderMock{}
es.On("SendEmail", mock.Anything).Return(tc.emailSenderResponse)
m := NewManager(db, es)
@ -1011,7 +909,7 @@ func TestRegisterPasswordResetCode(t *testing.T) {
t.Run("database error registering password reset code", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("QueryRow", ctx, registerPasswordResetCodeDBQ, "email@email.com").Return(nil, tests.ErrFakeDB)
db.On("Exec", ctx, registerPasswordResetCodeDBQ, "email@email.com", mock.Anything).Return(tests.ErrFakeDB)
m := NewManager(db, nil)
err := m.RegisterPasswordResetCode(ctx, "email@email.com", "http://baseurl.com")
@ -1131,15 +1029,15 @@ func TestRegisterUser(t *testing.T) {
func TestResetPassword(t *testing.T) {
ctx := context.Background()
code := []byte("code")
codeB64 := base64.URLEncoding.EncodeToString(code)
code := "code"
codeHashed := hash(code)
newPassword := "a66bV.Xp2" // #nosec
baseURL := "http://baseurl.com"
t.Run("invalid input", func(t *testing.T) {
testCases := []struct {
errMsg string
codeB64 string
code string
newPassword string
baseURL string
}{
@ -1151,19 +1049,19 @@ func TestResetPassword(t *testing.T) {
},
{
"new password not provided",
"code",
code,
"",
baseURL,
},
{
"invalid base url",
"code",
code,
newPassword,
"invalid",
},
{
"insecure password",
"code",
code,
"password",
baseURL,
},
@ -1174,7 +1072,7 @@ func TestResetPassword(t *testing.T) {
t.Parallel()
es := &email.SenderMock{}
m := NewManager(nil, es)
err := m.ResetPassword(ctx, tc.codeB64, tc.newPassword, tc.baseURL)
err := m.ResetPassword(ctx, tc.code, tc.newPassword, tc.baseURL)
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
assert.Contains(t, err.Error(), tc.errMsg)
})
@ -1200,10 +1098,10 @@ func TestResetPassword(t *testing.T) {
t.Run(tc.dbErr.Error(), func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("QueryRow", ctx, resetUserPasswordDBQ, code, mock.Anything).Return("", tc.dbErr)
db.On("QueryRow", ctx, resetUserPasswordDBQ, codeHashed, mock.Anything).Return("", tc.dbErr)
m := NewManager(db, nil)
err := m.ResetPassword(ctx, codeB64, newPassword, baseURL)
err := m.ResetPassword(ctx, code, newPassword, baseURL)
assert.Equal(t, tc.expectedErr, err)
db.AssertExpectations(t)
})
@ -1229,12 +1127,12 @@ func TestResetPassword(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("QueryRow", ctx, resetUserPasswordDBQ, code, mock.Anything).Return("email", nil)
db.On("QueryRow", ctx, resetUserPasswordDBQ, codeHashed, mock.Anything).Return("email", nil)
es := &email.SenderMock{}
es.On("SendEmail", mock.Anything).Return(tc.emailSenderResponse)
m := NewManager(db, es)
err := m.ResetPassword(ctx, codeB64, newPassword, baseURL)
err := m.ResetPassword(ctx, code, newPassword, baseURL)
assert.Equal(t, tc.emailSenderResponse, err)
db.AssertExpectations(t)
es.AssertExpectations(t)
@ -1494,13 +1392,13 @@ func TestVerifyEmail(t *testing.T) {
func TestVerifyPasswordResetCode(t *testing.T) {
ctx := context.Background()
code := []byte("code")
codeB64 := base64.URLEncoding.EncodeToString(code)
code := "code"
codeHashed := hash(code)
t.Run("invalid input", func(t *testing.T) {
testCases := []struct {
errMsg string
codeB64 string
errMsg string
code string
}{
{
"code not provided",
@ -1512,7 +1410,7 @@ func TestVerifyPasswordResetCode(t *testing.T) {
t.Run(tc.errMsg, func(t *testing.T) {
t.Parallel()
m := NewManager(nil, nil)
err := m.VerifyPasswordResetCode(ctx, tc.codeB64)
err := m.VerifyPasswordResetCode(ctx, tc.code)
assert.True(t, errors.Is(err, hub.ErrInvalidInput))
assert.Contains(t, err.Error(), tc.errMsg)
})
@ -1538,10 +1436,10 @@ func TestVerifyPasswordResetCode(t *testing.T) {
t.Run(tc.dbErr.Error(), func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("Exec", ctx, verifyPasswordResetCodeDBQ, code).Return(tc.dbErr)
db.On("Exec", ctx, verifyPasswordResetCodeDBQ, codeHashed).Return(tc.dbErr)
m := NewManager(db, nil)
err := m.VerifyPasswordResetCode(ctx, codeB64)
err := m.VerifyPasswordResetCode(ctx, code)
assert.Equal(t, tc.expectedErr, err)
db.AssertExpectations(t)
})
@ -1551,10 +1449,10 @@ func TestVerifyPasswordResetCode(t *testing.T) {
t.Run("password code verified successfully in database", func(t *testing.T) {
t.Parallel()
db := &tests.DBMock{}
db.On("Exec", ctx, verifyPasswordResetCodeDBQ, code).Return(nil)
db.On("Exec", ctx, verifyPasswordResetCodeDBQ, codeHashed).Return(nil)
m := NewManager(db, nil)
err := m.VerifyPasswordResetCode(ctx, codeB64)
err := m.VerifyPasswordResetCode(ctx, code)
assert.Equal(t, nil, err)
db.AssertExpectations(t)
})

View File

@ -14,18 +14,11 @@ type ManagerMock struct {
}
// ApproveSession implements the UserManager interface.
func (m *ManagerMock) ApproveSession(ctx context.Context, sessionID []byte, passcode string) error {
func (m *ManagerMock) ApproveSession(ctx context.Context, sessionID, passcode string) error {
args := m.Called(ctx, sessionID, passcode)
return args.Error(0)
}
// CheckAPIKey implements the UserManager interface.
func (m *ManagerMock) CheckAPIKey(ctx context.Context, apiKeyID, apiKeySecret string) (*hub.CheckAPIKeyOutput, error) {
args := m.Called(ctx, apiKeyID, apiKeySecret)
data, _ := args.Get(0).(*hub.CheckAPIKeyOutput)
return data, args.Error(1)
}
// CheckAvailability implements the UserManager interface.
func (m *ManagerMock) CheckAvailability(ctx context.Context, resourceKind, value string) (bool, error) {
args := m.Called(ctx, resourceKind, value)
@ -46,7 +39,7 @@ func (m *ManagerMock) CheckCredentials(
// CheckSession implements the UserManager interface.
func (m *ManagerMock) CheckSession(
ctx context.Context,
sessionID []byte,
sessionID string,
duration time.Duration,
) (*hub.CheckSessionOutput, error) {
args := m.Called(ctx, sessionID, duration)
@ -55,7 +48,7 @@ func (m *ManagerMock) CheckSession(
}
// DeleteSession implements the UserManager interface.
func (m *ManagerMock) DeleteSession(ctx context.Context, sessionID []byte) error {
func (m *ManagerMock) DeleteSession(ctx context.Context, sessionID string) error {
args := m.Called(ctx, sessionID)
return args.Error(0)
}