From 46306b07b9e6ecc0ddf770a4498208357bc49468 Mon Sep 17 00:00:00 2001 From: Daniel McCarney Date: Wed, 19 Oct 2016 13:44:37 -0400 Subject: [PATCH] Adds "SelectFoo" functions for each DB type. (#2259) In #2178 we moved to explicit `SELECT` statements using a set of `const` fields for each type to support db migrations and forward compatibility. This commit removes the temptation to interpolate queries by providing convenience `SelectFoo` functions for each type allowing the caller to provide the `WHERE` clause and arguments. Resolves #2214. --- cmd/admin-revoker/main.go | 3 +- cmd/cert-checker/main.go | 7 +- cmd/ocsp-updater/main.go | 61 +++++++------ sa/authz.go | 13 ++- sa/model.go | 183 +++++++++++++++++++++++++++++++++++--- sa/sa.go | 138 ++++++++-------------------- 6 files changed, 251 insertions(+), 154 deletions(-) diff --git a/cmd/admin-revoker/main.go b/cmd/admin-revoker/main.go index 34ceb8ab5..4c3ce3de2 100644 --- a/cmd/admin-revoker/main.go +++ b/cmd/admin-revoker/main.go @@ -80,8 +80,7 @@ func revokeBySerial(ctx context.Context, serial string, reasonCode revocation.Re panic(fmt.Sprintf("Invalid reason code: %d", reasonCode)) } - var certObj core.Certificate - err = tx.SelectOne(&certObj, fmt.Sprintf("SELECT %s FROM certificates WHERE serial = ?", sa.CertificateFields), serial) + certObj, err := sa.SelectCertificate(tx, "WHERE serial = ?", serial) if err == sql.ErrNoRows { return core.NotFoundError(fmt.Sprintf("No certificate found for %s", serial)) } diff --git a/cmd/cert-checker/main.go b/cmd/cert-checker/main.go index 1f3c43dd2..02b5861be 100644 --- a/cmd/cert-checker/main.go +++ b/cmd/cert-checker/main.go @@ -116,10 +116,9 @@ func (c *certChecker) getCerts(unexpiredOnly bool) error { args["limit"] = batchSize args["lastSerial"] = "" for offset := 0; offset < count; { - var certs []core.Certificate - _, err = c.dbMap.Select( - &certs, - fmt.Sprintf("SELECT %s FROM certificates WHERE issued >= :issued AND expires >= :now AND serial > :lastSerial LIMIT :limit", sa.CertificateFields), + certs, err := sa.SelectCertificates( + c.dbMap, + "WHERE issued >= :issued AND expires >= :now AND serial > :lastSerial LIMIT :limit", args, ) if err != nil { diff --git a/cmd/ocsp-updater/main.go b/cmd/ocsp-updater/main.go index 48fcfb6a6..200749cb0 100644 --- a/cmd/ocsp-updater/main.go +++ b/cmd/ocsp-updater/main.go @@ -232,22 +232,22 @@ func (updater *OCSPUpdater) findStaleOCSPResponses(oldestLastUpdatedTime time.Ti } func (updater *OCSPUpdater) getCertificatesWithMissingResponses(batchSize int) ([]core.CertificateStatus, error) { + const query = "WHERE ocspLastUpdated = 0 LIMIT ?" var statuses []core.CertificateStatus - var fields string + var err error if features.Enabled(features.CertStatusOptimizationsMigrated) { - fields = sa.CertificateStatusFieldsv2 + statuses, err = sa.SelectCertificateStatusesv2( + updater.dbMap, + query, + batchSize, + ) } else { - fields = sa.CertificateStatusFields + statuses, err = sa.SelectCertificateStatuses( + updater.dbMap, + query, + batchSize, + ) } - _, err := updater.dbMap.Select( - &statuses, - fmt.Sprintf(`SELECT %s FROM certificateStatus - WHERE ocspLastUpdated = 0 - LIMIT :limit`, fields), - map[string]interface{}{ - "limit": batchSize, - }, - ) if err == sql.ErrNoRows { return statuses, nil } @@ -260,11 +260,10 @@ type responseMeta struct { } func (updater *OCSPUpdater) generateResponse(ctx context.Context, status core.CertificateStatus) (*core.CertificateStatus, error) { - var cert core.Certificate - err := updater.dbMap.SelectOne( - &cert, - fmt.Sprintf("SELECT %s FROM certificates WHERE serial = :serial", sa.CertificateFields), - map[string]interface{}{"serial": status.Serial}, + cert, err := sa.SelectCertificate( + updater.dbMap, + "WHERE serial = ?", + status.Serial, ) if err != nil { return nil, err @@ -362,24 +361,24 @@ func (updater *OCSPUpdater) newCertificateTick(ctx context.Context, batchSize in } func (updater *OCSPUpdater) findRevokedCertificatesToUpdate(batchSize int) ([]core.CertificateStatus, error) { + const query = "WHERE status = ? AND ocspLastUpdated <= revokedDate LIMIT ?" var statuses []core.CertificateStatus - var fields string + var err error if features.Enabled(features.CertStatusOptimizationsMigrated) { - fields = sa.CertificateStatusFieldsv2 + statuses, err = sa.SelectCertificateStatusesv2( + updater.dbMap, + query, + string(core.OCSPStatusRevoked), + batchSize, + ) } else { - fields = sa.CertificateStatusFields + statuses, err = sa.SelectCertificateStatuses( + updater.dbMap, + query, + string(core.OCSPStatusRevoked), + batchSize, + ) } - _, err := updater.dbMap.Select( - &statuses, - fmt.Sprintf(`SELECT %s FROM certificateStatus - WHERE status = :revoked - AND ocspLastUpdated <= revokedDate - LIMIT :limit`, fields), - map[string]interface{}{ - "revoked": string(core.OCSPStatusRevoked), - "limit": batchSize, - }, - ) return statuses, err } diff --git a/sa/authz.go b/sa/authz.go index 577d620e5..0b5b1f999 100644 --- a/sa/authz.go +++ b/sa/authz.go @@ -63,13 +63,16 @@ func authzIdExists(tx *gorp.Transaction, id string) bool { * [0] - https://github.com/letsencrypt/boulder/issues/2162 */ func getAuthz(tx *gorp.Transaction, id string) (core.Authorization, string, error) { + const query = "WHERE ID = ?" var authz core.Authorization var table string // First try to find a row from the `pendingAuthorizations` table with // a `pendingauthzModel{}`. - var pa pendingauthzModel - err := tx.SelectOne(&pa, fmt.Sprintf("SELECT %s FROM pendingAuthorizations WHERE id = ?", pendingAuthzFields), id) + pa, err := selectPendingAuthz( + tx, + query, + id) // If there was an error other than "no rows", abort if err != nil && err != sql.ErrNoRows { err = Rollback(tx, err) @@ -83,8 +86,10 @@ func getAuthz(tx *gorp.Transaction, id string) (core.Authorization, string, erro // But if the err was ErrNoRows, then we need to try looking in the `authz` // table using a `authzModel` since there wasn't a `pendingAuthorization` // row - var fa authzModel - err = tx.SelectOne(&fa, fmt.Sprintf("SELECT %s FROM authz WHERE id = ?", authzFields), id) + fa, err := selectAuthz( + tx, + query, + id) // If there *still* was no rows, we're out of options. Nothing found if err == sql.ErrNoRows { err = fmt.Errorf("No pendingAuthorization or authz with ID %s", id) diff --git a/sa/model.go b/sa/model.go index ca2e709c1..0538fae66 100644 --- a/sa/model.go +++ b/sa/model.go @@ -12,23 +12,156 @@ import ( "github.com/letsencrypt/boulder/core" "github.com/letsencrypt/boulder/features" "github.com/letsencrypt/boulder/probs" + "github.com/letsencrypt/boulder/revocation" ) -const ( - regV1Fields string = "id, jwk, jwk_sha256, contact, agreement, initialIP, createdAt, LockCol" - regV2Fields string = regV1Fields + ", status" - pendingAuthzFields string = "id, identifier, registrationID, status, expires, combinations, LockCol" - authzFields string = "id, identifier, registrationID, status, expires, combinations" - sctFields string = "id, sctVersion, logID, timestamp, extensions, signature, certificateSerial, LockCol" +// A `dbOneSelector` is anything that provides a `SelectOne` function. +type dbOneSelector interface { + SelectOne(interface{}, string, ...interface{}) error +} - // CertificateFields and CertificateStatusFields are also used by cert-checker and ocsp-updater - CertificateFields string = "registrationID, serial, digest, der, issued, expires" - CertificateStatusFields string = "serial, subscriberApproved, status, ocspLastUpdated, revokedDate, revokedReason, lastExpirationNagSent, ocspResponse, LockCol" +// A `dbSelector` is anything that provides a `Select` function. +type dbSelector interface { + Select(interface{}, string, ...interface{}) ([]interface{}, error) +} - // CertificateStatusFieldsv2 is used when the CertStatusOptimizationsMigrated - // feature flag is enabled and includes "notAfter" and "isExpired" fields - CertificateStatusFieldsv2 string = CertificateStatusFields + ", notAfter, isExpired" -) +const regFields = "id, jwk, jwk_sha256, contact, agreement, initialIP, createdAt, LockCol" +const regFieldsv2 = regFields + ", status" + +// selectRegistration selects all fields of one registration model +func selectRegistration(s dbOneSelector, q string, args ...interface{}) (*regModelv1, error) { + var model regModelv1 + err := s.SelectOne( + &model, + "SELECT "+regFields+" FROM registrations "+q, + args..., + ) + return &model, err +} + +// selectRegistrationv2 selects all fields (including v2 migrated fields) of one registration model +func selectRegistrationv2(s dbOneSelector, q string, args ...interface{}) (*regModelv2, error) { + var model regModelv2 + err := s.SelectOne( + &model, + "SELECT "+regFieldsv2+" FROM registrations "+q, args...) + return &model, err +} + +// selectPendingAuthz selects all fields of one pending authorization model +func selectPendingAuthz(s dbOneSelector, q string, args ...interface{}) (*pendingauthzModel, error) { + var model pendingauthzModel + err := s.SelectOne( + &model, + "SELECT id, identifier, registrationID, status, expires, combinations, LockCol FROM pendingAuthorizations "+q, + args..., + ) + return &model, err +} + +const authzFields = "id, identifier, registrationID, status, expires, combinations" + +// selectAuthz selects all fields of one authorization model +func selectAuthz(s dbOneSelector, q string, args ...interface{}) (*authzModel, error) { + var model authzModel + err := s.SelectOne( + &model, + "SELECT "+authzFields+" FROM authz "+q, + args..., + ) + return &model, err +} + +// selectAuthzs selects all fields of multiple authorization objects +func selectAuthzs(s dbSelector, q string, args ...interface{}) ([]*core.Authorization, error) { + var models []*core.Authorization + _, err := s.Select( + &models, + "SELECT "+authzFields+" FROM authz "+q, + args..., + ) + return models, err +} + +// selectSctReceipt selects all fields of one SignedCertificateTimestamp object +func selectSctReceipt(s dbOneSelector, q string, args ...interface{}) (core.SignedCertificateTimestamp, error) { + var model core.SignedCertificateTimestamp + err := s.SelectOne( + &model, + "SELECT id, sctVersion, logID, timestamp, extensions, signature, certificateSerial, LockCol FROM sctReceipts "+q, + args..., + ) + return model, err +} + +const certFields = "registrationID, serial, digest, der, issued, expires" + +// SelectCertificate selects all fields of one certificate object +func SelectCertificate(s dbOneSelector, q string, args ...interface{}) (core.Certificate, error) { + var model core.Certificate + err := s.SelectOne( + &model, + "SELECT "+certFields+" FROM certificates "+q, + args..., + ) + return model, err +} + +// SelectCertificates selects all fields of multiple certificate objects +func SelectCertificates(s dbSelector, q string, args map[string]interface{}) ([]core.Certificate, error) { + var models []core.Certificate + _, err := s.Select( + &models, + "SELECT "+certFields+" FROM certificates "+q, args) + return models, err +} + +const certStatusFields = "serial, subscriberApproved, status, ocspLastUpdated, revokedDate, revokedReason, lastExpirationNagSent, ocspResponse, LockCol" +const certStatusFieldsv2 = certStatusFields + ", notAfter, isExpired" + +// SelectCertificateStatus selects all fields of one certificate status model +func SelectCertificateStatus(s dbOneSelector, q string, args ...interface{}) (certStatusModelv1, error) { + var model certStatusModelv1 + err := s.SelectOne( + &model, + "SELECT "+certStatusFields+" FROM certificateStatus "+q, + args..., + ) + return model, err +} + +// SelectCertificateStatusv2 selects all fields (including the v2 migrated fields) of one certificate status model +func SelectCertificateStatusv2(s dbOneSelector, q string, args ...interface{}) (certStatusModelv2, error) { + var model certStatusModelv2 + err := s.SelectOne( + &model, + "SELECT "+certStatusFieldsv2+" FROM certificateStatus "+q, + args..., + ) + return model, err +} + +// SelectCertificateStatuses selects all fields of multiple certificate status objects +func SelectCertificateStatuses(s dbSelector, q string, args ...interface{}) ([]core.CertificateStatus, error) { + var models []core.CertificateStatus + _, err := s.Select( + &models, + "SELECT "+certStatusFields+" FROM certificateStatus "+q, + args..., + ) + return models, err +} + +// SelectCertificateStatusesv2 selects all fields (including the v2 migrated fields) of multiple certificate status objects +func SelectCertificateStatusesv2(s dbSelector, q string, args ...interface{}) ([]core.CertificateStatus, error) { + var models []core.CertificateStatus + _, err := s.Select( + &models, + "SELECT "+certStatusFieldsv2+" FROM certificateStatus "+q, + args..., + ) + return models, err +} var mediumBlobSize = int(math.Pow(2, 24)) @@ -61,6 +194,30 @@ type regModelv2 struct { Status string `db:"status"` } +// We need two certStatus model structs, one for when boulder does *not* have +// the 20160817143417_CertStatusOptimizations.sql migration applied +// (certStatusModelv1) and one for when it does (certStatusModelv2) +// +// TODO(@cpu): Collapse into one struct once the migration has been applied +// & feature flag set. +type certStatusModelv1 struct { + Serial string `db:"serial"` + SubscriberApproved bool `db:"subscriberApproved"` + Status core.OCSPStatus `db:"status"` + OCSPLastUpdated time.Time `db:"ocspLastUpdated"` + RevokedDate time.Time `db:"revokedDate"` + RevokedReason revocation.Reason `db:"revokedReason"` + LastExpirationNagSent time.Time `db:"lastExpirationNagSent"` + OCSPResponse []byte `db:"ocspResponse"` + LockCol int64 `json:"-"` +} + +type certStatusModelv2 struct { + certStatusModelv1 + NotAfter time.Time `db:"notAfter"` + IsExpired bool `db:"isExpired"` +} + // challModel is the description of a core.Challenge in the database // // The Validation field is a stub; the column is only there for backward compatibility. diff --git a/sa/sa.go b/sa/sa.go index 2487e52d9..c89865a4d 100644 --- a/sa/sa.go +++ b/sa/sa.go @@ -47,30 +47,6 @@ type authzModel struct { core.Authorization } -// We need two certStatus model structs, one for when boulder does *not* have -// the 20160817143417_CertStatusOptimizations.sql migration applied -// (certStatusModelv1) and one for when it does (certStatusModelv2) -// -// TODO(@cpu): Collapse into one struct once the migration has been applied -// & feature flag set. -type certStatusModelv1 struct { - Serial string `db:"serial"` - SubscriberApproved bool `db:"subscriberApproved"` - Status core.OCSPStatus `db:"status"` - OCSPLastUpdated time.Time `db:"ocspLastUpdated"` - RevokedDate time.Time `db:"revokedDate"` - RevokedReason revocation.Reason `db:"revokedReason"` - LastExpirationNagSent time.Time `db:"lastExpirationNagSent"` - OCSPResponse []byte `db:"ocspResponse"` - LockCol int64 `json:"-"` -} - -type certStatusModelv2 struct { - certStatusModelv1 - NotAfter time.Time `db:"notAfter"` - IsExpired bool `db:"isExpired"` -} - // NewSQLStorageAuthority provides persistence using a SQL backend for // Boulder. It will modify the given gorp.DbMap by adding relevant tables. func NewSQLStorageAuthority(dbMap *gorp.DbMap, clk clock.Clock, logger blog.Logger) (*SQLStorageAuthority, error) { @@ -120,20 +96,14 @@ func updateChallenges(authID string, challenges []core.Challenge, tx *gorp.Trans // GetRegistration obtains a Registration by ID func (ssa *SQLStorageAuthority) GetRegistration(ctx context.Context, id int64) (core.Registration, error) { - var reg interface{} - var fields string + const query = "WHERE id = ?" + var model interface{} + var err error if features.Enabled(features.AllowAccountDeactivation) { - reg = ®Modelv2{} - fields = regV2Fields + model, err = selectRegistrationv2(ssa.dbMap, query, id) } else { - reg = ®Modelv1{} - fields = regV1Fields + model, err = selectRegistration(ssa.dbMap, query, id) } - err := ssa.dbMap.SelectOne( - reg, - fmt.Sprintf("SELECT %s FROM registrations WHERE id = ?", fields), - id, - ) if err == sql.ErrNoRows { return core.Registration{}, core.NoSuchRegistrationError( fmt.Sprintf("No registrations with ID %d", id), @@ -142,30 +112,23 @@ func (ssa *SQLStorageAuthority) GetRegistration(ctx context.Context, id int64) ( if err != nil { return core.Registration{}, err } - return modelToRegistration(reg) + return modelToRegistration(model) } // GetRegistrationByKey obtains a Registration by JWK func (ssa *SQLStorageAuthority) GetRegistrationByKey(ctx context.Context, key jose.JsonWebKey) (core.Registration, error) { - var reg interface{} - var fields string - if features.Enabled(features.AllowAccountDeactivation) { - reg = ®Modelv2{} - fields = regV2Fields - } else { - reg = ®Modelv1{} - fields = regV1Fields - } + const query = "WHERE jwk_sha256 = ?" + var model interface{} + var err error sha, err := core.KeyDigest(key.Key) if err != nil { return core.Registration{}, err } - err = ssa.dbMap.SelectOne( - reg, - fmt.Sprintf("SELECT %s FROM registrations WHERE jwk_sha256 = :key", fields), - map[string]interface{}{"key": sha}, - ) - + if features.Enabled(features.AllowAccountDeactivation) { + model, err = selectRegistrationv2(ssa.dbMap, query, sha) + } else { + model, err = selectRegistration(ssa.dbMap, query, sha) + } if err == sql.ErrNoRows { msg := fmt.Sprintf("No registrations with public key sha256 %s", sha) return core.Registration{}, core.NoSuchRegistrationError(msg) @@ -174,7 +137,7 @@ func (ssa *SQLStorageAuthority) GetRegistrationByKey(ctx context.Context, key jo return core.Registration{}, err } - return modelToRegistration(reg) + return modelToRegistration(model) } // GetAuthorization obtains an Authorization by ID @@ -213,18 +176,12 @@ func (ssa *SQLStorageAuthority) GetValidAuthorizations(ctx context.Context, regi qmarks[i] = "?" } - var auths []*core.Authorization - _, err = ssa.dbMap.Select( - &auths, - fmt.Sprintf(` - SELECT %s FROM authz - WHERE registrationID = ? - AND expires > ? - AND identifier IN (`+strings.Join(qmarks, ",")+`) - AND status = 'valid' - `, authzFields), - append([]interface{}{registrationID, now}, params...)..., - ) + auths, err := selectAuthzs(ssa.dbMap, + "WHERE registrationID = ? "+ + "AND expires > ? "+ + "AND identifier IN ("+strings.Join(qmarks, ",")+") "+ + "AND status = 'valid'", + append([]interface{}{registrationID, now}, params...)...) if err != nil { return nil, err } @@ -397,8 +354,7 @@ func (ssa *SQLStorageAuthority) GetCertificate(ctx context.Context, serial strin return core.Certificate{}, err } - var cert core.Certificate - err := ssa.dbMap.SelectOne(&cert, fmt.Sprintf("SELECT %s FROM certificates WHERE serial = ?", CertificateFields), serial) + cert, err := SelectCertificate(ssa.dbMap, "WHERE serial = ?", serial) if err == sql.ErrNoRows { return core.Certificate{}, core.NotFoundError(fmt.Sprintf("No certificate found for %s", serial)) } @@ -498,20 +454,14 @@ func (ssa *SQLStorageAuthority) MarkCertificateRevoked(ctx context.Context, seri return err } + const statusQuery = "WHERE serial = ?" var statusObj interface{} - var fields string if features.Enabled(features.CertStatusOptimizationsMigrated) { - statusObj = &certStatusModelv2{} - fields = CertificateStatusFieldsv2 + statusObj, err = SelectCertificateStatusv2(tx, statusQuery, serial) } else { - statusObj = &certStatusModelv1{} - fields = CertificateStatusFields + statusObj, err = SelectCertificateStatus(tx, statusQuery, serial) } - err = tx.SelectOne( - statusObj, - fmt.Sprintf("SELECT %s FROM certificateStatus WHERE serial = ?", fields), - serial) if err == sql.ErrNoRows { err = fmt.Errorf("No certificate with serial %s", serial) err = Rollback(tx, err) @@ -522,22 +472,21 @@ func (ssa *SQLStorageAuthority) MarkCertificateRevoked(ctx context.Context, seri return err } + var n int64 now := ssa.clk.Now() if features.Enabled(features.CertStatusOptimizationsMigrated) { - status := statusObj.(*certStatusModelv2) + status := statusObj.(certStatusModelv2) status.Status = core.OCSPStatusRevoked status.RevokedDate = now status.RevokedReason = reasonCode - statusObj = status + n, err = tx.Update(&status) } else { - status := statusObj.(*certStatusModelv1) + status := statusObj.(certStatusModelv1) status.Status = core.OCSPStatusRevoked status.RevokedDate = now status.RevokedReason = reasonCode - statusObj = status + n, err = tx.Update(&status) } - - n, err := tx.Update(statusObj) if err != nil { err = Rollback(tx, err) return err @@ -553,16 +502,14 @@ func (ssa *SQLStorageAuthority) MarkCertificateRevoked(ctx context.Context, seri // UpdateRegistration stores an updated Registration func (ssa *SQLStorageAuthority) UpdateRegistration(ctx context.Context, reg core.Registration) error { - var regType interface{} - var fields string + const query = "WHERE id = ?" + var model interface{} + var err error if features.Enabled(features.AllowAccountDeactivation) { - regType = ®Modelv2{} - fields = regV2Fields + model, err = selectRegistrationv2(ssa.dbMap, query, reg.ID) } else { - regType = ®Modelv1{} - fields = regV1Fields + model, err = selectRegistration(ssa.dbMap, query, reg.ID) } - err := ssa.dbMap.SelectOne(regType, fmt.Sprintf("SELECT %s FROM registrations WHERE id = ?", fields), reg.ID) if err == sql.ErrNoRows { msg := fmt.Sprintf("No registrations with ID %d", reg.ID) return core.NoSuchRegistrationError(msg) @@ -578,12 +525,12 @@ func (ssa *SQLStorageAuthority) UpdateRegistration(ctx context.Context, reg core // so that we can copy over the LockCol from one to the other. Once we have copied // that field we reassign to the interface so gorp can properly update it. if features.Enabled(features.AllowAccountDeactivation) { - erm := regType.(*regModelv2) + erm := model.(*regModelv2) urm := updatedRegModel.(*regModelv2) urm.LockCol = erm.LockCol updatedRegModel = urm } else { - erm := regType.(*regModelv1) + erm := model.(*regModelv1) urm := updatedRegModel.(*regModelv1) urm.LockCol = erm.LockCol updatedRegModel = urm @@ -985,20 +932,11 @@ func (e ErrNoReceipt) Error() string { // GetSCTReceipt gets a specific SCT receipt for a given certificate serial and // CT log ID func (ssa *SQLStorageAuthority) GetSCTReceipt(ctx context.Context, serial string, logID string) (receipt core.SignedCertificateTimestamp, err error) { - err = ssa.dbMap.SelectOne( - &receipt, - fmt.Sprintf("SELECT %s FROM sctReceipts WHERE certificateSerial = :serial AND logID = :logID", sctFields), - map[string]interface{}{ - "serial": serial, - "logID": logID, - }, - ) - + receipt, err = selectSctReceipt(ssa.dbMap, "WHERE certificateSerial = ? AND logID = ?", serial, logID) if err == sql.ErrNoRows { err = ErrNoReceipt(err.Error()) return } - return }