Use a context when dialing TLS for TLS-SNI (#3648)
This allows us to have fast-running unittests without modifying the global state in singleDialTimeout, which can become a const. Fixes #3628. Builds on top of #3629, review that first.
This commit is contained in:
parent
339ea954bd
commit
a1b98d9163
51
va/va.go
51
va/va.go
|
@ -47,7 +47,7 @@ const (
|
||||||
// before timing out. This timeout ignores the base RPC timeout and is strictly
|
// before timing out. This timeout ignores the base RPC timeout and is strictly
|
||||||
// used for the DialContext operations that take place during an
|
// used for the DialContext operations that take place during an
|
||||||
// HTTP-01/TLS-SNI-[01|02] challenge validation.
|
// HTTP-01/TLS-SNI-[01|02] challenge validation.
|
||||||
var singleDialTimeout = time.Second * 10
|
const singleDialTimeout = time.Second * 10
|
||||||
|
|
||||||
// RemoteVA wraps the core.ValidationAuthority interface and adds a field containing the addresses
|
// RemoteVA wraps the core.ValidationAuthority interface and adds a field containing the addresses
|
||||||
// of the remote gRPC server since the interface (and the underlying gRPC client) doesn't
|
// of the remote gRPC server since the interface (and the underlying gRPC client) doesn't
|
||||||
|
@ -508,7 +508,7 @@ func (va *ValidationAuthorityImpl) tryGetTLSSNICerts(ctx context.Context, identi
|
||||||
if !features.Enabled(features.IPv6First) {
|
if !features.Enabled(features.IPv6First) {
|
||||||
address := net.JoinHostPort(addresses[0].String(), thisRecord.Port)
|
address := net.JoinHostPort(addresses[0].String(), thisRecord.Port)
|
||||||
thisRecord.AddressUsed = addresses[0]
|
thisRecord.AddressUsed = addresses[0]
|
||||||
certs, err := va.getTLSSNICerts(address, identifier, challenge, zName)
|
certs, err := va.getTLSSNICerts(ctx, address, identifier, challenge, zName)
|
||||||
return certs, validationRecords, err
|
return certs, validationRecords, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -518,7 +518,7 @@ func (va *ValidationAuthorityImpl) tryGetTLSSNICerts(ctx context.Context, identi
|
||||||
address := net.JoinHostPort(v6[0].String(), thisRecord.Port)
|
address := net.JoinHostPort(v6[0].String(), thisRecord.Port)
|
||||||
thisRecord.AddressUsed = v6[0]
|
thisRecord.AddressUsed = v6[0]
|
||||||
|
|
||||||
certs, err := va.getTLSSNICerts(address, identifier, challenge, zName)
|
certs, err := va.getTLSSNICerts(ctx, address, identifier, challenge, zName)
|
||||||
|
|
||||||
// If there is no error, return immediately
|
// If there is no error, return immediately
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -547,7 +547,7 @@ func (va *ValidationAuthorityImpl) tryGetTLSSNICerts(ctx context.Context, identi
|
||||||
// talking to the first IPv6 address, try the first IPv4 address
|
// talking to the first IPv6 address, try the first IPv4 address
|
||||||
address := net.JoinHostPort(v4[0].String(), thisRecord.Port)
|
address := net.JoinHostPort(v4[0].String(), thisRecord.Port)
|
||||||
thisRecord.AddressUsed = v4[0]
|
thisRecord.AddressUsed = v4[0]
|
||||||
certs, err := va.getTLSSNICerts(address, identifier, challenge, zName)
|
certs, err := va.getTLSSNICerts(ctx, address, identifier, challenge, zName)
|
||||||
return certs, validationRecords, err
|
return certs, validationRecords, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -575,13 +575,15 @@ func (va *ValidationAuthorityImpl) validateTLSSNI01WithZName(ctx context.Context
|
||||||
return validationRecords, probs.Unauthorized(errText)
|
return validationRecords, probs.Unauthorized(errText)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (va *ValidationAuthorityImpl) getTLSSNICerts(hostPort string, identifier core.AcmeIdentifier, challenge core.Challenge, zName string) ([]*x509.Certificate, *probs.ProblemDetails) {
|
func (va *ValidationAuthorityImpl) getTLSSNICerts(
|
||||||
|
ctx context.Context,
|
||||||
|
hostPort string,
|
||||||
|
identifier core.AcmeIdentifier,
|
||||||
|
challenge core.Challenge,
|
||||||
|
zName string,
|
||||||
|
) ([]*x509.Certificate, *probs.ProblemDetails) {
|
||||||
va.log.Info(fmt.Sprintf("%s [%s] Attempting to validate for %s %s", challenge.Type, identifier, hostPort, zName))
|
va.log.Info(fmt.Sprintf("%s [%s] Attempting to validate for %s %s", challenge.Type, identifier, hostPort, zName))
|
||||||
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: singleDialTimeout}, "tcp", hostPort, &tls.Config{
|
conn, err := tlsDial(ctx, hostPort, zName)
|
||||||
ServerName: zName,
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
va.log.Info(fmt.Sprintf("%s connection failure for %s. err=[%#v] errStr=[%s]", challenge.Type, identifier, err, err))
|
va.log.Info(fmt.Sprintf("%s connection failure for %s. err=[%#v] errStr=[%s]", challenge.Type, identifier, err, err))
|
||||||
return nil, detailedError(err)
|
return nil, detailedError(err)
|
||||||
|
@ -604,6 +606,35 @@ func (va *ValidationAuthorityImpl) getTLSSNICerts(hostPort string, identifier co
|
||||||
return certs, nil
|
return certs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tlsDial does the equivalent of tls.Dial, but obeying a context. Once
|
||||||
|
// tls.DialContextWithDialer is available, switch to that.
|
||||||
|
func tlsDial(ctx context.Context, hostPort, zName string) (*tls.Conn, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, singleDialTimeout)
|
||||||
|
defer cancel()
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
netConn, err := dialer.DialContext(ctx, "tcp", hostPort)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn := tls.Client(netConn, &tls.Config{
|
||||||
|
ServerName: zName,
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
|
errChan := make(chan error)
|
||||||
|
go func() {
|
||||||
|
errChan <- conn.Handshake()
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (va *ValidationAuthorityImpl) validateHTTP01(ctx context.Context, identifier core.AcmeIdentifier, challenge core.Challenge) ([]core.ValidationRecord, *probs.ProblemDetails) {
|
func (va *ValidationAuthorityImpl) validateHTTP01(ctx context.Context, identifier core.AcmeIdentifier, challenge core.Challenge) ([]core.ValidationRecord, *probs.ProblemDetails) {
|
||||||
if identifier.Type != core.IdentifierDNS {
|
if identifier.Type != core.IdentifierDNS {
|
||||||
va.log.Info(fmt.Sprintf("Got non-DNS identifier for HTTP validation: %s", identifier))
|
va.log.Info(fmt.Sprintf("Got non-DNS identifier for HTTP validation: %s", identifier))
|
||||||
|
|
|
@ -598,14 +598,6 @@ func slowTLSSrv() *httptest.Server {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSSNI01TimeoutAfterConnect(t *testing.T) {
|
func TestTLSSNI01TimeoutAfterConnect(t *testing.T) {
|
||||||
// Set a short dial timeout so this test can happen quickly. Note: It would be
|
|
||||||
// better to override this with a context, but that doesn't work right now:
|
|
||||||
// https://github.com/letsencrypt/boulder/issues/3628
|
|
||||||
oldSingleDialTimeout := singleDialTimeout
|
|
||||||
singleDialTimeout = 50 * time.Millisecond
|
|
||||||
defer func() {
|
|
||||||
singleDialTimeout = oldSingleDialTimeout
|
|
||||||
}()
|
|
||||||
chall := createChallenge(core.ChallengeTypeTLSSNI01)
|
chall := createChallenge(core.ChallengeTypeTLSSNI01)
|
||||||
hs := slowTLSSrv()
|
hs := slowTLSSrv()
|
||||||
va, _ := setup(hs, 0)
|
va, _ := setup(hs, 0)
|
||||||
|
@ -628,7 +620,8 @@ func TestTLSSNI01TimeoutAfterConnect(t *testing.T) {
|
||||||
t.Fatalf("TLSSNI returned before %s (%s) with %#v", timeout, took, prob)
|
t.Fatalf("TLSSNI returned before %s (%s) with %#v", timeout, took, prob)
|
||||||
}
|
}
|
||||||
if took > 2*timeout {
|
if took > 2*timeout {
|
||||||
t.Fatalf("TLSSNI didn't timeout after %s", timeout)
|
t.Fatalf("TLSSNI didn't timeout after %s (took %s to return %#v)", timeout,
|
||||||
|
took, prob)
|
||||||
}
|
}
|
||||||
if prob == nil {
|
if prob == nil {
|
||||||
t.Fatalf("Connection should've timed out")
|
t.Fatalf("Connection should've timed out")
|
||||||
|
@ -641,21 +634,16 @@ func TestTLSSNI01TimeoutAfterConnect(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSSNI01DialTimeout(t *testing.T) {
|
func TestTLSSNI01DialTimeout(t *testing.T) {
|
||||||
// Set a short dial timeout so this test can happen quickly. Note: It would be
|
|
||||||
// better to override this with a context, but that doesn't work right now:
|
|
||||||
// https://github.com/letsencrypt/boulder/issues/3628
|
|
||||||
old := singleDialTimeout
|
|
||||||
singleDialTimeout = 50 * time.Millisecond
|
|
||||||
defer func() {
|
|
||||||
singleDialTimeout = old
|
|
||||||
}()
|
|
||||||
timeout := singleDialTimeout
|
|
||||||
chall := createChallenge(core.ChallengeTypeTLSSNI01)
|
chall := createChallenge(core.ChallengeTypeTLSSNI01)
|
||||||
hs := slowTLSSrv()
|
hs := slowTLSSrv()
|
||||||
va, _ := setup(hs, 0)
|
va, _ := setup(hs, 0)
|
||||||
va.dnsClient = dnsMockReturnsUnroutable{&bdns.MockDNSClient{}}
|
va.dnsClient = dnsMockReturnsUnroutable{&bdns.MockDNSClient{}}
|
||||||
started := time.Now()
|
started := time.Now()
|
||||||
|
|
||||||
|
timeout := 50 * time.Millisecond
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// The only method I've found so far to trigger a connect timeout is to
|
// The only method I've found so far to trigger a connect timeout is to
|
||||||
// connect to an unrouteable IP address. This usuall generates a connection
|
// connect to an unrouteable IP address. This usuall generates a connection
|
||||||
// timeout, but will rarely return "Network unreachable" instead. If we get
|
// timeout, but will rarely return "Network unreachable" instead. If we get
|
||||||
|
@ -734,10 +722,10 @@ func TestTLSSNI01TalkingToHTTP(t *testing.T) {
|
||||||
|
|
||||||
_, prob := va.validateTLSSNI01(ctx, dnsi("localhost"), chall)
|
_, prob := va.validateTLSSNI01(ctx, dnsi("localhost"), chall)
|
||||||
test.AssertError(t, prob, "TLS-SNI-01 validation passed when talking to a HTTP-only server")
|
test.AssertError(t, prob, "TLS-SNI-01 validation passed when talking to a HTTP-only server")
|
||||||
test.Assert(t, strings.HasSuffix(
|
expected := "Server only speaks HTTP, not TLS"
|
||||||
prob.Detail,
|
if !strings.HasSuffix(prob.Detail, expected) {
|
||||||
"Server only speaks HTTP, not TLS",
|
t.Errorf("Got wrong error detail. Expected %q, got %q", expected, prob.Detail)
|
||||||
), "validate TLS-SNI-01 didn't return useful error")
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func brokenTLSSrv() *httptest.Server {
|
func brokenTLSSrv() *httptest.Server {
|
||||||
|
|
Loading…
Reference in New Issue