diff --git a/scaler/handlers.go b/scaler/handlers.go index c819fea..7c2587a 100644 --- a/scaler/handlers.go +++ b/scaler/handlers.go @@ -73,7 +73,9 @@ func (e *impl) IsActive( ) if !ok { err := fmt.Errorf("host '%s' not found in counts", host) - allCounts := mergeCountsWithRoutingTable(e.pinger.counts(), e.routingTable) + allCounts := e.pinger.mergeCountsWithRoutingTable( + e.routingTable, + ) lggr.Error(err, "Given host was not found in queue count map", "host", host, "allCounts", allCounts) return nil, err } @@ -173,7 +175,7 @@ func (e *impl) GetMetrics( hostCount = e.pinger.aggregate() } else { err := fmt.Errorf("host '%s' not found in counts", host) - allCounts := mergeCountsWithRoutingTable(e.pinger.counts(), e.routingTable) + allCounts := e.pinger.mergeCountsWithRoutingTable(e.routingTable) lggr.Error(err, "allCounts", allCounts) return nil, err } diff --git a/scaler/host_counts.go b/scaler/host_counts.go index b7059e0..0a88748 100644 --- a/scaler/host_counts.go +++ b/scaler/host_counts.go @@ -4,22 +4,6 @@ import ( "github.com/kedacore/http-add-on/pkg/routing" ) -// mergeCountsWithRoutingTable ensures that all hosts in routing table -// are present in combined counts, if count is not present value is set to 0 -func mergeCountsWithRoutingTable( - counts map[string]int, - table routing.TableReader, -) map[string]int { - mergedCounts := make(map[string]int) - for _, host := range table.Hosts() { - mergedCounts[host] = 0 - } - for key, value := range counts { - mergedCounts[key] = value - } - return mergedCounts -} - // getHostCount gets proper count for given host regardless whether // host is in counts or only in routerTable func getHostCount( diff --git a/scaler/host_counts_test.go b/scaler/host_counts_test.go index 72bceda..f5e1643 100644 --- a/scaler/host_counts_test.go +++ b/scaler/host_counts_test.go @@ -14,79 +14,66 @@ type testCase struct { retCounts map[string]int } -var cases = []testCase{ - { - name: "empty queue", - table: newRoutingTable([]hostAndTarget{ - { - host: "www.example.com", - target: routing.Target{}, +func cases() []testCase { + return []testCase{ + { + name: "empty queue", + table: newRoutingTable([]hostAndTarget{ + { + host: "www.example.com", + target: routing.Target{}, + }, + { + host: "www.example2.com", + target: routing.Target{}, + }, + }), + counts: make(map[string]int), + retCounts: map[string]int{ + "www.example.com": 0, + "www.example2.com": 0, }, - { - host: "www.example2.com", - target: routing.Target{}, + }, + { + name: "one entry in queue, same entry in routing table", + table: newRoutingTable([]hostAndTarget{ + { + host: "example.com", + target: routing.Target{}, + }, + }), + counts: map[string]int{ + "example.com": 1, }, - }), - counts: make(map[string]int), - retCounts: map[string]int{ - "www.example.com": 0, - "www.example2.com": 0, - }, - }, - { - name: "one entry in queue, same entry in routing table", - table: newRoutingTable([]hostAndTarget{ - { - host: "example.com", - target: routing.Target{}, + retCounts: map[string]int{ + "example.com": 1, }, - }), - counts: map[string]int{ - "example.com": 1, }, - retCounts: map[string]int{ - "example.com": 1, - }, - }, - { - name: "one entry in queue, two in routing table", - table: newRoutingTable([]hostAndTarget{ - { - host: "example.com", - target: routing.Target{}, + { + name: "one entry in queue, two in routing table", + table: newRoutingTable([]hostAndTarget{ + { + host: "example.com", + target: routing.Target{}, + }, + { + host: "example2.com", + target: routing.Target{}, + }, + }), + counts: map[string]int{ + "example.com": 1, }, - { - host: "example2.com", - target: routing.Target{}, + retCounts: map[string]int{ + "example.com": 1, + "example2.com": 0, }, - }), - counts: map[string]int{ - "example.com": 1, }, - retCounts: map[string]int{ - "example.com": 1, - "example2.com": 0, - }, - }, -} - -func TestMergeCountsWithRoutingTable(t *testing.T) { - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - r := require.New(t) - ret := mergeCountsWithRoutingTable( - tc.counts, - tc.table, - ) - r.Equal(tc.retCounts, ret) - }) } + } - func TestGetHostCount(t *testing.T) { - - for _, tc := range cases { + for _, tc := range cases() { for host, retCount := range tc.retCounts { t.Run(tc.name, func(t *testing.T) { r := require.New(t) diff --git a/scaler/queue_pinger.go b/scaler/queue_pinger.go index 9606394..8a1980a 100644 --- a/scaler/queue_pinger.go +++ b/scaler/queue_pinger.go @@ -11,6 +11,7 @@ import ( "github.com/go-logr/logr" "github.com/kedacore/http-add-on/pkg/k8s" "github.com/kedacore/http-add-on/pkg/queue" + "github.com/kedacore/http-add-on/pkg/routing" "github.com/pkg/errors" "golang.org/x/sync/errgroup" ) @@ -124,6 +125,23 @@ func (q *queuePinger) counts() map[string]int { return q.allCounts } +// mergeCountsWithRoutingTable ensures that all hosts in routing table +// are present in combined counts, if count is not present value is set to 0 +func (q *queuePinger) mergeCountsWithRoutingTable( + table routing.TableReader, +) map[string]int { + q.pingMut.RLock() + defer q.pingMut.RUnlock() + mergedCounts := make(map[string]int) + for _, host := range table.Hosts() { + mergedCounts[host] = 0 + } + for key, value := range q.allCounts { + mergedCounts[key] = value + } + return mergedCounts +} + func (q *queuePinger) aggregate() int { q.pingMut.RLock() defer q.pingMut.RUnlock() diff --git a/scaler/queue_pinger_test.go b/scaler/queue_pinger_test.go index 7ec4b96..2e9582c 100644 --- a/scaler/queue_pinger_test.go +++ b/scaler/queue_pinger_test.go @@ -9,6 +9,7 @@ import ( "github.com/kedacore/http-add-on/pkg/k8s" "github.com/kedacore/http-add-on/pkg/queue" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" v1 "k8s.io/api/core/v1" ) @@ -215,3 +216,48 @@ func TestFetchCounts(t *testing.T) { } r.Equal(expectedCounts, cts) } + +func TestMergeCountsWithRoutingTable(t *testing.T) { + for _, tc := range cases() { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + grp, ctx := errgroup.WithContext(ctx) + r := require.New(t) + const C = 100 + tickr, q, err := newFakeQueuePinger( + ctx, + logr.Discard(), + ) + r.NoError(err) + defer tickr.Stop() + q.allCounts = tc.counts + + retCh := make(chan map[string]int) + for i := 0; i < C; i++ { + grp.Go(func() error { + retCh <- q.mergeCountsWithRoutingTable(tc.table) + return nil + }) + } + + // ensure we receive from retCh C times + allRets := map[int]map[string]int{} + for i := 0; i < C; i++ { + allRets[i] = <-retCh + } + + r.NoError(grp.Wait()) + + // ensure that all returned maps are the + // same + prev := allRets[0] + for i := 1; i < C; i++ { + r.Equal(prev, allRets[i]) + prev = allRets[i] + } + // ensure that all the returned maps are + // equal to what we expected + r.Equal(tc.retCounts, prev) + }) + } +}