Extract out `db.QuestionMarks` function (#6568)
We use this pattern in several places: there is a query that needs to have a variable number of placeholders (question marks) in it, depending on how many items we are inserting or querying for. For instance, when issuing a precertificate we add that precertificate's names to the "issuedNames" table. To make things more efficient, we do that in a single query, whether there is one name on the certificate or a hundred. That means interpolating into the query string with series of question marks that matches the number of names. We have a helper type MultiInserter that solves this problem for simple inserts, but it does not solve the problem for selects or more complex inserts, and we still have a number of places that generate their sequence of question marks manually. This change updates addIssuedNames to use MultiInserter. To enable that, it also narrows the interface required by MultiInserter.Insert, so it's easier to mock in tests. This change adds the new function db.QuestionMarks, which generates e.g. `?,?,?` depending on the input N. In a few places I had to rename a function parameter named `db` to avoid shadowing the `db` package.
This commit is contained in:
parent
1e7c64e5f2
commit
4be76afcaf
|
@ -212,16 +212,14 @@ func (m *mailer) updateLastNagTimestamps(ctx context.Context, certs []*x509.Cert
|
||||||
|
|
||||||
// updateLastNagTimestampsChunk processes a single chunk (up to 65k) of certificates.
|
// updateLastNagTimestampsChunk processes a single chunk (up to 65k) of certificates.
|
||||||
func (m *mailer) updateLastNagTimestampsChunk(ctx context.Context, certs []*x509.Certificate) {
|
func (m *mailer) updateLastNagTimestampsChunk(ctx context.Context, certs []*x509.Certificate) {
|
||||||
qmarks := make([]string, len(certs))
|
|
||||||
params := make([]interface{}, len(certs)+1)
|
params := make([]interface{}, len(certs)+1)
|
||||||
for i, cert := range certs {
|
for i, cert := range certs {
|
||||||
qmarks[i] = "?"
|
|
||||||
params[i+1] = core.SerialToString(cert.SerialNumber)
|
params[i+1] = core.SerialToString(cert.SerialNumber)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf(
|
query := fmt.Sprintf(
|
||||||
"UPDATE certificateStatus SET lastExpirationNagSent = ? WHERE serial IN (%s)",
|
"UPDATE certificateStatus SET lastExpirationNagSent = ? WHERE serial IN (%s)",
|
||||||
strings.Join(qmarks, ","),
|
db.QuestionMarks(len(certs)),
|
||||||
)
|
)
|
||||||
params[0] = m.clk.Now()
|
params[0] = m.clk.Now()
|
||||||
|
|
||||||
|
|
|
@ -64,6 +64,14 @@ type Executor interface {
|
||||||
Query(string, ...interface{}) (*sql.Rows, error)
|
Query(string, ...interface{}) (*sql.Rows, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Queryer offers the Query method. Note that this is not read-only (i.e. not
|
||||||
|
// Selector), since a Query can be `INSERT`, `UPDATE`, etc. The difference
|
||||||
|
// between Query and Exec is that Query can return rows. So for instance it is
|
||||||
|
// suitable for inserting rows and getting back ids.
|
||||||
|
type Queryer interface {
|
||||||
|
Query(string, ...interface{}) (*sql.Rows, error)
|
||||||
|
}
|
||||||
|
|
||||||
// Transaction extends an Executor and adds Rollback, Commit, and WithContext.
|
// Transaction extends an Executor and adds Rollback, Commit, and WithContext.
|
||||||
type Transaction interface {
|
type Transaction interface {
|
||||||
Executor
|
Executor
|
||||||
|
|
39
db/multi.go
39
db/multi.go
|
@ -18,20 +18,23 @@ type MultiInserter struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMultiInserter creates a new MultiInserter, checking for reasonable table
|
// NewMultiInserter creates a new MultiInserter, checking for reasonable table
|
||||||
// name and list of fields.
|
// name and list of fields. returningColumn is the name of a column to be used
|
||||||
func NewMultiInserter(table string, fields string, retCol string) (*MultiInserter, error) {
|
// in a `RETURNING xyz` clause at the end. If it is empty, no `RETURNING xyz`
|
||||||
|
// clause is used. If returningColumn is present, it must refer to a column
|
||||||
|
// that can be parsed into an int64.
|
||||||
|
func NewMultiInserter(table string, fields string, returningColumn string) (*MultiInserter, error) {
|
||||||
numFields := len(strings.Split(fields, ","))
|
numFields := len(strings.Split(fields, ","))
|
||||||
if len(table) == 0 || len(fields) == 0 || numFields == 0 {
|
if len(table) == 0 || len(fields) == 0 || numFields == 0 {
|
||||||
return nil, fmt.Errorf("empty table name or fields list")
|
return nil, fmt.Errorf("empty table name or fields list")
|
||||||
}
|
}
|
||||||
if strings.Contains(retCol, ",") {
|
if strings.Contains(returningColumn, ",") {
|
||||||
return nil, fmt.Errorf("return column must be singular, but got %q", retCol)
|
return nil, fmt.Errorf("return column must be singular, but got %q", returningColumn)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &MultiInserter{
|
return &MultiInserter{
|
||||||
table: table,
|
table: table,
|
||||||
fields: fields,
|
fields: fields,
|
||||||
retCol: retCol,
|
retCol: returningColumn,
|
||||||
numFields: numFields,
|
numFields: numFields,
|
||||||
values: make([][]interface{}, 0),
|
values: make([][]interface{}, 0),
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -50,12 +53,10 @@ func (mi *MultiInserter) Add(row []interface{}) error {
|
||||||
// for gorp to use in place of the query's question marks. Currently only
|
// for gorp to use in place of the query's question marks. Currently only
|
||||||
// used by .Insert(), below.
|
// used by .Insert(), below.
|
||||||
func (mi *MultiInserter) query() (string, []interface{}) {
|
func (mi *MultiInserter) query() (string, []interface{}) {
|
||||||
questionsRow := strings.TrimRight(strings.Repeat("?,", mi.numFields), ",")
|
|
||||||
|
|
||||||
var questionsBuf strings.Builder
|
var questionsBuf strings.Builder
|
||||||
var queryArgs []interface{}
|
var queryArgs []interface{}
|
||||||
for _, row := range mi.values {
|
for _, row := range mi.values {
|
||||||
fmt.Fprintf(&questionsBuf, "(%s),", questionsRow)
|
fmt.Fprintf(&questionsBuf, "(%s),", QuestionMarks(mi.numFields))
|
||||||
queryArgs = append(queryArgs, row...)
|
queryArgs = append(queryArgs, row...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,12 +72,12 @@ func (mi *MultiInserter) query() (string, []interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert performs the action represented by .query() on the provided database,
|
// Insert performs the action represented by .query() on the provided database,
|
||||||
// which is assumed to already have a context attached. If a non-empty retCol
|
// which is assumed to already have a context attached. If a non-empty
|
||||||
// was provided, then it returns the list of values from that column returned
|
// returningColumn was provided, then it returns the list of values from that
|
||||||
// by the query.
|
// column returned by the query.
|
||||||
func (mi *MultiInserter) Insert(exec Executor) ([]int64, error) {
|
func (mi *MultiInserter) Insert(queryer Queryer) ([]int64, error) {
|
||||||
query, queryArgs := mi.query()
|
query, queryArgs := mi.query()
|
||||||
rows, err := exec.Query(query, queryArgs...)
|
rows, err := queryer.Query(query, queryArgs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -94,9 +95,15 @@ func (mi *MultiInserter) Insert(exec Executor) ([]int64, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = rows.Close()
|
// Hack: sometimes in unittests we make a mock Queryer that returns a nil
|
||||||
if err != nil {
|
// `*sql.Rows`. A nil `*sql.Rows` is not actually valid— calling `Close()`
|
||||||
return nil, err
|
// on it will panic— but here we choose to treat it like an empty list,
|
||||||
|
// and skip calling `Close()` to avoid the panic.
|
||||||
|
if rows != nil {
|
||||||
|
err = rows.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ids, nil
|
return ids, nil
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// QuestionMarks returns a string consisting of N question marks, joined by
|
||||||
|
// commas. If n is <= 0, panics.
|
||||||
|
func QuestionMarks(n int) string {
|
||||||
|
if n <= 0 {
|
||||||
|
panic("db.QuestionMarks called with n <=0")
|
||||||
|
}
|
||||||
|
var qmarks strings.Builder
|
||||||
|
qmarks.Grow(2 * n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if i == 0 {
|
||||||
|
qmarks.WriteString("?")
|
||||||
|
} else {
|
||||||
|
qmarks.WriteString(",?")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return qmarks.String()
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/letsencrypt/boulder/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQuestionMarks(t *testing.T) {
|
||||||
|
test.AssertEquals(t, QuestionMarks(1), "?")
|
||||||
|
test.AssertEquals(t, QuestionMarks(2), "?,?")
|
||||||
|
test.AssertEquals(t, QuestionMarks(3), "?,?,?")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuestionMarksPanic(t *testing.T) {
|
||||||
|
defer func() { recover() }()
|
||||||
|
QuestionMarks(0)
|
||||||
|
t.Errorf("calling QuestionMarks(0) did not panic as expected")
|
||||||
|
}
|
26
sa/model.go
26
sa/model.go
|
@ -765,22 +765,27 @@ func deleteOrderFQDNSet(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func addIssuedNames(db db.Execer, cert *x509.Certificate, isRenewal bool) error {
|
func addIssuedNames(queryer db.Queryer, cert *x509.Certificate, isRenewal bool) error {
|
||||||
if len(cert.DNSNames) == 0 {
|
if len(cert.DNSNames) == 0 {
|
||||||
return berrors.InternalServerError("certificate has no DNSNames")
|
return berrors.InternalServerError("certificate has no DNSNames")
|
||||||
}
|
}
|
||||||
var qmarks []string
|
|
||||||
var values []interface{}
|
multiInserter, err := db.NewMultiInserter("issuedNames", "reversedName, serial, notBefore, renewal", "")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
for _, name := range cert.DNSNames {
|
for _, name := range cert.DNSNames {
|
||||||
values = append(values,
|
err = multiInserter.Add([]interface{}{
|
||||||
ReverseName(name),
|
ReverseName(name),
|
||||||
core.SerialToString(cert.SerialNumber),
|
core.SerialToString(cert.SerialNumber),
|
||||||
cert.NotBefore,
|
cert.NotBefore,
|
||||||
isRenewal)
|
isRenewal,
|
||||||
qmarks = append(qmarks, "(?, ?, ?, ?)")
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
query := `INSERT INTO issuedNames (reversedName, serial, notBefore, renewal) VALUES ` + strings.Join(qmarks, ", ") + `;`
|
_, err = multiInserter.Insert(queryer)
|
||||||
_, err := db.Exec(query, values...)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -932,10 +937,8 @@ type authzValidity struct {
|
||||||
// status and expiration date of each of them. It assumes that the provided
|
// status and expiration date of each of them. It assumes that the provided
|
||||||
// database selector already has a context associated with it.
|
// database selector already has a context associated with it.
|
||||||
func getAuthorizationStatuses(s db.Selector, ids []int64) ([]authzValidity, error) {
|
func getAuthorizationStatuses(s db.Selector, ids []int64) ([]authzValidity, error) {
|
||||||
var qmarks []string
|
|
||||||
var params []interface{}
|
var params []interface{}
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
qmarks = append(qmarks, "?")
|
|
||||||
params = append(params, id)
|
params = append(params, id)
|
||||||
}
|
}
|
||||||
var validityInfo []struct {
|
var validityInfo []struct {
|
||||||
|
@ -944,7 +947,8 @@ func getAuthorizationStatuses(s db.Selector, ids []int64) ([]authzValidity, erro
|
||||||
}
|
}
|
||||||
_, err := s.Select(
|
_, err := s.Select(
|
||||||
&validityInfo,
|
&validityInfo,
|
||||||
fmt.Sprintf("SELECT status, expires FROM authz2 WHERE id IN (%s)", strings.Join(qmarks, ",")),
|
fmt.Sprintf("SELECT status, expires FROM authz2 WHERE id IN (%s)",
|
||||||
|
db.QuestionMarks(len(ids))),
|
||||||
params...,
|
params...,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -122,7 +122,6 @@ func createPendingAuthorization(t *testing.T, sa *SQLStorageAuthority, domain st
|
||||||
err = sa.dbMap.Insert(&am)
|
err = sa.dbMap.Insert(&am)
|
||||||
test.AssertNotError(t, err, "creating test authorization")
|
test.AssertNotError(t, err, "creating test authorization")
|
||||||
|
|
||||||
t.Log(am.ID)
|
|
||||||
return am.ID
|
return am.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -717,12 +716,12 @@ func TestFQDNSetsExists(t *testing.T) {
|
||||||
test.Assert(t, exists.Exists, "FQDN set does exist")
|
test.Assert(t, exists.Exists, "FQDN set does exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
type execRecorder struct {
|
type queryRecorder struct {
|
||||||
query string
|
query string
|
||||||
args []interface{}
|
args []interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *execRecorder) Exec(query string, args ...interface{}) (sql.Result, error) {
|
func (e *queryRecorder) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
e.query = query
|
e.query = query
|
||||||
e.args = args
|
e.args = args
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -732,7 +731,7 @@ func TestAddIssuedNames(t *testing.T) {
|
||||||
serial := big.NewInt(1)
|
serial := big.NewInt(1)
|
||||||
expectedSerial := "000000000000000000000000000000000001"
|
expectedSerial := "000000000000000000000000000000000001"
|
||||||
notBefore := time.Date(2018, 2, 14, 12, 0, 0, 0, time.UTC)
|
notBefore := time.Date(2018, 2, 14, 12, 0, 0, 0, time.UTC)
|
||||||
placeholdersPerName := "(?, ?, ?, ?)"
|
placeholdersPerName := "(?,?,?,?)"
|
||||||
baseQuery := "INSERT INTO issuedNames (reversedName, serial, notBefore, renewal) VALUES"
|
baseQuery := "INSERT INTO issuedNames (reversedName, serial, notBefore, renewal) VALUES"
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
|
@ -807,7 +806,7 @@ func TestAddIssuedNames(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
var e execRecorder
|
var e queryRecorder
|
||||||
err := addIssuedNames(
|
err := addIssuedNames(
|
||||||
&e,
|
&e,
|
||||||
&x509.Certificate{
|
&x509.Certificate{
|
||||||
|
@ -819,7 +818,7 @@ func TestAddIssuedNames(t *testing.T) {
|
||||||
test.AssertNotError(t, err, "addIssuedNames failed")
|
test.AssertNotError(t, err, "addIssuedNames failed")
|
||||||
expectedPlaceholders := placeholdersPerName
|
expectedPlaceholders := placeholdersPerName
|
||||||
for i := 0; i < len(tc.IssuedNames)-1; i++ {
|
for i := 0; i < len(tc.IssuedNames)-1; i++ {
|
||||||
expectedPlaceholders = fmt.Sprintf("%s, %s", expectedPlaceholders, placeholdersPerName)
|
expectedPlaceholders = fmt.Sprintf("%s,%s", expectedPlaceholders, placeholdersPerName)
|
||||||
}
|
}
|
||||||
expectedQuery := fmt.Sprintf("%s %s;", baseQuery, expectedPlaceholders)
|
expectedQuery := fmt.Sprintf("%s %s;", baseQuery, expectedPlaceholders)
|
||||||
test.AssertEquals(t, e.query, expectedQuery)
|
test.AssertEquals(t, e.query, expectedQuery)
|
||||||
|
|
39
sa/saro.go
39
sa/saro.go
|
@ -760,10 +760,8 @@ func (ssa *SQLStorageAuthorityRO) GetAuthorizations2(ctx context.Context, req *s
|
||||||
identifierTypeToUint[string(identifier.DNS)],
|
identifierTypeToUint[string(identifier.DNS)],
|
||||||
}
|
}
|
||||||
|
|
||||||
qmarks := make([]string, len(req.Domains))
|
for _, name := range req.Domains {
|
||||||
for i, n := range req.Domains {
|
params = append(params, name)
|
||||||
qmarks[i] = "?"
|
|
||||||
params = append(params, n)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf(
|
query := fmt.Sprintf(
|
||||||
|
@ -775,7 +773,7 @@ func (ssa *SQLStorageAuthorityRO) GetAuthorizations2(ctx context.Context, req *s
|
||||||
identifierType = ? AND
|
identifierType = ? AND
|
||||||
identifierValue IN (%s)`,
|
identifierValue IN (%s)`,
|
||||||
authzFields,
|
authzFields,
|
||||||
strings.Join(qmarks, ","),
|
db.QuestionMarks(len(req.Domains)),
|
||||||
)
|
)
|
||||||
|
|
||||||
_, err := ssa.dbReadOnlyMap.Select(
|
_, err := ssa.dbReadOnlyMap.Select(
|
||||||
|
@ -965,30 +963,31 @@ func (ssa *SQLStorageAuthorityRO) GetValidAuthorizations2(ctx context.Context, r
|
||||||
return nil, errIncompleteRequest
|
return nil, errIncompleteRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
var authzModels []authzModel
|
query := fmt.Sprintf(
|
||||||
|
`SELECT %s FROM authz2 WHERE
|
||||||
|
registrationID = ? AND
|
||||||
|
status = ? AND
|
||||||
|
expires > ? AND
|
||||||
|
identifierType = ? AND
|
||||||
|
identifierValue IN (%s)`,
|
||||||
|
authzFields,
|
||||||
|
db.QuestionMarks(len(req.Domains)),
|
||||||
|
)
|
||||||
|
|
||||||
params := []interface{}{
|
params := []interface{}{
|
||||||
req.RegistrationID,
|
req.RegistrationID,
|
||||||
statusUint(core.StatusValid),
|
statusUint(core.StatusValid),
|
||||||
time.Unix(0, req.Now),
|
time.Unix(0, req.Now),
|
||||||
identifierTypeToUint[string(identifier.DNS)],
|
identifierTypeToUint[string(identifier.DNS)],
|
||||||
}
|
}
|
||||||
qmarks := make([]string, len(req.Domains))
|
for _, domain := range req.Domains {
|
||||||
for i, n := range req.Domains {
|
params = append(params, domain)
|
||||||
qmarks[i] = "?"
|
|
||||||
params = append(params, n)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var authzModels []authzModel
|
||||||
_, err := ssa.dbReadOnlyMap.Select(
|
_, err := ssa.dbReadOnlyMap.Select(
|
||||||
&authzModels,
|
&authzModels,
|
||||||
fmt.Sprintf(
|
query,
|
||||||
`SELECT %s FROM authz2 WHERE
|
|
||||||
registrationID = ? AND
|
|
||||||
status = ? AND
|
|
||||||
expires > ? AND
|
|
||||||
identifierType = ? AND
|
|
||||||
identifierValue IN (%s)`,
|
|
||||||
authzFields,
|
|
||||||
strings.Join(qmarks, ","),
|
|
||||||
),
|
|
||||||
params...,
|
params...,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in New Issue