BDNS: Ensure DNS server addresses are dialable (#5520)

- Add function `validateServerAddress()` to `bdns/servers.go` which ensures that
  DNS server addresses are TCP/ UDP dial-able per: https://golang.org/src/net/dial.go?#L281
- Add unit test for `validateServerAddress()` in `bdns/servers_test.go`
- Update `cmd/boulder-va/main.go` to handle `bdns.NewStaticProvider()`
  potentially returning an error.
- Update unit tests in `bdns/dns_test.go`:
  - Handle `bdns.NewStaticProvider()` potentially returning an error
  - Add an IPv6 address to `TestRotateServerOnErr`
- Ensure DNS server addresses are validated by `validateServerAddress` whenever:
  - `dynamicProvider.update() is called`
  - `staticProvider` is constructed
- Construct server addresses using `net.JoinHostPost()` when
  `dynamicProvider.Addrs()` is called

Fixes #5463
This commit is contained in:
Samantha 2021-07-20 10:11:11 -07:00 committed by GitHub
parent b59f4386f5
commit 6eee230d69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 198 additions and 37 deletions

View File

@ -243,9 +243,12 @@ func TestMain(m *testing.M) {
}
func TestDNSNoServers(t *testing.T) {
obj := NewTest(time.Hour, NewStaticProvider([]string{}), metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
staticProvider, err := NewStaticProvider([]string{})
test.AssertNotError(t, err, "Got error creating StaticProvider")
_, err := obj.LookupHost(context.Background(), "letsencrypt.org")
obj := NewTest(time.Hour, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
_, err = obj.LookupHost(context.Background(), "letsencrypt.org")
test.AssertError(t, err, "No servers")
_, err = obj.LookupTXT(context.Background(), "letsencrypt.org")
@ -256,26 +259,35 @@ func TestDNSNoServers(t *testing.T) {
}
func TestDNSOneServer(t *testing.T) {
obj := NewTest(time.Second*10, NewStaticProvider([]string{dnsLoopbackAddr}), metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
_, err := obj.LookupHost(context.Background(), "letsencrypt.org")
obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
_, err = obj.LookupHost(context.Background(), "letsencrypt.org")
test.AssertNotError(t, err, "No message")
}
func TestDNSDuplicateServers(t *testing.T) {
obj := NewTest(time.Second*10, NewStaticProvider([]string{dnsLoopbackAddr, dnsLoopbackAddr}), metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr, dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
_, err := obj.LookupHost(context.Background(), "letsencrypt.org")
obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
_, err = obj.LookupHost(context.Background(), "letsencrypt.org")
test.AssertNotError(t, err, "No message")
}
func TestDNSServFail(t *testing.T) {
obj := NewTest(time.Second*10, NewStaticProvider([]string{dnsLoopbackAddr}), metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
bad := "servfail.com"
_, err := obj.LookupTXT(context.Background(), bad)
_, err = obj.LookupTXT(context.Background(), bad)
test.AssertError(t, err, "LookupTXT didn't return an error")
_, err = obj.LookupHost(context.Background(), bad)
@ -287,7 +299,10 @@ func TestDNSServFail(t *testing.T) {
}
func TestDNSLookupTXT(t *testing.T) {
obj := NewTest(time.Second*10, NewStaticProvider([]string{dnsLoopbackAddr}), metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
a, err := obj.LookupTXT(context.Background(), "letsencrypt.org")
t.Logf("A: %v", a)
@ -301,7 +316,10 @@ func TestDNSLookupTXT(t *testing.T) {
}
func TestDNSLookupHost(t *testing.T) {
obj := NewTest(time.Second*10, NewStaticProvider([]string{dnsLoopbackAddr}), metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
ip, err := obj.LookupHost(context.Background(), "servfail.com")
t.Logf("servfail.com - IP: %s, Err: %s", ip, err)
@ -366,10 +384,13 @@ func TestDNSLookupHost(t *testing.T) {
}
func TestDNSNXDOMAIN(t *testing.T) {
obj := NewTest(time.Second*10, NewStaticProvider([]string{dnsLoopbackAddr}), metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
hostname := "nxdomain.letsencrypt.org"
_, err := obj.LookupHost(context.Background(), hostname)
_, err = obj.LookupHost(context.Background(), hostname)
expected := &Error{dns.TypeA, hostname, nil, dns.RcodeNameError}
test.AssertDeepEquals(t, err, expected)
@ -379,7 +400,10 @@ func TestDNSNXDOMAIN(t *testing.T) {
}
func TestDNSLookupCAA(t *testing.T) {
obj := NewTest(time.Second*10, NewStaticProvider([]string{dnsLoopbackAddr}), metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
obj := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 1, blog.UseMock())
removeIDExp := regexp.MustCompile(" id: [[:digit:]]+")
caas, resp, err := obj.LookupCAA(context.Background(), "bracewel.net")
@ -600,10 +624,13 @@ func TestRetry(t *testing.T) {
for i, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
testClient := NewTest(time.Second*10, NewStaticProvider([]string{dnsLoopbackAddr}), metrics.NoopRegisterer, clock.NewFake(), tc.maxTries, blog.UseMock())
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
testClient := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), tc.maxTries, blog.UseMock())
dr := testClient.(*impl)
dr.dnsClient = tc.te
_, err := dr.LookupTXT(context.Background(), "example.com")
_, err = dr.LookupTXT(context.Background(), "example.com")
if err == errTooManyRequests {
t.Errorf("#%d, sent more requests than the test case handles", i)
}
@ -627,12 +654,15 @@ func TestRetry(t *testing.T) {
})
}
testClient := NewTest(time.Second*10, NewStaticProvider([]string{dnsLoopbackAddr}), metrics.NoopRegisterer, clock.NewFake(), 3, blog.UseMock())
staticProvider, err := NewStaticProvider([]string{dnsLoopbackAddr})
test.AssertNotError(t, err, "Got error creating StaticProvider")
testClient := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), 3, blog.UseMock())
dr := testClient.(*impl)
dr.dnsClient = &testExchanger{errs: []error{isTempErr, isTempErr, nil}}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := dr.LookupTXT(ctx, "example.com")
_, err = dr.LookupTXT(ctx, "example.com")
if err == nil ||
err.Error() != "DNS problem: query timed out (and was canceled) looking up TXT for example.com" {
t.Errorf("expected %s, got %s", context.Canceled, err)
@ -710,21 +740,29 @@ func (e *rotateFailureExchanger) Exchange(m *dns.Msg, a string) (*dns.Msg, time.
func TestRotateServerOnErr(t *testing.T) {
// Configure three DNS servers
dnsServers := []string{
"a", "b", "c",
"a:53", "b:53", "[2606:4700:4700::1111]:53",
}
// Set up a DNS client using these servers that will retry queries up to
// a maximum of 5 times. It's important to choose a maxTries value >= the
// number of dnsServers to ensure we always get around to trying the one
// working server
staticProvider, err := NewStaticProvider(dnsServers)
test.AssertNotError(t, err, "Got error creating StaticProvider")
fmt.Println(staticProvider.servers)
maxTries := 5
client := NewTest(time.Second*10, NewStaticProvider(dnsServers), metrics.NoopRegisterer, clock.NewFake(), maxTries, blog.UseMock())
client := NewTest(time.Second*10, staticProvider, metrics.NoopRegisterer, clock.NewFake(), maxTries, blog.UseMock())
// Configure a mock exchanger that will always return a retryable error for
// the A and B servers. This will force the C server to do all the work once
// retries reach it.
// servers A and B. This will force server "[2606:4700:4700::1111]:53" to do
// all the work once retries reach it.
mock := &rotateFailureExchanger{
brokenAddresses: map[string]bool{"a": true, "b": true},
lookups: make(map[string]int),
brokenAddresses: map[string]bool{
"a:53": true,
"b:53": true,
},
lookups: make(map[string]int),
}
client.(*impl).dnsClient = mock
@ -732,16 +770,21 @@ func TestRotateServerOnErr(t *testing.T) {
// A or B is chosen there should be an error and a retry using the next server
// in the list. Since we configured maxTries to be larger than the number of
// servers *all* queries should eventually succeed by being retried against
// the C server.
// server "[2606:4700:4700::1111]:53".
for i := 0; i < maxTries*2; i++ {
_, err := client.LookupTXT(context.Background(), "example.com")
// Any errors are unexpected - the C server should have responded without error.
// Any errors are unexpected - server "[2606:4700:4700::1111]:53" should
// have responded without error.
test.AssertNotError(t, err, "Expected no error from eventual retry with functional server")
}
// We expect that the A and B servers had a non-zero number of lookups attempted
test.Assert(t, mock.lookups["a"] > 0, "Expected A server to have non-zero lookup attempts")
test.Assert(t, mock.lookups["b"] > 0, "Expected B server to have non-zero lookup attempts")
// We expect that the C server eventually served all of the lookups attempted
test.AssertEquals(t, mock.lookups["c"], maxTries*2)
// We expect that the A and B servers had a non-zero number of lookups
// attempted.
test.Assert(t, mock.lookups["a:53"] > 0, "Expected A server to have non-zero lookup attempts")
test.Assert(t, mock.lookups["b:53"] > 0, "Expected B server to have non-zero lookup attempts")
// We expect that the server "[2606:4700:4700::1111]:53" eventually served
// all of the lookups attempted.
test.AssertEquals(t, mock.lookups["[2606:4700:4700::1111]:53"], maxTries*2)
}

View File

@ -2,12 +2,15 @@ package bdns
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"strconv"
"sync"
"time"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
)
@ -29,8 +32,52 @@ type staticProvider struct {
var _ ServerProvider = &staticProvider{}
func NewStaticProvider(servers []string) *staticProvider {
return &staticProvider{servers: servers}
// validateServerAddress ensures that a given server address is formatted in
// such a way that it can be dialed. The provided server address must include a
// host/IP and port separated by colon. Additionally, if the host is a literal
// IPv6 address, it must be enclosed in square brackets.
// (https://golang.org/src/net/dial.go?s=9833:9881#L281)
func validateServerAddress(address string) error {
// Ensure the host and port portions of `address` can be split.
host, port, err := net.SplitHostPort(address)
if err != nil {
return err
}
// Ensure `address` contains both a `host` and `port` portion.
if host == "" || port == "" {
return errors.New("port cannot be missing")
}
// Ensure the `port` portion of `address` is a valid port.
portNum, err := strconv.Atoi(port)
if err != nil {
return errors.New("port must be an integer: %s")
}
if portNum <= 0 || portNum > 65535 {
return errors.New("port must be an integer between 0 - 65535")
}
// Ensure the `host` portion of `address` is a valid FQDN or IP address.
IPv6 := net.ParseIP(host).To16()
IPv4 := net.ParseIP(host).To4()
FQDN := dns.IsFqdn(dns.Fqdn(host))
if IPv6 == nil && IPv4 == nil && !FQDN {
return errors.New("host is not an FQDN or IP address")
}
return nil
}
func NewStaticProvider(servers []string) (*staticProvider, error) {
var serverAddrs []string
for _, server := range servers {
err := validateServerAddress(server)
if err != nil {
return nil, fmt.Errorf("server address %q invalid: %s", server, err)
}
serverAddrs = append(serverAddrs, server)
}
return &staticProvider{servers: serverAddrs}, nil
}
func (sp *staticProvider) Addrs() ([]string, error) {
@ -137,7 +184,7 @@ func (dp *dynamicProvider) update() error {
if err != nil {
return fmt.Errorf("failed to lookup SRV records for %q: %w", dp.name, err)
}
if srvs == nil || len(srvs) == 0 {
if len(srvs) == 0 {
return fmt.Errorf("no SRV records found for %q", dp.name)
}
@ -148,6 +195,11 @@ func (dp *dynamicProvider) update() error {
return fmt.Errorf("failed to resolve SRV Target %q: %w", srv.Target, err)
}
for _, addr := range addrs {
joinedHostPort := net.JoinHostPort(addr, fmt.Sprint(srv.Port))
err := validateServerAddress(joinedHostPort)
if err != nil {
return fmt.Errorf("invalid SRV addr %q: %w", joinedHostPort, err)
}
addrPorts[addr] = append(addrPorts[addr], srv.Port)
}
}
@ -164,8 +216,8 @@ func (dp *dynamicProvider) Addrs() ([]string, error) {
var r []string
dp.mu.RLock()
for ip, ports := range dp.addrs {
port := ports[rand.Intn(len(ports))]
addr := fmt.Sprintf("%s:%d", ip, port)
port := fmt.Sprint(ports[rand.Intn(len(ports))])
addr := net.JoinHostPort(ip, port)
r = append(r, addr)
}
dp.mu.RUnlock()

62
bdns/servers_test.go Normal file
View File

@ -0,0 +1,62 @@
package bdns
import (
"testing"
)
func Test_validateServerAddress(t *testing.T) {
type args struct {
server string
}
tests := []struct {
name string
args args
wantErr bool
}{
// ipv4 cases
{"ipv4 with port", args{"1.1.1.1:53"}, false},
// sad path
{"ipv4 without port", args{"1.1.1.1"}, true},
{"ipv4 port num missing", args{"1.1.1.1:"}, true},
{"ipv4 string for port", args{"1.1.1.1:foo"}, true},
{"ipv4 port out of range high", args{"1.1.1.1:65536"}, true},
{"ipv4 port out of range low", args{"1.1.1.1:0"}, true},
// ipv6 cases
{"ipv6 with port", args{"[2606:4700:4700::1111]:53"}, false},
// sad path
{"ipv6 sans brackets", args{"2606:4700:4700::1111:53"}, true},
{"ipv6 without port", args{"[2606:4700:4700::1111]"}, true},
{"ipv6 port num missing", args{"[2606:4700:4700::1111]:"}, true},
{"ipv6 string for port", args{"[2606:4700:4700::1111]:foo"}, true},
{"ipv6 port out of range high", args{"[2606:4700:4700::1111]:65536"}, true},
{"ipv6 port out of range low", args{"[2606:4700:4700::1111]:0"}, true},
// hostname cases
{"hostname with port", args{"foo:53"}, false},
// sad path
{"hostname without port", args{"foo"}, true},
{"hostname port num missing", args{"foo:"}, true},
{"hostname string for port", args{"foo:bar"}, true},
{"hostname port out of range high", args{"foo:65536"}, true},
{"hostname port out of range low", args{"foo:0"}, true},
// fqdn cases
{"fqdn with port", args{"bar.foo.baz:53"}, false},
// sad path
{"fqdn without port", args{"bar.foo.baz"}, true},
{"fqdn port num missing", args{"bar.foo.baz:"}, true},
{"fqdn string for port", args{"bar.foo.baz:bar"}, true},
{"fqdn port out of range high", args{"bar.foo.baz:65536"}, true},
{"fqdn port out of range low", args{"bar.foo.baz:0"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateServerAddress(tt.args.server)
if (err != nil) != tt.wantErr {
t.Errorf("formatServer() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

View File

@ -128,7 +128,8 @@ func main() {
servers, err = bdns.StartDynamicProvider(c.VA.DNSResolver, 60*time.Second)
cmd.FailOnError(err, "Couldn't start dynamic DNS server resolver")
} else {
servers = bdns.NewStaticProvider(c.VA.DNSResolvers)
servers, err = bdns.NewStaticProvider(c.VA.DNSResolvers)
cmd.FailOnError(err, "Couldn't parse static DNS server(s)")
}
var resolver bdns.Client

View File

@ -136,9 +136,12 @@ func TestDNSValidationServFail(t *testing.T) {
func TestDNSValidationNoServer(t *testing.T) {
va, log := setup(nil, 0, "", nil)
staticProvider, err := bdns.NewStaticProvider([]string{})
test.AssertNotError(t, err, "Couldn't make new static provider")
va.dnsClient = bdns.NewTest(
time.Second*5,
bdns.NewStaticProvider([]string{}),
staticProvider,
metrics.NoopRegisterer,
clock.New(),
1,