From ad6a26bf43f45cecab720de9eb202ecbf7a9abf2 Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Thu, 13 Apr 2023 12:33:04 -0700 Subject: [PATCH] Add native BulkGet to state.mysql (#2762) Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> --- state/mysql/mysql.go | 138 ++++++++++++++++++++++------------- state/mysql/mysql_test.go | 14 ++-- tests/config/state/tests.yml | 4 +- 3 files changed, 95 insertions(+), 61 deletions(-) diff --git a/state/mysql/mysql.go b/state/mysql/mysql.go index b09915fb5..9d64e4c9a 100644 --- a/state/mysql/mysql.go +++ b/state/mysql/mysql.go @@ -76,8 +76,6 @@ const ( // MySQL state store. type MySQL struct { - state.BulkStore - tableName string metadataTableName string cleanupInterval *time.Duration @@ -111,9 +109,7 @@ func NewMySQLStateStore(logger logger.Logger) state.Store { // Store the provided logger and return the object. The rest of the // properties will be populated in the Init function - s := newMySQLStateStore(logger, factory) - s.BulkStore = state.NewDefaultBulkStore(s) - return s + return newMySQLStateStore(logger, factory) } // Hidden implementation for testing. @@ -311,7 +307,7 @@ func (m *MySQL) ensureStateSchema(ctx context.Context) error { cctx, cancel := context.WithTimeout(ctx, m.timeout) defer cancel() _, err = m.db.ExecContext(cctx, - fmt.Sprintf("CREATE DATABASE %s;", m.schemaName), + "CREATE DATABASE "+m.schemaName, ) if err != nil { return err @@ -504,13 +500,13 @@ func (m *MySQL) deleteValue(parentCtx context.Context, querier querier, req *sta defer cancel() if req.ETag == nil || *req.ETag == "" { - result, err = querier.ExecContext(execCtx, fmt.Sprintf( - `DELETE FROM %s WHERE id = ?`, - m.tableName), req.Key) + result, err = querier.ExecContext(execCtx, + `DELETE FROM `+m.tableName+` WHERE id = ?`, + req.Key) } else { - result, err = querier.ExecContext(execCtx, fmt.Sprintf( - `DELETE FROM %s WHERE id = ? AND eTag = ?`, - m.tableName), req.Key, *req.ETag) + result, err = querier.ExecContext(execCtx, + `DELETE FROM `+m.tableName+` WHERE id = ? AND eTag = ?`, + req.Key, *req.ETag) } if err != nil { @@ -559,59 +555,26 @@ func (m *MySQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error // Store Interface. func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) { if req.Key == "" { - return nil, fmt.Errorf("missing key in get operation") + return nil, errors.New("missing key in get operation") } - var ( - eTag string - value []byte - isBinary bool - ) - ctx, cancel := context.WithTimeout(parentCtx, m.timeout) defer cancel() - //nolint:gosec - query := fmt.Sprintf( - `SELECT value, eTag, isbinary FROM %s WHERE id = ? - AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`, - m.tableName, // m.tableName is sanitized - ) - err := m.db.QueryRowContext(ctx, query, req.Key).Scan(&value, &eTag, &isBinary) + // Concatenation is required for table name because sql.DB does not substitute parameters for table names + query := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + ` WHERE id = ? + AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)` + row := m.db.QueryRowContext(ctx, query, req.Key) + _, value, etag, err := readRow(row) if err != nil { // If no rows exist, return an empty response, otherwise return an error. if errors.Is(err, sql.ErrNoRows) { return &state.GetResponse{}, nil } - return nil, err } - - if isBinary { - var ( - s string - data []byte - ) - - err = json.Unmarshal(value, &s) - if err != nil { - return nil, err - } - - data, err = base64.StdEncoding.DecodeString(s) - if err != nil { - return nil, err - } - - return &state.GetResponse{ - Data: data, - ETag: &eTag, - Metadata: req.Metadata, - }, nil - } - return &state.GetResponse{ Data: value, - ETag: &eTag, + ETag: &etag, Metadata: req.Metadata, }, nil } @@ -761,6 +724,77 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state. return nil } +func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ state.BulkGetOpts) ([]state.BulkGetResponse, error) { + if len(req) == 0 { + return []state.BulkGetResponse{}, nil + } + + // MySQL doesn't support passing an array for an IN clause, so we need to build a custom query + inClause := strings.Repeat("?,", len(req)) + inClause = inClause[:(len(inClause) - 1)] + params := make([]any, len(req)) + for i, r := range req { + params[i] = r.Key + } + + // Concatenation is required for table name because sql.DB does not substitute parameters for table names + stmt := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + ` + WHERE + id IN (` + inClause + `) + AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)` + ctx, cancel := context.WithTimeout(parentCtx, m.timeout) + defer cancel() + rows, err := m.db.QueryContext(ctx, stmt, params...) + if err != nil { + return nil, err + } + + var ( + n int + etag string + ) + res := make([]state.BulkGetResponse, len(req)) + for ; rows.Next(); n++ { + r := state.BulkGetResponse{} + r.Key, r.Data, etag, err = readRow(rows) + if err != nil { + r.Error = err.Error() + } + r.ETag = &etag + res[n] = r + } + + return res[:n], nil +} + +func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etag string, err error) { + var isBinary bool + err = row.Scan(&key, &value, &etag, &isBinary) + if err != nil { + return key, nil, "", err + } + + if isBinary { + var ( + s string + data []byte + ) + + err = json.Unmarshal(value, &s) + if err != nil { + return key, nil, "", fmt.Errorf("failed to unmarshal JSON binary data: %w", err) + } + + data, err = base64.StdEncoding.DecodeString(s) + if err != nil { + return key, nil, "", fmt.Errorf("failed to decode binary data: %w", err) + } + return key, data, etag, nil + } + + return key, value, etag, nil +} + // BulkSet adds/updates multiple entities on store // Store Interface. func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error { diff --git a/state/mysql/mysql_test.go b/state/mysql/mysql_test.go index 8d173c15f..ebe02a1e4 100644 --- a/state/mysql/mysql_test.go +++ b/state/mysql/mysql_test.go @@ -433,7 +433,7 @@ func TestGetHandlesNoRows(t *testing.T) { m, _ := mockDatabase(t) defer m.mySQL.Close() - m.mock1.ExpectQuery("SELECT value").WillReturnRows(sqlmock.NewRows([]string{"value", "eTag"})) + m.mock1.ExpectQuery("SELECT id").WillReturnRows(sqlmock.NewRows([]string{"UnitTest", "value", "eTag"})) request := &state.GetRequest{ Key: "UnitTest", @@ -443,7 +443,7 @@ func TestGetHandlesNoRows(t *testing.T) { response, err := m.mySQL.Get(context.Background(), request) // Assert - assert.Nil(t, err, "returned error") + assert.NoError(t, err, "returned error") assert.NotNil(t, response, "did not return empty response") } @@ -460,7 +460,7 @@ func TestGetHandlesNoKey(t *testing.T) { response, err := m.mySQL.Get(context.Background(), request) // Assert - assert.NotNil(t, err, "returned error") + assert.Error(t, err, "returned error") assert.Equal(t, "missing key in get operation", err.Error(), "wrong error returned") assert.Nil(t, response, "returned response") } @@ -490,8 +490,8 @@ func TestGetSucceeds(t *testing.T) { defer m.mySQL.Close() t.Run("has json type", func(t *testing.T) { - rows := sqlmock.NewRows([]string{"value", "eTag", "isbinary"}).AddRow("{}", "946af56e", false) - m.mock1.ExpectQuery("SELECT value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows) + rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", "{}", "946af56e", false) + m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows) request := &state.GetRequest{ Key: "UnitTest", @@ -508,8 +508,8 @@ func TestGetSucceeds(t *testing.T) { t.Run("has binary type", func(t *testing.T) { value, _ := utils.Marshal(base64.StdEncoding.EncodeToString([]byte("abcdefg")), json.Marshal) - rows := sqlmock.NewRows([]string{"value", "eTag", "isbinary"}).AddRow(value, "946af56e", true) - m.mock1.ExpectQuery("SELECT value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows) + rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", value, "946af56e", true) + m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows) request := &state.GetRequest{ Key: "UnitTest", diff --git a/tests/config/state/tests.yml b/tests/config/state/tests.yml index 254e03d5b..e5b4b0195 100644 --- a/tests/config/state/tests.yml +++ b/tests/config/state/tests.yml @@ -33,10 +33,10 @@ components: operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ] - component: mysql.mysql allOperations: false - operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ] + operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ] - component: mysql.mariadb allOperations: false - operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ] + operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ] - component: azure.tablestorage.storage allOperations: false operations: ["set", "get", "delete", "etag", "bulkset", "bulkdelete", "first-write"]