326 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			326 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
| package bdns
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"math/rand/v2"
 | |
| 	"net"
 | |
| 	"strconv"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/miekg/dns"
 | |
| 	"github.com/prometheus/client_golang/prometheus"
 | |
| 
 | |
| 	"github.com/letsencrypt/boulder/cmd"
 | |
| )
 | |
| 
 | |
| // ServerProvider represents a type which can provide a list of addresses for
 | |
| // the bdns to use as DNS resolvers. Different implementations may provide
 | |
| // different strategies for providing addresses, and may provide different kinds
 | |
| // of addresses (e.g. host:port combos vs IP addresses).
 | |
| type ServerProvider interface {
 | |
| 	Addrs() ([]string, error)
 | |
| 	Stop()
 | |
| }
 | |
| 
 | |
| // staticProvider stores a list of host:port combos, and provides that whole
 | |
| // list in randomized order when asked for addresses. This replicates the old
 | |
| // behavior of the bdns.impl's servers field.
 | |
| type staticProvider struct {
 | |
| 	servers []string
 | |
| }
 | |
| 
 | |
| var _ ServerProvider = &staticProvider{}
 | |
| 
 | |
| // validateServerAddress ensures that a given server address is formatted in
 | |
| // such a way that it can be dialed. The provided server address must include a
 | |
| // host/IP and port separated by colon. Additionally, if the host is a literal
 | |
| // IPv6 address, it must be enclosed in square brackets.
 | |
| // (https://golang.org/src/net/dial.go?s=9833:9881#L281)
 | |
| func validateServerAddress(address string) error {
 | |
| 	// Ensure the host and port portions of `address` can be split.
 | |
| 	host, port, err := net.SplitHostPort(address)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// Ensure `address` contains both a `host` and `port` portion.
 | |
| 	if host == "" || port == "" {
 | |
| 		return errors.New("port cannot be missing")
 | |
| 	}
 | |
| 
 | |
| 	// Ensure the `port` portion of `address` is a valid port.
 | |
| 	portNum, err := strconv.Atoi(port)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("parsing port number: %s", err)
 | |
| 	}
 | |
| 	if portNum <= 0 || portNum > 65535 {
 | |
| 		return errors.New("port must be an integer between 0 - 65535")
 | |
| 	}
 | |
| 
 | |
| 	// Ensure the `host` portion of `address` is a valid FQDN or IP address.
 | |
| 	IPv6 := net.ParseIP(host).To16()
 | |
| 	IPv4 := net.ParseIP(host).To4()
 | |
| 	FQDN := dns.IsFqdn(dns.Fqdn(host))
 | |
| 	if IPv6 == nil && IPv4 == nil && !FQDN {
 | |
| 		return errors.New("host is not an FQDN or IP address")
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func NewStaticProvider(servers []string) (*staticProvider, error) {
 | |
| 	var serverAddrs []string
 | |
| 	for _, server := range servers {
 | |
| 		err := validateServerAddress(server)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("server address %q invalid: %s", server, err)
 | |
| 		}
 | |
| 		serverAddrs = append(serverAddrs, server)
 | |
| 	}
 | |
| 	return &staticProvider{servers: serverAddrs}, nil
 | |
| }
 | |
| 
 | |
| func (sp *staticProvider) Addrs() ([]string, error) {
 | |
| 	if len(sp.servers) == 0 {
 | |
| 		return nil, fmt.Errorf("no servers configured")
 | |
| 	}
 | |
| 	r := make([]string, len(sp.servers))
 | |
| 	perm := rand.Perm(len(sp.servers))
 | |
| 	for i, v := range perm {
 | |
| 		r[i] = sp.servers[v]
 | |
| 	}
 | |
| 	return r, nil
 | |
| }
 | |
| 
 | |
| func (sp *staticProvider) Stop() {}
 | |
| 
 | |
| // dynamicProvider uses DNS to look up the set of IP addresses which correspond
 | |
| // to its single host. It returns this list in random order when asked for
 | |
| // addresses, and refreshes it regularly using a goroutine started by its
 | |
| // constructor.
 | |
| type dynamicProvider struct {
 | |
| 	// dnsAuthority is the single <hostname|IPv4|[IPv6]>:<port> of the DNS
 | |
| 	// server to be used for resolution of DNS backends. If the address contains
 | |
| 	// a hostname it will be resolved via the system DNS. If the port is left
 | |
| 	// unspecified it will default to '53'. If this field is left unspecified
 | |
| 	// the system DNS will be used for resolution of DNS backends.
 | |
| 	dnsAuthority string
 | |
| 	// service is the service name to look up SRV records for within the domain.
 | |
| 	// If this field is left unspecified 'dns' will be used as the service name.
 | |
| 	service string
 | |
| 	// proto is the IP protocol (tcp or udp) to look up SRV records for.
 | |
| 	proto string
 | |
| 	// domain is the name to look up SRV records within.
 | |
| 	domain string
 | |
| 	// A map of IP addresses (results of A record lookups for SRV Targets) to
 | |
| 	// ports (Port fields in SRV records) associated with those addresses.
 | |
| 	addrs map[string][]uint16
 | |
| 	// Other internal bookkeeping state.
 | |
| 	cancel        chan interface{}
 | |
| 	mu            sync.RWMutex
 | |
| 	refresh       time.Duration
 | |
| 	updateCounter *prometheus.CounterVec
 | |
| }
 | |
| 
 | |
| // ParseTarget takes the user input target string and default port, returns
 | |
| // formatted host and port info. If target doesn't specify a port, set the port
 | |
| // to be the defaultPort. If target is in IPv6 format and host-name is enclosed
 | |
| // in square brackets, brackets are stripped when setting the host.
 | |
| //
 | |
| // Examples:
 | |
| //   - target: "www.google.com" defaultPort: "443" returns host: "www.google.com", port: "443"
 | |
| //   - target: "ipv4-host:80" defaultPort: "443" returns host: "ipv4-host", port: "80"
 | |
| //   - target: "[ipv6-host]" defaultPort: "443" returns host: "ipv6-host", port: "443"
 | |
| //   - target: ":80" defaultPort: "443" returns host: "localhost", port: "80"
 | |
| //
 | |
| // This function is copied from:
 | |
| // https://github.com/grpc/grpc-go/blob/master/internal/resolver/dns/dns_resolver.go
 | |
| // It has been minimally modified to fit our code style.
 | |
| func ParseTarget(target, defaultPort string) (host, port string, err error) {
 | |
| 	if target == "" {
 | |
| 		return "", "", errors.New("missing address")
 | |
| 	}
 | |
| 	ip := net.ParseIP(target)
 | |
| 	if ip != nil {
 | |
| 		// Target is an IPv4 or IPv6(without brackets) address.
 | |
| 		return target, defaultPort, nil
 | |
| 	}
 | |
| 	host, port, err = net.SplitHostPort(target)
 | |
| 	if err == nil {
 | |
| 		if port == "" {
 | |
| 			// If the port field is empty (target ends with colon), e.g.
 | |
| 			// "[::1]:", this is an error.
 | |
| 			return "", "", errors.New("missing port after port-separator colon")
 | |
| 		}
 | |
| 		// target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port
 | |
| 		if host == "" {
 | |
| 			// Keep consistent with net.Dial(): If the host is empty, as in
 | |
| 			// ":80", the local system is assumed.
 | |
| 			host = "localhost"
 | |
| 		}
 | |
| 		return host, port, nil
 | |
| 	}
 | |
| 	host, port, err = net.SplitHostPort(target + ":" + defaultPort)
 | |
| 	if err == nil {
 | |
| 		// Target doesn't have port.
 | |
| 		return host, port, nil
 | |
| 	}
 | |
| 	return "", "", fmt.Errorf("invalid target address %v, error info: %v", target, err)
 | |
| }
 | |
| 
 | |
| var _ ServerProvider = &dynamicProvider{}
 | |
| 
 | |
| // StartDynamicProvider constructs a new dynamicProvider and starts its
 | |
| // auto-update goroutine. The auto-update process queries DNS for SRV records
 | |
| // at refresh intervals and uses the resulting IP/port combos to populate the
 | |
| // list returned by Addrs. The update process ignores the Priority and Weight
 | |
| // attributes of the SRV records.
 | |
| //
 | |
| // `proto` is the IP protocol (tcp or udp) to look up SRV records for.
 | |
| func StartDynamicProvider(c *cmd.DNSProvider, refresh time.Duration, proto string) (*dynamicProvider, error) {
 | |
| 	if c.SRVLookup.Domain == "" {
 | |
| 		return nil, fmt.Errorf("'domain' cannot be empty")
 | |
| 	}
 | |
| 
 | |
| 	service := c.SRVLookup.Service
 | |
| 	if service == "" {
 | |
| 		// Default to "dns" if no service is specified. This is the default
 | |
| 		// service name for DNS servers.
 | |
| 		service = "dns"
 | |
| 	}
 | |
| 
 | |
| 	host, port, err := ParseTarget(c.DNSAuthority, "53")
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	dnsAuthority := net.JoinHostPort(host, port)
 | |
| 	err = validateServerAddress(dnsAuthority)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	dp := dynamicProvider{
 | |
| 		dnsAuthority: dnsAuthority,
 | |
| 		service:      service,
 | |
| 		proto:        proto,
 | |
| 		domain:       c.SRVLookup.Domain,
 | |
| 		addrs:        make(map[string][]uint16),
 | |
| 		cancel:       make(chan interface{}),
 | |
| 		refresh:      refresh,
 | |
| 		updateCounter: prometheus.NewCounterVec(
 | |
| 			prometheus.CounterOpts{
 | |
| 				Name: "dns_update",
 | |
| 				Help: "Counter of attempts to update a dynamic provider",
 | |
| 			},
 | |
| 			[]string{"success"},
 | |
| 		),
 | |
| 	}
 | |
| 
 | |
| 	// Update once immediately, so we can know whether that was successful, then
 | |
| 	// kick off the long-running update goroutine.
 | |
| 	err = dp.update()
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to start dynamic provider: %w", err)
 | |
| 	}
 | |
| 	go dp.run()
 | |
| 
 | |
| 	return &dp, nil
 | |
| }
 | |
| 
 | |
| // run loops forever, calling dp.update() every dp.refresh interval. Does not
 | |
| // halt until the dp.cancel channel is closed, so should be run in a goroutine.
 | |
| func (dp *dynamicProvider) run() {
 | |
| 	t := time.NewTicker(dp.refresh)
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-t.C:
 | |
| 			err := dp.update()
 | |
| 			if err != nil {
 | |
| 				dp.updateCounter.With(prometheus.Labels{
 | |
| 					"success": "false",
 | |
| 				}).Inc()
 | |
| 				continue
 | |
| 			}
 | |
| 			dp.updateCounter.With(prometheus.Labels{
 | |
| 				"success": "true",
 | |
| 			}).Inc()
 | |
| 		case <-dp.cancel:
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // update performs the SRV and A record queries necessary to map the given DNS
 | |
| // domain name to a set of cacheable IP addresses and ports, and stores the
 | |
| // results in dp.addrs.
 | |
| func (dp *dynamicProvider) update() error {
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), dp.refresh/2)
 | |
| 	defer cancel()
 | |
| 
 | |
| 	resolver := &net.Resolver{
 | |
| 		PreferGo: true,
 | |
| 		Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
 | |
| 			d := &net.Dialer{}
 | |
| 			return d.DialContext(ctx, network, dp.dnsAuthority)
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	// RFC 2782 formatted SRV record being queried e.g. "_service._proto.name."
 | |
| 	record := fmt.Sprintf("_%s._%s.%s.", dp.service, dp.proto, dp.domain)
 | |
| 
 | |
| 	_, srvs, err := resolver.LookupSRV(ctx, dp.service, dp.proto, dp.domain)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("during SRV lookup of %q: %w", record, err)
 | |
| 	}
 | |
| 	if len(srvs) == 0 {
 | |
| 		return fmt.Errorf("SRV lookup of %q returned 0 results", record)
 | |
| 	}
 | |
| 
 | |
| 	addrPorts := make(map[string][]uint16)
 | |
| 	for _, srv := range srvs {
 | |
| 		addrs, err := resolver.LookupHost(ctx, srv.Target)
 | |
| 		if err != nil {
 | |
| 			return fmt.Errorf("during A/AAAA lookup of target %q from SRV record %q: %w", srv.Target, record, err)
 | |
| 		}
 | |
| 		for _, addr := range addrs {
 | |
| 			joinedHostPort := net.JoinHostPort(addr, fmt.Sprint(srv.Port))
 | |
| 			err := validateServerAddress(joinedHostPort)
 | |
| 			if err != nil {
 | |
| 				return fmt.Errorf("invalid addr %q from SRV record %q: %w", joinedHostPort, record, err)
 | |
| 			}
 | |
| 			addrPorts[addr] = append(addrPorts[addr], srv.Port)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	dp.mu.Lock()
 | |
| 	dp.addrs = addrPorts
 | |
| 	dp.mu.Unlock()
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // Addrs returns a shuffled list of IP/port pairs, with the guarantee that no
 | |
| // two IP/port pairs will share the same IP.
 | |
| func (dp *dynamicProvider) Addrs() ([]string, error) {
 | |
| 	var r []string
 | |
| 	dp.mu.RLock()
 | |
| 	for ip, ports := range dp.addrs {
 | |
| 		port := fmt.Sprint(ports[rand.IntN(len(ports))])
 | |
| 		addr := net.JoinHostPort(ip, port)
 | |
| 		r = append(r, addr)
 | |
| 	}
 | |
| 	dp.mu.RUnlock()
 | |
| 	rand.Shuffle(len(r), func(i, j int) {
 | |
| 		r[i], r[j] = r[j], r[i]
 | |
| 	})
 | |
| 	return r, nil
 | |
| }
 | |
| 
 | |
| // Stop tells the background update goroutine to cease. It does not wait for
 | |
| // confirmation that it has done so.
 | |
| func (dp *dynamicProvider) Stop() {
 | |
| 	close(dp.cancel)
 | |
| }
 |