Simplify streaming rows from the database (#7372)
Create a new method on the gorm rows object which runs a small closure for every row retrieved from the database. Use this new method to remove 20 lines of boilerplate from five different SA methods and rocsp-tool.
This commit is contained in:
parent
5e68cbe552
commit
7f04092e72
|
@ -9,6 +9,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jmhodges/clock"
|
"github.com/jmhodges/clock"
|
||||||
|
"golang.org/x/crypto/ocsp"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
capb "github.com/letsencrypt/boulder/ca/proto"
|
capb "github.com/letsencrypt/boulder/ca/proto"
|
||||||
"github.com/letsencrypt/boulder/core"
|
"github.com/letsencrypt/boulder/core"
|
||||||
"github.com/letsencrypt/boulder/db"
|
"github.com/letsencrypt/boulder/db"
|
||||||
|
@ -16,8 +19,6 @@ import (
|
||||||
"github.com/letsencrypt/boulder/rocsp"
|
"github.com/letsencrypt/boulder/rocsp"
|
||||||
"github.com/letsencrypt/boulder/sa"
|
"github.com/letsencrypt/boulder/sa"
|
||||||
"github.com/letsencrypt/boulder/test/ocsp/helper"
|
"github.com/letsencrypt/boulder/test/ocsp/helper"
|
||||||
"golang.org/x/crypto/ocsp"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type client struct {
|
type client struct {
|
||||||
|
@ -169,21 +170,15 @@ func (cl *client) scanFromDBOneBatch(ctx context.Context, prevID int64, frequenc
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, fmt.Errorf("scanning certificateStatus: %w", err)
|
return -1, fmt.Errorf("scanning certificateStatus: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
rerr := rows.Close()
|
|
||||||
if rerr != nil {
|
|
||||||
cl.logger.Infof("closing rows: %s", rerr)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var scanned int
|
var scanned int
|
||||||
var previousID int64
|
var previousID int64
|
||||||
for rows.Next() {
|
err = rows.ForEach(func(row *sa.CertStatusMetadata) error {
|
||||||
<-rowTicker.C
|
<-rowTicker.C
|
||||||
|
|
||||||
status, err := rows.Get()
|
status, err := rows.Get()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, fmt.Errorf("scanning row %d (previous ID %d): %w", scanned, previousID, err)
|
return fmt.Errorf("scanning row %d (previous ID %d): %w", scanned, previousID, err)
|
||||||
}
|
}
|
||||||
scanned++
|
scanned++
|
||||||
inflightIDs.add(uint64(status.ID))
|
inflightIDs.add(uint64(status.ID))
|
||||||
|
@ -195,7 +190,12 @@ func (cl *client) scanFromDBOneBatch(ctx context.Context, prevID int64, frequenc
|
||||||
}
|
}
|
||||||
output <- status
|
output <- status
|
||||||
previousID = status.ID
|
previousID = status.ID
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return previousID, nil
|
return previousID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
39
db/gorm.go
39
db/gorm.go
|
@ -142,6 +142,45 @@ type rows[T any] struct {
|
||||||
numCols int
|
numCols int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForEach calls the given function with each model object retrieved by
|
||||||
|
// repeatedly calling .Get(). It closes the rows object when it hits an error
|
||||||
|
// or finishes iterating over the rows, so it can only be called once. This is
|
||||||
|
// the intended way to use the result of QueryContext or QueryFrom; the other
|
||||||
|
// methods on this type are lower-level and intended for advanced use only.
|
||||||
|
func (r rows[T]) ForEach(do func(*T) error) (err error) {
|
||||||
|
defer func() {
|
||||||
|
// Close the row reader when we exit. Use the named error return to combine
|
||||||
|
// any error from normal execution with any error from closing.
|
||||||
|
closeErr := r.Close()
|
||||||
|
if closeErr != nil && err != nil {
|
||||||
|
err = fmt.Errorf("%w; also while closing the row reader: %w", err, closeErr)
|
||||||
|
} else if closeErr != nil {
|
||||||
|
err = closeErr
|
||||||
|
}
|
||||||
|
// If closeErr is nil, then just leaving the existing named return alone
|
||||||
|
// will do the right thing.
|
||||||
|
}()
|
||||||
|
|
||||||
|
for r.Next() {
|
||||||
|
row, err := r.Get()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading row: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = do(row)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.Err()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("iterating over row reader: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Next is a wrapper around sql.Rows.Next(). It must be called before every call
|
// Next is a wrapper around sql.Rows.Next(). It must be called before every call
|
||||||
// to Get(), including the first.
|
// to Get(), including the first.
|
||||||
func (r rows[T]) Next() bool {
|
func (r rows[T]) Next() bool {
|
||||||
|
|
|
@ -95,6 +95,7 @@ type MappedSelector[T any] interface {
|
||||||
// Rows is anything which lets you iterate over the result rows of a SELECT
|
// Rows is anything which lets you iterate over the result rows of a SELECT
|
||||||
// query. It is similar to sql.Rows, but generic.
|
// query. It is similar to sql.Rows, but generic.
|
||||||
type Rows[T any] interface {
|
type Rows[T any] interface {
|
||||||
|
ForEach(func(*T) error) error
|
||||||
Next() bool
|
Next() bool
|
||||||
Get() (*T, error)
|
Get() (*T, error)
|
||||||
Err() error
|
Err() error
|
||||||
|
|
132
sa/saro.go
132
sa/saro.go
|
@ -1273,9 +1273,8 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("starting db query: %w", err)
|
return fmt.Errorf("starting db query: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
for rows.Next() {
|
return rows.ForEach(func(row *incidentSerialModel) error {
|
||||||
// Scan the row into the model. Note: the fields must be passed in the
|
// Scan the row into the model. Note: the fields must be passed in the
|
||||||
// same order as the columns returned by the query above.
|
// same order as the columns returned by the query above.
|
||||||
ism, err := rows.Get()
|
ism, err := rows.Get()
|
||||||
|
@ -1296,17 +1295,8 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden
|
||||||
ispb.LastNoticeSent = timestamppb.New(*ism.LastNoticeSent)
|
ispb.LastNoticeSent = timestamppb.New(*ism.LastNoticeSent)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = stream.Send(ispb)
|
return stream.Send(ispb)
|
||||||
if err != nil {
|
})
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rows.Err()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ssa *SQLStorageAuthority) SerialsForIncident(req *sapb.SerialsForIncidentRequest, stream sapb.StorageAuthority_SerialsForIncidentServer) error {
|
func (ssa *SQLStorageAuthority) SerialsForIncident(req *sapb.SerialsForIncidentRequest, stream sapb.StorageAuthority_SerialsForIncidentServer) error {
|
||||||
|
@ -1363,43 +1353,21 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromRevokedCertificatesTable(re
|
||||||
return fmt.Errorf("reading db: %w", err)
|
return fmt.Errorf("reading db: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
return rows.ForEach(func(row *revokedCertModel) error {
|
||||||
err := rows.Close()
|
|
||||||
if err != nil {
|
|
||||||
ssa.log.AuditErrf("closing row reader: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
row, err := rows.Get()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("reading row: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Double-check that the cert wasn't revoked between the time at which we're
|
// Double-check that the cert wasn't revoked between the time at which we're
|
||||||
// constructing this snapshot CRL and right now. If the cert was revoked
|
// constructing this snapshot CRL and right now. If the cert was revoked
|
||||||
// at-or-after the "atTime", we'll just include it in the next generation
|
// at-or-after the "atTime", we'll just include it in the next generation
|
||||||
// of CRLs.
|
// of CRLs.
|
||||||
if row.RevokedDate.After(atTime) || row.RevokedDate.Equal(atTime) {
|
if row.RevokedDate.After(atTime) || row.RevokedDate.Equal(atTime) {
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = stream.Send(&corepb.CRLEntry{
|
return stream.Send(&corepb.CRLEntry{
|
||||||
Serial: row.Serial,
|
Serial: row.Serial,
|
||||||
Reason: int32(row.RevokedReason),
|
Reason: int32(row.RevokedReason),
|
||||||
RevokedAt: timestamppb.New(row.RevokedDate),
|
RevokedAt: timestamppb.New(row.RevokedDate),
|
||||||
})
|
})
|
||||||
if err != nil {
|
})
|
||||||
return fmt.Errorf("sending crl entry: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rows.Err()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iterating over row reader: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRevokedCertsFromCertificateStatusTable uses the new old certificateStatus
|
// getRevokedCertsFromCertificateStatusTable uses the new old certificateStatus
|
||||||
|
@ -1429,43 +1397,21 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromCertificateStatusTable(req
|
||||||
return fmt.Errorf("reading db: %w", err)
|
return fmt.Errorf("reading db: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
return rows.ForEach(func(row *crlEntryModel) error {
|
||||||
err := rows.Close()
|
|
||||||
if err != nil {
|
|
||||||
ssa.log.AuditErrf("closing row reader: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
row, err := rows.Get()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("reading row: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Double-check that the cert wasn't revoked between the time at which we're
|
// Double-check that the cert wasn't revoked between the time at which we're
|
||||||
// constructing this snapshot CRL and right now. If the cert was revoked
|
// constructing this snapshot CRL and right now. If the cert was revoked
|
||||||
// at-or-after the "atTime", we'll just include it in the next generation
|
// at-or-after the "atTime", we'll just include it in the next generation
|
||||||
// of CRLs.
|
// of CRLs.
|
||||||
if row.RevokedDate.After(atTime) || row.RevokedDate.Equal(atTime) {
|
if row.RevokedDate.After(atTime) || row.RevokedDate.Equal(atTime) {
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = stream.Send(&corepb.CRLEntry{
|
return stream.Send(&corepb.CRLEntry{
|
||||||
Serial: row.Serial,
|
Serial: row.Serial,
|
||||||
Reason: int32(row.RevokedReason),
|
Reason: int32(row.RevokedReason),
|
||||||
RevokedAt: timestamppb.New(row.RevokedDate),
|
RevokedAt: timestamppb.New(row.RevokedDate),
|
||||||
})
|
})
|
||||||
if err != nil {
|
})
|
||||||
return fmt.Errorf("sending crl entry: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rows.Err()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iterating over row reader: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMaxExpiration returns the timestamp of the farthest-future notAfter date
|
// GetMaxExpiration returns the timestamp of the farthest-future notAfter date
|
||||||
|
@ -1587,31 +1533,9 @@ func (ssa *SQLStorageAuthorityRO) GetSerialsByKey(req *sapb.SPKIHash, stream sap
|
||||||
return fmt.Errorf("reading db: %w", err)
|
return fmt.Errorf("reading db: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
return rows.ForEach(func(row *keyHashModel) error {
|
||||||
err := rows.Close()
|
return stream.Send(&sapb.Serial{Serial: row.CertSerial})
|
||||||
if err != nil {
|
})
|
||||||
ssa.log.AuditErrf("closing row reader: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
row, err := rows.Get()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("reading row: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = stream.Send(&sapb.Serial{Serial: row.CertSerial})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("sending serial: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rows.Err()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iterating over row reader: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ssa *SQLStorageAuthority) GetSerialsByKey(req *sapb.SPKIHash, stream sapb.StorageAuthority_GetSerialsByKeyServer) error {
|
func (ssa *SQLStorageAuthority) GetSerialsByKey(req *sapb.SPKIHash, stream sapb.StorageAuthority_GetSerialsByKeyServer) error {
|
||||||
|
@ -1639,31 +1563,9 @@ func (ssa *SQLStorageAuthorityRO) GetSerialsByAccount(req *sapb.RegistrationID,
|
||||||
return fmt.Errorf("reading db: %w", err)
|
return fmt.Errorf("reading db: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
return rows.ForEach(func(row *recordedSerialModel) error {
|
||||||
err := rows.Close()
|
return stream.Send(&sapb.Serial{Serial: row.Serial})
|
||||||
if err != nil {
|
})
|
||||||
ssa.log.AuditErrf("closing row reader: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
row, err := rows.Get()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("reading row: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = stream.Send(&sapb.Serial{Serial: row.Serial})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("sending serial: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rows.Err()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iterating over row reader: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ssa *SQLStorageAuthority) GetSerialsByAccount(req *sapb.RegistrationID, stream sapb.StorageAuthority_GetSerialsByAccountServer) error {
|
func (ssa *SQLStorageAuthority) GetSerialsByAccount(req *sapb.RegistrationID, stream sapb.StorageAuthority_GetSerialsByAccountServer) error {
|
||||||
|
|
Loading…
Reference in New Issue