From 7f04092e7297b4703e2dc5cf2fc9657a3a873670 Mon Sep 17 00:00:00 2001 From: Aaron Gable Date: Tue, 19 Mar 2024 08:39:00 -0700 Subject: [PATCH] Simplify streaming rows from the database (#7372) Create a new method on the gorm rows object which runs a small closure for every row retrieved from the database. Use this new method to remove 20 lines of boilerplate from five different SA methods and rocsp-tool. --- cmd/rocsp-tool/client.go | 20 +++--- db/gorm.go | 39 ++++++++++++ db/interfaces.go | 1 + sa/saro.go | 132 +++++---------------------------------- 4 files changed, 67 insertions(+), 125 deletions(-) diff --git a/cmd/rocsp-tool/client.go b/cmd/rocsp-tool/client.go index c7b1178ac..815973c7f 100644 --- a/cmd/rocsp-tool/client.go +++ b/cmd/rocsp-tool/client.go @@ -9,6 +9,9 @@ import ( "time" "github.com/jmhodges/clock" + "golang.org/x/crypto/ocsp" + "google.golang.org/protobuf/types/known/timestamppb" + capb "github.com/letsencrypt/boulder/ca/proto" "github.com/letsencrypt/boulder/core" "github.com/letsencrypt/boulder/db" @@ -16,8 +19,6 @@ import ( "github.com/letsencrypt/boulder/rocsp" "github.com/letsencrypt/boulder/sa" "github.com/letsencrypt/boulder/test/ocsp/helper" - "golang.org/x/crypto/ocsp" - "google.golang.org/protobuf/types/known/timestamppb" ) type client struct { @@ -169,21 +170,15 @@ func (cl *client) scanFromDBOneBatch(ctx context.Context, prevID int64, frequenc if err != nil { return -1, fmt.Errorf("scanning certificateStatus: %w", err) } - defer func() { - rerr := rows.Close() - if rerr != nil { - cl.logger.Infof("closing rows: %s", rerr) - } - }() var scanned int var previousID int64 - for rows.Next() { + err = rows.ForEach(func(row *sa.CertStatusMetadata) error { <-rowTicker.C status, err := rows.Get() if err != nil { - return -1, fmt.Errorf("scanning row %d (previous ID %d): %w", scanned, previousID, err) + return fmt.Errorf("scanning row %d (previous ID %d): %w", scanned, previousID, err) } scanned++ inflightIDs.add(uint64(status.ID)) @@ -195,7 +190,12 @@ func (cl *client) scanFromDBOneBatch(ctx context.Context, prevID int64, frequenc } output <- status previousID = status.ID + return nil + }) + if err != nil { + return -1, err } + return previousID, nil } diff --git a/db/gorm.go b/db/gorm.go index 112eddcff..477202faf 100644 --- a/db/gorm.go +++ b/db/gorm.go @@ -142,6 +142,45 @@ type rows[T any] struct { numCols int } +// ForEach calls the given function with each model object retrieved by +// repeatedly calling .Get(). It closes the rows object when it hits an error +// or finishes iterating over the rows, so it can only be called once. This is +// the intended way to use the result of QueryContext or QueryFrom; the other +// methods on this type are lower-level and intended for advanced use only. +func (r rows[T]) ForEach(do func(*T) error) (err error) { + defer func() { + // Close the row reader when we exit. Use the named error return to combine + // any error from normal execution with any error from closing. + closeErr := r.Close() + if closeErr != nil && err != nil { + err = fmt.Errorf("%w; also while closing the row reader: %w", err, closeErr) + } else if closeErr != nil { + err = closeErr + } + // If closeErr is nil, then just leaving the existing named return alone + // will do the right thing. + }() + + for r.Next() { + row, err := r.Get() + if err != nil { + return fmt.Errorf("reading row: %w", err) + } + + err = do(row) + if err != nil { + return err + } + } + + err = r.Err() + if err != nil { + return fmt.Errorf("iterating over row reader: %w", err) + } + + return nil +} + // Next is a wrapper around sql.Rows.Next(). It must be called before every call // to Get(), including the first. func (r rows[T]) Next() bool { diff --git a/db/interfaces.go b/db/interfaces.go index d0b555ae7..f08e25888 100644 --- a/db/interfaces.go +++ b/db/interfaces.go @@ -95,6 +95,7 @@ type MappedSelector[T any] interface { // Rows is anything which lets you iterate over the result rows of a SELECT // query. It is similar to sql.Rows, but generic. type Rows[T any] interface { + ForEach(func(*T) error) error Next() bool Get() (*T, error) Err() error diff --git a/sa/saro.go b/sa/saro.go index 59de85b65..4ab2d5813 100644 --- a/sa/saro.go +++ b/sa/saro.go @@ -1273,9 +1273,8 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden if err != nil { return fmt.Errorf("starting db query: %w", err) } - defer rows.Close() - for rows.Next() { + return rows.ForEach(func(row *incidentSerialModel) error { // Scan the row into the model. Note: the fields must be passed in the // same order as the columns returned by the query above. ism, err := rows.Get() @@ -1296,17 +1295,8 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden ispb.LastNoticeSent = timestamppb.New(*ism.LastNoticeSent) } - err = stream.Send(ispb) - if err != nil { - return err - } - } - - err = rows.Err() - if err != nil { - return err - } - return nil + return stream.Send(ispb) + }) } func (ssa *SQLStorageAuthority) SerialsForIncident(req *sapb.SerialsForIncidentRequest, stream sapb.StorageAuthority_SerialsForIncidentServer) error { @@ -1363,43 +1353,21 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromRevokedCertificatesTable(re return fmt.Errorf("reading db: %w", err) } - defer func() { - err := rows.Close() - if err != nil { - ssa.log.AuditErrf("closing row reader: %s", err) - } - }() - - for rows.Next() { - row, err := rows.Get() - if err != nil { - return fmt.Errorf("reading row: %w", err) - } - + return rows.ForEach(func(row *revokedCertModel) error { // Double-check that the cert wasn't revoked between the time at which we're // constructing this snapshot CRL and right now. If the cert was revoked // at-or-after the "atTime", we'll just include it in the next generation // of CRLs. if row.RevokedDate.After(atTime) || row.RevokedDate.Equal(atTime) { - continue + return nil } - err = stream.Send(&corepb.CRLEntry{ + return stream.Send(&corepb.CRLEntry{ Serial: row.Serial, Reason: int32(row.RevokedReason), RevokedAt: timestamppb.New(row.RevokedDate), }) - if err != nil { - return fmt.Errorf("sending crl entry: %w", err) - } - } - - err = rows.Err() - if err != nil { - return fmt.Errorf("iterating over row reader: %w", err) - } - - return nil + }) } // getRevokedCertsFromCertificateStatusTable uses the new old certificateStatus @@ -1429,43 +1397,21 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromCertificateStatusTable(req return fmt.Errorf("reading db: %w", err) } - defer func() { - err := rows.Close() - if err != nil { - ssa.log.AuditErrf("closing row reader: %s", err) - } - }() - - for rows.Next() { - row, err := rows.Get() - if err != nil { - return fmt.Errorf("reading row: %w", err) - } - + return rows.ForEach(func(row *crlEntryModel) error { // Double-check that the cert wasn't revoked between the time at which we're // constructing this snapshot CRL and right now. If the cert was revoked // at-or-after the "atTime", we'll just include it in the next generation // of CRLs. if row.RevokedDate.After(atTime) || row.RevokedDate.Equal(atTime) { - continue + return nil } - err = stream.Send(&corepb.CRLEntry{ + return stream.Send(&corepb.CRLEntry{ Serial: row.Serial, Reason: int32(row.RevokedReason), RevokedAt: timestamppb.New(row.RevokedDate), }) - if err != nil { - return fmt.Errorf("sending crl entry: %w", err) - } - } - - err = rows.Err() - if err != nil { - return fmt.Errorf("iterating over row reader: %w", err) - } - - return nil + }) } // GetMaxExpiration returns the timestamp of the farthest-future notAfter date @@ -1587,31 +1533,9 @@ func (ssa *SQLStorageAuthorityRO) GetSerialsByKey(req *sapb.SPKIHash, stream sap return fmt.Errorf("reading db: %w", err) } - defer func() { - err := rows.Close() - if err != nil { - ssa.log.AuditErrf("closing row reader: %s", err) - } - }() - - for rows.Next() { - row, err := rows.Get() - if err != nil { - return fmt.Errorf("reading row: %w", err) - } - - err = stream.Send(&sapb.Serial{Serial: row.CertSerial}) - if err != nil { - return fmt.Errorf("sending serial: %w", err) - } - } - - err = rows.Err() - if err != nil { - return fmt.Errorf("iterating over row reader: %w", err) - } - - return nil + return rows.ForEach(func(row *keyHashModel) error { + return stream.Send(&sapb.Serial{Serial: row.CertSerial}) + }) } func (ssa *SQLStorageAuthority) GetSerialsByKey(req *sapb.SPKIHash, stream sapb.StorageAuthority_GetSerialsByKeyServer) error { @@ -1639,31 +1563,9 @@ func (ssa *SQLStorageAuthorityRO) GetSerialsByAccount(req *sapb.RegistrationID, return fmt.Errorf("reading db: %w", err) } - defer func() { - err := rows.Close() - if err != nil { - ssa.log.AuditErrf("closing row reader: %s", err) - } - }() - - for rows.Next() { - row, err := rows.Get() - if err != nil { - return fmt.Errorf("reading row: %w", err) - } - - err = stream.Send(&sapb.Serial{Serial: row.Serial}) - if err != nil { - return fmt.Errorf("sending serial: %w", err) - } - } - - err = rows.Err() - if err != nil { - return fmt.Errorf("iterating over row reader: %w", err) - } - - return nil + return rows.ForEach(func(row *recordedSerialModel) error { + return stream.Send(&sapb.Serial{Serial: row.Serial}) + }) } func (ssa *SQLStorageAuthority) GetSerialsByAccount(req *sapb.RegistrationID, stream sapb.StorageAuthority_GetSerialsByAccountServer) error {