Add native BulkGet to state.mysql (#2762)

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Alessandro (Ale) Segala 2023-04-13 12:33:04 -07:00 committed by GitHub
parent 13193ffdd8
commit ad6a26bf43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 61 deletions

View File

@ -76,8 +76,6 @@ const (
// MySQL state store.
type MySQL struct {
state.BulkStore
tableName string
metadataTableName string
cleanupInterval *time.Duration
@ -111,9 +109,7 @@ func NewMySQLStateStore(logger logger.Logger) state.Store {
// Store the provided logger and return the object. The rest of the
// properties will be populated in the Init function
s := newMySQLStateStore(logger, factory)
s.BulkStore = state.NewDefaultBulkStore(s)
return s
return newMySQLStateStore(logger, factory)
}
// Hidden implementation for testing.
@ -311,7 +307,7 @@ func (m *MySQL) ensureStateSchema(ctx context.Context) error {
cctx, cancel := context.WithTimeout(ctx, m.timeout)
defer cancel()
_, err = m.db.ExecContext(cctx,
fmt.Sprintf("CREATE DATABASE %s;", m.schemaName),
"CREATE DATABASE "+m.schemaName,
)
if err != nil {
return err
@ -504,13 +500,13 @@ func (m *MySQL) deleteValue(parentCtx context.Context, querier querier, req *sta
defer cancel()
if req.ETag == nil || *req.ETag == "" {
result, err = querier.ExecContext(execCtx, fmt.Sprintf(
`DELETE FROM %s WHERE id = ?`,
m.tableName), req.Key)
result, err = querier.ExecContext(execCtx,
`DELETE FROM `+m.tableName+` WHERE id = ?`,
req.Key)
} else {
result, err = querier.ExecContext(execCtx, fmt.Sprintf(
`DELETE FROM %s WHERE id = ? AND eTag = ?`,
m.tableName), req.Key, *req.ETag)
result, err = querier.ExecContext(execCtx,
`DELETE FROM `+m.tableName+` WHERE id = ? AND eTag = ?`,
req.Key, *req.ETag)
}
if err != nil {
@ -559,59 +555,26 @@ func (m *MySQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error
// Store Interface.
func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
if req.Key == "" {
return nil, fmt.Errorf("missing key in get operation")
return nil, errors.New("missing key in get operation")
}
var (
eTag string
value []byte
isBinary bool
)
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
defer cancel()
//nolint:gosec
query := fmt.Sprintf(
`SELECT value, eTag, isbinary FROM %s WHERE id = ?
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`,
m.tableName, // m.tableName is sanitized
)
err := m.db.QueryRowContext(ctx, query, req.Key).Scan(&value, &eTag, &isBinary)
// Concatenation is required for table name because sql.DB does not substitute parameters for table names
query := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + ` WHERE id = ?
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
row := m.db.QueryRowContext(ctx, query, req.Key)
_, value, etag, err := readRow(row)
if err != nil {
// If no rows exist, return an empty response, otherwise return an error.
if errors.Is(err, sql.ErrNoRows) {
return &state.GetResponse{}, nil
}
return nil, err
}
if isBinary {
var (
s string
data []byte
)
err = json.Unmarshal(value, &s)
if err != nil {
return nil, err
}
data, err = base64.StdEncoding.DecodeString(s)
if err != nil {
return nil, err
}
return &state.GetResponse{
Data: data,
ETag: &eTag,
Metadata: req.Metadata,
}, nil
}
return &state.GetResponse{
Data: value,
ETag: &eTag,
ETag: &etag,
Metadata: req.Metadata,
}, nil
}
@ -761,6 +724,77 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
return nil
}
func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ state.BulkGetOpts) ([]state.BulkGetResponse, error) {
if len(req) == 0 {
return []state.BulkGetResponse{}, nil
}
// MySQL doesn't support passing an array for an IN clause, so we need to build a custom query
inClause := strings.Repeat("?,", len(req))
inClause = inClause[:(len(inClause) - 1)]
params := make([]any, len(req))
for i, r := range req {
params[i] = r.Key
}
// Concatenation is required for table name because sql.DB does not substitute parameters for table names
stmt := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + `
WHERE
id IN (` + inClause + `)
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
defer cancel()
rows, err := m.db.QueryContext(ctx, stmt, params...)
if err != nil {
return nil, err
}
var (
n int
etag string
)
res := make([]state.BulkGetResponse, len(req))
for ; rows.Next(); n++ {
r := state.BulkGetResponse{}
r.Key, r.Data, etag, err = readRow(rows)
if err != nil {
r.Error = err.Error()
}
r.ETag = &etag
res[n] = r
}
return res[:n], nil
}
func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etag string, err error) {
var isBinary bool
err = row.Scan(&key, &value, &etag, &isBinary)
if err != nil {
return key, nil, "", err
}
if isBinary {
var (
s string
data []byte
)
err = json.Unmarshal(value, &s)
if err != nil {
return key, nil, "", fmt.Errorf("failed to unmarshal JSON binary data: %w", err)
}
data, err = base64.StdEncoding.DecodeString(s)
if err != nil {
return key, nil, "", fmt.Errorf("failed to decode binary data: %w", err)
}
return key, data, etag, nil
}
return key, value, etag, nil
}
// BulkSet adds/updates multiple entities on store
// Store Interface.
func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error {

View File

@ -433,7 +433,7 @@ func TestGetHandlesNoRows(t *testing.T) {
m, _ := mockDatabase(t)
defer m.mySQL.Close()
m.mock1.ExpectQuery("SELECT value").WillReturnRows(sqlmock.NewRows([]string{"value", "eTag"}))
m.mock1.ExpectQuery("SELECT id").WillReturnRows(sqlmock.NewRows([]string{"UnitTest", "value", "eTag"}))
request := &state.GetRequest{
Key: "UnitTest",
@ -443,7 +443,7 @@ func TestGetHandlesNoRows(t *testing.T) {
response, err := m.mySQL.Get(context.Background(), request)
// Assert
assert.Nil(t, err, "returned error")
assert.NoError(t, err, "returned error")
assert.NotNil(t, response, "did not return empty response")
}
@ -460,7 +460,7 @@ func TestGetHandlesNoKey(t *testing.T) {
response, err := m.mySQL.Get(context.Background(), request)
// Assert
assert.NotNil(t, err, "returned error")
assert.Error(t, err, "returned error")
assert.Equal(t, "missing key in get operation", err.Error(), "wrong error returned")
assert.Nil(t, response, "returned response")
}
@ -490,8 +490,8 @@ func TestGetSucceeds(t *testing.T) {
defer m.mySQL.Close()
t.Run("has json type", func(t *testing.T) {
rows := sqlmock.NewRows([]string{"value", "eTag", "isbinary"}).AddRow("{}", "946af56e", false)
m.mock1.ExpectQuery("SELECT value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", "{}", "946af56e", false)
m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
request := &state.GetRequest{
Key: "UnitTest",
@ -508,8 +508,8 @@ func TestGetSucceeds(t *testing.T) {
t.Run("has binary type", func(t *testing.T) {
value, _ := utils.Marshal(base64.StdEncoding.EncodeToString([]byte("abcdefg")), json.Marshal)
rows := sqlmock.NewRows([]string{"value", "eTag", "isbinary"}).AddRow(value, "946af56e", true)
m.mock1.ExpectQuery("SELECT value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", value, "946af56e", true)
m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
request := &state.GetRequest{
Key: "UnitTest",

View File

@ -33,10 +33,10 @@ components:
operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
- component: mysql.mysql
allOperations: false
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
- component: mysql.mariadb
allOperations: false
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
- component: azure.tablestorage.storage
allOperations: false
operations: ["set", "get", "delete", "etag", "bulkset", "bulkdelete", "first-write"]