Use a context when dialing TLS for TLS-SNI (#3648)

This allows us to have fast-running unittests without modifying the global state in singleDialTimeout,
which can become a const.

Fixes #3628.

Builds on top of #3629, review that first.
This commit is contained in:
Jacob Hoffman-Andrews 2018-04-16 15:06:56 -07:00 committed by GitHub
parent 339ea954bd
commit a1b98d9163
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 32 deletions

View File

@ -47,7 +47,7 @@ const (
// before timing out. This timeout ignores the base RPC timeout and is strictly
// used for the DialContext operations that take place during an
// HTTP-01/TLS-SNI-[01|02] challenge validation.
var singleDialTimeout = time.Second * 10
const singleDialTimeout = time.Second * 10
// RemoteVA wraps the core.ValidationAuthority interface and adds a field containing the addresses
// of the remote gRPC server since the interface (and the underlying gRPC client) doesn't
@ -508,7 +508,7 @@ func (va *ValidationAuthorityImpl) tryGetTLSSNICerts(ctx context.Context, identi
if !features.Enabled(features.IPv6First) {
address := net.JoinHostPort(addresses[0].String(), thisRecord.Port)
thisRecord.AddressUsed = addresses[0]
certs, err := va.getTLSSNICerts(address, identifier, challenge, zName)
certs, err := va.getTLSSNICerts(ctx, address, identifier, challenge, zName)
return certs, validationRecords, err
}
@ -518,7 +518,7 @@ func (va *ValidationAuthorityImpl) tryGetTLSSNICerts(ctx context.Context, identi
address := net.JoinHostPort(v6[0].String(), thisRecord.Port)
thisRecord.AddressUsed = v6[0]
certs, err := va.getTLSSNICerts(address, identifier, challenge, zName)
certs, err := va.getTLSSNICerts(ctx, address, identifier, challenge, zName)
// If there is no error, return immediately
if err == nil {
@ -547,7 +547,7 @@ func (va *ValidationAuthorityImpl) tryGetTLSSNICerts(ctx context.Context, identi
// talking to the first IPv6 address, try the first IPv4 address
address := net.JoinHostPort(v4[0].String(), thisRecord.Port)
thisRecord.AddressUsed = v4[0]
certs, err := va.getTLSSNICerts(address, identifier, challenge, zName)
certs, err := va.getTLSSNICerts(ctx, address, identifier, challenge, zName)
return certs, validationRecords, err
}
@ -575,13 +575,15 @@ func (va *ValidationAuthorityImpl) validateTLSSNI01WithZName(ctx context.Context
return validationRecords, probs.Unauthorized(errText)
}
func (va *ValidationAuthorityImpl) getTLSSNICerts(hostPort string, identifier core.AcmeIdentifier, challenge core.Challenge, zName string) ([]*x509.Certificate, *probs.ProblemDetails) {
func (va *ValidationAuthorityImpl) getTLSSNICerts(
ctx context.Context,
hostPort string,
identifier core.AcmeIdentifier,
challenge core.Challenge,
zName string,
) ([]*x509.Certificate, *probs.ProblemDetails) {
va.log.Info(fmt.Sprintf("%s [%s] Attempting to validate for %s %s", challenge.Type, identifier, hostPort, zName))
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: singleDialTimeout}, "tcp", hostPort, &tls.Config{
ServerName: zName,
InsecureSkipVerify: true,
})
conn, err := tlsDial(ctx, hostPort, zName)
if err != nil {
va.log.Info(fmt.Sprintf("%s connection failure for %s. err=[%#v] errStr=[%s]", challenge.Type, identifier, err, err))
return nil, detailedError(err)
@ -604,6 +606,35 @@ func (va *ValidationAuthorityImpl) getTLSSNICerts(hostPort string, identifier co
return certs, nil
}
// tlsDial does the equivalent of tls.Dial, but obeying a context. Once
// tls.DialContextWithDialer is available, switch to that.
func tlsDial(ctx context.Context, hostPort, zName string) (*tls.Conn, error) {
ctx, cancel := context.WithTimeout(ctx, singleDialTimeout)
defer cancel()
dialer := &net.Dialer{}
netConn, err := dialer.DialContext(ctx, "tcp", hostPort)
if err != nil {
return nil, err
}
conn := tls.Client(netConn, &tls.Config{
ServerName: zName,
InsecureSkipVerify: true,
})
errChan := make(chan error)
go func() {
errChan <- conn.Handshake()
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errChan:
if err != nil {
return nil, err
}
}
return conn, nil
}
func (va *ValidationAuthorityImpl) validateHTTP01(ctx context.Context, identifier core.AcmeIdentifier, challenge core.Challenge) ([]core.ValidationRecord, *probs.ProblemDetails) {
if identifier.Type != core.IdentifierDNS {
va.log.Info(fmt.Sprintf("Got non-DNS identifier for HTTP validation: %s", identifier))

View File

@ -598,14 +598,6 @@ func slowTLSSrv() *httptest.Server {
}
func TestTLSSNI01TimeoutAfterConnect(t *testing.T) {
// Set a short dial timeout so this test can happen quickly. Note: It would be
// better to override this with a context, but that doesn't work right now:
// https://github.com/letsencrypt/boulder/issues/3628
oldSingleDialTimeout := singleDialTimeout
singleDialTimeout = 50 * time.Millisecond
defer func() {
singleDialTimeout = oldSingleDialTimeout
}()
chall := createChallenge(core.ChallengeTypeTLSSNI01)
hs := slowTLSSrv()
va, _ := setup(hs, 0)
@ -628,7 +620,8 @@ func TestTLSSNI01TimeoutAfterConnect(t *testing.T) {
t.Fatalf("TLSSNI returned before %s (%s) with %#v", timeout, took, prob)
}
if took > 2*timeout {
t.Fatalf("TLSSNI didn't timeout after %s", timeout)
t.Fatalf("TLSSNI didn't timeout after %s (took %s to return %#v)", timeout,
took, prob)
}
if prob == nil {
t.Fatalf("Connection should've timed out")
@ -641,21 +634,16 @@ func TestTLSSNI01TimeoutAfterConnect(t *testing.T) {
}
func TestTLSSNI01DialTimeout(t *testing.T) {
// Set a short dial timeout so this test can happen quickly. Note: It would be
// better to override this with a context, but that doesn't work right now:
// https://github.com/letsencrypt/boulder/issues/3628
old := singleDialTimeout
singleDialTimeout = 50 * time.Millisecond
defer func() {
singleDialTimeout = old
}()
timeout := singleDialTimeout
chall := createChallenge(core.ChallengeTypeTLSSNI01)
hs := slowTLSSrv()
va, _ := setup(hs, 0)
va.dnsClient = dnsMockReturnsUnroutable{&bdns.MockDNSClient{}}
started := time.Now()
timeout := 50 * time.Millisecond
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// The only method I've found so far to trigger a connect timeout is to
// connect to an unrouteable IP address. This usuall generates a connection
// timeout, but will rarely return "Network unreachable" instead. If we get
@ -734,10 +722,10 @@ func TestTLSSNI01TalkingToHTTP(t *testing.T) {
_, prob := va.validateTLSSNI01(ctx, dnsi("localhost"), chall)
test.AssertError(t, prob, "TLS-SNI-01 validation passed when talking to a HTTP-only server")
test.Assert(t, strings.HasSuffix(
prob.Detail,
"Server only speaks HTTP, not TLS",
), "validate TLS-SNI-01 didn't return useful error")
expected := "Server only speaks HTTP, not TLS"
if !strings.HasSuffix(prob.Detail, expected) {
t.Errorf("Got wrong error detail. Expected %q, got %q", expected, prob.Detail)
}
}
func brokenTLSSrv() *httptest.Server {