From 01a3fe76d5830004696b6001ba028721a9236d69 Mon Sep 17 00:00:00 2001 From: Anton Troshin Date: Tue, 29 Apr 2025 12:39:12 -0500 Subject: [PATCH] Add custom BulkGet method to Oracle Statestore (#3804) Signed-off-by: Anton Troshin --- state/oracledatabase/dbaccess.go | 1 + state/oracledatabase/oracledatabase.go | 4 + .../oracledatabase_integration_test.go | 86 +++++++++++++ state/oracledatabase/oracledatabase_test.go | 4 + state/oracledatabase/oracledatabaseaccess.go | 114 ++++++++++++++++++ 5 files changed, 209 insertions(+) diff --git a/state/oracledatabase/dbaccess.go b/state/oracledatabase/dbaccess.go index 15fb4f0d9..98ccb9e28 100644 --- a/state/oracledatabase/dbaccess.go +++ b/state/oracledatabase/dbaccess.go @@ -26,6 +26,7 @@ type dbAccess interface { Ping(ctx context.Context) error Set(ctx context.Context, req *state.SetRequest) error Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) + BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error) Delete(ctx context.Context, req *state.DeleteRequest) error ExecuteMulti(parentCtx context.Context, reqs []state.TransactionalStateOperation) error Close() error // io.Closer. diff --git a/state/oracledatabase/oracledatabase.go b/state/oracledatabase/oracledatabase.go index 890e9e760..e30357277 100644 --- a/state/oracledatabase/oracledatabase.go +++ b/state/oracledatabase/oracledatabase.go @@ -78,6 +78,10 @@ func (o *OracleDatabase) Get(ctx context.Context, req *state.GetRequest) (*state return o.dbaccess.Get(ctx, req) } +func (o *OracleDatabase) BulkGet(ctx context.Context, req []state.GetRequest, opts state.BulkGetOpts) ([]state.BulkGetResponse, error) { + return o.dbaccess.BulkGet(ctx, req) +} + // Set adds/updates an entity on store. func (o *OracleDatabase) Set(ctx context.Context, req *state.SetRequest) error { return o.dbaccess.Set(ctx, req) diff --git a/state/oracledatabase/oracledatabase_integration_test.go b/state/oracledatabase/oracledatabase_integration_test.go index 3136d5f22..b1dbc9bcf 100644 --- a/state/oracledatabase/oracledatabase_integration_test.go +++ b/state/oracledatabase/oracledatabase_integration_test.go @@ -120,6 +120,10 @@ func TestOracleDatabaseIntegration(t *testing.T) { testBulkSetAndBulkDelete(t, ods) }) + t.Run("Bulk get", func(t *testing.T) { + testBulkGet(t, ods) + }) + t.Run("Update and delete with etag succeeds", func(t *testing.T) { updateAndDeleteWithEtagSucceeds(t, ods) }) @@ -647,6 +651,88 @@ func testBulkSetAndBulkDelete(t *testing.T, ods state.Store) { assert.False(t, storeItemExists(t, db, setReq[1].Key)) } +func testBulkGet(t *testing.T, ods state.Store) { + db := getDB(ods) + + setReq := []state.SetRequest{ + { + Key: randomKey(), + Value: &fakeItem{Color: "red"}, + }, + { + Key: randomKey(), + Value: &fakeItem{Color: "blue"}, + }, + { + Key: randomKey(), + Value: &fakeItem{Color: "green"}, + }, + } + + err := ods.BulkSet(t.Context(), setReq, state.BulkStoreOpts{}) + require.NoError(t, err) + assert.True(t, storeItemExists(t, db, setReq[0].Key)) + assert.True(t, storeItemExists(t, db, setReq[1].Key)) + assert.True(t, storeItemExists(t, db, setReq[2].Key)) + + getReq := []state.GetRequest{ + { + Key: setReq[0].Key, + }, + { + Key: setReq[1].Key, + }, + { + Key: setReq[2].Key, + }, + { + Key: randomKey(), // This key doesn't exist + }, + } + + responses, err := ods.BulkGet(t.Context(), getReq, state.BulkGetOpts{}) + require.NoError(t, err) + require.Len(t, responses, 4) + + // Verify the responses + // First three items should exist + for i := range 3 { + assert.Equal(t, getReq[i].Key, responses[i].Key) + assert.NotNil(t, responses[i].Data) + assert.NotNil(t, responses[i].ETag) + + // Verify the data + var item fakeItem + err = json.Unmarshal(responses[i].Data, &item) + require.NoError(t, err) + + // Check the color matches what we set + originalItem := setReq[i].Value.(*fakeItem) + assert.Equal(t, originalItem.Color, item.Color) + } + + // The fourth item should not exist (empty response) + assert.Equal(t, getReq[3].Key, responses[3].Key) + assert.Nil(t, responses[3].Data) + assert.Nil(t, responses[3].ETag) + + // Clean up + deleteReq := []state.DeleteRequest{ + { + Key: setReq[0].Key, + }, + { + Key: setReq[1].Key, + }, + { + Key: setReq[2].Key, + }, + } + + err = ods.BulkDelete(t.Context(), deleteReq, state.BulkStoreOpts{}) + require.NoError(t, err) +} + // testInitConfiguration tests valid and invalid config settings. func testInitConfiguration(t *testing.T) { logger := logger.NewLogger("test") diff --git a/state/oracledatabase/oracledatabase_test.go b/state/oracledatabase/oracledatabase_test.go index 196742353..ae0fd26c8 100644 --- a/state/oracledatabase/oracledatabase_test.go +++ b/state/oracledatabase/oracledatabase_test.go @@ -64,6 +64,10 @@ func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.G return nil, nil } +func (m *fakeDBaccess) BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error) { + return []state.BulkGetResponse{}, nil +} + func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error { return nil } diff --git a/state/oracledatabase/oracledatabaseaccess.go b/state/oracledatabase/oracledatabaseaccess.go index f83db74c8..9a62ce6d4 100644 --- a/state/oracledatabase/oracledatabaseaccess.go +++ b/state/oracledatabase/oracledatabaseaccess.go @@ -312,6 +312,120 @@ func (o *oracleDatabaseAccess) Get(ctx context.Context, req *state.GetRequest) ( }, nil } +func (o *oracleDatabaseAccess) BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error) { + if len(req) == 0 { + return []state.BulkGetResponse{}, nil + } + + // Oracle supports the IN operator for bulk operations + // Build the IN clause with bind variables + // Oracle uses :1, :2, etc. for bind variables in the IN clause + params := make([]any, len(req)) + bindVars := make([]string, len(req)) + for i, r := range req { + if r.Key == "" { + return nil, errors.New("missing key in bulk get operation") + } + params[i] = r.Key + bindVars[i] = ":" + strconv.Itoa(i+1) + } + + inClause := strings.Join(bindVars, ",") + // Concatenation is required for table name because sql.DB does not substitute parameters for table names. + //nolint:gosec + query := "SELECT key, value, binary_yn, etag, expiration_time FROM " + o.metadata.TableName + " WHERE key IN (" + inClause + ") AND (expiration_time IS NULL OR expiration_time > systimestamp)" + + rows, err := o.db.QueryContext(ctx, query, params...) + if err != nil { + return nil, err + } + defer rows.Close() + + var n int + res := make([]state.BulkGetResponse, len(req)) + foundKeys := make(map[string]struct{}, len(req)) + + for rows.Next() { + if n >= len(req) { + // Sanity check to prevent panics, which should never happen + return nil, fmt.Errorf("query returned more records than expected (expected %d)", len(req)) + } + + var ( + key string + value string + binaryYN string + etag string + expireTime sql.NullTime + ) + + err = rows.Scan(&key, &value, &binaryYN, &etag, &expireTime) + if err != nil { + res[n] = state.BulkGetResponse{ + Key: key, + Error: err.Error(), + } + } else { + response := state.BulkGetResponse{ + Key: key, + ETag: &etag, + } + + if expireTime.Valid { + response.Metadata = map[string]string{ + state.GetRespMetaKeyTTLExpireTime: expireTime.Time.UTC().Format(time.RFC3339), + } + } + + if binaryYN == "Y" { + var ( + s string + data []byte + ) + if err = json.Unmarshal([]byte(value), &s); err != nil { + return nil, err + } + if data, err = base64.StdEncoding.DecodeString(s); err != nil { + return nil, err + } + response.Data = data + } else { + response.Data = []byte(value) + } + + res[n] = response + } + + foundKeys[key] = struct{}{} + n++ + } + + if err = rows.Err(); err != nil { + return nil, err + } + + // Populate missing keys with empty values + // This is to ensure consistency with the other state stores that implement BulkGet as a loop over Get, and with the Get method + if len(foundKeys) < len(req) { + var ok bool + for _, r := range req { + _, ok = foundKeys[r.Key] + if !ok { + if n >= len(req) { + // Sanity check to prevent panics, which should never happen + return nil, fmt.Errorf("query returned more records than expected (expected %d)", len(req)) + } + res[n] = state.BulkGetResponse{ + Key: r.Key, + } + n++ + } + } + } + + return res[:n], nil +} + // Delete removes an item from the state store. func (o *oracleDatabaseAccess) Delete(ctx context.Context, req *state.DeleteRequest) error { return o.doDelete(ctx, o.db, req)