Add custom BulkGet method to Oracle Statestore (#3804)

Signed-off-by: Anton Troshin <anton@diagrid.io>
This commit is contained in:
Anton Troshin 2025-04-29 12:39:12 -05:00 committed by GitHub
parent 397766a23e
commit 01a3fe76d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 209 additions and 0 deletions

View File

@ -26,6 +26,7 @@ type dbAccess interface {
Ping(ctx context.Context) error
Set(ctx context.Context, req *state.SetRequest) error
Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error)
BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error)
Delete(ctx context.Context, req *state.DeleteRequest) error
ExecuteMulti(parentCtx context.Context, reqs []state.TransactionalStateOperation) error
Close() error // io.Closer.

View File

@ -78,6 +78,10 @@ func (o *OracleDatabase) Get(ctx context.Context, req *state.GetRequest) (*state
return o.dbaccess.Get(ctx, req)
}
func (o *OracleDatabase) BulkGet(ctx context.Context, req []state.GetRequest, opts state.BulkGetOpts) ([]state.BulkGetResponse, error) {
return o.dbaccess.BulkGet(ctx, req)
}
// Set adds/updates an entity on store.
func (o *OracleDatabase) Set(ctx context.Context, req *state.SetRequest) error {
return o.dbaccess.Set(ctx, req)

View File

@ -120,6 +120,10 @@ func TestOracleDatabaseIntegration(t *testing.T) {
testBulkSetAndBulkDelete(t, ods)
})
t.Run("Bulk get", func(t *testing.T) {
testBulkGet(t, ods)
})
t.Run("Update and delete with etag succeeds", func(t *testing.T) {
updateAndDeleteWithEtagSucceeds(t, ods)
})
@ -647,6 +651,88 @@ func testBulkSetAndBulkDelete(t *testing.T, ods state.Store) {
assert.False(t, storeItemExists(t, db, setReq[1].Key))
}
func testBulkGet(t *testing.T, ods state.Store) {
db := getDB(ods)
setReq := []state.SetRequest{
{
Key: randomKey(),
Value: &fakeItem{Color: "red"},
},
{
Key: randomKey(),
Value: &fakeItem{Color: "blue"},
},
{
Key: randomKey(),
Value: &fakeItem{Color: "green"},
},
}
err := ods.BulkSet(t.Context(), setReq, state.BulkStoreOpts{})
require.NoError(t, err)
assert.True(t, storeItemExists(t, db, setReq[0].Key))
assert.True(t, storeItemExists(t, db, setReq[1].Key))
assert.True(t, storeItemExists(t, db, setReq[2].Key))
getReq := []state.GetRequest{
{
Key: setReq[0].Key,
},
{
Key: setReq[1].Key,
},
{
Key: setReq[2].Key,
},
{
Key: randomKey(), // This key doesn't exist
},
}
responses, err := ods.BulkGet(t.Context(), getReq, state.BulkGetOpts{})
require.NoError(t, err)
require.Len(t, responses, 4)
// Verify the responses
// First three items should exist
for i := range 3 {
assert.Equal(t, getReq[i].Key, responses[i].Key)
assert.NotNil(t, responses[i].Data)
assert.NotNil(t, responses[i].ETag)
// Verify the data
var item fakeItem
err = json.Unmarshal(responses[i].Data, &item)
require.NoError(t, err)
// Check the color matches what we set
originalItem := setReq[i].Value.(*fakeItem)
assert.Equal(t, originalItem.Color, item.Color)
}
// The fourth item should not exist (empty response)
assert.Equal(t, getReq[3].Key, responses[3].Key)
assert.Nil(t, responses[3].Data)
assert.Nil(t, responses[3].ETag)
// Clean up
deleteReq := []state.DeleteRequest{
{
Key: setReq[0].Key,
},
{
Key: setReq[1].Key,
},
{
Key: setReq[2].Key,
},
}
err = ods.BulkDelete(t.Context(), deleteReq, state.BulkStoreOpts{})
require.NoError(t, err)
}
// testInitConfiguration tests valid and invalid config settings.
func testInitConfiguration(t *testing.T) {
logger := logger.NewLogger("test")

View File

@ -64,6 +64,10 @@ func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.G
return nil, nil
}
func (m *fakeDBaccess) BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error) {
return []state.BulkGetResponse{}, nil
}
func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
return nil
}

View File

@ -312,6 +312,120 @@ func (o *oracleDatabaseAccess) Get(ctx context.Context, req *state.GetRequest) (
}, nil
}
func (o *oracleDatabaseAccess) BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error) {
if len(req) == 0 {
return []state.BulkGetResponse{}, nil
}
// Oracle supports the IN operator for bulk operations
// Build the IN clause with bind variables
// Oracle uses :1, :2, etc. for bind variables in the IN clause
params := make([]any, len(req))
bindVars := make([]string, len(req))
for i, r := range req {
if r.Key == "" {
return nil, errors.New("missing key in bulk get operation")
}
params[i] = r.Key
bindVars[i] = ":" + strconv.Itoa(i+1)
}
inClause := strings.Join(bindVars, ",")
// Concatenation is required for table name because sql.DB does not substitute parameters for table names.
//nolint:gosec
query := "SELECT key, value, binary_yn, etag, expiration_time FROM " + o.metadata.TableName + " WHERE key IN (" + inClause + ") AND (expiration_time IS NULL OR expiration_time > systimestamp)"
rows, err := o.db.QueryContext(ctx, query, params...)
if err != nil {
return nil, err
}
defer rows.Close()
var n int
res := make([]state.BulkGetResponse, len(req))
foundKeys := make(map[string]struct{}, len(req))
for rows.Next() {
if n >= len(req) {
// Sanity check to prevent panics, which should never happen
return nil, fmt.Errorf("query returned more records than expected (expected %d)", len(req))
}
var (
key string
value string
binaryYN string
etag string
expireTime sql.NullTime
)
err = rows.Scan(&key, &value, &binaryYN, &etag, &expireTime)
if err != nil {
res[n] = state.BulkGetResponse{
Key: key,
Error: err.Error(),
}
} else {
response := state.BulkGetResponse{
Key: key,
ETag: &etag,
}
if expireTime.Valid {
response.Metadata = map[string]string{
state.GetRespMetaKeyTTLExpireTime: expireTime.Time.UTC().Format(time.RFC3339),
}
}
if binaryYN == "Y" {
var (
s string
data []byte
)
if err = json.Unmarshal([]byte(value), &s); err != nil {
return nil, err
}
if data, err = base64.StdEncoding.DecodeString(s); err != nil {
return nil, err
}
response.Data = data
} else {
response.Data = []byte(value)
}
res[n] = response
}
foundKeys[key] = struct{}{}
n++
}
if err = rows.Err(); err != nil {
return nil, err
}
// Populate missing keys with empty values
// This is to ensure consistency with the other state stores that implement BulkGet as a loop over Get, and with the Get method
if len(foundKeys) < len(req) {
var ok bool
for _, r := range req {
_, ok = foundKeys[r.Key]
if !ok {
if n >= len(req) {
// Sanity check to prevent panics, which should never happen
return nil, fmt.Errorf("query returned more records than expected (expected %d)", len(req))
}
res[n] = state.BulkGetResponse{
Key: r.Key,
}
n++
}
}
}
return res[:n], nil
}
// Delete removes an item from the state store.
func (o *oracleDatabaseAccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
return o.doDelete(ctx, o.db, req)