diff --git a/.golangci.yml b/.golangci.yml
index b81ccb2d..34c9fda0 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -26,6 +26,8 @@ linters:
fast: false
linters-settings:
+ goconst:
+ min-occurrences: 5
golint:
min-confidence: 0
diff --git a/chart/templates/hub_secret.yaml b/chart/templates/hub_secret.yaml
index 58eb0e61..d256b869 100644
--- a/chart/templates/hub_secret.yaml
+++ b/chart/templates/hub_secret.yaml
@@ -15,6 +15,7 @@ stringData:
user: {{ .Values.db.user }}
password: {{ .Values.db.password }}
server:
+ baseURL: {{ .Values.hub.server.baseURL }}
addr: 0.0.0.0:8000
metricsAddr: 0.0.0.0:8001
shutdownTimeout: 30s
diff --git a/chart/values-production.yaml b/chart/values-production.yaml
index 80896543..724ae252 100644
--- a/chart/values-production.yaml
+++ b/chart/values-production.yaml
@@ -38,6 +38,7 @@ hub:
cpu: 2
memory: 8000Mi
server:
+ baseURL: https://artifacthub.io
oauth:
github:
redirectURL: https://artifacthub.io/oauth/github/callback
diff --git a/chart/values-staging.yaml b/chart/values-staging.yaml
index 8b7e847e..2169a001 100644
--- a/chart/values-staging.yaml
+++ b/chart/values-staging.yaml
@@ -39,6 +39,7 @@ hub:
cpu: 1
memory: 1000Mi
server:
+ baseURL: https://staging.artifacthub.io
oauth:
github:
redirectURL: https://staging.artifacthub.io/oauth/github/callback
diff --git a/chart/values.yaml b/chart/values.yaml
index 720eb701..77f6f104 100644
--- a/chart/values.yaml
+++ b/chart/values.yaml
@@ -32,6 +32,7 @@ hub:
cpu: 100m
memory: 500Mi
server:
+ baseURL: ""
basicAuth:
enabled: false
username: hub
diff --git a/cmd/hub/handlers/handlers.go b/cmd/hub/handlers/handlers.go
index 4ca2d12b..e6b21f4a 100644
--- a/cmd/hub/handlers/handlers.go
+++ b/cmd/hub/handlers/handlers.go
@@ -12,6 +12,7 @@ import (
"github.com/artifacthub/hub/cmd/hub/handlers/org"
"github.com/artifacthub/hub/cmd/hub/handlers/pkg"
"github.com/artifacthub/hub/cmd/hub/handlers/static"
+ "github.com/artifacthub/hub/cmd/hub/handlers/subscription"
"github.com/artifacthub/hub/cmd/hub/handlers/user"
"github.com/artifacthub/hub/internal/hub"
"github.com/artifacthub/hub/internal/img"
@@ -29,6 +30,7 @@ type Services struct {
UserManager hub.UserManager
PackageManager hub.PackageManager
ChartRepositoryManager hub.ChartRepositoryManager
+ SubscriptionManager hub.SubscriptionManager
ImageStore img.Store
}
@@ -51,6 +53,7 @@ type Handlers struct {
Users *user.Handlers
Packages *pkg.Handlers
ChartRepositories *chartrepo.Handlers
+ Subscriptions *subscription.Handlers
Static *static.Handlers
}
@@ -65,6 +68,7 @@ func Setup(cfg *viper.Viper, svc *Services) *Handlers {
Organizations: org.NewHandlers(svc.OrganizationManager),
Users: user.NewHandlers(svc.UserManager, cfg),
Packages: pkg.NewHandlers(svc.PackageManager),
+ Subscriptions: subscription.NewHandlers(svc.SubscriptionManager),
ChartRepositories: chartrepo.NewHandlers(svc.ChartRepositoryManager),
Static: static.NewHandlers(cfg, svc.ImageStore),
}
@@ -126,11 +130,18 @@ func (h *Handlers) setupRouter() {
r.With(h.Users.RequireLogin).Put("/", h.Packages.ToggleStar)
})
})
+ r.Route("/subscriptions", func(r chi.Router) {
+ r.Use(h.Users.RequireLogin)
+ r.Get("/{packageID}", h.Subscriptions.GetByPackage)
+ r.Post("/", h.Subscriptions.Add)
+ r.Delete("/", h.Subscriptions.Delete)
+ })
r.Post("/users", h.Users.RegisterUser)
r.Route("/user", func(r chi.Router) {
r.Use(h.Users.RequireLogin)
r.Get("/", h.Users.GetProfile)
r.Get("/orgs", h.Organizations.GetByUser)
+ r.Get("/subscriptions", h.Subscriptions.GetByUser)
r.Put("/password", h.Users.UpdatePassword)
r.Put("/profile", h.Users.UpdateProfile)
r.Route("/chart-repositories", func(r chi.Router) {
diff --git a/cmd/hub/handlers/org/handlers_test.go b/cmd/hub/handlers/org/handlers_test.go
index ea530fae..5db4043d 100644
--- a/cmd/hub/handlers/org/handlers_test.go
+++ b/cmd/hub/handlers/org/handlers_test.go
@@ -29,7 +29,7 @@ func TestAdd(t *testing.T) {
t.Run("invalid organization provided", func(t *testing.T) {
testCases := []struct {
description string
- repoJSON string
+ orgJSON string
omErr error
}{
{
@@ -62,7 +62,7 @@ func TestAdd(t *testing.T) {
}
w := httptest.NewRecorder()
- r, _ := http.NewRequest("POST", "/", strings.NewReader(tc.repoJSON))
+ r, _ := http.NewRequest("POST", "/", strings.NewReader(tc.orgJSON))
r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
hw.h.Add(w, r)
resp := w.Result()
@@ -486,7 +486,7 @@ func TestUpdate(t *testing.T) {
t.Run("invalid organization provided", func(t *testing.T) {
testCases := []struct {
description string
- repoJSON string
+ orgJSON string
omErr error
}{
{
@@ -514,7 +514,7 @@ func TestUpdate(t *testing.T) {
}
w := httptest.NewRecorder()
- r, _ := http.NewRequest("PUT", "/", strings.NewReader(tc.repoJSON))
+ r, _ := http.NewRequest("PUT", "/", strings.NewReader(tc.orgJSON))
r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
hw.h.Update(w, r)
resp := w.Result()
diff --git a/cmd/hub/handlers/subscription/handlers.go b/cmd/hub/handlers/subscription/handlers.go
new file mode 100644
index 00000000..e4200ad0
--- /dev/null
+++ b/cmd/hub/handlers/subscription/handlers.go
@@ -0,0 +1,97 @@
+package subscription
+
+import (
+ "encoding/json"
+ "errors"
+ "net/http"
+
+ "github.com/artifacthub/hub/cmd/hub/handlers/helpers"
+ "github.com/artifacthub/hub/internal/hub"
+ "github.com/artifacthub/hub/internal/subscription"
+ "github.com/go-chi/chi"
+ "github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
+)
+
+// Handlers represents a group of http handlers in charge of handling
+// subscriptions operations.
+type Handlers struct {
+ subscriptionManager hub.SubscriptionManager
+ logger zerolog.Logger
+}
+
+// NewHandlers creates a new Handlers instance.
+func NewHandlers(subscriptionManager hub.SubscriptionManager) *Handlers {
+ return &Handlers{
+ subscriptionManager: subscriptionManager,
+ logger: log.With().Str("handlers", "subscription").Logger(),
+ }
+}
+
+// Add is an http handler that adds the provided subscription to the database.
+func (h *Handlers) Add(w http.ResponseWriter, r *http.Request) {
+ s := &hub.Subscription{}
+ if err := json.NewDecoder(r.Body).Decode(&s); err != nil {
+ h.logger.Error().Err(err).Str("method", "Add").Msg("invalid subscription")
+ http.Error(w, "subscription provided is not valid", http.StatusBadRequest)
+ return
+ }
+ if err := h.subscriptionManager.Add(r.Context(), s); err != nil {
+ h.logger.Error().Err(err).Str("method", "Add").Send()
+ if errors.Is(err, subscription.ErrInvalidInput) {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ } else {
+ http.Error(w, "", http.StatusInternalServerError)
+ }
+ return
+ }
+}
+
+// Delete is an http handler that removes the provided subscription from the
+// database.
+func (h *Handlers) Delete(w http.ResponseWriter, r *http.Request) {
+ s := &hub.Subscription{}
+ if err := json.NewDecoder(r.Body).Decode(&s); err != nil {
+ h.logger.Error().Err(err).Str("method", "Delete").Msg("invalid subscription")
+ http.Error(w, "subscription provided is not valid", http.StatusBadRequest)
+ return
+ }
+ if err := h.subscriptionManager.Delete(r.Context(), s); err != nil {
+ h.logger.Error().Err(err).Str("method", "Delete").Send()
+ if errors.Is(err, subscription.ErrInvalidInput) {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ } else {
+ http.Error(w, "", http.StatusInternalServerError)
+ }
+ return
+ }
+}
+
+// GetByPackage is an http handler that returns the subscriptions a user has
+// for a given package.
+func (h *Handlers) GetByPackage(w http.ResponseWriter, r *http.Request) {
+ packageID := chi.URLParam(r, "packageID")
+ dataJSON, err := h.subscriptionManager.GetByPackageJSON(r.Context(), packageID)
+ if err != nil {
+ h.logger.Error().Err(err).Str("method", "GetByPackage").Send()
+ if errors.Is(err, subscription.ErrInvalidInput) {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ } else {
+ http.Error(w, "", http.StatusInternalServerError)
+ }
+ return
+ }
+ helpers.RenderJSON(w, dataJSON, 0)
+}
+
+// GetByUser is an http handler that returns the subscriptions of the user
+// doing the request.
+func (h *Handlers) GetByUser(w http.ResponseWriter, r *http.Request) {
+ dataJSON, err := h.subscriptionManager.GetByUserJSON(r.Context())
+ if err != nil {
+ h.logger.Error().Err(err).Str("method", "GetByUser").Send()
+ http.Error(w, "", http.StatusInternalServerError)
+ return
+ }
+ helpers.RenderJSON(w, dataJSON, 0)
+}
diff --git a/cmd/hub/handlers/subscription/handlers_test.go b/cmd/hub/handlers/subscription/handlers_test.go
new file mode 100644
index 00000000..ed7039a1
--- /dev/null
+++ b/cmd/hub/handlers/subscription/handlers_test.go
@@ -0,0 +1,304 @@
+package subscription
+
+import (
+ "context"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+
+ "github.com/artifacthub/hub/cmd/hub/handlers/helpers"
+ "github.com/artifacthub/hub/internal/hub"
+ "github.com/artifacthub/hub/internal/subscription"
+ "github.com/artifacthub/hub/internal/tests"
+ "github.com/rs/zerolog"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+)
+
+func TestMain(m *testing.M) {
+ zerolog.SetGlobalLevel(zerolog.Disabled)
+ os.Exit(m.Run())
+}
+
+func TestAdd(t *testing.T) {
+ t.Run("invalid subscription provided", func(t *testing.T) {
+ testCases := []struct {
+ description string
+ subscriptionJSON string
+ smErr error
+ }{
+ {
+ "no subscription provided",
+ "",
+ nil,
+ },
+ {
+ "invalid json",
+ "-",
+ nil,
+ },
+ {
+ "invalid package id",
+ `{"package_id": "invalid"}`,
+ subscription.ErrInvalidInput,
+ },
+ }
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.description, func(t *testing.T) {
+ hw := newHandlersWrapper()
+ if tc.smErr != nil {
+ hw.sm.On("Add", mock.Anything, mock.Anything).Return(tc.smErr)
+ }
+
+ w := httptest.NewRecorder()
+ r, _ := http.NewRequest("POST", "/", strings.NewReader(tc.subscriptionJSON))
+ r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
+ hw.h.Add(w, r)
+ resp := w.Result()
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
+ hw.sm.AssertExpectations(t)
+ })
+ }
+ })
+
+ t.Run("valid subscription provided", func(t *testing.T) {
+ subscriptionJSON := `
+ {
+ "package_id": "00000000-0000-0000-0000-000000000001",
+ "notification_kind": 0
+ }
+ `
+ testCases := []struct {
+ description string
+ err error
+ expectedStatusCode int
+ }{
+ {
+ "add subscription succeeded",
+ nil,
+ http.StatusOK,
+ },
+ {
+ "error adding subscription",
+ tests.ErrFakeDatabaseFailure,
+ http.StatusInternalServerError,
+ },
+ }
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.description, func(t *testing.T) {
+ hw := newHandlersWrapper()
+ hw.sm.On("Add", mock.Anything, mock.Anything).Return(tc.err)
+
+ w := httptest.NewRecorder()
+ r, _ := http.NewRequest("POST", "/", strings.NewReader(subscriptionJSON))
+ r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
+ hw.h.Add(w, r)
+ resp := w.Result()
+ defer resp.Body.Close()
+
+ assert.Equal(t, tc.expectedStatusCode, resp.StatusCode)
+ hw.sm.AssertExpectations(t)
+ })
+ }
+ })
+}
+
+func TestDelete(t *testing.T) {
+ t.Run("invalid subscription provided", func(t *testing.T) {
+ testCases := []struct {
+ description string
+ subscriptionJSON string
+ smErr error
+ }{
+ {
+ "no subscription provided",
+ "",
+ nil,
+ },
+ {
+ "invalid json",
+ "-",
+ nil,
+ },
+ {
+ "invalid package id",
+ `{"package_id": "invalid"}`,
+ subscription.ErrInvalidInput,
+ },
+ }
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.description, func(t *testing.T) {
+ hw := newHandlersWrapper()
+ if tc.smErr != nil {
+ hw.sm.On("Delete", mock.Anything, mock.Anything).Return(tc.smErr)
+ }
+
+ w := httptest.NewRecorder()
+ r, _ := http.NewRequest("DELETE", "/", strings.NewReader(tc.subscriptionJSON))
+ r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
+ hw.h.Delete(w, r)
+ resp := w.Result()
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
+ hw.sm.AssertExpectations(t)
+ })
+ }
+ })
+
+ t.Run("valid subscription provided", func(t *testing.T) {
+ subscriptionJSON := `
+ {
+ "package_id": "00000000-0000-0000-0000-000000000001",
+ "notification_kind": 0
+ }
+ `
+ testCases := []struct {
+ description string
+ err error
+ expectedStatusCode int
+ }{
+ {
+ "delete subscription succeeded",
+ nil,
+ http.StatusOK,
+ },
+ {
+ "error deleting subscription",
+ tests.ErrFakeDatabaseFailure,
+ http.StatusInternalServerError,
+ },
+ }
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.description, func(t *testing.T) {
+ hw := newHandlersWrapper()
+ hw.sm.On("Delete", mock.Anything, mock.Anything).Return(tc.err)
+
+ w := httptest.NewRecorder()
+ r, _ := http.NewRequest("DELETE", "/", strings.NewReader(subscriptionJSON))
+ r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
+ hw.h.Delete(w, r)
+ resp := w.Result()
+ defer resp.Body.Close()
+
+ assert.Equal(t, tc.expectedStatusCode, resp.StatusCode)
+ hw.sm.AssertExpectations(t)
+ })
+ }
+ })
+}
+
+func TestGetByPackage(t *testing.T) {
+ t.Run("error getting package subscriptions", func(t *testing.T) {
+ testCases := []struct {
+ smErr error
+ expectedStatusCode int
+ }{
+ {
+ subscription.ErrInvalidInput,
+ http.StatusBadRequest,
+ },
+ {
+ tests.ErrFakeDatabaseFailure,
+ http.StatusInternalServerError,
+ },
+ }
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.smErr.Error(), func(t *testing.T) {
+ hw := newHandlersWrapper()
+ hw.sm.On("GetByPackageJSON", mock.Anything, mock.Anything).Return(nil, tc.smErr)
+
+ w := httptest.NewRecorder()
+ r, _ := http.NewRequest("GET", "/", nil)
+ r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
+ hw.h.GetByPackage(w, r)
+ resp := w.Result()
+ defer resp.Body.Close()
+
+ assert.Equal(t, tc.expectedStatusCode, resp.StatusCode)
+ hw.sm.AssertExpectations(t)
+ })
+ }
+ })
+
+ t.Run("get package subscriptions succeeded", func(t *testing.T) {
+ hw := newHandlersWrapper()
+ hw.sm.On("GetByPackageJSON", mock.Anything, mock.Anything).Return([]byte("dataJSON"), nil)
+
+ w := httptest.NewRecorder()
+ r, _ := http.NewRequest("GET", "/", nil)
+ r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
+ hw.h.GetByPackage(w, r)
+ resp := w.Result()
+ defer resp.Body.Close()
+ h := resp.Header
+ data, _ := ioutil.ReadAll(resp.Body)
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "application/json", h.Get("Content-Type"))
+ assert.Equal(t, helpers.BuildCacheControlHeader(0), h.Get("Cache-Control"))
+ assert.Equal(t, []byte("dataJSON"), data)
+ hw.sm.AssertExpectations(t)
+ })
+}
+
+func TestGetByUser(t *testing.T) {
+ t.Run("error getting user subscriptions", func(t *testing.T) {
+ hw := newHandlersWrapper()
+ hw.sm.On("GetByUserJSON", mock.Anything).Return(nil, tests.ErrFakeDatabaseFailure)
+
+ w := httptest.NewRecorder()
+ r, _ := http.NewRequest("GET", "/", nil)
+ r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
+ hw.h.GetByUser(w, r)
+ resp := w.Result()
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
+ hw.sm.AssertExpectations(t)
+ })
+
+ t.Run("get user subscriptions succeeded", func(t *testing.T) {
+ hw := newHandlersWrapper()
+ hw.sm.On("GetByUserJSON", mock.Anything).Return([]byte("dataJSON"), nil)
+
+ w := httptest.NewRecorder()
+ r, _ := http.NewRequest("GET", "/", nil)
+ r = r.WithContext(context.WithValue(r.Context(), hub.UserIDKey, "userID"))
+ hw.h.GetByUser(w, r)
+ resp := w.Result()
+ defer resp.Body.Close()
+ h := resp.Header
+ data, _ := ioutil.ReadAll(resp.Body)
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "application/json", h.Get("Content-Type"))
+ assert.Equal(t, helpers.BuildCacheControlHeader(0), h.Get("Cache-Control"))
+ assert.Equal(t, []byte("dataJSON"), data)
+ hw.sm.AssertExpectations(t)
+ })
+}
+
+type handlersWrapper struct {
+ sm *subscription.ManagerMock
+ h *Handlers
+}
+
+func newHandlersWrapper() *handlersWrapper {
+ sm := &subscription.ManagerMock{}
+
+ return &handlersWrapper{
+ sm: sm,
+ h: NewHandlers(sm),
+ }
+}
diff --git a/cmd/hub/main.go b/cmd/hub/main.go
index 21ee7373..5589b3a3 100644
--- a/cmd/hub/main.go
+++ b/cmd/hub/main.go
@@ -5,6 +5,7 @@ import (
"net/http"
"os"
"os/signal"
+ "sync"
"syscall"
"time"
@@ -13,8 +14,10 @@ import (
"github.com/artifacthub/hub/internal/email"
"github.com/artifacthub/hub/internal/hub"
"github.com/artifacthub/hub/internal/img/pg"
+ "github.com/artifacthub/hub/internal/notification"
"github.com/artifacthub/hub/internal/org"
"github.com/artifacthub/hub/internal/pkg"
+ "github.com/artifacthub/hub/internal/subscription"
"github.com/artifacthub/hub/internal/user"
"github.com/artifacthub/hub/internal/util"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -32,7 +35,7 @@ func main() {
log.Fatal().Err(err).Msg("Logger setup failed")
}
- // Setup services required by the handlers to operate
+ // Setup database and email services
db, err := util.SetupDB(cfg)
if err != nil {
log.Fatal().Err(err).Msg("Database setup failed")
@@ -41,22 +44,23 @@ func main() {
if s := email.NewSender(cfg); s != nil {
es = s
}
- svc := &handlers.Services{
+
+ // Setup and launch server
+ hSvc := &handlers.Services{
OrganizationManager: org.NewManager(db, es),
UserManager: user.NewManager(db, es),
PackageManager: pkg.NewManager(db),
+ SubscriptionManager: subscription.NewManager(db),
ChartRepositoryManager: chartrepo.NewManager(db),
ImageStore: pg.NewImageStore(db),
}
-
- // Setup and launch server
addr := cfg.GetString("server.addr")
srv := &http.Server{
Addr: addr,
ReadTimeout: 5 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 1 * time.Minute,
- Handler: handlers.Setup(cfg, svc).Router,
+ Handler: handlers.Setup(cfg, hSvc).Router,
}
go func() {
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
@@ -74,11 +78,27 @@ func main() {
}
}()
+ // Setup and launch notifications dispatcher
+ nSvc := ¬ification.Services{
+ DB: db,
+ ES: es,
+ NotificationManager: notification.NewManager(),
+ SubscriptionManager: subscription.NewManager(db),
+ PackageManager: pkg.NewManager(db),
+ }
+ notificationsDispatcher := notification.NewDispatcher(cfg, nSvc)
+ ctx, stopNotificationsDispatcher := context.WithCancel(context.Background())
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go notificationsDispatcher.Run(ctx, &wg)
+
// Shutdown server gracefully when SIGINT or SIGTERM signal is received
shutdown := make(chan os.Signal, 1)
signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM)
<-shutdown
log.Info().Msg("Hub server shutting down..")
+ stopNotificationsDispatcher()
+ wg.Wait()
ctx, cancel := context.WithTimeout(context.Background(), cfg.GetDuration("server.shutdownTimeout"))
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
diff --git a/database/migrations/functions/001_load_functions.sql b/database/migrations/functions/001_load_functions.sql
index 2df47282..9e37d66b 100644
--- a/database/migrations/functions/001_load_functions.sql
+++ b/database/migrations/functions/001_load_functions.sql
@@ -39,6 +39,14 @@
{{ template "images/get_image.sql" }}
{{ template "images/register_image.sql" }}
+{{ template "subscriptions/add_subscription.sql" }}
+{{ template "subscriptions/delete_subscription.sql" }}
+{{ template "subscriptions/get_package_subscriptions.sql" }}
+{{ template "subscriptions/get_subscriptors.sql" }}
+{{ template "subscriptions/get_user_subscriptions.sql" }}
+
+{{ template "notifications/get_pending_notification.sql" }}
+
---- create above / drop below ----
-- Nothing to do
diff --git a/database/migrations/functions/notifications/get_pending_notification.sql b/database/migrations/functions/notifications/get_pending_notification.sql
new file mode 100644
index 00000000..4f69deb8
--- /dev/null
+++ b/database/migrations/functions/notifications/get_pending_notification.sql
@@ -0,0 +1,35 @@
+-- get_pending_notification returns a pending notification if available,
+-- updating its processed state if the notification is delivered successfully.
+-- This function should be called from a transaction that should be rolled back
+-- if the notification is not delivered successfully.
+create or replace function get_pending_notification()
+returns setof json as $$
+declare
+ v_notification_id uuid;
+ v_notification json;
+begin
+ -- Get pending notification if available
+ select notification_id, json_build_object(
+ 'notification_id', n.notification_id,
+ 'package_version', n.package_version,
+ 'package_id', n.package_id,
+ 'notification_kind', n.notification_kind_id
+ ) into v_notification_id, v_notification
+ from notification n
+ where n.processed = false
+ for update of n skip locked
+ limit 1;
+ if not found then
+ return;
+ end if;
+
+ -- Update notification processed state
+ -- (this will be committed once the notification is delivered successfully)
+ update notification set
+ processed = true,
+ processed_at = current_timestamp
+ where notification_id = v_notification_id;
+
+ return query select v_notification;
+end
+$$ language plpgsql;
diff --git a/database/migrations/functions/packages/get_package.sql b/database/migrations/functions/packages/get_package.sql
index 8f01f621..c954e704 100644
--- a/database/migrations/functions/packages/get_package.sql
+++ b/database/migrations/functions/packages/get_package.sql
@@ -7,7 +7,9 @@ declare
v_package_name text := p_input->>'package_name';
v_chart_repository_name text := p_input->>'chart_repository_name';
begin
- if v_chart_repository_name <> '' then
+ if p_input->>'package_id' <> '' then
+ v_package_id = p_input->>'package_id';
+ elsif v_chart_repository_name <> '' then
select p.package_id into v_package_id
from package p
join chart_repository r using (chart_repository_id)
diff --git a/database/migrations/functions/packages/get_packages_updates.sql b/database/migrations/functions/packages/get_packages_updates.sql
index a0ce8674..353c8f3a 100644
--- a/database/migrations/functions/packages/get_packages_updates.sql
+++ b/database/migrations/functions/packages/get_packages_updates.sql
@@ -102,7 +102,7 @@ returns setof json as $$
on p.organization_id = o.organization_id or r.organization_id = o.organization_id
where s.version = p.latest_version
and (s.deprecated is null or s.deprecated = false)
- order by updated_at desc limit 5
+ order by p.updated_at desc limit 5
) as pru
)
);
diff --git a/database/migrations/functions/packages/register_package.sql b/database/migrations/functions/packages/register_package.sql
index fb42249e..95825294 100644
--- a/database/migrations/functions/packages/register_package.sql
+++ b/database/migrations/functions/packages/register_package.sql
@@ -6,6 +6,7 @@
create or replace function register_package(p_pkg jsonb)
returns void as $$
declare
+ v_previous_latest_version text;
v_package_id uuid;
v_name text := p_pkg->>'name';
v_display_name text := nullif(p_pkg->>'display_name', '');
@@ -14,10 +15,16 @@ declare
select (array(select jsonb_array_elements_text(nullif(p_pkg->'keywords', 'null'::jsonb))))::text[]
);
v_chart_repository_id text := (p_pkg->'chart_repository')->>'chart_repository_id';
- v_package_latest_version_needs_update boolean := false;
v_maintainer jsonb;
v_maintainer_id uuid;
begin
+ -- Get package's latest version before registration, if available
+ select latest_version into v_previous_latest_version
+ from package
+ where package_kind_id = (p_pkg->>'kind')::int
+ and chart_repository_id = nullif(v_chart_repository_id, '')::uuid
+ and name = v_name;
+
-- Package
insert into package (
name,
@@ -131,6 +138,19 @@ begin
digest = excluded.digest,
readme = excluded.readme,
links = excluded.links,
- deprecated = excluded.deprecated;
+ deprecated = excluded.deprecated,
+ updated_at = current_timestamp;
+
+ -- Register new release notification if package's latest version has been
+ -- updated and there are subscriptors for this package and notification kind
+ if semver_gte(p_pkg->>'version', v_previous_latest_version) then
+ perform * from subscription
+ where notification_kind_id = 0 -- New package release
+ and package_id = v_package_id;
+ if found then
+ insert into notification (package_id, package_version, notification_kind_id)
+ values (v_package_id, p_pkg->>'version', 0);
+ end if;
+ end if;
end
$$ language plpgsql;
diff --git a/database/migrations/functions/subscriptions/add_subscription.sql b/database/migrations/functions/subscriptions/add_subscription.sql
new file mode 100644
index 00000000..3b470761
--- /dev/null
+++ b/database/migrations/functions/subscriptions/add_subscription.sql
@@ -0,0 +1,13 @@
+-- add_subscription adds the provided subscription to the database.
+create or replace function add_subscription(p_subscription jsonb)
+returns void as $$
+ insert into subscription (
+ user_id,
+ package_id,
+ notification_kind_id
+ ) values (
+ (p_subscription->>'user_id')::uuid,
+ (p_subscription->>'package_id')::uuid,
+ (p_subscription->>'notification_kind')::int
+ );
+$$ language sql;
diff --git a/database/migrations/functions/subscriptions/delete_subscription.sql b/database/migrations/functions/subscriptions/delete_subscription.sql
new file mode 100644
index 00000000..84057b3d
--- /dev/null
+++ b/database/migrations/functions/subscriptions/delete_subscription.sql
@@ -0,0 +1,8 @@
+-- delete_subscription deletes the provided subscription from the database.
+create or replace function delete_subscription(p_subscription jsonb)
+returns void as $$
+ delete from subscription
+ where user_id = (p_subscription->>'user_id')::uuid
+ and package_id = (p_subscription->>'package_id')::uuid
+ and notification_kind_id = (p_subscription->>'notification_kind')::int;
+$$ language sql;
diff --git a/database/migrations/functions/subscriptions/get_package_subscriptions.sql b/database/migrations/functions/subscriptions/get_package_subscriptions.sql
new file mode 100644
index 00000000..67407611
--- /dev/null
+++ b/database/migrations/functions/subscriptions/get_package_subscriptions.sql
@@ -0,0 +1,15 @@
+-- get_package_subscriptions returns the subscriptions the provided user has
+-- for a given package as a json array.
+create or replace function get_package_subscriptions(p_user_id uuid, p_package_id uuid)
+returns setof json as $$
+ select coalesce(json_agg(json_build_object(
+ 'notification_kind', notification_kind_id
+ )), '[]')
+ from (
+ select *
+ from subscription
+ where user_id = p_user_id
+ and package_id = p_package_id
+ order by notification_kind_id asc
+ ) s;
+$$ language sql;
diff --git a/database/migrations/functions/subscriptions/get_subscriptors.sql b/database/migrations/functions/subscriptions/get_subscriptors.sql
new file mode 100644
index 00000000..1d2f40de
--- /dev/null
+++ b/database/migrations/functions/subscriptions/get_subscriptors.sql
@@ -0,0 +1,12 @@
+-- get_subscriptors returns the users subscribed to the package provided for
+-- the given notification kind.
+create or replace function get_subscriptors(p_package_id uuid, p_notification_kind int)
+returns setof json as $$
+ select coalesce(json_agg(json_build_object(
+ 'email', u.email
+ )), '[]')
+ from subscription s
+ join "user" u using (user_id)
+ where s.package_id = p_package_id
+ and s.notification_kind_id = p_notification_kind;
+$$ language sql;
diff --git a/database/migrations/functions/subscriptions/get_user_subscriptions.sql b/database/migrations/functions/subscriptions/get_user_subscriptions.sql
new file mode 100644
index 00000000..d2784580
--- /dev/null
+++ b/database/migrations/functions/subscriptions/get_user_subscriptions.sql
@@ -0,0 +1,50 @@
+-- get_user_subscriptions returns all the subscriptions for the provided user
+-- as a json array.
+create or replace function get_user_subscriptions(p_user_id uuid)
+returns setof json as $$
+ select coalesce(json_agg(json_build_object(
+ 'package_id', package_id,
+ 'kind', package_kind_id,
+ 'name', name,
+ 'normalized_name', normalized_name,
+ 'logo_image_id', logo_image_id,
+ 'user_alias', user_alias,
+ 'organization_name', organization_name,
+ 'organization_display_name', organization_display_name,
+ 'chart_repository', (select nullif(
+ jsonb_build_object(
+ 'name', chart_repository_name,
+ 'display_name', chart_repository_display_name
+ ),
+ '{"name": null, "display_name": null}'::jsonb
+ )),
+ 'notification_kinds', (
+ select json_agg(distinct(notification_kind_id))
+ from subscription
+ where package_id = sp.package_id
+ and user_id = p_user_id
+ )
+ )), '[]')
+ from (
+ select
+ p.package_id,
+ p.package_kind_id,
+ p.name,
+ p.normalized_name,
+ p.logo_image_id,
+ u.alias as user_alias,
+ o.name as organization_name,
+ o.display_name as organization_display_name,
+ r.name as chart_repository_name,
+ r.display_name as chart_repository_display_name
+ from package p
+ left join chart_repository r using (chart_repository_id)
+ left join "user" u on p.user_id = u.user_id or r.user_id = u.user_id
+ left join organization o
+ on p.organization_id = o.organization_id or r.organization_id = o.organization_id
+ where p.package_id in (
+ select distinct(package_id) from subscription where user_id = p_user_id
+ )
+ order by p.normalized_name asc
+ ) sp;
+$$ language sql;
diff --git a/database/migrations/schema/001_initial_schema.sql b/database/migrations/schema/001_initial_schema.sql
index 1244120c..79c5e82a 100644
--- a/database/migrations/schema/001_initial_schema.sql
+++ b/database/migrations/schema/001_initial_schema.sql
@@ -111,6 +111,8 @@ create table if not exists snapshot (
links jsonb,
data jsonb,
deprecated boolean,
+ created_at timestamptz default current_timestamp not null,
+ updated_at timestamptz default current_timestamp not null,
primary key (package_id, version)
);
@@ -144,6 +146,34 @@ create table if not exists user_starred_package (
primary key (user_id, package_id)
);
+create table if not exists notification_kind (
+ notification_kind_id integer primary key,
+ name text not null check (name <> '')
+);
+
+insert into notification_kind values (0, 'New package release');
+insert into notification_kind values (1, 'Security alert');
+
+create table notification (
+ notification_id uuid primary key default gen_random_uuid(),
+ created_at timestamptz default current_timestamp not null,
+ processed boolean not null default false,
+ processed_at timestamptz,
+ package_version text not null check (package_version <> ''),
+ package_id uuid not null references package on delete cascade,
+ notification_kind_id integer not null references notification_kind on delete restrict,
+ unique (package_id, package_version)
+);
+
+create index notification_not_processed_idx on notification (notification_id) where processed = 'false';
+
+create table if not exists subscription (
+ user_id uuid not null references "user" on delete cascade,
+ package_id uuid not null references package on delete cascade,
+ notification_kind_id integer not null references notification_kind on delete restrict,
+ primary key (user_id, package_id, notification_kind_id)
+);
+
{{ if eq .loadSampleData "true" }}
{{ template "data/sample.sql" }}
{{ end }}
diff --git a/database/tests/functions/notifications/get_pending_notification.sql b/database/tests/functions/notifications/get_pending_notification.sql
new file mode 100644
index 00000000..e23b20ea
--- /dev/null
+++ b/database/tests/functions/notifications/get_pending_notification.sql
@@ -0,0 +1,66 @@
+-- Start transaction and plan tests
+begin;
+select plan(4);
+
+-- Declare some variables
+\set notification1ID '00000000-0000-0000-0000-000000000001'
+\set package1ID '00000000-0000-0000-0000-000000000001'
+
+-- No pending notifications available yet
+select is_empty(
+ $$ select get_pending_notification()::jsonb $$,
+ 'Should not return a notification'
+);
+
+-- Seed some data
+insert into package (
+ package_id,
+ name,
+ latest_version,
+ package_kind_id
+) values (
+ :'package1ID',
+ 'Package 1',
+ '1.0.0',
+ 1
+);
+insert into notification (notification_id, package_version, package_id, notification_kind_id)
+values (:'notification1ID', '1.0.0', :'package1ID', 0);
+savepoint before_getting_notification;
+
+-- Run some tests
+select is(
+ get_pending_notification()::jsonb,
+ '{
+ "notification_id": "00000000-0000-0000-0000-000000000001",
+ "package_version": "1.0.0",
+ "package_id": "00000000-0000-0000-0000-000000000001",
+ "notification_kind": 0
+ }'::jsonb,
+ 'Notification should be returned'
+);
+select results_eq(
+ $$
+ select processed from notification
+ where notification_id = '00000000-0000-0000-0000-000000000001'
+ $$,
+ $$
+ values (true)
+ $$,
+ 'Notification should be marked as processed'
+);
+rollback to before_getting_notification;
+select results_eq(
+ $$
+ select processed from notification
+ where notification_id = '00000000-0000-0000-0000-000000000001'
+ $$,
+ $$
+ values (false)
+ $$,
+ 'Notification should not be marked as processed as transaction was rolled back'
+);
+
+-- Finish tests and rollback transaction
+select * from finish();
+rollback;
diff --git a/database/tests/functions/packages/get_package.sql b/database/tests/functions/packages/get_package.sql
index 500b719f..d9d0c98d 100644
--- a/database/tests/functions/packages/get_package.sql
+++ b/database/tests/functions/packages/get_package.sql
@@ -1,6 +1,6 @@
-- Start transaction and plan tests
begin;
-select plan(4);
+select plan(5);
-- Declare some variables
\set org1ID '00000000-0000-0000-0000-000000000001'
@@ -76,7 +76,7 @@ insert into snapshot (
'12.1.0',
'digest-package1-1.0.0',
'readme-version-1.0.0',
- '{"link1": "https://link1", "link2": "https://link2"}',
+ '[{"name": "link1", "url": "https://link1"}, {"name": "link2", "url": "https://link2"}]',
'{"key": "value"}',
true
);
@@ -102,7 +102,7 @@ insert into snapshot (
'12.0.0',
'digest-package1-0.0.9',
'readme-version-0.0.9',
- '{"link1": "https://link1", "link2": "https://link2"}',
+ '[{"name": "link1", "url": "https://link1"}, {"name": "link2", "url": "https://link2"}]',
'{"key": "value"}'
);
insert into package (
@@ -139,6 +139,61 @@ insert into snapshot (
);
-- Run some tests
+select is(
+ get_package('{
+ "package_id": "00000000-0000-0000-0000-000000000001"
+ }')::jsonb,
+ '{
+ "package_id": "00000000-0000-0000-0000-000000000001",
+ "kind": 0,
+ "name": "Package 1",
+ "normalized_name": "package-1",
+ "logo_image_id": "00000000-0000-0000-0000-000000000001",
+ "display_name": "Package 1",
+ "description": "description",
+ "keywords": ["kw1", "kw2"],
+ "home_url": "home_url",
+ "readme": "readme-version-1.0.0",
+ "links": [
+ {
+ "name": "link1",
+ "url": "https://link1"
+ },
+ {
+ "name": "link2",
+ "url": "https://link2"
+ }
+ ],
+ "data": {
+ "key": "value"
+ },
+ "version": "1.0.0",
+ "available_versions": ["0.0.9", "1.0.0"],
+ "app_version": "12.1.0",
+ "digest": "digest-package1-1.0.0",
+ "deprecated": true,
+ "maintainers": [
+ {
+ "name": "name1",
+ "email": "email1"
+ },
+ {
+ "name": "name2",
+ "email": "email2"
+ }
+ ],
+ "user_alias": "user1",
+ "organization_name": null,
+ "organization_display_name": null,
+ "chart_repository": {
+ "chart_repository_id": "00000000-0000-0000-0000-000000000001",
+ "name": "repo1",
+ "display_name": "Repo 1",
+ "url": "https://repo1.com"
+ }
+ }'::jsonb,
+ 'Last package1 version is returned as a json object'
+);
select is(
get_package('{
"package_name": "package-1",
@@ -155,10 +210,16 @@ select is(
"keywords": ["kw1", "kw2"],
"home_url": "home_url",
"readme": "readme-version-1.0.0",
- "links": {
- "link1": "https://link1",
- "link2": "https://link2"
- },
+ "links": [
+ {
+ "name": "link1",
+ "url": "https://link1"
+ },
+ {
+ "name": "link2",
+ "url": "https://link2"
+ }
+ ],
"data": {
"key": "value"
},
@@ -206,10 +267,16 @@ select is(
"keywords": ["kw1", "kw2", "older"],
"home_url": "home_url (older)",
"readme": "readme-version-0.0.9",
- "links": {
- "link1": "https://link1",
- "link2": "https://link2"
- },
+ "links": [
+ {
+ "name": "link1",
+ "url": "https://link1"
+ },
+ {
+ "name": "link2",
+ "url": "https://link2"
+ }
+ ],
"data": {
"key": "value"
},
diff --git a/database/tests/functions/packages/get_packages_starred_by_user.sql b/database/tests/functions/packages/get_packages_starred_by_user.sql
index c9030316..b17f9061 100644
--- a/database/tests/functions/packages/get_packages_starred_by_user.sql
+++ b/database/tests/functions/packages/get_packages_starred_by_user.sql
@@ -40,8 +40,7 @@ insert into snapshot (
description,
app_version,
digest,
- readme,
- links
+ readme
) values (
:'package1ID',
'1.0.0',
@@ -49,8 +48,7 @@ insert into snapshot (
'description',
'12.1.0',
'digest-package1-1.0.0',
- 'readme',
- '{"link1": "https://link1", "link2": "https://link2"}'
+ 'readme'
);
insert into snapshot (
package_id,
@@ -59,8 +57,7 @@ insert into snapshot (
description,
app_version,
digest,
- readme,
- links
+ readme
) values (
:'package1ID',
'0.0.9',
@@ -68,8 +65,7 @@ insert into snapshot (
'description',
'12.0.0',
'digest-package1-0.0.9',
- 'readme',
- '{"link1": "https://link1", "link2": "https://link2"}'
+ 'readme'
);
insert into package (
package_id,
@@ -96,7 +92,6 @@ insert into snapshot (
app_version,
digest,
readme,
- links,
deprecated
) values (
:'package2ID',
@@ -106,7 +101,6 @@ insert into snapshot (
'12.1.0',
'digest-package2-1.0.0',
'readme',
- '{"link1": "https://link1", "link2": "https://link2"}',
true
);
insert into user_starred_package (user_id, package_id) values (:'user1ID', :'package1ID');
diff --git a/database/tests/functions/packages/get_packages_stats.sql b/database/tests/functions/packages/get_packages_stats.sql
index 71bb3e30..f81bf590 100644
--- a/database/tests/functions/packages/get_packages_stats.sql
+++ b/database/tests/functions/packages/get_packages_stats.sql
@@ -45,8 +45,7 @@ insert into snapshot (
home_url,
app_version,
digest,
- readme,
- links
+ readme
) values (
:'package1ID',
'1.0.0',
@@ -55,8 +54,7 @@ insert into snapshot (
'home_url',
'12.1.0',
'digest-package1-1.0.0',
- 'readme',
- '{"link1": "https://link1", "link2": "https://link2"}'
+ 'readme'
);
insert into snapshot (
package_id,
@@ -66,8 +64,7 @@ insert into snapshot (
home_url,
app_version,
digest,
- readme,
- links
+ readme
) values (
:'package1ID',
'0.0.9',
@@ -76,8 +73,7 @@ insert into snapshot (
'home_url',
'12.0.0',
'digest-package1-0.0.9',
- 'readme',
- '{"link1": "https://link1", "link2": "https://link2"}'
+ 'readme'
);
insert into package (
package_id,
@@ -102,8 +98,7 @@ insert into snapshot (
home_url,
app_version,
digest,
- readme,
- links
+ readme
) values (
:'package2ID',
'1.0.0',
@@ -112,8 +107,7 @@ insert into snapshot (
'home_url',
'12.1.0',
'digest-package2-1.0.0',
- 'readme',
- '{"link1": "https://link1", "link2": "https://link2"}'
+ 'readme'
);
insert into snapshot (
package_id,
@@ -123,8 +117,7 @@ insert into snapshot (
home_url,
app_version,
digest,
- readme,
- links
+ readme
) values (
:'package2ID',
'0.0.9',
@@ -133,8 +126,7 @@ insert into snapshot (
'home_url',
'12.0.0',
'digest-package2-0.0.9',
- 'readme',
- '{"link1": "https://link1", "link2": "https://link2"}'
+ 'readme'
);
-- Some packages have just been seeded
diff --git a/database/tests/functions/packages/get_packages_updates.sql b/database/tests/functions/packages/get_packages_updates.sql
index ee0dd827..42ea4fd0 100644
--- a/database/tests/functions/packages/get_packages_updates.sql
+++ b/database/tests/functions/packages/get_packages_updates.sql
@@ -59,7 +59,6 @@ insert into snapshot (
keywords,
home_url,
readme,
- links,
deprecated
) values (
:'package1ID',
@@ -69,7 +68,6 @@ insert into snapshot (
'{"kw1", "kw2"}',
'home_url',
'readme',
- '{"link1": "https://link1", "link2": "https://link2"}',
false
);
insert into package (
@@ -102,8 +100,7 @@ insert into snapshot (
home_url,
app_version,
digest,
- readme,
- links
+ readme
) values (
:'package2ID',
'1.0.0',
@@ -113,8 +110,7 @@ insert into snapshot (
'home_url',
'12.1.0',
'digest-package2-1.0.0',
- 'readme',
- '{"link1": "https://link1", "link2": "https://link2"}'
+ 'readme'
);
insert into package (
package_id,
@@ -146,7 +142,6 @@ insert into snapshot (
app_version,
digest,
readme,
- links,
deprecated
) values (
:'package3ID',
@@ -158,7 +153,6 @@ insert into snapshot (
'12.1.0',
'digest-package3-1.0.0',
'readme',
- '{"link1": "https://link1", "link2": "https://link2"}',
true
);
diff --git a/database/tests/functions/packages/register_package.sql b/database/tests/functions/packages/register_package.sql
index e2084edc..7d7700f5 100644
--- a/database/tests/functions/packages/register_package.sql
+++ b/database/tests/functions/packages/register_package.sql
@@ -1,16 +1,18 @@
-- Start transaction and plan tests
begin;
-select plan(11);
+select plan(16);
-- Declare some variables
\set org1ID '00000000-0000-0000-0000-000000000001'
\set repo1ID '00000000-0000-0000-0000-000000000001'
+\set user1ID '00000000-0000-0000-0000-000000000001'
-- Seed some data
insert into organization (organization_id, name, display_name, description, home_url)
values (:'org1ID', 'org1', 'Organization 1', 'Description 1', 'https://org1.com');
insert into chart_repository (chart_repository_id, name, display_name, url)
values (:'repo1ID', 'repo1', 'Repo 1', 'https://repo1.com');
+insert into "user" (user_id, alias, email) values (:'user1ID', 'user1', 'user1@email.com');
-- Register package
select register_package('
@@ -24,10 +26,16 @@ select register_package('
"keywords": ["kw1", "kw2"],
"home_url": "home_url",
"readme": "readme-version-1.0.0",
- "links": {
- "link1": "https://link1",
- "link2": "https://link2"
- },
+ "links": [
+ {
+ "name": "link1",
+ "url": "https://link1"
+ },
+ {
+ "name": "link2",
+ "url": "https://link2"
+ }
+ ],
"data": {
"key": "value"
},
@@ -105,7 +113,7 @@ select results_eq(
'12.1.0',
'digest-package1-1.0.0',
'readme-version-1.0.0',
- '{"link1": "https://link1", "link2": "https://link2"}'::jsonb,
+ '[{"name": "link1", "url": "https://link1"}, {"name": "link2", "url": "https://link2"}]'::jsonb,
'{"key": "value"}'::jsonb,
false
)
@@ -130,6 +138,20 @@ select results_eq(
$$,
'Maintainers should exist'
);
+select is_empty(
+ $$
+ select *
+ from notification n
+ join package p using (package_id)
+ where p.name = 'package1'
+ $$,
+ 'No new release notifications should exist for first version of package1'
+);
+
+-- Subscribe user1 to package1 new releases notifications
+insert into subscription (user_id, package_id, notification_kind_id)
+select :'user1ID', package_id, 0
+from package where name = 'package1';
-- Register a new version of the package previously registered
select register_package('
@@ -227,6 +249,16 @@ select is_empty(
$$,
'Orphan maintainers were deleted'
);
+select isnt_empty(
+ $$
+ select *
+ from notification n
+ join package p using (package_id)
+ where p.name = 'package1'
+ and n.package_version = '2.0.0'
+ $$,
+ 'New release notification should exist for package1 version 2.0.0'
+);
-- Register an old version of the package previously registered
select register_package('
@@ -318,6 +350,16 @@ select results_eq(
$$ values ('name1', 'email1') $$,
'Package maintainers should not have been updated'
);
+select is_empty(
+ $$
+ select *
+ from notification n
+ join package p using (package_id)
+ where p.name = 'package1'
+ and n.package_version = '0.0.9'
+ $$,
+ 'No new release notifications should exist for package1 version 0.0.9'
+);
-- Register package that belongs to an organization and check it succeeded
select register_package('
@@ -352,6 +394,38 @@ select results_eq(
$$,
'Package that belongs to organization should exist'
);
+select is_empty(
+ $$
+ select *
+ from notification n
+ join package p using (package_id)
+ where p.name = 'package3'
+ and n.package_version = '1.0.0'
+ $$,
+ 'No new release notifications should exist for first version of package3'
+);
+
+-- Register a new version of the package previously registered
+select register_package('
+{
+ "kind": 1,
+ "name": "package3",
+ "display_name": "Package 3",
+ "description": "description",
+ "version": "2.0.0",
+ "organization_id": "00000000-0000-0000-0000-000000000001"
+}
+');
+select is_empty(
+ $$
+ select *
+ from notification n
+ join package p using (package_id)
+ where p.name = 'package3'
+ and n.package_version = '2.0.0'
+ $$,
+ 'No new release notifications should exist for new version of package3 (no subscriptors)'
+);
-- Finish tests and rollback transaction
select * from finish();
diff --git a/database/tests/functions/packages/search_packages.sql b/database/tests/functions/packages/search_packages.sql
index f3c5359e..a036be1e 100644
--- a/database/tests/functions/packages/search_packages.sql
+++ b/database/tests/functions/packages/search_packages.sql
@@ -69,8 +69,7 @@ insert into snapshot (
home_url,
app_version,
digest,
- readme,
- links
+ readme
) values (
:'package1ID',
'1.0.0',
@@ -80,8 +79,7 @@ insert into snapshot (
'home_url',
'12.1.0',
'digest-package1-1.0.0',
- 'readme',
- '{"link1": "https://link1", "link2": "https://link2"}'
+ 'readme'
);
insert into snapshot (
package_id,
@@ -92,8 +90,7 @@ insert into snapshot (
home_url,
app_version,
digest,
- readme,
- links
+ readme
) values (
:'package1ID',
'0.0.9',
@@ -103,8 +100,7 @@ insert into snapshot (
'home_url',
'12.0.0',
'digest-package1-0.0.9',
- 'readme',
- '{"link1": "https://link1", "link2": "https://link2"}'
+ 'readme'
);
insert into package (
package_id,
@@ -135,7 +131,6 @@ insert into snapshot (
app_version,
digest,
readme,
- links,
deprecated
) values (
:'package2ID',
@@ -147,7 +142,6 @@ insert into snapshot (
'12.1.0',
'digest-package2-1.0.0',
'readme',
- '{"link1": "https://link1", "link2": "https://link2"}',
true
);
insert into snapshot (
@@ -159,8 +153,7 @@ insert into snapshot (
home_url,
app_version,
digest,
- readme,
- links
+ readme
) values (
:'package2ID',
'0.0.9',
@@ -170,8 +163,7 @@ insert into snapshot (
'home_url',
'12.0.0',
'digest-package2-0.0.9',
- 'readme',
- '{"link1": "https://link1", "link2": "https://link2"}'
+ 'readme'
);
insert into package (
package_id,
@@ -196,16 +188,14 @@ insert into snapshot (
display_name,
description,
keywords,
- readme,
- links
+ readme
) values (
:'package3ID',
'1.0.0',
'Package 3',
'description',
'{"kw3"}',
- 'readme',
- '{"link1": "https://link1", "link2": "https://link2"}'
+ 'readme'
);
-- Some packages have just been seeded
diff --git a/database/tests/functions/subscriptions/add_subscription.sql b/database/tests/functions/subscriptions/add_subscription.sql
new file mode 100644
index 00000000..ad09cf3e
--- /dev/null
+++ b/database/tests/functions/subscriptions/add_subscription.sql
@@ -0,0 +1,54 @@
+-- Start transaction and plan tests
+begin;
+select plan(1);
+
+-- Declare some variables
+\set user1ID '00000000-0000-0000-0000-000000000001'
+\set package1ID '00000000-0000-0000-0000-000000000001'
+
+-- Seed some data
+insert into "user" (user_id, alias, email)
+values (:'user1ID', 'user1', 'user1@email.com');
+insert into package (
+ package_id,
+ name,
+ latest_version,
+ package_kind_id
+) values (
+ :'package1ID',
+ 'Package 1',
+ '1.0.0',
+ 1
+);
+
+-- Add subscription
+select add_subscription('
+{
+ "user_id": "00000000-0000-0000-0000-000000000001",
+ "package_id": "00000000-0000-0000-0000-000000000001",
+ "notification_kind": 0
+}
+'::jsonb);
+
+-- Check if subscription was added successfully
+select results_eq(
+ $$
+ select
+ user_id,
+ package_id,
+ notification_kind_id
+ from subscription
+ $$,
+ $$
+ values (
+ '00000000-0000-0000-0000-000000000001'::uuid,
+ '00000000-0000-0000-0000-000000000001'::uuid,
+ 0
+ )
+ $$,
+ 'Subscription should exist'
+);
+
+-- Finish tests and rollback transaction
+select * from finish();
+rollback;
diff --git a/database/tests/functions/subscriptions/delete_subscription.sql b/database/tests/functions/subscriptions/delete_subscription.sql
new file mode 100644
index 00000000..8559e043
--- /dev/null
+++ b/database/tests/functions/subscriptions/delete_subscription.sql
@@ -0,0 +1,49 @@
+-- Start transaction and plan tests
+begin;
+select plan(1);
+
+-- Declare some variables
+\set user1ID '00000000-0000-0000-0000-000000000001'
+\set package1ID '00000000-0000-0000-0000-000000000001'
+
+-- Seed some data
+insert into "user" (user_id, alias, email)
+values (:'user1ID', 'user1', 'user1@email.com');
+insert into package (
+ package_id,
+ name,
+ latest_version,
+ package_kind_id
+) values (
+ :'package1ID',
+ 'Package 1',
+ '1.0.0',
+ 1
+);
+insert into subscription (user_id, package_id, notification_kind_id)
+values (:'user1ID', :'package1ID', 0);
+
+-- Delete subscription
+select delete_subscription('
+{
+ "user_id": "00000000-0000-0000-0000-000000000001",
+ "package_id": "00000000-0000-0000-0000-000000000001",
+ "notification_kind": 0
+}
+'::jsonb);
+
+-- Check if subscription was deleted successfully
+select is_empty(
+ $$
+ select *
+ from subscription
+ where user_id = '00000000-0000-0000-0000-000000000001'
+ and package_id = '00000000-0000-0000-0000-000000000001'
+ and notification_kind_id = 0
+ $$,
+ 'Subscription should not exist'
+);
+
+-- Finish tests and rollback transaction
+select * from finish();
+rollback;
diff --git a/database/tests/functions/subscriptions/get_package_subscriptions.sql b/database/tests/functions/subscriptions/get_package_subscriptions.sql
new file mode 100644
index 00000000..ac0160c5
--- /dev/null
+++ b/database/tests/functions/subscriptions/get_package_subscriptions.sql
@@ -0,0 +1,52 @@
+-- Start transaction and plan tests
+begin;
+select plan(3);
+
+-- Declare some variables
+\set user1ID '00000000-0000-0000-0000-000000000001'
+\set user2ID '00000000-0000-0000-0000-000000000002'
+\set package1ID '00000000-0000-0000-0000-000000000001'
+\set package2ID '00000000-0000-0000-0000-000000000002'
+
+-- Seed some data
+insert into "user" (user_id, alias, email)
+values (:'user1ID', 'user1', 'user1@email.com');
+insert into "user" (user_id, alias, email)
+values (:'user2ID', 'user2', 'user2@email.com');
+insert into package (
+ package_id,
+ name,
+ latest_version,
+ package_kind_id
+) values (
+ :'package1ID',
+ 'Package 1',
+ '1.0.0',
+ 1
+);
+insert into subscription (user_id, package_id, notification_kind_id)
+values (:'user1ID', :'package1ID', 0);
+
+-- Run some tests
+select is(
+ get_package_subscriptions(:'user1ID', :'package1ID')::jsonb,
+ '[{
+ "notification_kind": 0
+ }]'::jsonb,
+ 'A subscription with notification kind 0 should be returned'
+);
+select is(
+ get_package_subscriptions(:'user2ID', :'package1ID')::jsonb,
+ '[]'::jsonb,
+ 'No subscriptions should be returned for user2 and package1'
+);
+select is(
+ get_package_subscriptions(:'user1ID', :'package2ID')::jsonb,
+ '[]'::jsonb,
+ 'No subscriptions should be returned for user1 and package2'
+);
+
+
+-- Finish tests and rollback transaction
+select * from finish();
+rollback;
diff --git a/database/tests/functions/subscriptions/get_subscriptors.sql b/database/tests/functions/subscriptions/get_subscriptors.sql
new file mode 100644
index 00000000..f4aaeeac
--- /dev/null
+++ b/database/tests/functions/subscriptions/get_subscriptors.sql
@@ -0,0 +1,58 @@
+-- Start transaction and plan tests
+begin;
+select plan(2);
+
+-- Declare some variables
+\set user1ID '00000000-0000-0000-0000-000000000001'
+\set user2ID '00000000-0000-0000-0000-000000000002'
+\set user3ID '00000000-0000-0000-0000-000000000003'
+\set package1ID '00000000-0000-0000-0000-000000000001'
+\set package2ID '00000000-0000-0000-0000-000000000002'
+
+-- Seed some data
+insert into "user" (user_id, alias, email)
+values (:'user1ID', 'user1', 'user1@email.com');
+insert into "user" (user_id, alias, email)
+values (:'user2ID', 'user2', 'user2@email.com');
+insert into "user" (user_id, alias, email)
+values (:'user3ID', 'user3', 'user3@email.com');
+insert into package (
+ package_id,
+ name,
+ latest_version,
+ package_kind_id
+) values (
+ :'package1ID',
+ 'Package 1',
+ '1.0.0',
+ 1
+);
+insert into subscription (user_id, package_id, notification_kind_id)
+values (:'user1ID', :'package1ID', 0);
+insert into subscription (user_id, package_id, notification_kind_id)
+values (:'user2ID', :'package1ID', 0);
+insert into subscription (user_id, package_id, notification_kind_id)
+values (:'user3ID', :'package1ID', 1);
+
+-- Run some tests
+select is(
+ get_subscriptors(:'package1ID', 0)::jsonb,
+ '[
+ {
+ "email": "user1@email.com"
+ },
+ {
+ "email": "user2@email.com"
+ }
+ ]'::jsonb,
+ 'Two subscriptors expected for package1 and kind new releases'
+);
+select is(
+ get_subscriptors(:'package2ID', 0)::jsonb,
+ '[]'::jsonb,
+ 'No subscriptors expected for package2 and kind new releases'
+);
+
+-- Finish tests and rollback transaction
+select * from finish();
+rollback;
diff --git a/database/tests/functions/subscriptions/get_user_subscriptions.sql b/database/tests/functions/subscriptions/get_user_subscriptions.sql
new file mode 100644
index 00000000..44cd32a3
--- /dev/null
+++ b/database/tests/functions/subscriptions/get_user_subscriptions.sql
@@ -0,0 +1,100 @@
+-- Start transaction and plan tests
+begin;
+select plan(2);
+
+-- Declare some variables
+\set org1ID '00000000-0000-0000-0000-000000000001'
+\set user1ID '00000000-0000-0000-0000-000000000001'
+\set user2ID '00000000-0000-0000-0000-000000000002'
+\set repo1ID '00000000-0000-0000-0000-000000000001'
+\set package1ID '00000000-0000-0000-0000-000000000001'
+\set package2ID '00000000-0000-0000-0000-000000000002'
+\set image1ID '00000000-0000-0000-0000-000000000001'
+\set image2ID '00000000-0000-0000-0000-000000000002'
+
+-- Seed some data
+insert into organization (organization_id, name, display_name, description, home_url)
+values (:'org1ID', 'org1', 'Organization 1', 'Description 1', 'https://org1.com');
+insert into "user" (user_id, alias, email)
+values (:'user1ID', 'user1', 'user1@email.com');
+insert into "user" (user_id, alias, email)
+values (:'user2ID', 'user2', 'user2@email.com');
+insert into chart_repository (chart_repository_id, name, display_name, url, user_id)
+values (:'repo1ID', 'repo1', 'Repo 1', 'https://repo1.com', :'user1ID');
+insert into package (
+ package_id,
+ name,
+ latest_version,
+ logo_image_id,
+ package_kind_id,
+ chart_repository_id
+) values (
+ :'package1ID',
+ 'Package 1',
+ '1.0.0',
+ :'image1ID',
+ 0,
+ :'repo1ID'
+);
+insert into package (
+ package_id,
+ name,
+ latest_version,
+ logo_image_id,
+ package_kind_id,
+ organization_id
+) values (
+ :'package2ID',
+ 'Package 2',
+ '1.0.0',
+ :'image2ID',
+ 1,
+ :'org1ID'
+);
+insert into subscription (user_id, package_id, notification_kind_id)
+values (:'user1ID', :'package1ID', 0);
+insert into subscription (user_id, package_id, notification_kind_id)
+values (:'user1ID', :'package1ID', 1);
+insert into subscription (user_id, package_id, notification_kind_id)
+values (:'user1ID', :'package2ID', 0);
+
+-- Run some tests
+select is(
+ get_user_subscriptions(:'user1ID')::jsonb,
+ '[{
+ "package_id": "00000000-0000-0000-0000-000000000001",
+ "kind": 0,
+ "name": "Package 1",
+ "normalized_name": "package-1",
+ "logo_image_id": "00000000-0000-0000-0000-000000000001",
+ "user_alias": "user1",
+ "organization_name": null,
+ "organization_display_name": null,
+ "chart_repository": {
+ "name": "repo1",
+ "display_name": "Repo 1"
+ },
+ "notification_kinds": [0, 1]
+ }, {
+ "package_id": "00000000-0000-0000-0000-000000000002",
+ "kind": 1,
+ "name": "Package 2",
+ "normalized_name": "package-2",
+ "logo_image_id": "00000000-0000-0000-0000-000000000002",
+ "user_alias": null,
+ "organization_name": "org1",
+ "organization_display_name": "Organization 1",
+ "chart_repository": null,
+ "notification_kinds": [0]
+ }]'::jsonb,
+ 'Two subscriptions should be returned'
+);
+select is(
+ get_user_subscriptions(:'user2ID')::jsonb,
+ '[]',
+ 'No subscriptions expected for user2'
+);
+
+-- Finish tests and rollback transaction
+select * from finish();
+rollback;
diff --git a/database/tests/schema/schema.sql b/database/tests/schema/schema.sql
index f53508ea..52d554f3 100644
--- a/database/tests/schema/schema.sql
+++ b/database/tests/schema/schema.sql
@@ -1,6 +1,6 @@
-- Start transaction and plan tests
begin;
-select plan(62);
+select plan(82);
-- Check default_text_search_config is correct
select results_eq(
@@ -19,12 +19,15 @@ select tables_are(array[
'image',
'image_version',
'maintainer',
+ 'notification',
+ 'notification_kind',
'organization',
'package',
'package__maintainer',
'package_kind',
'session',
'snapshot',
+ 'subscription',
'user',
'user_starred_package',
'user__organization',
@@ -62,6 +65,19 @@ select columns_are('maintainer', array[
'name',
'email'
]);
+select columns_are('notification', array[
+ 'notification_id',
+ 'created_at',
+ 'processed',
+ 'processed_at',
+ 'package_version',
+ 'package_id',
+ 'notification_kind_id'
+]);
+select columns_are('notification_kind', array[
+ 'notification_kind_id',
+ 'name'
+]);
select columns_are('organization', array[
'organization_id',
'name',
@@ -114,7 +130,14 @@ select columns_are('snapshot', array[
'readme',
'links',
'data',
- 'deprecated'
+ 'deprecated',
+ 'created_at',
+ 'updated_at'
+]);
+select columns_are('subscription', array[
+ 'user_id',
+ 'package_id',
+ 'notification_kind_id'
]);
select columns_are('user', array[
'user_id',
@@ -148,10 +171,30 @@ select indexes_are('chart_repository', array[
'chart_repository_name_key',
'chart_repository_url_key'
]);
+select indexes_are('email_verification_code', array[
+ 'email_verification_code_pkey',
+ 'email_verification_code_user_id_key'
+]);
+select indexes_are('image', array[
+ 'image_pkey',
+ 'image_original_hash_key'
+]);
+select indexes_are('image_version', array[
+ 'image_version_pkey'
+]);
select indexes_are('maintainer', array[
'maintainer_pkey',
'maintainer_email_key'
]);
+select indexes_are('notification', array[
+ 'notification_pkey',
+ 'notification_not_processed_idx',
+ 'notification_package_id_package_version_key'
+]);
+select indexes_are('organization', array[
+ 'organization_pkey',
+ 'organization_name_key'
+]);
select indexes_are('package', array[
'package_pkey',
'package_package_kind_id_chart_repository_id_name_key',
@@ -170,10 +213,27 @@ select indexes_are('package__maintainer', array[
select indexes_are('package_kind', array[
'package_kind_pkey'
]);
+select indexes_are('session', array[
+ 'session_pkey'
+]);
select indexes_are('snapshot', array[
'snapshot_pkey',
'snapshot_digest_key'
]);
+select indexes_are('subscription', array[
+ 'subscription_pkey'
+]);
+select indexes_are('user', array[
+ 'user_pkey',
+ 'user_alias_key',
+ 'user_email_key'
+]);
+select indexes_are('user__organization', array[
+ 'user__organization_pkey'
+]);
+select indexes_are('user_starred_package', array[
+ 'user_starred_package_pkey'
+]);
-- Check expected functions exist
select has_function('add_organization');
@@ -217,6 +277,14 @@ select has_function('update_chart_repository');
select has_function('get_image');
select has_function('register_image');
+select has_function('add_subscription');
+select has_function('delete_subscription');
+select has_function('get_package_subscriptions');
+select has_function('get_subscriptors');
+select has_function('get_user_subscriptions');
+
+select has_function('get_pending_notification');
+
-- Check package kinds exist
select results_eq(
'select * from package_kind',
@@ -228,6 +296,16 @@ select results_eq(
'Package kinds should exist'
);
+-- Check notification kinds exist
+select results_eq(
+ 'select * from notification_kind',
+ $$ values
+ (0, 'New package release'),
+ (1, 'Security alert')
+ $$,
+ 'Package kinds should exist'
+);
+
-- Finish tests and rollback transaction
select * from finish();
rollback;
diff --git a/internal/hub/external.go b/internal/hub/external.go
index 9e96c774..63e14a8d 100644
--- a/internal/hub/external.go
+++ b/internal/hub/external.go
@@ -10,8 +10,9 @@ import (
// DB defines the methods the database handler must provide.
type DB interface {
- QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
+ Begin(ctx context.Context) (pgx.Tx, error)
Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error)
+ QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
}
// EmailSender defines the methods the email sender must provide.
diff --git a/internal/hub/notification.go b/internal/hub/notification.go
new file mode 100644
index 00000000..d09ff738
--- /dev/null
+++ b/internal/hub/notification.go
@@ -0,0 +1,33 @@
+package hub
+
+import (
+ "context"
+
+ "github.com/jackc/pgx/v4"
+)
+
+// Notification represents the details of a notification that will be sent to
+// a set of subscribers interested on it.
+type Notification struct {
+ NotificationID string `json:"notification_id"`
+ PackageVersion string `json:"package_version"`
+ PackageID string `json:"package_id"`
+ NotificationKind NotificationKind `json:"notification_kind"`
+}
+
+// NotificationKind represents the kind of a notification.
+type NotificationKind int64
+
+const (
+ // NewRelease represents a notification for a new package release.
+ NewRelease NotificationKind = 0
+
+ // SecurityAlert represents a notification for a security alert.
+ SecurityAlert NotificationKind = 1
+)
+
+// NotificationManager describes the methods a NotificationManager
+// implementation must provide.
+type NotificationManager interface {
+ GetPending(ctx context.Context, tx pgx.Tx) (*Notification, error)
+}
diff --git a/internal/hub/pkg.go b/internal/hub/pkg.go
index 60d50461..63f04284 100644
--- a/internal/hub/pkg.go
+++ b/internal/hub/pkg.go
@@ -4,6 +4,7 @@ import "context"
// GetPackageInput represents the input used to get a specific package.
type GetPackageInput struct {
+ PackageID string `json:"package_id"`
ChartRepositoryName string `json:"chart_repository_name"`
PackageName string `json:"package_name"`
Version string `json:"version"`
@@ -24,31 +25,31 @@ type Maintainer struct {
// Package represents a Kubernetes package.
type Package struct {
- PackageID string `json:"package_id"`
- Kind PackageKind `json:"kind"`
- Name string `json:"name"`
- NormalizedName string `json:"normalized_name"`
- LogoURL string `json:"logo_url"`
- LogoImageID string `json:"logo_image_id"`
- Stars int `json:"stars"`
- DisplayName string `json:"display_name"`
- Description string `json:"description"`
- Keywords []string `json:"keywords"`
- HomeURL string `json:"home_url"`
- Readme string `json:"readme"`
- Links []*Link `json:"links"`
- Data map[string]interface{} `json:"data"`
- Version string `json:"version"`
- AvailableVersions []string `json:"available_versions"`
- AppVersion string `json:"app_version"`
- Digest string `json:"digest"`
- Deprecated bool `json:"deprecated"`
- Maintainers []*Maintainer `json:"maintainers"`
- UserID string `json:"user_id"`
- UserAlias string `json:"user_alias"`
- OrganizationID string `json:"organization_id"`
- OrganizationName string `json:"organization_name"`
- ChartRepository *ChartRepository `json:"chart_repository"`
+ PackageID string `json:"package_id"`
+ Kind PackageKind `json:"kind"`
+ Name string `json:"name"`
+ NormalizedName string `json:"normalized_name"`
+ LogoURL string `json:"logo_url"`
+ LogoImageID string `json:"logo_image_id"`
+ DisplayName string `json:"display_name"`
+ Description string `json:"description"`
+ Keywords []string `json:"keywords"`
+ HomeURL string `json:"home_url"`
+ Readme string `json:"readme"`
+ Links []*Link `json:"links"`
+ Data map[string]interface{} `json:"data"`
+ Version string `json:"version"`
+ AvailableVersions []string `json:"available_versions"`
+ AppVersion string `json:"app_version"`
+ Digest string `json:"digest"`
+ Deprecated bool `json:"deprecated"`
+ Maintainers []*Maintainer `json:"maintainers"`
+ UserID string `json:"user_id"`
+ UserAlias string `json:"user_alias"`
+ OrganizationID string `json:"organization_id"`
+ OrganizationName string `json:"organization_name"`
+ OrganizationDisplayName string `json:"organization_display_name"`
+ ChartRepository *ChartRepository `json:"chart_repository"`
}
// PackageKind represents the kind of a given package.
@@ -68,6 +69,7 @@ const (
// PackageManager describes the methods a PackageManager implementation must
// provide.
type PackageManager interface {
+ Get(ctx context.Context, input *GetPackageInput) (*Package, error)
GetJSON(ctx context.Context, input *GetPackageInput) ([]byte, error)
GetStarredByUserJSON(ctx context.Context) ([]byte, error)
GetStarsJSON(ctx context.Context, packageID string) ([]byte, error)
diff --git a/internal/hub/subscription.go b/internal/hub/subscription.go
new file mode 100644
index 00000000..4966d900
--- /dev/null
+++ b/internal/hub/subscription.go
@@ -0,0 +1,21 @@
+package hub
+
+import "context"
+
+// Subscription represents a user's subscription to receive notifications about
+// a given package.
+type Subscription struct {
+ UserID string `json:"user_id"`
+ PackageID string `json:"package_id"`
+ NotificationKind NotificationKind `json:"notification_kind"`
+}
+
+// SubscriptionManager describes the methods a SubscriptionManager
+// implementation must provide.
+type SubscriptionManager interface {
+ Add(ctx context.Context, s *Subscription) error
+ Delete(ctx context.Context, s *Subscription) error
+ GetByPackageJSON(ctx context.Context, packageID string) ([]byte, error)
+ GetByUserJSON(ctx context.Context) ([]byte, error)
+ GetSubscriptors(ctx context.Context, packageID string, notificationKind NotificationKind) ([]*User, error)
+}
diff --git a/internal/notification/dispatcher.go b/internal/notification/dispatcher.go
new file mode 100644
index 00000000..68577b8f
--- /dev/null
+++ b/internal/notification/dispatcher.go
@@ -0,0 +1,72 @@
+package notification
+
+import (
+ "context"
+ "sync"
+
+ "github.com/artifacthub/hub/internal/hub"
+ "github.com/spf13/viper"
+)
+
+const (
+ defaultNumWorkers = 2
+)
+
+// Services is a wrapper around several internal services used to handle
+// notifications deliveries.
+type Services struct {
+ DB hub.DB
+ ES hub.EmailSender
+ NotificationManager hub.NotificationManager
+ SubscriptionManager hub.SubscriptionManager
+ PackageManager hub.PackageManager
+}
+
+// Dispatcher handles a group of workers in charge of delivering notifications.
+type Dispatcher struct {
+ numWorkers int
+ workers []*Worker
+}
+
+// NewDispatcher creates a new Dispatcher instance.
+func NewDispatcher(cfg *viper.Viper, svc *Services, opts ...func(d *Dispatcher)) *Dispatcher {
+ d := &Dispatcher{
+ numWorkers: defaultNumWorkers,
+ }
+ for _, o := range opts {
+ o(d)
+ }
+ baseURL := cfg.GetString("server.baseURL")
+ d.workers = make([]*Worker, 0, d.numWorkers)
+ for i := 0; i < d.numWorkers; i++ {
+ d.workers = append(d.workers, NewWorker(svc, baseURL))
+ }
+ return d
+}
+
+// WithNumWorkers allows providing a specific number of workers for a
+// Dispatcher instance.
+func WithNumWorkers(n int) func(d *Dispatcher) {
+ return func(d *Dispatcher) {
+ d.numWorkers = n
+ }
+}
+
+// Run starts the workers and lets them run until the dispatcher is asked to
+// stop via the context provided.
+func (d *Dispatcher) Run(ctx context.Context, wg *sync.WaitGroup) {
+ defer wg.Done()
+
+ // Start workers
+ wwg := &sync.WaitGroup{}
+ wctx, stopWorkers := context.WithCancel(context.Background())
+ for _, w := range d.workers {
+ wwg.Add(1)
+ go w.Run(wctx, wwg)
+ }
+
+ // Stop workers when dispatcher is asked to stop
+ <-ctx.Done()
+ stopWorkers()
+ wwg.Wait()
+}
diff --git a/internal/notification/dispatcher_test.go b/internal/notification/dispatcher_test.go
new file mode 100644
index 00000000..50eee6c8
--- /dev/null
+++ b/internal/notification/dispatcher_test.go
@@ -0,0 +1,31 @@
+package notification
+
+import (
+ "context"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/spf13/viper"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestDispatcher(t *testing.T) {
+ // Setup dispatcher
+ cfg := viper.New()
+ cfg.Set("server.baseURL", "http://localhost:8000")
+ d := NewDispatcher(cfg, nil, WithNumWorkers(0))
+
+ // Run it
+ ctx, stopDispatcher := context.WithCancel(context.Background())
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go d.Run(ctx, &wg)
+
+ // Check it stops as expected when asked to do so
+ stopDispatcher()
+ assert.Eventually(t, func() bool {
+ wg.Wait()
+ return true
+ }, 2*time.Second, 100*time.Millisecond)
+}
diff --git a/internal/notification/manager.go b/internal/notification/manager.go
new file mode 100644
index 00000000..cd7d27ae
--- /dev/null
+++ b/internal/notification/manager.go
@@ -0,0 +1,31 @@
+package notification
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/artifacthub/hub/internal/hub"
+ "github.com/jackc/pgx/v4"
+)
+
+// Manager provides an API to manage notifications.
+type Manager struct{}
+
+// NewManager creates a new Manager instance.
+func NewManager() *Manager {
+ return &Manager{}
+}
+
+// GetPending returns a pending notification to be delivered if available.
+func (m *Manager) GetPending(ctx context.Context, tx pgx.Tx) (*hub.Notification, error) {
+ query := "select get_pending_notification()"
+ var dataJSON []byte
+ if err := tx.QueryRow(ctx, query).Scan(&dataJSON); err != nil {
+ return nil, err
+ }
+ var n *hub.Notification
+ if err := json.Unmarshal(dataJSON, &n); err != nil {
+ return nil, err
+ }
+ return n, nil
+}
diff --git a/internal/notification/manager_test.go b/internal/notification/manager_test.go
new file mode 100644
index 00000000..524bd112
--- /dev/null
+++ b/internal/notification/manager_test.go
@@ -0,0 +1,51 @@
+package notification
+
+import (
+ "context"
+ "testing"
+
+ "github.com/artifacthub/hub/internal/hub"
+ "github.com/artifacthub/hub/internal/tests"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestGetPending(t *testing.T) {
+ dbQuery := "select get_pending_notification()"
+ ctx := context.Background()
+
+ t.Run("database error", func(t *testing.T) {
+ tx := &tests.TXMock{}
+ tx.On("QueryRow", dbQuery).Return(nil, tests.ErrFakeDatabaseFailure)
+ m := NewManager()
+
+ dataJSON, err := m.GetPending(ctx, tx)
+ assert.Equal(t, tests.ErrFakeDatabaseFailure, err)
+ assert.Nil(t, dataJSON)
+ tx.AssertExpectations(t)
+ })
+
+ t.Run("database query succeeded", func(t *testing.T) {
+ expectedNotification := &hub.Notification{
+ NotificationID: "00000000-0000-0000-0000-000000000001",
+ PackageVersion: "1.0.0",
+ PackageID: "00000000-0000-0000-0000-000000000001",
+ NotificationKind: hub.NewRelease,
+ }
+
+ tx := &tests.TXMock{}
+ tx.On("QueryRow", dbQuery).Return([]byte(`
+ {
+ "notification_id": "00000000-0000-0000-0000-000000000001",
+ "package_version": "1.0.0",
+ "package_id": "00000000-0000-0000-0000-000000000001",
+ "notification_kind": 0
+ }
+ `), nil)
+ m := NewManager()
+
+ n, err := m.GetPending(context.Background(), tx)
+ assert.NoError(t, err)
+ assert.Equal(t, expectedNotification, n)
+ tx.AssertExpectations(t)
+ })
+}
diff --git a/internal/notification/mock.go b/internal/notification/mock.go
new file mode 100644
index 00000000..a18d20c8
--- /dev/null
+++ b/internal/notification/mock.go
@@ -0,0 +1,21 @@
+package notification
+
+import (
+ "context"
+
+ "github.com/artifacthub/hub/internal/hub"
+ "github.com/jackc/pgx/v4"
+ "github.com/stretchr/testify/mock"
+)
+
+// ManagerMock is a mock implementation of the NotificationManager interface.
+type ManagerMock struct {
+ mock.Mock
+}
+
+// GetPending implements the NotificationManager interface.
+func (m *ManagerMock) GetPending(ctx context.Context, tx pgx.Tx) (*hub.Notification, error) {
+ args := m.Called(ctx, tx)
+ data, _ := args.Get(0).(*hub.Notification)
+ return data, args.Error(1)
+}
diff --git a/internal/notification/tmpl_new_release_email.go b/internal/notification/tmpl_new_release_email.go
new file mode 100644
index 00000000..9f53fddd
--- /dev/null
+++ b/internal/notification/tmpl_new_release_email.go
@@ -0,0 +1,177 @@
+package notification
+
+import "html/template"
+
+var newReleaseEmailTmpl = template.Must(template.New("").Parse(`
+
+
+
+
+
+ {{ .name }} new release
+
+
+
+
+
+ | |
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ .name }}
+ {{ .publisher }}
+
+ Version {{ .version }} has been released
+
+
+
+
+
+
+ |
+ Or you can copy-paste this link: {{ .baseURL}}{{ .packagePath }}
+ |
+
+
+
+ |
+
+
+ |
+
+
+
+
+
+
+
+
+
+
+
+ |
+ |
+
+
+
+
+`))
diff --git a/internal/notification/worker.go b/internal/notification/worker.go
new file mode 100644
index 00000000..1c142e27
--- /dev/null
+++ b/internal/notification/worker.go
@@ -0,0 +1,157 @@
+package notification
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/artifacthub/hub/internal/email"
+ "github.com/artifacthub/hub/internal/hub"
+ "github.com/artifacthub/hub/internal/util"
+ "github.com/jackc/pgx/v4"
+ "github.com/rs/zerolog/log"
+)
+
+const (
+ pauseOnEmptyQueue = 1 * time.Minute
+ pauseOnError = 1 * time.Second
+)
+
+// Worker is in charge of delivering pending notifications to their intended
+// recipients.
+type Worker struct {
+ svc *Services
+ baseURL string
+}
+
+// NewWorker creates a new Worker instance.
+func NewWorker(
+ svc *Services,
+ baseURL string,
+) *Worker {
+ return &Worker{
+ svc: svc,
+ baseURL: baseURL,
+ }
+}
+
+// Run is the main loop of the worker. It calls deliverNotification periodically
+// until it's asked to stop via the context provided.
+func (w *Worker) Run(ctx context.Context, wg *sync.WaitGroup) {
+ defer wg.Done()
+
+ for {
+ err := w.deliverNotification(ctx)
+ switch err {
+ case nil:
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+ case pgx.ErrNoRows:
+ select {
+ case <-time.After(pauseOnEmptyQueue):
+ case <-ctx.Done():
+ return
+ }
+ default:
+ select {
+ case <-time.After(pauseOnError):
+ case <-ctx.Done():
+ return
+ }
+ }
+ }
+}
+
+// deliverNotification gets a pending notification from the database and
+// delivers it.
+func (w *Worker) deliverNotification(ctx context.Context) error {
+ return util.DBTransact(ctx, w.svc.DB, func(tx pgx.Tx) error {
+ n, err := w.svc.NotificationManager.GetPending(ctx, tx)
+ if err != nil {
+ if !errors.Is(err, pgx.ErrNoRows) {
+ log.Error().Err(err).Msg("error getting pending notification")
+ }
+ return err
+ }
+ rcpts, err := w.svc.SubscriptionManager.GetSubscriptors(ctx, n.PackageID, n.NotificationKind)
+ if err != nil {
+ log.Error().Err(err).Msg("error getting notification subscriptors")
+ return err
+ }
+ if len(rcpts) == 0 {
+ return nil
+ }
+ emailData, err := w.prepareEmailData(ctx, n)
+ if err != nil {
+ log.Error().Err(err).Msg("error preparing email data")
+ return err
+ }
+ for _, u := range rcpts {
+ emailData.To = u.Email
+ if err := w.svc.ES.SendEmail(emailData); err != nil {
+ log.Error().
+ Err(err).
+ Str("notificationID", n.NotificationID).
+ Str("email", u.Email).
+ Msg("error sending notification email")
+ }
+ }
+ return nil
+ })
+}
+
+// prepareEmailData prepares the content of the notification email.
+func (w *Worker) prepareEmailData(ctx context.Context, n *hub.Notification) (*email.Data, error) {
+ var subject string
+ var emailBody bytes.Buffer
+
+ switch n.NotificationKind {
+ case hub.NewRelease:
+ p, err := w.svc.PackageManager.Get(ctx, &hub.GetPackageInput{PackageID: n.PackageID})
+ if err != nil {
+ return nil, err
+ }
+ subject = fmt.Sprintf("%s version %s released", p.Name, p.Version)
+ publisher := p.OrganizationName
+ if publisher == "" {
+ publisher = p.UserAlias
+ }
+ if p.ChartRepository != nil {
+ publisher += "/" + p.ChartRepository.Name
+ }
+ var packagePath string
+ switch p.Kind {
+ case hub.Chart:
+ packagePath = fmt.Sprintf("/package/chart/%s/%s", p.ChartRepository.Name, p.NormalizedName)
+ case hub.Falco:
+ packagePath = fmt.Sprintf("/package/falco/%s", p.NormalizedName)
+ case hub.OPA:
+ packagePath = fmt.Sprintf("/package/opa/%s", p.NormalizedName)
+ }
+ data := map[string]interface{}{
+ "publisher": publisher,
+ "kind": p.Kind,
+ "name": p.Name,
+ "version": n.PackageVersion,
+ "baseURL": w.baseURL,
+ "logoImageID": p.LogoImageID,
+ "packagePath": packagePath,
+ }
+ if err := newReleaseEmailTmpl.Execute(&emailBody, data); err != nil {
+ return nil, err
+ }
+ }
+
+ return &email.Data{
+ Subject: subject,
+ Body: emailBody.Bytes(),
+ }, nil
+}
+
+// - Publisher (/ Chart repo)
diff --git a/internal/notification/worker_test.go b/internal/notification/worker_test.go
new file mode 100644
index 00000000..591ca275
--- /dev/null
+++ b/internal/notification/worker_test.go
@@ -0,0 +1,189 @@
+package notification
+
+import (
+ "context"
+ "errors"
+ "os"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/artifacthub/hub/internal/email"
+ "github.com/artifacthub/hub/internal/hub"
+ "github.com/artifacthub/hub/internal/pkg"
+ "github.com/artifacthub/hub/internal/subscription"
+ "github.com/artifacthub/hub/internal/tests"
+ "github.com/rs/zerolog"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+)
+
+var errFake = errors.New("fake error for tests")
+
+func TestMain(m *testing.M) {
+ zerolog.SetGlobalLevel(zerolog.Disabled)
+ os.Exit(m.Run())
+}
+
+func TestWorker(t *testing.T) {
+ t.Run("error getting pending notification", func(t *testing.T) {
+ sw := newServicesWrapper()
+ sw.db.On("Begin", mock.Anything).Return(sw.tx, nil)
+ sw.tx.On("Rollback", mock.Anything).Return(nil)
+ sw.nm.On("GetPending", mock.Anything, mock.Anything).Return(nil, errFake)
+
+ w := NewWorker(sw.svc, "baseURL")
+ go w.Run(sw.ctx, sw.wg)
+ sw.assertExpectations(t)
+ })
+
+ t.Run("error getting notification subscriptors", func(t *testing.T) {
+ sw := newServicesWrapper()
+ sw.db.On("Begin", mock.Anything).Return(sw.tx, nil)
+ sw.tx.On("Rollback", mock.Anything).Return(nil)
+ sw.nm.On("GetPending", mock.Anything, mock.Anything).Return(&hub.Notification{
+ PackageID: "packageID",
+ NotificationKind: hub.NewRelease,
+ }, nil)
+ sw.sm.On("GetSubscriptors", mock.Anything, "packageID", hub.NewRelease).Return(nil, errFake)
+
+ w := NewWorker(sw.svc, "baseURL")
+ go w.Run(sw.ctx, sw.wg)
+ sw.assertExpectations(t)
+ })
+
+ t.Run("no subscriptors found", func(t *testing.T) {
+ sw := newServicesWrapper()
+ sw.db.On("Begin", mock.Anything).Return(sw.tx, nil)
+ sw.tx.On("Commit", mock.Anything).Return(nil)
+ sw.nm.On("GetPending", mock.Anything, mock.Anything).Return(&hub.Notification{
+ PackageID: "packageID",
+ NotificationKind: hub.NewRelease,
+ }, nil)
+ sw.sm.On("GetSubscriptors", mock.Anything, "packageID", hub.NewRelease).Return([]*hub.User{}, nil)
+
+ w := NewWorker(sw.svc, "baseURL")
+ go w.Run(sw.ctx, sw.wg)
+ sw.assertExpectations(t)
+ })
+
+ t.Run("error preparing email data", func(t *testing.T) {
+ sw := newServicesWrapper()
+ sw.db.On("Begin", mock.Anything).Return(sw.tx, nil)
+ sw.tx.On("Rollback", mock.Anything).Return(nil)
+ sw.nm.On("GetPending", mock.Anything, mock.Anything).Return(&hub.Notification{
+ PackageID: "packageID",
+ NotificationKind: hub.NewRelease,
+ }, nil)
+ sw.sm.On("GetSubscriptors", mock.Anything, "packageID", hub.NewRelease).Return([]*hub.User{
+ {
+ Email: "user1@email.com",
+ },
+ }, nil)
+ sw.pm.On("Get", mock.Anything, mock.Anything).Return(nil, errFake)
+
+ w := NewWorker(sw.svc, "baseURL")
+ go w.Run(sw.ctx, sw.wg)
+ sw.assertExpectations(t)
+ })
+
+ t.Run("error sending email", func(t *testing.T) {
+ sw := newServicesWrapper()
+ sw.db.On("Begin", mock.Anything).Return(sw.tx, nil)
+ sw.tx.On("Rollback", mock.Anything).Return(nil)
+ sw.nm.On("GetPending", mock.Anything, mock.Anything).Return(&hub.Notification{
+ PackageID: "packageID",
+ NotificationKind: hub.NewRelease,
+ }, nil)
+ sw.sm.On("GetSubscriptors", mock.Anything, "packageID", hub.NewRelease).Return([]*hub.User{
+ {
+ Email: "user1@email.com",
+ },
+ }, nil)
+ sw.pm.On("Get", mock.Anything, mock.Anything).Return(nil, errFake)
+ sw.es.On("SendEmail", mock.Anything).Return(errFake)
+
+ w := NewWorker(sw.svc, "baseURL")
+ go w.Run(sw.ctx, sw.wg)
+ sw.assertExpectations(t)
+ })
+
+ t.Run("notification delivered successfully", func(t *testing.T) {
+ sw := newServicesWrapper()
+ sw.db.On("Begin", mock.Anything).Return(sw.tx, nil)
+ sw.tx.On("Rollback", mock.Anything).Return(nil)
+ sw.nm.On("GetPending", mock.Anything, mock.Anything).Return(&hub.Notification{
+ PackageID: "packageID",
+ NotificationKind: hub.NewRelease,
+ }, nil)
+ sw.sm.On("GetSubscriptors", mock.Anything, "packageID", hub.NewRelease).Return([]*hub.User{
+ {
+ Email: "user1@email.com",
+ },
+ }, nil)
+ sw.pm.On("Get", mock.Anything, mock.Anything).Return(nil, errFake)
+ sw.es.On("SendEmail", mock.Anything).Return(nil)
+
+ w := NewWorker(sw.svc, "baseURL")
+ go w.Run(sw.ctx, sw.wg)
+ sw.assertExpectations(t)
+ })
+}
+
+type servicesWrapper struct {
+ ctx context.Context
+ stopWorker context.CancelFunc
+ wg *sync.WaitGroup
+ db *tests.DBMock
+ tx *tests.TXMock
+ es *email.SenderMock
+ nm *ManagerMock
+ sm *subscription.ManagerMock
+ pm *pkg.ManagerMock
+ svc *Services
+}
+
+func newServicesWrapper() *servicesWrapper {
+ // Context and wait group used for Worker.Run()
+ ctx, stopWorker := context.WithCancel(context.Background())
+ var wg sync.WaitGroup
+ wg.Add(1)
+
+ db := &tests.DBMock{}
+ tx := &tests.TXMock{}
+ es := &email.SenderMock{}
+ nm := &ManagerMock{}
+ sm := &subscription.ManagerMock{}
+ pm := &pkg.ManagerMock{}
+
+ return &servicesWrapper{
+ ctx: ctx,
+ stopWorker: stopWorker,
+ wg: &wg,
+ db: db,
+ tx: tx,
+ es: es,
+ nm: nm,
+ sm: sm,
+ pm: pm,
+ svc: &Services{
+ DB: db,
+ ES: es,
+ NotificationManager: nm,
+ SubscriptionManager: sm,
+ PackageManager: pm,
+ },
+ }
+}
+
+func (sw *servicesWrapper) assertExpectations(t *testing.T) {
+ sw.stopWorker()
+ assert.Eventually(t, func() bool {
+ sw.wg.Wait()
+ return true
+ }, 2*time.Second, 100*time.Millisecond)
+
+ sw.nm.AssertExpectations(t)
+ sw.sm.AssertExpectations(t)
+ sw.pm.AssertExpectations(t)
+}
diff --git a/internal/org/manager.go b/internal/org/manager.go
index b27eeff9..4ab4657e 100644
--- a/internal/org/manager.go
+++ b/internal/org/manager.go
@@ -187,6 +187,14 @@ func (m *Manager) DeleteMember(ctx context.Context, orgName, userAlias string) e
return err
}
+// GetByUserJSON returns the organizations the user doing the request belongs
+// to as a json object.
+func (m *Manager) GetByUserJSON(ctx context.Context) ([]byte, error) {
+ query := "select get_user_organizations($1::uuid)"
+ userID := ctx.Value(hub.UserIDKey).(string)
+ return m.dbQueryJSON(ctx, query, userID)
+}
+
// GetJSON returns the organization requested as a json object.
func (m *Manager) GetJSON(ctx context.Context, orgName string) ([]byte, error) {
// Validate input
@@ -199,14 +207,6 @@ func (m *Manager) GetJSON(ctx context.Context, orgName string) ([]byte, error) {
return m.dbQueryJSON(ctx, query, orgName)
}
-// GetByUserJSON returns the organizations the user doing the request belongs
-// to as a json object.
-func (m *Manager) GetByUserJSON(ctx context.Context) ([]byte, error) {
- query := "select get_user_organizations($1::uuid)"
- userID := ctx.Value(hub.UserIDKey).(string)
- return m.dbQueryJSON(ctx, query, userID)
-}
-
// GetMembersJSON returns the members of the provided organization as a json
// object.
func (m *Manager) GetMembersJSON(ctx context.Context, orgName string) ([]byte, error) {
diff --git a/internal/org/manager_test.go b/internal/org/manager_test.go
index ba8775c7..fd5e84d6 100644
--- a/internal/org/manager_test.go
+++ b/internal/org/manager_test.go
@@ -354,38 +354,6 @@ func TestDeleteMember(t *testing.T) {
})
}
-func TestGetJSON(t *testing.T) {
- dbQuery := `select get_organization($1::text)`
-
- t.Run("invalid input", func(t *testing.T) {
- m := NewManager(nil, nil)
- _, err := m.GetJSON(context.Background(), "")
- assert.True(t, errors.Is(err, ErrInvalidInput))
- })
-
- t.Run("database query succeeded", func(t *testing.T) {
- db := &tests.DBMock{}
- db.On("QueryRow", dbQuery, "orgName").Return([]byte("dataJSON"), nil)
- m := NewManager(db, nil)
-
- dataJSON, err := m.GetJSON(context.Background(), "orgName")
- assert.NoError(t, err)
- assert.Equal(t, []byte("dataJSON"), dataJSON)
- db.AssertExpectations(t)
- })
-
- t.Run("database error", func(t *testing.T) {
- db := &tests.DBMock{}
- db.On("QueryRow", dbQuery, "orgName").Return(nil, tests.ErrFakeDatabaseFailure)
- m := NewManager(db, nil)
-
- dataJSON, err := m.GetJSON(context.Background(), "orgName")
- assert.Equal(t, tests.ErrFakeDatabaseFailure, err)
- assert.Nil(t, dataJSON)
- db.AssertExpectations(t)
- })
-}
-
func TestGetByUserJSON(t *testing.T) {
dbQuery := `select get_user_organizations($1::uuid)`
ctx := context.WithValue(context.Background(), hub.UserIDKey, "userID")
@@ -420,6 +388,38 @@ func TestGetByUserJSON(t *testing.T) {
})
}
+func TestGetJSON(t *testing.T) {
+ dbQuery := `select get_organization($1::text)`
+
+ t.Run("invalid input", func(t *testing.T) {
+ m := NewManager(nil, nil)
+ _, err := m.GetJSON(context.Background(), "")
+ assert.True(t, errors.Is(err, ErrInvalidInput))
+ })
+
+ t.Run("database query succeeded", func(t *testing.T) {
+ db := &tests.DBMock{}
+ db.On("QueryRow", dbQuery, "orgName").Return([]byte("dataJSON"), nil)
+ m := NewManager(db, nil)
+
+ dataJSON, err := m.GetJSON(context.Background(), "orgName")
+ assert.NoError(t, err)
+ assert.Equal(t, []byte("dataJSON"), dataJSON)
+ db.AssertExpectations(t)
+ })
+
+ t.Run("database error", func(t *testing.T) {
+ db := &tests.DBMock{}
+ db.On("QueryRow", dbQuery, "orgName").Return(nil, tests.ErrFakeDatabaseFailure)
+ m := NewManager(db, nil)
+
+ dataJSON, err := m.GetJSON(context.Background(), "orgName")
+ assert.Equal(t, tests.ErrFakeDatabaseFailure, err)
+ assert.Nil(t, dataJSON)
+ db.AssertExpectations(t)
+ })
+}
+
func TestGetMembersJSON(t *testing.T) {
dbQuery := `select get_organization_members($1::uuid, $2::text)`
ctx := context.WithValue(context.Background(), hub.UserIDKey, "userID")
diff --git a/internal/pkg/manager.go b/internal/pkg/manager.go
index 48a7b946..1b6455a8 100644
--- a/internal/pkg/manager.go
+++ b/internal/pkg/manager.go
@@ -32,11 +32,24 @@ func NewManager(db hub.DB) *Manager {
}
}
+// Get returns the package identified by the input provided.
+func (m *Manager) Get(ctx context.Context, input *hub.GetPackageInput) (*hub.Package, error) {
+ dataJSON, err := m.GetJSON(ctx, input)
+ if err != nil {
+ return nil, err
+ }
+ p := &hub.Package{}
+ if err := json.Unmarshal(dataJSON, &p); err != nil {
+ return nil, err
+ }
+ return p, nil
+}
+
// GetJSON returns the package identified by the input provided as a json
// object. The json object is built by the database.
func (m *Manager) GetJSON(ctx context.Context, input *hub.GetPackageInput) ([]byte, error) {
// Validate input
- if input.PackageName == "" {
+ if input.PackageID == "" && input.PackageName == "" {
return nil, fmt.Errorf("%w: %s", ErrInvalidInput, "package name not provided")
}
diff --git a/internal/pkg/manager_test.go b/internal/pkg/manager_test.go
index 2ec6f83f..acd09686 100644
--- a/internal/pkg/manager_test.go
+++ b/internal/pkg/manager_test.go
@@ -11,6 +11,138 @@ import (
"github.com/stretchr/testify/mock"
)
+func TestGet(t *testing.T) {
+ dbQuery := "select get_package($1::jsonb)"
+
+ t.Run("invalid input", func(t *testing.T) {
+ m := NewManager(nil)
+ _, err := m.Get(context.Background(), &hub.GetPackageInput{})
+ assert.True(t, errors.Is(err, ErrInvalidInput))
+ })
+
+ t.Run("database error", func(t *testing.T) {
+ db := &tests.DBMock{}
+ db.On("QueryRow", dbQuery, mock.Anything, mock.Anything).Return(nil, tests.ErrFakeDatabaseFailure)
+ m := NewManager(db)
+
+ p, err := m.Get(context.Background(), &hub.GetPackageInput{PackageName: "pkg1"})
+ assert.Equal(t, tests.ErrFakeDatabaseFailure, err)
+ assert.Nil(t, p)
+ db.AssertExpectations(t)
+ })
+
+ t.Run("database query succeeded", func(t *testing.T) {
+ expectedPackage := &hub.Package{
+ PackageID: "00000000-0000-0000-0000-000000000001",
+ Kind: hub.Chart,
+ Name: "Package 1",
+ NormalizedName: "package-1",
+ LogoImageID: "00000000-0000-0000-0000-000000000001",
+ DisplayName: "Package 1",
+ Description: "description",
+ Keywords: []string{"kw1", "kw2"},
+ HomeURL: "home_url",
+ Readme: "readme-version-1.0.0",
+ Links: []*hub.Link{
+ {
+ Name: "link1",
+ URL: "https://link1",
+ },
+ {
+ Name: "link2",
+ URL: "https://link2",
+ },
+ },
+ Data: map[string]interface{}{
+ "key": "value",
+ },
+ Version: "1.0.0",
+ AvailableVersions: []string{"0.0.9", "1.0.0"},
+ AppVersion: "12.1.0",
+ Digest: "digest-package1-1.0.0",
+ Deprecated: true,
+ Maintainers: []*hub.Maintainer{
+ {
+ Name: "name1",
+ Email: "email1",
+ },
+ {
+ Name: "name2",
+ Email: "email2",
+ },
+ },
+ UserAlias: "user1",
+ OrganizationName: "org1",
+ OrganizationDisplayName: "Organization 1",
+ ChartRepository: &hub.ChartRepository{
+ ChartRepositoryID: "00000000-0000-0000-0000-000000000001",
+ Name: "repo1",
+ DisplayName: "Repo 1",
+ URL: "https://repo1.com",
+ },
+ }
+
+ db := &tests.DBMock{}
+ db.On("QueryRow", dbQuery, mock.Anything, mock.Anything).Return([]byte(`
+ {
+ "package_id": "00000000-0000-0000-0000-000000000001",
+ "kind": 0,
+ "name": "Package 1",
+ "normalized_name": "package-1",
+ "logo_image_id": "00000000-0000-0000-0000-000000000001",
+ "display_name": "Package 1",
+ "description": "description",
+ "keywords": ["kw1", "kw2"],
+ "home_url": "home_url",
+ "readme": "readme-version-1.0.0",
+ "links": [
+ {
+ "name": "link1",
+ "url": "https://link1"
+ },
+ {
+ "name": "link2",
+ "url": "https://link2"
+ }
+ ],
+ "data": {
+ "key": "value"
+ },
+ "version": "1.0.0",
+ "available_versions": ["0.0.9", "1.0.0"],
+ "app_version": "12.1.0",
+ "digest": "digest-package1-1.0.0",
+ "deprecated": true,
+ "maintainers": [
+ {
+ "name": "name1",
+ "email": "email1"
+ },
+ {
+ "name": "name2",
+ "email": "email2"
+ }
+ ],
+ "user_alias": "user1",
+ "organization_name": "org1",
+ "organization_display_name": "Organization 1",
+ "chart_repository": {
+ "chart_repository_id": "00000000-0000-0000-0000-000000000001",
+ "name": "repo1",
+ "display_name": "Repo 1",
+ "url": "https://repo1.com"
+ }
+ }
+ `), nil)
+ m := NewManager(db)
+
+ p, err := m.Get(context.Background(), &hub.GetPackageInput{PackageName: "package-1"})
+ assert.NoError(t, err)
+ assert.Equal(t, expectedPackage, p)
+ db.AssertExpectations(t)
+ })
+}
+
func TestGetJSON(t *testing.T) {
dbQuery := "select get_package($1::jsonb)"
diff --git a/internal/pkg/mock.go b/internal/pkg/mock.go
index 46450f0d..86ea88a0 100644
--- a/internal/pkg/mock.go
+++ b/internal/pkg/mock.go
@@ -12,6 +12,13 @@ type ManagerMock struct {
mock.Mock
}
+// Get implements the PackageManager interface.
+func (m *ManagerMock) Get(ctx context.Context, input *hub.GetPackageInput) (*hub.Package, error) {
+ args := m.Called(ctx, input)
+ data, _ := args.Get(0).(*hub.Package)
+ return data, args.Error(1)
+}
+
// GetJSON implements the PackageManager interface.
func (m *ManagerMock) GetJSON(ctx context.Context, input *hub.GetPackageInput) ([]byte, error) {
args := m.Called(ctx, input)
diff --git a/internal/subscription/manager.go b/internal/subscription/manager.go
new file mode 100644
index 00000000..931b3f67
--- /dev/null
+++ b/internal/subscription/manager.go
@@ -0,0 +1,120 @@
+package subscription
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+
+ "github.com/artifacthub/hub/internal/hub"
+ "github.com/satori/uuid"
+)
+
+var (
+ // ErrInvalidInput indicates that the input provided is not valid.
+ ErrInvalidInput = errors.New("invalid input")
+)
+
+// Manager provides an API to manage subscriptions.
+type Manager struct {
+ db hub.DB
+}
+
+// NewManager creates a new Manager instance.
+func NewManager(db hub.DB) *Manager {
+ return &Manager{
+ db: db,
+ }
+}
+
+// Add adds the provided subscription to the database.
+func (m *Manager) Add(ctx context.Context, s *hub.Subscription) error {
+ userID := ctx.Value(hub.UserIDKey).(string)
+ s.UserID = userID
+ if err := validateSubscription(s); err != nil {
+ return err
+ }
+ query := "select add_subscription($1::jsonb)"
+ sJSON, _ := json.Marshal(s)
+ _, err := m.db.Exec(ctx, query, sJSON)
+ return err
+}
+
+// Delete removes a subscription from the database.
+func (m *Manager) Delete(ctx context.Context, s *hub.Subscription) error {
+ userID := ctx.Value(hub.UserIDKey).(string)
+ s.UserID = userID
+ if err := validateSubscription(s); err != nil {
+ return err
+ }
+ query := "select delete_subscription($1::jsonb)"
+ sJSON, _ := json.Marshal(s)
+ _, err := m.db.Exec(ctx, query, sJSON)
+ return err
+}
+
+// GetByPackageJSON returns the subscriptions the user has for a given package
+// as json array of objects.
+func (m *Manager) GetByPackageJSON(ctx context.Context, packageID string) ([]byte, error) {
+ userID := ctx.Value(hub.UserIDKey).(string)
+ if _, err := uuid.FromString(packageID); err != nil {
+ return nil, fmt.Errorf("%w: %s", ErrInvalidInput, "invalid package id")
+ }
+ query := "select get_package_subscriptions($1::uuid, $2::uuid)"
+ var dataJSON []byte
+ if err := m.db.QueryRow(ctx, query, userID, packageID).Scan(&dataJSON); err != nil {
+ return nil, err
+ }
+ return dataJSON, nil
+}
+
+// GetByUserJSON returns all the subscriptions of the user doing the request as
+// as json array of objects.
+func (m *Manager) GetByUserJSON(ctx context.Context) ([]byte, error) {
+ userID := ctx.Value(hub.UserIDKey).(string)
+ query := "select get_user_subscriptions($1::uuid)"
+ var dataJSON []byte
+ if err := m.db.QueryRow(ctx, query, userID).Scan(&dataJSON); err != nil {
+ return nil, err
+ }
+ return dataJSON, nil
+}
+
+// GetSubscriptors returns the users subscribed to a package to receive certain
+// kind of notifications.
+func (m *Manager) GetSubscriptors(
+ ctx context.Context,
+ packageID string,
+ notificationKind hub.NotificationKind,
+) ([]*hub.User, error) {
+ if _, err := uuid.FromString(packageID); err != nil {
+ return nil, fmt.Errorf("%w: %s", ErrInvalidInput, "invalid package id")
+ }
+ if notificationKind != hub.NewRelease {
+ return nil, fmt.Errorf("%w: %s", ErrInvalidInput, "invalid notification kind")
+ }
+
+ query := "select get_subscriptors($1::uuid, $2::integer)"
+ var dataJSON []byte
+ err := m.db.QueryRow(ctx, query, packageID, notificationKind).Scan(&dataJSON)
+ if err != nil {
+ return nil, err
+ }
+ var subscriptors []*hub.User
+ if err := json.Unmarshal(dataJSON, &subscriptors); err != nil {
+ return nil, err
+ }
+ return subscriptors, nil
+}
+
+// validateSubscription checks if the subscription provided is valid to be used
+// as input for some database functions calls.
+func validateSubscription(s *hub.Subscription) error {
+ if _, err := uuid.FromString(s.PackageID); err != nil {
+ return fmt.Errorf("%w: %s", ErrInvalidInput, "invalid package id")
+ }
+ if s.NotificationKind != hub.NewRelease {
+ return fmt.Errorf("%w: %s", ErrInvalidInput, "invalid notification kind")
+ }
+ return nil
+}
diff --git a/internal/subscription/manager_test.go b/internal/subscription/manager_test.go
new file mode 100644
index 00000000..4ee7e2fe
--- /dev/null
+++ b/internal/subscription/manager_test.go
@@ -0,0 +1,299 @@
+package subscription
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/artifacthub/hub/internal/hub"
+ "github.com/artifacthub/hub/internal/tests"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+)
+
+const (
+ userID = "00000000-0000-0000-0000-000000000001"
+ packageID = "00000000-0000-0000-0000-000000000001"
+)
+
+func TestAdd(t *testing.T) {
+ dbQuery := "select add_subscription($1::jsonb)"
+ ctx := context.WithValue(context.Background(), hub.UserIDKey, userID)
+
+ t.Run("user id not found in ctx", func(t *testing.T) {
+ m := NewManager(nil)
+ assert.Panics(t, func() {
+ _ = m.Add(context.Background(), &hub.Subscription{})
+ })
+ })
+
+ t.Run("invalid input", func(t *testing.T) {
+ testCases := []struct {
+ errMsg string
+ s *hub.Subscription
+ }{
+ {
+ "invalid package id",
+ &hub.Subscription{
+ PackageID: "invalid",
+ },
+ },
+ {
+ "invalid notification kind",
+ &hub.Subscription{
+ PackageID: packageID,
+ NotificationKind: hub.NotificationKind(5),
+ },
+ },
+ }
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.errMsg, func(t *testing.T) {
+ m := NewManager(nil)
+ err := m.Add(ctx, tc.s)
+ assert.True(t, errors.Is(err, ErrInvalidInput))
+ assert.Contains(t, err.Error(), tc.errMsg)
+ })
+ }
+ })
+
+ s := &hub.Subscription{
+ PackageID: packageID,
+ NotificationKind: hub.NewRelease,
+ }
+
+ t.Run("database error", func(t *testing.T) {
+ db := &tests.DBMock{}
+ db.On("Exec", dbQuery, mock.Anything).Return(tests.ErrFakeDatabaseFailure)
+ m := NewManager(db)
+
+ err := m.Add(ctx, s)
+ assert.Equal(t, tests.ErrFakeDatabaseFailure, err)
+ db.AssertExpectations(t)
+ })
+
+ t.Run("database query succeeded", func(t *testing.T) {
+ db := &tests.DBMock{}
+ db.On("Exec", dbQuery, mock.Anything).Return(nil)
+ m := NewManager(db)
+
+ err := m.Add(ctx, s)
+ assert.NoError(t, err)
+ db.AssertExpectations(t)
+ })
+}
+
+func TestDelete(t *testing.T) {
+ dbQuery := "select delete_subscription($1::jsonb)"
+ ctx := context.WithValue(context.Background(), hub.UserIDKey, userID)
+
+ t.Run("user id not found in ctx", func(t *testing.T) {
+ m := NewManager(nil)
+ assert.Panics(t, func() {
+ _ = m.Delete(context.Background(), &hub.Subscription{})
+ })
+ })
+
+ t.Run("invalid input", func(t *testing.T) {
+ testCases := []struct {
+ errMsg string
+ s *hub.Subscription
+ }{
+ {
+ "invalid package id",
+ &hub.Subscription{
+ PackageID: "invalid",
+ },
+ },
+ {
+ "invalid notification kind",
+ &hub.Subscription{
+ PackageID: packageID,
+ NotificationKind: hub.NotificationKind(5),
+ },
+ },
+ }
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.errMsg, func(t *testing.T) {
+ m := NewManager(nil)
+ err := m.Delete(ctx, tc.s)
+ assert.True(t, errors.Is(err, ErrInvalidInput))
+ assert.Contains(t, err.Error(), tc.errMsg)
+ })
+ }
+ })
+
+ s := &hub.Subscription{
+ PackageID: packageID,
+ NotificationKind: hub.NewRelease,
+ }
+
+ t.Run("database error", func(t *testing.T) {
+ db := &tests.DBMock{}
+ db.On("Exec", dbQuery, mock.Anything).Return(tests.ErrFakeDatabaseFailure)
+ m := NewManager(db)
+
+ err := m.Delete(ctx, s)
+ assert.Equal(t, tests.ErrFakeDatabaseFailure, err)
+ db.AssertExpectations(t)
+ })
+
+ t.Run("database query succeeded", func(t *testing.T) {
+ db := &tests.DBMock{}
+ db.On("Exec", dbQuery, mock.Anything).Return(nil)
+ m := NewManager(db)
+
+ err := m.Delete(ctx, s)
+ assert.NoError(t, err)
+ db.AssertExpectations(t)
+ })
+}
+
+func TestGetByPackageJSON(t *testing.T) {
+ dbQuery := "select get_package_subscriptions($1::uuid, $2::uuid)"
+ ctx := context.WithValue(context.Background(), hub.UserIDKey, userID)
+
+ t.Run("user id not found in ctx", func(t *testing.T) {
+ m := NewManager(nil)
+ assert.Panics(t, func() {
+ _, _ = m.GetByPackageJSON(context.Background(), "")
+ })
+ })
+
+ t.Run("invalid input", func(t *testing.T) {
+ m := NewManager(nil)
+ _, err := m.GetByPackageJSON(ctx, "")
+ assert.True(t, errors.Is(err, ErrInvalidInput))
+ assert.Contains(t, err.Error(), "invalid package id")
+ })
+
+ t.Run("database query succeeded", func(t *testing.T) {
+ db := &tests.DBMock{}
+ db.On("QueryRow", dbQuery, userID, packageID).Return([]byte("dataJSON"), nil)
+ m := NewManager(db)
+
+ dataJSON, err := m.GetByPackageJSON(ctx, packageID)
+ assert.NoError(t, err)
+ assert.Equal(t, []byte("dataJSON"), dataJSON)
+ db.AssertExpectations(t)
+ })
+
+ t.Run("database error", func(t *testing.T) {
+ db := &tests.DBMock{}
+ db.On("QueryRow", dbQuery, userID, packageID).Return(nil, tests.ErrFakeDatabaseFailure)
+ m := NewManager(db)
+
+ dataJSON, err := m.GetByPackageJSON(ctx, packageID)
+ assert.Equal(t, tests.ErrFakeDatabaseFailure, err)
+ assert.Nil(t, dataJSON)
+ db.AssertExpectations(t)
+ })
+}
+
+func TestGetByUserJSON(t *testing.T) {
+ dbQuery := "select get_user_subscriptions($1::uuid)"
+ ctx := context.WithValue(context.Background(), hub.UserIDKey, userID)
+
+ t.Run("user id not found in ctx", func(t *testing.T) {
+ m := NewManager(nil)
+ assert.Panics(t, func() {
+ _, _ = m.GetByUserJSON(context.Background())
+ })
+ })
+
+ t.Run("database query succeeded", func(t *testing.T) {
+ db := &tests.DBMock{}
+ db.On("QueryRow", dbQuery, userID).Return([]byte("dataJSON"), nil)
+ m := NewManager(db)
+
+ dataJSON, err := m.GetByUserJSON(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, []byte("dataJSON"), dataJSON)
+ db.AssertExpectations(t)
+ })
+
+ t.Run("database error", func(t *testing.T) {
+ db := &tests.DBMock{}
+ db.On("QueryRow", dbQuery, userID).Return(nil, tests.ErrFakeDatabaseFailure)
+ m := NewManager(db)
+
+ dataJSON, err := m.GetByUserJSON(ctx)
+ assert.Equal(t, tests.ErrFakeDatabaseFailure, err)
+ assert.Nil(t, dataJSON)
+ db.AssertExpectations(t)
+ })
+}
+
+func TestGetSubscriptors(t *testing.T) {
+ dbQuery := "select get_subscriptors($1::uuid, $2::integer)"
+
+ t.Run("invalid input", func(t *testing.T) {
+ testCases := []struct {
+ errMsg string
+ packageID string
+ notificationKind hub.NotificationKind
+ }{
+ {
+ "invalid package id",
+ "invalid",
+ 0,
+ },
+ {
+ "invalid notification kind",
+ packageID,
+ hub.NotificationKind(5),
+ },
+ }
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.errMsg, func(t *testing.T) {
+ m := NewManager(nil)
+ dataJSON, err := m.GetSubscriptors(context.Background(), tc.packageID, tc.notificationKind)
+ assert.True(t, errors.Is(err, ErrInvalidInput))
+ assert.Contains(t, err.Error(), tc.errMsg)
+ assert.Nil(t, dataJSON)
+ })
+ }
+ })
+
+ t.Run("database error", func(t *testing.T) {
+ db := &tests.DBMock{}
+ db.On("QueryRow", dbQuery, packageID, hub.NotificationKind(0)).Return(nil, tests.ErrFakeDatabaseFailure)
+ m := NewManager(db)
+
+ subscriptors, err := m.GetSubscriptors(context.Background(), packageID, hub.NewRelease)
+ assert.Equal(t, tests.ErrFakeDatabaseFailure, err)
+ assert.Nil(t, subscriptors)
+ db.AssertExpectations(t)
+ })
+
+ t.Run("database query succeeded", func(t *testing.T) {
+ expectedSubscriptors := []*hub.User{
+ {
+ Email: "user1@email.com",
+ },
+ {
+ Email: "user2@email.com",
+ },
+ }
+
+ db := &tests.DBMock{}
+ db.On("QueryRow", dbQuery, packageID, hub.NotificationKind(0)).Return([]byte(`
+ [
+ {
+ "email": "user1@email.com"
+ },
+ {
+ "email": "user2@email.com"
+ }
+ ]
+ `), nil)
+ m := NewManager(db)
+
+ subscriptors, err := m.GetSubscriptors(context.Background(), packageID, hub.NewRelease)
+ assert.NoError(t, err)
+ assert.Equal(t, expectedSubscriptors, subscriptors)
+ db.AssertExpectations(t)
+ })
+}
diff --git a/internal/subscription/mock.go b/internal/subscription/mock.go
new file mode 100644
index 00000000..cd8e7b8e
--- /dev/null
+++ b/internal/subscription/mock.go
@@ -0,0 +1,50 @@
+package subscription
+
+import (
+ "context"
+
+ "github.com/artifacthub/hub/internal/hub"
+ "github.com/stretchr/testify/mock"
+)
+
+// ManagerMock is a mock implementation of the SubscriptionManager interface.
+type ManagerMock struct {
+ mock.Mock
+}
+
+// Add implements the SubscriptionManager interface.
+func (m *ManagerMock) Add(ctx context.Context, s *hub.Subscription) error {
+ args := m.Called(ctx, s)
+ return args.Error(0)
+}
+
+// Delete implements the SubscriptionManager interface.
+func (m *ManagerMock) Delete(ctx context.Context, s *hub.Subscription) error {
+ args := m.Called(ctx, s)
+ return args.Error(0)
+}
+
+// GetByPackageJSON implements the SubscriptionManager interface.
+func (m *ManagerMock) GetByPackageJSON(ctx context.Context, packageID string) ([]byte, error) {
+ args := m.Called(ctx, packageID)
+ data, _ := args.Get(0).([]byte)
+ return data, args.Error(1)
+}
+
+// GetByUserJSON implements the SubscriptionManager interface.
+func (m *ManagerMock) GetByUserJSON(ctx context.Context) ([]byte, error) {
+ args := m.Called(ctx)
+ data, _ := args.Get(0).([]byte)
+ return data, args.Error(1)
+}
+
+// GetByUserJSON implements the SubscriptionManager interface.
+func (m *ManagerMock) GetSubscriptors(
+ ctx context.Context,
+ packageID string,
+ notificationKind hub.NotificationKind,
+) ([]*hub.User, error) {
+ args := m.Called(ctx, packageID, notificationKind)
+ data, _ := args.Get(0).([]*hub.User)
+ return data, args.Error(1)
+}
diff --git a/internal/tests/db.go b/internal/tests/db.go
index f8177049..f3a796c4 100644
--- a/internal/tests/db.go
+++ b/internal/tests/db.go
@@ -17,6 +17,19 @@ type DBMock struct {
mock.Mock
}
+// Begin implements the DB interface.
+func (m *DBMock) Begin(ctx context.Context) (pgx.Tx, error) {
+ args := m.Called(ctx)
+ tx, _ := args.Get(0).(pgx.Tx)
+ return tx, args.Error(1)
+}
+
+// Exec implements the DB interface.
+func (m *DBMock) Exec(ctx context.Context, query string, params ...interface{}) (pgconn.CommandTag, error) {
+ args := m.Called(append([]interface{}{query}, params...)...)
+ return nil, args.Error(0)
+}
+
// QueryRow implements the DB interface.
func (m *DBMock) QueryRow(ctx context.Context, query string, params ...interface{}) pgx.Row {
args := m.Called(append([]interface{}{query}, params...)...)
@@ -32,10 +45,89 @@ func (m *DBMock) QueryRow(ctx context.Context, query string, params ...interface
return rowMock
}
-// Exec implements the DB interface.
-func (m *DBMock) Exec(ctx context.Context, query string, params ...interface{}) (pgconn.CommandTag, error) {
+// TXMock is a mock implementation of the pgx.Tx interface.
+type TXMock struct {
+ mock.Mock
+}
+
+// Begin implements the pgx.Tx interface.
+func (m *TXMock) Begin(ctx context.Context) (pgx.Tx, error) {
+ // NOTE: not used
+ return nil, nil
+}
+
+// Commit implements the pgx.Tx interface.
+func (m *TXMock) Commit(ctx context.Context) error {
+ args := m.Called(ctx)
+ return args.Error(0)
+}
+
+// Conn implements the pgx.Tx interface.
+func (m *TXMock) Conn() *pgx.Conn {
+ // NOTE: not used
+ return nil
+}
+
+// CopyFrom implements the pgx.Tx interface.
+func (m *TXMock) CopyFrom(
+ ctx context.Context,
+ tableName pgx.Identifier,
+ columnNames []string,
+ rowSrc pgx.CopyFromSource,
+) (int64, error) {
+ // NOTE: not used
+ return 0, nil
+}
+
+// Exec implements the pgx.Tx interface.
+func (m *TXMock) Exec(ctx context.Context, query string, params ...interface{}) (pgconn.CommandTag, error) {
+ // NOTE: not used
+ return nil, nil
+}
+
+// LargeObjects implements the pgx.Tx interface.
+func (m *TXMock) LargeObjects() pgx.LargeObjects {
+ // NOTE: not used
+ return pgx.LargeObjects{}
+}
+
+// Prepare implements the pgx.Tx interface.
+func (m *TXMock) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) {
+ // NOTE: not used
+ return nil, nil
+}
+
+// QueryRow implements the pgx.Tx interface.
+func (m *TXMock) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
+ // NOTE: not used
+ return nil, nil
+}
+
+// QueryRow implements the pgx.Tx interface.
+func (m *TXMock) QueryRow(ctx context.Context, query string, params ...interface{}) pgx.Row {
args := m.Called(append([]interface{}{query}, params...)...)
- return nil, args.Error(0)
+ rowMock := &RowMock{
+ err: args.Error(1),
+ }
+ switch v := args.Get(0).(type) {
+ case []interface{}:
+ rowMock.data = v
+ case interface{}:
+ rowMock.data = []interface{}{args.Get(0)}
+ }
+ return rowMock
+}
+
+// Rollback implements the pgx.Tx interface.
+func (m *TXMock) Rollback(ctx context.Context) error {
+ args := m.Called(ctx)
+ return args.Error(0)
+}
+
+// SendBatch implements the pgx.Tx interface.
+func (m *TXMock) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
+ // NOTE: not used
+ return nil
}
// RowMock is a mock implementation of the pgx.Row interface.
diff --git a/internal/util/db.go b/internal/util/db.go
index 9314d481..08988040 100644
--- a/internal/util/db.go
+++ b/internal/util/db.go
@@ -5,6 +5,7 @@ import (
"fmt"
"time"
+ "github.com/artifacthub/hub/internal/hub"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/log/zerologadapter"
"github.com/jackc/pgx/v4/pgxpool"
@@ -40,3 +41,24 @@ func SetupDB(cfg *viper.Viper) (*pgxpool.Pool, error) {
return pool, nil
}
+
+// DBTransact is a helper function that wraps some database transactions taking
+// care of committing and rolling back when needed.
+func DBTransact(ctx context.Context, db hub.DB, txFunc func(pgx.Tx) error) (err error) {
+ tx, err := db.Begin(ctx)
+ if err != nil {
+ return
+ }
+ defer func() {
+ if p := recover(); p != nil {
+ _ = tx.Rollback(ctx)
+ panic(p)
+ } else if err != nil {
+ _ = tx.Rollback(ctx)
+ } else {
+ err = tx.Commit(ctx)
+ }
+ }()
+ err = txFunc(tx)
+ return err
+}