Add custom BulkGet method to Oracle Statestore (#3804)
Signed-off-by: Anton Troshin <anton@diagrid.io>
This commit is contained in:
parent
397766a23e
commit
01a3fe76d5
|
@ -26,6 +26,7 @@ type dbAccess interface {
|
|||
Ping(ctx context.Context) error
|
||||
Set(ctx context.Context, req *state.SetRequest) error
|
||||
Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error)
|
||||
BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error)
|
||||
Delete(ctx context.Context, req *state.DeleteRequest) error
|
||||
ExecuteMulti(parentCtx context.Context, reqs []state.TransactionalStateOperation) error
|
||||
Close() error // io.Closer.
|
||||
|
|
|
@ -78,6 +78,10 @@ func (o *OracleDatabase) Get(ctx context.Context, req *state.GetRequest) (*state
|
|||
return o.dbaccess.Get(ctx, req)
|
||||
}
|
||||
|
||||
func (o *OracleDatabase) BulkGet(ctx context.Context, req []state.GetRequest, opts state.BulkGetOpts) ([]state.BulkGetResponse, error) {
|
||||
return o.dbaccess.BulkGet(ctx, req)
|
||||
}
|
||||
|
||||
// Set adds/updates an entity on store.
|
||||
func (o *OracleDatabase) Set(ctx context.Context, req *state.SetRequest) error {
|
||||
return o.dbaccess.Set(ctx, req)
|
||||
|
|
|
@ -120,6 +120,10 @@ func TestOracleDatabaseIntegration(t *testing.T) {
|
|||
testBulkSetAndBulkDelete(t, ods)
|
||||
})
|
||||
|
||||
t.Run("Bulk get", func(t *testing.T) {
|
||||
testBulkGet(t, ods)
|
||||
})
|
||||
|
||||
t.Run("Update and delete with etag succeeds", func(t *testing.T) {
|
||||
updateAndDeleteWithEtagSucceeds(t, ods)
|
||||
})
|
||||
|
@ -647,6 +651,88 @@ func testBulkSetAndBulkDelete(t *testing.T, ods state.Store) {
|
|||
assert.False(t, storeItemExists(t, db, setReq[1].Key))
|
||||
}
|
||||
|
||||
func testBulkGet(t *testing.T, ods state.Store) {
|
||||
db := getDB(ods)
|
||||
|
||||
setReq := []state.SetRequest{
|
||||
{
|
||||
Key: randomKey(),
|
||||
Value: &fakeItem{Color: "red"},
|
||||
},
|
||||
{
|
||||
Key: randomKey(),
|
||||
Value: &fakeItem{Color: "blue"},
|
||||
},
|
||||
{
|
||||
Key: randomKey(),
|
||||
Value: &fakeItem{Color: "green"},
|
||||
},
|
||||
}
|
||||
|
||||
err := ods.BulkSet(t.Context(), setReq, state.BulkStoreOpts{})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, storeItemExists(t, db, setReq[0].Key))
|
||||
assert.True(t, storeItemExists(t, db, setReq[1].Key))
|
||||
assert.True(t, storeItemExists(t, db, setReq[2].Key))
|
||||
|
||||
getReq := []state.GetRequest{
|
||||
{
|
||||
Key: setReq[0].Key,
|
||||
},
|
||||
{
|
||||
Key: setReq[1].Key,
|
||||
},
|
||||
{
|
||||
Key: setReq[2].Key,
|
||||
},
|
||||
{
|
||||
Key: randomKey(), // This key doesn't exist
|
||||
},
|
||||
}
|
||||
|
||||
responses, err := ods.BulkGet(t.Context(), getReq, state.BulkGetOpts{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 4)
|
||||
|
||||
// Verify the responses
|
||||
// First three items should exist
|
||||
for i := range 3 {
|
||||
assert.Equal(t, getReq[i].Key, responses[i].Key)
|
||||
assert.NotNil(t, responses[i].Data)
|
||||
assert.NotNil(t, responses[i].ETag)
|
||||
|
||||
// Verify the data
|
||||
var item fakeItem
|
||||
err = json.Unmarshal(responses[i].Data, &item)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the color matches what we set
|
||||
originalItem := setReq[i].Value.(*fakeItem)
|
||||
assert.Equal(t, originalItem.Color, item.Color)
|
||||
}
|
||||
|
||||
// The fourth item should not exist (empty response)
|
||||
assert.Equal(t, getReq[3].Key, responses[3].Key)
|
||||
assert.Nil(t, responses[3].Data)
|
||||
assert.Nil(t, responses[3].ETag)
|
||||
|
||||
// Clean up
|
||||
deleteReq := []state.DeleteRequest{
|
||||
{
|
||||
Key: setReq[0].Key,
|
||||
},
|
||||
{
|
||||
Key: setReq[1].Key,
|
||||
},
|
||||
{
|
||||
Key: setReq[2].Key,
|
||||
},
|
||||
}
|
||||
|
||||
err = ods.BulkDelete(t.Context(), deleteReq, state.BulkStoreOpts{})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// testInitConfiguration tests valid and invalid config settings.
|
||||
func testInitConfiguration(t *testing.T) {
|
||||
logger := logger.NewLogger("test")
|
||||
|
|
|
@ -64,6 +64,10 @@ func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.G
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *fakeDBaccess) BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error) {
|
||||
return []state.BulkGetResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -312,6 +312,120 @@ func (o *oracleDatabaseAccess) Get(ctx context.Context, req *state.GetRequest) (
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (o *oracleDatabaseAccess) BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error) {
|
||||
if len(req) == 0 {
|
||||
return []state.BulkGetResponse{}, nil
|
||||
}
|
||||
|
||||
// Oracle supports the IN operator for bulk operations
|
||||
// Build the IN clause with bind variables
|
||||
// Oracle uses :1, :2, etc. for bind variables in the IN clause
|
||||
params := make([]any, len(req))
|
||||
bindVars := make([]string, len(req))
|
||||
for i, r := range req {
|
||||
if r.Key == "" {
|
||||
return nil, errors.New("missing key in bulk get operation")
|
||||
}
|
||||
params[i] = r.Key
|
||||
bindVars[i] = ":" + strconv.Itoa(i+1)
|
||||
}
|
||||
|
||||
inClause := strings.Join(bindVars, ",")
|
||||
// Concatenation is required for table name because sql.DB does not substitute parameters for table names.
|
||||
//nolint:gosec
|
||||
query := "SELECT key, value, binary_yn, etag, expiration_time FROM " + o.metadata.TableName + " WHERE key IN (" + inClause + ") AND (expiration_time IS NULL OR expiration_time > systimestamp)"
|
||||
|
||||
rows, err := o.db.QueryContext(ctx, query, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var n int
|
||||
res := make([]state.BulkGetResponse, len(req))
|
||||
foundKeys := make(map[string]struct{}, len(req))
|
||||
|
||||
for rows.Next() {
|
||||
if n >= len(req) {
|
||||
// Sanity check to prevent panics, which should never happen
|
||||
return nil, fmt.Errorf("query returned more records than expected (expected %d)", len(req))
|
||||
}
|
||||
|
||||
var (
|
||||
key string
|
||||
value string
|
||||
binaryYN string
|
||||
etag string
|
||||
expireTime sql.NullTime
|
||||
)
|
||||
|
||||
err = rows.Scan(&key, &value, &binaryYN, &etag, &expireTime)
|
||||
if err != nil {
|
||||
res[n] = state.BulkGetResponse{
|
||||
Key: key,
|
||||
Error: err.Error(),
|
||||
}
|
||||
} else {
|
||||
response := state.BulkGetResponse{
|
||||
Key: key,
|
||||
ETag: &etag,
|
||||
}
|
||||
|
||||
if expireTime.Valid {
|
||||
response.Metadata = map[string]string{
|
||||
state.GetRespMetaKeyTTLExpireTime: expireTime.Time.UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
if binaryYN == "Y" {
|
||||
var (
|
||||
s string
|
||||
data []byte
|
||||
)
|
||||
if err = json.Unmarshal([]byte(value), &s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if data, err = base64.StdEncoding.DecodeString(s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response.Data = data
|
||||
} else {
|
||||
response.Data = []byte(value)
|
||||
}
|
||||
|
||||
res[n] = response
|
||||
}
|
||||
|
||||
foundKeys[key] = struct{}{}
|
||||
n++
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Populate missing keys with empty values
|
||||
// This is to ensure consistency with the other state stores that implement BulkGet as a loop over Get, and with the Get method
|
||||
if len(foundKeys) < len(req) {
|
||||
var ok bool
|
||||
for _, r := range req {
|
||||
_, ok = foundKeys[r.Key]
|
||||
if !ok {
|
||||
if n >= len(req) {
|
||||
// Sanity check to prevent panics, which should never happen
|
||||
return nil, fmt.Errorf("query returned more records than expected (expected %d)", len(req))
|
||||
}
|
||||
res[n] = state.BulkGetResponse{
|
||||
Key: r.Key,
|
||||
}
|
||||
n++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res[:n], nil
|
||||
}
|
||||
|
||||
// Delete removes an item from the state store.
|
||||
func (o *oracleDatabaseAccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
|
||||
return o.doDelete(ctx, o.db, req)
|
||||
|
|
Loading…
Reference in New Issue