Simplify streaming rows from the database (#7372)
Create a new method on the gorm rows object which runs a small closure for every row retrieved from the database. Use this new method to remove 20 lines of boilerplate from five different SA methods and rocsp-tool.
This commit is contained in:
parent
5e68cbe552
commit
7f04092e72
|
@ -9,6 +9,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/jmhodges/clock"
|
||||
"golang.org/x/crypto/ocsp"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
capb "github.com/letsencrypt/boulder/ca/proto"
|
||||
"github.com/letsencrypt/boulder/core"
|
||||
"github.com/letsencrypt/boulder/db"
|
||||
|
@ -16,8 +19,6 @@ import (
|
|||
"github.com/letsencrypt/boulder/rocsp"
|
||||
"github.com/letsencrypt/boulder/sa"
|
||||
"github.com/letsencrypt/boulder/test/ocsp/helper"
|
||||
"golang.org/x/crypto/ocsp"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
|
@ -169,21 +170,15 @@ func (cl *client) scanFromDBOneBatch(ctx context.Context, prevID int64, frequenc
|
|||
if err != nil {
|
||||
return -1, fmt.Errorf("scanning certificateStatus: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
rerr := rows.Close()
|
||||
if rerr != nil {
|
||||
cl.logger.Infof("closing rows: %s", rerr)
|
||||
}
|
||||
}()
|
||||
|
||||
var scanned int
|
||||
var previousID int64
|
||||
for rows.Next() {
|
||||
err = rows.ForEach(func(row *sa.CertStatusMetadata) error {
|
||||
<-rowTicker.C
|
||||
|
||||
status, err := rows.Get()
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("scanning row %d (previous ID %d): %w", scanned, previousID, err)
|
||||
return fmt.Errorf("scanning row %d (previous ID %d): %w", scanned, previousID, err)
|
||||
}
|
||||
scanned++
|
||||
inflightIDs.add(uint64(status.ID))
|
||||
|
@ -195,7 +190,12 @@ func (cl *client) scanFromDBOneBatch(ctx context.Context, prevID int64, frequenc
|
|||
}
|
||||
output <- status
|
||||
previousID = status.ID
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
return previousID, nil
|
||||
}
|
||||
|
||||
|
|
39
db/gorm.go
39
db/gorm.go
|
@ -142,6 +142,45 @@ type rows[T any] struct {
|
|||
numCols int
|
||||
}
|
||||
|
||||
// ForEach calls the given function with each model object retrieved by
|
||||
// repeatedly calling .Get(). It closes the rows object when it hits an error
|
||||
// or finishes iterating over the rows, so it can only be called once. This is
|
||||
// the intended way to use the result of QueryContext or QueryFrom; the other
|
||||
// methods on this type are lower-level and intended for advanced use only.
|
||||
func (r rows[T]) ForEach(do func(*T) error) (err error) {
|
||||
defer func() {
|
||||
// Close the row reader when we exit. Use the named error return to combine
|
||||
// any error from normal execution with any error from closing.
|
||||
closeErr := r.Close()
|
||||
if closeErr != nil && err != nil {
|
||||
err = fmt.Errorf("%w; also while closing the row reader: %w", err, closeErr)
|
||||
} else if closeErr != nil {
|
||||
err = closeErr
|
||||
}
|
||||
// If closeErr is nil, then just leaving the existing named return alone
|
||||
// will do the right thing.
|
||||
}()
|
||||
|
||||
for r.Next() {
|
||||
row, err := r.Get()
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading row: %w", err)
|
||||
}
|
||||
|
||||
err = do(row)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = r.Err()
|
||||
if err != nil {
|
||||
return fmt.Errorf("iterating over row reader: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Next is a wrapper around sql.Rows.Next(). It must be called before every call
|
||||
// to Get(), including the first.
|
||||
func (r rows[T]) Next() bool {
|
||||
|
|
|
@ -95,6 +95,7 @@ type MappedSelector[T any] interface {
|
|||
// Rows is anything which lets you iterate over the result rows of a SELECT
|
||||
// query. It is similar to sql.Rows, but generic.
|
||||
type Rows[T any] interface {
|
||||
ForEach(func(*T) error) error
|
||||
Next() bool
|
||||
Get() (*T, error)
|
||||
Err() error
|
||||
|
|
132
sa/saro.go
132
sa/saro.go
|
@ -1273,9 +1273,8 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden
|
|||
if err != nil {
|
||||
return fmt.Errorf("starting db query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
return rows.ForEach(func(row *incidentSerialModel) error {
|
||||
// Scan the row into the model. Note: the fields must be passed in the
|
||||
// same order as the columns returned by the query above.
|
||||
ism, err := rows.Get()
|
||||
|
@ -1296,17 +1295,8 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden
|
|||
ispb.LastNoticeSent = timestamppb.New(*ism.LastNoticeSent)
|
||||
}
|
||||
|
||||
err = stream.Send(ispb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return stream.Send(ispb)
|
||||
})
|
||||
}
|
||||
|
||||
func (ssa *SQLStorageAuthority) SerialsForIncident(req *sapb.SerialsForIncidentRequest, stream sapb.StorageAuthority_SerialsForIncidentServer) error {
|
||||
|
@ -1363,43 +1353,21 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromRevokedCertificatesTable(re
|
|||
return fmt.Errorf("reading db: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := rows.Close()
|
||||
if err != nil {
|
||||
ssa.log.AuditErrf("closing row reader: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
row, err := rows.Get()
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading row: %w", err)
|
||||
}
|
||||
|
||||
return rows.ForEach(func(row *revokedCertModel) error {
|
||||
// Double-check that the cert wasn't revoked between the time at which we're
|
||||
// constructing this snapshot CRL and right now. If the cert was revoked
|
||||
// at-or-after the "atTime", we'll just include it in the next generation
|
||||
// of CRLs.
|
||||
if row.RevokedDate.After(atTime) || row.RevokedDate.Equal(atTime) {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
err = stream.Send(&corepb.CRLEntry{
|
||||
return stream.Send(&corepb.CRLEntry{
|
||||
Serial: row.Serial,
|
||||
Reason: int32(row.RevokedReason),
|
||||
RevokedAt: timestamppb.New(row.RevokedDate),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending crl entry: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return fmt.Errorf("iterating over row reader: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// getRevokedCertsFromCertificateStatusTable uses the new old certificateStatus
|
||||
|
@ -1429,43 +1397,21 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromCertificateStatusTable(req
|
|||
return fmt.Errorf("reading db: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := rows.Close()
|
||||
if err != nil {
|
||||
ssa.log.AuditErrf("closing row reader: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
row, err := rows.Get()
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading row: %w", err)
|
||||
}
|
||||
|
||||
return rows.ForEach(func(row *crlEntryModel) error {
|
||||
// Double-check that the cert wasn't revoked between the time at which we're
|
||||
// constructing this snapshot CRL and right now. If the cert was revoked
|
||||
// at-or-after the "atTime", we'll just include it in the next generation
|
||||
// of CRLs.
|
||||
if row.RevokedDate.After(atTime) || row.RevokedDate.Equal(atTime) {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
err = stream.Send(&corepb.CRLEntry{
|
||||
return stream.Send(&corepb.CRLEntry{
|
||||
Serial: row.Serial,
|
||||
Reason: int32(row.RevokedReason),
|
||||
RevokedAt: timestamppb.New(row.RevokedDate),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending crl entry: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return fmt.Errorf("iterating over row reader: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetMaxExpiration returns the timestamp of the farthest-future notAfter date
|
||||
|
@ -1587,31 +1533,9 @@ func (ssa *SQLStorageAuthorityRO) GetSerialsByKey(req *sapb.SPKIHash, stream sap
|
|||
return fmt.Errorf("reading db: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := rows.Close()
|
||||
if err != nil {
|
||||
ssa.log.AuditErrf("closing row reader: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
row, err := rows.Get()
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading row: %w", err)
|
||||
}
|
||||
|
||||
err = stream.Send(&sapb.Serial{Serial: row.CertSerial})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending serial: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return fmt.Errorf("iterating over row reader: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return rows.ForEach(func(row *keyHashModel) error {
|
||||
return stream.Send(&sapb.Serial{Serial: row.CertSerial})
|
||||
})
|
||||
}
|
||||
|
||||
func (ssa *SQLStorageAuthority) GetSerialsByKey(req *sapb.SPKIHash, stream sapb.StorageAuthority_GetSerialsByKeyServer) error {
|
||||
|
@ -1639,31 +1563,9 @@ func (ssa *SQLStorageAuthorityRO) GetSerialsByAccount(req *sapb.RegistrationID,
|
|||
return fmt.Errorf("reading db: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := rows.Close()
|
||||
if err != nil {
|
||||
ssa.log.AuditErrf("closing row reader: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
row, err := rows.Get()
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading row: %w", err)
|
||||
}
|
||||
|
||||
err = stream.Send(&sapb.Serial{Serial: row.Serial})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending serial: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return fmt.Errorf("iterating over row reader: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return rows.ForEach(func(row *recordedSerialModel) error {
|
||||
return stream.Send(&sapb.Serial{Serial: row.Serial})
|
||||
})
|
||||
}
|
||||
|
||||
func (ssa *SQLStorageAuthority) GetSerialsByAccount(req *sapb.RegistrationID, stream sapb.StorageAuthority_GetSerialsByAccountServer) error {
|
||||
|
|
Loading…
Reference in New Issue