Use a context when dialing TLS for TLS-SNI (#3648)
This allows us to have fast-running unittests without modifying the global state in singleDialTimeout, which can become a const. Fixes #3628. Builds on top of #3629, review that first.
This commit is contained in:
parent
339ea954bd
commit
a1b98d9163
51
va/va.go
51
va/va.go
|
@ -47,7 +47,7 @@ const (
|
|||
// before timing out. This timeout ignores the base RPC timeout and is strictly
|
||||
// used for the DialContext operations that take place during an
|
||||
// HTTP-01/TLS-SNI-[01|02] challenge validation.
|
||||
var singleDialTimeout = time.Second * 10
|
||||
const singleDialTimeout = time.Second * 10
|
||||
|
||||
// RemoteVA wraps the core.ValidationAuthority interface and adds a field containing the addresses
|
||||
// of the remote gRPC server since the interface (and the underlying gRPC client) doesn't
|
||||
|
@ -508,7 +508,7 @@ func (va *ValidationAuthorityImpl) tryGetTLSSNICerts(ctx context.Context, identi
|
|||
if !features.Enabled(features.IPv6First) {
|
||||
address := net.JoinHostPort(addresses[0].String(), thisRecord.Port)
|
||||
thisRecord.AddressUsed = addresses[0]
|
||||
certs, err := va.getTLSSNICerts(address, identifier, challenge, zName)
|
||||
certs, err := va.getTLSSNICerts(ctx, address, identifier, challenge, zName)
|
||||
return certs, validationRecords, err
|
||||
}
|
||||
|
||||
|
@ -518,7 +518,7 @@ func (va *ValidationAuthorityImpl) tryGetTLSSNICerts(ctx context.Context, identi
|
|||
address := net.JoinHostPort(v6[0].String(), thisRecord.Port)
|
||||
thisRecord.AddressUsed = v6[0]
|
||||
|
||||
certs, err := va.getTLSSNICerts(address, identifier, challenge, zName)
|
||||
certs, err := va.getTLSSNICerts(ctx, address, identifier, challenge, zName)
|
||||
|
||||
// If there is no error, return immediately
|
||||
if err == nil {
|
||||
|
@ -547,7 +547,7 @@ func (va *ValidationAuthorityImpl) tryGetTLSSNICerts(ctx context.Context, identi
|
|||
// talking to the first IPv6 address, try the first IPv4 address
|
||||
address := net.JoinHostPort(v4[0].String(), thisRecord.Port)
|
||||
thisRecord.AddressUsed = v4[0]
|
||||
certs, err := va.getTLSSNICerts(address, identifier, challenge, zName)
|
||||
certs, err := va.getTLSSNICerts(ctx, address, identifier, challenge, zName)
|
||||
return certs, validationRecords, err
|
||||
}
|
||||
|
||||
|
@ -575,13 +575,15 @@ func (va *ValidationAuthorityImpl) validateTLSSNI01WithZName(ctx context.Context
|
|||
return validationRecords, probs.Unauthorized(errText)
|
||||
}
|
||||
|
||||
func (va *ValidationAuthorityImpl) getTLSSNICerts(hostPort string, identifier core.AcmeIdentifier, challenge core.Challenge, zName string) ([]*x509.Certificate, *probs.ProblemDetails) {
|
||||
func (va *ValidationAuthorityImpl) getTLSSNICerts(
|
||||
ctx context.Context,
|
||||
hostPort string,
|
||||
identifier core.AcmeIdentifier,
|
||||
challenge core.Challenge,
|
||||
zName string,
|
||||
) ([]*x509.Certificate, *probs.ProblemDetails) {
|
||||
va.log.Info(fmt.Sprintf("%s [%s] Attempting to validate for %s %s", challenge.Type, identifier, hostPort, zName))
|
||||
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: singleDialTimeout}, "tcp", hostPort, &tls.Config{
|
||||
ServerName: zName,
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
|
||||
conn, err := tlsDial(ctx, hostPort, zName)
|
||||
if err != nil {
|
||||
va.log.Info(fmt.Sprintf("%s connection failure for %s. err=[%#v] errStr=[%s]", challenge.Type, identifier, err, err))
|
||||
return nil, detailedError(err)
|
||||
|
@ -604,6 +606,35 @@ func (va *ValidationAuthorityImpl) getTLSSNICerts(hostPort string, identifier co
|
|||
return certs, nil
|
||||
}
|
||||
|
||||
// tlsDial does the equivalent of tls.Dial, but obeying a context. Once
|
||||
// tls.DialContextWithDialer is available, switch to that.
|
||||
func tlsDial(ctx context.Context, hostPort, zName string) (*tls.Conn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, singleDialTimeout)
|
||||
defer cancel()
|
||||
dialer := &net.Dialer{}
|
||||
netConn, err := dialer.DialContext(ctx, "tcp", hostPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := tls.Client(netConn, &tls.Config{
|
||||
ServerName: zName,
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
errChan <- conn.Handshake()
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (va *ValidationAuthorityImpl) validateHTTP01(ctx context.Context, identifier core.AcmeIdentifier, challenge core.Challenge) ([]core.ValidationRecord, *probs.ProblemDetails) {
|
||||
if identifier.Type != core.IdentifierDNS {
|
||||
va.log.Info(fmt.Sprintf("Got non-DNS identifier for HTTP validation: %s", identifier))
|
||||
|
|
|
@ -598,14 +598,6 @@ func slowTLSSrv() *httptest.Server {
|
|||
}
|
||||
|
||||
func TestTLSSNI01TimeoutAfterConnect(t *testing.T) {
|
||||
// Set a short dial timeout so this test can happen quickly. Note: It would be
|
||||
// better to override this with a context, but that doesn't work right now:
|
||||
// https://github.com/letsencrypt/boulder/issues/3628
|
||||
oldSingleDialTimeout := singleDialTimeout
|
||||
singleDialTimeout = 50 * time.Millisecond
|
||||
defer func() {
|
||||
singleDialTimeout = oldSingleDialTimeout
|
||||
}()
|
||||
chall := createChallenge(core.ChallengeTypeTLSSNI01)
|
||||
hs := slowTLSSrv()
|
||||
va, _ := setup(hs, 0)
|
||||
|
@ -628,7 +620,8 @@ func TestTLSSNI01TimeoutAfterConnect(t *testing.T) {
|
|||
t.Fatalf("TLSSNI returned before %s (%s) with %#v", timeout, took, prob)
|
||||
}
|
||||
if took > 2*timeout {
|
||||
t.Fatalf("TLSSNI didn't timeout after %s", timeout)
|
||||
t.Fatalf("TLSSNI didn't timeout after %s (took %s to return %#v)", timeout,
|
||||
took, prob)
|
||||
}
|
||||
if prob == nil {
|
||||
t.Fatalf("Connection should've timed out")
|
||||
|
@ -641,21 +634,16 @@ func TestTLSSNI01TimeoutAfterConnect(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTLSSNI01DialTimeout(t *testing.T) {
|
||||
// Set a short dial timeout so this test can happen quickly. Note: It would be
|
||||
// better to override this with a context, but that doesn't work right now:
|
||||
// https://github.com/letsencrypt/boulder/issues/3628
|
||||
old := singleDialTimeout
|
||||
singleDialTimeout = 50 * time.Millisecond
|
||||
defer func() {
|
||||
singleDialTimeout = old
|
||||
}()
|
||||
timeout := singleDialTimeout
|
||||
chall := createChallenge(core.ChallengeTypeTLSSNI01)
|
||||
hs := slowTLSSrv()
|
||||
va, _ := setup(hs, 0)
|
||||
va.dnsClient = dnsMockReturnsUnroutable{&bdns.MockDNSClient{}}
|
||||
started := time.Now()
|
||||
|
||||
timeout := 50 * time.Millisecond
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// The only method I've found so far to trigger a connect timeout is to
|
||||
// connect to an unrouteable IP address. This usuall generates a connection
|
||||
// timeout, but will rarely return "Network unreachable" instead. If we get
|
||||
|
@ -734,10 +722,10 @@ func TestTLSSNI01TalkingToHTTP(t *testing.T) {
|
|||
|
||||
_, prob := va.validateTLSSNI01(ctx, dnsi("localhost"), chall)
|
||||
test.AssertError(t, prob, "TLS-SNI-01 validation passed when talking to a HTTP-only server")
|
||||
test.Assert(t, strings.HasSuffix(
|
||||
prob.Detail,
|
||||
"Server only speaks HTTP, not TLS",
|
||||
), "validate TLS-SNI-01 didn't return useful error")
|
||||
expected := "Server only speaks HTTP, not TLS"
|
||||
if !strings.HasSuffix(prob.Detail, expected) {
|
||||
t.Errorf("Got wrong error detail. Expected %q, got %q", expected, prob.Detail)
|
||||
}
|
||||
}
|
||||
|
||||
func brokenTLSSrv() *httptest.Server {
|
||||
|
|
Loading…
Reference in New Issue