boulder/bdns/dns_test.go

892 lines
30 KiB
Go

package bdns
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"regexp"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/jmhodges/clock"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
blog "github.com/letsencrypt/boulder/log"
"github.com/letsencrypt/boulder/metrics"
"github.com/letsencrypt/boulder/test"
)
const dnsLoopbackAddr = "127.0.0.1:4053"
func mockDNSQuery(w http.ResponseWriter, httpReq *http.Request) {
if httpReq.Header.Get("Content-Type") != "application/dns-message" {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "client didn't send Content-Type: application/dns-message")
}
if httpReq.Header.Get("Accept") != "application/dns-message" {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "client didn't accept Content-Type: application/dns-message")
}
requestBody, err := io.ReadAll(httpReq.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "reading body: %s", err)
}
httpReq.Body.Close()
r := new(dns.Msg)
err = r.Unpack(requestBody)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "unpacking request: %s", err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false
appendAnswer := func(rr dns.RR) {
m.Answer = append(m.Answer, rr)
}
for _, q := range r.Question {
q.Name = strings.ToLower(q.Name)
if q.Name == "servfail.com." || q.Name == "servfailexception.example.com" {
m.Rcode = dns.RcodeServerFailure
break
}
switch q.Qtype {
case dns.TypeSOA:
record := new(dns.SOA)
record.Hdr = dns.RR_Header{Name: "letsencrypt.org.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 0}
record.Ns = "ns.letsencrypt.org."
record.Mbox = "master.letsencrypt.org."
record.Serial = 1
record.Refresh = 1
record.Retry = 1
record.Expire = 1
record.Minttl = 1
appendAnswer(record)
case dns.TypeAAAA:
if q.Name == "v6.letsencrypt.org." {
record := new(dns.AAAA)
record.Hdr = dns.RR_Header{Name: "v6.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0}
record.AAAA = net.ParseIP("2602:80a:6000:abad:cafe::1")
appendAnswer(record)
}
if q.Name == "dualstack.letsencrypt.org." {
record := new(dns.AAAA)
record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0}
record.AAAA = net.ParseIP("2602:80a:6000:abad:cafe::1")
appendAnswer(record)
}
if q.Name == "v4error.letsencrypt.org." {
record := new(dns.AAAA)
record.Hdr = dns.RR_Header{Name: "v4error.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0}
record.AAAA = net.ParseIP("2602:80a:6000:abad:cafe::1")
appendAnswer(record)
}
if q.Name == "v6error.letsencrypt.org." {
m.SetRcode(r, dns.RcodeNotImplemented)
}
if q.Name == "nxdomain.letsencrypt.org." {
m.SetRcode(r, dns.RcodeNameError)
}
if q.Name == "dualstackerror.letsencrypt.org." {
m.SetRcode(r, dns.RcodeNotImplemented)
}
case dns.TypeA:
if q.Name == "cps.letsencrypt.org." {
record := new(dns.A)
record.Hdr = dns.RR_Header{Name: "cps.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}
record.A = net.ParseIP("64.112.117.1")
appendAnswer(record)
}
if q.Name == "dualstack.letsencrypt.org." {
record := new(dns.A)
record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}
record.A = net.ParseIP("64.112.117.1")
appendAnswer(record)
}
if q.Name == "v6error.letsencrypt.org." {
record := new(dns.A)
record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}
record.A = net.ParseIP("64.112.117.1")
appendAnswer(record)
}
if q.Name == "v4error.letsencrypt.org." {
m.SetRcode(r, dns.RcodeNotImplemented)
}
if q.Name == "nxdomain.letsencrypt.org." {
m.SetRcode(r, dns.RcodeNameError)
}
if q.Name == "dualstackerror.letsencrypt.org." {
m.SetRcode(r, dns.RcodeRefused)
}
case dns.TypeCNAME:
if q.Name == "cname.letsencrypt.org." {
record := new(dns.CNAME)
record.Hdr = dns.RR_Header{Name: "cname.letsencrypt.org.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 30}
record.Target = "cps.letsencrypt.org."
appendAnswer(record)
}
if q.Name == "cname.example.com." {
record := new(dns.CNAME)
record.Hdr = dns.RR_Header{Name: "cname.example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 30}
record.Target = "CAA.example.com."
appendAnswer(record)
}
case dns.TypeDNAME:
if q.Name == "dname.letsencrypt.org." {
record := new(dns.DNAME)
record.Hdr = dns.RR_Header{Name: "dname.letsencrypt.org.", Rrtype: dns.TypeDNAME, Class: dns.ClassINET, Ttl: 30}
record.Target = "cps.letsencrypt.org."
appendAnswer(record)
}
case dns.TypeCAA:
if q.Name == "bracewel.net." || q.Name == "caa.example.com." {
record := new(dns.CAA)
record.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeCAA, Class: dns.ClassINET, Ttl: 0}
record.Tag = "issue"
record.Value = "letsencrypt.org"
record.Flag = 1
appendAnswer(record)
}
if q.Name == "cname.example.com." {
record := new(dns.CAA)
record.Hdr = dns.RR_Header{Name: "caa.example.com.", Rrtype: dns.TypeCAA, Class: dns.ClassINET, Ttl: 0}
record.Tag = "issue"
record.Value = "letsencrypt.org"
record.Flag = 1
appendAnswer(record)
}
if q.Name == "gonetld." {
m.SetRcode(r, dns.RcodeNameError)
}
case dns.TypeTXT:
if q.Name == "split-txt.letsencrypt.org." {
record := new(dns.TXT)
record.Hdr = dns.RR_Header{Name: "split-txt.letsencrypt.org.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0}
record.Txt = []string{"a", "b", "c"}
appendAnswer(record)
} else {
auth := new(dns.SOA)
auth.Hdr = dns.RR_Header{Name: "letsencrypt.org.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 0}
auth.Ns = "ns.letsencrypt.org."
auth.Mbox = "master.letsencrypt.org."
auth.Serial = 1
auth.Refresh = 1
auth.Retry = 1
auth.Expire = 1
auth.Minttl = 1
m.Ns = append(m.Ns, auth)
}
if q.Name == "nxdomain.letsencrypt.org." {
m.SetRcode(r, dns.RcodeNameError)
}
}
}
body, err := m.Pack()
if err != nil {
fmt.Fprintf(os.Stderr, "packing reply: %s\n", err)
}
w.Header().Set("Content-Type", "application/dns-message")
_, err = w.Write(body)
if err != nil {
panic(err) // running tests, so panic is OK
}
}
func serveLoopResolver(stopChan chan bool) {
m := http.NewServeMux()
m.HandleFunc("/dns-query", mockDNSQuery)
httpServer := &http.Server{
Addr: dnsLoopbackAddr,
Handler: m,
ReadTimeout: time.Second,
WriteTimeout: time.Second,
}
go func() {
cert := "../test/certs/ipki/localhost/cert.pem"
key := "../test/certs/ipki/localhost/key.pem"
err := httpServer.ListenAndServeTLS(cert, key)
if err != nil {
fmt.Println(err)
}
}()
go func() {
<-stopChan
err := httpServer.Shutdown(context.Background())
if err != nil {
log.Fatal(err)
}
}()
}
func pollServer() {
backoff := 200 * time.Millisecond
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ticker := time.NewTicker(backoff)
for {
select {
case <-ctx.Done():
fmt.Fprintln(os.Stderr, "Timeout reached while testing for the dns server to come up")
os.Exit(1)
case <-ticker.C:
conn, _ := dns.DialTimeout("udp", dnsLoopbackAddr, backoff)
if conn != nil {
_ = conn.Close()
return
}
}
}
}
// tlsConfig is used for the TLS config of client instances that talk to the
// DoH server set up in TestMain.
var tlsConfig *tls.Config
func TestMain(m *testing.M) {
root, err := os.ReadFile("../test/certs/ipki/minica.pem")
if err != nil {
log.Fatal(err)
}
pool := x509.NewCertPool()
pool.AppendCertsFromPEM(root)
tlsConfig = &tls.Config{
RootCAs: pool,
}
stop := make(chan bool, 1)
serveLoopResolver(stop)
pollServer()
ret := m.Run()
stop <- true
os.Exit(ret)
}
func TestDNSNoServers(t *testing.T) {
staticProvider, err := NewStaticProvider([]string{})
test.AssertNotError(t, err, "Got error creating StaticProvider")
obj := New(time.Hour, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
_, resolvers, err := obj.LookupHost(context.Background(), "letsencrypt.org")
test.AssertEquals(t, len(resolvers), 0)
test.AssertError(t, err, "No servers")
_, _, err = obj.LookupTXT(context.Background(), "letsencrypt.org")
test.AssertError(t, err, "No servers")
_, _, _, err = obj.LookupCAA(context.Background(), "letsencrypt.org")
test.AssertError(t, err, "No servers")
}
func TestDNSOneServer(t *testing.T) {
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
_, resolvers, err := obj.LookupHost(context.Background(), "cps.letsencrypt.org")
test.AssertEquals(t, len(resolvers), 2)
slices.Sort(resolvers)
test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
test.AssertNotError(t, err, "No message")
}
func TestDNSDuplicateServers(t *testing.T) {
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr, dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
_, resolvers, err := obj.LookupHost(context.Background(), "cps.letsencrypt.org")
test.AssertEquals(t, len(resolvers), 2)
slices.Sort(resolvers)
test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
test.AssertNotError(t, err, "No message")
}
func TestDNSServFail(t *testing.T) {
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
bad := "servfail.com"
_, _, err = obj.LookupTXT(context.Background(), bad)
test.AssertError(t, err, "LookupTXT didn't return an error")
_, _, err = obj.LookupHost(context.Background(), bad)
test.AssertError(t, err, "LookupHost didn't return an error")
emptyCaa, _, _, err := obj.LookupCAA(context.Background(), bad)
test.Assert(t, len(emptyCaa) == 0, "Query returned non-empty list of CAA records")
test.AssertError(t, err, "LookupCAA should have returned an error")
}
func TestDNSLookupTXT(t *testing.T) {
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
a, _, err := obj.LookupTXT(context.Background(), "letsencrypt.org")
t.Logf("A: %v", a)
test.AssertNotError(t, err, "No message")
a, _, err = obj.LookupTXT(context.Background(), "split-txt.letsencrypt.org")
t.Logf("A: %v ", a)
test.AssertNotError(t, err, "No message")
test.AssertEquals(t, len(a), 1)
test.AssertEquals(t, a[0], "abc")
}
// TODO(#8213): Convert this to a table test.
func TestDNSLookupHost(t *testing.T) {
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
ip, resolvers, err := obj.LookupHost(context.Background(), "servfail.com")
t.Logf("servfail.com - IP: %s, Err: %s", ip, err)
test.AssertError(t, err, "Server failure")
test.Assert(t, len(ip) == 0, "Should not have IPs")
slices.Sort(resolvers)
test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
ip, resolvers, err = obj.LookupHost(context.Background(), "nonexistent.letsencrypt.org")
t.Logf("nonexistent.letsencrypt.org - IP: %s, Err: %s", ip, err)
test.AssertError(t, err, "No valid A or AAAA records should error")
test.Assert(t, len(ip) == 0, "Should not have IPs")
slices.Sort(resolvers)
test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
// Single IPv4 address
ip, resolvers, err = obj.LookupHost(context.Background(), "cps.letsencrypt.org")
t.Logf("cps.letsencrypt.org - IP: %s, Err: %s", ip, err)
test.AssertNotError(t, err, "Not an error to exist")
test.Assert(t, len(ip) == 1, "Should have IP")
slices.Sort(resolvers)
test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
ip, resolvers, err = obj.LookupHost(context.Background(), "cps.letsencrypt.org")
t.Logf("cps.letsencrypt.org - IP: %s, Err: %s", ip, err)
test.AssertNotError(t, err, "Not an error to exist")
test.Assert(t, len(ip) == 1, "Should have IP")
slices.Sort(resolvers)
test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
// Single IPv6 address
ip, resolvers, err = obj.LookupHost(context.Background(), "v6.letsencrypt.org")
t.Logf("v6.letsencrypt.org - IP: %s, Err: %s", ip, err)
test.AssertNotError(t, err, "Not an error to exist")
test.Assert(t, len(ip) == 1, "Should not have IPs")
slices.Sort(resolvers)
test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
// Both IPv6 and IPv4 address
ip, resolvers, err = obj.LookupHost(context.Background(), "dualstack.letsencrypt.org")
t.Logf("dualstack.letsencrypt.org - IP: %s, Err: %s", ip, err)
test.AssertNotError(t, err, "Not an error to exist")
test.Assert(t, len(ip) == 2, "Should have 2 IPs")
expected := netip.MustParseAddr("64.112.117.1")
test.Assert(t, ip[0] == expected, "wrong ipv4 address")
expected = netip.MustParseAddr("2602:80a:6000:abad:cafe::1")
test.Assert(t, ip[1] == expected, "wrong ipv6 address")
slices.Sort(resolvers)
test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
// IPv6 error, IPv4 success
ip, resolvers, err = obj.LookupHost(context.Background(), "v6error.letsencrypt.org")
t.Logf("v6error.letsencrypt.org - IP: %s, Err: %s", ip, err)
test.AssertNotError(t, err, "Not an error to exist")
test.Assert(t, len(ip) == 1, "Should have 1 IP")
expected = netip.MustParseAddr("64.112.117.1")
test.Assert(t, ip[0] == expected, "wrong ipv4 address")
slices.Sort(resolvers)
test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
// IPv6 success, IPv4 error
ip, resolvers, err = obj.LookupHost(context.Background(), "v4error.letsencrypt.org")
t.Logf("v4error.letsencrypt.org - IP: %s, Err: %s", ip, err)
test.AssertNotError(t, err, "Not an error to exist")
test.Assert(t, len(ip) == 1, "Should have 1 IP")
expected = netip.MustParseAddr("2602:80a:6000:abad:cafe::1")
test.Assert(t, ip[0] == expected, "wrong ipv6 address")
slices.Sort(resolvers)
test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
// IPv6 error, IPv4 error
// Should return both the IPv4 error (Refused) and the IPv6 error (NotImplemented)
hostname := "dualstackerror.letsencrypt.org"
ip, resolvers, err = obj.LookupHost(context.Background(), hostname)
t.Logf("%s - IP: %s, Err: %s", hostname, ip, err)
test.AssertError(t, err, "Should be an error")
test.AssertContains(t, err.Error(), "REFUSED looking up A for")
test.AssertContains(t, err.Error(), "NOTIMP looking up AAAA for")
slices.Sort(resolvers)
test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
}
func TestDNSNXDOMAIN(t *testing.T) {
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
hostname := "nxdomain.letsencrypt.org"
_, _, err = obj.LookupHost(context.Background(), hostname)
test.AssertContains(t, err.Error(), "NXDOMAIN looking up A for")
test.AssertContains(t, err.Error(), "NXDOMAIN looking up AAAA for")
_, _, err = obj.LookupTXT(context.Background(), hostname)
expected := Error{dns.TypeTXT, hostname, nil, dns.RcodeNameError, nil}
test.AssertDeepEquals(t, err, expected)
}
func TestDNSLookupCAA(t *testing.T) {
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig)
removeIDExp := regexp.MustCompile(" id: [[:digit:]]+")
caas, resp, resolvers, err := obj.LookupCAA(context.Background(), "bracewel.net")
test.AssertNotError(t, err, "CAA lookup failed")
test.Assert(t, len(caas) > 0, "Should have CAA records")
test.AssertEquals(t, len(resolvers), 1)
test.AssertDeepEquals(t, resolvers, ResolverAddrs{"127.0.0.1:4053"})
expectedResp := `;; opcode: QUERY, status: NOERROR, id: XXXX
;; flags: qr rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0
;; QUESTION SECTION:
;bracewel.net. IN CAA
;; ANSWER SECTION:
bracewel.net. 0 IN CAA 1 issue "letsencrypt.org"
`
test.AssertEquals(t, removeIDExp.ReplaceAllString(resp, " id: XXXX"), expectedResp)
caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "nonexistent.letsencrypt.org")
test.AssertNotError(t, err, "CAA lookup failed")
test.Assert(t, len(caas) == 0, "Shouldn't have CAA records")
test.AssertEquals(t, resolvers[0], "127.0.0.1:4053")
expectedResp = ""
test.AssertEquals(t, resp, expectedResp)
caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "nxdomain.letsencrypt.org")
slices.Sort(resolvers)
test.AssertNotError(t, err, "CAA lookup failed")
test.Assert(t, len(caas) == 0, "Shouldn't have CAA records")
test.AssertEquals(t, resolvers[0], "127.0.0.1:4053")
expectedResp = ""
test.AssertEquals(t, resp, expectedResp)
caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "cname.example.com")
test.AssertNotError(t, err, "CAA lookup failed")
test.Assert(t, len(caas) > 0, "Should follow CNAME to find CAA")
test.AssertEquals(t, resolvers[0], "127.0.0.1:4053")
expectedResp = `;; opcode: QUERY, status: NOERROR, id: XXXX
;; flags: qr rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0
;; QUESTION SECTION:
;cname.example.com. IN CAA
;; ANSWER SECTION:
caa.example.com. 0 IN CAA 1 issue "letsencrypt.org"
`
test.AssertEquals(t, removeIDExp.ReplaceAllString(resp, " id: XXXX"), expectedResp)
_, _, resolvers, err = obj.LookupCAA(context.Background(), "gonetld")
test.AssertError(t, err, "should fail for TLD NXDOMAIN")
test.AssertContains(t, err.Error(), "NXDOMAIN")
test.AssertEquals(t, resolvers[0], "127.0.0.1:4053")
}
type testExchanger struct {
sync.Mutex
count int
errs []error
}
var errTooManyRequests = errors.New("too many requests")
func (te *testExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
te.Lock()
defer te.Unlock()
msg := &dns.Msg{
MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess},
}
if len(te.errs) <= te.count {
return nil, 0, errTooManyRequests
}
err := te.errs[te.count]
te.count++
return msg, 2 * time.Millisecond, err
}
func TestRetry(t *testing.T) {
isTempErr := &url.Error{Op: "read", Err: tempError(true)}
nonTempErr := &url.Error{Op: "read", Err: tempError(false)}
servFailError := errors.New("DNS problem: server failure at resolver looking up TXT for example.com")
type testCase struct {
name string
maxTries int
te *testExchanger
expected error
expectedCount int
metricsAllRetries float64
}
tests := []*testCase{
// The success on first try case
{
name: "success",
maxTries: 3,
te: &testExchanger{
errs: []error{nil},
},
expected: nil,
expectedCount: 1,
},
// Immediate non-OpError, error returns immediately
{
name: "non-operror",
maxTries: 3,
te: &testExchanger{
errs: []error{errors.New("nope")},
},
expected: servFailError,
expectedCount: 1,
},
// Temporary err, then non-OpError stops at two tries
{
name: "err-then-non-operror",
maxTries: 3,
te: &testExchanger{
errs: []error{isTempErr, errors.New("nope")},
},
expected: servFailError,
expectedCount: 2,
},
// Temporary error given always
{
name: "persistent-temp-error",
maxTries: 3,
te: &testExchanger{
errs: []error{
isTempErr,
isTempErr,
isTempErr,
},
},
expected: servFailError,
expectedCount: 3,
metricsAllRetries: 1,
},
// Even with maxTries at 0, we should still let a single request go
// through
{
name: "zero-maxtries",
maxTries: 0,
te: &testExchanger{
errs: []error{nil},
},
expected: nil,
expectedCount: 1,
},
// Temporary error given just once causes two tries
{
name: "single-temp-error",
maxTries: 3,
te: &testExchanger{
errs: []error{
isTempErr,
nil,
},
},
expected: nil,
expectedCount: 2,
},
// Temporary error given twice causes three tries
{
name: "double-temp-error",
maxTries: 3,
te: &testExchanger{
errs: []error{
isTempErr,
isTempErr,
nil,
},
},
expected: nil,
expectedCount: 3,
},
// Temporary error given thrice causes three tries and fails
{
name: "triple-temp-error",
maxTries: 3,
te: &testExchanger{
errs: []error{
isTempErr,
isTempErr,
isTempErr,
},
},
expected: servFailError,
expectedCount: 3,
metricsAllRetries: 1,
},
// temporary then non-Temporary error causes two retries
{
name: "temp-nontemp-error",
maxTries: 3,
te: &testExchanger{
errs: []error{
isTempErr,
nonTempErr,
},
},
expected: servFailError,
expectedCount: 2,
},
}
for i, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
testClient := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), tc.maxTries, "", blog.UseMock(), tlsConfig)
dr := testClient.(*impl)
dr.dnsClient = tc.te
_, _, err = dr.LookupTXT(context.Background(), "example.com")
if err == errTooManyRequests {
t.Errorf("#%d, sent more requests than the test case handles", i)
}
expectedErr := tc.expected
if (expectedErr == nil && err != nil) ||
(expectedErr != nil && err == nil) ||
(expectedErr != nil && expectedErr.Error() != err.Error()) {
t.Errorf("#%d, error, expected %v, got %v", i, expectedErr, err)
}
if tc.expectedCount != tc.te.count {
t.Errorf("#%d, error, expectedCount %v, got %v", i, tc.expectedCount, tc.te.count)
}
if tc.metricsAllRetries > 0 {
test.AssertMetricWithLabelsEquals(
t, dr.timeoutCounter, prometheus.Labels{
"qtype": "TXT",
"type": "out of retries",
"resolver": "127.0.0.1",
"isTLD": "false",
}, tc.metricsAllRetries)
}
})
}
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
testClient := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 3, "", blog.UseMock(), tlsConfig)
dr := testClient.(*impl)
dr.dnsClient = &testExchanger{errs: []error{isTempErr, isTempErr, nil}}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, _, err = dr.LookupTXT(ctx, "example.com")
if err == nil ||
err.Error() != "DNS problem: query timed out (and was canceled) looking up TXT for example.com" {
t.Errorf("expected %s, got %s", context.Canceled, err)
}
dr.dnsClient = &testExchanger{errs: []error{isTempErr, isTempErr, nil}}
ctx, cancel = context.WithTimeout(context.Background(), -10*time.Hour)
defer cancel()
_, _, err = dr.LookupTXT(ctx, "example.com")
if err == nil ||
err.Error() != "DNS problem: query timed out looking up TXT for example.com" {
t.Errorf("expected %s, got %s", context.DeadlineExceeded, err)
}
dr.dnsClient = &testExchanger{errs: []error{isTempErr, isTempErr, nil}}
ctx, deadlineCancel := context.WithTimeout(context.Background(), -10*time.Hour)
deadlineCancel()
_, _, err = dr.LookupTXT(ctx, "example.com")
if err == nil ||
err.Error() != "DNS problem: query timed out looking up TXT for example.com" {
t.Errorf("expected %s, got %s", context.DeadlineExceeded, err)
}
test.AssertMetricWithLabelsEquals(
t, dr.timeoutCounter, prometheus.Labels{
"qtype": "TXT",
"type": "canceled",
"resolver": "127.0.0.1",
}, 1)
test.AssertMetricWithLabelsEquals(
t, dr.timeoutCounter, prometheus.Labels{
"qtype": "TXT",
"type": "deadline exceeded",
"resolver": "127.0.0.1",
}, 2)
}
func TestIsTLD(t *testing.T) {
if isTLD("com") != "true" {
t.Errorf("expected 'com' to be a TLD, got %q", isTLD("com"))
}
if isTLD("example.com") != "false" {
t.Errorf("expected 'example.com' to not a TLD, got %q", isTLD("example.com"))
}
}
type tempError bool
func (t tempError) Temporary() bool { return bool(t) }
func (t tempError) Error() string { return fmt.Sprintf("Temporary: %t", t) }
// rotateFailureExchanger is a dns.Exchange implementation that tracks a count
// of the number of calls to `Exchange` for a given address in the `lookups`
// map. For all addresses in the `brokenAddresses` map, a retryable error is
// returned from `Exchange`. This mock is used by `TestRotateServerOnErr`.
type rotateFailureExchanger struct {
sync.Mutex
lookups map[string]int
brokenAddresses map[string]bool
}
// Exchange for rotateFailureExchanger tracks the `a` argument in `lookups` and
// if present in `brokenAddresses`, returns a temporary error.
func (e *rotateFailureExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
e.Lock()
defer e.Unlock()
// Track that exchange was called for the given server
e.lookups[a]++
// If its a broken server, return a retryable error
if e.brokenAddresses[a] {
isTempErr := &url.Error{Op: "read", Err: tempError(true)}
return nil, 2 * time.Millisecond, isTempErr
}
return m, 2 * time.Millisecond, nil
}
// TestRotateServerOnErr ensures that a retryable error returned from a DNS
// server will result in the retry being performed against the next server in
// the list.
func TestRotateServerOnErr(t *testing.T) {
// Configure three DNS servers
dnsServers := []string{
"a:53", "b:53", "[2606:4700:4700::1111]:53",
}
// Set up a DNS client using these servers that will retry queries up to
// a maximum of 5 times. It's important to choose a maxTries value >= the
// number of dnsServers to ensure we always get around to trying the one
// working server
staticProvider, err := NewStaticProvider(dnsServers)
test.AssertNotError(t, err, "Got error creating StaticProvider")
maxTries := 5
client := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), maxTries, "", blog.UseMock(), tlsConfig)
// Configure a mock exchanger that will always return a retryable error for
// servers A and B. This will force server "[2606:4700:4700::1111]:53" to do
// all the work once retries reach it.
mock := &rotateFailureExchanger{
brokenAddresses: map[string]bool{
"a:53": true,
"b:53": true,
},
lookups: make(map[string]int),
}
client.(*impl).dnsClient = mock
// Perform a bunch of lookups. We choose the initial server randomly. Any time
// A or B is chosen there should be an error and a retry using the next server
// in the list. Since we configured maxTries to be larger than the number of
// servers *all* queries should eventually succeed by being retried against
// server "[2606:4700:4700::1111]:53".
for range maxTries * 2 {
_, resolvers, err := client.LookupTXT(context.Background(), "example.com")
test.AssertEquals(t, len(resolvers), 1)
test.AssertEquals(t, resolvers[0], "[2606:4700:4700::1111]:53")
// Any errors are unexpected - server "[2606:4700:4700::1111]:53" should
// have responded without error.
test.AssertNotError(t, err, "Expected no error from eventual retry with functional server")
}
// We expect that the A and B servers had a non-zero number of lookups
// attempted.
test.Assert(t, mock.lookups["a:53"] > 0, "Expected A server to have non-zero lookup attempts")
test.Assert(t, mock.lookups["b:53"] > 0, "Expected B server to have non-zero lookup attempts")
// We expect that the server "[2606:4700:4700::1111]:53" eventually served
// all of the lookups attempted.
test.AssertEquals(t, mock.lookups["[2606:4700:4700::1111]:53"], maxTries*2)
}
type mockTempURLError struct{}
func (m *mockTempURLError) Error() string { return "whoops, oh gosh" }
func (m *mockTempURLError) Timeout() bool { return false }
func (m *mockTempURLError) Temporary() bool { return true }
type dohAlwaysRetryExchanger struct {
sync.Mutex
err error
}
func (dohE *dohAlwaysRetryExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
dohE.Lock()
defer dohE.Unlock()
tempURLerror := &url.Error{
Op: "GET",
URL: "https://example.com",
Err: &mockTempURLError{},
}
return nil, time.Second, tempURLerror
}
func TestDOHMetric(t *testing.T) {
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
testClient := New(time.Second*11, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 0, "", blog.UseMock(), tlsConfig)
resolver := testClient.(*impl)
resolver.dnsClient = &dohAlwaysRetryExchanger{err: &url.Error{Op: "read", Err: tempError(true)}}
// Starting out, we should count 0 "out of retries" errors.
test.AssertMetricWithLabelsEquals(t, resolver.timeoutCounter, prometheus.Labels{"qtype": "None", "type": "out of retries", "resolver": "127.0.0.1", "isTLD": "false"}, 0)
// Trigger the error.
_, _, _ = resolver.exchangeOne(context.Background(), "example.com", 0)
// Now, we should count 1 "out of retries" errors.
test.AssertMetricWithLabelsEquals(t, resolver.timeoutCounter, prometheus.Labels{"qtype": "None", "type": "out of retries", "resolver": "127.0.0.1", "isTLD": "false"}, 1)
}