Add native BulkGet to state.mysql (#2762)
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
13193ffdd8
commit
ad6a26bf43
|
@ -76,8 +76,6 @@ const (
|
|||
|
||||
// MySQL state store.
|
||||
type MySQL struct {
|
||||
state.BulkStore
|
||||
|
||||
tableName string
|
||||
metadataTableName string
|
||||
cleanupInterval *time.Duration
|
||||
|
@ -111,9 +109,7 @@ func NewMySQLStateStore(logger logger.Logger) state.Store {
|
|||
|
||||
// Store the provided logger and return the object. The rest of the
|
||||
// properties will be populated in the Init function
|
||||
s := newMySQLStateStore(logger, factory)
|
||||
s.BulkStore = state.NewDefaultBulkStore(s)
|
||||
return s
|
||||
return newMySQLStateStore(logger, factory)
|
||||
}
|
||||
|
||||
// Hidden implementation for testing.
|
||||
|
@ -311,7 +307,7 @@ func (m *MySQL) ensureStateSchema(ctx context.Context) error {
|
|||
cctx, cancel := context.WithTimeout(ctx, m.timeout)
|
||||
defer cancel()
|
||||
_, err = m.db.ExecContext(cctx,
|
||||
fmt.Sprintf("CREATE DATABASE %s;", m.schemaName),
|
||||
"CREATE DATABASE "+m.schemaName,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -504,13 +500,13 @@ func (m *MySQL) deleteValue(parentCtx context.Context, querier querier, req *sta
|
|||
defer cancel()
|
||||
|
||||
if req.ETag == nil || *req.ETag == "" {
|
||||
result, err = querier.ExecContext(execCtx, fmt.Sprintf(
|
||||
`DELETE FROM %s WHERE id = ?`,
|
||||
m.tableName), req.Key)
|
||||
result, err = querier.ExecContext(execCtx,
|
||||
`DELETE FROM `+m.tableName+` WHERE id = ?`,
|
||||
req.Key)
|
||||
} else {
|
||||
result, err = querier.ExecContext(execCtx, fmt.Sprintf(
|
||||
`DELETE FROM %s WHERE id = ? AND eTag = ?`,
|
||||
m.tableName), req.Key, *req.ETag)
|
||||
result, err = querier.ExecContext(execCtx,
|
||||
`DELETE FROM `+m.tableName+` WHERE id = ? AND eTag = ?`,
|
||||
req.Key, *req.ETag)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -559,59 +555,26 @@ func (m *MySQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error
|
|||
// Store Interface.
|
||||
func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
|
||||
if req.Key == "" {
|
||||
return nil, fmt.Errorf("missing key in get operation")
|
||||
return nil, errors.New("missing key in get operation")
|
||||
}
|
||||
|
||||
var (
|
||||
eTag string
|
||||
value []byte
|
||||
isBinary bool
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
|
||||
defer cancel()
|
||||
//nolint:gosec
|
||||
query := fmt.Sprintf(
|
||||
`SELECT value, eTag, isbinary FROM %s WHERE id = ?
|
||||
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`,
|
||||
m.tableName, // m.tableName is sanitized
|
||||
)
|
||||
err := m.db.QueryRowContext(ctx, query, req.Key).Scan(&value, &eTag, &isBinary)
|
||||
// Concatenation is required for table name because sql.DB does not substitute parameters for table names
|
||||
query := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + ` WHERE id = ?
|
||||
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
|
||||
row := m.db.QueryRowContext(ctx, query, req.Key)
|
||||
_, value, etag, err := readRow(row)
|
||||
if err != nil {
|
||||
// If no rows exist, return an empty response, otherwise return an error.
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return &state.GetResponse{}, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isBinary {
|
||||
var (
|
||||
s string
|
||||
data []byte
|
||||
)
|
||||
|
||||
err = json.Unmarshal(value, &s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err = base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &state.GetResponse{
|
||||
Data: data,
|
||||
ETag: &eTag,
|
||||
Metadata: req.Metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &state.GetResponse{
|
||||
Data: value,
|
||||
ETag: &eTag,
|
||||
ETag: &etag,
|
||||
Metadata: req.Metadata,
|
||||
}, nil
|
||||
}
|
||||
|
@ -761,6 +724,77 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ state.BulkGetOpts) ([]state.BulkGetResponse, error) {
|
||||
if len(req) == 0 {
|
||||
return []state.BulkGetResponse{}, nil
|
||||
}
|
||||
|
||||
// MySQL doesn't support passing an array for an IN clause, so we need to build a custom query
|
||||
inClause := strings.Repeat("?,", len(req))
|
||||
inClause = inClause[:(len(inClause) - 1)]
|
||||
params := make([]any, len(req))
|
||||
for i, r := range req {
|
||||
params[i] = r.Key
|
||||
}
|
||||
|
||||
// Concatenation is required for table name because sql.DB does not substitute parameters for table names
|
||||
stmt := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + `
|
||||
WHERE
|
||||
id IN (` + inClause + `)
|
||||
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
|
||||
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
|
||||
defer cancel()
|
||||
rows, err := m.db.QueryContext(ctx, stmt, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
n int
|
||||
etag string
|
||||
)
|
||||
res := make([]state.BulkGetResponse, len(req))
|
||||
for ; rows.Next(); n++ {
|
||||
r := state.BulkGetResponse{}
|
||||
r.Key, r.Data, etag, err = readRow(rows)
|
||||
if err != nil {
|
||||
r.Error = err.Error()
|
||||
}
|
||||
r.ETag = &etag
|
||||
res[n] = r
|
||||
}
|
||||
|
||||
return res[:n], nil
|
||||
}
|
||||
|
||||
func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etag string, err error) {
|
||||
var isBinary bool
|
||||
err = row.Scan(&key, &value, &etag, &isBinary)
|
||||
if err != nil {
|
||||
return key, nil, "", err
|
||||
}
|
||||
|
||||
if isBinary {
|
||||
var (
|
||||
s string
|
||||
data []byte
|
||||
)
|
||||
|
||||
err = json.Unmarshal(value, &s)
|
||||
if err != nil {
|
||||
return key, nil, "", fmt.Errorf("failed to unmarshal JSON binary data: %w", err)
|
||||
}
|
||||
|
||||
data, err = base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return key, nil, "", fmt.Errorf("failed to decode binary data: %w", err)
|
||||
}
|
||||
return key, data, etag, nil
|
||||
}
|
||||
|
||||
return key, value, etag, nil
|
||||
}
|
||||
|
||||
// BulkSet adds/updates multiple entities on store
|
||||
// Store Interface.
|
||||
func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error {
|
||||
|
|
|
@ -433,7 +433,7 @@ func TestGetHandlesNoRows(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.mySQL.Close()
|
||||
|
||||
m.mock1.ExpectQuery("SELECT value").WillReturnRows(sqlmock.NewRows([]string{"value", "eTag"}))
|
||||
m.mock1.ExpectQuery("SELECT id").WillReturnRows(sqlmock.NewRows([]string{"UnitTest", "value", "eTag"}))
|
||||
|
||||
request := &state.GetRequest{
|
||||
Key: "UnitTest",
|
||||
|
@ -443,7 +443,7 @@ func TestGetHandlesNoRows(t *testing.T) {
|
|||
response, err := m.mySQL.Get(context.Background(), request)
|
||||
|
||||
// Assert
|
||||
assert.Nil(t, err, "returned error")
|
||||
assert.NoError(t, err, "returned error")
|
||||
assert.NotNil(t, response, "did not return empty response")
|
||||
}
|
||||
|
||||
|
@ -460,7 +460,7 @@ func TestGetHandlesNoKey(t *testing.T) {
|
|||
response, err := m.mySQL.Get(context.Background(), request)
|
||||
|
||||
// Assert
|
||||
assert.NotNil(t, err, "returned error")
|
||||
assert.Error(t, err, "returned error")
|
||||
assert.Equal(t, "missing key in get operation", err.Error(), "wrong error returned")
|
||||
assert.Nil(t, response, "returned response")
|
||||
}
|
||||
|
@ -490,8 +490,8 @@ func TestGetSucceeds(t *testing.T) {
|
|||
defer m.mySQL.Close()
|
||||
|
||||
t.Run("has json type", func(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"value", "eTag", "isbinary"}).AddRow("{}", "946af56e", false)
|
||||
m.mock1.ExpectQuery("SELECT value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
|
||||
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", "{}", "946af56e", false)
|
||||
m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
|
||||
|
||||
request := &state.GetRequest{
|
||||
Key: "UnitTest",
|
||||
|
@ -508,8 +508,8 @@ func TestGetSucceeds(t *testing.T) {
|
|||
|
||||
t.Run("has binary type", func(t *testing.T) {
|
||||
value, _ := utils.Marshal(base64.StdEncoding.EncodeToString([]byte("abcdefg")), json.Marshal)
|
||||
rows := sqlmock.NewRows([]string{"value", "eTag", "isbinary"}).AddRow(value, "946af56e", true)
|
||||
m.mock1.ExpectQuery("SELECT value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
|
||||
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", value, "946af56e", true)
|
||||
m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
|
||||
|
||||
request := &state.GetRequest{
|
||||
Key: "UnitTest",
|
||||
|
|
|
@ -33,10 +33,10 @@ components:
|
|||
operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
||||
- component: mysql.mysql
|
||||
allOperations: false
|
||||
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
||||
operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
||||
- component: mysql.mariadb
|
||||
allOperations: false
|
||||
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
||||
operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
||||
- component: azure.tablestorage.storage
|
||||
allOperations: false
|
||||
operations: ["set", "get", "delete", "etag", "bulkset", "bulkdelete", "first-write"]
|
||||
|
|
Loading…
Reference in New Issue