diff --git a/network/transports.go b/network/transports.go index d96eda117..c97d0052e 100644 --- a/network/transports.go +++ b/network/transports.go @@ -55,6 +55,9 @@ var backOffTemplate = wait.Backoff{ Steps: 15, } +// ErrTimeoutDialing when the timeout is reached after set amount of time. +var ErrTimeoutDialing = errors.New("timed out dialing") + // DialWithBackOff executes `net.Dialer.DialContext()` with exponentially increasing // dial timeouts. In addition it sleeps with random jitter between tries. var DialWithBackOff = NewBackoffDialer(backOffTemplate) @@ -110,7 +113,7 @@ func dialBackOffHelper(ctx context.Context, network, address string, bo wait.Bac return c, nil } elapsed := time.Since(start) - return nil, fmt.Errorf("timed out dialing after %.2fs", elapsed.Seconds()) + return nil, fmt.Errorf("%w after %.2fs", ErrTimeoutDialing, elapsed.Seconds()) } func newHTTPTransport(disableKeepAlives, disableCompression bool, maxIdle, maxIdlePerHost int) http.RoundTripper { diff --git a/network/transports_test.go b/network/transports_test.go index 7691e1628..4aeff2f08 100644 --- a/network/transports_test.go +++ b/network/transports_test.go @@ -20,18 +20,19 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" + "fmt" + "io" "net" "net/http" "net/http/httptest" "strings" + "syscall" "testing" + "time" "k8s.io/apimachinery/pkg/util/sets" -) - -const ( - timeoutErr = "timed out dialing" - connectionRefusedErr = "connection refused" + "k8s.io/apimachinery/pkg/util/wait" ) func TestHTTPRoundTripper(t *testing.T) { @@ -76,67 +77,187 @@ func TestHTTPRoundTripper(t *testing.T) { } } -func TestDialWithBackoff(t *testing.T) { - // Make the test short. - bo := backOffTemplate - bo.Steps = 2 +func TestDialWithBackoffConnectionRefused(t *testing.T) { + testDialWithBackoffConnectionRefused(nil, t) +} - // Nobody's listening on a random port. Usually. - c, err := dialBackOffHelper(context.Background(), "tcp4", "127.0.0.1:41482", bo, nil) - verifyFailedConnection(t, c, err, connectionRefusedErr) +func TestDialWithBackoffTimeout(t *testing.T) { + testDialWithBackoffTimeout(nil, t) +} - // Timeout. Use special testing IP address. - c, err = dialBackOffHelper(context.Background(), "tcp4", "198.18.0.254:8888", bo, nil) - verifyFailedConnection(t, c, err, timeoutErr) +func TestDialWithBackoffSuccess(t *testing.T) { + testDialWithBackoffSuccess(nil, t) +} - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) +func TestDialTLSWithBackoffConnectionRefused(t *testing.T) { + testDialWithBackoffConnectionRefused(exampleTlsConf(), t) +} + +func TestDialTLSWithBackoffTimeout(t *testing.T) { + testDialWithBackoffTimeout(exampleTlsConf(), t) +} + +func TestDialTLSWithBackoffSuccess(t *testing.T) { + testDialWithBackoffSuccess(exampleTlsConf(), t) +} + +func testDialWithBackoffConnectionRefused(tlsConf *tls.Config, t testingT) { + ctx := context.TODO() + port := findUnusedPortOrFail(t) + addr := fmt.Sprintf("127.0.0.1:%d", port) + dialer := newDialer(ctx, tlsConf) + c, err := dialer(addr) + closeOrFail(t, c) + if !errors.Is(err, syscall.ECONNREFUSED) { + t.Fatalf("Unexpected error: %+v", err) + } +} + +func testDialWithBackoffTimeout(tlsConf *tls.Config, t testingT) { + ctx := context.TODO() + closer, addr, err := listenOne() + if err != nil { + t.Fatal("Unable to create listener:", err) + } + defer closer() + c1, err := net.Dial("tcp4", addr.String()) + if err != nil { + t.Fatalf("Unable to connect to server on %s: %s", addr, err) + } + defer closeOrFail(t, c1) + + // Since the backlog is full, the next request must time out. + dialer := newDialer(ctx, tlsConf) + c, err := dialer(addr.String()) + if err == nil { + closeOrFail(t, c) + t.Fatal("Unexpected success dialing") + } + if !errors.Is(err, ErrTimeoutDialing) { + t.Fatalf("Unexpected error: %+v", err) + } +} + +func testDialWithBackoffSuccess(tlsConf *tls.Config, t testingT) { + //goland:noinspection HttpUrlsUsage + const ( + prefixHTTP = "http://" + prefixHTTPS = "https://" + ) + ctx := context.TODO() + var s *httptest.Server + servFn := httptest.NewServer + if tlsConf != nil { + servFn = httptest.NewTLSServer + } + s = servFn(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer s.Close() + prefix := prefixHTTP + if tlsConf != nil { + prefix = prefixHTTPS + rootCAs := x509.NewCertPool() + rootCAs.AddCert(s.Certificate()) + tlsConf.RootCAs = rootCAs + } + addr := strings.TrimPrefix(s.URL, prefix) - c, err = DialWithBackOff(context.Background(), "tcp4", strings.TrimPrefix(s.URL, "http://")) + dialer := newDialer(ctx, tlsConf) + c, err := dialer(addr) if err != nil { t.Fatal("Dial error =", err) } - c.Close() + closeOrFail(t, c) } -func TestDialTLSWithBackoff(t *testing.T) { - // Make the test short. - bo := backOffTemplate - bo.Steps = 2 - - tlsConf := &tls.Config{ +func exampleTlsConf() *tls.Config { + return &tls.Config{ InsecureSkipVerify: false, ServerName: "example.com", MinVersion: tls.VersionTLS12, } +} - // Nobody's listening on a random port. Usually. - c, err := dialBackOffHelper(context.Background(), "tcp4", "127.0.0.1:41482", bo, tlsConf) - verifyFailedConnection(t, c, err, connectionRefusedErr) +func newDialer(ctx context.Context, tlsConf *tls.Config) func(addr string) (net.Conn, error) { + // Make the test short. + bo := wait.Backoff{ + Duration: time.Millisecond, + Factor: 1.4, + Jitter: 0.1, // At most 10% jitter. + Steps: 1, + } - // Timeout. Use special testing IP address. - c, err = dialBackOffHelper(context.Background(), "tcp4", "198.18.0.254:8888", bo, tlsConf) - verifyFailedConnection(t, c, err, timeoutErr) + dialFn := func(addr string) (net.Conn, error) { + return NewBackoffDialer(bo)(ctx, "tcp4", addr) + } + if tlsConf != nil { + dialFn = func(addr string) (net.Conn, error) { + bo.Duration = 10 * time.Millisecond + bo.Steps = 3 + return NewTLSBackoffDialer(bo)(ctx, "tcp4", addr, tlsConf) + } + } + return dialFn +} - s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer s.Close() +func closeOrFail(t testingT, con io.Closer) { + if con == nil { + return + } + if err := con.Close(); err != nil { + t.Fatal(err) + } +} - rootCAs := x509.NewCertPool() - rootCAs.AddCert(s.Certificate()) - tlsConf.RootCAs = rootCAs - - c, err = DialTLSWithBackOff(context.Background(), "tcp4", strings.TrimPrefix(s.URL, "https://"), tlsConf) +func findUnusedPortOrFail(t testingT) int { + l, err := net.Listen("tcp", "localhost:0") if err != nil { - t.Fatal("Dial error =", err) + t.Fatal(err) } - c.Close() + defer closeOrFail(t, l) + return l.Addr().(*net.TCPAddr).Port } -func verifyFailedConnection(t *testing.T, c net.Conn, err error, prefix string) { - if err == nil { - c.Close() - t.Error("Unexpected success dialing") - } else if !strings.Contains(err.Error(), prefix) { - t.Errorf("Error = %v, want: %s(...)", err, prefix) - } +var errTest = errors.New("testing") + +func newTestErr(msg string, err error) error { + return fmt.Errorf("%w: %s: %v", errTest, msg, err) +} + +// listenOne creates a socket with backlog of one, and use that socket, so +// any other connection will guarantee to timeout. +// +// Golang doesn't allow us to set the backlog argument on syscall.Listen from +// net.ListenTCP, so we need to get directly into syscall land. +func listenOne() (func(), *net.TCPAddr, error) { + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0) + if err != nil { + return nil, nil, newTestErr("Couldn't get socket", err) + } + sa := &syscall.SockaddrInet4{ + Port: 0, + Addr: [4]byte{127, 0, 0, 1}, + } + if err = syscall.Bind(fd, sa); err != nil { + return nil, nil, newTestErr("Unable to bind", err) + } + if err = syscall.Listen(fd, 0); err != nil { + return nil, nil, newTestErr("Unable to Listen", err) + } + closer := func() { _ = syscall.Close(fd) } + listenaddr, err := syscall.Getsockname(fd) + if err != nil { + closer() + return nil, nil, newTestErr("Could not get sockname", err) + } + sa = listenaddr.(*syscall.SockaddrInet4) + addr := &net.TCPAddr{ + IP: sa.Addr[:], + Port: sa.Port, + } + return closer, addr, nil +} + +type testingT interface { + Fatal(args ...interface{}) + Fatalf(format string, args ...interface{}) }