diff --git a/va/va.go b/va/va.go index efb86735b..a1d66d5aa 100644 --- a/va/va.go +++ b/va/va.go @@ -47,7 +47,7 @@ const ( // before timing out. This timeout ignores the base RPC timeout and is strictly // used for the DialContext operations that take place during an // HTTP-01/TLS-SNI-[01|02] challenge validation. -var singleDialTimeout = time.Second * 10 +const singleDialTimeout = time.Second * 10 // RemoteVA wraps the core.ValidationAuthority interface and adds a field containing the addresses // of the remote gRPC server since the interface (and the underlying gRPC client) doesn't @@ -508,7 +508,7 @@ func (va *ValidationAuthorityImpl) tryGetTLSSNICerts(ctx context.Context, identi if !features.Enabled(features.IPv6First) { address := net.JoinHostPort(addresses[0].String(), thisRecord.Port) thisRecord.AddressUsed = addresses[0] - certs, err := va.getTLSSNICerts(address, identifier, challenge, zName) + certs, err := va.getTLSSNICerts(ctx, address, identifier, challenge, zName) return certs, validationRecords, err } @@ -518,7 +518,7 @@ func (va *ValidationAuthorityImpl) tryGetTLSSNICerts(ctx context.Context, identi address := net.JoinHostPort(v6[0].String(), thisRecord.Port) thisRecord.AddressUsed = v6[0] - certs, err := va.getTLSSNICerts(address, identifier, challenge, zName) + certs, err := va.getTLSSNICerts(ctx, address, identifier, challenge, zName) // If there is no error, return immediately if err == nil { @@ -547,7 +547,7 @@ func (va *ValidationAuthorityImpl) tryGetTLSSNICerts(ctx context.Context, identi // talking to the first IPv6 address, try the first IPv4 address address := net.JoinHostPort(v4[0].String(), thisRecord.Port) thisRecord.AddressUsed = v4[0] - certs, err := va.getTLSSNICerts(address, identifier, challenge, zName) + certs, err := va.getTLSSNICerts(ctx, address, identifier, challenge, zName) return certs, validationRecords, err } @@ -575,13 +575,15 @@ func (va *ValidationAuthorityImpl) validateTLSSNI01WithZName(ctx context.Context return validationRecords, probs.Unauthorized(errText) } -func (va *ValidationAuthorityImpl) getTLSSNICerts(hostPort string, identifier core.AcmeIdentifier, challenge core.Challenge, zName string) ([]*x509.Certificate, *probs.ProblemDetails) { +func (va *ValidationAuthorityImpl) getTLSSNICerts( + ctx context.Context, + hostPort string, + identifier core.AcmeIdentifier, + challenge core.Challenge, + zName string, +) ([]*x509.Certificate, *probs.ProblemDetails) { va.log.Info(fmt.Sprintf("%s [%s] Attempting to validate for %s %s", challenge.Type, identifier, hostPort, zName)) - conn, err := tls.DialWithDialer(&net.Dialer{Timeout: singleDialTimeout}, "tcp", hostPort, &tls.Config{ - ServerName: zName, - InsecureSkipVerify: true, - }) - + conn, err := tlsDial(ctx, hostPort, zName) if err != nil { va.log.Info(fmt.Sprintf("%s connection failure for %s. err=[%#v] errStr=[%s]", challenge.Type, identifier, err, err)) return nil, detailedError(err) @@ -604,6 +606,35 @@ func (va *ValidationAuthorityImpl) getTLSSNICerts(hostPort string, identifier co return certs, nil } +// tlsDial does the equivalent of tls.Dial, but obeying a context. Once +// tls.DialContextWithDialer is available, switch to that. +func tlsDial(ctx context.Context, hostPort, zName string) (*tls.Conn, error) { + ctx, cancel := context.WithTimeout(ctx, singleDialTimeout) + defer cancel() + dialer := &net.Dialer{} + netConn, err := dialer.DialContext(ctx, "tcp", hostPort) + if err != nil { + return nil, err + } + conn := tls.Client(netConn, &tls.Config{ + ServerName: zName, + InsecureSkipVerify: true, + }) + errChan := make(chan error) + go func() { + errChan <- conn.Handshake() + }() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-errChan: + if err != nil { + return nil, err + } + } + return conn, nil +} + func (va *ValidationAuthorityImpl) validateHTTP01(ctx context.Context, identifier core.AcmeIdentifier, challenge core.Challenge) ([]core.ValidationRecord, *probs.ProblemDetails) { if identifier.Type != core.IdentifierDNS { va.log.Info(fmt.Sprintf("Got non-DNS identifier for HTTP validation: %s", identifier)) diff --git a/va/va_test.go b/va/va_test.go index dcbcc5c3c..ed156acd4 100644 --- a/va/va_test.go +++ b/va/va_test.go @@ -598,14 +598,6 @@ func slowTLSSrv() *httptest.Server { } func TestTLSSNI01TimeoutAfterConnect(t *testing.T) { - // Set a short dial timeout so this test can happen quickly. Note: It would be - // better to override this with a context, but that doesn't work right now: - // https://github.com/letsencrypt/boulder/issues/3628 - oldSingleDialTimeout := singleDialTimeout - singleDialTimeout = 50 * time.Millisecond - defer func() { - singleDialTimeout = oldSingleDialTimeout - }() chall := createChallenge(core.ChallengeTypeTLSSNI01) hs := slowTLSSrv() va, _ := setup(hs, 0) @@ -628,7 +620,8 @@ func TestTLSSNI01TimeoutAfterConnect(t *testing.T) { t.Fatalf("TLSSNI returned before %s (%s) with %#v", timeout, took, prob) } if took > 2*timeout { - t.Fatalf("TLSSNI didn't timeout after %s", timeout) + t.Fatalf("TLSSNI didn't timeout after %s (took %s to return %#v)", timeout, + took, prob) } if prob == nil { t.Fatalf("Connection should've timed out") @@ -641,21 +634,16 @@ func TestTLSSNI01TimeoutAfterConnect(t *testing.T) { } func TestTLSSNI01DialTimeout(t *testing.T) { - // Set a short dial timeout so this test can happen quickly. Note: It would be - // better to override this with a context, but that doesn't work right now: - // https://github.com/letsencrypt/boulder/issues/3628 - old := singleDialTimeout - singleDialTimeout = 50 * time.Millisecond - defer func() { - singleDialTimeout = old - }() - timeout := singleDialTimeout chall := createChallenge(core.ChallengeTypeTLSSNI01) hs := slowTLSSrv() va, _ := setup(hs, 0) va.dnsClient = dnsMockReturnsUnroutable{&bdns.MockDNSClient{}} started := time.Now() + timeout := 50 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + // The only method I've found so far to trigger a connect timeout is to // connect to an unrouteable IP address. This usuall generates a connection // timeout, but will rarely return "Network unreachable" instead. If we get @@ -734,10 +722,10 @@ func TestTLSSNI01TalkingToHTTP(t *testing.T) { _, prob := va.validateTLSSNI01(ctx, dnsi("localhost"), chall) test.AssertError(t, prob, "TLS-SNI-01 validation passed when talking to a HTTP-only server") - test.Assert(t, strings.HasSuffix( - prob.Detail, - "Server only speaks HTTP, not TLS", - ), "validate TLS-SNI-01 didn't return useful error") + expected := "Server only speaks HTTP, not TLS" + if !strings.HasSuffix(prob.Detail, expected) { + t.Errorf("Got wrong error detail. Expected %q, got %q", expected, prob.Detail) + } } func brokenTLSSrv() *httptest.Server {