diff --git a/va/caa_test.go b/va/caa_test.go index 208592cac..c6f00b0b7 100644 --- a/va/caa_test.go +++ b/va/caa_test.go @@ -959,19 +959,18 @@ func TestMultiCAARechecking(t *testing.T) { } func TestCAAFailure(t *testing.T) { - chall := createChallenge(core.ChallengeTypeHTTP01) - hs := httpSrv(t, chall.Token) + hs := httpSrv(t, expectedToken) defer hs.Close() va, _ := setup(hs, 0, "", nil, caaMockDNS{}) - _, err := va.validate(ctx, dnsi("reserved.com"), 1, chall, expectedKeyAuthorization) + err := va.checkCAA(ctx, dnsi("reserved.com"), &caaParams{1, core.ChallengeTypeHTTP01}) if err == nil { t.Fatalf("Expected CAA rejection for reserved.com, got success") } test.AssertErrorIs(t, err, berrors.CAA) - _, err = va.validate(ctx, dnsi("example.gonetld"), 1, chall, expectedKeyAuthorization) + err = va.checkCAA(ctx, dnsi("example.gonetld"), &caaParams{1, core.ChallengeTypeHTTP01}) if err == nil { t.Fatalf("Expected CAA rejection for gonetld, got success") } diff --git a/va/dns_test.go b/va/dns_test.go index f0c4c28b2..a545228a4 100644 --- a/va/dns_test.go +++ b/va/dns_test.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net" - "strings" "testing" "time" @@ -91,44 +90,6 @@ func TestDNSValidationInvalid(t *testing.T) { test.AssertEquals(t, prob.Type, probs.MalformedProblem) } -func TestDNSValidationNotSane(t *testing.T) { - va, _ := setup(nil, 0, "", nil, nil) - - chall := createChallenge(core.ChallengeTypeDNS01) - chall.Token = "" - _, err := va.validateChallenge(ctx, dnsi("localhost"), chall, expectedKeyAuthorization) - prob := detailedError(err) - if prob.Type != probs.MalformedProblem { - t.Errorf("Got wrong error type: expected %s, got %s", - prob.Type, probs.MalformedProblem) - } - if !strings.Contains(prob.Error(), "Challenge failed consistency check:") { - t.Errorf("Got wrong error: %s", prob.Error()) - } - - chall.Token = "yfCBb-bRTLz8Wd1C0lTUQK3qlKj3-t2tYGwx5Hj7r_" - _, err = va.validateChallenge(ctx, dnsi("localhost"), chall, expectedKeyAuthorization) - prob = detailedError(err) - if prob.Type != probs.MalformedProblem { - t.Errorf("Got wrong error type: expected %s, got %s", - prob.Type, probs.MalformedProblem) - } - if !strings.Contains(prob.Error(), "Challenge failed consistency check:") { - t.Errorf("Got wrong error: %s", prob.Error()) - } - - _, err = va.validateChallenge(ctx, dnsi("localhost"), chall, "a") - prob = detailedError(err) - if prob.Type != probs.MalformedProblem { - t.Errorf("Got wrong error type: expected %s, got %s", - prob.Type, probs.MalformedProblem) - } - if !strings.Contains(prob.Error(), "Challenge failed consistency check:") { - t.Errorf("Got wrong error: %s", prob.Error()) - } - -} - func TestDNSValidationServFail(t *testing.T) { va, _ := setup(nil, 0, "", nil, nil) diff --git a/va/va.go b/va/va.go index 42b872fed..d43346bbc 100644 --- a/va/va.go +++ b/va/va.go @@ -120,7 +120,7 @@ func initMetrics(stats prometheus.Registerer) *vaMetrics { Help: "Time taken to remotely validate a challenge", Buckets: metrics.InternetFacingBuckets, }, - []string{"type", "result"}) + []string{"type"}) stats.MustRegister(remoteValidationTime) remoteValidationFailures := prometheus.NewCounter( prometheus.CounterOpts{ @@ -419,168 +419,122 @@ func detailedError(err error) *probs.ProblemDetails { return probs.Connection("Error getting validation data") } -// validate performs a challenge validation and, in parallel, -// checks CAA and GSB for the identifier. If any of those steps fails, it -// returns a ProblemDetails plus the validation records created during the -// validation attempt. -func (va *ValidationAuthorityImpl) validate( - ctx context.Context, - identifier identifier.ACMEIdentifier, - regid int64, - challenge core.Challenge, - keyAuthorization string, -) ([]core.ValidationRecord, error) { - - // If the identifier is a wildcard domain we need to validate the base - // domain by removing the "*." wildcard prefix. We create a separate - // `baseIdentifier` here before starting the `va.checkCAA` goroutine with the - // `identifier` to avoid a data race. - baseIdentifier := identifier - if strings.HasPrefix(identifier.Value, "*.") { - baseIdentifier.Value = strings.TrimPrefix(identifier.Value, "*.") - } - - validationRecords, err := va.validateChallenge(ctx, baseIdentifier, challenge, keyAuthorization) - if err != nil { - return validationRecords, err - } - - err = va.checkCAA(ctx, identifier, &caaParams{ - accountURIID: regid, - validationMethod: challenge.Type, - }) - if err != nil { - return validationRecords, err - } - - return validationRecords, nil -} - +// validateChallenge simply passes through to the appropriate validation method +// depending on the challenge type. func (va *ValidationAuthorityImpl) validateChallenge( ctx context.Context, - identifier identifier.ACMEIdentifier, - challenge core.Challenge, + ident identifier.ACMEIdentifier, + kind core.AcmeChallenge, + token string, keyAuthorization string, ) ([]core.ValidationRecord, error) { - err := challenge.CheckPending() - if err != nil { - return nil, berrors.MalformedError("Challenge failed consistency check: %s", err) - } - switch challenge.Type { + // Strip a (potential) leading wildcard token from the identifier. + ident.Value = strings.TrimPrefix(ident.Value, "*.") + + switch kind { case core.ChallengeTypeHTTP01: - return va.validateHTTP01(ctx, identifier, challenge.Token, keyAuthorization) + return va.validateHTTP01(ctx, ident, token, keyAuthorization) case core.ChallengeTypeDNS01: - return va.validateDNS01(ctx, identifier, keyAuthorization) + return va.validateDNS01(ctx, ident, keyAuthorization) case core.ChallengeTypeTLSALPN01: - return va.validateTLSALPN01(ctx, identifier, keyAuthorization) + return va.validateTLSALPN01(ctx, ident, keyAuthorization) } - return nil, berrors.MalformedError("invalid challenge type %s", challenge.Type) + return nil, berrors.MalformedError("invalid challenge type %s", kind) } -// performRemoteValidation calls `PerformValidation` for each of the configured -// remoteVAs in a random order. The provided `results` chan should have an equal -// size to the number of remote VAs. The validations will be performed in -// separate go-routines. If the result `error` from a remote -// `PerformValidation` RPC is nil or a nil `ProblemDetails` instance it is -// written directly to the `results` chan. If the err is a cancelled error it is -// treated as a nil error. Otherwise the error/problem is written to the results -// channel as-is. +// performRemoteValidation coordinates the whole process of kicking off and +// collecting results from calls to remote VAs' PerformValidation function. It +// returns a problem if too many remote perspectives failed to corroborate +// domain control, or nil if enough succeeded to surpass our corroboration +// threshold. func (va *ValidationAuthorityImpl) performRemoteValidation( ctx context.Context, req *vapb.PerformValidationRequest, - results chan<- *remoteVAResult) { - for _, i := range rand.Perm(len(va.remoteVAs)) { - remoteVA := va.remoteVAs[i] - go func(rva RemoteVA) { - result := &remoteVAResult{ - VAHostname: rva.Address, - } - res, err := rva.PerformValidation(ctx, req) - if err != nil && canceled.Is(err) { - // If the non-nil err was a canceled error, ignore it. That's fine: it - // just means we cancelled the remote VA request before it was - // finished because we didn't care about its result. Don't log to avoid - // spamming the logs. - result.Problem = probs.ServerInternal("Remote PerformValidation RPC canceled") - } else if err != nil { - // This is a real error, not just a problem with the validation. - va.log.Errf("Remote VA %q.PerformValidation failed: %s", rva.Address, err) - result.Problem = probs.ServerInternal("Remote PerformValidation RPC failed") - } else if res.Problems != nil { - prob, err := bgrpc.PBToProblemDetails(res.Problems) - if err != nil { - va.log.Infof("Remote VA %q.PerformValidation returned malformed problem: %s", rva.Address, err) - result.Problem = probs.ServerInternal( - fmt.Sprintf("Remote PerformValidation RPC returned malformed result: %s", err)) - } else { - va.log.Infof("Remote VA %q.PerformValidation returned problem: %s", rva.Address, prob) - result.Problem = prob - } - } - results <- result - }(remoteVA) +) *probs.ProblemDetails { + if len(va.remoteVAs) == 0 { + return nil } -} -// processRemoteValidationResults evaluates a primary VA result, and a channel -// of remote VA problems to produce a single overall validation result based on -// configured feature flags. The overall result is calculated based on the VA's -// configured `maxRemoteFailures` value, and the function returns as soon as -// that threshold has been exceeded or cannot possibly be exceeded. -func (va *ValidationAuthorityImpl) processRemoteValidationResults( - challengeType string, - remoteResultsChan <-chan *remoteVAResult) *probs.ProblemDetails { - - state := "failure" start := va.clk.Now() - defer func() { va.metrics.remoteValidationTime.With(prometheus.Labels{ - "type": challengeType, - "result": state, + "type": req.Challenge.Type, }).Observe(va.clk.Since(start).Seconds()) }() + type rvaResult struct { + hostname string + response *vapb.ValidationResult + err error + } + + results := make(chan *rvaResult) + + for _, i := range rand.Perm(len(va.remoteVAs)) { + remoteVA := va.remoteVAs[i] + go func(rva RemoteVA, out chan<- *rvaResult) { + res, err := rva.PerformValidation(ctx, req) + out <- &rvaResult{ + hostname: rva.Address, + response: res, + err: err, + } + }(remoteVA, results) + } + required := len(va.remoteVAs) - va.maxRemoteFailures good := 0 bad := 0 - - var remoteResults []*remoteVAResult var firstProb *probs.ProblemDetails - // Due to channel behavior this could block indefinitely and we rely on gRPC - // honoring the context deadline used in client calls to prevent that from - // happening. - for result := range remoteResultsChan { - // Add the result to the slice - remoteResults = append(remoteResults, result) - if result.Problem == nil { - good++ - } else { + + for res := range results { + var currProb *probs.ProblemDetails + + if res.err != nil { bad++ + + if canceled.Is(res.err) { + currProb = probs.ServerInternal("Remote PerformValidation RPC canceled") + } else { + va.log.Errf("Remote VA %q.PerformValidation failed: %s", res.hostname, res.err) + currProb = probs.ServerInternal("Remote PerformValidation RPC failed") + } + } else if res.response.Problems != nil { + bad++ + + var err error + currProb, err = bgrpc.PBToProblemDetails(res.response.Problems) + if err != nil { + va.log.Errf("Remote VA %q.PerformValidation returned malformed problem: %s", res.hostname, err) + currProb = probs.ServerInternal("Remote PerformValidation RPC returned malformed result") + } + } else { + good++ } - // Store the first non-nil problem to return later. - if firstProb == nil && result.Problem != nil { - firstProb = result.Problem + + if firstProb == nil && currProb != nil { + firstProb = currProb } + // Return as soon as we have enough successes or failures for a definitive result. if good >= required { - state = "success" return nil - } else if bad > va.maxRemoteFailures { - modifiedProblem := *result.Problem - modifiedProblem.Detail = "During secondary validation: " + firstProb.Detail - return &modifiedProblem + } + if bad > va.maxRemoteFailures { + va.metrics.remoteValidationFailures.Inc() + firstProb.Detail = fmt.Sprintf("During secondary validation: %s", firstProb.Detail) + return firstProb } // If we somehow haven't returned early, we need to break the loop once all // of the VAs have returned a result. - if len(remoteResults) == len(va.remoteVAs) { + if good+bad >= len(va.remoteVAs) { break } } - // This condition should not occur - it indicates the good/bad counts didn't - // meet either the required threshold or the maxRemoteFailures threshold. + // This condition should not occur - it indicates the good/bad counts neither + // met the required threshold nor the maxRemoteFailures threshold. return probs.ServerInternal("Too few remote PerformValidation RPC results") } @@ -641,6 +595,39 @@ type remoteVAResult struct { Problem *probs.ProblemDetails } +// performLocalValidation performs primary domain control validation and then +// checks CAA. If either step fails, it immediately returns a bare error so +// that our audit logging can include the underlying error. +func (va *ValidationAuthorityImpl) performLocalValidation( + ctx context.Context, + ident identifier.ACMEIdentifier, + regid int64, + kind core.AcmeChallenge, + token string, + keyAuthorization string, +) ([]core.ValidationRecord, error) { + // Do primary domain control validation. Any kind of error returned by this + // counts as a validation error, and will be converted into an appropriate + // probs.ProblemDetails by the calling function. + records, err := va.validateChallenge(ctx, ident, kind, token, keyAuthorization) + if err != nil { + return records, err + } + + // Do primary CAA checks. Any kind of error returned by this counts as not + // receiving permission to issue, and will be converted into an appropriate + // probs.ProblemDetails by the calling function. + err = va.checkCAA(ctx, ident, &caaParams{ + accountURIID: regid, + validationMethod: kind, + }) + if err != nil { + return records, err + } + + return records, nil +} + // PerformValidation validates the challenge for the domain in the request. // The returned result will always contain a list of validation records, even // when it also contains a problem. @@ -649,22 +636,15 @@ func (va *ValidationAuthorityImpl) PerformValidation(ctx context.Context, req *v if core.IsAnyNilOrZero(req, req.Domain, req.Challenge, req.Authz) { return nil, berrors.InternalServerError("Incomplete validation request") } - logEvent := verificationRequestEvent{ - ID: req.Authz.Id, - Requester: req.Authz.RegID, - Hostname: req.Domain, - } - vStart := va.clk.Now() - - var remoteResults chan *remoteVAResult - if remoteVACount := len(va.remoteVAs); remoteVACount > 0 { - remoteResults = make(chan *remoteVAResult, remoteVACount) - go va.performRemoteValidation(ctx, req, remoteResults) - } challenge, err := bgrpc.PBToChallenge(req.Challenge) if err != nil { - return nil, errors.New("Challenge failed to deserialize") + return nil, errors.New("challenge failed to deserialize") + } + + err = challenge.CheckPending() + if err != nil { + return nil, berrors.MalformedError("challenge failed consistency check: %s", err) } // TODO(#7514): Remove this fallback and belt-and-suspenders check. @@ -676,53 +656,63 @@ func (va *ValidationAuthorityImpl) PerformValidation(ctx context.Context, req *v return nil, errors.New("no expected keyAuthorization provided") } - records, err := va.validate(ctx, identifier.DNSIdentifier(req.Domain), req.Authz.RegID, challenge, keyAuthorization) - challenge.ValidationRecord = records - localValidationLatency := time.Since(vStart) + // Set up variables and a deferred closure to report validation latency + // metrics and log validation errors. Below here, do not use := to redeclare + // `prob`, or this will fail. + var prob *probs.ProblemDetails + var localLatency time.Duration + vStart := va.clk.Now() + logEvent := verificationRequestEvent{ + ID: req.Authz.Id, + Requester: req.Authz.RegID, + Hostname: req.Domain, + Challenge: challenge, + } + defer func() { + problemType := "" + if prob != nil { + problemType = string(prob.Type) + logEvent.Error = prob.Error() + logEvent.Challenge.Error = prob + logEvent.Challenge.Status = core.StatusInvalid + } else { + logEvent.Challenge.Status = core.StatusValid + } + + va.metrics.localValidationTime.With(prometheus.Labels{ + "type": string(logEvent.Challenge.Type), + "result": string(logEvent.Challenge.Status), + }).Observe(localLatency.Seconds()) + + va.metrics.validationTime.With(prometheus.Labels{ + "type": string(logEvent.Challenge.Type), + "result": string(logEvent.Challenge.Status), + "problem_type": problemType, + }).Observe(time.Since(vStart).Seconds()) + + logEvent.ValidationLatency = time.Since(vStart).Round(time.Millisecond).Seconds() + va.log.AuditObject("Validation result", logEvent) + }() + + // Do local validation. Note that we process the result in a couple ways + // *before* checking whether it returned an error. These few checks are + // carefully written to ensure that they work whether the local validation + // was successful or not, and cannot themselves fail. + records, err := va.performLocalValidation( + ctx, + identifier.DNSIdentifier(req.Domain), + req.Authz.RegID, + challenge.Type, + challenge.Token, + keyAuthorization) + localLatency = time.Since(vStart) // Check for malformed ValidationRecords - if !challenge.RecordsSane() && err == nil { - err = errors.New("Records for validation failed sanity check") + logEvent.Challenge.ValidationRecord = records + if err == nil && !logEvent.Challenge.RecordsSane() { + err = errors.New("records from local validation failed sanity check") } - var problemType string - var prob *probs.ProblemDetails - if err != nil { - prob = detailedError(err) - problemType = string(prob.Type) - challenge.Status = core.StatusInvalid - challenge.Error = prob - logEvent.Error = prob.Error() - logEvent.InternalError = err.Error() - } else if remoteResults != nil { - remoteProb := va.processRemoteValidationResults( - string(challenge.Type), - remoteResults) - - // If the remote result was a non-nil problem then fail the validation - if remoteProb != nil { - prob = remoteProb - challenge.Status = core.StatusInvalid - challenge.Error = remoteProb - // We only set .Error here, not .InternalError, because the - // remote VA doesn't send us the internal error. But that's ok, - // it got logged at the remote VA. - logEvent.Error = remoteProb.Error() - va.log.Infof("Validation failed due to remote failures: identifier=%v err=%s", - req.Domain, remoteProb) - va.metrics.remoteValidationFailures.Inc() - } else { - challenge.Status = core.StatusValid - } - } else { - challenge.Status = core.StatusValid - } - - logEvent.Challenge = challenge - - validationLatency := time.Since(vStart) - logEvent.ValidationLatency = validationLatency.Round(time.Millisecond).Seconds() - // Copy the "UsedRSAKEX" value from the last validationRecord into the log // event. Only the last record should have this bool set, because we only // record it if/when validation is finally successful, but we use the loop @@ -732,23 +722,19 @@ func (va *ValidationAuthorityImpl) PerformValidation(ctx context.Context, req *v logEvent.UsedRSAKEX = record.UsedRSAKEX || logEvent.UsedRSAKEX } - va.metrics.localValidationTime.With(prometheus.Labels{ - "type": string(challenge.Type), - "result": string(challenge.Status), - }).Observe(localValidationLatency.Seconds()) - va.metrics.validationTime.With(prometheus.Labels{ - "type": string(challenge.Type), - "result": string(challenge.Status), - "problem_type": problemType, - }).Observe(validationLatency.Seconds()) + if err != nil { + logEvent.InternalError = err.Error() + prob = detailedError(err) + return bgrpc.ValidationResultToPB(records, filterProblemDetails(prob)) + } - va.log.AuditObject("Validation result", logEvent) - - // The ProblemDetails will be serialized through gRPC, which requires UTF-8. - // It will also later be serialized in JSON, which defaults to UTF-8. Make - // sure it is UTF-8 clean now. - prob = filterProblemDetails(prob) - return bgrpc.ValidationResultToPB(records, prob) + // Do remote validation. We do this after local validation is complete to + // avoid wasting work when validation will fail anyway. This only returns a + // singular problem, because the remote VAs have already audit-logged their + // own validation records, and it's not helpful to present multiple large + // errors to the end user. + prob = va.performRemoteValidation(ctx, req) + return bgrpc.ValidationResultToPB(records, filterProblemDetails(prob)) } // usedRSAKEX returns true if the given cipher suite involves the use of an diff --git a/va/va_test.go b/va/va_test.go index ed6cde52e..a7ca0ee06 100644 --- a/va/va_test.go +++ b/va/va_test.go @@ -101,16 +101,6 @@ func createValidationRequest(domain string, challengeType core.AcmeChallenge) *v } } -func createChallenge(challengeType core.AcmeChallenge) core.Challenge { - return core.Challenge{ - Type: challengeType, - Status: core.StatusPending, - Token: expectedToken, - ValidationRecord: []core.ValidationRecord{}, - ProvidedKeyAuthorization: expectedKeyAuthorization, - } -} - // setup returns an in-memory VA and a mock logger. The default resolver client // is MockClient{}, but can be overridden. func setup(srv *httptest.Server, maxRemoteFailures int, userAgent string, remoteVAs []RemoteVA, mockDNSClientOverride bdns.Client) (*ValidationAuthorityImpl, *blog.Mock) { @@ -251,7 +241,7 @@ func (inmem inMemVA) IsCAAValid(ctx context.Context, req *vapb.IsCAAValidRequest func TestValidateMalformedChallenge(t *testing.T) { va, _ := setup(nil, 0, "", nil, nil) - _, err := va.validateChallenge(ctx, dnsi("example.com"), createChallenge("fake-type-01"), expectedKeyAuthorization) + _, err := va.validateChallenge(ctx, dnsi("example.com"), "fake-type-01", expectedToken, expectedKeyAuthorization) prob := detailedError(err) test.AssertEquals(t, prob.Type, probs.MalformedProblem)