Perform remote validation after primary validation (#7522)

Change the VA to perform remote validation wholly after local validation
and CAA checks, and to do so only if those local checks pass. This will
likely increase the latency of our successful validations, by making
them less parallel. However, it will reduce the amount of work we do on
unsuccessful validations, and reduce their latency, by not kicking off
and waiting for remote results.

Fixes https://github.com/letsencrypt/boulder/issues/7509
This commit is contained in:
Aaron Gable 2024-06-10 14:16:44 -07:00 committed by GitHub
parent e198d3529d
commit 5b647072b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 186 additions and 250 deletions

View File

@ -959,19 +959,18 @@ func TestMultiCAARechecking(t *testing.T) {
} }
func TestCAAFailure(t *testing.T) { func TestCAAFailure(t *testing.T) {
chall := createChallenge(core.ChallengeTypeHTTP01) hs := httpSrv(t, expectedToken)
hs := httpSrv(t, chall.Token)
defer hs.Close() defer hs.Close()
va, _ := setup(hs, 0, "", nil, caaMockDNS{}) va, _ := setup(hs, 0, "", nil, caaMockDNS{})
_, err := va.validate(ctx, dnsi("reserved.com"), 1, chall, expectedKeyAuthorization) err := va.checkCAA(ctx, dnsi("reserved.com"), &caaParams{1, core.ChallengeTypeHTTP01})
if err == nil { if err == nil {
t.Fatalf("Expected CAA rejection for reserved.com, got success") t.Fatalf("Expected CAA rejection for reserved.com, got success")
} }
test.AssertErrorIs(t, err, berrors.CAA) test.AssertErrorIs(t, err, berrors.CAA)
_, err = va.validate(ctx, dnsi("example.gonetld"), 1, chall, expectedKeyAuthorization) err = va.checkCAA(ctx, dnsi("example.gonetld"), &caaParams{1, core.ChallengeTypeHTTP01})
if err == nil { if err == nil {
t.Fatalf("Expected CAA rejection for gonetld, got success") t.Fatalf("Expected CAA rejection for gonetld, got success")
} }

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"strings"
"testing" "testing"
"time" "time"
@ -91,44 +90,6 @@ func TestDNSValidationInvalid(t *testing.T) {
test.AssertEquals(t, prob.Type, probs.MalformedProblem) test.AssertEquals(t, prob.Type, probs.MalformedProblem)
} }
func TestDNSValidationNotSane(t *testing.T) {
va, _ := setup(nil, 0, "", nil, nil)
chall := createChallenge(core.ChallengeTypeDNS01)
chall.Token = ""
_, err := va.validateChallenge(ctx, dnsi("localhost"), chall, expectedKeyAuthorization)
prob := detailedError(err)
if prob.Type != probs.MalformedProblem {
t.Errorf("Got wrong error type: expected %s, got %s",
prob.Type, probs.MalformedProblem)
}
if !strings.Contains(prob.Error(), "Challenge failed consistency check:") {
t.Errorf("Got wrong error: %s", prob.Error())
}
chall.Token = "yfCBb-bRTLz8Wd1C0lTUQK3qlKj3-t2tYGwx5Hj7r_"
_, err = va.validateChallenge(ctx, dnsi("localhost"), chall, expectedKeyAuthorization)
prob = detailedError(err)
if prob.Type != probs.MalformedProblem {
t.Errorf("Got wrong error type: expected %s, got %s",
prob.Type, probs.MalformedProblem)
}
if !strings.Contains(prob.Error(), "Challenge failed consistency check:") {
t.Errorf("Got wrong error: %s", prob.Error())
}
_, err = va.validateChallenge(ctx, dnsi("localhost"), chall, "a")
prob = detailedError(err)
if prob.Type != probs.MalformedProblem {
t.Errorf("Got wrong error type: expected %s, got %s",
prob.Type, probs.MalformedProblem)
}
if !strings.Contains(prob.Error(), "Challenge failed consistency check:") {
t.Errorf("Got wrong error: %s", prob.Error())
}
}
func TestDNSValidationServFail(t *testing.T) { func TestDNSValidationServFail(t *testing.T) {
va, _ := setup(nil, 0, "", nil, nil) va, _ := setup(nil, 0, "", nil, nil)

378
va/va.go
View File

@ -120,7 +120,7 @@ func initMetrics(stats prometheus.Registerer) *vaMetrics {
Help: "Time taken to remotely validate a challenge", Help: "Time taken to remotely validate a challenge",
Buckets: metrics.InternetFacingBuckets, Buckets: metrics.InternetFacingBuckets,
}, },
[]string{"type", "result"}) []string{"type"})
stats.MustRegister(remoteValidationTime) stats.MustRegister(remoteValidationTime)
remoteValidationFailures := prometheus.NewCounter( remoteValidationFailures := prometheus.NewCounter(
prometheus.CounterOpts{ prometheus.CounterOpts{
@ -419,168 +419,122 @@ func detailedError(err error) *probs.ProblemDetails {
return probs.Connection("Error getting validation data") return probs.Connection("Error getting validation data")
} }
// validate performs a challenge validation and, in parallel, // validateChallenge simply passes through to the appropriate validation method
// checks CAA and GSB for the identifier. If any of those steps fails, it // depending on the challenge type.
// returns a ProblemDetails plus the validation records created during the
// validation attempt.
func (va *ValidationAuthorityImpl) validate(
ctx context.Context,
identifier identifier.ACMEIdentifier,
regid int64,
challenge core.Challenge,
keyAuthorization string,
) ([]core.ValidationRecord, error) {
// If the identifier is a wildcard domain we need to validate the base
// domain by removing the "*." wildcard prefix. We create a separate
// `baseIdentifier` here before starting the `va.checkCAA` goroutine with the
// `identifier` to avoid a data race.
baseIdentifier := identifier
if strings.HasPrefix(identifier.Value, "*.") {
baseIdentifier.Value = strings.TrimPrefix(identifier.Value, "*.")
}
validationRecords, err := va.validateChallenge(ctx, baseIdentifier, challenge, keyAuthorization)
if err != nil {
return validationRecords, err
}
err = va.checkCAA(ctx, identifier, &caaParams{
accountURIID: regid,
validationMethod: challenge.Type,
})
if err != nil {
return validationRecords, err
}
return validationRecords, nil
}
func (va *ValidationAuthorityImpl) validateChallenge( func (va *ValidationAuthorityImpl) validateChallenge(
ctx context.Context, ctx context.Context,
identifier identifier.ACMEIdentifier, ident identifier.ACMEIdentifier,
challenge core.Challenge, kind core.AcmeChallenge,
token string,
keyAuthorization string, keyAuthorization string,
) ([]core.ValidationRecord, error) { ) ([]core.ValidationRecord, error) {
err := challenge.CheckPending() // Strip a (potential) leading wildcard token from the identifier.
if err != nil { ident.Value = strings.TrimPrefix(ident.Value, "*.")
return nil, berrors.MalformedError("Challenge failed consistency check: %s", err)
} switch kind {
switch challenge.Type {
case core.ChallengeTypeHTTP01: case core.ChallengeTypeHTTP01:
return va.validateHTTP01(ctx, identifier, challenge.Token, keyAuthorization) return va.validateHTTP01(ctx, ident, token, keyAuthorization)
case core.ChallengeTypeDNS01: case core.ChallengeTypeDNS01:
return va.validateDNS01(ctx, identifier, keyAuthorization) return va.validateDNS01(ctx, ident, keyAuthorization)
case core.ChallengeTypeTLSALPN01: case core.ChallengeTypeTLSALPN01:
return va.validateTLSALPN01(ctx, identifier, keyAuthorization) return va.validateTLSALPN01(ctx, ident, keyAuthorization)
} }
return nil, berrors.MalformedError("invalid challenge type %s", challenge.Type) return nil, berrors.MalformedError("invalid challenge type %s", kind)
} }
// performRemoteValidation calls `PerformValidation` for each of the configured // performRemoteValidation coordinates the whole process of kicking off and
// remoteVAs in a random order. The provided `results` chan should have an equal // collecting results from calls to remote VAs' PerformValidation function. It
// size to the number of remote VAs. The validations will be performed in // returns a problem if too many remote perspectives failed to corroborate
// separate go-routines. If the result `error` from a remote // domain control, or nil if enough succeeded to surpass our corroboration
// `PerformValidation` RPC is nil or a nil `ProblemDetails` instance it is // threshold.
// written directly to the `results` chan. If the err is a cancelled error it is
// treated as a nil error. Otherwise the error/problem is written to the results
// channel as-is.
func (va *ValidationAuthorityImpl) performRemoteValidation( func (va *ValidationAuthorityImpl) performRemoteValidation(
ctx context.Context, ctx context.Context,
req *vapb.PerformValidationRequest, req *vapb.PerformValidationRequest,
results chan<- *remoteVAResult) { ) *probs.ProblemDetails {
for _, i := range rand.Perm(len(va.remoteVAs)) { if len(va.remoteVAs) == 0 {
remoteVA := va.remoteVAs[i] return nil
go func(rva RemoteVA) {
result := &remoteVAResult{
VAHostname: rva.Address,
} }
res, err := rva.PerformValidation(ctx, req)
if err != nil && canceled.Is(err) {
// If the non-nil err was a canceled error, ignore it. That's fine: it
// just means we cancelled the remote VA request before it was
// finished because we didn't care about its result. Don't log to avoid
// spamming the logs.
result.Problem = probs.ServerInternal("Remote PerformValidation RPC canceled")
} else if err != nil {
// This is a real error, not just a problem with the validation.
va.log.Errf("Remote VA %q.PerformValidation failed: %s", rva.Address, err)
result.Problem = probs.ServerInternal("Remote PerformValidation RPC failed")
} else if res.Problems != nil {
prob, err := bgrpc.PBToProblemDetails(res.Problems)
if err != nil {
va.log.Infof("Remote VA %q.PerformValidation returned malformed problem: %s", rva.Address, err)
result.Problem = probs.ServerInternal(
fmt.Sprintf("Remote PerformValidation RPC returned malformed result: %s", err))
} else {
va.log.Infof("Remote VA %q.PerformValidation returned problem: %s", rva.Address, prob)
result.Problem = prob
}
}
results <- result
}(remoteVA)
}
}
// processRemoteValidationResults evaluates a primary VA result, and a channel
// of remote VA problems to produce a single overall validation result based on
// configured feature flags. The overall result is calculated based on the VA's
// configured `maxRemoteFailures` value, and the function returns as soon as
// that threshold has been exceeded or cannot possibly be exceeded.
func (va *ValidationAuthorityImpl) processRemoteValidationResults(
challengeType string,
remoteResultsChan <-chan *remoteVAResult) *probs.ProblemDetails {
state := "failure"
start := va.clk.Now() start := va.clk.Now()
defer func() { defer func() {
va.metrics.remoteValidationTime.With(prometheus.Labels{ va.metrics.remoteValidationTime.With(prometheus.Labels{
"type": challengeType, "type": req.Challenge.Type,
"result": state,
}).Observe(va.clk.Since(start).Seconds()) }).Observe(va.clk.Since(start).Seconds())
}() }()
type rvaResult struct {
hostname string
response *vapb.ValidationResult
err error
}
results := make(chan *rvaResult)
for _, i := range rand.Perm(len(va.remoteVAs)) {
remoteVA := va.remoteVAs[i]
go func(rva RemoteVA, out chan<- *rvaResult) {
res, err := rva.PerformValidation(ctx, req)
out <- &rvaResult{
hostname: rva.Address,
response: res,
err: err,
}
}(remoteVA, results)
}
required := len(va.remoteVAs) - va.maxRemoteFailures required := len(va.remoteVAs) - va.maxRemoteFailures
good := 0 good := 0
bad := 0 bad := 0
var remoteResults []*remoteVAResult
var firstProb *probs.ProblemDetails var firstProb *probs.ProblemDetails
// Due to channel behavior this could block indefinitely and we rely on gRPC
// honoring the context deadline used in client calls to prevent that from for res := range results {
// happening. var currProb *probs.ProblemDetails
for result := range remoteResultsChan {
// Add the result to the slice if res.err != nil {
remoteResults = append(remoteResults, result)
if result.Problem == nil {
good++
} else {
bad++ bad++
if canceled.Is(res.err) {
currProb = probs.ServerInternal("Remote PerformValidation RPC canceled")
} else {
va.log.Errf("Remote VA %q.PerformValidation failed: %s", res.hostname, res.err)
currProb = probs.ServerInternal("Remote PerformValidation RPC failed")
} }
// Store the first non-nil problem to return later. } else if res.response.Problems != nil {
if firstProb == nil && result.Problem != nil { bad++
firstProb = result.Problem
var err error
currProb, err = bgrpc.PBToProblemDetails(res.response.Problems)
if err != nil {
va.log.Errf("Remote VA %q.PerformValidation returned malformed problem: %s", res.hostname, err)
currProb = probs.ServerInternal("Remote PerformValidation RPC returned malformed result")
} }
} else {
good++
}
if firstProb == nil && currProb != nil {
firstProb = currProb
}
// Return as soon as we have enough successes or failures for a definitive result. // Return as soon as we have enough successes or failures for a definitive result.
if good >= required { if good >= required {
state = "success"
return nil return nil
} else if bad > va.maxRemoteFailures { }
modifiedProblem := *result.Problem if bad > va.maxRemoteFailures {
modifiedProblem.Detail = "During secondary validation: " + firstProb.Detail va.metrics.remoteValidationFailures.Inc()
return &modifiedProblem firstProb.Detail = fmt.Sprintf("During secondary validation: %s", firstProb.Detail)
return firstProb
} }
// If we somehow haven't returned early, we need to break the loop once all // If we somehow haven't returned early, we need to break the loop once all
// of the VAs have returned a result. // of the VAs have returned a result.
if len(remoteResults) == len(va.remoteVAs) { if good+bad >= len(va.remoteVAs) {
break break
} }
} }
// This condition should not occur - it indicates the good/bad counts didn't // This condition should not occur - it indicates the good/bad counts neither
// meet either the required threshold or the maxRemoteFailures threshold. // met the required threshold nor the maxRemoteFailures threshold.
return probs.ServerInternal("Too few remote PerformValidation RPC results") return probs.ServerInternal("Too few remote PerformValidation RPC results")
} }
@ -641,6 +595,39 @@ type remoteVAResult struct {
Problem *probs.ProblemDetails Problem *probs.ProblemDetails
} }
// performLocalValidation performs primary domain control validation and then
// checks CAA. If either step fails, it immediately returns a bare error so
// that our audit logging can include the underlying error.
func (va *ValidationAuthorityImpl) performLocalValidation(
ctx context.Context,
ident identifier.ACMEIdentifier,
regid int64,
kind core.AcmeChallenge,
token string,
keyAuthorization string,
) ([]core.ValidationRecord, error) {
// Do primary domain control validation. Any kind of error returned by this
// counts as a validation error, and will be converted into an appropriate
// probs.ProblemDetails by the calling function.
records, err := va.validateChallenge(ctx, ident, kind, token, keyAuthorization)
if err != nil {
return records, err
}
// Do primary CAA checks. Any kind of error returned by this counts as not
// receiving permission to issue, and will be converted into an appropriate
// probs.ProblemDetails by the calling function.
err = va.checkCAA(ctx, ident, &caaParams{
accountURIID: regid,
validationMethod: kind,
})
if err != nil {
return records, err
}
return records, nil
}
// PerformValidation validates the challenge for the domain in the request. // PerformValidation validates the challenge for the domain in the request.
// The returned result will always contain a list of validation records, even // The returned result will always contain a list of validation records, even
// when it also contains a problem. // when it also contains a problem.
@ -649,22 +636,15 @@ func (va *ValidationAuthorityImpl) PerformValidation(ctx context.Context, req *v
if core.IsAnyNilOrZero(req, req.Domain, req.Challenge, req.Authz) { if core.IsAnyNilOrZero(req, req.Domain, req.Challenge, req.Authz) {
return nil, berrors.InternalServerError("Incomplete validation request") return nil, berrors.InternalServerError("Incomplete validation request")
} }
logEvent := verificationRequestEvent{
ID: req.Authz.Id,
Requester: req.Authz.RegID,
Hostname: req.Domain,
}
vStart := va.clk.Now()
var remoteResults chan *remoteVAResult
if remoteVACount := len(va.remoteVAs); remoteVACount > 0 {
remoteResults = make(chan *remoteVAResult, remoteVACount)
go va.performRemoteValidation(ctx, req, remoteResults)
}
challenge, err := bgrpc.PBToChallenge(req.Challenge) challenge, err := bgrpc.PBToChallenge(req.Challenge)
if err != nil { if err != nil {
return nil, errors.New("Challenge failed to deserialize") return nil, errors.New("challenge failed to deserialize")
}
err = challenge.CheckPending()
if err != nil {
return nil, berrors.MalformedError("challenge failed consistency check: %s", err)
} }
// TODO(#7514): Remove this fallback and belt-and-suspenders check. // TODO(#7514): Remove this fallback and belt-and-suspenders check.
@ -676,53 +656,63 @@ func (va *ValidationAuthorityImpl) PerformValidation(ctx context.Context, req *v
return nil, errors.New("no expected keyAuthorization provided") return nil, errors.New("no expected keyAuthorization provided")
} }
records, err := va.validate(ctx, identifier.DNSIdentifier(req.Domain), req.Authz.RegID, challenge, keyAuthorization) // Set up variables and a deferred closure to report validation latency
challenge.ValidationRecord = records // metrics and log validation errors. Below here, do not use := to redeclare
localValidationLatency := time.Since(vStart) // `prob`, or this will fail.
var prob *probs.ProblemDetails
var localLatency time.Duration
vStart := va.clk.Now()
logEvent := verificationRequestEvent{
ID: req.Authz.Id,
Requester: req.Authz.RegID,
Hostname: req.Domain,
Challenge: challenge,
}
defer func() {
problemType := ""
if prob != nil {
problemType = string(prob.Type)
logEvent.Error = prob.Error()
logEvent.Challenge.Error = prob
logEvent.Challenge.Status = core.StatusInvalid
} else {
logEvent.Challenge.Status = core.StatusValid
}
va.metrics.localValidationTime.With(prometheus.Labels{
"type": string(logEvent.Challenge.Type),
"result": string(logEvent.Challenge.Status),
}).Observe(localLatency.Seconds())
va.metrics.validationTime.With(prometheus.Labels{
"type": string(logEvent.Challenge.Type),
"result": string(logEvent.Challenge.Status),
"problem_type": problemType,
}).Observe(time.Since(vStart).Seconds())
logEvent.ValidationLatency = time.Since(vStart).Round(time.Millisecond).Seconds()
va.log.AuditObject("Validation result", logEvent)
}()
// Do local validation. Note that we process the result in a couple ways
// *before* checking whether it returned an error. These few checks are
// carefully written to ensure that they work whether the local validation
// was successful or not, and cannot themselves fail.
records, err := va.performLocalValidation(
ctx,
identifier.DNSIdentifier(req.Domain),
req.Authz.RegID,
challenge.Type,
challenge.Token,
keyAuthorization)
localLatency = time.Since(vStart)
// Check for malformed ValidationRecords // Check for malformed ValidationRecords
if !challenge.RecordsSane() && err == nil { logEvent.Challenge.ValidationRecord = records
err = errors.New("Records for validation failed sanity check") if err == nil && !logEvent.Challenge.RecordsSane() {
err = errors.New("records from local validation failed sanity check")
} }
var problemType string
var prob *probs.ProblemDetails
if err != nil {
prob = detailedError(err)
problemType = string(prob.Type)
challenge.Status = core.StatusInvalid
challenge.Error = prob
logEvent.Error = prob.Error()
logEvent.InternalError = err.Error()
} else if remoteResults != nil {
remoteProb := va.processRemoteValidationResults(
string(challenge.Type),
remoteResults)
// If the remote result was a non-nil problem then fail the validation
if remoteProb != nil {
prob = remoteProb
challenge.Status = core.StatusInvalid
challenge.Error = remoteProb
// We only set .Error here, not .InternalError, because the
// remote VA doesn't send us the internal error. But that's ok,
// it got logged at the remote VA.
logEvent.Error = remoteProb.Error()
va.log.Infof("Validation failed due to remote failures: identifier=%v err=%s",
req.Domain, remoteProb)
va.metrics.remoteValidationFailures.Inc()
} else {
challenge.Status = core.StatusValid
}
} else {
challenge.Status = core.StatusValid
}
logEvent.Challenge = challenge
validationLatency := time.Since(vStart)
logEvent.ValidationLatency = validationLatency.Round(time.Millisecond).Seconds()
// Copy the "UsedRSAKEX" value from the last validationRecord into the log // Copy the "UsedRSAKEX" value from the last validationRecord into the log
// event. Only the last record should have this bool set, because we only // event. Only the last record should have this bool set, because we only
// record it if/when validation is finally successful, but we use the loop // record it if/when validation is finally successful, but we use the loop
@ -732,23 +722,19 @@ func (va *ValidationAuthorityImpl) PerformValidation(ctx context.Context, req *v
logEvent.UsedRSAKEX = record.UsedRSAKEX || logEvent.UsedRSAKEX logEvent.UsedRSAKEX = record.UsedRSAKEX || logEvent.UsedRSAKEX
} }
va.metrics.localValidationTime.With(prometheus.Labels{ if err != nil {
"type": string(challenge.Type), logEvent.InternalError = err.Error()
"result": string(challenge.Status), prob = detailedError(err)
}).Observe(localValidationLatency.Seconds()) return bgrpc.ValidationResultToPB(records, filterProblemDetails(prob))
va.metrics.validationTime.With(prometheus.Labels{ }
"type": string(challenge.Type),
"result": string(challenge.Status),
"problem_type": problemType,
}).Observe(validationLatency.Seconds())
va.log.AuditObject("Validation result", logEvent) // Do remote validation. We do this after local validation is complete to
// avoid wasting work when validation will fail anyway. This only returns a
// The ProblemDetails will be serialized through gRPC, which requires UTF-8. // singular problem, because the remote VAs have already audit-logged their
// It will also later be serialized in JSON, which defaults to UTF-8. Make // own validation records, and it's not helpful to present multiple large
// sure it is UTF-8 clean now. // errors to the end user.
prob = filterProblemDetails(prob) prob = va.performRemoteValidation(ctx, req)
return bgrpc.ValidationResultToPB(records, prob) return bgrpc.ValidationResultToPB(records, filterProblemDetails(prob))
} }
// usedRSAKEX returns true if the given cipher suite involves the use of an // usedRSAKEX returns true if the given cipher suite involves the use of an

View File

@ -101,16 +101,6 @@ func createValidationRequest(domain string, challengeType core.AcmeChallenge) *v
} }
} }
func createChallenge(challengeType core.AcmeChallenge) core.Challenge {
return core.Challenge{
Type: challengeType,
Status: core.StatusPending,
Token: expectedToken,
ValidationRecord: []core.ValidationRecord{},
ProvidedKeyAuthorization: expectedKeyAuthorization,
}
}
// setup returns an in-memory VA and a mock logger. The default resolver client // setup returns an in-memory VA and a mock logger. The default resolver client
// is MockClient{}, but can be overridden. // is MockClient{}, but can be overridden.
func setup(srv *httptest.Server, maxRemoteFailures int, userAgent string, remoteVAs []RemoteVA, mockDNSClientOverride bdns.Client) (*ValidationAuthorityImpl, *blog.Mock) { func setup(srv *httptest.Server, maxRemoteFailures int, userAgent string, remoteVAs []RemoteVA, mockDNSClientOverride bdns.Client) (*ValidationAuthorityImpl, *blog.Mock) {
@ -251,7 +241,7 @@ func (inmem inMemVA) IsCAAValid(ctx context.Context, req *vapb.IsCAAValidRequest
func TestValidateMalformedChallenge(t *testing.T) { func TestValidateMalformedChallenge(t *testing.T) {
va, _ := setup(nil, 0, "", nil, nil) va, _ := setup(nil, 0, "", nil, nil)
_, err := va.validateChallenge(ctx, dnsi("example.com"), createChallenge("fake-type-01"), expectedKeyAuthorization) _, err := va.validateChallenge(ctx, dnsi("example.com"), "fake-type-01", expectedToken, expectedKeyAuthorization)
prob := detailedError(err) prob := detailedError(err)
test.AssertEquals(t, prob.Type, probs.MalformedProblem) test.AssertEquals(t, prob.Type, probs.MalformedProblem)