diff --git a/cmd/expiration-mailer/main.go b/cmd/expiration-mailer/main.go index e58fa6411..b8dc49725 100644 --- a/cmd/expiration-mailer/main.go +++ b/cmd/expiration-mailer/main.go @@ -212,16 +212,14 @@ func (m *mailer) updateLastNagTimestamps(ctx context.Context, certs []*x509.Cert // updateLastNagTimestampsChunk processes a single chunk (up to 65k) of certificates. func (m *mailer) updateLastNagTimestampsChunk(ctx context.Context, certs []*x509.Certificate) { - qmarks := make([]string, len(certs)) params := make([]interface{}, len(certs)+1) for i, cert := range certs { - qmarks[i] = "?" params[i+1] = core.SerialToString(cert.SerialNumber) } query := fmt.Sprintf( "UPDATE certificateStatus SET lastExpirationNagSent = ? WHERE serial IN (%s)", - strings.Join(qmarks, ","), + db.QuestionMarks(len(certs)), ) params[0] = m.clk.Now() diff --git a/db/interfaces.go b/db/interfaces.go index 7937ed36a..d01b42ac2 100644 --- a/db/interfaces.go +++ b/db/interfaces.go @@ -64,6 +64,14 @@ type Executor interface { Query(string, ...interface{}) (*sql.Rows, error) } +// Queryer offers the Query method. Note that this is not read-only (i.e. not +// Selector), since a Query can be `INSERT`, `UPDATE`, etc. The difference +// between Query and Exec is that Query can return rows. So for instance it is +// suitable for inserting rows and getting back ids. +type Queryer interface { + Query(string, ...interface{}) (*sql.Rows, error) +} + // Transaction extends an Executor and adds Rollback, Commit, and WithContext. type Transaction interface { Executor diff --git a/db/multi.go b/db/multi.go index 8f9b6a681..36f5730ac 100644 --- a/db/multi.go +++ b/db/multi.go @@ -18,20 +18,23 @@ type MultiInserter struct { } // NewMultiInserter creates a new MultiInserter, checking for reasonable table -// name and list of fields. -func NewMultiInserter(table string, fields string, retCol string) (*MultiInserter, error) { +// name and list of fields. returningColumn is the name of a column to be used +// in a `RETURNING xyz` clause at the end. If it is empty, no `RETURNING xyz` +// clause is used. If returningColumn is present, it must refer to a column +// that can be parsed into an int64. +func NewMultiInserter(table string, fields string, returningColumn string) (*MultiInserter, error) { numFields := len(strings.Split(fields, ",")) if len(table) == 0 || len(fields) == 0 || numFields == 0 { return nil, fmt.Errorf("empty table name or fields list") } - if strings.Contains(retCol, ",") { - return nil, fmt.Errorf("return column must be singular, but got %q", retCol) + if strings.Contains(returningColumn, ",") { + return nil, fmt.Errorf("return column must be singular, but got %q", returningColumn) } return &MultiInserter{ table: table, fields: fields, - retCol: retCol, + retCol: returningColumn, numFields: numFields, values: make([][]interface{}, 0), }, nil @@ -50,12 +53,10 @@ func (mi *MultiInserter) Add(row []interface{}) error { // for gorp to use in place of the query's question marks. Currently only // used by .Insert(), below. func (mi *MultiInserter) query() (string, []interface{}) { - questionsRow := strings.TrimRight(strings.Repeat("?,", mi.numFields), ",") - var questionsBuf strings.Builder var queryArgs []interface{} for _, row := range mi.values { - fmt.Fprintf(&questionsBuf, "(%s),", questionsRow) + fmt.Fprintf(&questionsBuf, "(%s),", QuestionMarks(mi.numFields)) queryArgs = append(queryArgs, row...) } @@ -71,12 +72,12 @@ func (mi *MultiInserter) query() (string, []interface{}) { } // Insert performs the action represented by .query() on the provided database, -// which is assumed to already have a context attached. If a non-empty retCol -// was provided, then it returns the list of values from that column returned -// by the query. -func (mi *MultiInserter) Insert(exec Executor) ([]int64, error) { +// which is assumed to already have a context attached. If a non-empty +// returningColumn was provided, then it returns the list of values from that +// column returned by the query. +func (mi *MultiInserter) Insert(queryer Queryer) ([]int64, error) { query, queryArgs := mi.query() - rows, err := exec.Query(query, queryArgs...) + rows, err := queryer.Query(query, queryArgs...) if err != nil { return nil, err } @@ -94,9 +95,15 @@ func (mi *MultiInserter) Insert(exec Executor) ([]int64, error) { } } - err = rows.Close() - if err != nil { - return nil, err + // Hack: sometimes in unittests we make a mock Queryer that returns a nil + // `*sql.Rows`. A nil `*sql.Rows` is not actually valid— calling `Close()` + // on it will panic— but here we choose to treat it like an empty list, + // and skip calling `Close()` to avoid the panic. + if rows != nil { + err = rows.Close() + if err != nil { + return nil, err + } } return ids, nil diff --git a/db/qmarks.go b/db/qmarks.go new file mode 100644 index 000000000..2439e0a12 --- /dev/null +++ b/db/qmarks.go @@ -0,0 +1,21 @@ +package db + +import "strings" + +// QuestionMarks returns a string consisting of N question marks, joined by +// commas. If n is <= 0, panics. +func QuestionMarks(n int) string { + if n <= 0 { + panic("db.QuestionMarks called with n <=0") + } + var qmarks strings.Builder + qmarks.Grow(2 * n) + for i := 0; i < n; i++ { + if i == 0 { + qmarks.WriteString("?") + } else { + qmarks.WriteString(",?") + } + } + return qmarks.String() +} diff --git a/db/qmarks_test.go b/db/qmarks_test.go new file mode 100644 index 000000000..6ee9ebb10 --- /dev/null +++ b/db/qmarks_test.go @@ -0,0 +1,19 @@ +package db + +import ( + "testing" + + "github.com/letsencrypt/boulder/test" +) + +func TestQuestionMarks(t *testing.T) { + test.AssertEquals(t, QuestionMarks(1), "?") + test.AssertEquals(t, QuestionMarks(2), "?,?") + test.AssertEquals(t, QuestionMarks(3), "?,?,?") +} + +func TestQuestionMarksPanic(t *testing.T) { + defer func() { recover() }() + QuestionMarks(0) + t.Errorf("calling QuestionMarks(0) did not panic as expected") +} diff --git a/sa/model.go b/sa/model.go index 9fc9d7901..baa3386a0 100644 --- a/sa/model.go +++ b/sa/model.go @@ -765,22 +765,27 @@ func deleteOrderFQDNSet( return nil } -func addIssuedNames(db db.Execer, cert *x509.Certificate, isRenewal bool) error { +func addIssuedNames(queryer db.Queryer, cert *x509.Certificate, isRenewal bool) error { if len(cert.DNSNames) == 0 { return berrors.InternalServerError("certificate has no DNSNames") } - var qmarks []string - var values []interface{} + + multiInserter, err := db.NewMultiInserter("issuedNames", "reversedName, serial, notBefore, renewal", "") + if err != nil { + return err + } for _, name := range cert.DNSNames { - values = append(values, + err = multiInserter.Add([]interface{}{ ReverseName(name), core.SerialToString(cert.SerialNumber), cert.NotBefore, - isRenewal) - qmarks = append(qmarks, "(?, ?, ?, ?)") + isRenewal, + }) + if err != nil { + return err + } } - query := `INSERT INTO issuedNames (reversedName, serial, notBefore, renewal) VALUES ` + strings.Join(qmarks, ", ") + `;` - _, err := db.Exec(query, values...) + _, err = multiInserter.Insert(queryer) return err } @@ -932,10 +937,8 @@ type authzValidity struct { // status and expiration date of each of them. It assumes that the provided // database selector already has a context associated with it. func getAuthorizationStatuses(s db.Selector, ids []int64) ([]authzValidity, error) { - var qmarks []string var params []interface{} for _, id := range ids { - qmarks = append(qmarks, "?") params = append(params, id) } var validityInfo []struct { @@ -944,7 +947,8 @@ func getAuthorizationStatuses(s db.Selector, ids []int64) ([]authzValidity, erro } _, err := s.Select( &validityInfo, - fmt.Sprintf("SELECT status, expires FROM authz2 WHERE id IN (%s)", strings.Join(qmarks, ",")), + fmt.Sprintf("SELECT status, expires FROM authz2 WHERE id IN (%s)", + db.QuestionMarks(len(ids))), params..., ) if err != nil { diff --git a/sa/sa_test.go b/sa/sa_test.go index f2de8bb9b..4a0f9234a 100644 --- a/sa/sa_test.go +++ b/sa/sa_test.go @@ -122,7 +122,6 @@ func createPendingAuthorization(t *testing.T, sa *SQLStorageAuthority, domain st err = sa.dbMap.Insert(&am) test.AssertNotError(t, err, "creating test authorization") - t.Log(am.ID) return am.ID } @@ -717,12 +716,12 @@ func TestFQDNSetsExists(t *testing.T) { test.Assert(t, exists.Exists, "FQDN set does exist") } -type execRecorder struct { +type queryRecorder struct { query string args []interface{} } -func (e *execRecorder) Exec(query string, args ...interface{}) (sql.Result, error) { +func (e *queryRecorder) Query(query string, args ...interface{}) (*sql.Rows, error) { e.query = query e.args = args return nil, nil @@ -732,7 +731,7 @@ func TestAddIssuedNames(t *testing.T) { serial := big.NewInt(1) expectedSerial := "000000000000000000000000000000000001" notBefore := time.Date(2018, 2, 14, 12, 0, 0, 0, time.UTC) - placeholdersPerName := "(?, ?, ?, ?)" + placeholdersPerName := "(?,?,?,?)" baseQuery := "INSERT INTO issuedNames (reversedName, serial, notBefore, renewal) VALUES" testCases := []struct { @@ -807,7 +806,7 @@ func TestAddIssuedNames(t *testing.T) { for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { - var e execRecorder + var e queryRecorder err := addIssuedNames( &e, &x509.Certificate{ @@ -819,7 +818,7 @@ func TestAddIssuedNames(t *testing.T) { test.AssertNotError(t, err, "addIssuedNames failed") expectedPlaceholders := placeholdersPerName for i := 0; i < len(tc.IssuedNames)-1; i++ { - expectedPlaceholders = fmt.Sprintf("%s, %s", expectedPlaceholders, placeholdersPerName) + expectedPlaceholders = fmt.Sprintf("%s,%s", expectedPlaceholders, placeholdersPerName) } expectedQuery := fmt.Sprintf("%s %s;", baseQuery, expectedPlaceholders) test.AssertEquals(t, e.query, expectedQuery) diff --git a/sa/saro.go b/sa/saro.go index 07a67c256..20261f5c5 100644 --- a/sa/saro.go +++ b/sa/saro.go @@ -760,10 +760,8 @@ func (ssa *SQLStorageAuthorityRO) GetAuthorizations2(ctx context.Context, req *s identifierTypeToUint[string(identifier.DNS)], } - qmarks := make([]string, len(req.Domains)) - for i, n := range req.Domains { - qmarks[i] = "?" - params = append(params, n) + for _, name := range req.Domains { + params = append(params, name) } query := fmt.Sprintf( @@ -775,7 +773,7 @@ func (ssa *SQLStorageAuthorityRO) GetAuthorizations2(ctx context.Context, req *s identifierType = ? AND identifierValue IN (%s)`, authzFields, - strings.Join(qmarks, ","), + db.QuestionMarks(len(req.Domains)), ) _, err := ssa.dbReadOnlyMap.Select( @@ -965,30 +963,31 @@ func (ssa *SQLStorageAuthorityRO) GetValidAuthorizations2(ctx context.Context, r return nil, errIncompleteRequest } - var authzModels []authzModel + query := fmt.Sprintf( + `SELECT %s FROM authz2 WHERE + registrationID = ? AND + status = ? AND + expires > ? AND + identifierType = ? AND + identifierValue IN (%s)`, + authzFields, + db.QuestionMarks(len(req.Domains)), + ) + params := []interface{}{ req.RegistrationID, statusUint(core.StatusValid), time.Unix(0, req.Now), identifierTypeToUint[string(identifier.DNS)], } - qmarks := make([]string, len(req.Domains)) - for i, n := range req.Domains { - qmarks[i] = "?" - params = append(params, n) + for _, domain := range req.Domains { + params = append(params, domain) } + + var authzModels []authzModel _, err := ssa.dbReadOnlyMap.Select( &authzModels, - fmt.Sprintf( - `SELECT %s FROM authz2 WHERE - registrationID = ? AND - status = ? AND - expires > ? AND - identifierType = ? AND - identifierValue IN (%s)`, - authzFields, - strings.Join(qmarks, ","), - ), + query, params..., ) if err != nil {