discourse-auth-proxy/internal/httpproxy/backend.go

266 lines
7.7 KiB
Go

package httpproxy
import (
"context"
"fmt"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"net/url"
"sync"
"time"
)
type director func(*http.Request)
type directorID string
func newDirectorIDFromSRV(srv *net.SRV) directorID {
return directorID(srvToTCPAddr(srv))
}
// DNSSRVBackend bestows httputil.ReverseProxy with the ability to forward a
// request to backends discovered via DNS SRV records. This type automatically
// degrades to a conventional 'single host' reverse proxy if DNS SRV records are
// not in use.
//
// Initialise with NewDNSSRVBackend, call the Lookup method, then assign a
// reference to the Director method in a ReverseProxy value. Refer to the
// documentation for the Lookup and Director methods for quirks.
//
// The DNSSRVBackend does not currently offer its own SRV-aware transport for
// SRV-compatible retries; this omission will almost certainly lead to pain and
// suffering if your DNS SRV records are subject to constant churn. Client-side
// retries may be used to mitigate these effects.
type DNSSRVBackend struct {
target *url.URL
fallbackDirector director
srvsLock sync.RWMutex
srvs []*net.SRV
// directors is a map of memoised director funcs. Used to optimise the
// common case of many requests to a mostly static set of backends.
directors map[directorID]director
}
// NewDNSSRVBackend returns a new DNSSRVBackend value. The host in target will
// eventually be queried for DNS SRV records. If none are found, or if the host
// specifies an IPv4/6 address, the Director method will proxy directly to the
// literal value of target, bypassing all SRV functionality. Path and query
// semantics are identical to the httputil package's NewSingleHostReverseProxy
// function. DNS SRV records will not begin to influence Director method
// behaviour until the Lookup method is called.
func NewDNSSRVBackend(target *url.URL) *DNSSRVBackend {
return &DNSSRVBackend{
target: target,
fallbackDirector: initDirector(target),
directors: make(map[directorID]director),
}
}
// Lookup continuously polls the DNS for new target SRV records. One lookup is
// made shortly upon entry into this function, then every interval thereafter.
// A random value between zero and jitter is applied to each interval. The
// function will return if no SRV records are found after abandon, which may be
// set to zero in order to keep the function alive.
//
// Results from the most recent successful lookup are used by the Director
// function. DNS TTLs are not honoured. This quirk is considered to be a bug,
// and may be fixed at a later date.
//
// Because the SRV priority field only makes sense within the context of
// retries, and because the DNSSRVBackend does not currently offer any mechanism
// that implements SRV-compatible retries, the Lookup method discards all but
// the most-preferred targets by priority. In other words, given the following
// set of SRV RRs:
//
// _foobar._tcp SRV 0 1 9 old-slow-box.example.com.
// SRV 0 3 9 new-fast-box.example.com.
// SRV 1 0 9 sysadmins-box.example.com.
// SRV 1 0 9 server.example.com.
//
// Lookup will discard 'sysadmins-box' and 'server'. Only 'old-slow-box' and
// 'new-fast-box' will be cached for later use. (The priority field need not be
// zero, nor need it be constant over time.)
//
// ctx may be used to cancel the execution of this function. The caller is
// expected to call this function in its own goroutine.
func (b *DNSSRVBackend) Lookup(ctx context.Context, interval, jitter, abandon time.Duration) {
if _, resolvable := b.name(); !resolvable {
return
}
var (
noexpire = func() bool { return true }
again = noexpire
)
if abandon != time.Duration(0) {
deadline := time.Now().Add(abandon)
again = func() bool { return time.Now().Before(deadline) }
}
lookup:
for {
if err := b.lookupOne(ctx); err == nil {
// SRV records are in use. Hang around for future updates.
again = noexpire
}
if !again() {
break lookup
}
select {
case <-ctx.Done():
break lookup
case <-time.After(interval + randomJitter(jitter)):
}
}
}
// Director implements a httputil.ReverseProxy Director. The host and port from
// req are replaced with values taken from an SRV record if any such record is
// found in the DNS. When multiple SRV records are found in the DNS, each
// request is assigned a backend in keeping with SRV weights. The director does
// not implement any form of backend affinity.
//
// SRV lookups are handled outside the critical path: see the Lookup method.
func (b *DNSSRVBackend) Director(req *http.Request) {
d := b.findDirector()
d(req)
}
func (b *DNSSRVBackend) lookupOne(ctx context.Context) error {
name, _ := b.name()
_, srvs, err := net.DefaultResolver.LookupSRV(ctx, "", "", name)
if err != nil {
return err
}
b.setSrvs(srvFilterTopPriority(srvs))
return nil
}
func (b *DNSSRVBackend) findDirector() director {
srvs := b.getSrvs()
if srvs == nil || len(srvs) == 0 {
return b.fallbackDirector
}
srvShuffleByWeight(srvs)
srv := srvs[0]
id := newDirectorIDFromSRV(srv)
if d, ok := b.directors[id]; ok {
return d
}
d := initDirectorFromSRV(b.target, srv)
b.directors[id] = d
return d
}
func (b *DNSSRVBackend) getSrvs() []*net.SRV {
var srvs []*net.SRV
b.srvsLock.RLock()
if b.srvs != nil {
srvs = make([]*net.SRV, len(b.srvs))
copy(srvs, b.srvs)
}
b.srvsLock.RUnlock()
return srvs
}
func (b *DNSSRVBackend) setSrvs(srvs []*net.SRV) {
b.srvsLock.Lock()
b.srvs = srvs
b.srvsLock.Unlock()
}
// name returns the host component from b.target and a boolean indicating
// whether the host component is likely to be a resolvable name.
func (b *DNSSRVBackend) name() (string, bool) {
// We could try to match against the record naming patterns described in RFC
// 2782 (and further elaborated on in RFC 6763), but it's easier to be more
// permissive in this specific case. Anything that ain't an IPv4/6 address
// is fair game.
host := b.target.Hostname()
resolvable := false
if net.ParseIP(host) == nil {
resolvable = true
}
return host, resolvable
}
func initDirector(target *url.URL) director {
rp := httputil.NewSingleHostReverseProxy(target)
return rp.Director
}
func initDirectorFromSRV(target *url.URL, srv *net.SRV) director {
t := *target
t.Host = srvToTCPAddr(srv)
return initDirector(&t)
}
func srvToTCPAddr(srv *net.SRV) string {
return fmt.Sprintf("%s:%d", srv.Target, srv.Port)
}
// srvFilterTopPriority returns, from srvs, the set of SRV RRs that share the
// most-preferred priority, whatever that priority value might be. Relative RR
// ordering within this set is maintained. srvs is expected to be pre-sorted by
// priority. (The net resolver will automatically sort SRV RRs by prio.)
func srvFilterTopPriority(srvs []*net.SRV) []*net.SRV {
if len(srvs) == 0 {
return []*net.SRV{}
}
var (
top = make([]*net.SRV, 0, len(srvs))
prio = srvs[0].Priority
)
for _, s := range srvs {
if s.Priority != prio {
break
}
top = append(top, s)
}
return top
}
// srvShuffleByWeight applies the weight shuffling algorithm described in RFC
// 2782 to a slice of SRV RRs of equivalent priority.
//
// This implementation was shamelessly clagged from the Go standard library.
func srvShuffleByWeight(srvs []*net.SRV) {
sum := 0
for _, srv := range srvs {
sum += int(srv.Weight)
}
for sum > 0 && len(srvs) > 1 {
s := 0
n := rand.Intn(sum)
for i := range srvs {
s += int(srvs[i].Weight)
if s > n {
if i > 0 {
srvs[0], srvs[i] = srvs[i], srvs[0]
}
break
}
}
sum -= int(srvs[0].Weight)
srvs = srvs[1:]
}
}
func randomJitter(max time.Duration) time.Duration {
m := int64(max)
if m == 0 {
return time.Duration(0)
}
return time.Duration(rand.Int63n(m))
}