894 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			894 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Go
		
	
	
	
| package bdns
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"log"
 | |
| 	"net"
 | |
| 	"net/url"
 | |
| 	"os"
 | |
| 	"regexp"
 | |
| 	"slices"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/jmhodges/clock"
 | |
| 	"github.com/miekg/dns"
 | |
| 	"github.com/prometheus/client_golang/prometheus"
 | |
| 
 | |
| 	"github.com/letsencrypt/boulder/features"
 | |
| 	blog "github.com/letsencrypt/boulder/log"
 | |
| 	"github.com/letsencrypt/boulder/metrics"
 | |
| 	"github.com/letsencrypt/boulder/test"
 | |
| )
 | |
| 
 | |
| const dnsLoopbackAddr = "127.0.0.1:4053"
 | |
| 
 | |
| func mockDNSQuery(w dns.ResponseWriter, r *dns.Msg) {
 | |
| 	m := new(dns.Msg)
 | |
| 	m.SetReply(r)
 | |
| 	m.Compress = false
 | |
| 
 | |
| 	appendAnswer := func(rr dns.RR) {
 | |
| 		m.Answer = append(m.Answer, rr)
 | |
| 	}
 | |
| 	for _, q := range r.Question {
 | |
| 		q.Name = strings.ToLower(q.Name)
 | |
| 		if q.Name == "servfail.com." || q.Name == "servfailexception.example.com" {
 | |
| 			m.Rcode = dns.RcodeServerFailure
 | |
| 			break
 | |
| 		}
 | |
| 		switch q.Qtype {
 | |
| 		case dns.TypeSOA:
 | |
| 			record := new(dns.SOA)
 | |
| 			record.Hdr = dns.RR_Header{Name: "letsencrypt.org.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 0}
 | |
| 			record.Ns = "ns.letsencrypt.org."
 | |
| 			record.Mbox = "master.letsencrypt.org."
 | |
| 			record.Serial = 1
 | |
| 			record.Refresh = 1
 | |
| 			record.Retry = 1
 | |
| 			record.Expire = 1
 | |
| 			record.Minttl = 1
 | |
| 			appendAnswer(record)
 | |
| 		case dns.TypeAAAA:
 | |
| 			if q.Name == "v6.letsencrypt.org." {
 | |
| 				record := new(dns.AAAA)
 | |
| 				record.Hdr = dns.RR_Header{Name: "v6.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0}
 | |
| 				record.AAAA = net.ParseIP("::1")
 | |
| 				appendAnswer(record)
 | |
| 			}
 | |
| 			if q.Name == "dualstack.letsencrypt.org." {
 | |
| 				record := new(dns.AAAA)
 | |
| 				record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0}
 | |
| 				record.AAAA = net.ParseIP("::1")
 | |
| 				appendAnswer(record)
 | |
| 			}
 | |
| 			if q.Name == "v4error.letsencrypt.org." {
 | |
| 				record := new(dns.AAAA)
 | |
| 				record.Hdr = dns.RR_Header{Name: "v4error.letsencrypt.org.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0}
 | |
| 				record.AAAA = net.ParseIP("::1")
 | |
| 				appendAnswer(record)
 | |
| 			}
 | |
| 			if q.Name == "v6error.letsencrypt.org." {
 | |
| 				m.SetRcode(r, dns.RcodeNotImplemented)
 | |
| 			}
 | |
| 			if q.Name == "nxdomain.letsencrypt.org." {
 | |
| 				m.SetRcode(r, dns.RcodeNameError)
 | |
| 			}
 | |
| 			if q.Name == "dualstackerror.letsencrypt.org." {
 | |
| 				m.SetRcode(r, dns.RcodeNotImplemented)
 | |
| 			}
 | |
| 		case dns.TypeA:
 | |
| 			if q.Name == "cps.letsencrypt.org." {
 | |
| 				record := new(dns.A)
 | |
| 				record.Hdr = dns.RR_Header{Name: "cps.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}
 | |
| 				record.A = net.ParseIP("127.0.0.1")
 | |
| 				appendAnswer(record)
 | |
| 			}
 | |
| 			if q.Name == "dualstack.letsencrypt.org." {
 | |
| 				record := new(dns.A)
 | |
| 				record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}
 | |
| 				record.A = net.ParseIP("127.0.0.1")
 | |
| 				appendAnswer(record)
 | |
| 			}
 | |
| 			if q.Name == "v6error.letsencrypt.org." {
 | |
| 				record := new(dns.A)
 | |
| 				record.Hdr = dns.RR_Header{Name: "dualstack.letsencrypt.org.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}
 | |
| 				record.A = net.ParseIP("127.0.0.1")
 | |
| 				appendAnswer(record)
 | |
| 			}
 | |
| 			if q.Name == "v4error.letsencrypt.org." {
 | |
| 				m.SetRcode(r, dns.RcodeNotImplemented)
 | |
| 			}
 | |
| 			if q.Name == "nxdomain.letsencrypt.org." {
 | |
| 				m.SetRcode(r, dns.RcodeNameError)
 | |
| 			}
 | |
| 			if q.Name == "dualstackerror.letsencrypt.org." {
 | |
| 				m.SetRcode(r, dns.RcodeRefused)
 | |
| 			}
 | |
| 		case dns.TypeCNAME:
 | |
| 			if q.Name == "cname.letsencrypt.org." {
 | |
| 				record := new(dns.CNAME)
 | |
| 				record.Hdr = dns.RR_Header{Name: "cname.letsencrypt.org.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 30}
 | |
| 				record.Target = "cps.letsencrypt.org."
 | |
| 				appendAnswer(record)
 | |
| 			}
 | |
| 			if q.Name == "cname.example.com." {
 | |
| 				record := new(dns.CNAME)
 | |
| 				record.Hdr = dns.RR_Header{Name: "cname.example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 30}
 | |
| 				record.Target = "CAA.example.com."
 | |
| 				appendAnswer(record)
 | |
| 			}
 | |
| 		case dns.TypeDNAME:
 | |
| 			if q.Name == "dname.letsencrypt.org." {
 | |
| 				record := new(dns.DNAME)
 | |
| 				record.Hdr = dns.RR_Header{Name: "dname.letsencrypt.org.", Rrtype: dns.TypeDNAME, Class: dns.ClassINET, Ttl: 30}
 | |
| 				record.Target = "cps.letsencrypt.org."
 | |
| 				appendAnswer(record)
 | |
| 			}
 | |
| 		case dns.TypeCAA:
 | |
| 			if q.Name == "bracewel.net." || q.Name == "caa.example.com." {
 | |
| 				record := new(dns.CAA)
 | |
| 				record.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeCAA, Class: dns.ClassINET, Ttl: 0}
 | |
| 				record.Tag = "issue"
 | |
| 				record.Value = "letsencrypt.org"
 | |
| 				record.Flag = 1
 | |
| 				appendAnswer(record)
 | |
| 			}
 | |
| 			if q.Name == "cname.example.com." {
 | |
| 				record := new(dns.CAA)
 | |
| 				record.Hdr = dns.RR_Header{Name: "caa.example.com.", Rrtype: dns.TypeCAA, Class: dns.ClassINET, Ttl: 0}
 | |
| 				record.Tag = "issue"
 | |
| 				record.Value = "letsencrypt.org"
 | |
| 				record.Flag = 1
 | |
| 				appendAnswer(record)
 | |
| 			}
 | |
| 			if q.Name == "gonetld." {
 | |
| 				m.SetRcode(r, dns.RcodeNameError)
 | |
| 			}
 | |
| 		case dns.TypeTXT:
 | |
| 			if q.Name == "split-txt.letsencrypt.org." {
 | |
| 				record := new(dns.TXT)
 | |
| 				record.Hdr = dns.RR_Header{Name: "split-txt.letsencrypt.org.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0}
 | |
| 				record.Txt = []string{"a", "b", "c"}
 | |
| 				appendAnswer(record)
 | |
| 			} else {
 | |
| 				auth := new(dns.SOA)
 | |
| 				auth.Hdr = dns.RR_Header{Name: "letsencrypt.org.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 0}
 | |
| 				auth.Ns = "ns.letsencrypt.org."
 | |
| 				auth.Mbox = "master.letsencrypt.org."
 | |
| 				auth.Serial = 1
 | |
| 				auth.Refresh = 1
 | |
| 				auth.Retry = 1
 | |
| 				auth.Expire = 1
 | |
| 				auth.Minttl = 1
 | |
| 				m.Ns = append(m.Ns, auth)
 | |
| 			}
 | |
| 			if q.Name == "nxdomain.letsencrypt.org." {
 | |
| 				m.SetRcode(r, dns.RcodeNameError)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	err := w.WriteMsg(m)
 | |
| 	if err != nil {
 | |
| 		panic(err) // running tests, so panic is OK
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func serveLoopResolver(stopChan chan bool) {
 | |
| 	dns.HandleFunc(".", mockDNSQuery)
 | |
| 	tcpServer := &dns.Server{
 | |
| 		Addr:         dnsLoopbackAddr,
 | |
| 		Net:          "tcp",
 | |
| 		ReadTimeout:  time.Second,
 | |
| 		WriteTimeout: time.Second,
 | |
| 	}
 | |
| 	udpServer := &dns.Server{
 | |
| 		Addr:         dnsLoopbackAddr,
 | |
| 		Net:          "udp",
 | |
| 		ReadTimeout:  time.Second,
 | |
| 		WriteTimeout: time.Second,
 | |
| 	}
 | |
| 	go func() {
 | |
| 		err := tcpServer.ListenAndServe()
 | |
| 		if err != nil {
 | |
| 			fmt.Println(err)
 | |
| 		}
 | |
| 	}()
 | |
| 	go func() {
 | |
| 		err := udpServer.ListenAndServe()
 | |
| 		if err != nil {
 | |
| 			fmt.Println(err)
 | |
| 		}
 | |
| 	}()
 | |
| 	go func() {
 | |
| 		<-stopChan
 | |
| 		err := tcpServer.Shutdown()
 | |
| 		if err != nil {
 | |
| 			log.Fatal(err)
 | |
| 		}
 | |
| 		err = udpServer.Shutdown()
 | |
| 		if err != nil {
 | |
| 			log.Fatal(err)
 | |
| 		}
 | |
| 	}()
 | |
| }
 | |
| 
 | |
| func pollServer() {
 | |
| 	backoff := 200 * time.Millisecond
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 | |
| 	defer cancel()
 | |
| 	ticker := time.NewTicker(backoff)
 | |
| 
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-ctx.Done():
 | |
| 			fmt.Fprintln(os.Stderr, "Timeout reached while testing for the dns server to come up")
 | |
| 			os.Exit(1)
 | |
| 		case <-ticker.C:
 | |
| 			conn, _ := dns.DialTimeout("udp", dnsLoopbackAddr, backoff)
 | |
| 			if conn != nil {
 | |
| 				_ = conn.Close()
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestMain(m *testing.M) {
 | |
| 	stop := make(chan bool, 1)
 | |
| 	serveLoopResolver(stop)
 | |
| 	pollServer()
 | |
| 	ret := m.Run()
 | |
| 	stop <- true
 | |
| 	os.Exit(ret)
 | |
| }
 | |
| 
 | |
| func TestDNSNoServers(t *testing.T) {
 | |
| 	staticProvider, err := NewStaticProvider([]string{})
 | |
| 	test.AssertNotError(t, err, "Got error creating StaticProvider")
 | |
| 
 | |
| 	obj := NewTest(time.Hour, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock(), nil)
 | |
| 
 | |
| 	_, resolvers, err := obj.LookupHost(context.Background(), "letsencrypt.org")
 | |
| 	test.AssertEquals(t, len(resolvers), 0)
 | |
| 	test.AssertError(t, err, "No servers")
 | |
| 
 | |
| 	_, _, err = obj.LookupTXT(context.Background(), "letsencrypt.org")
 | |
| 	test.AssertError(t, err, "No servers")
 | |
| 
 | |
| 	_, _, _, err = obj.LookupCAA(context.Background(), "letsencrypt.org")
 | |
| 	test.AssertError(t, err, "No servers")
 | |
| }
 | |
| 
 | |
| func TestDNSOneServer(t *testing.T) {
 | |
| 	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
 | |
| 	test.AssertNotError(t, err, "Got error creating StaticProvider")
 | |
| 
 | |
| 	obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock(), nil)
 | |
| 
 | |
| 	_, resolvers, err := obj.LookupHost(context.Background(), "cps.letsencrypt.org")
 | |
| 	test.AssertEquals(t, len(resolvers), 2)
 | |
| 	slices.Sort(resolvers)
 | |
| 	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
 | |
| 	test.AssertNotError(t, err, "No message")
 | |
| }
 | |
| 
 | |
| func TestDNSDuplicateServers(t *testing.T) {
 | |
| 	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr, dnsLoopbackAddr})
 | |
| 	test.AssertNotError(t, err, "Got error creating StaticProvider")
 | |
| 
 | |
| 	obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock(), nil)
 | |
| 
 | |
| 	_, resolvers, err := obj.LookupHost(context.Background(), "cps.letsencrypt.org")
 | |
| 	test.AssertEquals(t, len(resolvers), 2)
 | |
| 	slices.Sort(resolvers)
 | |
| 	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
 | |
| 	test.AssertNotError(t, err, "No message")
 | |
| }
 | |
| 
 | |
| func TestDNSServFail(t *testing.T) {
 | |
| 	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
 | |
| 	test.AssertNotError(t, err, "Got error creating StaticProvider")
 | |
| 
 | |
| 	obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock(), nil)
 | |
| 	bad := "servfail.com"
 | |
| 
 | |
| 	_, _, err = obj.LookupTXT(context.Background(), bad)
 | |
| 	test.AssertError(t, err, "LookupTXT didn't return an error")
 | |
| 
 | |
| 	_, _, err = obj.LookupHost(context.Background(), bad)
 | |
| 	test.AssertError(t, err, "LookupHost didn't return an error")
 | |
| 
 | |
| 	emptyCaa, _, _, err := obj.LookupCAA(context.Background(), bad)
 | |
| 	test.Assert(t, len(emptyCaa) == 0, "Query returned non-empty list of CAA records")
 | |
| 	test.AssertError(t, err, "LookupCAA should have returned an error")
 | |
| }
 | |
| 
 | |
| func TestDNSLookupTXT(t *testing.T) {
 | |
| 	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
 | |
| 	test.AssertNotError(t, err, "Got error creating StaticProvider")
 | |
| 
 | |
| 	obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock(), nil)
 | |
| 
 | |
| 	a, _, err := obj.LookupTXT(context.Background(), "letsencrypt.org")
 | |
| 	t.Logf("A: %v", a)
 | |
| 	test.AssertNotError(t, err, "No message")
 | |
| 
 | |
| 	a, _, err = obj.LookupTXT(context.Background(), "split-txt.letsencrypt.org")
 | |
| 	t.Logf("A: %v ", a)
 | |
| 	test.AssertNotError(t, err, "No message")
 | |
| 	test.AssertEquals(t, len(a), 1)
 | |
| 	test.AssertEquals(t, a[0], "abc")
 | |
| }
 | |
| 
 | |
| func TestDNSLookupHost(t *testing.T) {
 | |
| 	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
 | |
| 	test.AssertNotError(t, err, "Got error creating StaticProvider")
 | |
| 
 | |
| 	obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock(), nil)
 | |
| 
 | |
| 	ip, resolvers, err := obj.LookupHost(context.Background(), "servfail.com")
 | |
| 	t.Logf("servfail.com - IP: %s, Err: %s", ip, err)
 | |
| 	test.AssertError(t, err, "Server failure")
 | |
| 	test.Assert(t, len(ip) == 0, "Should not have IPs")
 | |
| 	slices.Sort(resolvers)
 | |
| 	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
 | |
| 
 | |
| 	ip, resolvers, err = obj.LookupHost(context.Background(), "nonexistent.letsencrypt.org")
 | |
| 	t.Logf("nonexistent.letsencrypt.org - IP: %s, Err: %s", ip, err)
 | |
| 	test.AssertError(t, err, "No valid A or AAAA records should error")
 | |
| 	test.Assert(t, len(ip) == 0, "Should not have IPs")
 | |
| 	slices.Sort(resolvers)
 | |
| 	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
 | |
| 
 | |
| 	// Single IPv4 address
 | |
| 	ip, resolvers, err = obj.LookupHost(context.Background(), "cps.letsencrypt.org")
 | |
| 	t.Logf("cps.letsencrypt.org - IP: %s, Err: %s", ip, err)
 | |
| 	test.AssertNotError(t, err, "Not an error to exist")
 | |
| 	test.Assert(t, len(ip) == 1, "Should have IP")
 | |
| 	slices.Sort(resolvers)
 | |
| 	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
 | |
| 	ip, resolvers, err = obj.LookupHost(context.Background(), "cps.letsencrypt.org")
 | |
| 	t.Logf("cps.letsencrypt.org - IP: %s, Err: %s", ip, err)
 | |
| 	test.AssertNotError(t, err, "Not an error to exist")
 | |
| 	test.Assert(t, len(ip) == 1, "Should have IP")
 | |
| 	slices.Sort(resolvers)
 | |
| 	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
 | |
| 
 | |
| 	// Single IPv6 address
 | |
| 	ip, resolvers, err = obj.LookupHost(context.Background(), "v6.letsencrypt.org")
 | |
| 	t.Logf("v6.letsencrypt.org - IP: %s, Err: %s", ip, err)
 | |
| 	test.AssertNotError(t, err, "Not an error to exist")
 | |
| 	test.Assert(t, len(ip) == 1, "Should not have IPs")
 | |
| 	slices.Sort(resolvers)
 | |
| 	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
 | |
| 
 | |
| 	// Both IPv6 and IPv4 address
 | |
| 	ip, resolvers, err = obj.LookupHost(context.Background(), "dualstack.letsencrypt.org")
 | |
| 	t.Logf("dualstack.letsencrypt.org - IP: %s, Err: %s", ip, err)
 | |
| 	test.AssertNotError(t, err, "Not an error to exist")
 | |
| 	test.Assert(t, len(ip) == 2, "Should have 2 IPs")
 | |
| 	expected := net.ParseIP("127.0.0.1")
 | |
| 	test.Assert(t, ip[0].To4().Equal(expected), "wrong ipv4 address")
 | |
| 	expected = net.ParseIP("::1")
 | |
| 	test.Assert(t, ip[1].To16().Equal(expected), "wrong ipv6 address")
 | |
| 	slices.Sort(resolvers)
 | |
| 	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
 | |
| 
 | |
| 	// IPv6 error, IPv4 success
 | |
| 	ip, resolvers, err = obj.LookupHost(context.Background(), "v6error.letsencrypt.org")
 | |
| 	t.Logf("v6error.letsencrypt.org - IP: %s, Err: %s", ip, err)
 | |
| 	test.AssertNotError(t, err, "Not an error to exist")
 | |
| 	test.Assert(t, len(ip) == 1, "Should have 1 IP")
 | |
| 	expected = net.ParseIP("127.0.0.1")
 | |
| 	test.Assert(t, ip[0].To4().Equal(expected), "wrong ipv4 address")
 | |
| 	slices.Sort(resolvers)
 | |
| 	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
 | |
| 
 | |
| 	// IPv6 success, IPv4 error
 | |
| 	ip, resolvers, err = obj.LookupHost(context.Background(), "v4error.letsencrypt.org")
 | |
| 	t.Logf("v4error.letsencrypt.org - IP: %s, Err: %s", ip, err)
 | |
| 	test.AssertNotError(t, err, "Not an error to exist")
 | |
| 	test.Assert(t, len(ip) == 1, "Should have 1 IP")
 | |
| 	expected = net.ParseIP("::1")
 | |
| 	test.Assert(t, ip[0].To16().Equal(expected), "wrong ipv6 address")
 | |
| 	slices.Sort(resolvers)
 | |
| 	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
 | |
| 
 | |
| 	// IPv6 error, IPv4 error
 | |
| 	// Should return both the IPv4 error (Refused) and the IPv6 error (NotImplemented)
 | |
| 	hostname := "dualstackerror.letsencrypt.org"
 | |
| 	ip, resolvers, err = obj.LookupHost(context.Background(), hostname)
 | |
| 	t.Logf("%s - IP: %s, Err: %s", hostname, ip, err)
 | |
| 	test.AssertError(t, err, "Should be an error")
 | |
| 	test.AssertContains(t, err.Error(), "REFUSED looking up A for")
 | |
| 	test.AssertContains(t, err.Error(), "NOTIMP looking up AAAA for")
 | |
| 	slices.Sort(resolvers)
 | |
| 	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"A:127.0.0.1:4053", "AAAA:127.0.0.1:4053"})
 | |
| }
 | |
| 
 | |
| func TestDNSNXDOMAIN(t *testing.T) {
 | |
| 	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
 | |
| 	test.AssertNotError(t, err, "Got error creating StaticProvider")
 | |
| 
 | |
| 	obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock(), nil)
 | |
| 
 | |
| 	hostname := "nxdomain.letsencrypt.org"
 | |
| 	_, _, err = obj.LookupHost(context.Background(), hostname)
 | |
| 	test.AssertContains(t, err.Error(), "NXDOMAIN looking up A for")
 | |
| 	test.AssertContains(t, err.Error(), "NXDOMAIN looking up AAAA for")
 | |
| 
 | |
| 	_, _, err = obj.LookupTXT(context.Background(), hostname)
 | |
| 	expected := Error{dns.TypeTXT, hostname, nil, dns.RcodeNameError, nil}
 | |
| 	test.AssertDeepEquals(t, err, expected)
 | |
| }
 | |
| 
 | |
| func TestDNSLookupCAA(t *testing.T) {
 | |
| 	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
 | |
| 	test.AssertNotError(t, err, "Got error creating StaticProvider")
 | |
| 
 | |
| 	obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock(), nil)
 | |
| 	removeIDExp := regexp.MustCompile(" id: [[:digit:]]+")
 | |
| 
 | |
| 	caas, resp, resolvers, err := obj.LookupCAA(context.Background(), "bracewel.net")
 | |
| 	test.AssertNotError(t, err, "CAA lookup failed")
 | |
| 	test.Assert(t, len(caas) > 0, "Should have CAA records")
 | |
| 	test.AssertEquals(t, len(resolvers), 1)
 | |
| 	test.AssertDeepEquals(t, resolvers, ResolverAddrs{"127.0.0.1:4053"})
 | |
| 	expectedResp := `;; opcode: QUERY, status: NOERROR, id: XXXX
 | |
| ;; flags: qr rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0
 | |
| 
 | |
| ;; QUESTION SECTION:
 | |
| ;bracewel.net.	IN	 CAA
 | |
| 
 | |
| ;; ANSWER SECTION:
 | |
| bracewel.net.	0	IN	CAA	1 issue "letsencrypt.org"
 | |
| `
 | |
| 	test.AssertEquals(t, removeIDExp.ReplaceAllString(resp, " id: XXXX"), expectedResp)
 | |
| 
 | |
| 	caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "nonexistent.letsencrypt.org")
 | |
| 	test.AssertNotError(t, err, "CAA lookup failed")
 | |
| 	test.Assert(t, len(caas) == 0, "Shouldn't have CAA records")
 | |
| 	test.AssertEquals(t, resolvers[0], "127.0.0.1:4053")
 | |
| 	expectedResp = ""
 | |
| 	test.AssertEquals(t, resp, expectedResp)
 | |
| 
 | |
| 	caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "nxdomain.letsencrypt.org")
 | |
| 	slices.Sort(resolvers)
 | |
| 	test.AssertNotError(t, err, "CAA lookup failed")
 | |
| 	test.Assert(t, len(caas) == 0, "Shouldn't have CAA records")
 | |
| 	test.AssertEquals(t, resolvers[0], "127.0.0.1:4053")
 | |
| 	expectedResp = ""
 | |
| 	test.AssertEquals(t, resp, expectedResp)
 | |
| 
 | |
| 	caas, resp, resolvers, err = obj.LookupCAA(context.Background(), "cname.example.com")
 | |
| 	test.AssertNotError(t, err, "CAA lookup failed")
 | |
| 	test.Assert(t, len(caas) > 0, "Should follow CNAME to find CAA")
 | |
| 	test.AssertEquals(t, resolvers[0], "127.0.0.1:4053")
 | |
| 	expectedResp = `;; opcode: QUERY, status: NOERROR, id: XXXX
 | |
| ;; flags: qr rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0
 | |
| 
 | |
| ;; QUESTION SECTION:
 | |
| ;cname.example.com.	IN	 CAA
 | |
| 
 | |
| ;; ANSWER SECTION:
 | |
| caa.example.com.	0	IN	CAA	1 issue "letsencrypt.org"
 | |
| `
 | |
| 	test.AssertEquals(t, removeIDExp.ReplaceAllString(resp, " id: XXXX"), expectedResp)
 | |
| 
 | |
| 	_, _, resolvers, err = obj.LookupCAA(context.Background(), "gonetld")
 | |
| 	test.AssertError(t, err, "should fail for TLD NXDOMAIN")
 | |
| 	test.AssertContains(t, err.Error(), "NXDOMAIN")
 | |
| 	test.AssertEquals(t, resolvers[0], "127.0.0.1:4053")
 | |
| }
 | |
| 
 | |
| func TestIsPrivateIP(t *testing.T) {
 | |
| 	test.Assert(t, isPrivateV4(net.ParseIP("127.0.0.1")), "should be private")
 | |
| 	test.Assert(t, isPrivateV4(net.ParseIP("192.168.254.254")), "should be private")
 | |
| 	test.Assert(t, isPrivateV4(net.ParseIP("10.255.0.3")), "should be private")
 | |
| 	test.Assert(t, isPrivateV4(net.ParseIP("172.16.255.255")), "should be private")
 | |
| 	test.Assert(t, isPrivateV4(net.ParseIP("172.31.255.255")), "should be private")
 | |
| 	test.Assert(t, !isPrivateV4(net.ParseIP("128.0.0.1")), "should be private")
 | |
| 	test.Assert(t, !isPrivateV4(net.ParseIP("192.169.255.255")), "should not be private")
 | |
| 	test.Assert(t, !isPrivateV4(net.ParseIP("9.255.0.255")), "should not be private")
 | |
| 	test.Assert(t, !isPrivateV4(net.ParseIP("172.32.255.255")), "should not be private")
 | |
| 
 | |
| 	test.Assert(t, isPrivateV6(net.ParseIP("::0")), "should be private")
 | |
| 	test.Assert(t, isPrivateV6(net.ParseIP("::1")), "should be private")
 | |
| 	test.Assert(t, !isPrivateV6(net.ParseIP("::2")), "should not be private")
 | |
| 
 | |
| 	test.Assert(t, isPrivateV6(net.ParseIP("fe80::1")), "should be private")
 | |
| 	test.Assert(t, isPrivateV6(net.ParseIP("febf::1")), "should be private")
 | |
| 	test.Assert(t, !isPrivateV6(net.ParseIP("fec0::1")), "should not be private")
 | |
| 	test.Assert(t, !isPrivateV6(net.ParseIP("feff::1")), "should not be private")
 | |
| 
 | |
| 	test.Assert(t, isPrivateV6(net.ParseIP("ff00::1")), "should be private")
 | |
| 	test.Assert(t, isPrivateV6(net.ParseIP("ff10::1")), "should be private")
 | |
| 	test.Assert(t, isPrivateV6(net.ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")), "should be private")
 | |
| 
 | |
| 	test.Assert(t, isPrivateV6(net.ParseIP("2002::")), "should be private")
 | |
| 	test.Assert(t, isPrivateV6(net.ParseIP("2002:ffff:ffff:ffff:ffff:ffff:ffff:ffff")), "should be private")
 | |
| 	test.Assert(t, isPrivateV6(net.ParseIP("0100::")), "should be private")
 | |
| 	test.Assert(t, isPrivateV6(net.ParseIP("0100::0000:ffff:ffff:ffff:ffff")), "should be private")
 | |
| 	test.Assert(t, !isPrivateV6(net.ParseIP("0100::0001:0000:0000:0000:0000")), "should be private")
 | |
| }
 | |
| 
 | |
| type testExchanger struct {
 | |
| 	sync.Mutex
 | |
| 	count int
 | |
| 	errs  []error
 | |
| }
 | |
| 
 | |
| var errTooManyRequests = errors.New("too many requests")
 | |
| 
 | |
| func (te *testExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
 | |
| 	te.Lock()
 | |
| 	defer te.Unlock()
 | |
| 	msg := &dns.Msg{
 | |
| 		MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess},
 | |
| 	}
 | |
| 	if len(te.errs) <= te.count {
 | |
| 		return nil, 0, errTooManyRequests
 | |
| 	}
 | |
| 	err := te.errs[te.count]
 | |
| 	te.count++
 | |
| 
 | |
| 	return msg, 2 * time.Millisecond, err
 | |
| }
 | |
| 
 | |
| func TestRetry(t *testing.T) {
 | |
| 	isTempErr := &net.OpError{Op: "read", Err: tempError(true)}
 | |
| 	nonTempErr := &net.OpError{Op: "read", Err: tempError(false)}
 | |
| 	servFailError := errors.New("DNS problem: server failure at resolver looking up TXT for example.com")
 | |
| 	netError := errors.New("DNS problem: networking error looking up TXT for example.com")
 | |
| 	type testCase struct {
 | |
| 		name              string
 | |
| 		maxTries          int
 | |
| 		te                *testExchanger
 | |
| 		expected          error
 | |
| 		expectedCount     int
 | |
| 		metricsAllRetries float64
 | |
| 	}
 | |
| 	tests := []*testCase{
 | |
| 		// The success on first try case
 | |
| 		{
 | |
| 			name:     "success",
 | |
| 			maxTries: 3,
 | |
| 			te: &testExchanger{
 | |
| 				errs: []error{nil},
 | |
| 			},
 | |
| 			expected:      nil,
 | |
| 			expectedCount: 1,
 | |
| 		},
 | |
| 		// Immediate non-OpError, error returns immediately
 | |
| 		{
 | |
| 			name:     "non-operror",
 | |
| 			maxTries: 3,
 | |
| 			te: &testExchanger{
 | |
| 				errs: []error{errors.New("nope")},
 | |
| 			},
 | |
| 			expected:      servFailError,
 | |
| 			expectedCount: 1,
 | |
| 		},
 | |
| 		// Temporary err, then non-OpError stops at two tries
 | |
| 		{
 | |
| 			name:     "err-then-non-operror",
 | |
| 			maxTries: 3,
 | |
| 			te: &testExchanger{
 | |
| 				errs: []error{isTempErr, errors.New("nope")},
 | |
| 			},
 | |
| 			expected:      servFailError,
 | |
| 			expectedCount: 2,
 | |
| 		},
 | |
| 		// Temporary error given always
 | |
| 		{
 | |
| 			name:     "persistent-temp-error",
 | |
| 			maxTries: 3,
 | |
| 			te: &testExchanger{
 | |
| 				errs: []error{
 | |
| 					isTempErr,
 | |
| 					isTempErr,
 | |
| 					isTempErr,
 | |
| 				},
 | |
| 			},
 | |
| 			expected:          netError,
 | |
| 			expectedCount:     3,
 | |
| 			metricsAllRetries: 1,
 | |
| 		},
 | |
| 		// Even with maxTries at 0, we should still let a single request go
 | |
| 		// through
 | |
| 		{
 | |
| 			name:     "zero-maxtries",
 | |
| 			maxTries: 0,
 | |
| 			te: &testExchanger{
 | |
| 				errs: []error{nil},
 | |
| 			},
 | |
| 			expected:      nil,
 | |
| 			expectedCount: 1,
 | |
| 		},
 | |
| 		// Temporary error given just once causes two tries
 | |
| 		{
 | |
| 			name:     "single-temp-error",
 | |
| 			maxTries: 3,
 | |
| 			te: &testExchanger{
 | |
| 				errs: []error{
 | |
| 					isTempErr,
 | |
| 					nil,
 | |
| 				},
 | |
| 			},
 | |
| 			expected:      nil,
 | |
| 			expectedCount: 2,
 | |
| 		},
 | |
| 		// Temporary error given twice causes three tries
 | |
| 		{
 | |
| 			name:     "double-temp-error",
 | |
| 			maxTries: 3,
 | |
| 			te: &testExchanger{
 | |
| 				errs: []error{
 | |
| 					isTempErr,
 | |
| 					isTempErr,
 | |
| 					nil,
 | |
| 				},
 | |
| 			},
 | |
| 			expected:      nil,
 | |
| 			expectedCount: 3,
 | |
| 		},
 | |
| 		// Temporary error given thrice causes three tries and fails
 | |
| 		{
 | |
| 			name:     "triple-temp-error",
 | |
| 			maxTries: 3,
 | |
| 			te: &testExchanger{
 | |
| 				errs: []error{
 | |
| 					isTempErr,
 | |
| 					isTempErr,
 | |
| 					isTempErr,
 | |
| 				},
 | |
| 			},
 | |
| 			expected:          netError,
 | |
| 			expectedCount:     3,
 | |
| 			metricsAllRetries: 1,
 | |
| 		},
 | |
| 		// temporary then non-Temporary error causes two retries
 | |
| 		{
 | |
| 			name:     "temp-nontemp-error",
 | |
| 			maxTries: 3,
 | |
| 			te: &testExchanger{
 | |
| 				errs: []error{
 | |
| 					isTempErr,
 | |
| 					nonTempErr,
 | |
| 				},
 | |
| 			},
 | |
| 			expected:      netError,
 | |
| 			expectedCount: 2,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for i, tc := range tests {
 | |
| 		t.Run(tc.name, func(t *testing.T) {
 | |
| 			staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
 | |
| 			test.AssertNotError(t, err, "Got error creating StaticProvider")
 | |
| 
 | |
| 			testClient := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), tc.maxTries, blog.UseMock(), nil)
 | |
| 			dr := testClient.(*impl)
 | |
| 			dr.dnsClient = tc.te
 | |
| 			_, _, err = dr.LookupTXT(context.Background(), "example.com")
 | |
| 			if err == errTooManyRequests {
 | |
| 				t.Errorf("#%d, sent more requests than the test case handles", i)
 | |
| 			}
 | |
| 			expectedErr := tc.expected
 | |
| 			if (expectedErr == nil && err != nil) ||
 | |
| 				(expectedErr != nil && err == nil) ||
 | |
| 				(expectedErr != nil && expectedErr.Error() != err.Error()) {
 | |
| 				t.Errorf("#%d, error, expected %v, got %v", i, expectedErr, err)
 | |
| 			}
 | |
| 			if tc.expectedCount != tc.te.count {
 | |
| 				t.Errorf("#%d, error, expectedCount %v, got %v", i, tc.expectedCount, tc.te.count)
 | |
| 			}
 | |
| 			if tc.metricsAllRetries > 0 {
 | |
| 				test.AssertMetricWithLabelsEquals(
 | |
| 					t, dr.timeoutCounter, prometheus.Labels{
 | |
| 						"qtype":    "TXT",
 | |
| 						"type":     "out of retries",
 | |
| 						"resolver": "127.0.0.1",
 | |
| 						"isTLD":    "false",
 | |
| 					}, tc.metricsAllRetries)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
 | |
| 	test.AssertNotError(t, err, "Got error creating StaticProvider")
 | |
| 
 | |
| 	testClient := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 3, blog.UseMock(), nil)
 | |
| 	dr := testClient.(*impl)
 | |
| 	dr.dnsClient = &testExchanger{errs: []error{isTempErr, isTempErr, nil}}
 | |
| 	ctx, cancel := context.WithCancel(context.Background())
 | |
| 	cancel()
 | |
| 	_, _, err = dr.LookupTXT(ctx, "example.com")
 | |
| 	if err == nil ||
 | |
| 		err.Error() != "DNS problem: query timed out (and was canceled) looking up TXT for example.com" {
 | |
| 		t.Errorf("expected %s, got %s", context.Canceled, err)
 | |
| 	}
 | |
| 
 | |
| 	dr.dnsClient = &testExchanger{errs: []error{isTempErr, isTempErr, nil}}
 | |
| 	ctx, cancel = context.WithTimeout(context.Background(), -10*time.Hour)
 | |
| 	defer cancel()
 | |
| 	_, _, err = dr.LookupTXT(ctx, "example.com")
 | |
| 	if err == nil ||
 | |
| 		err.Error() != "DNS problem: query timed out looking up TXT for example.com" {
 | |
| 		t.Errorf("expected %s, got %s", context.DeadlineExceeded, err)
 | |
| 	}
 | |
| 
 | |
| 	dr.dnsClient = &testExchanger{errs: []error{isTempErr, isTempErr, nil}}
 | |
| 	ctx, deadlineCancel := context.WithTimeout(context.Background(), -10*time.Hour)
 | |
| 	deadlineCancel()
 | |
| 	_, _, err = dr.LookupTXT(ctx, "example.com")
 | |
| 	if err == nil ||
 | |
| 		err.Error() != "DNS problem: query timed out looking up TXT for example.com" {
 | |
| 		t.Errorf("expected %s, got %s", context.DeadlineExceeded, err)
 | |
| 	}
 | |
| 
 | |
| 	test.AssertMetricWithLabelsEquals(
 | |
| 		t, dr.timeoutCounter, prometheus.Labels{
 | |
| 			"qtype":    "TXT",
 | |
| 			"type":     "canceled",
 | |
| 			"resolver": "127.0.0.1",
 | |
| 		}, 1)
 | |
| 
 | |
| 	test.AssertMetricWithLabelsEquals(
 | |
| 		t, dr.timeoutCounter, prometheus.Labels{
 | |
| 			"qtype":    "TXT",
 | |
| 			"type":     "deadline exceeded",
 | |
| 			"resolver": "127.0.0.1",
 | |
| 		}, 2)
 | |
| }
 | |
| 
 | |
| func TestIsTLD(t *testing.T) {
 | |
| 	if isTLD("com") != "true" {
 | |
| 		t.Errorf("expected 'com' to be a TLD, got %q", isTLD("com"))
 | |
| 	}
 | |
| 	if isTLD("example.com") != "false" {
 | |
| 		t.Errorf("expected 'example.com' to not a TLD, got %q", isTLD("example.com"))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type tempError bool
 | |
| 
 | |
| func (t tempError) Temporary() bool { return bool(t) }
 | |
| func (t tempError) Error() string   { return fmt.Sprintf("Temporary: %t", t) }
 | |
| 
 | |
| // rotateFailureExchanger is a dns.Exchange implementation that tracks a count
 | |
| // of the number of calls to `Exchange` for a given address in the `lookups`
 | |
| // map. For all addresses in the `brokenAddresses` map, a retryable error is
 | |
| // returned from `Exchange`. This mock is used by `TestRotateServerOnErr`.
 | |
| type rotateFailureExchanger struct {
 | |
| 	sync.Mutex
 | |
| 	lookups         map[string]int
 | |
| 	brokenAddresses map[string]bool
 | |
| }
 | |
| 
 | |
| // Exchange for rotateFailureExchanger tracks the `a` argument in `lookups` and
 | |
| // if present in `brokenAddresses`, returns a temporary error.
 | |
| func (e *rotateFailureExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
 | |
| 	e.Lock()
 | |
| 	defer e.Unlock()
 | |
| 
 | |
| 	// Track that exchange was called for the given server
 | |
| 	e.lookups[a]++
 | |
| 
 | |
| 	// If its a broken server, return a retryable error
 | |
| 	if e.brokenAddresses[a] {
 | |
| 		isTempErr := &net.OpError{Op: "read", Err: tempError(true)}
 | |
| 		return nil, 2 * time.Millisecond, isTempErr
 | |
| 	}
 | |
| 
 | |
| 	return m, 2 * time.Millisecond, nil
 | |
| }
 | |
| 
 | |
| // TestRotateServerOnErr ensures that a retryable error returned from a DNS
 | |
| // server will result in the retry being performed against the next server in
 | |
| // the list.
 | |
| func TestRotateServerOnErr(t *testing.T) {
 | |
| 	// Configure three DNS servers
 | |
| 	dnsServers := []string{
 | |
| 		"a:53", "b:53", "[2606:4700:4700::1111]:53",
 | |
| 	}
 | |
| 
 | |
| 	// Set up a DNS client using these servers that will retry queries up to
 | |
| 	// a maximum of 5 times. It's important to choose a maxTries value >= the
 | |
| 	// number of dnsServers to ensure we always get around to trying the one
 | |
| 	// working server
 | |
| 	staticProvider, err := NewStaticProvider(dnsServers)
 | |
| 	test.AssertNotError(t, err, "Got error creating StaticProvider")
 | |
| 	fmt.Println(staticProvider.servers)
 | |
| 
 | |
| 	maxTries := 5
 | |
| 	client := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), maxTries, blog.UseMock(), nil)
 | |
| 
 | |
| 	// Configure a mock exchanger that will always return a retryable error for
 | |
| 	// servers A and B. This will force server "[2606:4700:4700::1111]:53" to do
 | |
| 	// all the work once retries reach it.
 | |
| 	mock := &rotateFailureExchanger{
 | |
| 		brokenAddresses: map[string]bool{
 | |
| 			"a:53": true,
 | |
| 			"b:53": true,
 | |
| 		},
 | |
| 		lookups: make(map[string]int),
 | |
| 	}
 | |
| 	client.(*impl).dnsClient = mock
 | |
| 
 | |
| 	// Perform a bunch of lookups. We choose the initial server randomly. Any time
 | |
| 	// A or B is chosen there should be an error and a retry using the next server
 | |
| 	// in the list. Since we configured maxTries to be larger than the number of
 | |
| 	// servers *all* queries should eventually succeed by being retried against
 | |
| 	// server "[2606:4700:4700::1111]:53".
 | |
| 	for range maxTries * 2 {
 | |
| 		_, resolvers, err := client.LookupTXT(context.Background(), "example.com")
 | |
| 		test.AssertEquals(t, len(resolvers), 1)
 | |
| 		test.AssertEquals(t, resolvers[0], "[2606:4700:4700::1111]:53")
 | |
| 		// Any errors are unexpected - server "[2606:4700:4700::1111]:53" should
 | |
| 		// have responded without error.
 | |
| 		test.AssertNotError(t, err, "Expected no error from eventual retry with functional server")
 | |
| 	}
 | |
| 
 | |
| 	// We expect that the A and B servers had a non-zero number of lookups
 | |
| 	// attempted.
 | |
| 	test.Assert(t, mock.lookups["a:53"] > 0, "Expected A server to have non-zero lookup attempts")
 | |
| 	test.Assert(t, mock.lookups["b:53"] > 0, "Expected B server to have non-zero lookup attempts")
 | |
| 
 | |
| 	// We expect that the server "[2606:4700:4700::1111]:53" eventually served
 | |
| 	// all of the lookups attempted.
 | |
| 	test.AssertEquals(t, mock.lookups["[2606:4700:4700::1111]:53"], maxTries*2)
 | |
| 
 | |
| }
 | |
| 
 | |
| type mockTempURLError struct{}
 | |
| 
 | |
| func (m *mockTempURLError) Error() string   { return "whoops, oh gosh" }
 | |
| func (m *mockTempURLError) Timeout() bool   { return false }
 | |
| func (m *mockTempURLError) Temporary() bool { return true }
 | |
| 
 | |
| type dohAlwaysRetryExchanger struct {
 | |
| 	sync.Mutex
 | |
| 	err error
 | |
| }
 | |
| 
 | |
| func (dohE *dohAlwaysRetryExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
 | |
| 	dohE.Lock()
 | |
| 	defer dohE.Unlock()
 | |
| 
 | |
| 	tempURLerror := &url.Error{
 | |
| 		Op:  "GET",
 | |
| 		URL: "https://example.com",
 | |
| 		Err: &mockTempURLError{},
 | |
| 	}
 | |
| 
 | |
| 	return nil, time.Second, tempURLerror
 | |
| }
 | |
| 
 | |
| func TestDOHMetric(t *testing.T) {
 | |
| 	features.Set(features.Config{DOH: true})
 | |
| 	defer features.Reset()
 | |
| 
 | |
| 	staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
 | |
| 	test.AssertNotError(t, err, "Got error creating StaticProvider")
 | |
| 
 | |
| 	testClient := NewTest(time.Second*11, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 0, blog.UseMock(), nil)
 | |
| 	resolver := testClient.(*impl)
 | |
| 	resolver.dnsClient = &dohAlwaysRetryExchanger{err: &url.Error{Op: "read", Err: tempError(true)}}
 | |
| 
 | |
| 	// Starting out, we should count 0 "out of retries" errors.
 | |
| 	test.AssertMetricWithLabelsEquals(t, resolver.timeoutCounter, prometheus.Labels{"qtype": "None", "type": "out of retries", "resolver": "127.0.0.1", "isTLD": "false"}, 0)
 | |
| 
 | |
| 	// Trigger the error.
 | |
| 	_, _, _ = resolver.exchangeOne(context.Background(), "example.com", 0)
 | |
| 
 | |
| 	// Now, we should count 1 "out of retries" errors.
 | |
| 	test.AssertMetricWithLabelsEquals(t, resolver.timeoutCounter, prometheus.Labels{"qtype": "None", "type": "out of retries", "resolver": "127.0.0.1", "isTLD": "false"}, 1)
 | |
| }
 |