51 lines
1.2 KiB
Go
51 lines
1.2 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
nethttp "net/http"
|
|
|
|
"github.com/go-logr/logr"
|
|
"github.com/kedacore/http-add-on/pkg/queue"
|
|
)
|
|
|
|
func getHost(r *nethttp.Request) (string, error) {
|
|
// check the host header first, then the request host
|
|
// field (which may contain the actual URL if there is no
|
|
// host header)
|
|
if r.Header.Get("Host") != "" {
|
|
return r.Header.Get("Host"), nil
|
|
}
|
|
if r.Host != "" {
|
|
return r.Host, nil
|
|
}
|
|
return "", fmt.Errorf("host not found")
|
|
}
|
|
|
|
// countMiddleware adds 1 to the given queue counter, executes next
|
|
// (by calling ServeHTTP on it), then decrements the queue counter
|
|
func countMiddleware(
|
|
lggr logr.Logger,
|
|
q queue.Counter,
|
|
next nethttp.Handler,
|
|
) nethttp.Handler {
|
|
return nethttp.HandlerFunc(func(w nethttp.ResponseWriter, r *nethttp.Request) {
|
|
host, err := getHost(r)
|
|
if err != nil {
|
|
lggr.Error(err, "not forwarding request")
|
|
w.WriteHeader(400)
|
|
w.Write([]byte("Host not found, not forwarding request"))
|
|
return
|
|
}
|
|
if err := q.Resize(host, +1); err != nil {
|
|
log.Printf("Error incrementing queue for %q (%s)", r.RequestURI, err)
|
|
}
|
|
defer func() {
|
|
if err := q.Resize(host, -1); err != nil {
|
|
log.Printf("Error decrementing queue for %q (%s)", r.RequestURI, err)
|
|
}
|
|
}()
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|