Add native BulkGet to state.mysql (#2762)

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Alessandro (Ale) Segala 2023-04-13 12:33:04 -07:00 committed by GitHub
parent 13193ffdd8
commit ad6a26bf43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 61 deletions

View File

@ -76,8 +76,6 @@ const (
// MySQL state store. // MySQL state store.
type MySQL struct { type MySQL struct {
state.BulkStore
tableName string tableName string
metadataTableName string metadataTableName string
cleanupInterval *time.Duration cleanupInterval *time.Duration
@ -111,9 +109,7 @@ func NewMySQLStateStore(logger logger.Logger) state.Store {
// Store the provided logger and return the object. The rest of the // Store the provided logger and return the object. The rest of the
// properties will be populated in the Init function // properties will be populated in the Init function
s := newMySQLStateStore(logger, factory) return newMySQLStateStore(logger, factory)
s.BulkStore = state.NewDefaultBulkStore(s)
return s
} }
// Hidden implementation for testing. // Hidden implementation for testing.
@ -311,7 +307,7 @@ func (m *MySQL) ensureStateSchema(ctx context.Context) error {
cctx, cancel := context.WithTimeout(ctx, m.timeout) cctx, cancel := context.WithTimeout(ctx, m.timeout)
defer cancel() defer cancel()
_, err = m.db.ExecContext(cctx, _, err = m.db.ExecContext(cctx,
fmt.Sprintf("CREATE DATABASE %s;", m.schemaName), "CREATE DATABASE "+m.schemaName,
) )
if err != nil { if err != nil {
return err return err
@ -504,13 +500,13 @@ func (m *MySQL) deleteValue(parentCtx context.Context, querier querier, req *sta
defer cancel() defer cancel()
if req.ETag == nil || *req.ETag == "" { if req.ETag == nil || *req.ETag == "" {
result, err = querier.ExecContext(execCtx, fmt.Sprintf( result, err = querier.ExecContext(execCtx,
`DELETE FROM %s WHERE id = ?`, `DELETE FROM `+m.tableName+` WHERE id = ?`,
m.tableName), req.Key) req.Key)
} else { } else {
result, err = querier.ExecContext(execCtx, fmt.Sprintf( result, err = querier.ExecContext(execCtx,
`DELETE FROM %s WHERE id = ? AND eTag = ?`, `DELETE FROM `+m.tableName+` WHERE id = ? AND eTag = ?`,
m.tableName), req.Key, *req.ETag) req.Key, *req.ETag)
} }
if err != nil { if err != nil {
@ -559,59 +555,26 @@ func (m *MySQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error
// Store Interface. // Store Interface.
func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) { func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
if req.Key == "" { if req.Key == "" {
return nil, fmt.Errorf("missing key in get operation") return nil, errors.New("missing key in get operation")
} }
var (
eTag string
value []byte
isBinary bool
)
ctx, cancel := context.WithTimeout(parentCtx, m.timeout) ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
defer cancel() defer cancel()
//nolint:gosec // Concatenation is required for table name because sql.DB does not substitute parameters for table names
query := fmt.Sprintf( query := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + ` WHERE id = ?
`SELECT value, eTag, isbinary FROM %s WHERE id = ? AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`, row := m.db.QueryRowContext(ctx, query, req.Key)
m.tableName, // m.tableName is sanitized _, value, etag, err := readRow(row)
)
err := m.db.QueryRowContext(ctx, query, req.Key).Scan(&value, &eTag, &isBinary)
if err != nil { if err != nil {
// If no rows exist, return an empty response, otherwise return an error. // If no rows exist, return an empty response, otherwise return an error.
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return &state.GetResponse{}, nil return &state.GetResponse{}, nil
} }
return nil, err return nil, err
} }
if isBinary {
var (
s string
data []byte
)
err = json.Unmarshal(value, &s)
if err != nil {
return nil, err
}
data, err = base64.StdEncoding.DecodeString(s)
if err != nil {
return nil, err
}
return &state.GetResponse{
Data: data,
ETag: &eTag,
Metadata: req.Metadata,
}, nil
}
return &state.GetResponse{ return &state.GetResponse{
Data: value, Data: value,
ETag: &eTag, ETag: &etag,
Metadata: req.Metadata, Metadata: req.Metadata,
}, nil }, nil
} }
@ -761,6 +724,77 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
return nil return nil
} }
func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ state.BulkGetOpts) ([]state.BulkGetResponse, error) {
if len(req) == 0 {
return []state.BulkGetResponse{}, nil
}
// MySQL doesn't support passing an array for an IN clause, so we need to build a custom query
inClause := strings.Repeat("?,", len(req))
inClause = inClause[:(len(inClause) - 1)]
params := make([]any, len(req))
for i, r := range req {
params[i] = r.Key
}
// Concatenation is required for table name because sql.DB does not substitute parameters for table names
stmt := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + `
WHERE
id IN (` + inClause + `)
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
defer cancel()
rows, err := m.db.QueryContext(ctx, stmt, params...)
if err != nil {
return nil, err
}
var (
n int
etag string
)
res := make([]state.BulkGetResponse, len(req))
for ; rows.Next(); n++ {
r := state.BulkGetResponse{}
r.Key, r.Data, etag, err = readRow(rows)
if err != nil {
r.Error = err.Error()
}
r.ETag = &etag
res[n] = r
}
return res[:n], nil
}
func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etag string, err error) {
var isBinary bool
err = row.Scan(&key, &value, &etag, &isBinary)
if err != nil {
return key, nil, "", err
}
if isBinary {
var (
s string
data []byte
)
err = json.Unmarshal(value, &s)
if err != nil {
return key, nil, "", fmt.Errorf("failed to unmarshal JSON binary data: %w", err)
}
data, err = base64.StdEncoding.DecodeString(s)
if err != nil {
return key, nil, "", fmt.Errorf("failed to decode binary data: %w", err)
}
return key, data, etag, nil
}
return key, value, etag, nil
}
// BulkSet adds/updates multiple entities on store // BulkSet adds/updates multiple entities on store
// Store Interface. // Store Interface.
func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error { func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error {

View File

@ -433,7 +433,7 @@ func TestGetHandlesNoRows(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.mySQL.Close() defer m.mySQL.Close()
m.mock1.ExpectQuery("SELECT value").WillReturnRows(sqlmock.NewRows([]string{"value", "eTag"})) m.mock1.ExpectQuery("SELECT id").WillReturnRows(sqlmock.NewRows([]string{"UnitTest", "value", "eTag"}))
request := &state.GetRequest{ request := &state.GetRequest{
Key: "UnitTest", Key: "UnitTest",
@ -443,7 +443,7 @@ func TestGetHandlesNoRows(t *testing.T) {
response, err := m.mySQL.Get(context.Background(), request) response, err := m.mySQL.Get(context.Background(), request)
// Assert // Assert
assert.Nil(t, err, "returned error") assert.NoError(t, err, "returned error")
assert.NotNil(t, response, "did not return empty response") assert.NotNil(t, response, "did not return empty response")
} }
@ -460,7 +460,7 @@ func TestGetHandlesNoKey(t *testing.T) {
response, err := m.mySQL.Get(context.Background(), request) response, err := m.mySQL.Get(context.Background(), request)
// Assert // Assert
assert.NotNil(t, err, "returned error") assert.Error(t, err, "returned error")
assert.Equal(t, "missing key in get operation", err.Error(), "wrong error returned") assert.Equal(t, "missing key in get operation", err.Error(), "wrong error returned")
assert.Nil(t, response, "returned response") assert.Nil(t, response, "returned response")
} }
@ -490,8 +490,8 @@ func TestGetSucceeds(t *testing.T) {
defer m.mySQL.Close() defer m.mySQL.Close()
t.Run("has json type", func(t *testing.T) { t.Run("has json type", func(t *testing.T) {
rows := sqlmock.NewRows([]string{"value", "eTag", "isbinary"}).AddRow("{}", "946af56e", false) rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", "{}", "946af56e", false)
m.mock1.ExpectQuery("SELECT value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows) m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
request := &state.GetRequest{ request := &state.GetRequest{
Key: "UnitTest", Key: "UnitTest",
@ -508,8 +508,8 @@ func TestGetSucceeds(t *testing.T) {
t.Run("has binary type", func(t *testing.T) { t.Run("has binary type", func(t *testing.T) {
value, _ := utils.Marshal(base64.StdEncoding.EncodeToString([]byte("abcdefg")), json.Marshal) value, _ := utils.Marshal(base64.StdEncoding.EncodeToString([]byte("abcdefg")), json.Marshal)
rows := sqlmock.NewRows([]string{"value", "eTag", "isbinary"}).AddRow(value, "946af56e", true) rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", value, "946af56e", true)
m.mock1.ExpectQuery("SELECT value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows) m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
request := &state.GetRequest{ request := &state.GetRequest{
Key: "UnitTest", Key: "UnitTest",

View File

@ -33,10 +33,10 @@ components:
operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ] operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
- component: mysql.mysql - component: mysql.mysql
allOperations: false allOperations: false
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ] operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
- component: mysql.mariadb - component: mysql.mariadb
allOperations: false allOperations: false
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ] operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
- component: azure.tablestorage.storage - component: azure.tablestorage.storage
allOperations: false allOperations: false
operations: ["set", "get", "delete", "etag", "bulkset", "bulkdelete", "first-write"] operations: ["set", "get", "delete", "etag", "bulkset", "bulkdelete", "first-write"]