diff --git a/va/va_test.go b/va/va_test.go index b69c091e9..0a8b2cfef 100644 --- a/va/va_test.go +++ b/va/va_test.go @@ -235,8 +235,7 @@ func TestHTTP(t *testing.T) { // TODO(#1989): close hs hs := httpSrv(t, chall.Token) - goodPort, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") + va, goodPort, _, log := setup(t, hs) // Attempt to fail a challenge by telling the VA to connect to a port we are // not listening on. @@ -244,7 +243,6 @@ func TestHTTP(t *testing.T) { if badPort == 65536 { badPort = goodPort - 1 } - va, _, log := setup() va.httpPort = badPort _, prob := va.validateHTTP01(ctx, ident, chall) @@ -331,10 +329,7 @@ func TestHTTPRedirectLookup(t *testing.T) { hs := httpSrv(t, expectedToken) defer hs.Close() - port, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") - va, _, log := setup() - va.httpPort = port + va, _, _, log := setup(t, hs) setChallengeToken(&chall, pathMoved) _, prob := va.validateHTTP01(ctx, ident, chall) @@ -356,7 +351,7 @@ func TestHTTPRedirectLookup(t *testing.T) { log.Clear() setChallengeToken(&chall, pathReLookupInvalid) - _, err = va.validateHTTP01(ctx, ident, chall) + _, err := va.validateHTTP01(ctx, ident, chall) test.AssertError(t, err, chall.Token) test.AssertEquals(t, len(log.GetAllMatching(`Resolved addresses for localhost \[using 127.0.0.1\]: \[127.0.0.1\]`)), 1) test.AssertEquals(t, len(log.GetAllMatching(`No valid IP addresses found for invalid.invalid`)), 1) @@ -395,10 +390,7 @@ func TestHTTPRedirectLoop(t *testing.T) { hs := httpSrv(t, expectedToken) defer hs.Close() - port, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") - va, _, _ := setup() - va.httpPort = port + va, _, _, _ := setup(t, hs) _, prob := va.validateHTTP01(ctx, ident, chall) if prob == nil { @@ -412,11 +404,8 @@ func TestHTTPRedirectUserAgent(t *testing.T) { hs := httpSrv(t, expectedToken) defer hs.Close() - port, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") - va, _, _ := setup() + va, _, _, _ := setup(t, hs) va.userAgent = rejectUserAgent - va.httpPort = port setChallengeToken(&chall, pathMoved) _, prob := va.validateHTTP01(ctx, ident, chall) @@ -451,11 +440,8 @@ func TestTLSSNI01(t *testing.T) { chall := createChallenge(core.ChallengeTypeTLSSNI01) hs := tlssni01Srv(t, chall) - port, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") - va, _, log := setup() - va.tlsPort = port + va, port, _, log := setup(t, hs) _, prob := va.validateTLSSNI01(ctx, ident, chall) if prob != nil { @@ -503,16 +489,15 @@ func TestTLSSNI01(t *testing.T) { // Take down validation server and check that validation fails. hs.Close() - _, err = va.validateTLSSNI01(ctx, ident, chall) + _, err := va.validateTLSSNI01(ctx, ident, chall) if err == nil { t.Fatalf("Server's down; expected refusal. Where did we connect?") } test.AssertEquals(t, prob.Type, probs.ConnectionProblem) httpOnly := httpSrv(t, "") - defer httpOnly.Close() - port, err = getPort(httpOnly) - test.AssertNotError(t, err, "failed to get test server port") + port, portErr := getPort(httpOnly) + test.AssertNotError(t, portErr, "failed to get test server port") va.tlsPort = port log.Clear() @@ -528,11 +513,8 @@ func TestTLSSNI02(t *testing.T) { chall := createChallenge(core.ChallengeTypeTLSSNI02) hs := tlssni02Srv(t, chall) - port, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") - va, _, log := setup() - va.tlsPort = port + va, port, _, log := setup(t, hs) _, prob := va.validateTLSSNI02(ctx, ident, chall) if prob != nil { @@ -580,7 +562,7 @@ func TestTLSSNI02(t *testing.T) { // Take down validation server and check that validation fails. hs.Close() - _, err = va.validateTLSSNI02(ctx, ident, chall) + _, err := va.validateTLSSNI02(ctx, ident, chall) if err == nil { t.Fatalf("Server's down; expected refusal. Where did we connect?") } @@ -588,8 +570,8 @@ func TestTLSSNI02(t *testing.T) { httpOnly := httpSrv(t, "") defer httpOnly.Close() - port, err = getPort(httpOnly) - test.AssertNotError(t, err, "failed to get test server port") + port, portErr := getPort(httpOnly) + test.AssertNotError(t, portErr, "failed to get test server port") va.tlsPort = port log.Clear() @@ -616,10 +598,7 @@ func TestTLSError(t *testing.T) { chall := createChallenge(core.ChallengeTypeTLSSNI01) hs := brokenTLSSrv() - port, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") - va, _, _ := setup() - va.tlsPort = port + va, _, _, _ := setup(t, hs) _, prob := va.validateTLSSNI01(ctx, ident, chall) if prob == nil { @@ -700,10 +679,7 @@ func TestSNIErrInvalidChain(t *testing.T) { chall := createChallenge(core.ChallengeTypeTLSSNI01) hs := misconfiguredTLSSrv() - port, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") - va, _, _ := setup() - va.tlsPort = port + va, _, _, _ := setup(t, hs) // Validate the SNI challenge with the test server, expecting it to fail _, prob := va.validateTLSSNI01(ctx, ident, chall) @@ -724,13 +700,10 @@ func TestValidateHTTP(t *testing.T) { setChallengeToken(&chall, core.NewToken()) hs := httpSrv(t, chall.Token) - port, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") - va, _, _ := setup() - va.httpPort = port - defer hs.Close() + va, _, _, _ := setup(t, hs) + _, prob := va.validateChallenge(ctx, ident, chall) test.Assert(t, prob == nil, "validation failed") } @@ -760,11 +733,7 @@ func TestValidateTLSSNI01(t *testing.T) { hs := tlssni01Srv(t, chall) defer hs.Close() - port, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") - - va, _, _ := setup() - va.tlsPort = port + va, _, _, _ := setup(t, hs) _, prob := va.validateChallenge(ctx, ident, chall) @@ -772,7 +741,7 @@ func TestValidateTLSSNI01(t *testing.T) { } func TestValidateTLSSNI01NotSane(t *testing.T) { - va, _, _ := setup() + va, _, _, _ := setup(t, nil) chall := createChallenge(core.ChallengeTypeTLSSNI01) @@ -784,7 +753,7 @@ func TestValidateTLSSNI01NotSane(t *testing.T) { } func TestCAATimeout(t *testing.T) { - va, _, _ := setup() + va, _, _, _ := setup(t, nil) err := va.checkCAA(ctx, core.AcmeIdentifier{Type: core.IdentifierDNS, Value: "caa-timeout.com"}) if err.Type != probs.ConnectionProblem { t.Errorf("Expected timeout error type %s, got %s", probs.ConnectionProblem, err.Type) @@ -826,7 +795,7 @@ func TestCAAChecking(t *testing.T) { {"unsatisfiable.com", true, false}, } - va, _, _ := setup() + va, _, _, _ := setup(t, nil) for _, caaTest := range tests { present, valid, err := va.checkCAARecords(ctx, core.AcmeIdentifier{Type: "dns", Value: caaTest.Domain}) if err != nil { @@ -862,7 +831,7 @@ func TestCAAChecking(t *testing.T) { } func TestPerformValidationInvalid(t *testing.T) { - va, stats, _ := setup() + va, _, stats, _ := setup(t, nil) chalDNS := createChallenge(core.ChallengeTypeDNS01) _, prob := va.PerformValidation(context.Background(), "foo.com", chalDNS, core.Authorization{}) test.Assert(t, prob != nil, "validation succeeded") @@ -870,7 +839,7 @@ func TestPerformValidationInvalid(t *testing.T) { } func TestDNSValidationEmpty(t *testing.T) { - va, stats, _ := setup() + va, _, stats, _ := setup(t, nil) chalDNS := createChallenge(core.ChallengeTypeDNS01) _, prob := va.PerformValidation( context.Background(), @@ -882,7 +851,7 @@ func TestDNSValidationEmpty(t *testing.T) { } func TestPerformValidationValid(t *testing.T) { - va, stats, _ := setup() + va, _, stats, _ := setup(t, nil) // create a challenge with well known token chalDNS := core.DNSChallenge01() chalDNS.Token = expectedToken @@ -893,7 +862,7 @@ func TestPerformValidationValid(t *testing.T) { } func TestDNSValidationFailure(t *testing.T) { - va, _, _ := setup() + va, _, _, _ := setup(t, nil) chalDNS := createChallenge(core.ChallengeTypeDNS01) @@ -911,7 +880,7 @@ func TestDNSValidationInvalid(t *testing.T) { chalDNS := core.DNSChallenge01() chalDNS.ProvidedKeyAuthorization = expectedKeyAuthorization - va, _, _ := setup() + va, _, _, _ := setup(t, nil) _, prob := va.validateChallenge(ctx, notDNS, chalDNS) @@ -919,7 +888,7 @@ func TestDNSValidationInvalid(t *testing.T) { } func TestDNSValidationNotSane(t *testing.T) { - va, _, _ := setup() + va, _, _, _ := setup(t, nil) chal0 := core.DNSChallenge01() chal0.Token = "" @@ -950,7 +919,7 @@ func TestDNSValidationNotSane(t *testing.T) { } func TestDNSValidationServFail(t *testing.T) { - va, _, _ := setup() + va, _, _, _ := setup(t, nil) chalDNS := createChallenge(core.ChallengeTypeDNS01) @@ -964,7 +933,7 @@ func TestDNSValidationServFail(t *testing.T) { } func TestDNSValidationNoServer(t *testing.T) { - va, _, _ := setup() + va, _, _, _ := setup(t, nil) va.dnsResolver = bdns.NewTestDNSResolverImpl( time.Second*5, nil, @@ -980,7 +949,7 @@ func TestDNSValidationNoServer(t *testing.T) { } func TestDNSValidationOK(t *testing.T) { - va, _, _ := setup() + va, _, _, _ := setup(t, nil) // create a challenge with well known token chalDNS := core.DNSChallenge01() @@ -998,7 +967,7 @@ func TestDNSValidationOK(t *testing.T) { } func TestDNSValidationNoAuthorityOK(t *testing.T) { - va, _, _ := setup() + va, _, _, _ := setup(t, nil) // create a challenge with well known token chalDNS := core.DNSChallenge01() @@ -1021,11 +990,7 @@ func TestCAAFailure(t *testing.T) { hs := tlssni01Srv(t, chall) defer hs.Close() - port, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") - - va, _, _ := setup() - va.tlsPort = port + va, _, _, _ := setup(t, hs) ident.Value = "reserved.com" _, prob := va.validateChallengeAndCAA(ctx, ident, chall) @@ -1038,10 +1003,7 @@ func TestLimitedReader(t *testing.T) { ident.Value = "localhost" hs := httpSrv(t, "01234567890123456789012345678901234567890123456789012345678901234567890123456789") - port, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") - va, _, _ := setup() - va.httpPort = port + va, _, _, _ := setup(t, hs) defer hs.Close() @@ -1052,12 +1014,27 @@ func TestLimitedReader(t *testing.T) { "Expected failure due to truncation") } -func setup() (*ValidationAuthorityImpl, *mocks.Statter, *blog.Mock) { +func setup(t *testing.T, srv *httptest.Server) (*ValidationAuthorityImpl, int, *mocks.Statter, *blog.Mock) { stats := mocks.NewStatter() scope := metrics.NewStatsdScope(stats, "VA") logger := blog.NewMock() + + var portConfig cmd.PortConfig + if srv != nil { + port, err := getPort(srv) + if err != nil { + // We never expect to fail to get the port for an http server, so treat it + // as a fatal test failure and halt immediately + t.Fatalf("Unable to get port for test server: %s\n", err.Error()) + } + portConfig = cmd.PortConfig{ + HTTPPort: port, + TLSPort: port, + } + } va := NewValidationAuthorityImpl( - &cmd.PortConfig{}, + // Use the test server's port as both the HTTPPort and the TLSPort for the VA + &portConfig, nil, nil, &bdns.MockDNSResolver{}, @@ -1066,7 +1043,7 @@ func setup() (*ValidationAuthorityImpl, *mocks.Statter, *blog.Mock) { scope, clock.Default(), logger) - return va, stats, logger + return va, portConfig.HTTPPort, stats, logger } func TestCheckCAAFallback(t *testing.T) { @@ -1205,9 +1182,6 @@ func TestAvailableAddresses(t *testing.T) { } func TestFallbackDialer(t *testing.T) { - // Create a test VA - va, _, _ := setup() - // Create a new challenge to use for the httpSrv chall := core.HTTPChallenge01() setChallengeToken(&chall, core.NewToken()) @@ -1216,11 +1190,8 @@ func TestFallbackDialer(t *testing.T) { hs := httpSrv(t, chall.Token) defer hs.Close() - // Figure out what port the test server is on, and configure the VA to use - // that for HTTP challenges - port, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") - va.httpPort = port + // Create a test VA + va, _, _, _ := setup(t, hs) // Create an identifier for a host that has an IPv6 and an IPv4 address. // Since the IPv6First feature flag is not enabled we expect that the IPv4 @@ -1264,9 +1235,6 @@ func TestFallbackDialer(t *testing.T) { } func TestFallbackTLS(t *testing.T) { - // Create a test VA - va, _, _ := setup() - // Create a new challenge to use for the httpSrv chall := createChallenge(core.ChallengeTypeTLSSNI01) @@ -1275,11 +1243,8 @@ func TestFallbackTLS(t *testing.T) { hs := tlssni01Srv(t, chall) defer hs.Close() - // Figure out what port the test server is on, and configure the VA to use - // that for TLS challenges - port, err := getPort(hs) - test.AssertNotError(t, err, "failed to get test server port") - va.tlsPort = port + // Create a test VA + va, _, _, _ := setup(t, hs) // Create an identifier for a host that has an IPv6 and an IPv4 address. // Since the IPv6First feature flag is not enabled we expect that the IPv4