Add native BulkGet to state.mysql (#2762)
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
13193ffdd8
commit
ad6a26bf43
|
@ -76,8 +76,6 @@ const (
|
||||||
|
|
||||||
// MySQL state store.
|
// MySQL state store.
|
||||||
type MySQL struct {
|
type MySQL struct {
|
||||||
state.BulkStore
|
|
||||||
|
|
||||||
tableName string
|
tableName string
|
||||||
metadataTableName string
|
metadataTableName string
|
||||||
cleanupInterval *time.Duration
|
cleanupInterval *time.Duration
|
||||||
|
@ -111,9 +109,7 @@ func NewMySQLStateStore(logger logger.Logger) state.Store {
|
||||||
|
|
||||||
// Store the provided logger and return the object. The rest of the
|
// Store the provided logger and return the object. The rest of the
|
||||||
// properties will be populated in the Init function
|
// properties will be populated in the Init function
|
||||||
s := newMySQLStateStore(logger, factory)
|
return newMySQLStateStore(logger, factory)
|
||||||
s.BulkStore = state.NewDefaultBulkStore(s)
|
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hidden implementation for testing.
|
// Hidden implementation for testing.
|
||||||
|
@ -311,7 +307,7 @@ func (m *MySQL) ensureStateSchema(ctx context.Context) error {
|
||||||
cctx, cancel := context.WithTimeout(ctx, m.timeout)
|
cctx, cancel := context.WithTimeout(ctx, m.timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
_, err = m.db.ExecContext(cctx,
|
_, err = m.db.ExecContext(cctx,
|
||||||
fmt.Sprintf("CREATE DATABASE %s;", m.schemaName),
|
"CREATE DATABASE "+m.schemaName,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -504,13 +500,13 @@ func (m *MySQL) deleteValue(parentCtx context.Context, querier querier, req *sta
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if req.ETag == nil || *req.ETag == "" {
|
if req.ETag == nil || *req.ETag == "" {
|
||||||
result, err = querier.ExecContext(execCtx, fmt.Sprintf(
|
result, err = querier.ExecContext(execCtx,
|
||||||
`DELETE FROM %s WHERE id = ?`,
|
`DELETE FROM `+m.tableName+` WHERE id = ?`,
|
||||||
m.tableName), req.Key)
|
req.Key)
|
||||||
} else {
|
} else {
|
||||||
result, err = querier.ExecContext(execCtx, fmt.Sprintf(
|
result, err = querier.ExecContext(execCtx,
|
||||||
`DELETE FROM %s WHERE id = ? AND eTag = ?`,
|
`DELETE FROM `+m.tableName+` WHERE id = ? AND eTag = ?`,
|
||||||
m.tableName), req.Key, *req.ETag)
|
req.Key, *req.ETag)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -559,59 +555,26 @@ func (m *MySQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error
|
||||||
// Store Interface.
|
// Store Interface.
|
||||||
func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
|
func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
|
||||||
if req.Key == "" {
|
if req.Key == "" {
|
||||||
return nil, fmt.Errorf("missing key in get operation")
|
return nil, errors.New("missing key in get operation")
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
eTag string
|
|
||||||
value []byte
|
|
||||||
isBinary bool
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
|
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
//nolint:gosec
|
// Concatenation is required for table name because sql.DB does not substitute parameters for table names
|
||||||
query := fmt.Sprintf(
|
query := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + ` WHERE id = ?
|
||||||
`SELECT value, eTag, isbinary FROM %s WHERE id = ?
|
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
|
||||||
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`,
|
row := m.db.QueryRowContext(ctx, query, req.Key)
|
||||||
m.tableName, // m.tableName is sanitized
|
_, value, etag, err := readRow(row)
|
||||||
)
|
|
||||||
err := m.db.QueryRowContext(ctx, query, req.Key).Scan(&value, &eTag, &isBinary)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If no rows exist, return an empty response, otherwise return an error.
|
// If no rows exist, return an empty response, otherwise return an error.
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return &state.GetResponse{}, nil
|
return &state.GetResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if isBinary {
|
|
||||||
var (
|
|
||||||
s string
|
|
||||||
data []byte
|
|
||||||
)
|
|
||||||
|
|
||||||
err = json.Unmarshal(value, &s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err = base64.StdEncoding.DecodeString(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &state.GetResponse{
|
|
||||||
Data: data,
|
|
||||||
ETag: &eTag,
|
|
||||||
Metadata: req.Metadata,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &state.GetResponse{
|
return &state.GetResponse{
|
||||||
Data: value,
|
Data: value,
|
||||||
ETag: &eTag,
|
ETag: &etag,
|
||||||
Metadata: req.Metadata,
|
Metadata: req.Metadata,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -761,6 +724,77 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ state.BulkGetOpts) ([]state.BulkGetResponse, error) {
|
||||||
|
if len(req) == 0 {
|
||||||
|
return []state.BulkGetResponse{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MySQL doesn't support passing an array for an IN clause, so we need to build a custom query
|
||||||
|
inClause := strings.Repeat("?,", len(req))
|
||||||
|
inClause = inClause[:(len(inClause) - 1)]
|
||||||
|
params := make([]any, len(req))
|
||||||
|
for i, r := range req {
|
||||||
|
params[i] = r.Key
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concatenation is required for table name because sql.DB does not substitute parameters for table names
|
||||||
|
stmt := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + `
|
||||||
|
WHERE
|
||||||
|
id IN (` + inClause + `)
|
||||||
|
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
|
||||||
|
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
|
||||||
|
defer cancel()
|
||||||
|
rows, err := m.db.QueryContext(ctx, stmt, params...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
n int
|
||||||
|
etag string
|
||||||
|
)
|
||||||
|
res := make([]state.BulkGetResponse, len(req))
|
||||||
|
for ; rows.Next(); n++ {
|
||||||
|
r := state.BulkGetResponse{}
|
||||||
|
r.Key, r.Data, etag, err = readRow(rows)
|
||||||
|
if err != nil {
|
||||||
|
r.Error = err.Error()
|
||||||
|
}
|
||||||
|
r.ETag = &etag
|
||||||
|
res[n] = r
|
||||||
|
}
|
||||||
|
|
||||||
|
return res[:n], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etag string, err error) {
|
||||||
|
var isBinary bool
|
||||||
|
err = row.Scan(&key, &value, &etag, &isBinary)
|
||||||
|
if err != nil {
|
||||||
|
return key, nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if isBinary {
|
||||||
|
var (
|
||||||
|
s string
|
||||||
|
data []byte
|
||||||
|
)
|
||||||
|
|
||||||
|
err = json.Unmarshal(value, &s)
|
||||||
|
if err != nil {
|
||||||
|
return key, nil, "", fmt.Errorf("failed to unmarshal JSON binary data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = base64.StdEncoding.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return key, nil, "", fmt.Errorf("failed to decode binary data: %w", err)
|
||||||
|
}
|
||||||
|
return key, data, etag, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, value, etag, nil
|
||||||
|
}
|
||||||
|
|
||||||
// BulkSet adds/updates multiple entities on store
|
// BulkSet adds/updates multiple entities on store
|
||||||
// Store Interface.
|
// Store Interface.
|
||||||
func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error {
|
func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error {
|
||||||
|
|
|
@ -433,7 +433,7 @@ func TestGetHandlesNoRows(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.mySQL.Close()
|
defer m.mySQL.Close()
|
||||||
|
|
||||||
m.mock1.ExpectQuery("SELECT value").WillReturnRows(sqlmock.NewRows([]string{"value", "eTag"}))
|
m.mock1.ExpectQuery("SELECT id").WillReturnRows(sqlmock.NewRows([]string{"UnitTest", "value", "eTag"}))
|
||||||
|
|
||||||
request := &state.GetRequest{
|
request := &state.GetRequest{
|
||||||
Key: "UnitTest",
|
Key: "UnitTest",
|
||||||
|
@ -443,7 +443,7 @@ func TestGetHandlesNoRows(t *testing.T) {
|
||||||
response, err := m.mySQL.Get(context.Background(), request)
|
response, err := m.mySQL.Get(context.Background(), request)
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assert.Nil(t, err, "returned error")
|
assert.NoError(t, err, "returned error")
|
||||||
assert.NotNil(t, response, "did not return empty response")
|
assert.NotNil(t, response, "did not return empty response")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -460,7 +460,7 @@ func TestGetHandlesNoKey(t *testing.T) {
|
||||||
response, err := m.mySQL.Get(context.Background(), request)
|
response, err := m.mySQL.Get(context.Background(), request)
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assert.NotNil(t, err, "returned error")
|
assert.Error(t, err, "returned error")
|
||||||
assert.Equal(t, "missing key in get operation", err.Error(), "wrong error returned")
|
assert.Equal(t, "missing key in get operation", err.Error(), "wrong error returned")
|
||||||
assert.Nil(t, response, "returned response")
|
assert.Nil(t, response, "returned response")
|
||||||
}
|
}
|
||||||
|
@ -490,8 +490,8 @@ func TestGetSucceeds(t *testing.T) {
|
||||||
defer m.mySQL.Close()
|
defer m.mySQL.Close()
|
||||||
|
|
||||||
t.Run("has json type", func(t *testing.T) {
|
t.Run("has json type", func(t *testing.T) {
|
||||||
rows := sqlmock.NewRows([]string{"value", "eTag", "isbinary"}).AddRow("{}", "946af56e", false)
|
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", "{}", "946af56e", false)
|
||||||
m.mock1.ExpectQuery("SELECT value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
|
m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
|
||||||
|
|
||||||
request := &state.GetRequest{
|
request := &state.GetRequest{
|
||||||
Key: "UnitTest",
|
Key: "UnitTest",
|
||||||
|
@ -508,8 +508,8 @@ func TestGetSucceeds(t *testing.T) {
|
||||||
|
|
||||||
t.Run("has binary type", func(t *testing.T) {
|
t.Run("has binary type", func(t *testing.T) {
|
||||||
value, _ := utils.Marshal(base64.StdEncoding.EncodeToString([]byte("abcdefg")), json.Marshal)
|
value, _ := utils.Marshal(base64.StdEncoding.EncodeToString([]byte("abcdefg")), json.Marshal)
|
||||||
rows := sqlmock.NewRows([]string{"value", "eTag", "isbinary"}).AddRow(value, "946af56e", true)
|
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", value, "946af56e", true)
|
||||||
m.mock1.ExpectQuery("SELECT value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
|
m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
|
||||||
|
|
||||||
request := &state.GetRequest{
|
request := &state.GetRequest{
|
||||||
Key: "UnitTest",
|
Key: "UnitTest",
|
||||||
|
|
|
@ -33,10 +33,10 @@ components:
|
||||||
operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
||||||
- component: mysql.mysql
|
- component: mysql.mysql
|
||||||
allOperations: false
|
allOperations: false
|
||||||
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
||||||
- component: mysql.mariadb
|
- component: mysql.mariadb
|
||||||
allOperations: false
|
allOperations: false
|
||||||
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
||||||
- component: azure.tablestorage.storage
|
- component: azure.tablestorage.storage
|
||||||
allOperations: false
|
allOperations: false
|
||||||
operations: ["set", "get", "delete", "etag", "bulkset", "bulkdelete", "first-write"]
|
operations: ["set", "get", "delete", "etag", "bulkset", "bulkdelete", "first-write"]
|
||||||
|
|
Loading…
Reference in New Issue