From 8f1de3b57ea30e03bb4b753e2e3a371734e582dd Mon Sep 17 00:00:00 2001 From: Daniel McCarney Date: Tue, 21 Mar 2017 19:57:28 -0400 Subject: [PATCH 1/5] Allows expiration-mailer to use template as subject. (#2613) This commit resolves #2599 by adding support to the expiration-mailer to treat the subject for email messages as a template. This allows for the dynamic subject lines from #2435 to be used with a prefix for staging emails. --- cmd/expiration-mailer/main.go | 98 +++++++++++++++++------------- cmd/expiration-mailer/main_test.go | 40 ++++++------ cmd/expiration-mailer/send_test.go | 2 +- 3 files changed, 79 insertions(+), 61 deletions(-) diff --git a/cmd/expiration-mailer/main.go b/cmd/expiration-mailer/main.go index 071ac57bd..b83528a89 100644 --- a/cmd/expiration-mailer/main.go +++ b/cmd/expiration-mailer/main.go @@ -34,29 +34,26 @@ import ( sapb "github.com/letsencrypt/boulder/sa/proto" ) -const defaultNagCheckInterval = 24 * time.Hour - -type emailContent struct { - ExpirationDate string - DaysToExpiration int - DNSNames string -} +const ( + defaultNagCheckInterval = 24 * time.Hour + defaultExpirationSubject = "Let's Encrypt certificate expiration notice for domain {{.ExpirationSubject}}" +) type regStore interface { GetRegistration(context.Context, int64) (core.Registration, error) } type mailer struct { - stats metrics.Scope - log blog.Logger - dbMap *gorp.DbMap - rs regStore - mailer bmail.Mailer - emailTemplate *template.Template - subject string - nagTimes []time.Duration - limit int - clk clock.Clock + stats metrics.Scope + log blog.Logger + dbMap *gorp.DbMap + rs regStore + mailer bmail.Mailer + emailTemplate *template.Template + subjectTemplate *template.Template + nagTimes []time.Duration + limit int + clk clock.Clock } func (m *mailer) sendNags(contacts []string, certs []*x509.Certificate) error { @@ -101,33 +98,42 @@ func (m *mailer) sendNags(contacts []string, certs []*x509.Certificate) error { sort.Strings(domains) m.log.Debug(fmt.Sprintf("Sending mail for %s (%s)", strings.Join(domains, ", "), strings.Join(serials, ", "))) - var subject string - if m.subject != "" { - // If there is a subject from the configuration file, we should use it as-is - // to preserve the "classic" behaviour before we added a domain name. - subject = m.subject - } else { - // Otherwise, when no subject is configured we should make one using the - // domain names in the expiring certificate - subject = fmt.Sprintf("Certificate expiration notice for domain %q", domains[0]) - if len(domains) > 1 { - subject += fmt.Sprintf(" (and %d more)", len(domains)-1) - } + // Construct the information about the expiring certificates for use in the + // subject template + expiringSubject := fmt.Sprintf("%q", domains[0]) + if len(domains) > 1 { + expiringSubject += fmt.Sprintf(" (and %d more)", len(domains)-1) } - email := emailContent{ + // Execute the subjectTemplate by filling in the ExpirationSubject + subjBuf := new(bytes.Buffer) + err := m.subjectTemplate.Execute(subjBuf, struct { + ExpirationSubject string + }{ + ExpirationSubject: expiringSubject, + }) + if err != nil { + m.stats.Inc("Errors.SendingNag.SubjectTemplateFailure", 1) + return err + } + + email := struct { + ExpirationDate string + DaysToExpiration int + DNSNames string + }{ ExpirationDate: expDate.UTC().Format(time.RFC822Z), DaysToExpiration: int(expiresIn.Hours() / 24), DNSNames: strings.Join(domains, "\n"), } msgBuf := new(bytes.Buffer) - err := m.emailTemplate.Execute(msgBuf, email) + err = m.emailTemplate.Execute(msgBuf, email) if err != nil { m.stats.Inc("Errors.SendingNag.TemplateFailure", 1) return err } startSending := m.clk.Now() - err = m.mailer.SendMail(emails, subject, msgBuf.String()) + err = m.mailer.SendMail(emails, subjBuf.String(), msgBuf.String()) if err != nil { return err } @@ -444,6 +450,14 @@ func main() { tmpl, err := template.New("expiry-email").Parse(string(emailTmpl)) cmd.FailOnError(err, "Could not parse email template") + // If there is no configured subject template, use a default + if c.Mailer.Subject == "" { + c.Mailer.Subject = defaultExpirationSubject + } + // Load subject template + subjTmpl, err := template.New("expiry-email-subject").Parse(c.Mailer.Subject) + cmd.FailOnError(err, fmt.Sprintf("Could not parse email subject template")) + fromAddress, err := netmail.ParseAddress(c.Mailer.From) cmd.FailOnError(err, fmt.Sprintf("Could not parse from address: %s", c.Mailer.From)) @@ -482,16 +496,16 @@ func main() { sort.Sort(nags) m := mailer{ - stats: scope, - subject: c.Mailer.Subject, - log: logger, - dbMap: dbMap, - rs: sac, - mailer: mailClient, - emailTemplate: tmpl, - nagTimes: nags, - limit: c.Mailer.CertLimit, - clk: cmd.Clock(), + stats: scope, + log: logger, + dbMap: dbMap, + rs: sac, + mailer: mailClient, + subjectTemplate: subjTmpl, + emailTemplate: tmpl, + nagTimes: nags, + limit: c.Mailer.CertLimit, + clk: cmd.Clock(), } go cmd.DebugServer(c.Mailer.DebugAddr) diff --git a/cmd/expiration-mailer/main_test.go b/cmd/expiration-mailer/main_test.go index b05f0eabf..3a265fe3e 100644 --- a/cmd/expiration-mailer/main_test.go +++ b/cmd/expiration-mailer/main_test.go @@ -94,9 +94,10 @@ var ( "n":"rFH5kUBZrlPj73epjJjyCxzVzZuV--JjKgapoqm9pOuOt20BUTdHqVfC2oDclqM7HFhkkX9OSJMTHgZ7WaVqZv9u1X2yjdx9oVmMLuspX7EytW_ZKDZSzL-sCOFCuQAuYKkLbsdcA3eHBK_lwc4zwdeHFMKIulNvLqckkqYB9s8GpgNXBDIQ8GjR5HuJke_WUNjYHSd8jY1LU9swKWsLQe2YoQUz_ekQvBvBCoaFEtrtRaSJKNLIVDObXFr2TLIiFiM0Em90kK01-eQ7ZiruZTKomll64bRFPoNo4_uwubddg3xTqur2vdF3NyhTrYdvAgTem4uC0PFjEQ1bK_djBQ", "e":"AQAB" }`) - log = blog.UseMock() - tmpl = template.Must(template.New("expiry-email").Parse(testTmpl)) - ctx = context.Background() + log = blog.UseMock() + tmpl = template.Must(template.New("expiry-email").Parse(testTmpl)) + subjTmpl = template.Must(template.New("expiry-email-subject").Parse("Testing: " + defaultExpirationSubject)) + ctx = context.Background() ) func TestSendNags(t *testing.T) { @@ -105,15 +106,17 @@ func TestSendNags(t *testing.T) { rs := newFakeRegStore() fc := newFakeClock(t) + staticTmpl := template.Must(template.New("expiry-email-subject-static").Parse(testEmailSubject)) + m := mailer{ stats: stats, log: log, mailer: &mc, emailTemplate: tmpl, // Explicitly override the default subject to use testEmailSubject - subject: testEmailSubject, - rs: rs, - clk: fc, + subjectTemplate: staticTmpl, + rs: rs, + clk: fc, } cert := &x509.Certificate{ @@ -222,14 +225,14 @@ func TestFindExpiringCertificates(t *testing.T) { To: emailARaw, // A certificate with only one domain should have only one domain listed in // the subject - Subject: "Certificate expiration notice for domain \"example-a.com\"", + Subject: "Testing: Let's Encrypt certificate expiration notice for domain \"example-a.com\"", Body: "hi, cert for DNS names example-a.com is going to expire in 0 days (03 Jan 06 14:04 +0000)", }, testCtx.mc.Messages[0]) test.AssertEquals(t, mocks.MailerMessage{ To: emailBRaw, // A certificate with two domains should have only one domain listed and an // additional count included - Subject: "Certificate expiration notice for domain \"another.example-c.com\" (and 1 more)", + Subject: "Testing: Let's Encrypt certificate expiration notice for domain \"another.example-c.com\" (and 1 more)", Body: "hi, cert for DNS names another.example-c.com\nexample-c.com is going to expire in 7 days (09 Jan 06 16:04 +0000)", }, testCtx.mc.Messages[1]) @@ -838,7 +841,7 @@ func TestDedupOnRegistration(t *testing.T) { To: emailARaw, // A certificate with three domain names should have one in the subject and // a count of '2 more' at the end - Subject: "Certificate expiration notice for domain \"example-a.com\" (and 2 more)", + Subject: "Testing: Let's Encrypt certificate expiration notice for domain \"example-a.com\" (and 2 more)", Body: fmt.Sprintf(`hi, cert for DNS names %s is going to expire in 1 days (%s)`, domains, rawCertB.NotAfter.Format(time.RFC822Z)), @@ -878,15 +881,16 @@ func setup(t *testing.T, nagTimes []time.Duration) *testCtx { } m := &mailer{ - log: log, - stats: stats, - mailer: mc, - emailTemplate: tmpl, - dbMap: dbMap, - rs: ssa, - nagTimes: offsetNags, - limit: 100, - clk: fc, + log: log, + stats: stats, + mailer: mc, + emailTemplate: tmpl, + subjectTemplate: subjTmpl, + dbMap: dbMap, + rs: ssa, + nagTimes: offsetNags, + limit: 100, + clk: fc, } return &testCtx{ dbMap: dbMap, diff --git a/cmd/expiration-mailer/send_test.go b/cmd/expiration-mailer/send_test.go index 2a59d3b75..dfd1ff9ed 100644 --- a/cmd/expiration-mailer/send_test.go +++ b/cmd/expiration-mailer/send_test.go @@ -45,7 +45,7 @@ func TestSendEarliestCertInfo(t *testing.T) { } domains := "example-a.com\nexample-b.com\nshared-example.com" expected := mocks.MailerMessage{ - Subject: "Certificate expiration notice for domain \"example-a.com\" (and 2 more)", + Subject: "Testing: Let's Encrypt certificate expiration notice for domain \"example-a.com\" (and 2 more)", Body: fmt.Sprintf(`hi, cert for DNS names %s is going to expire in 2 days (%s)`, domains, rawCertB.NotAfter.Format(time.RFC822Z)), From c71c3cff80b38deda78c7db33dfc56e3855d511f Mon Sep 17 00:00:00 2001 From: David Calavera Date: Wed, 22 Mar 2017 10:17:59 -0700 Subject: [PATCH 2/5] Implement TLS-SNI-02 challenge validations. (#2585) I think these are all the necessary changes to implement TLS-SNI-02 validations, according to the section 7.3 of draft 05: https://tools.ietf.org/html/draft-ietf-acme-acme-05#section-7.3 I don't have much experience with this code, I'll really appreciate your feedback. Signed-off-by: David Calavera --- core/challenges.go | 9 +- core/core_test.go | 13 +++ core/objects.go | 6 ++ core/objects_test.go | 2 +- features/featureflag_string.go | 4 +- features/features.go | 2 + policy/pa.go | 4 + test/config-next/ca.json | 3 +- test/config-next/cert-checker.json | 6 +- test/config-next/ra.json | 3 +- va/va.go | 155 +++++++++++++++++++++++------ va/va_test.go | 120 +++++++++++++++++++--- 12 files changed, 271 insertions(+), 56 deletions(-) diff --git a/core/challenges.go b/core/challenges.go index 722439520..2bdab6242 100644 --- a/core/challenges.go +++ b/core/challenges.go @@ -13,12 +13,17 @@ func HTTPChallenge01() Challenge { return newChallenge(ChallengeTypeHTTP01) } -// TLSSNIChallenge01 constructs a random tls-sni-00 challenge +// TLSSNIChallenge01 constructs a random tls-sni-01 challenge func TLSSNIChallenge01() Challenge { return newChallenge(ChallengeTypeTLSSNI01) } -// DNSChallenge01 constructs a random DNS challenge +// TLSSNIChallenge02 constructs a random tls-sni-02 challenge +func TLSSNIChallenge02() Challenge { + return newChallenge(ChallengeTypeTLSSNI02) +} + +// DNSChallenge01 constructs a random dns-01 challenge func DNSChallenge01() Challenge { return newChallenge(ChallengeTypeDNS01) } diff --git a/core/core_test.go b/core/core_test.go index 821b84d3b..3e6f44c52 100644 --- a/core/core_test.go +++ b/core/core_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "testing" + "github.com/letsencrypt/boulder/features" "github.com/letsencrypt/boulder/test" "gopkg.in/square/go-jose.v1" ) @@ -34,6 +35,11 @@ func TestChallenges(t *testing.T) { t.Errorf("New tls-sni-01 challenge is not sane: %v", tlssni01) } + tlssni02 := TLSSNIChallenge02() + if !tlssni02.IsSane(false) { + t.Errorf("New tls-sni-02 challenge is not sane: %v", tlssni02) + } + dns01 := DNSChallenge01() if !dns01.IsSane(false) { t.Errorf("New dns-01 challenge is not sane: %v", dns01) @@ -43,6 +49,13 @@ func TestChallenges(t *testing.T) { test.Assert(t, ValidChallenge(ChallengeTypeTLSSNI01), "Refused valid challenge") test.Assert(t, ValidChallenge(ChallengeTypeDNS01), "Refused valid challenge") test.Assert(t, !ValidChallenge("nonsense-71"), "Accepted invalid challenge") + + test.Assert(t, !ValidChallenge(ChallengeTypeTLSSNI02), "Accepted invalid challenge") + + _ = features.Set(map[string]bool{"AllowTLS02Challenges": true}) + defer features.Reset() + + test.Assert(t, ValidChallenge(ChallengeTypeTLSSNI02), "Refused valid challenge") } // objects.go diff --git a/core/objects.go b/core/objects.go index 77c44c717..6ee8734d5 100644 --- a/core/objects.go +++ b/core/objects.go @@ -12,6 +12,7 @@ import ( "gopkg.in/square/go-jose.v1" + "github.com/letsencrypt/boulder/features" "github.com/letsencrypt/boulder/probs" "github.com/letsencrypt/boulder/revocation" ) @@ -69,6 +70,7 @@ const ( const ( ChallengeTypeHTTP01 = "http-01" ChallengeTypeTLSSNI01 = "tls-sni-01" + ChallengeTypeTLSSNI02 = "tls-sni-02" ChallengeTypeDNS01 = "dns-01" ) @@ -81,6 +83,8 @@ func ValidChallenge(name string) bool { fallthrough case ChallengeTypeDNS01: return true + case ChallengeTypeTLSSNI02: + return features.Enabled(features.AllowTLS02Challenges) default: return false @@ -261,6 +265,8 @@ func (ch Challenge) RecordsSane() bool { } } case ChallengeTypeTLSSNI01: + fallthrough + case ChallengeTypeTLSSNI02: if len(ch.ValidationRecord) > 1 { return false } diff --git a/core/objects_test.go b/core/objects_test.go index c4efc5dc7..46afb0afc 100644 --- a/core/objects_test.go +++ b/core/objects_test.go @@ -57,7 +57,7 @@ func TestChallengeSanityCheck(t *testing.T) { }`), &accountKey) test.AssertNotError(t, err, "Error unmarshaling JWK") - types := []string{ChallengeTypeHTTP01, ChallengeTypeTLSSNI01, ChallengeTypeDNS01} + types := []string{ChallengeTypeHTTP01, ChallengeTypeTLSSNI01, ChallengeTypeTLSSNI02, ChallengeTypeDNS01} for _, challengeType := range types { chall := Challenge{ Type: challengeType, diff --git a/features/featureflag_string.go b/features/featureflag_string.go index 3e57e958c..8d243bd97 100644 --- a/features/featureflag_string.go +++ b/features/featureflag_string.go @@ -4,9 +4,9 @@ package features import "fmt" -const _FeatureFlag_name = "unusedIDNASupportAllowAccountDeactivationAllowKeyRolloverResubmitMissingSCTsOnlyGoogleSafeBrowsingV4UseAIAIssuerURL" +const _FeatureFlag_name = "unusedIDNASupportAllowAccountDeactivationAllowKeyRolloverResubmitMissingSCTsOnlyGoogleSafeBrowsingV4UseAIAIssuerURLAllowTLS02Challenges" -var _FeatureFlag_index = [...]uint8{0, 6, 17, 41, 57, 80, 100, 115} +var _FeatureFlag_index = [...]uint8{0, 6, 17, 41, 57, 80, 100, 115, 135} func (i FeatureFlag) String() string { if i < 0 || i >= FeatureFlag(len(_FeatureFlag_index)-1) { diff --git a/features/features.go b/features/features.go index b75e8cdee..bcb0c9f0d 100644 --- a/features/features.go +++ b/features/features.go @@ -18,6 +18,7 @@ const ( ResubmitMissingSCTsOnly GoogleSafeBrowsingV4 UseAIAIssuerURL + AllowTLS02Challenges ) // List of features and their default value, protected by fMu @@ -29,6 +30,7 @@ var features = map[FeatureFlag]bool{ ResubmitMissingSCTsOnly: false, GoogleSafeBrowsingV4: false, UseAIAIssuerURL: false, + AllowTLS02Challenges: false, } var fMu = new(sync.RWMutex) diff --git a/policy/pa.go b/policy/pa.go index 48817873d..62c49ae49 100644 --- a/policy/pa.go +++ b/policy/pa.go @@ -286,6 +286,10 @@ func (pa *AuthorityImpl) ChallengesFor(identifier core.AcmeIdentifier) ([]core.C challenges = append(challenges, core.TLSSNIChallenge01()) } + if features.Enabled(features.AllowTLS02Challenges) && pa.enabledChallenges[core.ChallengeTypeTLSSNI02] { + challenges = append(challenges, core.TLSSNIChallenge02()) + } + if pa.enabledChallenges[core.ChallengeTypeDNS01] { challenges = append(challenges, core.DNSChallenge01()) } diff --git a/test/config-next/ca.json b/test/config-next/ca.json index 4528a97c4..cf402993e 100644 --- a/test/config-next/ca.json +++ b/test/config-next/ca.json @@ -134,7 +134,8 @@ "serviceQueue": "CA.server" }, "features": { - "IDNASupport": true + "IDNASupport": true, + "AllowTLS02Challenges": true } }, diff --git a/test/config-next/cert-checker.json b/test/config-next/cert-checker.json index 55e1bbf27..efe84a669 100644 --- a/test/config-next/cert-checker.json +++ b/test/config-next/cert-checker.json @@ -3,7 +3,8 @@ "dbConnectFile": "test/secrets/cert_checker_dburl", "maxDBConns": 10, "features": { - "IDNASupport": true + "IDNASupport": true, + "AllowTLS02Challenges": true }, "hostnamePolicyFile": "test/hostname-policy.json" }, @@ -12,7 +13,8 @@ "challenges": { "http-01": true, "tls-sni-01": true, - "dns-01": true + "dns-01": true, + "tls-sni-02": true } }, diff --git a/test/config-next/ra.json b/test/config-next/ra.json index 5269bf0ae..c4eab5046 100644 --- a/test/config-next/ra.json +++ b/test/config-next/ra.json @@ -46,7 +46,8 @@ }, "features": { "IDNASupport": true, - "AllowKeyRollover": true + "AllowKeyRollover": true, + "AllowTLS02Challenges": true } }, diff --git a/va/va.go b/va/va.go index 49e1475fd..f82ec3c10 100644 --- a/va/va.go +++ b/va/va.go @@ -303,7 +303,7 @@ func certNames(cert *x509.Certificate) []string { return names } -func (va *ValidationAuthorityImpl) validateTLSWithZName(ctx context.Context, identifier core.AcmeIdentifier, challenge core.Challenge, zName string) ([]core.ValidationRecord, *probs.ProblemDetails) { +func (va *ValidationAuthorityImpl) validateTLSSNI01WithZName(ctx context.Context, identifier core.AcmeIdentifier, challenge core.Challenge, zName string) ([]core.ValidationRecord, *probs.ProblemDetails) { addr, allAddrs, problem := va.getAddr(ctx, identifier.Value) validationRecords := []core.ValidationRecord{ { @@ -320,32 +320,12 @@ func (va *ValidationAuthorityImpl) validateTLSWithZName(ctx context.Context, ide portString := strconv.Itoa(va.tlsPort) hostPort := net.JoinHostPort(addr.String(), portString) validationRecords[0].Port = portString - va.log.Info(fmt.Sprintf("%s [%s] Attempting to validate for %s %s", challenge.Type, identifier, hostPort, zName)) - conn, err := tls.DialWithDialer(&net.Dialer{Timeout: validationTimeout}, "tcp", hostPort, &tls.Config{ - ServerName: zName, - InsecureSkipVerify: true, - }) - if err != nil { - va.log.Info(fmt.Sprintf("TLS-01 connection failure for %s. err=[%#v] errStr=[%s]", identifier, err, err)) - return validationRecords, - parseHTTPConnError(fmt.Sprintf("Failed to connect to %s for TLS-SNI-01 challenge", hostPort), err) + certs, problem := va.getTLSSNICerts(hostPort, identifier, challenge, zName) + if problem != nil { + return validationRecords, problem } - // close errors are not important here - defer func() { - _ = conn.Close() - }() - // Check that zName is a dNSName SAN in the server's certificate - certs := conn.ConnectionState().PeerCertificates - if len(certs) == 0 { - va.log.Info(fmt.Sprintf("TLS-SNI-01 challenge for %s resulted in no certificates", identifier.Value)) - return validationRecords, probs.Unauthorized("No certs presented for TLS SNI challenge") - } - for i, cert := range certs { - va.log.AuditInfo(fmt.Sprintf("TLS-SNI-01 challenge for %s received certificate (%d of %d): cert=[%s]", - identifier.Value, i+1, len(certs), hex.EncodeToString(cert.Raw))) - } leafCert := certs[0] for _, name := range leafCert.DNSNames { if subtle.ConstantTimeCompare([]byte(name), []byte(zName)) == 1 { @@ -355,14 +335,100 @@ func (va *ValidationAuthorityImpl) validateTLSWithZName(ctx context.Context, ide names := certNames(leafCert) errText := fmt.Sprintf( - "Incorrect validation certificate for TLS-SNI-01 challenge. "+ + "Incorrect validation certificate for %s challenge. "+ "Requested %s from %s. Received %d certificate(s), "+ "first certificate had names %q", - zName, hostPort, len(certs), strings.Join(names, ", ")) - va.log.Info(fmt.Sprintf("Remote host failed to give TLS-01 challenge name. host: %s", identifier)) + challenge.Type, zName, hostPort, len(certs), strings.Join(names, ", ")) + va.log.Info(fmt.Sprintf("Remote host failed to give %s challenge name. host: %s", challenge.Type, identifier)) return validationRecords, probs.Unauthorized(errText) } +func (va *ValidationAuthorityImpl) validateTLSSNI02WithZNames(ctx context.Context, identifier core.AcmeIdentifier, challenge core.Challenge, sanAName, sanBName string) ([]core.ValidationRecord, *probs.ProblemDetails) { + addr, allAddrs, problem := va.getAddr(ctx, identifier.Value) + validationRecords := []core.ValidationRecord{ + { + Hostname: identifier.Value, + AddressesResolved: allAddrs, + AddressUsed: addr, + }, + } + if problem != nil { + return validationRecords, problem + } + + // Make a connection with SNI = nonceName + portString := strconv.Itoa(va.tlsPort) + hostPort := net.JoinHostPort(addr.String(), portString) + validationRecords[0].Port = portString + + certs, problem := va.getTLSSNICerts(hostPort, identifier, challenge, sanAName) + if problem != nil { + return validationRecords, problem + } + + leafCert := certs[0] + if len(leafCert.DNSNames) != 2 { + names := strings.Join(certNames(leafCert), ", ") + msg := fmt.Sprintf("%s challenge certificate doesn't include exactly 2 DNSName entries. Received %d certificate(s), first certificate had names %q", challenge.Type, len(certs), names) + return validationRecords, probs.Malformed(msg) + } + + var validSanAName, validSanBName bool + for _, name := range leafCert.DNSNames { + // Note: ConstantTimeCompare is not strictly necessary here, but can't hurt. + if subtle.ConstantTimeCompare([]byte(name), []byte(sanAName)) == 1 { + validSanAName = true + } + + if subtle.ConstantTimeCompare([]byte(name), []byte(sanBName)) == 1 { + validSanBName = true + } + } + + if validSanAName && validSanBName { + return validationRecords, nil + } + + names := certNames(leafCert) + errText := fmt.Sprintf( + "Incorrect validation certificate for %s challenge. "+ + "Requested %s from %s. Received %d certificate(s), "+ + "first certificate had names %q", + challenge.Type, sanAName, hostPort, len(certs), strings.Join(names, ", ")) + va.log.Info(fmt.Sprintf("Remote host failed to give %s challenge name. host: %s", challenge.Type, identifier)) + return validationRecords, probs.Unauthorized(errText) +} + +func (va *ValidationAuthorityImpl) getTLSSNICerts(hostPort string, identifier core.AcmeIdentifier, challenge core.Challenge, zName string) ([]*x509.Certificate, *probs.ProblemDetails) { + va.log.Info(fmt.Sprintf("%s [%s] Attempting to validate for %s %s", challenge.Type, identifier, hostPort, zName)) + conn, err := tls.DialWithDialer(&net.Dialer{Timeout: validationTimeout}, "tcp", hostPort, &tls.Config{ + ServerName: zName, + InsecureSkipVerify: true, + }) + + if err != nil { + va.log.Info(fmt.Sprintf("%s connection failure for %s. err=[%#v] errStr=[%s]", challenge.Type, identifier, err, err)) + return nil, + parseHTTPConnError(fmt.Sprintf("Failed to connect to %s for %s challenge", hostPort, challenge.Type), err) + } + // close errors are not important here + defer func() { + _ = conn.Close() + }() + + // Check that zName is a dNSName SAN in the server's certificate + certs := conn.ConnectionState().PeerCertificates + if len(certs) == 0 { + va.log.Info(fmt.Sprintf("%s challenge for %s resulted in no certificates", challenge.Type, identifier.Value)) + return nil, probs.Unauthorized(fmt.Sprintf("No certs presented for %s challenge", challenge.Type)) + } + for i, cert := range certs { + va.log.AuditInfo(fmt.Sprintf("%s challenge for %s received certificate (%d of %d): cert=[%s]", + challenge.Type, identifier.Value, i+1, len(certs), hex.EncodeToString(cert.Raw))) + } + return certs, nil +} + func (va *ValidationAuthorityImpl) validateHTTP01(ctx context.Context, identifier core.AcmeIdentifier, challenge core.Challenge) ([]core.ValidationRecord, *probs.ProblemDetails) { if identifier.Type != core.IdentifierDNS { va.log.Info(fmt.Sprintf("Got non-DNS identifier for HTTP validation: %s", identifier)) @@ -390,17 +456,38 @@ func (va *ValidationAuthorityImpl) validateHTTP01(ctx context.Context, identifie func (va *ValidationAuthorityImpl) validateTLSSNI01(ctx context.Context, identifier core.AcmeIdentifier, challenge core.Challenge) ([]core.ValidationRecord, *probs.ProblemDetails) { if identifier.Type != "dns" { - va.log.Info(fmt.Sprintf("Identifier type for TLS-SNI was not DNS: %s", identifier)) - return nil, probs.Malformed("Identifier type for TLS-SNI was not DNS") + va.log.Info(fmt.Sprintf("Identifier type for TLS-SNI-01 was not DNS: %s", identifier)) + return nil, probs.Malformed("Identifier type for TLS-SNI-01 was not DNS") } // Compute the digest that will appear in the certificate - h := sha256.New() - h.Write([]byte(challenge.ProvidedKeyAuthorization)) - Z := hex.EncodeToString(h.Sum(nil)) + h := sha256.Sum256([]byte(challenge.ProvidedKeyAuthorization)) + Z := hex.EncodeToString(h[:]) ZName := fmt.Sprintf("%s.%s.%s", Z[:32], Z[32:], core.TLSSNISuffix) - return va.validateTLSWithZName(ctx, identifier, challenge, ZName) + return va.validateTLSSNI01WithZName(ctx, identifier, challenge, ZName) +} + +func (va *ValidationAuthorityImpl) validateTLSSNI02(ctx context.Context, identifier core.AcmeIdentifier, challenge core.Challenge) ([]core.ValidationRecord, *probs.ProblemDetails) { + if identifier.Type != "dns" { + va.log.Info(fmt.Sprintf("Identifier type for TLS-SNI-02 was not DNS: %s", identifier)) + return nil, probs.Malformed("Identifier type for TLS-SNI-02 was not DNS") + } + + const tlsSNITokenID = "token" + const tlsSNIKaID = "ka" + + // Compute the digest for the SAN b that will appear in the certificate + ha := sha256.Sum256([]byte(challenge.Token)) + za := hex.EncodeToString(ha[:]) + sanAName := fmt.Sprintf("%s.%s.%s.%s", za[:32], za[32:], tlsSNITokenID, core.TLSSNISuffix) + + // Compute the digest for the SAN B that will appear in the certificate + hb := sha256.Sum256([]byte(challenge.ProvidedKeyAuthorization)) + zb := hex.EncodeToString(hb[:]) + sanBName := fmt.Sprintf("%s.%s.%s.%s", zb[:32], zb[32:], tlsSNIKaID, core.TLSSNISuffix) + + return va.validateTLSSNI02WithZNames(ctx, identifier, challenge, sanAName, sanBName) } // badTLSHeader contains the string 'HTTP /' which is returned when @@ -549,6 +636,8 @@ func (va *ValidationAuthorityImpl) validateChallenge(ctx context.Context, identi return va.validateHTTP01(ctx, identifier, challenge) case core.ChallengeTypeTLSSNI01: return va.validateTLSSNI01(ctx, identifier, challenge) + case core.ChallengeTypeTLSSNI02: + return va.validateTLSSNI02(ctx, identifier, challenge) case core.ChallengeTypeDNS01: return va.validateDNS01(ctx, identifier, challenge) } diff --git a/va/va_test.go b/va/va_test.go index 5574836ef..7e59866d1 100644 --- a/va/va_test.go +++ b/va/va_test.go @@ -159,12 +159,27 @@ func httpSrv(t *testing.T, token string) *httptest.Server { return server } -func tlssniSrv(t *testing.T, chall core.Challenge) *httptest.Server { - h := sha256.New() - h.Write([]byte(chall.ProvidedKeyAuthorization)) - Z := hex.EncodeToString(h.Sum(nil)) +func tlssni01Srv(t *testing.T, chall core.Challenge) *httptest.Server { + h := sha256.Sum256([]byte(chall.ProvidedKeyAuthorization)) + Z := hex.EncodeToString(h[:]) ZName := fmt.Sprintf("%s.%s.acme.invalid", Z[:32], Z[32:]) + return tlssniSrvWithNames(t, chall, ZName) +} + +func tlssni02Srv(t *testing.T, chall core.Challenge) *httptest.Server { + ha := sha256.Sum256([]byte(chall.Token)) + za := hex.EncodeToString(ha[:]) + sanAName := fmt.Sprintf("%s.%s.token.acme.invalid", za[:32], za[32:]) + + hb := sha256.Sum256([]byte(chall.ProvidedKeyAuthorization)) + zb := hex.EncodeToString(hb[:]) + sanBName := fmt.Sprintf("%s.%s.ka.acme.invalid", zb[:32], zb[32:]) + + return tlssniSrvWithNames(t, chall, sanAName, sanBName) +} + +func tlssniSrvWithNames(t *testing.T, chall core.Challenge, names ...string) *httptest.Server { template := &x509.Certificate{ SerialNumber: big.NewInt(1337), Subject: pkix.Name{ @@ -177,7 +192,7 @@ func tlssniSrv(t *testing.T, chall core.Challenge) *httptest.Server { ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, - DNSNames: []string{ZName}, + DNSNames: names, } certBytes, _ := x509.CreateCertificate(rand.Reader, template, template, &TheKey.PublicKey, &TheKey) @@ -190,7 +205,7 @@ func tlssniSrv(t *testing.T, chall core.Challenge) *httptest.Server { Certificates: []tls.Certificate{*cert}, ClientAuth: tls.NoClientCert, GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - if clientHello.ServerName != ZName { + if clientHello.ServerName != names[0] { time.Sleep(time.Second * 10) return nil, nil } @@ -431,10 +446,10 @@ func getPort(hs *httptest.Server) (int, error) { return int(port), nil } -func TestTLSSNI(t *testing.T) { +func TestTLSSNI01(t *testing.T) { chall := createChallenge(core.ChallengeTypeTLSSNI01) - hs := tlssniSrv(t, chall) + hs := tlssni01Srv(t, chall) port, err := getPort(hs) test.AssertNotError(t, err, "failed to get test server port") @@ -443,7 +458,7 @@ func TestTLSSNI(t *testing.T) { _, prob := va.validateTLSSNI01(ctx, ident, chall) if prob != nil { - t.Fatalf("Unexpected failure in validateTLSSNI01: %s", prob) + t.Fatalf("Unexpected failure in validate TLS-SNI-01: %s", prob) } test.AssertEquals(t, len(log.GetAllMatching(`Resolved addresses for localhost \[using 127.0.0.1\]: \[127.0.0.1\]`)), 1) if len(log.GetAllMatching(`challenge for localhost received certificate \(1 of 1\): cert=\[`)) != 1 { @@ -501,11 +516,88 @@ func TestTLSSNI(t *testing.T) { log.Clear() _, err = va.validateTLSSNI01(ctx, ident, chall) - test.AssertError(t, err, "TLS SNI validation passed when talking to a HTTP-only server") + test.AssertError(t, err, "TLS-SNI-01 validation passed when talking to a HTTP-only server") test.Assert(t, strings.HasSuffix( err.Error(), "Server only speaks HTTP, not TLS", - ), "validateTLSSNI01 didn't return useful error") + ), "validate TLS-SNI-01 didn't return useful error") +} + +func TestTLSSNI02(t *testing.T) { + chall := createChallenge(core.ChallengeTypeTLSSNI02) + + hs := tlssni02Srv(t, chall) + port, err := getPort(hs) + test.AssertNotError(t, err, "failed to get test server port") + + va, _, log := setup() + va.tlsPort = port + + _, prob := va.validateTLSSNI02(ctx, ident, chall) + if prob != nil { + t.Fatalf("Unexpected failure in validate TLS-SNI-02: %s", prob) + } + test.AssertEquals(t, len(log.GetAllMatching(`Resolved addresses for localhost \[using 127.0.0.1\]: \[127.0.0.1\]`)), 1) + if len(log.GetAllMatching(`challenge for localhost received certificate \(1 of 1\): cert=\[`)) != 1 { + t.Errorf("Didn't get log message with validated certificate. Instead got:\n%s", + strings.Join(log.GetAllMatching(".*"), "\n")) + } + + log.Clear() + _, prob = va.validateTLSSNI02(ctx, core.AcmeIdentifier{ + Type: core.IdentifierType("ip"), + Value: net.JoinHostPort("127.0.0.1", fmt.Sprintf("%d", port)), + }, chall) + if prob == nil { + t.Fatalf("IdentifierType IP shouldn't have worked.") + } + test.AssertEquals(t, prob.Type, probs.MalformedProblem) + + log.Clear() + _, prob = va.validateTLSSNI02(ctx, core.AcmeIdentifier{Type: core.IdentifierDNS, Value: "always.invalid"}, chall) + if prob == nil { + t.Fatalf("Domain name was supposed to be invalid.") + } + test.AssertEquals(t, prob.Type, probs.UnknownHostProblem) + + // Need to create a new authorized keys object to get an unknown SNI (from the signature value) + chall.Token = core.NewToken() + chall.ProvidedKeyAuthorization = "invalid" + + log.Clear() + started := time.Now() + _, prob = va.validateTLSSNI02(ctx, ident, chall) + took := time.Since(started) + if prob == nil { + t.Fatalf("Validation should have failed") + } + test.AssertEquals(t, prob.Type, probs.ConnectionProblem) + // Check that the TLS connection times out after 5 seconds and doesn't block for 10 seconds + test.Assert(t, (took > (time.Second * 5)), "TLS returned before 5 seconds") + test.Assert(t, (took < (time.Second * 10)), "TLS connection didn't timeout after 5 seconds") + test.AssertEquals(t, len(log.GetAllMatching(`Resolved addresses for localhost \[using 127.0.0.1\]: \[127.0.0.1\]`)), 1) + + // Take down validation server and check that validation fails. + hs.Close() + _, err = va.validateTLSSNI02(ctx, ident, chall) + if err == nil { + t.Fatalf("Server's down; expected refusal. Where did we connect?") + } + test.AssertEquals(t, prob.Type, probs.ConnectionProblem) + + httpOnly := httpSrv(t, "") + defer httpOnly.Close() + port, err = getPort(httpOnly) + test.AssertNotError(t, err, "failed to get test server port") + va.tlsPort = port + + log.Clear() + _, err = va.validateTLSSNI02(ctx, ident, chall) + test.AssertError(t, err, "TLS-SNI-02 validation passed when talking to a HTTP-only server") + test.Assert(t, strings.HasSuffix( + err.Error(), + "Server only speaks HTTP, not TLS", + ), "validate TLS-SNI-02 didn't return useful error") } func brokenTLSSrv() *httptest.Server { @@ -664,7 +756,7 @@ func setChallengeToken(ch *core.Challenge, token string) { func TestValidateTLSSNI01(t *testing.T) { chall := createChallenge(core.ChallengeTypeTLSSNI01) - hs := tlssniSrv(t, chall) + hs := tlssni01Srv(t, chall) defer hs.Close() port, err := getPort(hs) @@ -678,7 +770,7 @@ func TestValidateTLSSNI01(t *testing.T) { test.Assert(t, prob == nil, "validation failed") } -func TestValidateTLSSNINotSane(t *testing.T) { +func TestValidateTLSSNI01NotSane(t *testing.T) { va, _, _ := setup() chall := createChallenge(core.ChallengeTypeTLSSNI01) @@ -925,7 +1017,7 @@ func TestDNSValidationNoAuthorityOK(t *testing.T) { func TestCAAFailure(t *testing.T) { chall := createChallenge(core.ChallengeTypeTLSSNI01) - hs := tlssniSrv(t, chall) + hs := tlssni01Srv(t, chall) defer hs.Close() port, err := getPort(hs) From 194a55d7c72836fce426fa72958ffcdd711a9777 Mon Sep 17 00:00:00 2001 From: Roland Bracewell Shoemaker Date: Wed, 22 Mar 2017 12:43:43 -0700 Subject: [PATCH 3/5] Remove RabbitMQ + AMQP references from README (#2616) Fixes #2407. --- README.md | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 1588de7d1..69c40cfaa 100644 --- a/README.md +++ b/README.md @@ -83,13 +83,9 @@ We recommend setting git's [fsckObjects setting](https://groups.google.com/forum/#!topic/binary-transparency/f-BI4o8HZW0/discussion) for better integrity guarantees when getting updates. -Boulder requires an installation of RabbitMQ, libtool-ltdl, goose, and -MariaDB 10.1 to work correctly. On Ubuntu and CentOS, you may have to -install RabbitMQ from https://rabbitmq.com/download.html to get a -recent version. If you want to save some trouble installing MariaDB and RabbitMQ -you can run them using Docker: +Boulder requires an installation of libtool-ltdl, goose, SoftHSM, and MariaDB 10.1 to work correctly. If you want to save some trouble installing MariaDB and SoftHSM you can run them using Docker: - docker-compose up -d bmysql brabbitmq bhsm + docker-compose up -d bmysql bhsm Also, Boulder requires Go 1.5. As of September 2015 this version is not yet available in OS repositories, so you will have to install from https://golang.org/dl/. @@ -121,7 +117,7 @@ Edit /etc/hosts to add this line: 127.0.0.1 boulder boulder-rabbitmq boulder-mysql -Resolve Go-dependencies, set up a database and RabbitMQ: +Resolve Go-dependencies, set up a database: ./test/setup.sh @@ -198,7 +194,7 @@ Requests from ACME clients result in new objects and changes to objects. The St Objects are also passed from one component to another on change events. For example, when a client provides a successful response to a validation challenge, it results in a change to the corresponding validation object. The Validation Authority forwards the new validation object to the Storage Authority for storage, and to the Registration Authority for any updates to a related Authorization object. -Boulder uses AMQP as a message bus. For components that you want to be remote, it is necessary to instantiate a "client" and "server" for that component. The client implements the component's Go interface, while the server has the actual logic for the component. More details in `amqp-rpc.go`. +Boulder uses gRPC for inter-component communication. For components that you want to be remote, it is necessary to instantiate a "client" and "server" for that component. The client implements the component's Go interface, while the server has the actual logic for the component. More details on this communication model can be found in the [gRPC documentation](http://www.grpc.io/docs/). The full details of how the various ACME operations happen in Boulder are laid out in [DESIGN.md](https://github.com/letsencrypt/boulder/blob/master/DESIGN.md) @@ -208,9 +204,7 @@ Dependencies All Go dependencies are vendored under the vendor directory, to [make dependency management easier](https://golang.org/cmd/go/#hdr-Vendor_Directories). -Local development also requires a RabbitMQ installation and MariaDB -10 installation (see above). MariaDB should be run on port 3306 for the -default integration tests. +Local development also requires a MariaDB 10 installation. MariaDB should be run on port 3306 for the default integration tests. To update the Go dependencies: @@ -253,7 +247,6 @@ you will get conflicting types between our vendored version and the cfssl vendor Adding RPCs ----------- -Boulder is moving towards using gRPC for all RPCs. To add a new RPC method, add -it to the relevant .proto file, then run: +Boulder uses gRPC for all RPCs. To add a new RPC method, add it to the relevant .proto file, then run: docker-compose run boulder go generate ./path/to/pkg/... From e2b2511898f5d76ea68e26402521d628751d97f4 Mon Sep 17 00:00:00 2001 From: Roland Bracewell Shoemaker Date: Wed, 22 Mar 2017 23:27:31 -0700 Subject: [PATCH 4/5] Overhaul internal error usage (#2583) This patch removes all usages of the `core.XXXError` and almost all usages of `probs` outside of the WFE and VA and replaces them with a unified internal error type. Since the VA uses `probs.ProblemDetails` quite extensively in challenges, and currently stores them in the DB I've saved this change for another change (it'll also require a migration). Since `ProblemDetails` should only ever be exposed to end-users all of its related logic should be moved into the `WFE` but since it still needs to be exposed to the VA and SA I've left it in place for now. The new internal `errors` package offers the same convenience functions as `probs` does as well as a new simpler type testing method. A few small changes have also been made to error messages, mainly adding the library and function name to internal server errors for easier debugging (i.e. where a number of functions return the exact same errors and there is no other way to distinguish which method threw the error). Also adds proper encoding of internal errors transferred over gRPC (the current encoding scheme is kept for `core` and `probs` errors since it'll be ideally be removed after we deploy this and follow-up changes) using `grpc/metadata` instead of the gRPC status codes. Fixes #2507. Updates #2254 and #2505. --- bdns/mocks.go | 2 +- bdns/problem.go | 13 -- bdns/problem_test.go | 15 +-- ca/ca.go | 25 ++-- ca/ca_test.go | 20 ++- cmd/admin-revoker/main.go | 3 +- cmd/expiration-mailer/main_test.go | 4 +- cmd/orphan-finder/main.go | 5 +- cmd/orphan-finder/main_test.go | 5 +- core/util.go | 46 +------ core/util_test.go | 37 ------ docs/error-handling.md | 11 ++ errors/errors.go | 96 +++++++++++++++ goodkey/good_key.go | 29 +++-- grpc/bcodes.go | 48 +++++++- grpc/bcodes_test.go | 8 +- grpc/errors_test.go | 16 ++- grpc/interceptors.go | 12 +- mocks/mocks.go | 7 +- policy/pa.go | 34 +++--- ra/ra.go | 190 ++++++++++++++++------------- ra/ra_test.go | 39 +++--- rpc/amqp-rpc.go | 18 ++- rpc/amqp-rpc_test.go | 5 + sa/sa.go | 36 +++--- sa/sa_test.go | 13 +- va/va.go | 6 +- va/va_test.go | 2 +- wfe/jose.go | 21 ++-- wfe/probs.go | 80 ++++++++++++ wfe/probs_test.go | 55 +++++++++ wfe/wfe.go | 27 ++-- wfe/wfe_test.go | 11 +- 33 files changed, 588 insertions(+), 351 deletions(-) create mode 100644 docs/error-handling.md create mode 100644 errors/errors.go create mode 100644 wfe/probs.go create mode 100644 wfe/probs_test.go diff --git a/bdns/mocks.go b/bdns/mocks.go index 8940dcdb5..32d9865e4 100644 --- a/bdns/mocks.go +++ b/bdns/mocks.go @@ -134,7 +134,7 @@ func (mock *MockDNSResolver) LookupCAA(_ context.Context, domain string) ([]*dns record.Value = ";" results = append(results, &record) case "bad-local-resolver.com": - return nil, DNSError{underlying: MockTimeoutError()} + return nil, &DNSError{dns.TypeCAA, domain, MockTimeoutError(), -1} } return results, nil } diff --git a/bdns/problem.go b/bdns/problem.go index 25f97bc73..02eb1ddbf 100644 --- a/bdns/problem.go +++ b/bdns/problem.go @@ -4,7 +4,6 @@ import ( "fmt" "net" - "github.com/letsencrypt/boulder/probs" "github.com/miekg/dns" "golang.org/x/net/context" ) @@ -56,15 +55,3 @@ func (d DNSError) Timeout() bool { const detailDNSTimeout = "query timed out" const detailDNSNetFailure = "networking error" const detailServerFailure = "server failure at resolver" - -// ProblemDetailsFromDNSError checks the error returned from Lookup... methods -// and tests if the error was an underlying net.OpError or an error caused by -// resolver returning SERVFAIL or other invalid Rcodes and returns the relevant -// core.ProblemDetails. The detail string will contain a mention of the DNS -// record type and domain given. -func ProblemDetailsFromDNSError(err error) *probs.ProblemDetails { - if dnsErr, ok := err.(*DNSError); ok { - return probs.ConnectionFailure(dnsErr.Error()) - } - return probs.ConnectionFailure(detailServerFailure) -} diff --git a/bdns/problem_test.go b/bdns/problem_test.go index 086ecd870..09e317d54 100644 --- a/bdns/problem_test.go +++ b/bdns/problem_test.go @@ -7,11 +7,9 @@ import ( "github.com/miekg/dns" "golang.org/x/net/context" - - "github.com/letsencrypt/boulder/probs" ) -func TestProblemDetailsFromDNSError(t *testing.T) { +func TestDNSError(t *testing.T) { testCases := []struct { err error expected string @@ -19,9 +17,6 @@ func TestProblemDetailsFromDNSError(t *testing.T) { { &DNSError{dns.TypeA, "hostname", MockTimeoutError(), -1}, "DNS problem: query timed out looking up A for hostname", - }, { - errors.New("other failure"), - detailServerFailure, }, { &DNSError{dns.TypeMX, "hostname", &net.OpError{Err: errors.New("some net error")}, -1}, "DNS problem: networking error looking up MX for hostname", @@ -37,12 +32,8 @@ func TestProblemDetailsFromDNSError(t *testing.T) { }, } for _, tc := range testCases { - err := ProblemDetailsFromDNSError(tc.err) - if err.Type != probs.ConnectionProblem { - t.Errorf("ProblemDetailsFromDNSError(%q).Type = %q, expected %q", tc.err, err.Type, probs.ConnectionProblem) - } - if err.Detail != tc.expected { - t.Errorf("ProblemDetailsFromDNSError(%q).Detail = %q, expected %q", tc.err, err.Detail, tc.expected) + if tc.err.Error() != tc.expected { + t.Errorf("got %q, expected %q", tc.err.Error(), tc.expected) } } } diff --git a/ca/ca.go b/ca/ca.go index 0da20001b..e570cf60a 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -29,6 +29,7 @@ import ( "github.com/letsencrypt/boulder/cmd" "github.com/letsencrypt/boulder/core" csrlib "github.com/letsencrypt/boulder/csr" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/goodkey" blog "github.com/letsencrypt/boulder/log" "github.com/letsencrypt/boulder/metrics" @@ -295,12 +296,10 @@ func (ca *CertificateAuthorityImpl) extensionsFromCSR(csr *x509.CertificateReque ca.stats.Inc(metricCSRExtensionTLSFeature, 1) value, ok := ext.Value.([]byte) if !ok { - msg := fmt.Sprintf("Malformed extension with OID %v", ext.Type) - return nil, core.MalformedRequestError(msg) + return nil, berrors.MalformedError("malformed extension with OID %v", ext.Type) } else if !bytes.Equal(value, mustStapleFeatureValue) { - msg := fmt.Sprintf("Unsupported value for extension with OID %v", ext.Type) ca.stats.Inc(metricCSRExtensionTLSFeatureInvalid, 1) - return nil, core.MalformedRequestError(msg) + return nil, berrors.MalformedError("unsupported value for extension with OID %v", ext.Type) } if ca.enableMustStaple { @@ -386,7 +385,7 @@ func (ca *CertificateAuthorityImpl) IssueCertificate(ctx context.Context, csr x5 regID, ); err != nil { ca.log.AuditErr(err.Error()) - return emptyCert, core.MalformedRequestError(err.Error()) + return emptyCert, berrors.MalformedError(err.Error()) } requestedExtensions, err := ca.extensionsFromCSR(&csr) @@ -398,7 +397,7 @@ func (ca *CertificateAuthorityImpl) IssueCertificate(ctx context.Context, csr x5 notAfter := ca.clk.Now().Add(ca.validityPeriod) if issuer.cert.NotAfter.Before(notAfter) { - err = core.InternalServerError("Cannot issue a certificate that expires after the issuer certificate.") + err = berrors.InternalServerError("cannot issue a certificate that expires after the issuer certificate") ca.log.AuditErr(err.Error()) return emptyCert, err } @@ -415,7 +414,7 @@ func (ca *CertificateAuthorityImpl) IssueCertificate(ctx context.Context, csr x5 serialBytes[0] = byte(ca.prefix) _, err = rand.Read(serialBytes[1:]) if err != nil { - err = core.InternalServerError(err.Error()) + err = berrors.InternalServerError("failed to generate serial: %s", err) ca.log.AuditErr(fmt.Sprintf("Serial randomness failed, err=[%v]", err)) return emptyCert, err } @@ -430,7 +429,7 @@ func (ca *CertificateAuthorityImpl) IssueCertificate(ctx context.Context, csr x5 case *ecdsa.PublicKey: profile = ca.ecdsaProfile default: - err = core.InternalServerError(fmt.Sprintf("unsupported key type %T", csr.PublicKey)) + err = berrors.InternalServerError("unsupported key type %T", csr.PublicKey) ca.log.AuditErr(err.Error()) return emptyCert, err } @@ -456,21 +455,21 @@ func (ca *CertificateAuthorityImpl) IssueCertificate(ctx context.Context, csr x5 certPEM, err := issuer.eeSigner.Sign(req) ca.noteSignError(err) if err != nil { - err = core.InternalServerError(err.Error()) + err = berrors.InternalServerError("failed to sign certificate: %s", err) ca.log.AuditErr(fmt.Sprintf("Signing failed: serial=[%s] err=[%v]", serialHex, err)) return emptyCert, err } ca.stats.Inc("Signatures.Certificate", 1) if len(certPEM) == 0 { - err = core.InternalServerError("No certificate returned by server") + err = berrors.InternalServerError("no certificate returned by server") ca.log.AuditErr(fmt.Sprintf("PEM empty from Signer: serial=[%s] err=[%v]", serialHex, err)) return emptyCert, err } block, _ := pem.Decode(certPEM) if block == nil || block.Type != "CERTIFICATE" { - err = core.InternalServerError("Invalid certificate value returned") + err = berrors.InternalServerError("invalid certificate value returned") ca.log.AuditErr(fmt.Sprintf("PEM decode error, aborting: serial=[%s] pem=[%s] err=[%v]", serialHex, certPEM, err)) return emptyCert, err @@ -487,7 +486,7 @@ func (ca *CertificateAuthorityImpl) IssueCertificate(ctx context.Context, csr x5 // This is one last check for uncaught errors if err != nil { - err = core.InternalServerError(err.Error()) + err = berrors.InternalServerError(err.Error()) ca.log.AuditErr(fmt.Sprintf("Uncaught error, aborting: serial=[%s] cert=[%s] err=[%v]", serialHex, hex.EncodeToString(certDER), err)) return emptyCert, err @@ -496,7 +495,7 @@ func (ca *CertificateAuthorityImpl) IssueCertificate(ctx context.Context, csr x5 // Store the cert with the certificate authority, if provided _, err = ca.SA.AddCertificate(ctx, certDER, regID) if err != nil { - err = core.InternalServerError(err.Error()) + err = berrors.InternalServerError(err.Error()) // Note: This log line is parsed by cmd/orphan-finder. If you make any // changes here, you should make sure they are reflected in orphan-finder. ca.log.AuditErr(fmt.Sprintf( diff --git a/ca/ca_test.go b/ca/ca_test.go index 41c36be37..6fd402619 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -21,6 +21,7 @@ import ( "github.com/letsencrypt/boulder/cmd" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/goodkey" blog "github.com/letsencrypt/boulder/log" "github.com/letsencrypt/boulder/metrics" @@ -471,8 +472,7 @@ func TestNoHostnames(t *testing.T) { csr, _ := x509.ParseCertificateRequest(NoNamesCSR) _, err = ca.IssueCertificate(ctx, *csr, 1001) test.AssertError(t, err, "Issued certificate with no names") - _, ok := err.(core.MalformedRequestError) - test.Assert(t, ok, "Incorrect error type returned") + test.Assert(t, berrors.Is(err, berrors.Malformed), "Incorrect error type returned") } func TestRejectTooManyNames(t *testing.T) { @@ -493,8 +493,7 @@ func TestRejectTooManyNames(t *testing.T) { csr, _ := x509.ParseCertificateRequest(TooManyNameCSR) _, err = ca.IssueCertificate(ctx, *csr, 1001) test.AssertError(t, err, "Issued certificate with too many names") - _, ok := err.(core.MalformedRequestError) - test.Assert(t, ok, "Incorrect error type returned") + test.Assert(t, berrors.Is(err, berrors.Malformed), "Incorrect error type returned") } func TestRejectValidityTooLong(t *testing.T) { @@ -520,8 +519,7 @@ func TestRejectValidityTooLong(t *testing.T) { csr, _ := x509.ParseCertificateRequest(NoCNCSR) _, err = ca.IssueCertificate(ctx, *csr, 1) test.AssertError(t, err, "Cannot issue a certificate that expires after the intermediate certificate") - _, ok := err.(core.InternalServerError) - test.Assert(t, ok, "Incorrect error type returned") + test.Assert(t, berrors.Is(err, berrors.InternalServer), "Incorrect error type returned") } func TestShortKey(t *testing.T) { @@ -541,8 +539,7 @@ func TestShortKey(t *testing.T) { csr, _ := x509.ParseCertificateRequest(ShortKeyCSR) _, err = ca.IssueCertificate(ctx, *csr, 1001) test.AssertError(t, err, "Issued a certificate with too short a key.") - _, ok := err.(core.MalformedRequestError) - test.Assert(t, ok, "Incorrect error type returned") + test.Assert(t, berrors.Is(err, berrors.Malformed), "Incorrect error type returned") } func TestAllowNoCN(t *testing.T) { @@ -603,8 +600,7 @@ func TestLongCommonName(t *testing.T) { csr, _ := x509.ParseCertificateRequest(LongCNCSR) _, err = ca.IssueCertificate(ctx, *csr, 1001) test.AssertError(t, err, "Issued a certificate with a CN over 64 bytes.") - _, ok := err.(core.MalformedRequestError) - test.Assert(t, ok, "Incorrect error type returned") + test.Assert(t, berrors.Is(err, berrors.Malformed), "Incorrect error type returned") } func TestWrongSignature(t *testing.T) { @@ -746,9 +742,7 @@ func TestExtensions(t *testing.T) { stats.EXPECT().Inc(metricCSRExtensionTLSFeatureInvalid, int64(1)).Return(nil) _, err = ca.IssueCertificate(ctx, *tlsFeatureUnknownCSR, 1001) test.AssertError(t, err, "Allowed a CSR with an empty TLS feature extension") - if _, ok := err.(core.MalformedRequestError); !ok { - t.Errorf("Wrong error type when rejecting a CSR with empty TLS feature extension") - } + test.Assert(t, berrors.Is(err, berrors.Malformed), "Wrong error type when rejecting a CSR with empty TLS feature extension") // Unsupported extensions should be silently ignored, having the same // extensions as the TLS Feature cert above, minus the TLS Feature Extension diff --git a/cmd/admin-revoker/main.go b/cmd/admin-revoker/main.go index 879403175..bbcac13a1 100644 --- a/cmd/admin-revoker/main.go +++ b/cmd/admin-revoker/main.go @@ -16,6 +16,7 @@ import ( "github.com/letsencrypt/boulder/cmd" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/features" bgrpc "github.com/letsencrypt/boulder/grpc" blog "github.com/letsencrypt/boulder/log" @@ -117,7 +118,7 @@ func revokeBySerial(ctx context.Context, serial string, reasonCode revocation.Re certObj, err := sa.SelectCertificate(tx, "WHERE serial = ?", serial) if err == sql.ErrNoRows { - return core.NotFoundError(fmt.Sprintf("No certificate found for %s", serial)) + return berrors.NotFoundError("certificate with serial %q not found", serial) } if err != nil { return err diff --git a/cmd/expiration-mailer/main_test.go b/cmd/expiration-mailer/main_test.go index 3a265fe3e..22549f927 100644 --- a/cmd/expiration-mailer/main_test.go +++ b/cmd/expiration-mailer/main_test.go @@ -23,6 +23,7 @@ import ( "gopkg.in/square/go-jose.v1" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" blog "github.com/letsencrypt/boulder/log" "github.com/letsencrypt/boulder/metrics" "github.com/letsencrypt/boulder/mocks" @@ -50,8 +51,7 @@ type fakeRegStore struct { func (f fakeRegStore) GetRegistration(ctx context.Context, id int64) (core.Registration, error) { r, ok := f.RegByID[id] if !ok { - msg := fmt.Sprintf("no such registration %d", id) - return r, core.NoSuchRegistrationError(msg) + return r, berrors.NotFoundError("no registration found for %q", id) } return r, nil } diff --git a/cmd/orphan-finder/main.go b/cmd/orphan-finder/main.go index f6adc5797..e98ec5b70 100644 --- a/cmd/orphan-finder/main.go +++ b/cmd/orphan-finder/main.go @@ -17,6 +17,7 @@ import ( "github.com/letsencrypt/boulder/cmd" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/features" bgrpc "github.com/letsencrypt/boulder/grpc" blog "github.com/letsencrypt/boulder/log" @@ -68,7 +69,9 @@ func checkDER(sai certificateStorage, der []byte) error { if err == nil { return errAlreadyExists } - if _, ok := err.(core.NotFoundError); ok { + // TODO(#2600): Remove core.NotFoundError check once boulder/errors + // code is deployed + if _, ok := err.(core.NotFoundError); ok || berrors.Is(err, berrors.NotFound) { return nil } return fmt.Errorf("Existing certificate lookup failed: %s", err) diff --git a/cmd/orphan-finder/main_test.go b/cmd/orphan-finder/main_test.go index ed8889741..8c836f02e 100644 --- a/cmd/orphan-finder/main_test.go +++ b/cmd/orphan-finder/main_test.go @@ -7,8 +7,9 @@ import ( "golang.org/x/net/context" "github.com/jmhodges/clock" - "github.com/letsencrypt/boulder/core" + "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" blog "github.com/letsencrypt/boulder/log" "github.com/letsencrypt/boulder/test" ) @@ -28,7 +29,7 @@ func (m *mockSA) GetCertificate(ctx context.Context, s string) (core.Certificate if m.certificate.DER != nil { return m.certificate, nil } - return core.Certificate{}, core.NotFoundError("no cert stored") + return core.Certificate{}, berrors.NotFoundError("no cert stored") } func checkNoErrors(t *testing.T) { diff --git a/core/util.go b/core/util.go index d01d1179b..ac64ab3c9 100644 --- a/core/util.go +++ b/core/util.go @@ -16,16 +16,15 @@ import ( "io/ioutil" "math/big" mrand "math/rand" - "net/http" "regexp" "sort" "strings" "time" "unicode" - blog "github.com/letsencrypt/boulder/log" - "github.com/letsencrypt/boulder/probs" jose "gopkg.in/square/go-jose.v1" + + blog "github.com/letsencrypt/boulder/log" ) // Package Variables Variables @@ -100,47 +99,6 @@ func (e RateLimitedError) Error() string { return string(e) } func (e TooManyRPCRequestsError) Error() string { return string(e) } func (e BadNonceError) Error() string { return string(e) } -// statusTooManyRequests is the HTTP status code meant for rate limiting -// errors. It's not currently in the net/http library so we add it here. -const statusTooManyRequests = 429 - -// ProblemDetailsForError turns an error into a ProblemDetails with the special -// case of returning the same error back if its already a ProblemDetails. If the -// error is of an type unknown to ProblemDetailsForError, it will return a -// ServerInternal ProblemDetails. -func ProblemDetailsForError(err error, msg string) *probs.ProblemDetails { - switch e := err.(type) { - case *probs.ProblemDetails: - return e - case MalformedRequestError: - return probs.Malformed(fmt.Sprintf("%s :: %s", msg, err)) - case NotSupportedError: - return &probs.ProblemDetails{ - Type: probs.ServerInternalProblem, - Detail: fmt.Sprintf("%s :: %s", msg, err), - HTTPStatus: http.StatusNotImplemented, - } - case UnauthorizedError: - return probs.Unauthorized(fmt.Sprintf("%s :: %s", msg, err)) - case NotFoundError: - return probs.NotFound(fmt.Sprintf("%s :: %s", msg, err)) - case LengthRequiredError: - prob := probs.Malformed("missing Content-Length header") - prob.HTTPStatus = http.StatusLengthRequired - return prob - case SignatureValidationError: - return probs.Malformed(fmt.Sprintf("%s :: %s", msg, err)) - case RateLimitedError: - return probs.RateLimited(fmt.Sprintf("%s :: %s", msg, err)) - case BadNonceError: - return probs.BadNonce(fmt.Sprintf("%s :: %s", msg, err)) - default: - // Internal server error messages may include sensitive data, so we do - // not include it. - return probs.ServerInternal(msg) - } -} - // Random stuff // RandomString returns a randomly generated string of the requested length. diff --git a/core/util_test.go b/core/util_test.go index b3fdc6c32..4cd363aad 100644 --- a/core/util_test.go +++ b/core/util_test.go @@ -5,13 +5,11 @@ import ( "fmt" "math" "math/big" - "reflect" "sort" "testing" "gopkg.in/square/go-jose.v1" - "github.com/letsencrypt/boulder/probs" "github.com/letsencrypt/boulder/test" ) @@ -110,38 +108,3 @@ func TestUniqueLowerNames(t *testing.T) { sort.Strings(u) test.AssertDeepEquals(t, []string{"a.com", "bar.com", "baz.com", "foobar.com"}, u) } - -func TestProblemDetailsFromError(t *testing.T) { - testCases := []struct { - err error - statusCode int - problem probs.ProblemType - }{ - {InternalServerError("foo"), 500, probs.ServerInternalProblem}, - {NotSupportedError("foo"), 501, probs.ServerInternalProblem}, - {MalformedRequestError("foo"), 400, probs.MalformedProblem}, - {UnauthorizedError("foo"), 403, probs.UnauthorizedProblem}, - {NotFoundError("foo"), 404, probs.MalformedProblem}, - {SignatureValidationError("foo"), 400, probs.MalformedProblem}, - {RateLimitedError("foo"), 429, probs.RateLimitedProblem}, - {LengthRequiredError("foo"), 411, probs.MalformedProblem}, - {BadNonceError("foo"), 400, probs.BadNonceProblem}, - } - for _, c := range testCases { - p := ProblemDetailsForError(c.err, "k") - if p.HTTPStatus != c.statusCode { - t.Errorf("Incorrect status code for %s. Expected %d, got %d", reflect.TypeOf(c.err).Name(), c.statusCode, p.HTTPStatus) - } - if probs.ProblemType(p.Type) != c.problem { - t.Errorf("Expected problem urn %#v, got %#v", c.problem, p.Type) - } - } - - expected := &probs.ProblemDetails{ - Type: probs.MalformedProblem, - HTTPStatus: 200, - Detail: "gotcha", - } - p := ProblemDetailsForError(expected, "k") - test.AssertDeepEquals(t, expected, p) -} diff --git a/docs/error-handling.md b/docs/error-handling.md new file mode 100644 index 000000000..34ef01671 --- /dev/null +++ b/docs/error-handling.md @@ -0,0 +1,11 @@ +# Error Handling Guidance + +Previously Boulder has used a mix of various error types to represent errors internally, mainly the `core.XXXError` types and `probs.ProblemDetails`, without any guidance on which should be used when or where. + +We have switched away from this to using a single unified internal error type, `boulder/errors.BoulderError` which should be used anywhere we need to pass errors between components and need to be able to indicate and test the type of the error that was passed. `probs.ProblemDetails` should only be used in the WFE when creating a problem document to pass directly back to the user client. + +A mapping exists in the WFE to map all of the available `boulder/errors.ErrorType`s to the relevant `probs.ProblemType`s. Internally errors should be wrapped when doing so provides some further context to the error that aides in debugging or will be passed back to the user client. An error may be unwrapped, or a simple stdlib `error` may be used, but doing so means the `probs.ProblemType` mapping will always be `probs.ServerInternalProblem` so should only be used for errors that do not need to be presented back to the user client. + +`boulder/errors.BoulderError`s have two components: an internal type, `boulder/errors.ErrorType`, and a detail string. The internal type should be used for a. allowing the receiver to determine what caused the error, e.g. by using `boulder/errors.NotFound` to indicate a DB operation couldn't find the requested resource, and b. allowing the WFE to convert the error to the relevant `probs.ProblemType` for display to the user. The detail string should provide a user readable explanation of the issue to be presented to the user; the only exception to this is when the internal type is `boulder/errors.InternalServer` in which case the detail of the error will be stripped by the WFE and the only message presented to the user will be provided by the caller in the WFE. + +Error type testing should be done with `boulder/errors.Is` instead of locally doing a type cast test. diff --git a/errors/errors.go b/errors/errors.go new file mode 100644 index 000000000..018b1808b --- /dev/null +++ b/errors/errors.go @@ -0,0 +1,96 @@ +package errors + +import "fmt" + +// ErrorType provides a coarse category for BoulderErrors +type ErrorType int + +const ( + InternalServer ErrorType = iota + NotSupported + Malformed + Unauthorized + NotFound + SignatureValidation + RateLimit + TooManyRequests + RejectedIdentifier + UnsupportedIdentifier + InvalidEmail + ConnectionFailure +) + +// BoulderError represents internal Boulder errors +type BoulderError struct { + Type ErrorType + Detail string +} + +func (be *BoulderError) Error() string { + return be.Detail +} + +// New is a convenience function for creating a new BoulderError +func New(errType ErrorType, msg string, args ...interface{}) error { + return &BoulderError{ + Type: errType, + Detail: fmt.Sprintf(msg, args...), + } +} + +// Is is a convenience function for testing the internal type of an BoulderError +func Is(err error, errType ErrorType) bool { + bErr, ok := err.(*BoulderError) + if !ok { + return false + } + return bErr.Type == errType +} + +func InternalServerError(msg string, args ...interface{}) error { + return New(InternalServer, msg, args...) +} + +func NotSupportedError(msg string, args ...interface{}) error { + return New(NotSupported, msg, args...) +} + +func MalformedError(msg string, args ...interface{}) error { + return New(Malformed, msg, args...) +} + +func UnauthorizedError(msg string, args ...interface{}) error { + return New(Unauthorized, msg, args...) +} + +func NotFoundError(msg string, args ...interface{}) error { + return New(NotFound, msg, args...) +} + +func SignatureValidationError(msg string, args ...interface{}) error { + return New(SignatureValidation, msg, args...) +} + +func RateLimitError(msg string, args ...interface{}) error { + return New(RateLimit, msg, args...) +} + +func TooManyRequestsError(msg string, args ...interface{}) error { + return New(TooManyRequests, msg, args...) +} + +func RejectedIdentifierError(msg string, args ...interface{}) error { + return New(RejectedIdentifier, msg, args...) +} + +func UnsupportedIdentifierError(msg string, args ...interface{}) error { + return New(UnsupportedIdentifier, msg, args...) +} + +func InvalidEmailError(msg string, args ...interface{}) error { + return New(InvalidEmail, msg, args...) +} + +func ConnectionFailureError(msg string, args ...interface{}) error { + return New(ConnectionFailure, msg, args...) +} diff --git a/goodkey/good_key.go b/goodkey/good_key.go index 052219635..d65b18e98 100644 --- a/goodkey/good_key.go +++ b/goodkey/good_key.go @@ -5,12 +5,11 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" - "fmt" "math/big" "reflect" "sync" - "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" ) // To generate, run: primes 2 752 | tr '\n' , @@ -67,7 +66,7 @@ func (policy *KeyPolicy) GoodKey(key crypto.PublicKey) error { case *ecdsa.PublicKey: return policy.goodKeyECDSA(*t) default: - return core.MalformedRequestError(fmt.Sprintf("Unknown key type %s", reflect.TypeOf(key))) + return berrors.MalformedError("unknown key type %s", reflect.TypeOf(key)) } } @@ -97,7 +96,7 @@ func (policy *KeyPolicy) goodKeyECDSA(key ecdsa.PublicKey) (err error) { // This code assumes that the point at infinity is (0,0), which is the // case for all supported curves. if isPointAtInfinityNISTP(key.X, key.Y) { - return core.MalformedRequestError("Key x, y must not be the point at infinity") + return berrors.MalformedError("key x, y must not be the point at infinity") } // SP800-56A § 5.6.2.3.2 Step 2. @@ -114,11 +113,11 @@ func (policy *KeyPolicy) goodKeyECDSA(key ecdsa.PublicKey) (err error) { // correct representation of an element in the underlying field by verifying // that x and y are integers in [0, p-1]. if key.X.Sign() < 0 || key.Y.Sign() < 0 { - return core.MalformedRequestError("Key x, y must not be negative") + return berrors.MalformedError("key x, y must not be negative") } if key.X.Cmp(params.P) >= 0 || key.Y.Cmp(params.P) >= 0 { - return core.MalformedRequestError("Key x, y must not exceed P-1") + return berrors.MalformedError("key x, y must not exceed P-1") } // SP800-56A § 5.6.2.3.2 Step 3. @@ -136,7 +135,7 @@ func (policy *KeyPolicy) goodKeyECDSA(key ecdsa.PublicKey) (err error) { // This proves that the public key is on the correct elliptic curve. // But in practice, this test is provided by crypto/elliptic, so use that. if !key.Curve.IsOnCurve(key.X, key.Y) { - return core.MalformedRequestError("Key point is not on the curve") + return berrors.MalformedError("key point is not on the curve") } // SP800-56A § 5.6.2.3.2 Step 4. @@ -152,7 +151,7 @@ func (policy *KeyPolicy) goodKeyECDSA(key ecdsa.PublicKey) (err error) { // n*Q = O iff n*Q is the point at infinity (see step 1). ox, oy := key.Curve.ScalarMult(key.X, key.Y, params.N.Bytes()) if !isPointAtInfinityNISTP(ox, oy) { - return core.MalformedRequestError("Public key does not have correct order") + return berrors.MalformedError("public key does not have correct order") } // End of SP800-56A § 5.6.2.3.2 Public Key Validation Routine. @@ -178,14 +177,14 @@ func (policy *KeyPolicy) goodCurve(c elliptic.Curve) (err error) { case policy.AllowECDSANISTP384 && params == elliptic.P384().Params(): return nil default: - return core.MalformedRequestError(fmt.Sprintf("ECDSA curve %v not allowed", params.Name)) + return berrors.MalformedError("ECDSA curve %v not allowed", params.Name) } } // GoodKeyRSA determines if a RSA pubkey meets our requirements func (policy *KeyPolicy) goodKeyRSA(key rsa.PublicKey) (err error) { if !policy.AllowRSA { - return core.MalformedRequestError("RSA keys are not allowed") + return berrors.MalformedError("RSA keys are not allowed") } // Baseline Requirements Appendix A @@ -194,15 +193,15 @@ func (policy *KeyPolicy) goodKeyRSA(key rsa.PublicKey) (err error) { modulusBitLen := modulus.BitLen() const maxKeySize = 4096 if modulusBitLen < 2048 { - return core.MalformedRequestError(fmt.Sprintf("Key too small: %d", modulusBitLen)) + return berrors.MalformedError("key too small: %d", modulusBitLen) } if modulusBitLen > maxKeySize { - return core.MalformedRequestError(fmt.Sprintf("Key too large: %d > %d", modulusBitLen, maxKeySize)) + return berrors.MalformedError("key too large: %d > %d", modulusBitLen, maxKeySize) } // Bit lengths that are not a multiple of 8 may cause problems on some // client implementations. if modulusBitLen%8 != 0 { - return core.MalformedRequestError(fmt.Sprintf("Key length wasn't a multiple of 8: %d", modulusBitLen)) + return berrors.MalformedError("key length wasn't a multiple of 8: %d", modulusBitLen) } // The CA SHALL confirm that the value of the public exponent is an // odd number equal to 3 or more. Additionally, the public exponent @@ -211,13 +210,13 @@ func (policy *KeyPolicy) goodKeyRSA(key rsa.PublicKey) (err error) { // 2^32 - 1 or 2^64 - 1, because it stores E as an integer. So we // don't need to check the upper bound. if (key.E%2) == 0 || key.E < ((1<<16)+1) { - return core.MalformedRequestError(fmt.Sprintf("Key exponent should be odd and >2^16: %d", key.E)) + return berrors.MalformedError("key exponent should be odd and >2^16: %d", key.E) } // The modulus SHOULD also have the following characteristics: an odd // number, not the power of a prime, and have no factors smaller than 752. // TODO: We don't yet check for "power of a prime." if checkSmallPrimes(modulus) { - return core.MalformedRequestError("Key divisible by small prime") + return berrors.MalformedError("key divisible by small prime") } return nil diff --git a/grpc/bcodes.go b/grpc/bcodes.go index 3702950f4..0e1a563f6 100644 --- a/grpc/bcodes.go +++ b/grpc/bcodes.go @@ -3,17 +3,22 @@ package grpc import ( "encoding/json" "errors" + "strconv" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/probs" ) // gRPC error codes used by Boulder. While the gRPC codes // end at 16 we start at 100 to provide a little leeway // in case they ever decide to add more +// TODO(#2507): Deprecated, remove once boulder/errors code is deployed const ( MalformedRequestError = iota + 100 NotSupportedError @@ -62,10 +67,25 @@ func errorToCode(err error) codes.Code { } } -func wrapError(err error) error { +// wrapError wraps the internal error types we use for transport across the gRPC +// layer and appends an appropriate errortype to the gRPC trailer via the provided +// context. core.XXXError and probs.ProblemDetails error types are encoded using the gRPC +// error status code which has been deprecated (#2507). errors.BoulderError error types +// are encoded using the grpc/metadata in the context.Context for the RPC which is +// considered to be the 'proper' method of encoding custom error types (grpc/grpc#4543 +// and grpc/grpc-go#478) +func wrapError(ctx context.Context, err error) error { if err == nil { return nil } + if berr, ok := err.(*berrors.BoulderError); ok { + // Ignoring the error return here is safe because if setting the metadata + // fails, we'll still return an error, but it will be interpreted on the + // other side as an InternalServerError instead of a more specific one. + _ = grpc.SetTrailer(ctx, metadata.Pairs("errortype", strconv.Itoa(int(berr.Type)))) + return grpc.Errorf(codes.Unknown, err.Error()) + } + // TODO(2589): deprecated, remove once boulder/errors code has been deployed code := errorToCode(err) var body string if code == ProblemDetails { @@ -83,10 +103,34 @@ func wrapError(err error) error { return grpc.Errorf(code, body) } -func unwrapError(err error) error { +// unwrapError unwraps errors returned from gRPC client calls which were wrapped +// with wrapError to their proper internal error type. If the provided metadata +// object has an "errortype" field, that will be used to set the type of the +// error. If the error is a core.XXXError or a probs.ProblemDetails the type +// is determined using the gRPC error code which has been deprecated (#2507). +func unwrapError(err error, md metadata.MD) error { if err == nil { return nil } + if errTypeStrs, ok := md["errortype"]; ok { + unwrappedErr := grpc.ErrorDesc(err) + if len(errTypeStrs) != 1 { + return berrors.InternalServerError( + "multiple errorType metadata, wrapped error %q", + unwrappedErr, + ) + } + errType, decErr := strconv.Atoi(errTypeStrs[0]) + if decErr != nil { + return berrors.InternalServerError( + "failed to decode error type, decoding error %q, wrapped error %q", + decErr, + unwrappedErr, + ) + } + return berrors.New(berrors.ErrorType(errType), unwrappedErr) + } + // TODO(2589): deprecated, remove once boulder/errors code has been deployed code := grpc.Code(err) errBody := grpc.ErrorDesc(err) switch code { diff --git a/grpc/bcodes_test.go b/grpc/bcodes_test.go index 87de6eb8c..15cae596f 100644 --- a/grpc/bcodes_test.go +++ b/grpc/bcodes_test.go @@ -30,11 +30,11 @@ func TestErrors(t *testing.T) { } for _, tc := range testcases { - wrappedErr := wrapError(tc.err) + wrappedErr := wrapError(nil, tc.err) test.AssertEquals(t, grpc.Code(wrappedErr), tc.expectedCode) - test.AssertDeepEquals(t, tc.err, unwrapError(wrappedErr)) + test.AssertDeepEquals(t, tc.err, unwrapError(wrappedErr, nil)) } - test.AssertEquals(t, wrapError(nil), nil) - test.AssertEquals(t, unwrapError(nil), nil) + test.AssertEquals(t, wrapError(nil, nil), nil) + test.AssertEquals(t, unwrapError(nil, nil), nil) } diff --git a/grpc/errors_test.go b/grpc/errors_test.go index f03b54b60..80e27a2ac 100644 --- a/grpc/errors_test.go +++ b/grpc/errors_test.go @@ -4,12 +4,16 @@ import ( "fmt" "net" "testing" + "time" + "github.com/jmhodges/clock" "golang.org/x/net/context" "google.golang.org/grpc" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" testproto "github.com/letsencrypt/boulder/grpc/test_proto" + "github.com/letsencrypt/boulder/metrics" "github.com/letsencrypt/boulder/probs" "github.com/letsencrypt/boulder/test" ) @@ -19,11 +23,15 @@ type errorServer struct { } func (s *errorServer) Chill(_ context.Context, _ *testproto.Time) (*testproto.Time, error) { - return nil, wrapError(s.err) + return nil, s.err } func TestErrorWrapping(t *testing.T) { - srv := grpc.NewServer() + fc := clock.NewFake() + stats := metrics.NewNoopScope() + si := serverInterceptor{stats, fc} + ci := clientInterceptor{stats, fc, time.Second} + srv := grpc.NewServer(grpc.UnaryInterceptor(si.intercept)) es := &errorServer{} testproto.RegisterChillerServer(srv, es) lis, err := net.Listen("tcp", ":") @@ -34,6 +42,7 @@ func TestErrorWrapping(t *testing.T) { conn, err := grpc.Dial( lis.Addr().String(), grpc.WithInsecure(), + grpc.WithUnaryInterceptor(ci.intercept), ) test.AssertNotError(t, err, "Failed to dial grpc test server") client := testproto.NewChillerClient(conn) @@ -41,10 +50,11 @@ func TestErrorWrapping(t *testing.T) { for _, tc := range []error{ core.MalformedRequestError("yup"), &probs.ProblemDetails{Type: probs.MalformedProblem, Detail: "yup"}, + berrors.MalformedError("yup"), } { es.err = tc _, err := client.Chill(context.Background(), &testproto.Time{}) test.Assert(t, err != nil, fmt.Sprintf("nil error returned, expected: %s", err)) - test.AssertDeepEquals(t, unwrapError(err), tc) + test.AssertDeepEquals(t, err, tc) } } diff --git a/grpc/interceptors.go b/grpc/interceptors.go index bbcec2b9f..364f9d8a9 100644 --- a/grpc/interceptors.go +++ b/grpc/interceptors.go @@ -1,7 +1,6 @@ package grpc import ( - "errors" "strings" "time" @@ -9,7 +8,9 @@ import ( "github.com/jmhodges/clock" "golang.org/x/net/context" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/metrics" ) @@ -36,7 +37,7 @@ func cleanMethod(m string, trimService bool) string { func (si *serverInterceptor) intercept(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { if info == nil { si.stats.Inc("NoInfo", 1) - return nil, errors.New("passed nil *grpc.UnaryServerInfo") + return nil, berrors.InternalServerError("passed nil *grpc.UnaryServerInfo") } s := si.clk.Now() methodScope := si.stats.NewScope(cleanMethod(info.FullMethod, true)) @@ -47,7 +48,7 @@ func (si *serverInterceptor) intercept(ctx context.Context, req interface{}, inf methodScope.GaugeDelta("InProgress", -1) if err != nil { methodScope.Inc("Failed", 1) - err = wrapError(err) + err = wrapError(ctx, err) } return resp, err } @@ -84,12 +85,15 @@ func (ci *clientInterceptor) intercept( // Disable fail-fast so RPCs will retry until deadline, even if all backends // are down. opts = append(opts, grpc.FailFast(false)) + // Create grpc/metadata.Metadata to encode internal error type if one is returned + md := metadata.New(nil) + opts = append(opts, grpc.Trailer(&md)) err := grpc_prometheus.UnaryClientInterceptor(localCtx, method, req, reply, cc, invoker, opts...) methodScope.TimingDuration("Latency", ci.clk.Since(s)) methodScope.GaugeDelta("InProgress", -1) if err != nil { methodScope.Inc("Failed", 1) - err = unwrapError(err) + err = unwrapError(err, md) } return err } diff --git a/mocks/mocks.go b/mocks/mocks.go index caa78acb7..d9698d8b8 100644 --- a/mocks/mocks.go +++ b/mocks/mocks.go @@ -19,6 +19,7 @@ import ( "gopkg.in/square/go-jose.v1" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/revocation" ) @@ -145,12 +146,12 @@ func (sa *StorageAuthority) GetRegistrationByKey(_ context.Context, jwk *jose.Js if core.KeyDigestEquals(jwk, test2KeyPublic) { // No key found - return core.Registration{ID: 2}, core.NoSuchRegistrationError("reg not found") + return core.Registration{ID: 2}, berrors.NotFoundError("reg not found") } if core.KeyDigestEquals(jwk, test4KeyPublic) { // No key found - return core.Registration{ID: 5}, core.NoSuchRegistrationError("reg not found") + return core.Registration{ID: 5}, berrors.NotFoundError("reg not found") } if core.KeyDigestEquals(jwk, testE1KeyPublic) { @@ -158,7 +159,7 @@ func (sa *StorageAuthority) GetRegistrationByKey(_ context.Context, jwk *jose.Js } if core.KeyDigestEquals(jwk, testE2KeyPublic) { - return core.Registration{ID: 4}, core.NoSuchRegistrationError("reg not found") + return core.Registration{ID: 4}, berrors.NotFoundError("reg not found") } if core.KeyDigestEquals(jwk, test3KeyPublic) { diff --git a/policy/pa.go b/policy/pa.go index 62c49ae49..903965ea5 100644 --- a/policy/pa.go +++ b/policy/pa.go @@ -15,9 +15,9 @@ import ( "golang.org/x/net/idna" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/features" blog "github.com/letsencrypt/boulder/log" - "github.com/letsencrypt/boulder/probs" "github.com/letsencrypt/boulder/reloader" ) @@ -127,22 +127,22 @@ func suffixMatch(labels []string, suffixSet map[string]bool, properSuffix bool) } var ( - errInvalidIdentifier = probs.Malformed("Invalid identifier type") - errNonPublic = probs.Malformed("Name does not end in a public suffix") - errICANNTLD = probs.Malformed("Name is an ICANN TLD") - errBlacklisted = probs.RejectedIdentifier("Policy forbids issuing for name") - errNotWhitelisted = probs.Malformed("Name is not whitelisted") - errInvalidDNSCharacter = probs.Malformed("Invalid character in DNS name") - errNameTooLong = probs.Malformed("DNS name too long") - errIPAddress = probs.Malformed("Issuance for IP addresses not supported") - errTooManyLabels = probs.Malformed("DNS name has too many labels") - errEmptyName = probs.Malformed("DNS name was empty") - errNameEndsInDot = probs.Malformed("DNS name ends in a period") - errTooFewLabels = probs.Malformed("DNS name does not have enough labels") - errLabelTooShort = probs.Malformed("DNS label is too short") - errLabelTooLong = probs.Malformed("DNS label is too long") - errIDNNotSupported = probs.UnsupportedIdentifier("Internationalized domain names (starting with xn--) not yet supported") - errMalformedIDN = probs.Malformed("DNS label contains malformed punycode") + errInvalidIdentifier = berrors.MalformedError("Invalid identifier type") + errNonPublic = berrors.MalformedError("Name does not end in a public suffix") + errICANNTLD = berrors.MalformedError("Name is an ICANN TLD") + errBlacklisted = berrors.RejectedIdentifierError("Policy forbids issuing for name") + errNotWhitelisted = berrors.MalformedError("Name is not whitelisted") + errInvalidDNSCharacter = berrors.MalformedError("Invalid character in DNS name") + errNameTooLong = berrors.MalformedError("DNS name too long") + errIPAddress = berrors.MalformedError("Issuance for IP addresses not supported") + errTooManyLabels = berrors.MalformedError("DNS name has too many labels") + errEmptyName = berrors.MalformedError("DNS name was empty") + errNameEndsInDot = berrors.MalformedError("DNS name ends in a period") + errTooFewLabels = berrors.MalformedError("DNS name does not have enough labels") + errLabelTooShort = berrors.MalformedError("DNS label is too short") + errLabelTooLong = berrors.MalformedError("DNS label is too long") + errIDNNotSupported = berrors.UnsupportedIdentifierError("Internationalized domain names (starting with xn--) not yet supported") + errMalformedIDN = berrors.MalformedError("DNS label contains malformed punycode") ) // WillingToIssue determines whether the CA is willing to issue for the provided diff --git a/ra/ra.go b/ra/ra.go index 0e2c5d044..9f5fd4dde 100644 --- a/ra/ra.go +++ b/ra/ra.go @@ -21,6 +21,7 @@ import ( "github.com/letsencrypt/boulder/bdns" "github.com/letsencrypt/boulder/core" csrlib "github.com/letsencrypt/boulder/csr" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/features" "github.com/letsencrypt/boulder/goodkey" "github.com/letsencrypt/boulder/grpc" @@ -163,10 +164,10 @@ func (ra *RegistrationAuthorityImpl) updateIssuedCount() error { return nil } -const ( - unparseableEmailDetail = "not a valid e-mail address" - emptyDNSResponseDetail = "empty DNS response" - multipleAddressDetail = "more than one e-mail address" +var ( + unparseableEmailError = berrors.InvalidEmailError("not a valid e-mail address") + emptyDNSResponseError = berrors.InvalidEmailError("empty DNS response") + multipleAddressError = berrors.InvalidEmailError("more than one e-mail address") ) func problemIsTimeout(err error) bool { @@ -177,13 +178,13 @@ func problemIsTimeout(err error) bool { return false } -func validateEmail(ctx context.Context, address string, resolver bdns.DNSResolver) (prob *probs.ProblemDetails) { +func validateEmail(ctx context.Context, address string, resolver bdns.DNSResolver) error { emails, err := mail.ParseAddressList(address) if err != nil { - return probs.InvalidEmail(unparseableEmailDetail) + return unparseableEmailError } if len(emails) > 1 { - return probs.InvalidEmail(multipleAddressDetail) + return multipleAddressError } splitEmail := strings.SplitN(emails[0].Address, "@", -1) domain := strings.ToLower(splitEmail[len(splitEmail)-1]) @@ -209,21 +210,17 @@ func validateEmail(ctx context.Context, address string, resolver bdns.DNSResolve } if errMX != nil { - prob := bdns.ProblemDetailsFromDNSError(errMX) - prob.Type = probs.InvalidEmailProblem - return prob + return berrors.InvalidEmailError(errMX.Error()) } else if len(resultMX) > 0 { return nil } if errA != nil { - prob := bdns.ProblemDetailsFromDNSError(errA) - prob.Type = probs.InvalidEmailProblem - return prob + return berrors.InvalidEmailError(errA.Error()) } else if len(resultA) > 0 { return nil } - return probs.InvalidEmail(emptyDNSResponseDetail) + return emptyDNSResponseError } type certificateRequestEvent struct { @@ -258,7 +255,7 @@ func (ra *RegistrationAuthorityImpl) checkRegistrationLimit(ctx context.Context, if count >= limit.GetThreshold(ip.String(), noRegistrationID) { ra.regByIPStats.Inc("Exceeded", 1) ra.log.Info(fmt.Sprintf("Rate limit exceeded, RegistrationsByIP, IP: %s", ip)) - return core.RateLimitedError("Too many registrations from this IP") + return berrors.RateLimitError("too many registrations for this IP") } ra.regByIPStats.Inc("Pass", 1) } @@ -268,7 +265,7 @@ func (ra *RegistrationAuthorityImpl) checkRegistrationLimit(ctx context.Context, // NewRegistration constructs a new Registration from a request. func (ra *RegistrationAuthorityImpl) NewRegistration(ctx context.Context, init core.Registration) (reg core.Registration, err error) { if err = ra.keyPolicy.GoodKey(init.Key.Key); err != nil { - return core.Registration{}, core.MalformedRequestError(fmt.Sprintf("Invalid public key: %s", err.Error())) + return core.Registration{}, berrors.MalformedError("invalid public key: %s", err.Error()) } if err = ra.checkRegistrationLimit(ctx, init.InitialIP); err != nil { return core.Registration{}, err @@ -292,9 +289,9 @@ func (ra *RegistrationAuthorityImpl) NewRegistration(ctx context.Context, init c // Store the authorization object, then return it reg, err = ra.SA.NewRegistration(ctx, reg) if err != nil { - // InternalServerError since the user-data was validated before being + // berrors.InternalServerError since the user-data was validated before being // passed to the SA. - err = core.InternalServerError(err.Error()) + err = berrors.InternalServerError(err.Error()) } ra.stats.Inc("NewRegistrations", 1) @@ -306,33 +303,38 @@ func (ra *RegistrationAuthorityImpl) validateContacts(ctx context.Context, conta return nil // Nothing to validate } if ra.maxContactsPerReg > 0 && len(*contacts) > ra.maxContactsPerReg { - return core.MalformedRequestError(fmt.Sprintf("Too many contacts provided: %d > %d", - len(*contacts), ra.maxContactsPerReg)) + return berrors.MalformedError( + "too many contacts provided: %d > %d", + len(*contacts), + ra.maxContactsPerReg, + ) } for _, contact := range *contacts { if contact == "" { - return core.MalformedRequestError("Empty contact") + return berrors.MalformedError("empty contact") } parsed, err := url.Parse(contact) if err != nil { - return core.MalformedRequestError("Invalid contact") + return berrors.MalformedError("invalid contact") } if parsed.Scheme != "mailto" { - return core.MalformedRequestError(fmt.Sprintf("Contact method %s is not supported", parsed.Scheme)) + return berrors.MalformedError("contact method %s is not supported", parsed.Scheme) } if !core.IsASCII(contact) { - return core.MalformedRequestError( - fmt.Sprintf("Contact email [%s] contains non-ASCII characters", contact)) + return berrors.MalformedError( + "contact email [%s] contains non-ASCII characters", + contact, + ) } start := ra.clk.Now() ra.stats.Inc("ValidateEmail.Calls", 1) - problem := validateEmail(ctx, parsed.Opaque, ra.DNSResolver) + err = validateEmail(ctx, parsed.Opaque, ra.DNSResolver) ra.stats.TimingDuration("ValidateEmail.Latency", ra.clk.Now().Sub(start)) - if problem != nil { + if err != nil { ra.stats.Inc("ValidateEmail.Errors", 1) - return problem + return err } ra.stats.Inc("ValidateEmail.Successes", 1) } @@ -353,7 +355,7 @@ func (ra *RegistrationAuthorityImpl) checkPendingAuthorizationLimit(ctx context. if count >= limit.GetThreshold(noKey, regID) { ra.pendAuthByRegIDStats.Inc("Exceeded", 1) ra.log.Info(fmt.Sprintf("Rate limit exceeded, PendingAuthorizationsByRegID, regID: %d", regID)) - return core.RateLimitedError("Too many currently pending authorizations.") + return berrors.RateLimitError("too many currently pending authorizations") } ra.pendAuthByRegIDStats.Inc("Pass", 1) } @@ -420,22 +422,27 @@ func (ra *RegistrationAuthorityImpl) NewAuthorization(ctx context.Context, reque if identifier.Type == core.IdentifierDNS { isSafeResp, err := ra.VA.IsSafeDomain(ctx, &vaPB.IsSafeDomainRequest{Domain: &identifier.Value}) if err != nil { - outErr := core.InternalServerError("unable to determine if domain was safe") - ra.log.Warning(fmt.Sprintf("%s: %s", string(outErr), err)) + outErr := berrors.InternalServerError("unable to determine if domain was safe") + ra.log.Warning(fmt.Sprintf("%s: %s", outErr, err)) return authz, outErr } if !isSafeResp.GetIsSafe() { - return authz, core.UnauthorizedError(fmt.Sprintf("%#v was considered an unsafe domain by a third-party API", identifier.Value)) + return authz, berrors.UnauthorizedError( + "%q was considered an unsafe domain by a third-party API", + identifier.Value, + ) } } if ra.reuseValidAuthz { auths, err := ra.SA.GetValidAuthorizations(ctx, regID, []string{identifier.Value}, ra.clk.Now()) if err != nil { - outErr := core.InternalServerError( - fmt.Sprintf("unable to get existing validations for regID: %d, identifier: %s", - regID, identifier.Value)) - ra.log.Warning(string(outErr)) + outErr := berrors.InternalServerError( + "unable to get existing validations for regID: %d, identifier: %s", + regID, + identifier.Value, + ) + ra.log.Warning(outErr.Error()) return authz, outErr } @@ -445,10 +452,11 @@ func (ra *RegistrationAuthorityImpl) NewAuthorization(ctx context.Context, reque // `Challenge` values that the client expects in the result. populatedAuthz, err := ra.SA.GetAuthorization(ctx, existingAuthz.ID) if err != nil { - outErr := core.InternalServerError( - fmt.Sprintf("unable to get existing authorization for auth ID: %s", - existingAuthz.ID)) - ra.log.Warning(fmt.Sprintf("%s: %s", string(outErr), existingAuthz.ID)) + outErr := berrors.InternalServerError( + "unable to get existing authorization for auth ID: %s", + existingAuthz.ID, + ) + ra.log.Warning(fmt.Sprintf("%s: %s", outErr.Error(), existingAuthz.ID)) return authz, outErr } @@ -480,18 +488,18 @@ func (ra *RegistrationAuthorityImpl) NewAuthorization(ctx context.Context, reque // Get a pending Auth first so we can get our ID back, then update with challenges authz, err = ra.SA.NewPendingAuthorization(ctx, authz) if err != nil { - // InternalServerError since the user-data was validated before being + // berrors.InternalServerError since the user-data was validated before being // passed to the SA. - err = core.InternalServerError(fmt.Sprintf("Invalid authorization request: %s", err)) + err = berrors.InternalServerError("invalid authorization request: %s", err) return core.Authorization{}, err } // Check each challenge for sanity. for _, challenge := range authz.Challenges { if !challenge.IsSaneForClientOffer() { - // InternalServerError because we generated these challenges, they should + // berrors.InternalServerError because we generated these challenges, they should // be OK. - err = core.InternalServerError(fmt.Sprintf("Challenge didn't pass sanity check: %+v", challenge)) + err = berrors.InternalServerError("challenge didn't pass sanity check: %+v", challenge) return core.Authorization{}, err } } @@ -523,12 +531,12 @@ func (ra *RegistrationAuthorityImpl) MatchesCSR(cert core.Certificate, csr *x509 hostNames = core.UniqueLowerNames(hostNames) if !core.KeyDigestEquals(parsedCertificate.PublicKey, csr.PublicKey) { - err = core.InternalServerError("Generated certificate public key doesn't match CSR public key") + err = berrors.InternalServerError("generated certificate public key doesn't match CSR public key") return } if !ra.forceCNFromSAN && len(csr.Subject.CommonName) > 0 && parsedCertificate.Subject.CommonName != strings.ToLower(csr.Subject.CommonName) { - err = core.InternalServerError("Generated certificate CommonName doesn't match CSR CommonName") + err = berrors.InternalServerError("generated certificate CommonName doesn't match CSR CommonName") return } // Sort both slices of names before comparison. @@ -536,39 +544,39 @@ func (ra *RegistrationAuthorityImpl) MatchesCSR(cert core.Certificate, csr *x509 sort.Strings(parsedNames) sort.Strings(hostNames) if !reflect.DeepEqual(parsedNames, hostNames) { - err = core.InternalServerError("Generated certificate DNSNames don't match CSR DNSNames") + err = berrors.InternalServerError("generated certificate DNSNames don't match CSR DNSNames") return } if !reflect.DeepEqual(parsedCertificate.IPAddresses, csr.IPAddresses) { - err = core.InternalServerError("Generated certificate IPAddresses don't match CSR IPAddresses") + err = berrors.InternalServerError("generated certificate IPAddresses don't match CSR IPAddresses") return } if !reflect.DeepEqual(parsedCertificate.EmailAddresses, csr.EmailAddresses) { - err = core.InternalServerError("Generated certificate EmailAddresses don't match CSR EmailAddresses") + err = berrors.InternalServerError("generated certificate EmailAddresses don't match CSR EmailAddresses") return } if len(parsedCertificate.Subject.Country) > 0 || len(parsedCertificate.Subject.Organization) > 0 || len(parsedCertificate.Subject.OrganizationalUnit) > 0 || len(parsedCertificate.Subject.Locality) > 0 || len(parsedCertificate.Subject.Province) > 0 || len(parsedCertificate.Subject.StreetAddress) > 0 || len(parsedCertificate.Subject.PostalCode) > 0 { - err = core.InternalServerError("Generated certificate Subject contains fields other than CommonName, or SerialNumber") + err = berrors.InternalServerError("generated certificate Subject contains fields other than CommonName, or SerialNumber") return } now := ra.clk.Now() if now.Sub(parsedCertificate.NotBefore) > time.Hour*24 { - err = core.InternalServerError(fmt.Sprintf("Generated certificate is back dated %s", now.Sub(parsedCertificate.NotBefore))) + err = berrors.InternalServerError("generated certificate is back dated %s", now.Sub(parsedCertificate.NotBefore)) return } if !parsedCertificate.BasicConstraintsValid { - err = core.InternalServerError("Generated certificate doesn't have basic constraints set") + err = berrors.InternalServerError("generated certificate doesn't have basic constraints set") return } if parsedCertificate.IsCA { - err = core.InternalServerError("Generated certificate can sign other certificates") + err = berrors.InternalServerError("generated certificate can sign other certificates") return } if !reflect.DeepEqual(parsedCertificate.ExtKeyUsage, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) { - err = core.InternalServerError("Generated certificate doesn't have correct key usage extensions") + err = berrors.InternalServerError("generated certificate doesn't have correct key usage extensions") return } @@ -592,16 +600,17 @@ func (ra *RegistrationAuthorityImpl) checkAuthorizations(ctx context.Context, na if authz == nil { badNames = append(badNames, name) } else if authz.Expires == nil { - return fmt.Errorf("Found an authorization with a nil Expires field: id %s", authz.ID) + return berrors.InternalServerError("found an authorization with a nil Expires field: id %s", authz.ID) } else if authz.Expires.Before(now) { badNames = append(badNames, name) } } if len(badNames) > 0 { - return core.UnauthorizedError(fmt.Sprintf( - "Authorizations for these names not found or expired: %s", - strings.Join(badNames, ", "))) + return berrors.UnauthorizedError( + "authorizations for these names not found or expired: %s", + strings.Join(badNames, ", "), + ) } return nil } @@ -628,7 +637,7 @@ func (ra *RegistrationAuthorityImpl) NewCertificate(ctx context.Context, req cor }() if regID <= 0 { - err = core.MalformedRequestError(fmt.Sprintf("Invalid registration ID: %d", regID)) + err = berrors.MalformedError("invalid registration ID: %d", regID) return emptyCert, err } @@ -641,8 +650,7 @@ func (ra *RegistrationAuthorityImpl) NewCertificate(ctx context.Context, req cor // Verify the CSR csr := req.CSR if err := csrlib.VerifyCSR(csr, ra.maxNames, &ra.keyPolicy, ra.PA, ra.forceCNFromSAN, regID); err != nil { - err = core.MalformedRequestError(err.Error()) - return emptyCert, err + return emptyCert, berrors.MalformedError(err.Error()) } logEvent.CommonName = csr.Subject.CommonName @@ -653,13 +661,13 @@ func (ra *RegistrationAuthorityImpl) NewCertificate(ctx context.Context, req cor copy(names, csr.DNSNames) if len(names) == 0 { - err = core.UnauthorizedError("CSR has no names in it") + err = berrors.UnauthorizedError("CSR has no names in it") logEvent.Error = err.Error() return emptyCert, err } if core.KeyDigestEquals(csr.PublicKey, registration.Key) { - err = core.MalformedRequestError("Certificate public key must be different than account key") + err = berrors.MalformedError("certificate public key must be different than account key") return emptyCert, err } @@ -703,9 +711,9 @@ func (ra *RegistrationAuthorityImpl) NewCertificate(ctx context.Context, req cor parsedCertificate, err := x509.ParseCertificate([]byte(cert.DER)) if err != nil { - // InternalServerError because the certificate from the CA should be + // berrors.InternalServerError because the certificate from the CA should be // parseable. - err = core.InternalServerError(err.Error()) + err = berrors.InternalServerError("failed to parse certificate: %s", err.Error()) logEvent.Error = err.Error() return emptyCert, err } @@ -785,8 +793,10 @@ func (ra *RegistrationAuthorityImpl) checkCertificatesPerNameLimit(ctx context.C domains := strings.Join(badNames, ", ") ra.certsForDomainStats.Inc("Exceeded", 1) ra.log.Info(fmt.Sprintf("Rate limit exceeded, CertificatesForDomain, regID: %d, domains: %s", regID, domains)) - return core.RateLimitedError(fmt.Sprintf( - "Too many certificates already issued for: %s", domains)) + return berrors.RateLimitError( + "too many certificates already issued for: %s", + domains, + ) } ra.certsForDomainStats.Inc("Pass", 1) @@ -801,10 +811,10 @@ func (ra *RegistrationAuthorityImpl) checkCertificatesPerFQDNSetLimit(ctx contex } names = core.UniqueLowerNames(names) if int(count) > limit.GetThreshold(strings.Join(names, ","), regID) { - return core.RateLimitedError(fmt.Sprintf( - "Too many certificates already issued for exact set of domains: %s", + return berrors.RateLimitError( + "too many certificates already issued for exact set of domains: %s", strings.Join(names, ","), - )) + ) } return nil } @@ -817,12 +827,15 @@ func (ra *RegistrationAuthorityImpl) checkTotalCertificatesLimit() error { // or not yet updated, fail. if ra.clk.Now().After(ra.totalIssuedLastUpdate.Add(5*time.Minute)) || ra.totalIssuedLastUpdate.IsZero() { - return core.InternalServerError(fmt.Sprintf("Total certificate count out of date: updated %s", ra.totalIssuedLastUpdate)) + return berrors.InternalServerError( + "Total certificate count out of date: updated %s", + ra.totalIssuedLastUpdate, + ) } if ra.totalIssuedCount >= totalCertLimits.Threshold { ra.totalCertsStats.Inc("Exceeded", 1) ra.log.Info(fmt.Sprintf("Rate limit exceeded, TotalCertificates, totalIssued: %d, lastUpdated %s", ra.totalIssuedCount, ra.totalIssuedLastUpdate)) - return core.RateLimitedError("Global certificate issuance limit reached. Try again in an hour.") + return berrors.RateLimitError("global certificate issuance limit reached. Try again in an hour") } ra.totalCertsStats.Inc("Pass", 1) return nil @@ -873,9 +886,9 @@ func (ra *RegistrationAuthorityImpl) UpdateRegistration(ctx context.Context, bas err = ra.SA.UpdateRegistration(ctx, base) if err != nil { - // InternalServerError since the user-data was validated before being + // berrors.InternalServerError since the user-data was validated before being // passed to the SA. - err = core.InternalServerError(fmt.Sprintf("Could not update registration: %s", err)) + err = berrors.InternalServerError("Could not update registration: %s", err) return core.Registration{}, err } @@ -948,13 +961,13 @@ func mergeUpdate(r *core.Registration, input core.Registration) bool { func (ra *RegistrationAuthorityImpl) UpdateAuthorization(ctx context.Context, base core.Authorization, challengeIndex int, response core.Challenge) (authz core.Authorization, err error) { // Refuse to update expired authorizations if base.Expires == nil || base.Expires.Before(ra.clk.Now()) { - err = core.NotFoundError("Expired authorization") + err = berrors.MalformedError("expired authorization") return } authz = base if challengeIndex >= len(authz.Challenges) { - err = core.MalformedRequestError(fmt.Sprintf("Invalid challenge index: %d", challengeIndex)) + err = berrors.MalformedError("invalid challenge index '%d'", challengeIndex) return } @@ -963,8 +976,11 @@ func (ra *RegistrationAuthorityImpl) UpdateAuthorization(ctx context.Context, ba if response.Type != "" && ch.Type != response.Type { // TODO(riking): Check the rate on this, uncomment error return if negligible ra.stats.Inc("StartChallengeWrongType", 1) - // err = core.MalformedRequestError(fmt.Sprintf("Invalid update to challenge - provided type was %s but actual type is %s", response.Type, ch.Type)) - // return + // return authz, berrors.MalformedError( + // "invalid challenge update: provided type was %s but actual type is %s", + // response.Type, + // ch.Type, + // ) } // When configured with `reuseValidAuthz` we can expect some clients to try @@ -980,7 +996,7 @@ func (ra *RegistrationAuthorityImpl) UpdateAuthorization(ctx context.Context, ba // Look up the account key for this authorization reg, err := ra.SA.GetRegistration(ctx, authz.RegistrationID) if err != nil { - err = core.InternalServerError(err.Error()) + err = berrors.InternalServerError(err.Error()) return } @@ -988,11 +1004,11 @@ func (ra *RegistrationAuthorityImpl) UpdateAuthorization(ctx context.Context, ba // check it against the value provided expectedKeyAuthorization, err := ch.ExpectedKeyAuthorization(reg.Key) if err != nil { - err = core.InternalServerError("Could not compute expected key authorization value") + err = berrors.InternalServerError("could not compute expected key authorization value") return } if expectedKeyAuthorization != response.ProvidedKeyAuthorization { - err = core.MalformedRequestError("Provided key authorization was incorrect") + err = berrors.MalformedError("provided key authorization was incorrect") return } @@ -1001,7 +1017,7 @@ func (ra *RegistrationAuthorityImpl) UpdateAuthorization(ctx context.Context, ba // Double check before sending to VA if !ch.IsSaneForValidation() { - err = core.MalformedRequestError("Response does not complete challenge") + err = berrors.MalformedError("response does not complete challenge") return } @@ -1009,7 +1025,7 @@ func (ra *RegistrationAuthorityImpl) UpdateAuthorization(ctx context.Context, ba if err = ra.SA.UpdatePendingAuthorization(ctx, authz); err != nil { ra.log.Warning(fmt.Sprintf( "Error calling ra.SA.UpdatePendingAuthorization: %s\n", err.Error())) - err = core.InternalServerError("Could not update pending authorization") + err = berrors.InternalServerError("could not update pending authorization") return } ra.stats.Inc("NewPendingAuthorizations", 1) @@ -1172,11 +1188,11 @@ func (ra *RegistrationAuthorityImpl) onValidationUpdate(ctx context.Context, aut // DeactivateRegistration deactivates a valid registration func (ra *RegistrationAuthorityImpl) DeactivateRegistration(ctx context.Context, reg core.Registration) error { if reg.Status != core.StatusValid { - return core.MalformedRequestError("Only valid registrations can be deactivated") + return berrors.MalformedError("only valid registrations can be deactivated") } err := ra.SA.DeactivateRegistration(ctx, reg.ID) if err != nil { - return core.InternalServerError(err.Error()) + return berrors.InternalServerError(err.Error()) } return nil } @@ -1184,11 +1200,11 @@ func (ra *RegistrationAuthorityImpl) DeactivateRegistration(ctx context.Context, // DeactivateAuthorization deactivates a currently valid authorization func (ra *RegistrationAuthorityImpl) DeactivateAuthorization(ctx context.Context, auth core.Authorization) error { if auth.Status != core.StatusValid && auth.Status != core.StatusPending { - return core.MalformedRequestError("Only valid and pending authorizations can be deactivated") + return berrors.MalformedError("only valid and pending authorizations can be deactivated") } err := ra.SA.DeactivateAuthorization(ctx, auth.ID) if err != nil { - return core.InternalServerError(err.Error()) + return berrors.InternalServerError(err.Error()) } return nil } diff --git a/ra/ra_test.go b/ra/ra_test.go index 7c29060ce..fa3d0183b 100644 --- a/ra/ra_test.go +++ b/ra/ra_test.go @@ -23,6 +23,7 @@ import ( "github.com/letsencrypt/boulder/bdns" "github.com/letsencrypt/boulder/cmd" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/features" "github.com/letsencrypt/boulder/goodkey" blog "github.com/letsencrypt/boulder/log" @@ -324,9 +325,9 @@ func TestValidateEmail(t *testing.T) { input string expected string }{ - {"an email`", unparseableEmailDetail}, - {"a@always.invalid", emptyDNSResponseDetail}, - {"a@email.com, b@email.com", multipleAddressDetail}, + {"an email`", unparseableEmailError.Error()}, + {"a@always.invalid", emptyDNSResponseError.Error()}, + {"a@email.com, b@email.com", multipleAddressError.Error()}, {"a@always.error", "DNS problem: networking error looking up A for always.error"}, } testSuccesses := []string{ @@ -339,20 +340,21 @@ func TestValidateEmail(t *testing.T) { } for _, tc := range testFailures { - problem := validateEmail(context.Background(), tc.input, &bdns.MockDNSResolver{}) - if problem.Type != probs.InvalidEmailProblem { - t.Errorf("validateEmail(%q): got problem type %#v, expected %#v", tc.input, problem.Type, probs.InvalidEmailProblem) + err := validateEmail(context.Background(), tc.input, &bdns.MockDNSResolver{}) + if !berrors.Is(err, berrors.InvalidEmail) { + t.Errorf("validateEmail(%q): got error %#v, expected type berrors.InvalidEmail", tc.input, err) } - if problem.Detail != tc.expected { + + if err.Error() != tc.expected { t.Errorf("validateEmail(%q): got %#v, expected %#v", - tc.input, problem.Detail, tc.expected) + tc.input, err.Error(), tc.expected) } } for _, addr := range testSuccesses { - if prob := validateEmail(context.Background(), addr, &bdns.MockDNSResolver{}); prob != nil { - t.Errorf("validateEmail(%q): expected success, but it failed: %s", - addr, prob) + if err := validateEmail(context.Background(), addr, &bdns.MockDNSResolver{}); err != nil { + t.Errorf("validateEmail(%q): expected success, but it failed: %#v", + addr, err) } } } @@ -680,11 +682,8 @@ func TestNewAuthorizationInvalidName(t *testing.T) { if err == nil { t.Fatalf("NewAuthorization succeeded for 127.0.0.1, should have failed") } - if _, ok := err.(*probs.ProblemDetails); !ok { - t.Errorf("Wrong type for NewAuthorization error: expected *probs.ProblemDetails, got %T", err) - } - if err.(*probs.ProblemDetails).Type != probs.MalformedProblem { - t.Errorf("Incorrect problem type. Expected %s got %s", probs.MalformedProblem, err.(*probs.ProblemDetails).Type) + if !berrors.Is(err, berrors.Malformed) { + t.Errorf("expected berrors.BoulderError with internal type berrors.Malformed, got %T", err) } } @@ -806,7 +805,7 @@ func TestCertificateKeyNotEqualAccountKey(t *testing.T) { // Registration has key == AccountKeyA _, err = ra.NewCertificate(ctx, certRequest, Registration.ID) test.AssertError(t, err, "Should have rejected cert with key = account key") - test.AssertEquals(t, err.Error(), "Certificate public key must be different than account key") + test.AssertEquals(t, err.Error(), "certificate public key must be different than account key") t.Log("DONE TestCertificateKeyNotEqualAccountKey") } @@ -1108,7 +1107,7 @@ func TestCheckCertificatesPerNameLimit(t *testing.T) { mockSA.nameCounts["example.com"] = 10 err = ra.checkCertificatesPerNameLimit(ctx, []string{"www.example.com", "example.com"}, rlp, 99) test.AssertError(t, err, "incorrectly failed to rate limit example.com") - if _, ok := err.(core.RateLimitedError); !ok { + if !berrors.Is(err, berrors.RateLimit) { t.Errorf("Incorrect error type %#v", err) } @@ -1127,7 +1126,7 @@ func TestCheckCertificatesPerNameLimit(t *testing.T) { mockSA.nameCounts["bigissuer.com"] = 100 err = ra.checkCertificatesPerNameLimit(ctx, []string{"www.example.com", "subdomain.bigissuer.com"}, rlp, 99) test.AssertError(t, err, "incorrectly failed to rate limit bigissuer") - if _, ok := err.(core.RateLimitedError); !ok { + if !berrors.Is(err, berrors.RateLimit) { t.Errorf("Incorrect error type") } @@ -1135,7 +1134,7 @@ func TestCheckCertificatesPerNameLimit(t *testing.T) { mockSA.nameCounts["smallissuer.co.uk"] = 1 err = ra.checkCertificatesPerNameLimit(ctx, []string{"www.smallissuer.co.uk"}, rlp, 99) test.AssertError(t, err, "incorrectly failed to rate limit smallissuer") - if _, ok := err.(core.RateLimitedError); !ok { + if !berrors.Is(err, berrors.RateLimit) { t.Errorf("Incorrect error type %#v", err) } } diff --git a/rpc/amqp-rpc.go b/rpc/amqp-rpc.go index 89672f9b2..51d2beed8 100644 --- a/rpc/amqp-rpc.go +++ b/rpc/amqp-rpc.go @@ -10,6 +10,7 @@ import ( "fmt" "io/ioutil" "os" + "strconv" "strings" "sync" "sync/atomic" @@ -21,6 +22,7 @@ import ( "github.com/letsencrypt/boulder/cmd" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" blog "github.com/letsencrypt/boulder/log" "github.com/letsencrypt/boulder/metrics" "github.com/letsencrypt/boulder/probs" @@ -200,6 +202,9 @@ func wrapError(err error) *rpcError { wrapped.Type = string(terr.Type) wrapped.Value = terr.Detail wrapped.HTTPStatus = terr.HTTPStatus + case *berrors.BoulderError: + wrapped.Type = fmt.Sprintf("berr:%d", terr.Type) + wrapped.Value = terr.Detail } return wrapped } @@ -236,6 +241,17 @@ func unwrapError(rpcError *rpcError) error { HTTPStatus: rpcError.HTTPStatus, } } + if strings.HasPrefix(rpcError.Type, "berr:") { + errType, decErr := strconv.Atoi(rpcError.Type[5:]) + if decErr != nil { + return berrors.InternalServerError( + "failed to decode error type, decoding error %q, wrapped error %q", + decErr, + rpcError.Value, + ) + } + return berrors.New(berrors.ErrorType(errType), rpcError.Value) + } return errors.New(rpcError.Value) } } @@ -388,7 +404,7 @@ func (rpc *AmqpRPCServer) replyTooManyRequests(msg amqp.Delivery) error { // remaining messages are processed. func (rpc *AmqpRPCServer) Start(c *cmd.AMQPConfig) error { tooManyGoroutines := rpcResponse{ - Error: wrapError(core.TooManyRPCRequestsError("RPC server has spawned too many Goroutines")), + Error: wrapError(berrors.TooManyRequestsError("RPC server has spawned too many Goroutines")), } tooManyRequestsResponse, err := json.Marshal(tooManyGoroutines) if err != nil { diff --git a/rpc/amqp-rpc_test.go b/rpc/amqp-rpc_test.go index d1452eb4e..eaed05f8c 100644 --- a/rpc/amqp-rpc_test.go +++ b/rpc/amqp-rpc_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/probs" "github.com/letsencrypt/boulder/test" ) @@ -56,6 +57,10 @@ func TestWrapError(t *testing.T) { errors.New(""), errors.New(""), }, + { + berrors.MalformedError("foo"), + berrors.MalformedError("foo"), + }, } for i, tc := range complicated { actual := unwrapError(wrapError(tc.given)) diff --git a/sa/sa.go b/sa/sa.go index f16572e9b..3dfb86852 100644 --- a/sa/sa.go +++ b/sa/sa.go @@ -5,7 +5,6 @@ import ( "crypto/x509" "database/sql" "encoding/json" - "errors" "fmt" "math/big" "net" @@ -18,6 +17,7 @@ import ( jose "gopkg.in/square/go-jose.v1" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/features" blog "github.com/letsencrypt/boulder/log" "github.com/letsencrypt/boulder/revocation" @@ -122,9 +122,7 @@ func (ssa *SQLStorageAuthority) GetRegistration(ctx context.Context, id int64) ( model, err = selectRegistration(ssa.dbMap, query, id) } if err == sql.ErrNoRows { - return core.Registration{}, core.NoSuchRegistrationError( - fmt.Sprintf("No registrations with ID %d", id), - ) + return core.Registration{}, berrors.NotFoundError("registration with ID '%d' not found", id) } if err != nil { return core.Registration{}, err @@ -150,8 +148,7 @@ func (ssa *SQLStorageAuthority) GetRegistrationByKey(ctx context.Context, key *j model, err = selectRegistration(ssa.dbMap, query, sha) } if err == sql.ErrNoRows { - msg := fmt.Sprintf("No registrations with public key sha256 %s", sha) - return core.Registration{}, core.NoSuchRegistrationError(msg) + return core.Registration{}, berrors.NotFoundError("no registrations with public key sha256 %q", sha) } if err != nil { return core.Registration{}, err @@ -218,7 +215,7 @@ func (ssa *SQLStorageAuthority) GetAuthorization(ctx context.Context, id string) // domain names from the parameters that the account has authorizations for. func (ssa *SQLStorageAuthority) GetValidAuthorizations(ctx context.Context, registrationID int64, names []string, now time.Time) (map[string]*core.Authorization, error) { if len(names) == 0 { - return nil, errors.New("GetValidAuthorizations: no names received") + return nil, berrors.InternalServerError("no names received") } params := make([]interface{}, len(names)) @@ -421,7 +418,7 @@ func (ssa *SQLStorageAuthority) GetCertificate(ctx context.Context, serial strin cert, err := SelectCertificate(ssa.dbMap, "WHERE serial = ?", serial) if err == sql.ErrNoRows { - return core.Certificate{}, core.NotFoundError(fmt.Sprintf("No certificate found for %s", serial)) + return core.Certificate{}, berrors.NotFoundError("certificate with serial %q not found", serial) } if err != nil { return core.Certificate{}, err @@ -520,7 +517,7 @@ func (ssa *SQLStorageAuthority) MarkCertificateRevoked(ctx context.Context, seri return err } if n == 0 { - err = errors.New("No certificate updated. Maybe the lock column was off?") + err = berrors.InternalServerError("no certificate updated") err = Rollback(tx, err) return err } @@ -539,8 +536,7 @@ func (ssa *SQLStorageAuthority) UpdateRegistration(ctx context.Context, reg core model, err = selectRegistration(ssa.dbMap, query, reg.ID) } if err == sql.ErrNoRows { - msg := fmt.Sprintf("No registrations with ID %d", reg.ID) - return core.NoSuchRegistrationError(msg) + return berrors.NotFoundError("registration with ID '%d' not found", reg.ID) } updatedRegModel, err := registrationToModel(®) @@ -569,8 +565,7 @@ func (ssa *SQLStorageAuthority) UpdateRegistration(ctx context.Context, reg core return err } if n == 0 { - msg := fmt.Sprintf("Requested registration not found %d", reg.ID) - return core.NoSuchRegistrationError(msg) + return berrors.NotFoundError("registration with ID '%d' not found", reg.ID) } return nil @@ -636,23 +631,24 @@ func (ssa *SQLStorageAuthority) UpdatePendingAuthorization(ctx context.Context, } if !statusIsPending(authz.Status) { - err = errors.New("Use FinalizeAuthorization() to update to a final status") + err = berrors.InternalServerError("authorization is not pending") return Rollback(tx, err) } if existingFinal(tx, authz.ID) { - err = errors.New("Cannot update a final authorization") + err = berrors.InternalServerError("cannot update a finalized authorization") return Rollback(tx, err) } if !existingPending(tx, authz.ID) { - err = errors.New("Requested authorization not found " + authz.ID) + err = berrors.InternalServerError("authorization with ID '%d' not found", authz.ID) return Rollback(tx, err) } pa, err := selectPendingAuthz(tx, "WHERE id = ?", authz.ID) if err == sql.ErrNoRows { - return Rollback(tx, fmt.Errorf("No pending authorization with ID %s", authz.ID)) + err = berrors.InternalServerError("authorization with ID '%d' not found", authz.ID) + return Rollback(tx, err) } if err != nil { return Rollback(tx, err) @@ -680,18 +676,18 @@ func (ssa *SQLStorageAuthority) FinalizeAuthorization(ctx context.Context, authz // Check that a pending authz exists if !existingPending(tx, authz.ID) { - err = errors.New("Cannot finalize an authorization that is not pending") + err = berrors.InternalServerError("authorization with ID %q not found", authz.ID) return Rollback(tx, err) } if statusIsPending(authz.Status) { - err = errors.New("Cannot finalize to a non-final status") + err = berrors.InternalServerError("authorization with ID %q is not pending", authz.ID) return Rollback(tx, err) } auth := &authzModel{authz} pa, err := selectPendingAuthz(tx, "WHERE id = ?", authz.ID) if err == sql.ErrNoRows { - return Rollback(tx, fmt.Errorf("No pending authorization with ID %s", authz.ID)) + return Rollback(tx, berrors.InternalServerError("authorization with ID %q not found", authz.ID)) } if err != nil { return Rollback(tx, err) diff --git a/sa/sa_test.go b/sa/sa_test.go index f0a24dc58..1cda30c37 100644 --- a/sa/sa_test.go +++ b/sa/sa_test.go @@ -21,6 +21,7 @@ import ( jose "gopkg.in/square/go-jose.v1" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/features" blog "github.com/letsencrypt/boulder/log" "github.com/letsencrypt/boulder/revocation" @@ -122,19 +123,19 @@ func TestNoSuchRegistrationErrors(t *testing.T) { defer cleanUp() _, err := sa.GetRegistration(ctx, 100) - if _, ok := err.(core.NoSuchRegistrationError); !ok { - t.Errorf("GetRegistration: expected NoSuchRegistrationError, got %T type error (%s)", err, err) + if !berrors.Is(err, berrors.NotFound) { + t.Errorf("GetRegistration: expected a berrors.NotFound type error, got %T type error (%s)", err, err) } jwk := satest.GoodJWK() _, err = sa.GetRegistrationByKey(ctx, jwk) - if _, ok := err.(core.NoSuchRegistrationError); !ok { - t.Errorf("GetRegistrationByKey: expected a NoSuchRegistrationError, got %T type error (%s)", err, err) + if !berrors.Is(err, berrors.NotFound) { + t.Errorf("GetRegistrationByKey: expected a berrors.NotFound type error, got %T type error (%s)", err, err) } err = sa.UpdateRegistration(ctx, core.Registration{ID: 100, Key: jwk}) - if _, ok := err.(core.NoSuchRegistrationError); !ok { - t.Errorf("UpdateRegistration: expected a NoSuchRegistrationError, got %T type error (%v)", err, err) + if !berrors.Is(err, berrors.NotFound) { + t.Errorf("UpdateRegistration: expected a berrors.NotFound type error, got %T type error (%v)", err, err) } } diff --git a/va/va.go b/va/va.go index f82ec3c10..34cc60112 100644 --- a/va/va.go +++ b/va/va.go @@ -105,7 +105,7 @@ func (va ValidationAuthorityImpl) getAddr(ctx context.Context, hostname string) addrs, err := va.dnsResolver.LookupHost(ctx, hostname) if err != nil { va.log.Debug(fmt.Sprintf("%s DNS failure: %s", hostname, err)) - problem := bdns.ProblemDetailsFromDNSError(err) + problem := probs.ConnectionFailure(err.Error()) return net.IP{}, nil, problem } @@ -538,7 +538,7 @@ func (va *ValidationAuthorityImpl) validateDNS01(ctx context.Context, identifier if err != nil { va.log.Info(fmt.Sprintf("Failed to lookup txt records for %s. err=[%#v] errStr=[%s]", identifier, err, err)) - return nil, bdns.ProblemDetailsFromDNSError(err) + return nil, probs.ConnectionFailure(err.Error()) } // If there weren't any TXT records return a distinct error message to allow @@ -572,7 +572,7 @@ func (va *ValidationAuthorityImpl) checkCAA(ctx context.Context, identifier core func (va *ValidationAuthorityImpl) checkCAAInternal(ctx context.Context, ident core.AcmeIdentifier) *probs.ProblemDetails { present, valid, err := va.checkCAARecords(ctx, ident) if err != nil { - return bdns.ProblemDetailsFromDNSError(err) + return probs.ConnectionFailure(err.Error()) } va.log.AuditInfo(fmt.Sprintf( "Checked CAA records for %s, [Present: %t, Valid for issuance: %t]", diff --git a/va/va_test.go b/va/va_test.go index 7e59866d1..9ffc6c875 100644 --- a/va/va_test.go +++ b/va/va_test.go @@ -1097,7 +1097,7 @@ func TestCheckCAAFallback(t *testing.T) { prob = va.checkCAA(ctx, core.AcmeIdentifier{Value: "bad-local-resolver.com", Type: "dns"}) test.Assert(t, prob != nil, "returned ProblemDetails was nil") test.AssertEquals(t, prob.Type, probs.ConnectionProblem) - test.AssertEquals(t, prob.Detail, "server failure at resolver") + test.AssertEquals(t, prob.Detail, "DNS problem: query timed out looking up CAA for bad-local-resolver.com") } func TestParseResults(t *testing.T) { diff --git a/wfe/jose.go b/wfe/jose.go index 6dda904ba..36ca1960f 100644 --- a/wfe/jose.go +++ b/wfe/jose.go @@ -3,10 +3,10 @@ package wfe import ( "crypto/ecdsa" "crypto/rsa" - "fmt" - "github.com/letsencrypt/boulder/core" "gopkg.in/square/go-jose.v1" + + berrors "github.com/letsencrypt/boulder/errors" ) func algorithmForKey(key *jose.JsonWebKey) (string, error) { @@ -23,7 +23,7 @@ func algorithmForKey(key *jose.JsonWebKey) (string, error) { return string(jose.ES512), nil } } - return "", core.SignatureValidationError("no signature algorithms suitable for given key type") + return "", berrors.SignatureValidationError("no signature algorithms suitable for given key type") } const ( @@ -44,15 +44,16 @@ func checkAlgorithm(key *jose.JsonWebKey, parsedJws *jose.JsonWebSignature) (str } jwsAlgorithm := parsedJws.Signatures[0].Header.Algorithm if jwsAlgorithm != algorithm { - return invalidJWSAlgorithm, - core.SignatureValidationError(fmt.Sprintf( - "signature type '%s' in JWS header is not supported, expected one of RS256, ES256, ES384 or ES512", - jwsAlgorithm)) + return invalidJWSAlgorithm, berrors.SignatureValidationError( + "signature type '%s' in JWS header is not supported, expected one of RS256, ES256, ES384 or ES512", + jwsAlgorithm, + ) } if key.Algorithm != "" && key.Algorithm != algorithm { - return invalidAlgorithmOnKey, - core.SignatureValidationError(fmt.Sprintf( - "algorithm '%s' on JWK is unacceptable", key.Algorithm)) + return invalidAlgorithmOnKey, berrors.SignatureValidationError( + "algorithm '%s' on JWK is unacceptable", + key.Algorithm, + ) } return "", nil } diff --git a/wfe/probs.go b/wfe/probs.go new file mode 100644 index 000000000..7b5c01e9b --- /dev/null +++ b/wfe/probs.go @@ -0,0 +1,80 @@ +package wfe + +import ( + "fmt" + "net/http" + + "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" + "github.com/letsencrypt/boulder/probs" +) + +func problemDetailsForBoulderError(err *berrors.BoulderError, msg string) *probs.ProblemDetails { + switch err.Type { + case berrors.NotSupported: + return &probs.ProblemDetails{ + Type: probs.ServerInternalProblem, + Detail: fmt.Sprintf("%s :: %s", msg, err), + HTTPStatus: http.StatusNotImplemented, + } + case berrors.Malformed, berrors.SignatureValidation: + return probs.Malformed(fmt.Sprintf("%s :: %s", msg, err)) + case berrors.Unauthorized: + return probs.Unauthorized(fmt.Sprintf("%s :: %s", msg, err)) + case berrors.NotFound: + return probs.NotFound(fmt.Sprintf("%s :: %s", msg, err)) + case berrors.RateLimit: + return probs.RateLimited(fmt.Sprintf("%s :: %s", msg, err)) + case berrors.InternalServer, berrors.TooManyRequests: + // Internal server error messages may include sensitive data, so we do + // not include it. + return probs.ServerInternal(msg) + case berrors.RejectedIdentifier: + return probs.RejectedIdentifier(msg) + case berrors.UnsupportedIdentifier: + return probs.UnsupportedIdentifier(msg) + default: + // Internal server error messages may include sensitive data, so we do + // not include it. + return probs.ServerInternal(msg) + } +} + +// problemDetailsForError turns an error into a ProblemDetails with the special +// case of returning the same error back if its already a ProblemDetails. If the +// error is of an type unknown to ProblemDetailsForError, it will return a +// ServerInternal ProblemDetails. +func problemDetailsForError(err error, msg string) *probs.ProblemDetails { + switch e := err.(type) { + case *probs.ProblemDetails: + return e + case *berrors.BoulderError: + return problemDetailsForBoulderError(e, msg) + case core.MalformedRequestError: + return probs.Malformed(fmt.Sprintf("%s :: %s", msg, err)) + case core.NotSupportedError: + return &probs.ProblemDetails{ + Type: probs.ServerInternalProblem, + Detail: fmt.Sprintf("%s :: %s", msg, err), + HTTPStatus: http.StatusNotImplemented, + } + case core.UnauthorizedError: + return probs.Unauthorized(fmt.Sprintf("%s :: %s", msg, err)) + case core.NotFoundError: + return probs.NotFound(fmt.Sprintf("%s :: %s", msg, err)) + case core.LengthRequiredError: + prob := probs.Malformed("missing Content-Length header") + prob.HTTPStatus = http.StatusLengthRequired + return prob + case core.SignatureValidationError: + return probs.Malformed(fmt.Sprintf("%s :: %s", msg, err)) + case core.RateLimitedError: + return probs.RateLimited(fmt.Sprintf("%s :: %s", msg, err)) + case core.BadNonceError: + return probs.BadNonce(fmt.Sprintf("%s :: %s", msg, err)) + default: + // Internal server error messages may include sensitive data, so we do + // not include it. + return probs.ServerInternal(msg) + } +} diff --git a/wfe/probs_test.go b/wfe/probs_test.go new file mode 100644 index 000000000..52124ae23 --- /dev/null +++ b/wfe/probs_test.go @@ -0,0 +1,55 @@ +package wfe + +import ( + "reflect" + "testing" + + "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" + "github.com/letsencrypt/boulder/probs" + "github.com/letsencrypt/boulder/test" +) + +func TestProblemDetailsFromError(t *testing.T) { + testCases := []struct { + err error + statusCode int + problem probs.ProblemType + }{ + // boulder/core error types + {core.InternalServerError("foo"), 500, probs.ServerInternalProblem}, + {core.NotSupportedError("foo"), 501, probs.ServerInternalProblem}, + {core.MalformedRequestError("foo"), 400, probs.MalformedProblem}, + {core.UnauthorizedError("foo"), 403, probs.UnauthorizedProblem}, + {core.NotFoundError("foo"), 404, probs.MalformedProblem}, + {core.SignatureValidationError("foo"), 400, probs.MalformedProblem}, + {core.RateLimitedError("foo"), 429, probs.RateLimitedProblem}, + {core.LengthRequiredError("foo"), 411, probs.MalformedProblem}, + {core.BadNonceError("foo"), 400, probs.BadNonceProblem}, + // boulder/errors error types + {berrors.InternalServerError("foo"), 500, probs.ServerInternalProblem}, + {berrors.NotSupportedError("foo"), 501, probs.ServerInternalProblem}, + {berrors.MalformedError("foo"), 400, probs.MalformedProblem}, + {berrors.UnauthorizedError("foo"), 403, probs.UnauthorizedProblem}, + {berrors.NotFoundError("foo"), 404, probs.MalformedProblem}, + {berrors.SignatureValidationError("foo"), 400, probs.MalformedProblem}, + {berrors.RateLimitError("foo"), 429, probs.RateLimitedProblem}, + } + for _, c := range testCases { + p := problemDetailsForError(c.err, "k") + if p.HTTPStatus != c.statusCode { + t.Errorf("Incorrect status code for %s. Expected %d, got %d", reflect.TypeOf(c.err).Name(), c.statusCode, p.HTTPStatus) + } + if probs.ProblemType(p.Type) != c.problem { + t.Errorf("Expected problem urn %#v, got %#v", c.problem, p.Type) + } + } + + expected := &probs.ProblemDetails{ + Type: probs.MalformedProblem, + HTTPStatus: 200, + Detail: "gotcha", + } + p := problemDetailsForError(expected, "k") + test.AssertDeepEquals(t, expected, p) +} diff --git a/wfe/wfe.go b/wfe/wfe.go index 9a356d684..4f493bd53 100644 --- a/wfe/wfe.go +++ b/wfe/wfe.go @@ -22,6 +22,7 @@ import ( jose "gopkg.in/square/go-jose.v1" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/features" "github.com/letsencrypt/boulder/goodkey" blog "github.com/letsencrypt/boulder/log" @@ -464,7 +465,9 @@ func (wfe *WebFrontEndImpl) verifyPOST(ctx context.Context, logEvent *requestEve // Special case: If no registration was found, but regCheck is false, use an // empty registration and the submitted key. The caller is expected to do some // validation on the returned key. - if _, ok := err.(core.NoSuchRegistrationError); ok && !regCheck { + // TODO(#2600): Remove core.NoSuchRegistrationError check once boulder/errors + // code is deployed + if _, ok := err.(core.NoSuchRegistrationError); (ok || berrors.Is(err, berrors.NotFound)) && !regCheck { // When looking up keys from the registrations DB, we can be confident they // are "good". But when we are verifying against any submitted key, we want // to check its quality before doing the verify. @@ -478,7 +481,9 @@ func (wfe *WebFrontEndImpl) verifyPOST(ctx context.Context, logEvent *requestEve // For all other errors, or if regCheck is true, return error immediately. wfe.stats.Inc("Errors.UnableToGetRegistrationByKey", 1) logEvent.AddError("unable to fetch registration by the given JWK: %s", err) - if _, ok := err.(core.NoSuchRegistrationError); ok { + // TODO(#2600): Remove core.NoSuchRegistrationError check once boulder/errors + // code is deployed + if _, ok := err.(core.NoSuchRegistrationError); ok || berrors.Is(err, berrors.NotFound) { return nil, nil, reg, probs.Unauthorized(unknownKey) } @@ -630,7 +635,7 @@ func (wfe *WebFrontEndImpl) NewRegistration(ctx context.Context, logEvent *reque reg, err := wfe.RA.NewRegistration(ctx, init) if err != nil { logEvent.AddError("unable to create new registration: %s", err) - wfe.sendError(response, logEvent, core.ProblemDetailsForError(err, "Error creating new registration"), err) + wfe.sendError(response, logEvent, problemDetailsForError(err, "Error creating new registration"), err) return } logEvent.Requester = reg.ID @@ -686,7 +691,7 @@ func (wfe *WebFrontEndImpl) NewAuthorization(ctx context.Context, logEvent *requ authz, err := wfe.RA.NewAuthorization(ctx, init, currReg.ID) if err != nil { logEvent.AddError("unable to create new authz: %s", err) - wfe.sendError(response, logEvent, core.ProblemDetailsForError(err, "Error creating new authz"), err) + wfe.sendError(response, logEvent, problemDetailsForError(err, "Error creating new authz"), err) return } logEvent.Extra["AuthzID"] = authz.ID @@ -816,7 +821,7 @@ func (wfe *WebFrontEndImpl) RevokeCertificate(ctx context.Context, logEvent *req err = wfe.RA.RevokeCertificateWithReg(ctx, *parsedCertificate, reason, registration.ID) if err != nil { logEvent.AddError("failed to revoke certificate: %s", err) - wfe.sendError(response, logEvent, core.ProblemDetailsForError(err, "Failed to revoke certificate"), err) + wfe.sendError(response, logEvent, problemDetailsForError(err, "Failed to revoke certificate"), err) } else { wfe.log.Debug(fmt.Sprintf("Revoked %v", serial)) response.WriteHeader(http.StatusOK) @@ -911,7 +916,7 @@ func (wfe *WebFrontEndImpl) NewCertificate(ctx context.Context, logEvent *reques cert, err := wfe.RA.NewCertificate(ctx, certificateRequest, reg.ID) if err != nil { logEvent.AddError("unable to create new cert: %s", err) - wfe.sendError(response, logEvent, core.ProblemDetailsForError(err, "Error creating new cert"), err) + wfe.sendError(response, logEvent, problemDetailsForError(err, "Error creating new cert"), err) return } @@ -1103,7 +1108,7 @@ func (wfe *WebFrontEndImpl) postChallenge( updatedAuthorization, err := wfe.RA.UpdateAuthorization(ctx, authz, challengeIndex, challengeUpdate) if err != nil { logEvent.AddError("unable to update challenge: %s", err) - wfe.sendError(response, logEvent, core.ProblemDetailsForError(err, "Unable to update challenge"), err) + wfe.sendError(response, logEvent, problemDetailsForError(err, "Unable to update challenge"), err) return } @@ -1205,7 +1210,7 @@ func (wfe *WebFrontEndImpl) Registration(ctx context.Context, logEvent *requestE updatedReg, err := wfe.RA.UpdateRegistration(ctx, currReg, update) if err != nil { logEvent.AddError("unable to update registration: %s", err) - wfe.sendError(response, logEvent, core.ProblemDetailsForError(err, "Unable to update registration"), err) + wfe.sendError(response, logEvent, problemDetailsForError(err, "Unable to update registration"), err) return } @@ -1251,7 +1256,7 @@ func (wfe *WebFrontEndImpl) deactivateAuthorization(ctx context.Context, authz * err = wfe.RA.DeactivateAuthorization(ctx, *authz) if err != nil { logEvent.AddError("unable to deactivate authorization", err) - wfe.sendError(response, logEvent, core.ProblemDetailsForError(err, "Error deactivating authorization"), err) + wfe.sendError(response, logEvent, problemDetailsForError(err, "Error deactivating authorization"), err) return false } // Since the authorization passed to DeactivateAuthorization isn't @@ -1501,7 +1506,7 @@ func (wfe *WebFrontEndImpl) KeyRollover(ctx context.Context, logEvent *requestEv updatedReg, err := wfe.RA.UpdateRegistration(ctx, reg, core.Registration{Key: newKey}) if err != nil { logEvent.AddError("unable to update registration: %s", err) - wfe.sendError(response, logEvent, core.ProblemDetailsForError(err, "Unable to update registration"), err) + wfe.sendError(response, logEvent, problemDetailsForError(err, "Unable to update registration"), err) return } @@ -1520,7 +1525,7 @@ func (wfe *WebFrontEndImpl) deactivateRegistration(ctx context.Context, reg core err := wfe.RA.DeactivateRegistration(ctx, reg) if err != nil { logEvent.AddError("unable to deactivate registration", err) - wfe.sendError(response, logEvent, core.ProblemDetailsForError(err, "Error deactivating registration"), err) + wfe.sendError(response, logEvent, problemDetailsForError(err, "Error deactivating registration"), err) return } reg.Status = core.StatusDeactivated diff --git a/wfe/wfe_test.go b/wfe/wfe_test.go index 2d745dcf5..5fb1914bd 100644 --- a/wfe/wfe_test.go +++ b/wfe/wfe_test.go @@ -25,6 +25,7 @@ import ( "gopkg.in/square/go-jose.v1" "github.com/letsencrypt/boulder/core" + berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/features" "github.com/letsencrypt/boulder/goodkey" blog "github.com/letsencrypt/boulder/log" @@ -809,7 +810,7 @@ func TestIssueCertificate(t *testing.T) { }`, wfe.nonceService))) assertJSONEquals(t, responseWriter.Body.String(), - `{"type":"urn:acme:error:unauthorized","detail":"Error creating new cert :: Authorizations for these names not found or expired: meep.com","status":403}`) + `{"type":"urn:acme:error:unauthorized","detail":"Error creating new cert :: authorizations for these names not found or expired: meep.com","status":403}`) assertCsrLogged(t, mockLog) mockLog.Clear() @@ -1209,16 +1210,16 @@ func makeRevokeRequestJSON(reason *revocation.Reason) ([]byte, error) { return revokeRequestJSON, nil } -// An SA mock that always returns NoSuchRegistrationError. This is necessary +// An SA mock that always returns a berrors.NotFound type error. This is necessary // because the standard mock in our mocks package always returns a given test // registration when GetRegistrationByKey is called, and we want to get a -// NoSuchRegistrationError for tests that pass regCheck = false to verifyPOST. +// berrors.NotFound type error for tests that pass regCheck = false to verifyPOST. type mockSANoSuchRegistration struct { core.StorageGetter } func (msa mockSANoSuchRegistration) GetRegistrationByKey(ctx context.Context, jwk *jose.JsonWebKey) (core.Registration, error) { - return core.Registration{}, core.NoSuchRegistrationError("reg not found") + return core.Registration{}, berrors.NotFoundError("reg not found") } // Valid revocation request for existing, non-revoked cert, signed with cert @@ -1825,7 +1826,7 @@ func TestBadKeyCSR(t *testing.T) { assertJSONEquals(t, responseWriter.Body.String(), - `{"type":"urn:acme:error:malformed","detail":"Invalid key in certificate request :: Key too small: 512","status":400}`) + `{"type":"urn:acme:error:malformed","detail":"Invalid key in certificate request :: key too small: 512","status":400}`) } // This uses httptest.NewServer because ServeMux.ServeHTTP won't prevent the From acbd9ed3a73828b626724f3f2f32b05dac12f855 Mon Sep 17 00:00:00 2001 From: Roland Bracewell Shoemaker Date: Fri, 24 Mar 2017 11:04:35 -0700 Subject: [PATCH 5/5] Purge both pending and finalized authorizations as well as challenges (#2149) Fixes #2148. Instead of just doing a blanket `DELETE FROM ...` this changes the `expired-authz-purger` to select all of the expired IDs (for both pending and finalized authorizations) then loop over them deleting each and its associated challenges from their respective tables. Local testing indicates the performance of this is not awful but we should do a test run on staging to verify. If it ends up taking way too long to run there the easiest optimization would be to turn the slice of IDs into a channel and run multiple workers looping over the channel deleting stuff instead of just a single one. Makes a few small integration test changes in order to facilitate deleting both pending and finalized authorizations. --- cmd/expired-authz-purger/main.go | 95 ++++++++++++++++++--------- cmd/expired-authz-purger/main_test.go | 41 +++++++++--- test/integration-test.py | 22 ++++--- test/sa_db_users.sql | 2 + 4 files changed, 111 insertions(+), 49 deletions(-) diff --git a/cmd/expired-authz-purger/main.go b/cmd/expired-authz-purger/main.go index aa547b4ad..b83b769a8 100644 --- a/cmd/expired-authz-purger/main.go +++ b/cmd/expired-authz-purger/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "database/sql" "encoding/json" "flag" "fmt" @@ -45,19 +46,45 @@ type expiredAuthzPurger struct { batchSize int64 } -func (p *expiredAuthzPurger) purgeAuthzs(purgeBefore time.Time, yes bool) (int64, error) { - if !yes { - var count int - err := p.db.SelectOne(&count, `SELECT COUNT(1) FROM pendingAuthorizations AS pa WHERE expires <= ?`, purgeBefore) - if err != nil { - return 0, err +func (p *expiredAuthzPurger) purge(table string, yes bool, purgeBefore time.Time) error { + var ids []string + for { + var idBatch []string + var query string + switch table { + case "pendingAuthorizations": + query = "SELECT id FROM pendingAuthorizations WHERE expires <= ? LIMIT ? OFFSET ?" + case "authz": + query = "SELECT id FROM authz WHERE expires <= ? LIMIT ? OFFSET ?" } + _, err := p.db.Select( + &idBatch, + query, + purgeBefore, + p.batchSize, + len(ids), + ) + if err != nil && err != sql.ErrNoRows { + return err + } + if len(idBatch) == 0 { + break + } + ids = append(ids, idBatch...) + } + + if !yes { reader := bufio.NewReader(os.Stdin) for { - fmt.Fprintf(os.Stdout, "\nAbout to purge %d pending authorizations, proceed? [y/N]: ", count) + fmt.Fprintf( + os.Stdout, + "\nAbout to purge %d authorizations from %s and all associated challenges, proceed? [y/N]: ", + len(ids), + table, + ) text, err := reader.ReadString('\n') if err != nil { - return 0, err + return err } text = strings.ToLower(text) if text != "y\n" && text != "n\n" && text != "\n" { @@ -71,33 +98,39 @@ func (p *expiredAuthzPurger) purgeAuthzs(purgeBefore time.Time, yes bool) (int64 } } - rowsAffected := int64(0) - for { - result, err := p.db.Exec(` - DELETE FROM pendingAuthorizations - WHERE expires <= ? - LIMIT ? - `, - purgeBefore, - p.batchSize, - ) + for _, id := range ids { + // Delete challenges + authorization. We delete challenges first and fail out + // if that doesn't succeed so that we don't ever orphan challenges which would + // require a relatively expensive join to then find. + _, err := p.db.Exec("DELETE FROM challenges WHERE authorizationID = ?", id) if err != nil { - return rowsAffected, err + return err } - rows, err := result.RowsAffected() + var query string + switch table { + case "pendingAuthorizations": + query = "DELETE FROM pendingAuthorizations WHERE id = ?" + case "authz": + query = "DELETE FROM authz WHERE id = ?" + } + _, err = p.db.Exec(query, id) if err != nil { - return rowsAffected, err - } - - p.stats.Inc("PendingAuthzDeleted", rows) - rowsAffected += rows - p.log.Info(fmt.Sprintf("Progress: Deleted %d (%d total) expired pending authorizations", rows, rowsAffected)) - - if rows < p.batchSize { - p.log.Info(fmt.Sprintf("Deleted a total of %d expired pending authorizations", rowsAffected)) - return rowsAffected, nil + return err } } + + p.log.Info(fmt.Sprintf("Deleted a total of %d expired authorizations from %s", len(ids), table)) + return nil +} + +func (p *expiredAuthzPurger) purgeAuthzs(purgeBefore time.Time, yes bool) error { + for _, table := range []string{"pendingAuthorizations", "authz"} { + err := p.purge(table, yes, purgeBefore) + if err != nil { + return err + } + } + return nil } func main() { @@ -144,6 +177,6 @@ func main() { os.Exit(1) } purgeBefore := purger.clk.Now().Add(-config.ExpiredAuthzPurger.GracePeriod.Duration) - _, err = purger.purgeAuthzs(purgeBefore, *yes) + err = purger.purgeAuthzs(purgeBefore, *yes) cmd.FailOnError(err, "Failed to purge authorizations") } diff --git a/cmd/expired-authz-purger/main_test.go b/cmd/expired-authz-purger/main_test.go index 47cf60b68..d82040c4b 100644 --- a/cmd/expired-authz-purger/main_test.go +++ b/cmd/expired-authz-purger/main_test.go @@ -34,24 +34,47 @@ func TestPurgeAuthzs(t *testing.T) { p := expiredAuthzPurger{stats, log, fc, dbMap, 1} - rows, err := p.purgeAuthzs(time.Time{}, true) + err = p.purgeAuthzs(time.Time{}, true) test.AssertNotError(t, err, "purgeAuthzs failed") - test.AssertEquals(t, rows, int64(0)) old, new := fc.Now().Add(-time.Hour), fc.Now().Add(time.Hour) reg := satest.CreateWorkingRegistration(t, ssa) - _, err = ssa.NewPendingAuthorization(context.Background(), core.Authorization{RegistrationID: reg.ID, Expires: &old}) + _, err = ssa.NewPendingAuthorization(context.Background(), core.Authorization{ + RegistrationID: reg.ID, + Expires: &old, + Challenges: []core.Challenge{{ID: 1}}, + }) test.AssertNotError(t, err, "NewPendingAuthorization failed") - _, err = ssa.NewPendingAuthorization(context.Background(), core.Authorization{RegistrationID: reg.ID, Expires: &old}) + _, err = ssa.NewPendingAuthorization(context.Background(), core.Authorization{ + RegistrationID: reg.ID, + Expires: &old, + Challenges: []core.Challenge{{ID: 2}}, + }) test.AssertNotError(t, err, "NewPendingAuthorization failed") - _, err = ssa.NewPendingAuthorization(context.Background(), core.Authorization{RegistrationID: reg.ID, Expires: &new}) + _, err = ssa.NewPendingAuthorization(context.Background(), core.Authorization{ + RegistrationID: reg.ID, + Expires: &new, + Challenges: []core.Challenge{{ID: 3}}, + }) test.AssertNotError(t, err, "NewPendingAuthorization failed") - rows, err = p.purgeAuthzs(fc.Now(), true) + err = p.purgeAuthzs(fc.Now(), true) test.AssertNotError(t, err, "purgeAuthzs failed") - test.AssertEquals(t, rows, int64(2)) - rows, err = p.purgeAuthzs(fc.Now().Add(time.Hour), true) + count, err := dbMap.SelectInt("SELECT COUNT(1) FROM pendingAuthorizations") + test.AssertNotError(t, err, "dbMap.SelectInt failed") + test.AssertEquals(t, count, int64(1)) + count, err = dbMap.SelectInt("SELECT COUNT(1) FROM challenges") + test.AssertNotError(t, err, "dbMap.SelectInt failed") + test.AssertEquals(t, count, int64(1)) + + err = p.purgeAuthzs(fc.Now().Add(time.Hour), true) test.AssertNotError(t, err, "purgeAuthzs failed") - test.AssertEquals(t, rows, int64(1)) + count, err = dbMap.SelectInt("SELECT COUNT(1) FROM pendingAuthorizations") + test.AssertNotError(t, err, "dbMap.SelectInt failed") + test.AssertEquals(t, count, int64(0)) + count, err = dbMap.SelectInt("SELECT COUNT(1) FROM challenges") + test.AssertNotError(t, err, "dbMap.SelectInt failed") + test.AssertEquals(t, count, int64(0)) + } diff --git a/test/integration-test.py b/test/integration-test.py index 985726578..2def39aa1 100644 --- a/test/integration-test.py +++ b/test/integration-test.py @@ -281,14 +281,13 @@ def get_future_output(cmd, date): return run(cmd, env={'FAKECLOCK': date.strftime("%a %b %d %H:%M:%S UTC %Y")}) def test_expired_authz_purger(): - def expect(target_time, num): - expected_output = '' - if num is not None: - expected_output = 'Deleted a total of %d expired pending authorizations' % num - + def expect(target_time, num, table): out = get_future_output("./bin/expired-authz-purger --config cmd/expired-authz-purger/config.json --yes", target_time) if 'via FAKECLOCK' not in out: raise Exception("expired-authz-purger was not built with `integration` build tag") + if num is None: + return + expected_output = 'Deleted a total of %d expired authorizations from %s' % (num, table) if expected_output not in out: raise Exception("expired-authz-purger did not print '%s'. Output:\n%s" % ( expected_output, out)) @@ -296,7 +295,7 @@ def test_expired_authz_purger(): now = datetime.datetime.utcnow() # Run the purger once to clear out any backlog so we have a clean slate. - expect(now, None) + expect(now, None, "") # Make an authz, but don't attempt its challenges. chisel.make_client().request_domain_challenges("eap-test.com") @@ -304,8 +303,13 @@ def test_expired_authz_purger(): # Run the authz twice: Once immediate, expecting nothing to be purged, and # once as if it were the future, expecting one purged authz. after_grace_period = now + datetime.timedelta(days=+14, minutes=+3) - expect(now, 0) - expect(after_grace_period, 1) + expect(now, 0, "pendingAuthorizations") + expect(after_grace_period, 1, "pendingAuthorizations") + + auth_and_issue([random_domain()]) + after_grace_period = now + datetime.timedelta(days=+67, minutes=+3) + expect(now, 0, "authz") + expect(after_grace_period, 1, "authz") def test_certificates_per_name(): chisel.expect_problem("urn:acme:error:rateLimited", @@ -394,9 +398,9 @@ def main(): def run_chisel(): # TODO(https://github.com/letsencrypt/boulder/issues/2521): Add TLS-SNI test. + test_expired_authz_purger() test_ct_submission() test_gsb_lookups() - test_expired_authz_purger() test_multidomain() test_expiration_mailer() test_caa() diff --git a/test/sa_db_users.sql b/test/sa_db_users.sql index 059b8551c..a75024efb 100644 --- a/test/sa_db_users.sql +++ b/test/sa_db_users.sql @@ -69,6 +69,8 @@ GRANT SELECT ON certificates TO 'cert_checker'@'localhost'; -- Expired authorization purger GRANT SELECT,DELETE ON pendingAuthorizations TO 'purger'@'localhost'; +GRANT SELECT,DELETE ON authz TO 'purger'@'localhost'; +GRANT SELECT,DELETE ON challenges TO 'purger'@'localhost'; -- Test setup and teardown GRANT ALL PRIVILEGES ON * to 'test_setup'@'localhost';