mysql outbinding support (#615)
* mysql outbinding support * fix lint issues * use time duration string instead of number of seconds * correct typo * use 'addr' instead of 'server' and 'port' in order to align with url in dsn format * simplify configuration, and allow PEM configurable * jsonfy the query result * add more unit test for mysql binding * add unit test to verify timestamp * add type verify in integration test * add test to verify BOOLEAN * update comment Co-authored-by: Yaron Schneider <yaronsc@microsoft.com>
This commit is contained in:
parent
11454d7ccf
commit
97912e75c6
|
|
@ -0,0 +1,369 @@
|
|||
// ------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
// ------------------------------------------------------------
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/dapr/components-contrib/bindings"
|
||||
"github.com/dapr/dapr/pkg/logger"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
// list of operations.
|
||||
execOperation bindings.OperationKind = "exec"
|
||||
queryOperation bindings.OperationKind = "query"
|
||||
closeOperation bindings.OperationKind = "close"
|
||||
|
||||
// configurations to connect to Mysql, either a data source name represent by URL
|
||||
connectionURLKey = "url"
|
||||
|
||||
// To connect to MySQL running in Azure over SSL you have to download a
|
||||
// SSL certificate. If this is provided the driver will connect using
|
||||
// SSL. If you have disable SSL you can leave this empty.
|
||||
// When the user provides a pem path their connection string must end with
|
||||
// &tls=custom
|
||||
// The connection string should be in the following format
|
||||
// "%s:%s@tcp(%s:3306)/%s?allowNativePasswords=true&tls=custom",'myadmin@mydemoserver', 'yourpassword', 'mydemoserver.mysql.database.azure.com', 'targetdb'
|
||||
pemPathKey = "pemPath"
|
||||
|
||||
// other general settings for DB connections
|
||||
maxIdleConnsKey = "maxIdleConns"
|
||||
maxOpenConnsKey = "maxOpenConns"
|
||||
connMaxLifetimeKey = "connMaxLifetime"
|
||||
connMaxIdleTimeKey = "connMaxIdleTime"
|
||||
|
||||
// keys from request's metadata
|
||||
commandSQLKey = "sql"
|
||||
|
||||
// keys from response's metadata
|
||||
respOpKey = "operation"
|
||||
respSQLKey = "sql"
|
||||
respStartTimeKey = "start-time"
|
||||
respRowsAffectedKey = "rows-affected"
|
||||
respEndTimeKey = "end-time"
|
||||
respDurationKey = "duration"
|
||||
)
|
||||
|
||||
// Mysql represents MySQL output bindings
|
||||
type Mysql struct {
|
||||
db *sql.DB
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
var _ = bindings.OutputBinding(&Mysql{})
|
||||
|
||||
// NewMysql returns a new MySQL output binding
|
||||
func NewMysql(logger logger.Logger) *Mysql {
|
||||
return &Mysql{logger: logger}
|
||||
}
|
||||
|
||||
// Init initializes the MySQL binding
|
||||
func (m *Mysql) Init(metadata bindings.Metadata) error {
|
||||
m.logger.Debug("Initializing MySql binding")
|
||||
|
||||
p := metadata.Properties
|
||||
url, ok := p[connectionURLKey]
|
||||
if !ok || url == "" {
|
||||
return fmt.Errorf("missing MySql connection string")
|
||||
}
|
||||
|
||||
db, err := initDB(url, metadata.Properties[pemPathKey])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = propertyToInt(p, maxIdleConnsKey, db.SetMaxIdleConns)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = propertyToInt(p, maxOpenConnsKey, db.SetMaxOpenConns)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = propertyToDuration(p, connMaxIdleTimeKey, db.SetConnMaxIdleTime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = propertyToDuration(p, connMaxLifetimeKey, db.SetConnMaxLifetime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = db.Ping()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "unable to ping the DB")
|
||||
}
|
||||
|
||||
m.db = db
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invoke handles all invoke operations
|
||||
func (m *Mysql) Invoke(req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||
if req == nil {
|
||||
return nil, errors.Errorf("invoke request required")
|
||||
}
|
||||
|
||||
if req.Operation == closeOperation {
|
||||
return nil, m.db.Close()
|
||||
}
|
||||
|
||||
if req.Metadata == nil {
|
||||
return nil, errors.Errorf("metadata required")
|
||||
}
|
||||
m.logger.Debugf("operation: %v", req.Operation)
|
||||
|
||||
s, ok := req.Metadata[commandSQLKey]
|
||||
if !ok || s == "" {
|
||||
return nil, errors.Errorf("required metadata not set: %s", commandSQLKey)
|
||||
}
|
||||
|
||||
startTime := time.Now().UTC()
|
||||
|
||||
resp := &bindings.InvokeResponse{
|
||||
Metadata: map[string]string{
|
||||
respOpKey: string(req.Operation),
|
||||
respSQLKey: s,
|
||||
respStartTimeKey: startTime.Format(time.RFC3339Nano),
|
||||
},
|
||||
}
|
||||
|
||||
switch req.Operation { // nolint: exhaustive
|
||||
case execOperation:
|
||||
r, err := m.exec(s)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error executing %s with %v", s, err)
|
||||
}
|
||||
resp.Metadata[respRowsAffectedKey] = strconv.FormatInt(r, 10)
|
||||
|
||||
case queryOperation:
|
||||
d, err := m.query(s)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error executing %s with %v", s, err)
|
||||
}
|
||||
resp.Data = d
|
||||
|
||||
default:
|
||||
return nil, errors.Errorf("invalid operation type: %s. Expected %s, %s, or %s",
|
||||
req.Operation, execOperation, queryOperation, closeOperation)
|
||||
}
|
||||
|
||||
endTime := time.Now().UTC()
|
||||
resp.Metadata[respEndTimeKey] = endTime.Format(time.RFC3339Nano)
|
||||
resp.Metadata[respDurationKey] = endTime.Sub(startTime).String()
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Operations returns list of operations supported by Mysql binding
|
||||
func (m *Mysql) Operations() []bindings.OperationKind {
|
||||
return []bindings.OperationKind{
|
||||
execOperation,
|
||||
queryOperation,
|
||||
closeOperation,
|
||||
}
|
||||
}
|
||||
|
||||
// Close will close the DB
|
||||
func (m *Mysql) Close() error {
|
||||
if m.db != nil {
|
||||
return m.db.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mysql) query(s string) ([]byte, error) {
|
||||
m.logger.Debugf("query: %s", s)
|
||||
|
||||
rows, err := m.db.Query(s)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error executing %s", s)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
_ = rows.Err()
|
||||
}()
|
||||
|
||||
result, err := jsonify(rows)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error marshalling query result for %s", s)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Mysql) exec(sql string) (int64, error) {
|
||||
m.logger.Debugf("exec: %s", sql)
|
||||
|
||||
res, err := m.db.Exec(sql)
|
||||
if err != nil {
|
||||
return 0, errors.Wrapf(err, "error executing %s", sql)
|
||||
}
|
||||
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func propertyToInt(props map[string]string, key string, setter func(int)) error {
|
||||
if v, ok := props[key]; ok {
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
setter(i)
|
||||
} else {
|
||||
return errors.Wrapf(err, "error converitng %s:%s to int", key, v)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func propertyToDuration(props map[string]string, key string, setter func(time.Duration)) error {
|
||||
if v, ok := props[key]; ok {
|
||||
if d, err := time.ParseDuration(v); err == nil {
|
||||
setter(d)
|
||||
} else {
|
||||
return errors.Wrapf(err, "error converitng %s:%s to time duration", key, v)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func initDB(url, pemPath string) (*sql.DB, error) {
|
||||
if _, err := mysql.ParseDSN(url); err != nil {
|
||||
return nil, errors.Wrapf(err, "illegal Data Source Name (DNS) specified by %s", connectionURLKey)
|
||||
}
|
||||
|
||||
if pemPath != "" {
|
||||
rootCertPool := x509.NewCertPool()
|
||||
pem, err := ioutil.ReadFile(pemPath)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "Error reading PEM file from %s", pemPath)
|
||||
}
|
||||
|
||||
ok := rootCertPool.AppendCertsFromPEM(pem)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to append PEM")
|
||||
}
|
||||
|
||||
err = mysql.RegisterTLSConfig("custom", &tls.Config{RootCAs: rootCertPool, MinVersion: tls.VersionTLS12})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error register TLS config")
|
||||
}
|
||||
}
|
||||
|
||||
db, err := sql.Open("mysql", url)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error opening DB connection")
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func jsonify(rows *sql.Rows) ([]byte, error) {
|
||||
columnTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret []interface{}
|
||||
for rows.Next() {
|
||||
scanArgs := prepareScanArgs(columnTypes)
|
||||
err := rows.Scan(scanArgs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := convertScanArgs(columnTypes, scanArgs)
|
||||
ret = append(ret, r)
|
||||
}
|
||||
|
||||
return json.Marshal(ret)
|
||||
}
|
||||
|
||||
func convertScanArgs(columnTypes []*sql.ColumnType, scanArgs []interface{}) map[string]interface{} {
|
||||
r := map[string]interface{}{}
|
||||
|
||||
for i, v := range columnTypes {
|
||||
if s, ok := (scanArgs[i]).(*sql.NullString); ok {
|
||||
r[v.Name()] = s.String
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if s, ok := (scanArgs[i]).(*sql.NullBool); ok {
|
||||
r[v.Name()] = s.Bool
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if s, ok := (scanArgs[i]).(*sql.NullInt32); ok {
|
||||
r[v.Name()] = s.Int32
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if s, ok := (scanArgs[i]).(*sql.NullInt64); ok {
|
||||
r[v.Name()] = s.Int64
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if s, ok := (scanArgs[i]).(*sql.NullFloat64); ok {
|
||||
r[v.Name()] = s.Float64
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if s, ok := (scanArgs[i]).(*sql.NullTime); ok {
|
||||
r[v.Name()] = s.Time
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// this won't happen since the default switch is sql.NullString
|
||||
r[v.Name()] = scanArgs[i]
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func prepareScanArgs(columnTypes []*sql.ColumnType) []interface{} {
|
||||
scanArgs := make([]interface{}, len(columnTypes))
|
||||
for i, v := range columnTypes {
|
||||
switch v.DatabaseTypeName() {
|
||||
case "BOOL":
|
||||
scanArgs[i] = new(sql.NullBool)
|
||||
case "INT", "MEDIUMINT", "SMALLINT", "CHAR", "TINYINT":
|
||||
scanArgs[i] = new(sql.NullInt32)
|
||||
case "BIGINT":
|
||||
scanArgs[i] = new(sql.NullInt64)
|
||||
case "DOUBLE", "FLOAT", "DECIMAL":
|
||||
scanArgs[i] = new(sql.NullFloat64)
|
||||
case "DATE", "TIME", "YEAR":
|
||||
scanArgs[i] = new(sql.NullTime)
|
||||
default:
|
||||
scanArgs[i] = new(sql.NullString)
|
||||
}
|
||||
}
|
||||
|
||||
return scanArgs
|
||||
}
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
// ------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
// ------------------------------------------------------------
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dapr/components-contrib/bindings"
|
||||
"github.com/dapr/dapr/pkg/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
testCreateTable = `CREATE TABLE IF NOT EXISTS foo (
|
||||
id bigint NOT NULL,
|
||||
v1 character varying(50) NOT NULL,
|
||||
b BOOLEAN,
|
||||
ts TIMESTAMP)`
|
||||
testDropTable = `DROP TABLE foo`
|
||||
testInsert = "INSERT INTO foo (id, v1, b, ts) VALUES (%d, 'test-%d', %t, '%v')"
|
||||
testDelete = "DELETE FROM foo"
|
||||
testUpdate = "UPDATE foo SET ts = '%v' WHERE id = %d"
|
||||
testSelect = "SELECT * FROM foo WHERE id < 3"
|
||||
)
|
||||
|
||||
func TestOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Get operation list", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
b := NewMysql(nil)
|
||||
assert.NotNil(t, b)
|
||||
l := b.Operations()
|
||||
assert.Equal(t, 3, len(l))
|
||||
assert.Contains(t, l, execOperation)
|
||||
assert.Contains(t, l, closeOperation)
|
||||
assert.Contains(t, l, queryOperation)
|
||||
})
|
||||
}
|
||||
|
||||
// SETUP TESTS
|
||||
// 1. `CREATE DATABASE daprtest;`
|
||||
// 2. `CREATE USER daprtest;`
|
||||
// 3. `GRANT ALL PRIVILEGES ON daprtest.* to daprtest;`
|
||||
// 4. `export MYSQL_TEST_CONN_URL=daprtest@tcp(localhost:3306)/daprtest`
|
||||
// 5. `go test -v -count=1 ./bindings/mysql -run ^TestMysqlIntegrationWithURL`
|
||||
|
||||
func TestMysqlIntegration(t *testing.T) {
|
||||
url := os.Getenv("MYSQL_TEST_CONN_URL")
|
||||
if url == "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
b := NewMysql(logger.NewLogger("test"))
|
||||
m := bindings.Metadata{Properties: map[string]string{connectionURLKey: url}}
|
||||
if err := b.Init(m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer b.Close()
|
||||
|
||||
req := &bindings.InvokeRequest{Metadata: map[string]string{}}
|
||||
|
||||
t.Run("Invoke create table", func(t *testing.T) {
|
||||
req.Operation = execOperation
|
||||
req.Metadata[commandSQLKey] = testCreateTable
|
||||
res, err := b.Invoke(req)
|
||||
assertResponse(t, res, err)
|
||||
})
|
||||
|
||||
t.Run("Invoke delete", func(t *testing.T) {
|
||||
req.Operation = execOperation
|
||||
req.Metadata[commandSQLKey] = testDelete
|
||||
res, err := b.Invoke(req)
|
||||
assertResponse(t, res, err)
|
||||
})
|
||||
|
||||
t.Run("Invoke insert", func(t *testing.T) {
|
||||
req.Operation = execOperation
|
||||
for i := 0; i < 10; i++ {
|
||||
req.Metadata[commandSQLKey] = fmt.Sprintf(testInsert, i, i, true, time.Now().Format(time.RFC3339))
|
||||
res, err := b.Invoke(req)
|
||||
assertResponse(t, res, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invoke update", func(t *testing.T) {
|
||||
req.Operation = execOperation
|
||||
for i := 0; i < 10; i++ {
|
||||
req.Metadata[commandSQLKey] = fmt.Sprintf(testUpdate, time.Now().Format(time.RFC3339), i)
|
||||
res, err := b.Invoke(req)
|
||||
assertResponse(t, res, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invoke select", func(t *testing.T) {
|
||||
req.Operation = queryOperation
|
||||
req.Metadata[commandSQLKey] = testSelect
|
||||
res, err := b.Invoke(req)
|
||||
assertResponse(t, res, err)
|
||||
t.Logf("received result: %s", res.Data)
|
||||
|
||||
// verify number, boolean and string
|
||||
assert.Contains(t, string(res.Data), "\"id\":1")
|
||||
assert.Contains(t, string(res.Data), "\"b\":1")
|
||||
assert.Contains(t, string(res.Data), "\"v1\":\"test-1\"")
|
||||
|
||||
result := make([]interface{}, 0)
|
||||
err = json.Unmarshal(res.Data, &result)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 3, len(result))
|
||||
|
||||
// verify timestamp
|
||||
ts, ok := result[0].(map[string]interface{})["ts"].(string)
|
||||
assert.True(t, ok)
|
||||
// have to use custom layout to parse timestamp, see this: https://github.com/dapr/components-contrib/pull/615
|
||||
var tt time.Time
|
||||
tt, err = time.Parse("2006-01-02 15:04:05", ts)
|
||||
assert.Nil(t, err)
|
||||
t.Logf("time stamp is: %v", tt)
|
||||
})
|
||||
|
||||
t.Run("Invoke delete", func(t *testing.T) {
|
||||
req.Operation = execOperation
|
||||
req.Metadata[commandSQLKey] = testDelete
|
||||
req.Data = nil
|
||||
res, err := b.Invoke(req)
|
||||
assertResponse(t, res, err)
|
||||
})
|
||||
|
||||
t.Run("Invoke drop", func(t *testing.T) {
|
||||
req.Operation = execOperation
|
||||
req.Metadata[commandSQLKey] = testDropTable
|
||||
res, err := b.Invoke(req)
|
||||
assertResponse(t, res, err)
|
||||
})
|
||||
|
||||
t.Run("Invoke close", func(t *testing.T) {
|
||||
req.Operation = closeOperation
|
||||
req.Metadata = nil
|
||||
req.Data = nil
|
||||
_, err := b.Invoke(req)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func assertResponse(t *testing.T, res *bindings.InvokeResponse, err error) {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, res)
|
||||
assert.NotNil(t, res.Metadata)
|
||||
}
|
||||
|
|
@ -0,0 +1,172 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/dapr/components-contrib/bindings"
|
||||
"github.com/dapr/dapr/pkg/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
m, mock, _ := mockDatabase(t)
|
||||
defer m.Close()
|
||||
|
||||
t.Run("no dbType provided", func(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"id", "value", "timestamp"}).
|
||||
AddRow(1, "value-1", time.Now()).
|
||||
AddRow(2, "value-2", time.Now().Add(1000)).
|
||||
AddRow(3, "value-3", time.Now().Add(2000))
|
||||
|
||||
mock.ExpectQuery("SELECT \\* FROM foo WHERE id < 4").WillReturnRows(rows)
|
||||
ret, err := m.query(`SELECT * FROM foo WHERE id < 4`)
|
||||
assert.Nil(t, err)
|
||||
t.Logf("query result: %s", ret)
|
||||
assert.Contains(t, string(ret), "\"id\":\"1\"")
|
||||
var result []interface{}
|
||||
err = json.Unmarshal(ret, &result)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 3, len(result))
|
||||
})
|
||||
|
||||
t.Run("dbType provided", func(t *testing.T) {
|
||||
col1 := sqlmock.NewColumn("id").OfType("BIGINT", 1)
|
||||
col2 := sqlmock.NewColumn("value").OfType("FLOAT", 1.0)
|
||||
col3 := sqlmock.NewColumn("timestamp").OfType("TIME", time.Now())
|
||||
rows := sqlmock.NewRowsWithColumnDefinition(col1, col2, col3).
|
||||
AddRow(1, 1.1, time.Now()).
|
||||
AddRow(2, 2.2, time.Now().Add(1000)).
|
||||
AddRow(3, 3.3, time.Now().Add(2000))
|
||||
mock.ExpectQuery("SELECT \\* FROM foo WHERE id < 4").WillReturnRows(rows)
|
||||
ret, err := m.query("SELECT * FROM foo WHERE id < 4")
|
||||
assert.Nil(t, err)
|
||||
t.Logf("query result: %s", ret)
|
||||
|
||||
// verify number
|
||||
assert.Contains(t, string(ret), "\"id\":1")
|
||||
assert.Contains(t, string(ret), "\"value\":2.2")
|
||||
|
||||
var result []interface{}
|
||||
err = json.Unmarshal(ret, &result)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 3, len(result))
|
||||
|
||||
// verify timestamp
|
||||
ts, ok := result[0].(map[string]interface{})["timestamp"].(string)
|
||||
assert.True(t, ok)
|
||||
var tt time.Time
|
||||
tt, err = time.Parse(time.RFC3339, ts)
|
||||
assert.Nil(t, err)
|
||||
t.Logf("time stamp is: %v", tt)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExec(t *testing.T) {
|
||||
m, mock, _ := mockDatabase(t)
|
||||
defer m.Close()
|
||||
mock.ExpectExec("INSERT INTO foo \\(id, v1, ts\\) VALUES \\(.*\\)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
i, err := m.exec("INSERT INTO foo (id, v1, ts) VALUES (1, 'test-1', '2021-01-22')")
|
||||
assert.Equal(t, int64(1), i)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestInvoke(t *testing.T) {
|
||||
m, mock, _ := mockDatabase(t)
|
||||
defer m.Close()
|
||||
|
||||
t.Run("exec operation succeeds", func(t *testing.T) {
|
||||
mock.ExpectExec("INSERT INTO foo \\(id, v1, ts\\) VALUES \\(.*\\)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
metadata := map[string]string{commandSQLKey: "INSERT INTO foo (id, v1, ts) VALUES (1, 'test-1', '2021-01-22')"}
|
||||
req := &bindings.InvokeRequest{
|
||||
Data: nil,
|
||||
Metadata: metadata,
|
||||
Operation: execOperation,
|
||||
}
|
||||
resp, err := m.Invoke(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "1", resp.Metadata[respRowsAffectedKey])
|
||||
})
|
||||
|
||||
t.Run("exec operation fails", func(t *testing.T) {
|
||||
mock.ExpectExec("INSERT INTO foo \\(id, v1, ts\\) VALUES \\(.*\\)").WillReturnError(errors.New("insert failed"))
|
||||
metadata := map[string]string{commandSQLKey: "INSERT INTO foo (id, v1, ts) VALUES (1, 'test-1', '2021-01-22')"}
|
||||
req := &bindings.InvokeRequest{
|
||||
Data: nil,
|
||||
Metadata: metadata,
|
||||
Operation: execOperation,
|
||||
}
|
||||
resp, err := m.Invoke(req)
|
||||
assert.Nil(t, resp)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("query operation succeeds", func(t *testing.T) {
|
||||
col1 := sqlmock.NewColumn("id").OfType("BIGINT", 1)
|
||||
col2 := sqlmock.NewColumn("value").OfType("FLOAT", 1.0)
|
||||
col3 := sqlmock.NewColumn("timestamp").OfType("TIME", time.Now())
|
||||
rows := sqlmock.NewRowsWithColumnDefinition(col1, col2, col3).AddRow(1, 1.1, time.Now())
|
||||
mock.ExpectQuery("SELECT \\* FROM foo WHERE id < \\d+").WillReturnRows(rows)
|
||||
|
||||
metadata := map[string]string{commandSQLKey: "SELECT * FROM foo WHERE id < 2"}
|
||||
req := &bindings.InvokeRequest{
|
||||
Data: nil,
|
||||
Metadata: metadata,
|
||||
Operation: queryOperation,
|
||||
}
|
||||
resp, err := m.Invoke(req)
|
||||
assert.Nil(t, err)
|
||||
var data []interface{}
|
||||
err = json.Unmarshal(resp.Data, &data)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(data))
|
||||
})
|
||||
|
||||
t.Run("query operation fails", func(t *testing.T) {
|
||||
mock.ExpectQuery("SELECT \\* FROM foo WHERE id < \\d+").WillReturnError(errors.New("query failed"))
|
||||
metadata := map[string]string{commandSQLKey: "SELECT * FROM foo WHERE id < 2"}
|
||||
req := &bindings.InvokeRequest{
|
||||
Data: nil,
|
||||
Metadata: metadata,
|
||||
Operation: queryOperation,
|
||||
}
|
||||
resp, err := m.Invoke(req)
|
||||
assert.Nil(t, resp)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("close operation", func(t *testing.T) {
|
||||
mock.ExpectClose()
|
||||
req := &bindings.InvokeRequest{
|
||||
Operation: closeOperation,
|
||||
}
|
||||
resp, _ := m.Invoke(req)
|
||||
assert.Nil(t, resp)
|
||||
})
|
||||
|
||||
t.Run("unsupported operation", func(t *testing.T) {
|
||||
req := &bindings.InvokeRequest{
|
||||
Data: nil,
|
||||
Metadata: map[string]string{},
|
||||
Operation: "unsupported",
|
||||
}
|
||||
resp, err := m.Invoke(req)
|
||||
assert.Nil(t, resp)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func mockDatabase(t *testing.T) (*Mysql, sqlmock.Sqlmock, error) {
|
||||
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
|
||||
m := NewMysql(logger.NewLogger("test"))
|
||||
m.db = db
|
||||
|
||||
return m, mock, err
|
||||
}
|
||||
2
go.mod
2
go.mod
|
|
@ -16,6 +16,7 @@ require (
|
|||
github.com/Azure/go-autorest/autorest v0.11.12
|
||||
github.com/Azure/go-autorest/autorest/adal v0.9.5
|
||||
github.com/Azure/go-autorest/autorest/azure/auth v0.4.2
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0
|
||||
github.com/Shopify/sarama v1.23.1
|
||||
github.com/a8m/documentdb v1.2.1-0.20190920062420-efdd52fe0905
|
||||
github.com/aerospike/aerospike-client-go v2.7.0+incompatible
|
||||
|
|
@ -35,6 +36,7 @@ require (
|
|||
github.com/eclipse/paho.mqtt.golang v1.2.0
|
||||
github.com/fasthttp-contrib/sessions v0.0.0-20160905201309-74f6ac73d5d5
|
||||
github.com/go-redis/redis/v7 v7.0.1
|
||||
github.com/go-sql-driver/mysql v1.5.0
|
||||
github.com/gocql/gocql v0.0.0-20191018090344-07ace3bab0f8
|
||||
github.com/golang/mock v1.4.4
|
||||
github.com/golang/protobuf v1.4.3
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -154,6 +154,8 @@ github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBp
|
|||
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
|
||||
github.com/DataDog/datadog-go v2.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
|
||||
github.com/DataDog/zstd v1.3.6-0.20190409195224-796139022798 h1:2T/jmrHeTezcCM58lvEQXs0UpQJCo5SoGAcg+mbSTIg=
|
||||
github.com/DataDog/zstd v1.3.6-0.20190409195224-796139022798/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
|
||||
|
|
|
|||
Loading…
Reference in New Issue