fix: Mysql should support more data types. #923 (#926)

* fix: Mysql should support more data types. #923

* go fmt

* fix lint issue

* revise according to the review comment

Co-authored-by: Phil Kedy <phil.kedy@gmail.com>
Co-authored-by: Yaron Schneider <yaronsc@microsoft.com>
This commit is contained in:
Ian Luo 2021-07-09 01:57:18 +08:00 committed by GitHub
parent 7d2bc9bbdf
commit 0777a6a943
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 65 deletions

View File

@ -9,9 +9,11 @@ import (
"crypto/tls"
"crypto/x509"
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"io/ioutil"
"reflect"
"strconv"
"time"
@ -203,7 +205,7 @@ func (m *Mysql) query(s string) ([]byte, error) {
_ = rows.Err()
}()
result, err := jsonify(rows)
result, err := m.jsonify(rows)
if err != nil {
return nil, errors.Wrapf(err, "error marshalling query result for %s", s)
}
@ -277,7 +279,7 @@ func initDB(url, pemPath string) (*sql.DB, error) {
return db, nil
}
func jsonify(rows *sql.Rows) ([]byte, error) {
func (m *Mysql) jsonify(rows *sql.Rows) ([]byte, error) {
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
@ -285,84 +287,58 @@ func jsonify(rows *sql.Rows) ([]byte, error) {
var ret []interface{}
for rows.Next() {
scanArgs := prepareScanArgs(columnTypes)
err := rows.Scan(scanArgs...)
values := prepareValues(columnTypes)
err := rows.Scan(values...)
if err != nil {
return nil, err
}
r := convertScanArgs(columnTypes, scanArgs)
r := m.convert(columnTypes, values)
ret = append(ret, r)
}
return json.Marshal(ret)
}
func convertScanArgs(columnTypes []*sql.ColumnType, scanArgs []interface{}) map[string]interface{} {
func prepareValues(columnTypes []*sql.ColumnType) []interface{} {
types := make([]reflect.Type, len(columnTypes))
for i, tp := range columnTypes {
types[i] = tp.ScanType()
}
values := make([]interface{}, len(columnTypes))
for i := range values {
values[i] = reflect.New(types[i]).Interface()
}
return values
}
func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []interface{}) map[string]interface{} {
r := map[string]interface{}{}
for i, v := range columnTypes {
if s, ok := (scanArgs[i]).(*sql.NullString); ok {
r[v.Name()] = s.String
for i, ct := range columnTypes {
value := values[i]
continue
switch v := values[i].(type) {
case driver.Valuer:
if vv, err := v.Value(); err == nil {
value = interface{}(vv)
} else {
m.logger.Warnf("error to convert value: %v", err)
}
case *sql.RawBytes:
// special case for sql.RawBytes, see https://github.com/go-sql-driver/mysql/blob/master/fields.go#L178
switch ct.DatabaseTypeName() {
case "VARCHAR", "CHAR":
value = string(*v)
}
}
if s, ok := (scanArgs[i]).(*sql.NullBool); ok {
r[v.Name()] = s.Bool
continue
if value != nil {
r[ct.Name()] = value
}
if s, ok := (scanArgs[i]).(*sql.NullInt32); ok {
r[v.Name()] = s.Int32
continue
}
if s, ok := (scanArgs[i]).(*sql.NullInt64); ok {
r[v.Name()] = s.Int64
continue
}
if s, ok := (scanArgs[i]).(*sql.NullFloat64); ok {
r[v.Name()] = s.Float64
continue
}
if s, ok := (scanArgs[i]).(*sql.NullTime); ok {
r[v.Name()] = s.Time
continue
}
// this won't happen since the default switch is sql.NullString
r[v.Name()] = scanArgs[i]
}
return r
}
func prepareScanArgs(columnTypes []*sql.ColumnType) []interface{} {
scanArgs := make([]interface{}, len(columnTypes))
for i, v := range columnTypes {
switch v.DatabaseTypeName() {
case "BOOL":
scanArgs[i] = new(sql.NullBool)
case "INT", "MEDIUMINT", "SMALLINT", "CHAR", "TINYINT":
scanArgs[i] = new(sql.NullInt32)
case "BIGINT":
scanArgs[i] = new(sql.NullInt64)
case "DOUBLE", "FLOAT", "DECIMAL":
scanArgs[i] = new(sql.NullFloat64)
case "DATE", "TIME", "YEAR":
scanArgs[i] = new(sql.NullTime)
default:
scanArgs[i] = new(sql.NullString)
}
}
return scanArgs
}

View File

@ -121,7 +121,7 @@ func TestMysqlIntegration(t *testing.T) {
assert.True(t, ok)
// have to use custom layout to parse timestamp, see this: https://github.com/dapr/components-contrib/pull/615
var tt time.Time
tt, err = time.Parse("2006-01-02 15:04:05", ts)
tt, err = time.Parse("2006-01-02T15:04:05Z", ts)
assert.Nil(t, err)
t.Logf("time stamp is: %v", tt)
})

View File

@ -26,7 +26,7 @@ func TestQuery(t *testing.T) {
ret, err := m.query(`SELECT * FROM foo WHERE id < 4`)
assert.Nil(t, err)
t.Logf("query result: %s", ret)
assert.Contains(t, string(ret), "\"id\":\"1\"")
assert.Contains(t, string(ret), "\"id\":1")
var result []interface{}
err = json.Unmarshal(ret, &result)
assert.Nil(t, err)