From 673cc6eaa4f84509d1efc7c5844249a0309b99d3 Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Tue, 8 Nov 2016 06:07:36 -0800 Subject: [PATCH] Fix max_statement_time and long_query_time (#2311) We try to set `max_statement_time` on new database connections so that long queries can be interrupted server side. However, the existing code is broken because of connection pooling. The `SET SESSION max_statement_time=...` gets executed on on connection, but subsequent queries won't necessarily be executed on the same connection. This changes fixed the problem by introducing an alternate DB driver that wraps the MySQL driver to prefix every query with `SET STATEMENT max_statement_time=...` This also changes `vars.go` to use the DSN form of a database name instead of the URL form. This allows using it directly in `prefixed_db`'s tests, and this is the direction we're moving all of our database URLs. We previously used a homebrewed URL syntax because it allowed us to extract certain fields and set config options, but now `mysql.Config` serves that need. Fixes #2251 --- prefixed_db/db.go | 47 ++++++++++++++++++++++++++++++++++ prefixed_db/db_test.go | 36 ++++++++++++++++++++++++++ sa/database.go | 57 +++++++++++++++++++++++++++--------------- test/vars/vars.go | 2 +- 4 files changed, 121 insertions(+), 21 deletions(-) create mode 100644 prefixed_db/db.go create mode 100644 prefixed_db/db_test.go diff --git a/prefixed_db/db.go b/prefixed_db/db.go new file mode 100644 index 000000000..415dadcd6 --- /dev/null +++ b/prefixed_db/db.go @@ -0,0 +1,47 @@ +package prefixed_db + +import "database/sql/driver" + +// New clones a database driver to create a new driver with the property that +// every statement executed will have the given prefix prepended. +// This is useful, for instance, to set statement-level variables like +// max_statement_time and long_query_time. +func New(prefix string, underlying driver.Driver) driver.Driver { + return &prefixedDB{ + prefix: prefix, + underlying: underlying, + } +} + +type prefixedDB struct { + prefix string + underlying driver.Driver +} + +func (p *prefixedDB) Open(name string) (driver.Conn, error) { + conn, err := p.underlying.Open(name) + if err != nil { + return nil, err + } + return &prefixedConn{ + prefix: p.prefix, + conn: conn, + }, nil +} + +type prefixedConn struct { + prefix string + conn driver.Conn +} + +func (c *prefixedConn) Prepare(query string) (driver.Stmt, error) { + return c.conn.Prepare(c.prefix + " " + query) +} + +func (c *prefixedConn) Close() error { + return c.conn.Close() +} + +func (c *prefixedConn) Begin() (driver.Tx, error) { + return c.conn.Begin() +} diff --git a/prefixed_db/db_test.go b/prefixed_db/db_test.go new file mode 100644 index 000000000..24d8ed245 --- /dev/null +++ b/prefixed_db/db_test.go @@ -0,0 +1,36 @@ +package prefixed_db + +import ( + "database/sql" + "log" + "strings" + "sync" + "testing" + + "github.com/go-sql-driver/mysql" + "github.com/letsencrypt/boulder/test/vars" +) + +func TestPrefixing(t *testing.T) { + sql.Register("prefixedmysql", New("SET STATEMENT max_statement_time=0.1 FOR", mysql.MySQLDriver{})) + db, err := sql.Open("prefixedmysql", vars.DBConnSA) + if err != nil { + log.Fatal(err) + } + if err := db.Ping(); err != nil { + log.Fatal(err) + } + var wg sync.WaitGroup + for i := 1; i < 10; i++ { + wg.Add(1) + go func(i int) { + _, err := db.Exec("SELECT 1 FROM (SELECT SLEEP(?)) as subselect;", i) + if err == nil || !strings.HasPrefix(err.Error(), "Error 1969:") { + t.Error("Expected to get Error 1969 (timeout), got", err) + } + wg.Done() + }(i) + } + wg.Wait() + _ = db.Close() +} diff --git a/sa/database.go b/sa/database.go index 08177682c..1737e9c8c 100644 --- a/sa/database.go +++ b/sa/database.go @@ -1,8 +1,11 @@ package sa import ( + "crypto/rand" "database/sql" "fmt" + "math" + "math/big" "net/url" "strings" "time" @@ -14,6 +17,7 @@ import ( "github.com/letsencrypt/boulder/features" blog "github.com/letsencrypt/boulder/log" "github.com/letsencrypt/boulder/metrics" + "github.com/letsencrypt/boulder/prefixed_db" ) // NewDbMap creates the root gorp mapping object. Create one of these for each @@ -54,7 +58,39 @@ var setMaxOpenConns = func(db *sql.DB, maxOpenConns int) { func NewDbMapFromConfig(config *mysql.Config, maxOpenConns int) (*gorp.DbMap, error) { adjustMySQLConfig(config) - db, err := sqlOpen("mysql", config.FormatDSN()) + // We always want strict mode. Rather than leaving this up to DB config, we + // prefix each statement with it. + prefix := "SET STATEMENT sql_mode='STRICT_ALL_TABLES' FOR " + + // If a read timeout is set, we set max_statement_time to 95% of that, and + // long_query_time to 80% of that. That way we get logs of queries that are + // close to timing out but not yet doing so, and our queries get stopped by + // max_statement_time before timing out the read. This generates clearer + // errors, and avoids unnecessary reconnects. + if config.ReadTimeout != 0 { + // In MariaDB, max_statement_time and long_query_time are both seconds. + // Note: in MySQL (which we don't use), max_statement_time is millis. + readTimeout := config.ReadTimeout.Seconds() + prefix = fmt.Sprintf( + "SET STATEMENT max_statement_time=%g, long_query_time=%g, sql_mode='STRICT_ALL_TABLES' FOR ", + readTimeout*0.95, readTimeout*0.80) + } + + // The way we generate a customized database driver means that we need to + // choose a name to register the driver with. Because this function can be + // called multiple times with different parameters, we choose a random name + // each time we register to avoid conflicts with other DB instances. + // We use crypto/rand rather than math.Rand not out of a particular necessity + // for high-quality randomness, but simply because using non-crypto rand is a + // code smell. + driverNum, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) + if err != nil { + return nil, err + } + driverName := fmt.Sprintf("mysql-%d", driverNum) + sql.Register(driverName, prefixed_db.New(prefix, mysql.MySQLDriver{})) + + db, err := sqlOpen(driverName, config.FormatDSN()) if err != nil { return nil, err } @@ -67,25 +103,6 @@ func NewDbMapFromConfig(config *mysql.Config, maxOpenConns int) (*gorp.DbMap, er dbmap := &gorp.DbMap{Db: db, Dialect: dialect, TypeConverter: BoulderTypeConverter{}} initTables(dbmap) - _, err = dbmap.Exec("SET SESSION sql_mode = 'STRICT_ALL_TABLES';") - if err != nil { - return nil, err - } - // If a read timeout is set, we set max_statement_time to 95% of that, and - // long_query_time to 80% of that. That way we get logs of queries that are - // close to timing out but not yet doing so, and our queries get stopped by - // max_statement_time before timing out the read. This generates clearer - // errors, and avoids unnecessary reconnects. - if config.ReadTimeout != 0 { - // In MariaDB, max_statement_time and long_query_time are both seconds. - // Note: in MySQL (which we don't use), max_statement_time is millis. - readTimeout := config.ReadTimeout.Seconds() - _, err := dbmap.Exec("SET SESSION max_statement_time = ?, long_query_time = ?;", - readTimeout*0.95, readTimeout*0.80) - if err != nil { - return nil, err - } - } return dbmap, err } diff --git a/test/vars/vars.go b/test/vars/vars.go index 36c231a5a..f664c05f3 100644 --- a/test/vars/vars.go +++ b/test/vars/vars.go @@ -3,7 +3,7 @@ package vars import "fmt" const ( - dbURL = "mysql+tcp://%s@boulder-mysql:3306/%s" + dbURL = "%s@tcp(boulder-mysql:3306)/%s" ) var (