Simplify streaming rows from the database (#7372)

Create a new method on the gorm rows object which runs a small closure
for every row retrieved from the database. Use this new method to remove
20 lines of boilerplate from five different SA methods and rocsp-tool.
This commit is contained in:
Aaron Gable 2024-03-19 08:39:00 -07:00 committed by GitHub
parent 5e68cbe552
commit 7f04092e72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 67 additions and 125 deletions

View File

@ -9,6 +9,9 @@ import (
"time"
"github.com/jmhodges/clock"
"golang.org/x/crypto/ocsp"
"google.golang.org/protobuf/types/known/timestamppb"
capb "github.com/letsencrypt/boulder/ca/proto"
"github.com/letsencrypt/boulder/core"
"github.com/letsencrypt/boulder/db"
@ -16,8 +19,6 @@ import (
"github.com/letsencrypt/boulder/rocsp"
"github.com/letsencrypt/boulder/sa"
"github.com/letsencrypt/boulder/test/ocsp/helper"
"golang.org/x/crypto/ocsp"
"google.golang.org/protobuf/types/known/timestamppb"
)
type client struct {
@ -169,21 +170,15 @@ func (cl *client) scanFromDBOneBatch(ctx context.Context, prevID int64, frequenc
if err != nil {
return -1, fmt.Errorf("scanning certificateStatus: %w", err)
}
defer func() {
rerr := rows.Close()
if rerr != nil {
cl.logger.Infof("closing rows: %s", rerr)
}
}()
var scanned int
var previousID int64
for rows.Next() {
err = rows.ForEach(func(row *sa.CertStatusMetadata) error {
<-rowTicker.C
status, err := rows.Get()
if err != nil {
return -1, fmt.Errorf("scanning row %d (previous ID %d): %w", scanned, previousID, err)
return fmt.Errorf("scanning row %d (previous ID %d): %w", scanned, previousID, err)
}
scanned++
inflightIDs.add(uint64(status.ID))
@ -195,7 +190,12 @@ func (cl *client) scanFromDBOneBatch(ctx context.Context, prevID int64, frequenc
}
output <- status
previousID = status.ID
return nil
})
if err != nil {
return -1, err
}
return previousID, nil
}

View File

@ -142,6 +142,45 @@ type rows[T any] struct {
numCols int
}
// ForEach calls the given function with each model object retrieved by
// repeatedly calling .Get(). It closes the rows object when it hits an error
// or finishes iterating over the rows, so it can only be called once. This is
// the intended way to use the result of QueryContext or QueryFrom; the other
// methods on this type are lower-level and intended for advanced use only.
func (r rows[T]) ForEach(do func(*T) error) (err error) {
defer func() {
// Close the row reader when we exit. Use the named error return to combine
// any error from normal execution with any error from closing.
closeErr := r.Close()
if closeErr != nil && err != nil {
err = fmt.Errorf("%w; also while closing the row reader: %w", err, closeErr)
} else if closeErr != nil {
err = closeErr
}
// If closeErr is nil, then just leaving the existing named return alone
// will do the right thing.
}()
for r.Next() {
row, err := r.Get()
if err != nil {
return fmt.Errorf("reading row: %w", err)
}
err = do(row)
if err != nil {
return err
}
}
err = r.Err()
if err != nil {
return fmt.Errorf("iterating over row reader: %w", err)
}
return nil
}
// Next is a wrapper around sql.Rows.Next(). It must be called before every call
// to Get(), including the first.
func (r rows[T]) Next() bool {

View File

@ -95,6 +95,7 @@ type MappedSelector[T any] interface {
// Rows is anything which lets you iterate over the result rows of a SELECT
// query. It is similar to sql.Rows, but generic.
type Rows[T any] interface {
ForEach(func(*T) error) error
Next() bool
Get() (*T, error)
Err() error

View File

@ -1273,9 +1273,8 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden
if err != nil {
return fmt.Errorf("starting db query: %w", err)
}
defer rows.Close()
for rows.Next() {
return rows.ForEach(func(row *incidentSerialModel) error {
// Scan the row into the model. Note: the fields must be passed in the
// same order as the columns returned by the query above.
ism, err := rows.Get()
@ -1296,17 +1295,8 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden
ispb.LastNoticeSent = timestamppb.New(*ism.LastNoticeSent)
}
err = stream.Send(ispb)
if err != nil {
return err
}
}
err = rows.Err()
if err != nil {
return err
}
return nil
return stream.Send(ispb)
})
}
func (ssa *SQLStorageAuthority) SerialsForIncident(req *sapb.SerialsForIncidentRequest, stream sapb.StorageAuthority_SerialsForIncidentServer) error {
@ -1363,43 +1353,21 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromRevokedCertificatesTable(re
return fmt.Errorf("reading db: %w", err)
}
defer func() {
err := rows.Close()
if err != nil {
ssa.log.AuditErrf("closing row reader: %s", err)
}
}()
for rows.Next() {
row, err := rows.Get()
if err != nil {
return fmt.Errorf("reading row: %w", err)
}
return rows.ForEach(func(row *revokedCertModel) error {
// Double-check that the cert wasn't revoked between the time at which we're
// constructing this snapshot CRL and right now. If the cert was revoked
// at-or-after the "atTime", we'll just include it in the next generation
// of CRLs.
if row.RevokedDate.After(atTime) || row.RevokedDate.Equal(atTime) {
continue
return nil
}
err = stream.Send(&corepb.CRLEntry{
return stream.Send(&corepb.CRLEntry{
Serial: row.Serial,
Reason: int32(row.RevokedReason),
RevokedAt: timestamppb.New(row.RevokedDate),
})
if err != nil {
return fmt.Errorf("sending crl entry: %w", err)
}
}
err = rows.Err()
if err != nil {
return fmt.Errorf("iterating over row reader: %w", err)
}
return nil
})
}
// getRevokedCertsFromCertificateStatusTable uses the new old certificateStatus
@ -1429,43 +1397,21 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromCertificateStatusTable(req
return fmt.Errorf("reading db: %w", err)
}
defer func() {
err := rows.Close()
if err != nil {
ssa.log.AuditErrf("closing row reader: %s", err)
}
}()
for rows.Next() {
row, err := rows.Get()
if err != nil {
return fmt.Errorf("reading row: %w", err)
}
return rows.ForEach(func(row *crlEntryModel) error {
// Double-check that the cert wasn't revoked between the time at which we're
// constructing this snapshot CRL and right now. If the cert was revoked
// at-or-after the "atTime", we'll just include it in the next generation
// of CRLs.
if row.RevokedDate.After(atTime) || row.RevokedDate.Equal(atTime) {
continue
return nil
}
err = stream.Send(&corepb.CRLEntry{
return stream.Send(&corepb.CRLEntry{
Serial: row.Serial,
Reason: int32(row.RevokedReason),
RevokedAt: timestamppb.New(row.RevokedDate),
})
if err != nil {
return fmt.Errorf("sending crl entry: %w", err)
}
}
err = rows.Err()
if err != nil {
return fmt.Errorf("iterating over row reader: %w", err)
}
return nil
})
}
// GetMaxExpiration returns the timestamp of the farthest-future notAfter date
@ -1587,31 +1533,9 @@ func (ssa *SQLStorageAuthorityRO) GetSerialsByKey(req *sapb.SPKIHash, stream sap
return fmt.Errorf("reading db: %w", err)
}
defer func() {
err := rows.Close()
if err != nil {
ssa.log.AuditErrf("closing row reader: %s", err)
}
}()
for rows.Next() {
row, err := rows.Get()
if err != nil {
return fmt.Errorf("reading row: %w", err)
}
err = stream.Send(&sapb.Serial{Serial: row.CertSerial})
if err != nil {
return fmt.Errorf("sending serial: %w", err)
}
}
err = rows.Err()
if err != nil {
return fmt.Errorf("iterating over row reader: %w", err)
}
return nil
return rows.ForEach(func(row *keyHashModel) error {
return stream.Send(&sapb.Serial{Serial: row.CertSerial})
})
}
func (ssa *SQLStorageAuthority) GetSerialsByKey(req *sapb.SPKIHash, stream sapb.StorageAuthority_GetSerialsByKeyServer) error {
@ -1639,31 +1563,9 @@ func (ssa *SQLStorageAuthorityRO) GetSerialsByAccount(req *sapb.RegistrationID,
return fmt.Errorf("reading db: %w", err)
}
defer func() {
err := rows.Close()
if err != nil {
ssa.log.AuditErrf("closing row reader: %s", err)
}
}()
for rows.Next() {
row, err := rows.Get()
if err != nil {
return fmt.Errorf("reading row: %w", err)
}
err = stream.Send(&sapb.Serial{Serial: row.Serial})
if err != nil {
return fmt.Errorf("sending serial: %w", err)
}
}
err = rows.Err()
if err != nil {
return fmt.Errorf("iterating over row reader: %w", err)
}
return nil
return rows.ForEach(func(row *recordedSerialModel) error {
return stream.Send(&sapb.Serial{Serial: row.Serial})
})
}
func (ssa *SQLStorageAuthority) GetSerialsByAccount(req *sapb.RegistrationID, stream sapb.StorageAuthority_GetSerialsByAccountServer) error {