diff --git a/database/migrations/functions/users/register_session.sql b/database/migrations/functions/users/register_session.sql index 76e195cd..06bb7f26 100644 --- a/database/migrations/functions/users/register_session.sql +++ b/database/migrations/functions/users/register_session.sql @@ -1,13 +1,20 @@ -- register_session registers the provided session in the database. create or replace function register_session(p_session jsonb) returns bytea as $$ +declare + v_session_id bytea := gen_random_bytes(32); +begin insert into session ( + session_id, user_id, ip, user_agent ) values ( + sha512(v_session_id), (p_session->>'user_id')::uuid, nullif(p_session->>'ip', '')::inet, nullif(p_session->>'user_agent', '') - ) returning session_id; -$$ language sql; + ); + return v_session_id; +end +$$ language plpgsql; diff --git a/database/migrations/schema/010_hash_session_id.sql b/database/migrations/schema/010_hash_session_id.sql new file mode 100644 index 00000000..b55f7144 --- /dev/null +++ b/database/migrations/schema/010_hash_session_id.sql @@ -0,0 +1,5 @@ +alter table session alter column session_id drop default; + +---- create above / drop below ---- + +alter table session alter column session_id set default gen_random_bytes(32); diff --git a/database/tests/functions/users/register_session.sql b/database/tests/functions/users/register_session.sql index 7969234a..6e6a2456 100644 --- a/database/tests/functions/users/register_session.sql +++ b/database/tests/functions/users/register_session.sql @@ -36,7 +36,7 @@ select results_eq( ); select is( session_id, - :'session_id', + sha512(:'session_id'), 'Returned session_id returned should be registered' ) from session where user_id = '00000000-0000-0000-0000-000000000001'; diff --git a/database/tests/functions/users/reset_user_password.sql b/database/tests/functions/users/reset_user_password.sql index 9d0410fb..29cdc253 100644 --- a/database/tests/functions/users/reset_user_password.sql +++ b/database/tests/functions/users/reset_user_password.sql @@ -17,7 +17,7 @@ insert into password_reset_code (password_reset_code_id, user_id, created_at) values (sha512(:'code1ID'), :'user1ID', current_timestamp); insert into password_reset_code (password_reset_code_id, user_id, created_at) values (sha512(:'code2ID'), :'user2ID', current_timestamp - '30 minute'::interval); -insert into session (user_id) values (:'user1ID'); +insert into session (session_id, user_id) values (gen_random_bytes(32), :'user1ID'); -- Password reset should fail in the following cases select throws_ok( diff --git a/internal/user/manager.go b/internal/user/manager.go index 7cd90475..578db75d 100644 --- a/internal/user/manager.go +++ b/internal/user/manager.go @@ -195,7 +195,7 @@ func (m *Manager) CheckSession( // Get session details from database var userID string var createdAt int64 - err := m.db.QueryRow(ctx, getSessionDBQ, sessionID).Scan(&userID, &createdAt) + err := m.db.QueryRow(ctx, getSessionDBQ, hashSessionID(sessionID)).Scan(&userID, &createdAt) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return &hub.CheckSessionOutput{Valid: false}, nil @@ -222,7 +222,7 @@ func (m *Manager) DeleteSession(ctx context.Context, sessionID []byte) error { } // Delete session from database - _, err := m.db.Exec(ctx, deleteSessionDBQ, sessionID) + _, err := m.db.Exec(ctx, deleteSessionDBQ, hashSessionID(sessionID)) return err } @@ -538,3 +538,9 @@ func (m *Manager) VerifyPasswordResetCode(ctx context.Context, codeB64 string) e } return err } + +// hashSessionID is a helper function that creates a sha512 hash of the +// sessionID provided. +func hashSessionID(sessionID []byte) string { + return fmt.Sprintf("\\x%x", sha512.Sum512(sessionID)) +} diff --git a/internal/user/manager_test.go b/internal/user/manager_test.go index 36b399ce..e754a320 100644 --- a/internal/user/manager_test.go +++ b/internal/user/manager_test.go @@ -267,6 +267,7 @@ func TestCheckCredentials(t *testing.T) { func TestCheckSession(t *testing.T) { ctx := context.Background() + hashedSessionID := hashSessionID([]byte("sessionID")) t.Run("invalid input", func(t *testing.T) { testCases := []struct { @@ -305,7 +306,7 @@ func TestCheckSession(t *testing.T) { t.Run("session not found in database", func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, getSessionDBQ, []byte("sessionID")).Return(nil, pgx.ErrNoRows) + db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return(nil, pgx.ErrNoRows) m := NewManager(db, nil) output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour) @@ -318,7 +319,7 @@ func TestCheckSession(t *testing.T) { t.Run("error getting session from database", func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, getSessionDBQ, []byte("sessionID")).Return(nil, tests.ErrFakeDB) + db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return(nil, tests.ErrFakeDB) m := NewManager(db, nil) output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour) @@ -330,7 +331,7 @@ func TestCheckSession(t *testing.T) { t.Run("session has expired", func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, getSessionDBQ, []byte("sessionID")).Return([]interface{}{"userID", int64(1)}, nil) + db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return([]interface{}{"userID", int64(1)}, nil) m := NewManager(db, nil) output, err := m.CheckSession(ctx, []byte("sessionID"), 1*time.Hour) @@ -343,7 +344,7 @@ func TestCheckSession(t *testing.T) { t.Run("valid session", func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("QueryRow", ctx, getSessionDBQ, []byte("sessionID")).Return([]interface{}{ + db.On("QueryRow", ctx, getSessionDBQ, hashedSessionID).Return([]interface{}{ "userID", time.Now().Unix(), }, nil) @@ -359,6 +360,7 @@ func TestCheckSession(t *testing.T) { func TestDeleteSession(t *testing.T) { ctx := context.Background() + hashedSessionID := hashSessionID([]byte("sessionID")) t.Run("invalid input", func(t *testing.T) { testCases := []struct { @@ -408,7 +410,7 @@ func TestDeleteSession(t *testing.T) { t.Run(tc.description, func(t *testing.T) { t.Parallel() db := &tests.DBMock{} - db.On("Exec", ctx, deleteSessionDBQ, []byte("sessionID")).Return(tc.dbResponse) + db.On("Exec", ctx, deleteSessionDBQ, hashedSessionID).Return(tc.dbResponse) m := NewManager(db, nil) err := m.DeleteSession(ctx, []byte("sessionID"))