package bdns import ( "context" "crypto/tls" "crypto/x509" "errors" "fmt" "io" "log" "net" "net/http" "net/netip" "net/url" "os" "regexp" "slices" "strings" "sync" "testing" "time" "github.com/jmhodges/clock" "github.com/miekg/dns" "github.com/prometheus/client_golang/prometheus" blog "github.com/letsencrypt/boulder/log" "github.com/letsencrypt/boulder/metrics" "github.com/letsencrypt/boulder/test" ) const dnsLoopbackAddr = "127.0.0.1:4053" func mockDNSQuery(w http.ResponseWriter, httpReq *http.Request) { if httpReq.Header.Get("Content-Type") != "application/dns-message" { w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, "client didn't send Content-Type: application/dns-message") } if httpReq.Header.Get("Accept") != "application/dns-message" { w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, "client didn't accept Content-Type: application/dns-message") } requestBody, err := io.ReadAll(httpReq.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, "reading body: %s", err) } httpReq.Body.Close() r := new(dns.Msg) err = r.Unpack(requestBody) if err != nil { w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, "unpacking request: %s", err) } m := new(dns.Msg) m.SetReply(r) m.Compress = false appendAnswer := func(rr dns.RR) { m.Answer = append(m.Answer, rr) } for _, q := range r.Question { q.Name = strings.ToLower(q.Name) if q.Name == "servfail.com." || q.Name == "servfailexception.example.com" { m.Rcode = dns.RcodeServerFailure break } switch q.Qtype { case dns.TypeSOA: record := new(dns.SOA) record.Hdr = dns.RR_Header{Name: "letsencrypt.org.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 0} record.Ns = "ns.letsencrypt.org." record.Mbox = "master.letsencrypt.org." record.Serial = 1 record.Refresh = 1 record.Retry = 1 record.Expire = 1 record.Minttl = 1 appendAnswer(record) case dns.TypeAAAA: if q.Name == "v6.letsencrypt.org." { record := new(dns.AAAA) record.Hdr = dns.RR_Header{Name: "v6.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0} record.AAAA = net.ParseIP("2602:80a:6000:abad:cafe::1") appendAnswer(record) } if q.Name == "dualstack.letsencrypt.org." { record := new(dns.AAAA) record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0} record.AAAA = net.ParseIP("2602:80a:6000:abad:cafe::1") appendAnswer(record) } if q.Name == "v4error.letsencrypt.org." { record := new(dns.AAAA) record.Hdr = dns.RR_Header{Name: "v4error.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0} record.AAAA = net.ParseIP("2602:80a:6000:abad:cafe::1") appendAnswer(record) } if q.Name == "v6error.letsencrypt.org." { m.SetRcode(r, dns.RcodeNotImplemented) } if q.Name == "nxdomain.letsencrypt.org." { m.SetRcode(r, dns.RcodeNameError) } if q.Name == "dualstackerror.letsencrypt.org." { m.SetRcode(r, dns.RcodeNotImplemented) } case dns.TypeA: if q.Name == "cps.letsencrypt.org." { record := new(dns.A) record.Hdr = dns.RR_Header{Name: "cps.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0} record.A = net.ParseIP("64.112.117.1") appendAnswer(record) } if q.Name == "dualstack.letsencrypt.org." { record := new(dns.A) record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0} record.A = net.ParseIP("64.112.117.1") appendAnswer(record) } if q.Name == "v6error.letsencrypt.org." { record := new(dns.A) record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0} record.A = net.ParseIP("64.112.117.1") appendAnswer(record) } if q.Name == "v4error.letsencrypt.org." { m.SetRcode(r, dns.RcodeNotImplemented) } if q.Name == "nxdomain.letsencrypt.org." { m.SetRcode(r, dns.RcodeNameError) } if q.Name == "dualstackerror.letsencrypt.org." { m.SetRcode(r, dns.RcodeRefused) } case dns.TypeCNAME: if q.Name == "cname.letsencrypt.org." { record := new(dns.CNAME) record.Hdr = dns.RR_Header{Name: "cname.letsencrypt.org.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 30} record.Target = "cps.letsencrypt.org." appendAnswer(record) } if q.Name == "cname.example.com." { record := new(dns.CNAME) record.Hdr = dns.RR_Header{Name: "cname.example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 30} record.Target = "CAA.example.com." appendAnswer(record) } case dns.TypeDNAME: if q.Name == "dname.letsencrypt.org." { record := new(dns.DNAME) record.Hdr = dns.RR_Header{Name: "dname.letsencrypt.org.", Rrtype: dns.TypeDNAME, Class: dns.ClassINET, Ttl: 30} record.Target = "cps.letsencrypt.org." appendAnswer(record) } case dns.TypeCAA: if q.Name == "bracewel.net." || q.Name == "caa.example.com." { record := new(dns.CAA) record.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeCAA, Class: dns.ClassINET, Ttl: 0} record.Tag = "issue" record.Value = "letsencrypt.org" record.Flag = 1 appendAnswer(record) } if q.Name == "cname.example.com." { record := new(dns.CAA) record.Hdr = dns.RR_Header{Name: "caa.example.com.", Rrtype: dns.TypeCAA, Class: dns.ClassINET, Ttl: 0} record.Tag = "issue" record.Value = "letsencrypt.org" record.Flag = 1 appendAnswer(record) } if q.Name == "gonetld." { m.SetRcode(r, dns.RcodeNameError) } case dns.TypeTXT: if q.Name == "split-txt.letsencrypt.org." { record := new(dns.TXT) record.Hdr = dns.RR_Header{Name: "split-txt.letsencrypt.org.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0} record.Txt = []string{"a", "b", "c"} appendAnswer(record) } else { auth := new(dns.SOA) auth.Hdr = dns.RR_Header{Name: "letsencrypt.org.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 0} auth.Ns = "ns.letsencrypt.org." auth.Mbox = "master.letsencrypt.org." auth.Serial = 1 auth.Refresh = 1 auth.Retry = 1 auth.Expire = 1 auth.Minttl = 1 m.Ns = append(m.Ns, auth) } if q.Name == "nxdomain.letsencrypt.org." { m.SetRcode(r, dns.RcodeNameError) } } } body, err := m.Pack() if err != nil { fmt.Fprintf(os.Stderr, "packing reply: %s\n", err) } w.Header().Set("Content-Type", "application/dns-message") _, err = w.Write(body) if err != nil { panic(err) // running tests, so panic is OK } } func serveLoopResolver(stopChan chan bool) { m := http.NewServeMux() m.HandleFunc("/dns-query", mockDNSQuery) httpServer := &http.Server{ Addr: dnsLoopbackAddr, Handler: m, ReadTimeout: time.Second, WriteTimeout: time.Second, } go func() { cert := "../test/certs/ipki/localhost/cert.pem" key := "../test/certs/ipki/localhost/key.pem" err := httpServer.ListenAndServeTLS(cert, key) if err != nil { fmt.Println(err) } }() go func() { <-stopChan err := httpServer.Shutdown(context.Background()) if err != nil { log.Fatal(err) } }() } func pollServer() { backoff := 200 * time.Millisecond ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() ticker := time.NewTicker(backoff) for { select { case <-ctx.Done(): fmt.Fprintln(os.Stderr, "Timeout reached while testing for the dns server to come up") os.Exit(1) case <-ticker.C: conn, _ := dns.DialTimeout("udp", dnsLoopbackAddr, backoff) if conn != nil { _ = conn.Close() return } } } } // tlsConfig is used for the TLS config of client instances that talk to the // DoH server set up in TestMain. var tlsConfig *tls.Config func TestMain(m *testing.M) { root, err := os.ReadFile("../test/certs/ipki/minica.pem") if err != nil { log.Fatal(err) } pool := x509.NewCertPool() pool.AppendCertsFromPEM(root) tlsConfig = &tls.Config{ RootCAs: pool, } stop := make(chan bool, 1) serveLoopResolver(stop) pollServer() ret := m.Run() stop <- true os.Exit(ret) } func TestDNSNoServers(t *testing.T) { staticProvider, err := NewStaticProvider([]string{}) test.AssertNotError(t, err, "Got error creating StaticProvider") obj := New(time.Hour, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) _, resolvers, err := obj.LookupHost(context.Background(), "letsencrypt.org") test.AssertEquals(t, len(resolvers), 0) test.AssertError(t, err, "No servers") _, _, err = obj.LookupTXT(context.Background(), "letsencrypt.org") test.AssertError(t, err, "No servers") _, _, _, err = obj.LookupCAA(context.Background(), "letsencrypt.org") test.AssertError(t, err, "No servers") } func TestDNSOneServer(t *testing.T) { staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) test.AssertNotError(t, err, "Got error creating StaticProvider") obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) _, resolvers, err := obj.LookupHost(context.Background(), "cps.letsencrypt.org") test.AssertEquals(t, len(resolvers), 2) slices.Sort(resolvers) test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) test.AssertNotError(t, err, "No message") } func TestDNSDuplicateServers(t *testing.T) { staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr, dnsLoopbackAddr}) test.AssertNotError(t, err, "Got error creating StaticProvider") obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) _, resolvers, err := obj.LookupHost(context.Background(), "cps.letsencrypt.org") test.AssertEquals(t, len(resolvers), 2) slices.Sort(resolvers) test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) test.AssertNotError(t, err, "No message") } func TestDNSServFail(t *testing.T) { staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) test.AssertNotError(t, err, "Got error creating StaticProvider") obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) bad := "servfail.com" _, _, err = obj.LookupTXT(context.Background(), bad) test.AssertError(t, err, "LookupTXT didn't return an error") _, _, err = obj.LookupHost(context.Background(), bad) test.AssertError(t, err, "LookupHost didn't return an error") emptyCaa, _, _, err := obj.LookupCAA(context.Background(), bad) test.Assert(t, len(emptyCaa) == 0, "Query returned non-empty list of CAA records") test.AssertError(t, err, "LookupCAA should have returned an error") } func TestDNSLookupTXT(t *testing.T) { staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) test.AssertNotError(t, err, "Got error creating StaticProvider") obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) a, _, err := obj.LookupTXT(context.Background(), "letsencrypt.org") t.Logf("A: %v", a) test.AssertNotError(t, err, "No message") a, _, err = obj.LookupTXT(context.Background(), "split-txt.letsencrypt.org") t.Logf("A: %v ", a) test.AssertNotError(t, err, "No message") test.AssertEquals(t, len(a), 1) test.AssertEquals(t, a[0], "abc") } // TODO(#8213): Convert this to a table test. func TestDNSLookupHost(t *testing.T) { staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) test.AssertNotError(t, err, "Got error creating StaticProvider") obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) ip, resolvers, err := obj.LookupHost(context.Background(), "servfail.com") t.Logf("servfail.com - IP: %s, Err: %s", ip, err) test.AssertError(t, err, "Server failure") test.Assert(t, len(ip) == 0, "Should not have IPs") slices.Sort(resolvers) test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) ip, resolvers, err = obj.LookupHost(context.Background(), "nonexistent.letsencrypt.org") t.Logf("nonexistent.letsencrypt.org - IP: %s, Err: %s", ip, err) test.AssertError(t, err, "No valid A or AAAA records should error") test.Assert(t, len(ip) == 0, "Should not have IPs") slices.Sort(resolvers) test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) // Single IPv4 address ip, resolvers, err = obj.LookupHost(context.Background(), "cps.letsencrypt.org") t.Logf("cps.letsencrypt.org - IP: %s, Err: %s", ip, err) test.AssertNotError(t, err, "Not an error to exist") test.Assert(t, len(ip) == 1, "Should have IP") slices.Sort(resolvers) test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) ip, resolvers, err = obj.LookupHost(context.Background(), "cps.letsencrypt.org") t.Logf("cps.letsencrypt.org - IP: %s, Err: %s", ip, err) test.AssertNotError(t, err, "Not an error to exist") test.Assert(t, len(ip) == 1, "Should have IP") slices.Sort(resolvers) test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) // Single IPv6 address ip, resolvers, err = obj.LookupHost(context.Background(), "v6.letsencrypt.org") t.Logf("v6.letsencrypt.org - IP: %s, Err: %s", ip, err) test.AssertNotError(t, err, "Not an error to exist") test.Assert(t, len(ip) == 1, "Should not have IPs") slices.Sort(resolvers) test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) // Both IPv6 and IPv4 address ip, resolvers, err = obj.LookupHost(context.Background(), "dualstack.letsencrypt.org") t.Logf("dualstack.letsencrypt.org - IP: %s, Err: %s", ip, err) test.AssertNotError(t, err, "Not an error to exist") test.Assert(t, len(ip) == 2, "Should have 2 IPs") expected := netip.MustParseAddr("64.112.117.1") test.Assert(t, ip[0] == expected, "wrong ipv4 address") expected = netip.MustParseAddr("2602:80a:6000:abad:cafe::1") test.Assert(t, ip[1] == expected, "wrong ipv6 address") slices.Sort(resolvers) test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) // IPv6 error, IPv4 success ip, resolvers, err = obj.LookupHost(context.Background(), "v6error.letsencrypt.org") t.Logf("v6error.letsencrypt.org - IP: %s, Err: %s", ip, err) test.AssertNotError(t, err, "Not an error to exist") test.Assert(t, len(ip) == 1, "Should have 1 IP") expected = netip.MustParseAddr("64.112.117.1") test.Assert(t, ip[0] == expected, "wrong ipv4 address") slices.Sort(resolvers) test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) // IPv6 success, IPv4 error ip, resolvers, err = obj.LookupHost(context.Background(), "v4error.letsencrypt.org") t.Logf("v4error.letsencrypt.org - IP: %s, Err: %s", ip, err) test.AssertNotError(t, err, "Not an error to exist") test.Assert(t, len(ip) == 1, "Should have 1 IP") expected = netip.MustParseAddr("2602:80a:6000:abad:cafe::1") test.Assert(t, ip[0] == expected, "wrong ipv6 address") slices.Sort(resolvers) test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) // IPv6 error, IPv4 error // Should return both the IPv4 error (Refused) and the IPv6 error (NotImplemented) hostname := "dualstackerror.letsencrypt.org" ip, resolvers, err = obj.LookupHost(context.Background(), hostname) t.Logf("%s - IP: %s, Err: %s", hostname, ip, err) test.AssertError(t, err, "Should be an error") test.AssertContains(t, err.Error(), "REFUSED looking up A for") test.AssertContains(t, err.Error(), "NOTIMP looking up AAAA for") slices.Sort(resolvers) test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"}) } func TestDNSNXDOMAIN(t *testing.T) { staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) test.AssertNotError(t, err, "Got error creating StaticProvider") obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) hostname := "nxdomain.letsencrypt.org" _, _, err = obj.LookupHost(context.Background(), hostname) test.AssertContains(t, err.Error(), "NXDOMAIN looking up A for") test.AssertContains(t, err.Error(), "NXDOMAIN looking up AAAA for") _, _, err = obj.LookupTXT(context.Background(), hostname) expected := Error{dns.TypeTXT, hostname, nil, dns.RcodeNameError, nil} test.AssertDeepEquals(t, err, expected) } func TestDNSLookupCAA(t *testing.T) { staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) test.AssertNotError(t, err, "Got error creating StaticProvider") obj := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, "", blog.UseMock(), tlsConfig) removeIDExp := regexp.MustCompile(" id: [[:digit:]]+") caas, resp, resolvers, err := obj.LookupCAA(context.Background(), "bracewel.net") test.AssertNotError(t, err, "CAA lookup failed") test.Assert(t, len(caas) > 0, "Should have CAA records") test.AssertEquals(t, len(resolvers), 1) test.AssertDeepEquals(t, resolvers, ResolverAddrs{"127.0.0.1:4053"}) expectedResp := `;; opcode: QUERY, status: NOERROR, id: XXXX ;; flags: qr rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0 ;; QUESTION SECTION: ;bracewel.net. IN CAA ;; ANSWER SECTION: bracewel.net. 0 IN CAA 1 issue "letsencrypt.org" ` test.AssertEquals(t, removeIDExp.ReplaceAllString(resp, " id: XXXX"), expectedResp) caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "nonexistent.letsencrypt.org") test.AssertNotError(t, err, "CAA lookup failed") test.Assert(t, len(caas) == 0, "Shouldn't have CAA records") test.AssertEquals(t, resolvers[0], "127.0.0.1:4053") expectedResp = "" test.AssertEquals(t, resp, expectedResp) caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "nxdomain.letsencrypt.org") slices.Sort(resolvers) test.AssertNotError(t, err, "CAA lookup failed") test.Assert(t, len(caas) == 0, "Shouldn't have CAA records") test.AssertEquals(t, resolvers[0], "127.0.0.1:4053") expectedResp = "" test.AssertEquals(t, resp, expectedResp) caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "cname.example.com") test.AssertNotError(t, err, "CAA lookup failed") test.Assert(t, len(caas) > 0, "Should follow CNAME to find CAA") test.AssertEquals(t, resolvers[0], "127.0.0.1:4053") expectedResp = `;; opcode: QUERY, status: NOERROR, id: XXXX ;; flags: qr rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0 ;; QUESTION SECTION: ;cname.example.com. IN CAA ;; ANSWER SECTION: caa.example.com. 0 IN CAA 1 issue "letsencrypt.org" ` test.AssertEquals(t, removeIDExp.ReplaceAllString(resp, " id: XXXX"), expectedResp) _, _, resolvers, err = obj.LookupCAA(context.Background(), "gonetld") test.AssertError(t, err, "should fail for TLD NXDOMAIN") test.AssertContains(t, err.Error(), "NXDOMAIN") test.AssertEquals(t, resolvers[0], "127.0.0.1:4053") } type testExchanger struct { sync.Mutex count int errs []error } var errTooManyRequests = errors.New("too many requests") func (te *testExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) { te.Lock() defer te.Unlock() msg := &dns.Msg{ MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess}, } if len(te.errs) <= te.count { return nil, 0, errTooManyRequests } err := te.errs[te.count] te.count++ return msg, 2 * time.Millisecond, err } func TestRetry(t *testing.T) { isTempErr := &url.Error{Op: "read", Err: tempError(true)} nonTempErr := &url.Error{Op: "read", Err: tempError(false)} servFailError := errors.New("DNS problem: server failure at resolver looking up TXT for example.com") type testCase struct { name string maxTries int te *testExchanger expected error expectedCount int metricsAllRetries float64 } tests := []*testCase{ // The success on first try case { name: "success", maxTries: 3, te: &testExchanger{ errs: []error{nil}, }, expected: nil, expectedCount: 1, }, // Immediate non-OpError, error returns immediately { name: "non-operror", maxTries: 3, te: &testExchanger{ errs: []error{errors.New("nope")}, }, expected: servFailError, expectedCount: 1, }, // Temporary err, then non-OpError stops at two tries { name: "err-then-non-operror", maxTries: 3, te: &testExchanger{ errs: []error{isTempErr, errors.New("nope")}, }, expected: servFailError, expectedCount: 2, }, // Temporary error given always { name: "persistent-temp-error", maxTries: 3, te: &testExchanger{ errs: []error{ isTempErr, isTempErr, isTempErr, }, }, expected: servFailError, expectedCount: 3, metricsAllRetries: 1, }, // Even with maxTries at 0, we should still let a single request go // through { name: "zero-maxtries", maxTries: 0, te: &testExchanger{ errs: []error{nil}, }, expected: nil, expectedCount: 1, }, // Temporary error given just once causes two tries { name: "single-temp-error", maxTries: 3, te: &testExchanger{ errs: []error{ isTempErr, nil, }, }, expected: nil, expectedCount: 2, }, // Temporary error given twice causes three tries { name: "double-temp-error", maxTries: 3, te: &testExchanger{ errs: []error{ isTempErr, isTempErr, nil, }, }, expected: nil, expectedCount: 3, }, // Temporary error given thrice causes three tries and fails { name: "triple-temp-error", maxTries: 3, te: &testExchanger{ errs: []error{ isTempErr, isTempErr, isTempErr, }, }, expected: servFailError, expectedCount: 3, metricsAllRetries: 1, }, // temporary then non-Temporary error causes two retries { name: "temp-nontemp-error", maxTries: 3, te: &testExchanger{ errs: []error{ isTempErr, nonTempErr, }, }, expected: servFailError, expectedCount: 2, }, } for i, tc := range tests { t.Run(tc.name, func(t *testing.T) { staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) test.AssertNotError(t, err, "Got error creating StaticProvider") testClient := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), tc.maxTries, "", blog.UseMock(), tlsConfig) dr := testClient.(*impl) dr.dnsClient = tc.te _, _, err = dr.LookupTXT(context.Background(), "example.com") if err == errTooManyRequests { t.Errorf("#%d, sent more requests than the test case handles", i) } expectedErr := tc.expected if (expectedErr == nil && err != nil) || (expectedErr != nil && err == nil) || (expectedErr != nil && expectedErr.Error() != err.Error()) { t.Errorf("#%d, error, expected %v, got %v", i, expectedErr, err) } if tc.expectedCount != tc.te.count { t.Errorf("#%d, error, expectedCount %v, got %v", i, tc.expectedCount, tc.te.count) } if tc.metricsAllRetries > 0 { test.AssertMetricWithLabelsEquals( t, dr.timeoutCounter, prometheus.Labels{ "qtype": "TXT", "type": "out of retries", "resolver": "127.0.0.1", "isTLD": "false", }, tc.metricsAllRetries) } }) } staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) test.AssertNotError(t, err, "Got error creating StaticProvider") testClient := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 3, "", blog.UseMock(), tlsConfig) dr := testClient.(*impl) dr.dnsClient = &testExchanger{errs: []error{isTempErr, isTempErr, nil}} ctx, cancel := context.WithCancel(context.Background()) cancel() _, _, err = dr.LookupTXT(ctx, "example.com") if err == nil || err.Error() != "DNS problem: query timed out (and was canceled) looking up TXT for example.com" { t.Errorf("expected %s, got %s", context.Canceled, err) } dr.dnsClient = &testExchanger{errs: []error{isTempErr, isTempErr, nil}} ctx, cancel = context.WithTimeout(context.Background(), -10*time.Hour) defer cancel() _, _, err = dr.LookupTXT(ctx, "example.com") if err == nil || err.Error() != "DNS problem: query timed out looking up TXT for example.com" { t.Errorf("expected %s, got %s", context.DeadlineExceeded, err) } dr.dnsClient = &testExchanger{errs: []error{isTempErr, isTempErr, nil}} ctx, deadlineCancel := context.WithTimeout(context.Background(), -10*time.Hour) deadlineCancel() _, _, err = dr.LookupTXT(ctx, "example.com") if err == nil || err.Error() != "DNS problem: query timed out looking up TXT for example.com" { t.Errorf("expected %s, got %s", context.DeadlineExceeded, err) } test.AssertMetricWithLabelsEquals( t, dr.timeoutCounter, prometheus.Labels{ "qtype": "TXT", "type": "canceled", "resolver": "127.0.0.1", }, 1) test.AssertMetricWithLabelsEquals( t, dr.timeoutCounter, prometheus.Labels{ "qtype": "TXT", "type": "deadline exceeded", "resolver": "127.0.0.1", }, 2) } func TestIsTLD(t *testing.T) { if isTLD("com") != "true" { t.Errorf("expected 'com' to be a TLD, got %q", isTLD("com")) } if isTLD("example.com") != "false" { t.Errorf("expected 'example.com' to not a TLD, got %q", isTLD("example.com")) } } type tempError bool func (t tempError) Temporary() bool { return bool(t) } func (t tempError) Error() string { return fmt.Sprintf("Temporary: %t", t) } // rotateFailureExchanger is a dns.Exchange implementation that tracks a count // of the number of calls to `Exchange` for a given address in the `lookups` // map. For all addresses in the `brokenAddresses` map, a retryable error is // returned from `Exchange`. This mock is used by `TestRotateServerOnErr`. type rotateFailureExchanger struct { sync.Mutex lookups map[string]int brokenAddresses map[string]bool } // Exchange for rotateFailureExchanger tracks the `a` argument in `lookups` and // if present in `brokenAddresses`, returns a temporary error. func (e *rotateFailureExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) { e.Lock() defer e.Unlock() // Track that exchange was called for the given server e.lookups[a]++ // If its a broken server, return a retryable error if e.brokenAddresses[a] { isTempErr := &url.Error{Op: "read", Err: tempError(true)} return nil, 2 * time.Millisecond, isTempErr } return m, 2 * time.Millisecond, nil } // TestRotateServerOnErr ensures that a retryable error returned from a DNS // server will result in the retry being performed against the next server in // the list. func TestRotateServerOnErr(t *testing.T) { // Configure three DNS servers dnsServers := []string{ "a:53", "b:53", "[2606:4700:4700::1111]:53", } // Set up a DNS client using these servers that will retry queries up to // a maximum of 5 times. It's important to choose a maxTries value >= the // number of dnsServers to ensure we always get around to trying the one // working server staticProvider, err := NewStaticProvider(dnsServers) test.AssertNotError(t, err, "Got error creating StaticProvider") maxTries := 5 client := New(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), maxTries, "", blog.UseMock(), tlsConfig) // Configure a mock exchanger that will always return a retryable error for // servers A and B. This will force server "[2606:4700:4700::1111]:53" to do // all the work once retries reach it. mock := &rotateFailureExchanger{ brokenAddresses: map[string]bool{ "a:53": true, "b:53": true, }, lookups: make(map[string]int), } client.(*impl).dnsClient = mock // Perform a bunch of lookups. We choose the initial server randomly. Any time // A or B is chosen there should be an error and a retry using the next server // in the list. Since we configured maxTries to be larger than the number of // servers *all* queries should eventually succeed by being retried against // server "[2606:4700:4700::1111]:53". for range maxTries * 2 { _, resolvers, err := client.LookupTXT(context.Background(), "example.com") test.AssertEquals(t, len(resolvers), 1) test.AssertEquals(t, resolvers[0], "[2606:4700:4700::1111]:53") // Any errors are unexpected - server "[2606:4700:4700::1111]:53" should // have responded without error. test.AssertNotError(t, err, "Expected no error from eventual retry with functional server") } // We expect that the A and B servers had a non-zero number of lookups // attempted. test.Assert(t, mock.lookups["a:53"] > 0, "Expected A server to have non-zero lookup attempts") test.Assert(t, mock.lookups["b:53"] > 0, "Expected B server to have non-zero lookup attempts") // We expect that the server "[2606:4700:4700::1111]:53" eventually served // all of the lookups attempted. test.AssertEquals(t, mock.lookups["[2606:4700:4700::1111]:53"], maxTries*2) } type mockTempURLError struct{} func (m *mockTempURLError) Error() string { return "whoops, oh gosh" } func (m *mockTempURLError) Timeout() bool { return false } func (m *mockTempURLError) Temporary() bool { return true } type dohAlwaysRetryExchanger struct { sync.Mutex err error } func (dohE *dohAlwaysRetryExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) { dohE.Lock() defer dohE.Unlock() tempURLerror := &url.Error{ Op: "GET", URL: "https://example.com", Err: &mockTempURLError{}, } return nil, time.Second, tempURLerror } func TestDOHMetric(t *testing.T) { staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr}) test.AssertNotError(t, err, "Got error creating StaticProvider") testClient := New(time.Second*11, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 0, "", blog.UseMock(), tlsConfig) resolver := testClient.(*impl) resolver.dnsClient = &dohAlwaysRetryExchanger{err: &url.Error{Op: "read", Err: tempError(true)}} // Starting out, we should count 0 "out of retries" errors. test.AssertMetricWithLabelsEquals(t, resolver.timeoutCounter, prometheus.Labels{"qtype": "None", "type": "out of retries", "resolver": "127.0.0.1", "isTLD": "false"}, 0) // Trigger the error. _, _, _ = resolver.exchangeOne(context.Background(), "example.com", 0) // Now, we should count 1 "out of retries" errors. test.AssertMetricWithLabelsEquals(t, resolver.timeoutCounter, prometheus.Labels{"qtype": "None", "type": "out of retries", "resolver": "127.0.0.1", "isTLD": "false"}, 1) }