From b3202fc34bbe9f887d2551ff0254af3c0c1c9b3f Mon Sep 17 00:00:00 2001 From: "Sergio C. Arteaga" Date: Mon, 26 Apr 2021 08:55:14 +0200 Subject: [PATCH] Some improvements in HTTP clients used by backend (#1269) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sergio CastaƱo Arteaga --- charts/artifact-hub/Chart.yaml | 2 +- charts/artifact-hub/templates/hub_secret.yaml | 1 + .../templates/scanner_secret.yaml | 1 + .../templates/tracker_secret.yaml | 1 + charts/artifact-hub/values-production.yaml | 2 + charts/artifact-hub/values-staging.yaml | 2 + charts/artifact-hub/values.schema.json | 8 +- charts/artifact-hub/values.yaml | 1 + cmd/hub/main.go | 8 +- cmd/scanner/main.go | 3 +- cmd/tracker/main.go | 5 +- internal/handlers/handlers.go | 5 +- internal/handlers/webhook/handlers.go | 6 +- internal/handlers/webhook/handlers_test.go | 2 +- internal/notification/dispatcher.go | 5 +- internal/notification/worker.go | 22 +- internal/notification/worker_test.go | 40 ++-- internal/repo/manager.go | 25 +- internal/repo/manager_test.go | 221 +++++++++--------- internal/tests/http.go | 12 - internal/util/http.go | 100 ++++++++ 21 files changed, 280 insertions(+), 192 deletions(-) create mode 100644 internal/util/http.go diff --git a/charts/artifact-hub/Chart.yaml b/charts/artifact-hub/Chart.yaml index 0c07a950..9975ec7d 100644 --- a/charts/artifact-hub/Chart.yaml +++ b/charts/artifact-hub/Chart.yaml @@ -2,7 +2,7 @@ apiVersion: v2 name: artifact-hub description: Artifact Hub is a web-based application that enables finding, installing, and publishing Kubernetes packages. type: application -version: 0.18.3 +version: 0.18.4 appVersion: 0.18.0 kubeVersion: ">= 1.14.0-0" home: https://artifacthub.io diff --git a/charts/artifact-hub/templates/hub_secret.yaml b/charts/artifact-hub/templates/hub_secret.yaml index b7ec3089..cb6cfafc 100644 --- a/charts/artifact-hub/templates/hub_secret.yaml +++ b/charts/artifact-hub/templates/hub_secret.yaml @@ -5,6 +5,7 @@ metadata: type: Opaque stringData: hub.yaml: |- + restrictedHTTPClient: {{ .Values.restrictedHTTPClient }} log: level: {{ .Values.log.level }} pretty: {{ .Values.log.pretty }} diff --git a/charts/artifact-hub/templates/scanner_secret.yaml b/charts/artifact-hub/templates/scanner_secret.yaml index 48c94346..ffc61db1 100644 --- a/charts/artifact-hub/templates/scanner_secret.yaml +++ b/charts/artifact-hub/templates/scanner_secret.yaml @@ -5,6 +5,7 @@ metadata: type: Opaque stringData: scanner.yaml: |- + restrictedHTTPClient: {{ .Values.restrictedHTTPClient }} log: level: {{ .Values.log.level }} pretty: {{ .Values.log.pretty }} diff --git a/charts/artifact-hub/templates/tracker_secret.yaml b/charts/artifact-hub/templates/tracker_secret.yaml index fc1187a4..798b036b 100644 --- a/charts/artifact-hub/templates/tracker_secret.yaml +++ b/charts/artifact-hub/templates/tracker_secret.yaml @@ -5,6 +5,7 @@ metadata: type: Opaque stringData: tracker.yaml: |- + restrictedHTTPClient: {{ .Values.restrictedHTTPClient }} log: level: {{ .Values.log.level }} pretty: {{ .Values.log.pretty }} diff --git a/charts/artifact-hub/values-production.yaml b/charts/artifact-hub/values-production.yaml index cf30cb57..2b4b9d7d 100644 --- a/charts/artifact-hub/values-production.yaml +++ b/charts/artifact-hub/values-production.yaml @@ -1,3 +1,5 @@ +restrictedHTTPClient: true + log: level: debug pretty: false diff --git a/charts/artifact-hub/values-staging.yaml b/charts/artifact-hub/values-staging.yaml index 2ce866b8..b0070e91 100644 --- a/charts/artifact-hub/values-staging.yaml +++ b/charts/artifact-hub/values-staging.yaml @@ -1,3 +1,5 @@ +restrictedHTTPClient: true + log: level: debug pretty: false diff --git a/charts/artifact-hub/values.schema.json b/charts/artifact-hub/values.schema.json index f791472b..08f6870e 100644 --- a/charts/artifact-hub/values.schema.json +++ b/charts/artifact-hub/values.schema.json @@ -558,6 +558,12 @@ "type": "string", "default": "IfNotPresent" }, + "restrictedHTTPClient": { + "type": "boolean", + "title": "Enable restricted HTTP client", + "description": "Artifact Hub makes external HTTP requests for several purposes, like getting repositories metadata, dispatching webhooks, etc. When this option is enabled, requests to the private network space as well as to some other special addresses won't be allowed.", + "default": false + }, "scanner": { "title": "Scanner configuration", "type": "object", @@ -738,5 +744,5 @@ "required": ["deploy", "persistence"] } }, - "required": ["db", "dbMigrator", "hub", "images", "log", "postgresql", "pullPolicy", "tracker", "trivy", "scanner"] + "required": ["db", "dbMigrator", "hub", "images", "log", "postgresql", "pullPolicy", "restrictedHTTPClient", "tracker", "trivy", "scanner"] } diff --git a/charts/artifact-hub/values.yaml b/charts/artifact-hub/values.yaml index ad52bf3c..eb762055 100644 --- a/charts/artifact-hub/values.yaml +++ b/charts/artifact-hub/values.yaml @@ -4,6 +4,7 @@ imagePullSecrets: [] imageTag: "" dynamicResourceNamePrefixEnabled: false pullPolicy: IfNotPresent +restrictedHTTPClient: false log: level: info diff --git a/cmd/hub/main.go b/cmd/hub/main.go index 60209a07..6d424658 100644 --- a/cmd/hub/main.go +++ b/cmd/hub/main.go @@ -54,14 +54,14 @@ func main() { if err != nil { log.Fatal().Err(err).Msg("authorizer setup failed") } - hc := &http.Client{Timeout: 10 * time.Second} + hc := util.SetupHTTPClient(cfg.GetBool("restrictedHTTPClient")) // Setup and launch http server ctx, stop := context.WithCancel(context.Background()) hSvc := &handlers.Services{ OrganizationManager: org.NewManager(db, es, az), UserManager: user.NewManager(db, es), - RepositoryManager: repo.NewManager(cfg, db, az), + RepositoryManager: repo.NewManager(cfg, db, az, hc), PackageManager: pkg.NewManager(db), SubscriptionManager: subscription.NewManager(db), WebhookManager: webhook.NewManager(db), @@ -69,6 +69,7 @@ func main() { StatsManager: stats.NewManager(db), ImageStore: pg.NewImageStore(cfg, db, hc, nil), Authorizer: az, + HTTPClient: hc, } h, err := handlers.Setup(ctx, cfg, hSvc) if err != nil { @@ -117,8 +118,9 @@ func main() { ES: es, NotificationManager: notification.NewManager(), SubscriptionManager: subscription.NewManager(db), - RepositoryManager: repo.NewManager(cfg, db, az), + RepositoryManager: repo.NewManager(cfg, db, az, hc), PackageManager: pkg.NewManager(db), + HTTPClient: hc, } notificationsDispatcher := notification.NewDispatcher(cfg, nSvc) wg.Add(1) diff --git a/cmd/scanner/main.go b/cmd/scanner/main.go index 9734be20..20e34ff3 100644 --- a/cmd/scanner/main.go +++ b/cmd/scanner/main.go @@ -53,7 +53,8 @@ func main() { if err != nil { log.Fatal().Err(err).Msg("authorizer setup failed") } - rm := repo.NewManager(cfg, db, az) + hc := util.SetupHTTPClient(cfg.GetBool("restrictedHTTPClient")) + rm := repo.NewManager(cfg, db, az, hc) pm := pkg.NewManager(db) ec := repo.NewErrorsCollector(rm, repo.Scanner) diff --git a/cmd/tracker/main.go b/cmd/tracker/main.go index 8f547692..67933997 100644 --- a/cmd/tracker/main.go +++ b/cmd/tracker/main.go @@ -3,7 +3,6 @@ package main import ( "context" "errors" - "net/http" "os" "os/exec" "os/signal" @@ -68,9 +67,9 @@ func main() { if err != nil { log.Fatal().Err(err).Msg("authorizer setup failed") } - rm := repo.NewManager(cfg, db, az) + hc := util.SetupHTTPClient(cfg.GetBool("restrictedHTTPClient")) + rm := repo.NewManager(cfg, db, az, hc) pm := pkg.NewManager(db) - hc := &http.Client{Timeout: 10 * time.Second} githubMaxRequestsPerHour := githubMaxRequestsPerHourUnauthenticated if cfg.GetString("creds.githubToken") != "" { githubMaxRequestsPerHour = githubMaxRequestsPerHourAuthenticated diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index a705e4d5..0455699c 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -48,6 +48,7 @@ type Services struct { StatsManager hub.StatsManager ImageStore img.Store Authorizer hub.Authorizer + HTTPClient hub.HTTPClient } // Metrics groups some metrics collected from a Handlers instance. @@ -90,9 +91,9 @@ func Setup(ctx context.Context, cfg *viper.Viper, svc *Services) (*Handlers, err Organizations: org.NewHandlers(svc.OrganizationManager, svc.Authorizer, cfg), Users: userHandlers, Repositories: repo.NewHandlers(svc.RepositoryManager), - Packages: pkg.NewHandlers(svc.PackageManager, svc.RepositoryManager, cfg, &http.Client{}), + Packages: pkg.NewHandlers(svc.PackageManager, svc.RepositoryManager, cfg, svc.HTTPClient), Subscriptions: subscription.NewHandlers(svc.SubscriptionManager), - Webhooks: webhook.NewHandlers(svc.WebhookManager), + Webhooks: webhook.NewHandlers(svc.WebhookManager, svc.HTTPClient), APIKeys: apikey.NewHandlers(svc.APIKeyManager), Static: static.NewHandlers(cfg, svc.ImageStore), Stats: stats.NewHandlers(svc.StatsManager), diff --git a/internal/handlers/webhook/handlers.go b/internal/handlers/webhook/handlers.go index 4293d0cb..c62f3061 100644 --- a/internal/handlers/webhook/handlers.go +++ b/internal/handlers/webhook/handlers.go @@ -20,13 +20,15 @@ import ( type Handlers struct { webhookManager hub.WebhookManager logger zerolog.Logger + hc hub.HTTPClient } // NewHandlers creates a new Handlers instance. -func NewHandlers(webhookManager hub.WebhookManager) *Handlers { +func NewHandlers(webhookManager hub.WebhookManager, hc hub.HTTPClient) *Handlers { return &Handlers{ webhookManager: webhookManager, logger: log.With().Str("handlers", "webhook").Logger(), + hc: hc, } } @@ -134,7 +136,7 @@ func (h *Handlers) TriggerTest(w http.ResponseWriter, r *http.Request) { } req.Header.Set("Content-Type", contentType) req.Header.Set("X-ArtifactHub-Secret", wh.Secret) - resp, err := http.DefaultClient.Do(req) + resp, err := h.hc.Do(req) if err != nil { err = fmt.Errorf("error doing request: %w", err) helpers.RenderErrorWithCodeJSON(w, err, http.StatusBadRequest) diff --git a/internal/handlers/webhook/handlers_test.go b/internal/handlers/webhook/handlers_test.go index b435ca0b..ff75e120 100644 --- a/internal/handlers/webhook/handlers_test.go +++ b/internal/handlers/webhook/handlers_test.go @@ -702,7 +702,7 @@ func newHandlersWrapper() *handlersWrapper { return &handlersWrapper{ wm: wm, - h: NewHandlers(wm), + h: NewHandlers(wm, http.DefaultClient), } } diff --git a/internal/notification/dispatcher.go b/internal/notification/dispatcher.go index 7f8368d6..6988d9e9 100644 --- a/internal/notification/dispatcher.go +++ b/internal/notification/dispatcher.go @@ -2,7 +2,6 @@ package notification import ( "context" - "net/http" "sync" "time" @@ -26,6 +25,7 @@ type Services struct { SubscriptionManager hub.SubscriptionManager RepositoryManager hub.RepositoryManager PackageManager hub.PackageManager + HTTPClient hub.HTTPClient } // Dispatcher handles a group of workers in charge of delivering notifications. @@ -47,10 +47,9 @@ func NewDispatcher(cfg *viper.Viper, svc *Services, opts ...func(d *Dispatcher)) // Setup and launch workers c := cache.New(cacheDefaultExpiration, cacheCleanupInterval) baseURL := cfg.GetString("server.baseURL") - httpClient := &http.Client{Timeout: 10 * time.Second} d.workers = make([]*Worker, 0, d.numWorkers) for i := 0; i < d.numWorkers; i++ { - d.workers = append(d.workers, NewWorker(svc, c, baseURL, httpClient)) + d.workers = append(d.workers, NewWorker(svc, c, baseURL)) } return d diff --git a/internal/notification/worker.go b/internal/notification/worker.go index ab8d0129..81b26132 100644 --- a/internal/notification/worker.go +++ b/internal/notification/worker.go @@ -35,17 +35,11 @@ var ( ErrRetryable = errors.New("retryable error") ) -// HTTPClient defines the methods an HTTPClient implementation must provide. -type HTTPClient interface { - Do(req *http.Request) (*http.Response, error) -} - // Worker is in charge of delivering notifications to their intended recipients. type Worker struct { - svc *Services - cache *cache.Cache - baseURL string - httpClient HTTPClient + svc *Services + cache *cache.Cache + baseURL string } // NewWorker creates a new Worker instance. @@ -53,13 +47,11 @@ func NewWorker( svc *Services, c *cache.Cache, baseURL string, - httpClient HTTPClient, ) *Worker { return &Worker{ - svc: svc, - cache: c, - baseURL: baseURL, - httpClient: httpClient, + svc: svc, + cache: c, + baseURL: baseURL, } } @@ -185,7 +177,7 @@ func (w *Worker) deliverWebhookNotification(ctx context.Context, n *hub.Notifica req, _ := http.NewRequest("POST", n.Webhook.URL, &payload) req.Header.Set("Content-Type", contentType) req.Header.Set("X-ArtifactHub-Secret", n.Webhook.Secret) - resp, err := w.httpClient.Do(req) + resp, err := w.svc.HTTPClient.Do(req) if err != nil { return err } diff --git a/internal/notification/worker_test.go b/internal/notification/worker_test.go index 008d7ab4..606e75de 100644 --- a/internal/notification/worker_test.go +++ b/internal/notification/worker_test.go @@ -88,7 +88,7 @@ func TestWorker(t *testing.T) { sw.nm.On("GetPending", sw.ctx, sw.tx).Return(nil, tests.ErrFake) sw.tx.On("Rollback", sw.ctx).Return(nil) - w := NewWorker(sw.svc, sw.cache, "", sw.hc) + w := NewWorker(sw.svc, sw.cache, "") go w.Run(sw.ctx, sw.wg) sw.assertExpectations(t) }) @@ -101,7 +101,7 @@ func TestWorker(t *testing.T) { sw.pm.On("Get", sw.ctx, gpi).Return(nil, tests.ErrFake) sw.tx.On("Rollback", sw.ctx).Return(nil) - w := NewWorker(sw.svc, sw.cache, "", sw.hc) + w := NewWorker(sw.svc, sw.cache, "") go w.Run(sw.ctx, sw.wg) sw.assertExpectations(t) }) @@ -114,7 +114,7 @@ func TestWorker(t *testing.T) { sw.rm.On("GetByID", sw.ctx, "repositoryID", false).Return(nil, tests.ErrFake) sw.tx.On("Rollback", sw.ctx).Return(nil) - w := NewWorker(sw.svc, sw.cache, "", sw.hc) + w := NewWorker(sw.svc, sw.cache, "") go w.Run(sw.ctx, sw.wg) sw.assertExpectations(t) }) @@ -129,7 +129,7 @@ func TestWorker(t *testing.T) { sw.nm.On("UpdateStatus", sw.ctx, sw.tx, n1.NotificationID, true, tests.ErrFake).Return(nil) sw.tx.On("Commit", sw.ctx).Return(nil) - w := NewWorker(sw.svc, sw.cache, "", sw.hc) + w := NewWorker(sw.svc, sw.cache, "") go w.Run(sw.ctx, sw.wg) sw.assertExpectations(t) }) @@ -144,7 +144,7 @@ func TestWorker(t *testing.T) { sw.nm.On("UpdateStatus", sw.ctx, sw.tx, n3.NotificationID, true, tests.ErrFake).Return(nil) sw.tx.On("Commit", sw.ctx).Return(nil) - w := NewWorker(sw.svc, sw.cache, "", sw.hc) + w := NewWorker(sw.svc, sw.cache, "") go w.Run(sw.ctx, sw.wg) sw.assertExpectations(t) }) @@ -159,7 +159,7 @@ func TestWorker(t *testing.T) { sw.nm.On("UpdateStatus", sw.ctx, sw.tx, n1.NotificationID, true, nil).Return(nil) sw.tx.On("Commit", sw.ctx).Return(nil) - w := NewWorker(sw.svc, sw.cache, "", sw.hc) + w := NewWorker(sw.svc, sw.cache, "") go w.Run(sw.ctx, sw.wg) sw.assertExpectations(t) }) @@ -174,7 +174,7 @@ func TestWorker(t *testing.T) { sw.nm.On("UpdateStatus", sw.ctx, sw.tx, n3.NotificationID, true, nil).Return(nil) sw.tx.On("Commit", sw.ctx).Return(nil) - w := NewWorker(sw.svc, sw.cache, "", sw.hc) + w := NewWorker(sw.svc, sw.cache, "") go w.Run(sw.ctx, sw.wg) sw.assertExpectations(t) }) @@ -187,7 +187,7 @@ func TestWorker(t *testing.T) { sw.pm.On("Get", sw.ctx, gpi).Return(nil, tests.ErrFake) sw.tx.On("Rollback", sw.ctx).Return(nil) - w := NewWorker(sw.svc, sw.cache, "", sw.hc) + w := NewWorker(sw.svc, sw.cache, "") go w.Run(sw.ctx, sw.wg) sw.assertExpectations(t) }) @@ -202,7 +202,7 @@ func TestWorker(t *testing.T) { sw.nm.On("UpdateStatus", sw.ctx, sw.tx, n2.NotificationID, true, tests.ErrFake).Return(nil) sw.tx.On("Commit", sw.ctx).Return(nil) - w := NewWorker(sw.svc, sw.cache, "", sw.hc) + w := NewWorker(sw.svc, sw.cache, "") go w.Run(sw.ctx, sw.wg) sw.assertExpectations(t) }) @@ -220,7 +220,7 @@ func TestWorker(t *testing.T) { sw.nm.On("UpdateStatus", sw.ctx, sw.tx, n2.NotificationID, true, mock.Anything).Return(nil) sw.tx.On("Commit", sw.ctx).Return(nil) - w := NewWorker(sw.svc, sw.cache, "", sw.hc) + w := NewWorker(sw.svc, sw.cache, "") go w.Run(sw.ctx, sw.wg) sw.assertExpectations(t) }) @@ -238,7 +238,7 @@ func TestWorker(t *testing.T) { sw.nm.On("UpdateStatus", sw.ctx, sw.tx, n2.NotificationID, true, nil).Return(nil) sw.tx.On("Commit", sw.ctx).Return(nil) - w := NewWorker(sw.svc, sw.cache, "", sw.hc) + w := NewWorker(sw.svc, sw.cache, "") go w.Run(sw.ctx, sw.wg) sw.assertExpectations(t) }) @@ -307,6 +307,7 @@ func TestWorker(t *testing.T) { defer ts.Close() sw := newServicesWrapper() + sw.svc.HTTPClient = &http.Client{} sw.db.On("Begin", sw.ctx).Return(sw.tx, nil) sw.nm.On("GetPending", sw.ctx, sw.tx).Return(&hub.Notification{ NotificationID: "notificationID", @@ -322,7 +323,7 @@ func TestWorker(t *testing.T) { sw.nm.On("UpdateStatus", sw.ctx, sw.tx, n2.NotificationID, true, nil).Return(nil) sw.tx.On("Commit", sw.ctx).Return(nil) - w := NewWorker(sw.svc, sw.cache, "http://baseURL", http.DefaultClient) + w := NewWorker(sw.svc, sw.cache, "http://baseURL") go w.Run(sw.ctx, sw.wg) sw.assertExpectations(t) }) @@ -342,7 +343,7 @@ type servicesWrapper struct { rm *repo.ManagerMock pm *pkg.ManagerMock cache *cache.Cache - hc *httpClientMock + hc *tests.HTTPClientMock svc *Services } @@ -360,7 +361,7 @@ func newServicesWrapper() *servicesWrapper { rm := &repo.ManagerMock{} pm := &pkg.ManagerMock{} cache := cache.New(1*time.Minute, 5*time.Minute) - hc := &httpClientMock{} + hc := &tests.HTTPClientMock{} return &servicesWrapper{ ctx: ctx, @@ -382,6 +383,7 @@ func newServicesWrapper() *servicesWrapper { SubscriptionManager: sm, RepositoryManager: rm, PackageManager: pm, + HTTPClient: hc, }, } } @@ -402,13 +404,3 @@ func (sw *servicesWrapper) assertExpectations(t *testing.T) { sw.pm.AssertExpectations(t) sw.hc.AssertExpectations(t) } - -type httpClientMock struct { - mock.Mock -} - -func (m *httpClientMock) Do(req *http.Request) (*http.Response, error) { - args := m.Called(req) - resp, _ := args.Get(0).(*http.Response) - return resp, args.Error(1) -} diff --git a/internal/repo/manager.go b/internal/repo/manager.go index 41d6ce95..a079d50f 100644 --- a/internal/repo/manager.go +++ b/internal/repo/manager.go @@ -13,7 +13,6 @@ import ( "path/filepath" "regexp" "strings" - "time" "github.com/artifacthub/hub/internal/hub" "github.com/artifacthub/hub/internal/util" @@ -66,39 +65,36 @@ var ( GitRepoURLRE = regexp.MustCompile(`^(https:\/\/(github|gitlab)\.com\/[A-Za-z0-9_.-]+\/[A-Za-z0-9_.-]+)\/?(.*)$`) ) -// HTTPGetter defines the methods an HTTPGetter implementation must provide. -type HTTPGetter interface { - Get(url string) (*http.Response, error) -} - // Manager provides an API to manage repositories. type Manager struct { cfg *viper.Viper db hub.DB - hg HTTPGetter + hc hub.HTTPClient rc hub.RepositoryCloner helmIndexLoader hub.HelmIndexLoader az hub.Authorizer } // NewManager creates a new Manager instance. -func NewManager(cfg *viper.Viper, db hub.DB, az hub.Authorizer, opts ...func(m *Manager)) *Manager { +func NewManager( + cfg *viper.Viper, + db hub.DB, + az hub.Authorizer, + hc hub.HTTPClient, + opts ...func(m *Manager), +) *Manager { // Setup manager m := &Manager{ cfg: cfg, db: db, helmIndexLoader: &HelmIndexLoader{}, az: az, + hc: hc, } for _, o := range opts { o(m) } - // Setup HTTP getter - if m.hg == nil { - m.hg = &http.Client{Timeout: 10 * time.Second} - } - // Setup repository cloner if m.rc == nil { m.rc = &Cloner{} @@ -392,7 +388,8 @@ func (m *Manager) readMetadataFile(mdFile string) ([]byte, error) { return nil, fmt.Errorf("error reading repository metadata file: %w", err) } } else { - resp, err := m.hg.Get(mdFile) + req, _ := http.NewRequest("GET", mdFile, nil) + resp, err := m.hc.Do(req) if err != nil { return nil, fmt.Errorf("error downloading repository metadata file: %w", err) } diff --git a/internal/repo/manager_test.go b/internal/repo/manager_test.go index 3ea446db..ad4f2127 100644 --- a/internal/repo/manager_test.go +++ b/internal/repo/manager_test.go @@ -31,7 +31,7 @@ func TestAdd(t *testing.T) { t.Run("user id not found in ctx", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) assert.Panics(t, func() { _ = m.Add(context.Background(), "orgName", nil) }) @@ -184,7 +184,7 @@ func TestAdd(t *testing.T) { } else { l.On("LoadIndex", tc.r).Return(nil, "", nil).Maybe() } - m := NewManager(cfg, nil, nil, WithHelmIndexLoader(l)) + m := NewManager(cfg, nil, nil, nil, WithHelmIndexLoader(l)) err := m.Add(ctx, tc.orgName, tc.r) assert.True(t, errors.Is(err, hub.ErrInvalidInput)) @@ -210,7 +210,7 @@ func TestAdd(t *testing.T) { }).Return(tests.ErrFake) l := &HelmIndexLoaderMock{} l.On("LoadIndex", r).Return(nil, "", nil) - m := NewManager(cfg, nil, az, WithHelmIndexLoader(l)) + m := NewManager(cfg, nil, az, nil, WithHelmIndexLoader(l)) err := m.Add(ctx, "orgName", r) assert.Equal(t, tests.ErrFake, err) @@ -258,7 +258,7 @@ func TestAdd(t *testing.T) { }).Return(nil) l := &HelmIndexLoaderMock{} l.On("LoadIndex", tc.r).Return(nil, "", nil) - m := NewManager(cfg, db, az, WithHelmIndexLoader(l)) + m := NewManager(cfg, db, az, nil, WithHelmIndexLoader(l)) err := m.Add(ctx, "orgName", tc.r) assert.Equal(t, tc.expectedError, err) @@ -306,7 +306,7 @@ func TestAdd(t *testing.T) { if tc.r.Kind == hub.Helm { l.On("LoadIndex", tc.r).Return(nil, "", nil) } - m := NewManager(cfg, db, az, WithHelmIndexLoader(l)) + m := NewManager(cfg, db, az, nil, WithHelmIndexLoader(l)) err := m.Add(ctx, "orgName", tc.r) assert.NoError(t, err) @@ -342,7 +342,7 @@ func TestCheckAvailability(t *testing.T) { tc := tc t.Run(tc.errMsg, func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) _, err := m.CheckAvailability(context.Background(), tc.resourceKind, tc.value) assert.True(t, errors.Is(err, hub.ErrInvalidInput)) assert.Contains(t, err.Error(), tc.errMsg) @@ -374,7 +374,7 @@ func TestCheckAvailability(t *testing.T) { tc.dbQuery = fmt.Sprintf("select not exists (%s)", tc.dbQuery) db := &tests.DBMock{} db.On("QueryRow", ctx, tc.dbQuery, "value").Return(tc.available, nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) available, err := m.CheckAvailability(ctx, tc.resourceKind, "value/") assert.NoError(t, err) @@ -389,7 +389,7 @@ func TestCheckAvailability(t *testing.T) { db := &tests.DBMock{} dbQuery := fmt.Sprintf(`select not exists (%s)`, checkRepoNameAvailDBQ) db.On("QueryRow", ctx, dbQuery, "value").Return(false, tests.ErrFakeDB) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) available, err := m.CheckAvailability(ctx, "repositoryName", "value") assert.Equal(t, tests.ErrFakeDB, err) @@ -407,10 +407,12 @@ func TestClaimOwnership(t *testing.T) { opaRepoJSON := []byte(`{"kind": 2, "url": "http://repo.url"}`) ociRepoJSON := []byte(`{"kind": 0, "url": "oci://registry.io/repo/pkg"}`) ctx := context.WithValue(context.Background(), hub.UserIDKey, userID) + mdYmlReq, _ := http.NewRequest("GET", "http://repo.url/artifacthub-repo.yml", nil) + mdYamlReq, _ := http.NewRequest("GET", "http://repo.url/artifacthub-repo.yaml", nil) t.Run("user id not found in ctx", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) assert.Panics(t, func() { _ = m.ClaimOwnership(context.Background(), "repo1", "") }) @@ -430,7 +432,7 @@ func TestClaimOwnership(t *testing.T) { tc := tc t.Run(tc.errMsg, func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) err := m.ClaimOwnership(ctx, tc.repoName, "") assert.True(t, errors.Is(err, hub.ErrInvalidInput)) @@ -442,7 +444,7 @@ func TestClaimOwnership(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getRepoByNameDBQ, "repo1", false).Return(nil, tests.ErrFakeDB) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) err := m.ClaimOwnership(ctx, "repo1", org) assert.Equal(t, tests.ErrFakeDB, err) @@ -453,16 +455,16 @@ func TestClaimOwnership(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getRepoByNameDBQ, "repo1", false).Return(helmRepoJSON, nil) - hg := &tests.HTTPGetterMock{} - hg.On("Get", "http://repo.url/artifacthub-repo.yml").Return(&http.Response{ + hc := &tests.HTTPClientMock{} + hc.On("Do", mdYmlReq).Return(&http.Response{ Body: ioutil.NopCloser(strings.NewReader("")), StatusCode: http.StatusNotFound, }, nil) - hg.On("Get", "http://repo.url/artifacthub-repo.yaml").Return(&http.Response{ + hc.On("Do", mdYamlReq).Return(&http.Response{ Body: ioutil.NopCloser(strings.NewReader("")), StatusCode: http.StatusNotFound, }, nil) - m := NewManager(cfg, db, nil, withHTTPGetter(hg)) + m := NewManager(cfg, db, nil, hc) err := m.ClaimOwnership(ctx, "repo1", org) assert.Error(t, err) @@ -473,7 +475,7 @@ func TestClaimOwnership(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getRepoByNameDBQ, "repo1", false).Return(ociRepoJSON, nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) err := m.ClaimOwnership(ctx, "repo1", org) assert.True(t, errors.Is(err, hub.ErrInvalidInput)) @@ -485,13 +487,13 @@ func TestClaimOwnership(t *testing.T) { db := &tests.DBMock{} db.On("QueryRow", ctx, getRepoByNameDBQ, "repo1", false).Return(helmRepoJSON, nil) db.On("QueryRow", ctx, getUserEmailDBQ, userID).Return("", tests.ErrFakeDB) - hg := &tests.HTTPGetterMock{} mdFile, _ := os.Open("testdata/artifacthub-repo.yml") - hg.On("Get", "http://repo.url/artifacthub-repo.yml").Return(&http.Response{ + hc := &tests.HTTPClientMock{} + hc.On("Do", mdYmlReq).Return(&http.Response{ Body: mdFile, StatusCode: http.StatusOK, }, nil) - m := NewManager(cfg, db, nil, withHTTPGetter(hg)) + m := NewManager(cfg, db, nil, hc) err := m.ClaimOwnership(ctx, "repo1", org) assert.Equal(t, tests.ErrFakeDB, err) @@ -503,13 +505,13 @@ func TestClaimOwnership(t *testing.T) { db := &tests.DBMock{} db.On("QueryRow", ctx, getRepoByNameDBQ, "repo1", false).Return(helmRepoJSON, nil) db.On("QueryRow", ctx, getUserEmailDBQ, userID).Return("user1@email.com", nil) - hg := &tests.HTTPGetterMock{} mdFile, _ := os.Open("testdata/artifacthub-repo.yml") - hg.On("Get", "http://repo.url/artifacthub-repo.yml").Return(&http.Response{ + hc := &tests.HTTPClientMock{} + hc.On("Do", mdYmlReq).Return(&http.Response{ Body: mdFile, StatusCode: http.StatusOK, }, nil) - m := NewManager(cfg, db, nil, withHTTPGetter(hg)) + m := NewManager(cfg, db, nil, hc) err := m.ClaimOwnership(ctx, "repo1", org) assert.Equal(t, hub.ErrInsufficientPrivilege, err) @@ -522,13 +524,13 @@ func TestClaimOwnership(t *testing.T) { db.On("QueryRow", ctx, getRepoByNameDBQ, "repo1", false).Return(helmRepoJSON, nil) db.On("QueryRow", ctx, getUserEmailDBQ, userID).Return("owner1@email.com", nil) db.On("Exec", ctx, transferRepoDBQ, "repo1", userIDP, orgP, true).Return(nil) - hg := &tests.HTTPGetterMock{} mdFile, _ := os.Open("testdata/artifacthub-repo.yml") - hg.On("Get", "http://repo.url/artifacthub-repo.yml").Return(&http.Response{ + hc := &tests.HTTPClientMock{} + hc.On("Do", mdYmlReq).Return(&http.Response{ Body: mdFile, StatusCode: http.StatusOK, }, nil) - m := NewManager(cfg, db, nil, withHTTPGetter(hg)) + m := NewManager(cfg, db, nil, hc) err := m.ClaimOwnership(ctx, "repo1", org) assert.Nil(t, err) @@ -543,7 +545,7 @@ func TestClaimOwnership(t *testing.T) { var r *hub.Repository _ = json.Unmarshal(opaRepoJSON, &r) rc.On("CloneRepository", ctx, r).Return("", "", tests.ErrFake) - m := NewManager(cfg, db, nil, withRepositoryCloner(rc)) + m := NewManager(cfg, db, nil, nil, withRepositoryCloner(rc)) err := m.ClaimOwnership(ctx, "repo1", org) assert.Equal(t, tests.ErrFake, err) @@ -560,7 +562,7 @@ func TestClaimOwnership(t *testing.T) { var r *hub.Repository _ = json.Unmarshal(opaRepoJSON, &r) rc.On("CloneRepository", ctx, r).Return(".", "testdata", nil) - m := NewManager(cfg, db, nil, withRepositoryCloner(rc)) + m := NewManager(cfg, db, nil, nil, withRepositoryCloner(rc)) err := m.ClaimOwnership(ctx, "repo1", org) assert.Nil(t, err) @@ -573,7 +575,7 @@ func TestDelete(t *testing.T) { t.Run("user id not found in ctx", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) assert.Panics(t, func() { _ = m.Delete(context.Background(), "repo1") }) @@ -581,7 +583,7 @@ func TestDelete(t *testing.T) { t.Run("invalid input", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) err := m.Delete(ctx, "") assert.True(t, errors.Is(err, hub.ErrInvalidInput)) }) @@ -602,7 +604,7 @@ func TestDelete(t *testing.T) { UserID: "userID", Action: hub.DeleteOrganizationRepository, }).Return(tests.ErrFake) - m := NewManager(cfg, db, az) + m := NewManager(cfg, db, az, nil) err := m.Delete(ctx, "repo1") assert.Equal(t, tests.ErrFake, err) @@ -642,7 +644,7 @@ func TestDelete(t *testing.T) { UserID: "userID", Action: hub.DeleteOrganizationRepository, }).Return(nil) - m := NewManager(cfg, db, az) + m := NewManager(cfg, db, az, nil) err := m.Delete(ctx, "repo1") assert.Equal(t, tc.expectedError, err) @@ -663,7 +665,7 @@ func TestDelete(t *testing.T) { } `), nil) db.On("Exec", ctx, deleteRepoDBQ, "userID", "repo1").Return(nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) err := m.Delete(ctx, "repo1") assert.NoError(t, err) @@ -703,7 +705,7 @@ func TestGetAll(t *testing.T) { "official": true }] `), nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) r, err := m.GetAll(ctx, false) require.NoError(t, err) @@ -739,7 +741,7 @@ func TestGetAllJSON(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getAllReposDBQ, false).Return(nil, tests.ErrFakeDB) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) dataJSON, err := m.GetAllJSON(ctx, false) assert.Equal(t, tests.ErrFakeDB, err) @@ -751,7 +753,7 @@ func TestGetAllJSON(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getAllReposDBQ, false).Return([]byte("dataJSON"), nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) dataJSON, err := m.GetAllJSON(ctx, false) assert.NoError(t, err) @@ -781,7 +783,7 @@ func TestGetByID(t *testing.T) { tc := tc t.Run(tc.errStr, func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) _, err := m.GetByID(ctx, tc.repoID, false) assert.True(t, errors.Is(err, hub.ErrInvalidInput)) @@ -794,7 +796,7 @@ func TestGetByID(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getRepoByIDDBQ, repoID, false).Return(nil, tests.ErrFakeDB) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) r, err := m.GetByID(context.Background(), repoID, false) assert.Equal(t, tests.ErrFakeDB, err) @@ -816,7 +818,7 @@ func TestGetByID(t *testing.T) { "official": true } `), nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) r, err := m.GetByID(context.Background(), repoID, false) require.NoError(t, err) @@ -855,7 +857,7 @@ func TestGetByKind(t *testing.T) { "official": true }] `), nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) r, err := m.GetByKind(ctx, hub.Helm, false) require.NoError(t, err) @@ -884,7 +886,7 @@ func TestGetByKindJSON(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getReposByKindDBQ, hub.OLM, false).Return(nil, tests.ErrFakeDB) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) dataJSON, err := m.GetByKindJSON(ctx, hub.OLM, false) assert.Equal(t, tests.ErrFakeDB, err) @@ -896,7 +898,7 @@ func TestGetByKindJSON(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getReposByKindDBQ, hub.OLM, false).Return([]byte("dataJSON"), nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) dataJSON, err := m.GetByKindJSON(ctx, hub.OLM, false) assert.NoError(t, err) @@ -910,7 +912,7 @@ func TestGetByName(t *testing.T) { t.Run("invalid input", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) _, err := m.GetByName(ctx, "", false) assert.True(t, errors.Is(err, hub.ErrInvalidInput)) }) @@ -931,7 +933,7 @@ func TestGetByName(t *testing.T) { "organization_name": "" } `), nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) r, err := m.GetByName(context.Background(), "repo1", false) require.NoError(t, err) @@ -949,7 +951,7 @@ func TestGetByName(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getRepoByNameDBQ, "repo1", false).Return(nil, tests.ErrFakeDB) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) r, err := m.GetByName(context.Background(), "repo1", false) assert.Equal(t, tests.ErrFakeDB, err) @@ -961,7 +963,7 @@ func TestGetByName(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getRepoByNameDBQ, "repo1", false).Return([]byte("invalid json"), nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) r, err := m.GetByName(context.Background(), "repo1", false) assert.Error(t, err) @@ -971,9 +973,14 @@ func TestGetByName(t *testing.T) { } func TestGetMetadata(t *testing.T) { + mdYmlReq, _ := http.NewRequest("GET", "http://url.test/ok.yml", nil) + mdYamlReq, _ := http.NewRequest("GET", "http://url.test/ok.yaml", nil) + mdNotFoundYmlReq, _ := http.NewRequest("GET", "http://url.test/not-found.yml", nil) + mdNotFoundYamlReq, _ := http.NewRequest("GET", "http://url.test/not-found.yaml", nil) + t.Run("local file: error reading repository metadata file", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) _, err := m.GetMetadata("testdata/not-exists.yaml") assert.Error(t, err) assert.Contains(t, err.Error(), "error reading repository metadata file") @@ -981,10 +988,10 @@ func TestGetMetadata(t *testing.T) { t.Run("remote file: error downloading repository metadata file", func(t *testing.T) { t.Parallel() - hg := &tests.HTTPGetterMock{} - hg.On("Get", "http://url.test/not-found.yml").Return(nil, tests.ErrFake) - hg.On("Get", "http://url.test/not-found.yaml").Return(nil, tests.ErrFake) - m := NewManager(cfg, nil, nil, withHTTPGetter(hg)) + hc := &tests.HTTPClientMock{} + hc.On("Do", mdNotFoundYmlReq).Return(nil, tests.ErrFake) + hc.On("Do", mdNotFoundYamlReq).Return(nil, tests.ErrFake) + m := NewManager(cfg, nil, nil, hc) _, err := m.GetMetadata("http://url.test/not-found") assert.Error(t, err) assert.Contains(t, err.Error(), "error downloading repository metadata file") @@ -992,16 +999,16 @@ func TestGetMetadata(t *testing.T) { t.Run("remote file: unexpected status code received", func(t *testing.T) { t.Parallel() - hg := &tests.HTTPGetterMock{} - hg.On("Get", "http://url.test/not-found.yml").Return(&http.Response{ + hc := &tests.HTTPClientMock{} + hc.On("Do", mdNotFoundYmlReq).Return(&http.Response{ Body: ioutil.NopCloser(strings.NewReader("")), StatusCode: http.StatusNotFound, }, nil) - hg.On("Get", "http://url.test/not-found.yaml").Return(&http.Response{ + hc.On("Do", mdNotFoundYamlReq).Return(&http.Response{ Body: ioutil.NopCloser(strings.NewReader("")), StatusCode: http.StatusNotFound, }, nil) - m := NewManager(cfg, nil, nil, withHTTPGetter(hg)) + m := NewManager(cfg, nil, nil, hc) _, err := m.GetMetadata("http://url.test/not-found") assert.Error(t, err) assert.Contains(t, err.Error(), "unexpected status code received") @@ -1009,16 +1016,16 @@ func TestGetMetadata(t *testing.T) { t.Run("remote file: error reading repository metadata file", func(t *testing.T) { t.Parallel() - hg := &tests.HTTPGetterMock{} - hg.On("Get", "http://url.test/not-found.yml").Return(&http.Response{ + hc := &tests.HTTPClientMock{} + hc.On("Do", mdNotFoundYmlReq).Return(&http.Response{ Body: ioutil.NopCloser(tests.ErrReader(0)), StatusCode: http.StatusOK, }, nil) - hg.On("Get", "http://url.test/not-found.yaml").Return(&http.Response{ + hc.On("Do", mdNotFoundYamlReq).Return(&http.Response{ Body: ioutil.NopCloser(tests.ErrReader(0)), StatusCode: http.StatusOK, }, nil) - m := NewManager(cfg, nil, nil, withHTTPGetter(hg)) + m := NewManager(cfg, nil, nil, hc) _, err := m.GetMetadata("http://url.test/not-found") assert.Error(t, err) assert.Contains(t, err.Error(), "error reading repository metadata file") @@ -1026,7 +1033,7 @@ func TestGetMetadata(t *testing.T) { t.Run("error unmarshaling repository metadata file", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) _, err := m.GetMetadata("testdata/invalid") assert.Error(t, err) assert.Contains(t, err.Error(), "error unmarshaling repository metadata file") @@ -1034,7 +1041,7 @@ func TestGetMetadata(t *testing.T) { t.Run("invalid repository id", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) _, err := m.GetMetadata("testdata/invalid-repo-id") assert.Error(t, err) assert.Contains(t, err.Error(), "invalid repository id") @@ -1042,42 +1049,42 @@ func TestGetMetadata(t *testing.T) { t.Run("local file: success fetching .yml", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) _, err := m.GetMetadata("testdata/artifacthub-repo") assert.NoError(t, err) }) t.Run("local file: success .yaml", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) _, err := m.GetMetadata("testdata/test-yaml-repo") assert.NoError(t, err) }) t.Run("remote file: success", func(t *testing.T) { t.Parallel() - hg := &tests.HTTPGetterMock{} - hg.On("Get", "http://url.test/ok.yml").Return(&http.Response{ + hc := &tests.HTTPClientMock{} + hc.On("Do", mdYmlReq).Return(&http.Response{ Body: ioutil.NopCloser(strings.NewReader("repositoryID: 00000000-0000-0000-0000-000000000001")), StatusCode: http.StatusOK, }, nil) - m := NewManager(cfg, nil, nil, withHTTPGetter(hg)) + m := NewManager(cfg, nil, nil, hc) _, err := m.GetMetadata("http://url.test/ok") assert.NoError(t, err) }) t.Run("remote file: success on yaml", func(t *testing.T) { t.Parallel() - hg := &tests.HTTPGetterMock{} - hg.On("Get", "http://url.test/ok.yaml").Return(&http.Response{ - Body: ioutil.NopCloser(strings.NewReader("repositoryID: 00000000-0000-0000-0000-000000000001")), - StatusCode: http.StatusOK, - }, nil) - hg.On("Get", "http://url.test/ok.yml").Return(&http.Response{ + hc := &tests.HTTPClientMock{} + hc.On("Do", mdYmlReq).Return(&http.Response{ Body: ioutil.NopCloser(strings.NewReader("")), StatusCode: http.StatusNotFound, }, nil) - m := NewManager(cfg, nil, nil, withHTTPGetter(hg)) + hc.On("Do", mdYamlReq).Return(&http.Response{ + Body: ioutil.NopCloser(strings.NewReader("repositoryID: 00000000-0000-0000-0000-000000000001")), + StatusCode: http.StatusOK, + }, nil) + m := NewManager(cfg, nil, nil, hc) _, err := m.GetMetadata("http://url.test/ok") assert.NoError(t, err) }) @@ -1088,7 +1095,7 @@ func TestGetPackagesDigest(t *testing.T) { t.Run("invalid input", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) _, err := m.GetPackagesDigest(context.Background(), "invalid") assert.True(t, errors.Is(err, hub.ErrInvalidInput)) }) @@ -1104,7 +1111,7 @@ func TestGetPackagesDigest(t *testing.T) { "package2@0.0.9": "digest-package2-0.0.9" } `), nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) pd, err := m.GetPackagesDigest(ctx, "00000000-0000-0000-0000-000000000001") require.NoError(t, err) @@ -1122,7 +1129,7 @@ func TestGetOwnedByOrgJSON(t *testing.T) { t.Run("user id not found in ctx", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) assert.Panics(t, func() { _, _ = m.GetOwnedByOrgJSON(context.Background(), "orgName", false) }) @@ -1130,7 +1137,7 @@ func TestGetOwnedByOrgJSON(t *testing.T) { t.Run("invalid input", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) _, err := m.GetOwnedByOrgJSON(ctx, "", false) assert.True(t, errors.Is(err, hub.ErrInvalidInput)) }) @@ -1139,7 +1146,7 @@ func TestGetOwnedByOrgJSON(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getOrgReposDBQ, "userID", "orgName", false).Return(nil, tests.ErrFakeDB) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) dataJSON, err := m.GetOwnedByOrgJSON(ctx, "orgName", false) assert.Equal(t, tests.ErrFakeDB, err) @@ -1151,7 +1158,7 @@ func TestGetOwnedByOrgJSON(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getOrgReposDBQ, "userID", "orgName", false).Return([]byte("dataJSON"), nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) dataJSON, err := m.GetOwnedByOrgJSON(ctx, "orgName", false) assert.NoError(t, err) @@ -1165,7 +1172,7 @@ func TestGetOwnedByUserJSON(t *testing.T) { t.Run("user id not found in ctx", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) assert.Panics(t, func() { _, _ = m.GetOwnedByUserJSON(context.Background(), false) }) @@ -1175,7 +1182,7 @@ func TestGetOwnedByUserJSON(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getUserReposDBQ, "userID", false).Return(nil, tests.ErrFakeDB) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) dataJSON, err := m.GetOwnedByUserJSON(ctx, false) assert.Equal(t, tests.ErrFakeDB, err) @@ -1187,7 +1194,7 @@ func TestGetOwnedByUserJSON(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("QueryRow", ctx, getUserReposDBQ, "userID", false).Return([]byte("dataJSON"), nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) dataJSON, err := m.GetOwnedByUserJSON(ctx, false) assert.NoError(t, err) @@ -1208,7 +1215,7 @@ func TestGetRemoteDigest(t *testing.T) { t.Parallel() l := &HelmIndexLoaderMock{} l.On("LoadIndex", helmHTTP).Return(nil, "", tests.ErrFake) - m := NewManager(cfg, nil, nil, WithHelmIndexLoader(l)) + m := NewManager(cfg, nil, nil, nil, WithHelmIndexLoader(l)) digest, err := m.GetRemoteDigest(ctx, helmHTTP) assert.Empty(t, digest) @@ -1219,7 +1226,7 @@ func TestGetRemoteDigest(t *testing.T) { t.Parallel() l := &HelmIndexLoaderMock{} l.On("LoadIndex", helmHTTP).Return(nil, "digest", nil) - m := NewManager(cfg, nil, nil, WithHelmIndexLoader(l)) + m := NewManager(cfg, nil, nil, nil, WithHelmIndexLoader(l)) digest, err := m.GetRemoteDigest(ctx, helmHTTP) assert.Equal(t, "digest", digest) @@ -1232,7 +1239,7 @@ func TestSetLastScanningResults(t *testing.T) { t.Run("invalid input", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) err := m.SetLastScanningResults(ctx, "invalid", "errors") assert.True(t, errors.Is(err, hub.ErrInvalidInput)) }) @@ -1241,7 +1248,7 @@ func TestSetLastScanningResults(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("Exec", ctx, setLastScanningResultsDBQ, repoID, "errors", false).Return(nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) err := m.SetLastScanningResults(ctx, repoID, "errors") assert.NoError(t, err) @@ -1252,7 +1259,7 @@ func TestSetLastScanningResults(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("Exec", ctx, setLastScanningResultsDBQ, repoID, "errors", false).Return(tests.ErrFakeDB) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) err := m.SetLastScanningResults(ctx, repoID, "errors") assert.Equal(t, tests.ErrFakeDB, err) @@ -1265,7 +1272,7 @@ func TestSetLastTrackingResults(t *testing.T) { t.Run("invalid input", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) err := m.SetLastTrackingResults(ctx, "invalid", "errors") assert.True(t, errors.Is(err, hub.ErrInvalidInput)) }) @@ -1274,7 +1281,7 @@ func TestSetLastTrackingResults(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("Exec", ctx, setLastTrackingResultsDBQ, repoID, "errors", false).Return(nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) err := m.SetLastTrackingResults(ctx, repoID, "errors") assert.NoError(t, err) @@ -1285,7 +1292,7 @@ func TestSetLastTrackingResults(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("Exec", ctx, setLastTrackingResultsDBQ, repoID, "errors", false).Return(tests.ErrFakeDB) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) err := m.SetLastTrackingResults(ctx, repoID, "errors") assert.Equal(t, tests.ErrFakeDB, err) @@ -1298,7 +1305,7 @@ func TestSetVerifiedPublisher(t *testing.T) { t.Run("invalid input", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) err := m.SetVerifiedPublisher(ctx, "invalid", true) assert.True(t, errors.Is(err, hub.ErrInvalidInput)) }) @@ -1307,7 +1314,7 @@ func TestSetVerifiedPublisher(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("Exec", ctx, setVerifiedPublisherDBQ, repoID, true).Return(nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) err := m.SetVerifiedPublisher(ctx, repoID, true) assert.NoError(t, err) @@ -1318,7 +1325,7 @@ func TestSetVerifiedPublisher(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("Exec", ctx, setVerifiedPublisherDBQ, repoID, true).Return(tests.ErrFakeDB) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) err := m.SetVerifiedPublisher(ctx, repoID, true) assert.Equal(t, tests.ErrFakeDB, err) @@ -1335,7 +1342,7 @@ func TestTransfer(t *testing.T) { t.Run("user id not found in ctx", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) assert.Panics(t, func() { _ = m.Transfer(context.Background(), "repo1", "", false) }) @@ -1355,7 +1362,7 @@ func TestTransfer(t *testing.T) { tc := tc t.Run(tc.errMsg, func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) err := m.Transfer(ctx, tc.repoName, "", false) assert.True(t, errors.Is(err, hub.ErrInvalidInput)) @@ -1379,7 +1386,7 @@ func TestTransfer(t *testing.T) { UserID: "userID", Action: hub.TransferOrganizationRepository, }).Return(tests.ErrFake) - m := NewManager(cfg, db, az) + m := NewManager(cfg, db, az, nil) err := m.Transfer(ctx, "repo1", "orgDest", false) assert.Equal(t, tests.ErrFake, err) @@ -1419,7 +1426,7 @@ func TestTransfer(t *testing.T) { UserID: "userID", Action: hub.TransferOrganizationRepository, }).Return(nil) - m := NewManager(cfg, db, az) + m := NewManager(cfg, db, az, nil) err := m.Transfer(ctx, "repo1", org, false) assert.Equal(t, tc.expectedError, err) @@ -1440,7 +1447,7 @@ func TestTransfer(t *testing.T) { } `), nil) db.On("Exec", ctx, transferRepoDBQ, "repo1", userIDP, orgP, false).Return(nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) err := m.Transfer(ctx, "repo1", org, false) assert.NoError(t, err) @@ -1453,7 +1460,7 @@ func TestUpdate(t *testing.T) { t.Run("user id not found in ctx", func(t *testing.T) { t.Parallel() - m := NewManager(cfg, nil, nil) + m := NewManager(cfg, nil, nil, nil) assert.Panics(t, func() { _ = m.Update(context.Background(), nil) }) @@ -1547,7 +1554,7 @@ func TestUpdate(t *testing.T) { } else { l.On("LoadIndex", tc.r).Return(nil, "", nil).Maybe() } - m := NewManager(cfg, nil, nil, WithHelmIndexLoader(l)) + m := NewManager(cfg, nil, nil, nil, WithHelmIndexLoader(l)) err := m.Update(ctx, tc.r) assert.True(t, errors.Is(err, hub.ErrInvalidInput)) @@ -1581,7 +1588,7 @@ func TestUpdate(t *testing.T) { }).Return(tests.ErrFake) l := &HelmIndexLoaderMock{} l.On("LoadIndex", r).Return(nil, "", nil) - m := NewManager(cfg, db, az, WithHelmIndexLoader(l)) + m := NewManager(cfg, db, az, nil, WithHelmIndexLoader(l)) err := m.Update(ctx, r) assert.Equal(t, tests.ErrFake, err) @@ -1638,7 +1645,7 @@ func TestUpdate(t *testing.T) { l := &HelmIndexLoaderMock{} l.On("LoadIndex", tc.r).Return(nil, "", nil) - m := NewManager(cfg, db, az, WithHelmIndexLoader(l)) + m := NewManager(cfg, db, az, nil, WithHelmIndexLoader(l)) err := m.Update(ctx, tc.r) assert.Equal(t, tc.expectedError, err) @@ -1682,7 +1689,7 @@ func TestUpdate(t *testing.T) { if tc.r.Kind == hub.Helm { l.On("LoadIndex", tc.r).Return(nil, "", nil) } - m := NewManager(cfg, db, nil, WithHelmIndexLoader(l)) + m := NewManager(cfg, db, nil, nil, WithHelmIndexLoader(l)) err := m.Update(ctx, tc.r) assert.NoError(t, err) @@ -1702,7 +1709,7 @@ func TestUpdateDigest(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("Exec", ctx, updateRepoDigestDBQ, repositoryID, digest).Return(nil) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) err := m.UpdateDigest(ctx, repositoryID, digest) assert.NoError(t, err) @@ -1713,7 +1720,7 @@ func TestUpdateDigest(t *testing.T) { t.Parallel() db := &tests.DBMock{} db.On("Exec", ctx, updateRepoDigestDBQ, repositoryID, digest).Return(tests.ErrFakeDB) - m := NewManager(cfg, db, nil) + m := NewManager(cfg, db, nil, nil) err := m.UpdateDigest(ctx, repositoryID, digest) assert.Equal(t, tests.ErrFakeDB, err) @@ -1721,12 +1728,6 @@ func TestUpdateDigest(t *testing.T) { }) } -func withHTTPGetter(hg HTTPGetter) func(m *Manager) { - return func(m *Manager) { - m.hg = hg - } -} - func withRepositoryCloner(rc hub.RepositoryCloner) func(m *Manager) { return func(m *Manager) { m.rc = rc diff --git a/internal/tests/http.go b/internal/tests/http.go index 7aff0318..b5c3e4d7 100644 --- a/internal/tests/http.go +++ b/internal/tests/http.go @@ -6,18 +6,6 @@ import ( "github.com/stretchr/testify/mock" ) -// HTTPGetterMock is a mock HTTPGetter implementation. -type HTTPGetterMock struct { - mock.Mock -} - -// Get implements the HTTPGetter interface. -func (m *HTTPGetterMock) Get(url string) (*http.Response, error) { - args := m.Called(url) - resp, _ := args.Get(0).(*http.Response) - return resp, args.Error(1) -} - // HTTPClientMock is a mock HTTPClient implementation. type HTTPClientMock struct { mock.Mock diff --git a/internal/util/http.go b/internal/util/http.go new file mode 100644 index 00000000..dcc4a8de --- /dev/null +++ b/internal/util/http.go @@ -0,0 +1,100 @@ +package util + +import ( + "context" + "crypto/tls" + "errors" + "net" + "net/http" + "syscall" + "time" + + "github.com/artifacthub/hub/internal/hub" +) + +var ( + // ErrRestrictedConnection error indicates that connections to the provided + // address are restricted. + ErrRestrictedConnection = errors.New("restricted connection") +) + +// SetupHTTPClient is a helper that returns an http client. If restricted is +// set to true, the http client won't be able to make requests to a set of +// restricted addresses. +func SetupHTTPClient(restricted bool) hub.HTTPClient { + if restricted { + return setupRestrictedHTTPClient() + } + return &http.Client{ + Timeout: 10 * time.Second, + } +} + +// setupRestrictedHTTPClient returns an http client that is not allowed to make +// requests to a set of restricted addresses. +func setupRestrictedHTTPClient() hub.HTTPClient { + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + DualStack: true, + Control: checkRestrictions, + } + transport := &http.Transport{ + DialContext: dialer.DialContext, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return tls.DialWithDialer(dialer, network, addr, &tls.Config{ + MinVersion: tls.VersionTLS12, + }) + }, + ExpectContinueTimeout: 1 * time.Second, + ForceAttemptHTTP2: true, + IdleConnTimeout: 90 * time.Second, + MaxIdleConns: 100, + Proxy: http.ProxyFromEnvironment, + TLSHandshakeTimeout: 10 * time.Second, + } + return &http.Client{ + Timeout: 10 * time.Second, + Transport: transport, + } +} + +// checkRestrictions checks if a connection to the provided network and address +// should be restricted. +func checkRestrictions(network string, address string, conn syscall.RawConn) error { + if !(network == "tcp4" || network == "tcp6") { + return ErrRestrictedConnection + } + host, _, err := net.SplitHostPort(address) + if err != nil { + return ErrRestrictedConnection + } + ip := net.ParseIP(host) + if ip == nil { + return ErrRestrictedConnection + } + if !ip.IsGlobalUnicast() || isPrivate(ip) { // TODO: use ip.IsPrivate() when available + return ErrRestrictedConnection + } + return nil +} + +// isPrivate reports whether ip is a private address, according to +// RFC 1918 (IPv4 addresses) and RFC 4193 (IPv6 addresses). +// +// Source: https://github.com/golang/go/blob/5963f0a332496a68f1eb2d0c6a5badd73c9f046d/src/net/ip.go#L131-L148 +func isPrivate(ip net.IP) bool { + if ip4 := ip.To4(); ip4 != nil { + // Following RFC 1918, Section 3. Private Address Space which says: + // The Internet Assigned Numbers Authority (IANA) has reserved the + // following three blocks of the IP address space for private internets: + // 10.0.0.0 - 10.255.255.255 (10/8 prefix) + // 172.16.0.0 - 172.31.255.255 (172.16/12 prefix) + // 192.168.0.0 - 192.168.255.255 (192.168/16 prefix) + return ip4[0] == 10 || + (ip4[0] == 172 && ip4[1]&0xf0 == 16) || + (ip4[0] == 192 && ip4[1] == 168) + } + // Following RFC 4193, Section 8. IANA Considerations which says: + // The IANA has assigned the FC00::/7 prefix to "Unique Local Unicast". + return len(ip) == net.IPv6len && ip[0]&0xfe == 0xfc +}