Perform remote validation after primary validation (#7522)

Change the VA to perform remote validation wholly after local validation
and CAA checks, and to do so only if those local checks pass. This will
likely increase the latency of our successful validations, by making
them less parallel. However, it will reduce the amount of work we do on
unsuccessful validations, and reduce their latency, by not kicking off
and waiting for remote results.

Fixes https://github.com/letsencrypt/boulder/issues/7509
This commit is contained in:
Aaron Gable 2024-06-10 14:16:44 -07:00 committed by GitHub
parent e198d3529d
commit 5b647072b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 186 additions and 250 deletions

View File

@ -959,19 +959,18 @@ func TestMultiCAARechecking(t *testing.T) {
}
func TestCAAFailure(t *testing.T) {
chall := createChallenge(core.ChallengeTypeHTTP01)
hs := httpSrv(t, chall.Token)
hs := httpSrv(t, expectedToken)
defer hs.Close()
va, _ := setup(hs, 0, "", nil, caaMockDNS{})
_, err := va.validate(ctx, dnsi("reserved.com"), 1, chall, expectedKeyAuthorization)
err := va.checkCAA(ctx, dnsi("reserved.com"), &caaParams{1, core.ChallengeTypeHTTP01})
if err == nil {
t.Fatalf("Expected CAA rejection for reserved.com, got success")
}
test.AssertErrorIs(t, err, berrors.CAA)
_, err = va.validate(ctx, dnsi("example.gonetld"), 1, chall, expectedKeyAuthorization)
err = va.checkCAA(ctx, dnsi("example.gonetld"), &caaParams{1, core.ChallengeTypeHTTP01})
if err == nil {
t.Fatalf("Expected CAA rejection for gonetld, got success")
}

View File

@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
"strings"
"testing"
"time"
@ -91,44 +90,6 @@ func TestDNSValidationInvalid(t *testing.T) {
test.AssertEquals(t, prob.Type, probs.MalformedProblem)
}
func TestDNSValidationNotSane(t *testing.T) {
va, _ := setup(nil, 0, "", nil, nil)
chall := createChallenge(core.ChallengeTypeDNS01)
chall.Token = ""
_, err := va.validateChallenge(ctx, dnsi("localhost"), chall, expectedKeyAuthorization)
prob := detailedError(err)
if prob.Type != probs.MalformedProblem {
t.Errorf("Got wrong error type: expected %s, got %s",
prob.Type, probs.MalformedProblem)
}
if !strings.Contains(prob.Error(), "Challenge failed consistency check:") {
t.Errorf("Got wrong error: %s", prob.Error())
}
chall.Token = "yfCBb-bRTLz8Wd1C0lTUQK3qlKj3-t2tYGwx5Hj7r_"
_, err = va.validateChallenge(ctx, dnsi("localhost"), chall, expectedKeyAuthorization)
prob = detailedError(err)
if prob.Type != probs.MalformedProblem {
t.Errorf("Got wrong error type: expected %s, got %s",
prob.Type, probs.MalformedProblem)
}
if !strings.Contains(prob.Error(), "Challenge failed consistency check:") {
t.Errorf("Got wrong error: %s", prob.Error())
}
_, err = va.validateChallenge(ctx, dnsi("localhost"), chall, "a")
prob = detailedError(err)
if prob.Type != probs.MalformedProblem {
t.Errorf("Got wrong error type: expected %s, got %s",
prob.Type, probs.MalformedProblem)
}
if !strings.Contains(prob.Error(), "Challenge failed consistency check:") {
t.Errorf("Got wrong error: %s", prob.Error())
}
}
func TestDNSValidationServFail(t *testing.T) {
va, _ := setup(nil, 0, "", nil, nil)

378
va/va.go
View File

@ -120,7 +120,7 @@ func initMetrics(stats prometheus.Registerer) *vaMetrics {
Help: "Time taken to remotely validate a challenge",
Buckets: metrics.InternetFacingBuckets,
},
[]string{"type", "result"})
[]string{"type"})
stats.MustRegister(remoteValidationTime)
remoteValidationFailures := prometheus.NewCounter(
prometheus.CounterOpts{
@ -419,168 +419,122 @@ func detailedError(err error) *probs.ProblemDetails {
return probs.Connection("Error getting validation data")
}
// validate performs a challenge validation and, in parallel,
// checks CAA and GSB for the identifier. If any of those steps fails, it
// returns a ProblemDetails plus the validation records created during the
// validation attempt.
func (va *ValidationAuthorityImpl) validate(
ctx context.Context,
identifier identifier.ACMEIdentifier,
regid int64,
challenge core.Challenge,
keyAuthorization string,
) ([]core.ValidationRecord, error) {
// If the identifier is a wildcard domain we need to validate the base
// domain by removing the "*." wildcard prefix. We create a separate
// `baseIdentifier` here before starting the `va.checkCAA` goroutine with the
// `identifier` to avoid a data race.
baseIdentifier := identifier
if strings.HasPrefix(identifier.Value, "*.") {
baseIdentifier.Value = strings.TrimPrefix(identifier.Value, "*.")
}
validationRecords, err := va.validateChallenge(ctx, baseIdentifier, challenge, keyAuthorization)
if err != nil {
return validationRecords, err
}
err = va.checkCAA(ctx, identifier, &caaParams{
accountURIID: regid,
validationMethod: challenge.Type,
})
if err != nil {
return validationRecords, err
}
return validationRecords, nil
}
// validateChallenge simply passes through to the appropriate validation method
// depending on the challenge type.
func (va *ValidationAuthorityImpl) validateChallenge(
ctx context.Context,
identifier identifier.ACMEIdentifier,
challenge core.Challenge,
ident identifier.ACMEIdentifier,
kind core.AcmeChallenge,
token string,
keyAuthorization string,
) ([]core.ValidationRecord, error) {
err := challenge.CheckPending()
if err != nil {
return nil, berrors.MalformedError("Challenge failed consistency check: %s", err)
}
switch challenge.Type {
// Strip a (potential) leading wildcard token from the identifier.
ident.Value = strings.TrimPrefix(ident.Value, "*.")
switch kind {
case core.ChallengeTypeHTTP01:
return va.validateHTTP01(ctx, identifier, challenge.Token, keyAuthorization)
return va.validateHTTP01(ctx, ident, token, keyAuthorization)
case core.ChallengeTypeDNS01:
return va.validateDNS01(ctx, identifier, keyAuthorization)
return va.validateDNS01(ctx, ident, keyAuthorization)
case core.ChallengeTypeTLSALPN01:
return va.validateTLSALPN01(ctx, identifier, keyAuthorization)
return va.validateTLSALPN01(ctx, ident, keyAuthorization)
}
return nil, berrors.MalformedError("invalid challenge type %s", challenge.Type)
return nil, berrors.MalformedError("invalid challenge type %s", kind)
}
// performRemoteValidation calls `PerformValidation` for each of the configured
// remoteVAs in a random order. The provided `results` chan should have an equal
// size to the number of remote VAs. The validations will be performed in
// separate go-routines. If the result `error` from a remote
// `PerformValidation` RPC is nil or a nil `ProblemDetails` instance it is
// written directly to the `results` chan. If the err is a cancelled error it is
// treated as a nil error. Otherwise the error/problem is written to the results
// channel as-is.
// performRemoteValidation coordinates the whole process of kicking off and
// collecting results from calls to remote VAs' PerformValidation function. It
// returns a problem if too many remote perspectives failed to corroborate
// domain control, or nil if enough succeeded to surpass our corroboration
// threshold.
func (va *ValidationAuthorityImpl) performRemoteValidation(
ctx context.Context,
req *vapb.PerformValidationRequest,
results chan<- *remoteVAResult) {
for _, i := range rand.Perm(len(va.remoteVAs)) {
remoteVA := va.remoteVAs[i]
go func(rva RemoteVA) {
result := &remoteVAResult{
VAHostname: rva.Address,
}
res, err := rva.PerformValidation(ctx, req)
if err != nil && canceled.Is(err) {
// If the non-nil err was a canceled error, ignore it. That's fine: it
// just means we cancelled the remote VA request before it was
// finished because we didn't care about its result. Don't log to avoid
// spamming the logs.
result.Problem = probs.ServerInternal("Remote PerformValidation RPC canceled")
} else if err != nil {
// This is a real error, not just a problem with the validation.
va.log.Errf("Remote VA %q.PerformValidation failed: %s", rva.Address, err)
result.Problem = probs.ServerInternal("Remote PerformValidation RPC failed")
} else if res.Problems != nil {
prob, err := bgrpc.PBToProblemDetails(res.Problems)
if err != nil {
va.log.Infof("Remote VA %q.PerformValidation returned malformed problem: %s", rva.Address, err)
result.Problem = probs.ServerInternal(
fmt.Sprintf("Remote PerformValidation RPC returned malformed result: %s", err))
} else {
va.log.Infof("Remote VA %q.PerformValidation returned problem: %s", rva.Address, prob)
result.Problem = prob
}
}
results <- result
}(remoteVA)
) *probs.ProblemDetails {
if len(va.remoteVAs) == 0 {
return nil
}
}
// processRemoteValidationResults evaluates a primary VA result, and a channel
// of remote VA problems to produce a single overall validation result based on
// configured feature flags. The overall result is calculated based on the VA's
// configured `maxRemoteFailures` value, and the function returns as soon as
// that threshold has been exceeded or cannot possibly be exceeded.
func (va *ValidationAuthorityImpl) processRemoteValidationResults(
challengeType string,
remoteResultsChan <-chan *remoteVAResult) *probs.ProblemDetails {
state := "failure"
start := va.clk.Now()
defer func() {
va.metrics.remoteValidationTime.With(prometheus.Labels{
"type": challengeType,
"result": state,
"type": req.Challenge.Type,
}).Observe(va.clk.Since(start).Seconds())
}()
type rvaResult struct {
hostname string
response *vapb.ValidationResult
err error
}
results := make(chan *rvaResult)
for _, i := range rand.Perm(len(va.remoteVAs)) {
remoteVA := va.remoteVAs[i]
go func(rva RemoteVA, out chan<- *rvaResult) {
res, err := rva.PerformValidation(ctx, req)
out <- &rvaResult{
hostname: rva.Address,
response: res,
err: err,
}
}(remoteVA, results)
}
required := len(va.remoteVAs) - va.maxRemoteFailures
good := 0
bad := 0
var remoteResults []*remoteVAResult
var firstProb *probs.ProblemDetails
// Due to channel behavior this could block indefinitely and we rely on gRPC
// honoring the context deadline used in client calls to prevent that from
// happening.
for result := range remoteResultsChan {
// Add the result to the slice
remoteResults = append(remoteResults, result)
if result.Problem == nil {
good++
} else {
for res := range results {
var currProb *probs.ProblemDetails
if res.err != nil {
bad++
if canceled.Is(res.err) {
currProb = probs.ServerInternal("Remote PerformValidation RPC canceled")
} else {
va.log.Errf("Remote VA %q.PerformValidation failed: %s", res.hostname, res.err)
currProb = probs.ServerInternal("Remote PerformValidation RPC failed")
}
} else if res.response.Problems != nil {
bad++
var err error
currProb, err = bgrpc.PBToProblemDetails(res.response.Problems)
if err != nil {
va.log.Errf("Remote VA %q.PerformValidation returned malformed problem: %s", res.hostname, err)
currProb = probs.ServerInternal("Remote PerformValidation RPC returned malformed result")
}
} else {
good++
}
// Store the first non-nil problem to return later.
if firstProb == nil && result.Problem != nil {
firstProb = result.Problem
if firstProb == nil && currProb != nil {
firstProb = currProb
}
// Return as soon as we have enough successes or failures for a definitive result.
if good >= required {
state = "success"
return nil
} else if bad > va.maxRemoteFailures {
modifiedProblem := *result.Problem
modifiedProblem.Detail = "During secondary validation: " + firstProb.Detail
return &modifiedProblem
}
if bad > va.maxRemoteFailures {
va.metrics.remoteValidationFailures.Inc()
firstProb.Detail = fmt.Sprintf("During secondary validation: %s", firstProb.Detail)
return firstProb
}
// If we somehow haven't returned early, we need to break the loop once all
// of the VAs have returned a result.
if len(remoteResults) == len(va.remoteVAs) {
if good+bad >= len(va.remoteVAs) {
break
}
}
// This condition should not occur - it indicates the good/bad counts didn't
// meet either the required threshold or the maxRemoteFailures threshold.
// This condition should not occur - it indicates the good/bad counts neither
// met the required threshold nor the maxRemoteFailures threshold.
return probs.ServerInternal("Too few remote PerformValidation RPC results")
}
@ -641,6 +595,39 @@ type remoteVAResult struct {
Problem *probs.ProblemDetails
}
// performLocalValidation performs primary domain control validation and then
// checks CAA. If either step fails, it immediately returns a bare error so
// that our audit logging can include the underlying error.
func (va *ValidationAuthorityImpl) performLocalValidation(
ctx context.Context,
ident identifier.ACMEIdentifier,
regid int64,
kind core.AcmeChallenge,
token string,
keyAuthorization string,
) ([]core.ValidationRecord, error) {
// Do primary domain control validation. Any kind of error returned by this
// counts as a validation error, and will be converted into an appropriate
// probs.ProblemDetails by the calling function.
records, err := va.validateChallenge(ctx, ident, kind, token, keyAuthorization)
if err != nil {
return records, err
}
// Do primary CAA checks. Any kind of error returned by this counts as not
// receiving permission to issue, and will be converted into an appropriate
// probs.ProblemDetails by the calling function.
err = va.checkCAA(ctx, ident, &caaParams{
accountURIID: regid,
validationMethod: kind,
})
if err != nil {
return records, err
}
return records, nil
}
// PerformValidation validates the challenge for the domain in the request.
// The returned result will always contain a list of validation records, even
// when it also contains a problem.
@ -649,22 +636,15 @@ func (va *ValidationAuthorityImpl) PerformValidation(ctx context.Context, req *v
if core.IsAnyNilOrZero(req, req.Domain, req.Challenge, req.Authz) {
return nil, berrors.InternalServerError("Incomplete validation request")
}
logEvent := verificationRequestEvent{
ID: req.Authz.Id,
Requester: req.Authz.RegID,
Hostname: req.Domain,
}
vStart := va.clk.Now()
var remoteResults chan *remoteVAResult
if remoteVACount := len(va.remoteVAs); remoteVACount > 0 {
remoteResults = make(chan *remoteVAResult, remoteVACount)
go va.performRemoteValidation(ctx, req, remoteResults)
}
challenge, err := bgrpc.PBToChallenge(req.Challenge)
if err != nil {
return nil, errors.New("Challenge failed to deserialize")
return nil, errors.New("challenge failed to deserialize")
}
err = challenge.CheckPending()
if err != nil {
return nil, berrors.MalformedError("challenge failed consistency check: %s", err)
}
// TODO(#7514): Remove this fallback and belt-and-suspenders check.
@ -676,53 +656,63 @@ func (va *ValidationAuthorityImpl) PerformValidation(ctx context.Context, req *v
return nil, errors.New("no expected keyAuthorization provided")
}
records, err := va.validate(ctx, identifier.DNSIdentifier(req.Domain), req.Authz.RegID, challenge, keyAuthorization)
challenge.ValidationRecord = records
localValidationLatency := time.Since(vStart)
// Set up variables and a deferred closure to report validation latency
// metrics and log validation errors. Below here, do not use := to redeclare
// `prob`, or this will fail.
var prob *probs.ProblemDetails
var localLatency time.Duration
vStart := va.clk.Now()
logEvent := verificationRequestEvent{
ID: req.Authz.Id,
Requester: req.Authz.RegID,
Hostname: req.Domain,
Challenge: challenge,
}
defer func() {
problemType := ""
if prob != nil {
problemType = string(prob.Type)
logEvent.Error = prob.Error()
logEvent.Challenge.Error = prob
logEvent.Challenge.Status = core.StatusInvalid
} else {
logEvent.Challenge.Status = core.StatusValid
}
va.metrics.localValidationTime.With(prometheus.Labels{
"type": string(logEvent.Challenge.Type),
"result": string(logEvent.Challenge.Status),
}).Observe(localLatency.Seconds())
va.metrics.validationTime.With(prometheus.Labels{
"type": string(logEvent.Challenge.Type),
"result": string(logEvent.Challenge.Status),
"problem_type": problemType,
}).Observe(time.Since(vStart).Seconds())
logEvent.ValidationLatency = time.Since(vStart).Round(time.Millisecond).Seconds()
va.log.AuditObject("Validation result", logEvent)
}()
// Do local validation. Note that we process the result in a couple ways
// *before* checking whether it returned an error. These few checks are
// carefully written to ensure that they work whether the local validation
// was successful or not, and cannot themselves fail.
records, err := va.performLocalValidation(
ctx,
identifier.DNSIdentifier(req.Domain),
req.Authz.RegID,
challenge.Type,
challenge.Token,
keyAuthorization)
localLatency = time.Since(vStart)
// Check for malformed ValidationRecords
if !challenge.RecordsSane() && err == nil {
err = errors.New("Records for validation failed sanity check")
logEvent.Challenge.ValidationRecord = records
if err == nil && !logEvent.Challenge.RecordsSane() {
err = errors.New("records from local validation failed sanity check")
}
var problemType string
var prob *probs.ProblemDetails
if err != nil {
prob = detailedError(err)
problemType = string(prob.Type)
challenge.Status = core.StatusInvalid
challenge.Error = prob
logEvent.Error = prob.Error()
logEvent.InternalError = err.Error()
} else if remoteResults != nil {
remoteProb := va.processRemoteValidationResults(
string(challenge.Type),
remoteResults)
// If the remote result was a non-nil problem then fail the validation
if remoteProb != nil {
prob = remoteProb
challenge.Status = core.StatusInvalid
challenge.Error = remoteProb
// We only set .Error here, not .InternalError, because the
// remote VA doesn't send us the internal error. But that's ok,
// it got logged at the remote VA.
logEvent.Error = remoteProb.Error()
va.log.Infof("Validation failed due to remote failures: identifier=%v err=%s",
req.Domain, remoteProb)
va.metrics.remoteValidationFailures.Inc()
} else {
challenge.Status = core.StatusValid
}
} else {
challenge.Status = core.StatusValid
}
logEvent.Challenge = challenge
validationLatency := time.Since(vStart)
logEvent.ValidationLatency = validationLatency.Round(time.Millisecond).Seconds()
// Copy the "UsedRSAKEX" value from the last validationRecord into the log
// event. Only the last record should have this bool set, because we only
// record it if/when validation is finally successful, but we use the loop
@ -732,23 +722,19 @@ func (va *ValidationAuthorityImpl) PerformValidation(ctx context.Context, req *v
logEvent.UsedRSAKEX = record.UsedRSAKEX || logEvent.UsedRSAKEX
}
va.metrics.localValidationTime.With(prometheus.Labels{
"type": string(challenge.Type),
"result": string(challenge.Status),
}).Observe(localValidationLatency.Seconds())
va.metrics.validationTime.With(prometheus.Labels{
"type": string(challenge.Type),
"result": string(challenge.Status),
"problem_type": problemType,
}).Observe(validationLatency.Seconds())
if err != nil {
logEvent.InternalError = err.Error()
prob = detailedError(err)
return bgrpc.ValidationResultToPB(records, filterProblemDetails(prob))
}
va.log.AuditObject("Validation result", logEvent)
// The ProblemDetails will be serialized through gRPC, which requires UTF-8.
// It will also later be serialized in JSON, which defaults to UTF-8. Make
// sure it is UTF-8 clean now.
prob = filterProblemDetails(prob)
return bgrpc.ValidationResultToPB(records, prob)
// Do remote validation. We do this after local validation is complete to
// avoid wasting work when validation will fail anyway. This only returns a
// singular problem, because the remote VAs have already audit-logged their
// own validation records, and it's not helpful to present multiple large
// errors to the end user.
prob = va.performRemoteValidation(ctx, req)
return bgrpc.ValidationResultToPB(records, filterProblemDetails(prob))
}
// usedRSAKEX returns true if the given cipher suite involves the use of an

View File

@ -101,16 +101,6 @@ func createValidationRequest(domain string, challengeType core.AcmeChallenge) *v
}
}
func createChallenge(challengeType core.AcmeChallenge) core.Challenge {
return core.Challenge{
Type: challengeType,
Status: core.StatusPending,
Token: expectedToken,
ValidationRecord: []core.ValidationRecord{},
ProvidedKeyAuthorization: expectedKeyAuthorization,
}
}
// setup returns an in-memory VA and a mock logger. The default resolver client
// is MockClient{}, but can be overridden.
func setup(srv *httptest.Server, maxRemoteFailures int, userAgent string, remoteVAs []RemoteVA, mockDNSClientOverride bdns.Client) (*ValidationAuthorityImpl, *blog.Mock) {
@ -251,7 +241,7 @@ func (inmem inMemVA) IsCAAValid(ctx context.Context, req *vapb.IsCAAValidRequest
func TestValidateMalformedChallenge(t *testing.T) {
va, _ := setup(nil, 0, "", nil, nil)
_, err := va.validateChallenge(ctx, dnsi("example.com"), createChallenge("fake-type-01"), expectedKeyAuthorization)
_, err := va.validateChallenge(ctx, dnsi("example.com"), "fake-type-01", expectedToken, expectedKeyAuthorization)
prob := detailedError(err)
test.AssertEquals(t, prob.Type, probs.MalformedProblem)