MySQL binding: allow passing parameters for queries (#2975)

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Alessandro (Ale) Segala 2023-07-12 14:16:35 -07:00 committed by GitHub
parent fd8e3a2086
commit 1349fca858
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 192 additions and 134 deletions

View File

@ -17,17 +17,17 @@ binding:
- name: query - name: query
description: "The query operation is used for SELECT statements, which returns the metadata along with data in a form of an array of row values." description: "The query operation is used for SELECT statements, which returns the metadata along with data in a form of an array of row values."
- name: close - name: close
description: "The close operation can be used to explicitly close the DB connection and return it to the pool. This operation doesnt have any response." description: "The close operation can be used to explicitly close the DB connection and return it to the pool. This operation doesn't have any response."
metadata: metadata:
- name: url - name: url
required: true required: true
description: "Represent a DB connection in Data Source Name (DNS) format." description: "Represent a DB connection in Data Source Name (DNS) format"
example: "user:password@tcp(localhost:3306)/dbname" example: '"user:password@tcp(localhost:3306)/dbname"'
type: string type: string
- name: pemPath - name: pemPath
required: false required: false
description: "Path to the PEM file. Used with SSL connection" description: "Path to the PEM file. Used with SSL connection"
example: "path/to/pem/file" example: '"path/to/pem/file"'
type: string type: string
- name: maxIdleConns - name: maxIdleConns
required: false required: false
@ -49,8 +49,3 @@ metadata:
description: "The max connection idel time." description: "The max connection idel time."
example: "12s" example: "12s"
type: duration type: duration
- name: maxRetries
required: false
description: "MaxRetries is the maximum number of retries for a query."
example: "5"
type: number

View File

@ -25,6 +25,7 @@ import (
"os" "os"
"reflect" "reflect"
"strconv" "strconv"
"sync/atomic"
"time" "time"
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
@ -52,7 +53,8 @@ const (
// "%s:%s@tcp(%s:3306)/%s?allowNativePasswords=true&tls=custom",'myadmin@mydemoserver', 'yourpassword', 'mydemoserver.mysql.database.azure.com', 'targetdb'. // "%s:%s@tcp(%s:3306)/%s?allowNativePasswords=true&tls=custom",'myadmin@mydemoserver', 'yourpassword', 'mydemoserver.mysql.database.azure.com', 'targetdb'.
// keys from request's metadata. // keys from request's metadata.
commandSQLKey = "sql" commandSQLKey = "sql"
commandParamsKey = "params"
// keys from response's metadata. // keys from response's metadata.
respOpKey = "operation" respOpKey = "operation"
@ -67,6 +69,7 @@ const (
type Mysql struct { type Mysql struct {
db *sql.DB db *sql.DB
logger logger.Logger logger logger.Logger
closed atomic.Bool
} }
type mysqlMetadata struct { type mysqlMetadata struct {
@ -87,21 +90,22 @@ type mysqlMetadata struct {
// ConnMaxIdleTime is the maximum amount of time a connection may be idle. // ConnMaxIdleTime is the maximum amount of time a connection may be idle.
ConnMaxIdleTime time.Duration `mapstructure:"connMaxIdleTime"` ConnMaxIdleTime time.Duration `mapstructure:"connMaxIdleTime"`
// MaxRetries is the maximum number of retries for a query.
MaxRetries int `mapstructure:"maxRetries"`
} }
// NewMysql returns a new MySQL output binding. // NewMysql returns a new MySQL output binding.
func NewMysql(logger logger.Logger) bindings.OutputBinding { func NewMysql(logger logger.Logger) bindings.OutputBinding {
return &Mysql{logger: logger} return &Mysql{
logger: logger,
}
} }
// Init initializes the MySQL binding. // Init initializes the MySQL binding.
func (m *Mysql) Init(ctx context.Context, md bindings.Metadata) error { func (m *Mysql) Init(ctx context.Context, md bindings.Metadata) error {
m.logger.Debug("Initializing MySql binding") if m.closed.Load() {
return errors.New("cannot initialize a previously-closed component")
}
// parse metadata // Parse metadata
meta := mysqlMetadata{} meta := mysqlMetadata{}
err := metadata.DecodeMetadata(md.Properties, &meta) err := metadata.DecodeMetadata(md.Properties, &meta)
if err != nil { if err != nil {
@ -112,23 +116,29 @@ func (m *Mysql) Init(ctx context.Context, md bindings.Metadata) error {
return fmt.Errorf("missing MySql connection string") return fmt.Errorf("missing MySql connection string")
} }
db, err := initDB(meta.URL, meta.PemPath) m.db, err = initDB(meta.URL, meta.PemPath)
if err != nil { if err != nil {
return err return err
} }
db.SetMaxIdleConns(meta.MaxIdleConns) if meta.MaxIdleConns > 0 {
db.SetMaxOpenConns(meta.MaxOpenConns) m.db.SetMaxIdleConns(meta.MaxIdleConns)
db.SetConnMaxIdleTime(meta.ConnMaxIdleTime) }
db.SetConnMaxLifetime(meta.ConnMaxLifetime) if meta.MaxOpenConns > 0 {
m.db.SetMaxOpenConns(meta.MaxOpenConns)
}
if meta.ConnMaxIdleTime > 0 {
m.db.SetConnMaxIdleTime(meta.ConnMaxIdleTime)
}
if meta.ConnMaxLifetime > 0 {
m.db.SetConnMaxLifetime(meta.ConnMaxLifetime)
}
err = db.PingContext(ctx) err = m.db.PingContext(ctx)
if err != nil { if err != nil {
return fmt.Errorf("unable to ping the DB: %w", err) return fmt.Errorf("unable to ping the DB: %w", err)
} }
m.db = db
return nil return nil
} }
@ -138,22 +148,38 @@ func (m *Mysql) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi
return nil, errors.New("invoke request required") return nil, errors.New("invoke request required")
} }
// We let the "close" operation here succeed even if the component has been closed already
if req.Operation == closeOperation { if req.Operation == closeOperation {
return nil, m.db.Close() return nil, m.Close()
}
if m.closed.Load() {
return nil, errors.New("component is closed")
} }
if req.Metadata == nil { if req.Metadata == nil {
return nil, errors.New("metadata required") return nil, errors.New("metadata required")
} }
m.logger.Debugf("operation: %v", req.Operation)
s, ok := req.Metadata[commandSQLKey] s := req.Metadata[commandSQLKey]
if !ok || s == "" { if s == "" {
return nil, fmt.Errorf("required metadata not set: %s", commandSQLKey) return nil, fmt.Errorf("required metadata not set: %s", commandSQLKey)
} }
startTime := time.Now() // Metadata property "params" contains JSON-encoded parameters, and it's optional
// If present, it must be unserializable into a []any object
var (
params []any
err error
)
if paramsStr := req.Metadata[commandParamsKey]; paramsStr != "" {
err = json.Unmarshal([]byte(paramsStr), &params)
if err != nil {
return nil, fmt.Errorf("invalid metadata property %s: failed to unserialize into an array: %w", commandParamsKey, err)
}
}
startTime := time.Now().UTC()
resp := &bindings.InvokeResponse{ resp := &bindings.InvokeResponse{
Metadata: map[string]string{ Metadata: map[string]string{
respOpKey: string(req.Operation), respOpKey: string(req.Operation),
@ -162,16 +188,16 @@ func (m *Mysql) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi
}, },
} }
switch req.Operation { //nolint:exhaustive switch req.Operation {
case execOperation: case execOperation:
r, err := m.exec(ctx, s) r, err := m.exec(ctx, s, params...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp.Metadata[respRowsAffectedKey] = strconv.FormatInt(r, 10) resp.Metadata[respRowsAffectedKey] = strconv.FormatInt(r, 10)
case queryOperation: case queryOperation:
d, err := m.query(ctx, s) d, err := m.query(ctx, s, params...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -182,7 +208,7 @@ func (m *Mysql) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi
req.Operation, execOperation, queryOperation, closeOperation) req.Operation, execOperation, queryOperation, closeOperation)
} }
endTime := time.Now() endTime := time.Now().UTC()
resp.Metadata[respEndTimeKey] = endTime.Format(time.RFC3339Nano) resp.Metadata[respEndTimeKey] = endTime.Format(time.RFC3339Nano)
resp.Metadata[respDurationKey] = endTime.Sub(startTime).String() resp.Metadata[respDurationKey] = endTime.Sub(startTime).String()
@ -200,23 +226,26 @@ func (m *Mysql) Operations() []bindings.OperationKind {
// Close will close the DB. // Close will close the DB.
func (m *Mysql) Close() error { func (m *Mysql) Close() error {
if !m.closed.CompareAndSwap(false, true) {
// If this failed, the component has already been closed
// We allow multiple calls to close
return nil
}
if m.db != nil { if m.db != nil {
return m.db.Close() m.db.Close()
m.db = nil
} }
return nil return nil
} }
func (m *Mysql) query(ctx context.Context, sql string) ([]byte, error) { func (m *Mysql) query(ctx context.Context, sql string, params ...any) ([]byte, error) {
rows, err := m.db.QueryContext(ctx, sql) rows, err := m.db.QueryContext(ctx, sql, params...)
if err != nil { if err != nil {
return nil, fmt.Errorf("error executing query: %w", err) return nil, fmt.Errorf("error executing query: %w", err)
} }
defer rows.Close()
defer func() {
_ = rows.Close()
_ = rows.Err()
}()
result, err := m.jsonify(rows) result, err := m.jsonify(rows)
if err != nil { if err != nil {
@ -226,10 +255,8 @@ func (m *Mysql) query(ctx context.Context, sql string) ([]byte, error) {
return result, nil return result, nil
} }
func (m *Mysql) exec(ctx context.Context, sql string) (int64, error) { func (m *Mysql) exec(ctx context.Context, sql string, params ...any) (int64, error) {
m.logger.Debugf("exec: %s", sql) res, err := m.db.ExecContext(ctx, sql, params...)
res, err := m.db.ExecContext(ctx, sql)
if err != nil { if err != nil {
return 0, fmt.Errorf("error executing query: %w", err) return 0, fmt.Errorf("error executing query: %w", err)
} }
@ -238,13 +265,15 @@ func (m *Mysql) exec(ctx context.Context, sql string) (int64, error) {
} }
func initDB(url, pemPath string) (*sql.DB, error) { func initDB(url, pemPath string) (*sql.DB, error) {
if _, err := mysql.ParseDSN(url); err != nil { conf, err := mysql.ParseDSN(url)
if err != nil {
return nil, fmt.Errorf("illegal Data Source Name (DSN) specified by %s", connectionURLKey) return nil, fmt.Errorf("illegal Data Source Name (DSN) specified by %s", connectionURLKey)
} }
if pemPath != "" { if pemPath != "" {
var pem []byte
rootCertPool := x509.NewCertPool() rootCertPool := x509.NewCertPool()
pem, err := os.ReadFile(pemPath) pem, err = os.ReadFile(pemPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("error reading PEM file from %s: %w", pemPath, err) return nil, fmt.Errorf("error reading PEM file from %s: %w", pemPath, err)
} }
@ -254,17 +283,25 @@ func initDB(url, pemPath string) (*sql.DB, error) {
return nil, fmt.Errorf("failed to append PEM") return nil, fmt.Errorf("failed to append PEM")
} }
err = mysql.RegisterTLSConfig("custom", &tls.Config{RootCAs: rootCertPool, MinVersion: tls.VersionTLS12}) err = mysql.RegisterTLSConfig("custom", &tls.Config{
RootCAs: rootCertPool,
MinVersion: tls.VersionTLS12,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("error register TLS config: %w", err) return nil, fmt.Errorf("error register TLS config: %w", err)
} }
} }
db, err := sql.Open("mysql", url) // Required to correctly parse time columns
// See: https://stackoverflow.com/a/45040724
conf.ParseTime = true
connector, err := mysql.NewConnector(conf)
if err != nil { if err != nil {
return nil, fmt.Errorf("error opening DB connection: %w", err) return nil, fmt.Errorf("error opening DB connection: %w", err)
} }
db := sql.OpenDB(connector)
return db, nil return db, nil
} }
@ -274,7 +311,7 @@ func (m *Mysql) jsonify(rows *sql.Rows) ([]byte, error) {
return nil, err return nil, err
} }
var ret []interface{} var ret []any
for rows.Next() { for rows.Next() {
values := prepareValues(columnTypes) values := prepareValues(columnTypes)
err := rows.Scan(values...) err := rows.Scan(values...)
@ -289,13 +326,13 @@ func (m *Mysql) jsonify(rows *sql.Rows) ([]byte, error) {
return json.Marshal(ret) return json.Marshal(ret)
} }
func prepareValues(columnTypes []*sql.ColumnType) []interface{} { func prepareValues(columnTypes []*sql.ColumnType) []any {
types := make([]reflect.Type, len(columnTypes)) types := make([]reflect.Type, len(columnTypes))
for i, tp := range columnTypes { for i, tp := range columnTypes {
types[i] = tp.ScanType() types[i] = tp.ScanType()
} }
values := make([]interface{}, len(columnTypes)) values := make([]any, len(columnTypes))
for i := range values { for i := range values {
values[i] = reflect.New(types[i]).Interface() values[i] = reflect.New(types[i]).Interface()
} }
@ -303,8 +340,8 @@ func prepareValues(columnTypes []*sql.ColumnType) []interface{} {
return values return values
} }
func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []interface{}) map[string]interface{} { func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []any) map[string]any {
r := map[string]interface{}{} r := map[string]any{}
for i, ct := range columnTypes { for i, ct := range columnTypes {
value := values[i] value := values[i]
@ -312,7 +349,7 @@ func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []interface{}) map
switch v := values[i].(type) { switch v := values[i].(type) {
case driver.Valuer: case driver.Valuer:
if vv, err := v.Value(); err == nil { if vv, err := v.Value(); err == nil {
value = interface{}(vv) value = any(vv)
} else { } else {
m.logger.Warnf("error to convert value: %v", err) m.logger.Warnf("error to convert value: %v", err)
} }

View File

@ -22,36 +22,20 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/bindings" "github.com/dapr/components-contrib/bindings"
"github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/logger" "github.com/dapr/kit/logger"
) )
const ( // MySQL doesn't accept RFC3339 formatted time, rejects trailing 'Z' for UTC indicator.
// MySQL doesn't accept RFC3339 formatted time, rejects trailing 'Z' for UTC indicator. const mySQLDateTimeFormat = "2006-01-02 15:04:05"
mySQLDateTimeFormat = "2006-01-02 15:04:05"
testCreateTable = `CREATE TABLE IF NOT EXISTS foo (
id bigint NOT NULL,
v1 character varying(50) NOT NULL,
b BOOLEAN,
ts TIMESTAMP,
data LONGTEXT)`
testDropTable = `DROP TABLE foo`
testInsert = "INSERT INTO foo (id, v1, b, ts, data) VALUES (%d, 'test-%d', %t, '%v', '%s')"
testDelete = "DELETE FROM foo"
testUpdate = "UPDATE foo SET ts = '%v' WHERE id = %d"
testSelect = "SELECT * FROM foo WHERE id < 3"
testSelectJSONExtract = "SELECT JSON_EXTRACT(data, '$.key') AS `key` FROM foo WHERE id < 3"
)
func TestOperations(t *testing.T) { func TestOperations(t *testing.T) {
t.Parallel()
t.Run("Get operation list", func(t *testing.T) { t.Run("Get operation list", func(t *testing.T) {
t.Parallel() b := NewMysql(logger.NewLogger("test"))
b := NewMysql(nil) require.NotNil(t, b)
assert.NotNil(t, b)
l := b.Operations() l := b.Operations()
assert.Equal(t, 3, len(l)) assert.Equal(t, 3, len(l))
assert.Contains(t, l, execOperation) assert.Contains(t, l, execOperation)
@ -70,123 +54,165 @@ func TestOperations(t *testing.T) {
func TestMysqlIntegration(t *testing.T) { func TestMysqlIntegration(t *testing.T) {
url := os.Getenv("MYSQL_TEST_CONN_URL") url := os.Getenv("MYSQL_TEST_CONN_URL")
if url == "" { if url == "" {
t.SkipNow() t.Skip("Skipping because env var MYSQL_TEST_CONN_URL is empty")
} }
b := NewMysql(logger.NewLogger("test")).(*Mysql) b := NewMysql(logger.NewLogger("test")).(*Mysql)
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}} m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
if err := b.Init(context.Background(), m); err != nil {
t.Fatal(err) err := b.Init(context.Background(), m)
} require.NoError(t, err)
defer b.Close() defer b.Close()
req := &bindings.InvokeRequest{Metadata: map[string]string{}}
t.Run("Invoke create table", func(t *testing.T) { t.Run("Invoke create table", func(t *testing.T) {
req.Operation = execOperation res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
req.Metadata[commandSQLKey] = testCreateTable Operation: execOperation,
res, err := b.Invoke(context.TODO(), req) Metadata: map[string]string{
commandSQLKey: `CREATE TABLE IF NOT EXISTS foo (
id bigint NOT NULL,
v1 character varying(50) NOT NULL,
b BOOLEAN,
ts TIMESTAMP,
data LONGTEXT)`,
},
})
assertResponse(t, res, err) assertResponse(t, res, err)
}) })
t.Run("Invoke delete", func(t *testing.T) { t.Run("Invoke delete", func(t *testing.T) {
req.Operation = execOperation res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
req.Metadata[commandSQLKey] = testDelete Operation: execOperation,
res, err := b.Invoke(context.TODO(), req) Metadata: map[string]string{
commandSQLKey: "DELETE FROM foo",
},
})
assertResponse(t, res, err) assertResponse(t, res, err)
}) })
t.Run("Invoke insert", func(t *testing.T) { t.Run("Invoke insert", func(t *testing.T) {
req.Operation = execOperation
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
req.Metadata[commandSQLKey] = fmt.Sprintf(testInsert, i, i, true, time.Now().Format(mySQLDateTimeFormat), "{\"key\":\"val\"}") res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
res, err := b.Invoke(context.TODO(), req) Operation: execOperation,
Metadata: map[string]string{
commandSQLKey: fmt.Sprintf(
"INSERT INTO foo (id, v1, b, ts, data) VALUES (%d, 'test-%d', %t, '%v', '%s')",
i, i, true, time.Now().Format(mySQLDateTimeFormat), `{"key":"val"}`),
},
})
assertResponse(t, res, err) assertResponse(t, res, err)
} }
}) })
t.Run("Invoke update", func(t *testing.T) { t.Run("Invoke update", func(t *testing.T) {
req.Operation = execOperation date := time.Now().Add(time.Hour)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
req.Metadata[commandSQLKey] = fmt.Sprintf(testUpdate, time.Now().Format(mySQLDateTimeFormat), i) res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
res, err := b.Invoke(context.TODO(), req) Operation: execOperation,
Metadata: map[string]string{
commandSQLKey: fmt.Sprintf(
"UPDATE foo SET ts = '%v' WHERE id = %d",
date.Add(10*time.Duration(i)*time.Second).Format(mySQLDateTimeFormat), i),
},
})
assertResponse(t, res, err) assertResponse(t, res, err)
assert.Equal(t, "1", res.Metadata[respRowsAffectedKey])
}
})
t.Run("Invoke update with parameters", func(t *testing.T) {
date := time.Now().Add(2 * time.Hour)
for i := 0; i < 10; i++ {
res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
Operation: execOperation,
Metadata: map[string]string{
commandSQLKey: "UPDATE foo SET ts = ? WHERE id = ?",
commandParamsKey: fmt.Sprintf(`[%q,%d]`, date.Add(10*time.Duration(i)*time.Second).Format(mySQLDateTimeFormat), i),
},
})
assertResponse(t, res, err)
assert.Equal(t, "1", res.Metadata[respRowsAffectedKey])
} }
}) })
t.Run("Invoke select", func(t *testing.T) { t.Run("Invoke select", func(t *testing.T) {
req.Operation = queryOperation res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
req.Metadata[commandSQLKey] = testSelect Operation: queryOperation,
res, err := b.Invoke(context.TODO(), req) Metadata: map[string]string{
commandSQLKey: "SELECT * FROM foo WHERE id < 3",
},
})
assertResponse(t, res, err) assertResponse(t, res, err)
t.Logf("received result: %s", res.Data) t.Logf("received result: %s", res.Data)
// verify number, boolean and string // verify number, boolean and string
assert.Contains(t, string(res.Data), "\"id\":1") assert.Contains(t, string(res.Data), `"id":1`)
assert.Contains(t, string(res.Data), "\"b\":1") assert.Contains(t, string(res.Data), `"b":1`)
assert.Contains(t, string(res.Data), "\"v1\":\"test-1\"") assert.Contains(t, string(res.Data), `"v1":"test-1"`)
assert.Contains(t, string(res.Data), "\"data\":\"{\\\"key\\\":\\\"val\\\"}\"") assert.Contains(t, string(res.Data), `"data":"{\"key\":\"val\"}"`)
result := make([]interface{}, 0) result := make([]any, 0)
err = json.Unmarshal(res.Data, &result) err = json.Unmarshal(res.Data, &result)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, 3, len(result)) assert.Equal(t, 3, len(result))
// verify timestamp // verify timestamp
ts, ok := result[0].(map[string]interface{})["ts"].(string) ts, ok := result[0].(map[string]any)["ts"].(string)
assert.True(t, ok) assert.True(t, ok)
// have to use custom layout to parse timestamp, see this: https://github.com/dapr/components-contrib/pull/615 // have to use custom layout to parse timestamp, see this: https://github.com/dapr/components-contrib/pull/615
var tt time.Time var tt time.Time
tt, err = time.Parse("2006-01-02T15:04:05Z", ts) tt, err = time.Parse("2006-01-02T15:04:05Z", ts)
assert.Nil(t, err) require.NoError(t, err)
t.Logf("time stamp is: %v", tt) t.Logf("time stamp is: %v", tt)
}) })
t.Run("Invoke select JSON_EXTRACT", func(t *testing.T) { t.Run("Invoke select with parameters", func(t *testing.T) {
req.Operation = queryOperation res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
req.Metadata[commandSQLKey] = testSelectJSONExtract Operation: queryOperation,
res, err := b.Invoke(context.TODO(), req) Metadata: map[string]string{
commandSQLKey: "SELECT * FROM foo WHERE id = ?",
commandParamsKey: `[1]`,
},
})
assertResponse(t, res, err) assertResponse(t, res, err)
t.Logf("received result: %s", res.Data) t.Logf("received result: %s", res.Data)
// verify json extract number // verify number, boolean and string
assert.Contains(t, string(res.Data), "{\"key\":\"\\\"val\\\"\"}") assert.Contains(t, string(res.Data), `"id":1`)
assert.Contains(t, string(res.Data), `"b":1`)
assert.Contains(t, string(res.Data), `"v1":"test-1"`)
assert.Contains(t, string(res.Data), `"data":"{\"key\":\"val\"}"`)
result := make([]interface{}, 0) result := make([]any, 0)
err = json.Unmarshal(res.Data, &result) err = json.Unmarshal(res.Data, &result)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, 3, len(result)) assert.Equal(t, 1, len(result))
})
t.Run("Invoke delete", func(t *testing.T) {
req.Operation = execOperation
req.Metadata[commandSQLKey] = testDelete
req.Data = nil
res, err := b.Invoke(context.TODO(), req)
assertResponse(t, res, err)
}) })
t.Run("Invoke drop", func(t *testing.T) { t.Run("Invoke drop", func(t *testing.T) {
req.Operation = execOperation res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
req.Metadata[commandSQLKey] = testDropTable Operation: execOperation,
res, err := b.Invoke(context.TODO(), req) Metadata: map[string]string{
commandSQLKey: "DROP TABLE foo",
},
})
assertResponse(t, res, err) assertResponse(t, res, err)
}) })
t.Run("Invoke close", func(t *testing.T) { t.Run("Invoke close", func(t *testing.T) {
req.Operation = closeOperation _, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
req.Metadata = nil Operation: closeOperation,
req.Data = nil })
_, err := b.Invoke(context.TODO(), req)
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }
func assertResponse(t *testing.T, res *bindings.InvokeResponse, err error) { func assertResponse(t *testing.T, res *bindings.InvokeResponse, err error) {
t.Helper()
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, res) assert.NotNil(t, res)
if res != nil { if res != nil {
assert.NotNil(t, res.Metadata) assert.NotEmpty(t, res.Metadata)
} }
} }

View File

@ -42,7 +42,7 @@ func TestQuery(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
t.Logf("query result: %s", ret) t.Logf("query result: %s", ret)
assert.Contains(t, string(ret), "\"id\":1") assert.Contains(t, string(ret), "\"id\":1")
var result []interface{} var result []any
err = json.Unmarshal(ret, &result) err = json.Unmarshal(ret, &result)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 3, len(result)) assert.Equal(t, 3, len(result))
@ -65,13 +65,13 @@ func TestQuery(t *testing.T) {
assert.Contains(t, string(ret), "\"id\":1") assert.Contains(t, string(ret), "\"id\":1")
assert.Contains(t, string(ret), "\"value\":2.2") assert.Contains(t, string(ret), "\"value\":2.2")
var result []interface{} var result []any
err = json.Unmarshal(ret, &result) err = json.Unmarshal(ret, &result)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 3, len(result)) assert.Equal(t, 3, len(result))
// verify timestamp // verify timestamp
ts, ok := result[0].(map[string]interface{})["timestamp"].(string) ts, ok := result[0].(map[string]any)["timestamp"].(string)
assert.True(t, ok) assert.True(t, ok)
var tt time.Time var tt time.Time
tt, err = time.Parse(time.RFC3339, ts) tt, err = time.Parse(time.RFC3339, ts)
@ -134,7 +134,7 @@ func TestInvoke(t *testing.T) {
} }
resp, err := m.Invoke(context.Background(), req) resp, err := m.Invoke(context.Background(), req)
assert.Nil(t, err) assert.Nil(t, err)
var data []interface{} var data []any
err = json.Unmarshal(resp.Data, &data) err = json.Unmarshal(resp.Data, &data)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 1, len(data)) assert.Equal(t, 1, len(data))