Ensure SelectOne queries never return more than 1 row (#6900)

As a follow-up to https://github.com/letsencrypt/boulder/issues/5467, I
did an audit of all places where we call SelectOne to ensure that those
queries can never return more than one result. These four functions were
the only places that weren't already constrained to a single result
through the use of "SELECT COUNT", "LIMIT 1", "WHERE uniqueKey =", or
similar. Limit these functions' queries to always only return a single
result, now that their underlying tables no longer have unique key
constraints.

Additionally, slightly refactor selectRegistration to just take a single
column name rather than a whole WHERE clause.

Fixes https://github.com/letsencrypt/boulder/issues/6521
This commit is contained in:
Aaron Gable 2023-05-17 14:13:21 -07:00 committed by GitHub
parent f91aa1d57d
commit 56f8537e68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 12 deletions

View File

@ -59,11 +59,15 @@ func badJSONError(msg string, jsonData []byte, err error) error {
const regFields = "id, jwk, jwk_sha256, contact, agreement, initialIP, createdAt, LockCol, status"
// selectRegistration selects all fields of one registration model
func selectRegistration(s db.OneSelector, q string, args ...interface{}) (*regModel, error) {
func selectRegistration(s db.OneSelector, whereCol string, args ...interface{}) (*regModel, error) {
if whereCol != "id" && whereCol != "jwk_sha256" {
return nil, fmt.Errorf("column name %q invalid for registrations table WHERE clause", whereCol)
}
var model regModel
err := s.SelectOne(
&model,
"SELECT "+regFields+" FROM registrations "+q,
"SELECT "+regFields+" FROM registrations WHERE "+whereCol+" = ? LIMIT 1",
args...,
)
return &model, err
@ -92,7 +96,7 @@ func SelectPrecertificate(s db.OneSelector, serial string) (core.Certificate, er
var model precertificateModel
err := s.SelectOne(
&model,
"SELECT "+precertFields+" FROM precertificates WHERE serial = ?",
"SELECT "+precertFields+" FROM precertificates WHERE serial = ? LIMIT 1",
serial)
return core.Certificate{
RegistrationID: model.RegistrationID,
@ -147,7 +151,7 @@ func SelectCertificateStatus(s db.OneSelector, serial string) (core.CertificateS
var model core.CertificateStatus
err := s.SelectOne(
&model,
"SELECT "+certStatusFields+" FROM certificateStatus WHERE serial = ?",
"SELECT "+certStatusFields+" FROM certificateStatus WHERE serial = ? LIMIT 1",
serial,
)
return model, err
@ -168,7 +172,7 @@ func SelectRevocationStatus(s db.OneSelector, serial string) (*sapb.RevocationSt
var model RevocationStatusModel
err := s.SelectOne(
&model,
"SELECT status, revokedDate, revokedReason FROM certificateStatus WHERE serial = ?",
"SELECT status, revokedDate, revokedReason FROM certificateStatus WHERE serial = ? LIMIT 1",
serial,
)
if err != nil {

View File

@ -125,8 +125,7 @@ func (ssa *SQLStorageAuthority) UpdateRegistration(ctx context.Context, req *cor
return nil, errIncompleteRequest
}
const query = "WHERE id = ?"
curr, err := selectRegistration(ssa.dbMap.WithContext(ctx), query, req.Id)
curr, err := selectRegistration(ssa.dbMap.WithContext(ctx), "id", req.Id)
if err != nil {
if db.IsNoRows(err) {
return nil, berrors.NotFoundError("registration with ID '%d' not found", req.Id)

View File

@ -233,6 +233,33 @@ func TestNoSuchRegistrationErrors(t *testing.T) {
test.AssertErrorIs(t, err, berrors.NotFound)
}
func TestSelectRegistration(t *testing.T) {
sa, _, cleanUp := initSA(t)
defer cleanUp()
var ctx = context.Background()
var ssaCtx = sa.dbMap.WithContext(ctx)
jwk := goodTestJWK()
jwkJSON, _ := jwk.MarshalJSON()
sha, err := core.KeyDigestB64(jwk.Key)
test.AssertNotError(t, err, "couldn't parse jwk.Key")
initialIP, _ := net.ParseIP("43.34.43.34").MarshalText()
reg, err := sa.NewRegistration(ctx, &corepb.Registration{
Key: jwkJSON,
Contact: []string{"mailto:foo@example.com"},
InitialIP: initialIP,
})
test.AssertNotError(t, err, fmt.Sprintf("couldn't create new registration: %s", err))
test.Assert(t, reg.Id != 0, "ID shouldn't be 0")
_, err = selectRegistration(ssaCtx, "id", reg.Id)
test.AssertNotError(t, err, "selecting by id should work")
_, err = selectRegistration(ssaCtx, "jwk_sha256", sha)
test.AssertNotError(t, err, "selecting by jwk_sha256 should work")
_, err = selectRegistration(ssaCtx, "initialIP", reg.Id)
test.AssertError(t, err, "selecting by any other column should not work")
}
func TestReplicationLagRetries(t *testing.T) {
sa, clk, cleanUp := initSA(t)
defer cleanUp()

View File

@ -107,14 +107,13 @@ func (ssa *SQLStorageAuthorityRO) GetRegistration(ctx context.Context, req *sapb
return nil, errIncompleteRequest
}
const query = "WHERE id = ?"
model, err := selectRegistration(ssa.dbReadOnlyMap.WithContext(ctx), query, req.Id)
model, err := selectRegistration(ssa.dbReadOnlyMap.WithContext(ctx), "id", req.Id)
if db.IsNoRows(err) && ssa.lagFactor != 0 {
// GetRegistration is often called to validate a JWK belonging to a brand
// new account whose registrations table row hasn't propagated to the read
// replica yet. If we get a NoRows, wait a little bit and retry, once.
ssa.clk.Sleep(ssa.lagFactor)
model, err = selectRegistration(ssa.dbReadOnlyMap.WithContext(ctx), query, req.Id)
model, err = selectRegistration(ssa.dbReadOnlyMap.WithContext(ctx), "id", req.Id)
if err != nil {
if db.IsNoRows(err) {
ssa.lagFactorCounter.WithLabelValues("GetRegistration", "notfound").Inc()
@ -151,12 +150,11 @@ func (ssa *SQLStorageAuthorityRO) GetRegistrationByKey(ctx context.Context, req
return nil, err
}
const query = "WHERE jwk_sha256 = ?"
sha, err := core.KeyDigestB64(jwk.Key)
if err != nil {
return nil, err
}
model, err := selectRegistration(ssa.dbReadOnlyMap.WithContext(ctx), query, sha)
model, err := selectRegistration(ssa.dbReadOnlyMap.WithContext(ctx), "jwk_sha256", sha)
if err != nil {
if db.IsNoRows(err) {
return nil, berrors.NotFoundError("no registrations with public key sha256 %q", sha)