MySQL binding: allow passing parameters for queries (#2975)

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Alessandro (Ale) Segala 2023-07-12 14:16:35 -07:00 committed by GitHub
parent fd8e3a2086
commit 1349fca858
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 192 additions and 134 deletions

View File

@ -17,17 +17,17 @@ binding:
- name: query
description: "The query operation is used for SELECT statements, which returns the metadata along with data in a form of an array of row values."
- name: close
description: "The close operation can be used to explicitly close the DB connection and return it to the pool. This operation doesnt have any response."
description: "The close operation can be used to explicitly close the DB connection and return it to the pool. This operation doesn't have any response."
metadata:
- name: url
required: true
description: "Represent a DB connection in Data Source Name (DNS) format."
example: "user:password@tcp(localhost:3306)/dbname"
description: "Represent a DB connection in Data Source Name (DNS) format"
example: '"user:password@tcp(localhost:3306)/dbname"'
type: string
- name: pemPath
required: false
description: "Path to the PEM file. Used with SSL connection"
example: "path/to/pem/file"
example: '"path/to/pem/file"'
type: string
- name: maxIdleConns
required: false
@ -49,8 +49,3 @@ metadata:
description: "The max connection idel time."
example: "12s"
type: duration
- name: maxRetries
required: false
description: "MaxRetries is the maximum number of retries for a query."
example: "5"
type: number

View File

@ -25,6 +25,7 @@ import (
"os"
"reflect"
"strconv"
"sync/atomic"
"time"
"github.com/go-sql-driver/mysql"
@ -52,7 +53,8 @@ const (
// "%s:%s@tcp(%s:3306)/%s?allowNativePasswords=true&tls=custom",'myadmin@mydemoserver', 'yourpassword', 'mydemoserver.mysql.database.azure.com', 'targetdb'.
// keys from request's metadata.
commandSQLKey = "sql"
commandSQLKey = "sql"
commandParamsKey = "params"
// keys from response's metadata.
respOpKey = "operation"
@ -67,6 +69,7 @@ const (
type Mysql struct {
db *sql.DB
logger logger.Logger
closed atomic.Bool
}
type mysqlMetadata struct {
@ -87,21 +90,22 @@ type mysqlMetadata struct {
// ConnMaxIdleTime is the maximum amount of time a connection may be idle.
ConnMaxIdleTime time.Duration `mapstructure:"connMaxIdleTime"`
// MaxRetries is the maximum number of retries for a query.
MaxRetries int `mapstructure:"maxRetries"`
}
// NewMysql returns a new MySQL output binding.
func NewMysql(logger logger.Logger) bindings.OutputBinding {
return &Mysql{logger: logger}
return &Mysql{
logger: logger,
}
}
// Init initializes the MySQL binding.
func (m *Mysql) Init(ctx context.Context, md bindings.Metadata) error {
m.logger.Debug("Initializing MySql binding")
if m.closed.Load() {
return errors.New("cannot initialize a previously-closed component")
}
// parse metadata
// Parse metadata
meta := mysqlMetadata{}
err := metadata.DecodeMetadata(md.Properties, &meta)
if err != nil {
@ -112,23 +116,29 @@ func (m *Mysql) Init(ctx context.Context, md bindings.Metadata) error {
return fmt.Errorf("missing MySql connection string")
}
db, err := initDB(meta.URL, meta.PemPath)
m.db, err = initDB(meta.URL, meta.PemPath)
if err != nil {
return err
}
db.SetMaxIdleConns(meta.MaxIdleConns)
db.SetMaxOpenConns(meta.MaxOpenConns)
db.SetConnMaxIdleTime(meta.ConnMaxIdleTime)
db.SetConnMaxLifetime(meta.ConnMaxLifetime)
if meta.MaxIdleConns > 0 {
m.db.SetMaxIdleConns(meta.MaxIdleConns)
}
if meta.MaxOpenConns > 0 {
m.db.SetMaxOpenConns(meta.MaxOpenConns)
}
if meta.ConnMaxIdleTime > 0 {
m.db.SetConnMaxIdleTime(meta.ConnMaxIdleTime)
}
if meta.ConnMaxLifetime > 0 {
m.db.SetConnMaxLifetime(meta.ConnMaxLifetime)
}
err = db.PingContext(ctx)
err = m.db.PingContext(ctx)
if err != nil {
return fmt.Errorf("unable to ping the DB: %w", err)
}
m.db = db
return nil
}
@ -138,22 +148,38 @@ func (m *Mysql) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi
return nil, errors.New("invoke request required")
}
// We let the "close" operation here succeed even if the component has been closed already
if req.Operation == closeOperation {
return nil, m.db.Close()
return nil, m.Close()
}
if m.closed.Load() {
return nil, errors.New("component is closed")
}
if req.Metadata == nil {
return nil, errors.New("metadata required")
}
m.logger.Debugf("operation: %v", req.Operation)
s, ok := req.Metadata[commandSQLKey]
if !ok || s == "" {
s := req.Metadata[commandSQLKey]
if s == "" {
return nil, fmt.Errorf("required metadata not set: %s", commandSQLKey)
}
startTime := time.Now()
// Metadata property "params" contains JSON-encoded parameters, and it's optional
// If present, it must be unserializable into a []any object
var (
params []any
err error
)
if paramsStr := req.Metadata[commandParamsKey]; paramsStr != "" {
err = json.Unmarshal([]byte(paramsStr), &params)
if err != nil {
return nil, fmt.Errorf("invalid metadata property %s: failed to unserialize into an array: %w", commandParamsKey, err)
}
}
startTime := time.Now().UTC()
resp := &bindings.InvokeResponse{
Metadata: map[string]string{
respOpKey: string(req.Operation),
@ -162,16 +188,16 @@ func (m *Mysql) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi
},
}
switch req.Operation { //nolint:exhaustive
switch req.Operation {
case execOperation:
r, err := m.exec(ctx, s)
r, err := m.exec(ctx, s, params...)
if err != nil {
return nil, err
}
resp.Metadata[respRowsAffectedKey] = strconv.FormatInt(r, 10)
case queryOperation:
d, err := m.query(ctx, s)
d, err := m.query(ctx, s, params...)
if err != nil {
return nil, err
}
@ -182,7 +208,7 @@ func (m *Mysql) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi
req.Operation, execOperation, queryOperation, closeOperation)
}
endTime := time.Now()
endTime := time.Now().UTC()
resp.Metadata[respEndTimeKey] = endTime.Format(time.RFC3339Nano)
resp.Metadata[respDurationKey] = endTime.Sub(startTime).String()
@ -200,23 +226,26 @@ func (m *Mysql) Operations() []bindings.OperationKind {
// Close will close the DB.
func (m *Mysql) Close() error {
if !m.closed.CompareAndSwap(false, true) {
// If this failed, the component has already been closed
// We allow multiple calls to close
return nil
}
if m.db != nil {
return m.db.Close()
m.db.Close()
m.db = nil
}
return nil
}
func (m *Mysql) query(ctx context.Context, sql string) ([]byte, error) {
rows, err := m.db.QueryContext(ctx, sql)
func (m *Mysql) query(ctx context.Context, sql string, params ...any) ([]byte, error) {
rows, err := m.db.QueryContext(ctx, sql, params...)
if err != nil {
return nil, fmt.Errorf("error executing query: %w", err)
}
defer func() {
_ = rows.Close()
_ = rows.Err()
}()
defer rows.Close()
result, err := m.jsonify(rows)
if err != nil {
@ -226,10 +255,8 @@ func (m *Mysql) query(ctx context.Context, sql string) ([]byte, error) {
return result, nil
}
func (m *Mysql) exec(ctx context.Context, sql string) (int64, error) {
m.logger.Debugf("exec: %s", sql)
res, err := m.db.ExecContext(ctx, sql)
func (m *Mysql) exec(ctx context.Context, sql string, params ...any) (int64, error) {
res, err := m.db.ExecContext(ctx, sql, params...)
if err != nil {
return 0, fmt.Errorf("error executing query: %w", err)
}
@ -238,13 +265,15 @@ func (m *Mysql) exec(ctx context.Context, sql string) (int64, error) {
}
func initDB(url, pemPath string) (*sql.DB, error) {
if _, err := mysql.ParseDSN(url); err != nil {
conf, err := mysql.ParseDSN(url)
if err != nil {
return nil, fmt.Errorf("illegal Data Source Name (DSN) specified by %s", connectionURLKey)
}
if pemPath != "" {
var pem []byte
rootCertPool := x509.NewCertPool()
pem, err := os.ReadFile(pemPath)
pem, err = os.ReadFile(pemPath)
if err != nil {
return nil, fmt.Errorf("error reading PEM file from %s: %w", pemPath, err)
}
@ -254,17 +283,25 @@ func initDB(url, pemPath string) (*sql.DB, error) {
return nil, fmt.Errorf("failed to append PEM")
}
err = mysql.RegisterTLSConfig("custom", &tls.Config{RootCAs: rootCertPool, MinVersion: tls.VersionTLS12})
err = mysql.RegisterTLSConfig("custom", &tls.Config{
RootCAs: rootCertPool,
MinVersion: tls.VersionTLS12,
})
if err != nil {
return nil, fmt.Errorf("error register TLS config: %w", err)
}
}
db, err := sql.Open("mysql", url)
// Required to correctly parse time columns
// See: https://stackoverflow.com/a/45040724
conf.ParseTime = true
connector, err := mysql.NewConnector(conf)
if err != nil {
return nil, fmt.Errorf("error opening DB connection: %w", err)
}
db := sql.OpenDB(connector)
return db, nil
}
@ -274,7 +311,7 @@ func (m *Mysql) jsonify(rows *sql.Rows) ([]byte, error) {
return nil, err
}
var ret []interface{}
var ret []any
for rows.Next() {
values := prepareValues(columnTypes)
err := rows.Scan(values...)
@ -289,13 +326,13 @@ func (m *Mysql) jsonify(rows *sql.Rows) ([]byte, error) {
return json.Marshal(ret)
}
func prepareValues(columnTypes []*sql.ColumnType) []interface{} {
func prepareValues(columnTypes []*sql.ColumnType) []any {
types := make([]reflect.Type, len(columnTypes))
for i, tp := range columnTypes {
types[i] = tp.ScanType()
}
values := make([]interface{}, len(columnTypes))
values := make([]any, len(columnTypes))
for i := range values {
values[i] = reflect.New(types[i]).Interface()
}
@ -303,8 +340,8 @@ func prepareValues(columnTypes []*sql.ColumnType) []interface{} {
return values
}
func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []interface{}) map[string]interface{} {
r := map[string]interface{}{}
func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []any) map[string]any {
r := map[string]any{}
for i, ct := range columnTypes {
value := values[i]
@ -312,7 +349,7 @@ func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []interface{}) map
switch v := values[i].(type) {
case driver.Valuer:
if vv, err := v.Value(); err == nil {
value = interface{}(vv)
value = any(vv)
} else {
m.logger.Warnf("error to convert value: %v", err)
}

View File

@ -22,36 +22,20 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/bindings"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/logger"
)
const (
// MySQL doesn't accept RFC3339 formatted time, rejects trailing 'Z' for UTC indicator.
mySQLDateTimeFormat = "2006-01-02 15:04:05"
testCreateTable = `CREATE TABLE IF NOT EXISTS foo (
id bigint NOT NULL,
v1 character varying(50) NOT NULL,
b BOOLEAN,
ts TIMESTAMP,
data LONGTEXT)`
testDropTable = `DROP TABLE foo`
testInsert = "INSERT INTO foo (id, v1, b, ts, data) VALUES (%d, 'test-%d', %t, '%v', '%s')"
testDelete = "DELETE FROM foo"
testUpdate = "UPDATE foo SET ts = '%v' WHERE id = %d"
testSelect = "SELECT * FROM foo WHERE id < 3"
testSelectJSONExtract = "SELECT JSON_EXTRACT(data, '$.key') AS `key` FROM foo WHERE id < 3"
)
// MySQL doesn't accept RFC3339 formatted time, rejects trailing 'Z' for UTC indicator.
const mySQLDateTimeFormat = "2006-01-02 15:04:05"
func TestOperations(t *testing.T) {
t.Parallel()
t.Run("Get operation list", func(t *testing.T) {
t.Parallel()
b := NewMysql(nil)
assert.NotNil(t, b)
b := NewMysql(logger.NewLogger("test"))
require.NotNil(t, b)
l := b.Operations()
assert.Equal(t, 3, len(l))
assert.Contains(t, l, execOperation)
@ -70,123 +54,165 @@ func TestOperations(t *testing.T) {
func TestMysqlIntegration(t *testing.T) {
url := os.Getenv("MYSQL_TEST_CONN_URL")
if url == "" {
t.SkipNow()
t.Skip("Skipping because env var MYSQL_TEST_CONN_URL is empty")
}
b := NewMysql(logger.NewLogger("test")).(*Mysql)
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
if err := b.Init(context.Background(), m); err != nil {
t.Fatal(err)
}
err := b.Init(context.Background(), m)
require.NoError(t, err)
defer b.Close()
req := &bindings.InvokeRequest{Metadata: map[string]string{}}
t.Run("Invoke create table", func(t *testing.T) {
req.Operation = execOperation
req.Metadata[commandSQLKey] = testCreateTable
res, err := b.Invoke(context.TODO(), req)
res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
Operation: execOperation,
Metadata: map[string]string{
commandSQLKey: `CREATE TABLE IF NOT EXISTS foo (
id bigint NOT NULL,
v1 character varying(50) NOT NULL,
b BOOLEAN,
ts TIMESTAMP,
data LONGTEXT)`,
},
})
assertResponse(t, res, err)
})
t.Run("Invoke delete", func(t *testing.T) {
req.Operation = execOperation
req.Metadata[commandSQLKey] = testDelete
res, err := b.Invoke(context.TODO(), req)
res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
Operation: execOperation,
Metadata: map[string]string{
commandSQLKey: "DELETE FROM foo",
},
})
assertResponse(t, res, err)
})
t.Run("Invoke insert", func(t *testing.T) {
req.Operation = execOperation
for i := 0; i < 10; i++ {
req.Metadata[commandSQLKey] = fmt.Sprintf(testInsert, i, i, true, time.Now().Format(mySQLDateTimeFormat), "{\"key\":\"val\"}")
res, err := b.Invoke(context.TODO(), req)
res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
Operation: execOperation,
Metadata: map[string]string{
commandSQLKey: fmt.Sprintf(
"INSERT INTO foo (id, v1, b, ts, data) VALUES (%d, 'test-%d', %t, '%v', '%s')",
i, i, true, time.Now().Format(mySQLDateTimeFormat), `{"key":"val"}`),
},
})
assertResponse(t, res, err)
}
})
t.Run("Invoke update", func(t *testing.T) {
req.Operation = execOperation
date := time.Now().Add(time.Hour)
for i := 0; i < 10; i++ {
req.Metadata[commandSQLKey] = fmt.Sprintf(testUpdate, time.Now().Format(mySQLDateTimeFormat), i)
res, err := b.Invoke(context.TODO(), req)
res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
Operation: execOperation,
Metadata: map[string]string{
commandSQLKey: fmt.Sprintf(
"UPDATE foo SET ts = '%v' WHERE id = %d",
date.Add(10*time.Duration(i)*time.Second).Format(mySQLDateTimeFormat), i),
},
})
assertResponse(t, res, err)
assert.Equal(t, "1", res.Metadata[respRowsAffectedKey])
}
})
t.Run("Invoke update with parameters", func(t *testing.T) {
date := time.Now().Add(2 * time.Hour)
for i := 0; i < 10; i++ {
res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
Operation: execOperation,
Metadata: map[string]string{
commandSQLKey: "UPDATE foo SET ts = ? WHERE id = ?",
commandParamsKey: fmt.Sprintf(`[%q,%d]`, date.Add(10*time.Duration(i)*time.Second).Format(mySQLDateTimeFormat), i),
},
})
assertResponse(t, res, err)
assert.Equal(t, "1", res.Metadata[respRowsAffectedKey])
}
})
t.Run("Invoke select", func(t *testing.T) {
req.Operation = queryOperation
req.Metadata[commandSQLKey] = testSelect
res, err := b.Invoke(context.TODO(), req)
res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
Operation: queryOperation,
Metadata: map[string]string{
commandSQLKey: "SELECT * FROM foo WHERE id < 3",
},
})
assertResponse(t, res, err)
t.Logf("received result: %s", res.Data)
// verify number, boolean and string
assert.Contains(t, string(res.Data), "\"id\":1")
assert.Contains(t, string(res.Data), "\"b\":1")
assert.Contains(t, string(res.Data), "\"v1\":\"test-1\"")
assert.Contains(t, string(res.Data), "\"data\":\"{\\\"key\\\":\\\"val\\\"}\"")
assert.Contains(t, string(res.Data), `"id":1`)
assert.Contains(t, string(res.Data), `"b":1`)
assert.Contains(t, string(res.Data), `"v1":"test-1"`)
assert.Contains(t, string(res.Data), `"data":"{\"key\":\"val\"}"`)
result := make([]interface{}, 0)
result := make([]any, 0)
err = json.Unmarshal(res.Data, &result)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, 3, len(result))
// verify timestamp
ts, ok := result[0].(map[string]interface{})["ts"].(string)
ts, ok := result[0].(map[string]any)["ts"].(string)
assert.True(t, ok)
// have to use custom layout to parse timestamp, see this: https://github.com/dapr/components-contrib/pull/615
var tt time.Time
tt, err = time.Parse("2006-01-02T15:04:05Z", ts)
assert.Nil(t, err)
require.NoError(t, err)
t.Logf("time stamp is: %v", tt)
})
t.Run("Invoke select JSON_EXTRACT", func(t *testing.T) {
req.Operation = queryOperation
req.Metadata[commandSQLKey] = testSelectJSONExtract
res, err := b.Invoke(context.TODO(), req)
t.Run("Invoke select with parameters", func(t *testing.T) {
res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
Operation: queryOperation,
Metadata: map[string]string{
commandSQLKey: "SELECT * FROM foo WHERE id = ?",
commandParamsKey: `[1]`,
},
})
assertResponse(t, res, err)
t.Logf("received result: %s", res.Data)
// verify json extract number
assert.Contains(t, string(res.Data), "{\"key\":\"\\\"val\\\"\"}")
// verify number, boolean and string
assert.Contains(t, string(res.Data), `"id":1`)
assert.Contains(t, string(res.Data), `"b":1`)
assert.Contains(t, string(res.Data), `"v1":"test-1"`)
assert.Contains(t, string(res.Data), `"data":"{\"key\":\"val\"}"`)
result := make([]interface{}, 0)
result := make([]any, 0)
err = json.Unmarshal(res.Data, &result)
assert.Nil(t, err)
assert.Equal(t, 3, len(result))
})
t.Run("Invoke delete", func(t *testing.T) {
req.Operation = execOperation
req.Metadata[commandSQLKey] = testDelete
req.Data = nil
res, err := b.Invoke(context.TODO(), req)
assertResponse(t, res, err)
require.NoError(t, err)
assert.Equal(t, 1, len(result))
})
t.Run("Invoke drop", func(t *testing.T) {
req.Operation = execOperation
req.Metadata[commandSQLKey] = testDropTable
res, err := b.Invoke(context.TODO(), req)
res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
Operation: execOperation,
Metadata: map[string]string{
commandSQLKey: "DROP TABLE foo",
},
})
assertResponse(t, res, err)
})
t.Run("Invoke close", func(t *testing.T) {
req.Operation = closeOperation
req.Metadata = nil
req.Data = nil
_, err := b.Invoke(context.TODO(), req)
_, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
Operation: closeOperation,
})
assert.NoError(t, err)
})
}
func assertResponse(t *testing.T, res *bindings.InvokeResponse, err error) {
t.Helper()
assert.NoError(t, err)
assert.NotNil(t, res)
if res != nil {
assert.NotNil(t, res.Metadata)
assert.NotEmpty(t, res.Metadata)
}
}

View File

@ -42,7 +42,7 @@ func TestQuery(t *testing.T) {
assert.Nil(t, err)
t.Logf("query result: %s", ret)
assert.Contains(t, string(ret), "\"id\":1")
var result []interface{}
var result []any
err = json.Unmarshal(ret, &result)
assert.Nil(t, err)
assert.Equal(t, 3, len(result))
@ -65,13 +65,13 @@ func TestQuery(t *testing.T) {
assert.Contains(t, string(ret), "\"id\":1")
assert.Contains(t, string(ret), "\"value\":2.2")
var result []interface{}
var result []any
err = json.Unmarshal(ret, &result)
assert.Nil(t, err)
assert.Equal(t, 3, len(result))
// verify timestamp
ts, ok := result[0].(map[string]interface{})["timestamp"].(string)
ts, ok := result[0].(map[string]any)["timestamp"].(string)
assert.True(t, ok)
var tt time.Time
tt, err = time.Parse(time.RFC3339, ts)
@ -134,7 +134,7 @@ func TestInvoke(t *testing.T) {
}
resp, err := m.Invoke(context.Background(), req)
assert.Nil(t, err)
var data []interface{}
var data []any
err = json.Unmarshal(resp.Data, &data)
assert.Nil(t, err)
assert.Equal(t, 1, len(data))