* fix: Mysql should support more data types. #923 * go fmt * fix lint issue * revise according to the review comment Co-authored-by: Phil Kedy <phil.kedy@gmail.com> Co-authored-by: Yaron Schneider <yaronsc@microsoft.com>
This commit is contained in:
parent
7d2bc9bbdf
commit
0777a6a943
|
|
@ -9,9 +9,11 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
|
|
@ -203,7 +205,7 @@ func (m *Mysql) query(s string) ([]byte, error) {
|
|||
_ = rows.Err()
|
||||
}()
|
||||
|
||||
result, err := jsonify(rows)
|
||||
result, err := m.jsonify(rows)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error marshalling query result for %s", s)
|
||||
}
|
||||
|
|
@ -277,7 +279,7 @@ func initDB(url, pemPath string) (*sql.DB, error) {
|
|||
return db, nil
|
||||
}
|
||||
|
||||
func jsonify(rows *sql.Rows) ([]byte, error) {
|
||||
func (m *Mysql) jsonify(rows *sql.Rows) ([]byte, error) {
|
||||
columnTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -285,84 +287,58 @@ func jsonify(rows *sql.Rows) ([]byte, error) {
|
|||
|
||||
var ret []interface{}
|
||||
for rows.Next() {
|
||||
scanArgs := prepareScanArgs(columnTypes)
|
||||
err := rows.Scan(scanArgs...)
|
||||
values := prepareValues(columnTypes)
|
||||
err := rows.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := convertScanArgs(columnTypes, scanArgs)
|
||||
r := m.convert(columnTypes, values)
|
||||
ret = append(ret, r)
|
||||
}
|
||||
|
||||
return json.Marshal(ret)
|
||||
}
|
||||
|
||||
func convertScanArgs(columnTypes []*sql.ColumnType, scanArgs []interface{}) map[string]interface{} {
|
||||
func prepareValues(columnTypes []*sql.ColumnType) []interface{} {
|
||||
types := make([]reflect.Type, len(columnTypes))
|
||||
for i, tp := range columnTypes {
|
||||
types[i] = tp.ScanType()
|
||||
}
|
||||
|
||||
values := make([]interface{}, len(columnTypes))
|
||||
for i := range values {
|
||||
values[i] = reflect.New(types[i]).Interface()
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []interface{}) map[string]interface{} {
|
||||
r := map[string]interface{}{}
|
||||
|
||||
for i, v := range columnTypes {
|
||||
if s, ok := (scanArgs[i]).(*sql.NullString); ok {
|
||||
r[v.Name()] = s.String
|
||||
for i, ct := range columnTypes {
|
||||
value := values[i]
|
||||
|
||||
continue
|
||||
switch v := values[i].(type) {
|
||||
case driver.Valuer:
|
||||
if vv, err := v.Value(); err == nil {
|
||||
value = interface{}(vv)
|
||||
} else {
|
||||
m.logger.Warnf("error to convert value: %v", err)
|
||||
}
|
||||
case *sql.RawBytes:
|
||||
// special case for sql.RawBytes, see https://github.com/go-sql-driver/mysql/blob/master/fields.go#L178
|
||||
switch ct.DatabaseTypeName() {
|
||||
case "VARCHAR", "CHAR":
|
||||
value = string(*v)
|
||||
}
|
||||
}
|
||||
|
||||
if s, ok := (scanArgs[i]).(*sql.NullBool); ok {
|
||||
r[v.Name()] = s.Bool
|
||||
|
||||
continue
|
||||
if value != nil {
|
||||
r[ct.Name()] = value
|
||||
}
|
||||
|
||||
if s, ok := (scanArgs[i]).(*sql.NullInt32); ok {
|
||||
r[v.Name()] = s.Int32
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if s, ok := (scanArgs[i]).(*sql.NullInt64); ok {
|
||||
r[v.Name()] = s.Int64
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if s, ok := (scanArgs[i]).(*sql.NullFloat64); ok {
|
||||
r[v.Name()] = s.Float64
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if s, ok := (scanArgs[i]).(*sql.NullTime); ok {
|
||||
r[v.Name()] = s.Time
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// this won't happen since the default switch is sql.NullString
|
||||
r[v.Name()] = scanArgs[i]
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func prepareScanArgs(columnTypes []*sql.ColumnType) []interface{} {
|
||||
scanArgs := make([]interface{}, len(columnTypes))
|
||||
for i, v := range columnTypes {
|
||||
switch v.DatabaseTypeName() {
|
||||
case "BOOL":
|
||||
scanArgs[i] = new(sql.NullBool)
|
||||
case "INT", "MEDIUMINT", "SMALLINT", "CHAR", "TINYINT":
|
||||
scanArgs[i] = new(sql.NullInt32)
|
||||
case "BIGINT":
|
||||
scanArgs[i] = new(sql.NullInt64)
|
||||
case "DOUBLE", "FLOAT", "DECIMAL":
|
||||
scanArgs[i] = new(sql.NullFloat64)
|
||||
case "DATE", "TIME", "YEAR":
|
||||
scanArgs[i] = new(sql.NullTime)
|
||||
default:
|
||||
scanArgs[i] = new(sql.NullString)
|
||||
}
|
||||
}
|
||||
|
||||
return scanArgs
|
||||
}
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ func TestMysqlIntegration(t *testing.T) {
|
|||
assert.True(t, ok)
|
||||
// have to use custom layout to parse timestamp, see this: https://github.com/dapr/components-contrib/pull/615
|
||||
var tt time.Time
|
||||
tt, err = time.Parse("2006-01-02 15:04:05", ts)
|
||||
tt, err = time.Parse("2006-01-02T15:04:05Z", ts)
|
||||
assert.Nil(t, err)
|
||||
t.Logf("time stamp is: %v", tt)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ func TestQuery(t *testing.T) {
|
|||
ret, err := m.query(`SELECT * FROM foo WHERE id < 4`)
|
||||
assert.Nil(t, err)
|
||||
t.Logf("query result: %s", ret)
|
||||
assert.Contains(t, string(ret), "\"id\":\"1\"")
|
||||
assert.Contains(t, string(ret), "\"id\":1")
|
||||
var result []interface{}
|
||||
err = json.Unmarshal(ret, &result)
|
||||
assert.Nil(t, err)
|
||||
|
|
|
|||
Loading…
Reference in New Issue