http-add-on/scaler/queue_pinger.go

148 lines
3.1 KiB
Go

// This file contains the implementation for the HTTP request queue used by the
// KEDA external scaler implementation
package main
import (
"context"
"net/http"
"sync"
"time"
"github.com/go-logr/logr"
"github.com/kedacore/http-add-on/pkg/k8s"
"github.com/kedacore/http-add-on/pkg/queue"
"golang.org/x/sync/errgroup"
)
type queuePinger struct {
getEndpointsFn k8s.GetEndpointsFunc
ns string
svcName string
adminPort string
pingMut *sync.RWMutex
lastPingTime time.Time
allCounts map[string]int
aggregateCount int
lggr logr.Logger
}
func newQueuePinger(
ctx context.Context,
lggr logr.Logger,
getEndpointsFn k8s.GetEndpointsFunc,
ns,
svcName,
adminPort string,
pingTicker *time.Ticker,
) *queuePinger {
pingMut := new(sync.RWMutex)
pinger := &queuePinger{
getEndpointsFn: getEndpointsFn,
ns: ns,
svcName: svcName,
adminPort: adminPort,
pingMut: pingMut,
lggr: lggr,
allCounts: map[string]int{},
}
go func() {
defer pingTicker.Stop()
for range pingTicker.C {
if err := pinger.requestCounts(ctx); err != nil {
lggr.Error(err, "getting request counts")
}
}
}()
return pinger
}
func (q *queuePinger) counts() map[string]int {
q.pingMut.RLock()
defer q.pingMut.RUnlock()
return q.allCounts
}
func (q *queuePinger) aggregate() int {
q.pingMut.RLock()
defer q.pingMut.RUnlock()
return q.aggregateCount
}
func (q *queuePinger) requestCounts(ctx context.Context) error {
lggr := q.lggr.WithName("queuePinger.requestCounts")
endpointURLs, err := k8s.EndpointsForService(
ctx,
q.ns,
q.svcName,
q.adminPort,
q.getEndpointsFn,
)
if err != nil {
return err
}
countsCh := make(chan *queue.Counts)
defer close(countsCh)
fetchGrp, _ := errgroup.WithContext(ctx)
for _, endpoint := range endpointURLs {
u := endpoint
fetchGrp.Go(func() error {
counts, err := queue.GetCounts(
ctx,
lggr,
http.DefaultClient,
*u,
)
if err != nil {
lggr.Error(
err,
"getting queue counts from interceptor",
"interceptorAddress",
u.String(),
)
return err
}
countsCh <- counts
return nil
})
}
// consume the results of the counts channel in a goroutine.
// we'll must for all the fetcher goroutines to finish after we
// start up this goroutine so that all goroutines can make
// progress
go func() {
agg := 0
totalCounts := make(map[string]int)
// range through the result of each endpoint
for count := range countsCh {
// each endpoint returns a map of counts, one count
// per host. add up the counts for each host
for host, val := range count.Counts {
agg += val
totalCounts[host] += val
}
}
q.pingMut.Lock()
defer q.pingMut.Unlock()
q.allCounts = totalCounts
q.aggregateCount = agg
q.lastPingTime = time.Now()
}()
// now that the counts channel is being consumed, all the
// fetch goroutines can make progress. wait for them
// to finish and check for errors.
if err := fetchGrp.Wait(); err != nil {
lggr.Error(err, "fetching all counts failed")
return err
}
return nil
}