diff --git a/prefixed_db/db.go b/prefixed_db/db.go new file mode 100644 index 000000000..415dadcd6 --- /dev/null +++ b/prefixed_db/db.go @@ -0,0 +1,47 @@ +package prefixed_db + +import "database/sql/driver" + +// New clones a database driver to create a new driver with the property that +// every statement executed will have the given prefix prepended. +// This is useful, for instance, to set statement-level variables like +// max_statement_time and long_query_time. +func New(prefix string, underlying driver.Driver) driver.Driver { + return &prefixedDB{ + prefix: prefix, + underlying: underlying, + } +} + +type prefixedDB struct { + prefix string + underlying driver.Driver +} + +func (p *prefixedDB) Open(name string) (driver.Conn, error) { + conn, err := p.underlying.Open(name) + if err != nil { + return nil, err + } + return &prefixedConn{ + prefix: p.prefix, + conn: conn, + }, nil +} + +type prefixedConn struct { + prefix string + conn driver.Conn +} + +func (c *prefixedConn) Prepare(query string) (driver.Stmt, error) { + return c.conn.Prepare(c.prefix + " " + query) +} + +func (c *prefixedConn) Close() error { + return c.conn.Close() +} + +func (c *prefixedConn) Begin() (driver.Tx, error) { + return c.conn.Begin() +} diff --git a/prefixed_db/db_test.go b/prefixed_db/db_test.go new file mode 100644 index 000000000..24d8ed245 --- /dev/null +++ b/prefixed_db/db_test.go @@ -0,0 +1,36 @@ +package prefixed_db + +import ( + "database/sql" + "log" + "strings" + "sync" + "testing" + + "github.com/go-sql-driver/mysql" + "github.com/letsencrypt/boulder/test/vars" +) + +func TestPrefixing(t *testing.T) { + sql.Register("prefixedmysql", New("SET STATEMENT max_statement_time=0.1 FOR", mysql.MySQLDriver{})) + db, err := sql.Open("prefixedmysql", vars.DBConnSA) + if err != nil { + log.Fatal(err) + } + if err := db.Ping(); err != nil { + log.Fatal(err) + } + var wg sync.WaitGroup + for i := 1; i < 10; i++ { + wg.Add(1) + go func(i int) { + _, err := db.Exec("SELECT 1 FROM (SELECT SLEEP(?)) as subselect;", i) + if err == nil || !strings.HasPrefix(err.Error(), "Error 1969:") { + t.Error("Expected to get Error 1969 (timeout), got", err) + } + wg.Done() + }(i) + } + wg.Wait() + _ = db.Close() +} diff --git a/sa/database.go b/sa/database.go index 08177682c..1737e9c8c 100644 --- a/sa/database.go +++ b/sa/database.go @@ -1,8 +1,11 @@ package sa import ( + "crypto/rand" "database/sql" "fmt" + "math" + "math/big" "net/url" "strings" "time" @@ -14,6 +17,7 @@ import ( "github.com/letsencrypt/boulder/features" blog "github.com/letsencrypt/boulder/log" "github.com/letsencrypt/boulder/metrics" + "github.com/letsencrypt/boulder/prefixed_db" ) // NewDbMap creates the root gorp mapping object. Create one of these for each @@ -54,7 +58,39 @@ var setMaxOpenConns = func(db *sql.DB, maxOpenConns int) { func NewDbMapFromConfig(config *mysql.Config, maxOpenConns int) (*gorp.DbMap, error) { adjustMySQLConfig(config) - db, err := sqlOpen("mysql", config.FormatDSN()) + // We always want strict mode. Rather than leaving this up to DB config, we + // prefix each statement with it. + prefix := "SET STATEMENT sql_mode='STRICT_ALL_TABLES' FOR " + + // If a read timeout is set, we set max_statement_time to 95% of that, and + // long_query_time to 80% of that. That way we get logs of queries that are + // close to timing out but not yet doing so, and our queries get stopped by + // max_statement_time before timing out the read. This generates clearer + // errors, and avoids unnecessary reconnects. + if config.ReadTimeout != 0 { + // In MariaDB, max_statement_time and long_query_time are both seconds. + // Note: in MySQL (which we don't use), max_statement_time is millis. + readTimeout := config.ReadTimeout.Seconds() + prefix = fmt.Sprintf( + "SET STATEMENT max_statement_time=%g, long_query_time=%g, sql_mode='STRICT_ALL_TABLES' FOR ", + readTimeout*0.95, readTimeout*0.80) + } + + // The way we generate a customized database driver means that we need to + // choose a name to register the driver with. Because this function can be + // called multiple times with different parameters, we choose a random name + // each time we register to avoid conflicts with other DB instances. + // We use crypto/rand rather than math.Rand not out of a particular necessity + // for high-quality randomness, but simply because using non-crypto rand is a + // code smell. + driverNum, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) + if err != nil { + return nil, err + } + driverName := fmt.Sprintf("mysql-%d", driverNum) + sql.Register(driverName, prefixed_db.New(prefix, mysql.MySQLDriver{})) + + db, err := sqlOpen(driverName, config.FormatDSN()) if err != nil { return nil, err } @@ -67,25 +103,6 @@ func NewDbMapFromConfig(config *mysql.Config, maxOpenConns int) (*gorp.DbMap, er dbmap := &gorp.DbMap{Db: db, Dialect: dialect, TypeConverter: BoulderTypeConverter{}} initTables(dbmap) - _, err = dbmap.Exec("SET SESSION sql_mode = 'STRICT_ALL_TABLES';") - if err != nil { - return nil, err - } - // If a read timeout is set, we set max_statement_time to 95% of that, and - // long_query_time to 80% of that. That way we get logs of queries that are - // close to timing out but not yet doing so, and our queries get stopped by - // max_statement_time before timing out the read. This generates clearer - // errors, and avoids unnecessary reconnects. - if config.ReadTimeout != 0 { - // In MariaDB, max_statement_time and long_query_time are both seconds. - // Note: in MySQL (which we don't use), max_statement_time is millis. - readTimeout := config.ReadTimeout.Seconds() - _, err := dbmap.Exec("SET SESSION max_statement_time = ?, long_query_time = ?;", - readTimeout*0.95, readTimeout*0.80) - if err != nil { - return nil, err - } - } return dbmap, err } diff --git a/test/vars/vars.go b/test/vars/vars.go index 36c231a5a..f664c05f3 100644 --- a/test/vars/vars.go +++ b/test/vars/vars.go @@ -3,7 +3,7 @@ package vars import "fmt" const ( - dbURL = "mysql+tcp://%s@boulder-mysql:3306/%s" + dbURL = "%s@tcp(boulder-mysql:3306)/%s" ) var (